diff --git a/python/demo.py b/python/demo.py index 9b886ca8b..81a77d683 100644 --- a/python/demo.py +++ b/python/demo.py @@ -15,18 +15,29 @@ with open(image_path, "rb") as f: img = f.read() -result, elapse_list = engine(img, return_word_box=True) +# result, elapse_list = engine(img, use_det=True, use_cls=False, use_rec=False) +result, elapse_list = engine(img) print(result) print(elapse_list) -(boxes, txts, scores, words_boxes, words) = list(zip(*result)) +# result, elapse = engine(image_path, use_det=False, use_cls=False, use_rec=True) +(boxes, txts, scores, words_boxes, words, words_scores) = list(zip(*result)) font_path = "resources/fonts/FZYTK.TTF" -vis_img = vis(img, boxes, txts, scores, font_path) -cv2.imwrite("vis.png", vis_img) words_boxes = sum(words_boxes, []) words_all = sum(words, []) -words_scores = [1.0] * len(words_boxes) +words_scores = sum(words_scores, []) vis_img = vis(img, words_boxes, words_all, words_scores, font_path) cv2.imwrite("vis_single.png", vis_img) + +# (boxes, txts, scores, words_boxes, words) = list(zip(*result)) + +# vis_img = vis(img, boxes, txts, scores, font_path) +# cv2.imwrite("vis.png", vis_img) + +# words_boxes = sum(words_boxes, []) +# words_all = sum(words, []) +# words_scores = [1.0] * len(words_boxes) +# vis_img = vis(img, words_boxes, words_all, words_scores, font_path) +# cv2.imwrite("vis_single.png", vis_img) diff --git a/python/rapidocr_onnxruntime/cal_rec_boxes/main.py b/python/rapidocr_onnxruntime/cal_rec_boxes/main.py index c8ea5da4c..296997f0c 100644 --- a/python/rapidocr_onnxruntime/cal_rec_boxes/main.py +++ b/python/rapidocr_onnxruntime/cal_rec_boxes/main.py @@ -3,11 +3,13 @@ # @Contact: liekkaskono@163.com import copy import math -from typing import Any, List, Optional, Tuple +from typing import List, Optional, Tuple import cv2 import numpy as np +from ..ch_ppocr_rec.utils import TextRecognizerOutput + class CalRecBoxes: """计算识别文字的汉字单字和英文单词的坐标框。代码借鉴自PaddlePaddle/PaddleOCR和fanqie03/char-detection""" @@ -19,13 +21,16 @@ def __call__( self, imgs: Optional[List[np.ndarray]], dt_boxes: Optional[List[np.ndarray]], - rec_res: Optional[List[Any]], - ): - res = [] - for img, box, rec_res in zip(imgs, dt_boxes, rec_res): + rec_res: TextRecognizerOutput, + ) -> TextRecognizerOutput: + # rec_res = list(zip(rec_res.line_results, rec_res.word_results)) + word_results = [] + for idx, (img, box) in enumerate(zip(imgs, dt_boxes)): direction = self.get_box_direction(box) - rec_txt, rec_conf, rec_word_info = rec_res[0], rec_res[1], rec_res[2] + rec_txt, rec_conf = rec_res.line_results[idx] + rec_word_info = rec_res.word_results[idx] + h, w = img.shape[:2] img_box = np.array([[0, 0], [w, 0], [w, h], [0, h]]) word_box_content_list, word_box_list, conf_list = self.cal_ocr_word_box( @@ -35,10 +40,12 @@ def __call__( word_box_list = self.reverse_rotate_crop_image( copy.deepcopy(box), word_box_list, direction ) - res.append( - [rec_txt, rec_conf, word_box_list, word_box_content_list, conf_list] + word_results.extend( + list(zip(word_box_content_list, conf_list, word_box_list)) ) - return res + + rec_res.word_results = tuple(list(v) for v in word_results) + return rec_res @staticmethod def get_box_direction(box: np.ndarray) -> str: diff --git a/python/rapidocr_onnxruntime/ch_ppocr_rec/__init__.py b/python/rapidocr_onnxruntime/ch_ppocr_rec/__init__.py index 37eafdc7b..46477ba68 100644 --- a/python/rapidocr_onnxruntime/ch_ppocr_rec/__init__.py +++ b/python/rapidocr_onnxruntime/ch_ppocr_rec/__init__.py @@ -1,4 +1,5 @@ # -*- encoding: utf-8 -*- # @Author: SWHL # @Contact: liekkaskono@163.com -from .text_recognize import TextRecognizer +from .main import TextRecognizer +from .utils import TextRecognizerOutput diff --git a/python/rapidocr_onnxruntime/ch_ppocr_rec/text_recognize.py b/python/rapidocr_onnxruntime/ch_ppocr_rec/text_recognize.py deleted file mode 100644 index e823ea655..000000000 --- a/python/rapidocr_onnxruntime/ch_ppocr_rec/text_recognize.py +++ /dev/null @@ -1,130 +0,0 @@ -# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import argparse -import math -import time -from typing import Any, Dict, List, Tuple, Union - -import cv2 -import numpy as np - -from rapidocr_onnxruntime.utils import OrtInferSession, read_yaml - -from .utils import CTCLabelDecode - - -class TextRecognizer: - def __init__(self, config: Dict[str, Any]): - self.session = OrtInferSession(config) - - character = None - if self.session.have_key(): - character = self.session.get_character_list() - - character_path = config.get("rec_keys_path", None) - self.postprocess_op = CTCLabelDecode( - character=character, character_path=character_path - ) - - self.rec_batch_num = config["rec_batch_num"] - self.rec_image_shape = config["rec_img_shape"] - - def __call__( - self, - img_list: Union[np.ndarray, List[np.ndarray]], - return_word_box: bool = False, - ) -> Tuple[List[Tuple[str, float]], float]: - if isinstance(img_list, np.ndarray): - img_list = [img_list] - - # Calculate the aspect ratio of all text bars - width_list = [img.shape[1] / float(img.shape[0]) for img in img_list] - - # Sorting can speed up the recognition process - indices = np.argsort(np.array(width_list)) - - img_num = len(img_list) - rec_res = [("", 0.0)] * img_num - - batch_num = self.rec_batch_num - elapse = 0 - for beg_img_no in range(0, img_num, batch_num): - end_img_no = min(img_num, beg_img_no + batch_num) - - # Parameter Alignment for PaddleOCR - imgC, imgH, imgW = self.rec_image_shape[:3] - max_wh_ratio = imgW / imgH - wh_ratio_list = [] - for ino in range(beg_img_no, end_img_no): - h, w = img_list[indices[ino]].shape[0:2] - wh_ratio = w * 1.0 / h - max_wh_ratio = max(max_wh_ratio, wh_ratio) - wh_ratio_list.append(wh_ratio) - - norm_img_batch = [] - for ino in range(beg_img_no, end_img_no): - norm_img = self.resize_norm_img(img_list[indices[ino]], max_wh_ratio) - norm_img_batch.append(norm_img[np.newaxis, :]) - norm_img_batch = np.concatenate(norm_img_batch).astype(np.float32) - - starttime = time.time() - preds = self.session(norm_img_batch)[0] - rec_result = self.postprocess_op( - preds, - return_word_box, - wh_ratio_list=wh_ratio_list, - max_wh_ratio=max_wh_ratio, - ) - - for rno, one_res in enumerate(rec_result): - rec_res[indices[beg_img_no + rno]] = one_res - elapse += time.time() - starttime - return rec_res, elapse - - def resize_norm_img(self, img: np.ndarray, max_wh_ratio: float) -> np.ndarray: - img_channel, img_height, img_width = self.rec_image_shape - assert img_channel == img.shape[2] - - img_width = int(img_height * max_wh_ratio) - - h, w = img.shape[:2] - ratio = w / float(h) - if math.ceil(img_height * ratio) > img_width: - resized_w = img_width - else: - resized_w = int(math.ceil(img_height * ratio)) - - resized_image = cv2.resize(img, (resized_w, img_height)) - resized_image = resized_image.astype("float32") - resized_image = resized_image.transpose((2, 0, 1)) / 255 - resized_image -= 0.5 - resized_image /= 0.5 - - padding_im = np.zeros((img_channel, img_height, img_width), dtype=np.float32) - padding_im[:, :, 0:resized_w] = resized_image - return padding_im - - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument("--image_path", type=str, help="image_dir|image_path") - parser.add_argument("--config_path", type=str, default="config.yaml") - args = parser.parse_args() - - config = read_yaml(args.config_path) - text_recognizer = TextRecognizer(config) - - img = cv2.imread(args.image_path) - rec_res, predict_time = text_recognizer(img) - print(f"rec result: {rec_res}\t cost: {predict_time}s") diff --git a/python/rapidocr_onnxruntime/ch_ppocr_rec/utils.py b/python/rapidocr_onnxruntime/ch_ppocr_rec/utils.py index 224b2f879..9837c9000 100644 --- a/python/rapidocr_onnxruntime/ch_ppocr_rec/utils.py +++ b/python/rapidocr_onnxruntime/ch_ppocr_rec/utils.py @@ -1,12 +1,43 @@ # -*- encoding: utf-8 -*- # @Author: SWHL # @Contact: liekkaskono@163.com +from dataclasses import dataclass, field from pathlib import Path from typing import List, Optional, Tuple, Union import numpy as np +@dataclass +class TextRecognizerConfig: + intra_op_num_threads: int = -1 + inter_op_num_threads: int = -1 + use_cuda: bool = False + use_dml: bool = False + model_path: Union[str, Path, None] = None + + rec_batch_num: int = 6 + rec_img_shape: Tuple[int, int, int] = (3, 48, 320) + rec_keys_path: Union[str, Path, None] = None + + +@dataclass +class TextRecognizerInput: + img_list: Union[np.ndarray, List[np.ndarray]] = field(init=False) + return_word_box: bool = False + + def __pose_init__(self): + if isinstance(self.img_list, np.ndarray): + self.img_list = [self.img_list] + + +@dataclass +class TextRecognizerOutput: + line_results: Optional[Tuple[List]] = None + word_results: Optional[List[List]] = None + elapse: Optional[float] = None + + class CTCLabelDecode: def __init__( self, @@ -18,18 +49,18 @@ def __init__( def __call__( self, preds: np.ndarray, return_word_box: bool = False, **kwargs - ) -> List[Tuple[str, float]]: + ) -> Tuple[List[Tuple[str, float]], List[List]]: preds_idx = preds.argmax(axis=2) preds_prob = preds.max(axis=2) - text = self.decode( + line_results, word_results = self.decode( preds_idx, preds_prob, return_word_box, is_remove_duplicate=True ) if return_word_box: - for rec_idx, rec in enumerate(text): + for rec_idx, rec in enumerate(word_results): wh_ratio = kwargs["wh_ratio_list"][rec_idx] max_wh_ratio = kwargs["max_wh_ratio"] - rec[2][0] = rec[2][0] * (wh_ratio / max_wh_ratio) - return text + rec[0] *= wh_ratio / max_wh_ratio + return line_results, word_results def get_character( self, @@ -80,7 +111,7 @@ def decode( is_remove_duplicate: bool = False, ) -> List[Tuple[str, float]]: """convert text-index into text-label.""" - result_list = [] + result_list, result_words_list = [], [] ignored_tokens = self.get_ignored_tokens() batch_size = len(text_index) for batch_idx in range(batch_size): @@ -93,6 +124,7 @@ def decode( if text_prob is not None: conf_list = np.array(text_prob[batch_idx][selection]).tolist() + conf_list = [round(conf, 5) for conf in conf_list] else: conf_list = [1] * len(selection) @@ -103,26 +135,23 @@ def decode( self.character[text_id] for text_id in text_index[batch_idx][selection] ] text = "".join(char_list) + + result_list.append([text, np.mean(conf_list).round(5).tolist()]) + if return_word_box: word_list, word_col_list, state_list = self.get_word_info( text, selection ) - result_list.append( - ( - text, - np.mean(conf_list).tolist(), - [ - len(text_index[batch_idx]), - word_list, - word_col_list, - state_list, - conf_list, - ], - ) + result_words_list.append( + [ + len(text_index[batch_idx]), + word_list, + word_col_list, + state_list, + conf_list, + ] ) - else: - result_list.append((text, np.mean(conf_list).tolist())) - return result_list + return result_list, result_words_list @staticmethod def get_word_info( diff --git a/python/rapidocr_onnxruntime/main.py b/python/rapidocr_onnxruntime/main.py index 3c0300e6b..d529fb4ca 100644 --- a/python/rapidocr_onnxruntime/main.py +++ b/python/rapidocr_onnxruntime/main.py @@ -11,7 +11,7 @@ from .cal_rec_boxes import CalRecBoxes from .ch_ppocr_cls import TextClassifier from .ch_ppocr_det import TextDetector -from .ch_ppocr_rec import TextRecognizer +from .ch_ppocr_rec import TextRecognizer, TextRecognizerOutput from .utils import ( LoadImage, UpdateParameters, @@ -106,14 +106,18 @@ def __call__( img, cls_res, cls_elapse = self.text_cls(img) if use_rec: - rec_res, rec_elapse = self.text_rec(img, return_word_box) + rec_res = self.text_rec(img, return_word_box) - if dt_boxes is not None and rec_res is not None and return_word_box: + if ( + return_word_box + and dt_boxes is not None + and rec_res.word_results is not None + ): rec_res = self.cal_rec_boxes(img, dt_boxes, rec_res) - for rec_res_i in rec_res: + for rec_res_i in rec_res.word_results: if rec_res_i[2]: rec_res_i[2] = ( - self._get_origin_points(rec_res_i[2], op_record, raw_h, raw_w) + self._get_origin_points([rec_res_i[2]], op_record, raw_h, raw_w) .astype(np.int32) .tolist() ) @@ -122,7 +126,12 @@ def __call__( dt_boxes = self._get_origin_points(dt_boxes, op_record, raw_h, raw_w) ocr_res = self.get_final_res( - dt_boxes, cls_res, rec_res, det_elapse, cls_elapse, rec_elapse + # dt_boxes, cls_res, rec_res, det_elapse, cls_elapse, rec_elapse + dt_boxes, + cls_res, + rec_res, + det_elapse, + cls_elapse, ) return ocr_res @@ -276,10 +285,9 @@ def get_final_res( self, dt_boxes: Optional[List[np.ndarray]], cls_res: Optional[List[List[Union[str, float]]]], - rec_res: Optional[List[Tuple[str, float, List[Union[str, float]]]]], + rec_res: TextRecognizerOutput, det_elapse: float, cls_elapse: float, - rec_elapse: float, ) -> Tuple[Optional[List[List[Union[Any, str]]]], Optional[List[float]]]: if dt_boxes is None and rec_res is None and cls_res is not None: return cls_res, [cls_elapse] @@ -288,7 +296,7 @@ def get_final_res( return None, None if dt_boxes is None and rec_res is not None: - return [[res[0], res[1]] for res in rec_res], [rec_elapse] + return [[v[0], v[1]] for v in rec_res.line_results], [rec_res.elapse] if dt_boxes is not None and rec_res is None: return [box.tolist() for box in dt_boxes], [det_elapse] @@ -297,10 +305,12 @@ def get_final_res( if not dt_boxes or not rec_res or len(dt_boxes) <= 0: return None, None - ocr_res = [[box.tolist(), *res] for box, res in zip(dt_boxes, rec_res)], [ + ocr_res = [ + [box.tolist(), *res] for box, res in zip(dt_boxes, rec_res.line_results) + ], [ det_elapse, cls_elapse, - rec_elapse, + rec_res.elapse, ] return ocr_res @@ -308,18 +318,19 @@ def filter_result( self, dt_boxes: Optional[List[np.ndarray]], rec_res: Optional[List[Tuple[str, float]]], - ) -> Tuple[Optional[List[np.ndarray]], Optional[List[Tuple[str, float]]]]: + ) -> Tuple[Optional[List[np.ndarray]], TextRecognizerOutput]: if dt_boxes is None or rec_res is None: return None, None filter_boxes, filter_rec_res = [], [] - for box, rec_reuslt in zip(dt_boxes, rec_res): + for box, rec_reuslt in zip(dt_boxes, rec_res.line_results): text, score = rec_reuslt[0], rec_reuslt[1] if float(score) >= self.text_score: filter_boxes.append(box) filter_rec_res.append(rec_reuslt) - return filter_boxes, filter_rec_res + rec_res.line_results = filter_rec_res + return filter_boxes, rec_res def main(): @@ -330,11 +341,7 @@ def main(): use_cls = not args.no_cls use_rec = not args.no_rec result, elapse_list = ocr_engine( - args.img_path, - use_det=use_det, - use_cls=use_cls, - use_rec=use_rec, - **vars(args) + args.img_path, use_det=use_det, use_cls=use_cls, use_rec=use_rec, **vars(args) ) logger.info(result)