Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PyTorch] cuda graph support #575

Merged
merged 8 commits into from
Apr 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/api/pytorch.rst
Original file line number Diff line number Diff line change
Expand Up @@ -41,4 +41,6 @@ pyTorch

.. autoapifunction:: transformer_engine.pytorch.onnx_export

.. autoapifunction:: transformer_engine.pytorch.make_graphed_callables

.. autoapifunction:: transformer_engine.pytorch.get_cpu_offload_context
5 changes: 3 additions & 2 deletions qa/L0_pytorch_unittest/test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,10 @@ set -e
pip install pytest==6.2.5 onnxruntime==1.13.1
pytest -v -s $TE_PATH/tests/pytorch/test_sanity.py
pytest -v -s $TE_PATH/tests/pytorch/test_deferred_init.py
PYTORCH_JIT=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 pytest -v -s $TE_PATH/tests/pytorch/test_numerics.py
PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 pytest -v -s $TE_PATH/tests/pytorch/test_numerics.py
PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 pytest -v -s $TE_PATH/tests/pytorch/test_cuda_graphs.py
pytest -v -s $TE_PATH/tests/pytorch/test_jit.py
pytest -v -s $TE_PATH/tests/pytorch/fused_attn/test_fused_attn.py
NVTE_TORCH_COMPILE=0 pytest -v -s $TE_PATH/tests/pytorch/fused_attn/test_fused_attn.py
pytest -v -s $TE_PATH/tests/pytorch/test_fused_rope.py
NVTE_TORCH_COMPILE=0 pytest -v -s $TE_PATH/tests/pytorch/test_onnx_export.py
pytest -v -s $TE_PATH/tests/pytorch/test_float8tensor.py
60 changes: 48 additions & 12 deletions tests/pytorch/fused_attn/test_fused_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import functools
from importlib.metadata import version
import os
import math
from typing import Any, Dict, List, Tuple, Union

from pkg_resources import packaging
Expand All @@ -28,15 +27,9 @@
fused_attn_bwd,
fused_attn_fwd,
)
from transformer_engine.pytorch.distributed import (
_set_cuda_rng_state,
CudaRNGStatesTracker,
)
from transformer_engine.pytorch.distributed import CudaRNGStatesTracker
import transformer_engine.pytorch.fp8 as fp8
from transformer_engine.pytorch.module.base import (
TransformerEngineBaseModule,
_prepare_backward,
)
from transformer_engine.pytorch.module.base import TransformerEngineBaseModule
from transformer_engine.pytorch.utils import (
get_device_compute_capability,
init_method_normal,
Expand All @@ -58,10 +51,18 @@

_NVTE_DEBUG = int(os.getenv("NVTE_DEBUG", "0"))


def reset_rng_states() -> None:
"""Revert back to initial RNG state"""
torch.set_rng_state(_cpu_rng_state)
_set_cuda_rng_state(_cuda_rng_state)
torch.cuda.set_rng_state(_cuda_rng_state)


@pytest.fixture(autouse=True)
def reset_global_fp8_state():
yield
fp8.FP8GlobalStateManager.reset()


@functools.cache
def _cudnn_version() -> Tuple[int, int, int]:
Expand All @@ -71,6 +72,7 @@ def _cudnn_version() -> Tuple[int, int, int]:
minor, patch = divmod(encoded_version, 100)
return (major, minor, patch)


class ModelConfig:
def __init__(
self,
Expand Down Expand Up @@ -103,6 +105,7 @@ def __init__(
self.num_layers = num_layers
self.bias_shape = bias_shape


def _is_fused_attention_supported(
config: ModelConfig,
dtype: torch.dtype,
Expand Down Expand Up @@ -151,24 +154,28 @@ def _is_fused_attention_supported(
return True, backends
return False, backends


@functools.cache
def _is_flash_attention_2_available() -> bool:
"""Check if flash-attn 2.0+ is available"""
Version = packaging.version.Version
return Version(version("flash-attn")) >= Version("2")


@functools.cache
def _is_flash_attention_2_1() -> bool:
"""Check if flash-attn 2.1+ is available"""
Version = packaging.version.Version
return Version(version("flash-attn")) >= Version("2.1")


@functools.cache
def _is_flash_attention_2_3() -> bool:
"""Check if flash-attn 2.3+ is available"""
Version = packaging.version.Version
return Version(version("flash-attn")) >= Version("2.3")


def _is_flash_attention_supported(config: ModelConfig) -> bool:
"""Check if FlashAttention supports a model configuration"""
if get_device_compute_capability() < (8, 0):
Expand All @@ -184,6 +191,7 @@ def _is_flash_attention_supported(config: ModelConfig) -> bool:
return False
return True


def _is_unfused_attention_supported(config: ModelConfig) -> bool:
"""Check if UnfusedDotProductAttention supports a model configuration"""
if ("padding" in config.attn_mask_type):
Expand All @@ -192,6 +200,7 @@ def _is_unfused_attention_supported(config: ModelConfig) -> bool:
return False
return True


model_configs_base = {
# test: b, h, hg, d, sq, skv, p, mask, bias # attn , backend
"base_1_0": ModelConfig(8, 16, 16, 64, 128, 128, 0.0, "no_mask", "no_bias"), # self , 0
Expand All @@ -200,11 +209,13 @@ def _is_unfused_attention_supported(config: ModelConfig) -> bool:
"base_2_1": ModelConfig(1, 24, 24, 128, 2048, 4096, 0.0, "no_mask", "no_bias"), # cross, 1
}


param_types = [torch.float16]
if is_bf16_compatible(): # bf16 requires sm_80 or higher
param_types.append(torch.bfloat16)
param_types_lean = [torch.bfloat16]


def get_swa(seq_q, seq_kv, w=None):
"""Generate a random sliding window size (left, right) if w is None,
and create its equivalent attention mask in [seq_q, seq_kv] shape"""
Expand All @@ -216,6 +227,7 @@ def get_swa(seq_q, seq_kv, w=None):
ml = ~ ml
return w, ml


@pytest.mark.skipif(_cudnn_version() < (8,9,1), reason="cuDNN 8.9.1+ is required.")
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("model_configs", [model_configs_base])
Expand Down Expand Up @@ -313,6 +325,7 @@ def test_dot_product_attention(dtype, model_configs, model, ckpt_attn, workspace
for i,_ in enumerate(fused_attn_bwd):
torch.testing.assert_close(fused_attn_bwd[i], fused_attn_bwd_1[i], **tols)


@pytest.mark.skipif(_cudnn_version() < (8,9,1), reason="cuDNN 8.9.1+ is required.")
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("model_configs", [model_configs_base])
Expand All @@ -321,6 +334,7 @@ def test_dpa_checkpoint(dtype, model_configs, model):
"""Test DotProductAttention module with checkpointing"""
test_dot_product_attention(dtype, model_configs, model, True, True, None, False)


model_configs_mask = {
# test: b, h, hg, d, sq, skv, p, mask, bias
"mask_1_0": ModelConfig(8, 16, 16, 64, 128, 128, 0.0, "causal", "no_bias"),
Expand All @@ -337,6 +351,7 @@ def test_dpa_checkpoint(dtype, model_configs, model):
"mask_6_1": ModelConfig(1, 24, 24, 128, 2048, 4096, 0.0, "padding_causal", "no_bias"),
}


@pytest.mark.skipif(_cudnn_version() < (8,9,1), reason="cuDNN 8.9.1+ is required.")
@pytest.mark.parametrize("dtype", param_types_lean)
@pytest.mark.parametrize("model_configs", [model_configs_mask])
Expand All @@ -345,6 +360,7 @@ def test_dpa_mask(dtype, model_configs, model):
"""Test DotProductAttention module with different mask types"""
test_dot_product_attention(dtype, model_configs, model, False, True, None, False)


model_configs_bias = {
# test: b, h, hg, d, sq, skv, p, mask, bias
"bias_1_0": ModelConfig(4, 16, 16, 64, 128, 128, 0.0, "no_mask", "post_scale_bias"),
Expand Down Expand Up @@ -373,6 +389,7 @@ def test_dpa_mask(dtype, model_configs, model):
"bias_4_5": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "padding_causal", "alibi"), # skipped
}


@pytest.mark.skipif(_cudnn_version() < (8,9,1), reason="cuDNN 8.9.1+ is required.")
@pytest.mark.parametrize("dtype", param_types_lean)
@pytest.mark.parametrize("model_configs", [model_configs_bias])
Expand All @@ -381,6 +398,7 @@ def test_dpa_bias(dtype, model_configs, model):
"""Test DotProductAttention module with different bias types"""
test_dot_product_attention(dtype, model_configs, model, False, True, None, False)


model_configs_bias_shapes = {
# test: b, h, hg, d, sq, skv, p,
"bias_1_0": ModelConfig(4, 16, 16, 64, 128, 128, 0.0,
Expand All @@ -398,6 +416,7 @@ def test_dpa_bias(dtype, model_configs, model):
"causal", "alibi", bias_shape='bhss', alibi_type='custom'),
}


@pytest.mark.skipif(_cudnn_version() < (8,9,1), reason="cuDNN 8.9.1+ is required.")
@pytest.mark.parametrize("dtype", param_types_lean)
@pytest.mark.parametrize("model_configs", [model_configs_bias_shapes])
Expand All @@ -413,6 +432,8 @@ def test_dpa_bias_shapes(dtype, model_configs, model):
"swa_1_2": ModelConfig(4, 24, 24, 128, 2048, 2048, 0.0, "no_mask", "no_bias"),
"swa_1_3": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "no_mask", "no_bias"),
}


@pytest.mark.skipif(not _is_flash_attention_2_3(), reason="Flash-attn 2.3+ is required.")
@pytest.mark.parametrize("dtype", param_types_lean)
@pytest.mark.parametrize("model_configs", [model_configs_swa])
Expand All @@ -428,6 +449,8 @@ def test_dpa_sliding_window(dtype, model_configs, model):
"alibi_2_0": ModelConfig(2, 24, 24, 128, 1024, 1024, 0.0, "causal", "alibi", alibi_type= "custom"),
"alibi_2_1": ModelConfig(1, 24, 24, 128, 1024, 2048, 0.0, "causal", "alibi", alibi_type= "custom"),
}


@pytest.mark.skipif(not _is_flash_attention_2_3(), reason="Flash-attn 2.3+ is required.")
@pytest.mark.parametrize("dtype", param_types_lean)
@pytest.mark.parametrize("model_configs", [model_configs_alibi_slopes])
Expand All @@ -436,13 +459,15 @@ def test_dpa_alibi_slopes(dtype, model_configs, model):
"""Test DotProductAttention module with ALiBi slopes"""
test_dot_product_attention(dtype, model_configs, model, False, True, None, False)


qkv_layouts = [
'sb3hd', 'sbh3d', 'sbhd_sb2hd', 'sbhd_sbh2d', 'sbhd_sbhd_sbhd',
'bs3hd', 'bsh3d', 'bshd_bs2hd', 'bshd_bsh2d', 'bshd_bshd_bshd',
# will add tests for thd layouts later when the support is available in fused attention
#'t3hd', 'th3d', 'thd_t2hd', 'thd_th2d', 'thd_thd_thd',
]


model_configs_layout = {
# test: b, h, hg, d, sq, skv, p, mask, bias
"layout_0_0": ModelConfig(2, 16, 16, 64, 128, 128, 0.0, "no_mask", "no_bias"),
Expand All @@ -455,6 +480,7 @@ def test_dpa_alibi_slopes(dtype, model_configs, model):
"layout_1_3": ModelConfig(1, 24, 24, 128, 2048, 4096, 0.0, "padding_causal", "post_scale_bias"),
}


@pytest.mark.skipif(_cudnn_version() < (8,9,5), reason="cuDNN 8.9.5+ is required.")
@pytest.mark.parametrize("dtype", param_types_lean)
@pytest.mark.parametrize("model_configs", [model_configs_layout])
Expand All @@ -464,6 +490,7 @@ def test_dpa_qkv_layout(dtype, model_configs, model, qkv_layout):
"""Test DotProductAttention module with different QKV layouts"""
test_dot_product_attention(dtype, model_configs, model, False, True, qkv_layout, False)


def _run_dot_product_attention(
dtype: torch.dtype,
config: ModelConfig,
Expand Down Expand Up @@ -646,6 +673,7 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker:

return out, (inp[0].grad, inp[1].grad, inp[2].grad)


model_configs_te_layer = {
# test: b, h, hg, d, sq, skv, p, mask, bias
"te_1_0": ModelConfig(2, 16, 16, 64, 128, 128, 0.0, "no_mask", "post_scale_bias"),
Expand All @@ -658,6 +686,7 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker:
"te_3_1": ModelConfig(4, 16, 16, 64, 2048, 2048, 0.0, "causal", "alibi"),
}


@pytest.mark.skipif(_cudnn_version() < (8,9,1), reason="cuDNN 8.9.1+ is required.")
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("model_configs", [model_configs_te_layer])
Expand Down Expand Up @@ -742,6 +771,7 @@ def test_transformer_layer(dtype, model_configs, model, ckpt_attn, qkv_format, f
torch.testing.assert_close(fused_attn_fwd, flash_attn_fwd, **tols)
torch.testing.assert_close(fused_attn_bwd, flash_attn_bwd, **tols)


@pytest.mark.skipif(_cudnn_version() < (8,9,1), reason="cuDNN 8.9.1+ is required.")
@pytest.mark.parametrize("dtype", param_types_lean)
@pytest.mark.parametrize("model_configs", [model_configs_te_layer])
Expand All @@ -755,6 +785,7 @@ def test_te_layer_misc(dtype, model_configs, model, qkv_format):
test_transformer_layer(dtype, model_configs, model,
ckpt_attn, qkv_format, fused_qkv_params, RoPE)


@pytest.mark.skipif(_cudnn_version() < (8,9,1), reason="cuDNN 8.9.1+ is required.")
@pytest.mark.parametrize("dtype", param_types_lean)
@pytest.mark.parametrize("model_configs", [model_configs_te_layer])
Expand All @@ -780,6 +811,7 @@ def find_factors(x):
test_transformer_layer(dtype, model_configs, model,
ckpt_attn, qkv_format, fused_qkv_params, RoPE)


def _run_transformer_layer(
dtype: torch.dtype,
config: ModelConfig,
Expand Down Expand Up @@ -912,8 +944,10 @@ def _run_transformer_layer(
"fp8_1": ModelConfig(1, 16, 16, 64, 512, 512, 0.0, "no_mask", "no_bias"),
"fp8_2": ModelConfig(4, 16, 16, 64, 512, 512, 0.0, "no_mask", "no_bias"),
}

param_types_fp8 = [torch.float16]


@pytest.mark.skipif(_cudnn_version() < (8,9,3), reason="cuDNN 8.9.3+ is required.")
@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
@pytest.mark.skipif(get_device_compute_capability() != (9, 0), reason="FP8 tests require Hopper.")
Expand Down Expand Up @@ -946,6 +980,7 @@ def test_dpa_fp8(dtype, model):
torch.testing.assert_close(fused_attn_fwd, unfused_attn_fwd, **tols)
torch.testing.assert_close(fused_attn_bwd, unfused_attn_bwd, **tols)


def _run_dpa_fp8(dtype, config, backend):
"""Run FusedAttention FP8 backend, i.e.
fused_attn_fwd/bwd_qkvpacked from cpp_extensions"""
Expand Down Expand Up @@ -989,6 +1024,7 @@ def _run_dpa_fp8(dtype, config, backend):
dqkv.view(config.batch_size, config.max_seqlen_q, 3,
config.num_heads, config.head_dim).transpose(0,1).contiguous())


def _run_dpa_fp8_ref(dtype, config, backend):
"""Run UnfusedDotProductAttention as a reference, i.e.
plain PyTorch implementation in FP16 and inputs/outputs
Expand Down Expand Up @@ -1188,8 +1224,7 @@ def forward(
def backward(
ctx, grad_output: torch.Tensor
) -> Tuple[Union[torch.Tensor, None], ...]:

with _prepare_backward(True, ctx.fp8_meta, None, 1, name="_DPA"):
with torch.cuda.nvtx.range("_DPA"):
(
inputmat_t,
qkv_weight_t_fp8,
Expand Down Expand Up @@ -1298,6 +1333,7 @@ def backward(
None,
None)


class DPA_FP8(TransformerEngineBaseModule):
def __init__(
self,
Expand Down
Loading
Loading