10
10
from yunchang .kernels import FlashAttentionImpl , select_flash_attn_impl
11
11
import torch .distributed as dist
12
12
from yunchang .comm .all_to_all import SeqAllToAll4D
13
- import torch .nn .functional as F
14
-
15
13
16
14
17
15
class UlyssesAttention (torch .nn .Module ):
@@ -32,20 +30,21 @@ def __init__(
32
30
scatter_idx : int = 2 ,
33
31
gather_idx : int = 1 ,
34
32
use_sync : bool = False ,
35
- attn_type : FlashAttentionImpl = FlashAttentionImpl .FA
33
+ attn_type : FlashAttentionImpl = FlashAttentionImpl .FA ,
36
34
) -> None :
37
35
38
36
super (UlyssesAttention , self ).__init__ ()
39
37
self .spg = sequence_process_group
40
38
self .scatter_idx = scatter_idx
41
39
self .gather_idx = gather_idx
40
+ self .use_sync = use_sync
42
41
self .attn_type = attn_type
43
- self .attn_fn = select_flash_attn_impl (attn_type )
44
42
45
43
device = torch .device ("cuda" if torch .cuda .is_available () else "cpu" )
46
44
gpu_name = torch .cuda .get_device_name (device )
47
45
if "Turing" in gpu_name or "Tesla" in gpu_name or "T4" in gpu_name :
48
46
self .attn_type = FlashAttentionImpl .TORCH
47
+ self .attn_fn = select_flash_attn_impl (self .attn_type , stage = "fwd-bwd" )
49
48
50
49
def forward (
51
50
self ,
@@ -79,15 +78,13 @@ def forward(
79
78
# (bs, seq_len/N, head_cnt, head_size) -> (bs, seq_len, head_cnt/N, head_size)
80
79
81
80
# scatter 2, gather 1
82
- q = SeqAllToAll4D .apply (self .spg , query , self .scatter_idx , self .gather_idx )
83
- k = SeqAllToAll4D .apply (self .spg , key , self .scatter_idx , self .gather_idx )
84
- v = SeqAllToAll4D .apply (self .spg , value , self .scatter_idx , self .gather_idx )
85
-
86
-
81
+ q = SeqAllToAll4D .apply (self .spg , query , self .scatter_idx , self .gather_idx , self .use_sync )
82
+ k = SeqAllToAll4D .apply (self .spg , key , self .scatter_idx , self .gather_idx , self .use_sync )
83
+ v = SeqAllToAll4D .apply (self .spg , value , self .scatter_idx , self .gather_idx , self .use_sync )
87
84
88
85
if softmax_scale is None :
89
86
softmax_scale = q .shape [- 1 ] ** - 0.5
90
-
87
+
91
88
context_layer = self .attn_fn (
92
89
q ,
93
90
k ,
@@ -108,7 +105,7 @@ def forward(
108
105
# (bs, seq_len, head_cnt/N, head_size) -> (bs, seq_len/N, head_cnt, head_size)
109
106
# scatter 1, gather 2
110
107
output = SeqAllToAll4D .apply (
111
- self .spg , context_layer , self .gather_idx , self .scatter_idx
108
+ self .spg , context_layer , self .gather_idx , self .scatter_idx , self . use_sync
112
109
)
113
110
114
111
# out e.g., [s/p::h]
0 commit comments