Skip to content

Commit

Permalink
chore: update files
Browse files Browse the repository at this point in the history
  • Loading branch information
SWHL committed Jan 19, 2025
1 parent 86ae3f5 commit c2b409f
Show file tree
Hide file tree
Showing 6 changed files with 110 additions and 185 deletions.
21 changes: 16 additions & 5 deletions python/demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,18 +15,29 @@
with open(image_path, "rb") as f:
img = f.read()

result, elapse_list = engine(img, return_word_box=True)
# result, elapse_list = engine(img, use_det=True, use_cls=False, use_rec=False)
result, elapse_list = engine(img)
print(result)
print(elapse_list)

(boxes, txts, scores, words_boxes, words) = list(zip(*result))
# result, elapse = engine(image_path, use_det=False, use_cls=False, use_rec=True)

(boxes, txts, scores, words_boxes, words, words_scores) = list(zip(*result))
font_path = "resources/fonts/FZYTK.TTF"
vis_img = vis(img, boxes, txts, scores, font_path)
cv2.imwrite("vis.png", vis_img)

words_boxes = sum(words_boxes, [])
words_all = sum(words, [])
words_scores = [1.0] * len(words_boxes)
words_scores = sum(words_scores, [])
vis_img = vis(img, words_boxes, words_all, words_scores, font_path)
cv2.imwrite("vis_single.png", vis_img)

# (boxes, txts, scores, words_boxes, words) = list(zip(*result))

# vis_img = vis(img, boxes, txts, scores, font_path)
# cv2.imwrite("vis.png", vis_img)

# words_boxes = sum(words_boxes, [])
# words_all = sum(words, [])
# words_scores = [1.0] * len(words_boxes)
# vis_img = vis(img, words_boxes, words_all, words_scores, font_path)
# cv2.imwrite("vis_single.png", vis_img)
25 changes: 16 additions & 9 deletions python/rapidocr_onnxruntime/cal_rec_boxes/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,13 @@
# @Contact: [email protected]
import copy
import math
from typing import Any, List, Optional, Tuple
from typing import List, Optional, Tuple

import cv2
import numpy as np

from ..ch_ppocr_rec.utils import TextRecognizerOutput


class CalRecBoxes:
"""计算识别文字的汉字单字和英文单词的坐标框。代码借鉴自PaddlePaddle/PaddleOCR和fanqie03/char-detection"""
Expand All @@ -19,13 +21,16 @@ def __call__(
self,
imgs: Optional[List[np.ndarray]],
dt_boxes: Optional[List[np.ndarray]],
rec_res: Optional[List[Any]],
):
res = []
for img, box, rec_res in zip(imgs, dt_boxes, rec_res):
rec_res: TextRecognizerOutput,
) -> TextRecognizerOutput:
# rec_res = list(zip(rec_res.line_results, rec_res.word_results))
word_results = []
for idx, (img, box) in enumerate(zip(imgs, dt_boxes)):
direction = self.get_box_direction(box)

rec_txt, rec_conf, rec_word_info = rec_res[0], rec_res[1], rec_res[2]
rec_txt, rec_conf = rec_res.line_results[idx]
rec_word_info = rec_res.word_results[idx]

h, w = img.shape[:2]
img_box = np.array([[0, 0], [w, 0], [w, h], [0, h]])
word_box_content_list, word_box_list, conf_list = self.cal_ocr_word_box(
Expand All @@ -35,10 +40,12 @@ def __call__(
word_box_list = self.reverse_rotate_crop_image(
copy.deepcopy(box), word_box_list, direction
)
res.append(
[rec_txt, rec_conf, word_box_list, word_box_content_list, conf_list]
word_results.extend(
list(zip(word_box_content_list, conf_list, word_box_list))
)
return res

rec_res.word_results = tuple(list(v) for v in word_results)
return rec_res

@staticmethod
def get_box_direction(box: np.ndarray) -> str:
Expand Down
3 changes: 2 additions & 1 deletion python/rapidocr_onnxruntime/ch_ppocr_rec/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# -*- encoding: utf-8 -*-
# @Author: SWHL
# @Contact: [email protected]
from .text_recognize import TextRecognizer
from .main import TextRecognizer
from .utils import TextRecognizerOutput
130 changes: 0 additions & 130 deletions python/rapidocr_onnxruntime/ch_ppocr_rec/text_recognize.py

This file was deleted.

71 changes: 50 additions & 21 deletions python/rapidocr_onnxruntime/ch_ppocr_rec/utils.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,43 @@
# -*- encoding: utf-8 -*-
# @Author: SWHL
# @Contact: [email protected]
from dataclasses import dataclass, field
from pathlib import Path
from typing import List, Optional, Tuple, Union

import numpy as np


@dataclass
class TextRecognizerConfig:
intra_op_num_threads: int = -1
inter_op_num_threads: int = -1
use_cuda: bool = False
use_dml: bool = False
model_path: Union[str, Path, None] = None

rec_batch_num: int = 6
rec_img_shape: Tuple[int, int, int] = (3, 48, 320)
rec_keys_path: Union[str, Path, None] = None


@dataclass
class TextRecognizerInput:
img_list: Union[np.ndarray, List[np.ndarray]] = field(init=False)
return_word_box: bool = False

def __pose_init__(self):
if isinstance(self.img_list, np.ndarray):
self.img_list = [self.img_list]


@dataclass
class TextRecognizerOutput:
line_results: Optional[Tuple[List]] = None
word_results: Optional[List[List]] = None
elapse: Optional[float] = None


class CTCLabelDecode:
def __init__(
self,
Expand All @@ -18,18 +49,18 @@ def __init__(

def __call__(
self, preds: np.ndarray, return_word_box: bool = False, **kwargs
) -> List[Tuple[str, float]]:
) -> Tuple[List[Tuple[str, float]], List[List]]:
preds_idx = preds.argmax(axis=2)
preds_prob = preds.max(axis=2)
text = self.decode(
line_results, word_results = self.decode(
preds_idx, preds_prob, return_word_box, is_remove_duplicate=True
)
if return_word_box:
for rec_idx, rec in enumerate(text):
for rec_idx, rec in enumerate(word_results):
wh_ratio = kwargs["wh_ratio_list"][rec_idx]
max_wh_ratio = kwargs["max_wh_ratio"]
rec[2][0] = rec[2][0] * (wh_ratio / max_wh_ratio)
return text
rec[0] *= wh_ratio / max_wh_ratio
return line_results, word_results

def get_character(
self,
Expand Down Expand Up @@ -80,7 +111,7 @@ def decode(
is_remove_duplicate: bool = False,
) -> List[Tuple[str, float]]:
"""convert text-index into text-label."""
result_list = []
result_list, result_words_list = [], []
ignored_tokens = self.get_ignored_tokens()
batch_size = len(text_index)
for batch_idx in range(batch_size):
Expand All @@ -93,6 +124,7 @@ def decode(

if text_prob is not None:
conf_list = np.array(text_prob[batch_idx][selection]).tolist()
conf_list = [round(conf, 5) for conf in conf_list]
else:
conf_list = [1] * len(selection)

Expand All @@ -103,26 +135,23 @@ def decode(
self.character[text_id] for text_id in text_index[batch_idx][selection]
]
text = "".join(char_list)

result_list.append([text, np.mean(conf_list).round(5).tolist()])

if return_word_box:
word_list, word_col_list, state_list = self.get_word_info(
text, selection
)
result_list.append(
(
text,
np.mean(conf_list).tolist(),
[
len(text_index[batch_idx]),
word_list,
word_col_list,
state_list,
conf_list,
],
)
result_words_list.append(
[
len(text_index[batch_idx]),
word_list,
word_col_list,
state_list,
conf_list,
]
)
else:
result_list.append((text, np.mean(conf_list).tolist()))
return result_list
return result_list, result_words_list

@staticmethod
def get_word_info(
Expand Down
Loading

0 comments on commit c2b409f

Please sign in to comment.