diff --git a/.github/workflows/lineless_table_rec.yml b/.github/workflows/lineless_table_rec.yml index fdd84e5..9009ac4 100644 --- a/.github/workflows/lineless_table_rec.yml +++ b/.github/workflows/lineless_table_rec.yml @@ -35,40 +35,40 @@ jobs: pytest tests/test_lineless_table_rec.py - GenerateWHL_PushPyPi: - needs: UnitTesting - runs-on: ubuntu-latest - - steps: - - uses: actions/checkout@v3 - - - name: Set up Python 3.7 - uses: actions/setup-python@v4 - with: - python-version: '3.7' - architecture: 'x64' - - - name: Run setup.py - run: | - pip install -r requirements.txt - python -m pip install --upgrade pip - pip install wheel get_pypi_latest_version - - wget https://github.com/RapidAI/TableStructureRec/releases/download/v0.0.0/lineless_table_rec_models.zip - unzip lineless_table_rec_models.zip - mv lineless_table_rec_models/*.onnx lineless_table_rec/models/ - - python setup_lineless.py bdist_wheel "${{ github.event.head_commit.message }}" - - # - name: Publish distribution 📦 to Test PyPI - # uses: pypa/gh-action-pypi-publish@v1.5.0 - # with: - # password: ${{ secrets.TEST_PYPI_API_TOKEN }} - # repository_url: https://test.pypi.org/legacy/ - # packages_dir: dist/ - - - name: Publish distribution 📦 to PyPI - uses: pypa/gh-action-pypi-publish@v1.5.0 - with: - password: ${{ secrets.PYPI_API_TOKEN }} - packages_dir: dist/ +# GenerateWHL_PushPyPi: +# needs: UnitTesting +# runs-on: ubuntu-latest +# +# steps: +# - uses: actions/checkout@v3 +# +# - name: Set up Python 3.7 +# uses: actions/setup-python@v4 +# with: +# python-version: '3.7' +# architecture: 'x64' +# +# - name: Run setup.py +# run: | +# pip install -r requirements.txt +# python -m pip install --upgrade pip +# pip install wheel get_pypi_latest_version +# +# wget https://github.com/RapidAI/TableStructureRec/releases/download/v0.0.0/lineless_table_rec_models.zip +# unzip lineless_table_rec_models.zip +# mv lineless_table_rec_models/*.onnx lineless_table_rec/models/ +# +# python setup_lineless.py bdist_wheel "${{ github.event.head_commit.message }}" +# +# # - name: Publish distribution 📦 to Test PyPI +# # uses: pypa/gh-action-pypi-publish@v1.5.0 +# # with: +# # password: ${{ secrets.TEST_PYPI_API_TOKEN }} +# # repository_url: https://test.pypi.org/legacy/ +# # packages_dir: dist/ +# +# - name: Publish distribution 📦 to PyPI +# uses: pypa/gh-action-pypi-publish@v1.5.0 +# with: +# password: ${{ secrets.PYPI_API_TOKEN }} +# packages_dir: dist/ diff --git a/.github/workflows/table_cls.yml b/.github/workflows/table_cls.yml index 8398aed..eb9768b 100644 --- a/.github/workflows/table_cls.yml +++ b/.github/workflows/table_cls.yml @@ -35,33 +35,33 @@ jobs: pytest tests/test_table_cls.py - GenerateWHL_PushPyPi: - needs: UnitTesting - runs-on: ubuntu-latest - - steps: - - uses: actions/checkout@v3 - - - name: Set up Python 3.10 - uses: actions/setup-python@v4 - with: - python-version: '3.10' - architecture: 'x64' - - - name: Run setup.py - run: | - pip install -r requirements.txt - python -m pip install --upgrade pip - pip install wheel get_pypi_latest_version - - wget https://github.com/RapidAI/TableStructureRec/releases/download/v0.0.0/table_cls_models.zip - unzip table_cls_models.zip - mv table_cls_models/*.onnx table_cls/models/ - - python setup_table_cls.py bdist_wheel "${{ github.event.head_commit.message }}" - - - name: Publish distribution 📦 to PyPI - uses: pypa/gh-action-pypi-publish@v1.5.0 - with: - password: ${{ secrets.TABLE_CLS }} - packages_dir: dist/ +# GenerateWHL_PushPyPi: +# needs: UnitTesting +# runs-on: ubuntu-latest +# +# steps: +# - uses: actions/checkout@v3 +# +# - name: Set up Python 3.10 +# uses: actions/setup-python@v4 +# with: +# python-version: '3.10' +# architecture: 'x64' +# +# - name: Run setup.py +# run: | +# pip install -r requirements.txt +# python -m pip install --upgrade pip +# pip install wheel get_pypi_latest_version +# +# wget https://github.com/RapidAI/TableStructureRec/releases/download/v0.0.0/table_cls_models.zip +# unzip table_cls_models.zip +# mv table_cls_models/*.onnx table_cls/models/ +# +# python setup_table_cls.py bdist_wheel "${{ github.event.head_commit.message }}" +# +# - name: Publish distribution 📦 to PyPI +# uses: pypa/gh-action-pypi-publish@v1.5.0 +# with: +# password: ${{ secrets.TABLE_CLS }} +# packages_dir: dist/ diff --git a/.github/workflows/wired_table_rec.yml b/.github/workflows/wired_table_rec.yml index 68fbeef..b114855 100644 --- a/.github/workflows/wired_table_rec.yml +++ b/.github/workflows/wired_table_rec.yml @@ -35,33 +35,33 @@ jobs: pytest tests/test_wired_table_rec.py - GenerateWHL_PushPyPi: - needs: UnitTesting - runs-on: ubuntu-latest - - steps: - - uses: actions/checkout@v3 - - - name: Set up Python 3.7 - uses: actions/setup-python@v4 - with: - python-version: '3.7' - architecture: 'x64' - - - name: Run setup.py - run: | - pip install -r requirements.txt - python -m pip install --upgrade pip - pip install wheel get_pypi_latest_version - - wget https://github.com/RapidAI/TableStructureRec/releases/download/v0.0.0/wired_table_rec_models.zip - unzip wired_table_rec_models.zip - mv wired_table_rec_models/*.onnx wired_table_rec/models/ - - python setup_wired.py bdist_wheel "${{ github.event.head_commit.message }}" - - - name: Publish distribution 📦 to PyPI - uses: pypa/gh-action-pypi-publish@v1.5.0 - with: - password: ${{ secrets.PYPI_API_TOKEN }} - packages_dir: dist/ +# GenerateWHL_PushPyPi: +# needs: UnitTesting +# runs-on: ubuntu-latest +# +# steps: +# - uses: actions/checkout@v3 +# +# - name: Set up Python 3.7 +# uses: actions/setup-python@v4 +# with: +# python-version: '3.7' +# architecture: 'x64' +# +# - name: Run setup.py +# run: | +# pip install -r requirements.txt +# python -m pip install --upgrade pip +# pip install wheel get_pypi_latest_version +# +# wget https://github.com/RapidAI/TableStructureRec/releases/download/v0.0.0/wired_table_rec_models.zip +# unzip wired_table_rec_models.zip +# mv wired_table_rec_models/*.onnx wired_table_rec/models/ +# +# python setup_wired.py bdist_wheel "${{ github.event.head_commit.message }}" +# +# - name: Publish distribution 📦 to PyPI +# uses: pypa/gh-action-pypi-publish@v1.5.0 +# with: +# password: ${{ secrets.PYPI_API_TOKEN }} +# packages_dir: dist/ diff --git a/README.md b/README.md index 89c2b04..a84692c 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,6 @@
-

📊 Table Structure Recognition

+

📊 表格结构识别

@@ -10,61 +10,125 @@ SemVer2.0 GitHub +
+ +### 简介 + +💖该仓库是用来对文档中表格做结构化识别的推理库,包括来自paddle的表格识别模型, +阿里读光有线和无线表格识别模型,llaipython(微信)贡献的有线表格模型,网易Qanything内置表格分类模型等。 + +#### 特点 +⚡ **快** 采用ONNXRuntime作为推理引擎,cpu下单图推理1-7s + +🎯 **准**: 结合表格类型分类模型,区分有线表格,无线表格,任务更细分,精度更高 - [简体中文](./docs/README_zh.md) | English +🛡️ **稳**: 不依赖任何第三方训练框架,只依赖必要基础库,避免包冲突 + +### 效果展示 +
+ Demo
-### Introduction +### 指标结果 +[TableRecognitionMetric 评测工具](https://github.com/SWHL/TableRecognitionMetric) [评测数据集](https://huggingface.co/datasets/SWHL/table_rec_test_dataset) [Rapid OCR](https://github.com/RapidAI/RapidOCR) -This repo is an inference library used for structured recognition of tables in documents, including table structure recognition algorithm models from PaddleOCR, wired and wireless table recognition algorithm models from Alibaba Duguang, etc. +| 方法 | TEDS | +|:---------------------------------------------------------------------------------------------------------------------------|:----:| +| lineless_table_rec | 0.53561 | +| [RapidTable](https://github.com/RapidAI/RapidStructure/blob/b800b156015bf5cd6f5429295cdf48be682fd97e/docs/README_Table.md) | 0.58786 | +| wired_table_rec v1 | 0.70279 | +| wired_table_rec v2 | 0.78007 | +| table_cls + wired_table_rec v1 + lineless_table_rec | 0.74692 | +| table_cls + wired_table_rec v2 + lineless_table_rec |0.80235| -The repo has improved the pre- and post-processing of form recognition and combined with OCR to ensure that the form recognition part can be used directly. +### 安装 +``` python {linenos=table} +pip install wired_table_rec lineless_table_rec table_cls +``` -The repo will continue to focus on the field of table recognition, integrate the latest and most useful table recognition algorithms, and strive to create the most valuable table recognition tool library. +### 快速使用 +``` python {linenos=table} +import os -Welcome everyone to continue to pay attention. +from lineless_table_rec import LinelessTableRecognition +from lineless_table_rec.utils_table_recover import format_html, plot_rec_box_with_logic_info, plot_rec_box +from table_cls import TableCls +from wired_table_rec import WiredTableRecognition -### What is Table Structure Recognition? +lineless_engine = LinelessTableRecognition() +wired_engine = WiredTableRecognition() +table_cls = TableCls() +img_path = f'images/img14.jpg' -Table Structure Recognition (TSR) aims to extract the logical or physical structure of table images, thereby converting unstructured table images into machine-readable formats. +cls,elasp = table_cls(img_path) +if cls == 'wired': + table_engine = wired_engine +else: + table_engine = lineless_engine +html, elasp, polygons, logic_points, ocr_res = table_engine(img_path) +print(f"elasp: {elasp}") -Logical structure: represents the row/column relationship of cells (such as the same row, the same column) and the span information of cells. +# output_dir = f'outputs' +# complete_html = format_html(html) +# os.makedirs(os.path.dirname(f"{output_dir}/table.html"), exist_ok=True) +# with open(f"{output_dir}/table.html", "w", encoding="utf-8") as file: +# file.write(complete_html) +# # 可视化表格识别框 + 逻辑行列信息 +# plot_rec_box_with_logic_info( +# img_path, f"{output_dir}/table_rec_box.jpg", logic_points, polygons +# ) +# # 可视化 ocr 识别框 +# plot_rec_box(img_path, f"{output_dir}/ocr_box.jpg", ocr_res) +``` -Physical structure: includes not only the logical structure, but also the cell's bounding box, content and other information, emphasizing the physical location of the cell. +## FAQ (Frequently Asked Questions) -
- -
+1. **问:偏移的图片能够处理吗?** + - 答:该项目暂时不支持偏移图片识别,请先修正图片,也欢迎提pr来解决这个问题。 -Figure from: [Improving Table Structure Recognition with Visual-Alignment Sequential Coordinate Modeling](https://openaccess.thecvf.com/content/CVPR2023/html/Huang_Improving_Table_Structure_Recognition_With_Visual-Alignment_Sequential_Coordinate_Modeling_CVPR_2023_paper.html) +2. **问:识别框丢失了内部文字信息** + - 答:默认使用的rapidocr小模型,如果需要更高精度的效果,可以从 [模型列表](https://rapidai.github.io/RapidOCRDocs/model_list/#_1) + 下载更高精度的ocr模型,在执行时传入ocr_result即可 + +3. **问:模型支持 gpu 加速吗?** + - 答:目前表格模型的推理非常快,有线表格在100ms级别,无线表格在500ms级别, + 主要耗时在ocr阶段,可以参考 [rapidocr_paddle](https://rapidai.github.io/RapidOCRDocs/install_usage/rapidocr_paddle/usage/#_3) 加速ocr识别过程 -### Documentation +### TODO List +- [ ] 识别前图片偏移修正 +- [ ] 增加数据集数量,增加更多评测对比 +- [ ] 优化无线表格模型 -Full documentation can be found on [docs](https://rapidai.github.io/TableStructureRec/docs/), in Chinese. +### 处理流程 +```mermaid +flowchart TD + A[/表格图片/] --> B([表格分类]) + B --> C([有线表格识别]) & D([无线表格识别]) --> E([文字识别 rapidocr_onnxruntime]) + E --> F[/html结构化输出/] +``` -### Acknowledgements +### 致谢 -[PaddleOCR Table](https://github.com/PaddlePaddle/PaddleOCR/blob/4b17511491adcfd0f3e2970895d06814d1ce56cc/ppstructure/table/README_ch.md) +[PaddleOCR 表格识别](https://github.com/PaddlePaddle/PaddleOCR/blob/4b17511491adcfd0f3e2970895d06814d1ce56cc/ppstructure/table/README_ch.md) -[Cycle CenterNet](https://www.modelscope.cn/models/damo/cv_dla34_table-structure-recognition_cycle-centernet/summary) +[读光-表格结构识别-有线表格](https://www.modelscope.cn/models/damo/cv_dla34_table-structure-recognition_cycle-centernet/summary) -[LORE](https://www.modelscope.cn/models/damo/cv_resnet-transformer_table-structure-recognition_lore/summary) +[读光-表格结构识别-无线表格](https://www.modelscope.cn/models/damo/cv_resnet-transformer_table-structure-recognition_lore/summary) -### Contributing +[Qanything-RAG](https://github.com/netease-youdao/QAnything) -Pull requests are welcome. For major changes, please open an issue first -to discuss what you would like to change. +非常感谢 llaipython(微信,提供全套有偿高精度表格提取) 提供高精度有线表格模型。 -Please make sure to update tests as appropriate. +### 贡献指南 -### [Sponsor](https://rapidai.github.io/Knowledge-QA-LLM/docs/sponsor/) +欢迎提交请求。对于重大更改,请先打开issue讨论您想要改变的内容。 -If you want to sponsor the project, you can directly click the **Buy me a coffee** image, please write a note (e.g. your github account name) to facilitate adding to the sponsorship list below. +请确保适当更新测试。 -
- -
+### [赞助](https://rapidai.github.io/Knowledge-QA-LLM/docs/sponsor/) + +如果您想要赞助该项目,可直接点击当前页最上面的Sponsor按钮,请写好备注(**您的Github账号名称**),方便添加到赞助列表中。 -### License +### 开源许可证 -This project is released under the [Apache 2.0 license](https://github.com/RapidAI/TableStructureRec/blob/c41bbd23898cb27a957ed962b0ffee3c74dfeff1/LICENSE). +该项目采用[Apache 2.0](https://github.com/RapidAI/TableStructureRec/blob/c41bbd23898cb27a957ed962b0ffee3c74dfeff1/LICENSE)开源许可证。 diff --git a/demo_wired.py b/demo_wired.py index cc09e68..db2f23d 100644 --- a/demo_wired.py +++ b/demo_wired.py @@ -1,16 +1,26 @@ # -*- encoding: utf-8 -*- # @Author: SWHL # @Contact: liekkaskono@163.com -from pathlib import Path - +import os +from lineless_table_rec.utils_table_recover import format_html from wired_table_rec import WiredTableRecognition +from wired_table_rec.utils_table_recover import ( + plot_rec_box, + plot_rec_box_with_logic_info, +) +output_dir = "outputs" table_rec = WiredTableRecognition() -img_path = "tests/test_files/wired/squeeze_error.jpeg" -table_str, elapse = table_rec(img_path) -print(table_str) -print(elapse) +img_path = "tests/test_files/wired/table1.png" +html, elasp, polygons, logic_points, ocr_res = table_rec(img_path) +print(f"cost: {elasp:.5f}") +complete_html = format_html(html) +os.makedirs(os.path.dirname(f"{output_dir}/table.html"), exist_ok=True) +with open(f"{output_dir}/table.html", "w", encoding="utf-8") as file: + file.write(complete_html) -with open(f"{Path(img_path).stem}.html", "w", encoding="utf-8") as f: - f.write(table_str) +plot_rec_box_with_logic_info( + img_path, f"{output_dir}/table_rec_box.jpg", logic_points, polygons +) +plot_rec_box(img_path, f"{output_dir}/ocr_box.jpg", ocr_res) diff --git a/docs/README_zh.md b/docs/README_zh.md deleted file mode 100644 index 12af069..0000000 --- a/docs/README_zh.md +++ /dev/null @@ -1,65 +0,0 @@ -
-
-

📊 表格结构识别

-
- - -PyPI - - - SemVer2.0 - - GitHub - - 简体中文 | [English](https://github.com/RapidAI/TableStructureRec) -
- -### 简介 - -该仓库是用来对文档中表格做结构化识别的推理库,包括来自PaddleOCR的表格结构识别算法模型、来自阿里读光有线和无线表格识别算法模型等。 - -该仓库将表格识别前后处理做了完善,并结合OCR,保证表格识别部分可直接使用。 - -该仓库会持续关注表格识别这一领域,集成最新最好用的表格识别算法,争取打造最具有落地价值的表格识别工具库。 - -欢迎大家持续关注。 - -### 表格结构化识别 - -表格结构识别(Table Structure Recognition, TSR)旨在提取表格图像的逻辑或物理结构,从而将非结构化的表格图像转换为机器可读的格式。 - -逻辑结构:表示单元格的行/列关系(例如同行、同列)和单元格的跨度信息。 - -物理结构:不仅包含逻辑结构,还包含单元格的包围框、内容等信息,强调单元格的物理位置。 - -
- -
- -图来自: [Improving Table Structure Recognition with Visual-Alignment Sequential Coordinate Modeling](https://openaccess.thecvf.com/content/CVPR2023/html/Huang_Improving_Table_Structure_Recognition_With_Visual-Alignment_Sequential_Coordinate_Modeling_CVPR_2023_paper.html) - -### 文档 - -完整文档请移步:[docs](https://rapidai.github.io/TableStructureRec/docs/) - -### 致谢 - -[PaddleOCR 表格识别](https://github.com/PaddlePaddle/PaddleOCR/blob/4b17511491adcfd0f3e2970895d06814d1ce56cc/ppstructure/table/README_ch.md) - -[读光-表格结构识别-有线表格](https://www.modelscope.cn/models/damo/cv_dla34_table-structure-recognition_cycle-centernet/summary) - -[读光-表格结构识别-无线表格](https://www.modelscope.cn/models/damo/cv_resnet-transformer_table-structure-recognition_lore/summary) - -### 贡献指南 - -欢迎提交请求。对于重大更改,请先打开issue讨论您想要改变的内容。 - -请确保适当更新测试。 - -### [赞助](https://rapidai.github.io/Knowledge-QA-LLM/docs/sponsor/) - -如果您想要赞助该项目,可直接点击当前页最上面的Sponsor按钮,请写好备注(**您的Github账号名称**),方便添加到赞助列表中。 - -### 开源许可证 - -该项目采用[Apache 2.0](https://github.com/RapidAI/TableStructureRec/blob/c41bbd23898cb27a957ed962b0ffee3c74dfeff1/LICENSE)开源许可证。 diff --git a/lineless_table_rec/main.py b/lineless_table_rec/main.py index 8741c82..62a4e6b 100644 --- a/lineless_table_rec/main.py +++ b/lineless_table_rec/main.py @@ -5,7 +5,7 @@ import time import traceback from pathlib import Path -from typing import Any, Dict, List, Tuple, Union +from typing import Any, Dict, List, Tuple, Union, Optional import cv2 import numpy as np @@ -47,16 +47,25 @@ def __init__( self.det_process = DetProcess() self.ocr = RapidOCR() - def __call__(self, content: InputType): + def __call__( + self, + content: InputType, + ocr_result: Optional[List[Union[List[List[float]], str, str]]] = None, + ): ss = time.perf_counter() img = self.load_img(content) - ocr_res, _ = self.ocr(img) + if self.ocr is None and ocr_result is None: + raise ValueError( + "One of two conditions must be met: ocr_result is not empty, or rapidocr_onnxruntime is installed." + ) + if ocr_result is None: + ocr_result, _ = self.ocr(img) input_info = self.preprocess(img) try: polygons, slct_logi = self.infer(input_info) logi_points = self.filter_logi_points(slct_logi) # ocr 结果匹配 - cell_box_det_map, no_match_ocr_det = match_ocr_cell(ocr_res, polygons) + cell_box_det_map, no_match_ocr_det = match_ocr_cell(ocr_result, polygons) # 如果有识别框没有ocr结果,直接进行rec补充 cell_box_det_map = self.re_rec(img, polygons, cell_box_det_map) # 转换为中间格式,修正识别框坐标,将物理识别框,逻辑识别框,ocr识别框整合为dict,方便后续处理 @@ -92,7 +101,9 @@ def __call__(self, content: InputType): sorted_logi_points = [ t_box_ocr["t_logic_box"] for t_box_ocr in t_rec_ocr_list ] - ocr_boxes_res = [box_4_2_poly_to_box_4_1(ori_ocr[0]) for ori_ocr in ocr_res] + ocr_boxes_res = [ + box_4_2_poly_to_box_4_1(ori_ocr[0]) for ori_ocr in ocr_result + ] sorted_ocr_boxes_res, _ = sorted_ocr_boxes(ocr_boxes_res) table_elapse = time.perf_counter() - ss return ( @@ -181,21 +192,19 @@ def infer(self, input_content: Dict[str, Any]) -> Tuple[np.ndarray, np.ndarray]: def sort_and_gather_ocr_res(self, res): for i, dict_res in enumerate(res): _, sorted_idx = sorted_ocr_boxes( - [ocr_det[0] for ocr_det in dict_res["t_ocr_res"]] + [ocr_det[0] for ocr_det in dict_res["t_ocr_res"]], threhold=0.5 ) dict_res["t_ocr_res"] = [dict_res["t_ocr_res"][i] for i in sorted_idx] - dict_res["t_ocr_res"] = gather_ocr_list_by_row(dict_res["t_ocr_res"]) + dict_res["t_ocr_res"] = gather_ocr_list_by_row( + dict_res["t_ocr_res"], thehold=0.5 + ) return res def handle_overlap_row_col(self, res): max_row, max_col = 0, 0 for dict_res in res: - max_row = max( - max_row, dict_res["t_logic_box"][1] + 1 - ) # 加1是因为结束下标是包含在内的 - max_col = max( - max_col, dict_res["t_logic_box"][3] + 1 - ) # 加1是因为结束下标是包含在内的 + max_row = max(max_row, dict_res["t_logic_box"][1] + 1) # 加1是因为结束下标是包含在内的 + max_col = max(max_col, dict_res["t_logic_box"][3] + 1) # 加1是因为结束下标是包含在内的 # 创建一个二维数组来存储 sorted_logi_points 中的元素 grid = [[None] * max_col for _ in range(max_row)] diff --git a/lineless_table_rec/utils_table_recover.py b/lineless_table_rec/utils_table_recover.py index 98e3f9a..670364b 100644 --- a/lineless_table_rec/utils_table_recover.py +++ b/lineless_table_rec/utils_table_recover.py @@ -8,7 +8,6 @@ import cv2 import numpy as np import shapely -from numpy import ndarray from shapely.geometry import MultiPoint, Polygon @@ -196,7 +195,7 @@ def is_box_contained( def is_single_axis_contained( - box1: list | np.ndarray, box2: list | np.ndarray, axis="x", threshold=0.2 + box1: list | np.ndarray, box2: list | np.ndarray, axis="x", threhold=0.2 ) -> int | None: """ :param box1: Iterable [xmin,ymin,xmax,ymax] @@ -221,15 +220,15 @@ def is_single_axis_contained( ratio_b1 = b1_outside_area / b1_area if b1_area > 0 else 0 ratio_b2 = b2_outside_area / b2_area if b2_area > 0 else 0 - if ratio_b1 < threshold: + if ratio_b1 < threhold: return 1 - if ratio_b2 < threshold: + if ratio_b2 < threhold: return 2 return None def sorted_ocr_boxes( - dt_boxes: np.ndarray | list, + dt_boxes: np.ndarray | list, threhold: float = 0.2 ) -> tuple[np.ndarray | list, list[int]]: """ Sort text boxes in order from top to bottom, left to right @@ -249,10 +248,17 @@ def sorted_ocr_boxes( # 避免输出和输入格式不对应,与函数功能不符合 if isinstance(dt_boxes, np.ndarray): _boxes = np.array(_boxes) + threahold = 20 for i in range(num_boxes - 1): for j in range(i, -1, -1): - c_idx = is_single_axis_contained(_boxes[j], _boxes[j + 1], axis="y") - if c_idx is not None and _boxes[j + 1][0] < _boxes[j][0]: + c_idx = is_single_axis_contained( + _boxes[j], _boxes[j + 1], axis="y", threhold=threhold + ) + if ( + c_idx is not None + and _boxes[j + 1][0] < _boxes[j][0] + and abs(_boxes[j][1] - _boxes[j + 1][1]) < threahold + ): _boxes[j], _boxes[j + 1] = _boxes[j + 1].copy(), _boxes[j].copy() indices[j], indices[j + 1] = indices[j + 1], indices[j] else: @@ -261,12 +267,13 @@ def sorted_ocr_boxes( def gather_ocr_list_by_row( - ocr_list: list[list[list[float], str]] + ocr_list: list[list[list[float], str]], thehold: float = 0.2 ) -> list[list[list[float], str]]: """ :param ocr_list: [[[xmin,ymin,xmax,ymax], text]] :return: """ + threshold = 10 for i in range(len(ocr_list)): if not ocr_list[i]: continue @@ -278,9 +285,13 @@ def gather_ocr_list_by_row( next = ocr_list[j] cur_box = cur[0] next_box = next[0] - c_idx = is_single_axis_contained(cur[0], next[0], axis="y") + c_idx = is_single_axis_contained( + cur[0], next[0], axis="y", threhold=thehold + ) if c_idx: - cur[1] = cur[1] + next[1] + dis = max(next_box[0] - cur_box[1], 0) + blank_str = int(dis / threshold) * " " + cur[1] = cur[1] + blank_str + next[1] xmin = min(cur_box[0], next_box[0]) xmax = max(cur_box[2], next_box[2]) ymin = min(cur_box[1], next_box[1]) @@ -299,7 +310,7 @@ def box_4_1_poly_to_box_4_2(poly_box: list | np.ndarray) -> list[list[float]]: return [[xmin, ymin], [xmax, ymin], [xmax, ymax], [xmin, ymax]] -def box_4_2_poly_to_box_4_1(poly_box: list | np.ndarray) -> list[ndarray[Any, Any]]: +def box_4_2_poly_to_box_4_1(poly_box: list | np.ndarray) -> list[float]: """ 将poly_box转换为box_4_1 :param poly_box: @@ -409,6 +420,10 @@ def plot_html_table( # 创建一个二维数组来存储 sorted_logi_points 中的元素 grid = [[None] * max_col for _ in range(max_row)] + valid_start_row = (1 << 16) - 1 + valid_end_row = 0 + valid_start_col = (1 << 16) - 1 + valid_end_col = 0 # 将 sorted_logi_points 中的元素填充到 grid 中 for i, logic_point in enumerate(logi_points): row_start, row_end, col_start, col_end = ( @@ -417,43 +432,48 @@ def plot_html_table( logic_point[2], logic_point[3], ) + ocr_rec_text_list = cell_box_map.get(i) + if ocr_rec_text_list and "".join(ocr_rec_text_list): + valid_start_row = min(row_start, valid_start_row) + valid_start_col = min(col_start, valid_start_col) + valid_end_row = max(row_end, valid_end_row) + valid_end_col = max(col_end, valid_end_col) for row in range(row_start, row_end + 1): for col in range(col_start, col_end + 1): grid[row][col] = (i, row_start, row_end, col_start, col_end) # 创建表格 - table_html = "\n" + table_html = "
" # 遍历每行 for row in range(max_row): - empty_temp = True - temp = " \n" - + if row < valid_start_row or row > valid_end_row: + continue + temp = "" # 遍历每一列 for col in range(max_col): + if col < valid_start_col or col > valid_end_col: + continue if not grid[row][col]: - temp += " \n" + temp += "" else: i, row_start, row_end, col_start, col_end = grid[row][col] if not cell_box_map.get(i): continue - empty_temp = False if row == row_start and col == col_start: ocr_rec_text = cell_box_map.get(i) text = "
".join(ocr_rec_text) - if not text.strip(): - continue # 如果是起始单元格 row_span = row_end - row_start + 1 col_span = col_end - col_start + 1 cell_content = ( - f"\n" + f"" ) temp += cell_content - if not empty_temp: - table_html = table_html + temp + " \n" - table_html += "
{text}{text}
" + table_html = table_html + temp + "" + + table_html += "" return table_html diff --git a/requirements.txt b/requirements.txt index ff167db..31b4439 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,7 @@ numpy>=1.21.6 onnxruntime>=1.14.1 opencv_python +scipy +scikit-image Shapely rapidocr_onnxruntime>=1.3.8 diff --git a/tests/test_lineless_table_rec.py b/tests/test_lineless_table_rec.py index 6151c36..55453cf 100644 --- a/tests/test_lineless_table_rec.py +++ b/tests/test_lineless_table_rec.py @@ -21,8 +21,8 @@ @pytest.mark.parametrize( "img_path, table_str_len, td_nums", [ - ("lineless_table_recognition.jpg", 1869, 104), - ("table.jpg", 3000, 158), + ("lineless_table_recognition.jpg", 1840, 108), + ("table.jpg", 2870, 160), ], ) def test_input_normal(img_path, table_str_len, td_nums): @@ -182,3 +182,65 @@ def test_filter_duplicated_box(table_boxes, expected_delete_idx): assert ( delete_idx == expected_delete_idx ), f"Expected {expected_delete_idx}, but got {delete_idx}" + + +@pytest.mark.parametrize( + "logi_points, cell_box_map, expected_html", + [ + # 测试空输入 + ([], {}, "
"), + # 测试单个单元格,包含rowspan和colspan + ( + [[0, 0, 0, 0]], + {0: ["Cell 1"]}, + "
Cell 1
", + ), + # 测试多个独立单元格 + ( + [[0, 0, 0, 0], [1, 1, 1, 1]], + {0: ["Cell 1"], 1: ["Cell 2"]}, + "
Cell 1
Cell 2
", + ), + # 测试跨行的单元格 + ( + [[0, 1, 0, 0]], + {0: ["Row 1 Col 1", "Row 2 Col 1"]}, + "
Row 1 Col 1
Row 2 Col 1
", + ), + # 测试跨列的单元格 + ( + [[0, 0, 0, 1]], + {0: ["Col 1 Row 1", "Col 2 Row 1"]}, + "
Col 1 Row 1
Col 2 Row 1
", + ), + # 测试跨多行多列的单元格 + ( + [[0, 1, 0, 1]], + {0: ["Row 1 Col 1", "Row 2 Col 1"]}, + "
Row 1 Col 1
Row 2 Col 1
", + ), + # 测试跨行跨行跨列的单元格出现在中间 + ( + [[0, 0, 0, 0], [0, 1, 1, 2]], + {0: ["Cell 1"], 1: ["Row 2", "Col 2"]}, + "
Cell 1Row 2
Col 2
", + ), + # 测试跨行跨列的单元格出现在结尾 + ( + [[0, 0, 0, 0], [1, 1, 1, 1], [0, 1, 2, 2]], + {0: ["Cell 1"], 1: ["Cell 2"], 2: ["Row 1 Col 2", "Row 2 Col 2"]}, + "
Cell 1Row 1 Col 2
Row 2 Col 2
Cell 2
", + ), + # 测试去除无效行和无效列 + ( + [[0, 0, 0, 0], [1, 1, 0, 0], [1, 1, 1, 1], [0, 0, 1, 2]], + {2: ["Row 3 Col 1", "Row 3 Col 2"]}, + "
Row 3 Col 1
Row 3 Col 2
", + ), + ], +) +def test_plot_html_table(logi_points, cell_box_map, expected_html): + html_output = plot_html_table(logi_points, cell_box_map) + assert ( + html_output == expected_html + ), f"Expected HTML does not match. Got: {html_output}" diff --git a/tests/test_wired_table_line_util.py b/tests/test_wired_table_line_util.py new file mode 100644 index 0000000..9d22bc3 --- /dev/null +++ b/tests/test_wired_table_line_util.py @@ -0,0 +1,218 @@ +import pytest +import numpy as np +from wired_table_rec.utils_table_line_rec import ( + _order_points, + calculate_center_rotate_angle, + fit_line, + line_to_line, + min_area_rect, + adjust_lines, +) + + +@pytest.mark.parametrize( + "pts, expected", + [ + # 顺时针顺序正确,无需排序 + ( + np.array([[10, 10], [20, 10], [20, 20], [10, 20]]), + np.array([[10, 10], [20, 10], [20, 20], [10, 20]], dtype="float32"), + ), + # 完全相反顺序,进行重排序 + ( + np.array([[20, 10], [20, 20], [10, 20], [10, 10]]), + np.array([[10, 10], [20, 10], [20, 20], [10, 20]], dtype="float32"), + ), + # 部分错位顺序,重排序 + ( + np.array([[10, 20], [20, 20], [20, 10], [10, 10]]), + np.array([[10, 10], [20, 10], [20, 20], [10, 20]], dtype="float32"), + ), + ], +) +def test_order_points(pts, expected): + """ + 排序后得到[(xmin,ymin),(xmax,ymin),(xmax,ymax),(xmin,ymax)] + """ + result = _order_points(pts) + assert np.allclose(result, expected) + + +@pytest.mark.parametrize( + "box, expected_angle, expected_w, expected_h, expected_cx, expected_cy", + [ + # 沿中心点无旋转 + ([10, 10, 20, 10, 20, 20, 10, 20], 0.0, 10.0, 10.0, 15.0, 15.0), + # 沿中心点有旋转30度 + ( + [ + 13.16987, + 8.1698, + 21.830, + 13.16987, + 16.830127018922195, + 21.83012701892219, + 8.169872981077807, + 16.830127018922195, + ], + np.pi / 6, + 10.0, + 10.0, + 15.0, + 15.0, + ), + ], +) +def test_calculate_center_rotate_angle( + box, expected_angle, expected_w, expected_h, expected_cx, expected_cy +): + angle, w, h, cx, cy = calculate_center_rotate_angle(box) + assert np.isclose(angle, expected_angle, atol=1e-5) + assert np.isclose(w, expected_w, atol=1e-5) + assert np.isclose(h, expected_h, atol=1e-5) + assert np.isclose(cx, expected_cx, atol=1e-5) + assert np.isclose(cy, expected_cy, atol=1e-5) + + +# 测试函数 +@pytest.mark.parametrize( + "points, expected_A, expected_B, expected_C", + [ + # 根据两个点计算直线方程的参数 + ([(0, 0), (1, 1)], 1, -1, 0) + ], +) +def test_fit_line(points, expected_A, expected_B, expected_C): + A, B, C = fit_line(points) + assert np.isclose(A, expected_A, atol=1e-5) + assert np.isclose(B, expected_B, atol=1e-5) + assert np.isclose(C, expected_C, atol=1e-5) + + +@pytest.mark.parametrize( + "points1, points2, expected_result", + [ + # 横线在竖线同边,无角度偏移,延长第二个点到相交点 + ([0, 0, 0.9, 0], [1, 0, 1, 1], np.array([0, 0, 1, 0], dtype="float32")), + # 横线在竖线同边,有角度偏移,延长第一个点到相交点 + ([4, 3, 0, 0], [8, 0, 8, 8], np.array([8, 6, 0, 0], dtype="float32")), + # 横线在竖线异边,不进行延伸 + ([0, 0, 2, 1], [1, 0, 1, 1], np.array([0, 0, 2, 1], dtype="float32")), + # 超过偏移角度,不进行延伸 + ([0, 0, 0.9, 0.9], [1, 0, 1, 4], np.array([0, 0, 0.9, 0.9], dtype="float32")), + # 超过交点绝对值长度,不进行延伸 + ([4, 3, 0, 0], [50, 0, 50, 50], np.array([4, 3, 0, 0], dtype="float32")) + # + ], +) +def test_line_to_line(points1, points2, expected_result): + # 为测试方便,提高角度阈值到60度 + result = line_to_line(points1, points2, angle=38) + assert np.allclose(result, expected_result, atol=1e-5) + + +@pytest.mark.parametrize( + "coords, expected_result", + [ + # 竖线求最小外接矩形 + ( + np.array([[0, 1000], [10, 1000], [10, 1002], [20, 1002]]), + [1000, 0, 1002, 20], + ), + # 横线求最小外接矩形 + ( + np.array([[1000, 0], [1000, 10], [1002, 15], [1001, 30]]), + [0, 1000, 30, 1000], + ), + ], +) +def test_min_area_rect(coords, expected_result): + result = min_area_rect(coords) + assert np.allclose(result, expected_result, atol=2) + + +@pytest.mark.parametrize( + "lines, alph, angle, expected_result", + [ + # 每个坐标点都能合并 + ( + [(0, 0, 1, 0), (1, 0, 2, 0)], + # alph: 最大允许距离 + 50, + # angle: 角度阈值 + 50, + # 预期结果:两两合并 + [ + (0, 0, 1, 0), + (0, 0, 2, 0), + (1, 0, 1, 0), + (1, 0, 2, 0), + (1, 0, 0, 0), + (1, 0, 1, 0), + (2, 0, 0, 0), + (2, 0, 1, 0), + ], + ), + # y轴重叠过大不合并 + ( + [(0, 0.5, 0, 1.8), (0, 1, 0, 2)], + # alph: 最大允许距离 + 50, + # angle: 角度阈值 + 50, + [], + ), + # x轴重叠过大不合并 + ( + [(1, 0, 2, 0), (0, 0, 1.8, 0)], + # alph: 最大允许距离 + 50, + # angle: 角度阈值 + 50, + [], + ), + # 距离超过阈值不合并 + ( + [(0, 0, 1, 0), (11, 0, 13, 0)], + # alph: 最大允许距离 + 10, + # angle: 角度阈值 + 50, + ([]), + ), + # 角度超过阈值不合并 + ( + # 横线距离足够近 + [(0, 0, 1, 1), (1, 1, 2, 2), (2, 2, 3, 3)], + # alph: 最大允许距离 + 100, + # angle: 角度阈值 + 35, + # 预期结果:只有边界角度为0能合并 + ([(1, 1, 1, 1), (1, 1, 1, 1), (2, 2, 2, 2), (2, 2, 2, 2)]), + ), + # 多段合并,角度过滤,距离过滤同时存在,且有可以合并的点 + ( + [(0, 0, 1, 1), (1, 1, 2, 2), (2, 2, 100, 100)], + # alph: 最大允许距离 + 50, + # angle: 角度阈值 + 30, + # 预期结果:多条竖线合并为一条线 + ([(1, 1, 1, 1), (1, 1, 1, 1), (2, 2, 2, 2), (2, 2, 2, 2)]), + ), + # 只有一条线 + ( + [(0, 0, 1, 0)], + # alph: 最大允许距离 + 50, + # angle: 角度阈值 + 50, + # 预期结果:横线不变 + ([]), + ), + ], +) +def test_adjust_lines(lines, alph, angle, expected_result): + result = adjust_lines(lines, alph, angle) + assert result == expected_result diff --git a/tests/test_wired_table_rec.py b/tests/test_wired_table_rec.py index bc19d62..37d46e0 100644 --- a/tests/test_wired_table_rec.py +++ b/tests/test_wired_table_rec.py @@ -3,11 +3,20 @@ # @Contact: liekkaskono@163.com import sys from pathlib import Path - +import numpy as np import pytest from bs4 import BeautifulSoup from rapidocr_onnxruntime import RapidOCR +from wired_table_rec.utils import rescale_size +from wired_table_rec.utils_table_recover import ( + plot_html_table, + is_single_axis_contained, + gather_ocr_list_by_row, + sorted_ocr_boxes, + is_box_contained, +) + cur_dir = Path(__file__).resolve().parent root_dir = cur_dir.parent @@ -30,16 +39,16 @@ def get_td_nums(html: str) -> int: def test_squeeze_bug(): img_path = test_file_dir / "squeeze_error.jpeg" ocr_result, _ = ocr_engine(img_path) - table_str, _ = table_recog(str(img_path), ocr_result) + table_str, *_ = table_recog(str(img_path), ocr_result) td_nums = get_td_nums(table_str) - assert td_nums == 153 + assert td_nums == 228 @pytest.mark.parametrize( "img_path, gt_td_nums, gt2", [ ("table_recognition.jpg", 35, "d colsp"), - ("table2.jpg", 22, "td><"), ], ) @@ -47,9 +56,208 @@ def test_input_normal(img_path, gt_td_nums, gt2): img_path = test_file_dir / img_path ocr_result, _ = ocr_engine(img_path) - table_str, _ = table_recog(str(img_path), ocr_result) + table_str, *_ = table_recog(str(img_path), ocr_result) td_nums = get_td_nums(table_str) assert td_nums == gt_td_nums - assert table_str[-53:-46] == gt2 + +@pytest.mark.parametrize( + "box1, box2, threshold, expected", + [ + # Box1 完全包含在 Box2 内 + ([[10, 20, 30, 40], [5, 15, 45, 55], 0.2, 1]), + # Box2 完全包含在 Box1 内 + ([[5, 15, 45, 55], [10, 20, 30, 40], 0.2, 2]), + # Box1 和 Box2 部分重叠,但不满足阈值 + ([[10, 20, 30, 40], [25, 35, 45, 55], 0.2, None]), + # Box1 和 Box2 完全不重叠 + ([[10, 20, 30, 40], [50, 60, 70, 80], 0.2, None]), + # Box1 和 Box2 有交集,但不满足阈值 + ([[10, 20, 30, 40], [15, 25, 35, 45], 0.2, None]), + # Box1 和 Box2 有交集,且满足阈值 + ([[10, 20, 30, 40], [15, 25, 35, 45], 0.5, 1]), + # Box1 和 Box2 有交集,且满足阈值 + ([[15, 25, 35, 45], [14, 24, 16, 44], 0.6, 2]), + # Box1 和 Box2 相同 + ([[10, 20, 30, 40], [10, 20, 30, 40], 0.2, 1]), + # 使用 NumPy 数组作为输入 + ([np.array([10, 20, 30, 40]), np.array([5, 15, 45, 55]), 0.2, 1]), + ], +) +def test_is_box_contained(box1, box2, threshold, expected): + result = is_box_contained(box1, box2, threshold) + assert result == expected, f"Expected {expected}, but got {result}" + + +@pytest.mark.parametrize( + "box1, box2, axis, threshold, expected", + [ + # Box1 完全包含 Box2 (X轴) + ([10, 10, 20, 20], [12, 12, 18, 18], "x", 0.2, 2), + # Box2 完全包含 Box1 (X轴) + ([12, 12, 18, 18], [10, 10, 20, 20], "x", 0.2, 1), + # Box1 完全包含 Box2 (Y轴) + ([10, 10, 20, 20], [12, 12, 18, 18], "y", 0.2, 2), + # Box2 完全包含 Box1 (Y轴) + ([12, 12, 18, 18], [10, 10, 20, 20], "y", 0.2, 1), + # Box1 和 Box2 不相交 (X轴) + ([10, 10, 20, 20], [25, 25, 30, 30], "x", 0.2, None), + # Box1 和 Box2 不相交 (Y轴) + ([10, 10, 20, 20], [25, 25, 30, 30], "y", 0.2, None), + # Box1 部分包含 Box2 (X轴)-超过阈值 + ([10, 10, 20, 20], [15, 15, 25, 25], "x", 0.2, None), + # Box1 部分包含 Box2 (Y轴)-超过阈值 + ([10, 10, 20, 20], [15, 15, 25, 25], "y", 0.2, None), + # Box1 部分包含 Box2 (X轴)-满足阈值 + ([10, 10, 20, 20], [13, 15, 21, 25], "x", 0.2, 2), + # Box2 部分包含 Box1 (Y轴)-满足阈值 + ([10, 14, 20, 20], [15, 15, 25, 50], "y", 0.2, 1), + # Box1 和 Box2 完全重合 (X轴) + ([10, 10, 20, 20], [10, 10, 20, 20], "x", 0.2, 1), + # Box1 和 Box2 完全重合 (Y轴) + ([10, 10, 20, 20], [10, 10, 20, 20], "y", 0.2, 1), + ], +) +def test_is_single_axis_contained(box1, box2, axis, threshold, expected): + result = is_single_axis_contained(box1, box2, axis, threshold) + assert result == expected + + +@pytest.mark.parametrize( + "input_ocr_list, expected_output", + [ + ( + [[[10, 20, 30, 40], "text1"], [[15, 23, 35, 43], "text2"]], + [[[10, 20, 35, 43], "text1text2"]], + ), + ( + [ + [[10, 24, 30, 30], "text1"], + [[15, 25, 35, 45], "text2"], + [[5, 30, 15, 50], "text3"], + ], + [[[10, 24, 35, 45], "text1text2"], [[5, 30, 15, 50], "text3"]], + ), + ([], []), + ( + [[[10, 20, 30, 40], "text1"], [], [[15, 25, 35, 45], "text2"]], + [[[10, 20, 30, 40], "text1"], [[15, 25, 35, 45], "text2"]], + ), + ], +) +def test_gather_ocr_list_by_row(input_ocr_list, expected_output): + result = gather_ocr_list_by_row(input_ocr_list) + assert result == expected_output, f"Expected {expected_output}, but got {result}" + + +@pytest.mark.parametrize( + "dt_boxes, expected_boxes, expected_indices", + [ + # 基本排序情况 + ( + np.array([[2, 3, 4, 5], [3, 4, 5, 6], [1, 2, 2, 3]]), + np.array([[1, 2, 2, 3], [2, 3, 4, 5], [3, 4, 5, 6]]), + [2, 0, 1], + ), + # 基本排序错误,修正正确 + ( + np.array([[59, 0, 148, 52], [134, 0, 254, 53], [12, 13, 30, 40]]), + np.array([[12, 13, 30, 40], [59, 0, 148, 52], [134, 0, 254, 53]]), + [2, 0, 1], + ), + # 一个盒子的情况 + (np.array([[2, 3, 4, 5]]), np.array([[2, 3, 4, 5]]), [0]), + # 无盒子的情况 + (np.array([]), np.array([]), []), + ], +) +def test_sorted_ocr_boxes(dt_boxes, expected_boxes, expected_indices): + sorted_boxes, indices = sorted_ocr_boxes(dt_boxes) + assert ( + sorted_boxes.tolist() == expected_boxes.tolist() + ), f"Expected {expected_boxes.tolist()}, but got {sorted_boxes.tolist()}" + assert ( + indices == expected_indices + ), f"Expected {expected_indices}, but got {indices}" + + +@pytest.mark.parametrize( + "old_size, scale, return_scale, expected_result", + [ + # 以短边为准进行缩放 + ((100, 50), (300, 100), True, ((200, 100), 2.0)), + ((50, 100), (100, 300), True, ((100, 200), 2.0)), + # 以长边为准进行缩放 + ((100, 50), (200, 150), True, ((200, 100), 2.0)), + ((50, 100), (150, 200), True, ((100, 200), 2.0)), + ], +) +def test_rescale_size(old_size, scale, return_scale, expected_result): + result = rescale_size(old_size, scale, return_scale) + assert np.isclose(result[1], expected_result[1], atol=1e-5) + assert ( + result[0] == expected_result[0] + ), f"Expected {expected_result}, but got {result}" + + +@pytest.mark.parametrize( + "logi_points, cell_box_map, expected_html", + [ + # 测试空输入 + ([], {}, "
"), + # 测试单个单元格,包含rowspan和colspan + ( + [[0, 0, 0, 0]], + {0: ["Cell 1"]}, + "
Cell 1
", + ), + # 测试多个独立单元格 + ( + [[0, 0, 0, 0], [1, 1, 1, 1]], + {0: ["Cell 1"], 1: ["Cell 2"]}, + "
Cell 1
Cell 2
", + ), + # 测试跨行的单元格 + ( + [[0, 1, 0, 0]], + {0: ["Row 1 Col 1", "Row 2 Col 1"]}, + "
Row 1 Col 1
Row 2 Col 1
", + ), + # 测试跨列的单元格 + ( + [[0, 0, 0, 1]], + {0: ["Col 1 Row 1", "Col 2 Row 1"]}, + "
Col 1 Row 1
Col 2 Row 1
", + ), + # 测试跨多行多列的单元格 + ( + [[0, 1, 0, 1]], + {0: ["Row 1 Col 1", "Row 2 Col 1"]}, + "
Row 1 Col 1
Row 2 Col 1
", + ), + # 测试跨行跨行跨列的单元格出现在中间 + ( + [[0, 0, 0, 0], [0, 1, 1, 2]], + {0: ["Cell 1"], 1: ["Row 2", "Col 2"]}, + "
Cell 1Row 2
Col 2
", + ), + # 测试跨行跨列的单元格出现在结尾 + ( + [[0, 0, 0, 0], [1, 1, 1, 1], [0, 1, 2, 2]], + {0: ["Cell 1"], 1: ["Cell 2"], 2: ["Row 1 Col 2", "Row 2 Col 2"]}, + "
Cell 1Row 1 Col 2
Row 2 Col 2
Cell 2
", + ), + # 测试去除无效行和无效列 + ( + [[0, 0, 0, 0], [1, 1, 0, 0], [1, 1, 1, 1], [0, 0, 1, 2]], + {2: ["Row 3 Col 1", "Row 3 Col 2"]}, + "
Row 3 Col 1
Row 3 Col 2
", + ), + ], +) +def test_plot_html_table(logi_points, cell_box_map, expected_html): + html_output = plot_html_table(logi_points, cell_box_map) + assert ( + html_output == expected_html + ), f"Expected HTML does not match. Got: {html_output}" diff --git a/wired_table_rec/main.py b/wired_table_rec/main.py index 2b107fb..0b0b784 100644 --- a/wired_table_rec/main.py +++ b/wired_table_rec/main.py @@ -7,21 +7,38 @@ import time import traceback from pathlib import Path -from typing import List, Optional, Tuple, Union +from typing import List, Optional, Tuple, Union, Dict +import numpy as np +import cv2 -from .table_line_rec import TableLineRecognition +from wired_table_rec.table_line_rec import TableLineRecognition +from wired_table_rec.table_line_rec_plus import TableLineRecognitionPlus from .table_recover import TableRecover from .utils import InputType, LoadImage -from .utils_table_recover import match_ocr_cell, plot_html_table +from .utils_table_recover import ( + match_ocr_cell, + plot_html_table, + box_4_2_poly_to_box_4_1, + get_rotate_crop_image, + sorted_ocr_boxes, + gather_ocr_list_by_row, +) cur_dir = Path(__file__).resolve().parent default_model_path = cur_dir / "models" / "cycle_center_net_v1.onnx" +default_model_path_v2 = cur_dir / "models" / "cycle_center_net_v2.onnx" class WiredTableRecognition: - def __init__(self, table_model_path: Union[str, Path] = default_model_path): + def __init__(self, table_model_path: Union[str, Path] = None, version="v2"): self.load_img = LoadImage() - self.table_line_rec = TableLineRecognition(str(table_model_path)) + if version == "v2": + model_path = table_model_path if table_model_path else default_model_path_v2 + self.table_line_rec = TableLineRecognitionPlus(str(model_path)) + else: + model_path = table_model_path if table_model_path else default_model_path + self.table_line_rec = TableLineRecognition(str(model_path)) + self.table_recover = TableRecover() try: @@ -33,7 +50,7 @@ def __call__( self, img: InputType, ocr_result: Optional[List[Union[List[List[float]], str, str]]] = None, - ) -> Tuple[str, float]: + ) -> Tuple[str, float, list]: if self.ocr is None and ocr_result is None: raise ValueError( "One of two conditions must be met: ocr_result is not empty, or rapidocr_onnxruntime is installed." @@ -45,21 +62,142 @@ def __call__( polygons = self.table_line_rec(img) if polygons is None: logging.warning("polygons is None.") - return "", 0.0 + return "", 0.0, None, None, None try: - table_res = self.table_recover(polygons) - + table_res, logi_points = self.table_recover(polygons) + # 将坐标由逆时针转为顺时针方向,后续处理与无线表格对齐 + polygons[:, 1, :], polygons[:, 3, :] = ( + polygons[:, 3, :].copy(), + polygons[:, 1, :].copy(), + ) if ocr_result is None: ocr_result, _ = self.ocr(img) + cell_box_det_map, not_match_orc_boxes = match_ocr_cell(ocr_result, polygons) + # 如果有识别框没有ocr结果,直接进行rec补充 + # cell_box_det_map = self.re_rec_high_precise(img, polygons, cell_box_det_map) + cell_box_det_map = self.re_rec(img, polygons, cell_box_det_map) + # 转换为中间格式,修正识别框坐标,将物理识别框,逻辑识别框,ocr识别框整合为dict,方便后续处理 + t_rec_ocr_list = self.transform_res(cell_box_det_map, polygons, logi_points) + # 将每个单元格中的ocr识别结果排序和同行合并,输出的html能完整保留文字的换行格式 + t_rec_ocr_list = self.sort_and_gather_ocr_res(t_rec_ocr_list) + # cell_box_map = + logi_points = [t_box_ocr["t_logic_box"] for t_box_ocr in t_rec_ocr_list] + cell_box_det_map = { + i: [ocr_box_and_text[1] for ocr_box_and_text in t_box_ocr["t_ocr_res"]] + for i, t_box_ocr in enumerate(t_rec_ocr_list) + } + table_str = plot_html_table(logi_points, cell_box_det_map) + ocr_boxes_res = [ + box_4_2_poly_to_box_4_1(ori_ocr[0]) for ori_ocr in ocr_result + ] + sorted_ocr_boxes_res, _ = sorted_ocr_boxes(ocr_boxes_res) + sorted_polygons = [box_4_2_poly_to_box_4_1(box) for box in polygons] + sorted_logi_points = logi_points + table_elapse = time.perf_counter() - s - cell_box_map = match_ocr_cell(polygons, ocr_result) - table_str = plot_html_table(table_res, cell_box_map) - elapse = time.perf_counter() - s except Exception: logging.warning(traceback.format_exc()) - return "", 0.0 - return table_str, elapse + return "", 0.0, None, None, None + return ( + table_str, + table_elapse, + sorted_polygons, + sorted_logi_points, + sorted_ocr_boxes_res, + ) + + def transform_res( + self, + cell_box_det_map: dict[int, List[any]], + polygons: np.ndarray, + logi_points: List[np.ndarray], + ) -> List[dict[str, any]]: + res = [] + for i in range(len(polygons)): + ocr_res_list = cell_box_det_map.get(i) + if not ocr_res_list: + continue + xmin = min([ocr_box[0][0][0] for ocr_box in ocr_res_list]) + ymin = min([ocr_box[0][0][1] for ocr_box in ocr_res_list]) + xmax = max([ocr_box[0][2][0] for ocr_box in ocr_res_list]) + ymax = max([ocr_box[0][2][1] for ocr_box in ocr_res_list]) + dict_res = { + # xmin,xmax,ymin,ymax + "t_box": [xmin, ymin, xmax, ymax], + # row_start,row_end,col_start,col_end + "t_logic_box": logi_points[i].tolist(), + # [[xmin,xmax,ymin,ymax], text] + "t_ocr_res": [ + [box_4_2_poly_to_box_4_1(ocr_det[0]), ocr_det[1]] + for ocr_det in ocr_res_list + ], + } + res.append(dict_res) + return res + + def sort_and_gather_ocr_res(self, res): + for i, dict_res in enumerate(res): + _, sorted_idx = sorted_ocr_boxes( + [ocr_det[0] for ocr_det in dict_res["t_ocr_res"]], threhold=0.5 + ) + dict_res["t_ocr_res"] = [dict_res["t_ocr_res"][i] for i in sorted_idx] + dict_res["t_ocr_res"] = gather_ocr_list_by_row( + dict_res["t_ocr_res"], threhold=0.5 + ) + return res + + def re_rec( + self, + img: np.ndarray, + sorted_polygons: np.ndarray, + cell_box_map: Dict[int, List[str]], + ) -> Dict[int, List[any]]: + """找到poly对应为空的框,尝试将直接将poly框直接送到识别中""" + # + for i in range(sorted_polygons.shape[0]): + if cell_box_map.get(i): + continue + crop_img = get_rotate_crop_image(img, sorted_polygons[i]) + pad_img = cv2.copyMakeBorder( + crop_img, 5, 5, 100, 100, cv2.BORDER_CONSTANT, value=(255, 255, 255) + ) + rec_res, _ = self.ocr(pad_img, use_det=False, use_cls=True, use_rec=True) + box = sorted_polygons[i] + text = [rec[0] for rec in rec_res] + scores = [rec[1] for rec in rec_res] + cell_box_map[i] = [[box, "".join(text), min(scores)]] + return cell_box_map + + def re_rec_high_precise( + self, + img: np.ndarray, + sorted_polygons: np.ndarray, + cell_box_map: Dict[int, List[str]], + ) -> Dict[int, List[any]]: + """找到poly对应为空的框,尝试将直接将poly框直接送到识别中""" + # + cell_box_map = {} + for i in range(sorted_polygons.shape[0]): + if cell_box_map.get(i): + continue + crop_img = get_rotate_crop_image(img, sorted_polygons[i]) + pad_img = cv2.copyMakeBorder( + crop_img, 10, 10, 10, 10, cv2.BORDER_CONSTANT, value=(255, 255, 255) + ) + rec_res, _ = self.ocr(pad_img, use_det=True, use_cls=True, use_rec=True) + if not rec_res: + det_boxes = [sorted_polygons[i]] + text = [""] + scores = [1.0] + else: + det_boxes = [rec[0] for rec in rec_res] + text = [rec[1] for rec in rec_res] + scores = [rec[2] for rec in rec_res] + cell_box_map[i] = [ + [box, text, score] for box, text, score in zip(det_boxes, text, scores) + ] + return cell_box_map def main(): diff --git a/wired_table_rec/table_line_rec.py b/wired_table_rec/table_line_rec.py index 3549304..7be1004 100644 --- a/wired_table_rec/table_line_rec.py +++ b/wired_table_rec/table_line_rec.py @@ -16,7 +16,12 @@ group_bbox_by_gbox, nms, ) -from .utils_table_recover import merge_adjacent_polys, sorted_boxes +from .utils_table_recover import ( + merge_adjacent_polys, + sorted_ocr_boxes, + box_4_2_poly_to_box_4_1, + filter_duplicated_box, +) class TableLineRecognition: @@ -39,7 +44,12 @@ def __call__(self, img: np.ndarray) -> Optional[np.ndarray]: return None polygons = polygons.reshape(polygons.shape[0], 4, 2) - polygons = sorted_boxes(polygons) + del_idxs = filter_duplicated_box( + [box_4_2_poly_to_box_4_1(box) for box in polygons] + ) + polygons = np.delete(polygons, list(del_idxs), axis=0) + _, idx = sorted_ocr_boxes([box_4_2_poly_to_box_4_1(box) for box in polygons]) + polygons = polygons[idx] polygons = merge_adjacent_polys(polygons) return polygons diff --git a/wired_table_rec/table_line_rec_plus.py b/wired_table_rec/table_line_rec_plus.py new file mode 100644 index 0000000..8880891 --- /dev/null +++ b/wired_table_rec/table_line_rec_plus.py @@ -0,0 +1,114 @@ +import copy +import math +from typing import Optional, Dict, Any + +import cv2 +import numpy as np +from skimage import measure + +from wired_table_rec.utils import OrtInferSession, resize_img +from wired_table_rec.utils_table_line_rec import ( + get_table_line, + final_adjust_lines, + min_area_rect_box, + draw_lines, + adjust_lines, +) +from wired_table_rec.utils_table_recover import ( + sorted_ocr_boxes, + box_4_2_poly_to_box_4_1, +) + + +class TableLineRecognitionPlus: + def __init__(self, model_path: Optional[str] = None): + self.K = 1000 + self.MK = 4000 + self.mean = np.array([123.675, 116.28, 103.53], dtype=np.float32) + self.std = np.array([58.395, 57.12, 57.375], dtype=np.float32) + self.inp_height = 1024 + self.inp_width = 1024 + + self.session = OrtInferSession(model_path) + + def __call__(self, img: np.ndarray) -> Optional[np.ndarray]: + img_info = self.preprocess(img) + pred = self.infer(img_info) + polygons = self.postprocess(img, pred) + if polygons.size == 0: + return None + + polygons = polygons.reshape(polygons.shape[0], 4, 2) + polygons[:, 3, :], polygons[:, 1, :] = ( + polygons[:, 1, :].copy(), + polygons[:, 3, :].copy(), + ) + _, idx = sorted_ocr_boxes( + [box_4_2_poly_to_box_4_1(poly_box) for poly_box in polygons] + ) + polygons = polygons[idx] + return polygons + + def preprocess(self, img) -> Dict[str, Any]: + scale = (self.inp_height, self.inp_width) + img, _, _ = resize_img(img, scale, True) + img = img.copy().astype(np.float32) + assert img.dtype != np.uint8 + mean = np.float64(self.mean.reshape(1, -1)) + stdinv = 1 / np.float64(self.std.reshape(1, -1)) + cv2.cvtColor(img, cv2.COLOR_BGR2RGB, img) # inplace + cv2.subtract(img, mean, img) # inplace + cv2.multiply(img, stdinv, img) # inplace + img = img.transpose(2, 0, 1) + images = img[None, :] + return {"img": images} + + def infer(self, input): + result = self.session(input["img"][None, ...])[0][0] + result = result[0].astype(np.uint8) + return result + + def postprocess(self, img, pred, row=50, col=30, alph=15, angle=50): + ori_shape = img.shape + pred = np.uint8(pred) + hpred = copy.deepcopy(pred) # 横线 + vpred = copy.deepcopy(pred) # 竖线 + whereh = np.where(hpred == 1) + wherev = np.where(vpred == 2) + hpred[wherev] = 0 + vpred[whereh] = 0 + + hpred = cv2.resize(hpred, (ori_shape[1], ori_shape[0])) + vpred = cv2.resize(vpred, (ori_shape[1], ori_shape[0])) + + h, w = pred.shape + hors_k = int(math.sqrt(w) * 1.2) + vert_k = int(math.sqrt(h) * 1.2) + hkernel = cv2.getStructuringElement(cv2.MORPH_RECT, (hors_k, 1)) + vkernel = cv2.getStructuringElement(cv2.MORPH_RECT, (1, vert_k)) + vpred = cv2.morphologyEx( + vpred, cv2.MORPH_CLOSE, vkernel, iterations=1 + ) # 先膨胀后腐蚀的过程 + hpred = cv2.morphologyEx(hpred, cv2.MORPH_CLOSE, hkernel, iterations=1) + colboxes = get_table_line(vpred, axis=1, lineW=col) # 竖线 + rowboxes = get_table_line(hpred, axis=0, lineW=row) # 横线 + # rboxes_row_, rboxes_col_ = adjust_lines(rowboxes, colboxes, alph = alph, angle=angle) + rboxes_row_ = adjust_lines(rowboxes, alph=100, angle=angle) + rboxes_col_ = adjust_lines(colboxes, alph=alph, angle=angle) + rowboxes += rboxes_row_ + colboxes += rboxes_col_ + rowboxes, colboxes = final_adjust_lines(rowboxes, colboxes) + + tmp = np.zeros(img.shape[:2], dtype="uint8") + tmp = draw_lines(tmp, rowboxes + colboxes, color=255, lineW=2) + labels = measure.label(tmp < 255, connectivity=2) # 8连通区域标记 + regions = measure.regionprops(labels) + ceilboxes = min_area_rect_box( + regions, + False, + tmp.shape[1], + tmp.shape[0], + filtersmall=True, + adjust_box=False, + ) # 最后一个参数改为False + return np.array(ceilboxes) diff --git a/wired_table_rec/table_recover.py b/wired_table_rec/table_recover.py index 6123870..be0502c 100644 --- a/wired_table_rec/table_recover.py +++ b/wired_table_rec/table_recover.py @@ -16,7 +16,7 @@ def __call__(self, polygons: np.ndarray) -> Dict[int, Dict]: rows = self.get_rows(polygons) longest_col, each_col_widths, col_nums = self.get_benchmark_cols(rows, polygons) each_row_heights, row_nums = self.get_benchmark_rows(rows, polygons) - table_res = self.get_merge_cells( + table_res, logic_points_dict = self.get_merge_cells( polygons, rows, row_nums, @@ -25,7 +25,10 @@ def __call__(self, polygons: np.ndarray) -> Dict[int, Dict]: each_col_widths, each_row_heights, ) - return table_res + logic_points = np.array( + [logic_points_dict[i] for i in range(len(polygons))] + ).astype(np.int32) + return table_res, logic_points @staticmethod def get_rows(polygons: np.array) -> Dict[int, List[int]]: @@ -38,8 +41,8 @@ def get_rows(polygons: np.array) -> Dict[int, List[int]]: minus_res = concat_y[:, 1] - concat_y[:, 0] result = {} - thresh = 5.0 - split_idxs = np.argwhere(minus_res > thresh).squeeze() + thresh = 10.0 + split_idxs = np.argwhere(abs(minus_res) > thresh).squeeze() if split_idxs.ndim == 0: split_idxs = split_idxs[None, ...] @@ -62,7 +65,7 @@ def get_benchmark_cols( longest_col_points = polygons[longest_col] longest_x = longest_col_points[:, 0, 0] - theta = 10 + theta = 15 for row_value in rows.values(): cur_row = polygons[row_value][:, 0, 0] @@ -112,11 +115,12 @@ def get_benchmark_rows( leftmost_cell_idxs = [v[0] for v in rows.values()] benchmark_x = polygons[leftmost_cell_idxs][:, 0, 1] - theta = 10 + theta = 15 # 遍历其他所有的框,按照y轴进行区间划分 range_res = {} for cur_idx, cur_box in enumerate(polygons): - if cur_idx in benchmark_x: + # fix cur_idx in benchmark_x + if cur_idx in leftmost_cell_idxs: continue cur_y = cur_box[0, 1] @@ -148,7 +152,8 @@ def get_benchmark_rows( # 求出最后一行cell中,最大的高度作为最后一行的高度 bottommost_idxs = list(rows.values())[-1] bottommost_boxes = polygons[bottommost_idxs] - max_height = max([self.compute_L2(v[3, :], v[0, :]) for v in bottommost_boxes]) + # fix self.compute_L2(v[3, :], v[0, :]), v为逆时针,即v[3]为右上,v[0]为左上,v[1]为左下 + max_height = max([self.compute_L2(v[1, :], v[0, :]) for v in bottommost_boxes]) each_row_widths.append(max_height) row_nums = benchmark_x.shape[0] @@ -169,7 +174,8 @@ def get_merge_cells( each_row_heights: List[float], ) -> Dict[int, Dict[int, int]]: col_res_merge, row_res_merge = {}, {} - merge_thresh = 20 + logic_points = {} + merge_thresh = 10 for cur_row, col_list in rows.items(): one_col_result, one_row_result = {}, {} for one_col in col_list: @@ -178,40 +184,62 @@ def get_merge_cells( # 不一定是从0开始的,应该综合已有值和x坐标位置来确定起始位置 loc_col_idx = np.argmin(np.abs(longest_col[:, 0, 0] - box[0, 0])) - merge_col_cell = max(sum(one_col_result.values()), loc_col_idx) + col_start = max(sum(one_col_result.values()), loc_col_idx) # 计算合并多少个列方向单元格 - for i in range(merge_col_cell, col_nums): - col_cum_sum = sum(each_col_widths[merge_col_cell : i + 1]) - if i == merge_col_cell and col_cum_sum > box_width: + for i in range(col_start, col_nums): + col_cum_sum = sum(each_col_widths[col_start : i + 1]) + if i == col_start and col_cum_sum > box_width: one_col_result[one_col] = 1 break elif abs(col_cum_sum - box_width) <= merge_thresh: - one_col_result[one_col] = i + 1 - merge_col_cell + one_col_result[one_col] = i + 1 - col_start + break + # 这里必须进行修正,不然会出现超越阈值范围后列交错 + elif col_cum_sum > box_width: + idx = ( + i + if abs(col_cum_sum - box_width) + < abs(col_cum_sum - each_col_widths[i] - box_width) + else i - 1 + ) + one_col_result[one_col] = idx + 1 - col_start break else: - one_col_result[one_col] = i + 1 - merge_col_cell + 1 - + one_col_result[one_col] = col_nums - col_start + col_end = one_col_result[one_col] + col_start - 1 box_height = self.compute_L2(box[1, :], box[0, :]) - merge_row_cell = cur_row - for j in range(merge_row_cell, row_nums): - row_cum_sum = sum(each_row_heights[merge_row_cell : j + 1]) + row_start = cur_row + for j in range(row_start, row_nums): + row_cum_sum = sum(each_row_heights[row_start : j + 1]) # box_height 不确定是几行的高度,所以要逐个试验,找一个最近的几行的高 # 如果第一次row_cum_sum就比box_height大,那么意味着?丢失了一行 - if j == merge_row_cell and row_cum_sum > box_height: + if j == row_start and row_cum_sum > box_height: one_row_result[one_col] = 1 break - elif abs(box_height - row_cum_sum) <= merge_thresh: - one_row_result[one_col] = j + 1 - merge_row_cell + one_row_result[one_col] = j + 1 - row_start + break + # 这里必须进行修正,不然会出现超越阈值范围后行交错 + elif row_cum_sum > box_height: + idx = ( + j + if abs(row_cum_sum - box_height) + < abs(row_cum_sum - each_row_heights[j] - box_height) + else j - 1 + ) + one_row_result[one_col] = idx + 1 - row_start break else: - one_row_result[one_col] = j + 1 - merge_row_cell + 1 - + one_row_result[one_col] = row_nums - row_start + row_end = one_row_result[one_col] + row_start - 1 + logic_points[one_col] = np.array( + [row_start, row_end, col_start, col_end] + ) col_res_merge[cur_row] = one_col_result row_res_merge[cur_row] = one_row_result res = {} for i, (c, r) in enumerate(zip(col_res_merge.values(), row_res_merge.values())): res[i] = {k: [cc, r[k]] for k, cc in c.items()} - return res + return res, logic_points diff --git a/wired_table_rec/utils.py b/wired_table_rec/utils.py index 28f6eda..a721751 100644 --- a/wired_table_rec/utils.py +++ b/wired_table_rec/utils.py @@ -172,3 +172,181 @@ def verify_exist(file_path: Union[str, Path]): class LoadImageError(Exception): pass + + +# Pillow >=v9.1.0 use a slightly different naming scheme for filters. +# Set pillow_interp_codes according to the naming scheme used. +if Image is not None: + if hasattr(Image, "Resampling"): + pillow_interp_codes = { + "nearest": Image.Resampling.NEAREST, + "bilinear": Image.Resampling.BILINEAR, + "bicubic": Image.Resampling.BICUBIC, + "box": Image.Resampling.BOX, + "lanczos": Image.Resampling.LANCZOS, + "hamming": Image.Resampling.HAMMING, + } + else: + pillow_interp_codes = { + "nearest": Image.NEAREST, + "bilinear": Image.BILINEAR, + "bicubic": Image.BICUBIC, + "box": Image.BOX, + "lanczos": Image.LANCZOS, + "hamming": Image.HAMMING, + } + +cv2_interp_codes = { + "nearest": cv2.INTER_NEAREST, + "bilinear": cv2.INTER_LINEAR, + "bicubic": cv2.INTER_CUBIC, + "area": cv2.INTER_AREA, + "lanczos": cv2.INTER_LANCZOS4, +} + + +def resize_img(img, scale, keep_ratio=True): + if keep_ratio: + # 缩小使用area更保真 + if min(img.shape[:2]) > min(scale): + interpolation = "area" + else: + interpolation = "bicubic" # bilinear + img_new, scale_factor = imrescale( + img, scale, return_scale=True, interpolation=interpolation + ) + # the w_scale and h_scale has minor difference + # a real fix should be done in the mmcv.imrescale in the future + new_h, new_w = img_new.shape[:2] + h, w = img.shape[:2] + w_scale = new_w / w + h_scale = new_h / h + else: + img_new, w_scale, h_scale = imresize(img, scale, return_scale=True) + return img_new, w_scale, h_scale + + +def imrescale(img, scale, return_scale=False, interpolation="bilinear", backend=None): + """Resize image while keeping the aspect ratio. + + Args: + img (ndarray): The input image. + scale (float | tuple[int]): The scaling factor or maximum size. + If it is a float number, then the image will be rescaled by this + factor, else if it is a tuple of 2 integers, then the image will + be rescaled as large as possible within the scale. + return_scale (bool): Whether to return the scaling factor besides the + rescaled image. + interpolation (str): Same as :func:`resize`. + backend (str | None): Same as :func:`resize`. + + Returns: + ndarray: The rescaled image. + """ + h, w = img.shape[:2] + new_size, scale_factor = rescale_size((w, h), scale, return_scale=True) + rescaled_img = imresize(img, new_size, interpolation=interpolation, backend=backend) + if return_scale: + return rescaled_img, scale_factor + else: + return rescaled_img + + +def imresize( + img, size, return_scale=False, interpolation="bilinear", out=None, backend=None +): + """Resize image to a given size. + + Args: + img (ndarray): The input image. + size (tuple[int]): Target size (w, h). + return_scale (bool): Whether to return `w_scale` and `h_scale`. + interpolation (str): Interpolation method, accepted values are + "nearest", "bilinear", "bicubic", "area", "lanczos" for 'cv2' + backend, "nearest", "bilinear" for 'pillow' backend. + out (ndarray): The output destination. + backend (str | None): The image resize backend type. Options are `cv2`, + `pillow`, `None`. If backend is None, the global imread_backend + specified by ``mmcv.use_backend()`` will be used. Default: None. + + Returns: + tuple | ndarray: (`resized_img`, `w_scale`, `h_scale`) or + `resized_img`. + """ + h, w = img.shape[:2] + if backend is None: + backend = "cv2" + if backend not in ["cv2", "pillow"]: + raise ValueError( + f"backend: {backend} is not supported for resize." + f"Supported backends are 'cv2', 'pillow'" + ) + + if backend == "pillow": + assert img.dtype == np.uint8, "Pillow backend only support uint8 type" + pil_image = Image.fromarray(img) + pil_image = pil_image.resize(size, pillow_interp_codes[interpolation]) + resized_img = np.array(pil_image) + else: + resized_img = cv2.resize( + img, size, dst=out, interpolation=cv2_interp_codes[interpolation] + ) + if not return_scale: + return resized_img + else: + w_scale = size[0] / w + h_scale = size[1] / h + return resized_img, w_scale, h_scale + + +def rescale_size(old_size, scale, return_scale=False): + """Calculate the new size to be rescaled to. + + Args: + old_size (tuple[int]): The old size (w, h) of image. + scale (float | tuple[int]): The scaling factor or maximum size. + If it is a float number, then the image will be rescaled by this + factor, else if it is a tuple of 2 integers, then the image will + be rescaled as large as possible within the scale. + return_scale (bool): Whether to return the scaling factor besides the + rescaled image size. + + Returns: + tuple[int]: The new rescaled image size. + """ + w, h = old_size + if isinstance(scale, (float, int)): + if scale <= 0: + raise ValueError(f"Invalid scale {scale}, must be positive.") + scale_factor = scale + elif isinstance(scale, tuple): + max_long_edge = max(scale) + max_short_edge = min(scale) + scale_factor = min(max_long_edge / max(h, w), max_short_edge / min(h, w)) + else: + raise TypeError( + f"Scale must be a number or tuple of int, but got {type(scale)}" + ) + + new_size = _scale_size((w, h), scale_factor) + + if return_scale: + return new_size, scale_factor + else: + return new_size + + +def _scale_size(size, scale): + """Rescale a size by a ratio. + + Args: + size (tuple[int]): (w, h). + scale (float | tuple(float)): Scaling factor. + + Returns: + tuple[int]: scaled size. + """ + if isinstance(scale, (float, int)): + scale = (scale, scale) + w, h = size + return int(w * float(scale[0]) + 0.5), int(h * float(scale[1]) + 0.5) diff --git a/wired_table_rec/utils_table_line_rec.py b/wired_table_rec/utils_table_line_rec.py index 0e2caa5..6ebca07 100644 --- a/wired_table_rec/utils_table_line_rec.py +++ b/wired_table_rec/utils_table_line_rec.py @@ -6,6 +6,8 @@ import cv2 import numpy as np +from scipy.spatial import distance as dist +from skimage import measure def bbox_decode(heat, wh, reg=None, K=100): @@ -393,3 +395,282 @@ def get_distance(pt1, pt2): bboxes[k][2 * min_id + 1] = vertex[1] sign[k][min_id] = 1 return bboxes + + +def get_table_line(binimg, axis=0, lineW=10): + ##获取表格线 + ##axis=0 横线 + ##axis=1 竖线 + labels = measure.label(binimg > 0, connectivity=2) # 8连通区域标记 + regions = measure.regionprops(labels) + if axis == 1: + lineboxes = [ + min_area_rect(line.coords) + for line in regions + if line.bbox[2] - line.bbox[0] > lineW + ] + else: + lineboxes = [ + min_area_rect(line.coords) + for line in regions + if line.bbox[3] - line.bbox[1] > lineW + ] + return lineboxes + + +def min_area_rect(coords): + """ + 多边形外接矩形 + """ + rect = cv2.minAreaRect(coords[:, ::-1]) + box = cv2.boxPoints(rect) + box = box.reshape((8,)).tolist() + + box = image_location_sort_box(box) + + x1, y1, x2, y2, x3, y3, x4, y4 = box + degree, w, h, cx, cy = calculate_center_rotate_angle(box) + if w < h: + xmin = (x1 + x2) / 2 + xmax = (x3 + x4) / 2 + ymin = (y1 + y2) / 2 + ymax = (y3 + y4) / 2 + + else: + xmin = (x1 + x4) / 2 + xmax = (x2 + x3) / 2 + ymin = (y1 + y4) / 2 + ymax = (y2 + y3) / 2 + # degree,w,h,cx,cy = solve(box) + # x1,y1,x2,y2,x3,y3,x4,y4 = box + # return {'degree':degree,'w':w,'h':h,'cx':cx,'cy':cy} + return [xmin, ymin, xmax, ymax] + + +def image_location_sort_box(box): + x1, y1, x2, y2, x3, y3, x4, y4 = box[:8] + pts = (x1, y1), (x2, y2), (x3, y3), (x4, y4) + pts = np.array(pts, dtype="float32") + (x1, y1), (x2, y2), (x3, y3), (x4, y4) = _order_points(pts) + return [x1, y1, x2, y2, x3, y3, x4, y4] + + +def calculate_center_rotate_angle(box): + """ + 绕 cx,cy点 w,h 旋转 angle 的坐标,能一定程度缓解图片的内部倾斜,但是还是依赖模型稳妥 + x = cx-w/2 + y = cy-h/2 + x1-cx = -w/2*cos(angle) +h/2*sin(angle) + y1 -cy= -w/2*sin(angle) -h/2*cos(angle) + + h(x1-cx) = -wh/2*cos(angle) +hh/2*sin(angle) + w(y1 -cy)= -ww/2*sin(angle) -hw/2*cos(angle) + (hh+ww)/2sin(angle) = h(x1-cx)-w(y1 -cy) + + """ + x1, y1, x2, y2, x3, y3, x4, y4 = box[:8] + cx = (x1 + x3 + x2 + x4) / 4.0 + cy = (y1 + y3 + y4 + y2) / 4.0 + w = ( + np.sqrt((x2 - x1) ** 2 + (y2 - y1) ** 2) + + np.sqrt((x3 - x4) ** 2 + (y3 - y4) ** 2) + ) / 2 + h = ( + np.sqrt((x2 - x3) ** 2 + (y2 - y3) ** 2) + + np.sqrt((x1 - x4) ** 2 + (y1 - y4) ** 2) + ) / 2 + # x = cx-w/2 + # y = cy-h/2 + sinA = (h * (x1 - cx) - w * (y1 - cy)) * 1.0 / (h * h + w * w) * 2 + angle = np.arcsin(sinA) + return angle, w, h, cx, cy + + +def _order_points(pts): + # 根据x坐标对点进行排序 + """ + --------------------- + 本项目中是为了排序后得到[(xmin,ymin),(xmax,ymin),(xmax,ymax),(xmin,ymax)] + 作者:Tong_T + 来源:CSDN + 原文:https://blog.csdn.net/Tong_T/article/details/81907132 + 版权声明:本文为博主原创文章,转载请附上博文链接! + """ + x_sorted = pts[np.argsort(pts[:, 0]), :] + + left_most = x_sorted[:2, :] + right_most = x_sorted[2:, :] + left_most = left_most[np.argsort(left_most[:, 1]), :] + (tl, bl) = left_most + + distance = dist.cdist(tl[np.newaxis], right_most, "euclidean")[0] + (br, tr) = right_most[np.argsort(distance)[::-1], :] + + return np.array([tl, tr, br, bl], dtype="float32") + + +def sqrt(p1, p2): + return np.sqrt((p1[0] - p2[0]) ** 2 + (p1[1] - p2[1]) ** 2) + + +def adjust_lines(lines, alph=50, angle=50): + lines_n = len(lines) + new_lines = [] + for i in range(lines_n): + x1, y1, x2, y2 = lines[i] + cx1, cy1 = (x1 + x2) / 2, (y1 + y2) / 2 + for j in range(lines_n): + if i != j: + x3, y3, x4, y4 = lines[j] + cx2, cy2 = (x3 + x4) / 2, (y3 + y4) / 2 + if (x3 < cx1 < x4 or y3 < cy1 < y4) or ( + x1 < cx2 < x2 or y1 < cy2 < y2 + ): # 判断两个横线在y方向的投影重不重合 + continue + else: + r = sqrt((x1, y1), (x3, y3)) + k = abs((y3 - y1) / (x3 - x1 + 1e-10)) + a = math.atan(k) * 180 / math.pi + if r < alph and a < angle: + new_lines.append((x1, y1, x3, y3)) + + r = sqrt((x1, y1), (x4, y4)) + k = abs((y4 - y1) / (x4 - x1 + 1e-10)) + a = math.atan(k) * 180 / math.pi + if r < alph and a < angle: + new_lines.append((x1, y1, x4, y4)) + + r = sqrt((x2, y2), (x3, y3)) + k = abs((y3 - y2) / (x3 - x2 + 1e-10)) + a = math.atan(k) * 180 / math.pi + if r < alph and a < angle: + new_lines.append((x2, y2, x3, y3)) + r = sqrt((x2, y2), (x4, y4)) + k = abs((y4 - y2) / (x4 - x2 + 1e-10)) + a = math.atan(k) * 180 / math.pi + if r < alph and a < angle: + new_lines.append((x2, y2, x4, y4)) + return new_lines + + +def final_adjust_lines(rowboxes, colboxes): + nrow = len(rowboxes) + ncol = len(colboxes) + for i in range(nrow): + for j in range(ncol): + rowboxes[i] = line_to_line(rowboxes[i], colboxes[j], alpha=20, angle=30) + colboxes[j] = line_to_line(colboxes[j], rowboxes[i], alpha=20, angle=30) + return rowboxes, colboxes + + +def draw_lines(im, bboxes, color=(0, 0, 0), lineW=3): + """ + boxes: bounding boxes + """ + tmp = np.copy(im) + c = color + h, w = im.shape[:2] + + for box in bboxes: + x1, y1, x2, y2 = box[:4] + cv2.line( + tmp, (int(x1), int(y1)), (int(x2), int(y2)), c, lineW, lineType=cv2.LINE_AA + ) + + return tmp + + +def line_to_line(points1, points2, alpha=10, angle=30): + """ + 线段之间的距离 + """ + x1, y1, x2, y2 = points1 + ox1, oy1, ox2, oy2 = points2 + xy = np.array([(x1, y1), (x2, y2)], dtype="float32") + A1, B1, C1 = fit_line(xy) + oxy = np.array([(ox1, oy1), (ox2, oy2)], dtype="float32") + A2, B2, C2 = fit_line(oxy) + flag1 = point_line_cor(np.array([x1, y1], dtype="float32"), A2, B2, C2) + flag2 = point_line_cor(np.array([x2, y2], dtype="float32"), A2, B2, C2) + + if (flag1 > 0 and flag2 > 0) or (flag1 < 0 and flag2 < 0): # 横线或者竖线在竖线或者横线的同一侧 + if (A1 * B2 - A2 * B1) != 0: + x = (B1 * C2 - B2 * C1) / (A1 * B2 - A2 * B1) + y = (A2 * C1 - A1 * C2) / (A1 * B2 - A2 * B1) + # x, y = round(x, 2), round(y, 2) + p = (x, y) # 横线与竖线的交点 + r0 = sqrt(p, (x1, y1)) + r1 = sqrt(p, (x2, y2)) + + if min(r0, r1) < alpha: # 若交点与线起点或者终点的距离小于alpha,则延长线到交点 + if r0 < r1: + k = abs((y2 - p[1]) / (x2 - p[0] + 1e-10)) + a = math.atan(k) * 180 / math.pi + if a < angle or abs(90 - a) < angle: + points1 = np.array([p[0], p[1], x2, y2], dtype="float32") + else: + k = abs((y1 - p[1]) / (x1 - p[0] + 1e-10)) + a = math.atan(k) * 180 / math.pi + if a < angle or abs(90 - a) < angle: + points1 = np.array([x1, y1, p[0], p[1]], dtype="float32") + return points1 + + +def min_area_rect_box( + regions, flag=True, W=0, H=0, filtersmall=False, adjust_box=False +): + """ + 多边形外接矩形 + """ + boxes = [] + for region in regions: + if region.bbox_area > H * W * 3 / 4: # 过滤大的单元格 + continue + rect = cv2.minAreaRect(region.coords[:, ::-1]) + + box = cv2.boxPoints(rect) + box = box.reshape((8,)).tolist() + box = image_location_sort_box(box) + x1, y1, x2, y2, x3, y3, x4, y4 = box + angle, w, h, cx, cy = calculate_center_rotate_angle(box) + # if adjustBox: + # x1, y1, x2, y2, x3, y3, x4, y4 = xy_rotate_box(cx, cy, w + 5, h + 5, angle=0, degree=None) + # x1, x4 = max(x1, 0), max(x4, 0) + # y1, y2 = max(y1, 0), max(y2, 0) + + # if w > 32 and h > 32 and flag: + # if abs(angle / np.pi * 180) < 20: + # if filtersmall and (w < 10 or h < 10): + # continue + # boxes.append([x1, y1, x2, y2, x3, y3, x4, y4]) + # else: + if w * h < 0.5 * W * H: + if filtersmall and ( + w < 15 or h < 15 + ): # or w / h > 30 or h / w > 30): # 过滤小的单元格 + continue + boxes.append([x1, y1, x2, y2, x3, y3, x4, y4]) + return boxes + + +def point_line_cor(p, A, B, C): + ##判断点与线之间的位置关系 + # 一般式直线方程(Ax+By+c)=0 + x, y = p + r = A * x + B * y + C + return r + + +def fit_line(p): + """A = Y2 - Y1 + B = X1 - X2 + C = X2*Y1 - X1*Y2 + AX+BY+C=0 + 直线一般方程 + """ + x1, y1 = p[0] + x2, y2 = p[1] + A = y2 - y1 + B = x1 - x2 + C = x2 * y1 - x1 * y2 + return A, B, C diff --git a/wired_table_rec/utils_table_recover.py b/wired_table_rec/utils_table_recover.py index 10ddc0e..ca2cae5 100644 --- a/wired_table_rec/utils_table_recover.py +++ b/wired_table_rec/utils_table_recover.py @@ -1,8 +1,9 @@ # -*- encoding: utf-8 -*- # @Author: SWHL # @Contact: liekkaskono@163.com +import os import random -from typing import Dict, List, Union +from typing import Dict, List, Union, Any import cv2 import numpy as np @@ -35,6 +36,366 @@ def sorted_boxes(dt_boxes: np.ndarray) -> np.ndarray: return np.array(_boxes) +def calculate_iou(box1: list | np.ndarray, box2: list | np.ndarray) -> float: + """ + :param box1: Iterable [xmin,ymin,xmax,ymax] + :param box2: Iterable [xmin,ymin,xmax,ymax] + :return: iou: float 0-1 + """ + b1_x1, b1_y1, b1_x2, b1_y2 = box1[0], box1[1], box1[2], box1[3] + b2_x1, b2_y1, b2_x2, b2_y2 = box2[0], box2[1], box2[2], box2[3] + # 不相交直接退出检测 + if b1_x2 < b2_x1 or b1_x1 > b2_x2 or b1_y2 < b2_y1 or b1_y1 > b2_y2: + return 0.0 + # 计算交集 + inter_x1 = max(b1_x1, b2_x1) + inter_y1 = max(b1_y1, b2_y1) + inter_x2 = min(b1_x2, b2_x2) + inter_y2 = min(b1_y2, b2_y2) + i_area = max(0, inter_x2 - inter_x1) * max(0, inter_y2 - inter_y1) + + # 计算并集 + b1_area = (b1_x2 - b1_x1) * (b1_y2 - b1_y1) + b2_area = (b2_x2 - b2_x1) * (b2_y2 - b2_y1) + u_area = b1_area + b2_area - i_area + + # 避免除零错误,如果区域小到乘积为0,认为是错误识别,直接去掉 + if u_area == 0: + return 1 + # 检查完全包含 + iou = i_area / u_area + return iou + + +def caculate_single_axis_iou( + box1: list | np.ndarray, box2: list | np.ndarray, axis="x" +) -> float: + """ + :param box1: Iterable [xmin,ymin,xmax,ymax] + :param box2: Iterable [xmin,ymin,xmax,ymax] + :return: iou: float 0-1 + """ + b1_x1, b1_y1, b1_x2, b1_y2 = box1 + b2_x1, b2_y1, b2_x2, b2_y2 = box2 + if axis == "x": + i_min = max(b1_x1, b2_x1) + i_max = min(b1_x2, b2_x2) + u_area = max(b1_x2, b2_x2) - min(b1_x1, b2_x1) + else: + i_min = max(b1_y1, b2_y1) + i_max = min(b1_y2, b2_y2) + u_area = max(b1_y2, b2_y2) - min(b1_y1, b2_y1) + i_area = max(i_max - i_min, 0) + if u_area == 0: + return 1 + return i_area / u_area + + +def is_box_contained( + box1: list | np.ndarray, box2: list | np.ndarray, threshold=0.2 +) -> int | None: + """ + :param box1: Iterable [xmin,ymin,xmax,ymax] + :param box2: Iterable [xmin,ymin,xmax,ymax] + :return: 1: box1 is contained 2: box2 is contained None: no contain these + """ + b1_x1, b1_y1, b1_x2, b1_y2 = box1[0], box1[1], box1[2], box1[3] + b2_x1, b2_y1, b2_x2, b2_y2 = box2[0], box2[1], box2[2], box2[3] + # 不相交直接退出检测 + if b1_x2 < b2_x1 or b1_x1 > b2_x2 or b1_y2 < b2_y1 or b1_y1 > b2_y2: + return None + # 计算box2的总面积 + b2_area = (b2_x2 - b2_x1) * (b2_y2 - b2_y1) + b1_area = (b1_x2 - b1_x1) * (b1_y2 - b1_y1) + + # 计算box1和box2的交集 + intersect_x1 = max(b1_x1, b2_x1) + intersect_y1 = max(b1_y1, b2_y1) + intersect_x2 = min(b1_x2, b2_x2) + intersect_y2 = min(b1_y2, b2_y2) + + # 计算交集的面积 + intersect_area = max(0, intersect_x2 - intersect_x1) * max( + 0, intersect_y2 - intersect_y1 + ) + + # 计算外面的面积 + b1_outside_area = b1_area - intersect_area + b2_outside_area = b2_area - intersect_area + + # 计算外面的面积占box2总面积的比例 + ratio_b1 = b1_outside_area / b1_area if b1_area > 0 else 0 + ratio_b2 = b2_outside_area / b2_area if b2_area > 0 else 0 + + if ratio_b1 < threshold: + return 1 + if ratio_b2 < threshold: + return 2 + # 判断比例是否大于阈值 + return None + + +def is_single_axis_contained( + box1: list | np.ndarray, box2: list | np.ndarray, axis="x", threhold: float = 0.2 +) -> int | None: + """ + :param box1: Iterable [xmin,ymin,xmax,ymax] + :param box2: Iterable [xmin,ymin,xmax,ymax] + :return: 1: box1 is contained 2: box2 is contained None: no contain these + """ + b1_x1, b1_y1, b1_x2, b1_y2 = box1[0], box1[1], box1[2], box1[3] + b2_x1, b2_y1, b2_x2, b2_y2 = box2[0], box2[1], box2[2], box2[3] + + # 计算轴重叠大小 + if axis == "x": + b1_area = b1_x2 - b1_x1 + b2_area = b2_x2 - b2_x1 + i_area = min(b1_x2, b2_x2) - max(b1_x1, b2_x1) + else: + b1_area = b1_y2 - b1_y1 + b2_area = b2_y2 - b2_y1 + i_area = min(b1_y2, b2_y2) - max(b1_y1, b2_y1) + # 计算外面的面积 + b1_outside_area = b1_area - i_area + b2_outside_area = b2_area - i_area + + ratio_b1 = b1_outside_area / b1_area if b1_area > 0 else 0 + ratio_b2 = b2_outside_area / b2_area if b2_area > 0 else 0 + if ratio_b1 < threhold: + return 1 + if ratio_b2 < threhold: + return 2 + return None + + +def filter_duplicated_box(table_boxes: list[list[float]]) -> set[int]: + """ + :param table_boxes: [[xmin,ymin,xmax,ymax]] + :return: + """ + delete_idx = set() + for i in range(len(table_boxes)): + polygons_i = table_boxes[i] + if i in delete_idx: + continue + for j in range(i + 1, len(table_boxes)): + if j in delete_idx: + continue + # 下一个box + polygons_j = table_boxes[j] + # 重叠关系先记录,后续删除掉 + if calculate_iou(polygons_i, polygons_j) > 0.8: + delete_idx.add(j) + continue + # 是否存在包含关系 + contained_idx = is_box_contained(polygons_i, polygons_j) + if contained_idx == 2: + delete_idx.add(j) + elif contained_idx == 1: + delete_idx.add(i) + return delete_idx + + +def sorted_ocr_boxes( + dt_boxes: np.ndarray | list, threhold: float = 0.2 +) -> tuple[np.ndarray | list, list[int]]: + """ + Sort text boxes in order from top to bottom, left to right + args: + dt_boxes(array):detected text boxes with (xmin, ymin, xmax, ymax) + return: + sorted boxes(array) with (xmin, ymin, xmax, ymax) + """ + num_boxes = len(dt_boxes) + if num_boxes <= 0: + return dt_boxes, [] + indexed_boxes = [(box, idx) for idx, box in enumerate(dt_boxes)] + sorted_boxes_with_idx = sorted(indexed_boxes, key=lambda x: (x[0][1], x[0][0])) + _boxes, indices = zip(*sorted_boxes_with_idx) + indices = list(indices) + _boxes = [dt_boxes[i] for i in indices] + threahold = 20 + # 避免输出和输入格式不对应,与函数功能不符合 + if isinstance(dt_boxes, np.ndarray): + _boxes = np.array(_boxes) + for i in range(num_boxes - 1): + for j in range(i, -1, -1): + c_idx = is_single_axis_contained( + _boxes[j], _boxes[j + 1], axis="y", threhold=threhold + ) + if ( + c_idx is not None + and _boxes[j + 1][0] < _boxes[j][0] + and abs(_boxes[j][1] - _boxes[j + 1][1]) < threahold + ): + _boxes[j], _boxes[j + 1] = _boxes[j + 1].copy(), _boxes[j].copy() + indices[j], indices[j + 1] = indices[j + 1], indices[j] + else: + break + return _boxes, indices + + +def plot_rec_box_with_logic_info(img_path, output_path, logic_points, sorted_polygons): + """ + :param img_path + :param output_path + :param logic_points: [row_start,row_end,col_start,col_end] + :param sorted_polygons: [xmin,ymin,xmax,ymax] + :return: + """ + # 读取原图 + img = cv2.imread(img_path) + img = cv2.copyMakeBorder( + img, 0, 0, 0, 100, cv2.BORDER_CONSTANT, value=[255, 255, 255] + ) + # 绘制 polygons 矩形 + for idx, polygon in enumerate(sorted_polygons): + x0, y0, x1, y1 = polygon[0], polygon[1], polygon[2], polygon[3] + x0 = round(x0) + y0 = round(y0) + x1 = round(x1) + y1 = round(y1) + cv2.rectangle(img, (x0, y0), (x1, y1), (0, 0, 255), 1) + # 增大字体大小和线宽 + font_scale = 1.0 # 原先是0.5 + thickness = 2 # 原先是1 + + cv2.putText( + img, + f"{idx}-{logic_points[idx]}", + (x1, y1), + cv2.FONT_HERSHEY_PLAIN, + font_scale, + (0, 0, 255), + thickness, + ) + os.makedirs(os.path.dirname(output_path), exist_ok=True) + # 保存绘制后的图像 + cv2.imwrite(output_path, img) + + +def plot_rec_box(img_path, output_path, sorted_polygons): + """ + :param img_path + :param output_path + :param sorted_polygons: [xmin,ymin,xmax,ymax] + :return: + """ + # 处理ocr_res + img = cv2.imread(img_path) + img = cv2.copyMakeBorder( + img, 0, 0, 0, 100, cv2.BORDER_CONSTANT, value=[255, 255, 255] + ) + # 绘制 ocr_res 矩形 + for idx, polygon in enumerate(sorted_polygons): + x0, y0, x1, y1 = polygon[0], polygon[1], polygon[2], polygon[3] + x0 = round(x0) + y0 = round(y0) + x1 = round(x1) + y1 = round(y1) + cv2.rectangle(img, (x0, y0), (x1, y1), (0, 0, 255), 1) + # 增大字体大小和线宽 + font_scale = 1.0 # 原先是0.5 + thickness = 2 # 原先是1 + + cv2.putText( + img, + str(idx), + (x1, y1), + cv2.FONT_HERSHEY_PLAIN, + font_scale, + (0, 0, 255), + thickness, + ) + os.makedirs(os.path.dirname(output_path), exist_ok=True) + # 保存绘制后的图像 + cv2.imwrite(output_path, img) + + +def box_4_1_poly_to_box_4_2(poly_box: list | np.ndarray) -> list[list[float]]: + xmin, ymin, xmax, ymax = tuple(poly_box) + return [[xmin, ymin], [xmax, ymin], [xmax, ymax], [xmin, ymax]] + + +def box_4_2_poly_to_box_4_1(poly_box: list | np.ndarray) -> list[float]: + """ + 将poly_box转换为box_4_1 + :param poly_box: + :return: + """ + return [poly_box[0][0], poly_box[0][1], poly_box[2][0], poly_box[2][1]] + + +def match_ocr_cell(dt_rec_boxes: List[List[Union[Any, str]]], pred_bboxes: np.ndarray): + """ + :param dt_rec_boxes: [[(4.2), text, score]] + :param pred_bboxes: shap (4,2) + :return: + """ + matched = {} + not_match_orc_boxes = [] + for i, gt_box in enumerate(dt_rec_boxes): + for j, pred_box in enumerate(pred_bboxes): + pred_box = [pred_box[0][0], pred_box[0][1], pred_box[2][0], pred_box[2][1]] + ocr_boxes = gt_box[0] + # xmin,ymin,xmax,ymax + ocr_box = ( + ocr_boxes[0][0], + ocr_boxes[0][1], + ocr_boxes[2][0], + ocr_boxes[2][1], + ) + contained = is_box_contained(ocr_box, pred_box, 0.6) + if contained == 1 or calculate_iou(ocr_box, pred_box) > 0.8: + if j not in matched: + matched[j] = [gt_box] + else: + matched[j].append(gt_box) + else: + not_match_orc_boxes.append(gt_box) + + return matched, not_match_orc_boxes + + +def gather_ocr_list_by_row( + ocr_list: list[list[list[float], str]], threhold: float = 0.2 +) -> list[list[list[float], str]]: + """ + :param ocr_list: [[[xmin,ymin,xmax,ymax], text]] + :return: + """ + threshold = 10 + for i in range(len(ocr_list)): + if not ocr_list[i]: + continue + + for j in range(i + 1, len(ocr_list)): + if not ocr_list[j]: + continue + cur = ocr_list[i] + next = ocr_list[j] + cur_box = cur[0] + next_box = next[0] + c_idx = is_single_axis_contained( + cur[0], next[0], axis="y", threhold=threhold + ) + if c_idx: + dis = max(next_box[0] - cur_box[1], 0) + blank_str = int(dis / threshold) * " " + cur[1] = cur[1] + blank_str + next[1] + xmin = min(cur_box[0], next_box[0]) + xmax = max(cur_box[2], next_box[2]) + ymin = min(cur_box[1], next_box[1]) + ymax = max(cur_box[3], next_box[3]) + cur_box[0] = xmin + cur_box[1] = ymin + cur_box[2] = xmax + cur_box[3] = ymax + ocr_list[j] = None + ocr_list = [x for x in ocr_list if x] + return ocr_list + + def compute_poly_iou(a: np.ndarray, b: np.ndarray) -> float: """计算两个多边形的IOU @@ -122,33 +483,41 @@ def combine_two_poly(polygons: np.ndarray, idxs: np.ndarray) -> np.ndarray: return polygons -def match_ocr_cell( - polygons: np.ndarray, ocr_res: List[Union[List[List[float]], str, str]] -) -> Dict[int, List]: - cell_box_map = {} - dt_boxes, rec_res, _ = list(zip(*ocr_res)) - dt_boxes = np.array(dt_boxes) - iou_thresh = 0.009 - for i, cell_box in enumerate(polygons): - ious = [compute_poly_iou(dt_box, cell_box) for dt_box in dt_boxes] - - # 对有iou的值,计算是否存在包含关系。如存在→iou=1 - have_iou_idxs = np.argwhere(ious) - if have_iou_idxs.size > 0: - have_iou_idxs = have_iou_idxs.squeeze(1) - for idx in have_iou_idxs: - if is_inclusive_each_other(cell_box, dt_boxes[idx]): - ious[idx] = 1.0 - - if all(x <= iou_thresh for x in ious): - # 说明这个cell中没有文本 - cell_box_map.setdefault(i, []).append("") - continue - - same_cell_idxs = np.argwhere(np.array(ious) >= iou_thresh).squeeze(1) - one_cell_txts = "\n".join([rec_res[idx] for idx in same_cell_idxs]) - cell_box_map.setdefault(i, []).append(one_cell_txts) - return cell_box_map +def get_rotate_crop_image(img: np.ndarray, points: np.ndarray) -> np.ndarray: + img_crop_width = int( + max( + np.linalg.norm(points[0] - points[1]), + np.linalg.norm(points[2] - points[3]), + ) + ) + img_crop_height = int( + max( + np.linalg.norm(points[0] - points[3]), + np.linalg.norm(points[1] - points[2]), + ) + ) + pts_std = np.float32( + [ + [0, 0], + [img_crop_width, 0], + [img_crop_width, img_crop_height], + [0, img_crop_height], + ] + ) + M = cv2.getPerspectiveTransform( + points.astype(np.float32), pts_std.astype(np.float32) + ) + dst_img = cv2.warpPerspective( + img, + M, + (img_crop_width, img_crop_height), + borderMode=cv2.BORDER_REPLICATE, + flags=cv2.INTER_CUBIC, + ) + dst_img_height, dst_img_width = dst_img.shape[0:2] + if dst_img_height * 1.0 / dst_img_width >= 1.5: + dst_img = np.rot90(dst_img) + return dst_img def is_inclusive_each_other(box1: np.ndarray, box2: np.ndarray): @@ -185,27 +554,75 @@ def is_inclusive_each_other(box1: np.ndarray, box2: np.ndarray): return False -def plot_html_table(table_res: Dict, cell_box_map: Dict) -> str: - table_str = f"" - for v in table_res.values(): - table_str += "" - for cell_idx, span_info in v.items(): - cur_txt = "
".join(cell_box_map.get(cell_idx, "")) - colspan, rowspan = span_info - if colspan != 1 and rowspan != 1: - table_str += ( - f'' - ) - elif colspan == 1 and rowspan != 1: - table_str += f'' - elif colspan != 1 and rowspan == 1: - table_str += f'' +def plot_html_table( + logi_points: np.ndarray | list, cell_box_map: Dict[int, List[str]] +) -> str: + # 初始化最大行数和列数 + max_row = 0 + max_col = 0 + # 计算最大行数和列数 + for point in logi_points: + max_row = max(max_row, point[1] + 1) # 加1是因为结束下标是包含在内的 + max_col = max(max_col, point[3] + 1) # 加1是因为结束下标是包含在内的 + + # 创建一个二维数组来存储 sorted_logi_points 中的元素 + grid = [[None] * max_col for _ in range(max_row)] + + valid_start_row = (1 << 16) - 1 + valid_end_row = 0 + valid_start_col = (1 << 16) - 1 + valid_end_col = 0 + # 将 sorted_logi_points 中的元素填充到 grid 中 + for i, logic_point in enumerate(logi_points): + row_start, row_end, col_start, col_end = ( + logic_point[0], + logic_point[1], + logic_point[2], + logic_point[3], + ) + ocr_rec_text_list = cell_box_map.get(i) + if ocr_rec_text_list and "".join(ocr_rec_text_list): + valid_start_row = min(row_start, valid_start_row) + valid_start_col = min(col_start, valid_start_col) + valid_end_row = max(row_end, valid_end_row) + valid_end_col = max(col_end, valid_end_col) + for row in range(row_start, row_end + 1): + for col in range(col_start, col_end + 1): + grid[row][col] = (i, row_start, row_end, col_start, col_end) + + # 创建表格 + table_html = "
{cur_txt}{cur_txt}{cur_txt}
" + + # 遍历每行 + for row in range(max_row): + if row < valid_start_row or row > valid_end_row: + continue + temp = "" + # 遍历每一列 + for col in range(max_col): + if col < valid_start_col or col > valid_end_col: + continue + if not grid[row][col]: + temp += "" else: - table_str += f"" - - table_str += "" - table_str += "
{cur_txt}
" - return table_str + i, row_start, row_end, col_start, col_end = grid[row][col] + if not cell_box_map.get(i): + continue + if row == row_start and col == col_start: + ocr_rec_text = cell_box_map.get(i) + text = "
".join(ocr_rec_text) + # 如果是起始单元格 + row_span = row_end - row_start + 1 + col_span = col_end - col_start + 1 + cell_content = ( + f"{text}" + ) + temp += cell_content + + table_html = table_html + temp + "" + + table_html += "" + return table_html def vis_table(img: np.ndarray, polygons: np.ndarray) -> np.ndarray: