Skip to content

Commit 8ba6da2

Browse files
authored
fix a sliced prediction bug for small images (#33)
1 parent 79c1f00 commit 8ba6da2

File tree

2 files changed

+36
-19
lines changed

2 files changed

+36
-19
lines changed

sahi/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = "0.3.0"
1+
__version__ = "0.3.1"

sahi/predict.py

+35-18
Original file line numberDiff line numberDiff line change
@@ -192,28 +192,45 @@ def get_sliced_prediction(
192192
# create prediction input
193193
num_group = int(num_slices / num_batch)
194194
if verbose == 1 or verbose == 2:
195-
print("Number of slices:", num_slices)
195+
if num_slices > 0:
196+
print("Number of slices:", num_slices)
197+
else:
198+
print("Number of slices:", 1)
196199
object_prediction_list = []
197-
for group_ind in range(num_group):
198-
# prepare batch (currently supports only 1 batch)
199-
image_list = []
200-
shift_amount_list = []
201-
for image_ind in range(num_batch):
202-
image_list.append(
203-
slice_image_result.images[group_ind * num_batch + image_ind]
204-
)
205-
shift_amount_list.append(
206-
slice_image_result.starting_pixels[group_ind * num_batch + image_ind]
200+
if num_slices > 0: # if zero_frac < max_allowed_zeros_ratio from slice_image
201+
for group_ind in range(num_group):
202+
# prepare batch (currently supports only 1 batch)
203+
image_list = []
204+
shift_amount_list = []
205+
for image_ind in range(num_batch):
206+
image_list.append(
207+
slice_image_result.images[group_ind * num_batch + image_ind]
208+
)
209+
shift_amount_list.append(
210+
slice_image_result.starting_pixels[
211+
group_ind * num_batch + image_ind
212+
]
213+
)
214+
# perform batch prediction
215+
prediction_result = get_prediction(
216+
image=image_list[0],
217+
detection_model=detection_model,
218+
shift_amount=shift_amount_list[0],
219+
full_shape=[
220+
slice_image_result.original_image_height,
221+
slice_image_result.original_image_width,
222+
],
207223
)
208-
# perform batch prediction
224+
object_prediction_list.extend(prediction_result["object_prediction_list"])
225+
else: # if zero_frac >= max_allowed_zeros_ratio from slice_image
209226
prediction_result = get_prediction(
210-
image=image_list[0],
227+
image=image,
211228
detection_model=detection_model,
212-
shift_amount=shift_amount_list[0],
213-
full_shape=[
214-
slice_image_result.original_image_height,
215-
slice_image_result.original_image_width,
216-
],
229+
shift_amount=[0, 0],
230+
full_shape=None,
231+
merger=None,
232+
matcher=None,
233+
verbose=0,
217234
)
218235
object_prediction_list.extend(prediction_result["object_prediction_list"])
219236

0 commit comments

Comments
 (0)