diff --git a/README.md b/README.md index 997f639f5..67a7c3e1a 100644 --- a/README.md +++ b/README.md @@ -45,15 +45,15 @@ OneTrainer is a one-stop solution for all your Diffusion training needs. - Windows: Double click or execute `install.bat` - Linux and Mac: Execute `install.sh` - #### Manual installation - - 1. Clone the repository `git clone https://github.com/Nerogar/OneTrainer.git` - 2. Navigate into the cloned directory `cd OneTrainer` - 3. Set up a virtual environment `python -m venv venv` - 4. Activate the new venv: - - Windows: `venv\scripts\activate` - - Linux and Mac: Depends on your shell, activate the venv accordingly - 5. Install the requirements `pip install -r requirements.txt` +#### Manual installation + +1. Clone the repository `git clone https://github.com/Nerogar/OneTrainer.git` +2. Navigate into the cloned directory `cd OneTrainer` +3. Set up a virtual environment `python -m venv venv` +4. Activate the new venv: + - Windows: `venv\scripts\activate` + - Linux and Mac: Depends on your shell, activate the venv accordingly +5. Install the requirements `pip install -r requirements.txt` > [!Tip] > Some Linux distributions are missing required packages for instance: On Ubuntu you must install `libGL`: diff --git a/modules/dataLoader/Flux2BaseDataLoader.py b/modules/dataLoader/Flux2BaseDataLoader.py new file mode 100644 index 000000000..be02f24a4 --- /dev/null +++ b/modules/dataLoader/Flux2BaseDataLoader.py @@ -0,0 +1,164 @@ +import os + +from modules.dataLoader.BaseDataLoader import BaseDataLoader +from modules.dataLoader.mixin.DataLoaderText2ImageMixin import DataLoaderText2ImageMixin +from modules.model.Flux2Model import ( + MISTRAL_HIDDEN_STATES_LAYERS, + MISTRAL_SYSTEM_MESSAGE, + QWEN3_HIDDEN_STATES_LAYERS, + Flux2Model, + mistral_format_input, + qwen3_format_input, +) +from modules.modelSetup.BaseFlux2Setup import BaseFlux2Setup +from modules.util import factory +from modules.util.config.TrainConfig import TrainConfig +from modules.util.enum.ModelType import ModelType +from modules.util.TrainProgress import TrainProgress + +from mgds.pipelineModules.DecodeTokens import DecodeTokens +from mgds.pipelineModules.DecodeVAE import DecodeVAE +from mgds.pipelineModules.EncodeMistralText import EncodeMistralText +from mgds.pipelineModules.EncodeQwenText import EncodeQwenText +from mgds.pipelineModules.EncodeVAE import EncodeVAE +from mgds.pipelineModules.RescaleImageChannels import RescaleImageChannels +from mgds.pipelineModules.SampleVAEDistribution import SampleVAEDistribution +from mgds.pipelineModules.SaveImage import SaveImage +from mgds.pipelineModules.SaveText import SaveText +from mgds.pipelineModules.ScaleImage import ScaleImage +from mgds.pipelineModules.Tokenize import Tokenize + + +class Flux2BaseDataLoader( + BaseDataLoader, + DataLoaderText2ImageMixin, +): + def _preparation_modules(self, config: TrainConfig, model: Flux2Model): + rescale_image = RescaleImageChannels(image_in_name='image', image_out_name='image', in_range_min=0, in_range_max=1, out_range_min=-1, out_range_max=1) + encode_image = EncodeVAE(in_name='image', out_name='latent_image_distribution', vae=model.vae, autocast_contexts=[model.autocast_context], dtype=model.train_dtype.torch_dtype()) + image_sample = SampleVAEDistribution(in_name='latent_image_distribution', out_name='latent_image', mode='mean') + downscale_mask = ScaleImage(in_name='mask', out_name='latent_mask', factor=0.125) + if model.is_dev(): + tokenize_prompt = Tokenize(in_name='prompt', tokens_out_name='tokens', mask_out_name='tokens_mask', tokenizer=model.tokenizer, max_token_length=config.text_encoder_sequence_length, + apply_chat_template = lambda caption: mistral_format_input([caption], MISTRAL_SYSTEM_MESSAGE), apply_chat_template_kwargs = {'add_generation_prompt': False}, + ) + encode_prompt = EncodeMistralText(tokens_name='tokens', tokens_attention_mask_in_name='tokens_mask', hidden_state_out_name='text_encoder_hidden_state', tokens_attention_mask_out_name='tokens_mask', + text_encoder=model.text_encoder, autocast_contexts=[model.autocast_context], dtype=model.train_dtype.torch_dtype(), + hidden_state_output_index=MISTRAL_HIDDEN_STATES_LAYERS, + ) + else: #klein + tokenize_prompt = Tokenize(in_name='prompt', tokens_out_name='tokens', mask_out_name='tokens_mask', tokenizer=model.tokenizer, max_token_length=config.text_encoder_sequence_length, + apply_chat_template = lambda caption: qwen3_format_input(caption), apply_chat_template_kwargs = {'add_generation_prompt': True, 'enable_thinking': False} + ) + if config.dataloader_threads > 1: + #TODO this code is copied from Z-Image, which also uses Qwen3ForCausalLM. The leak issue probably also applies for Flux2.Klein: + raise NotImplementedError("Multiple data loader threads are not supported due to an issue with the transformers library: https://github.com/huggingface/transformers/issues/42673") + encode_prompt = EncodeQwenText(tokens_name='tokens', tokens_attention_mask_in_name='tokens_mask', hidden_state_out_name='text_encoder_hidden_state', tokens_attention_mask_out_name='tokens_mask', + text_encoder=model.text_encoder, hidden_state_output_index=QWEN3_HIDDEN_STATES_LAYERS, autocast_contexts=[model.autocast_context], dtype=model.train_dtype.torch_dtype()) + + + modules = [rescale_image, encode_image, image_sample] + if config.masked_training or config.model_type.has_mask_input(): + modules.append(downscale_mask) + + modules += [tokenize_prompt, encode_prompt] + return modules + + def _cache_modules(self, config: TrainConfig, model: Flux2Model, model_setup: BaseFlux2Setup): + image_split_names = ['latent_image', 'original_resolution', 'crop_offset'] + + if config.masked_training or config.model_type.has_mask_input(): + image_split_names.append('latent_mask') + + image_aggregate_names = ['crop_resolution', 'image_path'] + + text_split_names = [] + + sort_names = image_aggregate_names + image_split_names + [ + 'prompt', 'tokens', 'tokens_mask', 'text_encoder_hidden_state', + 'concept' + ] + + text_split_names += ['tokens', 'tokens_mask', 'text_encoder_hidden_state'] + + return self._cache_modules_from_names( + model, model_setup, + image_split_names=image_split_names, + image_aggregate_names=image_aggregate_names, + text_split_names=text_split_names, + sort_names=sort_names, + config=config, + text_caching=True, + ) + + def _output_modules(self, config: TrainConfig, model: Flux2Model, model_setup: BaseFlux2Setup): + output_names = [ + 'image_path', 'latent_image', + 'prompt', + 'tokens', + 'tokens_mask', + 'original_resolution', 'crop_resolution', 'crop_offset', + ] + + if config.masked_training or config.model_type.has_mask_input(): + output_names.append('latent_mask') + + output_names.append('text_encoder_hidden_state') + + return self._output_modules_from_out_names( + model, model_setup, + output_names=output_names, + config=config, + use_conditioning_image=False, + vae=model.vae, + autocast_context=[model.autocast_context], + train_dtype=model.train_dtype, + ) + + def _debug_modules(self, config: TrainConfig, model: Flux2Model): + debug_dir = os.path.join(config.debug_dir, "dataloader") + + def before_save_fun(): + model.vae_to(self.train_device) + + decode_image = DecodeVAE(in_name='latent_image', out_name='decoded_image', vae=model.vae, autocast_contexts=[model.autocast_context], dtype=model.train_dtype.torch_dtype()) + upscale_mask = ScaleImage(in_name='latent_mask', out_name='decoded_mask', factor=8) + decode_prompt = DecodeTokens(in_name='tokens', out_name='decoded_prompt', tokenizer=model.tokenizer) + save_image = SaveImage(image_in_name='decoded_image', original_path_in_name='image_path', path=debug_dir, in_range_min=-1, in_range_max=1, before_save_fun=before_save_fun) + # SaveImage(image_in_name='latent_mask', original_path_in_name='image_path', path=debug_dir, in_range_min=0, in_range_max=1, before_save_fun=before_save_fun) + save_mask = SaveImage(image_in_name='decoded_mask', original_path_in_name='image_path', path=debug_dir, in_range_min=0, in_range_max=1, before_save_fun=before_save_fun) + save_prompt = SaveText(text_in_name='decoded_prompt', original_path_in_name='image_path', path=debug_dir, before_save_fun=before_save_fun) + + # These modules don't really work, since they are inserted after a sorting operation that does not include this data + # SaveImage(image_in_name='mask', original_path_in_name='image_path', path=debug_dir, in_range_min=0, in_range_max=1), + # SaveImage(image_in_name='image', original_path_in_name='image_path', path=debug_dir, in_range_min=-1, in_range_max=1), + + modules = [] + + modules.append(decode_image) + modules.append(save_image) + + if config.masked_training or config.model_type.has_mask_input(): + modules.append(upscale_mask) + modules.append(save_mask) + + modules.append(decode_prompt) + modules.append(save_prompt) + + return modules + + def _create_dataset( + self, + config: TrainConfig, + model: Flux2Model, + model_setup: BaseFlux2Setup, + train_progress: TrainProgress, + is_validation: bool = False, + ): + return DataLoaderText2ImageMixin._create_dataset(self, + config, model, model_setup, train_progress, is_validation, + aspect_bucketing_quantization=64, + ) + + +factory.register(BaseDataLoader, Flux2BaseDataLoader, ModelType.FLUX_2) diff --git a/modules/model/Flux2Model.py b/modules/model/Flux2Model.py new file mode 100644 index 000000000..46744a704 --- /dev/null +++ b/modules/model/Flux2Model.py @@ -0,0 +1,330 @@ +import math +from contextlib import nullcontext +from random import Random + +from modules.model.BaseModel import BaseModel +from modules.module.LoRAModule import LoRAModuleWrapper +from modules.util.convert_util import add_prefix, lora_qkv_fusion, qkv_fusion, remove_prefix, swap_chunks +from modules.util.enum.ModelType import ModelType +from modules.util.LayerOffloadConductor import LayerOffloadConductor + +import torch +from torch import Tensor + +from diffusers import ( + AutoencoderKLFlux2, + DiffusionPipeline, + FlowMatchEulerDiscreteScheduler, + Flux2KleinPipeline, + Flux2Pipeline, + Flux2Transformer2DModel, +) +from diffusers.pipelines.flux2.pipeline_flux2 import format_input as mistral_format_input +from transformers import Mistral3ForConditionalGeneration, PixtralProcessor, Qwen2Tokenizer, Qwen3ForCausalLM + +MISTRAL_SYSTEM_MESSAGE = "You are an AI that reasons about image descriptions. You give structured responses focusing on object relationships, object attribution and actions without speculation." +MISTRAL_HIDDEN_STATES_LAYERS = [10, 20, 30] +QWEN3_HIDDEN_STATES_LAYERS = [9, 18, 27] + +def qwen3_format_input(text: str): + return [ + {"role": "user", "content": text}, + ] + + +def diffusers_to_original(qkv_fusion): + return [ + ("context_embedder", "txt_in"), + ("x_embedder", "img_in"), + ("time_guidance_embed.timestep_embedder", "time_in", [ + ("linear_1", "in_layer"), + ("linear_2", "out_layer"), + ]), + ("time_guidance_embed.guidance_embedder", "guidance_in", [ + ("linear_1", "in_layer"), + ("linear_2", "out_layer"), + ]), + ("double_stream_modulation_img.linear", "double_stream_modulation_img.lin"), + ("double_stream_modulation_txt.linear", "double_stream_modulation_txt.lin"), + ("single_stream_modulation.linear", "single_stream_modulation.lin"), + ("proj_out", "final_layer.linear"), + ("norm_out.linear", "final_layer.adaLN_modulation.1", swap_chunks, swap_chunks), + ("transformer_blocks.{i}", "double_blocks.{i}", + qkv_fusion("attn.to_q", "attn.to_k", "attn.to_v", "img_attn.qkv") + \ + qkv_fusion("attn.add_q_proj", "attn.add_k_proj", "attn.add_v_proj", "txt_attn.qkv") + [ + ("attn.norm_k.weight", "img_attn.norm.key_norm.scale"), + ("attn.norm_q.weight", "img_attn.norm.query_norm.scale"), + ("attn.to_out.0", "img_attn.proj"), + ("ff.linear_in", "img_mlp.0"), + ("ff.linear_out", "img_mlp.2"), + ("attn.norm_added_k.weight", "txt_attn.norm.key_norm.scale"), + ("attn.norm_added_q.weight", "txt_attn.norm.query_norm.scale"), + ("attn.to_add_out", "txt_attn.proj"), + ("ff_context.linear_in", "txt_mlp.0"), + ("ff_context.linear_out", "txt_mlp.2"), + ]), + ("single_transformer_blocks.{i}", "single_blocks.{i}", [ + ("attn.to_qkv_mlp_proj", "linear1"), + ("attn.to_out", "linear2"), + ("attn.norm_k.weight", "norm.key_norm.scale"), + ("attn.norm_q.weight", "norm.query_norm.scale"), + ]), + ] + +diffusers_lora_to_original = diffusers_to_original(lora_qkv_fusion) +diffusers_checkpoint_to_original = diffusers_to_original(qkv_fusion) +diffusers_lora_to_comfy = [remove_prefix("transformer"), diffusers_to_original(lora_qkv_fusion), add_prefix("diffusion_model")] + + +class Flux2Model(BaseModel): + # base model data + tokenizer: PixtralProcessor | Qwen2Tokenizer | None + noise_scheduler: FlowMatchEulerDiscreteScheduler | None + text_encoder: Mistral3ForConditionalGeneration | Qwen3ForCausalLM | None + vae: AutoencoderKLFlux2 | None + transformer: Flux2Transformer2DModel | None + + # autocast context + text_encoder_autocast_context: torch.autocast | nullcontext + + text_encoder_offload_conductor: LayerOffloadConductor | None + transformer_offload_conductor: LayerOffloadConductor | None + + transformer_lora: LoRAModuleWrapper | None + lora_state_dict: dict | None + + def __init__( + self, + model_type: ModelType, + ): + super().__init__( + model_type=model_type, + ) + + self.tokenizer = None + self.noise_scheduler = None + self.text_encoder = None + self.vae = None + self.transformer = None + + self.text_encoder_autocast_context = nullcontext() + + self.text_encoder_offload_conductor = None + self.transformer_offload_conductor = None + + self.transformer_lora = None + self.lora_state_dict = None + + def adapters(self) -> list[LoRAModuleWrapper]: + return [a for a in [ + self.transformer_lora, + ] if a is not None] + + def vae_to(self, device: torch.device): + self.vae.to(device=device) + + def text_encoder_to(self, device: torch.device): + if self.text_encoder is not None: + if self.text_encoder_offload_conductor is not None and \ + self.text_encoder_offload_conductor.layer_offload_activated(): + self.text_encoder_offload_conductor.to(device) + else: + self.text_encoder.to(device=device) + + def transformer_to(self, device: torch.device): + if self.transformer_offload_conductor is not None and \ + self.transformer_offload_conductor.layer_offload_activated(): + self.transformer_offload_conductor.to(device) + else: + self.transformer.to(device=device) + + if self.transformer_lora is not None: + self.transformer_lora.to(device) + + def to(self, device: torch.device): + self.vae_to(device) + self.text_encoder_to(device) + self.transformer_to(device) + + def eval(self): + self.vae.eval() + if self.text_encoder is not None: + self.text_encoder.eval() + self.transformer.eval() + + def create_pipeline(self) -> DiffusionPipeline: + klass = Flux2Pipeline if self.is_dev() else Flux2KleinPipeline + return klass( + transformer=self.transformer, + scheduler=self.noise_scheduler, + vae=self.vae, + text_encoder=self.text_encoder, + tokenizer=self.tokenizer, + ) + + def encode_text( + self, + train_device: torch.device, + batch_size: int = 1, #TODO unused + rand: Random | None = None, + text: str = None, + tokens: Tensor = None, + tokens_mask: Tensor = None, + text_encoder_sequence_length: int | None = None, + text_encoder_dropout_probability: float | None = None, + text_encoder_output: Tensor = None, + ) -> tuple[Tensor, Tensor]: + + if tokens is None and text is not None: + if isinstance(text, str): + text = [text] + + if self.is_dev(): + messages = mistral_format_input(prompts=text, system_message=MISTRAL_SYSTEM_MESSAGE) + text = self.tokenizer.apply_chat_template( + messages, + tokenize=False, + add_generation_prompt=False, + ) + + tokenizer_output = self.tokenizer( + text, + max_length=text_encoder_sequence_length, #max length is including system message + padding='max_length', + truncation=True, + return_tensors="pt" + ) + else: #Flux2.Klein + for i, prompt_item in enumerate(text): + messages = qwen3_format_input(prompt_item) + prompt_item = self.tokenizer.apply_chat_template( + messages, + tokenize=False, + add_generation_prompt=True, + enable_thinking=False, + ) + text[i] = prompt_item + + tokenizer_output = self.tokenizer( + text, + max_length=text_encoder_sequence_length, + padding='max_length', + truncation=True, + return_tensors="pt" + ) + + tokens = tokenizer_output.input_ids.to(self.text_encoder.device) + tokens_mask = tokenizer_output.attention_mask.to(self.text_encoder.device) + + if text_encoder_output is None and self.text_encoder is not None: + with self.text_encoder_autocast_context: + text_encoder_output = self.text_encoder( + tokens, + attention_mask=tokens_mask.float(), + output_hidden_states=True, + use_cache=False, + ) + text_encoder_output = torch.cat([text_encoder_output.hidden_states[k] + for k in (MISTRAL_HIDDEN_STATES_LAYERS if self.is_dev() else QWEN3_HIDDEN_STATES_LAYERS)], dim=2) + + if text_encoder_dropout_probability is not None and text_encoder_dropout_probability > 0.0: + raise NotImplementedError #https://github.com/Nerogar/OneTrainer/issues/957 + + return text_encoder_output + + def is_dev(self) -> bool: + return isinstance(self.tokenizer, PixtralProcessor) + + def is_klein(self) -> bool: + return not self.is_dev() + + #code adapted from https://github.com/huggingface/diffusers/blob/c8656ed73c638e51fc2e777a5fd355d69fa5220f/src/diffusers/pipelines/flux2/pipeline_flux2.py + @staticmethod + def prepare_latent_image_ids(latents: torch.Tensor) -> torch.Tensor: + batch_size, _, height, width = latents.shape + + t = torch.arange(1, device=latents.device) + h = torch.arange(height, device=latents.device) + w = torch.arange(width, device=latents.device) + l_ = torch.arange(1, device=latents.device) + + latent_ids = torch.cartesian_prod(t, h, w, l_) + latent_ids = latent_ids.unsqueeze(0).expand(batch_size, -1, -1) + + return latent_ids + + #packing and unpacking on patchified latents + @staticmethod + def pack_latents(latents) -> Tensor: + batch_size, num_channels, height, width = latents.shape + return latents.reshape(batch_size, num_channels, height * width).permute(0, 2, 1) + + @staticmethod + def unpack_latents(latents, height: int, width: int) -> Tensor: + batch_size, seq_len, num_channels = latents.shape + return latents.reshape(batch_size, height, width, num_channels).permute(0, 3, 1, 2) + + #TODO inference code uses empirical mu. But that code cannot be used for training because it depends on num of inference steps + # is dynamic timestep shifting during training still applicable? + #unpatchified width and height + def calculate_timestep_shift(self, latent_height: int, latent_width: int) -> float: + base_seq_len = self.noise_scheduler.config.base_image_seq_len + max_seq_len = self.noise_scheduler.config.max_image_seq_len + base_shift = self.noise_scheduler.config.base_shift + max_shift = self.noise_scheduler.config.max_shift + patch_size = 2 + + image_seq_len = (latent_width // patch_size) * (latent_height // patch_size) + m = (max_shift - base_shift) / (max_seq_len - base_seq_len) + b = base_shift - m * base_seq_len + mu = image_seq_len * m + b + return math.exp(mu) + + @staticmethod + def prepare_text_ids(x: torch.Tensor) -> torch.Tensor: + B, L, _ = x.shape + out_ids = [] + + for _ in range(B): #TODO why iterate? can text ids have different length? according to diffusers and original inference code: no + t = torch.arange(1, device=x.device) + h = torch.arange(1, device=x.device) + w = torch.arange(1, device=x.device) + l_ = torch.arange(L, device=x.device) + + coords = torch.cartesian_prod(t, h, w, l_) + out_ids.append(coords) + + return torch.stack(out_ids) + + @staticmethod + def patchify_latents(latents: torch.Tensor) -> torch.Tensor: + batch_size, num_channels_latents, height, width = latents.shape + latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2) + latents = latents.permute(0, 1, 3, 5, 2, 4) + latents = latents.reshape(batch_size, num_channels_latents * 4, height // 2, width // 2) + return latents + + @staticmethod + def unpatchify_latents(latents: torch.Tensor) -> torch.Tensor: + batch_size, num_channels_latents, height, width = latents.shape + latents = latents.reshape(batch_size, num_channels_latents // (2 * 2), 2, 2, height, width) + latents = latents.permute(0, 1, 4, 2, 5, 3) + latents = latents.reshape(batch_size, num_channels_latents // (2 * 2), height * 2, width * 2) + return latents + + #scaling on patchified latents + def scale_latents(self, latents: Tensor) -> Tensor: + #TODO moves to device - necessary? save in model? + latents_bn_mean = self.vae.bn.running_mean.view(1, -1, 1, 1).to(latents.device, latents.dtype) + latents_bn_std = torch.sqrt(self.vae.bn.running_var.view(1, -1, 1, 1) + self.vae.config.batch_norm_eps).to( + latents.device, latents.dtype + ) + return (latents - latents_bn_mean) / latents_bn_std + + + def unscale_latents(self, latents: Tensor) -> Tensor: + latents_bn_mean = self.vae.bn.running_mean.view(1, -1, 1, 1).to(latents.device, latents.dtype) + latents_bn_std = torch.sqrt(self.vae.bn.running_var.view(1, -1, 1, 1) + self.vae.config.batch_norm_eps).to( + latents.device, latents.dtype + ) + return latents * latents_bn_std + latents_bn_mean diff --git a/modules/model/FluxModel.py b/modules/model/FluxModel.py index c9ec81c88..4b85f1272 100644 --- a/modules/model/FluxModel.py +++ b/modules/model/FluxModel.py @@ -341,7 +341,7 @@ def unpack_latents(self, latents, height: int, width: int): return latents - def calculate_timestep_shift(self, latent_width: int, latent_height: int): + def calculate_timestep_shift(self, latent_height: int, latent_width: int): base_seq_len = self.noise_scheduler.config.base_image_seq_len max_seq_len = self.noise_scheduler.config.max_image_seq_len base_shift = self.noise_scheduler.config.base_shift diff --git a/modules/modelLoader/Flux2ModelLoader.py b/modules/modelLoader/Flux2ModelLoader.py new file mode 100644 index 000000000..20d36293f --- /dev/null +++ b/modules/modelLoader/Flux2ModelLoader.py @@ -0,0 +1,244 @@ +import os +import traceback + +from modules.model.BaseModel import BaseModel +from modules.model.Flux2Model import Flux2Model +from modules.modelLoader.GenericFineTuneModelLoader import make_fine_tune_model_loader +from modules.modelLoader.GenericLoRAModelLoader import make_lora_model_loader +from modules.modelLoader.mixin.HFModelLoaderMixin import HFModelLoaderMixin +from modules.modelLoader.mixin.LoRALoaderMixin import LoRALoaderMixin +from modules.util.config.TrainConfig import QuantizationConfig +from modules.util.convert.lora.convert_lora_util import LoraConversionKeySet +from modules.util.enum.ModelType import ModelType +from modules.util.ModelNames import ModelNames +from modules.util.ModelWeightDtypes import ModelWeightDtypes + +import torch + +from diffusers import ( + AutoencoderKLFlux2, + FlowMatchEulerDiscreteScheduler, + Flux2Transformer2DModel, + GGUFQuantizationConfig, +) +from transformers import ( + Mistral3ForConditionalGeneration, + PixtralProcessor, + Qwen2Tokenizer, + Qwen3ForCausalLM, +) + + +class Flux2ModelLoader( + HFModelLoaderMixin, +): + def __init__(self): + super().__init__() + + def __load_internal( + self, + model: Flux2Model, + model_type: ModelType, + weight_dtypes: ModelWeightDtypes, + base_model_name: str, + transformer_model_name: str, + vae_model_name: str, + quantization: QuantizationConfig, + ): + if os.path.isfile(os.path.join(base_model_name, "meta.json")): + self.__load_diffusers( + model, model_type, weight_dtypes, base_model_name, transformer_model_name, vae_model_name, quantization, + ) + else: + raise Exception("not an internal model") + + def __load_diffusers( + self, + model: Flux2Model, + model_type: ModelType, + weight_dtypes: ModelWeightDtypes, + base_model_name: str, + transformer_model_name: str, + vae_model_name: str, + quantization: QuantizationConfig, + ): + diffusers_sub = [] + transformers_sub = ["text_encoder"] + if not transformer_model_name: + diffusers_sub.append("transformer") + if not vae_model_name: + diffusers_sub.append("vae") + + self._prepare_sub_modules( + base_model_name, + diffusers_modules=diffusers_sub, + transformers_modules=transformers_sub, + ) + + if transformer_model_name: + transformer = Flux2Transformer2DModel.from_single_file( + transformer_model_name, + #avoid loading the transformer in float32: + torch_dtype = torch.bfloat16 if weight_dtypes.transformer.torch_dtype() is None else weight_dtypes.transformer.torch_dtype(), + quantization_config=GGUFQuantizationConfig(compute_dtype=torch.bfloat16) if weight_dtypes.transformer.is_gguf() else None, + ) + transformer = self._convert_diffusers_sub_module_to_dtype( + transformer, weight_dtypes.transformer, weight_dtypes.train_dtype, quantization, + ) + else: + transformer = self._load_diffusers_sub_module( + Flux2Transformer2DModel, + weight_dtypes.transformer, + weight_dtypes.train_dtype, + base_model_name, + "transformer", + quantization, + ) + + if transformer.config.num_attention_heads == 48: #Flux2.Dev + tokenizer = PixtralProcessor.from_pretrained( + base_model_name, + subfolder="tokenizer", + ).tokenizer + + text_encoder = self._load_transformers_sub_module( + Mistral3ForConditionalGeneration, + weight_dtypes.text_encoder, + weight_dtypes.fallback_train_dtype, + base_model_name, + "text_encoder", + ) + else: #Flux2.Klein + tokenizer = Qwen2Tokenizer.from_pretrained( + base_model_name, + subfolder="tokenizer", + ) + text_encoder = self._load_transformers_sub_module( + Qwen3ForCausalLM, + weight_dtypes.text_encoder, + weight_dtypes.fallback_train_dtype, + base_model_name, + "text_encoder", + ) + #TODO this is a tied weight. The dtype conversion code in _load_transformers_sub_module + #currently does not support tied weights. Reconstruct but clone, because the quantization code + #doesn't support tied weights either: + text_encoder.lm_head.weight = type(text_encoder.lm_head.weight)(text_encoder.model.embed_tokens.weight) + + noise_scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained( + base_model_name, + subfolder="scheduler", + ) + + if vae_model_name: + vae = self._load_diffusers_sub_module( + AutoencoderKLFlux2, + weight_dtypes.vae, + weight_dtypes.train_dtype, + vae_model_name, + ) + else: + vae = self._load_diffusers_sub_module( + AutoencoderKLFlux2, + weight_dtypes.vae, + weight_dtypes.train_dtype, + base_model_name, + "vae", + ) + + model.model_type = model_type + model.tokenizer = tokenizer + model.noise_scheduler = noise_scheduler + model.text_encoder = text_encoder + model.vae = vae + model.transformer = transformer + + def __load_safetensors( + self, + model: Flux2Model, + model_type: ModelType, + weight_dtypes: ModelWeightDtypes, + base_model_name: str, + transformer_model_name: str, + vae_model_name: str, + quantization: QuantizationConfig, + ): + #no single file .safetensors for Qwen available at the time of writing this code + raise NotImplementedError("Loading of single file Flux2 models not supported. Use the diffusers model instead. Optionally, transformer-only safetensor files can be loaded by overriding the transformer.") + + def load( + self, + model: Flux2Model, + model_type: ModelType, + model_names: ModelNames, + weight_dtypes: ModelWeightDtypes, + quantization: QuantizationConfig, + ): + stacktraces = [] + + try: + self.__load_internal( + model, model_type, weight_dtypes, model_names.base_model, model_names.transformer_model, model_names.vae_model, quantization, + ) + return + except Exception: + stacktraces.append(traceback.format_exc()) + + try: + self.__load_diffusers( + model, model_type, weight_dtypes, model_names.base_model, model_names.transformer_model, model_names.vae_model, quantization, + ) + return + except Exception: + stacktraces.append(traceback.format_exc()) + + try: + self.__load_safetensors( + model, model_type, weight_dtypes, model_names.base_model, model_names.transformer_model, model_names.vae_model, quantization, + ) + return + except Exception: + stacktraces.append(traceback.format_exc()) + + for stacktrace in stacktraces: + print(stacktrace) + raise Exception("could not load model: " + model_names.base_model) + + + +class Flux2LoRALoader( + LoRALoaderMixin +): + def __init__(self): + super().__init__() + + def _get_convert_key_sets(self, model: BaseModel) -> list[LoraConversionKeySet] | None: + return None #TODO + #return convert_flux_lora_key_sets() + + def load( + self, + model: Flux2Model, + model_names: ModelNames, + ): + return self._load(model, model_names) + + +Flux2LoRAModelLoader = make_lora_model_loader( + model_spec_map={ + ModelType.FLUX_2: "resources/sd_model_spec/flux_2.0-lora.json", + }, + model_class=Flux2Model, + model_loader_class=Flux2ModelLoader, + lora_loader_class=Flux2LoRALoader, + embedding_loader_class=None, +) + +Flux2FineTuneModelLoader = make_fine_tune_model_loader( + model_spec_map={ + ModelType.FLUX_2: "resources/sd_model_spec/flux_2.0.json", + }, + model_class=Flux2Model, + model_loader_class=Flux2ModelLoader, + embedding_loader_class=None, +) diff --git a/modules/modelSampler/Flux2Sampler.py b/modules/modelSampler/Flux2Sampler.py new file mode 100644 index 000000000..146f410b5 --- /dev/null +++ b/modules/modelSampler/Flux2Sampler.py @@ -0,0 +1,198 @@ +import copy +import inspect +from collections.abc import Callable + +from modules.model.Flux2Model import Flux2Model +from modules.modelSampler.BaseModelSampler import BaseModelSampler, ModelSamplerOutput +from modules.util import factory +from modules.util.config.SampleConfig import SampleConfig +from modules.util.enum.AudioFormat import AudioFormat +from modules.util.enum.FileType import FileType +from modules.util.enum.ImageFormat import ImageFormat +from modules.util.enum.ModelType import ModelType +from modules.util.enum.NoiseScheduler import NoiseScheduler +from modules.util.enum.VideoFormat import VideoFormat +from modules.util.torch_util import torch_gc + +import torch + +from diffusers.pipelines.flux2.pipeline_flux2 import compute_empirical_mu + +import numpy as np +from tqdm import tqdm + + +class Flux2Sampler(BaseModelSampler): + def __init__( + self, + train_device: torch.device, + temp_device: torch.device, + model: Flux2Model, + model_type: ModelType, + ): + super().__init__(train_device, temp_device) + + self.model = model + self.model_type = model_type + self.pipeline = model.create_pipeline() + + @torch.no_grad() + def __sample_base( + self, + prompt: str, + negative_prompt: str, + height: int, + width: int, + seed: int, + random_seed: bool, + diffusion_steps: int, + cfg_scale: float, + noise_scheduler: NoiseScheduler, + text_encoder_sequence_length: int | None = None, + on_update_progress: Callable[[int, int], None] = lambda _, __: None, + ) -> ModelSamplerOutput: + with self.model.autocast_context: + generator = torch.Generator(device=self.train_device) + if random_seed: + generator.seed() + else: + generator.manual_seed(seed) + + noise_scheduler = copy.deepcopy(self.model.noise_scheduler) + image_processor = self.pipeline.image_processor + transformer = self.pipeline.transformer + vae = self.pipeline.vae + + vae_scale_factor = 8 + num_latent_channels = 32 + patch_size = 2 + + # prepare prompt + self.model.text_encoder_to(self.train_device) + + batch_size = 2 if cfg_scale > 1.0 and not transformer.config.guidance_embeds else 1 + prompt_embedding = self.model.encode_text( + text=[prompt, negative_prompt] if batch_size == 2 else prompt, + train_device=self.train_device, + text_encoder_sequence_length=text_encoder_sequence_length, + ) + + self.model.text_encoder_to(self.temp_device) + torch_gc() + + # prepare latent image + latent_image = torch.randn( + size=(1, num_latent_channels, height // vae_scale_factor, width // vae_scale_factor), + generator=generator, + device=self.train_device, + dtype=torch.float32, + ) + + latent_image = self.model.patchify_latents(latent_image) + image_ids = self.model.prepare_latent_image_ids(latent_image) + + #TODO test dynamic timestep shifting instead of empirical + #shift = self.model.calculate_timestep_shift(latent_image.shape[-2], latent_image.shape[-1]) + #mu = math.log(shift) + + latent_image = self.model.pack_latents(latent_image) + image_seq_len = latent_image.shape[1] + mu = compute_empirical_mu(image_seq_len, diffusion_steps) + + # prepare timesteps + #TODO for other models, too? This is different than with sigmas=None + sigmas = np.linspace(1.0, 1 / diffusion_steps, diffusion_steps) + noise_scheduler.set_timesteps(diffusion_steps, device=self.train_device, mu=mu, sigmas=sigmas) + timesteps = noise_scheduler.timesteps + + # denoising loop + extra_step_kwargs = {} #TODO remove + if "generator" in set(inspect.signature(noise_scheduler.step).parameters.keys()): + extra_step_kwargs["generator"] = generator + + text_ids = self.model.prepare_text_ids(prompt_embedding) + + + self.model.transformer_to(self.train_device) + guidance = (torch.tensor([cfg_scale], device=self.train_device, dtype=self.model.train_dtype.torch_dtype()) + if transformer.config.guidance_embeds else None) + for i, timestep in enumerate(tqdm(timesteps, desc="sampling")): + latent_model_input = torch.cat([latent_image] * batch_size) + expanded_timestep = timestep.expand(latent_model_input.shape[0]) + + + noise_pred = transformer( + hidden_states=latent_model_input.to(dtype=self.model.train_dtype.torch_dtype()), + timestep=expanded_timestep / 1000, + guidance=guidance, + encoder_hidden_states=prompt_embedding.to(dtype=self.model.train_dtype.torch_dtype()), + txt_ids=text_ids, + img_ids=image_ids, + joint_attention_kwargs=None, + return_dict=True + ).sample + + if batch_size == 2: + noise_pred_positive, noise_pred_negative = noise_pred.chunk(2) + noise_pred = noise_pred_negative + cfg_scale * (noise_pred_positive - noise_pred_negative) + + latent_image = noise_scheduler.step(noise_pred, timestep, latent_image, return_dict=False, **extra_step_kwargs)[0] + + on_update_progress(i + 1, len(timesteps)) + + self.model.transformer_to(self.temp_device) + torch_gc() + self.model.vae_to(self.train_device) + + latent_image = self.model.unpack_latents( + latent_image, + height // vae_scale_factor // patch_size, + width // vae_scale_factor // patch_size, + ) + latents = self.model.unscale_latents(latent_image) + latents = self.model.unpatchify_latents(latents) + + image = vae.decode(latents, return_dict=False)[0] + + image = image_processor.postprocess(image, output_type='pil') + + self.model.vae_to(self.temp_device) + torch_gc() + + return ModelSamplerOutput( + file_type=FileType.IMAGE, + data=image[0], + ) + + def sample( + self, + sample_config: SampleConfig, + destination: str, + image_format: ImageFormat | None = None, + video_format: VideoFormat | None = None, + audio_format: AudioFormat | None = None, + on_sample: Callable[[ModelSamplerOutput], None] = lambda _: None, + on_update_progress: Callable[[int, int], None] = lambda _, __: None, + ): + sampler_output = self.__sample_base( + prompt=sample_config.prompt, + negative_prompt=sample_config.negative_prompt, + height=self.quantize_resolution(sample_config.height, 64), + width=self.quantize_resolution(sample_config.width, 64), + seed=sample_config.seed, + random_seed=sample_config.random_seed, + diffusion_steps=sample_config.diffusion_steps, + cfg_scale=sample_config.cfg_scale, + noise_scheduler=sample_config.noise_scheduler, + text_encoder_sequence_length=sample_config.text_encoder_1_sequence_length, + on_update_progress=on_update_progress, + ) + + self.save_sampler_output( + sampler_output, destination, + image_format, video_format, audio_format, + ) + + on_sample(sampler_output) + +factory.register(BaseModelSampler, Flux2Sampler, ModelType.FLUX_2) diff --git a/modules/modelSampler/FluxSampler.py b/modules/modelSampler/FluxSampler.py index a0e73b593..93f8837ba 100644 --- a/modules/modelSampler/FluxSampler.py +++ b/modules/modelSampler/FluxSampler.py @@ -147,7 +147,6 @@ def __sample_base( self.model.transformer_to(self.temp_device) torch_gc() - latent_image = self.model.unpack_latents( latent_image, height // vae_scale_factor, @@ -160,7 +159,7 @@ def __sample_base( latents = (latent_image / vae.config.scaling_factor) + vae.config.shift_factor image = vae.decode(latents, return_dict=False)[0] - do_denormalize = [True] * image.shape[0] + do_denormalize = [True] * image.shape[0] #TODO remove and test, from Flux and other models. True is the default image = image_processor.postprocess(image, output_type='pil', do_denormalize=do_denormalize) self.model.vae_to(self.temp_device) diff --git a/modules/modelSaver/Flux2FineTuneModelSaver.py b/modules/modelSaver/Flux2FineTuneModelSaver.py new file mode 100644 index 000000000..cbc6b7520 --- /dev/null +++ b/modules/modelSaver/Flux2FineTuneModelSaver.py @@ -0,0 +1,11 @@ +from modules.model.Flux2Model import Flux2Model +from modules.modelSaver.flux2.Flux2ModelSaver import Flux2ModelSaver +from modules.modelSaver.GenericFineTuneModelSaver import make_fine_tune_model_saver +from modules.util.enum.ModelType import ModelType + +Flux2FineTuneModelSaver = make_fine_tune_model_saver( + ModelType.FLUX_2, + model_class=Flux2Model, + model_saver_class=Flux2ModelSaver, + embedding_saver_class=None, +) diff --git a/modules/modelSaver/Flux2LoRAModelSaver.py b/modules/modelSaver/Flux2LoRAModelSaver.py new file mode 100644 index 000000000..cd5507d91 --- /dev/null +++ b/modules/modelSaver/Flux2LoRAModelSaver.py @@ -0,0 +1,11 @@ +from modules.model.Flux2Model import Flux2Model +from modules.modelSaver.flux2.Flux2LoRASaver import Flux2LoRASaver +from modules.modelSaver.GenericLoRAModelSaver import make_lora_model_saver +from modules.util.enum.ModelType import ModelType + +Flux2LoRAModelSaver = make_lora_model_saver( + ModelType.FLUX_2, + model_class=Flux2Model, + lora_saver_class=Flux2LoRASaver, + embedding_saver_class=None, +) diff --git a/modules/modelSaver/flux2/Flux2LoRASaver.py b/modules/modelSaver/flux2/Flux2LoRASaver.py new file mode 100644 index 000000000..15471a82c --- /dev/null +++ b/modules/modelSaver/flux2/Flux2LoRASaver.py @@ -0,0 +1,52 @@ +import os +from pathlib import Path + +from modules.model.Flux2Model import Flux2Model, diffusers_lora_to_comfy +from modules.modelSaver.mixin.LoRASaverMixin import LoRASaverMixin +from modules.util.convert.lora.convert_lora_util import LoraConversionKeySet +from modules.util.convert_util import convert +from modules.util.enum.ModelFormat import ModelFormat + +import torch +from torch import Tensor + +from safetensors.torch import save_file + + +class Flux2LoRASaver( + LoRASaverMixin, +): + def __init__(self): + super().__init__() + + def _get_convert_key_sets(self, model: Flux2Model) -> list[LoraConversionKeySet] | None: + return None + + def _get_state_dict( + self, + model: Flux2Model, + ) -> dict[str, Tensor]: + state_dict = {} + if model.transformer_lora is not None: + state_dict |= model.transformer_lora.state_dict() + if model.lora_state_dict is not None: + state_dict |= model.lora_state_dict + + return state_dict + + def save( + self, + model: Flux2Model, + output_model_format: ModelFormat, + output_model_destination: str, + dtype: torch.dtype | None, + ): + if output_model_format == ModelFormat.COMFY_LORA: + state_dict = self._get_state_dict(model) + save_state_dict = self._convert_state_dict_dtype(state_dict, dtype) + save_state_dict = convert(save_state_dict, diffusers_lora_to_comfy) + + os.makedirs(Path(output_model_destination).parent.absolute(), exist_ok=True) + save_file(save_state_dict, output_model_destination, self._create_safetensors_header(model, save_state_dict)) + else: + self._save(model, output_model_format, output_model_destination, dtype) diff --git a/modules/modelSaver/flux2/Flux2ModelSaver.py b/modules/modelSaver/flux2/Flux2ModelSaver.py new file mode 100644 index 000000000..e2976244e --- /dev/null +++ b/modules/modelSaver/flux2/Flux2ModelSaver.py @@ -0,0 +1,85 @@ +import copy +import os.path +from pathlib import Path + +from modules.model.Flux2Model import Flux2Model, diffusers_checkpoint_to_original +from modules.modelSaver.mixin.DtypeModelSaverMixin import DtypeModelSaverMixin +from modules.util.convert_util import convert +from modules.util.enum.ModelFormat import ModelFormat + +import torch + +from safetensors.torch import save_file + + +class Flux2ModelSaver( + DtypeModelSaverMixin, +): + def __init__(self): + super().__init__() + + def __save_diffusers( + self, + model: Flux2Model, + destination: str, + dtype: torch.dtype | None, + ): + # Copy the model to cpu by first moving the original model to cpu. This preserves some VRAM. + pipeline = model.create_pipeline() + pipeline.to("cpu") + if dtype is not None: #TODO necessary? + # replace the tokenizers __deepcopy__ before calling deepcopy, to prevent a copy being made. + # the tokenizer tries to reload from the file system otherwise + tokenizer = pipeline.tokenizer + tokenizer.__deepcopy__ = lambda memo: tokenizer + + save_pipeline = copy.deepcopy(pipeline) + save_pipeline.to(device="cpu", dtype=dtype, silence_dtype_warnings=True) + + delattr(tokenizer, '__deepcopy__') + else: + save_pipeline = pipeline + + os.makedirs(Path(destination).absolute(), exist_ok=True) + save_pipeline.save_pretrained(destination) + + if dtype is not None: + del save_pipeline + + def __save_safetensors( + self, + model: Flux2Model, + destination: str, + dtype: torch.dtype | None, + ): + state_dict = model.transformer.state_dict() + state_dict = convert(state_dict, diffusers_checkpoint_to_original) + + save_state_dict = self._convert_state_dict_dtype(state_dict, dtype) + self._convert_state_dict_to_contiguous(save_state_dict) + + os.makedirs(Path(destination).parent.absolute(), exist_ok=True) + + save_file(save_state_dict, destination, self._create_safetensors_header(model, save_state_dict)) + + def __save_internal( + self, + model: Flux2Model, + destination: str, + ): + self.__save_diffusers(model, destination, None) + + def save( + self, + model: Flux2Model, + output_model_format: ModelFormat, + output_model_destination: str, + dtype: torch.dtype | None, + ): + match output_model_format: + case ModelFormat.DIFFUSERS: + self.__save_diffusers(model, output_model_destination, dtype) + case ModelFormat.SAFETENSORS: + self.__save_safetensors(model, output_model_destination, dtype) + case ModelFormat.INTERNAL: + self.__save_internal(model, output_model_destination) diff --git a/modules/modelSetup/BaseFlux2Setup.py b/modules/modelSetup/BaseFlux2Setup.py new file mode 100644 index 000000000..f21a1cca1 --- /dev/null +++ b/modules/modelSetup/BaseFlux2Setup.py @@ -0,0 +1,206 @@ +from abc import ABCMeta +from random import Random + +import modules.util.multi_gpu_util as multi +from modules.model.Flux2Model import Flux2Model +from modules.model.FluxModel import FluxModel +from modules.modelSetup.BaseModelSetup import BaseModelSetup +from modules.modelSetup.mixin.ModelSetupDebugMixin import ModelSetupDebugMixin +from modules.modelSetup.mixin.ModelSetupDiffusionLossMixin import ModelSetupDiffusionLossMixin +from modules.modelSetup.mixin.ModelSetupEmbeddingMixin import ModelSetupEmbeddingMixin +from modules.modelSetup.mixin.ModelSetupFlowMatchingMixin import ModelSetupFlowMatchingMixin +from modules.modelSetup.mixin.ModelSetupNoiseMixin import ModelSetupNoiseMixin +from modules.util.checkpointing_util import ( + enable_checkpointing_for_flux2_transformer, + enable_checkpointing_for_mistral_encoder_layers, + enable_checkpointing_for_qwen3_encoder_layers, +) +from modules.util.config.TrainConfig import TrainConfig +from modules.util.dtype_util import create_autocast_context, disable_fp16_autocast_context +from modules.util.enum.TrainingMethod import TrainingMethod +from modules.util.quantization_util import quantize_layers +from modules.util.torch_util import torch_gc +from modules.util.TrainProgress import TrainProgress + +import torch +from torch import Tensor + + +class BaseFlux2Setup( + BaseModelSetup, + ModelSetupDiffusionLossMixin, + ModelSetupDebugMixin, + ModelSetupNoiseMixin, + ModelSetupFlowMatchingMixin, + ModelSetupEmbeddingMixin, + metaclass=ABCMeta +): + LAYER_PRESETS = { + "blocks": ["transformer_block"], + "full": [], + } + + def setup_optimizations( + self, + model: Flux2Model, + config: TrainConfig, + ): + if config.gradient_checkpointing.enabled(): + model.transformer_offload_conductor = \ + enable_checkpointing_for_flux2_transformer(model.transformer, config) + if model.text_encoder is not None: + if model.is_dev(): + model.text_encoder_offload_conductor = \ + enable_checkpointing_for_mistral_encoder_layers(model.text_encoder, config) + else: + model.text_encoder_offload_conductor = \ + enable_checkpointing_for_qwen3_encoder_layers(model.text_encoder, config) + + if config.force_circular_padding: + raise NotImplementedError #TODO applies to Flux2? +# apply_circular_padding_to_conv2d(model.vae) +# apply_circular_padding_to_conv2d(model.transformer) +# if model.transformer_lora is not None: +# apply_circular_padding_to_conv2d(model.transformer_lora) + + model.autocast_context, model.train_dtype = create_autocast_context(self.train_device, config.train_dtype, [ + config.weight_dtypes().transformer, + config.weight_dtypes().text_encoder, + config.weight_dtypes().vae, + config.weight_dtypes().lora if config.training_method == TrainingMethod.LORA else None, + ], config.enable_autocast_cache) + + model.text_encoder_autocast_context, model.text_encoder_train_dtype = \ + disable_fp16_autocast_context( + self.train_device, + config.train_dtype, + config.fallback_train_dtype, + [ + config.weight_dtypes().text_encoder, + config.weight_dtypes().lora if config.training_method == TrainingMethod.LORA else None, + ], + config.enable_autocast_cache, + ) + + quantize_layers(model.text_encoder, self.train_device, model.text_encoder_train_dtype, config) + quantize_layers(model.vae, self.train_device, model.train_dtype, config) + quantize_layers(model.transformer, self.train_device, model.train_dtype, config) + + def predict( + self, + model: Flux2Model, + batch: dict, + config: TrainConfig, + train_progress: TrainProgress, + *, + deterministic: bool = False, + ) -> dict: + with model.autocast_context: + batch_seed = 0 if deterministic else train_progress.global_step * multi.world_size() + multi.rank() + generator = torch.Generator(device=config.train_device) + generator.manual_seed(batch_seed) + rand = Random(batch_seed) + + text_encoder_output = model.encode_text( + train_device=self.train_device, + batch_size=batch['latent_image'].shape[0], + rand=rand, + tokens=batch.get("tokens"), + tokens_mask=batch.get("tokens_mask"), + text_encoder_sequence_length=config.text_encoder_sequence_length, + text_encoder_output=batch.get('text_encoder_hidden_state'), + text_encoder_dropout_probability=config.text_encoder.dropout_probability, + ) + latent_image = model.patchify_latents(batch['latent_image'].float()) + latent_height = latent_image.shape[-2] + latent_width = latent_image.shape[-1] + scaled_latent_image = model.scale_latents(latent_image) + + latent_noise = self._create_noise(scaled_latent_image, config, generator) + + shift = model.calculate_timestep_shift(latent_height, latent_width) + timestep = self._get_timestep_discrete( + model.noise_scheduler.config['num_train_timesteps'], + deterministic, + generator, + scaled_latent_image.shape[0], + config, + shift = shift if config.dynamic_timestep_shifting else config.timestep_shift, + ) + + scaled_noisy_latent_image, sigma = self._add_noise_discrete( + scaled_latent_image, + latent_noise, + timestep, + model.noise_scheduler.timesteps, + ) + latent_input = scaled_noisy_latent_image + + if model.transformer.config.guidance_embeds: + guidance = torch.tensor([config.transformer.guidance_scale], device=self.train_device, dtype=model.train_dtype.torch_dtype()) + guidance = guidance.expand(latent_input.shape[0]) + else: + guidance = None + + text_ids = model.prepare_text_ids(text_encoder_output) + image_ids = model.prepare_latent_image_ids(latent_input) + packed_latent_input = model.pack_latents(latent_input) + + packed_predicted_flow = model.transformer( + hidden_states=packed_latent_input.to(dtype=model.train_dtype.torch_dtype()), + timestep=timestep / 1000, + guidance=guidance, + encoder_hidden_states=text_encoder_output.to(dtype=model.train_dtype.torch_dtype()), + txt_ids=text_ids, + img_ids=image_ids, + joint_attention_kwargs=None, + return_dict=True + ).sample + + predicted_flow = model.unpack_latents( + packed_predicted_flow, + latent_input.shape[2], + latent_input.shape[3], + ) + + flow = latent_noise - scaled_latent_image + model_output_data = { + 'loss_type': 'target', + 'timestep': timestep, + 'predicted': predicted_flow, + 'target': flow, + } + + if config.debug_mode: + with torch.no_grad(): + predicted_scaled_latent_image = scaled_noisy_latent_image - predicted_flow * sigma + self._save_tokens("7-prompt", batch['tokens'], model.tokenizer, config, train_progress) + self._save_latent("1-noise", latent_noise, config, train_progress) + self._save_latent("2-noisy_image", scaled_noisy_latent_image, config, train_progress) + self._save_latent("3-predicted_flow", predicted_flow, config, train_progress) + self._save_latent("4-flow", flow, config, train_progress) + self._save_latent("5-predicted_image", predicted_scaled_latent_image, config, train_progress) + self._save_latent("6-image", scaled_latent_image, config, train_progress) + + return model_output_data + + def calculate_loss( + self, + model: Flux2Model, + batch: dict, + data: dict, + config: TrainConfig, + ) -> Tensor: + return self._flow_matching_losses( + batch=batch, + data=data, + config=config, + train_device=self.train_device, + sigmas=model.noise_scheduler.sigmas, + ).mean() + + def prepare_text_caching(self, model: FluxModel, config: TrainConfig): + model.to(self.temp_device) + model.text_encoder_to(self.train_device) + model.eval() + torch_gc() diff --git a/modules/modelSetup/BaseQwenSetup.py b/modules/modelSetup/BaseQwenSetup.py index dd2d97d0b..b9385b651 100644 --- a/modules/modelSetup/BaseQwenSetup.py +++ b/modules/modelSetup/BaseQwenSetup.py @@ -9,7 +9,7 @@ from modules.modelSetup.mixin.ModelSetupFlowMatchingMixin import ModelSetupFlowMatchingMixin from modules.modelSetup.mixin.ModelSetupNoiseMixin import ModelSetupNoiseMixin from modules.util.checkpointing_util import ( - enable_checkpointing_for_qwen_encoder_layers, + enable_checkpointing_for_qwen25vl_encoder_layers, enable_checkpointing_for_qwen_transformer, ) from modules.util.config.TrainConfig import TrainConfig @@ -50,7 +50,7 @@ def setup_optimizations( enable_checkpointing_for_qwen_transformer(model.transformer, config) if model.text_encoder is not None: model.text_encoder_offload_conductor = \ - enable_checkpointing_for_qwen_encoder_layers(model.text_encoder, config) + enable_checkpointing_for_qwen25vl_encoder_layers(model.text_encoder, config) if config.force_circular_padding: #TODO useful for Qwen? apply_circular_padding_to_conv2d(model.vae) diff --git a/modules/modelSetup/BaseZImageSetup.py b/modules/modelSetup/BaseZImageSetup.py index e792500c1..eedcc96c4 100644 --- a/modules/modelSetup/BaseZImageSetup.py +++ b/modules/modelSetup/BaseZImageSetup.py @@ -10,7 +10,7 @@ from modules.modelSetup.mixin.ModelSetupFlowMatchingMixin import ModelSetupFlowMatchingMixin from modules.modelSetup.mixin.ModelSetupNoiseMixin import ModelSetupNoiseMixin from modules.util.checkpointing_util import ( - enable_checkpointing_for_z_image_encoder_layers, + enable_checkpointing_for_qwen25vl_encoder_layers, enable_checkpointing_for_z_image_transformer, ) from modules.util.config.TrainConfig import TrainConfig @@ -50,7 +50,7 @@ def setup_optimizations( enable_checkpointing_for_z_image_transformer(model.transformer, config) if model.text_encoder is not None: model.text_encoder_offload_conductor = \ - enable_checkpointing_for_z_image_encoder_layers(model.text_encoder, config) + enable_checkpointing_for_qwen25vl_encoder_layers(model.text_encoder, config) if config.force_circular_padding: raise NotImplementedError #TODO applies to Z-Image? diff --git a/modules/modelSetup/Flux2FineTuneSetup.py b/modules/modelSetup/Flux2FineTuneSetup.py new file mode 100644 index 000000000..c157921e5 --- /dev/null +++ b/modules/modelSetup/Flux2FineTuneSetup.py @@ -0,0 +1,88 @@ +from modules.model.Flux2Model import Flux2Model +from modules.modelSetup.BaseFlux2Setup import BaseFlux2Setup +from modules.modelSetup.BaseModelSetup import BaseModelSetup +from modules.util import factory +from modules.util.config.TrainConfig import TrainConfig +from modules.util.enum.ModelType import ModelType +from modules.util.enum.TrainingMethod import TrainingMethod +from modules.util.ModuleFilter import ModuleFilter +from modules.util.NamedParameterGroup import NamedParameterGroupCollection +from modules.util.optimizer_util import init_model_parameters +from modules.util.TrainProgress import TrainProgress + +import torch + + +class Flux2FineTuneSetup( + BaseFlux2Setup, +): + def __init__( + self, + train_device: torch.device, + temp_device: torch.device, + debug_mode: bool, + ): + super().__init__( + train_device=train_device, + temp_device=temp_device, + debug_mode=debug_mode, + ) + + def create_parameters( + self, + model: Flux2Model, + config: TrainConfig, + ) -> NamedParameterGroupCollection: + parameter_group_collection = NamedParameterGroupCollection() + + self._create_model_part_parameters(parameter_group_collection, "transformer", model.transformer, config.transformer, + freeze=ModuleFilter.create(config), debug=config.debug_mode) + return parameter_group_collection + + def __setup_requires_grad( + self, + model: Flux2Model, + config: TrainConfig, + ): + self._setup_model_part_requires_grad("transformer", model.transformer, config.transformer, model.train_progress) + model.vae.requires_grad_(False) + model.text_encoder.requires_grad_(False) + + + def setup_model( + self, + model: Flux2Model, + config: TrainConfig, + ): + self.__setup_requires_grad(model, config) + init_model_parameters(model, self.create_parameters(model, config), self.train_device) + + def setup_train_device( + self, + model: Flux2Model, + config: TrainConfig, + ): + vae_on_train_device = not config.latent_caching + text_encoder_on_train_device = not config.latent_caching + + model.text_encoder_to(self.train_device if text_encoder_on_train_device else self.temp_device) + model.vae_to(self.train_device if vae_on_train_device else self.temp_device) + model.transformer_to(self.train_device) + + model.text_encoder.eval() + model.vae.eval() + + if config.transformer.train: + model.transformer.train() + else: + model.transformer.eval() + + def after_optimizer_step( + self, + model: Flux2Model, + config: TrainConfig, + train_progress: TrainProgress + ): + self.__setup_requires_grad(model, config) + +factory.register(BaseModelSetup, Flux2FineTuneSetup, ModelType.FLUX_2, TrainingMethod.FINE_TUNE) diff --git a/modules/modelSetup/Flux2LoRASetup.py b/modules/modelSetup/Flux2LoRASetup.py new file mode 100644 index 000000000..a9be60d47 --- /dev/null +++ b/modules/modelSetup/Flux2LoRASetup.py @@ -0,0 +1,101 @@ +from modules.model.Flux2Model import Flux2Model +from modules.modelSetup.BaseFlux2Setup import BaseFlux2Setup +from modules.modelSetup.BaseModelSetup import BaseModelSetup +from modules.module.LoRAModule import LoRAModuleWrapper +from modules.util import factory +from modules.util.config.TrainConfig import TrainConfig +from modules.util.enum.ModelType import ModelType +from modules.util.enum.TrainingMethod import TrainingMethod +from modules.util.NamedParameterGroup import NamedParameterGroupCollection +from modules.util.optimizer_util import init_model_parameters +from modules.util.TrainProgress import TrainProgress + +import torch + + +class Flux2LoRASetup( + BaseFlux2Setup, +): + def __init__( + self, + train_device: torch.device, + temp_device: torch.device, + debug_mode: bool, + ): + super().__init__( + train_device=train_device, + temp_device=temp_device, + debug_mode=debug_mode, + ) + + def create_parameters( + self, + model: Flux2Model, + config: TrainConfig, + ) -> NamedParameterGroupCollection: + parameter_group_collection = NamedParameterGroupCollection() + + self._create_model_part_parameters(parameter_group_collection, "transformer_lora", model.transformer_lora, config.transformer) + return parameter_group_collection + + def __setup_requires_grad( + self, + model: Flux2Model, + config: TrainConfig, + ): + model.text_encoder.requires_grad_(False) + model.transformer.requires_grad_(False) + model.vae.requires_grad_(False) + + self._setup_model_part_requires_grad("transformer_lora", model.transformer_lora, config.transformer, model.train_progress) + + def setup_model( + self, + model: Flux2Model, + config: TrainConfig, + ): + model.transformer_lora = LoRAModuleWrapper( + model.transformer, "lora_transformer", config, config.layer_filter.split(",") + ) + + if model.lora_state_dict: + model.transformer_lora.load_state_dict(model.lora_state_dict) + model.lora_state_dict = None + + model.transformer_lora.set_dropout(config.dropout_probability) + model.transformer_lora.to(dtype=config.lora_weight_dtype.torch_dtype()) + model.transformer_lora.hook_to_module() + + self.__setup_requires_grad(model, config) + + init_model_parameters(model, self.create_parameters(model, config), self.train_device) + + def setup_train_device( + self, + model: Flux2Model, + config: TrainConfig, + ): + vae_on_train_device = not config.latent_caching + text_encoder_on_train_device = not config.latent_caching + + model.text_encoder_to(self.train_device if text_encoder_on_train_device else self.temp_device) + model.vae_to(self.train_device if vae_on_train_device else self.temp_device) + model.transformer_to(self.train_device) + + model.text_encoder.eval() + model.vae.eval() + + if config.transformer.train: + model.transformer.train() + else: + model.transformer.eval() + + def after_optimizer_step( + self, + model: Flux2Model, + config: TrainConfig, + train_progress: TrainProgress + ): + self.__setup_requires_grad(model, config) + +factory.register(BaseModelSetup, Flux2LoRASetup, ModelType.FLUX_2, TrainingMethod.LORA) diff --git a/modules/module/quantized/LinearW8A8.py b/modules/module/quantized/LinearW8A8.py index b73745879..3da6342ca 100644 --- a/modules/module/quantized/LinearW8A8.py +++ b/modules/module/quantized/LinearW8A8.py @@ -164,8 +164,8 @@ def torch_backward(a, b): run_benchmark(lambda: torch_backward(y_8, w_8), "torch mm backward int8") run_benchmark(lambda: mm_8bit(y_8, w_8), "triton mm backward int8") - run_benchmark(lambda: int8_forward_tokenwise(x, w_8, w_scale), "torch forward int", compile=True) - run_benchmark(lambda: int8_backward_axiswise(y, w_8, w_scale), "triton backward int", compile=True) + run_benchmark(lambda: int8_forward_tokenwise(x, w_8, w_scale, bias=None, compute_dtype=torch.bfloat16), "torch forward int", compile=True) + run_benchmark(lambda: int8_backward_axiswise(y, w_8, w_scale, bias=None, compute_dtype=torch.bfloat16), "triton backward int", compile=True) @torch.no_grad() diff --git a/modules/trainer/GenericTrainer.py b/modules/trainer/GenericTrainer.py index 0607fe349..52f812b61 100644 --- a/modules/trainer/GenericTrainer.py +++ b/modules/trainer/GenericTrainer.py @@ -158,7 +158,7 @@ def start(self): if self.config.validation: self.validation_data_loader = self.create_data_loader( - self.model, self.model.train_progress, is_validation=True + self.model, self.model_setup, self.model.train_progress, is_validation=True ) def __save_config_to_workspace(self): diff --git a/modules/ui/ModelTab.py b/modules/ui/ModelTab.py index 1e336ab2b..7d52d1748 100644 --- a/modules/ui/ModelTab.py +++ b/modules/ui/ModelTab.py @@ -5,6 +5,7 @@ from modules.util.enum.ConfigPart import ConfigPart from modules.util.enum.DataType import DataType from modules.util.enum.ModelFormat import ModelFormat +from modules.util.enum.ModelType import PeftType from modules.util.enum.TrainingMethod import TrainingMethod from modules.util.ui import components from modules.util.ui.UIState import UIState @@ -55,8 +56,10 @@ def refresh_ui(self): self.__setup_wuerstchen_ui(base_frame) elif self.train_config.model_type.is_pixart(): self.__setup_pixart_alpha_ui(base_frame) - elif self.train_config.model_type.is_flux(): + elif self.train_config.model_type.is_flux_1(): self.__setup_flux_ui(base_frame) + elif self.train_config.model_type.is_flux_2(): + self.__setup_flux_2_ui(base_frame) elif self.train_config.model_type.is_z_image(): self.__setup_z_image_ui(base_frame) elif self.train_config.model_type.is_chroma(): @@ -131,6 +134,26 @@ def __setup_flux_ui(self, frame): allow_legacy_safetensors=self.train_config.training_method == TrainingMethod.LORA, ) + def __setup_flux_2_ui(self, frame): + row = 0 + row = self.__create_base_dtype_components(frame, row) + row = self.__create_base_components( + frame, + row, + has_transformer=True, + allow_override_transformer=True, + has_text_encoder_1=True, + has_vae=True, + ) + row = self.__create_output_components( + frame, + row, + allow_safetensors=True, + allow_diffusers=self.train_config.training_method == TrainingMethod.FINE_TUNE, + allow_legacy_safetensors=self.train_config.training_method == TrainingMethod.LORA, + allow_comfy=self.train_config.training_method == TrainingMethod.LORA and self.train_config.peft_type == PeftType.LORA, + ) + def __setup_z_image_ui(self, frame): row = 0 row = self.__create_base_dtype_components(frame, row) @@ -590,6 +613,7 @@ def __create_output_components( allow_safetensors: bool = False, allow_diffusers: bool = False, allow_legacy_safetensors: bool = False, + allow_comfy: bool = False, ) -> int: # output model destination components.label(frame, row, 0, "Model Output Destination", @@ -617,6 +641,8 @@ def __create_output_components( formats.append(("Diffusers", ModelFormat.DIFFUSERS)) # if allow_legacy_safetensors: # formats.append(("Legacy Safetensors", ModelFormat.LEGACY_SAFETENSORS)) + if allow_comfy: + formats.append(("Comfy", ModelFormat.COMFY_LORA)) components.label(frame, row, 0, "Output Format", tooltip="Format to use when saving the output model") diff --git a/modules/ui/OptimizerParamsWindow.py b/modules/ui/OptimizerParamsWindow.py index 8c3e038c2..ed50e9f57 100644 --- a/modules/ui/OptimizerParamsWindow.py +++ b/modules/ui/OptimizerParamsWindow.py @@ -198,6 +198,7 @@ def create_dynamic_ui( 'approx_mars': {'title': 'Approx MARS-M', 'tooltip': 'Enables Approximated MARS-M, a variance reduction technique. It uses the previous step\'s gradient to correct the current update, leading to lower losses and improved convergence stability. This requires additional state to store the previous gradient.', 'type': 'bool'}, 'kappa_p': {'title': 'Lion-K P-value', 'tooltip': 'Controls the Lp-norm geometry for the Lion update. 1.0 = Standard Lion (Sign update, coordinate-wise), best for Transformers. 2.0 = Spherical Lion (Normalized update, rotational invariant), best for Conv2d layers (in unet models). Values between 1.0 and 2.0 interpolate behavior between the two.', 'type': 'float'}, 'auto_kappa_p': {'title': 'Auto Lion-K', 'tooltip': 'Automatically determines the optimal P-value based on layer dimensions. Uses p=2.0 (Spherical) for 4D (Conv) tensors for stability and rotational invariance, and p=1.0 (Sign) for 2D (Linear) tensors for sparsity. Overrides the manual P-value. Recommend for unet models.', 'type': 'bool'}, + 'compile': {'title': 'Compiled Optimizer', 'tooltip': 'Enables PyTorch compilation for the optimizer internal step logic. This is intended to improve performance by allowing PyTorch to fuse operations and optimize the computational graph.', 'type': 'bool'}, } # @formatter:on diff --git a/modules/ui/TopBar.py b/modules/ui/TopBar.py index c53ea4160..91fd6dfa1 100644 --- a/modules/ui/TopBar.py +++ b/modules/ui/TopBar.py @@ -92,8 +92,9 @@ def __init__( ("Stable Cascade", ModelType.STABLE_CASCADE_1), ("PixArt Alpha", ModelType.PIXART_ALPHA), ("PixArt Sigma", ModelType.PIXART_SIGMA), - ("Flux Dev", ModelType.FLUX_DEV_1), + ("Flux Dev.1", ModelType.FLUX_DEV_1), ("Flux Fill Dev", ModelType.FLUX_FILL_DEV_1), + ("Flux 2 [Dev, Klein]", ModelType.FLUX_2), ("Sana", ModelType.SANA), ("Hunyuan Video", ModelType.HUNYUAN_VIDEO), ("HiDream Full", ModelType.HI_DREAM_FULL), diff --git a/modules/ui/TrainingTab.py b/modules/ui/TrainingTab.py index 931ba8039..2c564e8d3 100644 --- a/modules/ui/TrainingTab.py +++ b/modules/ui/TrainingTab.py @@ -69,8 +69,10 @@ def refresh_ui(self): self.__setup_wuerstchen_ui(column_0, column_1, column_2) elif self.train_config.model_type.is_pixart(): self.__setup_pixart_alpha_ui(column_0, column_1, column_2) - elif self.train_config.model_type.is_flux(): + elif self.train_config.model_type.is_flux_1(): self.__setup_flux_ui(column_0, column_1, column_2) + elif self.train_config.model_type.is_flux_2(): + self.__setup_flux_2_ui(column_0, column_1, column_2) elif self.train_config.model_type.is_chroma(): self.__setup_chroma_ui(column_0, column_1, column_2) elif self.train_config.model_type.is_qwen(): @@ -167,6 +169,18 @@ def __setup_flux_ui(self, column_0, column_1, column_2): self.__create_loss_frame(column_2, 2) self.__create_layer_frame(column_2, 3) + def __setup_flux_2_ui(self, column_0, column_1, column_2): + self.__create_base_frame(column_0, 0) + self.__create_text_encoder_frame(column_0, 1, supports_clip_skip=False, supports_training=False, supports_sequence_length=True) + + self.__create_base2_frame(column_1, 0) + self.__create_transformer_frame(column_1, 1, supports_guidance_scale=True, supports_force_attention_mask=False) + self.__create_noise_frame(column_1, 2, supports_dynamic_timestep_shifting=True) + + self.__create_masked_frame(column_2, 1) + self.__create_loss_frame(column_2, 2) + self.__create_layer_frame(column_2, 3) + def __setup_chroma_ui(self, column_0, column_1, column_2): self.__create_base_frame(column_0, 0) self.__create_text_encoder_frame(column_0, 1) @@ -400,12 +414,11 @@ def __create_base2_frame(self, master, row, video_training_enabled: bool = False tooltip="Enables circular padding for all conv layers to better train seamless images") components.switch(frame, row, 1, self.ui_state, "force_circular_padding") - def __create_text_encoder_frame(self, master, row, supports_clip_skip=True, supports_training=True): + def __create_text_encoder_frame(self, master, row, supports_clip_skip=True, supports_training=True, supports_sequence_length=False): frame = ctk.CTkFrame(master=master, corner_radius=5) frame.grid(row=row, column=0, padx=5, pady=5, sticky="nsew") frame.grid_columnconfigure(0, weight=1) - # train text encoder if supports_training: components.label(frame, 0, 0, "Train Text Encoder", tooltip="Enables training the text encoder model") @@ -434,6 +447,13 @@ def __create_text_encoder_frame(self, master, row, supports_clip_skip=True, supp tooltip="The number of additional clip layers to skip. 0 = the model default") components.entry(frame, 4, 1, self.ui_state, "text_encoder_layer_skip") + if supports_sequence_length: + # text encoder sequence length + components.label(frame, row, 0, "Text Encoder Sequence Length", + tooltip="Number of tokens for captions") + components.entry(frame, row, 1, self.ui_state, "text_encoder_sequence_length") + row += 1 + def __create_text_encoder_n_frame( self, master, diff --git a/modules/util/checkpointing_util.py b/modules/util/checkpointing_util.py index 133f97cf0..74df81769 100644 --- a/modules/util/checkpointing_util.py +++ b/modules/util/checkpointing_util.py @@ -25,6 +25,7 @@ from transformers.models.clip.modeling_clip import CLIPEncoderLayer from transformers.models.gemma2.modeling_gemma2 import Gemma2DecoderLayer from transformers.models.llama.modeling_llama import LlamaDecoderLayer +from transformers.models.mistral.modeling_mistral import MistralDecoderLayer from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLDecoderLayer from transformers.models.qwen3.modeling_qwen3 import Qwen3DecoderLayer from transformers.models.t5.modeling_t5 import T5Block @@ -111,7 +112,6 @@ def __init__(self, orig_module: nn.Module, orig_forward, train_device: torch.dev self.layer_index = layer_index def __checkpointing_forward(self, dummy: torch.Tensor, call_id: int, *args): - if self.layer_index == 0 and not torch.is_grad_enabled(): self.conductor.start_forward(True) @@ -131,7 +131,6 @@ def __checkpointing_forward(self, dummy: torch.Tensor, call_id: int, *args): def forward(self, *args, **kwargs): call_id = _generate_call_index() args = _kwargs_to_args(self.orig_forward if self.checkpoint is None else self.checkpoint.forward, args, kwargs) - if torch.is_grad_enabled(): return torch.utils.checkpoint.checkpoint( self.__checkpointing_forward, @@ -306,7 +305,17 @@ def enable_checkpointing_for_llama_encoder_layers( (LlamaDecoderLayer, []), ]) -def enable_checkpointing_for_qwen_encoder_layers( +def enable_checkpointing_for_mistral_encoder_layers( + model: nn.Module, + config: TrainConfig, +) -> LayerOffloadConductor: + return enable_checkpointing(model, config, False, [ + (MistralDecoderLayer, []), + ]) + + + +def enable_checkpointing_for_qwen25vl_encoder_layers( model: nn.Module, config: TrainConfig, ) -> LayerOffloadConductor: @@ -314,12 +323,12 @@ def enable_checkpointing_for_qwen_encoder_layers( (Qwen2_5_VLDecoderLayer, []), # TODO No activation offloading for other encoders, see above. But clip skip is not implemented for QwenVL. Then do activation offloading? ]) -def enable_checkpointing_for_z_image_encoder_layers( +def enable_checkpointing_for_qwen3_encoder_layers( model: nn.Module, config: TrainConfig, ) -> LayerOffloadConductor: return enable_checkpointing(model, config, False, [ - (Qwen3DecoderLayer, []), # TODO No activation offloading for other encoders, see above. But clip skip is not implemented for QwenVL. Then do activation offloading? + (Qwen3DecoderLayer, []), # No activation offloading, because hidden states are taken from the middle of the network by Flux2 ]) def enable_checkpointing_for_stable_diffusion_3_transformer( @@ -339,6 +348,15 @@ def enable_checkpointing_for_flux_transformer( (model.single_transformer_blocks, ["hidden_states" ]), ]) +def enable_checkpointing_for_flux2_transformer( + model: nn.Module, + config: TrainConfig, +) -> LayerOffloadConductor: + return enable_checkpointing(model, config, config.compile, [ + (model.transformer_blocks, ["hidden_states", "encoder_hidden_states"]), + (model.single_transformer_blocks, ["hidden_states" ]), + ]) + def enable_checkpointing_for_chroma_transformer( model: nn.Module, diff --git a/modules/util/config/SampleConfig.py b/modules/util/config/SampleConfig.py index 4b72345f8..1b2aba652 100644 --- a/modules/util/config/SampleConfig.py +++ b/modules/util/config/SampleConfig.py @@ -19,6 +19,7 @@ class SampleConfig(BaseConfig): noise_scheduler: NoiseScheduler text_encoder_1_layer_skip: int + text_encoder_1_sequence_length: int | None text_encoder_2_layer_skip: int text_encoder_2_sequence_length: int | None text_encoder_3_layer_skip: int @@ -35,6 +36,7 @@ def __init__(self, data: list[(str, Any, type, bool)]): def from_train_config(self, train_config): self.text_encoder_1_layer_skip = train_config.text_encoder_layer_skip + self.text_encoder_1_sequence_length = train_config.text_encoder_sequence_length self.text_encoder_2_layer_skip = train_config.text_encoder_2_layer_skip self.text_encoder_2_sequence_length = train_config.text_encoder_2_sequence_length self.text_encoder_3_layer_skip = train_config.text_encoder_3_layer_skip @@ -60,6 +62,7 @@ def default_values(): data.append(("noise_scheduler", NoiseScheduler.DDIM, NoiseScheduler, False)) data.append(("text_encoder_1_layer_skip", 0, int, False)) + data.append(("text_encoder_1_sequence_length", None, int, True)) data.append(("text_encoder_2_layer_skip", 0, int, False)) data.append(("text_encoder_2_sequence_length", None, int, True)) data.append(("text_encoder_3_layer_skip", 0, int, False)) diff --git a/modules/util/config/TrainConfig.py b/modules/util/config/TrainConfig.py index ddaee4b89..756b4c5bb 100644 --- a/modules/util/config/TrainConfig.py +++ b/modules/util/config/TrainConfig.py @@ -143,6 +143,7 @@ class TrainOptimizerConfig(BaseConfig): approx_mars: False kappa_p: float auto_kappa_p: False + compile: False def __init__(self, data: list[(str, Any, type, bool)]): super().__init__(data) @@ -261,6 +262,7 @@ def default_values(): data.append(("approx_mars", False, bool, False)) data.append(("kappa_p", None, float, True)) data.append(("auto_kappa_p", False, bool, False)) + data.append(("compile", False, bool, False)) return TrainOptimizerConfig(data) @@ -273,7 +275,7 @@ class TrainModelPartConfig(BaseConfig): stop_training_after_unit: TimeUnit learning_rate: float weight_dtype: DataType - dropout_probability: float + dropout_probability: float #this is text encoder caption dropout! train_embedding: bool attention_mask: bool guidance_scale: float @@ -430,7 +432,7 @@ class TrainConfig(BaseConfig): vb_loss_strength: float loss_weight_fn: LossWeight loss_weight_strength: float - dropout_probability: float + dropout_probability: float #this is LoRA dropout! loss_scaler: LossScaler learning_rate_scaler: LearningRateScaler clip_grad_norm: float @@ -1069,6 +1071,7 @@ def default_values() -> 'TrainConfig': text_encoder.learning_rate = None data.append(("text_encoder", text_encoder, TrainModelPartConfig, False)) data.append(("text_encoder_layer_skip", 0, int, False)) + data.append(("text_encoder_sequence_length", 512, int, True)) # text encoder 2 text_encoder_2 = TrainModelPartConfig.default_values() diff --git a/modules/util/create.py b/modules/util/create.py index 99cb74e1a..74edf85d0 100644 --- a/modules/util/create.py +++ b/modules/util/create.py @@ -685,6 +685,7 @@ def create_optimizer( alpha=optimizer_config.alpha if optimizer_config.alpha is not None else 5, kourkoutas_beta=optimizer_config.kourkoutas_beta if optimizer_config.kourkoutas_beta is not None else False, k_warmup_steps=optimizer_config.k_warmup_steps if optimizer_config.k_warmup_steps is not None else 0, + compiled_optimizer=optimizer_config.compile if optimizer_config.compile is not None else False, ) # ADOPT_ADV Optimizer @@ -711,6 +712,7 @@ def create_optimizer( alpha_grad=optimizer_config.alpha_grad if optimizer_config.alpha_grad is not None else 100, kourkoutas_beta=optimizer_config.kourkoutas_beta if optimizer_config.kourkoutas_beta is not None else False, k_warmup_steps=optimizer_config.k_warmup_steps if optimizer_config.k_warmup_steps is not None else 0, + compiled_optimizer=optimizer_config.compile if optimizer_config.compile is not None else False, ) # PRODIGY_ADV Optimizer @@ -744,6 +746,7 @@ def create_optimizer( alpha_grad=optimizer_config.alpha_grad if optimizer_config.alpha_grad is not None else 100, kourkoutas_beta=optimizer_config.kourkoutas_beta if optimizer_config.kourkoutas_beta is not None else False, k_warmup_steps=optimizer_config.k_warmup_steps if optimizer_config.k_warmup_steps is not None else 0, + compiled_optimizer=optimizer_config.compile if optimizer_config.compile is not None else False, ) # SIMPLIFIED_AdEMAMix Optimizer @@ -766,6 +769,24 @@ def create_optimizer( orthogonal_gradient=optimizer_config.orthogonal_gradient if optimizer_config.orthogonal_gradient is not None else False, kourkoutas_beta=optimizer_config.kourkoutas_beta if optimizer_config.kourkoutas_beta is not None else False, k_warmup_steps=optimizer_config.k_warmup_steps if optimizer_config.k_warmup_steps is not None else 0, + compiled_optimizer=optimizer_config.compile if optimizer_config.compile is not None else False, + ) + + # SignSGD_ADV Optimizer + case Optimizer.SIGNSGD_ADV: + from adv_optm import SignSGD_adv + optimizer = SignSGD_adv( + params=parameters, + lr=config.learning_rate, + momentum=optimizer_config.momentum if optimizer_config.momentum is not None else 0, + weight_decay=optimizer_config.weight_decay if optimizer_config.weight_decay is not None else 0.0, + nnmf_factor=optimizer_config.nnmf_factor if optimizer_config.nnmf_factor is not None else False, + cautious_wd=optimizer_config.cautious_wd if optimizer_config.cautious_wd is not None else False, + stochastic_rounding=optimizer_config.stochastic_rounding, + orthogonal_gradient=optimizer_config.orthogonal_gradient if optimizer_config.orthogonal_gradient is not None else False, + compiled_optimizer=optimizer_config.compile if optimizer_config.compile is not None else False, + Simplified_AdEMAMix=optimizer_config.Simplified_AdEMAMix if optimizer_config.Simplified_AdEMAMix is not None else False, + alpha_grad=optimizer_config.alpha_grad if optimizer_config.alpha_grad is not None else 100, ) # LION_ADV Optimizer @@ -785,6 +806,7 @@ def create_optimizer( orthogonal_gradient=optimizer_config.orthogonal_gradient if optimizer_config.orthogonal_gradient is not None else False, kappa_p=optimizer_config.kappa_p if optimizer_config.kappa_p is not None else 1.0, auto_kappa_p=optimizer_config.auto_kappa_p if optimizer_config.auto_kappa_p is not None else False, + compiled_optimizer=optimizer_config.compile if optimizer_config.compile is not None else False, ) # LION_PRODIGY_ADV Optimizer @@ -811,6 +833,7 @@ def create_optimizer( orthogonal_gradient=optimizer_config.orthogonal_gradient if optimizer_config.orthogonal_gradient is not None else False, kappa_p=optimizer_config.kappa_p if optimizer_config.kappa_p is not None else 1.0, auto_kappa_p=optimizer_config.auto_kappa_p if optimizer_config.auto_kappa_p is not None else False, + compiled_optimizer=optimizer_config.compile if optimizer_config.compile is not None else False, ) # MUON_ADV Optimizer @@ -859,6 +882,7 @@ def create_optimizer( accelerated_ns=optimizer_config.accelerated_ns if optimizer_config.accelerated_ns is not None else False, orthogonal_gradient=optimizer_config.orthogonal_gradient if optimizer_config.orthogonal_gradient is not None else False, approx_mars=optimizer_config.approx_mars if optimizer_config.approx_mars is not None else False, + compiled_optimizer=optimizer_config.compile if optimizer_config.compile is not None else False, **adam_kwargs ) @@ -912,6 +936,7 @@ def create_optimizer( accelerated_ns=optimizer_config.accelerated_ns if optimizer_config.accelerated_ns is not None else False, orthogonal_gradient=optimizer_config.orthogonal_gradient if optimizer_config.orthogonal_gradient is not None else False, approx_mars=optimizer_config.approx_mars if optimizer_config.approx_mars is not None else False, + compiled_optimizer=optimizer_config.compile if optimizer_config.compile is not None else False, **adam_kwargs ) diff --git a/modules/util/enum/ModelFormat.py b/modules/util/enum/ModelFormat.py index 597ad4442..70193a61b 100644 --- a/modules/util/enum/ModelFormat.py +++ b/modules/util/enum/ModelFormat.py @@ -6,6 +6,7 @@ class ModelFormat(Enum): CKPT = 'CKPT' SAFETENSORS = 'SAFETENSORS' LEGACY_SAFETENSORS = 'LEGACY_SAFETENSORS' + COMFY_LORA = 'COMFY_LORA' INTERNAL = 'INTERNAL' # an internal format that stores all information to resume training @@ -23,6 +24,8 @@ def file_extension(self) -> str: return '.safetensors' case ModelFormat.LEGACY_SAFETENSORS: return '.safetensors' + case ModelFormat.COMFY_LORA: + return '.safetensors' case _: return '' diff --git a/modules/util/enum/ModelType.py b/modules/util/enum/ModelType.py index bb8740e97..f3da05941 100644 --- a/modules/util/enum/ModelType.py +++ b/modules/util/enum/ModelType.py @@ -25,6 +25,7 @@ class ModelType(Enum): FLUX_DEV_1 = 'FLUX_DEV_1' FLUX_FILL_DEV_1 = 'FLUX_FILL_DEV_1' + FLUX_2 = 'FLUX_2' SANA = 'SANA' @@ -77,9 +78,17 @@ def is_pixart_sigma(self): return self == ModelType.PIXART_SIGMA def is_flux(self): + return self == ModelType.FLUX_DEV_1 \ + or self == ModelType.FLUX_FILL_DEV_1 \ + or self == ModelType.FLUX_2 + + def is_flux_1(self): return self == ModelType.FLUX_DEV_1 \ or self == ModelType.FLUX_FILL_DEV_1 + def is_flux_2(self): + return self == ModelType.FLUX_2 + def is_chroma(self): return self == ModelType.CHROMA_1 @@ -116,7 +125,7 @@ def has_depth_input(self): def has_multiple_text_encoders(self): return self.is_stable_diffusion_3() \ or self.is_stable_diffusion_xl() \ - or self.is_flux() \ + or self.is_flux_1() \ or self.is_hunyuan_video() \ or self.is_hi_dream() \ diff --git a/modules/util/enum/Optimizer.py b/modules/util/enum/Optimizer.py index c3edfb837..58edf419c 100644 --- a/modules/util/enum/Optimizer.py +++ b/modules/util/enum/Optimizer.py @@ -42,6 +42,7 @@ class Optimizer(Enum): # 32 bit is torch and not bnb SGD = 'SGD' SGD_8BIT = 'SGD_8BIT' + SIGNSGD_ADV = 'SIGNSGD_ADV' # Schedule-free optimizers SCHEDULE_FREE_ADAMW = 'SCHEDULE_FREE_ADAMW' @@ -116,6 +117,7 @@ def supports_fused_back_pass(self): Optimizer.LION_PRODIGY_ADV, Optimizer.MUON_ADV, Optimizer.ADAMUON_ADV, + Optimizer.SIGNSGD_ADV, ] # Small helper for adjusting learning rates to adaptive optimizers. diff --git a/modules/util/optimizer_util.py b/modules/util/optimizer_util.py index add2bec24..f879bb5d9 100644 --- a/modules/util/optimizer_util.py +++ b/modules/util/optimizer_util.py @@ -457,6 +457,7 @@ def init_model_parameters( "use_bias_correction": True, "nnmf_factor": False, "stochastic_rounding": True, + "compile": False, "fused_back_pass": False, "use_atan2": False, "cautious_mask": False, @@ -476,6 +477,7 @@ def init_model_parameters( "weight_decay": 0.0, "nnmf_factor": False, "stochastic_rounding": True, + "compile": False, "fused_back_pass": False, "use_atan2": False, "cautious_mask": False, @@ -498,6 +500,7 @@ def init_model_parameters( "weight_decay": 0.0, "nnmf_factor": False, "stochastic_rounding": True, + "compile": False, "fused_back_pass": False, "d0": 1e-6, "d_coef": 1.0, @@ -529,11 +532,24 @@ def init_model_parameters( "use_bias_correction": True, "nnmf_factor": False, "stochastic_rounding": True, + "compile": False, "fused_back_pass": False, "orthogonal_gradient": False, "kourkoutas_beta": False, "k_warmup_steps": None, }, + Optimizer.SIGNSGD_ADV: { + "momentum": 0.99, + "cautious_wd": False, + "weight_decay": 0.0, + "nnmf_factor": False, + "stochastic_rounding": True, + "compiled_optimizer": False, + "fused_back_pass": False, + "orthogonal_gradient": False, + "Simplified_AdEMAMix": False, + "alpha_grad": 100.0, + }, Optimizer.LION_ADV: { "beta1": 0.9, "beta2": 0.99, @@ -542,6 +558,7 @@ def init_model_parameters( "clip_threshold": None, "nnmf_factor": False, "stochastic_rounding": True, + "compile": False, "fused_back_pass": False, "cautious_mask": False, "orthogonal_gradient": False, @@ -557,6 +574,7 @@ def init_model_parameters( "clip_threshold": None, "nnmf_factor": False, "stochastic_rounding": True, + "compile": False, "fused_back_pass": False, "d0": 1e-6, "d_coef": 1.0, @@ -580,6 +598,7 @@ def init_model_parameters( "rms_rescaling": True, "nnmf_factor": False, "stochastic_rounding": True, + "compile": False, "fused_back_pass": False, "MuonWithAuxAdam": True, "muon_hidden_layers": None, @@ -610,6 +629,7 @@ def init_model_parameters( "rms_rescaling": True, "nnmf_factor": False, "stochastic_rounding": True, + "compile": False, "fused_back_pass": False, "MuonWithAuxAdam": True, "muon_hidden_layers": None, diff --git a/requirements-global.txt b/requirements-global.txt index 561d7195d..8bec198c0 100644 --- a/requirements-global.txt +++ b/requirements-global.txt @@ -8,7 +8,7 @@ PyYAML==6.0.2 huggingface-hub==0.34.4 scipy==1.15.3 matplotlib==3.10.3 -av==14.4.0 +av==16.1.0 yt-dlp #no pinned version, frequently updated for compatibility with sites scenedetect==0.6.6 parse==1.20.2 @@ -21,7 +21,9 @@ pytorch-lightning==2.5.1.post0 # diffusion models #Note: check whether Qwen bugs in diffusers have been fixed before upgrading diffusers (see BaseQwenSetup): --e git+https://github.com/huggingface/diffusers.git@256e010#egg=diffusers +#-e git+https://github.com/huggingface/diffusers.git@256e010#egg=diffusers +#FIXME in this release, Qwen bugs have been fixed by diffusers. Remove workarounds in OneTrainer +-e git+https://github.com/dxqb/diffusers.git@flux2_tuples#egg=diffusers gguf==0.17.1 transformers==4.56.2 sentencepiece==0.2.1 # transitive dependency of transformers for tokenizer loading @@ -33,7 +35,7 @@ pooch==1.8.2 open-clip-torch==2.32.0 # data loader --e git+https://github.com/Nerogar/mgds.git@385578f#egg=mgds +-e git+https://github.com/dxqb/mgds.git@flux2_klein#egg=mgds # optimizers dadaptation==3.2 # dadaptation optimizers @@ -42,7 +44,7 @@ prodigyopt==1.1.2 # prodigy optimizer schedulefree==1.4.1 # schedule-free optimizers pytorch_optimizer==3.6.0 # pytorch optimizers prodigy-plus-schedule-free==2.0.1 # Prodigy plus optimizer -adv_optm==1.4.0 # advanced optimizers +adv_optm==2.1.0 # advanced optimizers -e git+https://github.com/KellerJordan/Muon.git@f90a42b#egg=muon-optimizer # Profiling @@ -57,5 +59,5 @@ fabric==3.2.2 # debug psutil==7.0.0 -requests==2.32.3 +requests==2.32.5 deepdiff==8.6.1 # output easy to read diff for troublshooting diff --git a/resources/sd_model_spec/flux_2.0-lora.json b/resources/sd_model_spec/flux_2.0-lora.json new file mode 100644 index 000000000..d2fb7ed78 --- /dev/null +++ b/resources/sd_model_spec/flux_2.0-lora.json @@ -0,0 +1,6 @@ +{ + "modelspec.sai_model_spec": "1.0.0", + "modelspec.architecture": "Flux.2/lora", + "modelspec.implementation": "https://github.com/huggingface/diffusers", + "modelspec.title": "Flux 2.0 LoRA" +} diff --git a/resources/sd_model_spec/flux_2.0.json b/resources/sd_model_spec/flux_2.0.json new file mode 100644 index 000000000..9648a7d1a --- /dev/null +++ b/resources/sd_model_spec/flux_2.0.json @@ -0,0 +1,6 @@ +{ + "modelspec.sai_model_spec": "1.0.0", + "modelspec.architecture": "Flux.2", + "modelspec.implementation": "https://github.com/huggingface/diffusers", + "modelspec.title": "Flux 2.0" +} diff --git a/training_presets/#flux2 LoRA 16GB.json b/training_presets/#flux2 LoRA 16GB.json new file mode 100644 index 000000000..cb770f1f3 --- /dev/null +++ b/training_presets/#flux2 LoRA 16GB.json @@ -0,0 +1,31 @@ +{ + "base_model_name": "black-forest-labs/FLUX.2-klein-base-9B", + "batch_size": 2, + "learning_rate": 0.0003, + "model_type": "FLUX_2", + "resolution": "512", + "compile": true, + "transformer": { + "train": true, + "weight_dtype": "INT_W8A8" + }, + "text_encoder": { + "train": false, + "weight_dtype": "FLOAT_8" + }, + "training_method": "LORA", + "vae": { + "weight_dtype": "FLOAT_32" + }, + "train_dtype": "BFLOAT_16", + "output_dtype": "BFLOAT_16", + "layer_filter": "blocks", + "layer_filter_preset": "transformer_block", + "quantization": { + "layer_filter": "transformer_block", + "layer_filter_preset": "blocks" + }, + "timestep_distribution": "LOGIT_NORMAL", + "dataloader_threads": 1, + "output_model_format": "COMFY_LORA" +}