diff --git a/README.md b/README.md index b6dd8ce..5663b5b 100644 --- a/README.md +++ b/README.md @@ -10,6 +10,9 @@ PyPI + +SemVer2.0 + diff --git a/demo.py b/demo.py index 58d9a0c..75af9d5 100644 --- a/demo.py +++ b/demo.py @@ -5,7 +5,36 @@ from rapid_orientation import RapidOrientation + +def scale_resize(img, resize_value=(280, 32)): + """ + @params: + img: ndarray + resize_value: (width, height) + """ + # padding + ratio = resize_value[0] / resize_value[1] # w / h + h, w = img.shape[:2] + if w / h < ratio: + # 补宽 + t = int(h * ratio) + w_padding = (t - w) // 2 + img = cv2.copyMakeBorder( + img, 0, 0, w_padding, w_padding, cv2.BORDER_CONSTANT, value=(0, 0, 0) + ) + else: + # 补高 (top, bottom, left, right) + t = int(w / ratio) + h_padding = (t - h) // 2 + color = tuple([int(i) for i in img[0, 0, :]]) + img = cv2.copyMakeBorder( + img, h_padding, h_padding, 0, 0, cv2.BORDER_CONSTANT, value=(0, 0, 0) + ) + img = cv2.resize(img, resize_value, interpolation=cv2.INTER_LANCZOS4) + return img + + orientation_engine = RapidOrientation() -img = cv2.imread("tests/test_files/img_rot180_demo.jpg") +img = cv2.imread("tests/test_files/1.png") cls_result, _ = orientation_engine(img) print(cls_result) diff --git a/rapid_orientation/config.yaml b/rapid_orientation/config.yaml index 78847b2..cf96003 100644 --- a/rapid_orientation/config.yaml +++ b/rapid_orientation/config.yaml @@ -6,11 +6,3 @@ CUDAExecutionProvider: arena_extend_strategy: kNextPowerOfTwo cudnn_conv_algo_search: EXHAUSTIVE do_copy_in_default_stream: true - -PreProcess: - - ResizeImage: - resize_short: 256 - - CropImage: - size: 224 - - NormalizeImage: - - ToCHWImage: diff --git a/rapid_orientation/main.py b/rapid_orientation/main.py index 387ba6d..aa90fd5 100644 --- a/rapid_orientation/main.py +++ b/rapid_orientation/main.py @@ -17,53 +17,52 @@ import argparse import time from pathlib import Path -from typing import Optional, Union +from typing import Union import cv2 import numpy as np -import yaml -from .utils import LoadImage, OrtInferSession, create_operators +from .utils.infer_engine import OrtInferSession +from .utils.load_image import LoadImage +from .utils.preprocess import Preprocess +from .utils.utils import read_yaml root_dir = Path(__file__).resolve().parent +DEFAULT_PATH = root_dir / "models" / "rapid_orientation.onnx" +DEFAULT_CFG = root_dir / "config.yaml" class RapidOrientation: - def __init__(self, model_path: Optional[str] = None): - config_path = str(root_dir / "config.yaml") - config = self.read_yaml(config_path) - if model_path is None: - model_path = str(root_dir / "models" / "rapid_orientation.onnx") + def __init__( + self, + model_path: Union[str, Path] = DEFAULT_PATH, + cfg_path: Union[str, Path] = DEFAULT_CFG, + ): + config = read_yaml(cfg_path) config["model_path"] = model_path self.session = OrtInferSession(config) - self.labels = self.session.get_metadata()["character"].splitlines() - - self.preprocess_ops = create_operators(config["PreProcess"]) + self.labels = self.session.get_character_list() + self.preprocess = Preprocess() self.load_img = LoadImage() def __call__(self, img_content: Union[str, np.ndarray, bytes, Path]): - images = self.load_img(img_content) + image = self.load_img(img_content) + + s = time.perf_counter() - s = time.time() - for ops in self.preprocess_ops: - images = ops(images) - image = np.array(images)[None, ...] + image = self.preprocess(image) + image = image[None, ...] pred_output = self.session(image)[0] pred_output = pred_output.squeeze() pred_idx = np.argmax(pred_output) pred_txt = self.labels[pred_idx] - elapse = time.time() - s - return pred_txt, elapse - @staticmethod - def read_yaml(yaml_path): - with open(yaml_path, "rb") as f: - data = yaml.load(f, Loader=yaml.Loader) - return data + elapse = time.perf_counter() - s + return pred_txt, elapse def main(): diff --git a/rapid_orientation/utils.py b/rapid_orientation/utils.py deleted file mode 100644 index 018e577..0000000 --- a/rapid_orientation/utils.py +++ /dev/null @@ -1,258 +0,0 @@ -# -*- encoding: utf-8 -*- -# @Author: SWHL -# @Contact: liekkaskono@163.com -import importlib -import warnings -from io import BytesIO -from pathlib import Path -from typing import Union - -import cv2 -import numpy as np -from onnxruntime import ( - GraphOptimizationLevel, - InferenceSession, - SessionOptions, - get_available_providers, - get_device, -) -from PIL import Image, UnidentifiedImageError - -InputType = Union[str, np.ndarray, bytes, Path] - - -class OrtInferSession: - def __init__(self, config): - sess_opt = SessionOptions() - sess_opt.log_severity_level = 4 - sess_opt.enable_cpu_mem_arena = False - sess_opt.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_ALL - - cuda_ep = "CUDAExecutionProvider" - cpu_ep = "CPUExecutionProvider" - cpu_provider_options = { - "arena_extend_strategy": "kSameAsRequested", - } - - EP_list = [] - if ( - config["use_cuda"] - and get_device() == "GPU" - and cuda_ep in get_available_providers() - ): - EP_list = [(cuda_ep, config[cuda_ep])] - EP_list.append((cpu_ep, cpu_provider_options)) - - self._verify_model(config["model_path"]) - self.session = InferenceSession( - config["model_path"], sess_options=sess_opt, providers=EP_list - ) - - has_cuda_ep = cuda_ep not in self.session.get_providers() - if config["use_cuda"] and has_cuda_ep: - warnings.warn( - f"{cuda_ep} is not avaiable for current env," - f"the inference part is automatically shifted to " - f"be executed under {cpu_ep}. " - f"Please ensure the installed onnxruntime-gpu " - f" version matches your cuda and cudnn version, " - f"you can check their relations from the offical web site: " - f"https://onnxruntime.ai/docs/execution-providers/CUDA-ExecutionProvider.html", - RuntimeWarning, - ) - - def __call__(self, input_content: np.ndarray) -> np.ndarray: - input_dict = dict(zip(self.get_input_names(), [input_content])) - try: - return self.session.run(self.get_output_names(), input_dict) - except Exception as e: - raise ONNXRuntimeError("ONNXRuntime inferece failed.") from e - - def get_input_names( - self, - ): - return [v.name for v in self.session.get_inputs()] - - def get_output_names( - self, - ): - return [v.name for v in self.session.get_outputs()] - - def get_metadata(self): - meta_dict = self.session.get_modelmeta().custom_metadata_map - return meta_dict - - @staticmethod - def _verify_model(model_path): - model_path = Path(model_path) - if not model_path.exists(): - raise FileNotFoundError(f"{model_path} does not exists.") - if not model_path.is_file(): - raise FileExistsError(f"{model_path} is not a file.") - - -class ONNXRuntimeError(Exception): - pass - - -class LoadImage: - def __init__( - self, - ): - pass - - def __call__(self, img: InputType) -> np.ndarray: - if not isinstance(img, InputType.__args__): - raise LoadImageError( - f"The img type {type(img)} does not in {InputType.__args__}" - ) - - img = self.load_img(img) - - if img.ndim == 2: - return cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) - - if img.ndim == 3 and img.shape[2] == 4: - return self.cvt_four_to_three(img) - - return img - - def load_img(self, img: InputType) -> np.ndarray: - if isinstance(img, (str, Path)): - self.verify_exist(img) - try: - img = np.array(Image.open(img)) - img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) - except UnidentifiedImageError as e: - raise LoadImageError(f"cannot identify image file {img}") from e - return img - - if isinstance(img, bytes): - img = np.array(Image.open(BytesIO(img))) - img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) - return img - - if isinstance(img, np.ndarray): - return img - - raise LoadImageError(f"{type(img)} is not supported!") - - @staticmethod - def cvt_four_to_three(img: np.ndarray) -> np.ndarray: - """RGBA → RGB""" - r, g, b, a = cv2.split(img) - new_img = cv2.merge((b, g, r)) - - not_a = cv2.bitwise_not(a) - not_a = cv2.cvtColor(not_a, cv2.COLOR_GRAY2BGR) - - new_img = cv2.bitwise_and(new_img, new_img, mask=a) - new_img = cv2.add(new_img, not_a) - return new_img - - @staticmethod - def verify_exist(file_path: Union[str, Path]): - if not Path(file_path).exists(): - raise LoadImageError(f"{file_path} does not exist.") - - -class LoadImageError(Exception): - pass - - -def create_operators(params): - """ - create operators based on the config - - Args: - params(list): a dict list, used to create some operators - """ - assert isinstance(params, list), "operator config should be a list" - mod = importlib.import_module(__name__) - ops = [] - for operator in params: - assert isinstance(operator, dict) and len(operator) == 1, "yaml format error" - op_name = list(operator)[0] - param = {} if operator[op_name] is None else operator[op_name] - op = getattr(mod, op_name)(**param) - ops.append(op) - return ops - - -class ResizeImage: - def __init__(self, size=None, resize_short=None): - if resize_short is not None and resize_short > 0: - self.resize_short = resize_short - self.w, self.h = None, None - elif size is not None: - self.resize_short = None - self.w = size if isinstance(size, int) else size[0] - self.h = size if isinstance(size, int) else size[1] - else: - raise ValueError( - "invalid params for ReisizeImage for '\ - 'both 'size' and 'resize_short' are None" - ) - - def __call__(self, img: np.ndarray): - img_h, img_w = img.shape[:2] - - if self.resize_short: - percent = float(self.resize_short) / min(img_w, img_h) - w = int(round(img_w * percent)) - h = int(round(img_h * percent)) - else: - w = self.w - h = self.h - return cv2.resize(img, (w, h)) - - -class NormalizeImage: - def __init__( - self, - ): - self.scale = np.float32(1.0 / 255.0) - mean = [0.485, 0.456, 0.406] - std = [0.229, 0.224, 0.225] - - shape = 1, 1, 3 - self.mean = np.array(mean).reshape(shape).astype("float32") - self.std = np.array(std).reshape(shape).astype("float32") - - def __call__(self, img): - img = np.array(img).astype(np.float32) - img = (img * self.scale - self.mean) / self.std - return img.astype(np.float32) - - -class ToCHWImage: - def __init__(self): - pass - - def __call__(self, img): - img = np.array(img) - return img.transpose((2, 0, 1)) - - -class CropImage: - def __init__(self, size): - self.size = size - if isinstance(size, int): - self.size = (size, size) - - def __call__(self, img): - w, h = self.size - img_h, img_w = img.shape[:2] - - if img_h < h or img_w < w: - raise ValueError( - f"The size({h}, {w}) of CropImage must be greater than " - f"size({img_h}, {img_w}) of image." - ) - - w_start = (img_w - w) // 2 - h_start = (img_h - h) // 2 - - w_end = w_start + w - h_end = h_start + h - return img[h_start:h_end, w_start:w_end, :] diff --git a/rapid_orientation/utils/__init__.py b/rapid_orientation/utils/__init__.py new file mode 100644 index 0000000..0ecdd4f --- /dev/null +++ b/rapid_orientation/utils/__init__.py @@ -0,0 +1,3 @@ +# -*- encoding: utf-8 -*- +# @Author: SWHL +# @Contact: liekkaskono@163.com diff --git a/rapid_orientation/utils/infer_engine.py b/rapid_orientation/utils/infer_engine.py new file mode 100644 index 0000000..fb1fd35 --- /dev/null +++ b/rapid_orientation/utils/infer_engine.py @@ -0,0 +1,231 @@ +# -*- encoding: utf-8 -*- +# @Author: SWHL +# @Contact: liekkaskono@163.com +import os +import platform +import traceback +from enum import Enum +from pathlib import Path +from typing import Any, Dict, List, Tuple, Union + +import numpy as np +from onnxruntime import ( + GraphOptimizationLevel, + InferenceSession, + SessionOptions, + get_available_providers, + get_device, +) + +from .logger import get_logger + + +class EP(Enum): + CPU_EP = "CPUExecutionProvider" + CUDA_EP = "CUDAExecutionProvider" + DIRECTML_EP = "DmlExecutionProvider" + + +class OrtInferSession: + def __init__(self, config: Dict[str, Any]): + self.logger = get_logger("OrtInferSession") + + model_path = config.get("model_path", None) + self._verify_model(model_path) + + self.cfg_use_cuda = config.get("use_cuda", None) + self.cfg_use_dml = config.get("use_dml", None) + + self.had_providers: List[str] = get_available_providers() + EP_list = self._get_ep_list() + + sess_opt = self._init_sess_opts(config) + self.session = InferenceSession( + model_path, + sess_options=sess_opt, + providers=EP_list, + ) + self._verify_providers() + + @staticmethod + def _init_sess_opts(config: Dict[str, Any]) -> SessionOptions: + sess_opt = SessionOptions() + sess_opt.log_severity_level = 4 + sess_opt.enable_cpu_mem_arena = False + sess_opt.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_ALL + + cpu_nums = os.cpu_count() + intra_op_num_threads = config.get("intra_op_num_threads", -1) + if intra_op_num_threads != -1 and 1 <= intra_op_num_threads <= cpu_nums: + sess_opt.intra_op_num_threads = intra_op_num_threads + + inter_op_num_threads = config.get("inter_op_num_threads", -1) + if inter_op_num_threads != -1 and 1 <= inter_op_num_threads <= cpu_nums: + sess_opt.inter_op_num_threads = inter_op_num_threads + + return sess_opt + + def _get_ep_list(self) -> List[Tuple[str, Dict[str, Any]]]: + cpu_provider_opts = { + "arena_extend_strategy": "kSameAsRequested", + } + EP_list = [(EP.CPU_EP.value, cpu_provider_opts)] + + cuda_provider_opts = { + "device_id": 0, + "arena_extend_strategy": "kNextPowerOfTwo", + "cudnn_conv_algo_search": "EXHAUSTIVE", + "do_copy_in_default_stream": True, + } + self.use_cuda = self._check_cuda() + if self.use_cuda: + EP_list.insert(0, (EP.CUDA_EP.value, cuda_provider_opts)) + + self.use_directml = self._check_dml() + if self.use_directml: + self.logger.info( + "Windows 10 or above detected, try to use DirectML as primary provider" + ) + directml_options = ( + cuda_provider_opts if self.use_cuda else cpu_provider_opts + ) + EP_list.insert(0, (EP.DIRECTML_EP.value, directml_options)) + return EP_list + + def _check_cuda(self) -> bool: + if not self.cfg_use_cuda: + return False + + cur_device = get_device() + if cur_device == "GPU" and EP.CUDA_EP.value in self.had_providers: + return True + + self.logger.warning( + "%s is not in available providers (%s). Use %s inference by default.", + EP.CUDA_EP.value, + self.had_providers, + self.had_providers[0], + ) + self.logger.info("!!!Recommend to use rapidocr_paddle for inference on GPU.") + self.logger.info( + "(For reference only) If you want to use GPU acceleration, you must do:" + ) + self.logger.info( + "First, uninstall all onnxruntime pakcages in current environment." + ) + self.logger.info( + "Second, install onnxruntime-gpu by `pip install onnxruntime-gpu`." + ) + self.logger.info( + "\tNote the onnxruntime-gpu version must match your cuda and cudnn version." + ) + self.logger.info( + "\tYou can refer this link: https://onnxruntime.ai/docs/execution-providers/CUDA-EP.html" + ) + self.logger.info( + "Third, ensure %s is in available providers list. e.g. ['CUDAExecutionProvider', 'CPUExecutionProvider']", + EP.CUDA_EP.value, + ) + return False + + def _check_dml(self) -> bool: + if not self.cfg_use_dml: + return False + + cur_os = platform.system() + if cur_os != "Windows": + self.logger.warning( + "DirectML is only supported in Windows OS. The current OS is %s. Use %s inference by default.", + cur_os, + self.had_providers[0], + ) + return False + + cur_window_version = int(platform.release().split(".")[0]) + if cur_window_version < 10: + self.logger.warning( + "DirectML is only supported in Windows 10 and above OS. The current Windows version is %s. Use %s inference by default.", + cur_window_version, + self.had_providers[0], + ) + return False + + if EP.DIRECTML_EP.value in self.had_providers: + return True + + self.logger.warning( + "%s is not in available providers (%s). Use %s inference by default.", + EP.DIRECTML_EP.value, + self.had_providers, + self.had_providers[0], + ) + self.logger.info("If you want to use DirectML acceleration, you must do:") + self.logger.info( + "First, uninstall all onnxruntime pakcages in current environment." + ) + self.logger.info( + "Second, install onnxruntime-directml by `pip install onnxruntime-directml`" + ) + self.logger.info( + "Third, ensure %s is in available providers list. e.g. ['DmlExecutionProvider', 'CPUExecutionProvider']", + EP.DIRECTML_EP.value, + ) + return False + + def _verify_providers(self): + session_providers = self.session.get_providers() + first_provider = session_providers[0] + + if self.use_cuda and first_provider != EP.CUDA_EP.value: + self.logger.warning( + "%s is not avaiable for current env, the inference part is automatically shifted to be executed under %s.", + EP.CUDA_EP.value, + first_provider, + ) + + if self.use_directml and first_provider != EP.DIRECTML_EP.value: + self.logger.warning( + "%s is not available for current env, the inference part is automatically shifted to be executed under %s.", + EP.DIRECTML_EP.value, + first_provider, + ) + + def __call__(self, input_content: np.ndarray) -> np.ndarray: + input_dict = dict(zip(self.get_input_names(), [input_content])) + try: + return self.session.run(self.get_output_names(), input_dict) + except Exception as e: + error_info = traceback.format_exc() + raise ONNXRuntimeError(error_info) from e + + def get_input_names(self) -> List[str]: + return [v.name for v in self.session.get_inputs()] + + def get_output_names(self) -> List[str]: + return [v.name for v in self.session.get_outputs()] + + def get_character_list(self, key: str = "character") -> List[str]: + meta_dict = self.session.get_modelmeta().custom_metadata_map + return meta_dict[key].splitlines() + + def have_key(self, key: str = "character") -> bool: + meta_dict = self.session.get_modelmeta().custom_metadata_map + if key in meta_dict.keys(): + return True + return False + + @staticmethod + def _verify_model(model_path: Union[str, Path, None]): + if model_path is None: + raise ValueError("model_path is None!") + + model_path = Path(model_path) + if not model_path.exists(): + raise FileNotFoundError(f"{model_path} does not exists.") + + if not model_path.is_file(): + raise FileExistsError(f"{model_path} is not a file.") + + +class ONNXRuntimeError(Exception): + pass diff --git a/rapid_orientation/utils/load_image.py b/rapid_orientation/utils/load_image.py new file mode 100644 index 0000000..f34b549 --- /dev/null +++ b/rapid_orientation/utils/load_image.py @@ -0,0 +1,123 @@ +# -*- encoding: utf-8 -*- +# @Author: SWHL +# @Contact: liekkaskono@163.com +from io import BytesIO +from pathlib import Path +from typing import Any, Union + +import cv2 +import numpy as np +from PIL import Image, UnidentifiedImageError + +root_dir = Path(__file__).resolve().parent +InputType = Union[str, np.ndarray, bytes, Path, Image.Image] + + +class LoadImage: + def __init__(self): + pass + + def __call__(self, img: InputType) -> np.ndarray: + if not isinstance(img, InputType.__args__): + raise LoadImageError( + f"The img type {type(img)} does not in {InputType.__args__}" + ) + + origin_img_type = type(img) + img = self.load_img(img) + img = self.convert_img(img, origin_img_type) + return img + + def load_img(self, img: InputType) -> np.ndarray: + if isinstance(img, (str, Path)): + self.verify_exist(img) + try: + img = self.img_to_ndarray(Image.open(img)) + except UnidentifiedImageError as e: + raise LoadImageError(f"cannot identify image file {img}") from e + return img + + if isinstance(img, bytes): + img = self.img_to_ndarray(Image.open(BytesIO(img))) + return img + + if isinstance(img, np.ndarray): + return img + + if isinstance(img, Image.Image): + return self.img_to_ndarray(img) + + raise LoadImageError(f"{type(img)} is not supported!") + + def img_to_ndarray(self, img: Image.Image) -> np.ndarray: + if img.mode == "1": + img = img.convert("L") + return np.array(img) + return np.array(img) + + def convert_img(self, img: np.ndarray, origin_img_type: Any) -> np.ndarray: + if img.ndim == 2: + return cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) + + if img.ndim == 3: + channel = img.shape[2] + if channel == 1: + return cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) + + if channel == 2: + return self.cvt_two_to_three(img) + + if channel == 3: + if issubclass(origin_img_type, (str, Path, bytes, Image.Image)): + return cv2.cvtColor(img, cv2.COLOR_RGB2BGR) + return img + + if channel == 4: + return self.cvt_four_to_three(img) + + raise LoadImageError( + f"The channel({channel}) of the img is not in [1, 2, 3, 4]" + ) + + raise LoadImageError(f"The ndim({img.ndim}) of the img is not in [2, 3]") + + @staticmethod + def cvt_two_to_three(img: np.ndarray) -> np.ndarray: + """gray + alpha → BGR""" + img_gray = img[..., 0] + img_bgr = cv2.cvtColor(img_gray, cv2.COLOR_GRAY2BGR) + + img_alpha = img[..., 1] + not_a = cv2.bitwise_not(img_alpha) + not_a = cv2.cvtColor(not_a, cv2.COLOR_GRAY2BGR) + + new_img = cv2.bitwise_and(img_bgr, img_bgr, mask=img_alpha) + new_img = cv2.add(new_img, not_a) + return new_img + + @staticmethod + def cvt_four_to_three(img: np.ndarray) -> np.ndarray: + """RGBA → BGR""" + r, g, b, a = cv2.split(img) + new_img = cv2.merge((b, g, r)) + + not_a = cv2.bitwise_not(a) + not_a = cv2.cvtColor(not_a, cv2.COLOR_GRAY2BGR) + + new_img = cv2.bitwise_and(new_img, new_img, mask=a) + + mean_color = np.mean(new_img) + if mean_color <= 0.0: + new_img = cv2.add(new_img, not_a) + else: + new_img = cv2.bitwise_not(new_img) + return new_img + + @staticmethod + def verify_exist(file_path: Union[str, Path]): + if not Path(file_path).exists(): + raise LoadImageError(f"{file_path} does not exist.") + + +class LoadImageError(Exception): + pass diff --git a/rapid_orientation/utils/logger.py b/rapid_orientation/utils/logger.py new file mode 100644 index 0000000..66522c4 --- /dev/null +++ b/rapid_orientation/utils/logger.py @@ -0,0 +1,21 @@ +# -*- encoding: utf-8 -*- +# @Author: SWHL +# @Contact: liekkaskono@163.com +import logging +from functools import lru_cache + + +@lru_cache(maxsize=32) +def get_logger(name: str) -> logging.Logger: + logger = logging.getLogger(name) + logger.setLevel(logging.DEBUG) + + fmt = "%(asctime)s - %(name)s - %(levelname)s: %(message)s" + format_str = logging.Formatter(fmt) + + sh = logging.StreamHandler() + sh.setLevel(logging.DEBUG) + + logger.addHandler(sh) + sh.setFormatter(format_str) + return logger diff --git a/rapid_orientation/utils/preprocess.py b/rapid_orientation/utils/preprocess.py new file mode 100644 index 0000000..0631b75 --- /dev/null +++ b/rapid_orientation/utils/preprocess.py @@ -0,0 +1,97 @@ +# -*- encoding: utf-8 -*- +# @Author: SWHL +# @Contact: liekkaskono@163.com +import cv2 +import numpy as np + + +class Preprocess: + def __init__(self): + self.resize_img = ResizeImage(resize_short=256) + self.crop_img = CropImage(size=224) + self.normal_img = NormalizeImage() + self.cvt_channel = ToCHWImage() + + def __call__(self, img: np.ndarray): + img = self.resize_img(img) + img = self.crop_img(img) + img = self.normal_img(img) + img = self.cvt_channel(img) + return img + + +class ResizeImage: + def __init__(self, size=None, resize_short=None): + if resize_short is not None and resize_short > 0: + self.resize_short = resize_short + self.w, self.h = None, None + elif size is not None: + self.resize_short = None + self.w = size if isinstance(size, int) else size[0] + self.h = size if isinstance(size, int) else size[1] + else: + raise ValueError( + "invalid params for ReisizeImage for '\ + 'both 'size' and 'resize_short' are None" + ) + + def __call__(self, img: np.ndarray): + img_h, img_w = img.shape[:2] + + w, h = self.w, self.h + if self.resize_short: + percent = float(self.resize_short) / min(img_w, img_h) + w = int(round(img_w * percent)) + h = int(round(img_h * percent)) + return cv2.resize(img, (w, h), interpolation=cv2.INTER_LANCZOS4) + + +class CropImage: + def __init__(self, size): + self.size = size + if isinstance(size, int): + self.size = (size, size) + + def __call__(self, img): + w, h = self.size + img_h, img_w = img.shape[:2] + + if img_h < h or img_w < w: + raise ValueError( + f"The size({h}, {w}) of CropImage must be greater than " + f"size({img_h}, {img_w}) of image." + ) + + w_start = (img_w - w) // 2 + h_start = (img_h - h) // 2 + + w_end = w_start + w + h_end = h_start + h + return img[h_start:h_end, w_start:w_end, :] + + +class NormalizeImage: + def __init__( + self, + ): + self.scale = np.float32(1.0 / 255.0) + mean = [0.485, 0.456, 0.406] + std = [0.229, 0.224, 0.225] + + shape = 1, 1, 3 + self.mean = np.array(mean).reshape(shape).astype("float32") + self.std = np.array(std).reshape(shape).astype("float32") + + def __call__(self, img): + img = np.array(img).astype(np.float32) + img = (img * self.scale - self.mean) / self.std + return img.astype(np.float32) + + +class ToCHWImage: + def __init__(self): + pass + + def __call__(self, img): + img = np.array(img) + return img.transpose((2, 0, 1)) diff --git a/rapid_orientation/utils/utils.py b/rapid_orientation/utils/utils.py new file mode 100644 index 0000000..f55f635 --- /dev/null +++ b/rapid_orientation/utils/utils.py @@ -0,0 +1,10 @@ +# -*- encoding: utf-8 -*- +# @Author: SWHL +# @Contact: liekkaskono@163.com +import yaml + + +def read_yaml(yaml_path): + with open(yaml_path, "rb") as f: + data = yaml.load(f, Loader=yaml.Loader) + return data diff --git a/tests/test_files/table.jpg b/tests/test_files/table.jpg deleted file mode 100644 index 95fdf84..0000000 Binary files a/tests/test_files/table.jpg and /dev/null differ diff --git a/tests/test_orientation.py b/tests/test_orientation.py index bf708df..d8781dc 100644 --- a/tests/test_orientation.py +++ b/tests/test_orientation.py @@ -14,7 +14,6 @@ from rapid_orientation import RapidOrientation - test_file_dir = cur_dir / "test_files" text_orientation = RapidOrientation()