diff --git a/optimum/exporters/openvino/model_configs.py b/optimum/exporters/openvino/model_configs.py index 75e686316..9438b3aea 100644 --- a/optimum/exporters/openvino/model_configs.py +++ b/optimum/exporters/openvino/model_configs.py @@ -604,6 +604,90 @@ def outputs(self) -> Dict[str, Dict[int, str]]: } +@register_in_tasks_manager("unet-controlnet", *["semantic-segmentation"], library_name="diffusers") +class UNetControlNetOpenVINOConfig(UNetOnnxConfig): + @property + def inputs(self) -> Dict[str, Dict[int, str]]: + common_inputs = { + "sample": {0: "batch_size", 2: "height", 3: "width"}, + "timestep": {0: "steps"}, + "encoder_hidden_states": {0: "batch_size", 1: "sequence_length"}, + "mid_block_additional_residual": {0: "batch_size", 2: "height", 3: "width"}, + } + for a in range(1, 25, 2): + if a == 23: + common_inputs["down_block_additional_residual"] = {0: "batch_size", 2: "height", 3: "width"} + break + else: + common_inputs[f"down_block_additional_residual.{a}"] = {0: "batch_size", 2: "height", 3: "width"} + # TODO : add text_image, image and image_embeds + if getattr(self._normalized_config, "addition_embed_type", None) == "text_time": + common_inputs["text_embeds"] = {0: "batch_size"} + common_inputs["time_ids"] = {0: "batch_size"} + + if getattr(self._normalized_config, "time_cond_proj_dim", None) is not None: + common_inputs["timestep_cond"] = {0: "batch_size"} + return common_inputs + + @property + def outputs(self) -> Dict[str, Dict[int, str]]: + return { + "out_sample": {0: "batch_size", 2: "height", 3: "width"}, + } + + def generate_dummy_inputs(self, framework: str = "pt", **kwargs): + dummy_inputs_generators = self._create_dummy_input_generator_classes(**kwargs) + + dummy_inputs = {} + for input_name in self.inputs: + for dummy_input_gen in dummy_inputs_generators: + if dummy_input_gen.supports_input(input_name): + dummy_inputs[input_name] = dummy_input_gen.generate( + input_name, framework=framework, int_dtype=self.int_dtype, float_dtype=self.float_dtype + ) + break + + import torch + + dummy_inputs["down_block_additional_residual_1"] = torch.randn(2, 320, 64, 64) + dummy_inputs["down_block_additional_residual_3"] = torch.randn(2, 320, 64, 64) + dummy_inputs["down_block_additional_residual_5"] = torch.randn(2, 320, 64, 64) + dummy_inputs["down_block_additional_residual_7"] = torch.randn(2, 320, 32, 32) + dummy_inputs["down_block_additional_residual_9"] = torch.randn(2, 640, 32, 32) + dummy_inputs["down_block_additional_residual_11"] = torch.randn(2, 640, 32, 32) + dummy_inputs["down_block_additional_residual_13"] = torch.randn(2, 640, 16, 16) + dummy_inputs["down_block_additional_residual_15"] = torch.randn(2, 1280, 16, 16) + dummy_inputs["down_block_additional_residual_17"] = torch.randn(2, 1280, 16, 16) + dummy_inputs["down_block_additional_residual_19"] = torch.randn(2, 1280, 8, 8) + dummy_inputs["down_block_additional_residual_21"] = torch.randn(2, 1280, 8, 8) + dummy_inputs["down_block_additional_residual"] = torch.randn(2, 1280, 8, 8) + dummy_inputs["mid_block_additional_residual"] = torch.randn(2, 1280, 8, 8) + dummy_inputs["encoder_hidden_states"] = dummy_inputs["encoder_hidden_states"][0] + + if getattr(self._normalized_config, "addition_embed_type", None) == "text_time": + dummy_inputs["added_cond_kwargs"] = { + "text_embeds": dummy_inputs.pop("text_embeds"), + "time_ids": dummy_inputs.pop("time_ids"), + } + + return dummy_inputs + + def rename_ambiguous_inputs(self, inputs) -> Dict[str, Dict[int, str]]: + """ + Updates the input names of the model to export. + Override the function when the model input names are ambiguous or too generic. + + Returns: + `Dict[str, Dict[int, str]]`: Updated inputs. + """ + new_inputs = {} + for name, v in inputs.items(): + if name.startswith("down_block_additional_residual"): + new_inputs[name.replace(".", "_")] = v + else: + new_inputs[name] = v + return new_inputs + @register_in_tasks_manager("vae-encoder", *["semantic-segmentation"], library_name="diffusers") class VaeEncoderOpenVINOConfig(VaeEncoderOnnxConfig): @property diff --git a/optimum/exporters/openvino/model_patcher.py b/optimum/exporters/openvino/model_patcher.py index d0aabfb2d..cd27b9565 100644 --- a/optimum/exporters/openvino/model_patcher.py +++ b/optimum/exporters/openvino/model_patcher.py @@ -382,9 +382,9 @@ def _llama_gemma_update_causal_mask_legacy(self, attention_mask, input_tensor, c offset = 0 mask_shape = attention_mask.shape mask_slice = (attention_mask.eq(0.0)).to(dtype=dtype) * min_dtype - causal_mask[ - : mask_shape[0], : mask_shape[1], offset : mask_shape[2] + offset, : mask_shape[3] - ] = mask_slice + causal_mask[: mask_shape[0], : mask_shape[1], offset : mask_shape[2] + offset, : mask_shape[3]] = ( + mask_slice + ) if ( self.config._attn_implementation == "sdpa" @@ -1655,9 +1655,9 @@ def _dbrx_update_causal_mask_legacy( offset = 0 mask_shape = attention_mask.shape mask_slice = (attention_mask.eq(0.0)).to(dtype=dtype) * min_dtype - causal_mask[ - : mask_shape[0], : mask_shape[1], offset : mask_shape[2] + offset, : mask_shape[3] - ] = mask_slice + causal_mask[: mask_shape[0], : mask_shape[1], offset : mask_shape[2] + offset, : mask_shape[3]] = ( + mask_slice + ) if ( self.config._attn_implementation == "sdpa" diff --git a/optimum/intel/__init__.py b/optimum/intel/__init__.py index 2f2d1cb66..4af8d5349 100644 --- a/optimum/intel/__init__.py +++ b/optimum/intel/__init__.py @@ -100,6 +100,7 @@ "OVStableDiffusionXLPipeline", "OVStableDiffusionXLImg2ImgPipeline", "OVLatentConsistencyModelPipeline", + "OVStableDiffusionControlNetPipeline", ] else: _import_structure["openvino"].extend( @@ -110,6 +111,7 @@ "OVStableDiffusionXLPipeline", "OVStableDiffusionXLImg2ImgPipeline", "OVLatentConsistencyModelPipeline", + "OVStableDiffusionControlNetPipeline", ] ) @@ -233,6 +235,7 @@ except OptionalDependencyNotAvailable: from .utils.dummy_openvino_and_diffusers_objects import ( OVLatentConsistencyModelPipeline, + OVStableDiffusionControlNetPipeline, OVStableDiffusionImg2ImgPipeline, OVStableDiffusionInpaintPipeline, OVStableDiffusionPipeline, @@ -242,6 +245,7 @@ else: from .openvino import ( OVLatentConsistencyModelPipeline, + OVStableDiffusionControlNetPipeline, OVStableDiffusionImg2ImgPipeline, OVStableDiffusionInpaintPipeline, OVStableDiffusionPipeline, diff --git a/optimum/intel/openvino/__init__.py b/optimum/intel/openvino/__init__.py index 4ee285f07..114ac812d 100644 --- a/optimum/intel/openvino/__init__.py +++ b/optimum/intel/openvino/__init__.py @@ -71,6 +71,7 @@ if is_diffusers_available(): from .modeling_diffusion import ( OVLatentConsistencyModelPipeline, + OVStableDiffusionControlNetPipeline, OVStableDiffusionImg2ImgPipeline, OVStableDiffusionInpaintPipeline, OVStableDiffusionPipeline, diff --git a/optimum/intel/openvino/modeling_diffusion.py b/optimum/intel/openvino/modeling_diffusion.py index 1b880e736..b94ded7f6 100644 --- a/optimum/intel/openvino/modeling_diffusion.py +++ b/optimum/intel/openvino/modeling_diffusion.py @@ -20,25 +20,31 @@ from copy import deepcopy from pathlib import Path from tempfile import TemporaryDirectory, gettempdir -from typing import Any, Dict, List, Optional, Union +from typing import Any, Dict, List, Optional, Tuple, Union import numpy as np import openvino +import openvino.runtime import PIL +import torch from diffusers import ( + ConfigMixin, + ControlNetModel, DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler, + StableDiffusionControlNetPipeline, StableDiffusionPipeline, StableDiffusionXLImg2ImgPipeline, StableDiffusionXLPipeline, ) from diffusers.schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME -from diffusers.utils import CONFIG_NAME, is_invisible_watermark_available +from diffusers.utils import CONFIG_NAME, is_invisible_watermark_available, numpy_to_pil from huggingface_hub import snapshot_download from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE from openvino._offline_transformations import compress_model_transformation from openvino.runtime import Core +from tqdm.auto import tqdm from transformers import CLIPFeatureExtractor, CLIPTokenizer from optimum.pipelines.diffusers.pipeline_latent_consistency import LatentConsistencyPipelineMixin @@ -54,6 +60,7 @@ DIFFUSION_MODEL_UNET_SUBFOLDER, DIFFUSION_MODEL_VAE_DECODER_SUBFOLDER, DIFFUSION_MODEL_VAE_ENCODER_SUBFOLDER, + DIFFUSION_MODEL_CONTROLNET_SUBFOLDER, ) from ...exporters.openvino import main_export @@ -714,10 +721,10 @@ def __call__(self, latent_sample: np.ndarray): outputs = self.request(inputs, share_inputs=True) return list(outputs.values()) - def _compile(self): - if "GPU" in self.device: - self.ov_config.update({"INFERENCE_PRECISION_HINT": "f32"}) - super()._compile() + # def _compile(self): + # if "GPU" in self.device: + # self.ov_config.update({"INFERENCE_PRECISION_HINT": "f32"}) + # super()._compile() class OVModelVaeEncoder(OVModelPart): @@ -735,10 +742,10 @@ def __call__(self, sample: np.ndarray): outputs = self.request(inputs, share_inputs=True) return list(outputs.values()) - def _compile(self): - if "GPU" in self.device: - self.ov_config.update({"INFERENCE_PRECISION_HINT": "f32"}) - super()._compile() + # def _compile(self): + # if "GPU" in self.device: + # self.ov_config.update({"INFERENCE_PRECISION_HINT": "f32"}) + # super()._compile() class OVStableDiffusionPipeline(OVStableDiffusionPipelineBase, StableDiffusionPipelineMixin): @@ -1100,3 +1107,1038 @@ def _raise_invalid_batch_size( f"To fix this, please either provide a different inputs to your model so that `batch_size` * `num_images_per_prompt` * 2 is equal to {expected_batch_size} " "or reshape it again accordingly using the `.reshape()` method by setting `batch_size` to -1. " + msg ) + + +class OVModelControlNet(OVModelPart): + def __init__( + self, model: openvino.runtime.Model, parent_model: OVBaseModel, ov_config: Optional[Dict[str, str]] = None + ): + super().__init__(model, parent_model, ov_config, "controlnet") + + def __call__( + self, + sample: np.ndarray, + timestep: np.ndarray, + encoder_hidden_states: np.ndarray, + controlnet_cond: np.ndarray, + text_embeds: Optional[np.ndarray] = None, + time_ids: Optional[np.ndarray] = None, + timestep_cond: Optional[np.ndarray] = None, + ): + self._compile() + + inputs = { + "sample": sample, + "timestep": timestep, + "encoder_hidden_states": encoder_hidden_states, + "controlnet_cond": controlnet_cond, + } + + if text_embeds is not None: + inputs["text_embeds"] = text_embeds + if time_ids is not None: + inputs["time_ids"] = time_ids + if timestep_cond is not None: + inputs["timestep_cond"] = timestep_cond + + outputs = self.request(inputs, share_inputs=True) + return list(outputs.values()) + + +class OVModelUnetControlNet(OVModelPart): + def __init__( + self, model: openvino.runtime.Model, parent_model: OVBaseModel, ov_config: Optional[Dict[str, str]] = None + ): + super().__init__(model, parent_model, ov_config, "unet") + + def __call__( + self, + sample: np.ndarray, + timestep: np.ndarray, + encoder_hidden_states: np.ndarray, + down_and_mid_block_samples: Tuple[np.ndarray], + text_embeds: Optional[np.ndarray] = None, + time_ids: Optional[np.ndarray] = None, + timestep_cond: Optional[np.ndarray] = None, + ): + self._compile() + + inputs = { + "sample": sample, + "timestep": timestep, + "encoder_hidden_states": encoder_hidden_states, + "mid_block_additional_residual": down_and_mid_block_samples[-1], + } + a = 1 + for block in down_and_mid_block_samples: + if a == 23: + inputs["down_block_additional_residual"] = block + break + else: + inputs[f"down_block_additional_residual_{a}"] = block + a += 2 + + if text_embeds is not None: + inputs["text_embeds"] = text_embeds + if time_ids is not None: + inputs["time_ids"] = time_ids + if timestep_cond is not None: + inputs["timestep_cond"] = timestep_cond + + outputs = self.request(inputs, share_inputs=True) + return list(outputs.values()) + + +class OVStableDiffusionControlNetPipelineBase(OVStableDiffusionPipelineBase): + """ + OpenVINO inference pipeline for Stable Diffusion with ControlNet guidence + """ + + auto_model_class = StableDiffusionControlNetPipeline + export_feature = "stable-diffusion-controlnet" + config_name = "model_index.json" + + def __init__( + self, + unet: openvino.runtime.Model, + controlnet: openvino.runtime.Model, + config: Dict[str, Any], + scheduler: Union[None, "DDIMScheduler", "PNDMScheduler", "LMSDiscreteScheduler"], + text_encoder: Optional[openvino.runtime.Model] = None, + text_encoder_2: Optional[openvino.runtime.Model] = None, + vae_decoder: Optional[openvino.runtime.Model] = None, + vae_encoder: Optional[openvino.runtime.Model] = None, + tokenizer: Optional["CLIPTokenizer"] = None, + tokenizer_2: Optional["CLIPTokenizer"] = None, + feature_extractor: Optional["CLIPFeatureExtractor"] = None, + device: str = "CPU", + dynamic_shapes: bool = False, + compile: bool = True, + ov_config: Optional[Dict[str, str]] = None, + model_save_dir: Optional[Union[str, Path, TemporaryDirectory]] = None, + quantization_config: Optional[Union[OVWeightQuantizationConfig, Dict]] = None, + **kwargs, + ): + # super().__init__() + self._device = device.upper() + self._internal_dict = config + self.ov_config = {} if ov_config is None else {**ov_config} + self.is_dynamic = dynamic_shapes + self.preprocessors = [] + # This attribute is needed to keep one reference on the temporary directory, since garbage collecting + # would end-up removing the directory containing the underlying OpenVINO model + self._model_save_dir_tempdirectory_instance = None + if isinstance(model_save_dir, TemporaryDirectory): + self._model_save_dir_tempdirectory_instance = model_save_dir + self._model_save_dir = Path(model_save_dir.name) + elif isinstance(model_save_dir, str): + self._model_save_dir = Path(model_save_dir) + else: + self._model_save_dir = model_save_dir + + self.vae_decoder = OVModelVaeDecoder(vae_decoder, self) + self.unet = OVModelUnetControlNet(unet, self) + self.controlnet = OVModelControlNet(controlnet, self) + self.text_encoder = OVModelTextEncoder(text_encoder, self) if text_encoder is not None else None + self.vae_encoder = OVModelVaeEncoder(vae_encoder, self) if vae_encoder is not None else None + self.vae_scale_factor = 8 + self.text_encoder_2 = ( + OVModelTextEncoder(text_encoder_2, self, model_name=DIFFUSION_MODEL_TEXT_ENCODER_2_SUBFOLDER) + if text_encoder_2 is not None + else None + ) + self.tokenizer_2 = tokenizer_2 + self.scheduler = scheduler + self.feature_extractor = feature_extractor + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + self.tokenizer = tokenizer + self._openvino_config = None + + sub_models = { + DIFFUSION_MODEL_TEXT_ENCODER_SUBFOLDER: self.text_encoder, + DIFFUSION_MODEL_UNET_SUBFOLDER: self.unet, + DIFFUSION_MODEL_CONTROLNET_SUBFOLDER: self.controlnet, + DIFFUSION_MODEL_VAE_DECODER_SUBFOLDER: self.vae_decoder, + DIFFUSION_MODEL_VAE_ENCODER_SUBFOLDER: self.vae_encoder, + DIFFUSION_MODEL_TEXT_ENCODER_2_SUBFOLDER: self.text_encoder_2, + } + for name in sub_models.keys(): + self._internal_dict[name] = ( + ("optimum", sub_models[name].__class__.__name__) if sub_models[name] is not None else (None, None) + ) + + if self.is_dynamic: + self.reshape(batch_size=-1, height=-1, width=-1, num_images_per_prompt=-1) + + self._internal_dict.pop("vae", None) + + if compile: + self.compile() + + def compile(self): + self.vae_decoder._compile() + self.unet._compile() + self.controlnet._compile() + for component in {self.text_encoder, self.text_encoder_2, self.vae_encoder}: + if component is not None: + component._compile() + + def export_controlnet( + model_id: str, + save_dir_path: Optional[Union[str, Path]] = None, + ): + controlnet = ControlNetModel.from_pretrained(model_id, torch_dtype=torch.float32) + controlnet.eval() + dummy_inputs = { + "sample": torch.randn((2, 4, 64, 64)), + "timestep": torch.tensor(1), + "encoder_hidden_states": torch.randn((2, 77, 768)), + "controlnet_cond": torch.randn((2, 3, 512, 512)), + } + input_info = [] + for name, inp in dummy_inputs.items(): + shape = openvino.PartialShape(inp.shape) + # element_type = dtype_mapping[input_tensor.dtype] + if len(shape) == 4: + shape[0] = -1 + shape[2] = -1 + shape[3] = -1 + elif len(shape) == 3: + shape[0] = -1 + input_info.append((shape)) + + CONTROLNET_OV_PATH = save_dir_path / "controlnet/openvino_model.xml" + with torch.no_grad(): + from functools import partial + + controlnet.forward = partial(controlnet.forward, return_dict=False) + ov_model = openvino.convert_model(controlnet, example_input=dummy_inputs, input=input_info) + openvino.save_model(ov_model, CONTROLNET_OV_PATH) + del ov_model + torch._C._jit_clear_class_registry() + torch.jit._recursive.concrete_type_store = torch.jit._recursive.ConcreteTypeStore() + torch.jit._state._clear_class_state() + print("ControlNet successfully converted to IR") + + def _save_pretrained(self, save_directory: Union[str, Path]): + """ + Saves the model to the OpenVINO IR format so that it can be re-loaded using the + [`~optimum.intel.openvino.modeling.OVModel.from_pretrained`] class method. + + Arguments: + save_directory (`str` or `Path`): + The directory where to save the model files + """ + save_directory = Path(save_directory) + + sub_models_to_save = { + self.controlnet: DIFFUSION_MODEL_CONTROLNET_SUBFOLDER, + self.unet: DIFFUSION_MODEL_UNET_SUBFOLDER, + self.vae_decoder: DIFFUSION_MODEL_VAE_DECODER_SUBFOLDER, + self.vae_encoder: DIFFUSION_MODEL_VAE_ENCODER_SUBFOLDER, + self.text_encoder: DIFFUSION_MODEL_TEXT_ENCODER_SUBFOLDER, + self.text_encoder_2: DIFFUSION_MODEL_TEXT_ENCODER_2_SUBFOLDER, + } + + for ov_model, dst_path in sub_models_to_save.items(): + if ov_model is not None: + dst_path = save_directory / dst_path / OV_XML_FILE_NAME + dst_path.parent.mkdir(parents=True, exist_ok=True) + openvino.save_model(ov_model.model, dst_path, compress_to_fp16=False) + model_dir = ov_model.config.get("_name_or_path", None) or ov_model._model_dir / ov_model._model_name + config_path = Path(model_dir) / ov_model.CONFIG_NAME + if config_path.is_file(): + shutil.copyfile(config_path, dst_path.parent / ov_model.CONFIG_NAME) + + self.scheduler.save_pretrained(save_directory / "scheduler") + if self.feature_extractor is not None: + self.feature_extractor.save_pretrained(save_directory / "feature_extractor") + if self.tokenizer is not None: + self.tokenizer.save_pretrained(save_directory / "tokenizer") + if self.tokenizer_2 is not None: + self.tokenizer_2.save_pretrained(save_directory / "tokenizer_2") + + self._save_openvino_config(save_directory) + + def _save_config(self, save_directory): + self.save_config(save_directory) + + @classmethod + def _from_pretrained( + cls, + model_id: Union[str, Path], + config: Dict[str, Any], + use_auth_token: Optional[Union[bool, str]] = None, + token: Optional[Union[bool, str]] = None, + revision: Optional[str] = None, + cache_dir: str = HUGGINGFACE_HUB_CACHE, + vae_decoder_file_name: Optional[str] = None, + text_encoder_file_name: Optional[str] = None, + unet_file_name: Optional[str] = None, + controlnet_file_name: Optional[str] = None, + vae_encoder_file_name: Optional[str] = None, + text_encoder_2_file_name: Optional[str] = None, + local_files_only: bool = False, + from_onnx: bool = False, + model_save_dir: Optional[Union[str, Path, TemporaryDirectory]] = None, + load_in_8bit: bool = False, + quantization_config: Union[OVWeightQuantizationConfig, Dict] = None, + **kwargs, + ): + if use_auth_token is not None: + warnings.warn( + "The `use_auth_token` argument is deprecated and will be removed soon. Please use the `token` argument instead.", + FutureWarning, + ) + if token is not None: + raise ValueError("You cannot use both `use_auth_token` and `token` arguments at the same time.") + token = use_auth_token + + default_file_name = ONNX_WEIGHTS_NAME if from_onnx else OV_XML_FILE_NAME + vae_decoder_file_name = vae_decoder_file_name or default_file_name + text_encoder_file_name = text_encoder_file_name or default_file_name + text_encoder_2_file_name = text_encoder_2_file_name or default_file_name + unet_file_name = unet_file_name or default_file_name + controlnet_file_name = controlnet_file_name or default_file_name + vae_encoder_file_name = vae_encoder_file_name or default_file_name + model_id = str(model_id) + patterns = set(config.keys()) + sub_models_names = patterns.intersection({"tokenizer", "tokenizer_2", "scheduler"}) + + if not os.path.isdir(model_id): + patterns.update({"vae_encoder", "vae_decoder"}) + allow_patterns = {os.path.join(k, "*") for k in patterns if not k.startswith("_")} + allow_patterns.update( + { + vae_decoder_file_name, + text_encoder_file_name, + text_encoder_2_file_name, + unet_file_name, + controlnet_file_name, + vae_encoder_file_name, + vae_decoder_file_name.replace(".xml", ".bin"), + text_encoder_file_name.replace(".xml", ".bin"), + text_encoder_2_file_name.replace(".xml", ".bin"), + unet_file_name.replace(".xml", ".bin"), + controlnet_file_name.replace(".xml", ".bin"), + vae_encoder_file_name.replace(".xml", ".bin"), + SCHEDULER_CONFIG_NAME, + CONFIG_NAME, + cls.config_name, + } + ) + ignore_patterns = ["*.msgpack", "*.safetensors", "*pytorch_model.bin"] + if not from_onnx: + ignore_patterns.extend(["*.onnx", "*.onnx_data"]) + # Downloads all repo's files matching the allowed patterns + model_id = snapshot_download( + model_id, + cache_dir=cache_dir, + local_files_only=local_files_only, + token=token, + revision=revision, + allow_patterns=allow_patterns, + ignore_patterns=ignore_patterns, + ) + new_model_save_dir = Path(model_id) + + for name in sub_models_names: + # Check if the subcomponent needs to be loaded + if kwargs.get(name, None) is not None: + continue + library_name, library_classes = config[name] + if library_classes is not None: + library = importlib.import_module(library_name) + class_obj = getattr(library, library_classes) + load_method = getattr(class_obj, "from_pretrained") + # Check if the module is in a subdirectory + if (new_model_save_dir / name).is_dir(): + kwargs[name] = load_method(new_model_save_dir / name) + else: + kwargs[name] = load_method(new_model_save_dir) + + unet_path = new_model_save_dir / DIFFUSION_MODEL_UNET_SUBFOLDER / unet_file_name + controlnet_path = new_model_save_dir / DIFFUSION_MODEL_CONTROLNET_SUBFOLDER / controlnet_file_name + components = { + "vae_encoder": new_model_save_dir / DIFFUSION_MODEL_VAE_ENCODER_SUBFOLDER / vae_encoder_file_name, + "vae_decoder": new_model_save_dir / DIFFUSION_MODEL_VAE_DECODER_SUBFOLDER / vae_decoder_file_name, + "text_encoder": new_model_save_dir / DIFFUSION_MODEL_TEXT_ENCODER_SUBFOLDER / text_encoder_file_name, + "text_encoder_2": new_model_save_dir / DIFFUSION_MODEL_TEXT_ENCODER_2_SUBFOLDER / text_encoder_2_file_name, + } + + if model_save_dir is None: + model_save_dir = new_model_save_dir + + quantization_config = cls._prepare_weight_quantization_config(quantization_config, load_in_8bit) + if quantization_config is None or quantization_config.dataset is None: + unet = cls.load_model(unet_path, quantization_config) + controlnet = cls.load_model(controlnet_path) + for key, value in components.items(): + components[key] = cls.load_model(value, quantization_config) if value.is_file() else None + else: + # Load uncompressed models to apply hybrid quantization further + raise NotImplementedError(f"Quantization in hybrid mode is not supported for {cls.__name__}") + + return cls( + unet=unet, + controlnet=controlnet, + config=config, + model_save_dir=model_save_dir, + quantization_config=quantization_config, + **components, + **kwargs, + ) + + @classmethod + def _from_transformers( + cls, + model_id: str, + config: Dict[str, Any], + use_auth_token: Optional[Union[bool, str]] = None, + token: Optional[Union[bool, str]] = None, + revision: Optional[str] = None, + force_download: bool = False, + cache_dir: str = HUGGINGFACE_HUB_CACHE, + local_files_only: bool = False, + tokenizer: Optional["CLIPTokenizer"] = None, + scheduler: Union["DDIMScheduler", "PNDMScheduler", "LMSDiscreteScheduler"] = None, + feature_extractor: Optional["CLIPFeatureExtractor"] = None, + tokenizer_2: Optional["CLIPTokenizer"] = None, + load_in_8bit: Optional[bool] = None, + quantization_config: Union[OVWeightQuantizationConfig, Dict] = None, + **kwargs, + ): + if use_auth_token is not None: + warnings.warn( + "The `use_auth_token` argument is deprecated and will be removed soon. Please use the `token` argument instead.", + FutureWarning, + ) + if token is not None: + raise ValueError("You cannot use both `use_auth_token` and `token` arguments at the same time.") + token = use_auth_token + + save_dir = TemporaryDirectory() + save_dir_path = Path(save_dir.name) + + # If load_in_8bit and quantization_config not specified then ov_config is set to None and will be set by default in convert depending on the model size + if load_in_8bit is None and not quantization_config: + ov_config = None + else: + ov_config = OVConfig(dtype="fp32") + + if "controlnet_model_id" not in kwargs.keys(): + raise ValueError("You must give controlnet id with controlnet_model_id=controlnet_model_id.") + else: + cls.export_controlnet(model_id=kwargs["controlnet_model_id"], save_dir_path=save_dir_path) + + main_export( + model_name_or_path=model_id, + output=save_dir_path, + task=cls.export_feature, + do_validation=False, + no_post_process=True, + revision=revision, + cache_dir=cache_dir, + token=token, + local_files_only=local_files_only, + force_download=force_download, + ov_config=ov_config, + ) + + return cls._from_pretrained( + model_id=save_dir_path, + config=config, + from_onnx=False, + token=token, + revision=revision, + force_download=force_download, + cache_dir=cache_dir, + local_files_only=local_files_only, + model_save_dir=save_dir, + tokenizer=tokenizer, + tokenizer_2=tokenizer_2, + scheduler=scheduler, + feature_extractor=feature_extractor, + load_in_8bit=load_in_8bit, + quantization_config=quantization_config, + **kwargs, + ) + + def _reshape_unet_controlnet( + self, + model: openvino.runtime.Model, + batch_size: int = -1, + height: int = -1, + width: int = -1, + num_images_per_prompt: int = -1, + tokenizer_max_length: int = -1, + ): + if batch_size == -1 or num_images_per_prompt == -1: + batch_size = -1 + else: + batch_size *= num_images_per_prompt + # The factor of 2 comes from the guidance scale > 1 + if "timestep_cond" not in {inputs.get_any_name() for inputs in model.inputs}: + batch_size *= 2 + + height = height // self.vae_scale_factor if height > 0 else height + width = width // self.vae_scale_factor if width > 0 else width + shapes = {} + for inputs in model.inputs: + shapes[inputs] = inputs.get_partial_shape() + if inputs.get_any_name() == "timestep": + shapes[inputs][0] = 1 + elif inputs.get_any_name() == "sample": + in_channels = self.unet.config.get("in_channels", None) + if in_channels is None: + in_channels = shapes[inputs][1] + if in_channels.is_dynamic: + logger.warning( + "Could not identify `in_channels` from the unet configuration, to statically reshape the unet please provide a configuration." + ) + self.is_dynamic = True + + shapes[inputs] = [batch_size, in_channels, height, width] + elif inputs.get_any_name() == "text_embeds": + shapes[inputs] = [batch_size, self.text_encoder_2.config["projection_dim"]] + elif inputs.get_any_name() == "time_ids": + shapes[inputs] = [batch_size, inputs.get_partial_shape()[1]] + elif inputs.get_any_name() == "timestep_cond": + shapes[inputs] = [batch_size, self.unet.config["time_cond_proj_dim"]] + elif inputs.get_any_name() == "encoder_hidden_states": + shapes[inputs][0] = batch_size + shapes[inputs][1] = tokenizer_max_length + elif inputs.get_any_name() == "down_block_additional_residual_1": + shapes[inputs][0] = batch_size + shapes[inputs][2] = height + shapes[inputs][3] = width + elif inputs.get_any_name() == "down_block_additional_residual_3": + shapes[inputs][0] = batch_size + shapes[inputs][2] = height + shapes[inputs][3] = width + elif inputs.get_any_name() == "down_block_additional_residual_5": + shapes[inputs][0] = batch_size + shapes[inputs][2] = height + shapes[inputs][3] = width + elif inputs.get_any_name() == "down_block_additional_residual_7": + shapes[inputs][0] = batch_size + shapes[inputs][2] = height // 2 + shapes[inputs][3] = width // 2 + elif inputs.get_any_name() == "down_block_additional_residual_9": + shapes[inputs][0] = batch_size + shapes[inputs][2] = height // 2 + shapes[inputs][3] = width // 2 + elif inputs.get_any_name() == "down_block_additional_residual_11": + shapes[inputs][0] = batch_size + shapes[inputs][2] = height // 2 + shapes[inputs][3] = width // 2 + elif inputs.get_any_name() == "down_block_additional_residual_13": + shapes[inputs][0] = batch_size + shapes[inputs][2] = height // 4 + shapes[inputs][3] = width // 4 + elif inputs.get_any_name() == "down_block_additional_residual_15": + shapes[inputs][0] = batch_size + shapes[inputs][2] = height // 4 + shapes[inputs][3] = width // 4 + elif inputs.get_any_name() == "down_block_additional_residual_17": + shapes[inputs][0] = batch_size + shapes[inputs][2] = height // 4 + shapes[inputs][3] = width // 4 + elif inputs.get_any_name() == "down_block_additional_residual_19": + shapes[inputs][0] = batch_size + shapes[inputs][2] = height // 8 + shapes[inputs][3] = width // 8 + elif inputs.get_any_name() == "down_block_additional_residual_21": + shapes[inputs][0] = batch_size + shapes[inputs][2] = height // 8 + shapes[inputs][3] = width // 8 + elif inputs.get_any_name() == "down_block_additional_residual": + shapes[inputs][0] = batch_size + shapes[inputs][2] = height // 8 + shapes[inputs][3] = width // 8 + elif inputs.get_any_name() == "mid_block_additional_residual": + shapes[inputs][0] = batch_size + shapes[inputs][2] = height // 8 + shapes[inputs][3] = width // 8 + + model.reshape(shapes) + return model + + def _reshape_controlnet( + self, + model: openvino.runtime.Model, + batch_size: int = -1, + height: int = -1, + width: int = -1, + num_images_per_prompt: int = -1, + tokenizer_max_length: int = -1, + ): + if batch_size == -1 or num_images_per_prompt == -1: + batch_size = -1 + else: + batch_size *= num_images_per_prompt + # The factor of 2 comes from the guidance scale > 1 + if "timestep_cond" not in {inputs.get_any_name() for inputs in model.inputs}: + batch_size *= 2 + + height_ = height // self.vae_scale_factor if height > 0 else height + width_ = width // self.vae_scale_factor if width > 0 else width + shapes = {} + for inputs in model.inputs: + shapes[inputs] = inputs.get_partial_shape() + if inputs.get_any_name() == "timestep": + shapes[inputs] = shapes[inputs] + elif inputs.get_any_name() == "sample": + in_channels = self.unet.config.get("in_channels", None) + if in_channels is None: + in_channels = shapes[inputs][1] + if in_channels.is_dynamic: + logger.warning( + "Could not identify `in_channels` from the unet configuration, to statically reshape the unet please provide a configuration." + ) + self.is_dynamic = True + + shapes[inputs] = [batch_size, in_channels, height_, width_] + elif inputs.get_any_name() == "controlnet_cond": + shapes[inputs][0] = batch_size + shapes[inputs][2] = height + shapes[inputs][3] = width + elif inputs.get_any_name() == "time_ids": + shapes[inputs] = [batch_size, inputs.get_partial_shape()[1]] + elif inputs.get_any_name() == "timestep_cond": + shapes[inputs] = [batch_size, self.unet.config["time_cond_proj_dim"]] + elif inputs.get_any_name() == "encoder_hidden_states": + shapes[inputs][0] = batch_size + shapes[inputs][1] = tokenizer_max_length + model.reshape(shapes) + return model + + def reshape( + self, + batch_size: int, + height: int, + width: int, + num_images_per_prompt: int = -1, + ): + self.is_dynamic = -1 in {batch_size, height, width, num_images_per_prompt} + self.vae_decoder.model = self._reshape_vae_decoder(self.vae_decoder.model, height, width) + if self.tokenizer is None and self.tokenizer_2 is None: + tokenizer_max_len = -1 + else: + tokenizer_max_len = ( + self.tokenizer.model_max_length if self.tokenizer is not None else self.tokenizer_2.model_max_length + ) + self.unet.model = self._reshape_unet_controlnet( + self.unet.model, batch_size, height, width, num_images_per_prompt, tokenizer_max_len + ) + + self.controlnet.model = self._reshape_controlnet( + self.controlnet.model, batch_size, height, width, num_images_per_prompt, tokenizer_max_len + ) + + if self.text_encoder is not None: + self.text_encoder.model = self._reshape_text_encoder( + self.text_encoder.model, batch_size, self.tokenizer.model_max_length + ) + + if self.text_encoder_2 is not None: + self.text_encoder_2.model = self._reshape_text_encoder( + self.text_encoder_2.model, batch_size, self.tokenizer_2.model_max_length + ) + + if self.vae_encoder is not None: + self.vae_encoder.model = self._reshape_vae_encoder(self.vae_encoder.model, batch_size, height, width) + + self.clear_requests() + return self + + +class StableDiffusionContrlNetPipelineMixin(ConfigMixin): + def _encode_prompt( + self, + prompt: Union[str, List[str]], + num_images_per_prompt: int = 1, + do_classifier_free_guidance: bool = True, + negative_prompt: Union[str, List[str]] = None, + ): + """ + Encodes the prompt into text encoder hidden states. + + Parameters: + prompt (str or list(str)): prompt to be encoded + num_images_per_prompt (int): number of images that should be generated per prompt + do_classifier_free_guidance (bool): whether to use classifier free guidance or not + negative_prompt (str or list(str)): negative prompt to be encoded + Returns: + text_embeddings (np.ndarray): text encoder hidden states + """ + batch_size = len(prompt) if isinstance(prompt, list) else 1 + + # tokenize input prompts + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="np", + ) + text_input_ids = text_inputs.input_ids + + text_embeddings = self.text_encoder(input_ids=text_input_ids.astype(np.int32))[0] + + # duplicate text embeddings for each generation per prompt + if num_images_per_prompt != 1: + bs_embed, seq_len, _ = text_embeddings.shape + text_embeddings = np.tile(text_embeddings, (1, num_images_per_prompt, 1)) + text_embeddings = np.reshape(text_embeddings, (bs_embed * num_images_per_prompt, seq_len, -1)) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance: + uncond_tokens: List[str] + max_length = text_input_ids.shape[-1] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + else: + uncond_tokens = negative_prompt + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="np", + ) + + uncond_embeddings = self.text_encoder(input_ids=uncond_input.input_ids.astype(np.int32))[0] + + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = uncond_embeddings.shape[1] + uncond_embeddings = np.tile(uncond_embeddings, (1, num_images_per_prompt, 1)) + uncond_embeddings = np.reshape(uncond_embeddings, (batch_size * num_images_per_prompt, seq_len, -1)) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + text_embeddings = np.concatenate([uncond_embeddings, text_embeddings]) + + return text_embeddings + + def prepare_latents( + self, + batch_size: int, + num_channels_latents: int, + height: int, + width: int, + dtype: np.dtype = np.float32, + latents: np.ndarray = None, + ): + """ + Preparing noise to image generation. If initial latents are not provided, they will be generated randomly, + then prepared latents scaled by the standard deviation required by the scheduler + + Parameters: + batch_size (int): input batch size + num_channels_latents (int): number of channels for noise generation + height (int): image height + width (int): image width + dtype (np.dtype, *optional*, np.float32): dtype for latents generation + latents (np.ndarray, *optional*, None): initial latent noise tensor, if not provided will be generated + Returns: + latents (np.ndarray): scaled initial noise for diffusion + """ + shape = ( + batch_size, + num_channels_latents, + height // self.vae_scale_factor, + width // self.vae_scale_factor, + ) + if latents is None: + latents = self.randn_tensor(shape, dtype=dtype) + else: + latents = latents + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + def decode_latents(self, latents: np.array, pad: Tuple[int]): + """ + Decode predicted image from latent space using VAE Decoder and unpad image result + + Parameters: + latents (np.ndarray): image encoded in diffusion latent space + pad (Tuple[int]): each side padding sizes obtained on preprocessing step + Returns: + image: decoded by VAE decoder image + """ + latents = 1 / 0.18215 * latents + image = self.vae_decoder(latents)[0] + (_, end_h), (_, end_w) = pad[1:3] + h, w = image.shape[2:] + unpad_h = h - end_h + unpad_w = w - end_w + image = image[:, :, :unpad_h, :unpad_w] + image = np.clip(image / 2 + 0.5, 0, 1) + image = np.transpose(image, (0, 2, 3, 1)) + return image + + def scale_fit_to_window(self, dst_width: int, dst_height: int, image_width: int, image_height: int): + """ + Preprocessing helper function for calculating image size for resize with peserving original aspect ratio + and fitting image to specific window size + + Parameters: + dst_width (int): destination window width + dst_height (int): destination window height + image_width (int): source image width + image_height (int): source image height + Returns: + result_width (int): calculated width for resize + result_height (int): calculated height for resize + """ + im_scale = min(dst_height / image_height, dst_width / image_width) + return int(im_scale * image_width), int(im_scale * image_height) + + def preprocess(self, image: PIL.Image.Image, height, width): + """ + Image preprocessing function. Takes image in PIL.Image format, resizes it to keep aspect ration and fits to model input window 512x512, + then converts it to np.ndarray and adds padding with zeros on right or bottom side of image (depends from aspect ratio), after that + converts data to float32 data type and change range of values from [0, 255] to [-1, 1], finally, converts data layout from planar NHWC to NCHW. + The function returns preprocessed input tensor and padding size, which can be used in postprocessing. + + Parameters: + image (PIL.Image.Image): input image + Returns: + image (np.ndarray): preprocessed image tensor + pad (Tuple[int]): pading size for each dimension for restoring image size in postprocessing + """ + src_width, src_height = image.size + dst_width, dst_height = self.scale_fit_to_window(width, height, src_width, src_height) + image = np.array(image.resize((dst_width, dst_height), resample=PIL.Image.Resampling.LANCZOS))[None, :] + pad_width = width - dst_width + pad_height = height - dst_height + pad = ((0, 0), (0, pad_height), (0, pad_width), (0, 0)) + image = np.pad(image, pad, mode="constant") + image = image.astype(np.float32) / 255.0 + image = image.transpose(0, 3, 1, 2) + return image, pad + + def randn_tensor( + self, + shape: Union[Tuple, List], + dtype: Optional[np.dtype] = np.float32, + ): + """ + Helper function for generation random values tensor with given shape and data type + + Parameters: + shape (Union[Tuple, List]): shape for filling random values + dtype (np.dtype, *optiona*, np.float32): data type for result + Returns: + latents (np.ndarray): tensor with random values with given data type and shape (usually represents noise in latent space) + """ + latents = np.random.randn(*shape).astype(dtype) + + return latents + + def progress_bar(self, iterable=None, total=None): + if not hasattr(self, "_progress_bar_config"): + self._progress_bar_config = {} + elif not isinstance(self._progress_bar_config, dict): + raise ValueError( + f"`self._progress_bar_config` should be of type `dict`, but is {type(self._progress_bar_config)}." + ) + + if iterable is not None: + return tqdm(iterable, **self._progress_bar_config) + elif total is not None: + return tqdm(total=total, **self._progress_bar_config) + else: + raise ValueError("Either `total` or `iterable` has to be defined.") + + def set_progress_bar_config(self, **kwargs): + self._progress_bar_config = kwargs + + def __call__( + self, + prompt: Union[str, List[str]], + image: PIL.Image.Image, + num_inference_steps: int = 10, + negative_prompt: Union[str, List[str]] = None, + guidance_scale: float = 7.5, + controlnet_conditioning_scale: float = 1.0, + eta: float = 0.0, + latents: Optional[np.array] = None, + height: Optional[int] = None, + width: Optional[int] = None, + ): + """ + Function invoked when calling the pipeline for generation. + + Parameters: + prompt (`str` or `List[str]`): + The prompt or prompts to guide the image generation. + image (`PIL.Image.Image`): + `PIL.Image`, or tensor representing an image batch which will be repainted according to `prompt`. + num_inference_steps (`int`, *optional*, defaults to 100): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + negative_prompt (`str` or `List[str]`): + negative prompt or prompts for generation + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. This pipeline requires a value of at least `1`. + latents (`np.ndarray`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + Returns: + image ([List[Union[np.ndarray, PIL.Image.Image]]): generaited images + + """ + + # 1. Define call parameters + batch_size = 1 if isinstance(prompt, str) else len(prompt) + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + # 2. Encode input prompt + text_embeddings = self._encode_prompt(prompt, negative_prompt=negative_prompt) + + # 3. Preprocess image + orig_width, orig_height = image.size + image, pad = self.preprocess(image, height=height, width=width) + height, width = image.shape[-2:] + if do_classifier_free_guidance: + image = np.concatenate(([image] * 2)) + + # 4. set timesteps + self.scheduler.set_timesteps(num_inference_steps) + timesteps = self.scheduler.timesteps + + # 6. Prepare latent variables + num_channels_latents = 4 + latents = self.prepare_latents( + batch_size, + num_channels_latents, + height, + width, + text_embeddings.dtype, + latents, + ) + + # 7. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + self.set_progress_bar_config(disable=True) + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # Expand the latents if we are doing classifier free guidance. + # The latents are expanded 3 times because for pix2pix the guidance\ + # is applied for both the text and the input image. + latent_model_input = np.concatenate([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + result = self.controlnet( + sample=latent_model_input, timestep=t, encoder_hidden_states=text_embeddings, controlnet_cond=image + ) + + down_and_mid_block_samples = [sample * controlnet_conditioning_scale for sample in result] + down_and_mid_block_samples = tuple(down_and_mid_block_samples) + # predict the noise residual + noise_pred = self.unet( + sample=latent_model_input, + timestep=t, + encoder_hidden_states=text_embeddings, + down_and_mid_block_samples=down_and_mid_block_samples, + )[0] + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred[0], noise_pred[1] + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step( + torch.from_numpy(noise_pred), t, torch.from_numpy(latents) + ).prev_sample.numpy() + + # update progress + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + # 8. Post-processing + image = self.decode_latents(latents, pad) + + # 9. Convert to PIL + image = numpy_to_pil(image) + image = [img.resize((orig_width, orig_height), PIL.Image.Resampling.LANCZOS) for img in image] + + return image + + +class OVStableDiffusionControlNetPipeline( + OVStableDiffusionControlNetPipelineBase, StableDiffusionContrlNetPipelineMixin +): + def __call__( + self, + prompt: Optional[Union[str, List[str]]] = None, + image: Optional[PIL.Image.Image] = None, + num_inference_steps: int = 10, + guidance_scale: float = 7.5, + controlnet_conditioning_scale: float = 1.0, + eta: float = 0.0, + latents: Optional[np.array] = None, + height: Optional[int] = None, + width: Optional[int] = None, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: int = 1, + **kwargs, + ): + height = height or self.unet.config.get("sample_size", 64) * self.vae_scale_factor + width = width or self.unet.config.get("sample_size", 64) * self.vae_scale_factor + _height = self.height + _width = self.width + expected_batch_size = self._batch_size + + if _height != -1 and height != _height: + logger.warning( + f"`height` was set to {height} but the static model will output images of height {_height}." + "To fix the height, please reshape your model accordingly using the `.reshape()` method." + ) + height = _height + + if _width != -1 and width != _width: + logger.warning( + f"`width` was set to {width} but the static model will output images of width {_width}." + "To fix the width, please reshape your model accordingly using the `.reshape()` method." + ) + width = _width + + if expected_batch_size != -1: + if isinstance(prompt, str): + batch_size = 1 + elif isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = kwargs.get("prompt_embeds").shape[0] + + _raise_invalid_batch_size(expected_batch_size, batch_size, num_images_per_prompt, guidance_scale) + + return StableDiffusionContrlNetPipelineMixin.__call__( + self, + prompt=prompt, + image=image, + num_inference_steps=num_inference_steps, + negative_prompt=negative_prompt, + guidance_scale=guidance_scale, + controlnet_conditioning_scale=controlnet_conditioning_scale, + eta=eta, + latents=latents, + height=height, + width=width, + ) diff --git a/optimum/intel/utils/dummy_openvino_and_diffusers_objects.py b/optimum/intel/utils/dummy_openvino_and_diffusers_objects.py index 78016ea71..f0877772c 100644 --- a/optimum/intel/utils/dummy_openvino_and_diffusers_objects.py +++ b/optimum/intel/utils/dummy_openvino_and_diffusers_objects.py @@ -26,6 +26,17 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["openvino", "diffusers"]) +class OVStableDiffusionControlNetPipeline(metaclass=DummyObject): + _backends = ["openvino", "diffusers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["openvino", "diffusers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["openvino", "diffusers"]) + + class OVStableDiffusionImg2ImgPipeline(metaclass=DummyObject): _backends = ["openvino", "diffusers"]