Skip to content

Commit b8d4a88

Browse files
authored
Merge pull request #12 from Koldim2001/feature/batch_processing
Feature/batch processing
2 parents b7d0ffd + 4220629 commit b8d4a88

File tree

6 files changed

+99
-8
lines changed

6 files changed

+99
-8
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,4 +19,5 @@ setup.cfg
1919
build
2020
info_how_pip_upload.txt
2121
examples/patched_yolo_infer
22+
**.engine
2223
**.ipynb

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,7 @@ Class implementing cropping and passing crops through a neural network for detec
121121
| resize_initial_size | bool | False | Whether to resize the results to the original input image size (ps: slow operation). |
122122
| memory_optimize | bool | True | Memory optimization option for segmentation (less accurate results when enabled). |
123123
| inference_extra_args | dict | None | Dictionary with extra ultralytics [inference parameters](https://docs.ultralytics.com/modes/predict/#inference-arguments) (possible keys: half, device, max_det, augment, agnostic_nms and retina_masks) |
124+
| batch_inference | bool | False | Batch inference of image crops through a neural network instead of sequential passes of crops (ps: faster inference, higher gpu memory use). |
124125

125126

126127
**CombineDetections**

patched_yolo_infer/README.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ This library facilitates various visualizations of inference results from ultral
99
You can install the library via pip:
1010

1111
```bash
12-
pip install patched_yolo_infer
12+
pip install patched-yolo-infer
1313
```
1414

1515
Note: If CUDA support is available, it's recommended to pre-install PyTorch with CUDA support before installing the library. Otherwise, the CPU version will be installed by default.
@@ -99,6 +99,7 @@ Class implementing cropping and passing crops through a neural network for detec
9999
- **resize_initial_size** (*bool*): Whether to resize the results to the original image size (ps: slow operation).
100100
- **memory_optimize** (*bool*): Memory optimization option for segmentation (less accurate results when enabled).
101101
- **inference_extra_args** (*dict*): Dictionary with extra ultralytics [inference parameters](https://docs.ultralytics.com/modes/predict/#inference-arguments) (possible keys: half, device, max_det, augment, agnostic_nms and retina_masks)
102+
- **batch_inference** (*bool*): Batch inference of image crops through a neural network instead of sequential passes of crops (ps: faster inference, higher gpu memory use)
102103

103104
**CombineDetections**
104105
Class implementing combining masks/boxes from multiple crops + NMS (Non-Maximum Suppression).\

patched_yolo_infer/nodes/MakeCropsDetectThem.py

Lines changed: 88 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,8 @@ class MakeCropsDetectThem:
5050
image size (ps: slow operation).
5151
class_names_dict (dict): Dictionary containing class names of the YOLO model.
5252
memory_optimize (bool): Memory optimization option for segmentation (less accurate results)
53+
batch_inference (bool): Batch inference of image crops through a neural network instead of
54+
sequential passes of crops (ps: Faster inference, higher memory use)
5355
inference_extra_args (dict): Dictionary with extra ultralytics inference parameters
5456
"""
5557
def __init__(
@@ -70,6 +72,7 @@ def __init__(
7072
model=None,
7173
memory_optimize=True,
7274
inference_extra_args=None,
75+
batch_inference=False,
7376
) -> None:
7477
if model is None:
7578
self.model = YOLO(model_path) # Load the model from the specified path
@@ -91,6 +94,7 @@ def __init__(
9194
self.memory_optimize = memory_optimize # memory opimization option for segmentation
9295
self.class_names_dict = self.model.names # dict with human-readable class names
9396
self.inference_extra_args = inference_extra_args # dict with extra ultralytics inference parameters
97+
self.batch_inference = batch_inference # batch inference of image crops through a neural network
9498

9599
self.crops = self.get_crops_xy(
96100
self.image,
@@ -100,7 +104,10 @@ def __init__(
100104
overlap_y=self.overlap_y,
101105
show=self.show_crops,
102106
)
103-
self._detect_objects()
107+
if self.batch_inference:
108+
self._detect_objects_batch()
109+
else:
110+
self._detect_objects()
104111

105112
def get_crops_xy(
106113
self,
@@ -141,6 +148,7 @@ def get_crops_xy(
141148
x_new = round((x_steps-1) * (shape_x * cross_koef_x) + shape_x)
142149
image_innitial = image_full.copy()
143150
image_full = cv2.resize(image_full, (x_new, y_new))
151+
batch_of_crops = []
144152

145153
if show:
146154
plt.figure(figsize=[x_steps*0.9, y_steps*0.9])
@@ -176,12 +184,17 @@ def get_crops_xy(
176184
x_start=x_start,
177185
y_start=y_start,
178186
))
187+
if self.batch_inference:
188+
batch_of_crops.append(im_temp)
179189

180190
if show:
181191
plt.show()
182192
print('Number of generated images:', count)
183193

184-
return data_all_crops
194+
if self.batch_inference:
195+
return data_all_crops, batch_of_crops
196+
else:
197+
return data_all_crops
185198

186199
def _detect_objects(self):
187200
"""
@@ -207,3 +220,76 @@ def _detect_objects(self):
207220
crop.calculate_real_values()
208221
if self.resize_initial_size:
209222
crop.resize_results()
223+
224+
def _detect_objects_batch(self):
225+
"""
226+
Method to detect objects in batch of image crops.
227+
228+
This method performs batch inference using the YOLO model,
229+
calculates real values, and optionally resizes the results.
230+
231+
Returns:
232+
None
233+
"""
234+
crops, batch = self.crops
235+
self.crops = crops
236+
self._calculate_batch_inference(
237+
batch,
238+
self.crops,
239+
self.model,
240+
imgsz=self.imgsz,
241+
conf=self.conf,
242+
iou=self.iou,
243+
segment=self.segment,
244+
classes_list=self.classes_list,
245+
memory_optimize=self.memory_optimize,
246+
extra_args=self.inference_extra_args
247+
)
248+
for crop in self.crops:
249+
crop.calculate_real_values()
250+
if self.resize_initial_size:
251+
crop.resize_results()
252+
253+
def _calculate_batch_inference(
254+
self,
255+
batch,
256+
crops,
257+
model,
258+
imgsz=640,
259+
conf=0.35,
260+
iou=0.7,
261+
segment=False,
262+
classes_list=None,
263+
memory_optimize=False,
264+
extra_args=None,
265+
):
266+
# Perform batch inference of image crops through a neural network
267+
extra_args = {} if extra_args is None else extra_args
268+
predictions = model.predict(
269+
batch,
270+
imgsz=imgsz,
271+
conf=conf,
272+
iou=iou,
273+
classes=classes_list,
274+
verbose=False,
275+
**extra_args
276+
)
277+
278+
for pred, crop in zip(predictions, crops):
279+
280+
# Get the bounding boxes and convert them to a list of lists
281+
crop.detected_xyxy = pred.boxes.xyxy.cpu().int().tolist()
282+
283+
# Get the classes and convert them to a list
284+
crop.detected_cls = pred.boxes.cls.cpu().int().tolist()
285+
286+
# Get the mask confidence scores
287+
crop.detected_conf = pred.boxes.conf.cpu().numpy()
288+
289+
if segment and len(crop.detected_cls) != 0:
290+
if memory_optimize:
291+
# Get the polygons
292+
crop.polygons = [mask.astype(np.uint16) for mask in pred.masks.xy]
293+
else:
294+
# Get the masks
295+
crop.detected_masks = pred.masks.data.cpu().numpy()

requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1+
numpy<2.0
12
torch
2-
numpy
33
opencv-python
44
matplotlib
55
ultralytics

setup.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,8 @@
88
long_description = "\n" + fh.read()
99

1010

11-
VERSION = '1.2.6'
12-
DESCRIPTION = '''YOLO-Patch-Based-Inference for detection/segmentation of small objects in images.'''
11+
VERSION = '1.2.7'
12+
DESCRIPTION = '''Patch-Based-Inference for detection/segmentation of small objects in images.'''
1313

1414
setup(
1515
name="patched_yolo_infer",
@@ -23,7 +23,7 @@
2323
packages=find_packages(),
2424
python_requires=">=3.8",
2525
install_requires=[
26-
'numpy',
26+
'numpy<2.0',
2727
'opencv-python',
2828
'matplotlib',
2929
'torch',
@@ -33,8 +33,10 @@
3333
"python",
3434
"yolov8",
3535
"yolov9",
36+
"yolov10",
3637
"rtdetr",
37-
"sam",
38+
"fastsam",
39+
"sahi",
3840
"object detection",
3941
"instance segmentation",
4042
"patch-based inference",

0 commit comments

Comments
 (0)