Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SD3 support #2073

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions optimum/exporters/onnx/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,6 +319,7 @@ def fix_dynamic_axes(
input_shapes = {}
dummy_inputs = self.generate_dummy_inputs(framework="np", **input_shapes)
dummy_inputs = self.generate_dummy_inputs_for_validation(dummy_inputs, onnx_input_names=onnx_input_names)
dummy_inputs = self.rename_ambiguous_inputs(dummy_inputs)

onnx_inputs = {}
for name, value in dummy_inputs.items():
Expand Down
4 changes: 4 additions & 0 deletions optimum/exporters/onnx/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -1183,6 +1183,10 @@ def onnx_export_from_model(
if tokenizer_2 is not None:
tokenizer_2.save_pretrained(output.joinpath("tokenizer_2"))

tokenizer_3 = getattr(model, "tokenizer_3", None)
if tokenizer_3 is not None:
tokenizer_3.save_pretrained(output.joinpath("tokenizer_3"))

model.save_config(output)

if float_dtype == "bf16":
Expand Down
81 changes: 71 additions & 10 deletions optimum/exporters/onnx/model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -1015,22 +1015,13 @@ def outputs(self) -> Dict[str, Dict[int, str]]:
"last_hidden_state": {0: "batch_size", 1: "sequence_length"},
"pooler_output": {0: "batch_size"},
}

if self._normalized_config.output_hidden_states:
for i in range(self._normalized_config.num_layers + 1):
common_outputs[f"hidden_states.{i}"] = {0: "batch_size", 1: "sequence_length"}

return common_outputs

def generate_dummy_inputs(self, framework: str = "pt", **kwargs):
dummy_inputs = super().generate_dummy_inputs(framework=framework, **kwargs)

# TODO: fix should be by casting inputs during inference and not export
if framework == "pt":
import torch

dummy_inputs["input_ids"] = dummy_inputs["input_ids"].to(dtype=torch.int32)
return dummy_inputs

def patch_model_for_export(
self,
model: Union["PreTrainedModel", "TFPreTrainedModel", "ModelMixin"],
Expand Down Expand Up @@ -1160,6 +1151,76 @@ def outputs(self) -> Dict[str, Dict[int, str]]:
}


class PooledProjectionsDummyInputGenerator(DummyInputGenerator):
SUPPORTED_INPUT_NAMES = "pooled_projections"

def __init__(
self,
task: str,
normalized_config: NormalizedConfig,
batch_size: int = DEFAULT_DUMMY_SHAPES["batch_size"],
**kwargs,
):
self.task = task
self.batch_size = batch_size
self.pooled_projection_dim = normalized_config.config.pooled_projection_dim

def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"):
return self.random_float_tensor(
[self.batch_size, self.pooled_projection_dim], framework=framework, dtype=float_dtype
)


class DummyTransformerTimestpsInputGenerator(DummyTimestepInputGenerator):
def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"):
if input_name == "timestep":
shape = [self.batch_size]
return self.random_float_tensor(shape, max_value=self.vocab_size, framework=framework, dtype=float_dtype)
return super().generate(input_name, framework, int_dtype, float_dtype)


class SD3TransformerOnnxConfig(UNetOnnxConfig):
DUMMY_INPUT_GENERATOR_CLASSES = (
(DummyTransformerTimestpsInputGenerator,)
+ UNetOnnxConfig.DUMMY_INPUT_GENERATOR_CLASSES
+ (PooledProjectionsDummyInputGenerator,)
)
NORMALIZED_CONFIG_CLASS = NormalizedConfig.with_args(
image_size="sample_size",
num_channels="in_channels",
hidden_size="joint_attention_dim",
vocab_size="attention_head_dim",
allow_new=True,
)

@property
def inputs(self):
common_inputs = super().inputs
common_inputs["pooled_projections"] = {0: "batch_size"}
return common_inputs

def rename_ambiguous_inputs(self, inputs):
# The input name in the model signature is `x, hence the export input name is updated.
hidden_states = inputs.pop("sample", None)
if hidden_states is not None:
inputs["hidden_states"] = hidden_states
return inputs


class T5EncoderOnnxConfig(CLIPTextOnnxConfig):
@property
def inputs(self):
return {
"input_ids": {0: "batch_size", 1: "sequence_length"},
}

@property
def outputs(self):
return {
"last_hidden_state": {0: "batch_size", 1: "sequence_length"},
}


class GroupViTOnnxConfig(CLIPOnnxConfig):
pass

Expand Down
16 changes: 13 additions & 3 deletions optimum/exporters/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,6 +335,10 @@ class TasksManager:
}

_DIFFUSERS_SUPPORTED_MODEL_TYPE = {
"t5-encoder": supported_tasks_mapping(
"feature-extraction",
onnx="T5EncoderOnnxConfig",
),
"clip-text-model": supported_tasks_mapping(
"feature-extraction",
onnx="CLIPTextOnnxConfig",
Expand All @@ -347,6 +351,10 @@ class TasksManager:
"semantic-segmentation",
onnx="UNetOnnxConfig",
),
"sd3-transformer": supported_tasks_mapping(
"semantic-segmentation",
onnx="SD3TransformerOnnxConfig",
),
"vae-encoder": supported_tasks_mapping(
"semantic-segmentation",
onnx="VaeEncoderOnnxConfig",
Expand Down Expand Up @@ -1170,12 +1178,14 @@ class TasksManager:
"transformers": _SUPPORTED_MODEL_TYPE,
}
_UNSUPPORTED_CLI_MODEL_TYPE = {
"unet",
"vae-encoder",
"vae-decoder",
"clip-text-model",
"clip-text-with-projection",
"sd3-transformer",
"t5-encoder",
"trocr", # supported through the vision-encoder-decoder model type
"unet",
"vae-encoder",
"vae-decoder",
}
_SUPPORTED_CLI_MODEL_TYPE = (
set(_SUPPORTED_MODEL_TYPE.keys())
Expand Down
161 changes: 111 additions & 50 deletions optimum/exporters/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,14 @@
StableDiffusionXLInpaintPipeline,
StableDiffusionXLPipeline,
)

if check_if_diffusers_greater("0.30.0"):
from diffusers import (
StableDiffusion3Img2ImgPipeline,
StableDiffusion3InpaintPipeline,
StableDiffusion3Pipeline,
)

from diffusers.models.attention_processor import (
Attention,
AttnAddedKVProcessor,
Expand Down Expand Up @@ -87,56 +95,95 @@ def _get_submodels_for_export_diffusion(
Returns the components of a Stable Diffusion model.
"""

models_for_export = {}

is_stable_diffusion_xl = isinstance(
pipeline, (StableDiffusionXLPipeline, StableDiffusionXLImg2ImgPipeline, StableDiffusionXLInpaintPipeline)
)
if is_stable_diffusion_xl:
projection_dim = pipeline.text_encoder_2.config.projection_dim
else:
projection_dim = pipeline.text_encoder.config.projection_dim
is_stable_diffusion_3 = isinstance(
pipeline, (StableDiffusion3Pipeline, StableDiffusion3Img2ImgPipeline, StableDiffusion3InpaintPipeline)
)

models_for_export = {}
is_torch_greater_or_equal_than_2_1 = version.parse(torch.__version__) >= version.parse("2.1.0")

# Text encoder
text_encoder = getattr(pipeline, "text_encoder", None)
if text_encoder is not None:
if is_stable_diffusion_xl:
if is_stable_diffusion_xl or is_stable_diffusion_3:
text_encoder.config.output_hidden_states = True
text_encoder.text_model.config.output_hidden_states = True

if is_stable_diffusion_3:
text_encoder.config.export_model_type = "clip-text-with-projection"
else:
text_encoder.config.export_model_type = "clip-text-model"

models_for_export["text_encoder"] = text_encoder

# U-NET
# ONNX export of torch.nn.functional.scaled_dot_product_attention not supported for < v2.1.0
is_torch_greater_or_equal_than_2_1 = version.parse(torch.__version__) >= version.parse("2.1.0")
if not is_torch_greater_or_equal_than_2_1:
pipeline.unet.set_attn_processor(AttnProcessor())
# Text encoder 2
text_encoder_2 = getattr(pipeline, "text_encoder_2", None)
if text_encoder_2 is not None:
text_encoder_2.config.output_hidden_states = True
text_encoder_2.text_model.config.output_hidden_states = True
text_encoder_2.config.export_model_type = "clip-text-with-projection"

pipeline.unet.config.text_encoder_projection_dim = projection_dim
# The U-NET time_ids inputs shapes depends on the value of `requires_aesthetics_score`
# https://github.com/huggingface/diffusers/blob/v0.18.2/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py#L571
pipeline.unet.config.requires_aesthetics_score = getattr(pipeline.config, "requires_aesthetics_score", False)
models_for_export["unet"] = pipeline.unet
models_for_export["text_encoder_2"] = text_encoder_2

# VAE Encoder https://github.com/huggingface/diffusers/blob/v0.11.1/src/diffusers/models/vae.py#L565
# Text encoder 3
text_encoder_3 = getattr(pipeline, "text_encoder_3", None)
if text_encoder_3 is not None:
text_encoder_3.config.export_model_type = "t5-encoder"
models_for_export["text_encoder_3"] = text_encoder_3

# U-NET
unet = getattr(pipeline, "unet", None)
if unet is not None:
# ONNX export of torch.nn.functional.scaled_dot_product_attention not supported for < v2.1.0
if not is_torch_greater_or_equal_than_2_1:
unet.set_attn_processor(AttnProcessor())

# The U-NET time_ids inputs shapes depends on the value of `requires_aesthetics_score`
# https://github.com/huggingface/diffusers/blob/v0.18.2/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py#L571
unet.config.requires_aesthetics_score = getattr(pipeline.config, "requires_aesthetics_score", False)
unet.config.time_cond_proj_dim = getattr(pipeline.unet.config, "time_cond_proj_dim", None)
unet.config.text_encoder_projection_dim = pipeline.text_encoder.config.projection_dim
unet.config.export_model_type = "unet"
models_for_export["unet"] = unet

# Transformer
transformer = getattr(pipeline, "transformer", None)
if transformer is not None:
# ONNX export of torch.nn.functional.scaled_dot_product_attention not supported for < v2.1.0
if not is_torch_greater_or_equal_than_2_1:
transformer.set_attn_processor(AttnProcessor())

transformer.config.requires_aesthetics_score = getattr(pipeline.config, "requires_aesthetics_score", False)
transformer.config.time_cond_proj_dim = getattr(pipeline.transformer.config, "time_cond_proj_dim", None)
transformer.config.text_encoder_projection_dim = pipeline.text_encoder.config.projection_dim
transformer.config.export_model_type = "sd3-transformer"
models_for_export["transformer"] = transformer

# VAE Encoder
vae_encoder = copy.deepcopy(pipeline.vae)

# ONNX export of torch.nn.functional.scaled_dot_product_attention not supported for < v2.1.0
if not is_torch_greater_or_equal_than_2_1:
vae_encoder = override_diffusers_2_0_attn_processors(vae_encoder)

# we return the distribution parameters to be able to recreate it in the decoder
vae_encoder.forward = lambda sample: {"latent_parameters": vae_encoder.encode(x=sample)["latent_dist"].parameters}
models_for_export["vae_encoder"] = vae_encoder

# VAE Decoder https://github.com/huggingface/diffusers/blob/v0.11.1/src/diffusers/models/vae.py#L600
# VAE Decoder
vae_decoder = copy.deepcopy(pipeline.vae)

# ONNX export of torch.nn.functional.scaled_dot_product_attention not supported for < v2.1.0
if not is_torch_greater_or_equal_than_2_1:
vae_decoder = override_diffusers_2_0_attn_processors(vae_decoder)

vae_decoder.forward = lambda latent_sample: vae_decoder.decode(z=latent_sample)
models_for_export["vae_decoder"] = vae_decoder

text_encoder_2 = getattr(pipeline, "text_encoder_2", None)
if text_encoder_2 is not None:
text_encoder_2.config.output_hidden_states = True
text_encoder_2.text_model.config.output_hidden_states = True
models_for_export["text_encoder_2"] = text_encoder_2

return models_for_export


Expand Down Expand Up @@ -294,31 +341,58 @@ def get_diffusion_models_for_export(
`Dict[str, Tuple[Union[`PreTrainedModel`, `TFPreTrainedModel`], `ExportConfig`]: A Dict containing the model and
export configs for the different components of the model.
"""

models_for_export = _get_submodels_for_export_diffusion(pipeline)

# Text encoder
if "text_encoder" in models_for_export:
text_encoder_config_constructor = TasksManager.get_exporter_config_constructor(
model=pipeline.text_encoder,
exporter=exporter,
library_name="diffusers",
task="feature-extraction",
model=pipeline.text_encoder, exporter=exporter, library_name="diffusers", task="feature-extraction"
)
text_encoder_export_config = text_encoder_config_constructor(
pipeline.text_encoder.config, int_dtype=int_dtype, float_dtype=float_dtype
)
models_for_export["text_encoder"] = (models_for_export["text_encoder"], text_encoder_export_config)

# Text encoder 2
if "text_encoder_2" in models_for_export:
export_config_constructor = TasksManager.get_exporter_config_constructor(
model=pipeline.text_encoder_2, exporter=exporter, library_name="diffusers", task="feature-extraction"
)
export_config = export_config_constructor(
pipeline.text_encoder_2.config, int_dtype=int_dtype, float_dtype=float_dtype
)
models_for_export["text_encoder_2"] = (models_for_export["text_encoder_2"], export_config)

# Text encoder 3
if "text_encoder_3" in models_for_export:
export_config_constructor = TasksManager.get_exporter_config_constructor(
model=pipeline.text_encoder_3, exporter=exporter, library_name="diffusers", task="feature-extraction"
)
export_config = export_config_constructor(
pipeline.text_encoder_3.config, int_dtype=int_dtype, float_dtype=float_dtype
)
models_for_export["text_encoder_3"] = (models_for_export["text_encoder_3"], export_config)

# U-NET
export_config_constructor = TasksManager.get_exporter_config_constructor(
model=pipeline.unet,
exporter=exporter,
library_name="diffusers",
task="semantic-segmentation",
model_type="unet",
)
unet_export_config = export_config_constructor(pipeline.unet.config, int_dtype=int_dtype, float_dtype=float_dtype)
models_for_export["unet"] = (models_for_export["unet"], unet_export_config)
if "unet" in models_for_export:
export_config_constructor = TasksManager.get_exporter_config_constructor(
model=pipeline.unet, exporter=exporter, library_name="diffusers", task="semantic-segmentation"
)
unet_export_config = export_config_constructor(
pipeline.unet.config, int_dtype=int_dtype, float_dtype=float_dtype
)
models_for_export["unet"] = (models_for_export["unet"], unet_export_config)

# Transformer
if "transformer" in models_for_export:
export_config_constructor = TasksManager.get_exporter_config_constructor(
model=pipeline.transformer, exporter=exporter, library_name="diffusers", task="semantic-segmentation"
)
transformer_export_config = export_config_constructor(
pipeline.transformer.config, int_dtype=int_dtype, float_dtype=float_dtype
)
models_for_export["transformer"] = (models_for_export["transformer"], transformer_export_config)

# VAE Encoder https://github.com/huggingface/diffusers/blob/v0.11.1/src/diffusers/models/vae.py#L565
vae_encoder = models_for_export["vae_encoder"]
Expand All @@ -344,19 +418,6 @@ def get_diffusion_models_for_export(
vae_export_config = vae_config_constructor(vae_decoder.config, int_dtype=int_dtype, float_dtype=float_dtype)
models_for_export["vae_decoder"] = (vae_decoder, vae_export_config)

if "text_encoder_2" in models_for_export:
export_config_constructor = TasksManager.get_exporter_config_constructor(
model=pipeline.text_encoder_2,
exporter=exporter,
library_name="diffusers",
task="feature-extraction",
model_type="clip-text-with-projection",
)
export_config = export_config_constructor(
pipeline.text_encoder_2.config, int_dtype=int_dtype, float_dtype=float_dtype
)
models_for_export["text_encoder_2"] = (models_for_export["text_encoder_2"], export_config)

return models_for_export


Expand Down
Loading
Loading