Skip to content

Commit

Permalink
fix(rapidocr_paddle): merge PR #293 #294
Browse files Browse the repository at this point in the history
  • Loading branch information
SWHL committed Dec 12, 2024
1 parent 7f5594b commit 777f9f0
Show file tree
Hide file tree
Showing 4 changed files with 38 additions and 17 deletions.
17 changes: 11 additions & 6 deletions python/rapidocr_paddle/cal_rec_boxes/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,14 +28,16 @@ def __call__(
rec_txt, rec_conf, rec_word_info = rec_res[0], rec_res[1], rec_res[2]
h, w = img.shape[:2]
img_box = np.array([[0, 0], [w, 0], [w, h], [0, h]])
word_box_content_list, word_box_list = self.cal_ocr_word_box(
word_box_content_list, word_box_list, conf_list = self.cal_ocr_word_box(
rec_txt, img_box, rec_word_info
)
word_box_list = self.adjust_box_overlap(copy.deepcopy(word_box_list))
word_box_list = self.reverse_rotate_crop_image(
copy.deepcopy(box), word_box_list, direction
)
res.append([rec_txt, rec_conf, word_box_list, word_box_content_list])
res.append(
[rec_txt, rec_conf, word_box_list, word_box_content_list, conf_list]
)
return res

@staticmethod
Expand All @@ -60,13 +62,13 @@ def get_box_direction(box: np.ndarray) -> str:
@staticmethod
def cal_ocr_word_box(
rec_txt: str, box: np.ndarray, rec_word_info: List[Tuple[str, List[int]]]
) -> Tuple[List[str], List[List[int]]]:
) -> Tuple[List[str], List[List[int]], List[float]]:
"""Calculate the detection frame for each word based on the results of recognition and detection of ocr
汉字坐标是单字的
英语坐标是单词级别的
"""

col_num, word_list, word_col_list, state_list = rec_word_info
col_num, word_list, word_col_list, state_list, conf_list = rec_word_info
box = box.tolist()
bbox_x_start = box[0][0]
bbox_x_end = box[1][0]
Expand All @@ -84,7 +86,7 @@ def cal_ocr_word_box(
def cal_char_width(width_list, word_col_):
if len(word_col_) == 1:
return
char_total_length = (word_col_[-1] - word_col_[0] + 1) * cell_width
char_total_length = (word_col_[-1] - word_col_[0]) * cell_width
char_width = char_total_length / (len(word_col_) - 1)
width_list.append(char_width)

Expand Down Expand Up @@ -124,7 +126,7 @@ def cal_box(col_list, width_list, word_box_list_):
cal_box(cn_col_list, cn_width_list, word_box_list)
cal_box(en_col_list, en_width_list, word_box_list)
sorted_word_box_list = sorted(word_box_list, key=lambda box: box[0][0])
return word_box_content_list, sorted_word_box_list
return word_box_content_list, sorted_word_box_list, conf_list

@staticmethod
def adjust_box_overlap(
Expand Down Expand Up @@ -218,13 +220,15 @@ def s_rotate(angle, valuex, valuey, pointx, pointy):
@staticmethod
def order_points(box: List[List[int]]) -> List[List[int]]:
"""矩形框顺序排列"""

def convert_to_1x2(p):
if p.shape == (2,):
return p.reshape((1, 2))
elif p.shape == (1, 2):
return p
else:
return p[:1, :]

box = np.array(box).reshape((-1, 2))
center_x, center_y = np.mean(box[:, 0]), np.mean(box[:, 1])
if np.any(box[:, 0] == center_x) and np.any(
Expand Down Expand Up @@ -268,6 +272,7 @@ def convert_to_1x2(p):
p23[np.where(p23[:, 1] == np.min(p23[:, 1]))],
p23[np.where(p23[:, 1] == np.max(p23[:, 1]))],
)

# 解决单字切割后横坐标完全相同的shape错误
p1 = convert_to_1x2(p1)
p2 = convert_to_1x2(p2)
Expand Down
18 changes: 15 additions & 3 deletions python/rapidocr_paddle/ch_ppocr_det/text_detect.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,9 @@

class TextDetector:
def __init__(self, config: Dict[str, Any]):
limit_side_len = config.get("limit_side_len", 736)
limit_type = config.get("limit_type", "min")
self.preprocess_op = DetPreProcess(limit_side_len, limit_type)
self.limit_type = config.get("limit_type", "min")
self.limit_side_len = config.get("limit_side_len", 736)
self.preprocess_op = None

post_process = {
"thresh": config.get("thresh", 0.3),
Expand All @@ -49,6 +49,7 @@ def __call__(self, img: np.ndarray) -> Tuple[Optional[np.ndarray], float]:
raise ValueError("img is None")

ori_img_shape = img.shape[0], img.shape[1]
self.preprocess_op = self.get_preprocess(max(img.shape[0], img.shape[1]))
prepro_img = self.preprocess_op(img)
if prepro_img is None:
return None, 0
Expand All @@ -59,6 +60,17 @@ def __call__(self, img: np.ndarray) -> Tuple[Optional[np.ndarray], float]:
elapse = time.perf_counter() - start_time
return dt_boxes, elapse

def get_preprocess(self, max_wh):
if self.limit_type == "min":
limit_side_len = self.limit_side_len
elif max_wh < 960:
limit_side_len = 960
elif max_wh < 1500:
limit_side_len = 1500
else:
limit_side_len = 2000
return DetPreProcess(limit_side_len, self.limit_type)

def filter_tag_det_res(
self, dt_boxes: np.ndarray, image_shape: Tuple[int, int]
) -> np.ndarray:
Expand Down
15 changes: 11 additions & 4 deletions python/rapidocr_paddle/ch_ppocr_rec/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def decode(
selection &= text_index[batch_idx] != ignored_token

if text_prob is not None:
conf_list = text_prob[batch_idx][selection]
conf_list = np.array(text_prob[batch_idx][selection]).tolist()
else:
conf_list = [1] * len(selection)

Expand All @@ -116,6 +116,7 @@ def decode(
word_list,
word_col_list,
state_list,
conf_list,
],
)
)
Expand Down Expand Up @@ -147,18 +148,24 @@ def get_word_info(
word_list = []
word_col_list = []
state_list = []
valid_col = np.where(selection == True)[0]
valid_col = np.where(selection)[0]
col_width = np.zeros(valid_col.shape)
if len(valid_col) > 0:
col_width[1:] = valid_col[1:] - valid_col[:-1]
col_width[0] = min(
3 if "\u4e00" <= text[0] <= "\u9fff" else 2, int(valid_col[0])
)

for c_i, char in enumerate(text):
if "\u4e00" <= char <= "\u9fff":
c_state = "cn"
else:
c_state = "en&num"

if state == None:
if state is None:
state = c_state

if state != c_state:
if state != c_state or col_width[c_i] > 4:
if len(word_content) != 0:
word_list.append(word_content)
word_col_list.append(word_col_content)
Expand Down
5 changes: 1 addition & 4 deletions python/rapidocr_paddle/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,10 +330,7 @@ def main():
use_cls = not args.no_cls
use_rec = not args.no_rec
result, elapse_list = ocr_engine(
args.img_path,
use_det=use_det,
use_cls=use_cls,
use_rec=use_rec,
args.img_path, use_det=use_det, use_cls=use_cls, use_rec=use_rec, **vars(args)
)
logger.info(result)

Expand Down

0 comments on commit 777f9f0

Please sign in to comment.