From e9a4777c3605c650163fe5dca7c9268a9a13b47e Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Mon, 24 Jun 2024 20:45:55 -0700 Subject: [PATCH 1/2] Update model patcher to support torch.export --- optimum/exporters/onnx/model_patcher.py | 45 +++++++++++++++++-------- 1 file changed, 31 insertions(+), 14 deletions(-) diff --git a/optimum/exporters/onnx/model_patcher.py b/optimum/exporters/onnx/model_patcher.py index 0a10534354..29a0eb7d83 100644 --- a/optimum/exporters/onnx/model_patcher.py +++ b/optimum/exporters/onnx/model_patcher.py @@ -141,13 +141,18 @@ def __init__( allow_past_in_outputs = hasattr(self.real_config, "use_past") and self.real_config.use_past + # Workaround https://github.com/pytorch/pytorch/issues/122649. + @torch._dynamo.assume_constant_result + def _config_outputs(): + return config.outputs + @functools.wraps(self.orig_forward) def patched_forward(*args, **kwargs): signature = inspect.signature(self.orig_forward) args, kwargs = override_arguments(args, kwargs, signature, model_kwargs=self.model_kwargs) outputs = self.orig_forward(*args, **kwargs) - + config_outputs = _config_outputs() # This code block handles different cases of the filterd_outputs input to align it with the expected # format of outputs. It is common for the output type of a model to vary, such as tensor, list, # tuple, etc. For Transformers models, the output is encapsulated in a ModelOutput object that @@ -159,25 +164,25 @@ def patched_forward(*args, **kwargs): for name, value in outputs.items(): onnx_output_name = config.torch_to_onnx_output_map.get(name, name) if ( - onnx_output_name in config.outputs + onnx_output_name in config_outputs or (allow_past_in_outputs and name.startswith("past_key_values")) - or any(key.startswith(onnx_output_name) for key in config.outputs.keys()) + or any(key.startswith(onnx_output_name) for key in config_outputs.keys()) ): filterd_outputs[name] = value elif isinstance(outputs, (list, tuple)): - outputs_list = list(config.outputs.keys()) + outputs_list = list(config_outputs.keys()) dict(zip(outputs_list, outputs)) else: - if len(config.outputs) > 1: - num_outputs = len(config.outputs) - outputs_str = ", ".join(config.outputs.keys()) + if len(config_outputs) > 1: + num_outputs = len(config_outputs) + outputs_str = ", ".join(config_outputs.keys()) raise ValueError( - f"config.outputs should have only one outputs, but it has {num_outputs} keys: {outputs_str}" + f"config_outputs should have only one outputs, but it has {num_outputs} keys: {outputs_str}" ) else: - name = list(config.outputs.keys())[0] + name = list(config_outputs.keys())[0] filterd_outputs[name] = outputs - name = list(config.outputs.keys())[0] + name = list(config_outputs.keys())[0] filterd_outputs[name] = outputs return filterd_outputs @@ -223,21 +228,27 @@ def __init__( if model.config.model_type == "pix2struct" and allow_past_in_outputs: model.config.text_config.use_cache = True + # Workaround https://github.com/pytorch/pytorch/issues/122649. + @torch._dynamo.assume_constant_result + def _config_outputs(): + return config.outputs + @functools.wraps(self.orig_forward) def patched_forward(*args, **kwargs): signature = inspect.signature(self.orig_forward) args, kwargs = override_arguments(args, kwargs, signature, model_kwargs=self.model_kwargs) outputs = self.orig_forward(*args, **kwargs) + config_outputs = _config_outputs() # Filter out cross attention past key values output from the decoder using KV cache, as they are constants. filterd_outputs = {} for name, value in outputs.items(): onnx_output_name = config.torch_to_onnx_output_map.get(name, name) if ( - onnx_output_name in config.outputs + onnx_output_name in config_outputs or (allow_past_in_outputs and name.startswith("past_key_values")) - or any(key.startswith(onnx_output_name) for key in config.outputs.keys()) + or any(key.startswith(onnx_output_name) for key in config_outputs.keys()) ): if name != "past_key_values": if self.real_config._behavior == "decoder" and name == "encoder_last_hidden_state": @@ -473,6 +484,11 @@ def __init__( allow_past_in_outputs = hasattr(self.real_config, "use_past") and self.real_config.use_past + # Workaround https://github.com/pytorch/pytorch/issues/122649. + @torch._dynamo.assume_constant_result + def _config_outputs(): + return config.outputs + @functools.wraps(self.orig_forward) def patched_forward(*args, **kwargs): model_kwargs = self.model_kwargs @@ -484,14 +500,15 @@ def patched_forward(*args, **kwargs): args, kwargs = override_arguments(args, kwargs, signature, model_kwargs=model_kwargs) outputs = self.orig_forward(*args, **kwargs) + config_outputs = _config_outputs() filterd_outputs = {} for name, value in outputs.items(): onnx_output_name = config.torch_to_onnx_output_map.get(name, name) if ( - onnx_output_name in config.outputs + onnx_output_name in config_outputs or (allow_past_in_outputs and name.startswith("past_key_values")) - or any(key.startswith(onnx_output_name) for key in config.outputs.keys()) + or any(key.startswith(onnx_output_name) for key in config_outputs.keys()) ): filterd_outputs[name] = value return filterd_outputs From 7e0597b80c0ef501c1d15ab42c4acc69031477a2 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Wed, 3 Jul 2024 15:06:26 -0700 Subject: [PATCH 2/2] Use torch --- optimum/exporters/onnx/__main__.py | 2 ++ optimum/exporters/onnx/convert.py | 2 +- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/optimum/exporters/onnx/__main__.py b/optimum/exporters/onnx/__main__.py index 585a779c2e..dfcd2adbe9 100644 --- a/optimum/exporters/onnx/__main__.py +++ b/optimum/exporters/onnx/__main__.py @@ -34,6 +34,8 @@ if is_torch_available(): import torch + import torch_onnx + torch_onnx.patch_torch(error_report=True, profile=True, dump_exported_program=True) from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Union diff --git a/optimum/exporters/onnx/convert.py b/optimum/exporters/onnx/convert.py index 4d5a2afc37..5f50340a94 100644 --- a/optimum/exporters/onnx/convert.py +++ b/optimum/exporters/onnx/convert.py @@ -580,7 +580,7 @@ def remap(value): f=output.as_posix(), input_names=input_names, output_names=output_names, - dynamic_axes=dynamix_axes, + # dynamic_axes=dynamix_axes, do_constant_folding=do_constant_folding, opset_version=opset, )