From b32908b791424abfcac2d8aa8df20489cdc2c222 Mon Sep 17 00:00:00 2001 From: Isaac Robinson Date: Fri, 1 Nov 2024 15:17:35 +0000 Subject: [PATCH] fixing and testing for empty prompts --- inference/models/owlv2/owlv2.py | 5 ++ .../models_predictions_tests/test_owlv2.py | 58 +++++++++++++++++-- 2 files changed, 59 insertions(+), 4 deletions(-) diff --git a/inference/models/owlv2/owlv2.py b/inference/models/owlv2/owlv2.py index 53a0a9c6e..a3d05ff40 100644 --- a/inference/models/owlv2/owlv2.py +++ b/inference/models/owlv2/owlv2.py @@ -327,6 +327,8 @@ def get_query_embedding( query_boxes_tensor = torch.tensor( query_boxes, dtype=image_boxes.dtype, device=image_boxes.device ) + if image_boxes.numel() == 0 or query_boxes_tensor.numel() == 0: + continue iou, _ = box_iou( to_corners(image_boxes), to_corners(query_boxes_tensor) ) # 3000, k @@ -456,6 +458,9 @@ def make_class_embeddings_dict( # NOTE: because we just computed the embedding for this image, this should never result in a KeyError embeddings = self.get_query_embedding(query_spec, iou_threshold) + if embeddings is None: + continue + # add the embeddings to their appropriate class and positive/negative list for embedding, class_name, is_positive in zip( embeddings, classes, is_positive diff --git a/tests/inference/models_predictions_tests/test_owlv2.py b/tests/inference/models_predictions_tests/test_owlv2.py index fa37bf13a..3ad5913ab 100644 --- a/tests/inference/models_predictions_tests/test_owlv2.py +++ b/tests/inference/models_predictions_tests/test_owlv2.py @@ -7,6 +7,8 @@ def test_owlv2(): "type": "url", "value": "https://media.roboflow.com/inference/seawithdock.jpeg", } + + # test we can handle a single positive prompt request = OwlV2InferenceRequest( image=image, training_data=[ @@ -14,8 +16,6 @@ def test_owlv2(): "image": image, "boxes": [ {"x": 223, "y": 306, "w": 40, "h": 226, "cls": "post", "negative": False}, - {"x": 247, "y": 294, "w": 25, "h": 165, "cls": "post", "negative": True}, - {"x": 264, "y": 327, "w": 21, "h": 74, "cls": "post", "negative": False}, ], } ], @@ -25,7 +25,7 @@ def test_owlv2(): response = OwlV2().infer_from_request(request) # we assert that we're finding all of the posts in the image - assert len(response.predictions) == 4 + assert len(response.predictions) == 5 # next we check the x coordinates to force something about localization # the exact value here is sensitive to: # 1. the image interpolation mode used @@ -37,6 +37,56 @@ def test_owlv2(): posts = [p for p in response.predictions if p.class_name == "post"] posts.sort(key=lambda x: x.x) assert abs(223 - posts[0].x) < 1.5 + assert abs(248 - posts[1].x) < 1.5 + assert abs(264 - posts[2].x) < 1.5 + assert abs(532 - posts[3].x) < 1.5 + assert abs(572 - posts[4].x) < 1.5 + + + # test we can handle multiple (positive and negative) prompts for the same image + request = OwlV2InferenceRequest( + image=image, + training_data=[ + { + "image": image, + "boxes": [ + {"x": 223, "y": 306, "w": 40, "h": 226, "cls": "post", "negative": False}, + {"x": 247, "y": 294, "w": 25, "h": 165, "cls": "post", "negative": True}, + {"x": 264, "y": 327, "w": 21, "h": 74, "cls": "post", "negative": False}, + ], + } + ], + visualize_predictions=True, + confidence=0.9, + ) + + response = OwlV2().infer_from_request(request) + assert len(response.predictions) == 4 + posts = [p for p in response.predictions if p.class_name == "post"] + posts.sort(key=lambda x: x.x) + assert abs(223 - posts[0].x) < 1.5 assert abs(264 - posts[1].x) < 1.5 assert abs(532 - posts[2].x) < 1.5 - assert abs(572 - posts[3].x) < 1.5 \ No newline at end of file + assert abs(572 - posts[3].x) < 1.5 + + # test that we can handle no prompts for an image + request = OwlV2InferenceRequest( + image=image, + training_data=[ + { + "image": image, + "boxes": [ + {"x": 223, "y": 306, "w": 40, "h": 226, "cls": "post", "negative": False} + ], + }, + { + "image": image, + "boxes": [], + }, + ], + visualize_predictions=True, + confidence=0.9, + ) + + response = OwlV2().infer_from_request(request) + assert len(response.predictions) == 5 \ No newline at end of file