Skip to content

Commit 7482b05

Browse files
committed
Re-write for SDPA to support older models
1 parent 0054659 commit 7482b05

File tree

1 file changed

+98
-205
lines changed

1 file changed

+98
-205
lines changed
Lines changed: 98 additions & 205 deletions
Original file line numberDiff line numberDiff line change
@@ -1,166 +1,18 @@
1+
"""
2+
Flash Attention support for Windows platforms.
3+
4+
This module provides a dynamic fallback mechanism for Flash Attention on Windows,
5+
patching PyTorch's scaled_dot_product_attention when native support is unavailable.
6+
"""
7+
18
import sys
29

310
import torch
11+
import torch.nn.functional as F
412

5-
from diffusers.models.attention_dispatch import AttentionBackendName, _AttentionBackendRegistry
613
from diffusers.utils import is_flash_attn_available
714

8-
9-
def all_tensors_on_device(query: torch.Tensor,
10-
key: torch.Tensor,
11-
value: torch.Tensor):
12-
# Check that all tensors are on the GPU device
13-
return query.is_cuda and key.is_cuda and value.is_cuda
14-
15-
16-
def check_for_attn_mask(attn_mask: torch.Tensor | None):
17-
# Flash Attention does not support non-null attn_mask
18-
return attn_mask is None
19-
20-
21-
def check_tensor_shapes(query: torch.Tensor,
22-
key: torch.Tensor,
23-
value: torch.Tensor):
24-
# All fused kernels requires query, key and value to be 4 dimensional
25-
query_dim = query.dim()
26-
return query_dim == key.dim() and query_dim == value.dim() and query_dim == 4
27-
28-
29-
def check_head_dim_size_flash(query: torch.Tensor,
30-
key: torch.Tensor,
31-
value: torch.Tensor):
32-
# All head_dim sizes must be equal and less than 256
33-
# (ROCm with AOTriton 0.9+ supports up to 512, but we keep 256 for simplicity)
34-
max_size = 256
35-
query_size_last = query.size(-1)
36-
key_size_last = key.size(-1)
37-
value_size_last = value.size(-1)
38-
39-
same_head_dim_size = (query_size_last == key_size_last and
40-
query_size_last == value_size_last)
41-
42-
# Check that all head dims are equal, all <= max_size, and query_size_last > 0
43-
return same_head_dim_size and max_size >= query_size_last > 0
44-
45-
46-
def check_flash_causal_non_square_seqlens(query: torch.Tensor,
47-
key: torch.Tensor,
48-
is_causal: bool):
49-
# FlashAttention does not support the is_causal flag when seqlen_q != seqlen_k
50-
# Flash attention layout is (N, S, H, E), so sequence length is at index -3
51-
return not (is_causal and not query.is_nested and not key.is_nested and query.shape[-3] != key.shape[-3])
52-
53-
54-
def has_for_nested_inputs(query: torch.Tensor,
55-
key: torch.Tensor,
56-
value: torch.Tensor):
57-
return (query.is_nested and query.layout == torch.strided) or \
58-
(key.is_nested and key.layout == torch.strided) or \
59-
(value.is_nested and value.layout == torch.strided)
60-
61-
def has_only_dense_inputs(query: torch.Tensor,
62-
key: torch.Tensor,
63-
value: torch.Tensor):
64-
return (not query.is_nested) and (not key.is_nested) and (not value.is_nested)
65-
66-
def check_grouped_query_attention(query: torch.Tensor,
67-
key: torch.Tensor,
68-
value: torch.Tensor,
69-
requires_same_num_heads: bool = True) -> bool:
70-
"""Check if grouped query attention configuration is valid."""
71-
# Flash attention layout is (N, S, H, E), so num_heads is at index -2
72-
q_num_heads = query.size(-2)
73-
k_num_heads = key.size(-2)
74-
v_num_heads = value.size(-2)
75-
same_kv_heads = k_num_heads == v_num_heads
76-
77-
if requires_same_num_heads and not same_kv_heads:
78-
return False
79-
80-
# Check if grouped query attention is supported and validate the number of heads
81-
return not (q_num_heads % k_num_heads != 0 or (not requires_same_num_heads and q_num_heads % v_num_heads != 0))
82-
83-
84-
def check_batch_size_and_num_heads_dense(query: torch.Tensor,
85-
key: torch.Tensor,
86-
value: torch.Tensor,
87-
enable_gqa: bool = False,
88-
supports_gqa: bool = True,
89-
requires_same_num_heads: bool = True) -> bool:
90-
"""Check batch size and num_heads compatibility for dense tensors.
91-
92-
This is expected to be called after check_tensor_shapes ensuring that the
93-
size() calls won't error since the inputs are all 4 dimensional.
94-
"""
95-
q_batch_size = query.size(0)
96-
k_batch_size = key.size(0)
97-
v_batch_size = value.size(0)
98-
99-
same_batch_size = (q_batch_size == k_batch_size and q_batch_size == v_batch_size)
100-
101-
# Flash attention layout is (N, S, H, E), so num_heads is at index -2
102-
q_num_heads = query.size(-2)
103-
k_num_heads = key.size(-2)
104-
v_num_heads = value.size(-2)
105-
106-
same_num_heads = (q_num_heads == k_num_heads and q_num_heads == v_num_heads)
107-
108-
# For dense inputs, both fused kernels require query, key and value to have the same batch_size
109-
if not same_batch_size:
110-
return False
111-
112-
if enable_gqa and supports_gqa:
113-
return check_grouped_query_attention(query, key, value, requires_same_num_heads)
114-
115-
# same num heads condition for non-gqa case
116-
return same_num_heads
117-
118-
119-
def check_nonzero_sequence_lengths_dense(query: torch.Tensor,
120-
key: torch.Tensor,
121-
value: torch.Tensor) -> bool:
122-
"""Check that sequence lengths are non-zero for dense tensors."""
123-
# In some cases people will pass in 0 sized tensors, this will
124-
# cause the fused path to error with unaligned mask
125-
# Flash attention layout is (N, S, H, E), so sequence length is at index -3
126-
zero_seq_len_q = query.size(-3) == 0
127-
zero_seq_len_k = key.size(-3) == 0
128-
return not (zero_seq_len_q or zero_seq_len_k)
129-
130-
131-
def check_last_dim_stride_equals_1_dense(query: torch.Tensor,
132-
key: torch.Tensor,
133-
value: torch.Tensor,
134-
attn_mask: torch.Tensor | None = None,
135-
ignore_singleton_dim: bool = True) -> bool:
136-
"""Check that the last dimension of inputs has stride 1.
137-
138-
The stride checking for NestedTensors is done within the kernel
139-
and .contiguous will be called if needed.
140-
141-
This function checks that the last dimension of the inputs to
142-
fused_attention have stride 1.
143-
"""
144-
qkv_strides_equal_1 = (query.stride(-1) == 1 and
145-
key.stride(-1) == 1 and
146-
value.stride(-1) == 1)
147-
148-
# If the head_dim is size 1 the stride won't matter, but we
149-
# check this condition before padding the head_dim to 1
150-
if ignore_singleton_dim:
151-
qkv_strides_equal_1 = qkv_strides_equal_1 or query.size(-1) == 1
152-
153-
is_cpu = query.is_cpu
154-
mask_stride_equal_1 = attn_mask.stride(-1) == 1 if attn_mask is not None else True
155-
mask_stride_valid = True if is_cpu else mask_stride_equal_1
156-
157-
return qkv_strides_equal_1 and mask_stride_valid
158-
159-
160-
def check_dtypes_low_precision(query: torch.Tensor,
161-
key: torch.Tensor,
162-
value: torch.Tensor):
163-
return query.dtype == key.dtype and query.dtype == value.dtype and query.dtype in [torch.float16, torch.bfloat16]
15+
ALLOWED_TYPES = {torch.float16, torch.bfloat16}
16416

16517

16618
def can_use_flash_attn(query: torch.Tensor,
@@ -169,72 +21,113 @@ def can_use_flash_attn(query: torch.Tensor,
16921
attn_mask: torch.Tensor | None = None,
17022
is_causal: bool = False,
17123
enable_gqa: bool = False):
172-
# Define gate functions that determine if a flash kernel can be ran
173-
if not (all_tensors_on_device(query, key, value) and
174-
check_tensor_shapes(query, key, value) and
175-
check_for_attn_mask(attn_mask) and
176-
check_head_dim_size_flash(query, key, value) and
177-
check_flash_causal_non_square_seqlens(query, key, is_causal) and
178-
check_dtypes_low_precision(query, key, value)):
24+
"""
25+
Check if Flash Attention can be used for the given tensors.
26+
27+
Args:
28+
query: Query tensor of shape (B, H, L, D)
29+
key: Key tensor of shape (B, H, L, D)
30+
value: Value tensor of shape (B, H, L, D)
31+
attn_mask: Optional attention mask (not supported by flash_attn)
32+
is_causal: Whether to use causal attention
33+
enable_gqa: Whether grouped query attention is enabled
34+
35+
Returns:
36+
bool: True if Flash Attention can be used, False otherwise
37+
"""
38+
# Fast grouped early rejects (most common failures first).
39+
dt = query.dtype
40+
if (
41+
attn_mask is not None # Explicit attention masks are not supported by flash_attn
42+
or dt not in ALLOWED_TYPES # flash_attn requires fp16/bf16
43+
or dt != key.dtype or dt != value.dtype # Q/K/V must have identical dtypes
44+
or not (query.is_cuda and key.is_cuda and value.is_cuda) # flash_attn is CUDA-only
45+
or query.dim() != 4 or key.dim() != 4 or value.dim() != 4 # Expect rank-4 (B, H, L, D)
46+
or query.is_nested or key.is_nested or value.is_nested # Nested tensors unsupported, keep our use-case simple
47+
):
48+
return False
49+
50+
# Unpack shapes once.
51+
(bq, q_heads, q_len, head_dim) = query.shape
52+
(bk, k_heads, k_len, k_head_dim) = key.shape
53+
(bv, v_heads, v_len, v_head_dim) = value.shape
54+
55+
# Batch & head dim validation.
56+
if bq != bk or bq != bv:
57+
return False
58+
if not (0 < head_dim <= 256 and head_dim == k_head_dim == v_head_dim):
17959
return False
18060

181-
# While PyTorch's Flash Attention implementation supports nested tensors,
182-
# we want to keep our use-case simple for now and avoid nested strided tensors as validations
183-
# require digging into tensor internals.
184-
if has_for_nested_inputs(query, key, value):
61+
# Sequence length checks.
62+
if q_len == 0 or k_len == 0:
63+
return False
64+
if is_causal and q_len != k_len: # causal path requires equal seq lengths
18565
return False
18666

187-
if has_only_dense_inputs(query, key, value):
188-
if not (check_batch_size_and_num_heads_dense(query, key, value, enable_gqa, supports_gqa=True) and
189-
check_nonzero_sequence_lengths_dense(query, key, value) and
190-
check_last_dim_stride_equals_1_dense(query, key, value, attn_mask, ignore_singleton_dim=True)):
67+
# Head count validation (GQA aware).
68+
if enable_gqa:
69+
if k_heads != v_heads or k_heads == 0 or (q_heads % k_heads) != 0:
70+
return False
71+
else:
72+
if not (q_heads == k_heads == v_heads):
73+
return False
74+
75+
# Stride check (only if dim > 1).
76+
if head_dim != 1:
77+
qs = query.stride(-1)
78+
ks = key.stride(-1)
79+
vs = value.stride(-1)
80+
if qs != 1 or ks != 1 or vs != 1: # All last-dim strides must be 1 (contiguous)
19181
return False
19282

19383
return True
19484

19585

19686
def supports_flash_attention_in_sdp():
87+
"""Check if Flash Attention is natively supported in scaled_dot_product."""
19788
return torch.cuda.is_available() and torch.backends.cuda.is_flash_attention_available()
19889

19990

20091
def register():
92+
"""
93+
Register Flash Attention fallback on Windows when native support is unavailable.
94+
95+
Patches F.scaled_dot_product_attention to use flash_attn_func when conditions allow,
96+
falling back to the original implementation otherwise.
97+
"""
20198
if sys.platform == "win32" and is_flash_attn_available() and not supports_flash_attention_in_sdp():
202-
from flash_attn.flash_attn_interface import flash_attn_func
203-
204-
def _native_flash_attention(
205-
query: torch.Tensor,
206-
key: torch.Tensor,
207-
value: torch.Tensor,
208-
attn_mask: torch.Tensor | None = None,
209-
dropout_p: float = 0.0,
210-
is_causal: bool = False,
211-
scale: float | None = None,
212-
enable_gqa: bool = False
213-
) -> torch.Tensor:
214-
# Determine if we can use flash attention
99+
try:
100+
from flash_attn.flash_attn_interface import flash_attn_func
101+
except Exception:
102+
return
103+
104+
_scaled_dot_product_attention = F.scaled_dot_product_attention
105+
106+
def _flash_dynamic_scaled_dot_product_attention(query: torch.Tensor,
107+
key: torch.Tensor,
108+
value: torch.Tensor,
109+
attn_mask: torch.Tensor | None = None,
110+
dropout_p: float = 0.0,
111+
is_causal: bool = False,
112+
scale: float | None = None,
113+
enable_gqa: bool = False):
215114
if can_use_flash_attn(query, key, value, attn_mask, is_causal, enable_gqa):
216-
return flash_attn_func(
217-
q=query,
218-
k=key,
219-
v=value,
115+
# transpose(1,2) is equivalent to permute(0,2,1,3) for (B,H,L,D) -> (B,L,H,D)
116+
q = query.transpose(1, 2)
117+
k = key.transpose(1, 2)
118+
v = value.transpose(1, 2)
119+
out = flash_attn_func(
120+
q=q, k=k, v=v,
220121
dropout_p=dropout_p,
221122
softmax_scale=scale,
222123
causal=is_causal
223124
)
224-
else:
225-
query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value))
226-
out = torch.nn.functional.scaled_dot_product_attention(
227-
query=query,
228-
key=key,
229-
value=value,
230-
attn_mask=attn_mask,
231-
dropout_p=dropout_p,
232-
is_causal=is_causal,
233-
scale=scale,
234-
enable_gqa=enable_gqa,
235-
)
236-
out = out.permute(0, 2, 1, 3)
237-
return out
125+
return out.transpose(1, 2)
126+
127+
# Fallback
128+
return _scaled_dot_product_attention(
129+
query=query, key=key, value=value,
130+
attn_mask=attn_mask, dropout_p=dropout_p,
131+
is_causal=is_causal, scale=scale, enable_gqa=enable_gqa)
238132

239-
# Register the dynamic flash attention backend in place of the native one
240-
_AttentionBackendRegistry.register(AttentionBackendName.NATIVE, [])(_native_flash_attention)
133+
F.scaled_dot_product_attention = _flash_dynamic_scaled_dot_product_attention

0 commit comments

Comments
 (0)