diff --git a/README.md b/README.md
index b6dd8ce..5663b5b 100644
--- a/README.md
+++ b/README.md
@@ -10,6 +10,9 @@
+
+
+
diff --git a/demo.py b/demo.py
index 58d9a0c..75af9d5 100644
--- a/demo.py
+++ b/demo.py
@@ -5,7 +5,36 @@
from rapid_orientation import RapidOrientation
+
+def scale_resize(img, resize_value=(280, 32)):
+ """
+ @params:
+ img: ndarray
+ resize_value: (width, height)
+ """
+ # padding
+ ratio = resize_value[0] / resize_value[1] # w / h
+ h, w = img.shape[:2]
+ if w / h < ratio:
+ # 补宽
+ t = int(h * ratio)
+ w_padding = (t - w) // 2
+ img = cv2.copyMakeBorder(
+ img, 0, 0, w_padding, w_padding, cv2.BORDER_CONSTANT, value=(0, 0, 0)
+ )
+ else:
+ # 补高 (top, bottom, left, right)
+ t = int(w / ratio)
+ h_padding = (t - h) // 2
+ color = tuple([int(i) for i in img[0, 0, :]])
+ img = cv2.copyMakeBorder(
+ img, h_padding, h_padding, 0, 0, cv2.BORDER_CONSTANT, value=(0, 0, 0)
+ )
+ img = cv2.resize(img, resize_value, interpolation=cv2.INTER_LANCZOS4)
+ return img
+
+
orientation_engine = RapidOrientation()
-img = cv2.imread("tests/test_files/img_rot180_demo.jpg")
+img = cv2.imread("tests/test_files/1.png")
cls_result, _ = orientation_engine(img)
print(cls_result)
diff --git a/rapid_orientation/config.yaml b/rapid_orientation/config.yaml
index 78847b2..cf96003 100644
--- a/rapid_orientation/config.yaml
+++ b/rapid_orientation/config.yaml
@@ -6,11 +6,3 @@ CUDAExecutionProvider:
arena_extend_strategy: kNextPowerOfTwo
cudnn_conv_algo_search: EXHAUSTIVE
do_copy_in_default_stream: true
-
-PreProcess:
- - ResizeImage:
- resize_short: 256
- - CropImage:
- size: 224
- - NormalizeImage:
- - ToCHWImage:
diff --git a/rapid_orientation/main.py b/rapid_orientation/main.py
index 387ba6d..aa90fd5 100644
--- a/rapid_orientation/main.py
+++ b/rapid_orientation/main.py
@@ -17,53 +17,52 @@
import argparse
import time
from pathlib import Path
-from typing import Optional, Union
+from typing import Union
import cv2
import numpy as np
-import yaml
-from .utils import LoadImage, OrtInferSession, create_operators
+from .utils.infer_engine import OrtInferSession
+from .utils.load_image import LoadImage
+from .utils.preprocess import Preprocess
+from .utils.utils import read_yaml
root_dir = Path(__file__).resolve().parent
+DEFAULT_PATH = root_dir / "models" / "rapid_orientation.onnx"
+DEFAULT_CFG = root_dir / "config.yaml"
class RapidOrientation:
- def __init__(self, model_path: Optional[str] = None):
- config_path = str(root_dir / "config.yaml")
- config = self.read_yaml(config_path)
- if model_path is None:
- model_path = str(root_dir / "models" / "rapid_orientation.onnx")
+ def __init__(
+ self,
+ model_path: Union[str, Path] = DEFAULT_PATH,
+ cfg_path: Union[str, Path] = DEFAULT_CFG,
+ ):
+ config = read_yaml(cfg_path)
config["model_path"] = model_path
self.session = OrtInferSession(config)
- self.labels = self.session.get_metadata()["character"].splitlines()
-
- self.preprocess_ops = create_operators(config["PreProcess"])
+ self.labels = self.session.get_character_list()
+ self.preprocess = Preprocess()
self.load_img = LoadImage()
def __call__(self, img_content: Union[str, np.ndarray, bytes, Path]):
- images = self.load_img(img_content)
+ image = self.load_img(img_content)
+
+ s = time.perf_counter()
- s = time.time()
- for ops in self.preprocess_ops:
- images = ops(images)
- image = np.array(images)[None, ...]
+ image = self.preprocess(image)
+ image = image[None, ...]
pred_output = self.session(image)[0]
pred_output = pred_output.squeeze()
pred_idx = np.argmax(pred_output)
pred_txt = self.labels[pred_idx]
- elapse = time.time() - s
- return pred_txt, elapse
- @staticmethod
- def read_yaml(yaml_path):
- with open(yaml_path, "rb") as f:
- data = yaml.load(f, Loader=yaml.Loader)
- return data
+ elapse = time.perf_counter() - s
+ return pred_txt, elapse
def main():
diff --git a/rapid_orientation/utils.py b/rapid_orientation/utils.py
deleted file mode 100644
index 018e577..0000000
--- a/rapid_orientation/utils.py
+++ /dev/null
@@ -1,258 +0,0 @@
-# -*- encoding: utf-8 -*-
-# @Author: SWHL
-# @Contact: liekkaskono@163.com
-import importlib
-import warnings
-from io import BytesIO
-from pathlib import Path
-from typing import Union
-
-import cv2
-import numpy as np
-from onnxruntime import (
- GraphOptimizationLevel,
- InferenceSession,
- SessionOptions,
- get_available_providers,
- get_device,
-)
-from PIL import Image, UnidentifiedImageError
-
-InputType = Union[str, np.ndarray, bytes, Path]
-
-
-class OrtInferSession:
- def __init__(self, config):
- sess_opt = SessionOptions()
- sess_opt.log_severity_level = 4
- sess_opt.enable_cpu_mem_arena = False
- sess_opt.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_ALL
-
- cuda_ep = "CUDAExecutionProvider"
- cpu_ep = "CPUExecutionProvider"
- cpu_provider_options = {
- "arena_extend_strategy": "kSameAsRequested",
- }
-
- EP_list = []
- if (
- config["use_cuda"]
- and get_device() == "GPU"
- and cuda_ep in get_available_providers()
- ):
- EP_list = [(cuda_ep, config[cuda_ep])]
- EP_list.append((cpu_ep, cpu_provider_options))
-
- self._verify_model(config["model_path"])
- self.session = InferenceSession(
- config["model_path"], sess_options=sess_opt, providers=EP_list
- )
-
- has_cuda_ep = cuda_ep not in self.session.get_providers()
- if config["use_cuda"] and has_cuda_ep:
- warnings.warn(
- f"{cuda_ep} is not avaiable for current env,"
- f"the inference part is automatically shifted to "
- f"be executed under {cpu_ep}. "
- f"Please ensure the installed onnxruntime-gpu "
- f" version matches your cuda and cudnn version, "
- f"you can check their relations from the offical web site: "
- f"https://onnxruntime.ai/docs/execution-providers/CUDA-ExecutionProvider.html",
- RuntimeWarning,
- )
-
- def __call__(self, input_content: np.ndarray) -> np.ndarray:
- input_dict = dict(zip(self.get_input_names(), [input_content]))
- try:
- return self.session.run(self.get_output_names(), input_dict)
- except Exception as e:
- raise ONNXRuntimeError("ONNXRuntime inferece failed.") from e
-
- def get_input_names(
- self,
- ):
- return [v.name for v in self.session.get_inputs()]
-
- def get_output_names(
- self,
- ):
- return [v.name for v in self.session.get_outputs()]
-
- def get_metadata(self):
- meta_dict = self.session.get_modelmeta().custom_metadata_map
- return meta_dict
-
- @staticmethod
- def _verify_model(model_path):
- model_path = Path(model_path)
- if not model_path.exists():
- raise FileNotFoundError(f"{model_path} does not exists.")
- if not model_path.is_file():
- raise FileExistsError(f"{model_path} is not a file.")
-
-
-class ONNXRuntimeError(Exception):
- pass
-
-
-class LoadImage:
- def __init__(
- self,
- ):
- pass
-
- def __call__(self, img: InputType) -> np.ndarray:
- if not isinstance(img, InputType.__args__):
- raise LoadImageError(
- f"The img type {type(img)} does not in {InputType.__args__}"
- )
-
- img = self.load_img(img)
-
- if img.ndim == 2:
- return cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
-
- if img.ndim == 3 and img.shape[2] == 4:
- return self.cvt_four_to_three(img)
-
- return img
-
- def load_img(self, img: InputType) -> np.ndarray:
- if isinstance(img, (str, Path)):
- self.verify_exist(img)
- try:
- img = np.array(Image.open(img))
- img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
- except UnidentifiedImageError as e:
- raise LoadImageError(f"cannot identify image file {img}") from e
- return img
-
- if isinstance(img, bytes):
- img = np.array(Image.open(BytesIO(img)))
- img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
- return img
-
- if isinstance(img, np.ndarray):
- return img
-
- raise LoadImageError(f"{type(img)} is not supported!")
-
- @staticmethod
- def cvt_four_to_three(img: np.ndarray) -> np.ndarray:
- """RGBA → RGB"""
- r, g, b, a = cv2.split(img)
- new_img = cv2.merge((b, g, r))
-
- not_a = cv2.bitwise_not(a)
- not_a = cv2.cvtColor(not_a, cv2.COLOR_GRAY2BGR)
-
- new_img = cv2.bitwise_and(new_img, new_img, mask=a)
- new_img = cv2.add(new_img, not_a)
- return new_img
-
- @staticmethod
- def verify_exist(file_path: Union[str, Path]):
- if not Path(file_path).exists():
- raise LoadImageError(f"{file_path} does not exist.")
-
-
-class LoadImageError(Exception):
- pass
-
-
-def create_operators(params):
- """
- create operators based on the config
-
- Args:
- params(list): a dict list, used to create some operators
- """
- assert isinstance(params, list), "operator config should be a list"
- mod = importlib.import_module(__name__)
- ops = []
- for operator in params:
- assert isinstance(operator, dict) and len(operator) == 1, "yaml format error"
- op_name = list(operator)[0]
- param = {} if operator[op_name] is None else operator[op_name]
- op = getattr(mod, op_name)(**param)
- ops.append(op)
- return ops
-
-
-class ResizeImage:
- def __init__(self, size=None, resize_short=None):
- if resize_short is not None and resize_short > 0:
- self.resize_short = resize_short
- self.w, self.h = None, None
- elif size is not None:
- self.resize_short = None
- self.w = size if isinstance(size, int) else size[0]
- self.h = size if isinstance(size, int) else size[1]
- else:
- raise ValueError(
- "invalid params for ReisizeImage for '\
- 'both 'size' and 'resize_short' are None"
- )
-
- def __call__(self, img: np.ndarray):
- img_h, img_w = img.shape[:2]
-
- if self.resize_short:
- percent = float(self.resize_short) / min(img_w, img_h)
- w = int(round(img_w * percent))
- h = int(round(img_h * percent))
- else:
- w = self.w
- h = self.h
- return cv2.resize(img, (w, h))
-
-
-class NormalizeImage:
- def __init__(
- self,
- ):
- self.scale = np.float32(1.0 / 255.0)
- mean = [0.485, 0.456, 0.406]
- std = [0.229, 0.224, 0.225]
-
- shape = 1, 1, 3
- self.mean = np.array(mean).reshape(shape).astype("float32")
- self.std = np.array(std).reshape(shape).astype("float32")
-
- def __call__(self, img):
- img = np.array(img).astype(np.float32)
- img = (img * self.scale - self.mean) / self.std
- return img.astype(np.float32)
-
-
-class ToCHWImage:
- def __init__(self):
- pass
-
- def __call__(self, img):
- img = np.array(img)
- return img.transpose((2, 0, 1))
-
-
-class CropImage:
- def __init__(self, size):
- self.size = size
- if isinstance(size, int):
- self.size = (size, size)
-
- def __call__(self, img):
- w, h = self.size
- img_h, img_w = img.shape[:2]
-
- if img_h < h or img_w < w:
- raise ValueError(
- f"The size({h}, {w}) of CropImage must be greater than "
- f"size({img_h}, {img_w}) of image."
- )
-
- w_start = (img_w - w) // 2
- h_start = (img_h - h) // 2
-
- w_end = w_start + w
- h_end = h_start + h
- return img[h_start:h_end, w_start:w_end, :]
diff --git a/rapid_orientation/utils/__init__.py b/rapid_orientation/utils/__init__.py
new file mode 100644
index 0000000..0ecdd4f
--- /dev/null
+++ b/rapid_orientation/utils/__init__.py
@@ -0,0 +1,3 @@
+# -*- encoding: utf-8 -*-
+# @Author: SWHL
+# @Contact: liekkaskono@163.com
diff --git a/rapid_orientation/utils/infer_engine.py b/rapid_orientation/utils/infer_engine.py
new file mode 100644
index 0000000..fb1fd35
--- /dev/null
+++ b/rapid_orientation/utils/infer_engine.py
@@ -0,0 +1,231 @@
+# -*- encoding: utf-8 -*-
+# @Author: SWHL
+# @Contact: liekkaskono@163.com
+import os
+import platform
+import traceback
+from enum import Enum
+from pathlib import Path
+from typing import Any, Dict, List, Tuple, Union
+
+import numpy as np
+from onnxruntime import (
+ GraphOptimizationLevel,
+ InferenceSession,
+ SessionOptions,
+ get_available_providers,
+ get_device,
+)
+
+from .logger import get_logger
+
+
+class EP(Enum):
+ CPU_EP = "CPUExecutionProvider"
+ CUDA_EP = "CUDAExecutionProvider"
+ DIRECTML_EP = "DmlExecutionProvider"
+
+
+class OrtInferSession:
+ def __init__(self, config: Dict[str, Any]):
+ self.logger = get_logger("OrtInferSession")
+
+ model_path = config.get("model_path", None)
+ self._verify_model(model_path)
+
+ self.cfg_use_cuda = config.get("use_cuda", None)
+ self.cfg_use_dml = config.get("use_dml", None)
+
+ self.had_providers: List[str] = get_available_providers()
+ EP_list = self._get_ep_list()
+
+ sess_opt = self._init_sess_opts(config)
+ self.session = InferenceSession(
+ model_path,
+ sess_options=sess_opt,
+ providers=EP_list,
+ )
+ self._verify_providers()
+
+ @staticmethod
+ def _init_sess_opts(config: Dict[str, Any]) -> SessionOptions:
+ sess_opt = SessionOptions()
+ sess_opt.log_severity_level = 4
+ sess_opt.enable_cpu_mem_arena = False
+ sess_opt.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_ALL
+
+ cpu_nums = os.cpu_count()
+ intra_op_num_threads = config.get("intra_op_num_threads", -1)
+ if intra_op_num_threads != -1 and 1 <= intra_op_num_threads <= cpu_nums:
+ sess_opt.intra_op_num_threads = intra_op_num_threads
+
+ inter_op_num_threads = config.get("inter_op_num_threads", -1)
+ if inter_op_num_threads != -1 and 1 <= inter_op_num_threads <= cpu_nums:
+ sess_opt.inter_op_num_threads = inter_op_num_threads
+
+ return sess_opt
+
+ def _get_ep_list(self) -> List[Tuple[str, Dict[str, Any]]]:
+ cpu_provider_opts = {
+ "arena_extend_strategy": "kSameAsRequested",
+ }
+ EP_list = [(EP.CPU_EP.value, cpu_provider_opts)]
+
+ cuda_provider_opts = {
+ "device_id": 0,
+ "arena_extend_strategy": "kNextPowerOfTwo",
+ "cudnn_conv_algo_search": "EXHAUSTIVE",
+ "do_copy_in_default_stream": True,
+ }
+ self.use_cuda = self._check_cuda()
+ if self.use_cuda:
+ EP_list.insert(0, (EP.CUDA_EP.value, cuda_provider_opts))
+
+ self.use_directml = self._check_dml()
+ if self.use_directml:
+ self.logger.info(
+ "Windows 10 or above detected, try to use DirectML as primary provider"
+ )
+ directml_options = (
+ cuda_provider_opts if self.use_cuda else cpu_provider_opts
+ )
+ EP_list.insert(0, (EP.DIRECTML_EP.value, directml_options))
+ return EP_list
+
+ def _check_cuda(self) -> bool:
+ if not self.cfg_use_cuda:
+ return False
+
+ cur_device = get_device()
+ if cur_device == "GPU" and EP.CUDA_EP.value in self.had_providers:
+ return True
+
+ self.logger.warning(
+ "%s is not in available providers (%s). Use %s inference by default.",
+ EP.CUDA_EP.value,
+ self.had_providers,
+ self.had_providers[0],
+ )
+ self.logger.info("!!!Recommend to use rapidocr_paddle for inference on GPU.")
+ self.logger.info(
+ "(For reference only) If you want to use GPU acceleration, you must do:"
+ )
+ self.logger.info(
+ "First, uninstall all onnxruntime pakcages in current environment."
+ )
+ self.logger.info(
+ "Second, install onnxruntime-gpu by `pip install onnxruntime-gpu`."
+ )
+ self.logger.info(
+ "\tNote the onnxruntime-gpu version must match your cuda and cudnn version."
+ )
+ self.logger.info(
+ "\tYou can refer this link: https://onnxruntime.ai/docs/execution-providers/CUDA-EP.html"
+ )
+ self.logger.info(
+ "Third, ensure %s is in available providers list. e.g. ['CUDAExecutionProvider', 'CPUExecutionProvider']",
+ EP.CUDA_EP.value,
+ )
+ return False
+
+ def _check_dml(self) -> bool:
+ if not self.cfg_use_dml:
+ return False
+
+ cur_os = platform.system()
+ if cur_os != "Windows":
+ self.logger.warning(
+ "DirectML is only supported in Windows OS. The current OS is %s. Use %s inference by default.",
+ cur_os,
+ self.had_providers[0],
+ )
+ return False
+
+ cur_window_version = int(platform.release().split(".")[0])
+ if cur_window_version < 10:
+ self.logger.warning(
+ "DirectML is only supported in Windows 10 and above OS. The current Windows version is %s. Use %s inference by default.",
+ cur_window_version,
+ self.had_providers[0],
+ )
+ return False
+
+ if EP.DIRECTML_EP.value in self.had_providers:
+ return True
+
+ self.logger.warning(
+ "%s is not in available providers (%s). Use %s inference by default.",
+ EP.DIRECTML_EP.value,
+ self.had_providers,
+ self.had_providers[0],
+ )
+ self.logger.info("If you want to use DirectML acceleration, you must do:")
+ self.logger.info(
+ "First, uninstall all onnxruntime pakcages in current environment."
+ )
+ self.logger.info(
+ "Second, install onnxruntime-directml by `pip install onnxruntime-directml`"
+ )
+ self.logger.info(
+ "Third, ensure %s is in available providers list. e.g. ['DmlExecutionProvider', 'CPUExecutionProvider']",
+ EP.DIRECTML_EP.value,
+ )
+ return False
+
+ def _verify_providers(self):
+ session_providers = self.session.get_providers()
+ first_provider = session_providers[0]
+
+ if self.use_cuda and first_provider != EP.CUDA_EP.value:
+ self.logger.warning(
+ "%s is not avaiable for current env, the inference part is automatically shifted to be executed under %s.",
+ EP.CUDA_EP.value,
+ first_provider,
+ )
+
+ if self.use_directml and first_provider != EP.DIRECTML_EP.value:
+ self.logger.warning(
+ "%s is not available for current env, the inference part is automatically shifted to be executed under %s.",
+ EP.DIRECTML_EP.value,
+ first_provider,
+ )
+
+ def __call__(self, input_content: np.ndarray) -> np.ndarray:
+ input_dict = dict(zip(self.get_input_names(), [input_content]))
+ try:
+ return self.session.run(self.get_output_names(), input_dict)
+ except Exception as e:
+ error_info = traceback.format_exc()
+ raise ONNXRuntimeError(error_info) from e
+
+ def get_input_names(self) -> List[str]:
+ return [v.name for v in self.session.get_inputs()]
+
+ def get_output_names(self) -> List[str]:
+ return [v.name for v in self.session.get_outputs()]
+
+ def get_character_list(self, key: str = "character") -> List[str]:
+ meta_dict = self.session.get_modelmeta().custom_metadata_map
+ return meta_dict[key].splitlines()
+
+ def have_key(self, key: str = "character") -> bool:
+ meta_dict = self.session.get_modelmeta().custom_metadata_map
+ if key in meta_dict.keys():
+ return True
+ return False
+
+ @staticmethod
+ def _verify_model(model_path: Union[str, Path, None]):
+ if model_path is None:
+ raise ValueError("model_path is None!")
+
+ model_path = Path(model_path)
+ if not model_path.exists():
+ raise FileNotFoundError(f"{model_path} does not exists.")
+
+ if not model_path.is_file():
+ raise FileExistsError(f"{model_path} is not a file.")
+
+
+class ONNXRuntimeError(Exception):
+ pass
diff --git a/rapid_orientation/utils/load_image.py b/rapid_orientation/utils/load_image.py
new file mode 100644
index 0000000..f34b549
--- /dev/null
+++ b/rapid_orientation/utils/load_image.py
@@ -0,0 +1,123 @@
+# -*- encoding: utf-8 -*-
+# @Author: SWHL
+# @Contact: liekkaskono@163.com
+from io import BytesIO
+from pathlib import Path
+from typing import Any, Union
+
+import cv2
+import numpy as np
+from PIL import Image, UnidentifiedImageError
+
+root_dir = Path(__file__).resolve().parent
+InputType = Union[str, np.ndarray, bytes, Path, Image.Image]
+
+
+class LoadImage:
+ def __init__(self):
+ pass
+
+ def __call__(self, img: InputType) -> np.ndarray:
+ if not isinstance(img, InputType.__args__):
+ raise LoadImageError(
+ f"The img type {type(img)} does not in {InputType.__args__}"
+ )
+
+ origin_img_type = type(img)
+ img = self.load_img(img)
+ img = self.convert_img(img, origin_img_type)
+ return img
+
+ def load_img(self, img: InputType) -> np.ndarray:
+ if isinstance(img, (str, Path)):
+ self.verify_exist(img)
+ try:
+ img = self.img_to_ndarray(Image.open(img))
+ except UnidentifiedImageError as e:
+ raise LoadImageError(f"cannot identify image file {img}") from e
+ return img
+
+ if isinstance(img, bytes):
+ img = self.img_to_ndarray(Image.open(BytesIO(img)))
+ return img
+
+ if isinstance(img, np.ndarray):
+ return img
+
+ if isinstance(img, Image.Image):
+ return self.img_to_ndarray(img)
+
+ raise LoadImageError(f"{type(img)} is not supported!")
+
+ def img_to_ndarray(self, img: Image.Image) -> np.ndarray:
+ if img.mode == "1":
+ img = img.convert("L")
+ return np.array(img)
+ return np.array(img)
+
+ def convert_img(self, img: np.ndarray, origin_img_type: Any) -> np.ndarray:
+ if img.ndim == 2:
+ return cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
+
+ if img.ndim == 3:
+ channel = img.shape[2]
+ if channel == 1:
+ return cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
+
+ if channel == 2:
+ return self.cvt_two_to_three(img)
+
+ if channel == 3:
+ if issubclass(origin_img_type, (str, Path, bytes, Image.Image)):
+ return cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
+ return img
+
+ if channel == 4:
+ return self.cvt_four_to_three(img)
+
+ raise LoadImageError(
+ f"The channel({channel}) of the img is not in [1, 2, 3, 4]"
+ )
+
+ raise LoadImageError(f"The ndim({img.ndim}) of the img is not in [2, 3]")
+
+ @staticmethod
+ def cvt_two_to_three(img: np.ndarray) -> np.ndarray:
+ """gray + alpha → BGR"""
+ img_gray = img[..., 0]
+ img_bgr = cv2.cvtColor(img_gray, cv2.COLOR_GRAY2BGR)
+
+ img_alpha = img[..., 1]
+ not_a = cv2.bitwise_not(img_alpha)
+ not_a = cv2.cvtColor(not_a, cv2.COLOR_GRAY2BGR)
+
+ new_img = cv2.bitwise_and(img_bgr, img_bgr, mask=img_alpha)
+ new_img = cv2.add(new_img, not_a)
+ return new_img
+
+ @staticmethod
+ def cvt_four_to_three(img: np.ndarray) -> np.ndarray:
+ """RGBA → BGR"""
+ r, g, b, a = cv2.split(img)
+ new_img = cv2.merge((b, g, r))
+
+ not_a = cv2.bitwise_not(a)
+ not_a = cv2.cvtColor(not_a, cv2.COLOR_GRAY2BGR)
+
+ new_img = cv2.bitwise_and(new_img, new_img, mask=a)
+
+ mean_color = np.mean(new_img)
+ if mean_color <= 0.0:
+ new_img = cv2.add(new_img, not_a)
+ else:
+ new_img = cv2.bitwise_not(new_img)
+ return new_img
+
+ @staticmethod
+ def verify_exist(file_path: Union[str, Path]):
+ if not Path(file_path).exists():
+ raise LoadImageError(f"{file_path} does not exist.")
+
+
+class LoadImageError(Exception):
+ pass
diff --git a/rapid_orientation/utils/logger.py b/rapid_orientation/utils/logger.py
new file mode 100644
index 0000000..66522c4
--- /dev/null
+++ b/rapid_orientation/utils/logger.py
@@ -0,0 +1,21 @@
+# -*- encoding: utf-8 -*-
+# @Author: SWHL
+# @Contact: liekkaskono@163.com
+import logging
+from functools import lru_cache
+
+
+@lru_cache(maxsize=32)
+def get_logger(name: str) -> logging.Logger:
+ logger = logging.getLogger(name)
+ logger.setLevel(logging.DEBUG)
+
+ fmt = "%(asctime)s - %(name)s - %(levelname)s: %(message)s"
+ format_str = logging.Formatter(fmt)
+
+ sh = logging.StreamHandler()
+ sh.setLevel(logging.DEBUG)
+
+ logger.addHandler(sh)
+ sh.setFormatter(format_str)
+ return logger
diff --git a/rapid_orientation/utils/preprocess.py b/rapid_orientation/utils/preprocess.py
new file mode 100644
index 0000000..0631b75
--- /dev/null
+++ b/rapid_orientation/utils/preprocess.py
@@ -0,0 +1,97 @@
+# -*- encoding: utf-8 -*-
+# @Author: SWHL
+# @Contact: liekkaskono@163.com
+import cv2
+import numpy as np
+
+
+class Preprocess:
+ def __init__(self):
+ self.resize_img = ResizeImage(resize_short=256)
+ self.crop_img = CropImage(size=224)
+ self.normal_img = NormalizeImage()
+ self.cvt_channel = ToCHWImage()
+
+ def __call__(self, img: np.ndarray):
+ img = self.resize_img(img)
+ img = self.crop_img(img)
+ img = self.normal_img(img)
+ img = self.cvt_channel(img)
+ return img
+
+
+class ResizeImage:
+ def __init__(self, size=None, resize_short=None):
+ if resize_short is not None and resize_short > 0:
+ self.resize_short = resize_short
+ self.w, self.h = None, None
+ elif size is not None:
+ self.resize_short = None
+ self.w = size if isinstance(size, int) else size[0]
+ self.h = size if isinstance(size, int) else size[1]
+ else:
+ raise ValueError(
+ "invalid params for ReisizeImage for '\
+ 'both 'size' and 'resize_short' are None"
+ )
+
+ def __call__(self, img: np.ndarray):
+ img_h, img_w = img.shape[:2]
+
+ w, h = self.w, self.h
+ if self.resize_short:
+ percent = float(self.resize_short) / min(img_w, img_h)
+ w = int(round(img_w * percent))
+ h = int(round(img_h * percent))
+ return cv2.resize(img, (w, h), interpolation=cv2.INTER_LANCZOS4)
+
+
+class CropImage:
+ def __init__(self, size):
+ self.size = size
+ if isinstance(size, int):
+ self.size = (size, size)
+
+ def __call__(self, img):
+ w, h = self.size
+ img_h, img_w = img.shape[:2]
+
+ if img_h < h or img_w < w:
+ raise ValueError(
+ f"The size({h}, {w}) of CropImage must be greater than "
+ f"size({img_h}, {img_w}) of image."
+ )
+
+ w_start = (img_w - w) // 2
+ h_start = (img_h - h) // 2
+
+ w_end = w_start + w
+ h_end = h_start + h
+ return img[h_start:h_end, w_start:w_end, :]
+
+
+class NormalizeImage:
+ def __init__(
+ self,
+ ):
+ self.scale = np.float32(1.0 / 255.0)
+ mean = [0.485, 0.456, 0.406]
+ std = [0.229, 0.224, 0.225]
+
+ shape = 1, 1, 3
+ self.mean = np.array(mean).reshape(shape).astype("float32")
+ self.std = np.array(std).reshape(shape).astype("float32")
+
+ def __call__(self, img):
+ img = np.array(img).astype(np.float32)
+ img = (img * self.scale - self.mean) / self.std
+ return img.astype(np.float32)
+
+
+class ToCHWImage:
+ def __init__(self):
+ pass
+
+ def __call__(self, img):
+ img = np.array(img)
+ return img.transpose((2, 0, 1))
diff --git a/rapid_orientation/utils/utils.py b/rapid_orientation/utils/utils.py
new file mode 100644
index 0000000..f55f635
--- /dev/null
+++ b/rapid_orientation/utils/utils.py
@@ -0,0 +1,10 @@
+# -*- encoding: utf-8 -*-
+# @Author: SWHL
+# @Contact: liekkaskono@163.com
+import yaml
+
+
+def read_yaml(yaml_path):
+ with open(yaml_path, "rb") as f:
+ data = yaml.load(f, Loader=yaml.Loader)
+ return data
diff --git a/tests/test_files/table.jpg b/tests/test_files/table.jpg
deleted file mode 100644
index 95fdf84..0000000
Binary files a/tests/test_files/table.jpg and /dev/null differ
diff --git a/tests/test_orientation.py b/tests/test_orientation.py
index bf708df..d8781dc 100644
--- a/tests/test_orientation.py
+++ b/tests/test_orientation.py
@@ -14,7 +14,6 @@
from rapid_orientation import RapidOrientation
-
test_file_dir = cur_dir / "test_files"
text_orientation = RapidOrientation()