From 01a5a65952fc868cc4401c851e1efde98a5bd0e8 Mon Sep 17 00:00:00 2001 From: Joker1212 <519548295@qq.com> Date: Tue, 3 Dec 2024 09:30:54 +0800 Subject: [PATCH] feat(rapidocr_onnxruntime): support en char rec (#272) * feat: sup en char rec * test: add en char rec test --- .../cal_rec_boxes/main.py | 65 +++++++++++-------- python/tests/test_ort.py | 4 ++ 2 files changed, 41 insertions(+), 28 deletions(-) diff --git a/python/rapidocr_onnxruntime/cal_rec_boxes/main.py b/python/rapidocr_onnxruntime/cal_rec_boxes/main.py index 5e0d71b7e..a1975d058 100644 --- a/python/rapidocr_onnxruntime/cal_rec_boxes/main.py +++ b/python/rapidocr_onnxruntime/cal_rec_boxes/main.py @@ -59,7 +59,7 @@ def get_box_direction(box: np.ndarray) -> str: @staticmethod def cal_ocr_word_box( - rec_txt: str, box: np.ndarray, rec_word_info: List[Tuple[str, List[int]]] + rec_txt: str, box: np.ndarray, rec_word_info: List[Tuple[str, List[int]]] ) -> Tuple[List[str], List[List[int]]]: """Calculate the detection frame for each word based on the results of recognition and detection of ocr 汉字坐标是单字的 @@ -77,39 +77,31 @@ def cal_ocr_word_box( word_box_list = [] word_box_content_list = [] cn_width_list = [] + en_width_list = [] cn_col_list = [] - for word, word_col, state in zip(word_list, word_col_list, state_list): - if state == "cn": - if len(word_col) != 1: - char_seq_length = (word_col[-1] - word_col[0] + 1) * cell_width - char_width = char_seq_length / (len(word_col) - 1) - cn_width_list.append(char_width) - cn_col_list += word_col - word_box_content_list += word - else: - cell_x_start = bbox_x_start + int(word_col[0] * cell_width) - cell_x_end = bbox_x_start + int((word_col[-1] + 1) * cell_width) - cell = [ - [cell_x_start, bbox_y_start], - [cell_x_end, bbox_y_start], - [cell_x_end, bbox_y_end], - [cell_x_start, bbox_y_end], - ] - word_box_list.append(cell) - word_box_content_list.append("".join(word)) + en_col_list = [] - if len(cn_col_list) != 0: - if len(cn_width_list) != 0: - avg_char_width = np.mean(cn_width_list) + def cal_char_width(width_list, word_col_): + if len(word_col) == 1: + return + char_total_length = (word_col_[-1] - word_col_[0] + 1) * cell_width + char_width = char_total_length / (len(word_col_) - 1) + width_list.append(char_width) + + def cal_box(col_list, width_list, word_box_list_): + if len(col_list) == 0: + return + if len(width_list) != 0: + avg_char_width = np.mean(width_list) else: avg_char_width = (bbox_x_end - bbox_x_start) / len(rec_txt) - for center_idx in cn_col_list: + for center_idx in col_list: center_x = (center_idx + 0.5) * cell_width cell_x_start = max(int(center_x - avg_char_width / 2), 0) + bbox_x_start cell_x_end = ( - min(int(center_x + avg_char_width / 2), bbox_x_end - bbox_x_start) - + bbox_x_start + min(int(center_x + avg_char_width / 2), bbox_x_end - bbox_x_start) + + bbox_x_start ) cell = [ [cell_x_start, bbox_y_start], @@ -117,7 +109,20 @@ def cal_ocr_word_box( [cell_x_end, bbox_y_end], [cell_x_start, bbox_y_end], ] - word_box_list.append(cell) + word_box_list_.append(cell) + + for word, word_col, state in zip(word_list, word_col_list, state_list): + if state == "cn": + cal_char_width(cn_width_list, word_col) + cn_col_list += word_col + word_box_content_list += word + else: + cal_char_width(en_width_list, word_col) + en_col_list += word_col + word_box_content_list += word + + cal_box(cn_col_list, cn_width_list, word_box_list) + cal_box(en_col_list, en_width_list, word_box_list) sorted_word_box_list = sorted(word_box_list, key=lambda box: box[0][0]) return word_box_content_list, sorted_word_box_list @@ -256,5 +261,9 @@ def order_points(box: List[List[int]]) -> List[List[int]]: p23[np.where(p23[:, 1] == np.min(p23[:, 1]))], p23[np.where(p23[:, 1] == np.max(p23[:, 1]))], ) - + # 解决单字矩形框重叠导致多个相同框的情况 + p1 = p1[:1, :] + p2 = p2[:1, :] + p3 = p3[:1, :] + p4 = p4[:1, :] return np.array([p1, p2, p3, p4]).reshape((-1, 2)).tolist() diff --git a/python/tests/test_ort.py b/python/tests/test_ort.py index 3784e42dd..c81b28567 100644 --- a/python/tests/test_ort.py +++ b/python/tests/test_ort.py @@ -236,6 +236,10 @@ def test_input_three_ndim_one_channel(): "text_vertical_words.png", ["已", "取", "之", "時", "不", "參", "一", "人", "見", "而"], ), + ( + "issue_170.png", + ["T", "E", "S", "T"], + ), ], ) def test_word_ocr(img_name: str, words: List[str]):