Skip to content

Commit

Permalink
[Pytorch] Check gradient in test numerics (#1229)
Browse files Browse the repository at this point in the history
* update test numerics

Signed-off-by: Pawel Gadzinski <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* update test numerics

Signed-off-by: Pawel Gadzinski <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* update test numerics

Signed-off-by: Pawel Gadzinski <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Update tests/pytorch/test_numerics.py

Co-authored-by: Tim Moon <[email protected]>
Signed-off-by: Paweł Gadziński <[email protected]>

* tests fix

Signed-off-by: Pawel Gadzinski <[email protected]>

* Not passing CI fixes

Signed-off-by: Pawel Gadzinski <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Not passing CI fixes

Signed-off-by: Pawel Gadzinski <[email protected]>

* Fix key

Co-authored-by: Tim Moon <[email protected]>
Signed-off-by: Kirthi Shankar Sivamani <[email protected]>

* fixes

Signed-off-by: Pawel Gadzinski <[email protected]>

---------

Signed-off-by: Pawel Gadzinski <[email protected]>
Signed-off-by: Paweł Gadziński <[email protected]>
Signed-off-by: Kirthi Shankar Sivamani <[email protected]>
Co-authored-by: Pawel Gadzinski <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Tim Moon <[email protected]>
Co-authored-by: Kirthi Shankar Sivamani <[email protected]>
  • Loading branch information
5 people authored Oct 24, 2024
1 parent 7a5fd0c commit 7b284fe
Showing 1 changed file with 143 additions and 45 deletions.
188 changes: 143 additions & 45 deletions tests/pytorch/test_numerics.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ def __init__(self, hidden_size, eps, num_attention_heads, embed, num_layers, seq


model_configs = {
"small": ModelConfig(128, 1e-5, 8, 36, 4, 128),
"126m": ModelConfig(768, 1e-5, 12, 64, 12, 2048),
}

Expand Down Expand Up @@ -110,23 +111,30 @@ def dtype_tols(dtype: torch.dtype) -> Dict[str, float]:


def assert_allclose(
l1: List[torch.Tensor],
l2: List[torch.Tensor],
atol: float,
l1: List[torch.Tensor], l2: List[torch.Tensor], atol: float, rtol: float = None
) -> bool:
"""Ensures two lists are equal."""
assert len(l1) == len(l2), "Unequal number of outputs."
for i, (t1, t2) in enumerate(zip(l1, l2)):
result = torch.allclose(t1, t2, atol=atol)
tols = dict(atol=atol)
if rtol is not None:
tols["rtol"] = rtol
result = torch.allclose(t1, t2, **tols)
if not result:
diff = torch.abs(t1 - t2).flatten()
m = torch.argmax(diff)
msg = (
f"Outputs not close enough in tensor at idx={i}. "
f"Location of the maximum difference: {m.item()} "
f"with {t1.flatten()[m].item()} vs {t2.flatten()[m].item()} "
f"(diff {diff[m].item()})."
)
diff = torch.abs(t1 - t2)
tol = atol + (rtol * torch.abs(t2))
exceed_mask = diff > tol
if exceed_mask.any():
indices = torch.nonzero(exceed_mask, as_tuple=True)
max_diff = diff[exceed_mask].max()
max_idx = (diff[exceed_mask] == max_diff).nonzero(as_tuple=True)[0][0]
max_location = [idx[max_idx].item() for idx in indices]
msg = (
f"Outputs not close enough in tensor at idx={i}. "
f"Maximum difference at location {max_location} "
f"with {t1[exceed_mask][max_idx].item()} vs {t2[exceed_mask][max_idx].item()} "
f"(diff {max_diff.item()})."
)
raise AssertionError(msg)


Expand Down Expand Up @@ -526,7 +534,7 @@ def _test_e2e_selective_recompute(bs, dtype, config, fp8, fp8_model_params=False

@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model", model_configs.keys())
@pytest.mark.parametrize("model", ["126m"])
@pytest.mark.parametrize("fp8", all_boolean)
@pytest.mark.parametrize("fp8_model_params", all_boolean)
def test_gpt_selective_activation_recompute(dtype, bs, model, fp8, fp8_model_params):
Expand Down Expand Up @@ -631,7 +639,7 @@ def _test_e2e_full_recompute(

@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model", model_configs.keys())
@pytest.mark.parametrize("model", ["126m"])
@pytest.mark.parametrize("fp8", all_boolean)
@pytest.mark.parametrize("fp8_model_params", all_boolean)
@pytest.mark.parametrize("use_reentrant", all_boolean)
Expand Down Expand Up @@ -764,7 +772,7 @@ def _test_e2e_checkpointing(bs, dtype, config, checkpoint=False, steps=10, path=

@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model", model_configs.keys())
@pytest.mark.parametrize("model", ["126m"])
def test_gpt_checkpointing(dtype, bs, model):
config = model_configs[model]
outputs = _test_e2e_checkpointing(bs, dtype, config, checkpoint=False)
Expand Down Expand Up @@ -809,7 +817,7 @@ def _test_e2e_gpt_accuracy(block, bs, dtype, config):

@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model", model_configs.keys())
@pytest.mark.parametrize("model", ["small"])
@pytest.mark.parametrize("parallel_attention_mlp", all_boolean)
def test_gpt_accuracy(dtype, bs, model, parallel_attention_mlp):
config = model_configs[model]
Expand Down Expand Up @@ -868,11 +876,25 @@ def test_gpt_accuracy(dtype, bs, model, parallel_attention_mlp):
te_outputs = _test_e2e_gpt_accuracy(te_gpt, bs, dtype, config)
torch_outputs = _test_e2e_gpt_accuracy(torch_gpt, bs, dtype, config)

atol = {
torch.float32: 5e-3,
torch.half: 5e-2,
torch.bfloat16: 1e-1,
}

# Check output.
if dtype == torch.float32:
assert_allclose(te_outputs[0], torch_outputs[0], 5e-3)
else:
assert_allclose(te_outputs[0], torch_outputs[0], 5e-2)
assert_allclose(te_outputs[0], torch_outputs[0], atol[dtype])

# Check gradients, only for small model
if model == "small":
atol[torch.float32] = 5e-2
rtol = {
torch.float32: 1e-2,
torch.half: 1e-2,
torch.bfloat16: 1e-2,
}
for te_output, torch_output in zip(te_outputs[1:], torch_outputs[1:]):
assert_allclose(te_output, torch_output, atol[dtype], rtol[dtype])


def _test_mha_accuracy(block, bs, dtype, config, mask_type, te=True):
Expand Down Expand Up @@ -906,7 +928,7 @@ def _test_mha_accuracy(block, bs, dtype, config, mask_type, te=True):

@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model", model_configs.keys())
@pytest.mark.parametrize("model", ["small"])
@pytest.mark.parametrize("mask_type", mask_types)
def test_mha_accuracy(dtype, bs, model, mask_type):
config = model_configs[model]
Expand Down Expand Up @@ -947,6 +969,21 @@ def test_mha_accuracy(dtype, bs, model, mask_type):
else:
assert_allclose(te_outputs[0], torch_outputs[0], 5e-2)

# Check gradients, only for small model
if model == "small":
atol = {
torch.float32: 5e-2,
torch.half: 5e-2,
torch.bfloat16: 5e-2,
}
rtol = {
torch.float32: 1e-2,
torch.half: 1e-2,
torch.bfloat16: 1e-2,
}
for te_output, torch_output in zip(te_outputs[1:], torch_outputs[1:]):
assert_allclose(te_output, torch_output, atol[dtype], rtol[dtype])


def _test_granular_accuracy(block, bs, dtype, config):
reset_rng_states()
Expand Down Expand Up @@ -1002,7 +1039,7 @@ def _test_dpa_accuracy(block, bs, dtype, config):

@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model", model_configs.keys())
@pytest.mark.parametrize("model", ["126m"])
def test_dpa_accuracy(dtype, bs, model):
config = model_configs[model]

Expand Down Expand Up @@ -1034,10 +1071,13 @@ def test_dpa_accuracy(dtype, bs, model):
else:
assert_allclose(te_outputs[0], torch_outputs[0], 5e-2)

for te_output, torch_output in zip(te_outputs[1:], torch_outputs[1:]):
assert_allclose(te_output, torch_output, atol=5e-2, rtol=1e-2)


@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model", model_configs.keys())
@pytest.mark.parametrize("model", ["small"])
def test_linear_accuracy(dtype, bs, model):
config = model_configs[model]

Expand Down Expand Up @@ -1066,15 +1106,20 @@ def test_linear_accuracy(dtype, bs, model):
torch_outputs = _test_granular_accuracy(torch_linear, bs, dtype, config)

# Check output.
if dtype == torch.float32:
assert_allclose(te_outputs[0], torch_outputs[0], 5e-3)
else:
assert_allclose(te_outputs[0], torch_outputs[0], 5e-2)
if model == "small":
tolerance = 5e-3 if dtype == torch.float32 else 5e-2
rtol = {
torch.float32: 1.3e-6,
torch.half: 1e-2,
torch.bfloat16: 2e-2,
}
for te_output, torch_output in zip(te_outputs, torch_outputs):
assert_allclose(te_output, torch_output, tolerance, rtol[dtype])


@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model", model_configs.keys())
@pytest.mark.parametrize("model", ["126m"])
@pytest.mark.parametrize("eps", [1e-1, 1e-3, 1e-5, 1e-7])
@pytest.mark.parametrize("zero_centered_gamma", all_boolean)
def test_rmsnorm_accuracy(dtype, bs, model, eps, zero_centered_gamma):
Expand Down Expand Up @@ -1102,18 +1147,29 @@ def test_rmsnorm_accuracy(dtype, bs, model, eps, zero_centered_gamma):
te_outputs = _test_granular_accuracy(te_rmsnorm, bs, dtype, config)
torch_outputs = _test_granular_accuracy(torch_rmsnorm, bs, dtype, config)

# Check output.
atol = {
torch.float32: 1e-7,
torch.half: 2e-3,
torch.bfloat16: 2e-2,
}

# Check output.
assert_allclose(te_outputs[0], torch_outputs[0], atol[dtype])

atol[torch.float32] = 2e-3
rtol = {
torch.float32: 1.3e-6,
torch.half: 1e-3,
torch.bfloat16: 1.6e-2,
}
# Check gradients
for te_output, torch_output in zip(te_outputs[1:], torch_outputs[1:]):
assert_allclose(te_output, torch_output, atol[dtype], rtol[dtype])


@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model", model_configs.keys())
@pytest.mark.parametrize("model", ["126m"])
@pytest.mark.parametrize("eps", [1e-1, 1e-3, 1e-5, 1e-7])
@pytest.mark.parametrize("zero_centered_gamma", all_boolean)
def test_layernorm_accuracy(dtype, bs, model, eps, zero_centered_gamma):
Expand Down Expand Up @@ -1142,18 +1198,29 @@ def test_layernorm_accuracy(dtype, bs, model, eps, zero_centered_gamma):
te_outputs = _test_granular_accuracy(te_layernorm, bs, dtype, config)
torch_outputs = _test_granular_accuracy(torch_layernorm, bs, dtype, config)

# Check output.
atol = {
torch.float32: 1e-7,
torch.half: 2e-3,
torch.bfloat16: 2e-2,
}

# Check output.
assert_allclose(te_outputs[0], torch_outputs[0], atol[dtype])

rtol = {
torch.float32: 1.3e-6,
torch.half: 1e-3,
torch.bfloat16: 1.6e-2,
}
atol[torch.float32] = 1e-4
# Check gradients
for te_output, torch_output in zip(te_outputs[1:], torch_outputs[1:]):
assert_allclose(te_output, torch_output, atol[dtype], rtol[dtype])


@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model", model_configs.keys())
@pytest.mark.parametrize("model", ["small"])
@pytest.mark.parametrize("normalization", all_normalizations)
@pytest.mark.parametrize("zero_centered_gamma", all_boolean)
def test_layernorm_linear_accuracy(dtype, bs, model, normalization, zero_centered_gamma):
Expand Down Expand Up @@ -1195,18 +1262,34 @@ def test_layernorm_linear_accuracy(dtype, bs, model, normalization, zero_centere
te_outputs = _test_granular_accuracy(te_ln_linear, bs, dtype, config)
torch_outputs = _test_granular_accuracy(torch_ln_linear, bs, dtype, config)

# Check output.
atol = {
torch.float32: 2.5e-4,
torch.half: 2e-3,
torch.bfloat16: 2e-2,
}

# Check output.
assert_allclose(te_outputs[0], torch_outputs[0], atol[dtype])

if model == "small":
atol = {
torch.float32: 1e-3,
torch.half: 5e-2,
torch.bfloat16: 5e-2,
}
rtol = {
torch.float32: 1e-3,
torch.half: 4e-2,
torch.bfloat16: 4e-2,
}
# Check gradients
for te_output, torch_output in zip(te_outputs[1:], torch_outputs[1:]):
assert_allclose(te_output, torch_output, atol[dtype], rtol[dtype])


@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model", model_configs.keys())
@pytest.mark.parametrize("model", ["small"])
@pytest.mark.parametrize("activation", all_activations)
@pytest.mark.parametrize("normalization", all_normalizations)
def test_layernorm_mlp_accuracy(dtype, bs, model, activation, normalization):
Expand Down Expand Up @@ -1246,11 +1329,26 @@ def test_layernorm_mlp_accuracy(dtype, bs, model, activation, normalization):
te_outputs = _test_granular_accuracy(te_ln_mlp, bs, dtype, config)
torch_outputs = _test_granular_accuracy(torch_ln_mlp, bs, dtype, config)

atol = {
torch.float32: 2e-2,
torch.half: 5e-2,
torch.bfloat16: 5e-2,
}

# Check output.
if dtype == torch.float32:
assert_allclose(te_outputs[0], torch_outputs[0], 5e-3)
else:
assert_allclose(te_outputs[0], torch_outputs[0], 5e-2)
assert_allclose(te_outputs[0], torch_outputs[0], atol[dtype])

# Check gradients, only for small model
rtol = {
torch.float32: 1e-3,
torch.half: 1e-2,
torch.bfloat16: 4e-2,
}
atol[torch.half] = 2e-1
atol[torch.bfloat16] = 2e-1
if model == "small":
for te_output, torch_output in zip(te_outputs[1:], torch_outputs[1:]):
assert_allclose(te_output, torch_output, atol[dtype], rtol[dtype])


def _test_grouped_linear_accuracy(block, num_gemms, bs, dtype, config, fp8=False):
Expand Down Expand Up @@ -1301,7 +1399,7 @@ def _test_grouped_linear_accuracy(block, num_gemms, bs, dtype, config, fp8=False
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("num_gemms", [3, 6])
@pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model", model_configs.keys())
@pytest.mark.parametrize("model", ["126m"])
@pytest.mark.parametrize("fp8", all_boolean)
@pytest.mark.parametrize("fp8_model_params", all_boolean)
def test_grouped_linear_accuracy(
Expand Down Expand Up @@ -1361,7 +1459,7 @@ def test_grouped_linear_accuracy_parallel_mode(parallel_mode):
dtype=torch.float32,
num_gemms=6,
bs=2,
model=list(model_configs.keys())[0],
model="126m",
fp8=True,
fp8_model_params=True,
parallel_mode=parallel_mode,
Expand All @@ -1374,7 +1472,7 @@ def test_grouped_linear_accuracy_single_gemm():
dtype=torch.float32,
num_gemms=1,
bs=2,
model=list(model_configs.keys())[0],
model="126m",
fp8=True,
fp8_model_params=True,
)
Expand Down Expand Up @@ -1475,7 +1573,7 @@ def _generate_random_numbers(n, total_sum):
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("num_gemms", [3, 6])
@pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model", model_configs.keys())
@pytest.mark.parametrize("model", ["126m"])
@pytest.mark.parametrize("fp8", [True])
@pytest.mark.parametrize("fp8_model_params", all_boolean)
def test_padding_grouped_linear_accuracy(
Expand Down Expand Up @@ -1594,7 +1692,7 @@ def train_step():

@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model", model_configs.keys())
@pytest.mark.parametrize("model", ["126m"])
def test_gpt_cuda_graph(dtype, bs, model):
config = model_configs[model]

Expand Down Expand Up @@ -1686,7 +1784,7 @@ def _test_gpt_fp8_parameters(bs, dtype, config, fp8_model_params):

@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model", model_configs.keys())
@pytest.mark.parametrize("model", ["126m"])
def test_gpt_fp8_parameters(dtype, bs, model):
if not fp8_available:
pytest.skip(reason_for_no_fp8)
Expand All @@ -1710,7 +1808,7 @@ def test_gpt_fp8_parameters(dtype, bs, model):

@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model", model_configs.keys())
@pytest.mark.parametrize("model", ["126m"])
def test_transformer_layer_hidden_states_format(dtype, bs, model):
config = model_configs[model]

Expand Down

0 comments on commit 7b284fe

Please sign in to comment.