Skip to content

Commit

Permalink
Merge pull request #782 from roboflow/improved_box_prompt_caching
Browse files Browse the repository at this point in the history
avoiding downloading images if possible
  • Loading branch information
PawelPeczek-Roboflow authored Nov 8, 2024
2 parents 5068dd0 + c230c24 commit 6049ebb
Show file tree
Hide file tree
Showing 3 changed files with 178 additions and 21 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
training_data=[
{
"image": img,
"boxes": [{"x": 223, "y": 306, "w": 40, "h": 226, "cls": "post"}],
"boxes": [{"x": 223, "y": 306, "w": 40, "h": 226, "cls": "post", "negative": False}],
}
],
visualize_predictions=False,
Expand Down
140 changes: 120 additions & 20 deletions inference/models/owlv2/owlv2.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import hashlib
import os
import pickle
from collections import defaultdict
from typing import Any, Dict, List, Literal, NewType, Tuple
from typing import Any, Dict, List, Literal, NewType, Tuple, Union

import numpy as np
import torch
Expand All @@ -10,7 +11,6 @@
from transformers import Owlv2ForObjectDetection, Owlv2Processor
from transformers.models.owlv2.modeling_owlv2 import box_iou

from inference.core.entities.requests.owlv2 import TrainingImage
from inference.core.entities.responses.inference import (
InferenceResponseImage,
ObjectDetectionInferenceResponse,
Expand All @@ -22,7 +22,11 @@
RoboflowCoreModel,
draw_detection_predictions,
)
from inference.core.utils.image_utils import load_image_rgb
from inference.core.utils.image_utils import (
ImageType,
extract_image_payload_and_type,
load_image_rgb,
)

# TYPES
Hash = NewType("Hash", str)
Expand Down Expand Up @@ -196,6 +200,58 @@ def make_class_map(
return class_map, class_names


def hash_function(value: Any) -> Hash:
# wrapper so we can change the hashing function in the future
return hashlib.sha1(value).hexdigest()


class LazyImageRetrievalWrapper:
def __init__(self, image: Any):
self.image = image

self._image_as_numpy = None
self._image_hash = None

@property
def image_as_numpy(self) -> np.ndarray:
if self._image_as_numpy is None:
self._image_as_numpy = load_image_rgb(self.image)
return self._image_as_numpy

@property
def image_hash(self) -> Hash:
if self._image_hash is None:
image_payload, image_type = extract_image_payload_and_type(self.image)
if image_type is ImageType.URL:
# we can use the url as the hash
self._image_hash = image_payload
elif image_type is ImageType.BASE64:
# this is presumably the compressed image bytes
# hashing this directly is faster than loading the raw image through numpy
# we have to make sure we're passing a buffer, so we encode to bytes if necessary
# see load_image_base64 in image_utils.py for more details about the base64 encoding
if type(image_payload) is str:
image_payload = image_payload.encode("utf-8")
self._image_hash = hash_function(image_payload)
else:
# not clear that there is something safe or faster to do than just loading the numpy array
# and hashing that
self._image_hash = hash_function(self.image_as_numpy.tobytes())
return self._image_hash


def hash_wrapped_training_data(wrapped_training_data: List[Dict[str, Any]]) -> Hash:
just_hash_relevant_data = [
[
d["image"].image_hash,
d["boxes"],
]
for d in wrapped_training_data
]
# we dump to pickle to serialize the data as a single object
return hash_function(pickle.dumps(just_hash_relevant_data))


class OwlV2(RoboflowCoreModel):
task_type = "object-detection"
box_format = "xywh"
Expand Down Expand Up @@ -224,9 +280,12 @@ def __init__(self, *args, model_id="owlv2/owlv2-base-patch16-ensemble", **kwargs
self.model.owlv2.vision_model = torch.compile(self.model.owlv2.vision_model)

def reset_cache(self):
self.image_embed_cache = LimitedSizeDict(
size_limit=1000
) # NOTE: this should have a max size
# each entry should be on the order of 300*4KB, so 1000 is 400MB of CUDA memory
self.image_embed_cache = LimitedSizeDict(size_limit=1000)
# each entry should be on the order of 10 bytes, so 1000 is 10KB
self.image_size_cache = LimitedSizeDict(size_limit=1000)
# entry size will vary depending on the number of samples, but 100 should be safe
self.class_embeddings_cache = LimitedSizeDict(size_limit=100)

def draw_predictions(
self,
Expand Down Expand Up @@ -258,15 +317,35 @@ def download_weights(self) -> None:
# Download from huggingface
pass

def compute_image_size(
self, image: Union[np.ndarray, LazyImageRetrievalWrapper]
) -> Tuple[int, int]:
# we build this in hopes of avoiding having to load the image solely for the purpose of getting its size
if isinstance(image, LazyImageRetrievalWrapper):
if (image_size := self.image_size_cache.get(image.image_hash)) is None:
image_size = image.image_as_numpy.shape[:2][::-1]
self.image_size_cache[image.image_hash] = image_size
else:
image_size = image.shape[:2][::-1]
return image_size

@torch.no_grad()
def embed_image(self, image: np.ndarray) -> Hash:
image_hash = hashlib.sha256(image.tobytes()).hexdigest()
def embed_image(self, image: Union[np.ndarray, LazyImageRetrievalWrapper]) -> Hash:
if isinstance(image, LazyImageRetrievalWrapper):
image_hash = image.image_hash
else:
image_hash = hash_function(image.tobytes())

if (image_embeds := self.image_embed_cache.get(image_hash)) is not None:
if image_hash in self.image_embed_cache:
return image_hash

np_image = (
image.image_as_numpy
if isinstance(image, LazyImageRetrievalWrapper)
else image
)
pixel_values = preprocess_image(
image, self.image_size, self.image_mean, self.image_std
np_image, self.image_size, self.image_mean, self.image_std
)

# torch 2.4 lets you use "cuda:0" as device_type
Expand Down Expand Up @@ -422,12 +501,16 @@ def infer(
else:
images = image

images = [LazyImageRetrievalWrapper(image) for image in images]

results = []
image_sizes = []
for image in images:
image = load_image_rgb(image)
image_sizes.append(image.shape[:2][::-1])
image_hash = self.embed_image(image)
for image_wrapper in images:
# happy path here is that both image size and image embeddings are cached
# in which case we avoid loading the image at all
image_size = self.compute_image_size(image_wrapper)
image_sizes.append(image_size)
image_hash = self.embed_image(image_wrapper)
result = self.infer_from_embed(
image_hash, class_embeddings_dict, confidence, iou_threshold
)
Expand All @@ -439,20 +522,35 @@ def infer(
def make_class_embeddings_dict(
self, training_data: List[Any], iou_threshold: float
) -> Dict[str, PosNegDictType]:
wrapped_training_data = [
{
"image": LazyImageRetrievalWrapper(train_image["image"]),
"boxes": train_image["boxes"],
}
for train_image in training_data
]

wrapped_training_data_hash = hash_wrapped_training_data(wrapped_training_data)
if (
class_embeddings_dict := self.class_embeddings_cache.get(
wrapped_training_data_hash
)
) is not None:
return class_embeddings_dict

class_embeddings_dict = defaultdict(lambda: {"positive": [], "negative": []})

bool_to_literal = {True: "positive", False: "negative"}
for train_image in training_data:
for train_image in wrapped_training_data:
# grab and embed image
image = load_image_rgb(train_image["image"])
image_hash = self.embed_image(image)
image_hash = self.embed_image(train_image["image"])

# grab and normalize box prompts for this image
image_size = self.compute_image_size(train_image["image"])
boxes = train_image["boxes"]
print(f"boxes: {boxes}")
coords = [[box["x"], box["y"], box["w"], box["h"]] for box in boxes]
coords = [
tuple([c / max(image.shape[:2]) for c in coord]) for coord in coords
]
coords = [tuple([c / max(image_size) for c in coord]) for coord in coords]
classes = [box["cls"] for box in boxes]
is_positive = [not box["negative"] for box in boxes]

Expand Down Expand Up @@ -482,6 +580,8 @@ def make_class_embeddings_dict(
for k, v in class_embeddings_dict.items()
}

self.class_embeddings_cache[wrapped_training_data_hash] = class_embeddings_dict

return class_embeddings_dict

def make_response(self, predictions, image_sizes, class_names):
Expand Down
57 changes: 57 additions & 0 deletions tests/inference/models_predictions_tests/test_owlv2.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,5 +285,62 @@ def test_owlv2_multiple_training_images():
assert len(response.predictions) == 5


@pytest.mark.slow
def test_owlv2_multiple_training_images_repeated_inference():
image = {
"type": "url",
"value": "https://media.roboflow.com/inference/seawithdock.jpeg",
}
second_image = {
"type": "url",
"value": "https://media.roboflow.com/inference/dock2.jpg",
}

request = OwlV2InferenceRequest(
image=image,
training_data=[
{
"image": image,
"boxes": [
{
"x": 223,
"y": 306,
"w": 40,
"h": 226,
"cls": "post",
"negative": False,
}
],
},
{
"image": second_image,
"boxes": [
{
"x": 3009,
"y": 1873,
"w": 289,
"h": 811,
"cls": "post",
"negative": True,
}
],
},
],
visualize_predictions=True,
confidence=0.9,
)

model = OwlV2()
first_response = model.infer_from_request(request)
second_response = model.infer_from_request(request)
for p1, p2 in zip(first_response.predictions, second_response.predictions):
assert p1.class_name == p2.class_name
assert p1.x == p2.x
assert p1.y == p2.y
assert p1.width == p2.width
assert p1.height == p2.height
assert p1.confidence == p2.confidence


if __name__ == "__main__":
test_owlv2()

0 comments on commit 6049ebb

Please sign in to comment.