diff --git a/docs/api/pytorch.rst b/docs/api/pytorch.rst index 9b291e6d0a..c9504c20af 100644 --- a/docs/api/pytorch.rst +++ b/docs/api/pytorch.rst @@ -41,4 +41,6 @@ pyTorch .. autoapifunction:: transformer_engine.pytorch.onnx_export +.. autoapifunction:: transformer_engine.pytorch.make_graphed_callables + .. autoapifunction:: transformer_engine.pytorch.get_cpu_offload_context diff --git a/qa/L0_pytorch_unittest/test.sh b/qa/L0_pytorch_unittest/test.sh index 0b94a8b77e..50f54cd714 100644 --- a/qa/L0_pytorch_unittest/test.sh +++ b/qa/L0_pytorch_unittest/test.sh @@ -9,9 +9,10 @@ set -e pip install pytest==6.2.5 onnxruntime==1.13.1 pytest -v -s $TE_PATH/tests/pytorch/test_sanity.py pytest -v -s $TE_PATH/tests/pytorch/test_deferred_init.py -PYTORCH_JIT=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 pytest -v -s $TE_PATH/tests/pytorch/test_numerics.py +PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 pytest -v -s $TE_PATH/tests/pytorch/test_numerics.py +PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 pytest -v -s $TE_PATH/tests/pytorch/test_cuda_graphs.py pytest -v -s $TE_PATH/tests/pytorch/test_jit.py -pytest -v -s $TE_PATH/tests/pytorch/fused_attn/test_fused_attn.py +NVTE_TORCH_COMPILE=0 pytest -v -s $TE_PATH/tests/pytorch/fused_attn/test_fused_attn.py pytest -v -s $TE_PATH/tests/pytorch/test_fused_rope.py NVTE_TORCH_COMPILE=0 pytest -v -s $TE_PATH/tests/pytorch/test_onnx_export.py pytest -v -s $TE_PATH/tests/pytorch/test_float8tensor.py diff --git a/tests/pytorch/fused_attn/test_fused_attn.py b/tests/pytorch/fused_attn/test_fused_attn.py index 65c3b8269b..b2c8f69ef3 100644 --- a/tests/pytorch/fused_attn/test_fused_attn.py +++ b/tests/pytorch/fused_attn/test_fused_attn.py @@ -5,7 +5,6 @@ import functools from importlib.metadata import version import os -import math from typing import Any, Dict, List, Tuple, Union from pkg_resources import packaging @@ -28,15 +27,9 @@ fused_attn_bwd, fused_attn_fwd, ) -from transformer_engine.pytorch.distributed import ( - _set_cuda_rng_state, - CudaRNGStatesTracker, -) +from transformer_engine.pytorch.distributed import CudaRNGStatesTracker import transformer_engine.pytorch.fp8 as fp8 -from transformer_engine.pytorch.module.base import ( - TransformerEngineBaseModule, - _prepare_backward, -) +from transformer_engine.pytorch.module.base import TransformerEngineBaseModule from transformer_engine.pytorch.utils import ( get_device_compute_capability, init_method_normal, @@ -58,10 +51,18 @@ _NVTE_DEBUG = int(os.getenv("NVTE_DEBUG", "0")) + def reset_rng_states() -> None: """Revert back to initial RNG state""" torch.set_rng_state(_cpu_rng_state) - _set_cuda_rng_state(_cuda_rng_state) + torch.cuda.set_rng_state(_cuda_rng_state) + + +@pytest.fixture(autouse=True) +def reset_global_fp8_state(): + yield + fp8.FP8GlobalStateManager.reset() + @functools.cache def _cudnn_version() -> Tuple[int, int, int]: @@ -71,6 +72,7 @@ def _cudnn_version() -> Tuple[int, int, int]: minor, patch = divmod(encoded_version, 100) return (major, minor, patch) + class ModelConfig: def __init__( self, @@ -103,6 +105,7 @@ def __init__( self.num_layers = num_layers self.bias_shape = bias_shape + def _is_fused_attention_supported( config: ModelConfig, dtype: torch.dtype, @@ -151,24 +154,28 @@ def _is_fused_attention_supported( return True, backends return False, backends + @functools.cache def _is_flash_attention_2_available() -> bool: """Check if flash-attn 2.0+ is available""" Version = packaging.version.Version return Version(version("flash-attn")) >= Version("2") + @functools.cache def _is_flash_attention_2_1() -> bool: """Check if flash-attn 2.1+ is available""" Version = packaging.version.Version return Version(version("flash-attn")) >= Version("2.1") + @functools.cache def _is_flash_attention_2_3() -> bool: """Check if flash-attn 2.3+ is available""" Version = packaging.version.Version return Version(version("flash-attn")) >= Version("2.3") + def _is_flash_attention_supported(config: ModelConfig) -> bool: """Check if FlashAttention supports a model configuration""" if get_device_compute_capability() < (8, 0): @@ -184,6 +191,7 @@ def _is_flash_attention_supported(config: ModelConfig) -> bool: return False return True + def _is_unfused_attention_supported(config: ModelConfig) -> bool: """Check if UnfusedDotProductAttention supports a model configuration""" if ("padding" in config.attn_mask_type): @@ -192,6 +200,7 @@ def _is_unfused_attention_supported(config: ModelConfig) -> bool: return False return True + model_configs_base = { # test: b, h, hg, d, sq, skv, p, mask, bias # attn , backend "base_1_0": ModelConfig(8, 16, 16, 64, 128, 128, 0.0, "no_mask", "no_bias"), # self , 0 @@ -200,11 +209,13 @@ def _is_unfused_attention_supported(config: ModelConfig) -> bool: "base_2_1": ModelConfig(1, 24, 24, 128, 2048, 4096, 0.0, "no_mask", "no_bias"), # cross, 1 } + param_types = [torch.float16] if is_bf16_compatible(): # bf16 requires sm_80 or higher param_types.append(torch.bfloat16) param_types_lean = [torch.bfloat16] + def get_swa(seq_q, seq_kv, w=None): """Generate a random sliding window size (left, right) if w is None, and create its equivalent attention mask in [seq_q, seq_kv] shape""" @@ -216,6 +227,7 @@ def get_swa(seq_q, seq_kv, w=None): ml = ~ ml return w, ml + @pytest.mark.skipif(_cudnn_version() < (8,9,1), reason="cuDNN 8.9.1+ is required.") @pytest.mark.parametrize("dtype", param_types) @pytest.mark.parametrize("model_configs", [model_configs_base]) @@ -313,6 +325,7 @@ def test_dot_product_attention(dtype, model_configs, model, ckpt_attn, workspace for i,_ in enumerate(fused_attn_bwd): torch.testing.assert_close(fused_attn_bwd[i], fused_attn_bwd_1[i], **tols) + @pytest.mark.skipif(_cudnn_version() < (8,9,1), reason="cuDNN 8.9.1+ is required.") @pytest.mark.parametrize("dtype", param_types) @pytest.mark.parametrize("model_configs", [model_configs_base]) @@ -321,6 +334,7 @@ def test_dpa_checkpoint(dtype, model_configs, model): """Test DotProductAttention module with checkpointing""" test_dot_product_attention(dtype, model_configs, model, True, True, None, False) + model_configs_mask = { # test: b, h, hg, d, sq, skv, p, mask, bias "mask_1_0": ModelConfig(8, 16, 16, 64, 128, 128, 0.0, "causal", "no_bias"), @@ -337,6 +351,7 @@ def test_dpa_checkpoint(dtype, model_configs, model): "mask_6_1": ModelConfig(1, 24, 24, 128, 2048, 4096, 0.0, "padding_causal", "no_bias"), } + @pytest.mark.skipif(_cudnn_version() < (8,9,1), reason="cuDNN 8.9.1+ is required.") @pytest.mark.parametrize("dtype", param_types_lean) @pytest.mark.parametrize("model_configs", [model_configs_mask]) @@ -345,6 +360,7 @@ def test_dpa_mask(dtype, model_configs, model): """Test DotProductAttention module with different mask types""" test_dot_product_attention(dtype, model_configs, model, False, True, None, False) + model_configs_bias = { # test: b, h, hg, d, sq, skv, p, mask, bias "bias_1_0": ModelConfig(4, 16, 16, 64, 128, 128, 0.0, "no_mask", "post_scale_bias"), @@ -373,6 +389,7 @@ def test_dpa_mask(dtype, model_configs, model): "bias_4_5": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "padding_causal", "alibi"), # skipped } + @pytest.mark.skipif(_cudnn_version() < (8,9,1), reason="cuDNN 8.9.1+ is required.") @pytest.mark.parametrize("dtype", param_types_lean) @pytest.mark.parametrize("model_configs", [model_configs_bias]) @@ -381,6 +398,7 @@ def test_dpa_bias(dtype, model_configs, model): """Test DotProductAttention module with different bias types""" test_dot_product_attention(dtype, model_configs, model, False, True, None, False) + model_configs_bias_shapes = { # test: b, h, hg, d, sq, skv, p, "bias_1_0": ModelConfig(4, 16, 16, 64, 128, 128, 0.0, @@ -398,6 +416,7 @@ def test_dpa_bias(dtype, model_configs, model): "causal", "alibi", bias_shape='bhss', alibi_type='custom'), } + @pytest.mark.skipif(_cudnn_version() < (8,9,1), reason="cuDNN 8.9.1+ is required.") @pytest.mark.parametrize("dtype", param_types_lean) @pytest.mark.parametrize("model_configs", [model_configs_bias_shapes]) @@ -413,6 +432,8 @@ def test_dpa_bias_shapes(dtype, model_configs, model): "swa_1_2": ModelConfig(4, 24, 24, 128, 2048, 2048, 0.0, "no_mask", "no_bias"), "swa_1_3": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "no_mask", "no_bias"), } + + @pytest.mark.skipif(not _is_flash_attention_2_3(), reason="Flash-attn 2.3+ is required.") @pytest.mark.parametrize("dtype", param_types_lean) @pytest.mark.parametrize("model_configs", [model_configs_swa]) @@ -428,6 +449,8 @@ def test_dpa_sliding_window(dtype, model_configs, model): "alibi_2_0": ModelConfig(2, 24, 24, 128, 1024, 1024, 0.0, "causal", "alibi", alibi_type= "custom"), "alibi_2_1": ModelConfig(1, 24, 24, 128, 1024, 2048, 0.0, "causal", "alibi", alibi_type= "custom"), } + + @pytest.mark.skipif(not _is_flash_attention_2_3(), reason="Flash-attn 2.3+ is required.") @pytest.mark.parametrize("dtype", param_types_lean) @pytest.mark.parametrize("model_configs", [model_configs_alibi_slopes]) @@ -436,6 +459,7 @@ def test_dpa_alibi_slopes(dtype, model_configs, model): """Test DotProductAttention module with ALiBi slopes""" test_dot_product_attention(dtype, model_configs, model, False, True, None, False) + qkv_layouts = [ 'sb3hd', 'sbh3d', 'sbhd_sb2hd', 'sbhd_sbh2d', 'sbhd_sbhd_sbhd', 'bs3hd', 'bsh3d', 'bshd_bs2hd', 'bshd_bsh2d', 'bshd_bshd_bshd', @@ -443,6 +467,7 @@ def test_dpa_alibi_slopes(dtype, model_configs, model): #'t3hd', 'th3d', 'thd_t2hd', 'thd_th2d', 'thd_thd_thd', ] + model_configs_layout = { # test: b, h, hg, d, sq, skv, p, mask, bias "layout_0_0": ModelConfig(2, 16, 16, 64, 128, 128, 0.0, "no_mask", "no_bias"), @@ -455,6 +480,7 @@ def test_dpa_alibi_slopes(dtype, model_configs, model): "layout_1_3": ModelConfig(1, 24, 24, 128, 2048, 4096, 0.0, "padding_causal", "post_scale_bias"), } + @pytest.mark.skipif(_cudnn_version() < (8,9,5), reason="cuDNN 8.9.5+ is required.") @pytest.mark.parametrize("dtype", param_types_lean) @pytest.mark.parametrize("model_configs", [model_configs_layout]) @@ -464,6 +490,7 @@ def test_dpa_qkv_layout(dtype, model_configs, model, qkv_layout): """Test DotProductAttention module with different QKV layouts""" test_dot_product_attention(dtype, model_configs, model, False, True, qkv_layout, False) + def _run_dot_product_attention( dtype: torch.dtype, config: ModelConfig, @@ -646,6 +673,7 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker: return out, (inp[0].grad, inp[1].grad, inp[2].grad) + model_configs_te_layer = { # test: b, h, hg, d, sq, skv, p, mask, bias "te_1_0": ModelConfig(2, 16, 16, 64, 128, 128, 0.0, "no_mask", "post_scale_bias"), @@ -658,6 +686,7 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker: "te_3_1": ModelConfig(4, 16, 16, 64, 2048, 2048, 0.0, "causal", "alibi"), } + @pytest.mark.skipif(_cudnn_version() < (8,9,1), reason="cuDNN 8.9.1+ is required.") @pytest.mark.parametrize("dtype", param_types) @pytest.mark.parametrize("model_configs", [model_configs_te_layer]) @@ -742,6 +771,7 @@ def test_transformer_layer(dtype, model_configs, model, ckpt_attn, qkv_format, f torch.testing.assert_close(fused_attn_fwd, flash_attn_fwd, **tols) torch.testing.assert_close(fused_attn_bwd, flash_attn_bwd, **tols) + @pytest.mark.skipif(_cudnn_version() < (8,9,1), reason="cuDNN 8.9.1+ is required.") @pytest.mark.parametrize("dtype", param_types_lean) @pytest.mark.parametrize("model_configs", [model_configs_te_layer]) @@ -755,6 +785,7 @@ def test_te_layer_misc(dtype, model_configs, model, qkv_format): test_transformer_layer(dtype, model_configs, model, ckpt_attn, qkv_format, fused_qkv_params, RoPE) + @pytest.mark.skipif(_cudnn_version() < (8,9,1), reason="cuDNN 8.9.1+ is required.") @pytest.mark.parametrize("dtype", param_types_lean) @pytest.mark.parametrize("model_configs", [model_configs_te_layer]) @@ -780,6 +811,7 @@ def find_factors(x): test_transformer_layer(dtype, model_configs, model, ckpt_attn, qkv_format, fused_qkv_params, RoPE) + def _run_transformer_layer( dtype: torch.dtype, config: ModelConfig, @@ -912,8 +944,10 @@ def _run_transformer_layer( "fp8_1": ModelConfig(1, 16, 16, 64, 512, 512, 0.0, "no_mask", "no_bias"), "fp8_2": ModelConfig(4, 16, 16, 64, 512, 512, 0.0, "no_mask", "no_bias"), } + param_types_fp8 = [torch.float16] + @pytest.mark.skipif(_cudnn_version() < (8,9,3), reason="cuDNN 8.9.3+ is required.") @pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8) @pytest.mark.skipif(get_device_compute_capability() != (9, 0), reason="FP8 tests require Hopper.") @@ -946,6 +980,7 @@ def test_dpa_fp8(dtype, model): torch.testing.assert_close(fused_attn_fwd, unfused_attn_fwd, **tols) torch.testing.assert_close(fused_attn_bwd, unfused_attn_bwd, **tols) + def _run_dpa_fp8(dtype, config, backend): """Run FusedAttention FP8 backend, i.e. fused_attn_fwd/bwd_qkvpacked from cpp_extensions""" @@ -989,6 +1024,7 @@ def _run_dpa_fp8(dtype, config, backend): dqkv.view(config.batch_size, config.max_seqlen_q, 3, config.num_heads, config.head_dim).transpose(0,1).contiguous()) + def _run_dpa_fp8_ref(dtype, config, backend): """Run UnfusedDotProductAttention as a reference, i.e. plain PyTorch implementation in FP16 and inputs/outputs @@ -1188,8 +1224,7 @@ def forward( def backward( ctx, grad_output: torch.Tensor ) -> Tuple[Union[torch.Tensor, None], ...]: - - with _prepare_backward(True, ctx.fp8_meta, None, 1, name="_DPA"): + with torch.cuda.nvtx.range("_DPA"): ( inputmat_t, qkv_weight_t_fp8, @@ -1298,6 +1333,7 @@ def backward( None, None) + class DPA_FP8(TransformerEngineBaseModule): def __init__( self, diff --git a/tests/pytorch/test_cuda_graphs.py b/tests/pytorch/test_cuda_graphs.py new file mode 100644 index 0000000000..2b1dcb3aa3 --- /dev/null +++ b/tests/pytorch/test_cuda_graphs.py @@ -0,0 +1,215 @@ +# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +from typing import List, Tuple +import pytest + +import torch +from transformer_engine.pytorch import ( + DotProductAttention, LayerNormLinear, LayerNormMLP, Linear, make_graphed_callables, + MultiheadAttention, TransformerLayer, fp8_autocast, fp8_model_init, +) +from transformer_engine.pytorch.fp8 import FP8GlobalStateManager +from transformer_engine.pytorch.utils import is_bf16_compatible + + +# Only run FP8 tests on H100. +fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() + + +seed = 1234 +torch.manual_seed(seed) +torch.cuda.manual_seed(seed) +# Record initial RNG state from script run. +_cpu_rng_state = torch.get_rng_state() +_cuda_rng_state = torch.cuda.get_rng_state() + + +class ModelConfig: + def __init__(self, hidden_size, nheads, kv, seq_len): + self.h = hidden_size + self.nheads = nheads + self.kv = kv + self.s = seq_len + +model_configs = { + "small": ModelConfig(64, 2, 32, 32), +} + +modules = ["transformer", "layernorm_mlp", "layernorm_linear", "linear", "mha", "dpa"] + +optimizers = [torch.optim.SGD, torch.optim.Adam] + +all_boolean = [True, False] + +dtypes = [torch.float32, torch.float16] +if is_bf16_compatible(): # bf16 requires sm_80 or higher + dtypes.append(torch.bfloat16) + + +def reset_rng_states() -> None: + """revert back to initial RNG state.""" + torch.set_rng_state(_cpu_rng_state) + torch.cuda.set_rng_state(_cuda_rng_state) + + +@pytest.fixture(autouse=True) +def reset_global_fp8_state(): + yield + FP8GlobalStateManager.reset() + + +def assert_all_equal(l1: List[torch.Tensor], l2: List[torch.Tensor], names=None) -> bool: + """Ensures two lists are equal.""" + assert len(l1) == len(l2), "Unequal number of outputs." + failed = False + failed_tensors = "" + for i, (t1, t2) in enumerate(zip(l1, l2)): + with torch.no_grad(): + t1.masked_fill_(t1.isnan(), 1.0) + t2.masked_fill_(t2.isnan(), 1.0) + if not torch.equal(t1, t2): + failed = True + failed_tensors += f" {names[i]}\n" if names is not None else f" tensor at idx={i}\n" + assert not failed, "Output mismatches in:\n" + failed_tensors + + +def generate_data( + s: int, b: int, h: int, nheads: int, kv: int, dtype: torch.dtype, + dpa: bool = False, warmup: bool = False, gen_labels: bool = False, +) -> Tuple[torch.Tensor, torch.Tensor]: + """Generate synthetic data.""" + gen_func = torch.ones if warmup else torch.randn + if dpa: + inputs = [gen_func(s, b, nheads, kv, device="cuda", requires_grad=True, dtype=dtype) for _ in range(3)] + else: + inputs = [gen_func(s, b, h, device="cuda", requires_grad=True, dtype=dtype)] + + if not gen_labels: + return inputs + + target = torch.randn(s, b, h, device="cuda", dtype=dtype) + return inputs, target + + +def get_outputs(model, output): + """Return grads and params for comparsion.""" + values = [] + for param in model.parameters(): + values.append(param) + if param.grad is not None: + values.append(param.grad) + values.append(output) + return values + + +def _test_cuda_graphs(config, bs, num_layers, dtype, fp8, fp8_params, graph, module, optimizer, graph_mode=""): + """Helper function for test.""" + reset_rng_states() + FP8GlobalStateManager.reset() + dpa = module == "dpa" + + with fp8_model_init(enabled=fp8_params): + # Create modules. + if module == "transformer": + modules = [TransformerLayer( + config.h, + config.h, + config.nheads, + hidden_dropout=0.0, + attention_dropout=0.0, + fuse_qkv_params=True, + params_dtype=dtype, + ) for _ in range(num_layers)] + elif module == "layernorm_mlp": + modules = [LayerNormMLP( + config.h, config.h, params_dtype=dtype + ) for _ in range(num_layers)] + elif module == "layernorm_linear": + modules = [LayerNormLinear( + config.h, config.h, params_dtype=dtype + ) for _ in range(num_layers)] + elif module == "mha": + modules = [MultiheadAttention( + config.h, + config.nheads, + attention_dropout=0.0, + params_dtype=dtype, + fuse_qkv_params=True, + ) for _ in range(num_layers)] + elif dpa: + assert config.h % config.nheads == 0, "Err." + assert num_layers == 1, "Err." + modules = [DotProductAttention( + config.nheads, config.kv, attention_dropout=0.0 + ) for _ in range(num_layers)] + else: + modules = [Linear( + config.h, config.h, device="cuda", params_dtype=dtype + ) for _ in range(num_layers)] + + # Generate model and wrap API to return graphed version. + if graph: + # Graph entire module at once. + if graph_mode == "full": + model = modules[0] if dpa else torch.nn.Sequential(*modules) + model = make_graphed_callables( + model, + generate_data(config.s, bs, config.h, config.nheads, config.kv, dtype, dpa=dpa, warmup=True), + num_warmup_iters=10, + fp8_enabled=fp8) + else: + modules = [make_graphed_callables( + module, + generate_data(config.s, bs, config.h, config.nheads, config.kv, dtype, dpa=dpa, warmup=True), + num_warmup_iters=10, + fp8_enabled=fp8) for module in modules] + model = modules[0] if dpa else torch.nn.Sequential(*modules) + else: + model = modules[0] if dpa else torch.nn.Sequential(*modules) + + # Loss function and optimizer. + loss_fn = torch.nn.MSELoss() + if not dpa: + optimizer = optimizer(model.parameters(), lr=0.001) + + # Launch. + for _ in range(10): + inputs, target = generate_data(config.s, bs, config.h, config.nheads, config.kv, dtype, dpa=dpa, gen_labels=True) + with fp8_autocast(enabled=fp8): + output = model(*inputs) + loss = loss_fn(output, target) + loss.backward() + if not dpa: + optimizer.step() + optimizer.zero_grad() + + return get_outputs(model, output) + + +@pytest.mark.parametrize("dtype", dtypes) +@pytest.mark.parametrize("bs", [1, 2]) +@pytest.mark.parametrize("model", model_configs.keys()) +@pytest.mark.parametrize("num_layers", [1, 10]) +@pytest.mark.parametrize("fp8", all_boolean) +@pytest.mark.parametrize("fp8_params", all_boolean) +@pytest.mark.parametrize("module", modules) +@pytest.mark.parametrize("optimizer", optimizers) +def test_gpt_make_graphed_callables(dtype, bs, model, num_layers, fp8, fp8_params, module, optimizer): + if fp8 and not fp8_available: + pytest.skip(reason_for_no_fp8) + if fp8_params and not fp8: + pytest.skip("FP8 needed for FP8 parameters.") + if module == "dpa" and num_layers > 1: + pytest.skip("Max 1 layer for DPA.") + + config = model_configs[model] + + outputs = _test_cuda_graphs(config, bs, num_layers, dtype, fp8, fp8_params, False, module, optimizer) + graph_outputs_mode1 = _test_cuda_graphs(config, bs, num_layers, dtype, fp8, fp8_params, True, module, optimizer, graph_mode="full") + graph_outputs_mode2 = _test_cuda_graphs(config, bs, num_layers, dtype, fp8, fp8_params, True, module, optimizer, graph_mode="individual") + + # Check that results match + assert_all_equal(outputs, graph_outputs_mode1) + assert_all_equal(outputs, graph_outputs_mode2) diff --git a/tests/pytorch/test_float8tensor.py b/tests/pytorch/test_float8tensor.py index 935519ca84..c4c39f9309 100644 --- a/tests/pytorch/test_float8tensor.py +++ b/tests/pytorch/test_float8tensor.py @@ -257,12 +257,10 @@ def test_inplace_ops( with pytest.raises(AssertionError): torch.testing.assert_close(x_fp8, x_ref, **tols) - @pytest.mark.parametrize("dims", [[33, 41], [5, 7, 11]]) - @pytest.mark.parametrize("transpose_dims", [(0, 1), (-2, -1), (0, 0)]) + @pytest.mark.parametrize("dims", [[33, 41], [7, 11]]) def test_transpose( self, dims: DimsType, - transpose_dims: Tuple[int, int], fp8_dtype: tex.DType = tex.DType.kFloat8E4M3, scale: float = 0.5, dtype: torch.dtype = torch.float32, @@ -271,74 +269,44 @@ def test_transpose( # Initialize random data dims = _to_list(dims) - x_ref = 2 * torch.rand(dims, dtype=dtype, device="cpu") - 1 + x = 2 * torch.rand(dims, dtype=dtype, device="cpu") - 1 x_fp8 = Float8Tensor.to_float8( - x_ref, + x, fp8_dtype=fp8_dtype, scale=torch.full([1], scale), ) - x_ref = x_fp8.from_float8() + x = x_fp8.from_float8() # Perform transpose - y_fp8 = x_fp8.transpose(*transpose_dims) - y_ref = x_ref.transpose(*transpose_dims) + x_fp8_t = x_fp8.transpose_2d() + x_t = x.transpose(0, 1) + x_fp8_t = Float8Tensor.make_like(x_fp8, data=x_fp8_t) # Check results tols = dict(rtol=0, atol=0) - torch.testing.assert_close(y_fp8, y_ref, **tols) + torch.testing.assert_close(x_fp8_t, x_t, **tols) # Make sure we are not trivially passing the test - if transpose_dims[0] != transpose_dims[1]: - with pytest.raises(AssertionError): - torch.testing.assert_close( - y_fp8, - x_ref, - **tols, - ) - - # Check transpose caching - if x_fp8.dim() == 2 and transpose_dims[0] != transpose_dims[1]: - - # Check that cached transpose is returned when expected - # Note: Sneakily destroy data so that recalculating - # transpose would give wrong answer. - x_fp8 += 0.5 - x_ref = x_fp8.from_float8() - torch.testing.assert_close( - x_fp8.transpose(*transpose_dims, update_cache="lazy"), - x_ref.transpose(*transpose_dims), - **tols, - ) - x_fp8_data = x_fp8._data.clone() - x_fp8._data.zero_() - torch.testing.assert_close( - x_fp8.transpose(*transpose_dims), - x_ref.transpose(*transpose_dims), - **tols, - ) - torch.testing.assert_close( - x_fp8.transpose(*transpose_dims, update_cache="lazy"), - x_ref.transpose(*transpose_dims), - **tols, - ) - torch.testing.assert_close( - x_fp8.transpose(*transpose_dims, update_cache="force"), - torch.zeros_like(x_ref.transpose(*transpose_dims)), - rtol=0, - atol=0, - ) - x_fp8._data.copy_(x_fp8_data) - x_fp8._reset_caches() - - # Make sure cache is reset after in-place operation - x_fp8.transpose(*transpose_dims, update_cache="force") - x_fp8 += 0.5 - x_ref = x_fp8.from_float8() - torch.testing.assert_close( - x_fp8.transpose(*transpose_dims), - x_ref.transpose(*transpose_dims), - **tols, - ) + with pytest.raises(AssertionError): + torch.testing.assert_close(x_fp8_t, x, **tols) + + # Caching test. + assert x_fp8._transpose_invalid, "Transpose cache must be invalid when not caching." + x_fp8 += 0.5 + x = x_fp8.from_float8() + x_fp8_t = Float8Tensor.make_like(x_fp8, data=x_fp8.transpose_2d(cache=True)) + x_t = x.transpose(0, 1) + torch.testing.assert_close(x_fp8_t, x_t, **tols) + assert not x_fp8._transpose_invalid, "Transpose cache reset incorrectly." + + # Inplace update test. + x_fp8 += 0.5 + assert x_fp8._transpose_invalid, "Transpose cache not invalidated properly." + x = x_fp8.from_float8() + x_fp8_t = Float8Tensor.make_like(x_fp8, data=x_fp8.transpose_2d(cache=True)) + x_t = x.transpose(0, 1) + torch.testing.assert_close(x_fp8_t, x_t, **tols) + assert not x_fp8._transpose_invalid, "Transpose cache reset incorrectly." def test_serialization( self, diff --git a/tests/pytorch/test_numerics.py b/tests/pytorch/test_numerics.py index c2eb2c01a5..ddb3ecf49f 100644 --- a/tests/pytorch/test_numerics.py +++ b/tests/pytorch/test_numerics.py @@ -4,7 +4,6 @@ import math import os -import sys from typing import List, Optional import pytest import copy @@ -25,7 +24,6 @@ MultiheadAttention, RMSNorm, TransformerLayer, LayerNorm, InferenceParams ) from transformer_engine.pytorch.distributed import checkpoint as te_checkpoint -from transformer_engine.pytorch.distributed import _set_cuda_rng_state, CudaRNGStatesTracker # Only run FP8 tests on H100. @@ -54,6 +52,14 @@ def __init__(self, hidden_size, eps, num_attention_heads, embed, num_layers, seq "126m": ModelConfig(768, 1e-5, 12, 64, 12, 2048), } +model_configs_inference = { + # hidden_size, eps, num_attention_heads, embed, num_layers, seq_len + "126m": ModelConfig(768, 1e-5, 12, 64, 12, 16), +} +backends_inference = ["FlashAttention", "UnfusedAttention"] +module_inference = ["TransformerLayer", "MultiheadAttention"] +input_formats_inference = ["sbhd", "bshd"] + param_types = [torch.float32, torch.float16] if is_bf16_compatible(): # bf16 requires sm_80 or higher param_types.append(torch.bfloat16) @@ -104,7 +110,13 @@ def assert_allclose(l1: List[torch.Tensor], l2: List[torch.Tensor], atol: float) def reset_rng_states() -> None: """revert back to initial RNG state.""" torch.set_rng_state(_cpu_rng_state) - _set_cuda_rng_state(_cuda_rng_state) + torch.cuda.set_rng_state(_cuda_rng_state) + + +@pytest.fixture(autouse=True) +def reset_global_fp8_state(): + yield + FP8GlobalStateManager.reset() class TorchScaledMaskedSoftmax(nn.Module): @@ -373,10 +385,10 @@ def __init__(self, hidden_size: int, eps: float, num_attention_heads: int, paral def forward( self, x: torch.Tensor, - attn_mask: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, ) -> torch.Tensor: a = self.ln(x) - b = self.causal_attn(a, attn_mask) + b = self.causal_attn(a, attention_mask) if self.parallel_attention_mlp: n = self.ln_mlp(x) x = x + nn.functional.dropout(b + n, p=0.1, training=self.training) @@ -396,13 +408,6 @@ def _test_e2e_selective_recompute(bs, dtype, config, fp8, fp8_model_params=False init_method = init_method_normal(sigma) output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers) - _DUMMY_CUDA_RNG_STATE_TRACKER = CudaRNGStatesTracker() - _DUMMY_CUDA_RNG_STATE_TRACKER.add("model-parallel-rng", seed) - - def get_dummy_cuda_rng_tracker(): - """Get cuda rng tracker.""" - return _DUMMY_CUDA_RNG_STATE_TRACKER - with fp8_model_init(enabled=fp8 and fp8_model_params): block = ( TransformerLayer( @@ -417,7 +422,6 @@ def get_dummy_cuda_rng_tracker(): kv_channels=config.embed, apply_residual_connection_post_layernorm=False, output_layernorm=False, - get_rng_state_tracker=get_dummy_cuda_rng_tracker, params_dtype=dtype, fuse_qkv_params=True, ) @@ -476,13 +480,6 @@ def _test_e2e_full_recompute( init_method = init_method_normal(sigma) output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers) - _DUMMY_CUDA_RNG_STATE_TRACKER = CudaRNGStatesTracker() - _DUMMY_CUDA_RNG_STATE_TRACKER.add("model-parallel-rng", seed) - - def get_dummy_cuda_rng_tracker(): - """Get cuda rng tracker.""" - return _DUMMY_CUDA_RNG_STATE_TRACKER - with fp8_model_init(enabled=fp8 and fp8_model_params): block = ( TransformerLayer( @@ -497,7 +494,6 @@ def get_dummy_cuda_rng_tracker(): kv_channels=config.embed, apply_residual_connection_post_layernorm=False, output_layernorm=False, - get_rng_state_tracker=get_dummy_cuda_rng_tracker, params_dtype=dtype, fuse_qkv_params=True, ) @@ -520,7 +516,6 @@ def get_dummy_cuda_rng_tracker(): checkpoint_core_attention=False, distribute_saved_activations=False, tp_group=None, - get_rng_state_tracker=get_dummy_cuda_rng_tracker, use_reentrant=use_reentrant, ) else: @@ -683,7 +678,7 @@ def _test_e2e_gpt_accuracy(block, bs, dtype, config): inp_hidden_states.retain_grad() inp_attn_mask = get_causal_attn_mask(config.seq_len) - out = block(inp_hidden_states, inp_attn_mask) + out = block(inp_hidden_states, attention_mask=inp_attn_mask) loss = out.sum() loss.backward() @@ -1261,13 +1256,6 @@ def _test_gpt_fp8_parameters(bs, dtype, config, fp8_model_params): init_method = init_method_normal(sigma) output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers) - _DUMMY_CUDA_RNG_STATE_TRACKER = CudaRNGStatesTracker() - _DUMMY_CUDA_RNG_STATE_TRACKER.add("model-parallel-rng", seed) - - def get_dummy_cuda_rng_tracker(): - """Get cuda rng tracker.""" - return _DUMMY_CUDA_RNG_STATE_TRACKER - with fp8_model_init(enabled=fp8_model_params): block = ( TransformerLayer( @@ -1282,7 +1270,6 @@ def get_dummy_cuda_rng_tracker(): kv_channels=config.embed, apply_residual_connection_post_layernorm=False, output_layernorm=False, - get_rng_state_tracker=get_dummy_cuda_rng_tracker, params_dtype=dtype, fuse_qkv_params=True, ) @@ -1321,6 +1308,7 @@ def test_gpt_fp8_parameters(dtype, bs, model): outputs_fp8_params = _test_gpt_fp8_parameters(bs, dtype, config, True) assert_all_equal(outputs, outputs_fp8_params) + @pytest.mark.parametrize("dtype", param_types) @pytest.mark.parametrize("bs", batch_sizes) @pytest.mark.parametrize("model", model_configs.keys()) @@ -1399,14 +1387,6 @@ def test_transformer_layer_hidden_states_format(dtype, bs, model): assert_all_equal([y_bshd], [y_sbhd.transpose(0,1).contiguous()]) -model_configs_inference = { - # hidden_size, eps, num_attention_heads, embed, num_layers, seq_len - "126m": ModelConfig(768, 1e-5, 12, 64, 12, 16), -} -backends_inference = ["FlashAttention", "UnfusedAttention"] -module_inference = ["TransformerLayer", "MultiheadAttention"] -input_formats_inference = ["sbhd", "bshd"] - @pytest.mark.parametrize("dtype", param_types) @pytest.mark.parametrize("bs", batch_sizes) @pytest.mark.parametrize("model_key", model_configs_inference.keys()) diff --git a/tests/pytorch/test_onnx_export.py b/tests/pytorch/test_onnx_export.py index 822b1450ec..7707264c7f 100644 --- a/tests/pytorch/test_onnx_export.py +++ b/tests/pytorch/test_onnx_export.py @@ -86,6 +86,12 @@ def set_max_seq_len(max_seq_len=128): os.environ["NVTE_ONNX_KVCACHE_MAX_SEQ_LEN"] = f"{max_seq_len}" +@pytest.fixture(autouse=True) +def reset_global_fp8_state(): + yield + FP8GlobalStateManager.reset() + + def create_fp8_recipe(): return recipe.DelayedScaling(margin=0, interval=1, fp8_format=recipe.Format.E4M3) diff --git a/tests/pytorch/test_sanity.py b/tests/pytorch/test_sanity.py index 217eacc9b3..e91e464fa4 100644 --- a/tests/pytorch/test_sanity.py +++ b/tests/pytorch/test_sanity.py @@ -48,6 +48,7 @@ def custom_amax_compute(amax_history: torch.Tensor) -> torch.Tensor: """Custom func to test recipe.""" return torch.min(amax_history, dim=0).values + @dataclass class ModelConfig: """Transformer model configuration""" @@ -115,6 +116,12 @@ def _disable_wgrads(block): p.requires_grad = False +@pytest.fixture(autouse=True) +def reset_global_fp8_state(): + yield + FP8GlobalStateManager.reset() + + def _test_sanity_e2e_cuda_graph(block, dtype, config, fp8_recipe, skip_wgrad): # Initialize loss function and optimizer. loss_fn = torch.nn.MSELoss() @@ -137,7 +144,7 @@ def _test_sanity_e2e_cuda_graph(block, dtype, config, fp8_recipe, skip_wgrad): with torch.cuda.stream(s): for _ in range(3): optimizer.zero_grad(set_to_none=True) - with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe): + with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe, _graph=True): out = block(static_input) loss = loss_fn(out, static_target) loss.backward() @@ -148,7 +155,7 @@ def _test_sanity_e2e_cuda_graph(block, dtype, config, fp8_recipe, skip_wgrad): g = torch.cuda.CUDAGraph() optimizer.zero_grad(set_to_none=True) with torch.cuda.graph(g): - with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe): + with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe, _graph=True): static_output = block(static_input) static_loss = loss_fn(static_output, static_target) static_loss.backward() diff --git a/transformer_engine/common/include/transformer_engine/cast_transpose_noop.h b/transformer_engine/common/include/transformer_engine/cast_transpose_noop.h new file mode 100644 index 0000000000..f9097679a6 --- /dev/null +++ b/transformer_engine/common/include/transformer_engine/cast_transpose_noop.h @@ -0,0 +1,35 @@ +/************************************************************************* + * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +/*! \file transpose_with_noop.h + * \brief Functions handling transposes with no-op. + */ + +#ifndef TRANSFORMER_ENGINE_CAST_TRANSPOSE_WITH_NOOP_H_ +#define TRANSFORMER_ENGINE_CAST_TRANSPOSE_WITH_NOOP_H_ + +#include "transformer_engine.h" + +#ifdef __cplusplus +extern "C" { +#endif + +void nvte_transpose_with_noop(const NVTETensor input, + const NVTETensor noop, + NVTETensor output, + cudaStream_t stream); + +void nvte_cast_transpose_with_noop(const NVTETensor input, + const NVTETensor noop, + NVTETensor cast_output, + NVTETensor transposed_output, + cudaStream_t stream); + +#ifdef __cplusplus +} // extern "C" +#endif + +#endif // TRANSFORMER_ENGINE_CAST_TRANSPOSE_WITH_NOOP_H_ diff --git a/transformer_engine/common/include/transformer_engine/recipe.h b/transformer_engine/common/include/transformer_engine/recipe.h index ddb64be5e7..49cc9af914 100644 --- a/transformer_engine/common/include/transformer_engine/recipe.h +++ b/transformer_engine/common/include/transformer_engine/recipe.h @@ -56,6 +56,45 @@ void nvte_delayed_scaling_recipe_amax_and_scale_update(const NVTETensor amax_his float margin, cudaStream_t stream); + +/*! \brief Bulk-update FP8 scaling factors with delayed scaling recipe after amax reduction. + * + * Operations performed include, updating the most recent amax history + * with the relevant segment of global reduction buffer if it's not 0, + * rotating the amax history based on the rule below, and updating the + * scales and scale_invs. + * + * The amax history is rotated by -1 (e.g. the first entry shifts to + * the last, the last entry shifts to the second to last) and the + * first entry is set to zero. The scaling factor is estimated so the + * FP8 tensor's maximum absolute value is + * @f$ 2^{-\text{margin}} \text{max}_\text{fp8\_dtype} @f$. + * + * \param[in] amax_reduction_buffer The contiguous buffer used for amax reduction. + * Shape: [num_scales * num_tensors] + * \param[in,out] amax_histories List of amax histories of maximum absolute values. + * Shape: num_tensors x [history_length, num_scales] + * \param[in,out] scales List of scaling factors for casting to FP8. + * Shape: num_tensors x [num_scales] + * \param[in,out] scale_invs List of scaling factors for casting from FP8. + * Shape: num_tensors x [num_scales] + * \param[in] amax_compute_algo Method to reduce amax history. Options are "max" and + * "most_recent". + * \param[in] fp8_dtype FP8 datatype. + * \param[in] margin Scaling factor margin. + * \param[in] stream CUDA stream. + */ +void nvte_delayed_scaling_recipe_amax_and_scale_update_after_reduction( + const NVTETensor amax_reduction_buffer, + std::vector amax_histories, + std::vector scales, + std::vector scale_invs, + const char *amax_compute_algo, + NVTEDType fp8_dtype, + float margin, + cudaStream_t stream); + + #ifdef __cplusplus } // extern "C" #endif diff --git a/transformer_engine/common/layer_norm/ln_api.cpp b/transformer_engine/common/layer_norm/ln_api.cpp index f5eb1896c4..7a01cf0345 100644 --- a/transformer_engine/common/layer_norm/ln_api.cpp +++ b/transformer_engine/common/layer_norm/ln_api.cpp @@ -229,19 +229,29 @@ void layernorm_fwd(const Tensor& x, // BxSxhidden_size // Query the kernel-specific launch parameters. launcher(launch_params, true); + if (launch_params.workspace_bytes == 0) { + launch_params.workspace_bytes = 1; + } + if (workspace->data.dptr == nullptr) { NVTE_CHECK(barrier->data.dptr == nullptr); workspace->data.dtype = layer_norm::DType::kByte; - if (launch_params.workspace_bytes == 0) { - launch_params.workspace_bytes = 1; - } workspace->data.shape = { launch_params.workspace_bytes }; barrier->data.dtype = layer_norm::DType::kInt32; barrier->data.shape = { launch_params.barrier_size }; return; + } else { + NVTE_CHECK(workspace->data.dtype == layer_norm::DType::kByte); + NVTE_CHECK(workspace->data.shape == std::vector{ launch_params.workspace_bytes }); + } + + if (launch_params.barrier_size > 0) { + NVTE_CHECK(barrier->data.dptr != nullptr); + NVTE_CHECK(barrier->data.dtype == layer_norm::DType::kInt32); + NVTE_CHECK(barrier->data.shape == std::vector{ launch_params.barrier_size }); } // Tensor checks are delayed here in order to recover workspace sizes with null data @@ -368,6 +378,27 @@ void layernorm_bwd(const Tensor& dz, barrier->data.shape = { launch_params.barrier_size }; return; + } else { + NVTE_CHECK(dbeta_part->data.dptr != nullptr); + auto pdw_shape = std::vector{ + static_cast(launch_params.params.ctas_per_col), hidden_size}; + + NVTE_CHECK(dgamma_part->data.dtype == ctype); + NVTE_CHECK(dgamma_part->data.shape == pdw_shape); + NVTE_CHECK(dbeta_part->data.dtype == ctype); + NVTE_CHECK(dbeta_part->data.shape == pdw_shape); + } + + if (launch_params.barrier_size > 0) { + NVTE_CHECK(barrier->data.dptr != nullptr); + NVTE_CHECK(barrier->data.dtype == layer_norm::DType::kInt32); + NVTE_CHECK(barrier->data.shape == std::vector{ launch_params.barrier_size }); + } + + if (launch_params.workspace_bytes > 0) { + NVTE_CHECK(workspace->data.dptr != nullptr); + NVTE_CHECK(workspace->data.dtype == layer_norm::DType::kByte); + NVTE_CHECK(workspace->data.shape == std::vector{ launch_params.workspace_bytes }); } // Tensor checks are delayed here in order to recover workspace sizes with null data diff --git a/transformer_engine/common/recipe/__init__.py b/transformer_engine/common/recipe/__init__.py index 55a706492f..9abbb69cbe 100644 --- a/transformer_engine/common/recipe/__init__.py +++ b/transformer_engine/common/recipe/__init__.py @@ -133,3 +133,13 @@ def __post_init__(self) -> None: (False, False, False), (False, False, True), ), "Only wgrad GEMM override is currently supported." + + def __repr__(self) -> str: + return ( + f"margin={self.margin}, " + f"interval={self.interval}, " + f"format={str(self.fp8_format).split('.')[1]}, " + f"amax_history_len={self.amax_history_len}, " + f"wgrad_override={self.override_linear_precision.wgrad}, " + f"reduce_amax={self.reduce_amax}" + ) diff --git a/transformer_engine/common/recipe/delayed_scaling.cu b/transformer_engine/common/recipe/delayed_scaling.cu index 3fa64920df..6e07b1ce9f 100644 --- a/transformer_engine/common/recipe/delayed_scaling.cu +++ b/transformer_engine/common/recipe/delayed_scaling.cu @@ -11,6 +11,7 @@ #include "../common.h" #include "../util/logging.h" +#include "../util/cuda_runtime.h" namespace transformer_engine { namespace delayed_scaling_recipe { @@ -38,6 +39,36 @@ inline float fp8_dtype_max(DType dtype) { return 0; } +// struct for amax parameters +struct AmaxParam { + int num_scale = 0; + float* amax_history = nullptr; + float* scale = nullptr; + float* scale_inv = nullptr; +}; + +// dummy struct for kernel_bulk's other params +struct OtherParams { + float* a; + size_t b; + AmaxComputeAlgo c; + float d; +}; + +#if CUDART_VERSION >= 12010 +constexpr size_t max_constant_memory_per_kernel = 32000; +constexpr size_t AMAX_PARAMS_LIMIT = ( + max_constant_memory_per_kernel - sizeof(OtherParams)) / sizeof(AmaxParam); +#else +constexpr size_t max_constant_memory_per_kernel = 4000; +constexpr size_t AMAX_PARAMS_LIMIT = ( + max_constant_memory_per_kernel - sizeof(OtherParams)) / sizeof(AmaxParam); +#endif + +struct AmaxParams { + AmaxParam param[AMAX_PARAMS_LIMIT]; +}; + namespace amax_and_scale_update_impl { // CUDA block size @@ -133,11 +164,96 @@ kernel(const float* amax_history_ptr, } } -} // namespace amax_and_scale_update_impl +/* CUDA kernel to bulk-update amax history and FP8 scaling factors + * + * Block dims: bsize x 1 x 1 + * + * Grid dims: num_tensors x 1 x 1 + */ +__global__ void __launch_bounds__(bsize) +kernel_bulk( + float* amax_reduction_buffer, + AmaxParams p, + size_t amax_history_length, + AmaxComputeAlgo amax_compute_algo, + float scaled_max) { + const size_t bid = blockIdx.x; + const size_t tid = threadIdx.x; + const int num_scale = p.param[bid].num_scale; + + int offset_in_buffer = 0; + for (int j = 0; j < bid; j++) { + offset_in_buffer += p.param[j].num_scale; + } + for (int count = 0; count < num_scale; count++) { + // Update amax + float amax = 0; + { + // Roll amax history + const auto& length = amax_history_length; + const auto& stride = p.param[bid].num_scale; + auto* amax_history = p.param[bid].amax_history+count; + const auto last_amax = ((amax_reduction_buffer != nullptr) + && (amax_reduction_buffer[offset_in_buffer+count] != 0.0f)) ? + amax_reduction_buffer[offset_in_buffer+count] : amax_history[0]; + for (size_t off = 0; off < length; off += bsize) { + const size_t i = off + tid; + float a = 0; + if (i < length) { + a = (i < length - 1) ? amax_history[(i+1)*stride] : last_amax; + amax = fmaxf(amax, a); + } + __syncthreads(); // Inplace roll + if (i < length) { + amax_history[i*stride] = (i > 0) ? a : 0; + } + } + + // Compute amax to use for scaling factor + switch (amax_compute_algo) { + case AmaxComputeAlgo::MOST_RECENT: + amax = last_amax; + break; + case AmaxComputeAlgo::MAX: + { + __shared__ float shared_amax[bsize]; + shared_amax[tid] = amax; + __syncthreads(); +#pragma unroll + for (size_t off = bsize / 2; off > 0; off /= 2) { + if (tid < off) { + shared_amax[tid] = fmaxf(shared_amax[tid], shared_amax[tid + off]); + } + __syncthreads(); + } + amax = shared_amax[tid]; + } + break; + default: + amax = 0; + } + } + + // Update scale and scale inverse + if (tid == 0) { + float scale; + if (isfinite(amax) && amax > 0) { + scale = scaled_max / amax; + } else { + scale = p.param[bid].scale[count]; + } + p.param[bid].scale[count] = scale; + p.param[bid].scale_inv[count] = 1 / scale; + } + } +} + +} // namespace amax_and_scale_update_impl } // namespace + void amax_and_scale_update(const Tensor &amax_history, const Tensor &scale, const Tensor &scale_inv, @@ -238,9 +354,105 @@ void amax_and_scale_update(const Tensor &amax_history, NVTE_CHECK_CUDA(cudaGetLastError()); } + +void amax_and_scale_update_after_reduction(const Tensor &amax_reduction_buffer, + std::vector amax_histories, + std::vector scales, + std::vector scale_invs, + const std::string &amax_compute_algo, + DType fp8_dtype, + float margin, + cudaStream_t stream) { + using namespace transformer_engine; + + // amax value to use for updating scaling factor + AmaxComputeAlgo amax_compute_algo_ = AmaxComputeAlgo::INVALID; + if (amax_compute_algo == "max") { + amax_compute_algo_ = AmaxComputeAlgo::MAX; + } else if (amax_compute_algo == "most_recent") { + amax_compute_algo_ = AmaxComputeAlgo::MOST_RECENT; + } else { + NVTE_ERROR("Unsupported amax compute algorithm (", amax_compute_algo, ")"); + } + + // Expected maximum value after scale is applied + const float scaled_max = fp8_dtype_max(fp8_dtype) * std::pow(2.f, -margin); + + // Number of elements in tensor + auto numel = [] (const Tensor *tensor) -> size_t { + size_t acc = 1; + for (const auto& dim : tensor->data.shape) { + acc *= dim; + } + return acc; + }; + + // Number of tensors in the bulk + const size_t num_tensors = amax_histories.size(); + const int num_kernels = (num_tensors+AMAX_PARAMS_LIMIT-1)/AMAX_PARAMS_LIMIT; + size_t amax_history_length = 0; + if (num_tensors > 0) { + amax_history_length = amax_histories[0]->data.shape[0]; + } + + // amax parameters + float* amax_buffer = static_cast(amax_reduction_buffer.data.dptr); + AmaxParams p; + for (int iter = 0; iter < num_kernels; iter++) { + size_t kernel_num_scales = 0; + size_t kernel_num_tensors = (iter == (num_kernels -1)) + ? num_tensors % AMAX_PARAMS_LIMIT: AMAX_PARAMS_LIMIT; + for (size_t pi = 0; pi < kernel_num_tensors; pi++) { + size_t i = iter * AMAX_PARAMS_LIMIT + pi; + + // Check tensors + int num_scale = amax_histories[i]->data.shape[1]; + NVTE_CHECK(amax_histories[i]->data.dtype == DType::kFloat32, + "Found ", dtype_name(amax_histories[i]->data.dtype), "."); + NVTE_CHECK(amax_histories[i]->data.shape.size() == 2, + "Found ", amax_histories[i]->data.shape.size(), " dims"); + NVTE_CHECK(numel(amax_histories[i]) == amax_history_length * num_scale, + "Expected ", amax_history_length * num_scale, " elements, ", + "but found ", numel(amax_histories[i]), "."); + NVTE_CHECK(scales[i]->data.dtype == DType::kFloat32, + "Found ", dtype_name(scales[i]->data.dtype), "."); + NVTE_CHECK(scales[i]->data.shape.size() == 1, + "Found ", scales[i]->data.shape.size(), " dims"); + NVTE_CHECK(numel(scales[i]) == num_scale, + "Expected ", num_scale, " elements, ", + "Found ", numel(scales[i]), "."); + + // amax parameters + kernel_num_scales += num_scale; + p.param[pi].num_scale = num_scale; + p.param[pi].amax_history = static_cast(amax_histories[i]->data.dptr); + p.param[pi].scale = static_cast(scales[i]->data.dptr); + p.param[pi].scale_inv = static_cast(scale_invs[i]->data.dptr); + } + + // Launch CUDA kernel + size_t grid_size = kernel_num_tensors; + const size_t block_size = amax_and_scale_update_impl::bsize; + amax_and_scale_update_impl::kernel_bulk + <<>>( + amax_buffer, + p, + amax_history_length, + amax_compute_algo_, + scaled_max); + NVTE_CHECK_CUDA(cudaGetLastError()); + + // shift amax buffer pointer + if (amax_buffer != nullptr) { + amax_buffer += kernel_num_scales; + } + } +} + } // namespace delayed_scaling_recipe } // namespace transformer_engine + void nvte_delayed_scaling_recipe_amax_and_scale_update(const NVTETensor amax_history, const NVTETensor scale, const NVTETensor scale_inv, @@ -267,3 +479,33 @@ void nvte_delayed_scaling_recipe_amax_and_scale_update(const NVTETensor amax_his margin, stream); } + + +void nvte_delayed_scaling_recipe_amax_and_scale_update_after_reduction( + const NVTETensor amax_reduction_buffer, + std::vector amax_histories, + std::vector scales, + std::vector scale_invs, + const char *amax_compute_algo, + NVTEDType fp8_dtype, + float margin, + cudaStream_t stream) { + NVTE_API_CALL(nvte_delayed_scaling_recipe_amax_and_scale_update_after_reduction); + using namespace transformer_engine; + size_t num_tensors = amax_histories.size(); + std::vector t_amax_histories, t_scales, t_scale_invs; + for (size_t i = 0; i < num_tensors; i++) { + t_amax_histories.push_back(reinterpret_cast(amax_histories[i])); + t_scales.push_back(reinterpret_cast(scales[i])); + t_scale_invs.push_back(reinterpret_cast(scale_invs[i])); + } + delayed_scaling_recipe::amax_and_scale_update_after_reduction( + *reinterpret_cast(amax_reduction_buffer), + t_amax_histories, + t_scales, + t_scale_invs, + amax_compute_algo, + static_cast(fp8_dtype), + margin, + stream); +} diff --git a/transformer_engine/common/rmsnorm/rmsnorm_api.cpp b/transformer_engine/common/rmsnorm/rmsnorm_api.cpp index 86ffc64c25..5ccfae1922 100644 --- a/transformer_engine/common/rmsnorm/rmsnorm_api.cpp +++ b/transformer_engine/common/rmsnorm/rmsnorm_api.cpp @@ -153,21 +153,32 @@ void rmsnorm_fwd(const Tensor &x, const Tensor &gamma, const float epsilon, Tens // Query the kernel-specific launch parameters. launcher(launch_params, true); + if (launch_params.workspace_bytes == 0) { + launch_params.workspace_bytes = 1; + } + if (workspace->data.dptr == nullptr) { NVTE_CHECK(barrier->data.dptr == nullptr); workspace->data.dtype = DType::kByte; - if (launch_params.workspace_bytes == 0) { - launch_params.workspace_bytes = 1; - } workspace->data.shape = {launch_params.workspace_bytes}; barrier->data.dtype = DType::kInt32; barrier->data.shape = {launch_params.barrier_size}; return; + } else { + NVTE_CHECK(workspace->data.dtype == DType::kByte); + NVTE_CHECK(workspace->data.shape == std::vector{ launch_params.workspace_bytes }); } + if (launch_params.barrier_size > 0) { + NVTE_CHECK(barrier->data.dptr != nullptr); + NVTE_CHECK(barrier->data.dtype == DType::kInt32); + NVTE_CHECK(barrier->data.shape == std::vector{ launch_params.barrier_size }); + } + + // Tensor checks are delayed here in order to recover workspace sizes with null data CheckInputTensor(x, "x"); CheckInputTensor(gamma, "gamma"); @@ -265,6 +276,23 @@ void rmsnorm_bwd(const Tensor &dz, const Tensor &x, const Tensor &rsigma, const barrier->data.shape = {launch_params.barrier_size}; return; + } else { + auto pdw_shape = std::vector{ + static_cast(launch_params.params.ctas_per_col), hidden_size}; + NVTE_CHECK(dgamma_part->data.dtype == ctype); + NVTE_CHECK(dgamma_part->data.shape == pdw_shape); + } + + if (launch_params.barrier_size > 0) { + NVTE_CHECK(barrier->data.dptr != nullptr); + NVTE_CHECK(barrier->data.dtype == DType::kInt32); + NVTE_CHECK(barrier->data.shape == std::vector{ launch_params.barrier_size }); + } + + if (launch_params.workspace_bytes > 0) { + NVTE_CHECK(workspace->data.dptr != nullptr); + NVTE_CHECK(workspace->data.dtype == DType::kByte); + NVTE_CHECK(workspace->data.shape == std::vector{ launch_params.workspace_bytes }); } // Tensor checks are delayed here in order to recover workspace sizes with null data diff --git a/transformer_engine/common/transpose/cast_transpose.cu b/transformer_engine/common/transpose/cast_transpose.cu index 9f1a18de7a..347aeb9b15 100644 --- a/transformer_engine/common/transpose/cast_transpose.cu +++ b/transformer_engine/common/transpose/cast_transpose.cu @@ -4,6 +4,7 @@ * See LICENSE for license information. ************************************************************************/ +#include #include #include #include @@ -56,6 +57,7 @@ template ; using OVec = Vec; @@ -163,6 +167,7 @@ template ; using OVec = Vec; @@ -294,6 +301,7 @@ cast_transpose_kernel_notaligned(const IType * const input, } void cast_transpose(const Tensor &input, + const Tensor &noop, Tensor *cast_output, Tensor *transposed_output, cudaStream_t stream) { @@ -301,6 +309,22 @@ void cast_transpose(const Tensor &input, CheckOutputTensor(*cast_output, "cast_output"); CheckOutputTensor(*transposed_output, "transposed_output"); + // Number of elements in tensor + auto numel = [] (const Tensor &tensor) -> size_t { + size_t acc = 1; + for (const auto& dim : tensor.data.shape) { + acc *= dim; + } + return acc; + }; + + if (noop.data.dptr != nullptr) { + NVTE_CHECK(numel(noop) == 1, + "Expected 1 element, ", + "but found ", numel(noop), "."); + NVTE_CHECK(noop.data.dtype == DType::kFloat32); + NVTE_CHECK(noop.data.dptr != nullptr); + } NVTE_CHECK(input.data.shape.size() == 2, "Input must have 2 dimensions."); NVTE_CHECK(cast_output->data.shape.size() == 2, "C output must have 2 dimensions."); NVTE_CHECK(transposed_output->data.shape.size() == 2, "T output must have 2 dimensions."); @@ -332,6 +356,7 @@ void cast_transpose(const Tensor &input, (THREADS_PER_WARP + 1) * sizeof(Vec), \ stream>>>( \ reinterpret_cast(input.data.dptr), \ + reinterpret_cast(noop.data.dptr), \ reinterpret_cast(cast_output->data.dptr), \ reinterpret_cast(transposed_output->data.dptr), \ reinterpret_cast(cast_output->scale.dptr), \ @@ -417,7 +442,23 @@ void nvte_cast_transpose(const NVTETensor input, cudaStream_t stream) { NVTE_API_CALL(nvte_cast_transpose); using namespace transformer_engine; + auto noop = Tensor(); + cast_transpose(*reinterpret_cast(input), + noop, + reinterpret_cast(cast_output), + reinterpret_cast(transposed_output), + stream); +} + +void nvte_cast_transpose_with_noop(const NVTETensor input, + const NVTETensor noop, + NVTETensor cast_output, + NVTETensor transposed_output, + cudaStream_t stream) { + NVTE_API_CALL(nvte_cast_transpose_with_noop); + using namespace transformer_engine; cast_transpose(*reinterpret_cast(input), + *reinterpret_cast(noop), reinterpret_cast(cast_output), reinterpret_cast(transposed_output), stream); diff --git a/transformer_engine/common/transpose/rtc/transpose.cu b/transformer_engine/common/transpose/rtc/transpose.cu index 72a1621763..f21014866b 100644 --- a/transformer_engine/common/transpose/rtc/transpose.cu +++ b/transformer_engine/common/transpose/rtc/transpose.cu @@ -22,9 +22,12 @@ constexpr size_t block_size = __BLOCK_SIZE__; __global__ void __launch_bounds__(block_size) transpose_optimized_kernel(const Type * __restrict__ const input, + const float * const noop, Type * __restrict__ const output, const size_t row_length, const size_t num_rows) { + if (noop != nullptr && noop[0] == 1.0f) return; + // Vectorized load/store sizes constexpr size_t nvec_in = load_size / sizeof(Type); constexpr size_t nvec_out = store_size / sizeof(Type); diff --git a/transformer_engine/common/transpose/transpose.cu b/transformer_engine/common/transpose/transpose.cu index f1b8d7a228..3ab83b944b 100644 --- a/transformer_engine/common/transpose/transpose.cu +++ b/transformer_engine/common/transpose/transpose.cu @@ -4,6 +4,7 @@ * See LICENSE for license information. ************************************************************************/ +#include #include #include #include @@ -30,9 +31,12 @@ template __global__ void __launch_bounds__(block_size) transpose_general_kernel(const Type * __restrict__ const input, + const fp32 * const noop, Type * __restrict__ const output, const size_t row_length, const size_t num_rows) { + if (noop != nullptr && noop[0] == 1.0f) return; + // Vectorized load/store sizes constexpr size_t nvec_in = load_size / sizeof(Type); constexpr size_t nvec_out = store_size / sizeof(Type); @@ -124,6 +128,7 @@ transpose_general_kernel(const Type * __restrict__ const input, } void transpose(const Tensor &input, + const Tensor &noop, Tensor *output_, cudaStream_t stream) { Tensor &output = *output_; @@ -140,6 +145,23 @@ void transpose(const Tensor &input, NVTE_CHECK(input.data.dtype == output.data.dtype, "Input and output type must match."); + // Number of elements in tensor + auto numel = [] (const Tensor &tensor) -> size_t { + size_t acc = 1; + for (const auto& dim : tensor.data.shape) { + acc *= dim; + } + return acc; + }; + + if (noop.data.dptr != nullptr) { + NVTE_CHECK(numel(noop) == 1, + "Expected 1 element, ", + "but found ", numel(noop), "."); + NVTE_CHECK(noop.data.dtype == DType::kFloat32); + NVTE_CHECK(noop.data.dptr != nullptr); + } + TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(input.data.dtype, Type, constexpr const char *type_name = TypeInfo::name; constexpr size_t type_size = sizeof(Type); @@ -239,6 +261,7 @@ void transpose(const Tensor &input, rtc_manager.launch(kernel_label, num_blocks(load_size, store_size), block_size, 0, stream, static_cast(input.data.dptr), + static_cast(noop.data.dptr), static_cast(output.data.dptr), row_length, num_rows); } else { // Statically-compiled general kernel @@ -250,6 +273,7 @@ void transpose(const Tensor &input, * DIVUP(num_rows, col_tile_size)); transpose_general_kernel<<>>( static_cast(input.data.dptr), + static_cast(noop.data.dptr), static_cast(output.data.dptr), row_length, num_rows); } @@ -263,7 +287,22 @@ void nvte_transpose(const NVTETensor input, cudaStream_t stream) { NVTE_API_CALL(nvte_transpose); using namespace transformer_engine; + auto noop = Tensor(); + transpose(*reinterpret_cast(input), + noop, + reinterpret_cast(output), + stream); +} + + +void nvte_transpose_with_noop(const NVTETensor input, + const NVTETensor noop, + NVTETensor output, + cudaStream_t stream) { + NVTE_API_CALL(nvte_transpose_with_noop); + using namespace transformer_engine; transpose(*reinterpret_cast(input), + *reinterpret_cast(noop), reinterpret_cast(output), stream); } diff --git a/transformer_engine/pytorch/__init__.py b/transformer_engine/pytorch/__init__.py index e3abfa00fc..4c513339a0 100644 --- a/transformer_engine/pytorch/__init__.py +++ b/transformer_engine/pytorch/__init__.py @@ -14,6 +14,7 @@ from .transformer import TransformerLayer from .fp8 import fp8_autocast from .fp8 import fp8_model_init +from .graph import make_graphed_callables from .export import onnx_export from .distributed import checkpoint from .distributed import CudaRNGStatesTracker diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index f03350eb4e..f57b58d736 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -52,9 +52,14 @@ get_distributed_world_size, get_distributed_rank, checkpoint, + set_all_rng_states, + CudaRNGStatesTracker, + graph_safe_rng_available, ) from transformer_engine.pytorch.export import is_in_onnx_export_mode from transformer_engine.pytorch.jit import jit_fuser, no_torch_dynamo +from transformer_engine.pytorch.graph import is_graph_capturing + _flash_attn_version = packaging.version.Version(version("flash-attn")) _flash_attn_version_required = packaging.version.Version("2.0.6") @@ -2401,10 +2406,13 @@ def __init__( assert (num_attention_heads % self.num_gqa_groups == 0 ), "The number of attention heads must be divisible by the number of GQA groups!" + self.rng_states_tracker = None if sequence_parallel or get_rng_state_tracker is None: attention_dropout_ctx = nullcontext else: - attention_dropout_ctx = get_rng_state_tracker().fork + self.rng_states_tracker = get_rng_state_tracker() + set_all_rng_states(self.rng_states_tracker.get_states()) + attention_dropout_ctx = self.rng_states_tracker.fork norm_factor = math.sqrt(self.hidden_size_per_attention_head) @@ -2648,6 +2656,14 @@ def forward( assert (attn_mask_type in AttnMaskTypes ), f"Attention mask type {attn_mask_type} is not supported!" + if self.rng_states_tracker is not None and is_graph_capturing(): + assert ( + isinstance(self.rng_states_tracker, CudaRNGStatesTracker) + ), "Unsupported RNG states tracker." + assert ( + graph_safe_rng_available() + ), "Upgrade PyTorch version to get RNG manipulation support for cuda graph capture." + if window_size is None: window_size = self.window_size @@ -3695,7 +3711,8 @@ def forward( # =================== projection_output = self.proj( - context_layer, is_first_microbatch=is_first_microbatch + context_layer, + is_first_microbatch=is_first_microbatch, ) if self.return_bias: diff --git a/transformer_engine/pytorch/cpp_extensions/transpose.py b/transformer_engine/pytorch/cpp_extensions/transpose.py index ce18dffca0..3671f2e064 100644 --- a/transformer_engine/pytorch/cpp_extensions/transpose.py +++ b/transformer_engine/pytorch/cpp_extensions/transpose.py @@ -22,19 +22,26 @@ def fp8_cast_transpose_fused( otype: tex.DType, cast_out: Optional[torch.Tensor] = None, transpose_out: Optional[torch.Tensor] = None, + noop_flag: Optional[torch.Tensor] = None, ) -> Union[Tuple[torch.Tensor, torch.Tensor], None]: """Cast + Transpose with FP8 output""" return_outputs = False - if cast_out is None or transpose_out is None: - cast_out = torch.empty_like(inp, dtype=torch.uint8) + if transpose_out is None: transpose_out = torch.empty( inp.shape[1], inp.shape[0], device="cuda", dtype=torch.uint8 ) return_outputs = True + if cast_out is None: + cast_out = torch.empty_like(inp, dtype=torch.uint8) + return_outputs = True + + if noop_flag is None: + noop_flag = torch.Tensor() - tex.fused_cast_transpose( + tex.fused_cast_transpose_noop( inp, + noop_flag, fp8_meta_tensor.scale[fp8_tensor], fp8_meta_tensor.amax_history[0][fp8_tensor], fp8_meta_tensor.scale_inv[fp8_tensor], diff --git a/transformer_engine/pytorch/csrc/comm_gemm_overlap.h b/transformer_engine/pytorch/csrc/comm_gemm_overlap.h index 4e3daf7512..3c039b9a88 100644 --- a/transformer_engine/pytorch/csrc/comm_gemm_overlap.h +++ b/transformer_engine/pytorch/csrc/comm_gemm_overlap.h @@ -157,7 +157,7 @@ struct UbufCommOverlap : torch::CustomClassHolder, UbufBase { } // Catch up the default torch stream - at::cuda::CUDAStream stream_main = at::cuda::getDefaultCUDAStream(); + at::cuda::CUDAStream stream_main = at::cuda::getCurrentCUDAStream(); CHECK_CUDA(cudaEventRecord(_start_comm, (cudaStream_t)stream_main)); CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_comm, _start_comm, 0)); @@ -238,13 +238,10 @@ struct UbufCommOverlap : torch::CustomClassHolder, UbufBase { int ori_sms = _ub_comm->sms; // Catch up the default torch stream - at::cuda::CUDAStream stream_main = at::cuda::getDefaultCUDAStream(); + at::cuda::CUDAStream stream_main = at::cuda::getCurrentCUDAStream(); CHECK_CUDA(cudaEventRecord(_start_compute, stream_main)); - CHECK_CUDA(cudaEventRecord(_stop_comm, _stream_comm)); - for (int i = 0; i < _stream_compute.size(); i++) { - CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_compute[i], _start_compute, 0)); - CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_compute[i], _stop_comm, 0)); - } + CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_compute[0], _start_compute, 0)); + CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_comm, _start_compute, 0)); if (A_scale_inverse.numel()) A_scale_inverse = A_scale_inverse[A_fp8_tensor]; @@ -350,11 +347,12 @@ struct UbufCommOverlap : torch::CustomClassHolder, UbufBase { int ori_sms = _ub_comm->sms; // Catch up the default torch stream - at::cuda::CUDAStream stream_main = at::cuda::getDefaultCUDAStream(); + at::cuda::CUDAStream stream_main = at::cuda::getCurrentCUDAStream(); CHECK_CUDA(cudaEventRecord(_start_compute, stream_main)); for (int i = 0; i < _stream_compute.size(); i++) { CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_compute[i], _start_compute, 0)); } + CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_comm, _start_compute, 0)); if (A_scale_inverse.numel()) A_scale_inverse = A_scale_inverse[A_fp8_tensor]; @@ -469,13 +467,13 @@ struct UbufCommOverlap : torch::CustomClassHolder, UbufBase { output_buf_chunk_ptr += output_chunk_size * _ubuf.element_size(); } } + for (int i = 0; i < _stream_compute.size(); i++) { + CHECK_CUDA( + cudaEventRecord(_stop_compute, (cudaStream_t)_stream_compute[i])); + CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _stop_compute, 0)); + } _ub_comm->sms = ori_sms; - int last_compute_stream_id = - (_num_splits + _stream_compute.size() - 1) % _stream_compute.size(); - CHECK_CUDA( - cudaEventRecord(_stop_compute, (cudaStream_t)_stream_compute[last_compute_stream_id])); CHECK_CUDA(cudaEventRecord(_stop_comm, (cudaStream_t)_stream_comm)); - CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _stop_compute, 0)); CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _stop_comm, 0)); at::cuda::setCurrentCUDAStream(stream_main); @@ -506,7 +504,7 @@ struct UbufCommOverlap : torch::CustomClassHolder, UbufBase { } } - at::cuda::CUDAStream stream_main = at::cuda::getDefaultCUDAStream(); + at::cuda::CUDAStream stream_main = at::cuda::getCurrentCUDAStream(); CHECK_CUDA(cudaEventRecord(_start_d2dcopy, (cudaStream_t)stream_main)); CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_comm, _start_d2dcopy, 0)); CHECK_CUDA(cudaMemcpyAsync(ubuf_ptr, input.data_ptr(), input.numel() * input.element_size(), @@ -805,14 +803,15 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase { if (B_scale_inverse.numel()) B_scale_inverse = B_scale_inverse[B_fp8_tensor]; - at::cuda::CUDAStream stream_main = at::cuda::getDefaultCUDAStream(); + at::cuda::CUDAStream stream_main = at::cuda::getCurrentCUDAStream(); CHECK_CUDA(cudaEventRecord(_start_compute, (cudaStream_t)stream_main)); + CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_send, _start_compute, 0)); + CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_recv, _start_compute, 0)); + for (int i = 0; i < _stream_compute.size(); i++) { + CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_compute[i], _start_compute, 0)); + } if (_aggregate2) { - // Catch up the default torch stream - CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_send, _start_compute, 0)); - CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_recv, _start_compute, 0)); - const int num_steps = _tp_size / 2; char *input_b_ptr = reinterpret_cast(_ubuf.data_ptr()); @@ -877,21 +876,9 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase { CHECK_CUDA(cudaMemcpyAsync(B_copy.data_ptr(), _ubufs[_tp_id].data_ptr(), _ubufs[_tp_id].numel() * _ubufs[_tp_id].element_size(), cudaMemcpyDeviceToDevice, (cudaStream_t)_stream_send)); - CHECK_CUDA(cudaEventRecord(_stop_send, (cudaStream_t)_stream_send)); - CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _stop_send, 0)); } } - at::cuda::setCurrentCUDAStream(stream_main); - int last_compute_stream_id = - (num_steps + _stream_compute.size() - 1) % _stream_compute.size(); - CHECK_CUDA( - cudaEventRecord(_stop_compute, (cudaStream_t)_stream_compute[last_compute_stream_id])); } else { - // Catch up the default torch stream - CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_send, _start_compute, 0)); - CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_recv, _start_compute, 0)); - CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_compute[0], _start_compute, 0)); - for (int i = 0; i < _tp_size; i++) { // Set the userbuffer id. Buffer under send is the input for the current // GEMM chunk The initial input chunk is stored _ubuf[rank]. This is to @@ -936,16 +923,19 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase { CHECK_CUDA(cudaMemcpyAsync(B_copy.data_ptr(), _ubufs[_tp_id].data_ptr(), _ubufs[_tp_id].numel() * _ubufs[_tp_id].element_size(), cudaMemcpyDeviceToDevice, (cudaStream_t)_stream_send)); - CHECK_CUDA(cudaEventRecord(_stop_send, (cudaStream_t)_stream_send)); - CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _stop_send, 0)); } } - at::cuda::setCurrentCUDAStream(stream_main); - int last_compute_stream_id = (_tp_size + _stream_compute.size() - 1) % _stream_compute.size(); + } + for (int i = 0; i < _stream_compute.size(); i++) { CHECK_CUDA( - cudaEventRecord(_stop_compute, (cudaStream_t)_stream_compute[last_compute_stream_id])); + cudaEventRecord(_stop_compute, (cudaStream_t)_stream_compute[i])); + CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _stop_compute, 0)); } - CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _stop_compute, 0)); + CHECK_CUDA(cudaEventRecord(_stop_send, (cudaStream_t)_stream_send)); + CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _stop_send, 0)); + CHECK_CUDA(cudaEventRecord(_stop_recv, (cudaStream_t)_stream_recv)); + CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _stop_recv, 0)); + at::cuda::setCurrentCUDAStream(stream_main); return D; } // split_overlap_ag diff --git a/transformer_engine/pytorch/csrc/common.h b/transformer_engine/pytorch/csrc/common.h index 4096280d17..f6d6bad57f 100644 --- a/transformer_engine/pytorch/csrc/common.h +++ b/transformer_engine/pytorch/csrc/common.h @@ -43,6 +43,7 @@ #include #include #include +#include namespace transformer_engine { diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index d3872c5b75..0887054665 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -223,6 +223,17 @@ void fused_cast_transpose(at::Tensor input, ); +void fused_cast_transpose_noop(at::Tensor input, + at::Tensor noop, + at::Tensor scale, + at::Tensor amax, + at::Tensor scale_inv, + at::Tensor input_cast, + at::Tensor input_transpose, + transformer_engine::DType otype +); + + std::vector fused_cast_transpose_bgrad(at::Tensor grad_output, at::Tensor scale, at::Tensor amax, @@ -263,6 +274,17 @@ at::Tensor fp8_transpose(at::Tensor input, transformer_engine::DType otype ); +void fp8_transpose_noalloc(at::Tensor input, + at::Tensor output, + transformer_engine::DType otype +); + +void fp8_transpose_noalloc_noop(at::Tensor input, + at::Tensor output, + at::Tensor noop, + transformer_engine::DType otype +); + /*************************************************************************************************** * Activations **************************************************************************************************/ @@ -559,16 +581,13 @@ at::Tensor scaled_aligned_causal_masked_softmax_backward(at::Tensor output_grads * FP8 recipe **************************************************************************************************/ -void fused_amax_and_scale_update(const at::Tensor &amax_history, - const at::Tensor &scale, - const at::Tensor &scale_inv, - const at::Tensor &scale_inv_mask, - at::Tensor updated_amax_history, - at::Tensor updated_scale, - at::Tensor updated_scale_inv, - const std::string& amax_compute_algo, - transformer_engine::DType fp8_dtype, - float margin); +void fused_amax_and_scale_update_after_reduction(const at::Tensor &amax_reduction_buffer, + std::vector amax_histories, + std::vector scales, + std::vector scale_invs, + const std::string &amax_compute_algo, + transformer_engine::DType fp8_dtype, + float margin); /*************************************************************************************************** * Rotary positional embedding diff --git a/transformer_engine/pytorch/csrc/extensions/pybind.cpp b/transformer_engine/pytorch/csrc/extensions/pybind.cpp index 328bf1dcb4..4a7d51cada 100644 --- a/transformer_engine/pytorch/csrc/extensions/pybind.cpp +++ b/transformer_engine/pytorch/csrc/extensions/pybind.cpp @@ -42,6 +42,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("rmsnorm_fwd", &rmsnorm_fwd, "RMSNorm FWD"); m.def("rmsnorm_fwd_noalloc", &rmsnorm_fwd_noalloc, "RMSNorm FWD"); m.def("fused_cast_transpose", &fused_cast_transpose, "Fused Cast + Transpose"); + m.def("fused_cast_transpose_noop", &fused_cast_transpose_noop, + "Fused Cast + Transpose with noop option"); m.def("fused_cast_transpose_bgrad", &fused_cast_transpose_bgrad, "Fused Cast + Transpose + BGRAD"); m.def("fused_fp8_transpose_bgrad", &fused_fp8_transpose_bgrad, @@ -67,6 +69,9 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("fused_attn_bwd", &fused_attn_bwd, "Fused Attention FP8/BF16/FP16 BWD with separate Q, K and V"); m.def("fp8_transpose", &fp8_transpose, "Transpose with FP8 I/O"); + m.def("fp8_transpose_noalloc", &fp8_transpose_noalloc, "Transpose with FP8 I/O"); + m.def("fp8_transpose_noalloc_noop", &fp8_transpose_noalloc_noop, + "Transpose with FP8 I/O with noop option."); m.def("gelu", &gelu, "GeLU with FP8 output"); m.def("relu", &relu, "ReLU with FP8 output"); m.def("geglu", &geglu, "GeGLU with FP8 output"); @@ -82,9 +87,9 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("fa_prepare_fwd", &fa_prepare_fwd, "Prepare QKV for Flash Attention"); m.def("fa_prepare_bwd", &fa_prepare_bwd, "Backward of QKV preparation for Flash Attention"); m.def("get_fused_attn_backend", &get_fused_attn_backend, "Get Fused Attention backend"); - m.def("fused_amax_and_scale_update", - &fused_amax_and_scale_update, - "Update amax history and FP8 scale"); + m.def("fused_amax_and_scale_update_after_reduction", + &fused_amax_and_scale_update_after_reduction, + "Update amax history and FP8 scale/scale_inv after reduction"); // fused apply rope m.def("fused_rope_forward", &fused_rope_forward, "Fused Apply RoPE FWD"); diff --git a/transformer_engine/pytorch/csrc/extensions/recipe.cu b/transformer_engine/pytorch/csrc/extensions/recipe.cu index f97d24a011..d5d8e2f7c8 100644 --- a/transformer_engine/pytorch/csrc/extensions/recipe.cu +++ b/transformer_engine/pytorch/csrc/extensions/recipe.cu @@ -11,24 +11,50 @@ #include #include -void fused_amax_and_scale_update(const at::Tensor &amax_history, - const at::Tensor &scale, - const at::Tensor &scale_inv, - const at::Tensor &scale_inv_mask, - at::Tensor updated_amax_history, - at::Tensor updated_scale, - at::Tensor updated_scale_inv, - const std::string& amax_compute_algo, - transformer_engine::DType fp8_dtype, - float margin) { - nvte_delayed_scaling_recipe_amax_and_scale_update( - makeTransformerEngineTensor(amax_history).data(), - makeTransformerEngineTensor(scale).data(), - makeTransformerEngineTensor(scale_inv).data(), - makeTransformerEngineTensor(scale_inv_mask).data(), - makeTransformerEngineTensor(updated_amax_history).data(), - makeTransformerEngineTensor(updated_scale).data(), - makeTransformerEngineTensor(updated_scale_inv).data(), + +void fused_amax_and_scale_update_after_reduction(const at::Tensor &amax_reduction_buffer, + std::vector amax_histories, + std::vector scales, + std::vector scale_invs, + const std::string &amax_compute_algo, + transformer_engine::DType fp8_dtype, + float margin) { + using namespace transformer_engine; + size_t num_tensors = amax_histories.size(); + std::vector t_amax_histories(num_tensors); + std::vector t_scales(num_tensors); + std::vector t_scale_invs(num_tensors); + std::vector te_amax_histories(num_tensors); + std::vector te_scales(num_tensors); + std::vector te_scale_invs(num_tensors); + for (size_t i = 0; i < num_tensors; i++) { + t_amax_histories[i].data.dptr = amax_histories[i].data_ptr(); + auto amax_sizes = amax_histories[i].sizes().vec(); + std::vector amax_shape{amax_sizes.begin(), amax_sizes.end()}; + t_amax_histories[i].data.shape = amax_shape; + t_amax_histories[i].data.dtype = DType::kFloat32; + + t_scales[i].data.dptr = scales[i].data_ptr(); + auto scale_sizes = scales[i].sizes().vec(); + std::vector scale_shape{scale_sizes.begin(), scale_sizes.end()}; + t_scales[i].data.shape = scale_shape; + t_scales[i].data.dtype = DType::kFloat32; + + t_scale_invs[i].data.dptr = scale_invs[i].data_ptr(); + auto scale_inv_sizes = scale_invs[i].sizes().vec(); + std::vector scale_inv_shape{scale_inv_sizes.begin(), scale_inv_sizes.end()}; + t_scale_invs[i].data.shape = scale_inv_shape; + t_scale_invs[i].data.dtype = DType::kFloat32; + + te_amax_histories[i] = reinterpret_cast(&t_amax_histories[i]); + te_scales[i] = reinterpret_cast(&t_scales[i]); + te_scale_invs[i] = reinterpret_cast(&t_scale_invs[i]); + } + nvte_delayed_scaling_recipe_amax_and_scale_update_after_reduction( + makeTransformerEngineTensor(amax_reduction_buffer).data(), + te_amax_histories, + te_scales, + te_scale_invs, amax_compute_algo.c_str(), static_cast(fp8_dtype), margin, diff --git a/transformer_engine/pytorch/csrc/extensions/transpose.cu b/transformer_engine/pytorch/csrc/extensions/transpose.cu index 038e82d955..fc178adeb4 100644 --- a/transformer_engine/pytorch/csrc/extensions/transpose.cu +++ b/transformer_engine/pytorch/csrc/extensions/transpose.cu @@ -32,6 +32,35 @@ void fused_cast_transpose(at::Tensor input, } +void fused_cast_transpose_noop(at::Tensor input, + at::Tensor noop, + at::Tensor scale, + at::Tensor amax, + at::Tensor scale_inv, + at::Tensor input_cast, + at::Tensor input_transpose, + transformer_engine::DType otype +) { + using namespace transformer_engine; + + size_t M = static_cast(input.size(0)); + size_t N = static_cast(input.size(1)); + + auto input_cu = makeTransformerEngineTensor(input); + auto noop_cu = makeTransformerEngineTensor(noop); + auto output_cast_cu = makeTransformerEngineTensor(input_cast.data_ptr(), {M, N}, otype, + amax.data_ptr(), scale.data_ptr(), + scale_inv.data_ptr()); + auto output_transpose_cu = makeTransformerEngineTensor(input_transpose.data_ptr(), {N, M}, otype, + amax.data_ptr(), scale.data_ptr(), + scale_inv.data_ptr()); + + nvte_cast_transpose_with_noop(input_cu.data(), noop_cu.data(), output_cast_cu.data(), + output_transpose_cu.data(), + at::cuda::getCurrentCUDAStream()); +} + + std::vector fused_cast_transpose_bgrad(at::Tensor grad_output, at::Tensor scale, at::Tensor amax, @@ -319,3 +348,39 @@ at::Tensor fp8_transpose(at::Tensor input, return output; } + + +void fp8_transpose_noalloc(at::Tensor input, + at::Tensor output, + transformer_engine::DType otype +) { + using namespace transformer_engine; + + size_t M = static_cast(input.size(0)); + size_t N = static_cast(input.size(1)); + + auto input_cu = makeTransformerEngineTensor(input.data_ptr(), {M, N}, otype); + auto output_cu = makeTransformerEngineTensor(output.data_ptr(), {N, M}, otype); + + nvte_transpose(input_cu.data(), output_cu.data(), at::cuda::getCurrentCUDAStream()); +} + + +void fp8_transpose_noalloc_noop(at::Tensor input, + at::Tensor output, + at::Tensor noop, + transformer_engine::DType otype +) { + using namespace transformer_engine; + + size_t M = static_cast(input.size(0)); + size_t N = static_cast(input.size(1)); + + auto input_cu = makeTransformerEngineTensor(input.data_ptr(), {M, N}, otype); + auto noop_cu = makeTransformerEngineTensor(noop); + auto output_cu = makeTransformerEngineTensor(output.data_ptr(), {N, M}, otype); + + nvte_transpose_with_noop( + input_cu.data(), noop_cu.data(), output_cu.data(), + at::cuda::getCurrentCUDAStream()); +} diff --git a/transformer_engine/pytorch/distributed.py b/transformer_engine/pytorch/distributed.py index 239cecf39b..8d499d88d6 100644 --- a/transformer_engine/pytorch/distributed.py +++ b/transformer_engine/pytorch/distributed.py @@ -5,10 +5,10 @@ """Methods needed for distributed training (DP/TP).""" import warnings from contextlib import contextmanager, AbstractContextManager, ContextDecorator -from typing import Any, Dict, Union, Optional, Callable, Tuple +from typing import Any, Dict, List, Union, Optional, Callable, Tuple import torch -from torch.cuda import _lazy_call +from torch.cuda import _lazy_call, _lazy_init from torch.utils.checkpoint import detach_variable, noop_context_fn from .utils import safely_set_viewless_tensor_data @@ -31,15 +31,60 @@ _FP8_ACTIVATION_RECOMPUTE_PHASE = False -def _set_cuda_rng_state(new_state: torch.Tensor, device: Union[int, str] = -1) -> None: - """Sets the random number generator state of the current GPU. +_ALL_ACTIVE_RNG_STATES = {} + + +def get_all_rng_states() -> bool: + """Returns all generator states used by `CudaRNGStatesTracker`.""" + return _ALL_ACTIVE_RNG_STATES + + +def set_all_rng_states(states: List) -> None: + """Updates all generator states used by `CudaRNGStatesTracker`.""" + global _ALL_ACTIVE_RNG_STATES + _ALL_ACTIVE_RNG_STATES = states + + +def graph_safe_rng_available() -> bool: + """Returns whether cuda graph safe RNG state manipulation is supported.""" + return (hasattr(torch.cuda.CUDAGraph, "register_generator_state") + and hasattr(torch.Generator, "graphsafe_set_state") + and hasattr(torch.Generator, "graphsafe_get_state") + and hasattr(torch.Generator, "clone_state")) + + +def _get_cuda_rng_state( + device: Union[int, str, torch.device] = "cuda", + clone: bool = False, + graph_safe: bool = True, +) -> torch.Tensor: + """Return the random number generator state of the specified GPU.""" + + _lazy_init() + if isinstance(device, str): + device = torch.device(device) + elif isinstance(device, int): + device = torch.device("cuda", device) + idx = device.index + if idx is None: + idx = torch.cuda.current_device() + default_generator = torch.cuda.default_generators[idx] + if graph_safe_rng_available() and graph_safe: + if clone: + # Reference to the cloned generator state + return default_generator.clone_state() + # Reference to the current generator state + return default_generator.graphsafe_get_state() + return default_generator.get_state() + + +def _set_cuda_rng_state( + new_state: torch.Tensor, + device: Union[int, str] = -1, + graph_safe = True, +) -> None: + """Sets the random number generator state of the current GPU.""" - Arguments: - new_state (torch.ByteTensor): The desired state - This function is adapted from PyTorch repo (torch.cuda.set_rng_state) - with a single change: the input state is not cloned. Cloning caused - major performance issues for +4 GPU cases. - """ if device == -1: device = torch.device("cuda") elif isinstance(device, str): @@ -52,6 +97,9 @@ def cb() -> None: if idx is None: idx = torch.cuda.current_device() default_generator = torch.cuda.default_generators[idx] + if graph_safe_rng_available() and graph_safe: + default_generator.graphsafe_set_state(new_state) + return default_generator.set_state(new_state) _lazy_call(cb) @@ -206,7 +254,7 @@ def forward( # Copy the rng states. ctx.fwd_cpu_rng_state = torch.get_rng_state() - ctx.fwd_cuda_rng_state = torch.cuda.get_rng_state() + ctx.fwd_cuda_rng_state = _get_cuda_rng_state(graph_safe=False) if get_rng_state_tracker is not None: ctx.fwd_cuda_rng_state_tracker = get_rng_state_tracker().get_states() @@ -271,13 +319,13 @@ def backward( # Store the current states. bwd_cpu_rng_state = torch.get_rng_state() - bwd_cuda_rng_state = torch.cuda.get_rng_state() + bwd_cuda_rng_state = _get_cuda_rng_state(graph_safe=False) if get_rng_state_tracker is not None: bwd_cuda_rng_state_tracker = get_rng_state_tracker().get_states() # Set the states to what it used to be before the forward pass. torch.set_rng_state(ctx.fwd_cpu_rng_state) - _set_cuda_rng_state(ctx.fwd_cuda_rng_state) + _set_cuda_rng_state(ctx.fwd_cuda_rng_state, graph_safe=False) if get_rng_state_tracker is not None: get_rng_state_tracker().set_states(ctx.fwd_cuda_rng_state_tracker) @@ -291,7 +339,7 @@ def backward( # Set the states back to what it was at the start of this function. torch.set_rng_state(bwd_cpu_rng_state) - _set_cuda_rng_state(bwd_cuda_rng_state) + _set_cuda_rng_state(bwd_cuda_rng_state, graph_safe=False) if get_rng_state_tracker is not None: get_rng_state_tracker().set_states(bwd_cuda_rng_state_tracker) @@ -317,6 +365,7 @@ def backward( ) return (None, None, None, None, None, None) + grads + class _CheckpointFrame: """ Storage frame for forward RNG states and detached activations from the forward recompute. @@ -338,7 +387,7 @@ def cache_rng_states(self, forward=True): """Cache fwd/bwd RNG states in the frame to restore later.""" rng_states = ( torch.get_rng_state(), - torch.cuda.get_rng_state(), + _get_cuda_rng_state(graph_safe=False), ) if self.get_rng_state_tracker is not None: rng_states += (self.get_rng_state_tracker().get_states(), ) @@ -356,7 +405,7 @@ def restore_rng_states(self, forward=True): rng_states = self.bwd_rng_states torch.set_rng_state(rng_states[0]) - _set_cuda_rng_state(rng_states[1]) + _set_cuda_rng_state(rng_states[1], graph_safe=False) if self.get_rng_state_tracker is not None: self.get_rng_state_tracker().set_states(rng_states[2]) @@ -604,6 +653,7 @@ def recompute_fn(*args, **kwargs): return out + class CudaRNGStatesTracker: """ For model parallelism, multiple RNG states need to simultaneously exist in order @@ -664,13 +714,23 @@ def add(self, name: str, seed: int) -> None: # Check that state is not already defined. if name in self.states_: raise Exception(f"cuda rng state {name} already exists") - # Get the current rng state. - orig_rng_state = torch.cuda.get_rng_state() - # Set the new state and store it. - torch.cuda.manual_seed(seed) - self.states_[name] = torch.cuda.get_rng_state() - # Reset rng state to what it was. - _set_cuda_rng_state(orig_rng_state) + + if graph_safe_rng_available(): + new_state = _get_cuda_rng_state(clone=True) + new_state.manual_seed(seed) + self.states_[name] = new_state + # Update global states. + set_all_rng_states(self.states_) + else: + # Get the current rng state. + orig_rng_state = _get_cuda_rng_state() + # Set the new state and store it. + torch.cuda.manual_seed(seed) + self.states_[name] = _get_cuda_rng_state(clone=True) + # Reset rng state to what it was. + _set_cuda_rng_state(orig_rng_state) + # Update global states. + set_all_rng_states(self.states_) @contextmanager def fork(self, name: str = "model-parallel-rng"): @@ -684,16 +744,17 @@ def fork(self, name: str = "model-parallel-rng"): # Check if we have added the state if name not in self.states_: raise Exception(f"cuda rng state {name} is not added") - # Store current rng state. - orig_cuda_rng_state = torch.cuda.get_rng_state() + # Get the reference to current rng state. + orig_cuda_rng_state = _get_cuda_rng_state() # Set rng state to the desired one _set_cuda_rng_state(self.states_[name]) # Do the stuff we wanted to do. try: yield finally: - # Update the current rng state for later use. - self.states_[name] = torch.cuda.get_rng_state() + # this is redundant with graph-safe API + if not graph_safe_rng_available(): + self.states_[name] = _get_cuda_rng_state() # And set the state to the original state we started with. _set_cuda_rng_state(orig_cuda_rng_state) diff --git a/transformer_engine/pytorch/float8_tensor.py b/transformer_engine/pytorch/float8_tensor.py index 8092d2fccd..9923d24a42 100644 --- a/transformer_engine/pytorch/float8_tensor.py +++ b/transformer_engine/pytorch/float8_tensor.py @@ -16,6 +16,7 @@ aten = torch.ops.aten c10d = torch.ops.c10d +updated_fp8_params = {} def _make_fp8_attr_property_funcs(name: str) -> Any: @@ -67,6 +68,31 @@ def backward(ctx, grad): return grad, None +def post_optimizer_step_fwd_amax_reduction(param: Float8Tensor) -> None: + """Amax scale and update when there is at least 1 trainable FP8 parameter.""" + param_id = id(param._data) + + if param_id not in FP8GlobalStateManager.fp8_param_to_autocast: + return + + autocast_key = FP8GlobalStateManager.fp8_param_to_autocast[param_id] + + if autocast_key not in FP8GlobalStateManager.autocast_to_fp8_params: + return + + if autocast_key in updated_fp8_params: + updated_fp8_params[autocast_key].add(param_id) + else: + updated_fp8_params[autocast_key] = {param_id} + + current_fp8_params_set = FP8GlobalStateManager.autocast_to_fp8_params[autocast_key] + # All FP8 trainable parameters have been updated. + if updated_fp8_params[autocast_key] == current_fp8_params_set: + FP8GlobalStateManager.reduce_and_update_fp8_tensors( + forward=True, fp8_weights=True) + del updated_fp8_params[autocast_key] + + class _ToFloat8Func(torch.autograd.Function): """Cast to FP8 from other dtype""" @staticmethod @@ -167,6 +193,7 @@ def backward(ctx, grad): # Assume that we want gradients in full precision return grad, None, None, None, None, None, None, None + class _IdentityFunc(torch.autograd.Function): """Identity function @@ -307,8 +334,9 @@ def __new__( ), f"Unsupported fp8_dtype {fp8_dtype}." self._fp8_dtype: tex.DType = fp8_dtype - # Cached transpose + # Transposed version of `_data`. self._transpose: Optional[Float8Tensor] = None + self._transpose_invalid: bool = True # FP8 scale-inverse self._scale_inv: Optional[torch.Tensor] = fp8_scale_inv @@ -435,80 +463,51 @@ def expand_as(self, other: torch.Tensor): return _IdentityFunc.apply(self) return super().expand_as(other) - def transpose( + def transpose_2d( self, - dim0: int = 0, - dim1: int = 1, *, - update_cache: str | bool = "reuse_only", + cache: bool = False, + noop_flag: Optional[torch.Tensor] = None, ) -> torch.Tensor: """ - Swap tensor dimensions - - For basic 2D matrix transposes, an optimized transpose kernel - is applied and a Float8Tensor is returned. + 2D transpose with caching support. Parameters ---------- - dim0: int, default = 0 - The first dimension to be transposed - dim1: int, default = 1 - The second dimension to be transposed - update_cache: str or bool, default = "reuse_only" - Memoization behavior. Options are - "reuse_only"/`False` (reuse cached value if - available, otherwise calculate transpose without - caching), "force"/`True` (calculate transpose - and cache), "lazy" (reuse cached value if - available, otherwise calculate transpose and - cache if possible). Caching is only supported - for basic 2D transposes and the cache is reset - after any in-place operations. - + cache: bool, default = `False` + Whether or not to cache the transpose. + noop_flag: Optional[torch.Tensor], default = `None` + Only used if argument `cache` is `True`, ignored otherwise. + A single element fp32 tensor with a value of 1.0 or 0.0 + which is treated as a boolean. `1.0` forces recompute + and `0.0` executes a noop using the same kernel. """ + assert self.dim() == 2, f"{self.dim()}-D transpose not supported." - # Check caching mode - if not isinstance(update_cache, str): - update_cache = "force" if update_cache else "reuse_only" - if update_cache not in ("force", "reuse_only", "lazy"): - raise ValueError( - "Supported values for update_cache are " - '"force" (True), "reuse_only" (False), "lazy" ' - f"(got {update_cache})" - ) + # Case: no caching. + if not cache: + return tex.fp8_transpose(self._data, self._fp8_dtype) - # Handle non-2D transposes - if -self.dim() <= dim0 < 0: - dim0 += self.dim() - if -self.dim() <= dim1 < 0: - dim1 += self.dim() - if self.dim() != 2 or dim0 == dim1: - if update_cache == "force": - raise ValueError( - "Transpose caching is only supported for basic 2D transposes " - f"(ndims={self.dim()}, dim0={dim0}, dim1={dim1})" - ) - return super().transpose(dim0, dim1) - - # Clear cache if needed - if update_cache == "force": - self._transpose = None - - # Compute transpose if needed - out = self._transpose - if out is None: - out = Float8Tensor.make_like( - self, - data=tex.fp8_transpose( - self._data.contiguous(), - self._fp8_dtype, - ), - ) + # Case: reuse cache without calling a kernel. + if not self._transpose_invalid and noop_flag is None: + assert self._transpose is not None, "Tranpose cache is empty." + return self._transpose - # Update cache if needed - if update_cache in ("force", "lazy"): - self._transpose = out - return out + # Allocate transpose if needed. + data_2d = self._data.reshape(-1, self._data.shape[-1]) + if self._transpose is None: + shape = (data_2d.shape[1], data_2d.shape[0]) + self._transpose = torch.empty(shape, dtype=torch.uint8, device=self._data.device) + + # Case: recompute transpose and store cache. + if noop_flag is None: + tex.fp8_transpose_noalloc(data_2d, self._transpose, self._fp8_dtype) + else: + # Case: cuda graph capture. + tex.fp8_transpose_noalloc_noop(data_2d, self._transpose, noop_flag, self._fp8_dtype) + + self._transpose_invalid = False + return self._transpose @torch.no_grad() def reset_fp8_meta_scale_inv(self) -> None: @@ -519,13 +518,11 @@ def reset_fp8_meta_scale_inv(self) -> None: the tensor. """ - if self._fp8_meta is None: - return + assert self._fp8_meta is not None, "FP8 meta tensors not found." fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key( forward=self._fp8_meta_forward, ) - scale_inv = self._fp8_meta[fp8_meta_key].scale_inv[self._fp8_meta_index] - scale_inv.view(1).copy_(self._scale_inv.view(1)) + self._fp8_meta[fp8_meta_key].scale_inv[self._fp8_meta_index].copy_(self._scale_inv[0]) def to_dtype(self, dtype: torch.dtype) -> Float8Tensor: """Create `Float8Tensor` with given nominal dtype @@ -541,12 +538,11 @@ def to_dtype(self, dtype: torch.dtype) -> Float8Tensor: ) def _reset_caches(self) -> None: - """Reset cached values - + """ + Set transpose cache as invalid. Should be called after any in-place operation. - """ - self._transpose = None + self._transpose_invalid = True @classmethod def __torch_dispatch__(cls, func, types, args, kwargs=None): @@ -574,7 +570,7 @@ def __torch_dispatch__(cls, func, types, args, kwargs=None): # Directly copy FP8 data if possible if dst._fp8_dtype == src._fp8_dtype: dst._data.copy_(src._data) - dst._scale_inv = src._scale_inv.clone() + dst._scale_inv.copy_(src._scale_inv.detach()) if dst._fp8_meta is not None: if src._fp8_meta is None: src_min, src_max = src.from_float8().aminmax() @@ -600,7 +596,6 @@ def __torch_dispatch__(cls, func, types, args, kwargs=None): dst.copy_(src.from_float8()) elif dst_is_fp8 and not src_is_fp8: - # Make sure input is in expected format src = src.expand(dst.size()) src = src.to( @@ -619,7 +614,7 @@ def __torch_dispatch__(cls, func, types, args, kwargs=None): fp8_meta_index = dst._fp8_meta_index scale = dst._fp8_meta[fp8_meta_key].scale[fp8_meta_index] amax = dst._fp8_meta[fp8_meta_key].amax_history[0][fp8_meta_index] - dst._scale_inv = scale.detach().view(1).reciprocal() + dst._scale_inv.copy_(scale.detach().reciprocal()) # Cast to FP8 if not dst._data.is_contiguous(): @@ -633,6 +628,9 @@ def __torch_dispatch__(cls, func, types, args, kwargs=None): dst._fp8_dtype, ) + # This branch is where the FP8 parameters are updated in-place during optimization. + # Handle forward amax reduction. + post_optimizer_step_fwd_amax_reduction(dst) else: # Invalid case @@ -641,6 +639,7 @@ def __torch_dispatch__(cls, func, types, args, kwargs=None): # Nothing to return for in-place ops if dst_is_fp8: dst._reset_caches() + return None # Slice op @@ -764,6 +763,7 @@ def _set_data(self, tensor: torch.Tensor) -> None: _fp8_meta_index = property(**_make_fp8_attr_property_funcs("fp8_meta_index")) _fp8_dtype = property(**_make_fp8_attr_property_funcs("dtype")) _transpose = property(**_make_fp8_attr_property_funcs("transpose")) + _transpose_invalid = property(**_make_fp8_attr_property_funcs("transpose_invalid")) _scale_inv = property(**_make_fp8_attr_property_funcs("scale_inv")) # Do not force the Float8Tensor type on the returned tensor diff --git a/transformer_engine/pytorch/fp8.py b/transformer_engine/pytorch/fp8.py index bbeea13af3..e821bfe11d 100644 --- a/transformer_engine/pytorch/fp8.py +++ b/transformer_engine/pytorch/fp8.py @@ -51,6 +51,17 @@ def get_fp8_te_dtype( return tex.DType.kFloat8E5M2 +def get_fp8_max( + fp8_recipe: DelayedScaling, fprop_tensor: bool = True +) -> tex.DType: + """Get max representible FP8 value.""" + if fp8_recipe.fp8_format == Format.E4M3 or ( + fp8_recipe.fp8_format == Format.HYBRID and fprop_tensor + ): + return Format.E4M3.value.max_fwd + return Format.E5M2.value.max_fwd + + class FP8GlobalStateManager: """Class to keep track of and manipulate the global FP8 state at different stages of execution. @@ -61,20 +72,21 @@ class FP8GlobalStateManager: FP8_DISTRIBUTED_GROUP = None FP8_PARAMETERS = False IS_FIRST_FP8_MODULE = False - FP8_AUTOCAST_COUNTER = 0 - FP8_CURRENT_CONTEXT_ID = 0 + FP8_GRAPH_CAPTURING = False FP8_AUTOCAST_DEPTH = 0 - global_fp8_buffer = {} + global_amax_buffer = {} + global_amax_history_buffer = {} + global_scale_buffer = {} + global_scale_inv_buffer = {} fp8_tensors_recompute_buffer = [] - amax_forward_global_reduce_func = None - buffer_delete_key_fwd = None - buffer_delete_key_bwd = None - amax_reduce_handle_fwd = None fp8_available = None reason_for_no_fp8 = "" - dp_amax_reduce_interval = None - dp_amax_reduce_forward_idx = 0 - dp_amax_reduce_backward_idx = 0 + multi_grad_hook_tensors = [] + bwd_amax_update_hook_registered = False + autocast_arguments = {} + autocast_to_fp8_params = {} + fp8_param_to_autocast = {} + skip_fp8_weight_update_tensor = None @classmethod def reset(cls) -> None: @@ -83,21 +95,35 @@ def reset(cls) -> None: cls.FP8_CALIBRATION = False cls.FP8_RECIPE = None cls.FP8_DISTRIBUTED_GROUP = None + cls.FP8_PARAMETERS = False cls.IS_FIRST_FP8_MODULE = False - cls.FP8_AUTOCAST_COUNTER = 0 - cls.FP8_CURRENT_CONTEXT_ID = 0 + cls.FP8_GRAPH_CAPTURING = False cls.FP8_AUTOCAST_DEPTH = 0 - cls.global_fp8_buffer = {} + cls.global_amax_buffer = {} + cls.global_amax_history_buffer = {} + cls.global_scale_buffer = {} + cls.global_scale_inv_buffer = {} cls.fp8_tensors_recompute_buffer = [] - cls.amax_forward_global_reduce_func = None - cls.buffer_delete_key_fwd = None - cls.buffer_delete_key_bwd = None - cls.amax_reduce_handle_fwd = None cls.fp8_available = None cls.reason_for_no_fp8 = "" - cls.dp_amax_reduce_interval = None - cls.dp_amax_reduce_forward_idx = 0 - cls.dp_amax_reduce_backward_idx = 0 + cls.multi_grad_hook_tensors = [] + cls.bwd_amax_update_hook_registered = False + cls.autocast_arguments = {} + cls.autocast_to_fp8_params = {} + cls.fp8_param_to_autocast = {} + cls.skip_fp8_weight_update_tensor = None + + @classmethod + def set_skip_fp8_weight_update_tensor(cls, skip: bool) -> None: + """`skip_fp8_weight_update_tensor` inplace setter.""" + if cls.skip_fp8_weight_update_tensor is None: + cls.skip_fp8_weight_update_tensor = torch.empty(1, dtype=torch.float32, device="cuda") + cls.skip_fp8_weight_update_tensor.fill_(skip) + + @classmethod + def get_skip_fp8_weight_update_tensor(cls) -> None: + """`skip_fp8_weight_update_tensor` getter.""" + return cls.skip_fp8_weight_update_tensor @classmethod def is_fp8_available(cls) -> Tuple[bool, str]: @@ -106,44 +132,6 @@ def is_fp8_available(cls) -> Tuple[bool, str]: cls.fp8_available, cls.reason_for_no_fp8 = check_fp8_support() return cls.fp8_available, cls.reason_for_no_fp8 - @classmethod - def get_global_fp8_state_checkpoint(cls) -> Dict[str, Union[int, str]]: - """Returns global fp8 state variables.""" - # Convert attributes to dictionary to make future proof against - # changes in global state variables in order to make setting the - # checkpoint backwards compatible. - global_fp8_state = {} - global_fp8_state["FP8_AUTOCAST_COUNTER"] = cls.FP8_AUTOCAST_COUNTER - global_fp8_state["FP8_CURRENT_CONTEXT_ID"] = cls.FP8_CURRENT_CONTEXT_ID - global_fp8_state["FP8_AUTOCAST_DEPTH"] = cls.FP8_AUTOCAST_DEPTH - global_fp8_state["buffer_delete_key_fwd"] = cls.buffer_delete_key_fwd - global_fp8_state["buffer_delete_key_bwd"] = cls.buffer_delete_key_bwd - global_fp8_state["dp_amax_reduce_interval"] = cls.dp_amax_reduce_interval - global_fp8_state["dp_amax_reduce_forward_idx"] = cls.dp_amax_reduce_forward_idx - global_fp8_state["dp_amax_reduce_backward_idx"] = cls.dp_amax_reduce_backward_idx - return global_fp8_state - - @classmethod - def set_global_fp8_state_checkpoint(cls, state: Dict[str, Union[int, str]]) -> None: - """Sets global fp8 state variables.""" - for k, v in state.items(): - if hasattr(cls, k): - setattr(cls, k, v) - - @classmethod - def get_global_fp8_buffer_checkpoint(cls) -> Dict[str, List[torch.Tensor]]: - """Returns global fp8 amax buffer.""" - return cls.global_fp8_buffer - - @classmethod - def set_global_fp8_buffer_checkpoint(cls, buffer: Dict[str, List[torch.Tensor]]) -> None: - """Sets global fp8 amax buffer.""" - # Map all tensors back to GPU. - for k, v in buffer.items(): - buffer[k] = [tensor.cuda() for tensor in v] - - cls.global_fp8_buffer = buffer - @staticmethod def get_meta_tensor_key(forward: bool = True) -> str: """Returns scaling key in `fp8_meta`.""" @@ -152,121 +140,102 @@ def get_meta_tensor_key(forward: bool = True) -> str: return "scaling_bwd" @staticmethod - def get_buffer_position_key(forward: bool = True) -> str: - """Returns module position key in `fp8_meta`.""" - if forward: - return "global_fp8_buffer_pos_fwd" - return "global_fp8_buffer_pos_bwd" - - @staticmethod - def get_autocast_key(forward: bool = True) -> str: - """Returns module position key in `fp8_meta`.""" - if forward: - return "autocast_id_fwd" - return "autocast_id_bwd" - - @staticmethod - def get_amax_buffer_key(fp8_meta: Dict[str, Any], forward: bool = True) -> str: - """Return a key in `_global_fp8_buffer` for the AMAX storage.""" - if forward: - return f"FWD_AMAX_{fp8_meta['autocast_id_fwd']}" - return f"BWD_AMAX_{fp8_meta['autocast_id_bwd']}" + def get_fwd_bwd_key(forward: bool = True) -> str: + """Convert bool `forward` to string.""" + return "forward" if forward else "backward" @classmethod - def get_amax_reduce_handle_fwd(cls) -> Union[bool, None]: - """Return AMAX reduction wait handle of forward prop.""" - return cls.amax_reduce_handle_fwd + def get_buffer_info(cls) -> str: + """ + Returns a key for `fp8_meta` that stores the module's index + in the global buffers along with autocast information. + """ + return "buffer_index_and_autocast_key" @classmethod - def setup_amax_forward_global_reduce_func(cls, f: Callable) -> None: - """Sets up the function to call during autocast exit.""" - cls.amax_forward_global_reduce_func = f + def get_key_in_buffer( + cls, + forward: bool, + fp8_weights: bool, + fp8_recipe: DelayedScaling, + fp8_group: dist_group_type, + ) -> str: + """Returns a key into the global FP8 buffers.""" + autocast_key = cls.get_unique_autocast_key(fp8_recipe, fp8_group) + fwd_bwd_key = cls.get_fwd_bwd_key(forward) + return f"{fwd_bwd_key}_{fp8_weights}_{autocast_key}" @classmethod - def add_amax_to_global_buffer(cls, fp8_meta: Dict[str, Any], forward: bool = True) -> None: - """Append 1D tensor `amax` to global buffer.""" - buffer_key = cls.get_amax_buffer_key(fp8_meta, forward=forward) - fp8_meta_tensor_key = cls.get_meta_tensor_key(forward=forward) - buffer_position_key = cls.get_buffer_position_key(forward=forward) - - if buffer_key not in cls.global_fp8_buffer: - cls.global_fp8_buffer[buffer_key] = [fp8_meta[fp8_meta_tensor_key].amax_history[0]] - else: - cls.global_fp8_buffer[buffer_key].append( - fp8_meta[fp8_meta_tensor_key].amax_history[0] - ) - - if buffer_position_key not in fp8_meta: - fp8_meta[buffer_position_key] = len(cls.global_fp8_buffer[buffer_key]) - 1 - - # Catch incorrect fp8_autocast usage. - assert fp8_meta[buffer_position_key] == len(cls.global_fp8_buffer[buffer_key]) - 1, \ - "Same module is being invoked more than once inside an `fp8_autocast` " \ - "region when using FP8 with amax reduction. This behavior is currently" \ - " unsupported. For more details and correct usage, please see " \ - "https://github.com/NVIDIA/TransformerEngine/pull/93." + def split_key_in_buffer(cls, key: str) -> Tuple[bool, bool, str]: + """Splits buffer key into relevant parts.""" + forward, fp8_weights, autocast_key = key.split("_", 2) + forward = forward == "forward" + fp8_weights = fp8_weights == "True" + return forward, fp8_weights, autocast_key @classmethod - def copy_amax_from_global_buffer( - cls, fp8_meta: Dict[str, Any], forward: bool = True + def add_fp8_tensors_to_global_buffer( + cls, + fp8_meta: Dict[str, Any], + fp8_weights: Optional[List[torch.Tensor]] = None, ) -> None: - """Populate current amax with the correct location from buffer.""" - fp8_meta_tensor_key = cls.get_meta_tensor_key(forward=forward) - buffer_position_key = cls.get_buffer_position_key(forward=forward) - if buffer_position_key not in fp8_meta: - return - - amax_buffer_key = cls.get_amax_buffer_key(fp8_meta, forward=forward) - assert amax_buffer_key in cls.global_fp8_buffer, "TE internal error." - - fp8_meta[fp8_meta_tensor_key].amax_history[0] = cls.global_fp8_buffer[amax_buffer_key][ - fp8_meta[buffer_position_key] - ] + """ + The amax reduction process happens completely outside the FP8 modules. + To participate in the reduction, the only role played by a module is + to call this function in order to append it's FP8 tensor into a global + buffer. There are 5 global buffers maintained, one each for amax, amax + history, scale, scale-inverse, and non-weight-mask. Each buffer has + keys that hold FP8 tensors. Keys have a `forward_` or `backward_` prefix + to indicate the type of FP8 tensor, since the forward and backward + reductions happen separately. + + Note: For CG capture, this method is called from the graphed + wrapper. For non CG case, it's called from within the module. + """ - @classmethod - def set_amax_buffer_key_deletion( - cls, fp8_meta: Dict[str, Any], forward: bool = True - ) -> None: - """Delete this amax key from global buffer during autocast end.""" - if cls.get_autocast_key(forward=forward) not in fp8_meta: + # Every module must call this function exactly once since + # the amax tensors are static. Ensures that compatibility + # with non-graphed modules is maintained. + index_in_buffer = cls.get_buffer_info() # Same index for fwd/bwd fp8 tensors. + if index_in_buffer in fp8_meta: return - if forward: - cls.buffer_delete_key_fwd = cls.get_amax_buffer_key(fp8_meta, forward=forward) - else: - cls.buffer_delete_key_bwd = cls.get_amax_buffer_key(fp8_meta, forward=forward) - - @classmethod - def delete_key_from_amax_buffer(cls, forward: bool = True) -> None: - """Delete the key from global amax buffer.""" - if forward: - if ( - cls.buffer_delete_key_fwd is not None - and cls.buffer_delete_key_fwd in cls.global_fp8_buffer - ): - del cls.global_fp8_buffer[cls.buffer_delete_key_fwd] - else: - if ( - cls.buffer_delete_key_bwd is not None - and cls.buffer_delete_key_bwd in cls.global_fp8_buffer - ): - del cls.global_fp8_buffer[cls.buffer_delete_key_bwd] - @classmethod - def get_fp8_context_id(cls) -> int: - """Returns an ID for the current FP8 context.""" - return cls.FP8_CURRENT_CONTEXT_ID - - @classmethod - def set_fp8_context_id(cls, ctx_id: int) -> None: - """Sets the current FP8 context.""" - cls.FP8_CURRENT_CONTEXT_ID = ctx_id - - @classmethod - def new_fp8_context_id(cls) -> int: - """Returns global autocast counter as a proxy to be used - as the autocast ID for FP8 modules. - """ - return cls.FP8_AUTOCAST_COUNTER + fp8_meta[index_in_buffer] = [] + for forward in (True, False): + # This algorithm creates a two-way map with `autocast_to_fp8_params` and + # `fp8_param_to_autocast`. This is used for keeping track of FP8 weights + # in an autocasted region and cross reference them in `float8_tensor.py` + # to perform the forward amax reduction. + if forward and fp8_weights is not None: + autocast_key = cls.get_unique_autocast_key( + fp8_meta["recipe"], fp8_meta["fp8_group"]) + fp8_weight_set = {id(w._data) for w in fp8_weights} + if autocast_key not in cls.autocast_to_fp8_params: + cls.autocast_to_fp8_params[autocast_key] = fp8_weight_set + else: + cls.autocast_to_fp8_params[autocast_key] = ( + cls.autocast_to_fp8_params[autocast_key].union(fp8_weight_set)) + # Identify correct autocast key for a given param. + for w in fp8_weight_set: + cls.fp8_param_to_autocast[w] = autocast_key + + key = cls.get_key_in_buffer( + forward, fp8_weights is not None, fp8_meta["recipe"], fp8_meta["fp8_group"]) + fp8_meta_tensor_key = cls.get_meta_tensor_key(forward=forward) + + if key not in cls.global_amax_buffer: + cls.global_amax_buffer[key] = [fp8_meta[fp8_meta_tensor_key].amax_history[0]] + cls.global_amax_history_buffer[key] = [fp8_meta[fp8_meta_tensor_key].amax_history] + cls.global_scale_buffer[key] = [fp8_meta[fp8_meta_tensor_key].scale] + cls.global_scale_inv_buffer[key] = [fp8_meta[fp8_meta_tensor_key].scale_inv] + else: + cls.global_amax_buffer[key].append(fp8_meta[fp8_meta_tensor_key].amax_history[0]) + cls.global_amax_history_buffer[key].append( + fp8_meta[fp8_meta_tensor_key].amax_history) + cls.global_scale_buffer[key].append(fp8_meta[fp8_meta_tensor_key].scale) + cls.global_scale_inv_buffer[key].append(fp8_meta[fp8_meta_tensor_key].scale_inv) + fp8_meta[index_in_buffer].append(len(cls.global_amax_buffer[key]) - 1) + fp8_meta[index_in_buffer].append(key) @classmethod def is_fp8_enabled(cls) -> bool: @@ -283,6 +252,11 @@ def with_fp8_parameters(cls) -> bool: """Should the parameters be stored as FP8""" return cls.FP8_PARAMETERS + @classmethod + def fp8_graph_capturing(cls) -> bool: + """Is CUDA graph capture under way?""" + return cls.FP8_GRAPH_CAPTURING or torch.cuda.is_current_stream_capturing() + @classmethod def is_first_fp8_module(cls): """Returns `True` only the first time when called multiple @@ -310,7 +284,8 @@ def get_fp8_autocast_state(cls) -> Tuple[bool, bool, DelayedScaling, dist_group_ cls.FP8_CALIBRATION, cls.FP8_RECIPE, cls.FP8_DISTRIBUTED_GROUP, - cls.IS_FIRST_FP8_MODULE) + cls.IS_FIRST_FP8_MODULE, + cls.FP8_GRAPH_CAPTURING) @classmethod def set_fp8_autocast_state( @@ -322,80 +297,100 @@ def set_fp8_autocast_state( cls.FP8_CALIBRATION, cls.FP8_RECIPE, cls.FP8_DISTRIBUTED_GROUP, - cls.IS_FIRST_FP8_MODULE) = fp8_state + cls.IS_FIRST_FP8_MODULE, + cls.FP8_GRAPH_CAPTURING) = fp8_state @staticmethod def reduce_tensor_across_group_op_max( - tensor: torch.Tensor, group: dist_group_type, async_op: bool + tensor: torch.Tensor, group: dist_group_type ) -> None: """Reduce tensor across given group.""" if torch.distributed.is_initialized(): - wait_handle = torch.distributed.all_reduce( + torch.distributed.all_reduce( tensor, op=torch.distributed.ReduceOp.MAX, group=group, - async_op=async_op, + async_op=False, ) - return wait_handle - return None @classmethod - def global_amax_reduction( + def reduce_and_update_fp8_tensors( cls, - fp8_meta: Dict[str, Any], - tp_group: dist_group_type, - tp_size: int, forward: bool = True, + fp8_weights: bool = False, ) -> None: """Concatenate, reduce, and split amaxes in the global buffer.""" - amax_buffer_key = cls.get_amax_buffer_key(fp8_meta, forward=forward) - - # Key already deleted. - if amax_buffer_key not in cls.global_fp8_buffer: - return None - - # Reduce AMAX in DP-domain at an interval. - # `NVTE_DP_AMAX_REDUCE_INTERVAL` should be set as an integer value larger than 0. If - # `NVTE_DP_AMAX_REDUCE_INTERVAL` is set to 0, AMAX is reduced only in TP domain. - if cls.dp_amax_reduce_interval is None: - cls.dp_amax_reduce_interval = int(os.getenv("NVTE_DP_AMAX_REDUCE_INTERVAL", "1")) - - if cls.dp_amax_reduce_interval == 0: - tp_amax_reduce = True - else: - tp_amax_reduce = False - if forward: - if cls.dp_amax_reduce_forward_idx == 0: - reduce_group = fp8_meta["fp8_group"] - else: - tp_amax_reduce = True - cls.dp_amax_reduce_forward_idx = ( - (cls.dp_amax_reduce_forward_idx + 1) % cls.dp_amax_reduce_interval) + for buffer_key, amax_buffer in cls.global_amax_buffer.items(): + # Check for forward or backward reduction. + fwd_update, fp8_weights_update, autocast_key = cls.split_key_in_buffer(buffer_key) + if fwd_update != forward: + continue + # Only skip a forward update when `fp8_weights` is explicitly set to `True` + # (inside optimizer) and the current key is not an `fp8_weight_update` key. + # For other cases, we need to reduce because of activation tensors. + # TODO(ksivaman) consider separate weight and activation fp8_tensors. + if fwd_update and fp8_weights and not fp8_weights_update: + continue + if len(amax_buffer) == 0: + continue + + # Retrieve autocast specific args and concat amaxes. + recipe, group = cls.autocast_arguments[autocast_key] + contiguous_amax = torch.cat(amax_buffer) + + # Reduction. + if (recipe.reduce_amax + and torch.distributed.is_initialized() + and torch.distributed.get_world_size(group=group) > 1): + cls.reduce_tensor_across_group_op_max(contiguous_amax, group) + + # Amax and scale update. + unfused_update = (bool(int(os.getenv("NVTE_UNFUSED_FP8_UPDATE", "0"))) + or callable(recipe.amax_compute_algo) + or callable(recipe.scaling_factor_compute_algo)) + + if not unfused_update: + tex.fused_amax_and_scale_update_after_reduction( + contiguous_amax, + cls.global_amax_history_buffer[buffer_key], + cls.global_scale_buffer[buffer_key], + cls.global_scale_inv_buffer[buffer_key], + recipe.amax_compute_algo, + get_fp8_te_dtype(recipe, forward), + recipe.margin, + ) else: - if cls.dp_amax_reduce_backward_idx == 0: - reduce_group = fp8_meta["fp8_group"] - else: - tp_amax_reduce = True - cls.dp_amax_reduce_backward_idx = ( - (cls.dp_amax_reduce_backward_idx + 1) % cls.dp_amax_reduce_interval) + split_and_copy(contiguous_amax, amax_buffer, [x.numel() for x in amax_buffer]) - if tp_amax_reduce: - if tp_size > 1: - reduce_group = tp_group - else: - return None + for amax_history, scale, scale_inv in zip( + cls.global_amax_history_buffer[buffer_key], + cls.global_scale_buffer[buffer_key], + cls.global_scale_inv_buffer[buffer_key], + ): + _amax_and_scale_update( + amax_history, scale, scale_inv, get_fp8_max(recipe, forward), recipe) - chunk_sizes = [x.numel() for x in cls.global_fp8_buffer[amax_buffer_key]] - contiguous_amax = torch.cat(cls.global_fp8_buffer[amax_buffer_key]) + @classmethod + def add_tensor_for_bwd_reduction_multi_grad_hook(cls, tensor): + """Add tensor to list for multi grad hook.""" + cls.multi_grad_hook_tensors.append(tensor) - wait_handle = cls.reduce_tensor_across_group_op_max( - contiguous_amax, - reduce_group, - fp8_meta["async_amax_reduction"], - ) + @classmethod + def hook_for_bwd_amax_reduction(cls, grads: Tuple[torch.Tensor]) -> None: # pylint: disable=unused-argument + """Executes at the end of backward pass.""" + cls.reduce_and_update_fp8_tensors(forward=False) - cls.global_fp8_buffer[amax_buffer_key] = list(contiguous_amax.split(chunk_sizes)) - return wait_handle + @classmethod + def get_unique_autocast_key( + cls, + recipe: Optional[DelayedScaling] = None, + group: Optional[dist_group_type] = None, + ): + """ + For FP8, each autocast can be uniquely identified by the recipe and fp8 group. + Safely using `hash` as we never cross checkpoint boundaries. + """ + return f"{str(recipe)}:{hash(group)}" @classmethod def fp8_autocast_enter( @@ -404,21 +399,29 @@ def fp8_autocast_enter( calibrating: bool = False, fp8_recipe: Optional[DelayedScaling] = None, fp8_group: Optional[dist_group_type] = None, + _graph: bool = False, ) -> None: """Set state and tracking variables for entry into FP8 region.""" - if cls.FP8_AUTOCAST_DEPTH == 0: - if callable(cls.amax_forward_global_reduce_func): - cls.amax_reduce_handle_fwd = cls.amax_forward_global_reduce_func() # pylint: disable=not-callable - cls.delete_key_from_amax_buffer(forward=True) + + fp8_recipe = get_default_fp8_recipe() if fp8_recipe is None else fp8_recipe + autocast_key = cls.get_unique_autocast_key(fp8_recipe, fp8_group) + cls.autocast_arguments[autocast_key] = (fp8_recipe, fp8_group) + + if enabled and cls.FP8_AUTOCAST_DEPTH == 0 and not _graph and torch.is_grad_enabled(): + if not cls.bwd_amax_update_hook_registered and len(cls.multi_grad_hook_tensors) > 0: + # This hook does not fire for graphed modules. + torch.autograd.graph.register_multi_grad_hook( + tuple(cls.multi_grad_hook_tensors), cls.hook_for_bwd_amax_reduction) + cls.bwd_amax_update_hook_registered = True cls.FP8_ENABLED = enabled cls.FP8_CALIBRATION = calibrating - cls.FP8_RECIPE = get_default_fp8_recipe() if fp8_recipe is None else fp8_recipe + cls.FP8_RECIPE = fp8_recipe cls.FP8_DISTRIBUTED_GROUP = fp8_group + cls.FP8_GRAPH_CAPTURING = _graph if cls.FP8_AUTOCAST_DEPTH == 0: cls.IS_FIRST_FP8_MODULE = True - cls.FP8_AUTOCAST_COUNTER += 1 cls.FP8_AUTOCAST_DEPTH += 1 if enabled: @@ -426,9 +429,14 @@ def fp8_autocast_enter( assert fp8_available, reason_for_no_fp8 @classmethod - def fp8_autocast_exit(cls): + def fp8_autocast_exit(cls, enabled: bool, _graph: bool) -> None: """Set state and tracking variables for exit from FP8 region.""" cls.FP8_AUTOCAST_DEPTH -= 1 + # Reduce only the non-FP8 weight modules here. + # FP8 weight modules are reduced at the end of the optimizer + # step after the weight amax is populated. + if enabled and cls.FP8_AUTOCAST_DEPTH == 0 and not _graph and torch.is_grad_enabled(): + cls.reduce_and_update_fp8_tensors(forward=True, fp8_weights=False) @classmethod def copy_forward_fp8_meta_tensors_for_recompute(cls, fp8_meta: Dict[str, Any]) -> None: @@ -525,6 +533,7 @@ def fp8_autocast( calibrating: bool = False, fp8_recipe: Optional[DelayedScaling] = None, fp8_group: Optional[dist_group_type] = None, + _graph: bool = False, ) -> None: """ Context manager for FP8 usage. @@ -568,23 +577,25 @@ def fp8_autocast( FP8GlobalStateManager.fp8_autocast_enter(enabled=enabled, calibrating=calibrating, fp8_recipe=fp8_recipe, - fp8_group=fp8_group) + fp8_group=fp8_group, + _graph=_graph) yield finally: FP8GlobalStateManager.set_fp8_autocast_state(fp8_state) # pylint: disable=used-before-assignment - FP8GlobalStateManager.fp8_autocast_exit() + FP8GlobalStateManager.fp8_autocast_exit(enabled, _graph=_graph) def _update_amax_history(amax_history: torch.Tensor) -> torch.Tensor: """Update amax history and set next amax to zero.""" if amax_history.shape[0] > 1: - amax_history = torch.roll(amax_history, -1, 0) + new_amax_history = torch.roll(amax_history, -1, 0) + amax_history.copy_(new_amax_history) amax_history[0].fill_(0.0) return amax_history @torch.jit.script -def _default_get_amax( +def _default_get_amax_and_update_history( amax_history: torch.Tensor, amax_compute_algo: str, ) -> Tuple[torch.Tensor, torch.Tensor]: @@ -609,63 +620,23 @@ def _default_sf_compute( sf = (fp8_max / amax) / (2 ** margin) sf = torch.where(amax > 0.0, sf, scale) sf = torch.where(torch.isfinite(amax), sf, scale) - return sf - - -@jit_fuser -def _compute_scaling_factor_inverse( - scale: torch.Tensor, - scale_inv: torch.Tensor, - non_weight_mask: torch.Tensor, - update_weight_scale_inv: bool, -) -> torch.Tensor: - """Compute inverse of scaling factor.""" - if update_weight_scale_inv: - return 1.0 / scale - return torch.where(non_weight_mask, 1.0 / scale, scale_inv) - - -def _fused_amax_and_scale_update( - amax_history: torch.Tensor, - scale: torch.Tensor, - scale_inv: torch.Tensor, - fp8_dtype: tex.DType, - margin: int, - amax_compute_algo: str, - non_weight_mask: torch.Tensor, - update_weight_scale_inv: bool, -) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - """Update amax history and FP8 scaling factors""" - if update_weight_scale_inv: - non_weight_mask = torch.Tensor() - tex.fused_amax_and_scale_update( - amax_history, - scale, - scale_inv, - non_weight_mask, - amax_history, - scale, - scale_inv, - amax_compute_algo, - fp8_dtype, - margin, - ) - return amax_history, scale, scale_inv + scale.copy_(sf) + return scale -def _compute_amax( +def _compute_amax_and_update_history( amax_history: torch.Tensor, - recipe: DelayedScaling, + amax_compute_algo: Union[Callable, str], ) -> Tuple[torch.Tensor, torch.Tensor]: """Obtain the amax from the history.""" - if callable(recipe.amax_compute_algo): - amax = recipe.amax_compute_algo(amax_history) + if callable(amax_compute_algo): + amax = amax_compute_algo(amax_history) amax_history = _update_amax_history(amax_history) return amax_history, amax - return _default_get_amax( + return _default_get_amax_and_update_history( amax_history, - recipe.amax_compute_algo, + amax_compute_algo, ) @@ -687,46 +658,29 @@ def _compute_scaling_factor( return recipe.scaling_factor_compute_algo(amax, scale, fp8_max, recipe) -def amax_and_scale_update( - fp8_meta: Dict[str, Any], - fwd_update: bool, - update_weight_scale_inv: bool = True, +def _amax_and_scale_update( + amax_history: torch.Tensor, + scale: torch.Tensor, + scale_inv: torch.Tensor, + fp8_max: float, + recipe: DelayedScaling, ) -> None: - """Updates fp8 amaxes/scales for fwd | bwd.""" - amax_compute = fp8_meta["recipe"].amax_compute_algo - sf_compute = fp8_meta["recipe"].scaling_factor_compute_algo - fp8_meta_tensor_key = "scaling_fwd" if fwd_update else "scaling_bwd" - fp8_max_key = "fp8_max_fwd" if fwd_update else "fp8_max_bwd" - - if not callable(amax_compute) and sf_compute is None: - ( - fp8_meta[fp8_meta_tensor_key].amax_history, - fp8_meta[fp8_meta_tensor_key].scale, - fp8_meta[fp8_meta_tensor_key].scale_inv, - ) = _fused_amax_and_scale_update( - fp8_meta[fp8_meta_tensor_key].amax_history, - fp8_meta[fp8_meta_tensor_key].scale, - fp8_meta[fp8_meta_tensor_key].scale_inv, - get_fp8_te_dtype(fp8_meta["recipe"], fwd_update), - fp8_meta["recipe"].margin, - fp8_meta["recipe"].amax_compute_algo, - fp8_meta[fp8_meta_tensor_key + "_non_weight_mask"], - update_weight_scale_inv, - ) - else: - fp8_meta[fp8_meta_tensor_key].amax_history, amax = _compute_amax( - fp8_meta[fp8_meta_tensor_key].amax_history, - fp8_meta["recipe"], - ) - fp8_meta[fp8_meta_tensor_key].scale = _compute_scaling_factor( - amax, - fp8_meta[fp8_meta_tensor_key].scale, - fp8_meta[fp8_max_key], - fp8_meta["recipe"], - ) - fp8_meta[fp8_meta_tensor_key].scale_inv = _compute_scaling_factor_inverse( - fp8_meta[fp8_meta_tensor_key].scale, - fp8_meta[fp8_meta_tensor_key].scale_inv, - fp8_meta[fp8_meta_tensor_key + "_non_weight_mask"], - update_weight_scale_inv, - ) + """Updates FP8 meta tensors.""" + new_amax_history, amax = _compute_amax_and_update_history( + amax_history, + recipe.amax_compute_algo, + ) + new_scale = _compute_scaling_factor(amax, scale, fp8_max, recipe) + scale.copy_(new_scale) + scale_inv.copy_(1.0 / new_scale) + amax_history.copy_(new_amax_history) + + +def split_and_copy( + buffer: torch.Tensor, + outputs: List[torch.Tensor], + chunk_sizes: List[int], +) -> None: + """Split `buffer` by `chunk_sizes` and copy into `outputs`.""" + splits = buffer.split(chunk_sizes) + torch._foreach_copy_(outputs, splits) diff --git a/transformer_engine/pytorch/graph.py b/transformer_engine/pytorch/graph.py new file mode 100644 index 0000000000..5de3b7a342 --- /dev/null +++ b/transformer_engine/pytorch/graph.py @@ -0,0 +1,548 @@ +# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Functions for CUDA Graphs support in FP8""" +import torch +from torch.utils._pytree import tree_flatten as _tree_flatten +from torch.utils._pytree import tree_unflatten as _tree_unflatten +from torch._C import _graph_pool_handle + +from .fp8 import ( + fp8_autocast, + FP8GlobalStateManager, + get_default_fp8_recipe, +) +from .distributed import get_all_rng_states, graph_safe_rng_available +from .module.base import TransformerEngineBaseModule + + +__all__ = ["make_graphed_callables"] + + +_IS_GRAPH_CAPTURING = False + + +def set_capture_start() -> None: + """Record beginning of `make_graphed_callables`.""" + global _IS_GRAPH_CAPTURING + _IS_GRAPH_CAPTURING = True + + +def set_capture_end() -> None: + """Record end of `make_graphed_callables`.""" + global _IS_GRAPH_CAPTURING + _IS_GRAPH_CAPTURING = False + + +def is_graph_capturing() -> None: + """Return whether within `make_graphed_callables`.""" + return _IS_GRAPH_CAPTURING + + +def graph_pool_handle(): + """ + Returns an opaque token representing the id of a graph memory pool. + """ + return _graph_pool_handle() + + +def _make_graphed_callables( + callables, + sample_args, + num_warmup_iters=3, + allow_unused_input=False, + fp8_weight_caching=False, + _order=None, +): + """ + Helper method for `make_graphed_callables` + """ + + if torch.is_autocast_enabled() and torch.is_autocast_cache_enabled(): + raise RuntimeError( + "make_graphed_callables does not support the autocast " + "caching. Please set `cache_enabled=False`." + ) + + just_one_callable = False + + if not isinstance(callables, tuple): + just_one_callable = True + callables = (callables,) + sample_args = (sample_args,) + + flatten_sample_args = [] + if _order is not None: + # order is a list containing 1..model_chunk values in the order of microbatch schedule + num_model_chunks = max(_order) + num_microbatches = len(_order) // num_model_chunks // 2 + assert num_model_chunks * num_microbatches * 2 == len(_order) + assert ( + len(sample_args)*2 >= len(_order) + and (len(sample_args)*2 % len(_order) == 0) + ), f'{len(sample_args)} >= {len(_order)} and {len(sample_args)} % {len(_order)} == 0' + num_layers = len(sample_args) // num_model_chunks // num_microbatches + assert ( + len(callables) == num_model_chunks*num_layers + ), (f"Callables should have ({num_model_chunks * num_layers}) " + + f"entries when order input is provided but got {len(callables)}." + ) + assert ( + len(sample_args) == num_model_chunks * num_microbatches * num_layers + ), (f"Expected {num_model_chunks * num_microbatches}" + + f"args tuple, but got {len(sample_args)}." + ) + + if fp8_weight_caching: + FP8GlobalStateManager.set_skip_fp8_weight_update_tensor(False) + + for c in callables: + if isinstance(c, torch.nn.Module): + assert ( + len(c._backward_hooks) == 0 + and len(c._forward_hooks) == 0 + and len(c._forward_pre_hooks) == 0 + ), ( + "Modules must not have hooks registered at the time they are passed. " + + "However, registering hooks on modules after passing them " + + "through make_graphed_callables is allowed." + ) + assert all(b.requires_grad is False for b in c.buffers()), ( + "In any :class:`~torch.nn.Module` passed to " + + ":func:`~make_graphed_callables`, only parameters may be trainable. " + + "All buffers must have ``requires_grad=False``." + ) + for args in sample_args: + flatten_arg, _ = _tree_flatten(args) + flatten_sample_args.append(tuple(flatten_arg)) + assert all(isinstance(arg, torch.Tensor) for arg in flatten_arg), ( + "In the beta API, sample_args " + + "for each callable must contain only Tensors. Other types are not allowed." + ) + + # If a callable is an nn.Module, its graph's full input surface is the args the user explicitly + # passes to forward (ie, its sample_args) AND the module's parameter attributes. + per_callable_len_user_args = [len(args) for args in flatten_sample_args] + if _order is None: + per_callable_module_params = [ + tuple(c.parameters()) if isinstance(c, torch.nn.Module) else () + for c in callables + ] + per_callable_static_input_surfaces = [ + flatten_sample_args[i] + per_callable_module_params[i] + for i in range(len(callables)) + ] + else: + per_callable_module_params = [] + for c in callables: + for i in range(num_microbatches): + per_callable_module_params.append( + tuple(c.parameters()) if isinstance(c, torch.nn.Module) else () + ) + assert len(per_callable_module_params) == len(flatten_sample_args) + per_callable_static_input_surfaces = [ + flatten_sample_args[i] + per_callable_module_params[i] + for i in range(len(flatten_sample_args)) + ] + + fwd_graphs = [torch.cuda.CUDAGraph() for _ in range(len(flatten_sample_args))] + bwd_graphs = [torch.cuda.CUDAGraph() for _ in range(len(flatten_sample_args))] + graph_callables = [None for _ in range(len(flatten_sample_args))] + # For cases with multiple active RNG states, e.g. TP. + if graph_safe_rng_available(): + for _, state in get_all_rng_states().items(): + for fwd_graph, bwd_graph in zip(fwd_graphs, bwd_graphs): + fwd_graph.register_generator_state(state) + bwd_graph.register_generator_state(state) + + mempool = graph_pool_handle() + + # Warmup + # Hopefully prevents cudnn benchmarking and other lazy-initialization cuda work + # from ending up in any captures. + torch.cuda.synchronize() + with torch.cuda.stream(torch.cuda.Stream()): + for c_i, func in enumerate(callables): + args = sample_args[c_i] + static_input_surface = per_callable_static_input_surfaces[c_i] + for _ in range(num_warmup_iters): + outputs, _ = _tree_flatten(func(*args)) + grad_inputs = torch.autograd.grad( + outputs=tuple(o for o in outputs if o.requires_grad), + inputs=tuple(i for i in static_input_surface if i.requires_grad), + grad_outputs=tuple( + torch.empty_like(o) for o in outputs if o.requires_grad + ), + only_inputs=True, + allow_unused=allow_unused_input, + ) + del outputs, grad_inputs + torch.cuda.synchronize() + + # All captures here share a mempool. To avoid replays corrupting each other's memory, + # the safest approach is to capture all passes in the same order they'll run: + # fwd 1, fwd 2, ... fwd N, then bwd N, bwd N-1, ... bwd 1. + + if _order is not None: # pylint: disable=too-many-nested-blocks + per_callable_static_outputs = [None] * len(flatten_sample_args) + per_callable_output_unflatten_spec = [None] * len(flatten_sample_args) + per_callable_static_grad_outputs = [None] * len(flatten_sample_args) + per_callable_static_grad_inputs = [None] * len(flatten_sample_args) + fwd_idx = [0] * num_model_chunks + bwd_idx = [0] * num_model_chunks + for c_id in _order: + if c_id > 0: + # Capture forward graph for model chunk c_id, microbatch fwd_idx[c_id-1] + m_chunk = c_id-1 + for l_no in range(num_layers): + func = callables[m_chunk*num_layers + l_no] + per_callable_fwd_idx = (m_chunk * num_microbatches * num_layers) \ + + (fwd_idx[m_chunk] * num_layers + l_no) + args = sample_args[per_callable_fwd_idx] + fwd_graph = fwd_graphs[per_callable_fwd_idx] + with torch.cuda.graph(fwd_graph, pool=mempool): + outputs = func(*args) + flatten_outputs, spec = _tree_flatten(outputs) + per_callable_static_outputs[per_callable_fwd_idx] = tuple(flatten_outputs) + per_callable_output_unflatten_spec[per_callable_fwd_idx] = spec + graph_callables[per_callable_fwd_idx] = func + fwd_idx[m_chunk] += 1 + else: + # Capture backward graph for model chunk c_id, microbatch bwd_idx[-c_id-1] + m_chunk = -c_id-1 + for l_no in list(reversed(range(num_layers))): + per_callable_bwd_idx = (m_chunk * num_microbatches * num_layers) \ + + (bwd_idx[m_chunk] * num_layers + l_no) + static_input_surface = per_callable_static_input_surfaces[per_callable_bwd_idx] + static_outputs = per_callable_static_outputs[per_callable_bwd_idx] + bwd_graph = bwd_graphs[per_callable_bwd_idx] + # For now, assumes all static_outputs require grad + static_grad_outputs = tuple( + torch.empty_like(o) if o.requires_grad else None for o in static_outputs + ) + with torch.cuda.graph(bwd_graph, pool=mempool): + grad_inputs = torch.autograd.grad( + outputs=tuple(o for o in static_outputs if o.requires_grad), + inputs=tuple(i for i in static_input_surface if i.requires_grad), + grad_outputs=tuple(o for o in static_grad_outputs if o is not None), + only_inputs=True, + allow_unused=allow_unused_input, + ) + # Constructs a tuple suitable for returning from Graphed.backward: + # Pads out the actually-needed grads with Nones in gradient slots for inputs + # that don't require grad. I couldn't think of a one-liner for this pattern. + static_grad_inputs = [] + grad_idx = 0 + for arg in static_input_surface: + if arg.requires_grad: + static_grad_inputs.append(grad_inputs[grad_idx]) + grad_idx += 1 + else: + static_grad_inputs.append(None) # type: ignore[arg-type] + static_grad_inputs = tuple(static_grad_inputs) # type: ignore[assignment] + + per_callable_static_grad_outputs[per_callable_bwd_idx] = static_grad_outputs + per_callable_static_grad_inputs[per_callable_bwd_idx] = static_grad_inputs + bwd_idx[m_chunk] += 1 + else: + # Capture forward graphs + per_callable_static_outputs = [] + per_callable_output_unflatten_spec = [] + graph_id = 0 + for func, args, fwd_graph in zip(callables, sample_args, fwd_graphs): + with torch.cuda.graph(fwd_graph, pool=mempool): + outputs = func(*args) + graph_callables[graph_id] = func + graph_id += 1 + + flatten_outputs, spec = _tree_flatten(outputs) + per_callable_static_outputs.append(tuple(flatten_outputs)) + per_callable_output_unflatten_spec.append(spec) + + # Capture backward graphs in reverse order + per_callable_static_grad_outputs = [] + per_callable_static_grad_inputs = [] + for static_input_surface, static_outputs, bwd_graph in zip( + reversed(per_callable_static_input_surfaces), + reversed(per_callable_static_outputs), + reversed(bwd_graphs), + ): + # For now, assumes all static_outputs require grad + static_grad_outputs = tuple( + torch.empty_like(o) if o.requires_grad else None for o in static_outputs + ) + with torch.cuda.graph(bwd_graph, pool=mempool): + grad_inputs = torch.autograd.grad( + outputs=tuple(o for o in static_outputs if o.requires_grad), + inputs=tuple(i for i in static_input_surface if i.requires_grad), + grad_outputs=tuple(o for o in static_grad_outputs if o is not None), + only_inputs=True, + allow_unused=allow_unused_input, + ) + # Constructs a tuple suitable for returning from Graphed.backward: + # Pads out the actually-needed grads with Nones in gradient slots for inputs that + # don't require grad. I couldn't think of a slick one-liner for this pattern. + static_grad_inputs = [] + grad_idx = 0 + for arg in static_input_surface: + if arg.requires_grad: + static_grad_inputs.append(grad_inputs[grad_idx]) + grad_idx += 1 + else: + static_grad_inputs.append(None) # type: ignore[arg-type] + static_grad_inputs = tuple(static_grad_inputs) # type: ignore[assignment] + + per_callable_static_grad_outputs.append(static_grad_outputs) + per_callable_static_grad_inputs.append(static_grad_inputs) + + # Reverses the most recent two lists + per_callable_static_grad_outputs = list(reversed(per_callable_static_grad_outputs)) + per_callable_static_grad_inputs = list(reversed(per_callable_static_grad_inputs)) + # Now for every per_callable list, per_callable_*[i] holds the stuff for the ith callable. + + def make_graphed_autograd_function( + fwd_graph, + bwd_graph, + module_params, + len_user_args, + output_unflatten_spec, + static_input_surface, + static_outputs, + static_grad_outputs, + static_grad_inputs, + ): + class Graphed(torch.autograd.Function): + """Autograd function for graph replay.""" + @staticmethod + def forward(ctx, skip_fp8_weight_update, *inputs): + # At this stage, only the user args may (potentially) be new tensors. + ctx.is_first_module = FP8GlobalStateManager.is_first_fp8_module() + if ctx.is_first_module and skip_fp8_weight_update is not None: + FP8GlobalStateManager.set_skip_fp8_weight_update_tensor(skip_fp8_weight_update) + + for i in range(len_user_args): + if static_input_surface[i].data_ptr() != inputs[i].data_ptr(): + static_input_surface[i].copy_(inputs[i]) + fwd_graph.replay() + assert isinstance(static_outputs, tuple) + return tuple(o.detach() for o in static_outputs) + + @staticmethod + @torch.autograd.function.once_differentiable + def backward(ctx, *grads): + assert len(grads) == len(static_grad_outputs) + for g, grad in zip(static_grad_outputs, grads): + if g is not None: + # don't copy if autograd gods have been kind and the + # incoming grad is already in the right place + if g.data_ptr() != grad.data_ptr(): + g.copy_(grad) + bwd_graph.replay() + + if ctx.is_first_module: + FP8GlobalStateManager.reduce_and_update_fp8_tensors(forward=False) + + # Input args that didn't require grad expect a None gradient. + assert isinstance(static_grad_inputs, tuple) + return (None,) + tuple( + b.detach() if b is not None else b for b in static_grad_inputs + ) + + def functionalized(*user_args, **user_kwargs): + # Runs the autograd function with inputs == all + # inputs to the graph that might require grad + # (explicit user args + module parameters) + # Assumes module params didn't change since capture. + skip_fp8_weight_update = None + if fp8_weight_caching: + assert ( + ("is_first_microbatch" in user_kwargs + and isinstance(user_kwargs["is_first_microbatch"], bool)) + ), "`is_first_microbatch` boolean kwarg must be provided for FP8 weight caching." + + skip_fp8_weight_update = not user_kwargs["is_first_microbatch"] + + flatten_user_args, _ = _tree_flatten(user_args) + out = Graphed.apply(skip_fp8_weight_update, *(tuple(flatten_user_args) + module_params)) + return _tree_unflatten(out, output_unflatten_spec) + + return functionalized + + # Put together the final graphed callables + ret = [] + for i in range(len(sample_args)): + graphed = make_graphed_autograd_function( + fwd_graphs[i], + bwd_graphs[i], + per_callable_module_params[i], + per_callable_len_user_args[i], + per_callable_output_unflatten_spec[i], + per_callable_static_input_surfaces[i], + per_callable_static_outputs[i], + per_callable_static_grad_outputs[i], + per_callable_static_grad_inputs[i], + ) + + func = graph_callables[i] + if isinstance(func, torch.nn.Module): + + def make_graphed_forward(func, graph_training_state, graphed, orig_fwd): + def new_fwd(*user_args, **user_kwargs): + # If the module's training-or-eval state matches what we graphed, + # run the graph, otherwise run the original forward method + if func.training == graph_training_state: + # Set the FP8 group from global amax reduction. + for m in func.modules(): + if (isinstance(m, TransformerEngineBaseModule) + and FP8GlobalStateManager.is_fp8_enabled()): + m.fp8_meta["fp8_group"] = FP8GlobalStateManager.get_fp8_group() + m.fp8_meta["recipe"] = FP8GlobalStateManager.get_fp8_recipe() + FP8GlobalStateManager.add_fp8_tensors_to_global_buffer( + m.fp8_meta, fp8_weights=m._get_fp8_params()) + return graphed(*user_args, **user_kwargs) + return orig_fwd(*user_args, **user_kwargs) + return new_fwd + + forward = make_graphed_forward(func, func.training, graphed, func.forward) + if _order is None: + func.forward = forward + ret.append(func) + else: + ret.append(forward) + else: + ret.append(graphed) + + if just_one_callable: + return ret[0] + + return tuple(ret) + + +def save_fp8_tensors(modules, amax_history_len): + """ + Returns the FP8 tensors for all modules + with adjusted amax history sizes. + """ + saved_fp8_meta_tensors = [] + for module in modules: + for m in module.modules(): + if isinstance(m, TransformerEngineBaseModule): + if m.primary_weights_in_fp8: + m.adjust_amax_history_length(amax_history_len) + saved_fp8_meta_tensors.append(m.get_fp8_meta_tensors()) + return saved_fp8_meta_tensors + + +def restore_fp8_tensors(modules, fp8_tensors): + """Restore FP8 tensors.""" + for module in modules: + for m in module.modules(): + if isinstance(m, TransformerEngineBaseModule): + m.reset_fp8_meta_tensors(fp8_tensors.pop(0)) + assert len(fp8_tensors) == 0, "TE internal error." + + +def make_graphed_callables( + modules, + sample_args, + num_warmup_iters=3, + allow_unused_input=False, + fp8_enabled=False, + fp8_calibrating=False, + fp8_recipe=None, + fp8_weight_caching=False, + _order=None, +): + """ + A version of PyTorch's `make_graphed_callables` utility function with support for + TransformerEngine modules and FP8. Please see the original version in upstream PyTorch + `here `_ + for extensive documentation. The documentation for additional parameters which are + specific to FP8 are given below. + + FP8 specific parameters + ----------------------- + fp8_enabled: bool, default = `True` + whether or not to enable fp8 + fp8_calibrating: bool, default = `False` + calibration mode allows collecting statistics such as amax and scale + data of fp8 tensors even when executing without fp8 enabled. This is + useful for saving an inference ready fp8 checkpoint while training + using a higher precision. + fp8_recipe: recipe.DelayedScaling, default = `None` + recipe used for FP8 training. + fp8_weight_caching: bool, default = `False` + Whether or not to cache FP8 weights across microbatches. if set to `True`, + the `is_first_microbatch` boolean argument must be passed into the forward + method for TransformerEngine modules. When storing primary weights in FP8 + using TE's `fp8_model_init` API and using an FP8 aware optimizer, this arg + must be set to `False` if calculating weight transposes' outside TE, e.g., + in the optimizer step. + """ + set_capture_start() + + fp8_recipe = get_default_fp8_recipe() if fp8_recipe is None else fp8_recipe + + # Handle single module. + just_one_callable = False + if not isinstance(modules, tuple): + just_one_callable = True + modules = (modules,) + + # Store FP8 tensors to reset later. + saved_fp8_tensors = save_fp8_tensors(modules, fp8_recipe.amax_history_len) + + # FP8 wrapper. + def wrap_autocast(block): + old_forward = block.forward + def forward_func(*args, **kwargs): + with fp8_autocast(enabled=fp8_enabled, + calibrating=fp8_calibrating, + fp8_recipe=fp8_recipe, + _graph=True): + outputs = old_forward(*args, **kwargs) + return outputs + block.forward = forward_func + + forward_funcs = [] + for module in modules: + assert isinstance(module, torch.nn.Module), f"Graphing for {type(module)} is not supported." + wrap_autocast(module) + forward_funcs.append(module) + + if just_one_callable: + forward_funcs = forward_funcs[0] + else: + forward_funcs = tuple(forward_funcs) + + # Save RNG state. + if graph_safe_rng_available(): + generators = [torch.cuda.default_generators[torch.cuda.current_device()], + *get_all_rng_states().values()] + original_rng_states = [state.get_state() for state in generators] + else: + original_rng_states = torch.cuda.get_rng_state() + + graphed_callables = _make_graphed_callables( + forward_funcs, sample_args, num_warmup_iters=num_warmup_iters, + allow_unused_input=allow_unused_input, + fp8_weight_caching=fp8_weight_caching, _order=_order) + + # Ensures warmup does not affect numerics for ops such as dropout. + if graph_safe_rng_available(): + for gen, state in zip(generators, original_rng_states): + gen.set_state(state) + else: + torch.cuda.set_rng_state(original_rng_states) + + # Reset FP8 gradients. + for module in modules: + for p in module.parameters(): + p.grad = None + + # Restore FP8 state. + restore_fp8_tensors(modules, saved_fp8_tensors) + + set_capture_end() + return graphed_callables diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index 56dd3c8fc4..7e0cf5c106 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -8,8 +8,7 @@ import pickle import warnings from abc import ABC, abstractmethod -from typing import Generator, Union, Optional, Tuple, Dict, Any, List -from functools import partial +from typing import Generator, Union, Optional, Tuple, List from contextlib import contextmanager import torch @@ -22,13 +21,11 @@ get_default_fp8_recipe, get_fp8_te_dtype, FP8GlobalStateManager, - amax_and_scale_update, ) from ..distributed import ( gather_along_first_dim, is_fp8_activation_recompute_enabled, in_fp8_activation_recompute_phase, - get_distributed_world_size, ) from ..cpp_extensions import ( fp8_cast_transpose_fused, @@ -44,7 +41,6 @@ _cublas_workspace = None _ub_communicators = None _NUM_MAX_UB_STREAMS = 3 -_amax_reduce_handle_bwd = None layers_atomic_ring_exchange = [] @@ -64,49 +60,6 @@ def get_workspace() -> torch.Tensor: ) return _cublas_workspace -@contextmanager -def _prepare_backward( - fp8: bool, - fp8_meta: Dict[str, Any], - tp_group: dist_group_type, - tp_size: int, - name: str = "" -) -> Generator[None, None, None]: - """Checks and prep for BWD.""" - if fp8: - global _amax_reduce_handle_bwd - if _amax_reduce_handle_bwd is not None: - _amax_reduce_handle_bwd.wait() - _amax_reduce_handle_bwd = None - - # Update amax and scale; Skip all setup for global amax reduction - if fp8_meta["recipe"].reduce_amax and get_distributed_world_size(fp8_meta["fp8_group"]) > 1: - # From previous iteration - FP8GlobalStateManager.copy_amax_from_global_buffer(fp8_meta, forward=False) - amax_and_scale_update(fp8_meta, False) - FP8GlobalStateManager.set_amax_buffer_key_deletion(fp8_meta, forward=False) - - # Get new backward key. - fp8_meta["autocast_id_bwd"] = fp8_meta["autocast_id_fwd_stack"].pop(0) - - FP8GlobalStateManager.add_amax_to_global_buffer(fp8_meta, forward=False) - else: - amax_and_scale_update(fp8_meta, False) - - with torch.cuda.nvtx.range(name + " backward"): - yield - - if (fp8 and fp8_meta["recipe"].reduce_amax - and get_distributed_world_size(fp8_meta["fp8_group"]) > 1): - if fp8_meta["first_module"]: - _amax_reduce_handle_bwd = FP8GlobalStateManager.global_amax_reduction( - fp8_meta, - tp_group, - tp_size, - forward=False - ) - FP8GlobalStateManager.delete_key_from_amax_buffer(forward=False) - def initialize_ub( shape: list, @@ -300,31 +253,54 @@ def __init__(self) -> None: self.tp_size = 1 self.sequence_parallel = False self.fp8_weight_shapes = [] - self.fp8_meta["autocast_id_fwd_stack"] = [] - self.fp8_meta["async_amax_reduction"] = bool( - int(os.getenv("NVTE_ASYNC_AMAX_REDUCTION", "0")) - ) self.param_init_meta = {} self.primary_weights_in_fp8 = FP8GlobalStateManager.with_fp8_parameters() + def adjust_amax_history_length(self, length: int, fwd: Optional[bool] = None) -> None: + """Increase or decrease size of amax history based on given `length`. + + .. warning:: + This changes the underlying amax memory location. + """ + if fwd is None: + fp8_meta_tensor_keys = ("scaling_fwd", "scaling_bwd") + else: + fp8_meta_tensor_keys = ("scaling_fwd" if fwd else "scaling_bwd",) + + for meta_key in fp8_meta_tensor_keys: + curr_len = self.fp8_meta[meta_key].amax_history.shape[0] + if length == curr_len: + continue + if length < curr_len: + self.fp8_meta[meta_key].amax_history = ( + self.fp8_meta[meta_key].amax_history[: length].clone()) + elif length > curr_len: + extra_rows = length - curr_len + self.fp8_meta[meta_key].amax_history = F.pad( + self.fp8_meta[meta_key].amax_history, pad=(0, 0, 0, extra_rows) + ) + + # Update the global buffers with new amax and history pointers. + if FP8GlobalStateManager.get_buffer_info() in self.fp8_meta: + fwd_pos, fwd_key, bwd_pos, bwd_key = ( + self.fp8_meta[FP8GlobalStateManager.get_buffer_info()]) + for pos, buffer_key in zip((fwd_pos, bwd_pos), (fwd_key, bwd_key)): + if buffer_key in FP8GlobalStateManager.global_amax_buffer: + assert ( + buffer_key in FP8GlobalStateManager.global_amax_history_buffer + ), "TE internal error during amax history change." + FP8GlobalStateManager.global_amax_buffer[buffer_key][pos] = ( + self.fp8_meta[meta_key].amax_history[0]) + FP8GlobalStateManager.global_amax_history_buffer[buffer_key][pos] = ( + self.fp8_meta[meta_key].amax_history) + def set_meta_tensor(self, fwd: bool) -> None: """Init scales and amaxes for fwd | bwd.""" fp8_meta_tensor_key = "scaling_fwd" if fwd else "scaling_bwd" if self.fp8_meta_tensors_initialized: # Handle changed amax history size. - curr_len = self.fp8_meta[fp8_meta_tensor_key].amax_history.shape[0] - need_len = self.fp8_meta["recipe"].amax_history_len - if need_len < curr_len: - self.fp8_meta[fp8_meta_tensor_key].amax_history = ( - self.fp8_meta[fp8_meta_tensor_key] - .amax_history[: self.fp8_meta["recipe"].amax_history_len].clone() - ) - elif need_len > curr_len: - extra_rows = need_len - curr_len - self.fp8_meta[fp8_meta_tensor_key].amax_history = F.pad( - self.fp8_meta[fp8_meta_tensor_key].amax_history, pad=(0, 0, 0, extra_rows) - ) + self.adjust_amax_history_length(self.fp8_meta["recipe"].amax_history_len, fwd=fwd) return # Max. number of fp8 tensors per GEMM = 3 (input, weight, output) for fwd and @@ -347,25 +323,45 @@ def set_meta_tensor(self, fwd: bool) -> None: device="cuda", ) - # Needed for calculation of scale inverses to - # preserve scale_inv when caching FP8 weights - if fwd: - # [True, False, True]: -> [input, weight, output] - self.fp8_meta[fp8_meta_tensor_key + "_non_weight_mask"] = torch.BoolTensor( - [True, False, True] * self.fp8_meta["num_gemms"] - ).cuda() - else: - # [True, True]: -> [grad_output, grad_input] - self.fp8_meta[fp8_meta_tensor_key + "_non_weight_mask"] = torch.BoolTensor( - [True, True] * self.fp8_meta["num_gemms"] - ).cuda() - def init_fp8_meta_tensors(self) -> None: """Init scales and amaxes.""" self.set_meta_tensor(True) self.set_meta_tensor(False) self.fp8_meta_tensors_initialized = True + def get_fp8_meta_tensors(self) -> None: + """Get scales and amaxes.""" + fwd_key, bwd_key = "scaling_fwd", "scaling_bwd" + if fwd_key not in self.fp8_meta or bwd_key not in self.fp8_meta: + return None + + fp8_meta_tensors = {fwd_key: [], bwd_key: []} + with torch.no_grad(): + for key in (fwd_key, bwd_key): + fp8_meta_tensors[key].append(self.fp8_meta[key].scale.clone()) + fp8_meta_tensors[key].append(self.fp8_meta[key].scale_inv.clone()) + fp8_meta_tensors[key].append(self.fp8_meta[key].amax_history.clone()) + return fp8_meta_tensors + + def reset_fp8_meta_tensors(self, fp8_meta_tensors=None) -> None: + """Reset scales and amaxes.""" + def reset(key): + if key in self.fp8_meta: + if fp8_meta_tensors is None: + self.fp8_meta[key].scale.copy_(torch.ones_like(self.fp8_meta[key].scale)) + self.fp8_meta[key].scale_inv.copy_( + torch.ones_like(self.fp8_meta[key].scale_inv)) + self.fp8_meta[key].amax_history.copy_( + torch.zeros_like(self.fp8_meta[key].amax_history)) + else: + assert key in fp8_meta_tensors, "Cannot reset fp8 tensors." + self.fp8_meta[key].scale.copy_(fp8_meta_tensors[key][0]) + self.fp8_meta[key].scale_inv.copy_(fp8_meta_tensors[key][1]) + self.fp8_meta[key].amax_history.copy_(fp8_meta_tensors[key][2]) + with torch.no_grad(): + reset("scaling_fwd") + reset("scaling_bwd") + def get_extra_state(self) -> torch.Tensor: """Save before checkpointing.""" state = None @@ -380,13 +376,11 @@ def get_extra_state(self) -> torch.Tensor: state["scale_bwd"] = self.fp8_meta["scaling_bwd"].scale state["scale_inv_bwd"] = self.fp8_meta["scaling_bwd"].scale_inv state["amax_history_bwd"] = self.fp8_meta["scaling_bwd"].amax_history - state["global_fp8_buffer"] = FP8GlobalStateManager.get_global_fp8_buffer_checkpoint() - state["global_fp8_state"] = FP8GlobalStateManager.get_global_fp8_state_checkpoint() # Store other pickelable values. extra = {} for k, v in self.fp8_meta.items(): - if isinstance(v, (bool, int, float, str, list)): + if isinstance(v, (bool, int, float, str, tuple, list)): extra[k] = v state["extra_fp8_variables"] = extra @@ -414,11 +408,6 @@ def set_extra_state(self, state: torch.Tensor) -> None: if state is None: return - # Restore global FP8 amax buffer. - FP8GlobalStateManager.set_global_fp8_buffer_checkpoint(state["global_fp8_buffer"]) - # Restore global FP8 state. - FP8GlobalStateManager.set_global_fp8_state_checkpoint(state["global_fp8_state"]) - # Load extra items. self.fp8_meta.update(state["extra_fp8_variables"]) self.fp8_meta["recipe"].amax_history_len = state["amax_history_fwd"].shape[0] @@ -527,6 +516,16 @@ def set_tensor_parallel_group(self, tp_group: Union[dist_group_type, None]) -> N self.tp_group = tp_group self.tp_group_initialized = True + def _get_fp8_params(self) -> Union[List[torch.Tensor], None]: + """returns the FP8 weights.""" + fp8_params = [] + for param in self.parameters(): + if isinstance(param, Float8Tensor) and param.requires_grad: + fp8_params.append(param) + if len(fp8_params) == 0: + return None + return fp8_params + # This routine is shared across FP8 and FP8_calibration paths so should not actually # assume FP8 execution. def init_fp8_metadata(self, num_gemms: int = 1) -> None: @@ -576,7 +575,6 @@ def prepare_forward( to setup the forward aggregated amax reduction for every module just in case. The autocast exit will pick up the most recent one. """ - # Activation recomputation is used and this is the second forward phase. if self.fp8 and in_fp8_activation_recompute_phase(): FP8GlobalStateManager.get_old_fp8_meta_tensors_for_recompute(self.fp8_meta) @@ -594,49 +592,14 @@ def prepare_forward( if is_first_microbatch is not None and not self.primary_weights_in_fp8: self.set_fp8_weights() - update_weight_scale_inv = is_first_microbatch is None or is_first_microbatch if self.fp8 and self.sequence_parallel: assert self.fp8_meta["recipe"].reduce_amax, \ "Amax reduction across tensor parallel group is " \ "necessary when using sequence parallelism with FP8." - # Previous iteration was grad_enabled - if self.fp8_meta.get("update_amax_and_scale_fwd", False): - if (self.fp8_meta["recipe"].reduce_amax - and get_distributed_world_size(self.fp8_meta["fp8_group"]) > 1): - FP8GlobalStateManager.copy_amax_from_global_buffer(self.fp8_meta, forward=True) - amax_and_scale_update( - self.fp8_meta, True, update_weight_scale_inv=update_weight_scale_inv - ) - FP8GlobalStateManager.set_amax_buffer_key_deletion(self.fp8_meta, forward=True) - else: - amax_and_scale_update( - self.fp8_meta, True, update_weight_scale_inv=update_weight_scale_inv - ) - - if self.fp8 and self.training: - # Setup for amax reduction - if (self.fp8_meta["recipe"].reduce_amax - and get_distributed_world_size(self.fp8_meta["fp8_group"]) > 1): - self.fp8_meta["first_module"] = FP8GlobalStateManager.is_first_fp8_module() - if self.fp8_meta["first_module"]: - # Wait for the prior AMAX reduction to finish - amax_reduce_handle_fwd = FP8GlobalStateManager.get_amax_reduce_handle_fwd() - if amax_reduce_handle_fwd is not None: - amax_reduce_handle_fwd.wait() - self.fp8_meta["autocast_id_fwd"] = ( - FP8GlobalStateManager.new_fp8_context_id()) - FP8GlobalStateManager.set_fp8_context_id(self.fp8_meta["autocast_id_fwd"]) - else: - self.fp8_meta["autocast_id_fwd"] = ( - FP8GlobalStateManager.get_fp8_context_id()) - self.fp8_meta["autocast_id_fwd_stack"].append( - self.fp8_meta["autocast_id_fwd"] - ) - FP8GlobalStateManager.add_amax_to_global_buffer(self.fp8_meta, forward=True) - self.fp8_meta["update_amax_and_scale_fwd"] = True - else: - self.fp8_meta["update_amax_and_scale_fwd"] = False + if self.fp8 and not FP8GlobalStateManager.fp8_graph_capturing(): + FP8GlobalStateManager.add_fp8_tensors_to_global_buffer( + self.fp8_meta, fp8_weights=self._get_fp8_params()) # Activation recomputation is used and this is the first forward phase. if ( @@ -653,18 +616,6 @@ def prepare_forward( FP8GlobalStateManager.restore_fp8_meta_tensors(self.fp8_meta) return - if (self.fp8 and self.training and self.fp8_meta["recipe"].reduce_amax - and get_distributed_world_size(self.fp8_meta["fp8_group"]) > 1): - FP8GlobalStateManager.set_fp8_context_id(self.fp8_meta["autocast_id_fwd"]) - reduce_func = partial( - FP8GlobalStateManager.global_amax_reduction, - self.fp8_meta, - self.tp_group, - self.tp_size, - forward=True - ) - FP8GlobalStateManager.setup_amax_forward_global_reduce_func(reduce_func) - def set_nccl_overlap_warning_if_tp(self) -> None: """When using TP, the NCCL communication needs to be scheduled before the GEMM for there to be a guaranteed overlap. From the diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index 985d587e54..8fdd5d1356 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -14,7 +14,6 @@ from .base import ( get_workspace, - _prepare_backward, get_ub, TransformerEngineBaseModule, _2X_ACC_FPROP, @@ -65,6 +64,7 @@ def forward( use_bias: bool, eps: float, is_first_microbatch: Union[bool, None], + skip_fp8_weight_update: Union[torch.Tensor, None], fp8: bool, fp8_calibration: bool, fp8_meta: Dict[str, Any], @@ -89,6 +89,7 @@ def forward( ub_overlap_rs_dgrad: bool, ub_overlap_ag: bool, ub_name: str, + dummy_tensor: torch.Tensor, # pylint: disable=unused-argument ) -> Union[Tuple[torch.Tensor, ...], torch.Tensor]: # Make sure input dimensions are compatible in_features = ln_weight.numel() @@ -98,7 +99,11 @@ def forward( assert_dim_for_fp8_exec(inputmat) assert_dim_for_fp8_exec(weight) - update_fp8_weights = is_first_microbatch is None or is_first_microbatch + update_fp8_weights = ( + is_first_microbatch is None + or is_first_microbatch + or skip_fp8_weight_update is not None + ) # Cast for native AMP inputmat = cast_if_needed(inputmat, activation_dtype) @@ -196,7 +201,6 @@ def forward( # Weight is already in FP8 weight.reset_fp8_meta_scale_inv() weight_fp8 = weight - weight_t_fp8 = None elif update_fp8_weights: # Need to cast weights to FP8 weight_fp8 = Float8Tensor( @@ -214,6 +218,7 @@ def forward( fp8_dtype_forward, cast_out=weight_fp8._data, transpose_out=weight_t_fp8._data, + noop_flag=skip_fp8_weight_update, ) else: tex.cast_to_fp8( @@ -295,6 +300,7 @@ def forward( weight_t_fp8, ln_out if weight.requires_grad else None, fp8_meta["scaling_fwd"].scale_inv.clone() if fp8 else None, + skip_fp8_weight_update.clone() if skip_fp8_weight_update is not None else None, ) ctx.activation_dtype = activation_dtype @@ -321,6 +327,7 @@ def forward( ctx.ub_name = ub_name ctx.requires_dgrad = inp.requires_grad ctx.normalization = normalization + ctx.primary_weights_in_fp8 = primary_weights_in_fp8 # Row Parallel Linear if parallel_mode == "row" and sequence_parallel: @@ -344,9 +351,7 @@ def forward( def backward( ctx, *grad_outputs: Tuple[torch.Tensor, ...] ) -> Tuple[Union[torch.Tensor, None], ...]: - with _prepare_backward( - ctx.fp8, ctx.fp8_meta, ctx.tp_group, ctx.tp_size, name="_LayerNormLinear" - ): + with torch.cuda.nvtx.range("_LayerNormLinear_backward"): ( inputmat, ln_weight, @@ -357,6 +362,7 @@ def backward( weight_t_fp8, ln_out, fwd_scale_inverses, + skip_fp8_weight_update, ) = ctx.saved_tensors if ctx.cpu_offloading and ctx.fuse_wgrad_accumulation: @@ -364,10 +370,13 @@ def backward( weight.main_grad = main_grad # Primary weights are in FP8. - if ctx.fp8 and weight_t_fp8 is None: - weight_t_fp8 = weight.transpose( - update_cache="reuse_only" if ctx.is_first_microbatch is None else "lazy", + if ctx.primary_weights_in_fp8: + weight_t_fp8 = weight.transpose_2d( + cache=ctx.is_first_microbatch is not None, + noop_flag=skip_fp8_weight_update, ) + elif ctx.fp8: + weight_t_fp8 = weight_t_fp8._data if ctx.ub_overlap_rs_dgrad: ctx.ub_bulk_dgrad = False @@ -472,7 +481,7 @@ def backward( # DGRAD: Evaluated unconditionally to feed into Linear backward _ = tex.fp8_gemm( - weight_t_fp8._data, + weight_t_fp8, fwd_scale_inverses, tex.FP8FwdTensors.GEMM1_WEIGHT, fp8_dtype_forward, @@ -686,6 +695,8 @@ def backward( None, None, None, + None, + None, ) @@ -970,7 +981,6 @@ def __init__( if self.primary_weights_in_fp8: self.init_fp8_metadata() - self.fp8_meta["update_amax_and_scale_fwd"] = True self.reset_parameters(defer_init=(device == 'meta')) @@ -990,6 +1000,10 @@ def __init__( self.fwd_ln_sm_margin = int(os.getenv("NVTE_FWD_LAYERNORM_SM_MARGIN", "0")) self.bwd_ln_sm_margin = int(os.getenv("NVTE_BWD_LAYERNORM_SM_MARGIN", "0")) + # Initialize a dummy tensor to be used as gradient hook for bwd amax reduction. + self.dummy_tensor = torch.zeros(1, device=device, requires_grad=True) + FP8GlobalStateManager.add_tensor_for_bwd_reduction_multi_grad_hook(self.dummy_tensor) + def reset_layer_norm_parameters(self) -> None: """Init LN params""" warnings.warn( @@ -1084,6 +1098,10 @@ def forward( produced) """ + skip_fp8_weight_update = FP8GlobalStateManager.get_skip_fp8_weight_update_tensor() + if skip_fp8_weight_update is not None: + is_first_microbatch = False + with self.prepare_forward(inp, is_first_microbatch) as inp: assert self.fp8 or not self.primary_weights_in_fp8, \ "Need to run inside fp8_autocast region when weights are stored in FP8." @@ -1132,6 +1150,7 @@ def forward( self.apply_bias and not self.gemm_bias_unfused_add, self.eps, is_first_microbatch, + skip_fp8_weight_update, self.fp8, self.fp8_calibration, self.fp8_meta, @@ -1156,6 +1175,7 @@ def forward( self.ub_overlap_rs_dgrad, self.ub_overlap_ag, self.ub_name, + self.dummy_tensor, ) out = fwd_fn(*args) diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index ad66e01e07..43103f06e1 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -13,7 +13,6 @@ from .base import ( get_workspace, - _prepare_backward, get_ub, TransformerEngineBaseModule, _2X_ACC_FPROP, @@ -94,6 +93,7 @@ def forward( use_fc2_bias: bool, eps: float, is_first_microbatch: Union[bool, None], + skip_fp8_weight_update: Union[torch.Tensor, None], fp8: bool, fp8_calibration: bool, fp8_meta: Dict[str, Any], @@ -121,6 +121,7 @@ def forward( ub_overlap_rs: bool, ub_overlap_ag: bool, gemm_gelu_fusion: bool, + dummy_tensor: torch.Tensor, # pylint: disable=unused-argument, ) -> Union[Tuple[torch.Tensor, ...], torch.Tensor]: # Make sure input dimensions are compatible in_features = ln_weight.numel() @@ -131,7 +132,11 @@ def forward( assert_dim_for_fp8_exec(fc1_weight) assert_dim_for_fp8_exec(fc2_weight) - update_fp8_weights = is_first_microbatch is None or is_first_microbatch + update_fp8_weights = ( + is_first_microbatch is None + or is_first_microbatch + or skip_fp8_weight_update is not None + ) activation_func = _act_func(activation)[0] @@ -225,8 +230,6 @@ def forward( fc2_weight.reset_fp8_meta_scale_inv() fc1_weight_fp8 = fc1_weight fc2_weight_fp8 = fc2_weight - fc1_weight_t_fp8 = None - fc2_weight_t_fp8 = None elif update_fp8_weights: # Need to cast weights to FP8 fc1_weight_fp8 = Float8Tensor( @@ -250,6 +253,7 @@ def forward( fp8_dtype_forward, cast_out=fc1_weight_fp8._data, transpose_out=fc1_weight_t_fp8._data, + noop_flag=skip_fp8_weight_update, ) tex.fp8_cast_transpose_fused( fc2_weight, @@ -258,6 +262,7 @@ def forward( fp8_dtype_forward, cast_out=fc2_weight_fp8._data, transpose_out=fc2_weight_t_fp8._data, + noop_flag=skip_fp8_weight_update, ) else: tex.cast_to_fp8( @@ -510,6 +515,7 @@ def forward( fc2_weight_t_fp8, fc1_bias, fp8_meta["scaling_fwd"].scale_inv.clone() if fp8 else None, + skip_fp8_weight_update.clone() if skip_fp8_weight_update is not None else None, ) ctx.activation_dtype = activation_dtype ctx.activation = activation @@ -538,6 +544,7 @@ def forward( ctx.ub_overlap_ag = ub_overlap_ag ctx.requires_dgrad = inp.requires_grad ctx.normalization = normalization + ctx.primary_weights_in_fp8 = primary_weights_in_fp8 # Row Parallel Linear if ub_overlap_rs: @@ -563,9 +570,7 @@ def forward( def backward( ctx, *grad_outputs: Tuple[torch.Tensor, ...] ) -> Tuple[Union[torch.Tensor, None], ...]: - with _prepare_backward( - ctx.fp8, ctx.fp8_meta, ctx.tp_group, ctx.tp_size, name="_LayerNormMLP" - ): + with torch.cuda.nvtx.range("_LayerNormMLP_backward"): ( inputmat, ln_weight, @@ -582,6 +587,7 @@ def backward( fc2_weight_t_fp8, fc1_bias, fwd_scale_inverses, + skip_fp8_weight_update, ) = ctx.saved_tensors if ctx.cpu_offloading and ctx.fuse_wgrad_accumulation: @@ -592,11 +598,18 @@ def backward( fc2_weight.main_grad = fc2_weight_main_grad # Primary weights are in FP8. - update_transpose_cache = "reuse_only" if ctx.is_first_microbatch is None else "lazy" - if ctx.fp8 and fc1_weight_t_fp8 is None: - fc1_weight_t_fp8 = fc1_weight.transpose(update_cache=update_transpose_cache) - if ctx.fp8 and fc2_weight_t_fp8 is None: - fc2_weight_t_fp8 = fc2_weight.transpose(update_cache=update_transpose_cache) + if ctx.primary_weights_in_fp8: + fc1_weight_t_fp8 = fc1_weight.transpose_2d( + cache=ctx.is_first_microbatch is not None, + noop_flag=skip_fp8_weight_update, + ) + fc2_weight_t_fp8 = fc2_weight.transpose_2d( + cache=ctx.is_first_microbatch is not None, + noop_flag=skip_fp8_weight_update, + ) + elif ctx.fp8: + fc1_weight_t_fp8 = fc1_weight_t_fp8._data + fc2_weight_t_fp8 = fc2_weight_t_fp8._data activation_func = _act_func(ctx.activation)[1] @@ -673,7 +686,7 @@ def backward( # FC2 DGRAD; Unconditional fc2_dgrad, _ = tex.fp8_gemm( - fc2_weight_t_fp8._data, + fc2_weight_t_fp8, fwd_scale_inverses, tex.FP8FwdTensors.GEMM2_WEIGHT, fp8_dtype_forward, @@ -826,7 +839,7 @@ def backward( ub_obj = None # FC1 DGRAD: Unconditional _ = tex.fp8_gemm( - fc1_weight_t_fp8._data, + fc1_weight_t_fp8, fwd_scale_inverses, tex.FP8FwdTensors.GEMM1_WEIGHT, fp8_dtype_forward, @@ -1151,6 +1164,8 @@ def backward( None, None, None, + None, + None, ) @@ -1389,7 +1404,6 @@ def __init__( if self.primary_weights_in_fp8: self.init_fp8_metadata(num_gemms=2) - self.fp8_meta["update_amax_and_scale_fwd"] = True self.reset_parameters(defer_init=(device == 'meta')) @@ -1414,6 +1428,10 @@ def __init__( self.fwd_ln_sm_margin = int(os.getenv("NVTE_FWD_LAYERNORM_SM_MARGIN", "0")) self.bwd_ln_sm_margin = int(os.getenv("NVTE_BWD_LAYERNORM_SM_MARGIN", "0")) + # Initialize a dummy tensor to be used as gradient hook for bwd amax reduction. + self.dummy_tensor = torch.zeros(1, device=device, requires_grad=True) + FP8GlobalStateManager.add_tensor_for_bwd_reduction_multi_grad_hook(self.dummy_tensor) + def reset_layer_norm_parameters(self) -> None: """Init LN params""" warnings.warn( @@ -1473,7 +1491,9 @@ def get_fp8_weights_scratchpad( @no_torch_dynamo() def forward( - self, inp: torch.Tensor, is_first_microbatch: Optional[bool] = None + self, + inp: torch.Tensor, + is_first_microbatch: Optional[bool] = None ) -> Union[torch.Tensor, Tuple[torch.Tensor, ...]]: """ Apply layer normalization to the input followed by a feedforward network (MLP Block). @@ -1497,6 +1517,10 @@ def forward( produced) """ + skip_fp8_weight_update = FP8GlobalStateManager.get_skip_fp8_weight_update_tensor() + if skip_fp8_weight_update is not None: + is_first_microbatch = False + with self.prepare_forward(inp, is_first_microbatch, num_gemms=2) as inp: assert self.fp8 or not self.primary_weights_in_fp8, \ "Need to run inside fp8_autocast region when weights are stored in FP8." @@ -1535,6 +1559,7 @@ def forward( self.apply_bias and not self.gemm_bias_unfused_add, self.eps, is_first_microbatch, + skip_fp8_weight_update, self.fp8, self.fp8_calibration, self.fp8_meta, @@ -1562,6 +1587,7 @@ def forward( self.ub_overlap_rs, self.ub_overlap_ag, self.gemm_gelu_fusion, + self.dummy_tensor, ) out = fwd_fn(*args) diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index 1f7898a592..4baf2d5965 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -11,7 +11,6 @@ from .base import ( get_workspace, - _prepare_backward, get_ub, TransformerEngineBaseModule, _2X_ACC_FPROP, @@ -65,6 +64,7 @@ def forward( bias: torch.Tensor, use_bias: bool, is_first_microbatch: Union[bool, None], + skip_fp8_weight_update: Union[torch.Tensor, None], fp8: bool, fp8_calibration: bool, fp8_meta: Dict[str, Any], @@ -80,7 +80,8 @@ def forward( primary_weights_in_fp8: bool, ub_overlap_rs: bool, ub_overlap_ag: bool, - ub_name: str + ub_name: str, + dummy_tensor: torch.Tensor, # pylint: disable=unused-argument ) -> torch.Tensor: # Make sure input dimensions are compatible in_features = weight.shape[-1] @@ -90,7 +91,12 @@ def forward( assert_dim_for_fp8_exec(inputmat) assert_dim_for_fp8_exec(weight) - update_fp8_weights = is_first_microbatch is None or is_first_microbatch + update_fp8_weights = ( + is_first_microbatch is None + or is_first_microbatch + or skip_fp8_weight_update is not None + ) + tp_world_size = get_distributed_world_size(tp_group) ub_overlap_rs = False if tp_world_size == 1 else ub_overlap_rs @@ -140,7 +146,6 @@ def forward( # Weight is already in FP8 weight.reset_fp8_meta_scale_inv() weight_fp8 = weight - weight_t_fp8 = None elif update_fp8_weights: # Need to cast weights to FP8 weight_fp8 = Float8Tensor( @@ -158,6 +163,7 @@ def forward( fp8_dtype_forward, cast_out=weight_fp8._data, transpose_out=weight_t_fp8._data, + noop_flag=skip_fp8_weight_update, ) else: cast_to_fp8( @@ -296,6 +302,7 @@ def forward( weight.main_grad if cpu_offloading and fuse_wgrad_accumulation else None, weight_t_fp8 if fp8 else None, fp8_meta["scaling_fwd"].scale_inv.clone() if fp8 else None, + skip_fp8_weight_update.clone() if skip_fp8_weight_update is not None else None, ) ctx.activation_dtype = activation_dtype ctx.fp8 = fp8 @@ -313,6 +320,7 @@ def forward( ctx.ub_name = ub_name ctx.tp_size = tp_size ctx.requires_dgrad = inp.requires_grad + ctx.primary_weights_in_fp8 = primary_weights_in_fp8 # Row Parallel Linear if ub_overlap_rs: @@ -330,9 +338,7 @@ def forward( def backward( ctx, grad_output: torch.Tensor ) -> Tuple[Union[torch.Tensor, None], ...]: - with _prepare_backward( - ctx.fp8, ctx.fp8_meta, ctx.tp_group, ctx.tp_size, name="_Linear" - ): + with torch.cuda.nvtx.range("_Linear_backward"): ( inputmat, inputmat_t, @@ -340,6 +346,7 @@ def backward( main_grad, weight_t_fp8, fwd_scale_inverses, + skip_fp8_weight_update, ) = ctx.saved_tensors if ctx.cpu_offloading and ctx.fuse_wgrad_accumulation: @@ -347,10 +354,14 @@ def backward( weight.main_grad = main_grad # Primary weights are in FP8. - if ctx.fp8 and weight_t_fp8 is None: - weight_t_fp8 = weight.transpose( - update_cache="reuse_only" if ctx.is_first_microbatch is None else "lazy", + if ctx.primary_weights_in_fp8: + weight_t_fp8 = weight.transpose_2d( + cache=ctx.is_first_microbatch is not None, + noop_flag=skip_fp8_weight_update, ) + elif ctx.fp8: + weight_t_fp8 = weight_t_fp8._data + tp_world_size = get_distributed_world_size(ctx.tp_group) ctx.ub_overlap_ag = False if tp_world_size == 1 else ctx.ub_overlap_ag if ctx.ub_overlap_ag: @@ -361,6 +372,7 @@ def backward( ub_algo = tex.UbufOverlapAlgo.ATOMIC_GEMM_AG_P2P else: ub_algo = tex.UbufOverlapAlgo.SPLIT_PIPELINED_AG_P2P + ( grad_output, grad_output_c, @@ -401,7 +413,7 @@ def backward( if ctx.requires_dgrad: if ctx.fp8: dgrad, _ = fp8_gemm( - weight_t_fp8._data, + weight_t_fp8, fwd_scale_inverses, tex.FP8FwdTensors.GEMM1_WEIGHT, fp8_dtype_forward, @@ -542,6 +554,8 @@ def backward( None, None, None, + None, + None, ) @@ -772,7 +786,6 @@ def __init__( if self.primary_weights_in_fp8: self.init_fp8_metadata() - self.fp8_meta["update_amax_and_scale_fwd"] = True self.reset_parameters(defer_init=(device == 'meta')) @@ -785,6 +798,10 @@ def __init__( else: self.gemm_bias_unfused_add = False + # Initialize a dummy tensor to be used as gradient hook for bwd amax reduction. + self.dummy_tensor = torch.zeros(1, device=device, requires_grad=True) + FP8GlobalStateManager.add_tensor_for_bwd_reduction_multi_grad_hook(self.dummy_tensor) + def reset_parameters(self, defer_init=False): super().reset_parameters(defer_init=defer_init) @@ -858,6 +875,10 @@ def forward( produced) """ + skip_fp8_weight_update = FP8GlobalStateManager.get_skip_fp8_weight_update_tensor() + if skip_fp8_weight_update is not None: + is_first_microbatch = False + with self.prepare_forward(inp, is_first_microbatch) as inp: assert self.fp8 or not self.primary_weights_in_fp8, \ "Need to run inside fp8_autocast region when weights are stored in FP8." @@ -903,6 +924,7 @@ def forward( bias_tensor, self.apply_bias and not self.gemm_bias_unfused_add, is_first_microbatch, + skip_fp8_weight_update, self.fp8, self.fp8_calibration, self.fp8_meta, @@ -919,6 +941,7 @@ def forward( self.ub_overlap_rs, self.ub_overlap_ag, self.ub_name, + self.dummy_tensor, ) out = linear_fn(*args) diff --git a/transformer_engine/pytorch/transformer.py b/transformer_engine/pytorch/transformer.py index 2e00333fa0..5b6fc1e5c3 100644 --- a/transformer_engine/pytorch/transformer.py +++ b/transformer_engine/pytorch/transformer.py @@ -473,6 +473,15 @@ def set_tensor_parallel_group(self, tp_group: Union[dist_group_type, None]) -> N if hasattr(child, "set_tensor_parallel_group"): child.set_tensor_parallel_group(tp_group) + def reset_fp8_meta_tensors(self) -> None: + """Set TP group""" + # Deep iterate but skip self to avoid infinite recursion. + for index, child in enumerate(self.modules()): + if index == 0: + continue + if hasattr(child, "reset_fp8_meta_tensors"): + child.reset_fp8_meta_tensors() + def set_context_parallel_group( self, cp_group: Union[dist_group_type, None], @@ -665,7 +674,8 @@ def forward( # MLP. mlp_outputs = self.layernorm_mlp( - hidden_states, is_first_microbatch=is_first_microbatch + hidden_states, + is_first_microbatch=is_first_microbatch, ) if self.apply_residual_connection_post_layernorm: mlp_output, mlp_bias, residual = mlp_outputs