From 0eceb064961c6a9e979de488f19f7e7a37cc7480 Mon Sep 17 00:00:00 2001 From: SWHL Date: Fri, 8 Nov 2024 15:25:58 +0800 Subject: [PATCH] feat: add batch inference and vote to generate results --- demo.py | 34 ++------------- rapid_orientation/main.py | 23 ++-------- rapid_orientation/utils/preprocess.py | 62 ++++++++++++++++++++++----- 3 files changed, 59 insertions(+), 60 deletions(-) diff --git a/demo.py b/demo.py index 75af9d5..530eb5d 100644 --- a/demo.py +++ b/demo.py @@ -5,36 +5,8 @@ from rapid_orientation import RapidOrientation - -def scale_resize(img, resize_value=(280, 32)): - """ - @params: - img: ndarray - resize_value: (width, height) - """ - # padding - ratio = resize_value[0] / resize_value[1] # w / h - h, w = img.shape[:2] - if w / h < ratio: - # 补宽 - t = int(h * ratio) - w_padding = (t - w) // 2 - img = cv2.copyMakeBorder( - img, 0, 0, w_padding, w_padding, cv2.BORDER_CONSTANT, value=(0, 0, 0) - ) - else: - # 补高 (top, bottom, left, right) - t = int(w / ratio) - h_padding = (t - h) // 2 - color = tuple([int(i) for i in img[0, 0, :]]) - img = cv2.copyMakeBorder( - img, h_padding, h_padding, 0, 0, cv2.BORDER_CONSTANT, value=(0, 0, 0) - ) - img = cv2.resize(img, resize_value, interpolation=cv2.INTER_LANCZOS4) - return img - - orientation_engine = RapidOrientation() -img = cv2.imread("tests/test_files/1.png") -cls_result, _ = orientation_engine(img) +img = cv2.imread("tests/test_files/img_rot0_demo.jpg") +cls_result, elapse = orientation_engine(img) print(cls_result) +print(elapse) diff --git a/rapid_orientation/main.py b/rapid_orientation/main.py index aa90fd5..56aaa9c 100644 --- a/rapid_orientation/main.py +++ b/rapid_orientation/main.py @@ -1,16 +1,3 @@ -# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. # -*- encoding: utf-8 -*- # @Author: SWHL # @Contact: liekkaskono@163.com @@ -44,7 +31,7 @@ def __init__( self.session = OrtInferSession(config) self.labels = self.session.get_character_list() - self.preprocess = Preprocess() + self.preprocess = Preprocess(batch_size=3) self.load_img = LoadImage() def __call__(self, img_content: Union[str, np.ndarray, bytes, Path]): @@ -53,13 +40,11 @@ def __call__(self, img_content: Union[str, np.ndarray, bytes, Path]): s = time.perf_counter() image = self.preprocess(image) - image = image[None, ...] pred_output = self.session(image)[0] - - pred_output = pred_output.squeeze() - pred_idx = np.argmax(pred_output) - pred_txt = self.labels[pred_idx] + pred_idxs = list(np.argmax(pred_output, axis=1)) + final_idx = max(set(pred_idxs), key=pred_idxs.count) + pred_txt = self.labels[final_idx] elapse = time.perf_counter() - s return pred_txt, elapse diff --git a/rapid_orientation/utils/preprocess.py b/rapid_orientation/utils/preprocess.py index 0631b75..3dccff6 100644 --- a/rapid_orientation/utils/preprocess.py +++ b/rapid_orientation/utils/preprocess.py @@ -1,23 +1,35 @@ # -*- encoding: utf-8 -*- # @Author: SWHL # @Contact: liekkaskono@163.com +import copy +import random + import cv2 import numpy as np class Preprocess: - def __init__(self): + def __init__(self, batch_size: int = 3): self.resize_img = ResizeImage(resize_short=256) self.crop_img = CropImage(size=224) + self.rand_crop = RandCropImageV2(size=224) self.normal_img = NormalizeImage() self.cvt_channel = ToCHWImage() + self.batch_size = batch_size + def __call__(self, img: np.ndarray): - img = self.resize_img(img) - img = self.crop_img(img) - img = self.normal_img(img) - img = self.cvt_channel(img) - return img + ori_img = self.resize_img(img) + + norm_img_batch = [] + for _ in range(self.batch_size): + img = self.crop_img(copy.deepcopy(ori_img)) + img = self.normal_img(img) + img = self.cvt_channel(img) + img = img[None, ...] + norm_img_batch.append(img) + norm_img_batch = np.concatenate(norm_img_batch).astype(np.float32) + return norm_img_batch class ResizeImage: @@ -43,7 +55,8 @@ def __call__(self, img: np.ndarray): percent = float(self.resize_short) / min(img_w, img_h) w = int(round(img_w * percent)) h = int(round(img_h * percent)) - return cv2.resize(img, (w, h), interpolation=cv2.INTER_LANCZOS4) + + return cv2.resize(img, dsize=(w, h), interpolation=cv2.INTER_LANCZOS4) class CropImage: @@ -70,10 +83,39 @@ def __call__(self, img): return img[h_start:h_end, w_start:w_end, :] +class RandCropImageV2: + """RandCropImageV2 is different from RandCropImage, + it will Select a cutting position randomly in a uniform distribution way, + and cut according to the given size without resize at last. + + Modified from https://github.com/PaddlePaddle/PaddleClas/blob/177e4be74639c0960efeae2c5166d3226c9a02eb/ppcls/data/preprocess/ops/operators.py#L448C1-L479C62 + + """ + + def __init__(self, size): + self.size = size + if isinstance(size, int): + self.size = (size, size) # (h, w) + + def __call__(self, img: np.ndarray): + img_h, img_w = img.shape[0], img.shape[1] + + tw, th = self.size + if img_h + 1 < th or img_w + 1 < tw: + raise ValueError( + f"Required crop size {(th, tw)} is larger then input image size {(img_h, img_w)}" + ) + + if img_w == tw and img_h == th: + return img + + top = random.randint(0, img_h - th + 1) + left = random.randint(0, img_w - tw + 1) + return img[top : top + th, left : left + tw, :] + + class NormalizeImage: - def __init__( - self, - ): + def __init__(self): self.scale = np.float32(1.0 / 255.0) mean = [0.485, 0.456, 0.406] std = [0.229, 0.224, 0.225]