Skip to content

Commit 7bbaf56

Browse files
authored
feat: add use_sync switch to ulysses (#103)
1 parent 1fb7f00 commit 7bbaf56

File tree

3 files changed

+13
-11
lines changed

3 files changed

+13
-11
lines changed

yunchang/hybrid/attn_layer.py

+1
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,7 @@ def forward(
9696
deterministic=deterministic,
9797
return_attn_probs=return_attn_probs,
9898
group=self.ring_pg,
99+
attn_type=self.attn_type,
99100
)
100101
else:
101102
query_layer = SeqAllToAll4D.apply(

yunchang/kernels/__init__.py

+4
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,8 @@ def select_flash_attn_impl(impl_type: FlashAttentionImpl, stage : str = "fwd-bwd
3030
elif stage == "fwd-bwd":
3131
print(f"flash_attn_func: {flash_attn_func} here")
3232
return flash_attn_func
33+
else:
34+
raise ValueError(f"Unknown stage: {stage}")
3335

3436
elif impl_type == FlashAttentionImpl.FA3:
3537
if stage == "fwd-only":
@@ -52,6 +54,8 @@ def fn(q,
5254
return flash3_attn_func(q, k, v, softmax_scale=softmax_scale, causal=causal)
5355

5456
return fn
57+
else:
58+
raise ValueError(f"Unknown stage: {stage}")
5559

5660
elif impl_type == FlashAttentionImpl.TORCH:
5761
if stage == "fwd-bwd":

yunchang/ulysses/attn_layer.py

+8-11
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,6 @@
1010
from yunchang.kernels import FlashAttentionImpl, select_flash_attn_impl
1111
import torch.distributed as dist
1212
from yunchang.comm.all_to_all import SeqAllToAll4D
13-
import torch.nn.functional as F
14-
1513

1614

1715
class UlyssesAttention(torch.nn.Module):
@@ -32,20 +30,21 @@ def __init__(
3230
scatter_idx: int = 2,
3331
gather_idx: int = 1,
3432
use_sync: bool = False,
35-
attn_type : FlashAttentionImpl = FlashAttentionImpl.FA
33+
attn_type : FlashAttentionImpl = FlashAttentionImpl.FA,
3634
) -> None:
3735

3836
super(UlyssesAttention, self).__init__()
3937
self.spg = sequence_process_group
4038
self.scatter_idx = scatter_idx
4139
self.gather_idx = gather_idx
40+
self.use_sync = use_sync
4241
self.attn_type = attn_type
43-
self.attn_fn = select_flash_attn_impl(attn_type)
4442

4543
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
4644
gpu_name = torch.cuda.get_device_name(device)
4745
if "Turing" in gpu_name or "Tesla" in gpu_name or "T4" in gpu_name:
4846
self.attn_type = FlashAttentionImpl.TORCH
47+
self.attn_fn = select_flash_attn_impl(self.attn_type, stage="fwd-bwd")
4948

5049
def forward(
5150
self,
@@ -79,15 +78,13 @@ def forward(
7978
# (bs, seq_len/N, head_cnt, head_size) -> (bs, seq_len, head_cnt/N, head_size)
8079

8180
# scatter 2, gather 1
82-
q = SeqAllToAll4D.apply(self.spg, query, self.scatter_idx, self.gather_idx)
83-
k = SeqAllToAll4D.apply(self.spg, key, self.scatter_idx, self.gather_idx)
84-
v = SeqAllToAll4D.apply(self.spg, value, self.scatter_idx, self.gather_idx)
85-
86-
81+
q = SeqAllToAll4D.apply(self.spg, query, self.scatter_idx, self.gather_idx, self.use_sync)
82+
k = SeqAllToAll4D.apply(self.spg, key, self.scatter_idx, self.gather_idx, self.use_sync)
83+
v = SeqAllToAll4D.apply(self.spg, value, self.scatter_idx, self.gather_idx, self.use_sync)
8784

8885
if softmax_scale is None:
8986
softmax_scale = q.shape[-1] ** -0.5
90-
87+
9188
context_layer = self.attn_fn(
9289
q,
9390
k,
@@ -108,7 +105,7 @@ def forward(
108105
# (bs, seq_len, head_cnt/N, head_size) -> (bs, seq_len/N, head_cnt, head_size)
109106
# scatter 1, gather 2
110107
output = SeqAllToAll4D.apply(
111-
self.spg, context_layer, self.gather_idx, self.scatter_idx
108+
self.spg, context_layer, self.gather_idx, self.scatter_idx, self.use_sync
112109
)
113110

114111
# out e.g., [s/p::h]

0 commit comments

Comments
 (0)