Skip to content

Commit

Permalink
Merge pull request #768 from roboflow/fixed_empty_prompts_crash
Browse files Browse the repository at this point in the history
fixing and testing for empty prompts
  • Loading branch information
PawelPeczek-Roboflow authored Nov 1, 2024
2 parents e1803bb + b32908b commit 9602ddc
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 4 deletions.
5 changes: 5 additions & 0 deletions inference/models/owlv2/owlv2.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,6 +327,8 @@ def get_query_embedding(
query_boxes_tensor = torch.tensor(
query_boxes, dtype=image_boxes.dtype, device=image_boxes.device
)
if image_boxes.numel() == 0 or query_boxes_tensor.numel() == 0:
continue
iou, _ = box_iou(
to_corners(image_boxes), to_corners(query_boxes_tensor)
) # 3000, k
Expand Down Expand Up @@ -456,6 +458,9 @@ def make_class_embeddings_dict(
# NOTE: because we just computed the embedding for this image, this should never result in a KeyError
embeddings = self.get_query_embedding(query_spec, iou_threshold)

if embeddings is None:
continue

# add the embeddings to their appropriate class and positive/negative list
for embedding, class_name, is_positive in zip(
embeddings, classes, is_positive
Expand Down
58 changes: 54 additions & 4 deletions tests/inference/models_predictions_tests/test_owlv2.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,15 @@ def test_owlv2():
"type": "url",
"value": "https://media.roboflow.com/inference/seawithdock.jpeg",
}

# test we can handle a single positive prompt
request = OwlV2InferenceRequest(
image=image,
training_data=[
{
"image": image,
"boxes": [
{"x": 223, "y": 306, "w": 40, "h": 226, "cls": "post", "negative": False},
{"x": 247, "y": 294, "w": 25, "h": 165, "cls": "post", "negative": True},
{"x": 264, "y": 327, "w": 21, "h": 74, "cls": "post", "negative": False},
],
}
],
Expand All @@ -25,7 +25,7 @@ def test_owlv2():

response = OwlV2().infer_from_request(request)
# we assert that we're finding all of the posts in the image
assert len(response.predictions) == 4
assert len(response.predictions) == 5
# next we check the x coordinates to force something about localization
# the exact value here is sensitive to:
# 1. the image interpolation mode used
Expand All @@ -37,6 +37,56 @@ def test_owlv2():
posts = [p for p in response.predictions if p.class_name == "post"]
posts.sort(key=lambda x: x.x)
assert abs(223 - posts[0].x) < 1.5
assert abs(248 - posts[1].x) < 1.5
assert abs(264 - posts[2].x) < 1.5
assert abs(532 - posts[3].x) < 1.5
assert abs(572 - posts[4].x) < 1.5


# test we can handle multiple (positive and negative) prompts for the same image
request = OwlV2InferenceRequest(
image=image,
training_data=[
{
"image": image,
"boxes": [
{"x": 223, "y": 306, "w": 40, "h": 226, "cls": "post", "negative": False},
{"x": 247, "y": 294, "w": 25, "h": 165, "cls": "post", "negative": True},
{"x": 264, "y": 327, "w": 21, "h": 74, "cls": "post", "negative": False},
],
}
],
visualize_predictions=True,
confidence=0.9,
)

response = OwlV2().infer_from_request(request)
assert len(response.predictions) == 4
posts = [p for p in response.predictions if p.class_name == "post"]
posts.sort(key=lambda x: x.x)
assert abs(223 - posts[0].x) < 1.5
assert abs(264 - posts[1].x) < 1.5
assert abs(532 - posts[2].x) < 1.5
assert abs(572 - posts[3].x) < 1.5
assert abs(572 - posts[3].x) < 1.5

# test that we can handle no prompts for an image
request = OwlV2InferenceRequest(
image=image,
training_data=[
{
"image": image,
"boxes": [
{"x": 223, "y": 306, "w": 40, "h": 226, "cls": "post", "negative": False}
],
},
{
"image": image,
"boxes": [],
},
],
visualize_predictions=True,
confidence=0.9,
)

response = OwlV2().infer_from_request(request)
assert len(response.predictions) == 5

0 comments on commit 9602ddc

Please sign in to comment.