Skip to content

Commit 5de3e14

Browse files
Refactor attention.py part 2 (#1704)
* Move MultiHeadAttention into its own file. Modify tests and files in t_e/pytorch to import from the new MHA module Signed-off-by: Kshitij Janardan Lakhani <[email protected]> * Resolving lost MHA changes from PR 1614 as a result of rebase Signed-off-by: Kshitij Janardan Lakhani <[email protected]> * Move context parallelism code into it's own file. Modify test and local imports of cp code accordingly Signed-off-by: Kshitij Janardan Lakhani <[email protected]> * Move softmax.py frm pytorch/ to pytorch/d_p_a Signed-off-by: Kshitij Janardan Lakhani <[email protected]> * Move Unfused and Fused attention to backends.py and some utils functions to pytorch/utils.py Signed-off-by: Kshitij Janardan Lakhani <[email protected]> * Resolving lost mark_activation_offload changes from PR 1678 as a result of rebase Signed-off-by: Kshitij Janardan Lakhani <[email protected]> * Code clean up Signed-off-by: Kshitij Janardan Lakhani <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Refactor attention dir Signed-off-by: Kshitij Janardan Lakhani <[email protected]> * Refactor dir structure. Make relevant symbols public in __init__ for attention and d_p_a dirs Move FA package imports to backends.py Code cleanup Signed-off-by: Kshitij Janardan Lakhani <[email protected]> * Modify tests to import attention modules correctly Signed-off-by: Kshitij Janardan Lakhani <[email protected]> * Lint fixes Signed-off-by: Kshitij Janardan Lakhani <[email protected]> * Code clean up and fix typo Signed-off-by: Kshitij Janardan Lakhani <[email protected]> * Allowing InferenceParams and RoPE imports from attention module and pytorch module Signed-off-by: Kshitij Janardan Lakhani <[email protected]> * Allow InferenceParams and RoPE imports via transformer_engine.pytorch and transformer_engine.pytorch.attention modules Remove unnecessary checks for check_set_window_size in MHA and TL Reorder backends such that smaller classes at the start and larger ones at the end Code clean up Signed-off-by: Kshitij Janardan Lakhani <[email protected]> * Reinstating changes from PR 1478 for rope.py lost during rebase conflict resolution Signed-off-by: Kshitij Janardan Lakhani <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix lint issues Signed-off-by: Kshitij Janardan Lakhani <[email protected]> * nit: Code clean up Signed-off-by: Kshitij Janardan Lakhani <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Make imports leaner Signed-off-by: Kshitij Janardan Lakhani <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Kshitij Janardan Lakhani <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent a3d464c commit 5de3e14

25 files changed

+7444
-7241
lines changed

docs/examples/attention/attention.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -458,7 +458,7 @@
458458
" </tr>\n",
459459
"</table>\n",
460460
"\n",
461-
"Some example usage of the different layouts can be found at [test_dpa_qkv_layout](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/fused_attn/test_fused_attn.py) and [test_dpa_qkv_layout_thd](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/fused_attn/test_fused_attn.py). Transformer Engine also provides a utility function [transformer_engine.pytorch.dot_product_attention.utils.get_qkv_layout](https://github.com/NVIDIA/TransformerEngine/blob/main/transformer_engine/pytorch/attention.py) to help determine which layout a set of `q`, `k`, `v` tensors have (PyTorch only).\n",
461+
"Some example usage of the different layouts can be found at [test_dpa_qkv_layout](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/fused_attn/test_fused_attn.py) and [test_dpa_qkv_layout_thd](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/fused_attn/test_fused_attn.py). Transformer Engine also provides a utility function [transformer_engine.pytorch.attention.dot_product_attention.utils.get_qkv_layout](https://github.com/NVIDIA/TransformerEngine/blob/main/transformer_engine/pytorch/attention.py) to help determine which layout a set of `q`, `k`, `v` tensors have (PyTorch only).\n",
462462
"\n",
463463
"<div class=\"alert alert-info\">\n",
464464
"<b>Note</b>\n",

docs/examples/te_llama/te_llama.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,9 @@
88
from contextlib import contextmanager
99

1010
import torch
11-
from torch import nn
1211

1312
import transformer_engine as te
14-
from transformer_engine.pytorch.dot_product_attention.rope import RotaryPositionEmbedding
15-
from transformer_engine.pytorch.fp8 import fp8_model_init
13+
from transformer_engine.pytorch.attention import RotaryPositionEmbedding
1614

1715
import transformers
1816
from transformers.models.llama.modeling_llama import (

tests/pytorch/fused_attn/run_fused_attn_with_cp.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,16 @@
22
#
33
# See LICENSE for license information.
44

5-
import os, sys, logging
5+
import os
6+
import sys
7+
import logging
68
from contextlib import nullcontext
79
import torch
810
import torch.distributed as dist
911
from transformer_engine.pytorch.attention import DotProductAttention
10-
from transformer_engine.pytorch.attention import get_cu_seqlens_on_cp_rank
12+
from transformer_engine.pytorch.attention.dot_product_attention.context_parallel import (
13+
get_cu_seqlens_on_cp_rank,
14+
)
1115
import transformer_engine_torch as tex
1216
from test_fused_attn_with_cp import model_configs_flash_attn, model_configs_fused_attn
1317
from transformer_engine.pytorch.fp8 import fp8_autocast

tests/pytorch/fused_attn/test_fused_attn.py

Lines changed: 5 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,9 @@
11
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
22
#
33
# See LICENSE for license information.
4-
5-
import functools
64
import logging
75
import math
86
import os
9-
from importlib.metadata import version
107
from typing import Any, Dict, List, Tuple, Union, Optional
118
from contextlib import contextmanager
129

@@ -15,26 +12,22 @@
1512

1613
from transformer_engine.common import recipe
1714
from transformer_engine.pytorch import TransformerLayer, fp8_autocast, fp8_model_init
18-
from transformer_engine.pytorch.attention import (
15+
from transformer_engine.pytorch.attention.dot_product_attention import (
1916
DotProductAttention,
20-
MultiheadAttention,
2117
_attention_backends,
2218
)
23-
from transformer_engine.pytorch.dot_product_attention.utils import (
19+
from transformer_engine.pytorch.attention.multi_head_attention import MultiheadAttention
20+
from transformer_engine.pytorch.attention.dot_product_attention.utils import (
2421
FlashAttentionUtils,
2522
get_attention_backend,
2623
check_set_window_size,
2724
AttentionParams,
2825
)
29-
from transformer_engine.pytorch.dot_product_attention.inference import InferenceParams
30-
from transformer_engine.pytorch.dot_product_attention.rope import RotaryPositionEmbedding
31-
from transformer_engine.pytorch.constants import TE_DType
26+
from transformer_engine.pytorch.attention import InferenceParams
27+
from transformer_engine.pytorch.attention import RotaryPositionEmbedding
3228
import transformer_engine.pytorch.cpp_extensions as ext
3329
from transformer_engine.pytorch.cpp_extensions.fused_attn import (
34-
AttnBiasType,
35-
AttnMaskType,
3630
FusedAttnBackend,
37-
QKVLayout,
3831
fused_attn_bwd,
3932
fused_attn_fwd,
4033
)
@@ -49,9 +42,7 @@
4942
)
5043
from transformer_engine.pytorch.utils import get_cudnn_version
5144
import transformer_engine_torch as tex
52-
from transformer_engine_torch import NVTE_Fused_Attn_Backend
5345
from transformer_engine.pytorch.tensor.quantized_tensor import (
54-
QuantizedTensor,
5546
Quantizer,
5647
prepare_for_saving,
5748
restore_from_saved,

tests/pytorch/fused_attn/test_fused_attn_with_cp.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
get_device_compute_capability,
1212
get_cudnn_version,
1313
)
14-
from transformer_engine.pytorch.dot_product_attention.utils import FlashAttentionUtils
14+
from transformer_engine.pytorch.attention.dot_product_attention.utils import FlashAttentionUtils
1515
from test_fused_attn import ModelConfig
1616

1717
model_configs_flash_attn = {

tests/pytorch/fused_attn/test_kv_cache.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -11,27 +11,28 @@
1111
import pytest
1212
import torch
1313

14+
from test_fused_attn import (
15+
ModelConfig,
16+
reset_rng_states,
17+
_get_attention_backends,
18+
)
19+
1420
from torch.distributions import Exponential
1521
from transformer_engine.pytorch import make_graphed_callables
1622
from transformer_engine.common import recipe
1723
from transformer_engine.pytorch import fp8_autocast, fp8_model_init
1824
from transformer_engine.pytorch.transformer import (
1925
TransformerLayer,
2026
)
21-
from transformer_engine.pytorch.attention import DotProductAttention
22-
from transformer_engine.pytorch.dot_product_attention.inference import InferenceParams
23-
from transformer_engine.pytorch.dot_product_attention.utils import FlashAttentionUtils as fa_utils
27+
from transformer_engine.pytorch.attention import DotProductAttention, InferenceParams
28+
from transformer_engine.pytorch.attention.dot_product_attention.utils import (
29+
FlashAttentionUtils as fa_utils,
30+
)
2431
from transformer_engine.pytorch.utils import (
25-
get_device_compute_capability,
2632
init_method_normal,
2733
scaled_init_method_normal,
2834
is_bf16_compatible,
2935
)
30-
from test_fused_attn import (
31-
ModelConfig,
32-
reset_rng_states,
33-
_get_attention_backends,
34-
)
3536

3637
# Initialize RNG state
3738
seed = 1234

tests/pytorch/test_fused_optimizer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from torch.testing._internal.common_device_type import largeTensorTest
1313
import transformer_engine.pytorch as te
1414
from transformer_engine.common.recipe import DelayedScaling
15-
from transformer_engine.pytorch.attention import MultiheadAttention
15+
from transformer_engine.pytorch.attention.multi_head_attention import MultiheadAttention
1616
from transformer_engine.pytorch import fp8_model_init
1717
from transformer_engine.pytorch.utils import is_bf16_compatible
1818
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager

tests/pytorch/test_fused_rope.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
22
#
33
# See LICENSE for license information.
4+
from typing import Callable, Tuple, Union
45
import math
5-
import pytest
66
import torch
7-
from typing import Callable, Tuple, Union
8-
from transformer_engine.pytorch.dot_product_attention.rope import (
7+
import pytest
8+
from transformer_engine.pytorch.attention.rope import (
99
RotaryPositionEmbedding,
1010
apply_rotary_pos_emb,
1111
)

tests/pytorch/test_numerics.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
import os
88
from typing import Dict, List, Tuple, Optional
99
import pytest
10-
import copy
1110
import random
1211

1312
import torch
@@ -38,7 +37,7 @@
3837
Fp8Padding,
3938
Fp8Unpadding,
4039
)
41-
from transformer_engine.pytorch.dot_product_attention.inference import InferenceParams
40+
from transformer_engine.pytorch.attention.inference import InferenceParams
4241
from transformer_engine.pytorch.distributed import checkpoint as te_checkpoint
4342
from transformer_engine.pytorch.cpp_extensions import general_gemm, general_grouped_gemm
4443
from transformer_engine.pytorch.tensor.float8_tensor import Float8Quantizer

transformer_engine/pytorch/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,8 @@ def _load_library():
9090
from transformer_engine.pytorch.module import destroy_ub
9191
from transformer_engine.pytorch.attention import DotProductAttention
9292
from transformer_engine.pytorch.attention import MultiheadAttention
93-
from transformer_engine.pytorch.dot_product_attention.inference import InferenceParams
93+
from transformer_engine.pytorch.attention import InferenceParams
94+
from transformer_engine.pytorch.attention import RotaryPositionEmbedding
9495
from transformer_engine.pytorch.transformer import TransformerLayer
9596
from transformer_engine.pytorch.permutation import (
9697
moe_permute,

0 commit comments

Comments
 (0)