Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions modules/trainer/GenericTrainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from modules.modelSetup.BaseModelSetup import BaseModelSetup
from modules.trainer.BaseTrainer import BaseTrainer
from modules.util import create, path_util
from modules.util.attn.flash_attn_win import disable_flash_attn_win, enable_flash_attn_win
from modules.util.bf16_stochastic_rounding import set_seed as bf16_stochastic_rounding_set_seed
from modules.util.callbacks.TrainCallbacks import TrainCallbacks
from modules.util.commands.TrainCommands import TrainCommands
Expand Down Expand Up @@ -88,6 +89,8 @@ def start(self):
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True

self.__apply_flash_attn_fallback()

self.model_loader = self.create_model_loader()
self.model_setup = self.create_model_setup()

Expand Down Expand Up @@ -600,6 +603,14 @@ def __before_eval(self):
torch.clear_autocast_cache()
self.model.optimizer.eval()


def __apply_flash_attn_fallback(self):
if self.config.use_flash_attn_fallback:
enable_flash_attn_win()
else:
disable_flash_attn_win()


def train(self):
train_device = torch.device(self.config.train_device)

Expand Down
4 changes: 4 additions & 0 deletions modules/ui/TrainUI.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,6 +335,10 @@ def create_general_tab(self, master):
tooltip="The device used to temporarily offload models while they are not used. Default:\"cpu\"")
components.entry(frame, 16, 1, self.ui_state, "temp_device")

components.label(frame, 17, 0, "Use Flash-Attention Fallback",
tooltip="Enables Flash-Attention fallback on Windows if native support is not available in PyTorch for a performance improvement during training/sampling.")
components.switch(frame, 17, 1, self.ui_state, "use_flash_attn_fallback")

frame.pack(fill="both", expand=1)
return frame

Expand Down
176 changes: 176 additions & 0 deletions modules/util/attn/flash_attn_win.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,176 @@
"""
Flash Attention support for Windows platforms.

This module provides a dynamic fallback mechanism for Flash Attention on Windows,
patching PyTorch's scaled_dot_product_attention when native support is unavailable.
"""

import sys

import torch
import torch.nn.functional as F

from diffusers.utils import is_flash_attn_available

ALLOWED_TYPES = {torch.float16, torch.bfloat16}
SUPPORTED_DEVICES = []


def _check_device_capability(device_index: int) -> bool:
"""Check if a specific CUDA device supports Flash Attention (compute capability >= 8.0, ie. Ampere GPUs onwards)."""
try:
return torch.cuda.get_device_properties(device_index).major >= 8
except Exception:
return False


if torch.cuda.is_available():
device_count = torch.cuda.device_count()
SUPPORTED_DEVICES = [_check_device_capability(i) for i in range(device_count)]


def is_supported_hardware(device: torch.device) -> bool:
"""
Check if the given device supports Flash Attention.
"""
return SUPPORTED_DEVICES[device.index]


def can_use_flash_attn(query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_mask: torch.Tensor | None = None,
is_causal: bool = False,
enable_gqa: bool = False) -> bool:
"""
Check if Flash Attention can be used for the given tensors.

Args:
query: Query tensor of shape (B, H, L, D)
key: Key tensor of shape (B, H, L, D)
value: Value tensor of shape (B, H, L, D)
attn_mask: Optional attention mask (not supported by flash_attn)
is_causal: Whether to use causal attention
enable_gqa: Whether grouped query attention is enabled

Returns:
bool: True if Flash Attention can be used, False otherwise
"""
# Fast grouped early rejects (most common failures first).
dt = query.dtype
if (
attn_mask is not None # Explicit attention masks are not supported by flash_attn
or dt not in ALLOWED_TYPES # flash_attn requires fp16/bf16
or dt != key.dtype or dt != value.dtype # Q/K/V must have identical dtypes
or not (query.is_cuda and key.is_cuda and value.is_cuda) # flash_attn is CUDA-only
or query.dim() != 4 or key.dim() != 4 or value.dim() != 4 # Expect rank-4 (B, H, L, D)
or query.is_nested or key.is_nested or value.is_nested # Nested tensors unsupported, keep our use-case simple
):
return False

# Hardware capability check
if not is_supported_hardware(query.device):
return False

# Unpack shapes once.
(bq, q_heads, q_len, head_dim) = query.shape
(bk, k_heads, k_len, k_head_dim) = key.shape
(bv, v_heads, v_len, v_head_dim) = value.shape

# Batch & head dim validation.
if bq != bk or bq != bv:
return False
if not (0 < head_dim <= 256 and head_dim == k_head_dim == v_head_dim):
return False

# Sequence length checks.
if q_len == 0 or k_len == 0:
return False
if is_causal and q_len != k_len: # causal path requires equal seq lengths
return False

# Head count validation (GQA aware).
if enable_gqa:
if k_heads != v_heads or k_heads == 0 or (q_heads % k_heads) != 0:
return False
else:
if not (q_heads == k_heads == v_heads):
return False

# Stride check (only if dim > 1).
if head_dim != 1:
qs = query.stride(-1)
ks = key.stride(-1)
vs = value.stride(-1)
if qs != 1 or ks != 1 or vs != 1: # All last-dim strides must be 1 (contiguous)
return False

return True


_scaled_dot_product_attention = None


def enable_flash_attn_win():
"""Enable Flash Attention fallback on Windows."""
_register()


def disable_flash_attn_win():
"""Disable Flash Attention fallback on Windows."""
global _scaled_dot_product_attention
if _scaled_dot_product_attention is not None:
F.scaled_dot_product_attention = _scaled_dot_product_attention
_scaled_dot_product_attention = None


def supports_flash_attention_in_sdp():
"""Check if Flash Attention is natively supported in scaled_dot_product."""
return torch.cuda.is_available() and torch.backends.cuda.is_flash_attention_available()


def _register():
"""
Register Flash Attention fallback on Windows when native support is unavailable.

Patches F.scaled_dot_product_attention to use flash_attn_func when conditions allow,
falling back to the original implementation otherwise.
"""
global _scaled_dot_product_attention
if _scaled_dot_product_attention is None and sys.platform == "win32" and is_flash_attn_available() and not supports_flash_attention_in_sdp():
try:
from flash_attn.flash_attn_interface import flash_attn_func
except Exception:
return

_scaled_dot_product_attention = F.scaled_dot_product_attention

def _flash_dynamic_scaled_dot_product_attention(query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_mask: torch.Tensor | None = None,
dropout_p: float = 0.0,
is_causal: bool = False,
scale: float | None = None,
enable_gqa: bool = False,
_fallback_sdpa = _scaled_dot_product_attention):
if can_use_flash_attn(query, key, value, attn_mask, is_causal, enable_gqa):
# transpose(1,2) is equivalent to permute(0,2,1,3) for (B,H,L,D) -> (B,L,H,D)
q = query.transpose(1, 2)
k = key.transpose(1, 2)
v = value.transpose(1, 2)
out = flash_attn_func(
q=q, k=k, v=v,
dropout_p=dropout_p,
softmax_scale=scale,
causal=is_causal
)
return out.transpose(1, 2)

# Fallback
return _fallback_sdpa(
query=query, key=key, value=value,
attn_mask=attn_mask, dropout_p=dropout_p,
is_causal=is_causal, scale=scale, enable_gqa=enable_gqa)

F.scaled_dot_product_attention = _flash_dynamic_scaled_dot_product_attention
2 changes: 2 additions & 0 deletions modules/util/config/TrainConfig.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,6 +378,7 @@ class TrainConfig(BaseConfig):
loss_scaler: LossScaler
learning_rate_scaler: LearningRateScaler
clip_grad_norm: float
use_flash_attn_fallback: bool

#layer filter
layer_filter: str # comma-separated
Expand Down Expand Up @@ -933,6 +934,7 @@ def default_values() -> 'TrainConfig':
data.append(("loss_scaler", LossScaler.NONE, LossScaler, False))
data.append(("learning_rate_scaler", LearningRateScaler.NONE, LearningRateScaler, False))
data.append(("clip_grad_norm", 1.0, float, True))
data.append(("use_flash_attn_fallback", True, bool, False))

# noise
data.append(("offset_noise_weight", 0.0, float, False))
Expand Down
5 changes: 5 additions & 0 deletions requirements-cuda.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,8 @@ triton-windows==3.4.0.post20; sys_platform == "win32"

# optimizers
bitsandbytes==0.46.0 # bitsandbytes for 8-bit optimizers and weight quantization

# flash-attn
https://github.com/zzlol63/flash-attention-prebuild-wheels/releases/download/v0.2/flash_attn-2.8.2+cu128torch2.8-cp310-cp310-win_amd64.whl; sys_platform == "win32" and python_version == "3.10"
https://github.com/zzlol63/flash-attention-prebuild-wheels/releases/download/v0.2/flash_attn-2.8.2+cu128torch2.8-cp311-cp311-win_amd64.whl; sys_platform == "win32" and python_version == "3.11"
https://github.com/zzlol63/flash-attention-prebuild-wheels/releases/download/v0.2/flash_attn-2.8.2+cu128torch2.8-cp312-cp312-win_amd64.whl; sys_platform == "win32" and python_version == "3.12"