Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions modules/modelSetup/BaseChromaSetup.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,8 @@ def setup_optimizations(
quantize_layers(model.vae, self.train_device, model.train_dtype, config)
quantize_layers(model.transformer, self.train_device, model.train_dtype, config)

self._set_attention_backend(model.transformer, config.attention_mechanism, mask=True)

def _setup_embeddings(
self,
model: ChromaModel,
Expand Down
2 changes: 2 additions & 0 deletions modules/modelSetup/BaseFluxSetup.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,8 @@ def setup_optimizations(
quantize_layers(model.vae, self.train_device, model.train_dtype, config)
quantize_layers(model.transformer, self.train_device, model.train_dtype, config)

self._set_attention_backend(model.transformer, config.attention_mechanism, mask=False)

def _setup_embeddings(
self,
model: FluxModel,
Expand Down
2 changes: 2 additions & 0 deletions modules/modelSetup/BaseHiDreamSetup.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,8 @@ def setup_optimizations(
quantize_layers(model.vae, self.train_device, model.train_dtype, config)
quantize_layers(model.transformer, self.train_device, model.transformer_train_dtype, config)

self._set_attention_backend(model.transformer, config.attention_mechanism, mask=True)

def _setup_embeddings(
self,
model: HiDreamModel,
Expand Down
2 changes: 2 additions & 0 deletions modules/modelSetup/BaseHunyuanVideoSetup.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,8 @@ def setup_optimizations(
quantize_layers(model.transformer, self.train_device, model.transformer_train_dtype, config)

model.vae.enable_tiling()
self._set_attention_backend(model.transformer, config.attention_mechanism, mask=True)


def _setup_embeddings(
self,
Expand Down
13 changes: 13 additions & 0 deletions modules/modelSetup/BaseModelSetup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

from modules.model.BaseModel import BaseModel
from modules.util.config.TrainConfig import TrainConfig, TrainEmbeddingConfig, TrainModelPartConfig
from modules.util.enum.AttentionMechanism import AttentionMechanism
from modules.util.enum.TrainingMethod import TrainingMethod
from modules.util.ModuleFilter import ModuleFilter
from modules.util.NamedParameterGroup import NamedParameterGroup, NamedParameterGroupCollection
Expand Down Expand Up @@ -235,3 +236,15 @@ def _setup_model_part_requires_grad(
if unique_name in self.frozen_parameters:
for param in self.frozen_parameters[unique_name]:
param.requires_grad_(False)

@staticmethod
def _set_attention_backend(component, attn: AttentionMechanism, mask: bool=False, varlen: bool=False):
match attn:
case AttentionMechanism.SDP:
component.set_attention_backend("native")
case AttentionMechanism.FLASH:
if mask or varlen:
print("Warning: FLASH attention might fail for this model, depending on other configuration (batch size > 1, etc.)")
component.set_attention_backend("flash")
case _:
raise NotImplementedError(f"attention mechanism {str(attn)} not implemented")
1 change: 1 addition & 0 deletions modules/modelSetup/BasePixArtAlphaSetup.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ def setup_optimizations(
quantize_layers(model.text_encoder, self.train_device, model.text_encoder_train_dtype, config)
quantize_layers(model.vae, self.train_device, model.train_dtype, config)
quantize_layers(model.transformer, self.train_device, model.train_dtype, config)
self._set_attention_backend(model.transformer, config.attention_mechanism, mask=True)

def _setup_embeddings(
self,
Expand Down
1 change: 1 addition & 0 deletions modules/modelSetup/BaseQwenSetup.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ def setup_optimizations(
quantize_layers(model.text_encoder, self.train_device, model.text_encoder_train_dtype, config)
quantize_layers(model.vae, self.train_device, model.train_dtype, config)
quantize_layers(model.transformer, self.train_device, model.train_dtype, config)
self._set_attention_backend(model.transformer, config.attention_mechanism, mask=True)

def predict(
self,
Expand Down
1 change: 1 addition & 0 deletions modules/modelSetup/BaseSanaSetup.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ def setup_optimizations(
quantize_layers(model.text_encoder, self.train_device, model.text_encoder_train_dtype, config)
quantize_layers(model.vae, self.train_device, model.train_dtype, config)
quantize_layers(model.transformer, self.train_device, model.train_dtype, config)
self._set_attention_backend(model.transformer, config.attention_mechanism, mask=True)

def _setup_embeddings(
self,
Expand Down
1 change: 1 addition & 0 deletions modules/modelSetup/BaseStableDiffusion3Setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ def setup_optimizations(
quantize_layers(model.text_encoder_3, self.train_device, model.text_encoder_3_train_dtype, config)
quantize_layers(model.vae, self.train_device, model.train_dtype, config)
quantize_layers(model.transformer, self.train_device, model.train_dtype, config)
self._set_attention_backend(model.transformer, config.attention_mechanism)

def _setup_embeddings(
self,
Expand Down
1 change: 1 addition & 0 deletions modules/modelSetup/BaseStableDiffusionSetup.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ def setup_optimizations(
quantize_layers(model.text_encoder, self.train_device, model.train_dtype, config)
quantize_layers(model.vae, self.train_device, model.train_dtype, config)
quantize_layers(model.unet, self.train_device, model.train_dtype, config)
self._set_attention_backend(model.unet, config.attention_mechanism)

def _setup_embeddings(
self,
Expand Down
1 change: 1 addition & 0 deletions modules/modelSetup/BaseStableDiffusionXLSetup.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ def setup_optimizations(
quantize_layers(model.text_encoder_2, self.train_device, model.train_dtype, config)
quantize_layers(model.vae, self.train_device, model.vae_train_dtype, config)
quantize_layers(model.unet, self.train_device, model.train_dtype, config)
self._set_attention_backend(model.unet, config.attention_mechanism)

def _setup_embeddings(
self,
Expand Down
1 change: 1 addition & 0 deletions modules/modelSetup/BaseZImageSetup.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ def setup_optimizations(
quantize_layers(model.text_encoder, self.train_device, model.text_encoder_train_dtype, config)
quantize_layers(model.vae, self.train_device, model.train_dtype, config)
quantize_layers(model.transformer, self.train_device, model.train_dtype, config)
self._set_attention_backend(model.transformer, config.attention_mechanism, mask=True)

def predict(
self,
Expand Down
8 changes: 8 additions & 0 deletions modules/ui/TrainingTab.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from modules.ui.SchedulerParamsWindow import SchedulerParamsWindow
from modules.ui.TimestepDistributionWindow import TimestepDistributionWindow
from modules.util.config.TrainConfig import TrainConfig
from modules.util.enum.AttentionMechanism import AttentionMechanism
from modules.util.enum.DataType import DataType
from modules.util.enum.EMAMode import EMAMode
from modules.util.enum.GradientCheckpointingMethod import GradientCheckpointingMethod
Expand Down Expand Up @@ -336,6 +337,13 @@ def __create_base2_frame(self, master, row, video_training_enabled: bool = False
frame.grid_columnconfigure(0, weight=1)
row = 0

# attention mechanism
components.label(frame, row, 0, "Attention",
tooltip="The attention mechanism used during training. Use `SDP` on linux. On windows, `FLASH` can be faster but you have to install it, and it does not support all models.")
components.options(frame, row, 1, [str(x) for x in list(AttentionMechanism)], self.ui_state,
"attention_mechanism")
row += 1

# ema
components.label(frame, row, 0, "EMA",
tooltip="EMA averages the training progress over many steps, better preserving different concepts in big datasets")
Expand Down
3 changes: 3 additions & 0 deletions modules/util/config/TrainConfig.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from modules.util.config.ConceptConfig import ConceptConfig
from modules.util.config.SampleConfig import SampleConfig
from modules.util.config.SecretsConfig import SecretsConfig
from modules.util.enum.AttentionMechanism import AttentionMechanism
from modules.util.enum.AudioFormat import AudioFormat
from modules.util.enum.ConfigPart import ConfigPart
from modules.util.enum.DataType import DataType
Expand Down Expand Up @@ -422,6 +423,7 @@ class TrainConfig(BaseConfig):
only_cache: bool
resolution: str
frames: str
attention_mechanism: AttentionMechanism
mse_strength: float
mae_strength: float
log_cosh_strength: float
Expand Down Expand Up @@ -1005,6 +1007,7 @@ def default_values() -> 'TrainConfig':
data.append(("only_cache", False, bool, False))
data.append(("resolution", "512", str, False))
data.append(("frames", "25", str, False))
data.append(("attention_mechanism", AttentionMechanism.SDP, AttentionMechanism, False))
data.append(("mse_strength", 1.0, float, False))
data.append(("mae_strength", 0.0, float, False))
data.append(("log_cosh_strength", 0.0, float, False))
Expand Down
9 changes: 9 additions & 0 deletions modules/util/enum/AttentionMechanism.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from enum import Enum


class AttentionMechanism(Enum):
SDP = 'SDP'
FLASH = 'FLASH'

def __str__(self):
return self.value
2 changes: 1 addition & 1 deletion requirements-global.txt
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ pytorch-lightning==2.5.1.post0

# diffusion models
#Note: check whether Qwen bugs in diffusers have been fixed before upgrading diffusers (see BaseQwenSetup):
-e git+https://github.com/huggingface/diffusers.git@256e010#egg=diffusers
-e git+https://github.com/huggingface/diffusers.git@6fb4c99#egg=diffusers
gguf==0.17.1
transformers==4.56.2
sentencepiece==0.2.1 # transitive dependency of transformers for tokenizer loading
Expand Down