diff --git a/backend/app.py b/backend/app.py index 69e9a9a9..9c238ff2 100644 --- a/backend/app.py +++ b/backend/app.py @@ -95,6 +95,7 @@ async def get_state(): "scripts_img2img": get_scripts_metadata(True), "face_restorers": [model.name() for model in shared.face_restorers], "sd_models": modules.sd_models.checkpoint_tiles(), # yes internal API has spelling error + "sd_vaes": ["None", "Automatic" ] + (list(modules.sd_vae.vae_dict)) } diff --git a/backend/config.py b/backend/config.py index 38d56bbd..a9944529 100644 --- a/backend/config.py +++ b/backend/config.py @@ -26,6 +26,12 @@ class BaseOptions(BaseModel): class GenerationOptions(BaseModel): sd_model: str = "model.ckpt" """Model to use for generation.""" + sd_vae: str = "Automatic" + """VAE to use for generation.""" + + clip_skip: int = 1 + """CLIP layers to skip during generation.""" + script: str = "None" """Which script to use.""" script_args: list = Field(default_factory=list) diff --git a/backend/structs.py b/backend/structs.py index c4cbe191..8f6c2f95 100644 --- a/backend/structs.py +++ b/backend/structs.py @@ -71,6 +71,8 @@ class ConfigResponse(PluginOptions): """List of available face restorers.""" sd_models: List[str] """List of available models.""" + sd_vaes: List[str] + """List of available VAEs.""" class ImageResponse(BaseModel): diff --git a/backend/utils.py b/backend/utils.py index 89ede617..cff7c6d5 100644 --- a/backend/utils.py +++ b/backend/utils.py @@ -85,6 +85,13 @@ def prepare_backend(opt: BaseModel): shared.opts.sd_model_checkpoint = opt.sd_model modules.sd_models.reload_model_weights(shared.sd_model) + if hasattr(opt, "sd_vae"): + shared.opts.sd_vae = opt.sd_vae + modules.sd_vae.reload_vae_weights() + + if hasattr(opt, "clip_skip"): + shared.opts.CLIP_stop_at_last_layers = opt.clip_skip + if hasattr(opt, "upscaler_name"): shared.opts.upscaler_for_img2img = opt.upscaler_name diff --git a/frontends/krita/krita_diff/client.py b/frontends/krita/krita_diff/client.py index e53f019e..2f89a393 100644 --- a/frontends/krita/krita_diff/client.py +++ b/frontends/krita/krita_diff/client.py @@ -220,6 +220,8 @@ def common_params(self, has_selection): # its fine to stuff extra stuff here; pydantic will shave off irrelevant params params = dict( sd_model=self.cfg("sd_model", str), + sd_vae=self.cfg("sd_vae", str), + clip_skip=self.cfg("clip_skip", int), batch_count=self.cfg("sd_batch_count", int), batch_size=self.cfg("sd_batch_size", int), base_size=self.cfg("sd_base_size", int), @@ -246,6 +248,7 @@ def cb(obj): assert len(obj["samplers_img2img"]) > 0 assert len(obj["face_restorers"]) > 0 assert len(obj["sd_models"]) > 0 + assert len(obj["sd_vaes"]) > 0 assert len(obj["scripts_txt2img"]) > 0 assert len(obj["scripts_img2img"]) > 0 except: @@ -269,6 +272,7 @@ def cb(obj): self.cfg.set("inpaint_script_list", list(obj["scripts_img2img"].keys())) self.cfg.set("face_restorer_model_list", obj["face_restorers"]) self.cfg.set("sd_model_list", obj["sd_models"]) + self.cfg.set("sd_vae_list", ["Automatic", "None"] + obj["sd_vaes"]) # extension script cfg obj["scripts_inpaint"] = obj["scripts_img2img"] diff --git a/frontends/krita/krita_diff/defaults.py b/frontends/krita/krita_diff/defaults.py index 7b7090ac..02508df2 100644 --- a/frontends/krita/krita_diff/defaults.py +++ b/frontends/krita/krita_diff/defaults.py @@ -63,6 +63,9 @@ class Defaults: sd_model_list: List[str] = field(default_factory=lambda: [ERROR_MSG]) sd_model: str = "model.ckpt" + sd_vae_list: List[str] = field(default_factory=lambda: [ERROR_MSG]) + sd_vae: str = "Automatic" + clip_skip: int = 1 sd_batch_size: int = 1 sd_batch_count: int = 1 sd_base_size: int = 512 diff --git a/frontends/krita/krita_diff/pages/common.py b/frontends/krita/krita_diff/pages/common.py index 2cf0c108..f0463a95 100644 --- a/frontends/krita/krita_diff/pages/common.py +++ b/frontends/krita/krita_diff/pages/common.py @@ -21,6 +21,16 @@ def __init__(self, *args, **kwargs): script.cfg, "sd_model_list", "sd_model", label="SD model:" ) + # VAE list + self.sd_vae_layout = QComboBoxLayout( + script.cfg, "sd_vae_list", "sd_vae", label="VAE:" + ) + + # Clip skip + self.clip_skip_layout = QSpinBoxLayout( + script.cfg, "clip_skip", label="Clip skip:", min=1, max=12, step=1 + ) + # batch size & count self.batch_count_layout = QSpinBoxLayout( script.cfg, "sd_batch_count", label="Batch count:", min=1, max=9999, step=1 @@ -83,6 +93,8 @@ def __init__(self, *args, **kwargs): layout.addLayout(self.codeformer_weight_layout) layout.addLayout(checkboxes_layout) layout.addLayout(self.sd_model_layout) + layout.addLayout(self.sd_vae_layout) + layout.addLayout(self.clip_skip_layout) layout.addLayout(batch_layout) layout.addLayout(size_layout) layout.addWidget(self.interrupt_btn) @@ -92,6 +104,8 @@ def __init__(self, *args, **kwargs): def cfg_init(self): self.sd_model_layout.cfg_init() + self.sd_vae_layout.cfg_init() + self.clip_skip_layout.cfg_init() self.batch_count_layout.cfg_init() self.batch_size_layout.cfg_init() self.base_size_layout.cfg_init() @@ -106,6 +120,8 @@ def cfg_init(self): def cfg_connect(self): self.sd_model_layout.cfg_connect() + self.sd_vae_layout.cfg_connect() + self.clip_skip_layout.cfg_connect() self.batch_count_layout.cfg_connect() self.batch_size_layout.cfg_connect() self.base_size_layout.cfg_connect()