1+ """
2+ Flash Attention support for Windows platforms.
3+
4+ This module provides a dynamic fallback mechanism for Flash Attention on Windows,
5+ patching PyTorch's scaled_dot_product_attention when native support is unavailable.
6+ """
7+
18import sys
29
310import torch
11+ import torch .nn .functional as F
412
5- from diffusers .models .attention_dispatch import AttentionBackendName , _AttentionBackendRegistry
613from diffusers .utils import is_flash_attn_available
714
8-
9- def all_tensors_on_device (query : torch .Tensor ,
10- key : torch .Tensor ,
11- value : torch .Tensor ):
12- # Check that all tensors are on the GPU device
13- return query .is_cuda and key .is_cuda and value .is_cuda
14-
15-
16- def check_for_attn_mask (attn_mask : torch .Tensor | None ):
17- # Flash Attention does not support non-null attn_mask
18- return attn_mask is None
19-
20-
21- def check_tensor_shapes (query : torch .Tensor ,
22- key : torch .Tensor ,
23- value : torch .Tensor ):
24- # All fused kernels requires query, key and value to be 4 dimensional
25- query_dim = query .dim ()
26- return query_dim == key .dim () and query_dim == value .dim () and query_dim == 4
27-
28-
29- def check_head_dim_size_flash (query : torch .Tensor ,
30- key : torch .Tensor ,
31- value : torch .Tensor ):
32- # All head_dim sizes must be equal and less than 256
33- # (ROCm with AOTriton 0.9+ supports up to 512, but we keep 256 for simplicity)
34- max_size = 256
35- query_size_last = query .size (- 1 )
36- key_size_last = key .size (- 1 )
37- value_size_last = value .size (- 1 )
38-
39- same_head_dim_size = (query_size_last == key_size_last and
40- query_size_last == value_size_last )
41-
42- # Check that all head dims are equal, all <= max_size, and query_size_last > 0
43- return same_head_dim_size and max_size >= query_size_last > 0
44-
45-
46- def check_flash_causal_non_square_seqlens (query : torch .Tensor ,
47- key : torch .Tensor ,
48- is_causal : bool ):
49- # FlashAttention does not support the is_causal flag when seqlen_q != seqlen_k
50- # Flash attention layout is (N, S, H, E), so sequence length is at index -3
51- return not (is_causal and not query .is_nested and not key .is_nested and query .shape [- 3 ] != key .shape [- 3 ])
52-
53-
54- def has_for_nested_inputs (query : torch .Tensor ,
55- key : torch .Tensor ,
56- value : torch .Tensor ):
57- return (query .is_nested and query .layout == torch .strided ) or \
58- (key .is_nested and key .layout == torch .strided ) or \
59- (value .is_nested and value .layout == torch .strided )
60-
61- def has_only_dense_inputs (query : torch .Tensor ,
62- key : torch .Tensor ,
63- value : torch .Tensor ):
64- return (not query .is_nested ) and (not key .is_nested ) and (not value .is_nested )
65-
66- def check_grouped_query_attention (query : torch .Tensor ,
67- key : torch .Tensor ,
68- value : torch .Tensor ,
69- requires_same_num_heads : bool = True ) -> bool :
70- """Check if grouped query attention configuration is valid."""
71- # Flash attention layout is (N, S, H, E), so num_heads is at index -2
72- q_num_heads = query .size (- 2 )
73- k_num_heads = key .size (- 2 )
74- v_num_heads = value .size (- 2 )
75- same_kv_heads = k_num_heads == v_num_heads
76-
77- if requires_same_num_heads and not same_kv_heads :
78- return False
79-
80- # Check if grouped query attention is supported and validate the number of heads
81- return not (q_num_heads % k_num_heads != 0 or (not requires_same_num_heads and q_num_heads % v_num_heads != 0 ))
82-
83-
84- def check_batch_size_and_num_heads_dense (query : torch .Tensor ,
85- key : torch .Tensor ,
86- value : torch .Tensor ,
87- enable_gqa : bool = False ,
88- supports_gqa : bool = True ,
89- requires_same_num_heads : bool = True ) -> bool :
90- """Check batch size and num_heads compatibility for dense tensors.
91-
92- This is expected to be called after check_tensor_shapes ensuring that the
93- size() calls won't error since the inputs are all 4 dimensional.
94- """
95- q_batch_size = query .size (0 )
96- k_batch_size = key .size (0 )
97- v_batch_size = value .size (0 )
98-
99- same_batch_size = (q_batch_size == k_batch_size and q_batch_size == v_batch_size )
100-
101- # Flash attention layout is (N, S, H, E), so num_heads is at index -2
102- q_num_heads = query .size (- 2 )
103- k_num_heads = key .size (- 2 )
104- v_num_heads = value .size (- 2 )
105-
106- same_num_heads = (q_num_heads == k_num_heads and q_num_heads == v_num_heads )
107-
108- # For dense inputs, both fused kernels require query, key and value to have the same batch_size
109- if not same_batch_size :
110- return False
111-
112- if enable_gqa and supports_gqa :
113- return check_grouped_query_attention (query , key , value , requires_same_num_heads )
114-
115- # same num heads condition for non-gqa case
116- return same_num_heads
117-
118-
119- def check_nonzero_sequence_lengths_dense (query : torch .Tensor ,
120- key : torch .Tensor ,
121- value : torch .Tensor ) -> bool :
122- """Check that sequence lengths are non-zero for dense tensors."""
123- # In some cases people will pass in 0 sized tensors, this will
124- # cause the fused path to error with unaligned mask
125- # Flash attention layout is (N, S, H, E), so sequence length is at index -3
126- zero_seq_len_q = query .size (- 3 ) == 0
127- zero_seq_len_k = key .size (- 3 ) == 0
128- return not (zero_seq_len_q or zero_seq_len_k )
129-
130-
131- def check_last_dim_stride_equals_1_dense (query : torch .Tensor ,
132- key : torch .Tensor ,
133- value : torch .Tensor ,
134- attn_mask : torch .Tensor | None = None ,
135- ignore_singleton_dim : bool = True ) -> bool :
136- """Check that the last dimension of inputs has stride 1.
137-
138- The stride checking for NestedTensors is done within the kernel
139- and .contiguous will be called if needed.
140-
141- This function checks that the last dimension of the inputs to
142- fused_attention have stride 1.
143- """
144- qkv_strides_equal_1 = (query .stride (- 1 ) == 1 and
145- key .stride (- 1 ) == 1 and
146- value .stride (- 1 ) == 1 )
147-
148- # If the head_dim is size 1 the stride won't matter, but we
149- # check this condition before padding the head_dim to 1
150- if ignore_singleton_dim :
151- qkv_strides_equal_1 = qkv_strides_equal_1 or query .size (- 1 ) == 1
152-
153- is_cpu = query .is_cpu
154- mask_stride_equal_1 = attn_mask .stride (- 1 ) == 1 if attn_mask is not None else True
155- mask_stride_valid = True if is_cpu else mask_stride_equal_1
156-
157- return qkv_strides_equal_1 and mask_stride_valid
158-
159-
160- def check_dtypes_low_precision (query : torch .Tensor ,
161- key : torch .Tensor ,
162- value : torch .Tensor ):
163- return query .dtype == key .dtype and query .dtype == value .dtype and query .dtype in [torch .float16 , torch .bfloat16 ]
15+ ALLOWED_TYPES = {torch .float16 , torch .bfloat16 }
16416
16517
16618def can_use_flash_attn (query : torch .Tensor ,
@@ -169,72 +21,113 @@ def can_use_flash_attn(query: torch.Tensor,
16921 attn_mask : torch .Tensor | None = None ,
17022 is_causal : bool = False ,
17123 enable_gqa : bool = False ):
172- # Define gate functions that determine if a flash kernel can be ran
173- if not (all_tensors_on_device (query , key , value ) and
174- check_tensor_shapes (query , key , value ) and
175- check_for_attn_mask (attn_mask ) and
176- check_head_dim_size_flash (query , key , value ) and
177- check_flash_causal_non_square_seqlens (query , key , is_causal ) and
178- check_dtypes_low_precision (query , key , value )):
24+ """
25+ Check if Flash Attention can be used for the given tensors.
26+
27+ Args:
28+ query: Query tensor of shape (B, H, L, D)
29+ key: Key tensor of shape (B, H, L, D)
30+ value: Value tensor of shape (B, H, L, D)
31+ attn_mask: Optional attention mask (not supported by flash_attn)
32+ is_causal: Whether to use causal attention
33+ enable_gqa: Whether grouped query attention is enabled
34+
35+ Returns:
36+ bool: True if Flash Attention can be used, False otherwise
37+ """
38+ # Fast grouped early rejects (most common failures first).
39+ dt = query .dtype
40+ if (
41+ attn_mask is not None # Explicit attention masks are not supported by flash_attn
42+ or dt not in ALLOWED_TYPES # flash_attn requires fp16/bf16
43+ or dt != key .dtype or dt != value .dtype # Q/K/V must have identical dtypes
44+ or not (query .is_cuda and key .is_cuda and value .is_cuda ) # flash_attn is CUDA-only
45+ or query .dim () != 4 or key .dim () != 4 or value .dim () != 4 # Expect rank-4 (B, H, L, D)
46+ or query .is_nested or key .is_nested or value .is_nested # Nested tensors unsupported, keep our use-case simple
47+ ):
48+ return False
49+
50+ # Unpack shapes once.
51+ (bq , q_heads , q_len , head_dim ) = query .shape
52+ (bk , k_heads , k_len , k_head_dim ) = key .shape
53+ (bv , v_heads , v_len , v_head_dim ) = value .shape
54+
55+ # Batch & head dim validation.
56+ if bq != bk or bq != bv :
57+ return False
58+ if not (0 < head_dim <= 256 and head_dim == k_head_dim == v_head_dim ):
17959 return False
18060
181- # While PyTorch's Flash Attention implementation supports nested tensors,
182- # we want to keep our use-case simple for now and avoid nested strided tensors as validations
183- # require digging into tensor internals.
184- if has_for_nested_inputs ( query , key , value ):
61+ # Sequence length checks.
62+ if q_len == 0 or k_len == 0 :
63+ return False
64+ if is_causal and q_len != k_len : # causal path requires equal seq lengths
18565 return False
18666
187- if has_only_dense_inputs (query , key , value ):
188- if not (check_batch_size_and_num_heads_dense (query , key , value , enable_gqa , supports_gqa = True ) and
189- check_nonzero_sequence_lengths_dense (query , key , value ) and
190- check_last_dim_stride_equals_1_dense (query , key , value , attn_mask , ignore_singleton_dim = True )):
67+ # Head count validation (GQA aware).
68+ if enable_gqa :
69+ if k_heads != v_heads or k_heads == 0 or (q_heads % k_heads ) != 0 :
70+ return False
71+ else :
72+ if not (q_heads == k_heads == v_heads ):
73+ return False
74+
75+ # Stride check (only if dim > 1).
76+ if head_dim != 1 :
77+ qs = query .stride (- 1 )
78+ ks = key .stride (- 1 )
79+ vs = value .stride (- 1 )
80+ if qs != 1 or ks != 1 or vs != 1 : # All last-dim strides must be 1 (contiguous)
19181 return False
19282
19383 return True
19484
19585
19686def supports_flash_attention_in_sdp ():
87+ """Check if Flash Attention is natively supported in scaled_dot_product."""
19788 return torch .cuda .is_available () and torch .backends .cuda .is_flash_attention_available ()
19889
19990
20091def register ():
92+ """
93+ Register Flash Attention fallback on Windows when native support is unavailable.
94+
95+ Patches F.scaled_dot_product_attention to use flash_attn_func when conditions allow,
96+ falling back to the original implementation otherwise.
97+ """
20198 if sys .platform == "win32" and is_flash_attn_available () and not supports_flash_attention_in_sdp ():
202- from flash_attn .flash_attn_interface import flash_attn_func
203-
204- def _native_flash_attention (
205- query : torch .Tensor ,
206- key : torch .Tensor ,
207- value : torch .Tensor ,
208- attn_mask : torch .Tensor | None = None ,
209- dropout_p : float = 0.0 ,
210- is_causal : bool = False ,
211- scale : float | None = None ,
212- enable_gqa : bool = False
213- ) -> torch .Tensor :
214- # Determine if we can use flash attention
99+ try :
100+ from flash_attn .flash_attn_interface import flash_attn_func
101+ except Exception :
102+ return
103+
104+ _scaled_dot_product_attention = F .scaled_dot_product_attention
105+
106+ def _flash_dynamic_scaled_dot_product_attention (query : torch .Tensor ,
107+ key : torch .Tensor ,
108+ value : torch .Tensor ,
109+ attn_mask : torch .Tensor | None = None ,
110+ dropout_p : float = 0.0 ,
111+ is_causal : bool = False ,
112+ scale : float | None = None ,
113+ enable_gqa : bool = False ):
215114 if can_use_flash_attn (query , key , value , attn_mask , is_causal , enable_gqa ):
216- return flash_attn_func (
217- q = query ,
218- k = key ,
219- v = value ,
115+ # transpose(1,2) is equivalent to permute(0,2,1,3) for (B,H,L,D) -> (B,L,H,D)
116+ q = query .transpose (1 , 2 )
117+ k = key .transpose (1 , 2 )
118+ v = value .transpose (1 , 2 )
119+ out = flash_attn_func (
120+ q = q , k = k , v = v ,
220121 dropout_p = dropout_p ,
221122 softmax_scale = scale ,
222123 causal = is_causal
223124 )
224- else :
225- query , key , value = (x .permute (0 , 2 , 1 , 3 ) for x in (query , key , value ))
226- out = torch .nn .functional .scaled_dot_product_attention (
227- query = query ,
228- key = key ,
229- value = value ,
230- attn_mask = attn_mask ,
231- dropout_p = dropout_p ,
232- is_causal = is_causal ,
233- scale = scale ,
234- enable_gqa = enable_gqa ,
235- )
236- out = out .permute (0 , 2 , 1 , 3 )
237- return out
125+ return out .transpose (1 , 2 )
126+
127+ # Fallback
128+ return _scaled_dot_product_attention (
129+ query = query , key = key , value = value ,
130+ attn_mask = attn_mask , dropout_p = dropout_p ,
131+ is_causal = is_causal , scale = scale , enable_gqa = enable_gqa )
238132
239- # Register the dynamic flash attention backend in place of the native one
240- _AttentionBackendRegistry .register (AttentionBackendName .NATIVE , [])(_native_flash_attention )
133+ F .scaled_dot_product_attention = _flash_dynamic_scaled_dot_product_attention
0 commit comments