Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 36 additions & 3 deletions modules/dataLoader/BaseDataLoader.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
import copy
from abc import ABCMeta, abstractmethod

from modules.dataLoader.mixin.DataLoaderMgdsMixin import DataLoaderMgdsMixin
from modules.model.BaseModel import BaseModel
from modules.modelSetup.BaseModelSetup import BaseModelSetup
from modules.util.config.TrainConfig import TrainConfig
from modules.util.TrainProgress import TrainProgress

from mgds.MGDS import MGDS, TrainDataLoader

Expand All @@ -16,16 +21,44 @@ def __init__(
self,
train_device: torch.device,
temp_device: torch.device,
config: TrainConfig,
model: BaseModel,
model_setup: BaseModelSetup,
train_progress: TrainProgress,
is_validation: bool = False,
):
super().__init__()

self.train_device = train_device
self.temp_device = temp_device

@abstractmethod
if is_validation:
config = copy.copy(config)
config.batch_size = 1
config.multi_gpu = False

self.__ds = self._create_dataset(
config=config,
model=model,
model_setup=model_setup,
train_progress=train_progress,
is_validation=is_validation,
)
self.__dl = TrainDataLoader(self.__ds, config.batch_size)

def get_data_set(self) -> MGDS:
pass
return self.__ds

@abstractmethod
def get_data_loader(self) -> TrainDataLoader:
return self.__dl

@abstractmethod
def _create_dataset(
self,
config: TrainConfig,
model: BaseModel,
model_setup: BaseModelSetup,
train_progress: TrainProgress,
is_validation,
):
pass
172 changes: 30 additions & 142 deletions modules/dataLoader/ChromaBaseDataLoader.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,16 @@
import copy
import os

from modules.dataLoader.BaseDataLoader import BaseDataLoader
from modules.dataLoader.mixin.DataLoaderText2ImageMixin import DataLoaderText2ImageMixin
from modules.model.BaseModel import BaseModel
from modules.model.ChromaModel import ChromaModel
from modules.modelSetup.BaseChromaSetup import BaseChromaSetup
from modules.modelSetup.BaseModelSetup import BaseModelSetup
from modules.util.config.TrainConfig import TrainConfig
from modules.util.torch_util import torch_gc
from modules.util.TrainProgress import TrainProgress

from mgds.MGDS import MGDS, TrainDataLoader
from mgds.pipelineModules.DecodeTokens import DecodeTokens
from mgds.pipelineModules.DecodeVAE import DecodeVAE
from mgds.pipelineModules.DiskCache import DiskCache
from mgds.pipelineModules.EncodeT5Text import EncodeT5Text
from mgds.pipelineModules.EncodeVAE import EncodeVAE
from mgds.pipelineModules.MapData import MapData
Expand All @@ -21,72 +20,34 @@
from mgds.pipelineModules.SaveText import SaveText
from mgds.pipelineModules.ScaleImage import ScaleImage
from mgds.pipelineModules.Tokenize import Tokenize
from mgds.pipelineModules.VariationSorting import VariationSorting

import torch


#TODO share more code with Flux
class ChromaBaseDataLoader(
BaseDataLoader,
DataLoaderText2ImageMixin,
):
def __init__(
self,
train_device: torch.device,
temp_device: torch.device,
config: TrainConfig,
model: ChromaModel,
train_progress: TrainProgress,
is_validation: bool = False,
):
super().__init__(
train_device,
temp_device,
)

if is_validation:
config = copy.copy(config)
config.batch_size = 1
config.multi_gpu = False

self.__ds = self.create_dataset(
config=config,
model=model,
train_progress=train_progress,
is_validation=is_validation,
)
self.__dl = TrainDataLoader(self.__ds, config.batch_size)

def get_data_set(self) -> MGDS:
return self.__ds

def get_data_loader(self) -> TrainDataLoader:
return self.__dl

def _preparation_modules(self, config: TrainConfig, model: ChromaModel):
rescale_image = RescaleImageChannels(image_in_name='image', image_out_name='image', in_range_min=0, in_range_max=1, out_range_min=-1, out_range_max=1)
encode_image = EncodeVAE(in_name='image', out_name='latent_image_distribution', vae=model.vae, autocast_contexts=[model.autocast_context], dtype=model.train_dtype.torch_dtype())
image_sample = SampleVAEDistribution(in_name='latent_image_distribution', out_name='latent_image', mode='mean')
downscale_mask = ScaleImage(in_name='mask', out_name='latent_mask', factor=0.125)
add_embeddings_to_prompt = MapData(in_name='prompt', out_name='prompt', map_fn=model.add_text_encoder_embeddings_to_prompt)
tokenize_prompt = Tokenize(in_name='prompt', tokens_out_name='tokens', mask_out_name='tokens_mask', tokenizer=model.tokenizer, max_token_length=model.tokenizer.model_max_length, expand_mask=1)
encode_prompt = EncodeT5Text(tokens_in_name='tokens', tokens_attention_mask_in_name="tokens_mask", hidden_state_out_name='text_encoder_hidden_state', pooled_out_name=None, add_layer_norm=True, text_encoder=model.text_encoder, hidden_state_output_index=-(1 + config.text_encoder_layer_skip), autocast_contexts=[model.autocast_context, model.text_encoder_autocast_context], dtype=model.text_encoder_train_dtype.torch_dtype())
encode_prompt = EncodeT5Text(tokens_in_name='tokens', tokens_attention_mask_in_name="tokens_mask", hidden_state_out_name='text_encoder_hidden_state', pooled_out_name=None, add_layer_norm=True,
text_encoder=model.text_encoder, hidden_state_output_index=-(1 + config.text_encoder_layer_skip), autocast_contexts=[model.autocast_context, model.text_encoder_autocast_context],
dtype=model.text_encoder_train_dtype.torch_dtype())

modules = [rescale_image, encode_image, image_sample]

modules.append(add_embeddings_to_prompt)
modules.append(tokenize_prompt)

if config.masked_training or config.model_type.has_mask_input():
modules.append(downscale_mask)

modules += [add_embeddings_to_prompt, tokenize_prompt]
if not config.train_text_encoder_or_embedding():
modules.append(encode_prompt)

return modules

def _cache_modules(self, config: TrainConfig, model: ChromaModel):
def _cache_modules(self, config: TrainConfig, model: ChromaModel, model_setup: BaseChromaSetup):
image_split_names = ['latent_image', 'original_resolution', 'crop_offset']

if config.masked_training or config.model_type.has_mask_input():
Expand All @@ -102,53 +63,19 @@ def _cache_modules(self, config: TrainConfig, model: ChromaModel):
]

if not config.train_text_encoder_or_embedding():
text_split_names.append('tokens')
text_split_names.append('tokens_mask')
text_split_names.append('text_encoder_hidden_state')

image_cache_dir = os.path.join(config.cache_dir, "image")
text_cache_dir = os.path.join(config.cache_dir, "text")

#TODO share more code with other models
def before_cache_image_fun():
model.to(self.temp_device)
model.vae_to(self.train_device)
model.eval()
torch_gc()

def before_cache_text_fun():
model.to(self.temp_device)

if not config.train_text_encoder_or_embedding():
model.text_encoder_to(self.train_device)

model.eval()
torch_gc()

image_disk_cache = DiskCache(cache_dir=image_cache_dir, split_names=image_split_names, aggregate_names=image_aggregate_names, variations_in_name='concept.image_variations', balancing_in_name='concept.balancing', balancing_strategy_in_name='concept.balancing_strategy', variations_group_in_name=['concept.path', 'concept.seed', 'concept.include_subdirectories', 'concept.image'], group_enabled_in_name='concept.enabled', before_cache_fun=before_cache_image_fun)

text_disk_cache = DiskCache(cache_dir=text_cache_dir, split_names=text_split_names, aggregate_names=[], variations_in_name='concept.text_variations', balancing_in_name='concept.balancing', balancing_strategy_in_name='concept.balancing_strategy', variations_group_in_name=['concept.path', 'concept.seed', 'concept.include_subdirectories', 'concept.text'], group_enabled_in_name='concept.enabled', before_cache_fun=before_cache_text_fun)

modules = []

if config.latent_caching:
modules.append(image_disk_cache)

if config.latent_caching:
sort_names = [x for x in sort_names if x not in image_aggregate_names]
sort_names = [x for x in sort_names if x not in image_split_names]

if not config.train_text_encoder_or_embedding():
modules.append(text_disk_cache)
sort_names = [x for x in sort_names if x not in text_split_names]

if len(sort_names) > 0:
variation_sorting = VariationSorting(names=sort_names, balancing_in_name='concept.balancing', balancing_strategy_in_name='concept.balancing_strategy', variations_group_in_name=['concept.path', 'concept.seed', 'concept.include_subdirectories', 'concept.text'], group_enabled_in_name='concept.enabled')
modules.append(variation_sorting)

return modules
text_split_names += ['tokens', 'tokens_mask', 'text_encoder_hidden_state']

return self._cache_modules_from_names(
model, model_setup,
image_split_names=image_split_names,
image_aggregate_names=image_aggregate_names,
text_split_names=text_split_names,
sort_names=sort_names,
config=config,
text_caching = not config.train_text_encoder_or_embedding(),
)

def _output_modules(self, config: TrainConfig, model: ChromaModel):
def _output_modules(self, config: TrainConfig, model: ChromaModel, model_setup: BaseChromaSetup):
output_names = [
'image_path', 'latent_image',
'prompt',
Expand All @@ -163,16 +90,10 @@ def _output_modules(self, config: TrainConfig, model: ChromaModel):
if not config.train_text_encoder_or_embedding():
output_names.append('text_encoder_hidden_state')

def before_cache_image_fun():
model.to(self.temp_device)
model.vae_to(self.train_device)
model.eval()
torch_gc()

return self._output_modules_from_out_names(
model, model_setup,
output_names=output_names,
config=config,
before_cache_image_fun=before_cache_image_fun,
use_conditioning_image=False,
vae=model.vae,
autocast_context=[model.autocast_context],
Expand All @@ -197,57 +118,24 @@ def before_save_fun():
# SaveImage(image_in_name='mask', original_path_in_name='image_path', path=debug_dir, in_range_min=0, in_range_max=1),
# SaveImage(image_in_name='image', original_path_in_name='image_path', path=debug_dir, in_range_min=-1, in_range_max=1),

modules = []

modules.append(decode_image)
modules.append(save_image)
modules = [decode_image, save_image]

if config.masked_training or config.model_type.has_mask_input():
modules.append(upscale_mask)
modules.append(save_mask)
modules += [upscale_mask, save_mask]

modules.append(decode_prompt)
modules.append(save_prompt)
modules += [decode_prompt, save_prompt]

return modules

def create_dataset(
def _create_dataset(
self,
config: TrainConfig,
model: ChromaModel,
model: BaseModel,
model_setup: BaseModelSetup,
train_progress: TrainProgress,
is_validation: bool = False,
):
enumerate_input = self._enumerate_input_modules(config)
load_input = self._load_input_modules(config, model.train_dtype)
mask_augmentation = self._mask_augmentation_modules(config)
aspect_bucketing_in = self._aspect_bucketing_in(config, 64)
crop_modules = self._crop_modules(config)
augmentation_modules = self._augmentation_modules(config)
inpainting_modules = self._inpainting_modules(config)
preparation_modules = self._preparation_modules(config, model)
cache_modules = self._cache_modules(config, model)
output_modules = self._output_modules(config, model)

debug_modules = self._debug_modules(config, model)

return self._create_mgds(
config,
[
enumerate_input,
load_input,
mask_augmentation,
aspect_bucketing_in,
crop_modules,
augmentation_modules,
inpainting_modules,
preparation_modules,
cache_modules,
output_modules,

debug_modules if config.debug_mode else None,
# inserted before output_modules, which contains a sorting operation
],
train_progress,
is_validation
return DataLoaderText2ImageMixin._create_dataset(self,
config, model, model_setup, train_progress, is_validation,
aspect_bucketing_quantization=64,
)
Loading