diff --git a/README.md b/README.md index 34a870a..6a62216 100644 --- a/README.md +++ b/README.md @@ -44,7 +44,7 @@ slanet_plus是paddlex内置的SLANet升级版模型,准确率有大幅提升
#### 2024.11.24 update -- 支持gpu推理,适配 rapidOCR 单字识别匹配 +- 支持gpu推理,适配 rapidOCR 单字识别匹配,支持逻辑坐标返回及可视化 #### 2024.10.13 update - 补充最新paddlex-SLANet-plus 模型(paddle2onnx原因暂不能支持onnx) @@ -143,6 +143,11 @@ save_drawed_path = save_dir / f"vis_{Path(img_path).name}" viser(img_path, table_html_str, save_html_path, table_cell_bboxes, save_drawed_path) +# 返回逻辑坐标 +# table_html_str, table_cell_bboxes, logic_points, elapse = table_engine(img_path, ocr_result, return_logic_points=True) +# save_logic_path = save_dir / f"vis_logic_{Path(img_path).name}" +# viser(img_path, table_html_str, save_html_path, table_cell_bboxes, save_drawed_path,logic_points, save_logic_path) + print(table_html_str) ``` diff --git a/rapid_table/main.py b/rapid_table/main.py index b2a52c0..d620fe5 100644 --- a/rapid_table/main.py +++ b/rapid_table/main.py @@ -43,6 +43,7 @@ def __call__( self, img_content: Union[str, np.ndarray, bytes, Path], ocr_result: List[Union[List[List[float]], str, str]] = None, + return_logic_points = False ) -> Tuple[str, float]: if self.ocr_engine is None and ocr_result is None: raise ValueError( @@ -63,7 +64,11 @@ def __call__( if self.model_type == "slanet-plus": pred_bboxes = self.adapt_slanet_plus(img, pred_bboxes) pred_html = self.table_matcher(pred_structures, pred_bboxes, dt_boxes, rec_res) - + # 避免低版本升级后出现问题,默认不打开 + if return_logic_points: + logic_points = self.table_matcher.decode_logic_points(pred_structures) + elapse = time.time() - s + return pred_html, pred_bboxes, logic_points, elapse elapse = time.time() - s return pred_html, pred_bboxes, elapse diff --git a/rapid_table/table_matcher/matcher.py b/rapid_table/table_matcher/matcher.py index b930c70..0453929 100644 --- a/rapid_table/table_matcher/matcher.py +++ b/rapid_table/table_matcher/matcher.py @@ -111,6 +111,73 @@ def get_pred_html(self, pred_structures, matched_index, ocr_contents): filter_elements = ["", "", "", ""] end_html = [v for v in end_html if v not in filter_elements] return "".join(end_html), end_html + def decode_logic_points(self, pred_structures): + logic_points = [] + current_row = 0 + current_col = 0 + max_rows = 0 + max_cols = 0 + occupied_cells = {} # 用于记录已经被占用的单元格 + + def is_occupied(row, col): + return (row, col) in occupied_cells + + def mark_occupied(row, col, rowspan, colspan): + for r in range(row, row + rowspan): + for c in range(col, col + colspan): + occupied_cells[(r, c)] = True + + i = 0 + while i < len(pred_structures): + token = pred_structures[i] + + if token == '': + current_col = 0 # 每次遇到 时,重置当前列号 + elif token == '': + current_row += 1 # 行结束,行号增加 + elif token .startswith(''): + if 'colspan=' in pred_structures[j]: + colspan = int(pred_structures[j].split('=')[1].strip('"\'')) + elif 'rowspan=' in pred_structures[j]: + rowspan = int(pred_structures[j].split('=')[1].strip('"\'')) + j += 1 + + # 跳过已经处理过的属性 token + i = j + + # 找到下一个未被占用的列 + while is_occupied(current_row, current_col): + current_col += 1 + + # 计算逻辑坐标 + r_start = current_row + r_end = current_row + rowspan - 1 + col_start = current_col + col_end = current_col + colspan - 1 + + # 记录逻辑坐标 + logic_points.append([r_start, r_end, col_start, col_end]) + + # 标记占用的单元格 + mark_occupied(r_start, col_start, rowspan, colspan) + + # 更新当前列号 + current_col += colspan + + # 更新最大行数和列数 + max_rows = max(max_rows, r_end + 1) + max_cols = max(max_cols, col_end + 1) + + i += 1 + + return logic_points def _filter_ocr_result(self, pred_bboxes, dt_boxes, rec_res): y1 = pred_bboxes[:, 1::2].min() diff --git a/rapid_table/utils.py b/rapid_table/utils.py index cc860b3..350f2e5 100644 --- a/rapid_table/utils.py +++ b/rapid_table/utils.py @@ -1,9 +1,10 @@ # -*- encoding: utf-8 -*- # @Author: SWHL # @Contact: liekkaskono@163.com +import os from io import BytesIO from pathlib import Path -from typing import Optional, Union +from typing import Optional, Union, List import cv2 import numpy as np @@ -14,7 +15,7 @@ class LoadImage: def __init__( - self, + self, ): pass @@ -79,17 +80,19 @@ class LoadImageError(Exception): class VisTable: def __init__( - self, + self, ): self.load_img = LoadImage() def __call__( - self, - img_path: Union[str, Path], - table_html_str: str, - save_html_path: Optional[str] = None, - table_cell_bboxes: Optional[np.ndarray] = None, - save_drawed_path: Optional[str] = None, + self, + img_path: Union[str, Path], + table_html_str: str, + save_html_path: Optional[str] = None, + table_cell_bboxes: Optional[np.ndarray] = None, + save_drawed_path: Optional[str] = None, + logic_points: List[List[float]] = None, + save_logic_path: Optional[str] = None, ) -> None: if save_html_path: html_with_border = self.insert_border_style(table_html_str) @@ -110,18 +113,77 @@ def __call__( if save_drawed_path: self.save_img(save_drawed_path, drawed_img) - + if save_logic_path and logic_points: + polygons = [[box[0],box[1], box[4], box[5]] for box in table_cell_bboxes] + self.plot_rec_box_with_logic_info(img_path, save_logic_path, logic_points, polygons) return drawed_img def insert_border_style(self, table_html_str: str): - style_res = """""" prefix_table, suffix_table = table_html_str.split("") html_with_border = f"{prefix_table}{style_res}{suffix_table}" return html_with_border + def plot_rec_box_with_logic_info(self, img_path, output_path, logic_points, sorted_polygons): + """ + :param img_path + :param output_path + :param logic_points: [row_start,row_end,col_start,col_end] + :param sorted_polygons: [xmin,ymin,xmax,ymax] + :return: + """ + # 读取原图 + img = cv2.imread(img_path) + img = cv2.copyMakeBorder( + img, 0, 0, 0, 100, cv2.BORDER_CONSTANT, value=[255, 255, 255] + ) + # 绘制 polygons 矩形 + for idx, polygon in enumerate(sorted_polygons): + x0, y0, x1, y1 = polygon[0], polygon[1], polygon[2], polygon[3] + x0 = round(x0) + y0 = round(y0) + x1 = round(x1) + y1 = round(y1) + cv2.rectangle(img, (x0, y0), (x1, y1), (0, 0, 255), 1) + # 增大字体大小和线宽 + font_scale = 0.9 # 原先是0.5 + thickness = 1 # 原先是1 + logic_point = logic_points[idx] + cv2.putText( + img, + f"row: {logic_point[0]}-{logic_point[1]}", + (x0 + 3, y0 + 8), + cv2.FONT_HERSHEY_PLAIN, + font_scale, + (0, 0, 255), + thickness, + ) + cv2.putText( + img, + f"col: {logic_point[2]}-{logic_point[3]}", + (x0 + 3, y0 + 18), + cv2.FONT_HERSHEY_PLAIN, + font_scale, + (0, 0, 255), + thickness, + ) + os.makedirs(os.path.dirname(output_path), exist_ok=True) + # 保存绘制后的图像 + cv2.imwrite(output_path, img) + @staticmethod def draw_rectangle(img: np.ndarray, boxes: np.ndarray) -> np.ndarray: img_copy = img.copy() diff --git a/tests/test_table.py b/tests/test_table.py index 8e69a3c..41a2060 100644 --- a/tests/test_table.py +++ b/tests/test_table.py @@ -29,3 +29,7 @@ def test_ocr_input(): def test_input_ocr_none(): table_html_str, table_cell_bboxes, elapse = table_engine(img_path) assert table_html_str.count("") == 16 + +def test_logic_points_out(): + table_html_str, table_cell_bboxes, logic_points, elapse = table_engine(img_path, return_logic_points=True) + assert len(table_cell_bboxes) == len(logic_points)