Skip to content

Commit

Permalink
feat: add batch inference and vote to generate results
Browse files Browse the repository at this point in the history
  • Loading branch information
SWHL committed Nov 8, 2024
1 parent 1fcb4bc commit 0eceb06
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 60 deletions.
34 changes: 3 additions & 31 deletions demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,36 +5,8 @@

from rapid_orientation import RapidOrientation


def scale_resize(img, resize_value=(280, 32)):
"""
@params:
img: ndarray
resize_value: (width, height)
"""
# padding
ratio = resize_value[0] / resize_value[1] # w / h
h, w = img.shape[:2]
if w / h < ratio:
# 补宽
t = int(h * ratio)
w_padding = (t - w) // 2
img = cv2.copyMakeBorder(
img, 0, 0, w_padding, w_padding, cv2.BORDER_CONSTANT, value=(0, 0, 0)
)
else:
# 补高 (top, bottom, left, right)
t = int(w / ratio)
h_padding = (t - h) // 2
color = tuple([int(i) for i in img[0, 0, :]])
img = cv2.copyMakeBorder(
img, h_padding, h_padding, 0, 0, cv2.BORDER_CONSTANT, value=(0, 0, 0)
)
img = cv2.resize(img, resize_value, interpolation=cv2.INTER_LANCZOS4)
return img


orientation_engine = RapidOrientation()
img = cv2.imread("tests/test_files/1.png")
cls_result, _ = orientation_engine(img)
img = cv2.imread("tests/test_files/img_rot0_demo.jpg")
cls_result, elapse = orientation_engine(img)
print(cls_result)
print(elapse)
23 changes: 4 additions & 19 deletions rapid_orientation/main.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,3 @@
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# -*- encoding: utf-8 -*-
# @Author: SWHL
# @Contact: [email protected]
Expand Down Expand Up @@ -44,7 +31,7 @@ def __init__(
self.session = OrtInferSession(config)
self.labels = self.session.get_character_list()

self.preprocess = Preprocess()
self.preprocess = Preprocess(batch_size=3)
self.load_img = LoadImage()

def __call__(self, img_content: Union[str, np.ndarray, bytes, Path]):
Expand All @@ -53,13 +40,11 @@ def __call__(self, img_content: Union[str, np.ndarray, bytes, Path]):
s = time.perf_counter()

image = self.preprocess(image)
image = image[None, ...]

pred_output = self.session(image)[0]

pred_output = pred_output.squeeze()
pred_idx = np.argmax(pred_output)
pred_txt = self.labels[pred_idx]
pred_idxs = list(np.argmax(pred_output, axis=1))
final_idx = max(set(pred_idxs), key=pred_idxs.count)
pred_txt = self.labels[final_idx]

elapse = time.perf_counter() - s
return pred_txt, elapse
Expand Down
62 changes: 52 additions & 10 deletions rapid_orientation/utils/preprocess.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,35 @@
# -*- encoding: utf-8 -*-
# @Author: SWHL
# @Contact: [email protected]
import copy
import random

import cv2
import numpy as np


class Preprocess:
def __init__(self):
def __init__(self, batch_size: int = 3):
self.resize_img = ResizeImage(resize_short=256)
self.crop_img = CropImage(size=224)
self.rand_crop = RandCropImageV2(size=224)
self.normal_img = NormalizeImage()
self.cvt_channel = ToCHWImage()

self.batch_size = batch_size

def __call__(self, img: np.ndarray):
img = self.resize_img(img)
img = self.crop_img(img)
img = self.normal_img(img)
img = self.cvt_channel(img)
return img
ori_img = self.resize_img(img)

norm_img_batch = []
for _ in range(self.batch_size):
img = self.crop_img(copy.deepcopy(ori_img))
img = self.normal_img(img)
img = self.cvt_channel(img)
img = img[None, ...]
norm_img_batch.append(img)
norm_img_batch = np.concatenate(norm_img_batch).astype(np.float32)
return norm_img_batch


class ResizeImage:
Expand All @@ -43,7 +55,8 @@ def __call__(self, img: np.ndarray):
percent = float(self.resize_short) / min(img_w, img_h)
w = int(round(img_w * percent))
h = int(round(img_h * percent))
return cv2.resize(img, (w, h), interpolation=cv2.INTER_LANCZOS4)

return cv2.resize(img, dsize=(w, h), interpolation=cv2.INTER_LANCZOS4)


class CropImage:
Expand All @@ -70,10 +83,39 @@ def __call__(self, img):
return img[h_start:h_end, w_start:w_end, :]


class RandCropImageV2:
"""RandCropImageV2 is different from RandCropImage,
it will Select a cutting position randomly in a uniform distribution way,
and cut according to the given size without resize at last.
Modified from https://github.com/PaddlePaddle/PaddleClas/blob/177e4be74639c0960efeae2c5166d3226c9a02eb/ppcls/data/preprocess/ops/operators.py#L448C1-L479C62
"""

def __init__(self, size):
self.size = size
if isinstance(size, int):
self.size = (size, size) # (h, w)

def __call__(self, img: np.ndarray):
img_h, img_w = img.shape[0], img.shape[1]

tw, th = self.size
if img_h + 1 < th or img_w + 1 < tw:
raise ValueError(
f"Required crop size {(th, tw)} is larger then input image size {(img_h, img_w)}"
)

if img_w == tw and img_h == th:
return img

top = random.randint(0, img_h - th + 1)
left = random.randint(0, img_w - tw + 1)
return img[top : top + th, left : left + tw, :]


class NormalizeImage:
def __init__(
self,
):
def __init__(self):
self.scale = np.float32(1.0 / 255.0)
mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]
Expand Down

0 comments on commit 0eceb06

Please sign in to comment.