diff --git a/README.md b/README.md
index 042dd94..1b6c44f 100644
--- a/README.md
+++ b/README.md
@@ -22,10 +22,14 @@
|`pp_layout_table`| 表格 | `layout_table.onnx` |`table` |
| `pp_layout_publaynet`| 英文 | `layout_publaynet.onnx` |`text title list table figure` |
| `pp_layout_table`| 中文 | `layout_cdla.onnx` | `text title figure figure_caption table table_caption`
`header footer reference equation` |
+| `yolov8n_layout_paper`| 论文 | `yolov8n_layout_paper.onnx` | `text title figure figure_caption table table_caption`
`header footer reference equation` |
+| `yolov8n_layout_report`| 研报 | `yolov8n_layout_report.onnx` | `text title header footer figure figure_caption table table_caption`
`toc` |
-模型来源:[PaddleOCR 版面分析](https://github.com/PaddlePaddle/PaddleOCR/blob/133d67f27dc8a241d6b2e30a9f047a0fb75bebbe/ppstructure/layout/README_ch.md)
+PP模型来源:[PaddleOCR 版面分析](https://github.com/PaddlePaddle/PaddleOCR/blob/133d67f27dc8a241d6b2e30a9f047a0fb75bebbe/ppstructure/layout/README_ch.md)
-模型下载地址为:[百度网盘](https://pan.baidu.com/s/1PI9fksW6F6kQfJhwUkewWg?pwd=p29g) | [Google Drive](https://drive.google.com/drive/folders/1DAPWSN2zGQ-ED_Pz7RaJGTjfkN2-Mvsf?usp=sharing)
+yolov8n系列来源:[360LayoutAnalysis](https://github.com/360AILAB-NLP/360LayoutAnalysis)
+
+模型下载地址为:[link](https://github.com/RapidAI/RapidLayout/releases/tag/v0.0.0)
### 安装
由于模型较小,预先将中文版面分析模型(`layout_cdla.onnx`)打包进了whl包内,如果做中文版面分析,可直接安装使用
@@ -41,7 +45,7 @@ import cv2
from rapid_layout import RapidLayout, VisLayout
# model_type类型参见上表。指定不同model_type时,会自动下载相应模型到安装目录下的。
-layout_engine = RapidLayout(box_threshold=0.5, model_type="pp_layout_cdla")
+layout_engine = RapidLayout(conf_thres=0.5, model_type="pp_layout_cdla")
img = cv2.imread('test_images/layout.png')
@@ -55,18 +59,23 @@ if ploted_img is not None:
- 用法:
```bash
$ rapid_layout -h
- usage: rapid_layout [-h] -img IMG_PATH [-m {pp_layout_cdla,pp_layout_publaynet,pp_layout_table}]
- [--box_threshold {pp_layout_cdla,pp_layout_publaynet,pp_layout_table}] [-v]
+ usage: rapid_layout [-h] -img IMG_PATH
+ [-m {pp_layout_cdla,pp_layout_publaynet,pp_layout_table,yolov8n_layout_paper,yolov8n_layout_report}]
+ [--conf_thres {pp_layout_cdla,pp_layout_publaynet,pp_layout_table,yolov8n_layout_paper,yolov8n_layout_report}]
+ [--iou_thres {pp_layout_cdla,pp_layout_publaynet,pp_layout_table,yolov8n_layout_paper,yolov8n_layout_report}]
+ [-v]
options:
- -h, --help show this help message and exit
- -img IMG_PATH, --img_path IMG_PATH
+ -h, --help show this help message and exit
+ -img IMG_PATH, --img_path IMG_PATH
Path to image for layout.
- -m {pp_layout_cdla,pp_layout_publaynet,pp_layout_table}, --model_type {pp_layout_cdla,pp_layout_publaynet,pp_layout_table}
+ -m {pp_layout_cdla,pp_layout_publaynet,pp_layout_table,yolov8n_layout_paper,yolov8n_layout_report}, --model_type {pp_layout_cdla,pp_layout_publaynet,pp_layout_table,yolov8n_layout_paper,yolov8n_layout_report}
Support model type
- --box_threshold {pp_layout_cdla,pp_layout_publaynet,pp_layout_table}
+ --conf_thres {pp_layout_cdla,pp_layout_publaynet,pp_layout_table,yolov8n_layout_paper,yolov8n_layout_report}
Box threshold, the range is [0, 1]
- -v, --vis Wheter to visualize the layout results.
+ --iou_thres {pp_layout_cdla,pp_layout_publaynet,pp_layout_table,yolov8n_layout_paper,yolov8n_layout_report}
+ IoU threshold, the range is [0, 1]
+ -v, --vis Wheter to visualize the layout results.
```
- 示例:
```bash
diff --git a/demo.py b/demo.py
index b73f279..d5e07f3 100644
--- a/demo.py
+++ b/demo.py
@@ -5,7 +5,7 @@
from rapid_layout import RapidLayout, VisLayout
-layout_engine = RapidLayout(box_threshold=0.5, model_type="pp_layout_cdla")
+layout_engine = RapidLayout(model_type="yolov8n_layout_paper")
img_path = "tests/test_files/layout.png"
img = cv2.imread(img_path)
diff --git a/rapid_layout/config.yaml b/rapid_layout/config.yaml
deleted file mode 100644
index 33a85d2..0000000
--- a/rapid_layout/config.yaml
+++ /dev/null
@@ -1,24 +0,0 @@
-model_path: models/layout_cdla.onnx
-
-use_cuda: false
-CUDAExecutionProvider:
- device_id: 0
- arena_extend_strategy: kNextPowerOfTwo
- cudnn_conv_algo_search: EXHAUSTIVE
- do_copy_in_default_stream: true
-
-pre_process:
- Resize:
- size: [800, 608]
- NormalizeImage:
- std: [0.229, 0.224, 0.225]
- mean: [0.485, 0.456, 0.406]
- scale: 1./255.
- order: hwc
- ToCHWImage:
- KeepKeys:
- keep_keys: ['image']
-
-post_process:
- score_threshold: 0.5
- nms_threshold: 0.5
\ No newline at end of file
diff --git a/rapid_layout/main.py b/rapid_layout/main.py
index b392c92..14cb89f 100644
--- a/rapid_layout/main.py
+++ b/rapid_layout/main.py
@@ -14,11 +14,11 @@
LoadImage,
OrtInferSession,
PicoDetPostProcess,
+ PPPreProcess,
VisLayout,
- create_operators,
+ YOLOv8PostProcess,
+ YOLOv8PreProcess,
get_logger,
- read_yaml,
- transform,
)
ROOT_DIR = Path(__file__).resolve().parent
@@ -29,64 +29,86 @@
"pp_layout_cdla": f"{ROOT_URL}/layout_cdla.onnx",
"pp_layout_publaynet": f"{ROOT_URL}/layout_publaynet.onnx",
"pp_layout_table": f"{ROOT_URL}/layout_table.onnx",
+ "yolov8n_layout_paper": f"{ROOT_URL}/yolov8n_layout_paper.onnx",
+ "yolov8n_layout_report": f"{ROOT_URL}/yolov8n_layout_report.onnx",
}
DEFAULT_MODEL_PATH = str(ROOT_DIR / "models" / "layout_cdla.onnx")
class RapidLayout:
+
def __init__(
self,
model_type: str = "pp_layout_cdla",
- box_threshold: float = 0.5,
+ model_path: Union[str, Path, None] = None,
+ conf_thres: float = 0.5,
+ iou_thres: float = 0.5,
use_cuda: bool = False,
):
- config_path = str(ROOT_DIR / "config.yaml")
- config = read_yaml(config_path)
- config["model_path"] = self.get_model_path(model_type)
- config["use_cuda"] = use_cuda
-
+ self.model_type = model_type
+ config = {
+ "model_path": self.get_model_path(model_type, model_path),
+ "use_cuda": use_cuda,
+ }
self.session = OrtInferSession(config)
labels = self.session.get_character_list()
logger.info("%s contains %s", model_type, labels)
- self.preprocess_op = create_operators(config["pre_process"])
+ # pp
+ self.pp_preprocess = PPPreProcess(img_size=(800, 608))
+ self.pp_postprocess = PicoDetPostProcess(labels, conf_thres, iou_thres)
+
+ # yolov8
+ self.yolov8_input_shape = (640, 640)
+ self.yolo_preprocess = YOLOv8PreProcess(img_size=self.yolov8_input_shape)
+ self.yolo_postprocess = YOLOv8PostProcess(labels, conf_thres, iou_thres)
- config["post_process"]["score_threshold"] = box_threshold
- self.postprocess_op = PicoDetPostProcess(labels, **config["post_process"])
self.load_img = LoadImage()
+ self.pp_layout_type = [
+ "pp_layout_cdla",
+ "pp_layout_publaynet",
+ "pp_layout_table",
+ ]
+ self.yolov8_layout_type = ["yolov8n_layout_paper", "yolov8n_layout_report"]
+
def __call__(
self, img_content: Union[str, np.ndarray, bytes, Path]
) -> Tuple[Optional[np.ndarray], Optional[np.ndarray], Optional[np.ndarray], float]:
img = self.load_img(img_content)
+ ori_img_shape = img.shape[:2]
- ori_im = img.copy()
- data = transform({"image": img}, self.preprocess_op)
- img = data[0]
- if img is None:
- return None, None, None, 0.0
+ if self.model_type in self.pp_layout_type:
+ return self.pp_layout(img, ori_img_shape)
- img = np.expand_dims(img, axis=0)
- img = img.copy()
+ if self.model_type in self.yolov8_layout_type:
+ return self.yolov8_layout(img, ori_img_shape)
- preds, elapse = 0, 1
- starttime = time.time()
+ raise ValueError(f"{self.model_type} is not supported.")
+
+ def pp_layout(self, img: np.ndarray, ori_img_shape: Tuple[int, int]):
+ s_time = time.time()
+
+ img = self.pp_preprocess(img)
preds = self.session(img)
+ boxes, scores, class_names = self.pp_postprocess(ori_img_shape, img, preds)
- score_list, boxes_list = [], []
- num_outs = int(len(preds) / 2)
- for out_idx in range(num_outs):
- score_list.append(preds[out_idx])
- boxes_list.append(preds[out_idx + num_outs])
+ elapse = time.time() - s_time
+ return boxes, scores, class_names, elapse
- boxes, scores, class_names = self.postprocess_op(
- ori_im, img, {"boxes": score_list, "boxes_num": boxes_list}
+ def yolov8_layout(self, img: np.ndarray, ori_img_shape: Tuple[int, int]):
+ input_tensor = self.yolo_preprocess(img)
+ outputs = self.session(input_tensor)
+ boxes, scores, class_names = self.yolo_postprocess(
+ outputs, ori_img_shape, self.yolov8_input_shape
)
- elapse = time.time() - starttime
- return boxes, scores, class_names, elapse
+ return boxes, scores, class_names
@staticmethod
- def get_model_path(model_type: str) -> str:
+ def get_model_path(model_type: str, model_path: Union[str, Path, None]) -> str:
+ if model_path is not None:
+ return model_path
+
model_url = KEY_TO_MODEL_URL.get(model_type, None)
if model_url:
model_path = DownloadModel.download(model_url)
@@ -110,12 +132,19 @@ def main():
help="Support model type",
)
parser.add_argument(
- "--box_threshold",
+ "--conf_thres",
type=float,
default=0.5,
choices=list(KEY_TO_MODEL_URL.keys()),
help="Box threshold, the range is [0, 1]",
)
+ parser.add_argument(
+ "--iou_thres",
+ type=float,
+ default=0.5,
+ choices=list(KEY_TO_MODEL_URL.keys()),
+ help="IoU threshold, the range is [0, 1]",
+ )
parser.add_argument(
"-v",
"--vis",
@@ -125,7 +154,7 @@ def main():
args = parser.parse_args()
layout_engine = RapidLayout(
- model_type=args.model_type, box_threshold=args.box_threshold
+ model_type=args.model_type, conf_thres=args.conf_thres, iou_thres=args.iou_thres
)
img = cv2.imread(args.img_path)
diff --git a/rapid_layout/utils/__init__.py b/rapid_layout/utils/__init__.py
index beb7dac..0cadd2a 100644
--- a/rapid_layout/utils/__init__.py
+++ b/rapid_layout/utils/__init__.py
@@ -7,8 +7,8 @@
from .infer_engine import OrtInferSession
from .load_image import LoadImage
from .logger import get_logger
-from .post_prepross import PicoDetPostProcess
-from .pre_procss import create_operators, transform
+from .post_prepross import PicoDetPostProcess, YOLOv8PostProcess
+from .pre_procss import PPPreProcess, YOLOv8PreProcess
from .vis_res import VisLayout
diff --git a/rapid_layout/utils/post_prepross.py b/rapid_layout/utils/post_prepross.py
index f134cf8..eacd006 100644
--- a/rapid_layout/utils/post_prepross.py
+++ b/rapid_layout/utils/post_prepross.py
@@ -1,39 +1,36 @@
# -*- encoding: utf-8 -*-
# @Author: SWHL
# @Contact: liekkaskono@163.com
+from typing import List, Tuple
+
import numpy as np
class PicoDetPostProcess:
- def __init__(
- self,
- labels,
- strides=[8, 16, 32, 64],
- score_threshold=0.4,
- nms_threshold=0.5,
- nms_top_k=1000,
- keep_top_k=100,
- ):
+ def __init__(self, labels, conf_thres=0.4, iou_thres=0.5):
self.labels = labels
- self.strides = strides
- self.score_threshold = score_threshold
- self.nms_threshold = nms_threshold
- self.nms_top_k = nms_top_k
- self.keep_top_k = keep_top_k
-
- def __call__(self, ori_img, img, preds):
- scores, raw_boxes = preds["boxes"], preds["boxes_num"]
+ self.strides = [8, 16, 32, 64]
+ self.conf_thres = conf_thres
+ self.iou_thres = iou_thres
+ self.nms_top_k = 1000
+ self.keep_top_k = 100
+
+ def __call__(self, ori_shape, img, preds):
+ scores, raw_boxes = [], []
+ num_outs = int(len(preds) / 2)
+ for out_idx in range(num_outs):
+ scores.append(preds[out_idx])
+ raw_boxes.append(preds[out_idx + num_outs])
+
batch_size = raw_boxes[0].shape[0]
reg_max = int(raw_boxes[0].shape[-1] / 4 - 1)
- out_boxes_num = []
- out_boxes_list = []
- ori_shape, input_shape, scale_factor = self.img_info(ori_img, img)
+ out_boxes_num, out_boxes_list = [], []
+ ori_shape, input_shape, scale_factor = self.img_info(ori_shape, img)
for batch_id in range(batch_size):
# generate centers
- decode_boxes = []
- select_scores = []
+ decode_boxes, select_scores = [], []
for stride, box_distribute, score in zip(self.strides, raw_boxes, scores):
box_distribute = box_distribute[batch_id]
score = score[batch_id]
@@ -71,19 +68,19 @@ def __call__(self, ori_img, img, preds):
# nms
bboxes = np.concatenate(decode_boxes, axis=0)
confidences = np.concatenate(select_scores, axis=0)
- picked_box_probs = []
- picked_labels = []
+ picked_box_probs, picked_labels = [], []
for class_index in range(0, confidences.shape[1]):
probs = confidences[:, class_index]
- mask = probs > self.score_threshold
+ mask = probs > self.conf_thres
probs = probs[mask]
if probs.shape[0] == 0:
continue
+
subset_boxes = bboxes[mask, :]
box_probs = np.concatenate([subset_boxes, probs.reshape(-1, 1)], axis=1)
box_probs = self.hard_nms(
box_probs,
- iou_threshold=self.nms_threshold,
+ iou_thres=self.iou_thres,
top_k=self.keep_top_k,
)
picked_box_probs.append(box_probs)
@@ -92,7 +89,6 @@ def __call__(self, ori_img, img, preds):
if len(picked_box_probs) == 0:
out_boxes_list.append(np.empty((0, 4)))
out_boxes_num.append(0)
-
else:
picked_box_probs = np.concatenate(picked_box_probs)
@@ -129,11 +125,6 @@ def __call__(self, ori_img, img, preds):
class_names.append(label)
return np.array(boxes), np.array(scores), np.array(class_names)
- def load_layout_dict(self, layout_dict_path):
- with open(layout_dict_path, "r", encoding="utf-8") as fp:
- labels = fp.readlines()
- return [label.strip("\n") for label in labels]
-
def warp_boxes(self, boxes, ori_shape):
"""Apply transform to boxes"""
width, height = ori_shape[1], ori_shape[0]
@@ -158,8 +149,7 @@ def warp_boxes(self, boxes, ori_shape):
return xy.astype(np.float32)
return boxes
- def img_info(self, ori_img, img):
- origin_shape = ori_img.shape
+ def img_info(self, origin_shape, img):
resize_shape = img.shape
im_scale_y = resize_shape[2] / float(origin_shape[0])
im_scale_x = resize_shape[3] / float(origin_shape[1])
@@ -195,11 +185,11 @@ def logsumexp(a, axis=None, b=None, keepdims=False):
return np.exp(x - logsumexp(x, axis=axis, keepdims=True))
- def hard_nms(self, box_scores, iou_threshold, top_k=-1, candidate_size=200):
+ def hard_nms(self, box_scores, iou_thres, top_k=-1, candidate_size=200):
"""
Args:
box_scores (N, 5): boxes in corner-form and probabilities.
- iou_threshold: intersection over union threshold.
+ iou_thres: intersection over union threshold.
top_k: keep top_k results. If k <= 0, keep all the results.
candidate_size: only consider the candidates with the highest scores.
Returns:
@@ -222,7 +212,7 @@ def hard_nms(self, box_scores, iou_threshold, top_k=-1, candidate_size=200):
rest_boxes,
np.expand_dims(current_box, axis=0),
)
- indexes = indexes[iou <= iou_threshold]
+ indexes = indexes[iou <= iou_thres]
return box_scores[picked, :]
@@ -254,3 +244,135 @@ def area_of(left_top, right_bottom):
"""
hw = np.clip(right_bottom - left_top, 0.0, None)
return hw[..., 0] * hw[..., 1]
+
+
+class YOLOv8PostProcess:
+
+ def __init__(self, labels: List[str], conf_thres=0.7, iou_thres=0.5):
+ self.labels = labels
+ self.conf_threshold = conf_thres
+ self.iou_threshold = iou_thres
+ self.input_width, self.input_height = None, None
+ self.img_width, self.img_height = None, None
+
+ def __call__(
+ self, output, ori_img_shape: Tuple[int, int], img_shape: Tuple[int, int]
+ ):
+ self.img_height, self.img_width = ori_img_shape
+ self.input_height, self.input_width = img_shape
+
+ predictions = np.squeeze(output[0]).T
+
+ # Filter out object confidence scores below threshold
+ scores = np.max(predictions[:, 4:], axis=1)
+ predictions = predictions[scores > self.conf_threshold, :]
+ scores = scores[scores > self.conf_threshold]
+
+ if len(scores) == 0:
+ return [], [], []
+
+ # Get the class with the highest confidence
+ class_ids = np.argmax(predictions[:, 4:], axis=1)
+
+ # Get bounding boxes for each object
+ boxes = self.extract_boxes(predictions)
+
+ # Apply non-maxima suppression to suppress weak, overlapping bounding boxes
+ # indices = nms(boxes, scores, self.iou_threshold)
+ indices = multiclass_nms(boxes, scores, class_ids, self.iou_threshold)
+
+ labels = [self.labels[i] for i in class_ids[indices]]
+ return boxes[indices], scores[indices], labels
+
+ def extract_boxes(self, predictions):
+ # Extract boxes from predictions
+ boxes = predictions[:, :4]
+
+ # Scale boxes to original image dimensions
+ boxes = self.rescale_boxes(boxes)
+
+ # Convert boxes to xyxy format
+ boxes = xywh2xyxy(boxes)
+
+ return boxes
+
+ def rescale_boxes(self, boxes):
+
+ # Rescale boxes to original image dimensions
+ input_shape = np.array(
+ [self.input_width, self.input_height, self.input_width, self.input_height]
+ )
+ boxes = np.divide(boxes, input_shape, dtype=np.float32)
+ boxes *= np.array(
+ [self.img_width, self.img_height, self.img_width, self.img_height]
+ )
+ return boxes
+
+
+def nms(boxes, scores, iou_threshold):
+ # Sort by score
+ sorted_indices = np.argsort(scores)[::-1]
+
+ keep_boxes = []
+ while sorted_indices.size > 0:
+ # Pick the last box
+ box_id = sorted_indices[0]
+ keep_boxes.append(box_id)
+
+ # Compute IoU of the picked box with the rest
+ ious = compute_iou(boxes[box_id, :], boxes[sorted_indices[1:], :])
+
+ # Remove boxes with IoU over the threshold
+ keep_indices = np.where(ious < iou_threshold)[0]
+
+ # print(keep_indices.shape, sorted_indices.shape)
+ sorted_indices = sorted_indices[keep_indices + 1]
+
+ return keep_boxes
+
+
+def multiclass_nms(boxes, scores, class_ids, iou_threshold):
+
+ unique_class_ids = np.unique(class_ids)
+
+ keep_boxes = []
+ for class_id in unique_class_ids:
+ class_indices = np.where(class_ids == class_id)[0]
+ class_boxes = boxes[class_indices, :]
+ class_scores = scores[class_indices]
+
+ class_keep_boxes = nms(class_boxes, class_scores, iou_threshold)
+ keep_boxes.extend(class_indices[class_keep_boxes])
+
+ return keep_boxes
+
+
+def compute_iou(box, boxes):
+ # Compute xmin, ymin, xmax, ymax for both boxes
+ xmin = np.maximum(box[0], boxes[:, 0])
+ ymin = np.maximum(box[1], boxes[:, 1])
+ xmax = np.minimum(box[2], boxes[:, 2])
+ ymax = np.minimum(box[3], boxes[:, 3])
+
+ # Compute intersection area
+ intersection_area = np.maximum(0, xmax - xmin) * np.maximum(0, ymax - ymin)
+
+ # Compute union area
+ box_area = (box[2] - box[0]) * (box[3] - box[1])
+ boxes_area = (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1])
+ union_area = box_area + boxes_area - intersection_area
+
+ # Compute IoU
+ iou = intersection_area / union_area
+
+ return iou
+
+
+def xywh2xyxy(x):
+ # Convert bounding box (x, y, w, h) to bounding box (x1, y1, x2, y2)
+ y = np.copy(x)
+ y[..., 0] = x[..., 0] - x[..., 2] / 2
+ y[..., 1] = x[..., 1] - x[..., 3] / 2
+ y[..., 2] = x[..., 0] + x[..., 2] / 2
+ y[..., 3] = x[..., 1] + x[..., 3] / 2
+ return y
diff --git a/rapid_layout/utils/pre_procss.py b/rapid_layout/utils/pre_procss.py
index f5e3e21..78ce748 100644
--- a/rapid_layout/utils/pre_procss.py
+++ b/rapid_layout/utils/pre_procss.py
@@ -2,7 +2,7 @@
# @Author: SWHL
# @Contact: liekkaskono@163.com
from pathlib import Path
-from typing import Union
+from typing import Optional, Tuple, Union
import cv2
import numpy as np
@@ -10,94 +10,44 @@
InputType = Union[str, np.ndarray, bytes, Path]
-def transform(data, ops=None):
- """transform"""
- if ops is None:
- ops = []
+class PPPreProcess:
- for op in ops:
- data = op(data)
- if data is None:
- return None
- return data
+ def __init__(self, img_size: Tuple[int, int]):
+ self.size = img_size
+ self.mean = np.array([0.485, 0.456, 0.406])
+ self.std = np.array([0.229, 0.224, 0.225])
+ self.scale = 1 / 255.0
+ def __call__(self, img: Optional[np.ndarray] = None) -> np.ndarray:
+ if img is None:
+ raise ValueError("img is None.")
-def create_operators(op_param_dict):
- ops = []
- for op_name, param in op_param_dict.items():
- if param is None:
- param = {}
- op = eval(op_name)(**param)
- ops.append(op)
- return ops
+ img = self.resize(img)
+ img = self.normalize(img)
+ img = self.permute(img)
+ img = np.expand_dims(img, axis=0)
+ return img.astype(np.float32)
-
-class Resize:
- def __init__(self, size=(640, 640)):
- self.size = size
-
- def resize_image(self, img):
+ def resize(self, img: np.ndarray) -> np.ndarray:
resize_h, resize_w = self.size
- ori_h, ori_w = img.shape[:2] # (h, w, c)
- ratio_h = float(resize_h) / ori_h
- ratio_w = float(resize_w) / ori_w
img = cv2.resize(img, (int(resize_w), int(resize_h)))
- return img, [ratio_h, ratio_w]
-
- def __call__(self, data):
- img = data["image"]
- if "polys" in data:
- text_polys = data["polys"]
-
- img_resize, [ratio_h, ratio_w] = self.resize_image(img)
- if "polys" in data:
- new_boxes = []
- for box in text_polys:
- new_box = []
- for cord in box:
- new_box.append([cord[0] * ratio_w, cord[1] * ratio_h])
- new_boxes.append(new_box)
- data["polys"] = np.array(new_boxes, dtype=np.float32)
- data["image"] = img_resize
- return data
-
-
-class NormalizeImage:
- def __init__(self, scale=None, mean=None, std=None, order="chw"):
- if isinstance(scale, str):
- scale = eval(scale)
-
- self.scale = np.float32(scale if scale is not None else 1.0 / 255.0)
- mean = mean if mean is not None else [0.485, 0.456, 0.406]
- std = std if std is not None else [0.229, 0.224, 0.225]
-
- shape = (3, 1, 1) if order == "chw" else (1, 1, 3)
- self.mean = np.array(mean).reshape(shape).astype("float32")
- self.std = np.array(std).reshape(shape).astype("float32")
-
- def __call__(self, data):
- img = np.array(data["image"])
- assert isinstance(img, np.ndarray), "invalid input 'img' in NormalizeImage"
- data["image"] = (img.astype("float32") * self.scale - self.mean) / self.std
- return data
+ return img
+ def normalize(self, img: np.ndarray) -> np.ndarray:
+ return (img.astype("float32") * self.scale - self.mean) / self.std
-class ToCHWImage:
- def __init__(self, **kwargs):
- pass
+ def permute(self, img: np.ndarray) -> np.ndarray:
+ return img.transpose((2, 0, 1))
- def __call__(self, data):
- img = np.array(data["image"])
- data["image"] = img.transpose((2, 0, 1))
- return data
+class YOLOv8PreProcess:
-class KeepKeys:
- def __init__(self, keep_keys):
- self.keep_keys = keep_keys
+ def __init__(self, img_size: Tuple[int, int]):
+ self.img_size = img_size
- def __call__(self, data):
- data_list = []
- for key in self.keep_keys:
- data_list.append(data[key])
- return data_list
+ def __call__(self, image: np.ndarray) -> np.ndarray:
+ input_img = cv2.resize(image, self.img_size)
+ input_img = input_img / 255.0
+ input_img = input_img.transpose(2, 0, 1)
+ input_tensor = input_img[np.newaxis, :, :, :].astype(np.float32)
+ return input_tensor
diff --git a/tests/test_layout.py b/tests/test_layout.py
index 1d3fa5c..fa56d7b 100644
--- a/tests/test_layout.py
+++ b/tests/test_layout.py
@@ -15,8 +15,6 @@
from rapid_layout import RapidLayout
test_file_dir = cur_dir / "test_files"
-layout_engine = RapidLayout()
-
img_path = test_file_dir / "layout.png"
img = cv2.imread(str(img_path))
@@ -26,5 +24,15 @@
"img_content", [img_path, str(img_path), open(img_path, "rb").read(), img]
)
def test_multi_input(img_content):
- boxes, scores, class_names, *elapse = layout_engine(img_content)
+ engine = RapidLayout()
+ boxes, scores, class_names, *elapse = engine(img_content)
assert len(boxes) == 15
+
+
+@pytest.mark.parametrize(
+ "img_content", [img_path, str(img_path), open(img_path, "rb").read(), img]
+)
+def test_yolov8_input(img_content):
+ engine = RapidLayout(model_type="yolov8n_layout_paper")
+ boxes, scores, class_names, *elapse = engine(img_content)
+ assert len(boxes) == 11