Skip to content

Commit

Permalink
fix lint errors
Browse files Browse the repository at this point in the history
Signed-off-by: Keshav <[email protected]>
  • Loading branch information
keshavb96 committed Mar 12, 2024
1 parent 253f62a commit 338670d
Showing 1 changed file with 6 additions and 4 deletions.
10 changes: 6 additions & 4 deletions transformer_engine/jax/cpp_extensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -642,7 +642,8 @@ def infer_sharding_from_operands(zero_centered_gamma, epsilon, mesh, arg_infos,
g_b_spec = get_padded_spec(arg_infos[4])
if g_b_spec[-1] is not None:
warnings.warn(
f"{LayerNormBwdPrimitive.name} does not support sharding of gradients of gamma and beta of Layernorm" \
f"{LayerNormBwdPrimitive.name} does not support sharding of gradients " \
f"of gamma and beta of Layernorm " \
f"Enforcing no sharding of parameters hidden dim! " \
)

Expand All @@ -663,10 +664,11 @@ def partition(zero_centered_gamma, epsilon, mesh, arg_infos, result_infos):
g_b_spec = get_padded_spec(arg_infos[4])
if g_b_spec[-1] is not None:
warnings.warn(
f"{LayerNormBwdPrimitive.name} does not support sharding of gradients of gamma and beta of Layernorm" \
f"{LayerNormBwdPrimitive.name} does not support sharding of gradients " \
f"of gamma and beta of Layernorm " \
f"Enforcing no sharding of parameters hidden dim! " \
)

dx_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-1], None))
dgamma_sharding = dbeta_sharding = NamedSharding(mesh, PartitionSpec(None))
out_shardings = dx_sharding, dgamma_sharding, dbeta_sharding
Expand Down Expand Up @@ -857,7 +859,7 @@ def partition(epsilon, mesh, arg_infos, result_infos):
f"{RmsNormFwdPrimitive.name} does not support sharding of parameter gamma " \
f"Enforcing no sharding of parameters hidden dim! " \
)

x_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-1], None))
g_sharding = NamedSharding(mesh, PartitionSpec(None))
out_sharding = x_sharding
Expand Down

0 comments on commit 338670d

Please sign in to comment.