From deac6885b5e36332658a88264e9b272fdb55147a Mon Sep 17 00:00:00 2001 From: Petros Toupas Date: Sat, 14 Dec 2024 00:34:53 +0900 Subject: [PATCH] Add ONNXModifier for optimising the ONNX model before converting for RVC4 execution (#55) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Martin Kozlovský --- .github/workflows/unittests.yaml | 1 + modelconverter/cli/utils.py | 1 + modelconverter/packages/base_exporter.py | 1 + modelconverter/packages/rvc4/exporter.py | 20 + modelconverter/utils/__init__.py | 3 +- modelconverter/utils/config.py | 1 + modelconverter/utils/onnx_tools.py | 1001 +++++++++++++++++++++- requirements.txt | 3 + tests/test_utils/conftest.py | 11 + tests/test_utils/test_config.py | 1 + tests/test_utils/test_modifier.py | 240 ++++++ 11 files changed, 1278 insertions(+), 5 deletions(-) create mode 100644 tests/test_utils/conftest.py create mode 100644 tests/test_utils/test_modifier.py diff --git a/.github/workflows/unittests.yaml b/.github/workflows/unittests.yaml index c48b3b9..39fbfcb 100644 --- a/.github/workflows/unittests.yaml +++ b/.github/workflows/unittests.yaml @@ -33,5 +33,6 @@ jobs: AWS_SECRET_ACCESS_KEY: ${{ secrets.AWS_SECRET_ACCESS_KEY }} AWS_S3_ENDPOINT_URL: ${{ secrets.AWS_S3_ENDPOINT_URL }} GOOGLE_APPLICATION_CREDENTIALS: ${{ secrets.GCP_CREDENTIALS }} + HUB_AI_API_KEY: ${{ secrets.HUB_AI_API_KEY }} run: python -m pytest tests/test_utils diff --git a/modelconverter/cli/utils.py b/modelconverter/cli/utils.py index 7c27c6e..ceedd6a 100644 --- a/modelconverter/cli/utils.py +++ b/modelconverter/cli/utils.py @@ -376,6 +376,7 @@ def get_target_specific_options( json_cfg = cfg.model_dump(mode="json") options = { "disable_onnx_simplification": cfg.disable_onnx_simplification, + "disable_onnx_optimisation": cfg.disable_onnx_optimisation, "inputs": json_cfg["inputs"], } if target == "rvc4": diff --git a/modelconverter/packages/base_exporter.py b/modelconverter/packages/base_exporter.py index 6ef47a7..d98b4b7 100644 --- a/modelconverter/packages/base_exporter.py +++ b/modelconverter/packages/base_exporter.py @@ -39,6 +39,7 @@ def __init__( self.outputs = {out.name: out for out in config.outputs} self.keep_intermediate_outputs = config.keep_intermediate_outputs self.disable_onnx_simplification = config.disable_onnx_simplification + self.disable_onnx_optimisation = config.disable_onnx_optimisation self.model_name = self.input_model.stem diff --git a/modelconverter/packages/rvc4/exporter.py b/modelconverter/packages/rvc4/exporter.py index ae212b6..fd5f94e 100644 --- a/modelconverter/packages/rvc4/exporter.py +++ b/modelconverter/packages/rvc4/exporter.py @@ -1,3 +1,4 @@ +import os import shutil import subprocess import time @@ -6,6 +7,7 @@ from typing import Any, Dict, List, NamedTuple, Optional, cast from modelconverter.utils import ( + ONNXModifier, exit_with, onnx_attach_normalization_to_inputs, read_image, @@ -57,6 +59,24 @@ def __init__(self, config: SingleStageConfig, output_dir: Path): self._attach_suffix(self.input_model, "modified.onnx"), self.inputs, ) + + if not config.disable_onnx_optimisation: + onnx_modifier = ONNXModifier( + model_path=self.input_model, + output_path=self._attach_suffix( + self.input_model, "modified_optimised.onnx" + ), + ) + + if ( + onnx_modifier.modify_onnx() + and onnx_modifier.compare_outputs() + ): + logger.info("ONNX model has been optimised for RVC4.") + shutil.move(onnx_modifier.output_path, self.input_model) + else: + if os.path.exists(onnx_modifier.output_path): + os.remove(onnx_modifier.output_path) else: logger.warning( "Input file type is not ONNX. Skipping pre-processing." diff --git a/modelconverter/utils/__init__.py b/modelconverter/utils/__init__.py index 6688f0f..9566125 100644 --- a/modelconverter/utils/__init__.py +++ b/modelconverter/utils/__init__.py @@ -28,7 +28,7 @@ modelconverter_config_to_nn, process_nn_archive, ) -from .onnx_tools import onnx_attach_normalization_to_inputs +from .onnx_tools import ONNXModifier, onnx_attach_normalization_to_inputs from .subprocess import subprocess_run __all__ = [ @@ -37,6 +37,7 @@ "S3Exception", "SubprocessException", "exit_with", + "ONNXModifier", "onnx_attach_normalization_to_inputs", "read_calib_dir", "read_image", diff --git a/modelconverter/utils/config.py b/modelconverter/utils/config.py index e617434..008362d 100644 --- a/modelconverter/utils/config.py +++ b/modelconverter/utils/config.py @@ -284,6 +284,7 @@ class SingleStageConfig(CustomBaseModel): keep_intermediate_outputs: bool = True disable_onnx_simplification: bool = False + disable_onnx_optimisation: bool = False output_remote_url: Optional[str] = None put_file_plugin: Optional[str] = None diff --git a/modelconverter/utils/onnx_tools.py b/modelconverter/utils/onnx_tools.py index 98138a8..fe01ab5 100644 --- a/modelconverter/utils/onnx_tools.py +++ b/modelconverter/utils/onnx_tools.py @@ -1,10 +1,13 @@ import logging from pathlib import Path -from typing import Dict +from typing import Dict, List, Optional, Tuple +import numpy as np import onnx +import onnx_graphsurgeon as gs +import onnxoptimizer from onnx import checker, helper -from onnx.onnx_pb import TensorProto +from onnxsim import simplify from modelconverter.utils.config import InputConfig @@ -37,6 +40,7 @@ def onnx_attach_normalization_to_inputs( for input_tensor in graph.input: input_name = input_tensor.name + input_dtype = input_tensor.type.tensor_type.elem_type if input_name not in input_configs: continue cfg = input_configs[input_name] @@ -109,7 +113,7 @@ def onnx_attach_normalization_to_inputs( mean_tensor = helper.make_tensor( f"mean_{input_name}", - TensorProto.FLOAT, + input_dtype, [1, len(cfg.mean_values), 1, 1] if layout == "NCHW" else [1, 1, 1, len(cfg.mean_values)], @@ -137,7 +141,7 @@ def onnx_attach_normalization_to_inputs( scale_tensor = helper.make_tensor( f"scale_{input_name}", - TensorProto.FLOAT, + input_dtype, [1, len(cfg.scale_values), 1, 1] if layout == "NCHW" else [1, 1, 1, len(cfg.scale_values)], @@ -176,3 +180,992 @@ def onnx_attach_normalization_to_inputs( onnx.save(model, str(save_path)) return save_path + + +class ONNXModifier: + """ONNX model modifier class to optimize and modify the ONNX model. + + Attributes: + ---------- + model_path : Path + Path to the base ONNX model + output_path : Path + Path to save the modified ONNX model + """ + + def __init__(self, model_path: Path, output_path: Path) -> None: + self.model_path = model_path + self.output_path = output_path + self.load_onnx() + self.prev_onnx_model = self.onnx_model + self.prev_onnx_gs = self.onnx_gs + + def load_onnx(self) -> None: + """Load the ONNX model and store it as onnx.ModelProto and + onnx_graphsurgeon.GraphSurgeon graph.""" + + logger.info(f"Loading model: {self.model_path.stem}") + + self.onnx_model, _ = simplify( + self.model_path.as_posix(), perform_optimization=True + ) + + self.dtype = onnx.mapping.TENSOR_TYPE_TO_NP_TYPE[ + self.onnx_model.graph.input[0].type.tensor_type.elem_type + ] + self.input_shape = [ + dim.dim_value + for dim in self.onnx_model.graph.input[ + 0 + ].type.tensor_type.shape.dim + ] + self.has_dynamic_shape = any( + dim == 0 or dim is None for dim in self.input_shape + ) + + self.onnx_gs = gs.import_onnx(self.onnx_model) + + def optimize_onnx(self, passes: Optional[List[str]] = None) -> None: + """Optimize and simplify the ONNX model's graph. + + @param passes: List of optimization passes to apply to the ONNX model + @type passes: Optional[List[str]] + """ + + optimised_onnx_model = onnxoptimizer.optimize( + self.onnx_model, passes=passes + ) + + optimised_onnx_model, _ = simplify( + optimised_onnx_model, perform_optimization=False + ) + + onnx.checker.check_model(optimised_onnx_model) + + self.onnx_model, self.onnx_gs = ( + optimised_onnx_model, + gs.import_onnx(optimised_onnx_model), + ) + + def export_onnx(self, passes: Optional[List[str]] = None) -> None: + """Export the modified ONNX model to the output path. + + @param passes: List of optimization passes to apply to the ONNX model + @type passes: Optional[List[str]] + """ + + self.optimize_onnx(passes) + + onnx.save(self.onnx_model, self.output_path) + + def add_outputs(self, output_names: List[str]) -> None: + """Add output nodes to the ONNX model. + + @param output_names: List of output node names to add to the ONNX model + @type output_names: List[str] + """ + + graph_outputs = [output.name for output in self.onnx_gs.outputs] + for name, tensor in self.onnx_gs.tensors().items(): + if name in output_names and name not in graph_outputs: + self.onnx_gs.outputs.append(tensor) + self.onnx_model = gs.export_onnx(self.onnx_gs) + + def get_constant_map(self, graph: gs.Graph) -> Dict[str, np.ndarray]: + """Extract constant tensors from the GraphSurgeon graph. + + @param graph: GraphSurgeon graph + @type graph: gs.Graph + @return: Constant tensor map with tensor name as key and tensor value as value + @rtype: Dict[str, np.ndarray] + """ + + return { + tensor.name: tensor.values + for tensor in graph.tensors().values() + if isinstance(tensor, gs.Constant) + } + + @staticmethod + def get_constant_value( + node: gs.Node, constant_map: Dict[str, np.ndarray] + ) -> Optional[Tuple[np.ndarray, int]]: + """Returns the constant value of a node if it is a constant node. + + @param node: Node to check + @type node: gs.Node + @param constant_map: Constant tensor map with tensor name as key and tensor + value as value + @type constant_map: Dict[str, np.ndarray] + @return: Constant tensor value and index + @rtype: Optional[Tuple[np.ndarray, int]] + """ + + for idx, input in enumerate(node.inputs): + if input.name in constant_map: + return (constant_map[input.name], idx) + + return None + + @staticmethod + def get_variable_input(node: gs.Node) -> Optional[Tuple[gs.Variable, int]]: + """Returns the variable input of a node. + + @param node: Node to check + @type node: gs.Node + @return: Variable input and index + @rtype: Optional[Tuple[gs.Variable, int]] + """ + + for idx, input in enumerate(node.inputs): + if isinstance(input, gs.Variable): + return (input, idx) + + return None + + def graph_cleanup( + self, + nodes_to_add: List[gs.Node], + nodes_to_remove: List[gs.Node], + connections_to_fix: List[Tuple[gs.Variable, gs.Variable]], + ) -> None: + """Cleanup the graph by adding new nodes, removing old nodes, and fixing + connections. + + @param nodes_to_add: List of nodes to add to the graph + @type nodes_to_add: List[gs.Node] + @param nodes_to_remove: List of nodes to remove from the graph + @type nodes_to_remove: List[gs.Node] + @param connections_to_fix: List of connections to fix in the graph + @type connections_to_fix: List[Tuple[gs.Variable, gs.Variable]] + """ + + for node in nodes_to_add: + self.onnx_gs.nodes.append(node) + + for old_input, new_input in connections_to_fix: + for node in self.onnx_gs.nodes: + for idx, input in enumerate(node.inputs): + if input == old_input: + node.inputs[idx] = new_input + + for node in nodes_to_remove: + self.onnx_gs.nodes.remove(node) + + self.onnx_gs.cleanup( + remove_unused_node_outputs=True, remove_unused_graph_inputs=True + ).toposort() + + def substitute_node_by_type( + self, source_node: str, target_node: str + ) -> None: + """Substitute a source node of a particular type with a target node of a + different type. Currently, only Sub -> Add and Div -> Mul substitutions are + allowed. + + @param source_node: Source node type to substitute + @type source_node: str + @param target_node: Target node type to substitute with + @type target_node: str + """ + + if source_node not in ["Sub", "Div"] or target_node not in [ + "Add", + "Mul", + ]: + raise ValueError( + "Invalid source or target node type. Valid source types: Sub, Div. Valid target types: Add, Mul." + ) + + if ( + source_node == "Sub" + and target_node == "Mul" + or source_node == "Div" + and target_node == "Add" + ): + raise ValueError( + "Invalid substitution. Available substitutions: Sub -> Add, Div -> Mul" + ) + + constant_map = self.get_constant_map(self.onnx_gs) + + def create_new_node( + node: gs.Node, target_node: str, const_idx: int + ) -> Optional[gs.Node]: + if const_idx == 0: + return None + + first_input = node.inputs[0] + second_input = node.inputs[const_idx] + if target_node == "Add": + new_cost_val = -second_input.values + return gs.Node( + op="Add", + inputs=[ + first_input, + gs.Constant( + name=f"{second_input.name}/Subtitute", + values=np.array( + new_cost_val, dtype=second_input.dtype + ), + ), + ], + outputs=[gs.Variable(name=f"{node.name}/Add_output")], + name=f"{node.name}/To_Add", + ) + elif target_node == "Mul": + new_cost_val = 1.0 / second_input.values + if second_input.dtype not in [ + np.float16, + np.float32, + np.float64, + ]: + return None + return gs.Node( + op="Mul", + inputs=[ + first_input, + gs.Constant( + name=f"{second_input.name}/Subtitute", + values=np.array( + new_cost_val, dtype=second_input.dtype + ), + ), + ], + outputs=[gs.Variable(name=f"{node.name}/Mul_output")], + name=f"{node.name}/To_Mul", + ) + + nodes_to_add = [] + nodes_to_remove = [] + connections_to_fix = [] + + for node in self.onnx_gs.nodes: + if node.op == source_node: + constant = self.get_constant_value(node, constant_map) + if constant is not None: + _, const_idx = constant + new_node = create_new_node(node, target_node, const_idx) + if new_node is not None: + nodes_to_add.append(new_node) + connections_to_fix.append( + ( + node.outputs[0], + new_node.outputs[0], + ) + ) + nodes_to_remove.append(node) + + self.graph_cleanup(nodes_to_add, nodes_to_remove, connections_to_fix) + self.onnx_model = gs.export_onnx(self.onnx_gs) + + self.optimize_onnx(passes=["fuse_add_bias_into_conv"]) + + def fuse_add_mul_to_bn(self) -> None: + """Fuse Add/Sub and Mul nodes that come immediately after a Conv node into a + BatchNormalization node. + + The fusion patterns considered are: + 1. Conv -> Add -> Mul + 2. Conv -> Mul -> Add + 3. Conv -> Mul + 4. Conv -> Add + """ + + FUSION_PATTERNS = [ + ("Conv", "Add", "Mul"), + ("Conv", "Mul", "Add"), + ("Conv", "Mul"), + ("Conv", "Add"), + ] + + constant_map = self.get_constant_map(self.onnx_gs) + + def create_batch_norm_node( + name: str, input_tensor: gs.Variable, scale: float, bias: float + ) -> gs.Node: + conv_channels = input_tensor.shape[1] + scale_values = np.array( + [scale] * conv_channels, dtype=self.dtype + ).squeeze() + bias_values = np.array( + [bias] * conv_channels, dtype=self.dtype + ).squeeze() + mean_values = np.zeros_like(scale_values) + var_values = np.ones_like(scale_values) + scale_tensor = gs.Constant( + name=f"{name}_scale", + values=scale_values, + ) + bias_tensor = gs.Constant( + name=f"{name}_bias", + values=bias_values, + ) + mean_tensor = gs.Constant( + name=f"{name}_mean", + values=mean_values, + ) + var_tensor = gs.Constant( + name=f"{name}_var", + values=var_values, + ) + bn_node = gs.Node( + op="BatchNormalization", + inputs=[ + input_tensor, + scale_tensor, + bias_tensor, + mean_tensor, + var_tensor, + ], + outputs=[gs.Variable(name=f"{name}_output")], + name=name, + ) + return bn_node + + all_sequences = [] + + for pattern in FUSION_PATTERNS: + for node in self.onnx_gs.nodes: + if node.op != pattern[0]: + continue + + sequence = [node] + current_node = node + for op_type in pattern[1:]: + next_nodes = [ + n + for n in self.onnx_gs.nodes + if n.inputs + and current_node.outputs[0] in n.inputs + and n.op == op_type + ] + if not next_nodes: + break + current_node = next_nodes[0] + sequence.append(current_node) + + if len(sequence) == len(pattern): + all_sequences.append(sequence) + + longest_sequences = [] + for seq in all_sequences: + is_subset = any( + all(node in longer_seq for node in seq) + and len(seq) < len(longer_seq) + for longer_seq in all_sequences + ) + if not is_subset: + longest_sequences.append(seq) + + nodes_to_add = [] + nodes_to_remove = [] + connections_to_fix = [] + + for sequence in longest_sequences: + valid_fusion = True + scale, bias = 1.0, 0.0 + + conv_node = None + for seq_node in sequence: + if seq_node.op == "Conv": + conv_node = seq_node + continue + + constant = self.get_constant_value(seq_node, constant_map) + if constant is None: + valid_fusion = False + break + + constant_val, _ = constant + + if seq_node.op == "Add": + bias += constant_val + elif seq_node.op == "Sub": + bias -= constant_val + elif seq_node.op == "Mul": + scale *= constant_val + + if ( + not valid_fusion + or not conv_node + or len(conv_node.outputs[0].outputs) > 1 + ): + continue + + bn_name = f"BatchNorm_{conv_node.name.replace('/', '', 1)}" + + bn_node = create_batch_norm_node( + bn_name, conv_node.outputs[0], scale, bias + ) + nodes_to_add.append(bn_node) + + if sequence[0].op == "Conv": + connections_to_fix.append( + ( + sequence[-1].outputs[0], + bn_node.outputs[0], + ) + ) + + for seq_node in sequence: + if seq_node.op != "Conv": + nodes_to_remove.append(seq_node) + + self.graph_cleanup(nodes_to_add, nodes_to_remove, connections_to_fix) + self.onnx_model = gs.export_onnx(self.onnx_gs) + + self.optimize_onnx(passes=["fuse_bn_into_conv"]) + + def fuse_single_add_mul_to_conv(self) -> None: + """Fuse Add and Mul nodes that precede a Conv node directly into the Conv + node.""" + + nodes_to_remove = [] + connections_to_fix = [] + + constant_map = self.get_constant_map(self.onnx_gs) + + for node in self.onnx_gs.nodes: + if node.op == "Mul": + mul_node = node + if len(mul_node.outputs[0].outputs) > 1: + continue + + conv_node = next( + (n for n in mul_node.outputs[0].outputs if n.op == "Conv"), + None, + ) + if conv_node is None: + continue + + constant = self.get_constant_value(mul_node, constant_map) + if constant is None: + continue + + mul_value, _ = constant + + conv_weights = conv_node.inputs[1] + + new_weights = conv_weights.values * mul_value + + conv_node.inputs[1] = gs.Constant( + name=conv_weights.name, + values=new_weights, + ) + + nodes_to_remove.append(mul_node) + + connections_to_fix.append( + ( + mul_node.outputs[0], + mul_node.inputs[0], + ) + ) + + if node.op == "Add": + add_node = node + if len(add_node.outputs[0].outputs) > 1: + continue + + conv_node = next( + (n for n in add_node.outputs[0].outputs if n.op == "Conv"), + None, + ) + if ( + conv_node is None + or ( + "pads" in conv_node.attrs + and any(conv_node.attrs["pads"]) + ) + or ( + "auto_pad" in conv_node.attrs + and conv_node.attrs["auto_pad"] + in ["SAME_UPPER", "SAME_LOWER"] + ) + ): + continue + + constant = self.get_constant_value(add_node, constant_map) + if constant is None: + continue + + add_value, _ = constant + + conv_weights = conv_node.inputs[1] + conv_bias = ( + conv_node.inputs[2] if len(conv_node.inputs) > 2 else None + ) + + if conv_bias is not None: + new_bias = conv_bias.values + np.sum( + add_value * conv_weights.values, axis=(1, 2, 3) + ) + if new_bias.shape != conv_bias.values.shape: + raise ValueError( + f"New bias shape: {new_bias.shape} != Old bias shape: {conv_bias.values.shape}" + ) + else: + new_bias = np.sum( + add_value * conv_weights.values, axis=(1, 2, 3) + ) + if new_bias.shape != conv_weights.shape[0]: + raise ValueError( + f"New bias shape: {new_bias.shape} != Conv weights shape: {conv_weights.shape[0]}" + ) + + if conv_bias is not None: + conv_node.inputs[2] = gs.Constant( + name=conv_bias.name, + values=new_bias, + ) + else: + conv_node.inputs.append( + gs.Constant( + name=f"{conv_node.name}_bias", + values=new_bias, + ) + ) + + nodes_to_remove.append(add_node) + + connections_to_fix.append( + ( + add_node.outputs[0], + add_node.inputs[0], + ) + ) + + self.graph_cleanup([], nodes_to_remove, connections_to_fix) + self.onnx_model = gs.export_onnx(self.onnx_gs) + + self.optimize_onnx() + + def fuse_comb_add_mul_to_conv(self) -> None: + """Fuse combinations of Add and Mul nodes preceding a Conv node directly into + the Conv node itself. + + The fusion patterns considered are: + 1. Add -> Mul -> Conv + 2. Mul -> Add -> Conv + """ + + nodes_to_remove = [] + connections_to_fix = [] + + constant_map = self.get_constant_map(self.onnx_gs) + + for node in self.onnx_gs.nodes: + if node.op == "Mul": + mul_node = node + + add_node = next( + (n for n in mul_node.outputs[0].outputs if n.op == "Add"), + None, + ) + if add_node is None: + continue + + conv_node = next( + (n for n in add_node.outputs[0].outputs if n.op == "Conv"), + None, + ) + if ( + conv_node is None + or ( + "pads" in conv_node.attrs + and any(conv_node.attrs["pads"]) + ) + or ( + "auto_pad" in conv_node.attrs + and conv_node.attrs["auto_pad"] + in ["SAME_UPPER", "SAME_LOWER"] + ) + ): + continue + + constant = self.get_constant_value(mul_node, constant_map) + if constant is None: + continue + mul_value, _ = constant + + constant = self.get_constant_value(add_node, constant_map) + if constant is None: + continue + add_value, _ = constant + + conv_weights = conv_node.inputs[1] + conv_bias = ( + conv_node.inputs[2] if len(conv_node.inputs) > 2 else None + ) + + new_weights = conv_weights.values * mul_value + + conv_node.inputs[1] = gs.Constant( + name=conv_weights.name, + values=new_weights, + ) + + if conv_bias is not None: + new_bias = conv_bias.values + np.sum( + add_value * conv_weights.values, axis=(1, 2, 3) + ) + if new_bias.shape != conv_bias.values.shape: + raise ValueError( + f"New bias shape: {new_bias.shape} != Old bias shape: {conv_bias.values.shape}" + ) + conv_node.inputs[2] = gs.Constant( + name=conv_bias.name, + values=new_bias, + ) + else: + new_bias = np.sum( + add_value * conv_weights.values, axis=(1, 2, 3) + ) + if new_bias.shape != conv_weights.shape[0]: + raise ValueError( + f"New bias shape: {new_bias.shape} != Conv weights shape: {conv_weights.shape[0]}" + ) + conv_node.inputs.append( + gs.Constant( + name=f"{conv_node.name}_bias", + values=new_bias, + ) + ) + + variable = self.get_variable_input(mul_node) + if variable is None: + continue + _, mul_idx = variable + + nodes_to_remove.append(mul_node) + nodes_to_remove.append(add_node) + + connections_to_fix.append( + ( + add_node.outputs[0], + mul_node.inputs[mul_idx], + ) + ) + + if node.op == "Add": + add_node = node + + mul_node = next( + (n for n in add_node.outputs[0].outputs if n.op == "Mul"), + None, + ) + if mul_node is None: + continue + + conv_node = next( + (n for n in mul_node.outputs[0].outputs if n.op == "Conv"), + None, + ) + if ( + conv_node is None + or ( + "pads" in conv_node.attrs + and any(conv_node.attrs["pads"]) + ) + or ( + "auto_pad" in conv_node.attrs + and conv_node.attrs["auto_pad"] + in ["SAME_UPPER", "SAME_LOWER"] + ) + ): + continue + + constant = self.get_constant_value(add_node, constant_map) + if constant is None: + continue + add_value, _ = constant + + constant = self.get_constant_value(mul_node, constant_map) + if constant is None: + continue + mul_value, _ = constant + + add_value *= mul_value + + conv_weights = conv_node.inputs[1] + conv_bias = ( + conv_node.inputs[2] if len(conv_node.inputs) > 2 else None + ) + + if conv_bias is not None: + new_bias = conv_bias.values + np.sum( + add_value * conv_weights.values, axis=(1, 2, 3) + ) + if new_bias.shape != conv_bias.values.shape: + raise ValueError( + f"New bias shape: {new_bias.shape} != Old bias shape: {conv_bias.values.shape}" + ) + conv_node.inputs[2] = gs.Constant( + name=conv_bias.name, + values=new_bias, + ) + else: + new_bias = np.sum( + add_value * conv_weights.values, axis=(1, 2, 3) + ) + if new_bias.shape != conv_weights.shape[0]: + raise ValueError( + f"New bias shape: {new_bias.shape} != Conv weights shape: {conv_weights.shape[0]}" + ) + conv_node.inputs.append( + gs.Constant( + name=f"{conv_node.name}_bias", + values=new_bias, + ) + ) + + new_weights = conv_weights.values * mul_value + + conv_node.inputs[1] = gs.Constant( + name=conv_weights.name, + values=new_weights, + ) + + variable = self.get_variable_input(add_node) + if variable is None: + continue + _, add_idx = variable + + nodes_to_remove.append(add_node) + nodes_to_remove.append(mul_node) + + connections_to_fix.append( + ( + mul_node.outputs[0], + add_node.inputs[add_idx], + ) + ) + + self.graph_cleanup([], nodes_to_remove, connections_to_fix) + self.onnx_model = gs.export_onnx(self.onnx_gs) + + self.optimize_onnx() + + def fuse_split_concat_to_conv(self) -> None: + """Fuse Split and Concat nodes that come before a Conv node into the Conv node. + + If any intermediate nodes have channel dimensions, the order of the channels is + reversed. + """ + + nodes_to_remove = [] + connections_to_fix = [] + + for node in self.onnx_gs.nodes: + if node.op == "Conv": + break + + if node.op == "Split": + split_node = node + + concat_node = next( + ( + n + for n in split_node.outputs[0].outputs + if n.op == "Concat" + ), + None, + ) + if concat_node is None: + continue + + intermediate_nodes = [] + current_node = concat_node + while current_node.op != "Conv": + current_node = next( + (n for n in current_node.outputs[0].outputs), None + ) + intermediate_nodes.append(current_node) + if current_node is None: + break + + conv_node = intermediate_nodes[-1] + if conv_node.op != "Conv": + continue + + conv_weights = conv_node.inputs[1] + + if split_node.attrs["axis"] != concat_node.attrs["axis"]: + raise ValueError( + f"Split and Concat axis mismatch: {split_node.attrs['axis']} != {concat_node.attrs['axis']}" + ) + + channels_axis = split_node.attrs["axis"] + if conv_weights.shape[channels_axis] not in [1, 3]: + break + + for inter_node in intermediate_nodes[:-1]: + constant = self.get_constant_value( + inter_node, self.get_constant_map(self.onnx_gs) + ) + if constant is None: + continue + constant_value, constant_idx = constant + if constant_value.ndim == 1: + continue + + if ( + constant_value.shape[channels_axis] + != conv_weights.values.shape[1] + ): + logger.warning( + f"Spatial dimensions mismatch between Conv and intermediate node {inter_node.name}: {constant_value.shape[channels_axis]} != {conv_weights.values.shape[1]}, discarding this step." + ) + + inter_node.inputs[constant_idx].values = np.flip( + constant_value, axis=channels_axis + ) + + conv_weights.values = np.flip( + conv_weights.values, axis=channels_axis + ) + + nodes_to_remove.append(split_node) + nodes_to_remove.append(concat_node) + + connections_to_fix.append( + ( + concat_node.outputs[0], + split_node.inputs[0], + ) + ) + + break + + self.graph_cleanup([], nodes_to_remove, connections_to_fix) + self.onnx_model = gs.export_onnx(self.onnx_gs) + + self.optimize_onnx() + + def modify_onnx(self) -> bool: + """Modify the ONNX model by applying a series of optimizations. + + @param passes: List of optimization passes to apply to the ONNX model + @type passes: Optional[List[str]] + """ + if self.has_dynamic_shape: + logger.warning( + "Identified dynamic input shape, skipping model modifications..." + ) + return False + + try: + logger.debug("Substituting Div -> Mul nodes...") + self.substitute_node_by_type(source_node="Div", target_node="Mul") + if not self.compare_outputs(from_modelproto=True): + logger.warning( + "Failed to substitute Div -> Mul nodes, reverting changes..." + ) + self.onnx_model = self.prev_onnx_model + self.onnx_gs = self.prev_onnx_gs + + logger.debug("Substituting Sub -> Add nodes...") + self.substitute_node_by_type(source_node="Sub", target_node="Add") + if not self.compare_outputs(from_modelproto=True): + logger.warning( + "Failed to substitute Sub -> Add nodes, reverting changes..." + ) + self.onnx_model = self.prev_onnx_model + self.onnx_gs = self.prev_onnx_gs + + logger.debug( + "Fusing Add and Mul nodes to BatchNormalization nodes and then into Conv nodes..." + ) + self.fuse_add_mul_to_bn() + if not self.compare_outputs(from_modelproto=True): + logger.warning( + "Failed to fuse Add and Mul nodes to BatchNormalization nodes, reverting changes..." + ) + self.onnx_model = self.prev_onnx_model + self.onnx_gs = self.prev_onnx_gs + + logger.debug("Fusing Add and Mul nodes to Conv nodes...") + self.fuse_comb_add_mul_to_conv() + if not self.compare_outputs(from_modelproto=True): + logger.warning( + "Failed to fuse Add and Mul nodes (combined) to Conv nodes, reverting changes..." + ) + self.onnx_model = self.prev_onnx_model + self.onnx_gs = self.prev_onnx_gs + self.fuse_single_add_mul_to_conv() + if not self.compare_outputs(from_modelproto=True): + logger.warning( + "Failed to fuse Add and Mul nodes (single) to Conv nodes, reverting changes..." + ) + self.onnx_model = self.prev_onnx_model + self.onnx_gs = self.prev_onnx_gs + + logger.debug("Fusing Split and Concat nodes to Conv nodes...") + self.fuse_split_concat_to_conv() + if not self.compare_outputs(from_modelproto=True): + logger.warning( + "Failed to fuse Split and Concat nodes to Conv nodes, reverting changes..." + ) + self.onnx_model = self.prev_onnx_model + self.onnx_gs = self.prev_onnx_gs + + self.export_onnx() + except Exception as e: + logger.error(f"Failed to modify the ONNX model: {e}") + return False + return True + + def compare_outputs(self, from_modelproto: bool = False) -> bool: + """Compare the outputs of two ONNX models. + + @param half: Flag to use half precision for the input tensors + @type half: bool + """ + + import onnxruntime as ort + + ort.set_default_logger_severity(3) + + if from_modelproto: + onnx_model_1 = self.prev_onnx_model.SerializeToString() + onnx_model_2 = self.onnx_model.SerializeToString() + else: + onnx_model_1 = self.model_path.as_posix() + onnx_model_2 = self.output_path.as_posix() + + ort_session_1 = ort.InferenceSession(onnx_model_1) + ort_session_2 = ort.InferenceSession(onnx_model_2) + + inputs = dict() + for input in ort_session_1.get_inputs(): + if input.type in ["tensor(float64)"]: + input_type = np.float64 + elif input.type in ["tensor(float32)", "tensor(float)"]: + input_type = np.float32 + elif input.type in ["tensor(float16)"]: + input_type = np.float16 + elif input.type in ["tensor(int64)"]: + input_type = np.int64 + elif input.type in ["tensor(int32)"]: + input_type = np.int32 + elif input.type in ["tensor(int16)"]: + input_type = np.int16 + elif input.type in ["tensor(int8)"]: + input_type = np.int8 + + inputs[input.name] = np.random.rand(*input.shape).astype( + input_type + ) + + outputs_1 = ort_session_1.run(None, inputs) + + outputs_2 = ort_session_2.run(None, inputs) + + equal_outputs = True + for out1, out2 in zip(outputs_1, outputs_2): + equal_outputs = equal_outputs and np.allclose( + out1, out2, rtol=5e-3, atol=5e-3 + ) + + return equal_outputs diff --git a/requirements.txt b/requirements.txt index 42edf7c..534d102 100644 --- a/requirements.txt +++ b/requirements.txt @@ -10,3 +10,6 @@ s3transfer typer docker keyring +onnx_graphsurgeon +onnxoptimizer +wget \ No newline at end of file diff --git a/tests/test_utils/conftest.py b/tests/test_utils/conftest.py new file mode 100644 index 0000000..1a15f27 --- /dev/null +++ b/tests/test_utils/conftest.py @@ -0,0 +1,11 @@ +import shutil + +import pytest + +from .test_modifier import DATA_DIR + + +@pytest.hookimpl(tryfirst=True) +def pytest_sessionfinish(session, exitstatus): + if DATA_DIR.exists(): + shutil.rmtree(DATA_DIR) diff --git a/tests/test_utils/test_config.py b/tests/test_utils/test_config.py index 156c0ff..00f4933 100644 --- a/tests/test_utils/test_config.py +++ b/tests/test_utils/test_config.py @@ -88,6 +88,7 @@ DEFAULT_GENERAL_CONFIG = { "keep_intermediate_outputs": True, "disable_onnx_simplification": False, + "disable_onnx_optimisation": False, "output_remote_url": None, "put_file_plugin": None, "input_bin": None, diff --git a/tests/test_utils/test_modifier.py b/tests/test_utils/test_modifier.py new file mode 100644 index 0000000..43a44d4 --- /dev/null +++ b/tests/test_utils/test_modifier.py @@ -0,0 +1,240 @@ +import json +import os +import shutil +from pathlib import Path +from typing import Tuple + +import requests +import wget +from luxonis_ml.nn_archive.config import Config as NNArchiveConfig +from luxonis_ml.nn_archive.config_building_blocks import InputType + +from modelconverter.utils import ONNXModifier +from modelconverter.utils.config import Config +from modelconverter.utils.onnx_tools import onnx_attach_normalization_to_inputs + +DATA_DIR = Path("tests/data/test_utils/hub_ai_models") + +API_KEY = os.getenv("HUB_AI_API_KEY", None) +HEADERS = {"Authorization": f"Bearer {API_KEY}"} + +EXCEMPTED_MODELS = [ + "l2cs", + "zero-dce-400x600", + "mult_640x352", + "mult_512x288", +] + + +def download_onnx_models(): + if not os.path.exists(DATA_DIR): + os.makedirs(DATA_DIR) + + url = "https://easyml.cloud.luxonis.com/models/api/v1/models?is_public=true&limit=1000" + response = requests.get(url, headers=HEADERS) + if response.status_code != 200: + raise ValueError( + f"Failed to get models. Status code: {response.status_code}" + ) + hub_ai_models = response.json() + + for model in hub_ai_models: + if "ONNX" in model["exportable_types"]: + model_name = model["name"] + model_dir = DATA_DIR / f"{model_name}" + if not os.path.exists(model_dir): + os.makedirs(model_dir) + model_id = model["id"] + + url = f"https://easyml.cloud.luxonis.com/models/api/v1/modelVersions?model_id={model_id}" + response = requests.get(url, headers=HEADERS) + if response.status_code != 200: + raise ValueError( + f"Failed to get model versions. Status code: {response.status_code}" + ) + model_versions = response.json() + + for version in model_versions: + if "ONNX" in version["exportable_types"]: + model_version_id = version["id"] + break + url = f"https://easyml.cloud.luxonis.com/models/api/v1/modelVersions/{model_version_id}/download" + response = requests.get(url, headers=HEADERS) + if response.status_code != 200: + raise ValueError( + f"Failed to download model. Status code: {response.status_code}" + ) + download_info = response.json() + + model_download_link = download_info[0]["download_link"] + + filename = wget.download( + model_download_link, out=model_dir.as_posix() + ) + + if filename.endswith(".tar.xz"): + shutil.unpack_archive(filename, model_dir.as_posix()) + + with open(model_dir / "config.json") as f: + cfg = json.load(f) + model_name = cfg["model"]["metadata"]["path"].split(".onnx")[0] + + shutil.move(filename, model_dir / f"{model_name}.tar.xz") + shutil.move( + model_dir / "config.json", + model_dir / f"{model_name}_config.json", + ) + + for item in Path(model_dir).iterdir(): + shutil.move(str(item), DATA_DIR / item.name) + + shutil.rmtree(model_dir) + else: + os.remove(filename) + + onnx_models = [] + for onnx_file in DATA_DIR.glob("*.onnx"): + if ( + onnx_file.stem not in EXCEMPTED_MODELS + and "_modified" not in onnx_file.stem + ): + onnx_models.append(onnx_file) + return onnx_models + + +def get_config(nn_config: Path) -> Tuple[Config, str]: + with open(nn_config) as f: + archive_config = NNArchiveConfig(**json.load(f)) + + main_stage_config = { + "input_model": str(DATA_DIR / archive_config.model.metadata.path), + "inputs": [], + "outputs": [], + } + + for inp in archive_config.model.inputs: + reverse = inp.preprocessing.reverse_channels + interleaved_to_planar = inp.preprocessing.interleaved_to_planar + dai_type = inp.preprocessing.dai_type + + layout = inp.layout + encoding = "NONE" + if inp.input_type == InputType.IMAGE: + if dai_type is not None: + if dai_type.startswith("RGB"): + encoding = {"from": "RGB", "to": "BGR"} + elif dai_type.startswith("BGR"): + encoding = "BGR" + elif dai_type.startswith("GRAY"): + encoding = "GRAY" + else: + encoding = {"from": "RGB", "to": "BGR"} + + if dai_type.endswith("i"): + layout = "NHWC" + elif dai_type.endswith("p"): + layout = "NCHW" + else: + if reverse is not None: + if reverse: + encoding = {"from": "RGB", "to": "BGR"} + else: + encoding = "BGR" + else: + encoding = {"from": "RGB", "to": "BGR"} + + if interleaved_to_planar is not None: + if interleaved_to_planar: + layout = "NHWC" + else: + layout = "NCHW" + channels = ( + inp.shape[layout.index("C")] + if layout and "C" in layout + else None + ) + if channels and channels == 1: + encoding = "GRAY" + + mean = inp.preprocessing.mean or [0, 0, 0] + scale = inp.preprocessing.scale or [1, 1, 1] + + main_stage_config["inputs"].append( + { + "name": inp.name, + "shape": inp.shape, + "layout": layout, + "data_type": inp.dtype.value, + "mean_values": mean, + "scale_values": scale, + "encoding": encoding, + } + ) + + for out in archive_config.model.outputs: + main_stage_config["outputs"].append( + { + "name": out.name, + "shape": out.shape, + "layout": out.layout, + "data_type": out.dtype.value, + } + ) + + main_stage_key = archive_config.model.metadata.name + config = { + "name": main_stage_key, + "stages": { + main_stage_key: main_stage_config, + }, + } + + for head in archive_config.model.heads or []: + postprocessor_path = getattr(head.metadata, "postprocessor_path", None) + if postprocessor_path is not None: + input_model_path = DATA_DIR / postprocessor_path + head_stage_config = { + "input_model": str(input_model_path), + "inputs": [], + "outputs": [], + "encoding": "NONE", + } + config["stages"][input_model_path.stem] = head_stage_config + + return Config.get_config(config, None), main_stage_key + + +def pytest_generate_tests(metafunc): + params = download_onnx_models() + metafunc.parametrize("onnx_file", params) + + +def test_onnx_model(onnx_file): + nn_config = onnx_file.parent / f"{onnx_file.stem}_config.json" + cfg, main_stage_key = get_config(nn_config) + + input_configs = { + input_config.name: input_config + for input_config in cfg.stages[main_stage_key].inputs + } + for input_name in input_configs: + input_configs[input_name].layout = "NCHW" + + modified_onnx = onnx_file.parent / f"{onnx_file.stem}_modified.onnx" + onnx_attach_normalization_to_inputs( + onnx_file, modified_onnx, input_configs + ) + + modified_optimised_onnx = ( + onnx_file.parent / f"{onnx_file.stem}_modified_optimised.onnx" + ) + onnx_modifier = ONNXModifier( + model_path=modified_onnx, output_path=modified_optimised_onnx + ) + + if onnx_modifier.has_dynamic_shape: + return + + assert ( + onnx_modifier.modify_onnx() and onnx_modifier.compare_outputs() + ), f"Test failed for {onnx_file.name}"