Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 5 additions & 29 deletions modules/modelSaver/ChromaEmbeddingModelSaver.py
Original file line number Diff line number Diff line change
@@ -1,32 +1,8 @@
from modules.model.ChromaModel import ChromaModel
from modules.modelSaver.BaseModelSaver import BaseModelSaver
from modules.modelSaver.chroma.ChromaEmbeddingSaver import ChromaEmbeddingSaver
from modules.modelSaver.mixin.InternalModelSaverMixin import InternalModelSaverMixin
from modules.util.enum.ModelFormat import ModelFormat
from modules.util.enum.ModelType import ModelType
from modules.modelSaver.GenericEmbeddingModelSaver import make_embedding_model_saver

import torch


class ChromaEmbeddingModelSaver(
BaseModelSaver,
InternalModelSaverMixin,
):
def __init__(self):
super().__init__()

def save(
self,
model: ChromaModel,
model_type: ModelType,
output_model_format: ModelFormat,
output_model_destination: str,
dtype: torch.dtype | None,
):
embedding_model_saver = ChromaEmbeddingSaver()

embedding_model_saver.save_multiple(model, output_model_format, output_model_destination, dtype)
embedding_model_saver.save_single(model, output_model_format, output_model_destination, dtype)

if output_model_format == ModelFormat.INTERNAL:
self._save_internal_data(model, output_model_destination)
ChromaEmbeddingModelSaver = make_embedding_model_saver(
model_class=ChromaModel,
embedding_saver_class=ChromaEmbeddingSaver,
)
36 changes: 6 additions & 30 deletions modules/modelSaver/ChromaFineTuneModelSaver.py
Original file line number Diff line number Diff line change
@@ -1,34 +1,10 @@
from modules.model.ChromaModel import ChromaModel
from modules.modelSaver.BaseModelSaver import BaseModelSaver
from modules.modelSaver.chroma.ChromaEmbeddingSaver import ChromaEmbeddingSaver
from modules.modelSaver.chroma.ChromaModelSaver import ChromaModelSaver
from modules.modelSaver.mixin.InternalModelSaverMixin import InternalModelSaverMixin
from modules.util.enum.ModelFormat import ModelFormat
from modules.util.enum.ModelType import ModelType
from modules.modelSaver.GenericFineTuneModelSaver import make_fine_tune_model_saver

import torch


class ChromaFineTuneModelSaver(
BaseModelSaver,
InternalModelSaverMixin,
):
def __init__(self):
super().__init__()

def save(
self,
model: ChromaModel,
model_type: ModelType,
output_model_format: ModelFormat,
output_model_destination: str,
dtype: torch.dtype | None,
):
base_model_saver = ChromaModelSaver()
embedding_model_saver = ChromaEmbeddingSaver()

base_model_saver.save(model, output_model_format, output_model_destination, dtype)
embedding_model_saver.save_multiple(model, output_model_format, output_model_destination, dtype)

if output_model_format == ModelFormat.INTERNAL:
self._save_internal_data(model, output_model_destination)
ChromaFineTuneModelSaver = make_fine_tune_model_saver(
model_class=ChromaModel,
model_saver_class=ChromaModelSaver,
embedding_saver_class=ChromaEmbeddingSaver,
)
37 changes: 6 additions & 31 deletions modules/modelSaver/ChromaLoRAModelSaver.py
Original file line number Diff line number Diff line change
@@ -1,35 +1,10 @@
from modules.model.ChromaModel import ChromaModel
from modules.modelSaver.BaseModelSaver import BaseModelSaver
from modules.modelSaver.chroma.ChromaEmbeddingSaver import ChromaEmbeddingSaver
from modules.modelSaver.chroma.ChromaLoRASaver import ChromaLoRASaver
from modules.modelSaver.mixin.InternalModelSaverMixin import InternalModelSaverMixin
from modules.util.enum.ModelFormat import ModelFormat
from modules.util.enum.ModelType import ModelType
from modules.modelSaver.GenericLoRAModelSaver import make_lora_model_saver

import torch


class ChromaLoRAModelSaver(
BaseModelSaver,
InternalModelSaverMixin,
):
def __init__(self):
super().__init__()

def save(
self,
model: ChromaModel,
model_type: ModelType,
output_model_format: ModelFormat,
output_model_destination: str,
dtype: torch.dtype | None,
):
lora_model_saver = ChromaLoRASaver()
embedding_model_saver = ChromaEmbeddingSaver()

lora_model_saver.save(model, output_model_format, output_model_destination, dtype)
if not model.train_config.bundle_additional_embeddings or output_model_format == ModelFormat.INTERNAL:
embedding_model_saver.save_multiple(model, output_model_format, output_model_destination, dtype)

if output_model_format == ModelFormat.INTERNAL:
self._save_internal_data(model, output_model_destination)
ChromaLoRAModelSaver = make_lora_model_saver(
model_class=ChromaModel,
lora_saver_class=ChromaLoRASaver,
embedding_saver_class=ChromaEmbeddingSaver,
)
34 changes: 5 additions & 29 deletions modules/modelSaver/FluxEmbeddingModelSaver.py
Original file line number Diff line number Diff line change
@@ -1,32 +1,8 @@
from modules.model.FluxModel import FluxModel
from modules.modelSaver.BaseModelSaver import BaseModelSaver
from modules.modelSaver.flux.FluxEmbeddingSaver import FluxEmbeddingSaver
from modules.modelSaver.mixin.InternalModelSaverMixin import InternalModelSaverMixin
from modules.util.enum.ModelFormat import ModelFormat
from modules.util.enum.ModelType import ModelType
from modules.modelSaver.GenericEmbeddingModelSaver import make_embedding_model_saver

import torch


class FluxEmbeddingModelSaver(
BaseModelSaver,
InternalModelSaverMixin,
):
def __init__(self):
super().__init__()

def save(
self,
model: FluxModel,
model_type: ModelType,
output_model_format: ModelFormat,
output_model_destination: str,
dtype: torch.dtype | None,
):
embedding_model_saver = FluxEmbeddingSaver()

embedding_model_saver.save_multiple(model, output_model_format, output_model_destination, dtype)
embedding_model_saver.save_single(model, output_model_format, output_model_destination, dtype)

if output_model_format == ModelFormat.INTERNAL:
self._save_internal_data(model, output_model_destination)
FluxEmbeddingModelSaver = make_embedding_model_saver(
model_class=FluxModel,
embedding_saver_class=FluxEmbeddingSaver,
)
36 changes: 6 additions & 30 deletions modules/modelSaver/FluxFineTuneModelSaver.py
Original file line number Diff line number Diff line change
@@ -1,34 +1,10 @@
from modules.model.FluxModel import FluxModel
from modules.modelSaver.BaseModelSaver import BaseModelSaver
from modules.modelSaver.flux.FluxEmbeddingSaver import FluxEmbeddingSaver
from modules.modelSaver.flux.FluxModelSaver import FluxModelSaver
from modules.modelSaver.mixin.InternalModelSaverMixin import InternalModelSaverMixin
from modules.util.enum.ModelFormat import ModelFormat
from modules.util.enum.ModelType import ModelType
from modules.modelSaver.GenericFineTuneModelSaver import make_fine_tune_model_saver

import torch


class FluxFineTuneModelSaver(
BaseModelSaver,
InternalModelSaverMixin,
):
def __init__(self):
super().__init__()

def save(
self,
model: FluxModel,
model_type: ModelType,
output_model_format: ModelFormat,
output_model_destination: str,
dtype: torch.dtype | None,
):
base_model_saver = FluxModelSaver()
embedding_model_saver = FluxEmbeddingSaver()

base_model_saver.save(model, output_model_format, output_model_destination, dtype)
embedding_model_saver.save_multiple(model, output_model_format, output_model_destination, dtype)

if output_model_format == ModelFormat.INTERNAL:
self._save_internal_data(model, output_model_destination)
FluxFineTuneModelSaver = make_fine_tune_model_saver(
model_class=FluxModel,
model_saver_class=FluxModelSaver,
embedding_saver_class=FluxEmbeddingSaver,
)
37 changes: 6 additions & 31 deletions modules/modelSaver/FluxLoRAModelSaver.py
Original file line number Diff line number Diff line change
@@ -1,35 +1,10 @@
from modules.model.FluxModel import FluxModel
from modules.modelSaver.BaseModelSaver import BaseModelSaver
from modules.modelSaver.flux.FluxEmbeddingSaver import FluxEmbeddingSaver
from modules.modelSaver.flux.FluxLoRASaver import FluxLoRASaver
from modules.modelSaver.mixin.InternalModelSaverMixin import InternalModelSaverMixin
from modules.util.enum.ModelFormat import ModelFormat
from modules.util.enum.ModelType import ModelType
from modules.modelSaver.GenericLoRAModelSaver import make_lora_model_saver

import torch


class FluxLoRAModelSaver(
BaseModelSaver,
InternalModelSaverMixin,
):
def __init__(self):
super().__init__()

def save(
self,
model: FluxModel,
model_type: ModelType,
output_model_format: ModelFormat,
output_model_destination: str,
dtype: torch.dtype | None,
):
lora_model_saver = FluxLoRASaver()
embedding_model_saver = FluxEmbeddingSaver()

lora_model_saver.save(model, output_model_format, output_model_destination, dtype)
if not model.train_config.bundle_additional_embeddings or output_model_format == ModelFormat.INTERNAL:
embedding_model_saver.save_multiple(model, output_model_format, output_model_destination, dtype)

if output_model_format == ModelFormat.INTERNAL:
self._save_internal_data(model, output_model_destination)
FluxLoRAModelSaver = make_lora_model_saver(
model_class=FluxModel,
lora_saver_class=FluxLoRASaver,
embedding_saver_class=FluxEmbeddingSaver,
)
37 changes: 37 additions & 0 deletions modules/modelSaver/GenericEmbeddingModelSaver.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
from modules.model.BaseModel import BaseModel
from modules.modelSaver.BaseModelSaver import BaseModelSaver
from modules.modelSaver.mixin.InternalModelSaverMixin import InternalModelSaverMixin
from modules.util.enum.ModelFormat import ModelFormat
from modules.util.enum.ModelType import ModelType

import torch


def make_embedding_model_saver(
model_class: type[BaseModel],
embedding_saver_class: type,
):
class GenericEmbeddingModelSaver(
BaseModelSaver,
InternalModelSaverMixin,
):
def __init__(self):
super().__init__()

def save(
self,
model: model_class,
model_type: ModelType,
output_model_format: ModelFormat,
output_model_destination: str,
dtype: torch.dtype | None,
):
embedding_model_saver = embedding_saver_class()

embedding_model_saver.save_multiple(model, output_model_format, output_model_destination, dtype)
embedding_model_saver.save_single(model, output_model_format, output_model_destination, dtype)

if output_model_format == ModelFormat.INTERNAL:
self._save_internal_data(model, output_model_destination)

return GenericEmbeddingModelSaver
40 changes: 40 additions & 0 deletions modules/modelSaver/GenericFineTuneModelSaver.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
from modules.model.BaseModel import BaseModel
from modules.modelSaver.BaseModelSaver import BaseModelSaver
from modules.modelSaver.mixin.InternalModelSaverMixin import InternalModelSaverMixin
from modules.util.enum.ModelFormat import ModelFormat
from modules.util.enum.ModelType import ModelType

import torch


def make_fine_tune_model_saver(
model_class: type[BaseModel],
model_saver_class: type,
embedding_saver_class: type | None,
):
class GenericFineTuneModelSaver(
BaseModelSaver,
InternalModelSaverMixin,
):
def __init__(self):
super().__init__()

def save(
self,
model: model_class,
model_type: ModelType,
output_model_format: ModelFormat,
output_model_destination: str,
dtype: torch.dtype | None,
):
base_model_saver = model_saver_class()
base_model_saver.save(model, output_model_format, output_model_destination, dtype)

if embedding_saver_class is not None:
embedding_model_saver = embedding_saver_class()
embedding_model_saver.save_multiple(model, output_model_format, output_model_destination, dtype)

if output_model_format == ModelFormat.INTERNAL:
self._save_internal_data(model, output_model_destination)

return GenericFineTuneModelSaver
41 changes: 41 additions & 0 deletions modules/modelSaver/GenericLoRAModelSaver.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
from modules.model.BaseModel import BaseModel
from modules.modelSaver.BaseModelSaver import BaseModelSaver
from modules.modelSaver.mixin.InternalModelSaverMixin import InternalModelSaverMixin
from modules.util.enum.ModelFormat import ModelFormat
from modules.util.enum.ModelType import ModelType

import torch


def make_lora_model_saver(
model_class: type[BaseModel],
lora_saver_class: type,
embedding_saver_class: type | None,
):
class GenericLoRAModelSaver(
BaseModelSaver,
InternalModelSaverMixin,
):
def __init__(self):
super().__init__()

def save(
self,
model: model_class,
model_type: ModelType,
output_model_format: ModelFormat,
output_model_destination: str,
dtype: torch.dtype | None,
):
lora_model_saver = lora_saver_class()
lora_model_saver.save(model, output_model_format, output_model_destination, dtype)

if embedding_saver_class is not None:
embedding_model_saver = embedding_saver_class()
if not model.train_config.bundle_additional_embeddings or output_model_format == ModelFormat.INTERNAL:
embedding_model_saver.save_multiple(model, output_model_format, output_model_destination, dtype)

if output_model_format == ModelFormat.INTERNAL:
self._save_internal_data(model, output_model_destination)

return GenericLoRAModelSaver
Loading