Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ONNXModifier for optimising the ONNX model before converting for RVC4 execution #55

Merged
merged 15 commits into from
Dec 13, 2024
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/unittests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -33,5 +33,6 @@ jobs:
AWS_SECRET_ACCESS_KEY: ${{ secrets.AWS_SECRET_ACCESS_KEY }}
AWS_S3_ENDPOINT_URL: ${{ secrets.AWS_S3_ENDPOINT_URL }}
GOOGLE_APPLICATION_CREDENTIALS: ${{ secrets.GCP_CREDENTIALS }}
HUB_AI_API_KEY: ${{ secrets.HUB_AI_API_KEY }}
run: python -m pytest tests/test_utils

2 changes: 0 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
default_language_version:
python: python3
repos:
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.1.2
Expand Down
1 change: 1 addition & 0 deletions modelconverter/cli/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,6 +376,7 @@ def get_target_specific_options(
json_cfg = cfg.model_dump(mode="json")
options = {
"disable_onnx_simplification": cfg.disable_onnx_simplification,
"disable_onnx_optimisation": cfg.disable_onnx_optimisation,
"inputs": json_cfg["inputs"],
}
if target == "rvc4":
Expand Down
1 change: 1 addition & 0 deletions modelconverter/packages/base_exporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ def __init__(
self.outputs = {out.name: out for out in config.outputs}
self.keep_intermediate_outputs = config.keep_intermediate_outputs
self.disable_onnx_simplification = config.disable_onnx_simplification
self.disable_onnx_optimisation = config.disable_onnx_optimisation

self.model_name = self.input_model.stem

Expand Down
29 changes: 17 additions & 12 deletions modelconverter/packages/rvc4/exporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,18 +60,23 @@ def __init__(self, config: SingleStageConfig, output_dir: Path):
self.inputs,
)

onnx_modifier = ONNXModifier(
model_path=self.input_model,
output_path=self._attach_suffix(
self.input_model, "modified_optimised.onnx"
),
)
onnx_modifier.modify_onnx()
if onnx_modifier.compare_outputs():
logger.info("ONNX model has been optimised for RVC4.")
shutil.move(onnx_modifier.output_path, self.input_model)
else:
os.remove(onnx_modifier.output_path)
if not config.disable_onnx_optimisation:
onnx_modifier = ONNXModifier(
model_path=self.input_model,
output_path=self._attach_suffix(
self.input_model, "modified_optimised.onnx"
),
)

if (
onnx_modifier.modify_onnx()
and onnx_modifier.compare_outputs()
):
logger.info("ONNX model has been optimised for RVC4.")
shutil.move(onnx_modifier.output_path, self.input_model)
else:
if os.path.exists(onnx_modifier.output_path):
os.remove(onnx_modifier.output_path)
else:
logger.warning(
"Input file type is not ONNX. Skipping pre-processing."
Expand Down
1 change: 1 addition & 0 deletions modelconverter/utils/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,7 @@ class SingleStageConfig(CustomBaseModel):

keep_intermediate_outputs: bool = True
disable_onnx_simplification: bool = False
disable_onnx_optimisation: bool = False
output_remote_url: Optional[str] = None
put_file_plugin: Optional[str] = None

Expand Down
Loading
Loading