From 6bcabddbcaf9619711b31ce49a55c1a46afd05a8 Mon Sep 17 00:00:00 2001 From: SWHL Date: Thu, 12 Dec 2024 21:58:48 +0800 Subject: [PATCH] fix(rapidocr_openvino): merge PR #293 #294 --- .../rapidocr_openvino/cal_rec_boxes/main.py | 38 ++++++++++++------- .../ch_ppocr_det/text_detect.py | 18 +++++++-- .../rapidocr_openvino/ch_ppocr_rec/utils.py | 15 ++++++-- python/rapidocr_openvino/main.py | 7 +--- 4 files changed, 53 insertions(+), 25 deletions(-) diff --git a/python/rapidocr_openvino/cal_rec_boxes/main.py b/python/rapidocr_openvino/cal_rec_boxes/main.py index cbc529493..c8ea5da4c 100644 --- a/python/rapidocr_openvino/cal_rec_boxes/main.py +++ b/python/rapidocr_openvino/cal_rec_boxes/main.py @@ -28,14 +28,16 @@ def __call__( rec_txt, rec_conf, rec_word_info = rec_res[0], rec_res[1], rec_res[2] h, w = img.shape[:2] img_box = np.array([[0, 0], [w, 0], [w, h], [0, h]]) - word_box_content_list, word_box_list = self.cal_ocr_word_box( + word_box_content_list, word_box_list, conf_list = self.cal_ocr_word_box( rec_txt, img_box, rec_word_info ) word_box_list = self.adjust_box_overlap(copy.deepcopy(word_box_list)) word_box_list = self.reverse_rotate_crop_image( copy.deepcopy(box), word_box_list, direction ) - res.append([rec_txt, rec_conf, word_box_list, word_box_content_list]) + res.append( + [rec_txt, rec_conf, word_box_list, word_box_content_list, conf_list] + ) return res @staticmethod @@ -60,13 +62,13 @@ def get_box_direction(box: np.ndarray) -> str: @staticmethod def cal_ocr_word_box( rec_txt: str, box: np.ndarray, rec_word_info: List[Tuple[str, List[int]]] - ) -> Tuple[List[str], List[List[int]]]: + ) -> Tuple[List[str], List[List[int]], List[float]]: """Calculate the detection frame for each word based on the results of recognition and detection of ocr 汉字坐标是单字的 英语坐标是单词级别的 """ - col_num, word_list, word_col_list, state_list = rec_word_info + col_num, word_list, word_col_list, state_list, conf_list = rec_word_info box = box.tolist() bbox_x_start = box[0][0] bbox_x_end = box[1][0] @@ -84,7 +86,7 @@ def cal_ocr_word_box( def cal_char_width(width_list, word_col_): if len(word_col_) == 1: return - char_total_length = (word_col_[-1] - word_col_[0] + 1) * cell_width + char_total_length = (word_col_[-1] - word_col_[0]) * cell_width char_width = char_total_length / (len(word_col_) - 1) width_list.append(char_width) @@ -124,7 +126,7 @@ def cal_box(col_list, width_list, word_box_list_): cal_box(cn_col_list, cn_width_list, word_box_list) cal_box(en_col_list, en_width_list, word_box_list) sorted_word_box_list = sorted(word_box_list, key=lambda box: box[0][0]) - return word_box_content_list, sorted_word_box_list + return word_box_content_list, sorted_word_box_list, conf_list @staticmethod def adjust_box_overlap( @@ -137,8 +139,8 @@ def adjust_box_overlap( distance = abs(cur[1][0] - nxt[0][0]) cur[1][0] -= distance / 2 cur[2][0] -= distance / 2 - nxt[0][0] += distance / 2 - nxt[3][0] += distance / 2 + nxt[0][0] += distance - distance / 2 + nxt[3][0] += distance - distance / 2 return word_box_list def reverse_rotate_crop_image( @@ -218,6 +220,15 @@ def s_rotate(angle, valuex, valuey, pointx, pointy): @staticmethod def order_points(box: List[List[int]]) -> List[List[int]]: """矩形框顺序排列""" + + def convert_to_1x2(p): + if p.shape == (2,): + return p.reshape((1, 2)) + elif p.shape == (1, 2): + return p + else: + return p[:1, :] + box = np.array(box).reshape((-1, 2)) center_x, center_y = np.mean(box[:, 0]), np.mean(box[:, 1]) if np.any(box[:, 0] == center_x) and np.any( @@ -261,9 +272,10 @@ def order_points(box: List[List[int]]) -> List[List[int]]: p23[np.where(p23[:, 1] == np.min(p23[:, 1]))], p23[np.where(p23[:, 1] == np.max(p23[:, 1]))], ) - # 解决单字矩形框重叠导致多个相同框的情况 - p1 = p1[:1, :] - p2 = p2[:1, :] - p3 = p3[:1, :] - p4 = p4[:1, :] + + # 解决单字切割后横坐标完全相同的shape错误 + p1 = convert_to_1x2(p1) + p2 = convert_to_1x2(p2) + p3 = convert_to_1x2(p3) + p4 = convert_to_1x2(p4) return np.array([p1, p2, p3, p4]).reshape((-1, 2)).tolist() diff --git a/python/rapidocr_openvino/ch_ppocr_det/text_detect.py b/python/rapidocr_openvino/ch_ppocr_det/text_detect.py index 2f2a0a967..2fdb4e9ac 100644 --- a/python/rapidocr_openvino/ch_ppocr_det/text_detect.py +++ b/python/rapidocr_openvino/ch_ppocr_det/text_detect.py @@ -26,9 +26,9 @@ class TextDetector: def __init__(self, config: Dict[str, Any]): - limit_side_len = config.get("limit_side_len", 736) - limit_type = config.get("limit_type", "min") - self.preprocess_op = DetPreProcess(limit_side_len, limit_type) + self.limit_type = config.get("limit_type", "min") + self.limit_side_len = config.get("limit_side_len", 736) + self.preprocess_op = None post_process = { "thresh": config.get("thresh", 0.3), @@ -49,6 +49,7 @@ def __call__(self, img: np.ndarray) -> Tuple[Optional[np.ndarray], float]: raise ValueError("img is None") ori_img_shape = img.shape[0], img.shape[1] + self.preprocess_op = self.get_preprocess(max(img.shape[0], img.shape[1])) prepro_img = self.preprocess_op(img) if prepro_img is None: return None, 0 @@ -59,6 +60,17 @@ def __call__(self, img: np.ndarray) -> Tuple[Optional[np.ndarray], float]: elapse = time.perf_counter() - start_time return dt_boxes, elapse + def get_preprocess(self, max_wh): + if self.limit_type == "min": + limit_side_len = self.limit_side_len + elif max_wh < 960: + limit_side_len = 960 + elif max_wh < 1500: + limit_side_len = 1500 + else: + limit_side_len = 2000 + return DetPreProcess(limit_side_len, self.limit_type) + def filter_tag_det_res( self, dt_boxes: np.ndarray, image_shape: Tuple[int, int] ) -> np.ndarray: diff --git a/python/rapidocr_openvino/ch_ppocr_rec/utils.py b/python/rapidocr_openvino/ch_ppocr_rec/utils.py index 83b89518d..224b2f879 100644 --- a/python/rapidocr_openvino/ch_ppocr_rec/utils.py +++ b/python/rapidocr_openvino/ch_ppocr_rec/utils.py @@ -92,7 +92,7 @@ def decode( selection &= text_index[batch_idx] != ignored_token if text_prob is not None: - conf_list = text_prob[batch_idx][selection] + conf_list = np.array(text_prob[batch_idx][selection]).tolist() else: conf_list = [1] * len(selection) @@ -116,6 +116,7 @@ def decode( word_list, word_col_list, state_list, + conf_list, ], ) ) @@ -147,7 +148,13 @@ def get_word_info( word_list = [] word_col_list = [] state_list = [] - valid_col = np.where(selection == True)[0] + valid_col = np.where(selection)[0] + col_width = np.zeros(valid_col.shape) + if len(valid_col) > 0: + col_width[1:] = valid_col[1:] - valid_col[:-1] + col_width[0] = min( + 3 if "\u4e00" <= text[0] <= "\u9fff" else 2, int(valid_col[0]) + ) for c_i, char in enumerate(text): if "\u4e00" <= char <= "\u9fff": @@ -155,10 +162,10 @@ def get_word_info( else: c_state = "en&num" - if state == None: + if state is None: state = c_state - if state != c_state: + if state != c_state or col_width[c_i] > 4: if len(word_content) != 0: word_list.append(word_content) word_col_list.append(word_col_content) diff --git a/python/rapidocr_openvino/main.py b/python/rapidocr_openvino/main.py index b12acaeef..330651dd4 100644 --- a/python/rapidocr_openvino/main.py +++ b/python/rapidocr_openvino/main.py @@ -276,7 +276,7 @@ def get_final_res( self, dt_boxes: Optional[List[np.ndarray]], cls_res: Optional[List[List[Union[str, float]]]], - rec_res: Optional[List[Tuple[str, float]]], + rec_res: Optional[List[Tuple[str, float, List[Union[str, float]]]]], det_elapse: float, cls_elapse: float, rec_elapse: float, @@ -330,10 +330,7 @@ def main(): use_cls = not args.no_cls use_rec = not args.no_rec result, elapse_list = ocr_engine( - args.img_path, - use_det=use_det, - use_cls=use_cls, - use_rec=use_rec, + args.img_path, use_det=use_det, use_cls=use_cls, use_rec=use_rec, **vars(args) ) logger.info(result)