Skip to content

Commit

Permalink
add version control for export and modify hpi config (#14513)
Browse files Browse the repository at this point in the history
  • Loading branch information
zhangyubo0722 authored Jan 8, 2025
1 parent a6b96bb commit bf2b73f
Showing 1 changed file with 19 additions and 8 deletions.
27 changes: 19 additions & 8 deletions ppocr/utils/export_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def dump_infer_config(config, path, logger):
}
elif arch_config["model_type"] == "det":
common_dynamic_shapes = {
"x": [[1, 3, 160, 160], [1, 3, 160, 160], [1, 3, 1280, 1280]]
"x": [[1, 3, 160, 160], [1, 3, 640, 640], [1, 3, 1280, 1280]]
}
elif arch_config["algorithm"] == "SLANet":
common_dynamic_shapes = {
Expand All @@ -64,11 +64,17 @@ def dump_infer_config(config, path, logger):
"x": [[1, 3, 224, 224], [1, 3, 448, 448], [8, 3, 1280, 1280]]
}
elif arch_config["algorithm"] == "UniMERNet":
common_dynamic_shapes = {"x": [[1, 3, 192, 672]]}
common_dynamic_shapes = {
"x": [[1, 3, 192, 672], [1, 3, 192, 672], [8, 3, 192, 672]]
}
elif arch_config["algorithm"] == "PP-FormulaNet-L":
common_dynamic_shapes = {"x": [[1, 3, 768, 768]]}
common_dynamic_shapes = {
"x": [[1, 3, 768, 768], [1, 3, 768, 768], [8, 3, 768, 768]]
}
elif arch_config["algorithm"] == "PP-FormulaNet-S":
common_dynamic_shapes = {"x": [[1, 3, 384, 384]]}
common_dynamic_shapes = {
"x": [[1, 3, 384, 384], [1, 3, 384, 384], [8, 3, 384, 384]]
}
else:
common_dynamic_shapes = None

Expand Down Expand Up @@ -345,17 +351,22 @@ def export_single_model(
ModuleNotFoundError
): # Encryption is not needed if the module cannot be imported
print("Skipping import of the encryption module")
paddle_version = version.parse(paddle.__version__)
if config["Global"].get("export_with_pir", False):
paddle_version = version.parse(paddle.__version__)
assert (
paddle_version >= version.parse("3.0.0b2")
or paddle_version == version.parse("0.0.0")
) and os.environ.get("FLAGS_enable_pir_api", None) not in ["0", "False"]
paddle.jit.save(model, save_path)
else:
model.forward.rollback()
with paddle.pir_utils.OldIrGuard():
model = dynamic_to_static(model, arch_config, logger, input_shape)
if paddle_version >= version.parse(
"3.0.0b2"
) or paddle_version == version.parse("0.0.0"):
model.forward.rollback()
with paddle.pir_utils.OldIrGuard():
model = dynamic_to_static(model, arch_config, logger, input_shape)
paddle.jit.save(model, save_path)
else:
paddle.jit.save(model, save_path)
else:
quanter.save_quantized_model(model, save_path)
Expand Down

0 comments on commit bf2b73f

Please sign in to comment.