diff --git a/python/rapidocr_onnxruntime/cal_rec_boxes/main.py b/python/rapidocr_onnxruntime/cal_rec_boxes/main.py index 4071633de..c8ea5da4c 100644 --- a/python/rapidocr_onnxruntime/cal_rec_boxes/main.py +++ b/python/rapidocr_onnxruntime/cal_rec_boxes/main.py @@ -35,7 +35,9 @@ def __call__( word_box_list = self.reverse_rotate_crop_image( copy.deepcopy(box), word_box_list, direction ) - res.append([rec_txt, rec_conf, word_box_list, word_box_content_list, conf_list]) + res.append( + [rec_txt, rec_conf, word_box_list, word_box_content_list, conf_list] + ) return res @staticmethod @@ -137,8 +139,8 @@ def adjust_box_overlap( distance = abs(cur[1][0] - nxt[0][0]) cur[1][0] -= distance / 2 cur[2][0] -= distance / 2 - nxt[0][0] += distance / 2 - nxt[3][0] += distance / 2 + nxt[0][0] += distance - distance / 2 + nxt[3][0] += distance - distance / 2 return word_box_list def reverse_rotate_crop_image( @@ -218,6 +220,15 @@ def s_rotate(angle, valuex, valuey, pointx, pointy): @staticmethod def order_points(box: List[List[int]]) -> List[List[int]]: """矩形框顺序排列""" + + def convert_to_1x2(p): + if p.shape == (2,): + return p.reshape((1, 2)) + elif p.shape == (1, 2): + return p + else: + return p[:1, :] + box = np.array(box).reshape((-1, 2)) center_x, center_y = np.mean(box[:, 0]), np.mean(box[:, 1]) if np.any(box[:, 0] == center_x) and np.any( @@ -261,9 +272,10 @@ def order_points(box: List[List[int]]) -> List[List[int]]: p23[np.where(p23[:, 1] == np.min(p23[:, 1]))], p23[np.where(p23[:, 1] == np.max(p23[:, 1]))], ) - # 解决单字矩形框重叠导致多个相同框的情况 - p1 = p1[:1, :] - p2 = p2[:1, :] - p3 = p3[:1, :] - p4 = p4[:1, :] + + # 解决单字切割后横坐标完全相同的shape错误 + p1 = convert_to_1x2(p1) + p2 = convert_to_1x2(p2) + p3 = convert_to_1x2(p3) + p4 = convert_to_1x2(p4) return np.array([p1, p2, p3, p4]).reshape((-1, 2)).tolist() diff --git a/python/rapidocr_onnxruntime/ch_ppocr_rec/utils.py b/python/rapidocr_onnxruntime/ch_ppocr_rec/utils.py index ace70ad44..224b2f879 100644 --- a/python/rapidocr_onnxruntime/ch_ppocr_rec/utils.py +++ b/python/rapidocr_onnxruntime/ch_ppocr_rec/utils.py @@ -116,7 +116,7 @@ def decode( word_list, word_col_list, state_list, - conf_list + conf_list, ], ) ) @@ -152,7 +152,9 @@ def get_word_info( col_width = np.zeros(valid_col.shape) if len(valid_col) > 0: col_width[1:] = valid_col[1:] - valid_col[:-1] - col_width[0] = min(3 if "\u4e00" <= text[0] <= "\u9fff" else 2, int(valid_col[0])) + col_width[0] = min( + 3 if "\u4e00" <= text[0] <= "\u9fff" else 2, int(valid_col[0]) + ) for c_i, char in enumerate(text): if "\u4e00" <= char <= "\u9fff":