Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ONNXModifier for optimising the ONNX model before converting for RVC4 execution #55

Merged
merged 15 commits into from
Dec 13, 2024
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/unittests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -33,5 +33,6 @@ jobs:
AWS_SECRET_ACCESS_KEY: ${{ secrets.AWS_SECRET_ACCESS_KEY }}
AWS_S3_ENDPOINT_URL: ${{ secrets.AWS_S3_ENDPOINT_URL }}
GOOGLE_APPLICATION_CREDENTIALS: ${{ secrets.GCP_CREDENTIALS }}
HUB_AI_API_KEY: ${{ secrets.HUB_AI_API_KEY }}
run: python -m pytest tests/test_utils

1 change: 1 addition & 0 deletions modelconverter/cli/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,6 +376,7 @@ def get_target_specific_options(
json_cfg = cfg.model_dump(mode="json")
options = {
"disable_onnx_simplification": cfg.disable_onnx_simplification,
"disable_onnx_optimisation": cfg.disable_onnx_optimisation,
"inputs": json_cfg["inputs"],
}
if target == "rvc4":
Expand Down
1 change: 1 addition & 0 deletions modelconverter/packages/base_exporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ def __init__(
self.outputs = {out.name: out for out in config.outputs}
self.keep_intermediate_outputs = config.keep_intermediate_outputs
self.disable_onnx_simplification = config.disable_onnx_simplification
self.disable_onnx_optimisation = config.disable_onnx_optimisation

self.model_name = self.input_model.stem

Expand Down
20 changes: 20 additions & 0 deletions modelconverter/packages/rvc4/exporter.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
import shutil
import subprocess
import time
Expand All @@ -6,6 +7,7 @@
from typing import Any, Dict, List, NamedTuple, Optional, cast

from modelconverter.utils import (
ONNXModifier,
exit_with,
onnx_attach_normalization_to_inputs,
read_image,
Expand Down Expand Up @@ -57,6 +59,24 @@ def __init__(self, config: SingleStageConfig, output_dir: Path):
self._attach_suffix(self.input_model, "modified.onnx"),
self.inputs,
)

if not config.disable_onnx_optimisation:
onnx_modifier = ONNXModifier(
model_path=self.input_model,
output_path=self._attach_suffix(
self.input_model, "modified_optimised.onnx"
),
)

if (
onnx_modifier.modify_onnx()
and onnx_modifier.compare_outputs()
):
logger.info("ONNX model has been optimised for RVC4.")
shutil.move(onnx_modifier.output_path, self.input_model)
else:
if os.path.exists(onnx_modifier.output_path):
os.remove(onnx_modifier.output_path)
else:
logger.warning(
"Input file type is not ONNX. Skipping pre-processing."
Expand Down
3 changes: 2 additions & 1 deletion modelconverter/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
modelconverter_config_to_nn,
process_nn_archive,
)
from .onnx_tools import onnx_attach_normalization_to_inputs
from .onnx_tools import ONNXModifier, onnx_attach_normalization_to_inputs
from .subprocess import subprocess_run

__all__ = [
Expand All @@ -37,6 +37,7 @@
"S3Exception",
"SubprocessException",
"exit_with",
"ONNXModifier",
"onnx_attach_normalization_to_inputs",
"read_calib_dir",
"read_image",
Expand Down
1 change: 1 addition & 0 deletions modelconverter/utils/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,7 @@ class SingleStageConfig(CustomBaseModel):

keep_intermediate_outputs: bool = True
disable_onnx_simplification: bool = False
disable_onnx_optimisation: bool = False
output_remote_url: Optional[str] = None
put_file_plugin: Optional[str] = None

Expand Down
Loading
Loading