Skip to content

Commit

Permalink
Merge pull request #25 from RapidAI/adapt_onnx_gpu_inference
Browse files Browse the repository at this point in the history
feat: adapt for onnx-gpu
  • Loading branch information
SWHL authored Nov 23, 2024
2 parents 604aa2d + 814d49b commit f9a2a6c
Show file tree
Hide file tree
Showing 5 changed files with 292 additions and 67 deletions.
87 changes: 50 additions & 37 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,47 @@ slanet_plus是paddlex内置的SLANet升级版模型,准确率有大幅提升
<img src="https://github.com/RapidAI/RapidTable/releases/download/assets/preview.gif" alt="Demo" width="100%" height="100%">
</div>

### [TableStructureRec](https://github.com/RapidAI/TableStructureRec) 关系
### 更新日志

<details>

#### 2024.11.24 update
- 支持gpu推理,适配 rapidOCR 单字识别匹配

#### 2024.10.13 update
- 补充最新paddlex-SLANet-plus 模型(paddle2onnx原因暂不能支持onnx)

#### 2023-12-29 v0.1.3 update

- 优化可视化结果部分

#### 2023-12-27 v0.1.2 update

- 添加返回cell坐标框参数
- 完善可视化函数

#### 2023-07-17 v0.1.0 update

-`rapidocr_onnxruntime`部分从`rapid_table`中解耦合出来,给出选项是否依赖,更加灵活。

- 增加接口输入参数`ocr_result`
- 如果在调用函数时,事先指定了`ocr_result`参数值,则不会再走OCR。其中`ocr_result`格式需要和`rapidocr_onnxruntime`返回值一致。
- 如果未指定`ocr_result`参数值,但是事先安装了`rapidocr_onnxruntime`库,则会自动调用该库,进行识别。
- 如果`ocr_result`未指定,且`rapidocr_onnxruntime`未安装,则会报错。必须满足两个条件中一个。

#### 2023-07-10 v0.0.13 updata

- 更改传入表格还原中OCR的实例接口,可以传入其他OCR实例,前提要与`rapidocr_onnxruntime`接口一致

#### 2023-07-06 v0.0.12 update

- 去掉返回表格的html字符串中的`<thead></thead><tbody></tbody>`元素,便于后续统一。
- 采用Black工具优化代码

</details>


### [TableStructureRec](https://github.com/RapidAI/TableStructureRec)关系

TableStructureRec库是一个表格识别算法的集合库,当前有`wired_table_rec`有线表格识别算法和`lineless_table_rec`无线表格识别算法的推理包。

Expand All @@ -58,6 +98,7 @@ RapidTable是整理自PP-Structure中表格识别部分而来。由于PP-Structu
```bash
pip install rapidocr_onnxruntime
pip install rapid_table
#pip install onnxruntime-gpu # for gpu inference
```

### 使用方式
Expand All @@ -76,14 +117,22 @@ table_engine = RapidTable()
from pathlib import Path

from rapid_table import RapidTable, VisTable
from rapidocr_onnxruntime import RapidOCR
from rapid_table.table_structure.utils import trans_char_ocr_res


table_engine = RapidTable()
# 开启onnx-gpu推理
# table_engine = RapidTable(use_cuda=True)
ocr_engine = RapidOCR()
viser = VisTable()

img_path = 'test_images/table.jpg'

ocr_result, _ = ocr_engine(img_path)
# 单字匹配
# ocr_result, _ = ocr_engine(img_path, return_word_box=True)
# ocr_result = trans_char_ocr_res(ocr_result)
table_html_str, table_cell_bboxes, elapse = table_engine(img_path, ocr_result)

save_dir = Path("./inference_results/")
Expand Down Expand Up @@ -134,39 +183,3 @@ print(table_html_str)
<table><tr><td>Methods</td><td></td><td></td><td></td><td>FPS</td></tr><tr><td>SegLink [26]</td><td>70.0</td><td>86d><td.0</td><td>77.0</td><td>8.9</td></tr><tr><td>PixelLink [4]</td><td>73.2</td><td>83.0</td><td>77.8</td><td></td></tr><tr><td>TextSnake [18]</td><td>73.9</td><td>83.2</td><td>78.3</td><td>1.1</td></tr><tr><td>TextField [37]</td><td>75.9</td><td>87.4</td><td>81.3</td><td>5.2</td></tr><tr><td>MSR[38]</td><td>76.7</td><td>87.87.4</td><td>81.7</td><td></td></tr><tr><td>FTSN [3]</td><td>77.1</td><td>87.6</td><td>82.0</td><td></td></tr><tr><td>LSE[30]</td><td>81.7</td><td>84.2</td><td>82.9</td><><ttd></td></tr><tr><td>CRAFT [2]</td><td>78.2</td><td>88.2</td><td>82.9</td><td>8.6</td></tr><tr><td>MCN[16]</td><td>79</td><td>88</td><td>83</td><td></td></tr><tr><td>ATRR</>[35]</td><td>82.1</td><td>85.2</td><td>83.6</td><td></td></tr><tr><td>PAN [34]</td><td>83.8</td><td>84.4</td><td>84.1</td><td>30.2</td></tr><tr><td>DB[12]</td><td>79.2</t91/d><td>91.5</td><td>84.9</td><td>32.0</td></tr><tr><td>DRRG[41]</td><td>82.30</td><td>88.05</td><td>85.08</td><td></td></tr><tr><td>Ours (SynText)</td><td>80.68</td><td>85<t..40</td><td>82.97</td><td>12.68</td></tr><tr><td>Ours (MLT-17)</td><td>84.54</td><td>86.62</td><td>85.57</td><td>12.31</td></tr></table>

</div>

### 更新日志

<details>

#### 2023-12-29 v0.1.3 update

- 优化可视化结果部分

#### 2023-12-27 v0.1.2 update

- 添加返回cell坐标框参数
- 完善可视化函数

#### 2023-07-17 v0.1.0 update

-`rapidocr_onnxruntime`部分从`rapid_table`中解耦合出来,给出选项是否依赖,更加灵活。

- 增加接口输入参数`ocr_result`
- 如果在调用函数时,事先指定了`ocr_result`参数值,则不会再走OCR。其中`ocr_result`格式需要和`rapidocr_onnxruntime`返回值一致。
- 如果未指定`ocr_result`参数值,但是事先安装了`rapidocr_onnxruntime`库,则会自动调用该库,进行识别。
- 如果`ocr_result`未指定,且`rapidocr_onnxruntime`未安装,则会报错。必须满足两个条件中一个。

#### 2023-07-10 v0.0.13 updata

- 更改传入表格还原中OCR的实例接口,可以传入其他OCR实例,前提要与`rapidocr_onnxruntime`接口一致

#### 2023-07-06 v0.0.12 update

- 去掉返回表格的html字符串中的`<thead></thead><tbody></tbody>`元素,便于后续统一。
- 采用Black工具优化代码

#### 2024.10.13 update
- 补充最新paddlex-SLANet-plus 模型(paddle2onnx原因暂不能支持onnx)

</details>
8 changes: 6 additions & 2 deletions rapid_table/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,19 @@


class RapidTable:
def __init__(self, model_path: Optional[str] = None, model_type: str = None):
def __init__(self, model_path: Optional[str] = None, model_type: str = None, use_cuda: bool = False):
if model_path is None:
model_path = str(
root_dir / "models" / "slanet-plus.onnx"
)
model_type = "slanet-plus"
self.model_type = model_type
self.load_img = LoadImage()
self.table_structure = TableStructurer(model_path)
config = {
"model_path": model_path,
"use_cuda": use_cuda,
}
self.table_structure = TableStructurer(config)
self.table_matcher = TableMatch()

try:
Expand Down
21 changes: 21 additions & 0 deletions rapid_table/table_structure/logger.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# -*- encoding: utf-8 -*-
# @Author: Jocker1212
# @Contact: [email protected]
import logging
from functools import lru_cache


@lru_cache(maxsize=32)
def get_logger(name: str) -> logging.Logger:
logger = logging.getLogger(name)
logger.setLevel(logging.DEBUG)

fmt = "%(asctime)s - %(name)s - %(levelname)s: %(message)s"
format_str = logging.Formatter(fmt)

sh = logging.StreamHandler()
sh.setLevel(logging.DEBUG)

logger.addHandler(sh)
sh.setFormatter(format_str)
return logger
7 changes: 4 additions & 3 deletions rapid_table/table_structure/table_structure.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,17 +12,18 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import time
from typing import Dict, Any

import numpy as np

from .utils import OrtInferSession, TableLabelDecode, TablePreprocess


class TableStructurer:
def __init__(self, model_path: str):
def __init__(self, config: Dict[str, Any]):
self.preprocess_op = TablePreprocess()

self.session = OrtInferSession(model_path)
self.session = OrtInferSession(config)

self.character = self.session.get_metadata()
self.postprocess_op = TableLabelDecode(self.character)
Expand All @@ -37,7 +38,7 @@ def __call__(self, img):
img = np.expand_dims(img, axis=0)
img = img.copy()

outputs = self.session(img)
outputs = self.session([img])

preds = {"loc_preds": outputs[0], "structure_probs": outputs[1]}

Expand Down
Loading

0 comments on commit f9a2a6c

Please sign in to comment.