Skip to content

Commit

Permalink
add disable_onnx_optimisation flag and address PR comments
Browse files Browse the repository at this point in the history
  • Loading branch information
ptoupas committed Dec 12, 2024
1 parent 7237a6c commit 887ed47
Show file tree
Hide file tree
Showing 6 changed files with 179 additions and 168 deletions.
1 change: 1 addition & 0 deletions modelconverter/cli/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,6 +381,7 @@ def get_target_specific_options(
json_cfg = cfg.model_dump(mode="json")
options = {
"disable_onnx_simplification": cfg.disable_onnx_simplification,
"disable_onnx_optimisation": cfg.disable_onnx_optimisation,
"inputs": json_cfg["inputs"],
}
if target == "rvc4":
Expand Down
1 change: 1 addition & 0 deletions modelconverter/packages/base_exporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ def __init__(
self.outputs = {out.name: out for out in config.outputs}
self.keep_intermediate_outputs = config.keep_intermediate_outputs
self.disable_onnx_simplification = config.disable_onnx_simplification
self.disable_onnx_optimisation = config.disable_onnx_optimisation

self.model_name = self.input_model.stem

Expand Down
28 changes: 16 additions & 12 deletions modelconverter/packages/rvc4/exporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,18 +60,22 @@ def __init__(self, config: SingleStageConfig, output_dir: Path):
self.inputs,
)

onnx_modifier = ONNXModifier(
model_path=self.input_model,
output_path=self._attach_suffix(
self.input_model, "modified_optimised.onnx"
),
)
onnx_modifier.modify_onnx()
if onnx_modifier.compare_outputs():
logger.info("ONNX model has been optimised for RVC4.")
shutil.move(onnx_modifier.output_path, self.input_model)
else:
os.remove(onnx_modifier.output_path)
if not config.disable_onnx_optimisation:
onnx_modifier = ONNXModifier(
model_path=self.input_model,
output_path=self._attach_suffix(
self.input_model, "modified_optimised.onnx"
),
)

if (
onnx_modifier.modify_onnx()
and onnx_modifier.compare_outputs()
):
logger.info("ONNX model has been optimised for RVC4.")
shutil.move(onnx_modifier.output_path, self.input_model)
else:
os.remove(onnx_modifier.output_path)
else:
logger.warning(
"Input file type is not ONNX. Skipping pre-processing."
Expand Down
1 change: 1 addition & 0 deletions modelconverter/utils/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,7 @@ class SingleStageConfig(CustomBaseModel):

keep_intermediate_outputs: bool = True
disable_onnx_simplification: bool = False
disable_onnx_optimisation: bool = False
output_remote_url: Optional[str] = None
put_file_plugin: Optional[str] = None

Expand Down
Loading

0 comments on commit 887ed47

Please sign in to comment.