From 0c70a914eaa5da6e477191e5e66fe3da4d4ad555 Mon Sep 17 00:00:00 2001 From: SWHL Date: Sat, 8 Jul 2023 19:11:07 +0800 Subject: [PATCH] Fix the result when frame is missing. --- README.md | 5 +- docs/README_en.md | 5 +- docs/doc_whl.md | 4 +- rapid_videocr/main.py | 78 ++++++------- rapid_videocr/rapid_videocr.py | 207 ++++++++++++++++++--------------- rapid_videocr/utils.py | 99 +++++++++++++--- tests/test_rapid_videocr.py | 83 +++++++------ 7 files changed, 282 insertions(+), 199 deletions(-) diff --git a/README.md b/README.md index 068efdc..00d1310 100644 --- a/README.md +++ b/README.md @@ -56,6 +56,9 @@ flowchart LR - [RapidVideOCR高级教程(有python基础的小伙伴)](https://github.com/SWHL/RapidVideOCR/wiki/RapidVideOCR%E9%AB%98%E7%BA%A7%E6%95%99%E7%A8%8B%EF%BC%88%E6%9C%89python%E5%9F%BA%E7%A1%80%E7%9A%84%E5%B0%8F%E4%BC%99%E4%BC%B4%EF%BC%89) ### 更新日志([more](https://github.com/SWHL/RapidVideOCR/wiki/%E6%9B%B4%E6%96%B0%E6%97%A5%E5%BF%97)) +- 🤓2023-07-08 v2.2.2 update: + - 修复批量识别时,不能读取中文路径的问题 + - 修复漏轴时,SRT中跳过问题。目前当出现某一轴未能识别,则会空出位置,便于校对。 - 🐲2023-06-22 Desktop v0.0.3 update: - 整合VideoSubFinder界面,增加视频批处理 - 优化多次选取之后,路径保存问题 @@ -64,8 +67,6 @@ flowchart LR - 将VSF的CLI整合到库中,只需指定`VideoSubFinderWXW.exe`的全路径即可。 - 增加批量识别功能,指定视频目录,即可自动提取目录下所有视频字幕 - 使用示例, 参见:[demo.py](https://github.com/SWHL/RapidVideOCR/blob/main/demo.py) -- ♠2023-06-04 Desktop v0.0.2 update: - - 修复issue #30: 保留上次选择的目录 ### 写在最后 - 微信扫描以下二维码,关注**RapidAI公众号**,回复video即可加入RapidVideOCR微信交流群: diff --git a/docs/README_en.md b/docs/README_en.md index a1f585e..55390fe 100644 --- a/docs/README_en.md +++ b/docs/README_en.md @@ -53,6 +53,9 @@ flowchart LR - [☆☆☆ RapidVideOCR Advanced Tutorial (Partners with python foundation)](https://github.com/SWHL/RapidVideOCR/wiki/RapidVideOCR%E9%AB%98%E7%BA%A7%E6%95%99%E7%A8%8B%EF%BC%88%E6%9C%89python%E5%9F%BA%E7%A1%80%E7%9A%84%E5%B0%8F%E4%BC%99%E4%BC%B4%EF%BC%89) ### Change log ([more](https://github.com/SWHL/RapidVideOCR/wiki/Changelog)) +- 🤓2023-07-08 v2.2.2 update: + - Fixed the problem that the Chinese path could not be read during batch recognition + - Skip issue in SRT when fixing missing axes. At present, when a certain axis fails to be recognized, the position will be vacated, which is convenient for proofreading. - 🐲2023-06-22 Desktop v0.0.3 update: - Integrate VideoSubFinder interface, increase video batch processing. - Optimize the problem of path preservation after multiple selections. @@ -61,8 +64,6 @@ flowchart LR - To integrate VSF's CLI into the library, just specify the full path of `VideoSubFinderWXW.exe`. - Add batch recognition function, specify the video directory, and then automatically extract all video subtitles in the directory - Use example, see: [demo.py](https://github.com/SWHL/RapidVideOCR/blob/main/demo.py) -- ♠ 2023-06-04 Desktop v0.0.2 update: - - Fix issue #30: Keep the last selected directory. ### Announce diff --git a/docs/doc_whl.md b/docs/doc_whl.md index b11bc61..2aca81a 100644 --- a/docs/doc_whl.md +++ b/docs/doc_whl.md @@ -26,8 +26,8 @@ pip install rapid_videocr from rapid_videocr import RapidVideOCR extractor = RapidVideOCR(is_concat_rec=True, - concat_batch=10, - out_format='srt') + concat_batch=10, + out_format='srt') rgb_dir = 'RGBImages' save_dir = 'outputs' diff --git a/rapid_videocr/main.py b/rapid_videocr/main.py index 3f6ae66..1aa5bc7 100644 --- a/rapid_videocr/main.py +++ b/rapid_videocr/main.py @@ -13,16 +13,16 @@ class RapidVideoSubFinderOCR: def __init__(self, vsf_exe_path: str = None, **ocr_params) -> None: if vsf_exe_path is None: - raise ValueError('vsf_exe_path must not be None.') + raise ValueError("vsf_exe_path must not be None.") self.vsf = VideoSubFinder(vsf_exe_path) self.video_ocr = RapidVideOCR(**ocr_params) - self.video_formats = ['.mp4', '.avi', '.mov', '.mkv'] + self.video_formats = [".mp4", ".avi", ".mov", ".mkv"] self.logger = get_logger() - def __call__(self, video_path: str, output_dir: str = 'outputs'): + def __call__(self, video_path: str, output_dir: str = "outputs"): if Path(video_path).is_dir(): - video_list = Path(video_path).rglob('*.*') + video_list = Path(video_path).rglob("*.*") video_list = [ v for v in video_list if v.suffix.lower() in self.video_formats ] @@ -30,26 +30,26 @@ def __call__(self, video_path: str, output_dir: str = 'outputs'): video_list = [video_path] self.logger.info( - 'Extracting subtitle images with VideoSubFinder (takes quite a long time) ...' + "Extracting subtitle images with VideoSubFinder (takes quite a long time) ..." ) video_num = len(video_list) for i, one_video in enumerate(video_list): self.logger.info( - f'[{i+1}/{video_num}] Starting to extract {one_video} key frame' + f"[{i+1}/{video_num}] Starting to extract {one_video} key frame" ) with tempfile.TemporaryDirectory() as tmp_dir: try: self.vsf(str(one_video), tmp_dir) except Exception as e: - self.logger.error(f'Extract {one_video} error, {e}, skip') + self.logger.error(f"Extract {one_video} error, {e}, skip") continue - self.logger.info(f'[{i+1}/{video_num}] Starting to run {one_video} ocr') + self.logger.info(f"[{i+1}/{video_num}] Starting to run {one_video} ocr") - rgb_dir = Path(tmp_dir) / 'RGBImages' + rgb_dir = Path(tmp_dir) / "RGBImages" if not list(rgb_dir.iterdir()): self.logger.warning( - f'Extracting frames from {one_video} is 0, skip' + f"Extracting frames from {one_video} is 0, skip" ) continue @@ -61,71 +61,71 @@ def __call__(self, video_path: str, output_dir: str = 'outputs'): def main() -> None: parser = argparse.ArgumentParser() parser.add_argument( - '-vsf', - '--vsf_exe_path', + "-vsf", + "--vsf_exe_path", type=str, default=None, - help='The full path of VideoSubFinderWXW.exe.', + help="The full path of VideoSubFinderWXW.exe.", ) parser.add_argument( - '-video_dir', - '--video_dir', + "-video_dir", + "--video_dir", type=str, default=None, - help='The full path of video or the path of video directory.', + help="The full path of video or the path of video directory.", ) parser.add_argument( - '-i', - '--img_dir', + "-i", + "--img_dir", type=str, default=None, - help='The full path of RGBImages or TXTImages.', + help="The full path of RGBImages or TXTImages.", ) parser.add_argument( - '-s', - '--save_dir', + "-s", + "--save_dir", type=str, - default='outputs', + default="outputs", help='The path of saving the recognition result. Default is "outputs" under the current directory.', ) parser.add_argument( - '-o', - '--out_format', + "-o", + "--out_format", type=str, - default='all', - choices=['srt', 'txt', 'all'], + default="all", + choices=["srt", "txt", "all"], help='Output file format. Default is "all".', ) parser.add_argument( - '-m', - '--mode', + "-m", + "--mode", type=str, - default='single', - choices=['single', 'concat'], + default="single", + choices=["single", "concat"], help='Which mode to run (concat recognition or single recognition). Default is "single".', ) parser.add_argument( - '-b', - '--concat_batch', + "-b", + "--concat_batch", type=int, default=10, - help='The batch of concating image nums in concat recognition mode. Default is 10.', + help="The batch of concating image nums in concat recognition mode. Default is 10.", ) parser.add_argument( - '-p', - '--print_console', + "-p", + "--print_console", type=bool, default=0, choices=[0, 1], - help='Whether to print the subtitle results to console. 1 means to print results to console. Default is 0.', + help="Whether to print the subtitle results to console. 1 means to print results to console. Default is 0.", ) args = parser.parse_args() - is_concat_rec = 'concat' in args.mode + is_concat_rec = "concat" in args.mode if not (args.vsf_exe_path is None and args.video_dir is None): raise ValueError( - '--vsf_exe_path or --video_dir must not be None at the same time.' + "--vsf_exe_path or --video_dir must not be None at the same time." ) if args.vsf_exe_path and args.video_dir: @@ -148,5 +148,5 @@ def main() -> None: extractor(args.img_dir, args.save_dir) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/rapid_videocr/rapid_videocr.py b/rapid_videocr/rapid_videocr.py index 944473d..49b8fc5 100644 --- a/rapid_videocr/rapid_videocr.py +++ b/rapid_videocr/rapid_videocr.py @@ -3,14 +3,20 @@ # @Contact: liekkaskono@163.com import argparse from pathlib import Path -from typing import List, Tuple, Union +from typing import Dict, List, Optional, Tuple, Union import cv2 import numpy as np from rapidocr_onnxruntime import RapidOCR from tqdm import tqdm -from .utils import CropByProject, get_logger, mkdir +from .utils import ( + CropByProject, + compute_poly_iou, + get_logger, + is_inclusive_each_other, + mkdir, +) CUR_DIR = Path(__file__).resolve().parent logger = get_logger() @@ -21,7 +27,7 @@ def __init__( self, is_concat_rec: bool = False, concat_batch: int = 10, - out_format: str = 'all', + out_format: str = "all", is_print_console: bool = False, ) -> None: """Init @@ -44,7 +50,7 @@ def __call__( self, video_sub_finder_dir: Union[str, Path], save_dir: Union[str, Path], - save_name: str = 'result', + save_name: str = "result", ) -> None: """call @@ -57,26 +63,26 @@ def __call__( """ video_sub_finder_dir = Path(video_sub_finder_dir) if not video_sub_finder_dir.exists(): - raise RapidVideOCRError(f'{video_sub_finder_dir} does not exist.') + raise RapidVideOCRError(f"{video_sub_finder_dir} does not exist.") dir_name = Path(video_sub_finder_dir).name - is_txt_dir = 'TXTImages' in dir_name + is_txt_dir = "TXTImages" in dir_name save_dir = Path(save_dir) mkdir(save_dir) - img_list = list(Path(video_sub_finder_dir).glob('*.jpeg')) + img_list = list(Path(video_sub_finder_dir).glob("*.jpeg")) img_list = sorted(img_list, key=self.get_sort_key) if not img_list: raise RapidVideOCRError( - f'{video_sub_finder_dir} has not images with jpeg as suffix.' + f"{video_sub_finder_dir} has not images with jpeg as suffix." ) if self.is_concat_rec: - logger.info('[OCR] Running with concat recognition.') + logger.info("[OCR] Running with concat recognition.") srt_result, txt_result = self.concat_rec(img_list, is_txt_dir) else: - logger.info('[OCR] Running with single recognition.') + logger.info("[OCR] Running with single recognition.") srt_result, txt_result = self.single_rec(img_list) self.export_file(save_dir, save_name, srt_result, txt_result) @@ -85,12 +91,12 @@ def __call__( self.print_console(txt_result) @staticmethod - def get_sort_key(x): - return int(''.join(str(x.stem).split('_')[:4])) + def get_sort_key(x: Path) -> int: + return int("".join(str(x.stem).split("_")[:4])) - def single_rec(self, img_list: List[str]) -> Tuple[List, List]: + def single_rec(self, img_list: List[Path]) -> Tuple[List[str], List[str]]: srt_result, txt_result = [], [] - for i, img_path in enumerate(tqdm(img_list, desc='OCR')): + for i, img_path in enumerate(tqdm(img_list, desc="OCR")): time_str = self.get_time(img_path) img = cv2.imdecode(np.fromfile(img_path, dtype=np.uint8), 1) @@ -98,24 +104,25 @@ def single_rec(self, img_list: List[str]) -> Tuple[List, List]: dt_boxes, rec_res = self.run_ocr(img) if rec_res: txts = self.process_same_line(dt_boxes, rec_res) - srt_result.append(f'{i+1}\n{time_str}\n{txts}\n') - txt_result.append(f'{txts}\n') + else: + txts = "" + + srt_result.append(f"{i+1}\n{time_str}\n{txts}\n") + txt_result.append(f"{txts}\n") return srt_result, txt_result - def concat_rec( - self, img_list: List[np.ndarray], is_txt_dir: bool - ) -> Tuple[List, List]: + def concat_rec(self, img_list: List[Path], is_txt_dir: bool) -> Tuple[List, List]: srt_result, txt_result = [], [] img_nums = len(img_list) - for start_i in tqdm(range(0, img_nums, self.batch_size), desc='OCR'): + for start_i in tqdm(range(0, img_nums, self.batch_size), desc="OCR"): end_i = min(img_nums, start_i + self.batch_size) concat_img, img_coordinates, img_paths = self.get_batch( img_list, start_i, end_i, is_txt_dir ) dt_boxes, rec_res = self.run_ocr(concat_img) - if not rec_res: + if rec_res is None or dt_boxes is None: continue srt_part, txt_part = self.get_match_results( @@ -126,25 +133,29 @@ def concat_rec( return srt_result, txt_result def get_batch( - self, img_list: List[str], start: int, end: int, is_txt_dir: bool - ) -> Tuple[np.ndarray, np.ndarray, List]: + self, img_list: List[Path], start: int, end: int, is_txt_dir: bool + ) -> Tuple[np.ndarray, np.ndarray, List[Path]]: select_imgs = img_list[start:end] - img_list, img_coordinates, batch_img_paths = [], [], [] + padding_value = 10 + array_img_list, img_coordinates, batch_img_paths = [], [], [] for i, img_path in enumerate(select_imgs): batch_img_paths.append(img_path) - img = cv2.imread(str(img_path)) + img = cv2.imdecode(np.fromfile(str(img_path), dtype=np.uint8), 1) if is_txt_dir: img = cv2.resize(img, None, fx=0.25, fy=0.25) - img_list.append(img) + pad_img = self.padding_img(img, (0, padding_value, 0, 0)) + array_img_list.append(pad_img) h, w = img.shape[:2] - img_coordinates.append([(0, i * h), (w, (i + 1) * h)]) + x0, y0 = 0, i * (h + padding_value) + x1, y1 = w, (i + 1) * (h + padding_value) + img_coordinates.append([(x0, y0), (x1, y0), (x1, y1), (x0, y1)]) - concat_img = np.vstack(img_list) + concat_img = np.vstack(array_img_list) return concat_img, np.array(img_coordinates), batch_img_paths def get_match_results( @@ -152,36 +163,40 @@ def get_match_results( start_i: int, img_coordinates: np.ndarray, dt_boxes: np.ndarray, - rec_res: List, - img_paths: list, - ) -> Tuple[List, List]: + rec_res: Tuple[str], + img_paths: List[Path], + ) -> Tuple[List[str], List[str]]: srt_result_part, txt_result_part = [], [] - match_dict = {} - y_points = img_coordinates[:, :, 1] - left_top_boxes = dt_boxes[:, 0, :] - for i, one_left in enumerate(left_top_boxes): - y = one_left[1] - condition = (y >= y_points[:, 0]) & (y < y_points[:, 1]) - index = np.argwhere(condition) - if not index.size: - match_dict[i] = '' - continue - - matched_index = index.squeeze().tolist() - matched_path = img_paths[matched_index] - match_dict.setdefault(matched_index, []).append( - [matched_path, dt_boxes[i], rec_res[i]] - ) + match_dict: Dict[int, List[Union[Path, np.ndarray, str]]] = { + k: [] for k in range(len(img_coordinates)) + } + visited_idx = [] + for i, frame_boxes in enumerate(img_coordinates): + for idx, dt_box in enumerate(dt_boxes): + if idx in visited_idx: + continue + + box_iou = compute_poly_iou(frame_boxes, dt_box) + if is_inclusive_each_other(frame_boxes, dt_box) or box_iou > 0.1: + matched_path = img_paths[idx] + match_dict.setdefault(i, []).append( + [matched_path, dt_box, rec_res[idx]] + ) + visited_idx.append(idx) for k, v in match_dict.items(): cur_frame_idx = start_i + k - img_path, boxes, recs = list(zip(*v)) + if v: + img_path, boxes, recs = list(zip(*v)) + time_str = self.get_time(img_path[0]) + txts = self.process_same_line(boxes, recs) + else: + time_str = self.get_time(img_paths[k]) + txts = "" - time_str = self.get_time(img_path[0]) - txts = self.process_same_line(boxes, recs) - srt_result_part.append(f'{cur_frame_idx+1}\n{time_str}\n{txts}\n') - txt_result_part.append(f'{txts}\n') + srt_result_part.append(f"{cur_frame_idx+1}\n{time_str}\n{txts}\n") + txt_result_part.append(f"{txts}\n") return srt_result_part, txt_result_part @staticmethod @@ -194,17 +209,19 @@ def get_time(file_path: Path) -> str: Returns: str: 字幕开始和截止时间戳字符串 """ - split_paths = file_path.stem.split('_') + split_paths = file_path.stem.split("_") start_time = split_paths[:4] - start_time[0] = f'{start_time[0]:0>2}' - start_str = ':'.join(start_time[:3]) + f',{start_time[3]}' + start_time[0] = f"{start_time[0]:0>2}" + start_str = ":".join(start_time[:3]) + f",{start_time[3]}" end_time = split_paths[5:9] - end_time[0] = f'{end_time[0]:0>2}' - end_str = ':'.join(end_time[:3]) + f',{end_time[3]}' - return f'{start_str} --> {end_str}' + end_time[0] = f"{end_time[0]:0>2}" + end_str = ":".join(end_time[:3]) + f",{end_time[3]}" + return f"{start_str} --> {end_str}" - def run_ocr(self, img: np.ndarray) -> Tuple[np.ndarray, List]: + def run_ocr( + self, img: np.ndarray + ) -> Tuple[Optional[np.ndarray], Optional[Tuple[str]]]: ocr_result, _ = self.rapid_ocr(img) if ocr_result is None: return None, None @@ -216,7 +233,7 @@ def run_ocr(self, img: np.ndarray) -> Tuple[np.ndarray, List]: def padding_img( img: np.ndarray, padding_value: Tuple[int, int, int, int], - padding_color: Tuple = (0, 0, 0), + padding_color: Tuple[int, int, int] = (0, 0, 0), ) -> np.ndarray: padded_img = cv2.copyMakeBorder( img, @@ -247,12 +264,12 @@ def process_same_line(self, dt_boxes, rec_res): for v in pair_point: used[v] = True concat_str.append(rec_res[v]) - final_res.append(' '.join(concat_str)) + final_res.append(" ".join(concat_str)) else: for v in pair_point: if not used[v]: final_res.append(rec_res[v]) - return '\n'.join(final_res) + return "\n".join(final_res) def export_file( self, @@ -264,36 +281,36 @@ def export_file( if isinstance(save_dir, str): save_dir = Path(save_dir) - srt_path = save_dir / f'{save_name}.srt' - txt_path = save_dir / f'{save_name}.txt' + srt_path = save_dir / f"{save_name}.srt" + txt_path = save_dir / f"{save_name}.txt" - if self.out_format == 'txt': + if self.out_format == "txt": self.save_file(txt_path, txt_result) - elif self.out_format == 'srt': + elif self.out_format == "srt": self.save_file(srt_path, srt_result) - elif self.out_format == 'all': + elif self.out_format == "all": self.save_file(txt_path, txt_result) self.save_file(srt_path, srt_result) else: - raise ValueError(f'The {self.out_format} dost not support.') - logger.info(f'[OCR] The result has been saved to {save_dir} directory.') + raise ValueError(f"The {self.out_format} dost not support.") + logger.info(f"[OCR] The result has been saved to {save_dir} directory.") def print_console(self, txt_result: List) -> None: for v in txt_result: print(v.strip()) @staticmethod - def save_file(save_path: Union[str, Path], content: list, mode: str = 'w') -> None: + def save_file(save_path: Union[str, Path], content: List, mode: str = "w") -> None: if not isinstance(save_path, str): save_path = str(save_path) if not isinstance(content, list): content = [content] - with open(save_path, mode, encoding='utf-8') as f: + with open(save_path, mode, encoding="utf-8") as f: for value in content: - f.write(f'{value}\n') - logger.info(f'[OCR] The file has been saved in the {save_path}') + f.write(f"{value}\n") + logger.info(f"[OCR] The file has been saved in the {save_path}") @staticmethod def _compute_centroid(points: np.ndarray) -> List: @@ -326,53 +343,53 @@ class RapidVideOCRError(Exception): def main() -> None: parser = argparse.ArgumentParser() parser.add_argument( - '-i', - '--img_dir', + "-i", + "--img_dir", type=str, required=True, - help='The full path of RGBImages or TXTImages.', + help="The full path of RGBImages or TXTImages.", ) parser.add_argument( - '-s', - '--save_dir', + "-s", + "--save_dir", type=str, - default='outputs', + default="outputs", help='The path of saving the recognition result. Default is "outputs" under the current directory.', ) parser.add_argument( - '-o', - '--out_format', + "-o", + "--out_format", type=str, - default='all', - choices=['srt', 'txt', 'all'], + default="all", + choices=["srt", "txt", "all"], help='Output file format. Default is "all".', ) parser.add_argument( - '-m', - '--mode', + "-m", + "--mode", type=str, - default='single', - choices=['single', 'concat'], + default="single", + choices=["single", "concat"], help='Which mode to run (concat recognition or single recognition). Default is "single".', ) parser.add_argument( - '-b', - '--concat_batch', + "-b", + "--concat_batch", type=int, default=10, - help='The batch of concating image nums in concat recognition mode. Default is 10.', + help="The batch of concating image nums in concat recognition mode. Default is 10.", ) parser.add_argument( - '-p', - '--print_console', + "-p", + "--print_console", type=bool, default=0, choices=[0, 1], - help='Whether to print the subtitle results to console. 1 means to print results to console. Default is 0.', + help="Whether to print the subtitle results to console. 1 means to print results to console. Default is 0.", ) args = parser.parse_args() - is_concat_rec = 'concat' in args.mode + is_concat_rec = "concat" in args.mode extractor = RapidVideOCR( is_concat_rec=is_concat_rec, concat_batch=args.concat_batch, @@ -382,5 +399,5 @@ def main() -> None: extractor(args.img_dir, args.save_dir) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/rapid_videocr/utils.py b/rapid_videocr/utils.py index 871fcf6..d89f14a 100644 --- a/rapid_videocr/utils.py +++ b/rapid_videocr/utils.py @@ -5,11 +5,13 @@ import logging import sys from pathlib import Path -from typing import List +from typing import List, Union import colorlog import cv2 import numpy as np +import shapely +from shapely.geometry import MultiPoint, Polygon logger_initialized = {} @@ -30,10 +32,10 @@ def __call__(self, origin_img): closed = cv2.dilate(img, None, iterations=1) # 水平投影 - x0, x1 = self.get_project_loc(closed, direction='width') + x0, x1 = self.get_project_loc(closed, direction="width") # 竖直投影 - y0, y1 = self.get_project_loc(closed, direction='height') + y0, y1 = self.get_project_loc(closed, direction="height") return origin_img[y0:y1, x0:x1] @@ -48,12 +50,12 @@ def get_project_loc(img, direction): Returns: tuple: 起始索引位置 """ - if direction == 'width': + if direction == "width": axis = 0 - elif direction == 'height': + elif direction == "height": axis = 1 else: - raise ValueError(f'direction {direction} is not supported!') + raise ValueError(f"direction {direction} is not supported!") loc_sum = np.sum(img == 255, axis=axis) loc_range = np.argwhere(loc_sum > 0) @@ -65,17 +67,17 @@ def mkdir(dir_path): Path(dir_path).mkdir(parents=True, exist_ok=True) -def read_txt(txt_path: str) -> List: +def read_txt(txt_path: Union[str, Path]) -> List[str]: if not isinstance(txt_path, str): txt_path = str(txt_path) - with open(txt_path, 'r', encoding='utf-8') as f: - data = list(map(lambda x: x.rstrip('\n'), f)) + with open(txt_path, "r", encoding="utf-8") as f: + data = list(map(lambda x: x.rstrip("\n"), f)) return data @functools.lru_cache() -def get_logger(name='rapid_videocr'): +def get_logger(name="rapid_videocr"): logger = logging.getLogger(name) if name in logger_initialized: return logger @@ -84,13 +86,13 @@ def get_logger(name='rapid_videocr'): if name.startswith(logger_name): return logger - fmt_string = '%(log_color)s[%(asctime)s] [%(name)s] %(levelname)s: %(message)s' + fmt_string = "%(log_color)s[%(asctime)s] [%(name)s] %(levelname)s: %(message)s" log_colors = { - 'DEBUG': 'white', - 'INFO': 'white', - 'WARNING': 'yellow', - 'ERROR': 'red', - 'CRITICAL': 'purple', + "DEBUG": "white", + "INFO": "white", + "WARNING": "yellow", + "ERROR": "red", + "CRITICAL": "purple", } fmt = colorlog.ColoredFormatter(fmt_string, log_colors=log_colors) stream_handler = logging.StreamHandler(stream=sys.stdout) @@ -101,3 +103,68 @@ def get_logger(name='rapid_videocr'): logger_initialized[name] = True logger.propagate = False return logger + + +def compute_poly_iou(a: np.ndarray, b: np.ndarray) -> float: + """计算两个多边形的IOU + + Args: + poly1 (np.ndarray): (4, 2) + poly2 (np.ndarray): (4, 2) + + Returns: + float: iou + """ + poly1 = Polygon(a).convex_hull + poly2 = Polygon(b).convex_hull + + union_poly = np.concatenate((a, b)) + + if not poly1.intersects(poly2): + return 0.0 + + try: + inter_area = poly1.intersection(poly2).area + union_area = MultiPoint(union_poly).convex_hull.area + except shapely.geos.TopologicalError: + print("shapely.geos.TopologicalError occured, iou set to 0") + return 0.0 + + if union_area == 0: + return 0.0 + + return float(inter_area) / union_area + + +def is_inclusive_each_other(box1: np.ndarray, box2: np.ndarray) -> bool: + """判断两个多边形框是否存在包含关系 + + Args: + box1 (np.ndarray): (4, 2) + box2 (np.ndarray): (4, 2) + + Returns: + bool: 是否存在包含关系 + """ + poly1 = Polygon(box1) + poly2 = Polygon(box2) + + poly1_area = poly1.convex_hull.area + poly2_area = poly2.convex_hull.area + + if poly1_area > poly2_area: + box_max = box1 + box_min = box2 + else: + box_max = box2 + box_min = box1 + + x0, y0 = np.min(box_min[:, 0]), np.min(box_min[:, 1]) + x1, y1 = np.max(box_min[:, 0]), np.max(box_min[:, 1]) + + edge_x0, edge_y0 = np.min(box_max[:, 0]), np.min(box_max[:, 1]) + edge_x1, edge_y1 = np.max(box_max[:, 0]), np.max(box_max[:, 1]) + + if x0 >= edge_x0 and y0 >= edge_y0 and x1 <= edge_x1 and y1 <= edge_y1: + return True + return False diff --git a/tests/test_rapid_videocr.py b/tests/test_rapid_videocr.py index d5471d2..1ab1436 100644 --- a/tests/test_rapid_videocr.py +++ b/tests/test_rapid_videocr.py @@ -1,13 +1,11 @@ # -*- encoding: utf-8 -*- # @Author: SWHL # @Contact: liekkaskono@163.com -from pathlib import Path - import shutil - import sys -import pytest +from pathlib import Path +import pytest cur_dir = Path(__file__).resolve().parent root_dir = cur_dir.parent @@ -15,20 +13,19 @@ sys.path.append(str(root_dir)) from rapid_videocr import RapidVideOCR, RapidVideOCRError -from rapid_videocr.utils import read_txt, mkdir - +from rapid_videocr.utils import mkdir, read_txt -test_file_dir = cur_dir / 'test_files' -srt_path = test_file_dir / 'result.srt' -txt_path = test_file_dir / 'result.txt' +test_file_dir = cur_dir / "test_files" +srt_path = test_file_dir / "result.srt" +txt_path = test_file_dir / "result.txt" @pytest.mark.parametrize( - 'img_dir', + "img_dir", [ - test_file_dir / 'RGBImages', - test_file_dir / 'TXTImages', - ] + test_file_dir / "RGBImages", + test_file_dir / "TXTImages", + ], ) def test_single_rec(img_dir): extractor = RapidVideOCR(is_concat_rec=False) @@ -37,22 +34,22 @@ def test_single_rec(img_dir): srt_data = read_txt(srt_path) txt_data = read_txt(txt_path) - assert len(srt_data) == 12 - assert srt_data[2] == '空间里面他绝对赢不了的' - assert srt_data[-2] == '你们接着善后' + assert len(srt_data) == 16 + assert srt_data[2] == "空间里面他绝对赢不了的" + assert srt_data[-2] == "你们接着善后" - assert len(txt_data) == 6 - assert txt_data[-2] == '你们接着善后' + assert len(txt_data) == 8 + assert txt_data[-2] == "你们接着善后" srt_path.unlink() txt_path.unlink() @pytest.mark.parametrize( - 'img_dir', + "img_dir", [ - test_file_dir / 'RGBImages', - ] + test_file_dir / "RGBImages", + ], ) def test_concat_rec(img_dir): extractor = RapidVideOCR(is_concat_rec=True) @@ -61,23 +58,23 @@ def test_concat_rec(img_dir): srt_data = read_txt(srt_path) txt_data = read_txt(txt_path) - assert len(srt_data) == 12 - assert srt_data[2] == '空间里面他绝对赢不了的' - assert srt_data[-2] == '你们接着善后' + assert len(srt_data) == 16 + assert srt_data[2] == "空间里面他绝对赢不了的" + assert srt_data[-2] == "你们接着善后" - assert len(txt_data) == 6 - assert txt_data[-2] == '你们接着善后' + assert len(txt_data) == 8 + assert txt_data[-2] == "你们接着善后" srt_path.unlink() txt_path.unlink() @pytest.mark.parametrize( - 'img_dir', + "img_dir", [ - test_file_dir / 'RGBImage', - test_file_dir / 'TXTImage', - ] + test_file_dir / "RGBImage", + test_file_dir / "TXTImage", + ], ) def test_empty_dir(img_dir): extractor = RapidVideOCR(is_concat_rec=False) @@ -91,11 +88,11 @@ def test_empty_dir(img_dir): @pytest.mark.parametrize( - 'img_dir', + "img_dir", [ - test_file_dir / 'RGBImage', - test_file_dir / 'TXTImage', - ] + test_file_dir / "RGBImage", + test_file_dir / "TXTImage", + ], ) def test_nothing_dir(img_dir): extractor = RapidVideOCR(is_concat_rec=False) @@ -108,27 +105,27 @@ def test_nothing_dir(img_dir): def test_out_only_srt(): - img_dir = test_file_dir / 'RGBImages' - extractor = RapidVideOCR(is_concat_rec=True, out_format='srt') + img_dir = test_file_dir / "RGBImages" + extractor = RapidVideOCR(is_concat_rec=True, out_format="srt") extractor(img_dir, test_file_dir) srt_data = read_txt(srt_path) - assert len(srt_data) == 12 - assert srt_data[2] == '空间里面他绝对赢不了的' - assert srt_data[-2] == '你们接着善后' + assert len(srt_data) == 16 + assert srt_data[2] == "空间里面他绝对赢不了的" + assert srt_data[-2] == "你们接着善后" srt_path.unlink() def test_out_only_txt(): - img_dir = test_file_dir / 'RGBImages' - extractor = RapidVideOCR(is_concat_rec=True, out_format='txt') + img_dir = test_file_dir / "RGBImages" + extractor = RapidVideOCR(is_concat_rec=True, out_format="txt") extractor(img_dir, test_file_dir) txt_data = read_txt(txt_path) - assert len(txt_data) == 6 - assert txt_data[-2] == '你们接着善后' + assert len(txt_data) == 8 + assert txt_data[-2] == "你们接着善后" txt_path.unlink()