-
-
Notifications
You must be signed in to change notification settings - Fork 392
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
6 changed files
with
110 additions
and
185 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,11 +3,13 @@ | |
# @Contact: [email protected] | ||
import copy | ||
import math | ||
from typing import Any, List, Optional, Tuple | ||
from typing import List, Optional, Tuple | ||
|
||
import cv2 | ||
import numpy as np | ||
|
||
from ..ch_ppocr_rec.utils import TextRecognizerOutput | ||
|
||
|
||
class CalRecBoxes: | ||
"""计算识别文字的汉字单字和英文单词的坐标框。代码借鉴自PaddlePaddle/PaddleOCR和fanqie03/char-detection""" | ||
|
@@ -19,13 +21,16 @@ def __call__( | |
self, | ||
imgs: Optional[List[np.ndarray]], | ||
dt_boxes: Optional[List[np.ndarray]], | ||
rec_res: Optional[List[Any]], | ||
): | ||
res = [] | ||
for img, box, rec_res in zip(imgs, dt_boxes, rec_res): | ||
rec_res: TextRecognizerOutput, | ||
) -> TextRecognizerOutput: | ||
# rec_res = list(zip(rec_res.line_results, rec_res.word_results)) | ||
word_results = [] | ||
for idx, (img, box) in enumerate(zip(imgs, dt_boxes)): | ||
direction = self.get_box_direction(box) | ||
|
||
rec_txt, rec_conf, rec_word_info = rec_res[0], rec_res[1], rec_res[2] | ||
rec_txt, rec_conf = rec_res.line_results[idx] | ||
rec_word_info = rec_res.word_results[idx] | ||
|
||
h, w = img.shape[:2] | ||
img_box = np.array([[0, 0], [w, 0], [w, h], [0, h]]) | ||
word_box_content_list, word_box_list, conf_list = self.cal_ocr_word_box( | ||
|
@@ -35,10 +40,12 @@ def __call__( | |
word_box_list = self.reverse_rotate_crop_image( | ||
copy.deepcopy(box), word_box_list, direction | ||
) | ||
res.append( | ||
[rec_txt, rec_conf, word_box_list, word_box_content_list, conf_list] | ||
word_results.extend( | ||
list(zip(word_box_content_list, conf_list, word_box_list)) | ||
) | ||
return res | ||
|
||
rec_res.word_results = tuple(list(v) for v in word_results) | ||
return rec_res | ||
|
||
@staticmethod | ||
def get_box_direction(box: np.ndarray) -> str: | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,5 @@ | ||
# -*- encoding: utf-8 -*- | ||
# @Author: SWHL | ||
# @Contact: [email protected] | ||
from .text_recognize import TextRecognizer | ||
from .main import TextRecognizer | ||
from .utils import TextRecognizerOutput |
130 changes: 0 additions & 130 deletions
130
python/rapidocr_onnxruntime/ch_ppocr_rec/text_recognize.py
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,12 +1,43 @@ | ||
# -*- encoding: utf-8 -*- | ||
# @Author: SWHL | ||
# @Contact: [email protected] | ||
from dataclasses import dataclass, field | ||
from pathlib import Path | ||
from typing import List, Optional, Tuple, Union | ||
|
||
import numpy as np | ||
|
||
|
||
@dataclass | ||
class TextRecognizerConfig: | ||
intra_op_num_threads: int = -1 | ||
inter_op_num_threads: int = -1 | ||
use_cuda: bool = False | ||
use_dml: bool = False | ||
model_path: Union[str, Path, None] = None | ||
|
||
rec_batch_num: int = 6 | ||
rec_img_shape: Tuple[int, int, int] = (3, 48, 320) | ||
rec_keys_path: Union[str, Path, None] = None | ||
|
||
|
||
@dataclass | ||
class TextRecognizerInput: | ||
img_list: Union[np.ndarray, List[np.ndarray]] = field(init=False) | ||
return_word_box: bool = False | ||
|
||
def __pose_init__(self): | ||
if isinstance(self.img_list, np.ndarray): | ||
self.img_list = [self.img_list] | ||
|
||
|
||
@dataclass | ||
class TextRecognizerOutput: | ||
line_results: Optional[Tuple[List]] = None | ||
word_results: Optional[List[List]] = None | ||
elapse: Optional[float] = None | ||
|
||
|
||
class CTCLabelDecode: | ||
def __init__( | ||
self, | ||
|
@@ -18,18 +49,18 @@ def __init__( | |
|
||
def __call__( | ||
self, preds: np.ndarray, return_word_box: bool = False, **kwargs | ||
) -> List[Tuple[str, float]]: | ||
) -> Tuple[List[Tuple[str, float]], List[List]]: | ||
preds_idx = preds.argmax(axis=2) | ||
preds_prob = preds.max(axis=2) | ||
text = self.decode( | ||
line_results, word_results = self.decode( | ||
preds_idx, preds_prob, return_word_box, is_remove_duplicate=True | ||
) | ||
if return_word_box: | ||
for rec_idx, rec in enumerate(text): | ||
for rec_idx, rec in enumerate(word_results): | ||
wh_ratio = kwargs["wh_ratio_list"][rec_idx] | ||
max_wh_ratio = kwargs["max_wh_ratio"] | ||
rec[2][0] = rec[2][0] * (wh_ratio / max_wh_ratio) | ||
return text | ||
rec[0] *= wh_ratio / max_wh_ratio | ||
return line_results, word_results | ||
|
||
def get_character( | ||
self, | ||
|
@@ -80,7 +111,7 @@ def decode( | |
is_remove_duplicate: bool = False, | ||
) -> List[Tuple[str, float]]: | ||
"""convert text-index into text-label.""" | ||
result_list = [] | ||
result_list, result_words_list = [], [] | ||
ignored_tokens = self.get_ignored_tokens() | ||
batch_size = len(text_index) | ||
for batch_idx in range(batch_size): | ||
|
@@ -93,6 +124,7 @@ def decode( | |
|
||
if text_prob is not None: | ||
conf_list = np.array(text_prob[batch_idx][selection]).tolist() | ||
conf_list = [round(conf, 5) for conf in conf_list] | ||
else: | ||
conf_list = [1] * len(selection) | ||
|
||
|
@@ -103,26 +135,23 @@ def decode( | |
self.character[text_id] for text_id in text_index[batch_idx][selection] | ||
] | ||
text = "".join(char_list) | ||
|
||
result_list.append([text, np.mean(conf_list).round(5).tolist()]) | ||
|
||
if return_word_box: | ||
word_list, word_col_list, state_list = self.get_word_info( | ||
text, selection | ||
) | ||
result_list.append( | ||
( | ||
text, | ||
np.mean(conf_list).tolist(), | ||
[ | ||
len(text_index[batch_idx]), | ||
word_list, | ||
word_col_list, | ||
state_list, | ||
conf_list, | ||
], | ||
) | ||
result_words_list.append( | ||
[ | ||
len(text_index[batch_idx]), | ||
word_list, | ||
word_col_list, | ||
state_list, | ||
conf_list, | ||
] | ||
) | ||
else: | ||
result_list.append((text, np.mean(conf_list).tolist())) | ||
return result_list | ||
return result_list, result_words_list | ||
|
||
@staticmethod | ||
def get_word_info( | ||
|
Oops, something went wrong.