Skip to content

Commit

Permalink
Remove useless code. Now it is always a PartitionSpec
Browse files Browse the repository at this point in the history
Signed-off-by: Frederic Bastien <[email protected]>
  • Loading branch information
nouiz committed Feb 25, 2024
1 parent 91b5149 commit fc97c70
Showing 1 changed file with 28 additions and 28 deletions.
56 changes: 28 additions & 28 deletions transformer_engine/jax/cpp_extensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1130,7 +1130,7 @@ def forward_partition(impl, scale_factor, mesh, arg_infos, result_infos):
softmax_forward partitioning
"""
del result_infos
logits_spec = NamedSharding(mesh, PartitionSpec(*get_spec(arg_infos[0])))
logits_spec = NamedSharding(mesh, get_spec(arg_infos[0]))
out_spec = logits_spec
arg_shardings = (logits_spec,)
out_shardings = out_spec
Expand Down Expand Up @@ -1223,8 +1223,8 @@ def backward_partition(impl, scale_factor, mesh, arg_infos, result_infos):
softmax_backward partition
"""
del result_infos
dz_spec = NamedSharding(mesh, PartitionSpec(*get_spec(arg_infos[0])))
softmax_out_spec = NamedSharding(mesh, PartitionSpec(*get_spec(arg_infos[1])))
dz_spec = NamedSharding(mesh, get_spec(arg_infos[0]))
softmax_out_spec = NamedSharding(mesh, get_spec(arg_infos[1]))
dx_spec = softmax_out_spec
arg_shardings = (dz_spec, softmax_out_spec)
out_shardings = dx_spec
Expand Down Expand Up @@ -1509,8 +1509,8 @@ def infer_sharding_from_operands(scale_factor, mesh, arg_infos, result_infos):
@staticmethod
def partition(scale_factor, mesh, arg_infos, result_infos):
del result_infos
logits_spec = NamedSharding(mesh, PartitionSpec(*get_spec(arg_infos[0])))
mask_spec = NamedSharding(mesh, PartitionSpec(*get_spec(arg_infos[1])))
logits_spec = NamedSharding(mesh, get_spec(arg_infos[0]))
mask_spec = NamedSharding(mesh, get_spec(arg_infos[1]))
arg_shardings = (logits_spec, mask_spec)
out_shardings = logits_spec
impl = partial(ScaledMaskedSoftmaxFwdPrimitive.impl, scale_factor=scale_factor)
Expand Down Expand Up @@ -3297,7 +3297,7 @@ def partition(mesh, arg_infos, result_infos):
dgelu partition
"""
del result_infos
dx_sharding = NamedSharding(mesh, PartitionSpec(*get_spec(arg_infos[1])))
dx_sharding = NamedSharding(mesh, get_spec(arg_infos[1]))
arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos)
out_shardings = dx_sharding
impl = DGeluPrimitive.impl
Expand Down Expand Up @@ -3529,7 +3529,7 @@ def partition(mesh, arg_infos, result_infos):
dgated_gelu partition
"""
del result_infos
dx_sharding = NamedSharding(mesh, PartitionSpec(*get_spec(arg_infos[1])))
dx_sharding = NamedSharding(mesh, get_spec(arg_infos[1]))
arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos)
out_shardings = dx_sharding
impl = DgatedGeluPrimitive.impl
Expand Down Expand Up @@ -3707,7 +3707,7 @@ def infer_sharding_from_operands(out_dtype, static_axis_boundary, transpose_axis
casted_x_sharding = NamedSharding(mesh, PartitionSpec(*x_spec))
xt_spec = _multidim_transpose(x_spec, static_axis_boundary, transpose_axis_boundary)
casted_transposed_x_sharding = NamedSharding(mesh, PartitionSpec(*xt_spec))
amax_sharding = NamedSharding(mesh, PartitionSpec(*get_spec(arg_infos[1])))
amax_sharding = NamedSharding(mesh, get_spec(arg_infos[1]))
return (casted_x_sharding, casted_transposed_x_sharding, amax_sharding)

@staticmethod
Expand All @@ -3718,7 +3718,7 @@ def partition(out_dtype, static_axis_boundary, transpose_axis_boundary, mesh, ar
casted_x_sharding = NamedSharding(mesh, PartitionSpec(*x_spec))
xt_spec = _multidim_transpose(x_spec, static_axis_boundary, transpose_axis_boundary)
casted_transposed_x_sharding = NamedSharding(mesh, PartitionSpec(*xt_spec))
amax_sharding = NamedSharding(mesh, PartitionSpec(*get_spec(arg_infos[1])))
amax_sharding = NamedSharding(mesh, get_spec(arg_infos[1]))
arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos)
out_shardings = (casted_x_sharding, casted_transposed_x_sharding, amax_sharding)

Expand Down Expand Up @@ -3848,15 +3848,15 @@ def infer_sharding_from_operands(out_dtype, mesh, arg_infos, result_infos):
del out_dtype, result_infos
x_spec = get_spec(arg_infos[0])
casted_x_sharding = NamedSharding(mesh, PartitionSpec(*x_spec))
amax_sharding = NamedSharding(mesh, PartitionSpec(*get_spec(arg_infos[1])))
amax_sharding = NamedSharding(mesh, get_spec(arg_infos[1]))
return (casted_x_sharding, amax_sharding)

@staticmethod
def partition(out_dtype, mesh, arg_infos, result_infos):
del result_infos
x_spec = get_spec(arg_infos[0])
casted_x_sharding = NamedSharding(mesh, PartitionSpec(*x_spec))
amax_sharding = NamedSharding(mesh, PartitionSpec(*get_spec(arg_infos[1])))
amax_sharding = NamedSharding(mesh, get_spec(arg_infos[1]))
arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos)
out_shardings = (casted_x_sharding, amax_sharding)

Expand Down Expand Up @@ -4201,7 +4201,7 @@ def infer_sharding_from_operands(out_dtype, zero_centered_gamma, epsilon, mesh,

out_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-1], None))
mu_sharding = rsigma_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-1]))
amax_sharding = NamedSharding(mesh, PartitionSpec(*get_spec(arg_infos[3])))
amax_sharding = NamedSharding(mesh, get_spec(arg_infos[3]))
return (out_sharding, mu_sharding, rsigma_sharding, amax_sharding)

@staticmethod
Expand All @@ -4215,12 +4215,12 @@ def partition(out_dtype, zero_centered_gamma, epsilon, mesh, arg_infos, result_i
f"and hurt performance."
)
x_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-1], None))
g_sharding = NamedSharding(mesh, PartitionSpec(*get_spec(arg_infos[1])))
b_sharding = NamedSharding(mesh, PartitionSpec(*get_spec(arg_infos[2])))
g_sharding = NamedSharding(mesh, get_spec(arg_infos[1]))
b_sharding = NamedSharding(mesh, get_spec(arg_infos[2]))
out_sharding = x_sharding
mu_sharding = rsigma_sharding = NamedSharding(
mesh, PartitionSpec(*get_spec(arg_infos[0])[:-1]))
amax_sharding = NamedSharding(mesh, PartitionSpec(*get_spec(arg_infos[3])))
mesh, get_spec(arg_infos[0])[:-1])
amax_sharding = NamedSharding(mesh, get_spec(arg_infos[3]))
fp8_meta_sharding = amax_sharding
arg_shardings = (x_sharding, g_sharding, b_sharding) + (fp8_meta_sharding,) * 3
out_shardings = (out_sharding, mu_sharding, rsigma_sharding, amax_sharding)
Expand Down Expand Up @@ -4433,7 +4433,7 @@ def infer_sharding_from_operands(out_dtype, epsilon, mesh, arg_infos, result_inf
)
out_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-1], None))
rsigma_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-1]))
amax_sharding = NamedSharding(mesh, PartitionSpec(*get_spec(arg_infos[2])))
amax_sharding = NamedSharding(mesh, get_spec(arg_infos[2]))
return (out_sharding, rsigma_sharding, amax_sharding)

@staticmethod
Expand All @@ -4447,10 +4447,10 @@ def partition(out_dtype, epsilon, mesh, arg_infos, result_infos):
f"and hurt performance."
)
x_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-1], None))
g_sharding = NamedSharding(mesh, PartitionSpec(*get_spec(arg_infos[1])))
g_sharding = NamedSharding(mesh, get_spec(arg_infos[1]))
out_sharding = x_sharding
rsigma_sharding = NamedSharding(mesh, PartitionSpec(*get_spec(arg_infos[0])[:-1]))
amax_sharding = NamedSharding(mesh, PartitionSpec(*get_spec(arg_infos[2])))
rsigma_sharding = NamedSharding(mesh, get_spec(arg_infos[0])[:-1])
amax_sharding = NamedSharding(mesh, get_spec(arg_infos[2]))
fp8_meta_sharding = amax_sharding
arg_shardings = (x_sharding, g_sharding) + (fp8_meta_sharding,) * 3
out_shardings = (out_sharding, rsigma_sharding, amax_sharding)
Expand Down Expand Up @@ -4585,15 +4585,15 @@ def infer_sharding_from_operands(out_dtype, mesh, arg_infos, result_infos):
del out_dtype, result_infos
x_spec = get_spec(arg_infos[0])
out_sharding = NamedSharding(mesh, PartitionSpec(*x_spec))
amax_sharding = NamedSharding(mesh, PartitionSpec(*get_spec(arg_infos[1])))
amax_sharding = NamedSharding(mesh, get_spec(arg_infos[1]))
return (out_sharding, amax_sharding)

@staticmethod
def partition(out_dtype, mesh, arg_infos, result_infos):
del result_infos
x_spec = get_spec(arg_infos[0])
out_sharding = NamedSharding(mesh, PartitionSpec(*x_spec))
amax_sharding = NamedSharding(mesh, PartitionSpec(*get_spec(arg_infos[1])))
amax_sharding = NamedSharding(mesh, get_spec(arg_infos[1]))
arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos)
out_shardings = (out_sharding, amax_sharding)

Expand Down Expand Up @@ -4789,7 +4789,7 @@ def infer_sharding_from_operands(out_dtype, static_axis_boundary, transpose_axis
tranposed_out_sharding = NamedSharding(mesh, PartitionSpec(*xt_spec))
dbias_shaprding = NamedSharding(
mesh, PartitionSpec(*x_spec[:static_axis_boundary + 1], x_spec[-1]))
amax_sharding = NamedSharding(mesh, PartitionSpec(*get_spec(arg_infos[2])))
amax_sharding = NamedSharding(mesh, get_spec(arg_infos[2]))
return (out_sharding, tranposed_out_sharding, dbias_shaprding, amax_sharding)

@staticmethod
Expand All @@ -4804,7 +4804,7 @@ def partition(out_dtype, static_axis_boundary, transpose_axis_boundary, mesh, ar
dbias_shaprding = NamedSharding(
mesh, PartitionSpec(*x_spec[:static_axis_boundary + 1], x_spec[-1]))

amax_sharding = NamedSharding(mesh, PartitionSpec(*get_spec(arg_infos[2])))
amax_sharding = NamedSharding(mesh, get_spec(arg_infos[2]))
arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos)
out_shardings = (casted_x_sharding, casted_transposed_x_sharding, dbias_shaprding,
amax_sharding)
Expand Down Expand Up @@ -4966,15 +4966,15 @@ def infer_sharding_from_operands(out_dtype, mesh, arg_infos, result_infos):
del out_dtype, result_infos
x_spec = get_spec(arg_infos[0])
out_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-2], x_spec[-1]))
amax_sharding = NamedSharding(mesh, PartitionSpec(*get_spec(arg_infos[1])))
amax_sharding = NamedSharding(mesh, get_spec(arg_infos[1]))
return (out_sharding, amax_sharding)

@staticmethod
def partition(out_dtype, mesh, arg_infos, result_infos):
del result_infos
x_spec = get_spec(arg_infos[0])
out_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-2], x_spec[-1]))
amax_sharding = NamedSharding(mesh, PartitionSpec(*get_spec(arg_infos[1])))
amax_sharding = NamedSharding(mesh, get_spec(arg_infos[1]))
arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos)
out_shardings = (out_sharding, amax_sharding)

Expand Down Expand Up @@ -5129,7 +5129,7 @@ def infer_sharding_from_operands(out_dtype, static_axis_boundary, mesh, arg_info
out_sharding = NamedSharding(mesh, PartitionSpec(*x_spec))
xt_spec = _multidim_transpose(x_spec, static_axis_boundary, -2)
tranposed_out_sharding = NamedSharding(mesh, PartitionSpec(*xt_spec))
amax_sharding = NamedSharding(mesh, PartitionSpec(*get_spec(arg_infos[2])))
amax_sharding = NamedSharding(mesh, get_spec(arg_infos[2]))
return (out_sharding, tranposed_out_sharding, amax_sharding)

@staticmethod
Expand All @@ -5140,7 +5140,7 @@ def partition(out_dtype, static_axis_boundary, mesh, arg_infos, result_infos):
xt_spec = _multidim_transpose(x_spec, static_axis_boundary, -2)
casted_transposed_x_sharding = NamedSharding(mesh, PartitionSpec(*xt_spec))

amax_sharding = NamedSharding(mesh, PartitionSpec(*get_spec(arg_infos[2])))
amax_sharding = NamedSharding(mesh, get_spec(arg_infos[2]))
arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos)
out_shardings = (casted_x_sharding, casted_transposed_x_sharding, amax_sharding)

Expand Down

0 comments on commit fc97c70

Please sign in to comment.