Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PyTorch] Non-reentrant mode for activation recompute #670

Merged
merged 12 commits into from
Feb 24, 2024
Merged
67 changes: 49 additions & 18 deletions tests/pytorch/test_numerics.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import math
import os
import sys
from typing import List, Optional
import pytest
import copy
Expand Down Expand Up @@ -72,22 +73,27 @@ def get_causal_attn_mask(sq: int) -> torch.Tensor:
return torch.triu(torch.ones(sq, sq, device="cuda"), diagonal=1).bool()


def assert_all_equal(l1: List[torch.Tensor], l2: List[torch.Tensor]) -> bool:
def assert_all_equal(l1: List[torch.Tensor], l2: List[torch.Tensor], names=None) -> bool:
"""Ensures two lists are equal."""
assert len(l1) == len(l2), "Unequal number of outputs."
for t1, t2 in zip(l1, l2):
assert torch.equal(t1, t2), "Output mismatch."
failed = False
failed_tensors = ""
for i, (t1, t2) in enumerate(zip(l1, l2)):
if not torch.equal(t1, t2):
failed = True
failed_tensors += f" {names[i]}\n" if names is not None else f" tensor at idx={i}\n"
assert not failed, "Output mismatches in:\n" + failed_tensors


def assert_allclose(l1: List[torch.Tensor], l2: List[torch.Tensor], atol: float) -> bool:
"""Ensures two lists are equal."""
assert len(l1) == len(l2), "Unequal number of outputs."
for t1, t2 in zip(l1, l2):
for i, (t1, t2) in enumerate(zip(l1, l2)):
result = torch.allclose(t1, t2, atol=atol)
if not result:
diff = torch.abs(t1 - t2).flatten()
m = torch.argmax(diff)
msg = (f"Outputs not close enough."
msg = (f"Outputs not close enough in tensor at idx={i}. "
f"Location of the maximum difference: {m.item()} "
f"with {t1.flatten()[m].item()} vs {t2.flatten()[m].item()} "
f"(diff {diff[m].item()})."
Expand Down Expand Up @@ -457,7 +463,12 @@ def test_gpt_selective_activation_recompute(dtype, bs, model, fp8, fp8_model_par
assert_all_equal(outputs, outputs_recompute)


def _test_e2e_full_recompute(bs, dtype, config, fp8, fp8_model_params=False, recompute=False):
def _test_e2e_full_recompute(
bs, dtype, config, fp8,
fp8_model_params=False,
recompute=False,
use_reentrant=True
):
reset_rng_states()
FP8GlobalStateManager.reset()

Expand Down Expand Up @@ -494,21 +505,23 @@ def get_dummy_cuda_rng_tracker():
)

te_inp_hidden_states = torch.randn(
config.seq_len, bs, config.hidden_size, dtype=dtype, requires_grad=True
config.seq_len, bs, config.hidden_size, dtype=dtype, requires_grad=use_reentrant
).cuda()
te_inp_hidden_states.retain_grad()
if use_reentrant:
te_inp_hidden_states.retain_grad()
te_inp_attn_mask = get_causal_attn_mask(config.seq_len)

with fp8_autocast(enabled=fp8):
if recompute:
te_out = te_checkpoint(
block,
False, # distribute_saved_activations
get_dummy_cuda_rng_tracker,
None, # tp_group
te_inp_hidden_states,
attention_mask=te_inp_attn_mask,
checkpoint_core_attention=False,
distribute_saved_activations=False,
tp_group=None,
get_rng_state_tracker=get_dummy_cuda_rng_tracker,
use_reentrant=use_reentrant,
)
else:
te_out = block(
Expand All @@ -520,27 +533,45 @@ def get_dummy_cuda_rng_tracker():
loss.backward()
torch.cuda.synchronize()

outputs = [te_out, te_inp_hidden_states.grad]
for p in block.parameters():
outputs = [te_out]
names = ["output"]
if use_reentrant:
outputs.append(te_inp_hidden_states.grad)
names.append("input")
for name, p in block.named_parameters():
if p.requires_grad:
outputs.append(p.grad)
return outputs
names.append(name)

return outputs, names


@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model", model_configs.keys())
@pytest.mark.parametrize("fp8", all_boolean)
@pytest.mark.parametrize("fp8_model_params", all_boolean)
def test_gpt_full_activation_recompute(dtype, bs, model, fp8, fp8_model_params):
@pytest.mark.parametrize("use_reentrant", all_boolean)
def test_gpt_full_activation_recompute(dtype, bs, model, fp8, fp8_model_params, use_reentrant):
if fp8 and not fp8_available:
pytest.skip(reason_for_no_fp8)

config = model_configs[model]

outputs = _test_e2e_full_recompute(bs, dtype, config, fp8, fp8_model_params, recompute=False)
outputs_recompute = _test_e2e_full_recompute(bs, dtype, config, fp8, fp8_model_params, recompute=True)
assert_all_equal(outputs, outputs_recompute)
if not use_reentrant:
# Non-reentrant checkpoint becomes non-deterministic with bias+GELU fusion
os.environ["NVTE_BIAS_GELU_NVFUSION"] = "0"

outputs, names = _test_e2e_full_recompute(bs, dtype, config, fp8, fp8_model_params,
recompute=False, use_reentrant=use_reentrant)
outputs_recompute, _ = _test_e2e_full_recompute(bs, dtype, config, fp8, fp8_model_params,
recompute=True, use_reentrant=use_reentrant)

if not use_reentrant:
# Reset bias+GELU fusion flag to avoid contaminating other tests
del os.environ["NVTE_BIAS_GELU_NVFUSION"]

assert_all_equal(outputs, outputs_recompute, names=names)


def _test_e2e_checkpointing_get_model(config, dtype):
Expand Down
6 changes: 3 additions & 3 deletions transformer_engine/pytorch/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -2472,9 +2472,9 @@ def custom_forward(*input_args, **input_kwargs):

hidden_states = checkpoint(
custom_forward,
False,
self.get_rng_state_tracker,
self.tp_group,
distribute_saved_activations=False,
get_rng_state_tracker=self.get_rng_state_tracker,
tp_group=self.tp_group,
*forward_args,
**forward_kwargs,
)
Expand Down
Loading
Loading