From f6865413acea7798c882637c0b45c24a0ba4217f Mon Sep 17 00:00:00 2001 From: SWHL Date: Thu, 28 Dec 2023 09:55:15 +0800 Subject: [PATCH] Optimize the func logic of the python version. --- python/demo.py | 1 - .../ch_ppocr_v3_rec/text_recognize.py | 6 +++--- python/rapidocr_onnxruntime/utils.py | 8 +++++--- 3 files changed, 8 insertions(+), 7 deletions(-) diff --git a/python/demo.py b/python/demo.py index 037f5502f..0fad6acd3 100644 --- a/python/demo.py +++ b/python/demo.py @@ -2,7 +2,6 @@ # @Author: SWHL # @Contact: liekkaskono@163.com import cv2 - from rapidocr_onnxruntime import RapidOCR, VisRes # from rapidocr_paddle import RapidOCR, VisRes diff --git a/python/rapidocr_onnxruntime/ch_ppocr_v3_rec/text_recognize.py b/python/rapidocr_onnxruntime/ch_ppocr_v3_rec/text_recognize.py index 5d499f110..344f6adc4 100644 --- a/python/rapidocr_onnxruntime/ch_ppocr_v3_rec/text_recognize.py +++ b/python/rapidocr_onnxruntime/ch_ppocr_v3_rec/text_recognize.py @@ -29,10 +29,10 @@ def __init__(self, config): self.session = OrtInferSession(config) if self.session.have_key(): - self.character_dict_path = self.session.get_character_list() + character_dict_path = self.session.get_character_list() else: - self.character_dict_path = config.get("rec_character_dict_path", None) - self.postprocess_op = CTCLabelDecode(self.character_dict_path) + character_dict_path = config.get("rec_keys_path", None) + self.postprocess_op = CTCLabelDecode(character_dict_path) self.rec_batch_num = config["rec_batch_num"] self.rec_image_shape = config["rec_img_shape"] diff --git a/python/rapidocr_onnxruntime/utils.py b/python/rapidocr_onnxruntime/utils.py index 2e5e34505..1eb01e618 100644 --- a/python/rapidocr_onnxruntime/utils.py +++ b/python/rapidocr_onnxruntime/utils.py @@ -88,11 +88,12 @@ def get_output_names( return [v.name for v in self.session.get_outputs()] def get_character_list(self, key: str = "character"): - return self.meta_dict[key].splitlines() + meta_dict = self.session.get_modelmeta().custom_metadata_map + return meta_dict[key].splitlines() def have_key(self, key: str = "character") -> bool: - self.meta_dict = self.session.get_modelmeta().custom_metadata_map - if key in self.meta_dict.keys(): + meta_dict = self.session.get_modelmeta().custom_metadata_map + if key in meta_dict.keys(): return True return False @@ -262,6 +263,7 @@ def init_args(): rec_group = parser.add_argument_group(title="Rec") rec_group.add_argument("--rec_use_cuda", action="store_true", default=False) rec_group.add_argument("--rec_model_path", type=str, default=None) + rec_group.add_argument("--rec_keys_path", type=str, default=None) rec_group.add_argument("--rec_img_shape", type=list, default=[3, 48, 320]) rec_group.add_argument("--rec_batch_num", type=int, default=6)