From 0d984158fb964979f8070f8907a065a6a06324d6 Mon Sep 17 00:00:00 2001
From: Jokcer <519548295@qq.com>
Date: Mon, 25 Nov 2024 21:54:29 +0800
Subject: [PATCH 1/2] feat: add logic points decode & vis
---
rapid_table/main.py | 7 ++-
rapid_table/table_matcher/matcher.py | 67 +++++++++++++++++++++
rapid_table/utils.py | 88 ++++++++++++++++++++++++----
tests/test_table.py | 4 ++
4 files changed, 152 insertions(+), 14 deletions(-)
diff --git a/rapid_table/main.py b/rapid_table/main.py
index b2a52c0..d620fe5 100644
--- a/rapid_table/main.py
+++ b/rapid_table/main.py
@@ -43,6 +43,7 @@ def __call__(
self,
img_content: Union[str, np.ndarray, bytes, Path],
ocr_result: List[Union[List[List[float]], str, str]] = None,
+ return_logic_points = False
) -> Tuple[str, float]:
if self.ocr_engine is None and ocr_result is None:
raise ValueError(
@@ -63,7 +64,11 @@ def __call__(
if self.model_type == "slanet-plus":
pred_bboxes = self.adapt_slanet_plus(img, pred_bboxes)
pred_html = self.table_matcher(pred_structures, pred_bboxes, dt_boxes, rec_res)
-
+ # 避免低版本升级后出现问题,默认不打开
+ if return_logic_points:
+ logic_points = self.table_matcher.decode_logic_points(pred_structures)
+ elapse = time.time() - s
+ return pred_html, pred_bboxes, logic_points, elapse
elapse = time.time() - s
return pred_html, pred_bboxes, elapse
diff --git a/rapid_table/table_matcher/matcher.py b/rapid_table/table_matcher/matcher.py
index b930c70..0453929 100644
--- a/rapid_table/table_matcher/matcher.py
+++ b/rapid_table/table_matcher/matcher.py
@@ -111,6 +111,73 @@ def get_pred_html(self, pred_structures, matched_index, ocr_contents):
filter_elements = ["", "", "
", ""]
end_html = [v for v in end_html if v not in filter_elements]
return "".join(end_html), end_html
+ def decode_logic_points(self, pred_structures):
+ logic_points = []
+ current_row = 0
+ current_col = 0
+ max_rows = 0
+ max_cols = 0
+ occupied_cells = {} # 用于记录已经被占用的单元格
+
+ def is_occupied(row, col):
+ return (row, col) in occupied_cells
+
+ def mark_occupied(row, col, rowspan, colspan):
+ for r in range(row, row + rowspan):
+ for c in range(col, col + colspan):
+ occupied_cells[(r, c)] = True
+
+ i = 0
+ while i < len(pred_structures):
+ token = pred_structures[i]
+
+ if token == '':
+ current_col = 0 # 每次遇到
时,重置当前列号
+ elif token == '
':
+ current_row += 1 # 行结束,行号增加
+ elif token .startswith(' | ':
+ j += 1
+ # 提取 colspan 和 rowspan 属性
+ while j < len(pred_structures) and not pred_structures[j].startswith('>'):
+ if 'colspan=' in pred_structures[j]:
+ colspan = int(pred_structures[j].split('=')[1].strip('"\''))
+ elif 'rowspan=' in pred_structures[j]:
+ rowspan = int(pred_structures[j].split('=')[1].strip('"\''))
+ j += 1
+
+ # 跳过已经处理过的属性 token
+ i = j
+
+ # 找到下一个未被占用的列
+ while is_occupied(current_row, current_col):
+ current_col += 1
+
+ # 计算逻辑坐标
+ r_start = current_row
+ r_end = current_row + rowspan - 1
+ col_start = current_col
+ col_end = current_col + colspan - 1
+
+ # 记录逻辑坐标
+ logic_points.append([r_start, r_end, col_start, col_end])
+
+ # 标记占用的单元格
+ mark_occupied(r_start, col_start, rowspan, colspan)
+
+ # 更新当前列号
+ current_col += colspan
+
+ # 更新最大行数和列数
+ max_rows = max(max_rows, r_end + 1)
+ max_cols = max(max_cols, col_end + 1)
+
+ i += 1
+
+ return logic_points
def _filter_ocr_result(self, pred_bboxes, dt_boxes, rec_res):
y1 = pred_bboxes[:, 1::2].min()
diff --git a/rapid_table/utils.py b/rapid_table/utils.py
index cc860b3..350f2e5 100644
--- a/rapid_table/utils.py
+++ b/rapid_table/utils.py
@@ -1,9 +1,10 @@
# -*- encoding: utf-8 -*-
# @Author: SWHL
# @Contact: liekkaskono@163.com
+import os
from io import BytesIO
from pathlib import Path
-from typing import Optional, Union
+from typing import Optional, Union, List
import cv2
import numpy as np
@@ -14,7 +15,7 @@
class LoadImage:
def __init__(
- self,
+ self,
):
pass
@@ -79,17 +80,19 @@ class LoadImageError(Exception):
class VisTable:
def __init__(
- self,
+ self,
):
self.load_img = LoadImage()
def __call__(
- self,
- img_path: Union[str, Path],
- table_html_str: str,
- save_html_path: Optional[str] = None,
- table_cell_bboxes: Optional[np.ndarray] = None,
- save_drawed_path: Optional[str] = None,
+ self,
+ img_path: Union[str, Path],
+ table_html_str: str,
+ save_html_path: Optional[str] = None,
+ table_cell_bboxes: Optional[np.ndarray] = None,
+ save_drawed_path: Optional[str] = None,
+ logic_points: List[List[float]] = None,
+ save_logic_path: Optional[str] = None,
) -> None:
if save_html_path:
html_with_border = self.insert_border_style(table_html_str)
@@ -110,18 +113,77 @@ def __call__(
if save_drawed_path:
self.save_img(save_drawed_path, drawed_img)
-
+ if save_logic_path and logic_points:
+ polygons = [[box[0],box[1], box[4], box[5]] for box in table_cell_bboxes]
+ self.plot_rec_box_with_logic_info(img_path, save_logic_path, logic_points, polygons)
return drawed_img
def insert_border_style(self, table_html_str: str):
- style_res = """"""
prefix_table, suffix_table = table_html_str.split("")
html_with_border = f"{prefix_table}{style_res}{suffix_table}"
return html_with_border
+ def plot_rec_box_with_logic_info(self, img_path, output_path, logic_points, sorted_polygons):
+ """
+ :param img_path
+ :param output_path
+ :param logic_points: [row_start,row_end,col_start,col_end]
+ :param sorted_polygons: [xmin,ymin,xmax,ymax]
+ :return:
+ """
+ # 读取原图
+ img = cv2.imread(img_path)
+ img = cv2.copyMakeBorder(
+ img, 0, 0, 0, 100, cv2.BORDER_CONSTANT, value=[255, 255, 255]
+ )
+ # 绘制 polygons 矩形
+ for idx, polygon in enumerate(sorted_polygons):
+ x0, y0, x1, y1 = polygon[0], polygon[1], polygon[2], polygon[3]
+ x0 = round(x0)
+ y0 = round(y0)
+ x1 = round(x1)
+ y1 = round(y1)
+ cv2.rectangle(img, (x0, y0), (x1, y1), (0, 0, 255), 1)
+ # 增大字体大小和线宽
+ font_scale = 0.9 # 原先是0.5
+ thickness = 1 # 原先是1
+ logic_point = logic_points[idx]
+ cv2.putText(
+ img,
+ f"row: {logic_point[0]}-{logic_point[1]}",
+ (x0 + 3, y0 + 8),
+ cv2.FONT_HERSHEY_PLAIN,
+ font_scale,
+ (0, 0, 255),
+ thickness,
+ )
+ cv2.putText(
+ img,
+ f"col: {logic_point[2]}-{logic_point[3]}",
+ (x0 + 3, y0 + 18),
+ cv2.FONT_HERSHEY_PLAIN,
+ font_scale,
+ (0, 0, 255),
+ thickness,
+ )
+ os.makedirs(os.path.dirname(output_path), exist_ok=True)
+ # 保存绘制后的图像
+ cv2.imwrite(output_path, img)
+
@staticmethod
def draw_rectangle(img: np.ndarray, boxes: np.ndarray) -> np.ndarray:
img_copy = img.copy()
diff --git a/tests/test_table.py b/tests/test_table.py
index 8e69a3c..41a2060 100644
--- a/tests/test_table.py
+++ b/tests/test_table.py
@@ -29,3 +29,7 @@ def test_ocr_input():
def test_input_ocr_none():
table_html_str, table_cell_bboxes, elapse = table_engine(img_path)
assert table_html_str.count("") == 16
+
+def test_logic_points_out():
+ table_html_str, table_cell_bboxes, logic_points, elapse = table_engine(img_path, return_logic_points=True)
+ assert len(table_cell_bboxes) == len(logic_points)
From e97025a9f672327d5a279ceaecbb6bdd5f8d32c0 Mon Sep 17 00:00:00 2001
From: Jokcer <519548295@qq.com>
Date: Mon, 25 Nov 2024 22:05:12 +0800
Subject: [PATCH 2/2] chore: update readme
---
README.md | 7 ++++++-
1 file changed, 6 insertions(+), 1 deletion(-)
diff --git a/README.md b/README.md
index 34a870a..6a62216 100644
--- a/README.md
+++ b/README.md
@@ -44,7 +44,7 @@ slanet_plus是paddlex内置的SLANet升级版模型,准确率有大幅提升
#### 2024.11.24 update
-- 支持gpu推理,适配 rapidOCR 单字识别匹配
+- 支持gpu推理,适配 rapidOCR 单字识别匹配,支持逻辑坐标返回及可视化
#### 2024.10.13 update
- 补充最新paddlex-SLANet-plus 模型(paddle2onnx原因暂不能支持onnx)
@@ -143,6 +143,11 @@ save_drawed_path = save_dir / f"vis_{Path(img_path).name}"
viser(img_path, table_html_str, save_html_path, table_cell_bboxes, save_drawed_path)
+# 返回逻辑坐标
+# table_html_str, table_cell_bboxes, logic_points, elapse = table_engine(img_path, ocr_result, return_logic_points=True)
+# save_logic_path = save_dir / f"vis_logic_{Path(img_path).name}"
+# viser(img_path, table_html_str, save_html_path, table_cell_bboxes, save_drawed_path,logic_points, save_logic_path)
+
print(table_html_str)
```