Skip to content

Commit

Permalink
Merge pull request #87 from RapidAI/optim_wired_table_rotated
Browse files Browse the repository at this point in the history
feat: optim rotated wired table rec
  • Loading branch information
Joker1212 authored Nov 28, 2024
2 parents 639f6f7 + 2eed579 commit 725777a
Show file tree
Hide file tree
Showing 6 changed files with 206 additions and 31 deletions.
10 changes: 7 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,13 @@
[English](README_en.md) | 简体中文
</div>

### 最近更新
- **2024.11.12**
- 抽离模型识别和处理过程核心阈值,方便大家进行微调适配自己的场景[输入参数](#核心参数)
### 最近更新
- **2024.11.16**
- 补充文档扭曲矫正方案,可作为前置处理 [RapidUnwrap](https://github.com/Joker1212/RapidUnWrap)
- **2024.11.22**
- 支持单字符匹配方案,需要RapidOCR>=1.4.0
- **2024.11.28**
- wiredV2模型提升对轻度旋转表格识别准确率,参见[输入参数](#核心参数)

### 简介
💖该仓库是用来对文档中表格做结构化识别的推理库,包括来自阿里读光有线和无线表格识别模型,llaipython(微信)贡献的有线表格模型,网易Qanything内置表格分类模型等。\
Expand Down Expand Up @@ -132,6 +132,7 @@ ocr_res = trans_char_ocr_res(ocr_res)

#### 表格旋转及透视修正
##### 1.简单背景,小角度场景
最新wiredV2模型自适应小角度旋转
```python
import cv2

Expand Down Expand Up @@ -178,6 +179,9 @@ html, elasp, polygons, logic_points, ocr_res = wired_table_rec(
ocr_result, # 输入rapidOCR识别结果,不传默认使用内部rapidocr模型
version="v2", #默认使用v2线框模型,切换阿里读光模型可改为v1
enhance_box_line=True, # 识别框切割增强(关闭避免多余切割,开启减少漏切割),默认为True
col_threshold=15, # 识别框左边界x坐标差值小于col_threshold的默认同列
row_threshold=10, # 识别框上边界y坐标差值小于row_threshold的默认同行
rotated_fix=True, # wiredV2支持,轻度旋转(-45°~45°)矫正,默认为True
need_ocr=True, # 是否进行OCR识别, 默认为True
rec_again=True,# 是否针对未识别到文字的表格框,进行单独截取再识别,默认为True
)
Expand Down
2 changes: 1 addition & 1 deletion tests/test_wired_table_rec.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def test_squeeze_bug():
ocr_result, _ = ocr_engine(img_path)
table_str, *_ = table_recog(str(img_path), ocr_result)
td_nums = get_td_nums(table_str)
assert td_nums >= 192
assert td_nums >= 160


@pytest.mark.parametrize(
Expand Down
10 changes: 8 additions & 2 deletions wired_table_rec/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,17 +60,23 @@ def __call__(
s = time.perf_counter()
rec_again = True
need_ocr = True
col_threshold = 15
row_threshold = 10
if kwargs:
rec_again = kwargs.get("rec_again", True)
need_ocr = kwargs.get("need_ocr", True)
col_threshold = kwargs.get("col_threshold", 15)
row_threshold = kwargs.get("row_threshold", 10)
img = self.load_img(img)
polygons = self.table_line_rec(img, **kwargs)
polygons, rotated_polygons = self.table_line_rec(img, **kwargs)
if polygons is None:
logging.warning("polygons is None.")
return "", 0.0, None, None, None

try:
table_res, logi_points = self.table_recover(polygons)
table_res, logi_points = self.table_recover(
rotated_polygons, row_threshold, col_threshold
)
# 将坐标由逆时针转为顺时针方向,后续处理与无线表格对齐
polygons[:, 1, :], polygons[:, 3, :] = (
polygons[:, 3, :].copy(),
Expand Down
10 changes: 6 additions & 4 deletions wired_table_rec/table_line_rec.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# -*- encoding: utf-8 -*-
# @Author: SWHL
# @Contact: [email protected]
from typing import Any, Dict, Optional
from typing import Any, Dict, Optional, Tuple

import cv2
import numpy as np
Expand Down Expand Up @@ -36,12 +36,14 @@ def __init__(self, model_path: Optional[str] = None):

self.session = OrtInferSession(model_path)

def __call__(self, img: np.ndarray, **kwargs) -> Optional[np.ndarray]:
def __call__(
self, img: np.ndarray, **kwargs
) -> Tuple[Optional[np.ndarray], Optional[np.ndarray]]:
img_info = self.preprocess(img)
pred = self.infer(img_info)
polygons = self.postprocess(pred)
if polygons.size == 0:
return None
return None, None

polygons = polygons.reshape(polygons.shape[0], 4, 2)
del_idxs = filter_duplicated_box(
Expand All @@ -53,7 +55,7 @@ def __call__(self, img: np.ndarray, **kwargs) -> Optional[np.ndarray]:
)
polygons = polygons[idx]
polygons = merge_adjacent_polys(polygons)
return polygons
return polygons, polygons

def preprocess(self, img) -> Dict[str, Any]:
height, width = img.shape[:2]
Expand Down
180 changes: 170 additions & 10 deletions wired_table_rec/table_line_rec_plus.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import copy
import math
from typing import Optional, Dict, Any
from typing import Optional, Dict, Any, Tuple

import cv2
import numpy as np
from skimage import measure

import matplotlib.pyplot as plt
from wired_table_rec.utils import OrtInferSession, resize_img
from wired_table_rec.utils_table_line_rec import (
get_table_line,
Expand All @@ -31,22 +31,31 @@ def __init__(self, model_path: Optional[str] = None):

self.session = OrtInferSession(model_path)

def __call__(self, img: np.ndarray, **kwargs) -> Optional[np.ndarray]:
def __call__(
self, img: np.ndarray, **kwargs
) -> Tuple[Optional[np.ndarray], Optional[np.ndarray]]:
img_info = self.preprocess(img)
pred = self.infer(img_info)
polygons = self.postprocess(img, pred, **kwargs)
polygons, rotated_polygons = self.postprocess(img, pred, **kwargs)
if polygons.size == 0:
return None
return None, None
polygons = polygons.reshape(polygons.shape[0], 4, 2)
polygons[:, 3, :], polygons[:, 1, :] = (
polygons[:, 1, :].copy(),
polygons[:, 3, :].copy(),
)
rotated_polygons = rotated_polygons.reshape(rotated_polygons.shape[0], 4, 2)
rotated_polygons[:, 3, :], rotated_polygons[:, 1, :] = (
rotated_polygons[:, 1, :].copy(),
rotated_polygons[:, 3, :].copy(),
)
_, idx = sorted_ocr_boxes(
[box_4_2_poly_to_box_4_1(poly_box) for poly_box in polygons], threhold=0.4
[box_4_2_poly_to_box_4_1(poly_box) for poly_box in rotated_polygons],
threhold=0.4,
)
polygons = polygons[idx]
return polygons
rotated_polygons = rotated_polygons[idx]
return polygons, rotated_polygons

def preprocess(self, img) -> Dict[str, Any]:
scale = (self.inp_height, self.inp_width)
Expand Down Expand Up @@ -86,7 +95,8 @@ def postprocess(self, img, pred, **kwargs):
extend_line = (
kwargs.get("extend_line", enhance_box_line) if kwargs else enhance_box_line
) # 是否进行线段延长使得端点连接

# 是否进行旋转修正
rotated_fix = kwargs.get("rotated_fix") if kwargs else True
ori_shape = img.shape
pred = np.uint8(pred)
hpred = copy.deepcopy(pred) # 横线
Expand Down Expand Up @@ -120,8 +130,109 @@ def postprocess(self, img, pred, **kwargs):
colboxes += rboxes_col_
if extend_line:
rowboxes, colboxes = final_adjust_lines(rowboxes, colboxes)
tmp = np.zeros(img.shape[:2], dtype="uint8")
tmp = draw_lines(tmp, rowboxes + colboxes, color=255, lineW=2)
line_img = np.zeros(img.shape[:2], dtype="uint8")
line_img = draw_lines(line_img, rowboxes + colboxes, color=255, lineW=2)
rotated_angle = self.cal_rotate_angle(line_img)
if rotated_fix and abs(rotated_angle) > 0.3:
rotated_line_img = self.rotate_image(line_img, rotated_angle)
rotated_polygons = self.cal_region_boxes(rotated_line_img)
polygons = self.unrotate_polygons(
rotated_polygons, rotated_angle, line_img.shape
)
else:
polygons = self.cal_region_boxes(line_img)
rotated_polygons = polygons.copy()
return polygons, rotated_polygons

def find_max_corners(self, line_img):
# 找到所有轮廓
contours, _ = cv2.findContours(
line_img, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE
)

# 如果没有找到轮廓,返回空列表
if not contours:
return []

# 找到面积最大的轮廓
max_contour = max(contours, key=cv2.contourArea)
# 计算最大轮廓的最小外接矩形
rect = cv2.minAreaRect(max_contour)

# 获取最小外接矩形的四个角点
box = cv2.boxPoints(rect)
box = np.int0(box)
#
# 对角点进行排序
# 计算中心点
center = np.mean(box, axis=0)

# 计算每个点与中心点的角度
angles = np.arctan2(box[:, 1] - center[1], box[:, 0] - center[0])

# 按角度排序
sorted_indices = np.argsort(angles)
sorted_box = box[sorted_indices]

# 确保顺序为左上、右上、右下、左下
top_left = sorted_box[0]
top_right = sorted_box[1]
bottom_right = sorted_box[2]
bottom_left = sorted_box[3]

# 创建一个纯黑色背景图像
black_img = np.zeros_like(line_img)

# 可视化最大轮廓和四个角点
plt.figure(figsize=(10, 10))
plt.imshow(black_img, cmap="gray")
plt.title("Max Contour and Corners on Black Background")

# 绘制最大轮廓
max_contour = max_contour.reshape(-1, 2)
plt.plot(max_contour[:, 0], max_contour[:, 1], "b-", linewidth=2)

# 绘制四个角点
plt.scatter(
[top_left[0], top_right[0], bottom_right[0], bottom_left[0]],
[top_left[1], top_right[1], bottom_right[1], bottom_left[1]],
c="g",
s=100,
marker="o",
)

plt.axis("off")
plt.show()

return [top_left, top_right, bottom_right, bottom_left]

def extend_image_and_adjust_coordinates(self, img, corners, polygons):
# 计算扩展边界
min_x = min(point[0] for point in corners)
min_y = min(point[1] for point in corners)
max_x = max(point[0] for point in corners)
max_y = max(point[1] for point in corners)

# 计算扩展的宽度和高度
left = -min_x if min_x < 0 else 0
top = -min_y if min_y < 0 else 0
right = max_x - img.shape[1] if max_x > img.shape[1] else 0
bottom = max_y - img.shape[0] if max_y > img.shape[0] else 0

# 扩展图像
new_width = img.shape[1] + left + right
new_height = img.shape[0] + top + bottom
extended_img = np.zeros((new_height, new_width), dtype=img.dtype)
extended_img[top : top + img.shape[0], left : left + img.shape[1]] = img

# 调整角点和多边形坐标
adjusted_corners = [(point[0] + left, point[1] + top) for point in corners]
adjusted_polygons = polygons.copy()
adjusted_polygons[:, 0::2] += left
adjusted_polygons[:, 1::2] += top
return extended_img, adjusted_corners, adjusted_polygons

def cal_region_boxes(self, tmp):
labels = measure.label(tmp < 255, connectivity=2) # 8连通区域标记
regions = measure.regionprops(labels)
ceilboxes = min_area_rect_box(
Expand All @@ -133,3 +244,52 @@ def postprocess(self, img, pred, **kwargs):
adjust_box=False,
) # 最后一个参数改为False
return np.array(ceilboxes)

def cal_rotate_angle(self, tmp):
# 计算最外侧的旋转框
contours, _ = cv2.findContours(tmp, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
if not contours:
return 0
largest_contour = max(contours, key=cv2.contourArea)
rect = cv2.minAreaRect(largest_contour)
# 计算旋转角度
angle = rect[2]
if angle < -45:
angle += 90
elif angle > 45:
angle -= 90
return angle

def rotate_image(self, image, angle):
# 获取图像的中心点
(h, w) = image.shape[:2]
center = (w // 2, h // 2)

# 计算旋转矩阵
M = cv2.getRotationMatrix2D(center, angle, 1.0)

# 进行旋转
rotated_image = cv2.warpAffine(
image, M, (w, h), flags=cv2.INTER_NEAREST, borderMode=cv2.BORDER_REPLICATE
)

return rotated_image

def unrotate_polygons(
self, polygons: np.ndarray, angle: float, img_shape: tuple
) -> np.ndarray:
# 将多边形旋转回原始位置
(h, w) = img_shape
center = (w // 2, h // 2)
M_inv = cv2.getRotationMatrix2D(center, -angle, 1.0)

# 将 (N, 8) 转换为 (N, 4, 2)
polygons_reshaped = polygons.reshape(-1, 4, 2)

# 批量逆旋转
unrotated_polygons = cv2.transform(polygons_reshaped, M_inv)

# 将 (N, 4, 2) 转换回 (N, 8)
unrotated_polygons = unrotated_polygons.reshape(-1, 8)

return unrotated_polygons
Loading

0 comments on commit 725777a

Please sign in to comment.