Skip to content

Commit

Permalink
chore: optimze char rec (#260)
Browse files Browse the repository at this point in the history
* chore: format code

* chore: format code

* test: update unit tests for different platform

* feat(rapidocr_onnxruntime): optimize code

* feat(rapidocr_onnxruntime): add --return_word_box params of CLI

* chore: update demo

* feat(rapidocr_onnxruntime): update vis of return_word_box
  • Loading branch information
SWHL authored Nov 19, 2024
1 parent 2ac3c35 commit df36c69
Show file tree
Hide file tree
Showing 11 changed files with 364 additions and 250 deletions.
13 changes: 9 additions & 4 deletions python/demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,20 +8,25 @@
# from rapidocr_paddle import RapidOCR, VisRes
# from rapidocr_openvino import RapidOCR, VisRes


engine = RapidOCR()
vis = VisRes()

image_path = "tests/test_files/black_font_color_transparent.png"
image_path = "tests/test_files/ch_en_num.jpg"
with open(image_path, "rb") as f:
img = f.read()

result, elapse_list = engine(img)
result, elapse_list = engine(img, return_word_box=True)
print(result)
print(elapse_list)

boxes, txts, scores = list(zip(*result))
(boxes, txts, scores, words_boxes, words) = list(zip(*result))

font_path = "resources/fonts/FZYTK.TTF"
vis_img = vis(img, boxes, txts, scores, font_path)
cv2.imwrite("vis.png", vis_img)

words_boxes = sum(words_boxes, [])
words_all = sum(words, [])
words_scores = [1.0] * len(words_boxes)
vis_img = vis(img, words_boxes, words_all, words_scores, font_path)
cv2.imwrite("vis_single.png", vis_img)
4 changes: 4 additions & 0 deletions python/rapidocr_onnxruntime/cal_rec_boxes/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
# -*- encoding: utf-8 -*-
# @Author: SWHL
# @Contact: [email protected]
from .main import CalRecBoxes
260 changes: 260 additions & 0 deletions python/rapidocr_onnxruntime/cal_rec_boxes/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,260 @@
# -*- encoding: utf-8 -*-
# @Author: SWHL / Joker1212
# @Contact: [email protected]
import copy
import math
from typing import Any, List, Optional, Tuple

import cv2
import numpy as np


class CalRecBoxes:
"""计算识别文字的汉字单字和英文单词的坐标框。代码借鉴自PaddlePaddle/PaddleOCR和fanqie03/char-detection"""

def __init__(self):
pass

def __call__(
self,
imgs: Optional[List[np.ndarray]],
dt_boxes: Optional[List[np.ndarray]],
rec_res: Optional[List[Any]],
):
res = []
for img, box, rec_res in zip(imgs, dt_boxes, rec_res):
direction = self.get_box_direction(box)

rec_txt, rec_conf, rec_word_info = rec_res[0], rec_res[1], rec_res[2]
h, w = img.shape[:2]
img_box = np.array([[0, 0], [w, 0], [w, h], [0, h]])
word_box_content_list, word_box_list = self.cal_ocr_word_box(
rec_txt, img_box, rec_word_info
)
word_box_list = self.adjust_box_overlap(copy.deepcopy(word_box_list))
word_box_list = self.reverse_rotate_crop_image(
copy.deepcopy(box), word_box_list, direction
)
res.append([rec_txt, rec_conf, word_box_list, word_box_content_list])
return res

@staticmethod
def get_box_direction(box: np.ndarray) -> str:
direction = "w"
img_crop_width = int(
max(
np.linalg.norm(box[0] - box[1]),
np.linalg.norm(box[2] - box[3]),
)
)
img_crop_height = int(
max(
np.linalg.norm(box[0] - box[3]),
np.linalg.norm(box[1] - box[2]),
)
)
if img_crop_height * 1.0 / img_crop_width >= 1.5:
direction = "h"
return direction

@staticmethod
def cal_ocr_word_box(
rec_txt: str, box: np.ndarray, rec_word_info: List[Tuple[str, List[int]]]
) -> Tuple[List[str], List[List[int]]]:
"""Calculate the detection frame for each word based on the results of recognition and detection of ocr
汉字坐标是单字的
英语坐标是单词级别的
"""

col_num, word_list, word_col_list, state_list = rec_word_info
box = box.tolist()
bbox_x_start = box[0][0]
bbox_x_end = box[1][0]
bbox_y_start = box[0][1]
bbox_y_end = box[2][1]

cell_width = (bbox_x_end - bbox_x_start) / col_num
word_box_list = []
word_box_content_list = []
cn_width_list = []
cn_col_list = []
for word, word_col, state in zip(word_list, word_col_list, state_list):
if state == "cn":
if len(word_col) != 1:
char_seq_length = (word_col[-1] - word_col[0] + 1) * cell_width
char_width = char_seq_length / (len(word_col) - 1)
cn_width_list.append(char_width)
cn_col_list += word_col
word_box_content_list += word
else:
cell_x_start = bbox_x_start + int(word_col[0] * cell_width)
cell_x_end = bbox_x_start + int((word_col[-1] + 1) * cell_width)
cell = [
[cell_x_start, bbox_y_start],
[cell_x_end, bbox_y_start],
[cell_x_end, bbox_y_end],
[cell_x_start, bbox_y_end],
]
word_box_list.append(cell)
word_box_content_list.append("".join(word))

if len(cn_col_list) != 0:
if len(cn_width_list) != 0:
avg_char_width = np.mean(cn_width_list)
else:
avg_char_width = (bbox_x_end - bbox_x_start) / len(rec_txt)

for center_idx in cn_col_list:
center_x = (center_idx + 0.5) * cell_width
cell_x_start = max(int(center_x - avg_char_width / 2), 0) + bbox_x_start
cell_x_end = (
min(int(center_x + avg_char_width / 2), bbox_x_end - bbox_x_start)
+ bbox_x_start
)
cell = [
[cell_x_start, bbox_y_start],
[cell_x_end, bbox_y_start],
[cell_x_end, bbox_y_end],
[cell_x_start, bbox_y_end],
]
word_box_list.append(cell)
sorted_word_box_list = sorted(word_box_list, key=lambda box: box[0][0])
return word_box_content_list, sorted_word_box_list

@staticmethod
def adjust_box_overlap(
word_box_list: List[List[List[int]]],
) -> List[List[List[int]]]:
# 调整bbox有重叠的地方
for i in range(len(word_box_list) - 1):
cur, nxt = word_box_list[i], word_box_list[i + 1]
if cur[1][0] > nxt[0][0]: # 有交集
distance = abs(cur[1][0] - nxt[0][0])
cur[1][0] -= distance / 2
cur[2][0] -= distance / 2
nxt[0][0] += distance / 2
nxt[3][0] += distance / 2
return word_box_list

def reverse_rotate_crop_image(
self,
bbox_points: np.ndarray,
word_points_list: List[List[List[int]]],
direction: str = "w",
) -> List[List[List[int]]]:
"""
get_rotate_crop_image的逆操作
img为原图
part_img为crop后的图
bbox_points为part_img中对应在原图的bbox, 四个点,左上,右上,右下,左下
part_points为在part_img中的点[(x, y), (x, y)]
"""
bbox_points = np.float32(bbox_points)

left = int(np.min(bbox_points[:, 0]))
top = int(np.min(bbox_points[:, 1]))
bbox_points[:, 0] = bbox_points[:, 0] - left
bbox_points[:, 1] = bbox_points[:, 1] - top

img_crop_width = int(np.linalg.norm(bbox_points[0] - bbox_points[1]))
img_crop_height = int(np.linalg.norm(bbox_points[0] - bbox_points[3]))

pts_std = np.array(
[
[0, 0],
[img_crop_width, 0],
[img_crop_width, img_crop_height],
[0, img_crop_height],
]
).astype(np.float32)
M = cv2.getPerspectiveTransform(bbox_points, pts_std)
_, IM = cv2.invert(M)

new_word_points_list = []
for word_points in word_points_list:
new_word_points = []
for point in word_points:
new_point = point
if direction == "h":
new_point = self.s_rotate(
math.radians(-90), new_point[0], new_point[1], 0, 0
)
new_point[0] = new_point[0] + img_crop_width

p = np.float32(new_point + [1])
x, y, z = np.dot(IM, p)
new_point = [x / z, y / z]

new_point = [int(new_point[0] + left), int(new_point[1] + top)]
new_word_points.append(new_point)
new_word_points = self.order_points(new_word_points)
new_word_points_list.append(new_word_points)
return new_word_points_list

@staticmethod
def s_rotate(angle, valuex, valuey, pointx, pointy):
"""绕pointx,pointy顺时针旋转
https://blog.csdn.net/qq_38826019/article/details/84233397
"""
valuex = np.array(valuex)
valuey = np.array(valuey)
sRotatex = (
(valuex - pointx) * math.cos(angle)
+ (valuey - pointy) * math.sin(angle)
+ pointx
)
sRotatey = (
(valuey - pointy) * math.cos(angle)
- (valuex - pointx) * math.sin(angle)
+ pointy
)
return [sRotatex, sRotatey]

@staticmethod
def order_points(box: List[List[int]]) -> List[List[int]]:
"""矩形框顺序排列"""
box = np.array(box).reshape((-1, 2))
center_x, center_y = np.mean(box[:, 0]), np.mean(box[:, 1])
if np.any(box[:, 0] == center_x) and np.any(
box[:, 1] == center_y
): # 有两点横坐标相等,有两点纵坐标相等,菱形
p1 = box[np.where(box[:, 0] == np.min(box[:, 0]))]
p2 = box[np.where(box[:, 1] == np.min(box[:, 1]))]
p3 = box[np.where(box[:, 0] == np.max(box[:, 0]))]
p4 = box[np.where(box[:, 1] == np.max(box[:, 1]))]
elif np.all(box[:, 0] == center_x): # 四个点的横坐标都相同
y_sort = np.argsort(box[:, 1])
p1 = box[y_sort[0]]
p2 = box[y_sort[1]]
p3 = box[y_sort[2]]
p4 = box[y_sort[3]]
elif np.any(box[:, 0] == center_x) and np.all(
box[:, 1] != center_y
): # 只有两点横坐标相等,先上下再左右
p12, p34 = (
box[np.where(box[:, 1] < center_y)],
box[np.where(box[:, 1] > center_y)],
)
p1, p2 = (
p12[np.where(p12[:, 0] == np.min(p12[:, 0]))],
p12[np.where(p12[:, 0] == np.max(p12[:, 0]))],
)
p3, p4 = (
p34[np.where(p34[:, 0] == np.max(p34[:, 0]))],
p34[np.where(p34[:, 0] == np.min(p34[:, 0]))],
)
else: # 只有两点纵坐标相等,或者是没有相等的,先左右再上下
p14, p23 = (
box[np.where(box[:, 0] < center_x)],
box[np.where(box[:, 0] > center_x)],
)
p1, p4 = (
p14[np.where(p14[:, 1] == np.min(p14[:, 1]))],
p14[np.where(p14[:, 1] == np.max(p14[:, 1]))],
)
p2, p3 = (
p23[np.where(p23[:, 1] == np.min(p23[:, 1]))],
p23[np.where(p23[:, 1] == np.max(p23[:, 1]))],
)

return np.array([p1, p2, p3, p4]).reshape((-1, 2)).tolist()
16 changes: 12 additions & 4 deletions python/rapidocr_onnxruntime/ch_ppocr_rec/text_recognize.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import cv2
import numpy as np

from rapidocr_onnxruntime.utils import OrtInferSession, read_yaml

from .utils import CTCLabelDecode
Expand All @@ -40,7 +41,9 @@ def __init__(self, config: Dict[str, Any]):
self.rec_image_shape = config["rec_img_shape"]

def __call__(
self, img_list: Union[np.ndarray, List[np.ndarray]], rec_word_box=False
self,
img_list: Union[np.ndarray, List[np.ndarray]],
return_word_box: bool = False,
) -> Tuple[List[Tuple[str, float]], float]:
if isinstance(img_list, np.ndarray):
img_list = [img_list]
Expand All @@ -58,6 +61,7 @@ def __call__(
elapse = 0
for beg_img_no in range(0, img_num, batch_num):
end_img_no = min(img_num, beg_img_no + batch_num)

# Parameter Alignment for PaddleOCR
imgC, imgH, imgW = self.rec_image_shape[:3]
max_wh_ratio = imgW / imgH
Expand All @@ -67,6 +71,7 @@ def __call__(
wh_ratio = w * 1.0 / h
max_wh_ratio = max(max_wh_ratio, wh_ratio)
wh_ratio_list.append(wh_ratio)

norm_img_batch = []
for ino in range(beg_img_no, end_img_no):
norm_img = self.resize_norm_img(img_list[indices[ino]], max_wh_ratio)
Expand All @@ -75,9 +80,12 @@ def __call__(

starttime = time.time()
preds = self.session(norm_img_batch)[0]
rec_result = self.postprocess_op(preds, rec_word_box,
wh_ratio_list=wh_ratio_list,
max_wh_ratio=max_wh_ratio,)
rec_result = self.postprocess_op(
preds,
return_word_box,
wh_ratio_list=wh_ratio_list,
max_wh_ratio=max_wh_ratio,
)

for rno, one_res in enumerate(rec_result):
rec_res[indices[beg_img_no + rno]] = one_res
Expand Down
Loading

0 comments on commit df36c69

Please sign in to comment.