Skip to content

Commit

Permalink
Ln force no weight sharding (#715)
Browse files Browse the repository at this point in the history
* disallow sharding of layernorm learnable parameters; force duplication

Signed-off-by: Keshav <[email protected]>

* fix tests and support tensors for gamma/beta in layernorms

Signed-off-by: Keshav <[email protected]>

* reverting

Signed-off-by: Keshav <[email protected]>

* added tests for rank-1 gamma/beta sharding

Signed-off-by: Keshav <[email protected]>

* fix lint errors

Signed-off-by: Keshav <[email protected]>

---------

Signed-off-by: Keshav <[email protected]>
  • Loading branch information
keshavb96 authored Mar 14, 2024
1 parent 2d0ab27 commit ffa2447
Show file tree
Hide file tree
Showing 3 changed files with 129 additions and 36 deletions.
2 changes: 1 addition & 1 deletion tests/jax/distributed_test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def generate_configs():
if is_devices_enough(2):
configs.append([2, (2,), ('dp'), MeshResource(dp_resource='dp')])
configs.append([2, (2,), ('tp'), MeshResource(tp_resource='tp')])

if is_devices_enough(4):
TP_size = 2
DP_size = 2
Expand Down
77 changes: 55 additions & 22 deletions tests/jax/test_distributed_layernorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#
# See LICENSE for license information.

import warnings
import pytest

import jax
Expand All @@ -20,7 +21,7 @@

class TestDistributedLayernorm:

def generate_inputs(self, shape, mesh_resource, dtype):
def generate_inputs(self, shape, mesh_resource, dtype, shard_weights):
weight_shape = (shape[-1],)

x = random.normal(random.PRNGKey(1124), shape, dtype=dtype)
Expand All @@ -34,7 +35,7 @@ def generate_inputs(self, shape, mesh_resource, dtype):
else:
raise NotImplementedError

g_pspec = b_pspec = PartitionSpec(None)
g_pspec = b_pspec = PartitionSpec(mesh_resource.dp_resource) if shard_weights else PartitionSpec(None)

return (x, gamma, beta), (x_pspec, g_pspec, b_pspec)

Expand All @@ -54,8 +55,9 @@ def generate_collectives_count_ref(self, mesh_resource, ln_type, shape, dtype):
@pytest.mark.parametrize('data_shape', [[32, 128, 1024], [32, 1024]])
@pytest.mark.parametrize('dtype', DTYPES)
@pytest.mark.parametrize('zero_centered_gamma', [False, True])
@pytest.mark.parametrize('shard_weights', [False, True])
def test_layernorm(self, device_count, mesh_shape, mesh_axes, mesh_resource, data_shape, dtype,
zero_centered_gamma):
zero_centered_gamma, shard_weights):
epsilon = 1e-6
ln_type = 'layernorm'

Expand All @@ -74,7 +76,7 @@ def ref_func(x, gamma, beta):
return jnp.mean(output)

(x, gamma, beta), (x_pspec, g_pspec, b_pspec) = \
self.generate_inputs(data_shape, mesh_resource, dtype)
self.generate_inputs(data_shape, mesh_resource, dtype, shard_weights)
collective_count_ref = self.generate_collectives_count_ref(mesh_resource, ln_type,
data_shape, dtype)
devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape)
Expand All @@ -84,19 +86,35 @@ def ref_func(x, gamma, beta):
gamma_ = jax.device_put(gamma, NamedSharding(mesh, g_pspec))
beta_ = jax.device_put(beta, NamedSharding(mesh, b_pspec))

compare_ops(target_func,
ref_func, [x_, gamma_, beta_],
collective_count_ref,
grad_args=(0, 1, 2),
metric_fwd_dtype=dtype,
metric_bwd_dtype=dtype,
in_shardings=(x_pspec, g_pspec, b_pspec),
out_shardings=(None, (x_pspec, g_pspec, b_pspec)))
with warnings.catch_warnings(record=True) as warns:
try:
compare_ops(target_func,
ref_func, [x_, gamma_, beta_],
collective_count_ref,
grad_args=(0, 1, 2),
metric_fwd_dtype=dtype,
metric_bwd_dtype=dtype,
in_shardings=(x_pspec, g_pspec, b_pspec),
out_shardings=(None, (x_pspec, g_pspec, b_pspec)))
except AssertionError as err:
# Layernorm should still produce the correct numerical result with
# gamma/beta sharded. However, the collective count may not be the same
# when XLA is forced to unshard gamma and/or beta. We can catch
# and ignore that specific error here.
if (g_pspec[-1] is None and b_pspec[-1] is None) or "Expected collective count" not in str(err):
raise err
finally:
for w in warns:
assert "Enforcing no sharding of parameters hidden dim!" in str(w), (
"Layernorm primitive did not raise the correct warning for "
"unsupported sharding of gamma and/or beta"
)

@pytest.mark.parametrize('device_count,mesh_shape,mesh_axes,mesh_resource', generate_configs())
@pytest.mark.parametrize('data_shape', [[32, 128, 1024], [32, 1024]])
@pytest.mark.parametrize('dtype', DTYPES)
def test_rmsnorm(self, device_count, mesh_shape, mesh_axes, mesh_resource, data_shape, dtype):
@pytest.mark.parametrize('shard_weights', [False, True])
def test_rmsnorm(self, device_count, mesh_shape, mesh_axes, mesh_resource, data_shape, dtype, shard_weights):
epsilon = 1e-6
ln_type = 'rmsnorm'

Expand All @@ -111,7 +129,7 @@ def ref_func(x, gamma):
return jnp.mean(output)

(x, gamma, _), (x_pspec, g_pspec, _) = \
self.generate_inputs(data_shape, mesh_resource, dtype)
self.generate_inputs(data_shape, mesh_resource, dtype, shard_weights)
collective_count_ref = self.generate_collectives_count_ref(mesh_resource, ln_type,
data_shape, dtype)
devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape)
Expand All @@ -120,11 +138,26 @@ def ref_func(x, gamma):
x_ = jax.device_put(x, NamedSharding(mesh, x_pspec))
gamma_ = jax.device_put(gamma, NamedSharding(mesh, g_pspec))

compare_ops(target_func,
ref_func, [x_, gamma_],
collective_count_ref,
grad_args=(0, 1),
metric_fwd_dtype=dtype,
metric_bwd_dtype=dtype,
in_shardings=(x_pspec, g_pspec),
out_shardings=(None, (x_pspec, g_pspec)))
with warnings.catch_warnings(record=True) as warns:
try:
compare_ops(target_func,
ref_func, [x_, gamma_],
collective_count_ref,
grad_args=(0, 1),
metric_fwd_dtype=dtype,
metric_bwd_dtype=dtype,
in_shardings=(x_pspec, g_pspec),
out_shardings=(None, (x_pspec, g_pspec)))
except AssertionError as err:
# RmsNorm should still produce the correct numerical result with
# gamma/beta sharded. However, the collective count may not be the same
# when XLA is forced to unshard gamma. We can catch
# and ignore that specific error here.
if g_pspec[-1] is None or "Expected collective count" not in str(err):
raise err
finally:
for w in warns:
assert "Enforcing no sharding of parameters hidden dim!" in str(w), (
"RmsNorm primitive did not raise the correct warning for "
"unsupported sharding of gamma and/or beta"
)
86 changes: 73 additions & 13 deletions transformer_engine/jax/cpp_extensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -453,9 +453,21 @@ def partition(zero_centered_gamma, epsilon, mesh, arg_infos, result_infos):
f"Force to not shard the hidden dim, which might introduce extra collective ops, " \
f"and hurt performance."
)
if g_spec[-1] is not None:
warnings.warn(
f"{LayerNormFwdPrimitive.name} does not support sharding of parameter gamma " \
f"Enforcing no sharding of parameters hidden dim! " \
)
if b_spec[-1] is not None:
warnings.warn(
f"{LayerNormFwdPrimitive.name} does not support sharding of parameter beta " \
f"Enforcing no sharding of parameters hidden dim! " \
)


x_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-1], None))
g_sharding = NamedSharding(mesh, PartitionSpec(*g_spec))
b_sharding = NamedSharding(mesh, PartitionSpec(*b_spec))
g_sharding = NamedSharding(mesh, PartitionSpec(None))
b_sharding = NamedSharding(mesh, PartitionSpec(None))
out_sharding = x_sharding
mu_sharding = rsigma_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-1]))

Expand Down Expand Up @@ -628,8 +640,15 @@ def infer_sharding_from_operands(zero_centered_gamma, epsilon, mesh, arg_infos,
f"and hurt performance."
)
g_b_spec = get_padded_spec(arg_infos[4])
if g_b_spec[-1] is not None:
warnings.warn(
f"{LayerNormBwdPrimitive.name} does not support sharding of gradients " \
f"of gamma and beta of Layernorm " \
f"Enforcing no sharding of parameters hidden dim! " \
)

dx_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-1], None))
dgamma_sharding = dbeta_sharding = NamedSharding(mesh, PartitionSpec(*g_b_spec))
dgamma_sharding = dbeta_sharding = NamedSharding(mesh, PartitionSpec(None))
return dx_sharding, dgamma_sharding, dbeta_sharding

@staticmethod
Expand All @@ -643,12 +662,19 @@ def partition(zero_centered_gamma, epsilon, mesh, arg_infos, result_infos):
f"and hurt performance."
)
g_b_spec = get_padded_spec(arg_infos[4])
if g_b_spec[-1] is not None:
warnings.warn(
f"{LayerNormBwdPrimitive.name} does not support sharding of gradients " \
f"of gamma and beta of Layernorm " \
f"Enforcing no sharding of parameters hidden dim! " \
)

dx_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-1], None))
dgamma_sharding = dbeta_sharding = NamedSharding(mesh, PartitionSpec(*g_b_spec))
dgamma_sharding = dbeta_sharding = NamedSharding(mesh, PartitionSpec(None))
out_shardings = dx_sharding, dgamma_sharding, dbeta_sharding
x_shardings = (dx_sharding,) * 2 # dz and x should have the same sharding.
mu_shardings = (NamedSharding(mesh, PartitionSpec(*x_spec[:-1])),) * 2
arg_shardings = (*x_shardings, *mu_shardings, NamedSharding(mesh, PartitionSpec(*g_b_spec)))
arg_shardings = (*x_shardings, *mu_shardings, NamedSharding(mesh, PartitionSpec(None)))

def sharded_impl(dz, x, mu, rsigma, gamma):
local_dx, local_dgamma, local_dbeta = \
Expand Down Expand Up @@ -828,8 +854,14 @@ def partition(epsilon, mesh, arg_infos, result_infos):
f"Force to not shard the hidden dim, which might introduce extra collective ops, " \
f"and hurt performance."
)
if g_spec[-1] is not None:
warnings.warn(
f"{RmsNormFwdPrimitive.name} does not support sharding of parameter gamma " \
f"Enforcing no sharding of parameters hidden dim! " \
)

x_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-1], None))
g_sharding = NamedSharding(mesh, PartitionSpec(*g_spec))
g_sharding = NamedSharding(mesh, PartitionSpec(None))
out_sharding = x_sharding
rsigma_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-1]))
arg_shardings = (x_sharding, g_sharding)
Expand Down Expand Up @@ -982,8 +1014,13 @@ def infer_sharding_from_operands(epsilon, mesh, arg_infos, result_infos):
f"and hurt performance."
)
g_spec = get_padded_spec(arg_infos[3])
if g_spec[-1] is not None:
warnings.warn(
f"{RmsNormBwdPrimitive.name} does not support sharding of parameter gamma " \
f"Enforcing no sharding of parameters hidden dim! " \
)
dx_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-1], None))
dgamma_sharding = NamedSharding(mesh, PartitionSpec(*g_spec))
dgamma_sharding = NamedSharding(mesh, PartitionSpec(None))
return dx_sharding, dgamma_sharding

@staticmethod
Expand All @@ -997,12 +1034,17 @@ def partition(epsilon, mesh, arg_infos, result_infos):
f"and hurt performance."
)
g_spec = get_padded_spec(arg_infos[3])
if g_spec[-1] is not None:
warnings.warn(
f"{RmsNormBwdPrimitive.name} does not support sharding of parameter gamma " \
f"Enforcing no sharding of parameters hidden dim! " \
)
dx_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-1], None))
dgamma_sharding = NamedSharding(mesh, PartitionSpec(*g_spec))
dgamma_sharding = NamedSharding(mesh, PartitionSpec(None))
out_shardings = dx_sharding, dgamma_sharding
x_shardings = (dx_sharding,) * 2 # dz and x should have the same sharding.
rsigma_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-1]))
arg_shardings = (*x_shardings, rsigma_sharding, NamedSharding(mesh, PartitionSpec(*g_spec)))
arg_shardings = (*x_shardings, rsigma_sharding, NamedSharding(mesh, PartitionSpec(None)))

def sharded_impl(dz, x, rsigma, gamma):
local_dx, local_dgamma = \
Expand Down Expand Up @@ -4336,15 +4378,27 @@ def infer_sharding_from_operands(out_dtype, zero_centered_gamma, epsilon, mesh,
def partition(out_dtype, zero_centered_gamma, epsilon, mesh, arg_infos, result_infos):
del result_infos
x_spec = get_padded_spec(arg_infos[0])
g_spec = get_padded_spec(arg_infos[1])
b_spec = get_padded_spec(arg_infos[2])
if x_spec[-1] is not None:
warnings.warn(
f"Does not support to shard hidden dim in {LayerNormFwdPrimitive.name}! " \
f"Does not support to shard hidden dim in {LayerNormFwdFp8Primitive.name}! " \
f"Force to not shard the hidden dim, which might introduce extra collective ops, " \
f"and hurt performance."
)
if g_spec[-1] is not None:
warnings.warn(
f"{LayerNormFwdFp8Primitive.name} does not support sharding of parameter gamma " \
f"Enforcing no sharding of parameters hidden dim! " \
)
if b_spec[-1] is not None:
warnings.warn(
f"{LayerNormFwdFp8Primitive.name} does not support sharding of parameter beta " \
f"Enforcing no sharding of parameters hidden dim! " \
)
x_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-1], None))
g_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[1])))
b_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[2])))
g_sharding = NamedSharding(mesh, PartitionSpec(None))
b_sharding = NamedSharding(mesh, PartitionSpec(None))
out_sharding = x_sharding
mu_sharding = rsigma_sharding = NamedSharding(
mesh, PartitionSpec(*get_padded_spec(arg_infos[0])[:-1]))
Expand Down Expand Up @@ -4568,14 +4622,20 @@ def infer_sharding_from_operands(out_dtype, epsilon, mesh, arg_infos, result_inf
def partition(out_dtype, epsilon, mesh, arg_infos, result_infos):
del result_infos
x_spec = get_padded_spec(arg_infos[0])
g_spec = get_padded_spec(arg_infos[1])
if x_spec[-1] is not None:
warnings.warn(
f"Does not support to shard hidden dim in {RmsNormFwdFp8Primitive.name}! " \
f"Force to not shard the hidden dim, which might introduce extra collective ops, " \
f"and hurt performance."
)
if g_spec[-1] is not None:
warnings.warn(
f"{RmsNormFwdFp8Primitive.name} does not support sharding of parameter gamma " \
f"Enforcing no sharding of parameters hidden dim! " \
)
x_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-1], None))
g_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[1])))
g_sharding = NamedSharding(mesh, PartitionSpec(None))
out_sharding = x_sharding
rsigma_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[0])[:-1]))
amax_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[2])))
Expand Down

0 comments on commit ffa2447

Please sign in to comment.