Skip to content

Commit

Permalink
Optimize the func logic of the python version.
Browse files Browse the repository at this point in the history
  • Loading branch information
SWHL committed Dec 28, 2023
1 parent f5fa9d1 commit f686541
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 7 deletions.
1 change: 0 additions & 1 deletion python/demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
# @Author: SWHL
# @Contact: [email protected]
import cv2

from rapidocr_onnxruntime import RapidOCR, VisRes

# from rapidocr_paddle import RapidOCR, VisRes
Expand Down
6 changes: 3 additions & 3 deletions python/rapidocr_onnxruntime/ch_ppocr_v3_rec/text_recognize.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,10 @@ def __init__(self, config):
self.session = OrtInferSession(config)

if self.session.have_key():
self.character_dict_path = self.session.get_character_list()
character_dict_path = self.session.get_character_list()
else:
self.character_dict_path = config.get("rec_character_dict_path", None)
self.postprocess_op = CTCLabelDecode(self.character_dict_path)
character_dict_path = config.get("rec_keys_path", None)
self.postprocess_op = CTCLabelDecode(character_dict_path)

self.rec_batch_num = config["rec_batch_num"]
self.rec_image_shape = config["rec_img_shape"]
Expand Down
8 changes: 5 additions & 3 deletions python/rapidocr_onnxruntime/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,11 +88,12 @@ def get_output_names(
return [v.name for v in self.session.get_outputs()]

def get_character_list(self, key: str = "character"):
return self.meta_dict[key].splitlines()
meta_dict = self.session.get_modelmeta().custom_metadata_map
return meta_dict[key].splitlines()

def have_key(self, key: str = "character") -> bool:
self.meta_dict = self.session.get_modelmeta().custom_metadata_map
if key in self.meta_dict.keys():
meta_dict = self.session.get_modelmeta().custom_metadata_map
if key in meta_dict.keys():
return True
return False

Expand Down Expand Up @@ -262,6 +263,7 @@ def init_args():
rec_group = parser.add_argument_group(title="Rec")
rec_group.add_argument("--rec_use_cuda", action="store_true", default=False)
rec_group.add_argument("--rec_model_path", type=str, default=None)
rec_group.add_argument("--rec_keys_path", type=str, default=None)
rec_group.add_argument("--rec_img_shape", type=list, default=[3, 48, 320])
rec_group.add_argument("--rec_batch_num", type=int, default=6)

Expand Down

0 comments on commit f686541

Please sign in to comment.