From 941b5d43776c20b3c59cacf6864e2b78084c8e50 Mon Sep 17 00:00:00 2001 From: Isaac Robinson Date: Thu, 7 Nov 2024 00:30:02 +0000 Subject: [PATCH 1/2] avoiding downloading images if possible --- .../benchmark_owlv2_inference_time.py | 2 +- inference/models/owlv2/owlv2.py | 134 +++++++++++++++--- .../models_predictions_tests/test_owlv2.py | 57 ++++++++ 3 files changed, 172 insertions(+), 21 deletions(-) diff --git a/development/benchmark_scripts/benchmark_owlv2_inference_time.py b/development/benchmark_scripts/benchmark_owlv2_inference_time.py index dfcc23f4d..95592ed87 100644 --- a/development/benchmark_scripts/benchmark_owlv2_inference_time.py +++ b/development/benchmark_scripts/benchmark_owlv2_inference_time.py @@ -28,7 +28,7 @@ training_data=[ { "image": img, - "boxes": [{"x": 223, "y": 306, "w": 40, "h": 226, "cls": "post"}], + "boxes": [{"x": 223, "y": 306, "w": 40, "h": 226, "cls": "post", "negative": False}], } ], visualize_predictions=False, diff --git a/inference/models/owlv2/owlv2.py b/inference/models/owlv2/owlv2.py index 8ff543453..3f2879fea 100644 --- a/inference/models/owlv2/owlv2.py +++ b/inference/models/owlv2/owlv2.py @@ -1,7 +1,8 @@ import hashlib import os +import pickle from collections import defaultdict -from typing import Any, Dict, List, Literal, NewType, Tuple +from typing import Any, Dict, List, Literal, NewType, Tuple, Union import numpy as np import torch @@ -10,7 +11,6 @@ from transformers import Owlv2ForObjectDetection, Owlv2Processor from transformers.models.owlv2.modeling_owlv2 import box_iou -from inference.core.entities.requests.owlv2 import TrainingImage from inference.core.entities.responses.inference import ( InferenceResponseImage, ObjectDetectionInferenceResponse, @@ -22,7 +22,11 @@ RoboflowCoreModel, draw_detection_predictions, ) -from inference.core.utils.image_utils import load_image_rgb +from inference.core.utils.image_utils import ( + ImageType, + extract_image_payload_and_type, + load_image_rgb, +) # TYPES Hash = NewType("Hash", str) @@ -196,6 +200,52 @@ def make_class_map( return class_map, class_names +def hash_function(value: Any) -> Hash: + # wrapper so we can change the hashing function in the future + return hashlib.sha1(value).hexdigest() + + +class LazyImageRetrievalWrapper: + def __init__(self, image: Any): + self.image = image + + self.types_that_can_hash_raw = {ImageType.URL, ImageType.BASE64} + + self._image_as_numpy = None + self._image_hash = None + + @property + def image_as_numpy(self) -> np.ndarray: + if self._image_as_numpy is None: + self._image_as_numpy = load_image_rgb(self.image) + return self._image_as_numpy + + @property + def image_hash(self) -> Hash: + if self._image_hash is None: + image_payload, image_type = extract_image_payload_and_type(self.image) + if image_type in self.types_that_can_hash_raw: + # for these types, hashing directly is faster than loading the raw image through numpy + # and is safe against the pointer changing ie a filepath mapping to a different file + if type(image_payload) is str: + image_payload = image_payload.encode("utf-8") + self._image_hash = hash_function(image_payload) + else: + self._image_hash = hash_function(self.image_as_numpy.tobytes()) + return self._image_hash + + +def hash_wrapped_training_data(wrapped_training_data: List[Dict[str, Any]]) -> Hash: + just_hash_relevant_data = [ + [ + d["image"].image_hash, + d["boxes"], + ] + for d in wrapped_training_data + ] + return hash_function(pickle.dumps(just_hash_relevant_data)) + + class OwlV2(RoboflowCoreModel): task_type = "object-detection" box_format = "xywh" @@ -224,9 +274,12 @@ def __init__(self, *args, model_id="owlv2/owlv2-base-patch16-ensemble", **kwargs self.model.owlv2.vision_model = torch.compile(self.model.owlv2.vision_model) def reset_cache(self): - self.image_embed_cache = LimitedSizeDict( - size_limit=1000 - ) # NOTE: this should have a max size + # each entry should be on the order of 300*4KB, so 1000 is 400MB of CUDA memory + self.image_embed_cache = LimitedSizeDict(size_limit=1000) + # each entry should be on the order of 10 bytes, so 1000 is 10KB + self.image_size_cache = LimitedSizeDict(size_limit=1000) + # entry size will vary depending on the number of samples, but 100 should be safe + self.class_embeddings_cache = LimitedSizeDict(size_limit=100) def draw_predictions( self, @@ -258,15 +311,35 @@ def download_weights(self) -> None: # Download from huggingface pass + def compute_image_size( + self, image: Union[np.ndarray, LazyImageRetrievalWrapper] + ) -> Tuple[int, int]: + # we build this in hopes of avoiding having to load the image solely for the purpose of getting its size + if isinstance(image, LazyImageRetrievalWrapper): + if (image_size := self.image_size_cache.get(image.image_hash)) is None: + image_size = image.image_as_numpy.shape[:2][::-1] + self.image_size_cache[image.image_hash] = image_size + else: + image_size = image.shape[:2][::-1] + return image_size + @torch.no_grad() - def embed_image(self, image: np.ndarray) -> Hash: - image_hash = hashlib.sha256(image.tobytes()).hexdigest() + def embed_image(self, image: Union[np.ndarray, LazyImageRetrievalWrapper]) -> Hash: + if isinstance(image, LazyImageRetrievalWrapper): + image_hash = image.image_hash + else: + image_hash = hash_function(image.tobytes()) - if (image_embeds := self.image_embed_cache.get(image_hash)) is not None: + if image_hash in self.image_embed_cache: return image_hash + np_image = ( + image.image_as_numpy + if isinstance(image, LazyImageRetrievalWrapper) + else image + ) pixel_values = preprocess_image( - image, self.image_size, self.image_mean, self.image_std + np_image, self.image_size, self.image_mean, self.image_std ) # torch 2.4 lets you use "cuda:0" as device_type @@ -422,12 +495,16 @@ def infer( else: images = image + images = [LazyImageRetrievalWrapper(image) for image in images] + results = [] image_sizes = [] - for image in images: - image = load_image_rgb(image) - image_sizes.append(image.shape[:2][::-1]) - image_hash = self.embed_image(image) + for image_wrapper in images: + # happy path here is that both image size and image embeddings are cached + # in which case we avoid loading the image at all + image_size = self.compute_image_size(image_wrapper) + image_sizes.append(image_size) + image_hash = self.embed_image(image_wrapper) result = self.infer_from_embed( image_hash, class_embeddings_dict, confidence, iou_threshold ) @@ -439,20 +516,35 @@ def infer( def make_class_embeddings_dict( self, training_data: List[Any], iou_threshold: float ) -> Dict[str, PosNegDictType]: + wrapped_training_data = [ + { + "image": LazyImageRetrievalWrapper(train_image["image"]), + "boxes": train_image["boxes"], + } + for train_image in training_data + ] + + wrapped_training_data_hash = hash_wrapped_training_data(wrapped_training_data) + if ( + class_embeddings_dict := self.class_embeddings_cache.get( + wrapped_training_data_hash + ) + ) is not None: + return class_embeddings_dict + class_embeddings_dict = defaultdict(lambda: {"positive": [], "negative": []}) bool_to_literal = {True: "positive", False: "negative"} - for train_image in training_data: + for train_image in wrapped_training_data: # grab and embed image - image = load_image_rgb(train_image["image"]) - image_hash = self.embed_image(image) + image_hash = self.embed_image(train_image["image"]) # grab and normalize box prompts for this image + image_size = self.compute_image_size(train_image["image"]) boxes = train_image["boxes"] + print(f"boxes: {boxes}") coords = [[box["x"], box["y"], box["w"], box["h"]] for box in boxes] - coords = [ - tuple([c / max(image.shape[:2]) for c in coord]) for coord in coords - ] + coords = [tuple([c / max(image_size) for c in coord]) for coord in coords] classes = [box["cls"] for box in boxes] is_positive = [not box["negative"] for box in boxes] @@ -482,6 +574,8 @@ def make_class_embeddings_dict( for k, v in class_embeddings_dict.items() } + self.class_embeddings_cache[wrapped_training_data_hash] = class_embeddings_dict + return class_embeddings_dict def make_response(self, predictions, image_sizes, class_names): diff --git a/tests/inference/models_predictions_tests/test_owlv2.py b/tests/inference/models_predictions_tests/test_owlv2.py index 33e8b7f44..95bf2d540 100644 --- a/tests/inference/models_predictions_tests/test_owlv2.py +++ b/tests/inference/models_predictions_tests/test_owlv2.py @@ -285,5 +285,62 @@ def test_owlv2_multiple_training_images(): assert len(response.predictions) == 5 +@pytest.mark.slow +def test_owlv2_multiple_training_images_repeated_inference(): + image = { + "type": "url", + "value": "https://media.roboflow.com/inference/seawithdock.jpeg", + } + second_image = { + "type": "url", + "value": "https://media.roboflow.com/inference/dock2.jpg", + } + + request = OwlV2InferenceRequest( + image=image, + training_data=[ + { + "image": image, + "boxes": [ + { + "x": 223, + "y": 306, + "w": 40, + "h": 226, + "cls": "post", + "negative": False, + } + ], + }, + { + "image": second_image, + "boxes": [ + { + "x": 3009, + "y": 1873, + "w": 289, + "h": 811, + "cls": "post", + "negative": True, + } + ], + }, + ], + visualize_predictions=True, + confidence=0.9, + ) + + model = OwlV2() + first_response = model.infer_from_request(request) + second_response = model.infer_from_request(request) + for p1, p2 in zip(first_response.predictions, second_response.predictions): + assert p1.class_name == p2.class_name + assert p1.x == p2.x + assert p1.y == p2.y + assert p1.width == p2.width + assert p1.height == p2.height + assert p1.confidence == p2.confidence + + if __name__ == "__main__": test_owlv2() From 12ac465857007dc0111ce10dd5f650a87c97606b Mon Sep 17 00:00:00 2001 From: Isaac Robinson Date: Thu, 7 Nov 2024 00:45:03 +0000 Subject: [PATCH 2/2] don't hash the url --- inference/models/owlv2/owlv2.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/inference/models/owlv2/owlv2.py b/inference/models/owlv2/owlv2.py index 3f2879fea..7e73c9407 100644 --- a/inference/models/owlv2/owlv2.py +++ b/inference/models/owlv2/owlv2.py @@ -209,8 +209,6 @@ class LazyImageRetrievalWrapper: def __init__(self, image: Any): self.image = image - self.types_that_can_hash_raw = {ImageType.URL, ImageType.BASE64} - self._image_as_numpy = None self._image_hash = None @@ -224,13 +222,20 @@ def image_as_numpy(self) -> np.ndarray: def image_hash(self) -> Hash: if self._image_hash is None: image_payload, image_type = extract_image_payload_and_type(self.image) - if image_type in self.types_that_can_hash_raw: - # for these types, hashing directly is faster than loading the raw image through numpy - # and is safe against the pointer changing ie a filepath mapping to a different file + if image_type is ImageType.URL: + # we can use the url as the hash + self._image_hash = image_payload + elif image_type is ImageType.BASE64: + # this is presumably the compressed image bytes + # hashing this directly is faster than loading the raw image through numpy + # we have to make sure we're passing a buffer, so we encode to bytes if necessary + # see load_image_base64 in image_utils.py for more details about the base64 encoding if type(image_payload) is str: image_payload = image_payload.encode("utf-8") self._image_hash = hash_function(image_payload) else: + # not clear that there is something safe or faster to do than just loading the numpy array + # and hashing that self._image_hash = hash_function(self.image_as_numpy.tobytes()) return self._image_hash @@ -243,6 +248,7 @@ def hash_wrapped_training_data(wrapped_training_data: List[Dict[str, Any]]) -> H ] for d in wrapped_training_data ] + # we dump to pickle to serialize the data as a single object return hash_function(pickle.dumps(just_hash_relevant_data))