From 0e1d9faed1ef8d341614c31b2fa7694b4a9f39a5 Mon Sep 17 00:00:00 2001 From: Phuong Nguyen Date: Thu, 12 Dec 2024 08:00:46 -0500 Subject: [PATCH 01/11] [JAX] Bug fix for distributed normalization (#1366) * fix ctx.aval_out indexing for workspace * add cudnn init to prepare phase of norm custom calls * add thread_local for norm registry instance --------- Signed-off-by: Phuong Nguyen --- .../common/normalization/common.h | 3 +-- .../jax/cpp_extensions/normalization.py | 12 +++++----- .../jax/csrc/extensions/pybind.cpp | 24 ++++++++++++++----- 3 files changed, 25 insertions(+), 14 deletions(-) diff --git a/transformer_engine/common/normalization/common.h b/transformer_engine/common/normalization/common.h index 8a8df63ba4..d1d56d5cc9 100644 --- a/transformer_engine/common/normalization/common.h +++ b/transformer_engine/common/normalization/common.h @@ -287,9 +287,8 @@ class CudnnNormalizationPlan : public NormalizationPlanBase { class NormalizationPlanRegistry { public: - // TODO thread-safe static NormalizationPlanRegistry& getInstance() { - static NormalizationPlanRegistry instance; + static thread_local NormalizationPlanRegistry instance; return instance; } diff --git a/transformer_engine/jax/cpp_extensions/normalization.py b/transformer_engine/jax/cpp_extensions/normalization.py index 0b7df0b5a8..69d7962b62 100644 --- a/transformer_engine/jax/cpp_extensions/normalization.py +++ b/transformer_engine/jax/cpp_extensions/normalization.py @@ -147,7 +147,7 @@ def lowering(ctx, x, gamma, beta, *, zero_centered_gamma, epsilon): batch_shape = out_shape[:-1] batch_size = reduce(operator.mul, x_shape) // hidden_size - wkspace_aval = ctx.avals_out[-2:] + wkspace_aval = ctx.avals_out[-1] out_types = [ ir.RankedTensorType.get(out_shape, output_type), @@ -441,7 +441,7 @@ def lowering(ctx, dz, x, mu, rsigma, gamma, *, zero_centered_gamma, epsilon): sm_margin = get_backward_sm_margin() - wkspace_aval = ctx.avals_out[-4:] + wkspace_aval = ctx.avals_out[-1] opaque = transformer_engine_jax.pack_norm_descriptor( batch_size, hidden_size, @@ -650,7 +650,7 @@ def lowering(ctx, x, gamma, *, epsilon): batch_shape = out_shape[:-1] batch_size = reduce(operator.mul, x_shape) // hidden_size - wkspace_aval = ctx.avals_out[-2:] + wkspace_aval = ctx.avals_out[-1] out_types = [ ir.RankedTensorType.get(out_shape, x_type.element_type), @@ -841,7 +841,7 @@ def lowering(ctx, dz, x, rsigma, gamma, *, epsilon): hidden_size = reduce(operator.mul, g_shape) batch_size = reduce(operator.mul, x_shape) // hidden_size - wkspace_aval = ctx.avals_out[-3:] + wkspace_aval = ctx.avals_out[-1] out_types = [ ir.RankedTensorType.get(x_shape, x_type.element_type), @@ -1088,7 +1088,7 @@ def lowering( batch_shape = out_shape[:-1] batch_size = reduce(operator.mul, x_shape) // hidden_size - wkspace_aval = ctx.avals_out[-2:] + wkspace_aval = ctx.avals_out[-1] out_types = [ ir.RankedTensorType.get(out_shape, ir_out_dtype), @@ -1394,7 +1394,7 @@ def lowering(ctx, x, gamma, amax, scale, scale_inv, *, out_dtype, epsilon): batch_shape = out_shape[:-1] batch_size = reduce(operator.mul, x_shape) // hidden_size - wkspace_aval = ctx.avals_out[-2:] + wkspace_aval = ctx.avals_out[-1] out_types = [ ir.RankedTensorType.get(out_shape, ir_out_dtype), diff --git a/transformer_engine/jax/csrc/extensions/pybind.cpp b/transformer_engine/jax/csrc/extensions/pybind.cpp index 9b5c156e5d..a319b74d76 100644 --- a/transformer_engine/jax/csrc/extensions/pybind.cpp +++ b/transformer_engine/jax/csrc/extensions/pybind.cpp @@ -83,12 +83,24 @@ pybind11::dict Registrations() { EncapsulateFunction(ScaledUpperTriangMaskedSoftmaxBackwardHandler); // Normalization - dict["te_layernorm_forward_ffi"] = EncapsulateFFI(LayerNormForwardHandler); - dict["te_layernorm_forward_fp8_ffi"] = EncapsulateFFI(LayerNormForwardFP8Handler); - dict["te_layernorm_backward_ffi"] = EncapsulateFFI(LayerNormBackwardHandler); - dict["te_rmsnorm_forward_ffi"] = EncapsulateFunction(RMSNormForwardHandler); - dict["te_rmsnorm_forward_fp8_ffi"] = EncapsulateFunction(RMSNormForwardFP8Handler); - dict["te_rmsnorm_backward_ffi"] = EncapsulateFunction(RMSNormBackwardHandler); + dict["te_layernorm_forward_ffi"] = + pybind11::dict(pybind11::arg("prepare") = EncapsulateFFI(CudnnHandleInitHandler), + pybind11::arg("execute") = EncapsulateFFI(LayerNormForwardHandler)); + dict["te_layernorm_forward_fp8_ffi"] = + pybind11::dict(pybind11::arg("prepare") = EncapsulateFFI(CudnnHandleInitHandler), + pybind11::arg("execute") = EncapsulateFFI(LayerNormForwardFP8Handler)); + dict["te_layernorm_backward_ffi"] = + pybind11::dict(pybind11::arg("prepare") = EncapsulateFFI(CudnnHandleInitHandler), + pybind11::arg("execute") = EncapsulateFFI(LayerNormBackwardHandler)); + dict["te_rmsnorm_forward_ffi"] = + pybind11::dict(pybind11::arg("prepare") = EncapsulateFFI(CudnnHandleInitHandler), + pybind11::arg("execute") = EncapsulateFFI(RMSNormForwardHandler)); + dict["te_rmsnorm_forward_fp8_ffi"] = + pybind11::dict(pybind11::arg("prepare") = EncapsulateFFI(CudnnHandleInitHandler), + pybind11::arg("execute") = EncapsulateFFI(RMSNormForwardFP8Handler)); + dict["te_rmsnorm_backward_ffi"] = + pybind11::dict(pybind11::arg("prepare") = EncapsulateFFI(CudnnHandleInitHandler), + pybind11::arg("execute") = EncapsulateFFI(RMSNormBackwardHandler)); // Attention pybind11::dict fused_attn_forward_ffi; From e7bfc0c547d63332e4f8d65e606dc69f4c22ffbe Mon Sep 17 00:00:00 2001 From: Kirthi Shankar Sivamani Date: Thu, 12 Dec 2024 14:16:09 -0800 Subject: [PATCH 02/11] Add user to CI (#1371) Add Jeremy to ci users Signed-off-by: Kirthi Shankar Sivamani --- .github/workflows/trigger-ci.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/trigger-ci.yml b/.github/workflows/trigger-ci.yml index 586abd0541..86d22b7944 100644 --- a/.github/workflows/trigger-ci.yml +++ b/.github/workflows/trigger-ci.yml @@ -42,6 +42,7 @@ jobs: || github.actor == 'kocchop' || github.actor == 'youngeunkwon0405' || github.actor == 'KshitijLakhani' + || github.actor == 'jberchtold-nvidia' ) steps: - name: Check if comment is issued by authorized person From 1ae81903a16f274ccdfd199c91634ab9833e4c9a Mon Sep 17 00:00:00 2001 From: Jingyue Wu Date: Fri, 13 Dec 2024 18:09:42 -0800 Subject: [PATCH 03/11] Fix an invalid reference in the doc (#1362) --- examples/pytorch/comm_gemm_overlap/README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/pytorch/comm_gemm_overlap/README.md b/examples/pytorch/comm_gemm_overlap/README.md index bb3ba209ed..fc8458844b 100644 --- a/examples/pytorch/comm_gemm_overlap/README.md +++ b/examples/pytorch/comm_gemm_overlap/README.md @@ -16,7 +16,7 @@ Forward and backward passes with layer weights distributed over all GPUs in a single node. ```bash -$ torchrun --nnodes=1 --nproc-per-node=$(nvidia-smi -L | wc -l) ln_mlp_with_overlap.py +$ torchrun --nnodes=1 --nproc-per-node=$(nvidia-smi -L | wc -l) te_layer_with_overlap.py # Sample output on 8x H100s: # [rank0:node0] |-- Created tensor-parallel group: [0, 1, 2, 3, 4, 5, 6, 7] @@ -70,7 +70,7 @@ Uses `torch.nn.parallel.DistributedDataParallel` for replicatin the model across groups in a single node. ```bash -$ torchrun --nnodes=1 --nproc-per-node=$(nvidia-smi -L | wc -l) ln_mlp_overlap.py --num-replicas 2 +$ torchrun --nnodes=1 --nproc-per-node=$(nvidia-smi -L | wc -l) te_layer_with_overlap.py --num-replicas 2 # Sample output on 8x H100s: # [rank0:node0] |-- Created tensor-parallel group: [0, 1, 2, 3] From 1975ace44b3d4255e2c2e7aa0546d394ab1c9ce3 Mon Sep 17 00:00:00 2001 From: Phuong Nguyen Date: Sat, 14 Dec 2024 12:09:21 -0500 Subject: [PATCH 04/11] [JAX] Bug Fix: Softmax FFIs with correct Encapsulates (#1375) * softmax custom calls with correct encapsulates * rm jax deprecated features --------- Signed-off-by: Phuong Nguyen --- .../jax/cpp_extensions/activation.py | 6 +++--- transformer_engine/jax/cpp_extensions/base.py | 2 +- .../jax/cpp_extensions/normalization.py | 14 +++++++------- .../jax/cpp_extensions/softmax.py | 8 ++++---- .../jax/csrc/extensions/pybind.cpp | 19 ++++++++----------- 5 files changed, 23 insertions(+), 26 deletions(-) diff --git a/transformer_engine/jax/cpp_extensions/activation.py b/transformer_engine/jax/cpp_extensions/activation.py index 44b396ad55..7f09e6f900 100644 --- a/transformer_engine/jax/cpp_extensions/activation.py +++ b/transformer_engine/jax/cpp_extensions/activation.py @@ -8,7 +8,7 @@ import jax import jax.numpy as jnp -from jax import core, dtypes +from jax import dtypes from jax.interpreters.mlir import ir from jax.sharding import PartitionSpec, NamedSharding from jax.extend import ffi @@ -98,7 +98,7 @@ def abstract(x_aval, *, act_enum): # pylint: disable=unused-argument assert x_shape[-2] == 2 or x_shape[-2] == 1 hidden_size = x_shape[-1] batch_shapes = x_shape[:-2] - out_aval = core.raise_to_shaped(x_aval) + out_aval = x_aval out_shape = (batch_shapes) + (hidden_size,) out_aval = out_aval.update(shape=out_shape, dtype=dtype) @@ -225,7 +225,7 @@ def abstract(dz_aval, x_aval, *, act_enum): # pylint: disable=unused-argument i_hidden_size = dz_aval.shape[-1] g_hidden_size = x_aval.shape[-1] assert i_hidden_size == g_hidden_size - out_aval = core.raise_to_shaped(x_aval) + out_aval = x_aval return out_aval diff --git a/transformer_engine/jax/cpp_extensions/base.py b/transformer_engine/jax/cpp_extensions/base.py index 3d88c1f078..3715e6f20c 100644 --- a/transformer_engine/jax/cpp_extensions/base.py +++ b/transformer_engine/jax/cpp_extensions/base.py @@ -7,7 +7,7 @@ from abc import ABCMeta, abstractmethod from functools import partial -from jax import core +from jax.extend import core from jax.interpreters import xla, mlir from jax.experimental.custom_partitioning import custom_partitioning from jax._src.interpreters import batching diff --git a/transformer_engine/jax/cpp_extensions/normalization.py b/transformer_engine/jax/cpp_extensions/normalization.py index 69d7962b62..8ad7ee4fcb 100644 --- a/transformer_engine/jax/cpp_extensions/normalization.py +++ b/transformer_engine/jax/cpp_extensions/normalization.py @@ -9,7 +9,7 @@ import jax import jax.numpy as jnp -from jax import core, dtypes +from jax import dtypes from jax.interpreters import mlir from jax.interpreters.mlir import ir from jax.sharding import PartitionSpec, NamedSharding @@ -74,7 +74,7 @@ def abstract(x_aval, gamma_aval, beta_aval, **kwargs): mu_rsigama_dtype = jnp.float32 - out_aval = core.raise_to_shaped(x_aval) + out_aval = x_aval mu_aval = rsigma_aval = out_aval.update(shape=out_aval.shape[:-1], dtype=mu_rsigama_dtype) assert gamma_aval.size == beta_aval.size @@ -361,8 +361,8 @@ def abstract(dz_aval, x_aval, mu_aval, rsigma_aval, gamma_aval, **kwargs): assert mu_aval.shape == rsigma_aval.shape == x_aval.shape[:-1] assert mu_dtype == rsigma_dtype == jnp.float32 - dx_aval = core.raise_to_shaped(dz_aval) - dgamma_aval = dbeta_aval = core.raise_to_shaped(gamma_aval) + dx_aval = dz_aval + dgamma_aval = dbeta_aval = gamma_aval (wkspace_info,) = transformer_engine_jax.get_layernorm_bwd_workspace_sizes( x_aval.size // gamma_aval.size, # batch size @@ -589,7 +589,7 @@ def abstract(x_aval, gamma_aval, **kwargs): rsigama_dtype = jnp.float32 - out_aval = core.raise_to_shaped(x_aval) + out_aval = x_aval rsigma_aval = out_aval.update(shape=out_aval.shape[:-1], dtype=rsigama_dtype) hidden_size = gamma_aval.size @@ -783,8 +783,8 @@ def abstract(dz_aval, x_aval, rsigma_aval, gamma_aval, **kwargs): assert rsigma_aval.shape == x_aval.shape[:-1] assert rsigma_dtype == jnp.float32 - dx_aval = core.raise_to_shaped(dz_aval) - dgamma_aval = core.raise_to_shaped(gamma_aval) + dx_aval = dz_aval + dgamma_aval = gamma_aval (wkspace_info,) = transformer_engine_jax.get_layernorm_bwd_workspace_sizes( x_aval.size // gamma_aval.size, # batch size diff --git a/transformer_engine/jax/cpp_extensions/softmax.py b/transformer_engine/jax/cpp_extensions/softmax.py index a12943f4c2..67053ecd8e 100644 --- a/transformer_engine/jax/cpp_extensions/softmax.py +++ b/transformer_engine/jax/cpp_extensions/softmax.py @@ -9,7 +9,7 @@ import jax import jax.numpy as jnp -from jax import core, dtypes +from jax import dtypes from jax.interpreters.mlir import ir from jax.sharding import PartitionSpec, NamedSharding from jax.extend import ffi @@ -126,7 +126,7 @@ def forward_abstract(logits_aval, scale_factor): assert k_seqlen <= SoftmaxPrimitive.max_k_seqlen_supported assert q_seqlen > 1 - out_aval = core.raise_to_shaped(logits_aval) + out_aval = logits_aval return out_aval @staticmethod @@ -237,7 +237,7 @@ def backward_abstract( assert dz_aval.shape == softmax_out_aval.shape - dx_aval = core.raise_to_shaped(dz_aval) + dx_aval = dz_aval return dx_aval @staticmethod @@ -578,7 +578,7 @@ def abstract(logits_aval, mask_aval, scale_factor): # pylint: disable=unused-ar assert mask_shape[-2] == q_seqlen assert mask_shape[-1] == k_seqlen - out_aval = core.raise_to_shaped(logits_aval) + out_aval = logits_aval return out_aval @staticmethod diff --git a/transformer_engine/jax/csrc/extensions/pybind.cpp b/transformer_engine/jax/csrc/extensions/pybind.cpp index a319b74d76..a986b91b30 100644 --- a/transformer_engine/jax/csrc/extensions/pybind.cpp +++ b/transformer_engine/jax/csrc/extensions/pybind.cpp @@ -61,26 +61,23 @@ pybind11::dict Registrations() { dict["te_act_lu_ffi"] = EncapsulateFFI(ActLuHandler); dict["te_act_lu_fp8_ffi"] = EncapsulateFFI(ActLuFP8Handler); dict["te_dact_lu_ffi"] = EncapsulateFFI(DActLuHandler); - dict["te_dact_lu_dbias_cast_transpose_ffi"] = - EncapsulateFunction(DActLuDBiasCastTransposeHandler); - dict["te_dgated_act_lu_cast_transpose_ffi"] = - EncapsulateFunction(DGatedActLuCastTransposeHandler); + dict["te_dact_lu_dbias_cast_transpose_ffi"] = EncapsulateFFI(DActLuDBiasCastTransposeHandler); + dict["te_dgated_act_lu_cast_transpose_ffi"] = EncapsulateFFI(DGatedActLuCastTransposeHandler); // Quantization dict["te_quantize_ffi"] = EncapsulateFFI(QuantizeHandler); dict["te_dequantize_ffi"] = EncapsulateFFI(DequantizeHandler); // Softmax - dict["te_scaled_softmax_forward_ffi"] = EncapsulateFunction(ScaledSoftmaxForwardHandler); - dict["te_scaled_softmax_backward_ffi"] = EncapsulateFunction(ScaledSoftmaxBackwardHandler); - dict["te_scaled_masked_softmax_forward_ffi"] = - EncapsulateFunction(ScaledMaskedSoftmaxForwardHandler); + dict["te_scaled_softmax_forward_ffi"] = EncapsulateFFI(ScaledSoftmaxForwardHandler); + dict["te_scaled_softmax_backward_ffi"] = EncapsulateFFI(ScaledSoftmaxBackwardHandler); + dict["te_scaled_masked_softmax_forward_ffi"] = EncapsulateFFI(ScaledMaskedSoftmaxForwardHandler); dict["te_scaled_masked_softmax_backward_ffi"] = - EncapsulateFunction(ScaledMaskedSoftmaxBackwardHandler); + EncapsulateFFI(ScaledMaskedSoftmaxBackwardHandler); dict["te_scaled_upper_triang_masked_softmax_forward_ffi"] = - EncapsulateFunction(ScaledUpperTriangMaskedSoftmaxForwardHandler); + EncapsulateFFI(ScaledUpperTriangMaskedSoftmaxForwardHandler); dict["te_scaled_upper_triang_masked_softmax_backward_ffi"] = - EncapsulateFunction(ScaledUpperTriangMaskedSoftmaxBackwardHandler); + EncapsulateFFI(ScaledUpperTriangMaskedSoftmaxBackwardHandler); // Normalization dict["te_layernorm_forward_ffi"] = From 0196ed4461ad561411aa828d1e9dc89a32ef7177 Mon Sep 17 00:00:00 2001 From: Youngeun Kwon Date: Mon, 16 Dec 2024 15:39:47 -0800 Subject: [PATCH 05/11] Enabling FP8 all-gather for TE Float8Tensor when using Torch FSDP2 (#1358) * draft implementation of fsdp2 fp8 all gather Signed-off-by: Youngeun Kwon * fix the convergence issue Signed-off-by: Youngeun Kwon * Add warning Signed-off-by: Youngeun Kwon * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * disable lint error Signed-off-by: Youngeun Kwon * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix the lint error Signed-off-by: Youngeun Kwon * fix lint error Signed-off-by: Youngeun Kwon * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix lint error Signed-off-by: Youngeun Kwon * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix lint error Signed-off-by: Youngeun Kwon * add comments Signed-off-by: Youngeun Kwon * add ref Signed-off-by: Youngeun Kwon * add related tests Signed-off-by: Youngeun Kwon * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Youngeun Kwon Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- qa/L1_pytorch_distributed_unittest/test.sh | 1 + tests/pytorch/distributed/run_fsdp2_model.py | 181 ++++++++++++++++++ tests/pytorch/distributed/test_torch_fsdp2.py | 67 +++++++ .../pytorch/tensor/float8_tensor.py | 92 ++++++++- 4 files changed, 340 insertions(+), 1 deletion(-) create mode 100644 tests/pytorch/distributed/run_fsdp2_model.py create mode 100644 tests/pytorch/distributed/test_torch_fsdp2.py diff --git a/qa/L1_pytorch_distributed_unittest/test.sh b/qa/L1_pytorch_distributed_unittest/test.sh index 9a11ccc008..4e52153db9 100644 --- a/qa/L1_pytorch_distributed_unittest/test.sh +++ b/qa/L1_pytorch_distributed_unittest/test.sh @@ -11,4 +11,5 @@ pytest -v -s $TE_PATH/tests/pytorch/distributed/test_numerics.py pytest -v -s $TE_PATH/tests/pytorch/distributed/test_comm_gemm_overlap.py pytest -v -s $TE_PATH/tests/pytorch/distributed/test_fusible_ops.py pytest -v -s $TE_PATH/tests/pytorch/distributed/test_fusible_ops_with_userbuffers.py +pytest -v -s $TE_PATH/tests/pytorch/distributed/test_torch_fsdp2.py pytest -v -s $TE_PATH/tests/pytorch/fused_attn/test_fused_attn_with_cp.py diff --git a/tests/pytorch/distributed/run_fsdp2_model.py b/tests/pytorch/distributed/run_fsdp2_model.py new file mode 100644 index 0000000000..0f00a6717b --- /dev/null +++ b/tests/pytorch/distributed/run_fsdp2_model.py @@ -0,0 +1,181 @@ +#!/usr/bin/python3 + +# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +import os +import sys +import argparse + +import transformer_engine.pytorch as te +from transformer_engine.common.recipe import Format, DelayedScaling + +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch import nn, optim +from torch.distributed import DeviceMesh +from torch.distributed._composable.fsdp import fully_shard +from torch.distributed.device_mesh import init_device_mesh +from contextlib import nullcontext + + +class SimpleNet(nn.Module): + def __init__(self, input_size, hidden_size, output_size): + super(SimpleNet, self).__init__() + self.fc1 = te.Linear(input_size, hidden_size) + self.fc2 = te.Linear(hidden_size, output_size) + + def forward(self, x): + x = F.relu(self.fc1(x)) + x = self.fc2(x) + return x + + +def save_custom_attrs(module): + custom_attrs = {} + for name, param in module.named_parameters(): + attrs = vars(param) + custom_attrs[name] = {k: v for k, v in attrs.items()} + return custom_attrs + + +def restore_custom_attrs(module, custom_attrs): + for name, param in module.named_parameters(): + if name in custom_attrs: + for attr_name, attr_value in custom_attrs[name].items(): + setattr(param, attr_name, attr_value) + + +def _parse_args(argv=None, namespace=None): + parser = argparse.ArgumentParser(description="Toy example for debugging fully_shard()") + parser.add_argument("--input-size", type=int, default=2048, help="Input size for the model") + parser.add_argument("--hidden-size", type=int, default=2048, help="Hidden layer size") + parser.add_argument("--output-size", type=int, default=2048, help="Output size for the model") + parser.add_argument("--batch-size", type=int, default=2048, help="Output size for the model") + parser.add_argument( + "--fp8-init", action="store_true", default=False, help="Initialize primary weights in FP8." + ) + parser.add_argument( + "--iter", type=int, default=10, help="Number of iterations for forward pass" + ) + parser.add_argument("--seed", type=int, default=42, help="RNG seed.") + # Adding hsdp_dim as a list argument, comma-separated + parser.add_argument( + "--sharding-dims", + type=int, + nargs="+", + help='FSDP/HSDP sharding dimensions ("replicate", "shard")', + ) + args = parser.parse_args(argv, namespace) + if args.sharding_dims: + assert len(args.sharding_dims) <= 2 + return args + + +sub_modules_to_wrap = [te.Linear] + + +def _train(args): + assert "TORCHELASTIC_RUN_ID" in os.environ + WORLD_RANK = int(os.getenv("RANK", "0")) + WORLD_SIZE = int(os.getenv("WORLD_SIZE", "1")) + LOCAL_RANK = int(os.getenv("LOCAL_RANK", "0")) + LOCAL_SIZE = int(os.getenv("LOCAL_WORLD_SIZE", "1")) + assert LOCAL_SIZE == WORLD_SIZE + + # Set device and initialize RNG states + torch.cuda.set_device(WORLD_RANK) + torch.manual_seed(args.seed) + torch.cuda.manual_seed(args.seed) + + # Initialize torch.distributed global process group and get DP/TP groups + dist_init_kwargs = { + "backend": "nccl", + "rank": WORLD_RANK, + "world_size": WORLD_SIZE, + } + assert dist.is_nccl_available() + dist.init_process_group(**dist_init_kwargs) + nccl_world = dist.new_group(backend="nccl") + device = torch.device(f"cuda:{LOCAL_RANK}") + + # FP8 Configuration + fp8_format = Format.HYBRID + fp8_recipe = DelayedScaling(fp8_format=fp8_format, amax_history_len=16, amax_compute_algo="max") + + if not args.fp8_init: + # Build model context (FP8 init) + build_model_context = nullcontext + build_model_context_args = {} + + from transformer_engine.pytorch import fp8_model_init + + build_model_context = fp8_model_init + build_model_context_args["enabled"] = True + + # Build the model with the specified context + with build_model_context(**build_model_context_args): + model = SimpleNet(args.input_size, args.hidden_size, args.output_size) + else: + model = SimpleNet(args.input_size, args.hidden_size, args.output_size) + # Move the model to the correct device + + model.to(device) + + if LOCAL_RANK == 0: + print(f"Rank {LOCAL_RANK}: Applying FSDP fully_shard() to the model...") + # Creating a DeviceMesh for fully_shard + world_size = int(WORLD_SIZE) + device_ids = list(range(world_size)) + if LOCAL_RANK == 0: + print(f"sharding-dims:{args.sharding_dims}") + # Setup the sharding mesh for FSDP/HSDP + if args.sharding_dims == None: # FSDP + mesh = DeviceMesh("cuda", device_ids) + elif len(args.sharding_dims) == 1: + assert args.sharding_dims[0] == device_ids[-1] + 1 + mesh = DeviceMesh("cuda", device_ids) + elif len(args.sharding_dims) == 2: # HSDP + assert args.sharding_dims[0] * args.sharding_dims[1] == device_ids[-1] + 1 + mesh = init_device_mesh( + "cuda", + (args.sharding_dims[0], args.sharding_dims[1]), + mesh_dim_names=("replicate", "shard"), + ) + else: + assert False + + # Apply FSDP/HSDP + custom_attrs = save_custom_attrs(model) + for sub_module in model.modules(): + if any( + isinstance(sub_module, sub_module_to_wrap) for sub_module_to_wrap in sub_modules_to_wrap + ): + fully_shard(sub_module, mesh=mesh) + fully_shard(model, mesh=mesh) + restore_custom_attrs(model, custom_attrs) + + optimizer = optim.Adam(model.parameters(), lr=1e-3) + + for iteration in range(args.iter): + # Zero the parameter gradients + optimizer.zero_grad() + input_data = torch.randn(args.batch_size, args.input_size).to(device) + output = model(input_data) + target = torch.randn(args.batch_size, args.output_size).to(device) + loss = F.mse_loss(output, target) + loss.backward() + optimizer.step() + if LOCAL_RANK == 0: + print(f"Rank {LOCAL_RANK}: Iteration {iteration} completed.") + + dist.destroy_process_group() + if LOCAL_RANK == 0: + print(f"Rank {LOCAL_RANK}: Done...") + return 0 + + +if __name__ == "__main__": + sys.exit(_train(_parse_args())) diff --git a/tests/pytorch/distributed/test_torch_fsdp2.py b/tests/pytorch/distributed/test_torch_fsdp2.py new file mode 100644 index 0000000000..3c9197c322 --- /dev/null +++ b/tests/pytorch/distributed/test_torch_fsdp2.py @@ -0,0 +1,67 @@ +# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +import os +import pytest +import subprocess +from pathlib import Path +from transformer_engine.pytorch.fp8 import FP8GlobalStateManager +import torch +from packaging.version import Version as PkgVersion + + +def get_torch_version(): + """Get pytorch version from __version__""" + + def get_torch_version_str(): + import torch + + return str(torch.__version__) + + return PkgVersion(get_torch_version_str()) + + +if torch.cuda.device_count() < 4: + pytest.skip("FSDP2 test requires at least 4 GPUs.") + +if torch.cuda.device_count() % 2 != 0: + pytest.skip("Number of device should be divided by 2.") + +if not get_torch_version() >= PkgVersion("2.4"): + pytest.skip("FSDP2 requires PyTorch >= 2.4.0 with FSDP 2 support.") + +fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() + +TEST_ROOT = Path(__file__).parent.resolve() +NUM_PROCS: int = torch.cuda.device_count() +LAUNCH_CMD = ["torchrun", f"--nproc_per_node={NUM_PROCS}"] + + +def _run_test(fp_init, sharding_dims): + test_path = TEST_ROOT / "run_fsdp2_model.py" + test_cmd = LAUNCH_CMD + [str(test_path)] + + if fp_init: + test_cmd += ["--fp8-init"] + if len(sharding_dims) == 1: + test_cmd += ["--sharding-dims", str(sharding_dims[0])] + elif len(sharding_dims) == 2: + test_cmd += ["--sharding-dims", str(sharding_dims[0]), str(sharding_dims[1])] + else: + assert False + result = subprocess.run(test_cmd, env=os.environ, capture_output=True, check=False) + if result.returncode != 0: + raise AssertionError(result.stderr.decode()) + + +all_boolean = [True, False] +sharding_dims = [[NUM_PROCS], [2, NUM_PROCS // 2]] + + +@pytest.mark.parametrize("sharding_dims", sharding_dims) +@pytest.mark.parametrize("fp8_init", all_boolean) +def test_distributed(fp8_init, sharding_dims): + if fp8_init and not fp8_available: + pytest.skip(reason_for_no_fp8) + _run_test(fp8_init, sharding_dims) diff --git a/transformer_engine/pytorch/tensor/float8_tensor.py b/transformer_engine/pytorch/tensor/float8_tensor.py index 7ace68a222..414e819f53 100644 --- a/transformer_engine/pytorch/tensor/float8_tensor.py +++ b/transformer_engine/pytorch/tensor/float8_tensor.py @@ -24,6 +24,19 @@ aten = torch.ops.aten updated_fp8_params = {} +_ops_to_preserve_subclass_in_fsdp2 = { + torch.ops.aten.empty_like.default, + torch.ops.aten.new_zeros.default, + torch.ops.aten.slice.Tensor, + torch.ops.aten.copy_.default, + torch.ops.aten.view.default, + torch.ops.aten.as_strided.default, + torch.ops.aten._to_copy.default, + torch.ops.aten._pin_memory.default, + torch.ops.aten.split.Tensor, + torch.ops.aten.clone.default, +} + def _make_fp8_attr_property_funcs(name: str) -> Any: """Make accessors for an FP8 attribute @@ -430,6 +443,37 @@ def __new__( return self + def fsdp_pre_all_gather(self, mesh): # pylint: disable=unused-argument + """ + A hook function used in torch fsdp2, called before all-gather + return (all-gather input), (metadata) + Ref: https://github.com/pytorch/pytorch/pull/122908 + + """ + + return (self._data,), (self,) + + def fsdp_post_all_gather( + self, + all_gather_outputs: Tuple[torch.Tensor, ...], + metadata: Any, + param_dtype: torch.dtype, # pylint: disable=unused-argument + *, + out: Optional[torch.Tensor] = None, + ): + """ + A hook function used in torch fsdp2, called after all-gather + return (Float8Tensor class instance of all-gathered input), (Things to free after forward) + Ref: https://github.com/pytorch/pytorch/pull/122908 + + """ + (data,) = all_gather_outputs + (sample,) = metadata + if out is not None: + assert isinstance(out, Float8Tensor), f"{type(out)}" + return None + return Float8Tensor.make_like(sample, data=data), all_gather_outputs + @classmethod def make_like( cls, @@ -902,7 +946,53 @@ def __torch_dispatch__(cls, func, types, args, kwargs=None): ) return Float8Tensor.make_like(tensor, data=data_view) - # Default case + # Related to FSDP2 + if func == aten.split.Tensor: + tensor = args[0] + data = tensor._data + func_out = data.__torch_dispatch__( + func, + types, + [data] + list(args[1:]), + kwargs, + ) + return [Float8Tensor.make_like(tensor, data=split_tensor) for split_tensor in func_out] + if func == aten.new_zeros.default: + tensor = args[0] + data = tensor._data + func_out = data.__torch_dispatch__( + func, + types, + [data] + list(args[1:]), + kwargs, + ) + return Float8Tensor.make_like(tensor, data=func_out) + if func == torch.ops.aten.as_strided.default: + tensor = args[0] + data = tensor._data + func_out = data.__torch_dispatch__( + func, + types, + [data] + list(args[1:]), + kwargs, + ) + return Float8Tensor.make_like(tensor, data=func_out) + if func == torch.ops.aten.detach.default: + return cls.detach(args[0]) + if func == torch.ops.aten.clone.default: + return cls.clone(args[0]) + if func == torch.ops.aten.copy_.default: + # Implementation in the superclass (QuantizedTensor) returns a proper output + pass + elif func in _ops_to_preserve_subclass_in_fsdp2: + # Ops in the _ops_to_preserve_subclass_in_fsdp2 are recommened to return the same class instance to work fine with the torch fsdp2 + warnings.warn( + f"A function call({func}) in {cls} may not return {cls} tensor as an output. It" + " might cause an error in torch FSDP2!" + ) + else: + pass + return super().__torch_dispatch__(func, types, args, kwargs) @classmethod From f4f35c2f715e8c219ee4f76de2b9e768af062cfe Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Mon, 16 Dec 2024 19:57:44 -0800 Subject: [PATCH 06/11] [common] Add max_t support for KV in THD (#1370) add max_t for KV Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- .../common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu | 1 + 1 file changed, 1 insertion(+) diff --git a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu index f242502261..b706eadace 100644 --- a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu +++ b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu @@ -661,6 +661,7 @@ void fused_attn_arbitrary_seqlen_bwd_impl( if (is_ragged && cudnn_runtime_version >= 90600) { sdpa_backward_options.set_max_total_seq_len_q(s_q); + sdpa_backward_options.set_max_total_seq_len_kv(s_kv); } if (cudnn_runtime_version >= 90200 && window_size_left != -1) { From 7f5c784e32391670cd4661f61edbca7912916a6c Mon Sep 17 00:00:00 2001 From: Reese Wang Date: Tue, 17 Dec 2024 15:41:40 +0800 Subject: [PATCH 07/11] [JAX] Fused attention unit tests fixes and refinements (#1352) * Add util functions to attn_mask_type Signed-off-by: Reese Wang * Add util functions to qkv_layout Signed-off-by: Reese Wang * Fix THD cross reference code Signed-off-by: Reese Wang * Remove explicit segment_pad, encoding it to segment_ids Signed-off-by: Reese Wang * Add jax.jit, replace _token with segment_ids, rename bias shape enum Signed-off-by: Reese Wang * Add comment for make_mask Signed-off-by: Reese Wang * Clean code Signed-off-by: Reese Wang * Add doc strings for the added functions Signed-off-by: Reese Wang * Remove cache for fa deterministic which causes UT failed Signed-off-by: Reese Wang * Rename fixture to avoid conflict Signed-off-by: Reese Wang --------- Signed-off-by: Reese Wang --- tests/jax/conftest.py | 2 +- tests/jax/test_distributed_fused_attn.py | 6 +- tests/jax/test_fused_attn.py | 227 ++++++++++-------- tests/jax/utils.py | 16 +- transformer_engine/jax/attention.py | 99 +++++--- .../jax/cpp_extensions/attention.py | 3 +- 6 files changed, 201 insertions(+), 152 deletions(-) diff --git a/tests/jax/conftest.py b/tests/jax/conftest.py index ccb6690a87..5bb86c6081 100644 --- a/tests/jax/conftest.py +++ b/tests/jax/conftest.py @@ -20,7 +20,7 @@ def clear_live_arrays(): @pytest.fixture(autouse=True, scope="module") -def enable_fused_attn(): +def enable_fused_attn_after_hopper(): """ Enable fused attn for hopper+ arch. Fused attn kernels on pre-hopper arch are not deterministic. diff --git a/tests/jax/test_distributed_fused_attn.py b/tests/jax/test_distributed_fused_attn.py index e194a228d2..1538062975 100644 --- a/tests/jax/test_distributed_fused_attn.py +++ b/tests/jax/test_distributed_fused_attn.py @@ -20,7 +20,6 @@ from utils import ( make_causal_mask, make_self_mask, - assert_tree_like_allclose, assert_allclose, print_debug_tensor_stats, ) @@ -32,7 +31,6 @@ AttnMaskType, QKVLayout, QKVFormat, - get_qkv_format, reorder_causal_load_balancing, inverse_reorder_causal_load_balancing, CPStrategy, @@ -421,7 +419,7 @@ def impl_test_contex_parallel_attn( dropout_prob = 0.0 is_training = True dp_size, cp_size, tp_size = mesh_shape - qkv_format = get_qkv_format(qkv_layout) + qkv_format = qkv_layout.get_qkv_format() batch, seqlen, num_head, hidden = data_shape @@ -503,7 +501,7 @@ def grad_func(func, *args, **kwargs): # Gradient is small, use a gradient multiplier to amplify the gradient _, max_seq_len, num_heads, _ = data_shape gradient_multiplier = max_seq_len * num_heads - if attn_mask_type in [AttnMaskType.CAUSAL_MASK, AttnMaskType.CAUSAL_BOTTOM_RIGHT_MASK]: + if attn_mask_type.is_causal(): gradient_multiplier /= 10 ret_valid = func(*args, **kwargs) return (jnp.mean(ret_valid, dtype=jnp.float32) * gradient_multiplier).astype(dtype) diff --git a/tests/jax/test_fused_attn.py b/tests/jax/test_fused_attn.py index af05538ef5..759ea893ef 100644 --- a/tests/jax/test_fused_attn.py +++ b/tests/jax/test_fused_attn.py @@ -28,7 +28,6 @@ QKVFormat, fused_attn, fused_attn_thd, - get_qkv_format, make_swa_mask, ) from transformer_engine.jax.cpp_extensions import FusedAttnHelper @@ -50,6 +49,7 @@ def init(): yield +@partial(jax.jit, static_argnums=(5, 6, 7, 9)) def general_dot_product_attention( query: ArrayLike, key: ArrayLike, @@ -102,29 +102,36 @@ def general_dot_product_attention( return context -def is_causal_mask(mask: AttnMaskType): - """ - Check if the mask is a causal mask - """ - return mask in [AttnMaskType.CAUSAL_MASK, AttnMaskType.PADDING_CAUSAL_MASK] - - -def make_causal_mask(q_tokens: ArrayLike, kv_tokens: ArrayLike) -> Array: +@jax.jit +def make_causal_mask( + segment_ids_q: ArrayLike, + segment_ids_kv: ArrayLike, + segment_pos_q: ArrayLike = None, + segment_pos_kv: ArrayLike = None, +) -> Array: """ Create inverse padded causal mask where `True` means allowing the corresponding position to participate in attention and `False` means masking out that position. + If segment_pos is not provided, aragne of the segment_ids will be applied. """ - q_idxs = jnp.broadcast_to(jnp.arange(q_tokens.shape[-1], dtype=jnp.int32), q_tokens.shape) - kv_idxs = jnp.broadcast_to(jnp.arange(kv_tokens.shape[-1], dtype=jnp.int32), kv_tokens.shape) - inv_causal_mask = make_attention_mask(q_idxs, kv_idxs, jnp.greater_equal) + if segment_pos_q is None: + segment_pos_q = jnp.broadcast_to( + jnp.arange(segment_ids_q.shape[-1], dtype=jnp.int32), segment_ids_q.shape + ) + if segment_pos_kv is None: + segment_pos_kv = jnp.broadcast_to( + jnp.arange(segment_ids_kv.shape[-1], dtype=jnp.int32), segment_ids_kv.shape + ) + inv_causal_mask = make_attention_mask(segment_pos_q, segment_pos_kv, jnp.greater_equal) return inv_causal_mask +@partial(jax.jit, static_argnums=(4, 5)) def make_mask( - q_token: ArrayLike, - kv_token: ArrayLike, - segment_pad_q: ArrayLike, - segment_pad_kv: ArrayLike, + segment_ids_q: ArrayLike, + segment_ids_kv: ArrayLike, + segment_pos_q: ArrayLike, + segment_pos_kv: ArrayLike, attn_mask_type: AttnMaskType, window_size: Optional[Tuple[int, int]] = None, ) -> Array: @@ -132,18 +139,31 @@ def make_mask( Create attention mask based on mask type. A `True` value in the mask means masking out the corresponding position and a `False` value means allowing that position to participate in attention. + + - segment_ids should start with 1, and using 0s for the paddings. + Expected that each segment starts without paddings. + - segment_pos marks the token position in the segments. + + A example pair of segments_ids and segment_pos: + segment_ids: [1, 1, 1, 0, 2, 2, 2, 3, 3, 3, 4, 0, 0, 5, 5, 5] + segment_pos: [0, 1, 2, 3, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2] """ inv_mask = make_attention_mask( - q_token, kv_token, lambda x, y: (jnp.logical_and(jnp.equal(x, y), x != 0)) + segment_ids_q, segment_ids_kv, lambda x, y: (jnp.logical_and(jnp.equal(x, y), x != 0)) ) - if is_causal_mask(attn_mask_type): - inv_causal_mask = make_causal_mask(q_token, kv_token) - inv_mask = combine_masks(inv_causal_mask, inv_mask) - if segment_pad_q is not None and segment_pad_kv is not None: - inv_pad_mask = make_attention_mask( - segment_pad_q, segment_pad_kv, lambda x, y: jnp.logical_and(x != 1, y != 1) + if attn_mask_type.is_causal(): + if segment_pos_q is None: + segment_pos_q = jnp.broadcast_to( + jnp.arange(segment_ids_q.shape[-1], dtype=jnp.int32), segment_ids_q.shape + ) + if segment_pos_kv is None: + segment_pos_kv = jnp.broadcast_to( + jnp.arange(segment_ids_kv.shape[-1], dtype=jnp.int32), segment_ids_kv.shape + ) + inv_causal_mask = make_attention_mask( + segment_pos_q, segment_pos_kv, lambda x, y: jnp.greater_equal(x, y) ) - inv_mask = combine_masks(inv_pad_mask, inv_mask) + inv_mask = combine_masks(inv_causal_mask, inv_mask) if window_size is not None: max_seqlen_q = inv_mask.shape[-2] @@ -157,7 +177,8 @@ def make_mask( return mask -def get_seqlens_and_offsets(segment_ids, segment_pad): +@jax.jit +def get_seqlens_and_offsets(segment_ids): batch, max_seqlen = segment_ids.shape bincount_vmap = jax.vmap(partial(jnp.bincount, length=max_seqlen)) seqlens_with_zero = bincount_vmap(segment_ids.astype(jnp.int32)) @@ -165,7 +186,7 @@ def get_seqlens_and_offsets(segment_ids, segment_pad): def _find_offsets(x): same_as_previous = jnp.logical_and(x[..., 1:] != x[..., :-1], x[..., 1:] != 0) - first_column = jnp.ones((x.shape[0], 1), dtype=bool) + first_column = x[..., :1] != 0 same_as_previous = jnp.hstack((first_column, same_as_previous)) return jax.vmap(partial(jnp.argwhere, size=x.shape[1], fill_value=-1))( same_as_previous @@ -173,13 +194,9 @@ def _find_offsets(x): offsets = _find_offsets(segment_ids) offsets = jnp.insert(offsets, -1, values=-1, axis=-1) - if segment_pad is not None: - segment_id_with_paddings = jnp.where(segment_pad, 0, segment_ids) - padding_aware_seqlen = bincount_vmap(segment_id_with_paddings) - output = jnp.insert(padding_aware_seqlen[..., 1:], -1, values=0, axis=-1) - else: - output = jnp.insert(seqlens, -1, values=0, axis=-1) - return output, offsets + seqlens = jnp.insert(seqlens, -1, values=0, axis=-1) + seqlens = jnp.where(seqlens, seqlens, -1) + return seqlens, offsets @jax.jit @@ -200,8 +217,8 @@ def jax_dpa(query, key, value, bias, mask, dropout_rng, **kwargs): query, key, value, - bias=bias, - mask=mask, + bias, + mask, deterministic=not kwargs["is_training"], scale_factor=kwargs["scaling_factor"], dropout_rate=kwargs["dropout_probability"], @@ -228,7 +245,6 @@ def customcall_fused_dpa( TE customcall dot product attention implementation """ qkv_layout = kwargs["qkv_layout"] - is_thd = get_qkv_format(qkv_layout) == QKVFormat.THD match qkv_layout: case QKVLayout.BS3HD | QKVLayout.T3HD: query, key, value = map(partial(jnp.expand_dims, axis=-3), [query, key, value]) @@ -242,7 +258,7 @@ def customcall_fused_dpa( qkv_args = (query, key, value) case _: raise ValueError(f"Unsupported {qkv_layout=}") - if not is_thd: + if not qkv_layout.is_thd(): kwargs.pop("max_segments_per_seq") return fused_attn(qkv_args, bias, mask, dropout_rng, **kwargs).astype(query.dtype) return fused_attn_thd( @@ -262,10 +278,10 @@ class BiasShape(Enum): Enum class to represent the different bias shapes used in the fused attention. """ - BIAS_1HSS = "1HSS" - BIAS_B1SS = "B1SS" - BIAS_BHSS = "BHSS" - BIAS_11SS = "11SS" + _1HSS = "1HSS" + _B1SS = "B1SS" + _BHSS = "BHSS" + _11SS = "11SS" @dataclass @@ -300,18 +316,12 @@ def _get_max_segments_per_sequence(self): def _check_configs(self): # TODO(rewang): probably adds this in is_fused_attn_available - if get_qkv_format(self.qkv_layout) == QKVFormat.THD and not self.attn_mask_type in [ - AttnMaskType.PADDING_MASK, - AttnMaskType.PADDING_CAUSAL_MASK, - ]: + if self.qkv_layout.is_thd() and not self.attn_mask_type.is_padding(): pytest.skip("THD format requires padding masks.") - qkv_format = get_qkv_format(self.qkv_layout) - if self.qkv_layout == QKVLayout.BS3HD or qkv_format == QKVFormat.THD: + if self.qkv_layout.is_qkvpacked(): if self.max_seqlen_q != self.max_seqlen_kv: pytest.skip(f"{self.qkv_layout} requires max_seqlen_q == max_seqlen_kv") - - if self.qkv_layout == QKVLayout.BS3HD or self.qkv_layout == QKVLayout.T3HD: if self.num_heads_q != self.num_heads_kv: pytest.skip(f"{self.qkv_layout} requires num_heads_q == num_heads_kv") @@ -339,15 +349,11 @@ def _check_configs(self): if ( self.attn_bias_type == AttnBiasType.POST_SCALE_BIAS - and self.bias_shape != BiasShape.BIAS_1HSS + and self.bias_shape != BiasShape._1HSS ): - if self.attn_mask_type not in [ - AttnMaskType.NO_MASK, - AttnMaskType.CAUSAL_MASK, - ]: + if self.attn_mask_type.is_padding(): pytest.skip( - "B1SS, BHSS and 11SS bias shapes are only supported for " - "AttnMaskType.NO_MASK and AttnMaskType.CAUSAL_MASK." + "B1SS, BHSS and 11SS bias shapes are only supported for non-padding mask" ) elif self.backend != NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen: pytest.skip( @@ -370,18 +376,18 @@ def _setup_inputs(self): if self.attn_bias_type == AttnBiasType.NO_BIAS: bias_shape = None - elif self.bias_shape == BiasShape.BIAS_1HSS: + elif self.bias_shape == BiasShape._1HSS: bias_shape = (1, self.num_heads_q, self.max_seqlen_q, self.max_seqlen_kv) - elif self.bias_shape == BiasShape.BIAS_B1SS: + elif self.bias_shape == BiasShape._B1SS: bias_shape = (self.batch_size, 1, self.max_seqlen_q, self.max_seqlen_kv) - elif self.bias_shape == BiasShape.BIAS_BHSS: + elif self.bias_shape == BiasShape._BHSS: bias_shape = ( self.batch_size, self.num_heads_q, self.max_seqlen_q, self.max_seqlen_kv, ) - elif self.bias_shape == BiasShape.BIAS_11SS: + elif self.bias_shape == BiasShape._11SS: bias_shape = (1, 1, self.max_seqlen_q, self.max_seqlen_kv) else: pytest.fail(f"PyTest attempted to use an unrecognized bias_layout = {self.bias_shape}!") @@ -391,7 +397,7 @@ def _setup_inputs(self): self.v = jax.random.uniform(v_key, v_shape, self.dtype, -1.0) if self.attn_bias_type != AttnBiasType.NO_BIAS: - if self.bias_shape == BiasShape.BIAS_1HSS: + if self.bias_shape == BiasShape._1HSS: self.bias = jax.random.uniform(bias_key, bias_shape, self.dtype, -1.0) else: # [b, 1, s, s], [b, h, s, s] and [1, 1, s, s] bias shapes are workarounds for @@ -408,10 +414,10 @@ def _setup_inputs(self): else: self.bias = None - if self.attn_mask_type in [AttnMaskType.NO_MASK, AttnMaskType.CAUSAL_MASK]: - pad_ratio = 0.0 - else: + if self.attn_mask_type.is_padding(): pad_ratio = 0.3 + else: + pad_ratio = 0.0 def gen_valid(bs, max_seqlen, pad_ratio): pad_len = int(max_seqlen * pad_ratio) @@ -425,6 +431,8 @@ def generate_random_segment_ids( rng = np.random.default_rng(seed=seed) # [1, 1, 1, 2, 2, 3, 3, 3, 3, 0, 0], 0 means pad segment_ids = np.zeros((batch_size, sequence_length), dtype=int) + segment_pos = np.zeros((batch_size, sequence_length), dtype=int) + # [0, 1, 2, 0, 1, 0, 1, 2, 3, 0, 0] # [0, 0, 1, 0, 1, 0, 0, 1, 1, 1, 1], 1 means pad segment_pad = np.zeros((batch_size, sequence_length), dtype=int) @@ -440,58 +448,62 @@ def generate_random_segment_ids( break segment_end = current_pos + segment_size segment_ids[i, current_pos:segment_end] = segment_id + segment_pos[i, current_pos:segment_end] = np.arange(segment_size) if with_segment_pad: num_valid = rng.integers(1, segment_size + 1) segment_pad[i, current_pos + num_valid : segment_end] = 1 current_pos = segment_end segment_id += 1 segment_pad[i, current_pos:sequence_length] = 1 - return segment_ids, segment_pad - if get_qkv_format(self.qkv_layout) == QKVFormat.THD: + segment_ids, segment_pos, segment_pad = map( + jnp.asarray, [segment_ids, segment_pos, segment_pad] + ) + segment_ids = jnp.where(segment_pad, 0, segment_ids) + return segment_ids, segment_pos, segment_pad + + if self.qkv_layout.is_thd(): self.num_segments_per_seq = 2 - self.token_q, self.segment_pad_q = generate_random_segment_ids( + self.segment_ids_q, self.segment_pos_q, self.pad_q = generate_random_segment_ids( self.batch_size, self.max_seqlen_q, self.num_segments_per_seq, seed=42 ) - # TODO(rewang): Check if qkvpacked supported different q/kv - # TODO(rewang): Causal with different q/kv segment_id fails - if self.qkv_layout == QKVLayout.T3HD or is_causal_mask(self.attn_mask_type): - self.token_kv = self.token_q - self.segment_pad_kv = self.segment_pad_q + if self.qkv_layout == QKVLayout.T3HD: + self.segment_ids_kv = self.segment_ids_q + self.segment_pos_kv = self.segment_pos_q + self.pad_kv = self.pad_q else: - self.token_kv, self.segment_pad_kv = generate_random_segment_ids( + self.segment_ids_kv, self.segment_pos_kv, self.pad_kv = generate_random_segment_ids( self.batch_size, self.max_seqlen_kv, self.num_segments_per_seq, seed=2024, ) - self.pad_q = self.segment_pad_q - self.pad_kv = self.segment_pad_kv + self.seqlens_q, self.offsets_q = get_seqlens_and_offsets(self.segment_ids_q) + self.seqlens_kv, self.offsets_kv = get_seqlens_and_offsets(self.segment_ids_kv) else: self.num_segments_per_seq = 1 - self.token_q, self.pad_q = gen_valid(self.batch_size, self.max_seqlen_q, pad_ratio) - self.token_kv, self.pad_kv = gen_valid(self.batch_size, self.max_seqlen_kv, pad_ratio) - self.segment_pad_q = self.segment_pad_kv = None + self.segment_ids_q, self.pad_q = gen_valid( + self.batch_size, self.max_seqlen_q, pad_ratio + ) + self.segment_ids_kv, self.pad_kv = gen_valid( + self.batch_size, self.max_seqlen_kv, pad_ratio + ) + self.segment_pos_q = self.segment_pos_kv = None + self.seqlens_q = self.seqlens_kv = self.offsets_q = self.offsets_kv = None + # For reference code self.mask = make_mask( - self.token_q, - self.token_kv, - self.segment_pad_q, - self.segment_pad_kv, + self.segment_ids_q, + self.segment_ids_kv, + self.segment_pos_q, + self.segment_pos_kv, self.attn_mask_type, self.window_size, ) - if get_qkv_format(self.qkv_layout) == QKVFormat.THD: - self.seqlens_q, self.offsets_q = get_seqlens_and_offsets( - self.token_q, self.segment_pad_q - ) - self.seqlens_kv, self.offsets_kv = get_seqlens_and_offsets( - self.token_kv, self.segment_pad_kv - ) + if self.qkv_layout.is_thd(): self.mask_for_customcall = None # THD format doesn't support mask else: - self.seqlens_q = self.seqlens_kv = self.offsets_q = self.offsets_kv = None self.mask_for_customcall = self.mask self.dropout_rng = dropout_key if self.dropout_prob > 0 else None @@ -547,13 +559,11 @@ def test_backward(self): """ self._setup_inputs() - if self.attn_bias_type != AttnBiasType.NO_BIAS and self.bias_shape != BiasShape.BIAS_1HSS: - pytest.skip("Bias gradient calculation is only supported for 1HSS bias shape.") def grad_func(func, *args, **kwargs): # Gradient is small, use a gradient multiplier to amplify the gradient gradient_multiplier = self.max_seqlen_q * self.num_heads_q - if is_causal_mask(self.attn_mask_type): + if self.attn_mask_type.is_causal(): gradient_multiplier /= 10 # Keep only valid result for the gradient ret_valid = jnp.where( @@ -586,7 +596,7 @@ def grad_func(func, *args, **kwargs): } # We can compute dBias only for the [1, h, s, s] layout - arg_nums = (0, 1, 2, 3) if self.bias_shape == BiasShape.BIAS_1HSS else (0, 1, 2) + arg_nums = (0, 1, 2, 3) if self.bias_shape == BiasShape._1HSS else (0, 1, 2) # Use FP16/BF16 to sum the results may cause overflow, use FP32 for the summation jitted_primitive = jit( @@ -629,7 +639,7 @@ def check_dqkv(primitive, reference, pad): check_dqkv(primitive_dk, reference_dk, self.pad_kv) check_dqkv(primitive_dv, reference_dv, self.pad_kv) - if self.attn_bias_type != AttnBiasType.NO_BIAS and self.bias_shape == BiasShape.BIAS_1HSS: + if self.attn_bias_type != AttnBiasType.NO_BIAS and self.bias_shape == BiasShape._1HSS: primitive_dbias = primitive_dgrad[3] reference_dbias = reference_dgrad[3] @@ -658,16 +668,6 @@ def check_dqkv(primitive, reference, pad): ) -@pytest.mark.parametrize( - "attn_bias_type, bias_shape", - [ - pytest.param(AttnBiasType.NO_BIAS, None, id="NO_BIAS"), - pytest.param(AttnBiasType.POST_SCALE_BIAS, BiasShape.BIAS_1HSS, id="POST_SCALE_BIAS-1HSS"), - pytest.param(AttnBiasType.POST_SCALE_BIAS, BiasShape.BIAS_B1SS, id="POST_SCALE_BIAS-B1SS"), - pytest.param(AttnBiasType.POST_SCALE_BIAS, BiasShape.BIAS_BHSS, id="POST_SCALE_BIAS-BHSS"), - pytest.param(AttnBiasType.POST_SCALE_BIAS, BiasShape.BIAS_11SS, id="POST_SCALE_BIAS-11SS"), - ], -) @pytest.mark.parametrize( "attn_mask_type", [ @@ -736,6 +736,16 @@ class TestFusedAttn: pytest.param(False, id="INFERENCE"), ], ) + @pytest.mark.parametrize( + "attn_bias_type, bias_shape", + [ + pytest.param(AttnBiasType.NO_BIAS, None, id="NO_BIAS"), + pytest.param(AttnBiasType.POST_SCALE_BIAS, BiasShape._1HSS, id="POST_SCALE_BIAS-1HSS"), + pytest.param(AttnBiasType.POST_SCALE_BIAS, BiasShape._B1SS, id="POST_SCALE_BIAS-B1SS"), + pytest.param(AttnBiasType.POST_SCALE_BIAS, BiasShape._BHSS, id="POST_SCALE_BIAS-BHSS"), + pytest.param(AttnBiasType.POST_SCALE_BIAS, BiasShape._11SS, id="POST_SCALE_BIAS-11SS"), + ], + ) def _test_forward( b, s_q, @@ -779,6 +789,13 @@ def _test_forward( runner.test_forward() @staticmethod + @pytest.mark.parametrize( + "attn_bias_type, bias_shape", + [ + pytest.param(AttnBiasType.NO_BIAS, None, id="NO_BIAS"), + pytest.param(AttnBiasType.POST_SCALE_BIAS, BiasShape._1HSS, id="POST_SCALE_BIAS-1HSS"), + ], + ) def test_backward( b, s_q, diff --git a/tests/jax/utils.py b/tests/jax/utils.py index 78a6225e1f..242bafa5e2 100644 --- a/tests/jax/utils.py +++ b/tests/jax/utils.py @@ -19,7 +19,11 @@ from jax import nn as jax_nn from jax import random as jax_random -from transformer_engine.jax.attention import AttnMaskType, make_swa_mask +from transformer_engine.jax.attention import ( + AttnMaskType, + canonicalize_attn_mask_type, + make_swa_mask, +) from transformer_engine.jax.fp8 import DType as TEDType PRNGKey = Any @@ -913,15 +917,7 @@ def apply_swa_mask( window_size: Tuple[int, int] = (-1, -1), ) -> Array: """Apply the sliding window mask to a given mask""" - mask_map = { - "no_mask": AttnMaskType.NO_MASK, - "padding": AttnMaskType.PADDING_MASK, - "causal": AttnMaskType.CAUSAL_MASK, - "padding_causal": AttnMaskType.PADDING_CAUSAL_MASK, - "causal_bottom_right": AttnMaskType.CAUSAL_BOTTOM_RIGHT_MASK, - "padding_causal_bottom_right": AttnMaskType.PADDING_CAUSAL_BOTTOM_RIGHT_MASK, - } - _attn_mask_type = mask_map.get(attn_mask_type, None) + _attn_mask_type = canonicalize_attn_mask_type(attn_mask_type) assert _attn_mask_type is not None max_seqlen_q = original_mask.shape[-2] max_seqlen_kv = original_mask.shape[-1] diff --git a/transformer_engine/jax/attention.py b/transformer_engine/jax/attention.py index 3ecc9bcd75..53451b6a78 100644 --- a/transformer_engine/jax/attention.py +++ b/transformer_engine/jax/attention.py @@ -46,6 +46,42 @@ class AttnMaskType(Enum): CAUSAL_BOTTOM_RIGHT_MASK = NVTE_Mask_Type.NVTE_CAUSAL_BOTTOM_RIGHT_MASK PADDING_CAUSAL_BOTTOM_RIGHT_MASK = NVTE_Mask_Type.NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK + def is_causal(self): + """Returns True if the mask is a causal mask""" + return self in [ + AttnMaskType.CAUSAL_MASK, + AttnMaskType.PADDING_CAUSAL_MASK, + AttnMaskType.CAUSAL_BOTTOM_RIGHT_MASK, + AttnMaskType.PADDING_CAUSAL_BOTTOM_RIGHT_MASK, + ] + + def is_padding(self): + """Returns True if the mask includes padding""" + return self in [ + AttnMaskType.PADDING_MASK, + AttnMaskType.PADDING_CAUSAL_MASK, + AttnMaskType.PADDING_CAUSAL_BOTTOM_RIGHT_MASK, + ] + + def is_bottom_right(self): + """Returns True if the causal mask is calculated from the bottom-right section""" + return self in [ + AttnMaskType.CAUSAL_BOTTOM_RIGHT_MASK, + AttnMaskType.PADDING_CAUSAL_BOTTOM_RIGHT_MASK, + ] + + +class QKVFormat(Enum): + """ + SBHD: q,k,v memory layout with [s, b, ..., h, d] + BSHD: q,k,v memory layout with [b, s, ..., h, d] + THD: q,k,v memory layout is same as BSHD, but allow multiple segments packed in a sequence. + """ + + SBHD = NVTE_QKV_Format.NVTE_SBHD + BSHD = NVTE_QKV_Format.NVTE_BSHD + THD = NVTE_QKV_Format.NVTE_THD + class QKVLayout(Enum): """ @@ -66,17 +102,35 @@ class QKVLayout(Enum): THD_T2HD = NVTE_QKV_Layout.NVTE_THD_T2HD THD_THD_THD = NVTE_QKV_Layout.NVTE_THD_THD_THD - -class QKVFormat(Enum): - """ - SBHD: q,k,v memory layout with [s, b, ..., h, d] - BSHD: q,k,v memory layout with [b, s, ..., h, d] - THD: q,k,v memory layout is same as BSHD, but allow multiple segments packed in a sequence. - """ - - SBHD = NVTE_QKV_Format.NVTE_SBHD - BSHD = NVTE_QKV_Format.NVTE_BSHD - THD = NVTE_QKV_Format.NVTE_THD + def get_qkv_format(self): + """ + Return the corresponding qkv_format (BSHD, SBHD, THD) + """ + return QKVFormat(nvte_get_qkv_format(self.value)) + + def is_qkvpacked(self): + """ + Return True if the query, key, value is packed + """ + return self in [QKVLayout.BS3HD, QKVLayout.T3HD] + + def is_kvpacked(self): + """ + Return True if the key, value is packed + """ + return self in [QKVLayout.BSHD_BS2HD, QKVLayout.THD_T2HD] + + def is_separate(self): + """ + Return True if the query, key, value are three separate tensors + """ + return self in [QKVLayout.BSHD_BSHD_BSHD, QKVLayout.THD_THD_THD] + + def is_thd(self): + """ + Return True if the layout belongs to THD + """ + return self in [QKVLayout.T3HD, QKVLayout.THD_T2HD, QKVLayout.THD_THD_THD] class CPStrategy(Enum): @@ -92,13 +146,6 @@ class CPStrategy(Enum): RING = 2 -def get_qkv_format(qkv_layout): - """ - Get qkv_format from qkv_layout - """ - return QKVFormat(nvte_get_qkv_format(qkv_layout.value)) - - def make_swa_mask( max_seqlen_q: int, max_seqlen_kv: int, @@ -136,12 +183,8 @@ def make_swa_mask( swa_mask = jnp.ones((max_seqlen_q, max_seqlen_kv), dtype=dtype) if window_size is None: return swa_mask - bottom_right_masks = [ - AttnMaskType.CAUSAL_BOTTOM_RIGHT_MASK, - AttnMaskType.PADDING_CAUSAL_BOTTOM_RIGHT_MASK, - ] left_window, right_window = window_size - if attn_mask_type in bottom_right_masks: + if attn_mask_type.is_bottom_right(): if left_window < 0: left_window = max_seqlen_kv if right_window < 0: @@ -310,7 +353,7 @@ def fused_attn( (jnp.ndarray): The output tensor from the fused attention. """ assert ( - get_qkv_format(qkv_layout) != QKVFormat.THD + not qkv_layout.is_thd() ), "Please use transformer_engine.jax.attention.fused_attn_thd for THD format." # Check inputs qkv @@ -327,11 +370,7 @@ def fused_attn( ), f"qkv=(query, key, value) is expected with {qkv_layout=} but got {qkv=}" # convert the mask to seqlens, mask doesn't support ragged offsets - if attn_mask_type in [ - AttnMaskType.NO_MASK, - AttnMaskType.CAUSAL_MASK, - AttnMaskType.CAUSAL_BOTTOM_RIGHT_MASK, - ]: + if not attn_mask_type.is_padding(): batch, q_max_seqlen, kv_max_seqlen = _obtain_batch_and_max_seqlen(qkv, qkv_layout) q_seq_lens = jnp.full((batch,), q_max_seqlen, dtype=jnp.int32) kv_seq_lens = jnp.full((batch,), kv_max_seqlen, dtype=jnp.int32) @@ -448,7 +487,7 @@ def fused_attn_thd( QKVLayout.T3HD, 0.125, 0, True, 3) """ assert ( - get_qkv_format(qkv_layout) == QKVFormat.THD + qkv_layout.is_thd() ), "Please use transformer_engine.jax.attention.fused_attn for non-THD format." # Check inputs qkv diff --git a/transformer_engine/jax/cpp_extensions/attention.py b/transformer_engine/jax/cpp_extensions/attention.py index 6591861057..f3dfca21ef 100644 --- a/transformer_engine/jax/cpp_extensions/attention.py +++ b/transformer_engine/jax/cpp_extensions/attention.py @@ -3,7 +3,7 @@ # See LICENSE for license information. """JAX/TE custom ops for attention""" from dataclasses import dataclass -from functools import partial, reduce, cache +from functools import partial, reduce import operator import os from typing import Optional, Tuple @@ -133,7 +133,6 @@ def get_fused_attn_backend(self): ) @staticmethod - @cache def is_non_deterministic_allowed(): """Check if non-deterministic kernels are allowed""" return bool(int(os.getenv("NVTE_ALLOW_NONDETERMINISTIC_ALGO", "1"))) From 83dac8cf30d8abe2af421eb82ffd1c5a4fc859cb Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Tue, 17 Dec 2024 18:15:37 -0800 Subject: [PATCH 08/11] [PyTorch] Add weights_only=False for torch.load (#1374) add weights_only=False for torch.load Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- tests/pytorch/test_float8tensor.py | 2 +- tests/pytorch/test_sanity.py | 2 +- tests/pytorch/test_torch_save_load.py | 6 +++--- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/tests/pytorch/test_float8tensor.py b/tests/pytorch/test_float8tensor.py index 51f4c695dc..a25ffa773c 100644 --- a/tests/pytorch/test_float8tensor.py +++ b/tests/pytorch/test_float8tensor.py @@ -339,7 +339,7 @@ def test_serialization( del x_fp8, byte_stream # Deserialize tensor - x_fp8 = torch.load(io.BytesIO(x_bytes)) + x_fp8 = torch.load(io.BytesIO(x_bytes), weights_only=False) del x_bytes # Check results diff --git a/tests/pytorch/test_sanity.py b/tests/pytorch/test_sanity.py index 4f057c12fe..32d517460a 100644 --- a/tests/pytorch/test_sanity.py +++ b/tests/pytorch/test_sanity.py @@ -1101,7 +1101,7 @@ def get_model(dtype, config): del block block = get_model(dtype, config) - block.load_state_dict(torch.load(path)) + block.load_state_dict(torch.load(path, weights_only=False)) torch.set_rng_state(_cpu_rng_state_new) torch.cuda.set_rng_state(_cuda_rng_state_new) diff --git a/tests/pytorch/test_torch_save_load.py b/tests/pytorch/test_torch_save_load.py index 7bf8fb99d5..be77109cb7 100644 --- a/tests/pytorch/test_torch_save_load.py +++ b/tests/pytorch/test_torch_save_load.py @@ -124,7 +124,7 @@ def forward(self, inp, weight): torch.save(model_in.state_dict(), tmp_filename) model_out = Test_TE_Export(precision, True) - model_out.load_state_dict(torch.load(tmp_filename)) + model_out.load_state_dict(torch.load(tmp_filename, weights_only=False)) model_out.eval() # scaling fwd @@ -263,7 +263,7 @@ def test_fp8_model_checkpoint( # to load the fp8 metadata before loading tensors. # # Load checkpoint - model.load_state_dict(torch.load(io.BytesIO(model_bytes))) + model.load_state_dict(torch.load(io.BytesIO(model_bytes), weights_only=False)) del model_bytes # Check that loaded model matches saved model @@ -450,7 +450,7 @@ def train_step( torch.testing.assert_close(m_model.scale_inv, m_ref["scale_inv"], **exact_tols) # Load checkpoint - model.load_state_dict(torch.load(io.BytesIO(model_bytes))) + model.load_state_dict(torch.load(io.BytesIO(model_bytes), weights_only=False)) del model_bytes # Check that new model's FP8 metadata matches saved model From f033498f6c941b190c869bfa09310c2de3efd2c9 Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Tue, 17 Dec 2024 18:15:47 -0800 Subject: [PATCH 09/11] [PyTorch] Fix get_swa_mask() for padding masks (#1281) * WIP: fix get_swa_mask for padding Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix mask type setting Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix the order of checking valid swa and changing mask type Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix lint Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * revamp to get full mask Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- tests/pytorch/fused_attn/test_fused_attn.py | 28 +-- transformer_engine/pytorch/attention.py | 227 ++++++++++++-------- 2 files changed, 157 insertions(+), 98 deletions(-) diff --git a/tests/pytorch/fused_attn/test_fused_attn.py b/tests/pytorch/fused_attn/test_fused_attn.py index 4e995dabb1..dea31b5971 100644 --- a/tests/pytorch/fused_attn/test_fused_attn.py +++ b/tests/pytorch/fused_attn/test_fused_attn.py @@ -531,18 +531,22 @@ def test_dpa_bias_shapes(dtype, model_configs, model): model_configs_swa = { # test: b, h, hg, d, sq, skv, p, mask, bias - "swa_1_0": ModelConfig(4, 16, 16, 64, 128, 128, 0.0, "no_mask", "no_bias"), - "swa_1_1": ModelConfig(2, 16, 16, 64, 128, 256, 0.0, "no_mask", "no_bias"), - "swa_1_2": ModelConfig(4, 24, 24, 128, 2048, 2048, 0.0, "no_mask", "no_bias"), - "swa_1_3": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "no_mask", "no_bias"), - "swa_2_0": ModelConfig(4, 16, 16, 64, 128, 128, 0.0, "causal", "no_bias"), - "swa_2_1": ModelConfig(2, 16, 16, 64, 128, 256, 0.0, "causal", "no_bias"), - "swa_2_2": ModelConfig(4, 24, 24, 128, 2048, 2048, 0.0, "causal", "no_bias"), - "swa_2_3": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "causal", "no_bias"), - "swa_3_0": ModelConfig(4, 16, 16, 64, 128, 128, 0.0, "causal_bottom_right", "no_bias"), - "swa_3_1": ModelConfig(2, 16, 16, 64, 128, 256, 0.0, "causal_bottom_right", "no_bias"), - "swa_3_2": ModelConfig(4, 24, 24, 128, 2048, 2048, 0.0, "causal_bottom_right", "no_bias"), - "swa_3_3": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "causal_bottom_right", "no_bias"), + "swa_1_0": ModelConfig(4, 16, 16, 64, 2048, 2048, 0.0, "no_mask", "no_bias"), + "swa_1_1": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "no_mask", "no_bias"), + "swa_2_0": ModelConfig(4, 16, 16, 64, 2048, 2048, 0.0, "causal", "no_bias"), + "swa_2_1": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "causal", "no_bias"), + "swa_3_0": ModelConfig(4, 16, 16, 64, 2048, 2048, 0.0, "causal_bottom_right", "no_bias"), + "swa_3_1": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "causal_bottom_right", "no_bias"), + "swa_4_0": ModelConfig(4, 16, 16, 64, 2048, 2048, 0.0, "padding", "no_bias"), + "swa_4_1": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "padding", "no_bias"), + "swa_5_0": ModelConfig(4, 16, 16, 64, 2048, 2048, 0.0, "padding_causal", "no_bias"), + "swa_5_1": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "padding_causal", "no_bias"), + "swa_6_0": ModelConfig( + 4, 16, 16, 64, 2048, 2048, 0.0, "padding_causal_bottom_right", "no_bias" + ), + "swa_6_1": ModelConfig( + 2, 24, 24, 128, 2048, 4096, 0.0, "padding_causal_bottom_right", "no_bias" + ), } diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index 8c529c58d0..be0d176520 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -1024,27 +1024,51 @@ def swap_key_value_dict(self, batch_indices): @torch.no_grad() -def get_swa_mask( - window_size: Tuple[int, int], +def get_full_mask( max_seqlen_q: int, max_seqlen_kv: int, attn_mask_type: str = "no_mask", - attention_mask: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None, + attention_mask: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]] = None, + window_size: Tuple[int, int] = None, + attention_type: str = "self", + bottom_right_alignment: bool = True, ) -> torch.Tensor: """ - Convert sliding window `window_size` to an equivalent "`arbitrary`" mask. - For "`causal`" mask type, the sliding window diagonal is aligned to the top left corner, - and for other mask types, the bottom right corner. + Get full attention mask in [..., max_seqlen_q, max_seqlen_kv] shape, based on `attn_mask_type`, + `attention_mask`, and `window_size`. For sliding window attention, the diagonal alignment depends + on both `attn_mask_type` and `bottom_right_alignment`, as detailed below.:: + + attn_mask_type output shape diagonal alignment + -------------------------------------------------------------------------------------------- + no_mask [1, 1, max_seqlen_q, max_seqlen_kv] follow bottom_right_alignment + causal [1, 1, max_seqlen_q, max_seqlen_kv] always top left + causal_bottom_right [1, 1, max_seqlen_q, max_seqlen_kv] always bottom right + padding [batch_size, 1, max_seqlen_q, max_seqlen_kv] follow bottom_right_alignment + padding_causal [batch_size, 1, max_seqlen_q, max_seqlen_kv] always top left + padding_causal_bottom_right [batch_size, 1, max_seqlen_q, max_seqlen_kv] always bottom right + arbitrary same as attention_mask follow bottom_right_alignment + + .. note:: + + For "padding_bottom_right" mask, or "padding" mask with `bottom_right_alignment` = True, the bottom right + diagonal comes from the bottom right corner of the [actual_seqlens_q[i], actual_seqlens_kv[i]] matrix, + i = 0,...,batch_size-1, not the [max_seqlen_q, max_seqlen_kv] matrix. For example, with max_seqlen_q = 4, + max_seqlen_kv = 4, attn_mask_type = "padding", attention_type = "cross", and attention_mask = ( + [[False, False, True, True], [False, False, False, False]], + [[False, False, False, True], [False, True, True, True]]), the returned full attention mask has [2, 4, 4] + shape and is,:: + + [[[False, False, False, True], + [False, False, False, True], + [ True, True, True, True], + [ True, True, True, True]], + [[False, True, True, True], + [False, True, True, True], + [False, True, True, True], + [False, True, True, True]]] Parameters ---------- - window_size: Tuple[int, int] - Sliding window size for local attention, where query at position i attends to keys - in [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q - + window_size[1]] inclusive. Special cases (-1, -1) and (-1, 0) mean no sliding - window and causal mask specifically. Both `causal` and `causal_bottom_right` masks - map to `window_size = (-1, 0)` and Transformer Engine distinguishes them based on - `attn_mask_type`. max_seqlen_q: int Maximum sequence length for queries. max_seqlen_kv: int @@ -1052,33 +1076,105 @@ def get_swa_mask( attn_mask_type: str, default = `no_mask` Attention mask type, {"`no_mask`", "`padding`", "`causal`", "`padding_causal`", "`causal_bottom_right`", "`padding_causal_bottom_right`", "`arbitrary`"} - attention_mask: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]], + attention_mask: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], default = `None` - Boolean tensor(s) used to mask out attention softmax input. + Boolean tensor(s) used to mask out attention softmax input. Please see DotProductAttention + for the requirements of `attention_mask` for different `attn_mask_type`s. + window_size: Tuple[int, int], default = `None` + Sliding window size for local attention, where query at position i attends to keys + in [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + + window_size[1]] inclusive. Special cases (-1, -1) and (-1, 0) mean no sliding + window and causal mask specifically. Both `causal` and `causal_bottom_right` masks + map to `window_size = (-1, 0)` and Transformer Engine distinguishes them based on + `attn_mask_type`. + attention_type: str, default = "self" + Attention type, {"self", "cross"} + bottom_right_alignment: bool, default = `True` + Whether to align the diagonal of the sliding window attention to the bottom right (`True`) + or top left (`False`) corner of the softmax matrix. Ignored if `attn_mask_type` explicitly + specifies "causal" or "causal_bottom_right". Returns ---------- + attn_mask_type: str + For sliding window attention (>=0, >0), "arbitrary"; otherwise, the same as input `attn_mask_type` attention_mask: torch.Tensor - Combined `attention_mask` (input) and sliding window attention mask. - The shape is [max_seqlen_q, max_seqlen_kv] when input `attention_mask` is None; - else, the same shape as input `attention_mask`. + The full attention mask based on `attn_mask_type`, `attention_mask` and `window_size` + actual_seqlens_q: torch.Tensor + For padding masks, the actual sequence lengths for queries, in shape [batch_size]. + For other masks, `None`. + actual_seqlens_kv: Optional[torch.Tensor], default = `None` + For padding masks, the actual sequence lengths for keys and values, in shape [batch_size]. + For other masks, `None`. """ - mask = torch.ones(max_seqlen_q, max_seqlen_kv, dtype=torch.bool, device="cuda") - if attn_mask_type in ["causal"]: - left = window_size[0] if window_size[0] != -1 else max_seqlen_q - right = window_size[1] if window_size[1] != -1 else max_seqlen_q - mask_upper = torch.triu(mask, diagonal=-left) - mask_lower = torch.tril(mask_upper, diagonal=right) - else: - left = window_size[0] if window_size[0] != -1 else max_seqlen_kv - right = window_size[1] if window_size[1] != -1 else max_seqlen_kv - mask_upper = torch.triu(mask, diagonal=max_seqlen_kv - max_seqlen_q - left) - mask_lower = torch.tril(mask_upper, diagonal=max_seqlen_kv - max_seqlen_q + right) - attn_mask_type = "arbitrary" - mask = mask_lower.logical_not() + # perform basic checks + change_type = window_size is not None and ( + window_size[0] != -1 or window_size[1] not in [-1, 0] + ) + if window_size is None: + window_size = (-1, -1) + if "causal" in attn_mask_type: + window_size = (window_size[0], 0) + window_size = ( + max_seqlen_kv if window_size[0] == -1 else window_size[0], + max_seqlen_q if window_size[1] == -1 else window_size[1], + ) + + # apply padding mask + actual_seqlens_q = None + actual_seqlens_kv = None + if "padding" in attn_mask_type: + if attention_type == "self": + attention_mask = torch.logical_or( + attention_mask.squeeze(1).unsqueeze(3), attention_mask + ) + else: + attention_mask = torch.logical_or( + attention_mask[0].squeeze(1).unsqueeze(3), attention_mask[1] + ) + m = attention_mask.logical_not() + actual_seqlens_q = m[:, 0, :, 0].sum(dim=1) + actual_seqlens_kv = m[:, 0, 0, :].sum(dim=1) + + # apply SWA mask + mask = torch.arange(max_seqlen_q, dtype=torch.int32, device="cuda").view( + 1, 1, max_seqlen_q, 1 + ) - torch.arange(max_seqlen_kv, dtype=torch.int32, device="cuda").view(1, 1, 1, max_seqlen_kv) + swa_left = None + swa_right = None + if attn_mask_type == "causal_bottom_right" or ( + attn_mask_type in ["no_mask", "arbitrary"] and bottom_right_alignment + ): + swa_left = mask + max_seqlen_kv - max_seqlen_q - window_size[0] + swa_right = mask + max_seqlen_kv - max_seqlen_q + window_size[1] + elif attn_mask_type in ["causal", "padding_causal"] or ( + attn_mask_type in ["no_mask", "padding", "arbitrary"] and not bottom_right_alignment + ): + swa_left = mask - window_size[0] + swa_right = mask + window_size[1] + elif attn_mask_type == "padding_causal_bottom_right" or ( + attn_mask_type == "padding" and bottom_right_alignment + ): + batch_size = attention_mask.shape[0] + swa_left = mask.expand(batch_size, 1, max_seqlen_q, max_seqlen_kv) + ( + actual_seqlens_kv - actual_seqlens_q - window_size[0] + ).view(batch_size, 1, 1, 1) + swa_right = mask.expand(batch_size, 1, max_seqlen_q, max_seqlen_kv) + ( + actual_seqlens_kv - actual_seqlens_q + window_size[1] + ).view(batch_size, 1, 1, 1) + swa_mask = torch.logical_not( + torch.where(swa_left <= 0, 1, 0) - torch.where(swa_right < 0, 1, 0) + ) if attention_mask is not None: - mask = torch.logical_and(attention_mask, mask) - return attn_mask_type, mask + attention_mask = torch.logical_or(swa_mask, attention_mask) + else: + attention_mask = swa_mask + + # change mask type + if change_type: + attn_mask_type = "arbitrary" + + return attn_mask_type, attention_mask, actual_seqlens_q, actual_seqlens_kv @torch.no_grad() @@ -4733,6 +4829,7 @@ def forward( cu_seqlens_kv: Optional[torch.Tensor] = None, # pylint: disable=unused-argument attn_mask_type: str = "causal", attention_mask: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None, + window_size: Optional[Tuple[int, int]] = None, core_attention_bias_type: str = "no_bias", core_attention_bias: Optional[torch.Tensor] = None, alibi_slopes: Optional[torch.Tensor] = None, @@ -4752,53 +4849,15 @@ def forward( query_layer.shape[0], key_layer.shape[0], ) - if "padding" in attn_mask_type: - if self.attention_type == "self": - assert attention_mask.shape == ( - batch_size, - 1, - 1, - max_seqlen_q, - ), "attention_mask should be a single tensor with [b, 1, 1, sq] shape!" - attention_mask = torch.logical_or( - attention_mask.squeeze(1).unsqueeze(3), attention_mask - ) - else: - assert ( - len(attention_mask) == 2 - and attention_mask[0].shape == (batch_size, 1, 1, max_seqlen_q) - and attention_mask[1].shape == (batch_size, 1, 1, max_seqlen_kv) - ), ( - "attention_mask should be a tuple of two tensors with shapes " - "[b, 1, 1, sq] and [b, 1, 1, skv]!" - ) - attention_mask = torch.logical_or( - attention_mask[0].squeeze(1).unsqueeze(3), attention_mask[1] - ) - mask = attention_mask.squeeze(1).logical_not() - actual_seqlens_q = mask[:, :, 0].sum(dim=1) - actual_seqlens_kv = mask[:, 0, :].sum(dim=1) - mask = torch.arange(max_seqlen_q, dtype=torch.int32, device="cuda").view( - 1, 1, max_seqlen_q, 1 - ) - torch.arange(max_seqlen_kv, dtype=torch.int32, device="cuda").view( - 1, 1, 1, max_seqlen_kv - ) - if attn_mask_type == "padding_causal": - attention_mask = torch.logical_or( - torch.where(mask.view(1, 1, max_seqlen_q, max_seqlen_kv) < 0, 1, 0), - attention_mask, - ) - if attn_mask_type == "padding_causal_bottom_right": - attention_mask = torch.logical_or( - torch.where( - mask.expand(batch_size, 1, max_seqlen_q, max_seqlen_kv) - + (actual_seqlens_kv - actual_seqlens_q).view(batch_size, 1, 1, 1) - < 0, - 1, - 0, - ), - attention_mask, - ) + + attn_mask_type, attention_mask, actual_seqlens_q, actual_seqlens_kv = get_full_mask( + max_seqlen_q, + max_seqlen_kv, + attn_mask_type=attn_mask_type, + attention_mask=attention_mask, + window_size=window_size, + attention_type=self.attention_type, + ) batch_size, seqlen = query_layer.shape[1], query_layer.shape[0] apply_qk_layer_scaling = self.apply_qk_layer_scaling and key_layer.dtype == torch.float16 @@ -8274,12 +8333,6 @@ def forward( ) if use_unfused_attention: - if window_size is not None and ( - window_size[0] != -1 or window_size[1] not in [-1, 0] - ): - attn_mask_type, attention_mask = get_swa_mask( - window_size, max_seqlen_q, max_seqlen_kv, attn_mask_type, attention_mask - ) if checkpoint_core_attention: return self._checkpointed_attention_forward( self.unfused_attention, @@ -8291,6 +8344,7 @@ def forward( cu_seqlens_kv=cu_seqlens_kv, attn_mask_type=attn_mask_type, attention_mask=attention_mask, + window_size=window_size, core_attention_bias_type=core_attention_bias_type, core_attention_bias=core_attention_bias, alibi_slopes=alibi_slopes, @@ -8304,6 +8358,7 @@ def forward( cu_seqlens_kv=cu_seqlens_kv, attn_mask_type=attn_mask_type, attention_mask=attention_mask, + window_size=window_size, core_attention_bias_type=core_attention_bias_type, core_attention_bias=core_attention_bias, alibi_slopes=alibi_slopes, From a3b32ec6cb15dac8dc96ae03e40f51dfd072f195 Mon Sep 17 00:00:00 2001 From: Phuong Nguyen Date: Wed, 18 Dec 2024 10:47:36 -0500 Subject: [PATCH 10/11] [JAX] Move parallel encoder tests to L0 distributed test set. (#1356) * Move test distributed encoder to L0 distributed test suit --------- Signed-off-by: Phuong Nguyen Co-authored-by: Reese Wang --- qa/L0_jax_distributed_unittest/test.sh | 15 +++++++++++++++ qa/L0_jax_unittest/test.sh | 3 +-- 2 files changed, 16 insertions(+), 2 deletions(-) create mode 100644 qa/L0_jax_distributed_unittest/test.sh diff --git a/qa/L0_jax_distributed_unittest/test.sh b/qa/L0_jax_distributed_unittest/test.sh new file mode 100644 index 0000000000..f9e16793a4 --- /dev/null +++ b/qa/L0_jax_distributed_unittest/test.sh @@ -0,0 +1,15 @@ +# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +set -xe + +: ${TE_PATH:=/opt/transformerengine} + +pip install -r $TE_PATH/examples/jax/encoder/requirements.txt + +# Make encoder tests to have run-to-run deterministic to have the stable CI results +export XLA_FLAGS="${XLA_FLAGS} --xla_gpu_deterministic_ops" +pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/examples/jax/encoder/test_multigpu_encoder.py +pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/examples/jax/encoder/test_model_parallel_encoder.py +pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/examples/jax/encoder/test_multiprocessing_encoder.py diff --git a/qa/L0_jax_unittest/test.sh b/qa/L0_jax_unittest/test.sh index db3aa31951..278a3c8b44 100644 --- a/qa/L0_jax_unittest/test.sh +++ b/qa/L0_jax_unittest/test.sh @@ -20,5 +20,4 @@ pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/examples/jax/mnist # Make encoder tests to have run-to-run deterministic to have the stable CI results export XLA_FLAGS="${XLA_FLAGS} --xla_gpu_deterministic_ops" -pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/examples/jax/encoder --ignore=$TE_PATH/examples/jax/encoder/test_multiprocessing_encoder.py -pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/examples/jax/encoder/test_multiprocessing_encoder.py +pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/examples/jax/encoder/test_single_gpu_encoder.py From 838345eba4fdd2a169dd9e087d39c30a360e684a Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Thu, 19 Dec 2024 21:32:41 -0800 Subject: [PATCH 11/11] [common/PyTorch] Add cuDNN SWA (left, 0) + padding + bottom right causal (#1378) * add swa (left,0) + padding + brcm support Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * final fixes Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * upgrade to FE 1.9-rc Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix jax tests Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * skip thd + CP + fused attn tests for cuDNN 9.6+ due to different stats shapes Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- 3rdparty/cudnn-frontend | 2 +- qa/L0_pytorch_unittest/test.sh | 2 +- tests/jax/test_fused_attn.py | 18 +- tests/pytorch/fused_attn/test_fused_attn.py | 186 ++++++++++++------ .../fused_attn/test_fused_attn_with_cp.py | 2 + .../common/fused_attn/fused_attn.cpp | 49 +++-- .../fused_attn_f16_arbitrary_seqlen.cu | 6 +- transformer_engine/pytorch/attention.py | 31 ++- 8 files changed, 195 insertions(+), 101 deletions(-) diff --git a/3rdparty/cudnn-frontend b/3rdparty/cudnn-frontend index 936021bfed..cc5632eda7 160000 --- a/3rdparty/cudnn-frontend +++ b/3rdparty/cudnn-frontend @@ -1 +1 @@ -Subproject commit 936021bfed8c91dc416af1588b2c4eca631a9e45 +Subproject commit cc5632eda70bbdac34455c2d94066d27d10e2699 diff --git a/qa/L0_pytorch_unittest/test.sh b/qa/L0_pytorch_unittest/test.sh index 17307574a9..61dd15d015 100644 --- a/qa/L0_pytorch_unittest/test.sh +++ b/qa/L0_pytorch_unittest/test.sh @@ -13,7 +13,6 @@ pytest -v -s $TE_PATH/tests/pytorch/test_deferred_init.py PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 pytest -v -s $TE_PATH/tests/pytorch/test_numerics.py PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 pytest -v -s $TE_PATH/tests/pytorch/test_cuda_graphs.py pytest -v -s $TE_PATH/tests/pytorch/test_jit.py -NVTE_TORCH_COMPILE=0 NVTE_DEBUG=1 NVTE_DEBUG_LEVEL=1 pytest -o log_cli=true --log-cli-level=INFO -v -s $TE_PATH/tests/pytorch/fused_attn/test_fused_attn.py pytest -v -s $TE_PATH/tests/pytorch/test_fused_rope.py pytest -v -s $TE_PATH/tests/pytorch/test_float8tensor.py pytest -v -s $TE_PATH/tests/pytorch/test_torch_save_load.py @@ -22,3 +21,4 @@ pytest -v -s $TE_PATH/tests/pytorch/test_fused_optimizer.py pytest -v -s $TE_PATH/tests/pytorch/test_multi_tensor.py pytest -v -s $TE_PATH/tests/pytorch/test_fusible_ops.py pytest -v -s $TE_PATH/tests/pytorch/test_permutation.py +NVTE_TORCH_COMPILE=0 NVTE_DEBUG=1 NVTE_DEBUG_LEVEL=1 pytest -o log_cli=true --log-cli-level=INFO -v -s $TE_PATH/tests/pytorch/fused_attn/test_fused_attn.py diff --git a/tests/jax/test_fused_attn.py b/tests/jax/test_fused_attn.py index 759ea893ef..10da7486cf 100644 --- a/tests/jax/test_fused_attn.py +++ b/tests/jax/test_fused_attn.py @@ -170,8 +170,7 @@ def make_mask( max_seqlen_kv = inv_mask.shape[-1] inv_swa_mask = make_swa_mask(max_seqlen_q, max_seqlen_kv, window_size, attn_mask_type) inv_swa_mask = jnp.broadcast_to(inv_swa_mask, inv_mask.shape) - # In inv_swa_mask and inv_mask 0 is masked out - inv_mask = jnp.where(inv_mask != 0, inv_swa_mask, inv_mask) + inv_mask = combine_masks(inv_mask, inv_swa_mask) mask = jnp.logical_not(inv_mask) return mask @@ -315,6 +314,13 @@ def _get_max_segments_per_sequence(self): return self.num_segments_per_seq + 1 def _check_configs(self): + # TODO(rewang): Fix THD + PADDING_CAUSAL + SWA reference + if ( + self.qkv_layout.is_thd() + and self.attn_mask_type == AttnMaskType.PADDING_CAUSAL_MASK + and self.window_size is not None + ): + pytest.skip("THD + PADDING_CAUSAL + SWA reference is not implemented.") # TODO(rewang): probably adds this in is_fused_attn_available if self.qkv_layout.is_thd() and not self.attn_mask_type.is_padding(): pytest.skip("THD format requires padding masks.") @@ -504,7 +510,13 @@ def generate_random_segment_ids( if self.qkv_layout.is_thd(): self.mask_for_customcall = None # THD format doesn't support mask else: - self.mask_for_customcall = self.mask + self.mask_for_customcall = make_mask( + self.segment_ids_q, + self.segment_ids_kv, + self.segment_pos_q, + self.segment_pos_kv, + self.attn_mask_type, + ) self.dropout_rng = dropout_key if self.dropout_prob > 0 else None self.scaling_factor = 1.0 / sqrt(self.head_dim) diff --git a/tests/pytorch/fused_attn/test_fused_attn.py b/tests/pytorch/fused_attn/test_fused_attn.py index dea31b5971..588e6e4ecd 100644 --- a/tests/pytorch/fused_attn/test_fused_attn.py +++ b/tests/pytorch/fused_attn/test_fused_attn.py @@ -237,19 +237,18 @@ def test_dot_product_attention( tols = dict(atol=1.5e-2, rtol=1.5e-2) config = model_configs[model] is_mla = config.head_dim_qk != config.head_dim_v + is_mqa_gqa = config.num_heads != config.num_gqa_groups if qkv_layout is None: if config.attn_type == "self": - qkv_layout = "sb3hd" if not is_mla else "sbhd_sbhd_sbhd" + qkv_layout = "sb3hd" if not is_mla and not is_mqa_gqa else "sbhd_sbhd_sbhd" else: - qkv_layout = "bshd_bs2hd" if not is_mla else "bshd_bshd_bshd" + qkv_layout = "bshd_bs2hd" if not is_mla and not is_mqa_gqa else "bshd_bshd_bshd" if "3" in qkv_layout and config.attn_type == "cross": pytest.skip("No need to test this layout for cross attention") - # Test backend availability - window_size = (-1, -1) - if swa: - window_size = [2, 2] - config.window_size = check_set_window_size(config.attn_mask_type, window_size) + if config.window_size == (-1, -1) and swa: + config.window_size = [2, 2] + config.window_size = check_set_window_size(config.attn_mask_type, config.window_size) available_backends, fused_attn_backends = _get_attention_backends( config, qkv_dtype=dtype, @@ -334,16 +333,16 @@ def test_dot_product_attention( is_training, ) - if unfused_attn_supported and fused_attn_supported: - logging.info("[test_dot_product_attention]: unfused attn vs fused attn") - torch.testing.assert_close(fused_attn_fwd, unfused_attn_fwd, **tols) - for i, _ in enumerate(unfused_attn_bwd): - torch.testing.assert_close(fused_attn_bwd[i], unfused_attn_bwd[i], **tols) if unfused_attn_supported and flash_attn_supported: logging.info("[test_dot_product_attention]: unfused attn vs flash attn") torch.testing.assert_close(flash_attn_fwd, unfused_attn_fwd, **tols) for i, _ in enumerate(flash_attn_bwd): torch.testing.assert_close(unfused_attn_bwd[i], flash_attn_bwd[i], **tols) + if unfused_attn_supported and fused_attn_supported: + logging.info("[test_dot_product_attention]: unfused attn vs fused attn") + torch.testing.assert_close(fused_attn_fwd, unfused_attn_fwd, **tols) + for i, _ in enumerate(unfused_attn_bwd): + torch.testing.assert_close(fused_attn_bwd[i], unfused_attn_bwd[i], **tols) if fused_attn_supported and flash_attn_supported: logging.info("[test_dot_product_attention]: fused attn vs flash attn") torch.testing.assert_close(fused_attn_fwd, flash_attn_fwd, **tols) @@ -399,30 +398,41 @@ def test_dpa_mla(dtype, model_configs, model): model_configs_mask = { # test: b, h, hg, d, sq, skv, p, mask, bias - "mask_1_0": ModelConfig(8, 16, 16, 64, 128, 128, 0.0, "causal", "no_bias"), - "mask_1_1": ModelConfig(4, 16, 16, 64, 128, 256, 0.0, "causal", "no_bias"), - "mask_2_0": ModelConfig(2, 24, 24, 128, 2048, 2048, 0.0, "causal", "no_bias"), - "mask_2_1": ModelConfig(1, 24, 24, 128, 2048, 4096, 0.0, "causal", "no_bias"), - "mask_3_0": ModelConfig(8, 16, 16, 64, 128, 128, 0.0, "padding", "no_bias"), - "mask_3_1": ModelConfig(4, 16, 16, 64, 128, 256, 0.0, "padding", "no_bias"), - "mask_4_0": ModelConfig(2, 24, 24, 128, 2048, 2048, 0.0, "padding", "no_bias"), - "mask_4_1": ModelConfig(1, 24, 24, 128, 2048, 4096, 0.0, "padding", "no_bias"), - "mask_5_0": ModelConfig(8, 16, 16, 64, 128, 128, 0.0, "padding_causal", "no_bias"), - "mask_5_1": ModelConfig(4, 16, 16, 64, 128, 256, 0.0, "padding_causal", "no_bias"), - "mask_6_0": ModelConfig(2, 24, 24, 128, 2048, 2048, 0.0, "padding_causal", "no_bias"), - "mask_6_1": ModelConfig(1, 24, 24, 128, 2048, 4096, 0.0, "padding_causal", "no_bias"), - "mask_7_0": ModelConfig(2, 24, 24, 128, 2048, 2048, 0.0, "causal_bottom_right", "no_bias"), - "mask_7_1": ModelConfig(1, 24, 24, 128, 2048, 4096, 0.0, "causal_bottom_right", "no_bias"), - "mask_8_0": ModelConfig( - 2, 24, 24, 128, 2048, 2048, 0.0, "padding_causal_bottom_right", "no_bias" + "mask_1_0": ModelConfig(2, 16, 16, 64, 2048, 2048, 0.0, "causal", "no_bias"), + "mask_1_1": ModelConfig(2, 24, 1, 128, 2048, 2048, 0.0, "causal", "no_bias"), + "mask_1_2": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "causal", "no_bias"), + "mask_2_0": ModelConfig(2, 16, 16, 64, 2048, 2048, 0.0, "causal_bottom_right", "no_bias"), + "mask_2_1": ModelConfig(2, 24, 1, 128, 2048, 2048, 0.0, "causal_bottom_right", "no_bias"), + "mask_2_2": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "causal_bottom_right", "no_bias"), + "mask_3_0": ModelConfig(2, 16, 16, 64, 2048, 2048, 0.0, "padding", "no_bias"), + "mask_3_1": ModelConfig(2, 24, 1, 128, 2048, 2048, 0.0, "padding", "no_bias"), + "mask_3_2": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "padding", "no_bias"), + "mask_4_0": ModelConfig(2, 16, 16, 64, 2048, 2048, 0.0, "padding_causal", "no_bias"), + "mask_4_1": ModelConfig(2, 24, 1, 128, 2048, 2048, 0.0, "padding_causal", "no_bias"), + "mask_4_2": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "padding_causal", "no_bias"), + "mask_5_0": ModelConfig( + 2, 16, 16, 64, 2048, 2048, 0.0, "padding_causal_bottom_right", "no_bias" + ), + "mask_5_1": ModelConfig( + 2, 24, 1, 128, 2048, 2048, 0.0, "padding_causal_bottom_right", "no_bias" + ), + "mask_5_2": ModelConfig( + 2, 24, 24, 128, 2048, 4096, 0.0, "padding_causal_bottom_right", "no_bias" + ), + "mask_6_0": ModelConfig(2, 16, 16, 128, 1, 2048, 0.0, "causal", "no_bias"), + "mask_6_1": ModelConfig(2, 16, 16, 256, 1, 2048, 0.0, "causal", "no_bias"), + "mask_7_0": ModelConfig(2, 16, 16, 128, 1, 2048, 0.0, "causal_bottom_right", "no_bias"), + "mask_7_1": ModelConfig(2, 16, 16, 256, 1, 2048, 0.0, "causal_bottom_right", "no_bias"), + "mask_8_0": ModelConfig(2, 24, 24, 128, 1, 2048, 0.0, "padding", "no_bias"), + "mask_8_1": ModelConfig(2, 16, 16, 256, 1, 2048, 0.0, "padding", "no_bias"), + "mask_9_0": ModelConfig(2, 24, 24, 128, 1, 2048, 0.0, "padding_causal", "no_bias"), + "mask_9_1": ModelConfig(2, 16, 16, 256, 1, 2048, 0.0, "padding_causal", "no_bias"), + "mask_10_0": ModelConfig( + 2, 24, 24, 128, 1, 2048, 0.0, "padding_causal_bottom_right", "no_bias" ), - "mask_8_1": ModelConfig( - 1, 24, 24, 128, 2048, 4096, 0.0, "padding_causal_bottom_right", "no_bias" + "mask_10_1": ModelConfig( + 2, 16, 16, 256, 1, 2048, 0.0, "padding_causal_bottom_right", "no_bias" ), - "mask_9_0": ModelConfig(2, 16, 16, 128, 1, 2048, 0.0, "causal", "no_bias"), - "mask_9_1": ModelConfig(2, 16, 16, 256, 1, 2048, 0.0, "causal", "no_bias"), - "mask_10_0": ModelConfig(2, 16, 16, 128, 1, 2048, 0.0, "causal_bottom_right", "no_bias"), - "mask_10_1": ModelConfig(2, 16, 16, 256, 1, 2048, 0.0, "causal_bottom_right", "no_bias"), } @@ -531,20 +541,28 @@ def test_dpa_bias_shapes(dtype, model_configs, model): model_configs_swa = { # test: b, h, hg, d, sq, skv, p, mask, bias - "swa_1_0": ModelConfig(4, 16, 16, 64, 2048, 2048, 0.0, "no_mask", "no_bias"), - "swa_1_1": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "no_mask", "no_bias"), - "swa_2_0": ModelConfig(4, 16, 16, 64, 2048, 2048, 0.0, "causal", "no_bias"), - "swa_2_1": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "causal", "no_bias"), - "swa_3_0": ModelConfig(4, 16, 16, 64, 2048, 2048, 0.0, "causal_bottom_right", "no_bias"), - "swa_3_1": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "causal_bottom_right", "no_bias"), - "swa_4_0": ModelConfig(4, 16, 16, 64, 2048, 2048, 0.0, "padding", "no_bias"), - "swa_4_1": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "padding", "no_bias"), - "swa_5_0": ModelConfig(4, 16, 16, 64, 2048, 2048, 0.0, "padding_causal", "no_bias"), - "swa_5_1": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "padding_causal", "no_bias"), - "swa_6_0": ModelConfig( - 4, 16, 16, 64, 2048, 2048, 0.0, "padding_causal_bottom_right", "no_bias" - ), + "swa_1_1": ModelConfig(2, 16, 16, 64, 2048, 2048, 0.0, "no_mask", "no_bias"), + "swa_1_2": ModelConfig(2, 24, 4, 128, 2048, 2048, 0.0, "no_mask", "no_bias"), + "swa_1_3": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "no_mask", "no_bias"), + "swa_2_1": ModelConfig(2, 16, 16, 64, 2048, 2048, 0.0, "causal", "no_bias"), + "swa_2_2": ModelConfig(2, 24, 4, 128, 2048, 2048, 0.0, "causal", "no_bias"), + "swa_2_3": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "causal", "no_bias"), + "swa_3_1": ModelConfig(2, 16, 16, 64, 2048, 2048, 0.0, "causal_bottom_right", "no_bias"), + "swa_3_2": ModelConfig(2, 24, 4, 128, 2048, 2048, 0.0, "causal_bottom_right", "no_bias"), + "swa_3_3": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "causal_bottom_right", "no_bias"), + "swa_4_1": ModelConfig(2, 16, 16, 64, 2048, 2048, 0.0, "padding", "no_bias"), + "swa_4_2": ModelConfig(2, 24, 4, 128, 2048, 2048, 0.0, "padding", "no_bias"), + "swa_4_3": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "padding", "no_bias"), + "swa_5_1": ModelConfig(2, 16, 16, 64, 2048, 2048, 0.0, "padding_causal", "no_bias"), + "swa_5_2": ModelConfig(2, 24, 4, 128, 2048, 2048, 0.0, "padding_causal", "no_bias"), + "swa_5_3": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "padding_causal", "no_bias"), "swa_6_1": ModelConfig( + 2, 16, 16, 64, 2048, 2048, 0.0, "padding_causal_bottom_right", "no_bias" + ), + "swa_6_2": ModelConfig( + 2, 24, 4, 128, 2048, 2048, 0.0, "padding_causal_bottom_right", "no_bias" + ), + "swa_6_3": ModelConfig( 2, 24, 24, 128, 2048, 4096, 0.0, "padding_causal_bottom_right", "no_bias" ), } @@ -623,18 +641,57 @@ def test_dpa_qkv_layout(dtype, model_configs, model, qkv_layout): qkv_layouts_thd = ["t3hd", "th3d", "thd_t2hd", "thd_th2d", "thd_thd_thd"] model_configs_layout_thd = { # test: b, h, hg, d, sq, skv, p, mask, bias - "layout_0_1": ModelConfig(3, 16, 4, 64, 128, 128, 0.0, "padding", "no_bias"), - "layout_0_2": ModelConfig(8, 16, 4, 64, 128, 128, 0.0, "padding", "no_bias"), - "layout_0_3": ModelConfig(1, 16, 16, 64, 128, 128, 0.0, "padding_causal", "no_bias"), - "layout_0_4": ModelConfig(8, 16, 16, 64, 128, 128, 0.0, "padding_causal", "no_bias"), - "layout_1_1": ModelConfig(1, 16, 16, 64, 2048, 2048, 0.0, "padding", "no_bias"), - "layout_1_2": ModelConfig(8, 16, 16, 64, 2048, 2048, 0.0, "padding", "no_bias"), - "layout_1_3": ModelConfig(1, 16, 1, 64, 2048, 2048, 0.0, "padding_causal", "no_bias"), - "layout_1_4": ModelConfig(8, 16, 1, 64, 2048, 2048, 0.0, "padding_causal", "no_bias"), - "layout_2_1": ModelConfig(1, 16, 16, 128, 128, 128, 0.0, "padding", "no_bias"), - "layout_2_2": ModelConfig(1, 16, 16, 64, 128, 256, 0.0, "padding", "no_bias"), - "layout_2_3": ModelConfig(1, 16, 16, 128, 2048, 2048, 0.0, "padding_causal", "no_bias"), - "layout_2_4": ModelConfig(8, 16, 16, 64, 2048, 4096, 0.0, "padding_causal", "no_bias"), + "layout_0_0": ModelConfig(2, 16, 16, 64, 2048, 2048, 0.0, "padding", "no_bias"), + "layout_0_1": ModelConfig(2, 24, 1, 128, 2048, 2048, 0.0, "padding", "no_bias"), + "layout_0_2": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "padding", "no_bias"), + "layout_1_0": ModelConfig(2, 16, 16, 64, 2048, 2048, 0.0, "padding_causal", "no_bias"), + "layout_1_1": ModelConfig(2, 24, 1, 128, 2048, 2048, 0.0, "padding_causal", "no_bias"), + "layout_1_2": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "padding_causal", "no_bias"), + "layout_2_0": ModelConfig( + 2, 16, 16, 64, 2048, 2048, 0.0, "padding_causal_bottom_right", "no_bias" + ), + "layout_2_1": ModelConfig( + 2, 24, 1, 128, 2048, 2048, 0.0, "padding_causal_bottom_right", "no_bias" + ), + "layout_2_2": ModelConfig( + 2, 24, 24, 128, 2048, 4096, 0.0, "padding_causal_bottom_right", "no_bias" + ), + "layout_3_0": ModelConfig( + 2, 16, 16, 64, 2048, 2048, 0.0, "padding", "no_bias", window_size=(4, 4) + ), + "layout_3_1": ModelConfig( + 2, 24, 1, 128, 2048, 2048, 0.0, "padding", "no_bias", window_size=(4, 4) + ), + "layout_3_2": ModelConfig( + 2, 24, 24, 128, 2048, 4096, 0.0, "padding", "no_bias", window_size=(4, 4) + ), + "layout_4_0": ModelConfig( + 2, 16, 16, 64, 2048, 2048, 0.0, "padding_causal", "no_bias", window_size=(4, 0) + ), + "layout_4_1": ModelConfig( + 2, 24, 1, 128, 2048, 2048, 0.0, "padding_causal", "no_bias", window_size=(4, 0) + ), + "layout_4_2": ModelConfig( + 2, 24, 24, 128, 2048, 4096, 0.0, "padding_causal", "no_bias", window_size=(4, 0) + ), + "layout_5_0": ModelConfig( + 2, 16, 16, 64, 2048, 2048, 0.0, "padding_causal_bottom_right", "no_bias", window_size=(4, 0) + ), + "layout_5_1": ModelConfig( + 2, 24, 1, 128, 2048, 2048, 0.0, "padding_causal_bottom_right", "no_bias", window_size=(4, 0) + ), + "layout_5_2": ModelConfig( + 2, + 24, + 24, + 128, + 2048, + 4096, + 0.0, + "padding_causal_bottom_right", + "no_bias", + window_size=(4, 0), + ), } @@ -651,11 +708,13 @@ def test_dpa_qkv_layout_thd(dtype, model_configs, model, qkv_layout): config = model_configs[model] if config.num_heads != config.num_gqa_groups and "3" in qkv_layout: pytest.skip("qkv_layout not applicable for MQA/GQA") + logging.info("[test_dpa_qkv_layout_thd]: pad_between_seqs = True") pad_between_seqs = True test_dot_product_attention( dtype, model_configs, model, False, True, qkv_layout, False, pad_between_seqs ) if get_cudnn_version() >= (9, 3, 0): + logging.info("[test_dpa_qkv_layout_thd]: pad_between_seqs = False") # cuDNN 9.3.0+ is required to run pad_between_seqs = False/True in the same run pad_between_seqs = False test_dot_product_attention( @@ -695,9 +754,12 @@ def _run_dot_product_attention( ) seqlens_kv = seqlens_q if config.attn_type == "cross": - seqlens_q = torch.randint( - 1, config.max_seqlen_q, [config.batch_size], dtype=torch.int32, device="cuda" - ) + if config.max_seqlen_q > 1: + seqlens_q = torch.randint( + 1, config.max_seqlen_q, [config.batch_size], dtype=torch.int32, device="cuda" + ) + else: + seqlens_q = torch.ones([config.batch_size], dtype=torch.int32, device="cuda") seqlens_kv = torch.randint( 1, config.max_seqlen_kv, [config.batch_size], dtype=torch.int32, device="cuda" ) diff --git a/tests/pytorch/fused_attn/test_fused_attn_with_cp.py b/tests/pytorch/fused_attn/test_fused_attn_with_cp.py index 1007d6aa34..fd8e543adc 100644 --- a/tests/pytorch/fused_attn/test_fused_attn_with_cp.py +++ b/tests/pytorch/fused_attn/test_fused_attn_with_cp.py @@ -121,6 +121,8 @@ def test_cp_with_fused_attention(dtype, model, qkv_format, cp_comm_type, fp8_mha pytest.skip("CP implementation with KV all-gather is only supported with cuDNN >= 9.3.0!") if dtype == "fp8" and get_device_compute_capability() < (9, 0): pytest.skip("FP8 attention is only supported on sm90+!") + if qkv_format == "thd" and get_cudnn_version() >= (9, 6, 0): + pytest.skip("THD format is not supported for cuDNN 9.6+!") config = model_configs_fused_attn[model] if qkv_format == "thd" and config.num_heads != config.num_gqa_groups: diff --git a/transformer_engine/common/fused_attn/fused_attn.cpp b/transformer_engine/common/fused_attn/fused_attn.cpp index 9cde765401..32e6d4df8f 100644 --- a/transformer_engine/common/fused_attn/fused_attn.cpp +++ b/transformer_engine/common/fused_attn/fused_attn.cpp @@ -135,6 +135,7 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( !requires_64bit_ragged_offset) { flag_m512 = true; } + // TODO(cyang): replace with cudnn-frontend check_support for cleaner logic and better error messaging if ( // architecture ((cudnn_runtime_version >= 8903 && sm_arch_ >= 80) || (cudnn_runtime_version < 8903 && (sm_arch_ == 80 || sm_arch_ == 90))) && @@ -152,7 +153,7 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( head_dim_qk % 8 == 0 && head_dim_v <= 256 && head_dim_v % 8 == 0)) && // bias type ((cudnn_runtime_version < 8906 && bias_type == NVTE_Bias_Type::NVTE_NO_BIAS) || - ((cudnn_runtime_version >= 8906) && + (cudnn_runtime_version >= 8906 && (bias_type == NVTE_Bias_Type::NVTE_NO_BIAS || (bias_type == NVTE_Bias_Type::NVTE_ALIBI && attn_mask_type != NVTE_Mask_Type::NVTE_NO_MASK && @@ -161,43 +162,67 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( attn_mask_type != NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK && sm_arch_ >= 90) || (bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS && sm_arch_ >= 90))) || - ((cudnn_runtime_version >= 90000) && + (cudnn_runtime_version >= 90000 && (bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS && sm_arch_ >= 80))) && // mask type + // pre-8.9.6: causal ((cudnn_runtime_version < 8906 && attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK) || - ((cudnn_runtime_version >= 8906) && + // 8.9.6: {bshd, sbhd} + {no_mask, causal, padding, padding_causal} + (cudnn_runtime_version >= 8906 && + (qkv_format == NVTE_QKV_Format::NVTE_SBHD || qkv_format == NVTE_QKV_Format::NVTE_BSHD) && (attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK || attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK || attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK || attn_mask_type == NVTE_Mask_Type::NVTE_NO_MASK)) || - ((cudnn_runtime_version >= 90300) && - attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_BOTTOM_RIGHT_MASK && - max_seqlen_q % 64 == 0 && max_seqlen_kv % 64 == 0 && - bias_type == NVTE_Bias_Type::NVTE_NO_BIAS && + // 9.1: adds thd + {padding, padding_causal} + (cudnn_runtime_version >= 90100 && qkv_format == NVTE_QKV_Format::NVTE_THD && + (attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK || + attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK)) || + // 9.3: adds {bshd, sbhd} + causal_bottom_right + self/cross-attn (sq <= skv) + (cudnn_runtime_version >= 90300 && (qkv_format == NVTE_QKV_Format::NVTE_SBHD || qkv_format == NVTE_QKV_Format::NVTE_BSHD) && - max_seqlen_q <= max_seqlen_kv && dropout == 0.0)) && + attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_BOTTOM_RIGHT_MASK && + max_seqlen_q % 64 == 0 && max_seqlen_kv % 64 == 0 && max_seqlen_q <= max_seqlen_kv && + bias_type == NVTE_Bias_Type::NVTE_NO_BIAS && dropout == 0.0) || + // 9.6: adds {bshd, sbhd, thd} + padding_causal_bottom_right + self/cross-attn (sq <= skv) + (cudnn_runtime_version >= 90600 && + attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK && + max_seqlen_q % 64 == 0 && max_seqlen_kv % 64 == 0 && max_seqlen_q <= max_seqlen_kv && + bias_type == NVTE_Bias_Type::NVTE_NO_BIAS && dropout == 0.0)) && // bias + mask combination (!(cudnn_runtime_version >= 8906 && (attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK || attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK) && bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS)) && // qkv format - ((qkv_format == NVTE_QKV_Format::NVTE_SBHD || qkv_format == NVTE_QKV_Format::NVTE_BSHD) || + (qkv_format == NVTE_QKV_Format::NVTE_SBHD || qkv_format == NVTE_QKV_Format::NVTE_BSHD || (qkv_format == NVTE_QKV_Format::NVTE_THD && sm_arch_ >= 90 && ((cudnn_runtime_version >= 90100 && num_attn_heads == num_gqa_groups) || - (cudnn_runtime_version >= 90600)))) && + cudnn_runtime_version >= 90600))) && // sliding window + // pre-9.2: full attn, causal ((cudnn_runtime_version < 90200 && window_size_left == -1 && (window_size_right == -1 || window_size_right == 0)) || + // 9.2: SWA (left, 0) + top-left diagonal + {bshd, sbhd} (cudnn_runtime_version >= 90200 && ((window_size_left == -1 && (window_size_right == -1 || window_size_right == 0)) || ((window_size_left >= 0 || window_size_left == -1) && window_size_right == 0 && (attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK || (attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_BOTTOM_RIGHT_MASK && max_seqlen_q == max_seqlen_kv)) && - dropout == 0.0 && bias_type == NVTE_Bias_Type::NVTE_NO_BIAS && + max_seqlen_q <= max_seqlen_kv && dropout == 0.0 && + bias_type == NVTE_Bias_Type::NVTE_NO_BIAS && (qkv_format == NVTE_QKV_Format::NVTE_BSHD || - qkv_format == NVTE_QKV_Format::NVTE_SBHD))))) && + qkv_format == NVTE_QKV_Format::NVTE_SBHD)))) || + // 9.6: SWA (left, 0) + top-left/bottom-right diagonal + {bshd, sbhd, thd} + (cudnn_runtime_version >= 90600 && + ((window_size_left == -1 && (window_size_right == -1 || window_size_right == 0)) || + ((window_size_left >= 0 || window_size_left == -1) && window_size_right == 0 && + (attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_BOTTOM_RIGHT_MASK || + attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK || + attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK) && + max_seqlen_q <= max_seqlen_kv && bias_type == NVTE_Bias_Type::NVTE_NO_BIAS && + dropout == 0.0)))) && // check 64-bit ragged offset support (supported_ragged_offset_size)) { flag_arb = true; diff --git a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu index b706eadace..cade624c8d 100644 --- a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu +++ b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu @@ -71,7 +71,8 @@ void fused_attn_arbitrary_seqlen_fwd_impl( is_bottom_right = false; } bool is_padding = ((mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK) || - (mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK)); + (mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK) || + (mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK)); bool is_dropout = (is_training && dropout_probability != 0.0f); bool is_ragged = (nvte_get_qkv_format(layout) == NVTE_QKV_Format::NVTE_THD); const auto cudnn_runtime_version = cudnnGetVersion(); @@ -451,7 +452,8 @@ void fused_attn_arbitrary_seqlen_bwd_impl( is_bottom_right = false; } bool is_padding = ((mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK) || - (mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK)); + (mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK) || + (mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK)); bool is_dropout = (dropout_probability != 0.0f); bool is_ragged = (nvte_get_qkv_format(layout) == NVTE_QKV_Format::NVTE_THD); const auto cudnn_runtime_version = cudnnGetVersion(); diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index be0d176520..9268b9636e 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -602,6 +602,12 @@ def get_attention_backend( "Disabling FusedAttention as it does not support context parallelism with MLA" ) use_fused_attention = False + elif cudnn_version >= (9, 6, 0) and qkv_format == "thd": + logger.debug( + "Disabling FusedAttention as it does not support context parallelism with THD for" + " cuDNN 9.6+" + ) + use_fused_attention = False # Filter: Attention mask # attn_mask_type | attention_mask | supported backends @@ -618,9 +624,7 @@ def get_attention_backend( # self-attention | | All # cross-attention | | FusedAttention, UnfusedDotProductAttention # causal_bottom_right | None | All - # padding_causal_bottom_right | Same as "padding" | - # self-attention | | All - # cross-attention | | FlashAttention, UnfusedDotProductAttention + # padding_causal_bottom_right | Same as "padding" | All # arbitrary | One tensor in shape broadcastable to | UnfusedDotProductAttention # | [b, h, sq, skv] | if attn_mask_type == "arbitrary": @@ -697,29 +701,16 @@ def get_attention_backend( " for FP8" ) use_fused_attention = False - elif window_size[1] != 0 or attention_dropout != 0.0 or qkv_format == "thd": + elif window_size[1] != 0 or attention_dropout != 0.0: logger.debug( "Disabling FusedAttention as it only supports sliding window attention " - "with causal mask, no dropout, and qkv_format = bshd/sbhd" - ) - use_fused_attention = False - elif max_seqlen_q != max_seqlen_kv and attn_mask_type in [ - "no_mask", - "padding", - "causal_bottom_right", - "padding_causal_bottom_right", - ]: - logger.debug( - "Disabling FusedAttention as it does not support sliding window attention " - "with attn_mask_type = %s for cross-attention", - attn_mask_type, + "with (left, 0) and no dropout" ) use_fused_attention = False - elif "padding" in attn_mask_type: + elif max_seqlen_q > max_seqlen_kv: logger.debug( "Disabling FusedAttention as it does not support sliding window attention " - "with attn_mask_type = %s", - attn_mask_type, + "with s_q > s_kv for cross-attention" ) use_fused_attention = False if use_flash_attention and (window_size[0] != -1 or window_size[1] not in [-1, 0]):