Skip to content

Commit

Permalink
Merge pull request #40 from RapidAI/develop
Browse files Browse the repository at this point in the history
feat: support unitable and optimize code and bump to v1.0.2
  • Loading branch information
SWHL authored Jan 8, 2025
2 parents 7810dfd + f0a3dfe commit cf06b0e
Show file tree
Hide file tree
Showing 21 changed files with 310 additions and 375 deletions.
18 changes: 5 additions & 13 deletions .github/workflows/publish_whl.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ on:
- v*

env:
RESOURCES_URL: https://github.com/RapidAI/RapidTable/releases/download/assets/rapid_table_models.zip
DEFAULT_MODEL: https://www.modelscope.cn/models/RapidAI/RapidTable/resolve/master/slanet-plus.onnx

jobs:
UnitTesting:
Expand All @@ -26,20 +26,15 @@ jobs:

- name: Unit testings
run: |
wget $RESOURCES_URL
ZIP_NAME=${RESOURCES_URL##*/}
DIR_NAME=${ZIP_NAME%.*}
unzip $DIR_NAME
cp $DIR_NAME/*.onnx rapid_table/models/
wget $DEFAULT_MODEL -P rapid_table/models
pip install -r requirements.txt
pip install rapidocr_onnxruntime
pip install torch
pip install torchvision
pip install tokenizers
pip install pytest
pytest tests/test_table.py
pytest tests/test_table_torch.py
pytest tests/test_main.py
GenerateWHL_PushPyPi:
needs: UnitTesting
Expand All @@ -59,11 +54,8 @@ jobs:
pip install -r requirements.txt
python -m pip install --upgrade pip
pip install wheel get_pypi_latest_version
wget $RESOURCES_URL
ZIP_NAME=${RESOURCES_URL##*/}
DIR_NAME=${ZIP_NAME%.*}
unzip $ZIP_NAME
mv $DIR_NAME/slanet-plus.onnx rapid_table/models/
wget $DEFAULT_MODEL -P rapid_table/models
python setup.py bdist_wheel ${{ github.ref_name }}
- name: Publish distribution 📦 to PyPI
Expand Down
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
*.json

# Created by .ignore support plugin (hsz.mobi)
### Python template
# Byte-compiled / optimized / DLL files
Expand Down
2 changes: 1 addition & 1 deletion LICENSE
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@
same "printed page" as the copyright notice for easier
identification within third-party archives.

Copyright 2024 RapidAI
Copyright 2025 RapidAI

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
Expand Down
106 changes: 59 additions & 47 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,14 @@ slanet_plus是paddlex内置的SLANet升级版模型,准确率有大幅提升

unitable是来源unitable的transformer模型,精度最高,暂仅支持pytorch推理,支持gpu推理加速,训练权重来源于 [OhMyTable项目](https://github.com/Sanster/OhMyTable)

### 效果展示

<div align="center">
<img src="https://github.com/RapidAI/RapidTable/releases/download/assets/preview.gif" alt="Demo" width="80%" height="80%">
</div>

### 模型列表

| `model_type` | 模型名称 | 推理框架 |模型大小 |推理耗时(单图 60KB)|
|:--------------|:--------------------------------------| :------: |:------ |:------ |
| `ppstructure_en` | `en_ppstructure_mobile_v2_SLANet.onnx` | onnxruntime |7.3M |0.15s |
Expand All @@ -33,23 +41,7 @@ unitable是来源unitable的transformer模型,精度最高,暂仅支持pytor
[PaddleX-SlaNetPlus 表格识别](https://github.com/PaddlePaddle/PaddleX/blob/release/3.0-beta1/docs/module_usage/tutorials/ocr_modules/table_structure_recognition.md)\
[Unitable](https://github.com/poloclub/unitable?tab=readme-ov-file)

模型下载地址为:[link](https://github.com/RapidAI/RapidTable/releases/tag/assets)

### 效果展示

<div align="center">
<img src="https://github.com/RapidAI/RapidTable/releases/download/assets/preview.gif" alt="Demo" width="100%" height="100%">
</div>

### [TableStructureRec](https://github.com/RapidAI/TableStructureRec)关系

TableStructureRec库是一个表格识别算法的集合库,当前有`wired_table_rec`有线表格识别算法和`lineless_table_rec`无线表格识别算法的推理包。

RapidTable是整理自PP-Structure中表格识别部分而来。由于PP-Structure较早,这个库命名就成了`rapid_table`

总之,RapidTable和TabelStructureRec都是表格识别的仓库。大家可以都试试,哪个好用用哪个。由于每个算法都不太同,暂时不打算做统一处理。

关于表格识别算法的比较,可参见[TableStructureRec测评](https://github.com/RapidAI/TableStructureRec#指标结果)
模型托管在modelscope上,具体下载地址为:[link](https://www.modelscope.cn/models/RapidAI/RapidTable/files)

### 安装

Expand All @@ -73,11 +65,33 @@ pip install onnxruntime-gpu # for onnx gpu inference

#### python脚本运行

RapidTable类提供model_path参数,可以自行指定上述2个模型,默认是`slanet-plus.onnx`。举例如下
> ⚠️注意:在`rapid_table>=1.0.0`之后,模型输入均采用dataclasses封装,简化和兼容参数传递。输入和输出定义如下
```python
table_engine = RapidTable()
# table_engine = RapidTable(use_cuda=True, device="cuda:0", model_type="unitable")
# 输入
@dataclass
class RapidTableInput:
model_type: Optional[str] = ModelType.SLANETPLUS.value
model_path: Union[str, Path, None, Dict[str, str]] = None
use_cuda: bool = False
device: str = "cpu"

# 输出
@dataclass
class RapidTableOutput:
pred_html: Optional[str] = None
cell_bboxes: Optional[np.ndarray] = None
logic_points: Optional[np.ndarray] = None
elapse: Optional[float] = None

# 使用示例
input_args = RapidTableInput(model_type="unitable")
table_engine = RapidTable(input_args)

img_path = 'test_images/table.jpg'
table_results = table_engine(img_path)

print(table_results.pred_html)
```

完整示例:
Expand All @@ -92,57 +106,45 @@ from rapid_table.table_structure.utils import trans_char_ocr_res
table_engine = RapidTable()

# 开启onnx-gpu推理
# table_engine = RapidTable(use_cuda=True)
# input_args = RapidTableInput(use_cuda=True)
# table_engine = RapidTable(input_args)

# 使用torch推理版本的unitable模型
# table_engine = RapidTable(use_cuda=True, device="cuda:0", model_type="unitable")
# input_args = RapidTableInput(model_type="unitable", use_cuda=True, device="cuda:0")
# table_engine = RapidTable(input_args)

ocr_engine = RapidOCR()
viser = VisTable()

img_path = 'test_images/table.jpg'

ocr_result, _ = ocr_engine(img_path)

# 单字匹配
# ocr_result, _ = ocr_engine(img_path, return_word_box=True)
# ocr_result = trans_char_ocr_res(ocr_result)

table_html_str, table_cell_bboxes, elapse = table_engine(img_path, ocr_result)
table_results = table_engine(img_path, ocr_result)

save_dir = Path("./inference_results/")
save_dir.mkdir(parents=True, exist_ok=True)

save_html_path = save_dir / f"{Path(img_path).stem}.html"
save_drawed_path = save_dir / f"vis_{Path(img_path).name}"

viser(img_path, table_html_str, save_html_path, table_cell_bboxes, save_drawed_path)

# 返回逻辑坐标
# table_html_str, table_cell_bboxes, logic_points, elapse = table_engine(img_path, ocr_result, return_logic_points=True)

# save_logic_path = save_dir / f"vis_logic_{Path(img_path).name}"
# viser(img_path, table_html_str, save_html_path, table_cell_bboxes, save_drawed_path,logic_points, save_logic_path)

viser(
img_path,
table_results.pred_html,
save_html_path,
table_results.cell_bboxes,
save_drawed_path,
table_results.logic_points,
save_logic_path,
)
print(table_html_str)
```

#### 终端运行

```bash
$ rapid_table -h
usage: rapid_table [-h] [-v] -img IMG_PATH [-m MODEL_PATH]

optional arguments:
-h, --help show this help message and exit
-v, --vis Whether to visualize the layout results.
-img IMG_PATH, --img_path IMG_PATH
Path to image for layout.
-m MODEL_PATH, --model_path MODEL_PATH
The model path used for inference.
```

示例:

```bash
rapid_table -v -img test_images/table.jpg
```
Expand Down Expand Up @@ -289,6 +291,16 @@ rapid_table -v -img test_images/table.jpg

</div>

### [TableStructureRec](https://github.com/RapidAI/TableStructureRec)关系

TableStructureRec库是一个表格识别算法的集合库,当前有`wired_table_rec`有线表格识别算法和`lineless_table_rec`无线表格识别算法的推理包。

RapidTable是整理自PP-Structure中表格识别部分而来。由于PP-Structure较早,这个库命名就成了`rapid_table`

总之,RapidTable和TabelStructureRec都是表格识别的仓库。大家可以都试试,哪个好用用哪个。由于每个算法都不太同,暂时不打算做统一处理。

关于表格识别算法的比较,可参见[TableStructureRec测评](https://github.com/RapidAI/TableStructureRec#指标结果)

### 更新日志

<details>
Expand Down
13 changes: 7 additions & 6 deletions demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,14 @@
import cv2
from rapidocr_onnxruntime import RapidOCR, VisRes

from rapid_table import RapidTable, VisTable
from rapid_table import RapidTable, RapidTableInput, VisTable

# Init
ocr_engine = RapidOCR()
vis_ocr = VisRes()
table_engine = RapidTable()

input_args = RapidTableInput(model_type="unitable")
table_engine = RapidTable(input_args)
viser = VisTable()

img_path = "tests/test_files/table.jpg"
Expand All @@ -21,7 +23,8 @@
boxes, txts, scores = list(zip(*ocr_result))

# Table Rec
table_html_str, table_cell_bboxes, _ = table_engine(img_path, ocr_result)
table_results = table_engine(img_path, ocr_result)
table_html_str, table_cell_bboxes = table_results.pred_html, table_results.cell_bboxes

# Save
save_dir = Path("outputs")
Expand All @@ -31,9 +34,7 @@
save_drawed_path = save_dir / f"{Path(img_path).stem}_table_vis{Path(img_path).suffix}"

# Visualize table rec result
vis_imged = viser(
img_path, table_html_str, save_html_path, table_cell_bboxes, save_drawed_path
)
vis_imged = viser(img_path, table_results, save_html_path, save_drawed_path)

# Visualize OCR result
save_ocr_path = save_dir / f"{Path(img_path).stem}_ocr_vis{Path(img_path).suffix}"
Expand Down
32 changes: 0 additions & 32 deletions demo_torch.py

This file was deleted.

4 changes: 2 additions & 2 deletions rapid_table/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# -*- encoding: utf-8 -*-
# @Author: SWHL
# @Contact: [email protected]
from .main import RapidTable
from .utils import VisTable
from .main import RapidTable, RapidTableInput
from .utils.utils import VisTable
Loading

0 comments on commit cf06b0e

Please sign in to comment.