Skip to content

Commit

Permalink
Merge pull request #773 from roboflow/fix-classification-caching
Browse files Browse the repository at this point in the history
Fix model type for classification
  • Loading branch information
PawelPeczek-Roboflow authored Nov 7, 2024
2 parents 8c4fdd4 + 63de2ce commit c0f4bc7
Show file tree
Hide file tree
Showing 4 changed files with 336 additions and 13 deletions.
83 changes: 71 additions & 12 deletions inference/core/cache/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,14 @@

from inference.core.devices.utils import GLOBAL_INFERENCE_SERVER_ID
from inference.core.entities.requests.inference import InferenceRequest
from inference.core.entities.responses.inference import InferenceResponse
from inference.core.entities.responses.inference import (
ClassificationInferenceResponse,
InferenceResponse,
InstanceSegmentationInferenceResponse,
KeypointsDetectionInferenceResponse,
MultiLabelClassificationInferenceResponse,
ObjectDetectionInferenceResponse,
)
from inference.core.env import TINY_CACHE
from inference.core.logger import logger
from inference.core.version import __version__
Expand Down Expand Up @@ -33,7 +40,6 @@ def to_cachable_inference_item(
}
request = infer_request.dict(include=included_request_fields)
response = build_condensed_response(infer_response)

return {
"inference_id": infer_request.id,
"inference_server_version": __version__,
Expand All @@ -47,22 +53,75 @@ def build_condensed_response(responses):
if not isinstance(responses, list):
responses = [responses]

response_handlers = {
ClassificationInferenceResponse: from_classification_response,
MultiLabelClassificationInferenceResponse: from_multilabel_classification_response,
ObjectDetectionInferenceResponse: from_object_detection_response,
InstanceSegmentationInferenceResponse: from_instance_segmentation_response,
KeypointsDetectionInferenceResponse: from_keypoints_detection_response,
}

formatted_responses = []
for response in responses:
if not getattr(response, "predictions", None):
continue
try:
predictions = [
{"confidence": pred.confidence, "class": pred.class_name}
for pred in response.predictions
]
formatted_responses.append(
{
"predictions": predictions,
"time": response.time,
}
)
handler = None
for cls, h in response_handlers.items():
if isinstance(response, cls):
handler = h
break
if handler:
predictions = handler(response)
formatted_responses.append(
{
"predictions": predictions,
"time": response.time,
}
)
except Exception as e:
logger.warning(f"Error formatting response, skipping caching: {e}")

return formatted_responses


def from_classification_response(response: ClassificationInferenceResponse):
return [
{"class": pred.class_name, "confidence": pred.confidence}
for pred in response.predictions
]


def from_multilabel_classification_response(
response: MultiLabelClassificationInferenceResponse,
):
return [
{"class": cls, "confidence": pred.confidence}
for cls, pred in response.predictions.items()
]


def from_object_detection_response(response: ObjectDetectionInferenceResponse):
return [
{"class": pred.class_name, "confidence": pred.confidence}
for pred in response.predictions
]


def from_instance_segmentation_response(
response: InstanceSegmentationInferenceResponse,
):
return [
{"class": pred.class_name, "confidence": pred.confidence}
for pred in response.predictions
]


def from_keypoints_detection_response(response: KeypointsDetectionInferenceResponse):
predictions = []
for pred in response.predictions:
for keypoint in pred.keypoints:
predictions.append(
{"class": keypoint.class_name, "confidence": keypoint.confidence}
)
return predictions
4 changes: 4 additions & 0 deletions inference/core/entities/requests/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,10 @@ class ClassificationInferenceRequest(CVInferenceRequest):
visualize_predictions (Optional[bool]): If true, the predictions will be drawn on the original image and returned as a base64 string.
"""

def __init__(self, **kwargs):
kwargs["model_type"] = "classification"
super().__init__(**kwargs)

confidence: Optional[float] = Field(
default=0.4,
examples=[0.5],
Expand Down
1 change: 0 additions & 1 deletion inference/core/interfaces/http/http_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -2308,7 +2308,6 @@ async def legacy_infer_from_request(
usage_billable=countinference,
**args,
)

inference_response = await self.model_manager.infer_from_request(
inference_request.model_id,
inference_request,
Expand Down
261 changes: 261 additions & 0 deletions tests/inference/unit_tests/core/cache/test_serializers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,261 @@
import os
from unittest.mock import MagicMock
import pytest
from inference.core.cache.serializers import (
to_cachable_inference_item,
build_condensed_response,
)
from inference.core.entities.requests.inference import (
ClassificationInferenceRequest,
ObjectDetectionInferenceRequest,
)
from inference.core.entities.responses.inference import (
ClassificationInferenceResponse,
MultiLabelClassificationInferenceResponse,
InstanceSegmentationInferenceResponse,
KeypointsDetectionInferenceResponse,
ObjectDetectionInferenceResponse,
ObjectDetectionPrediction,
ClassificationPrediction,
MultiLabelClassificationPrediction,
InstanceSegmentationPrediction,
KeypointsPrediction,
Keypoint,
Point,
)


@pytest.fixture
def mock_classification_data():
mock_response = MagicMock(spec=ClassificationInferenceResponse)
predictions = [
ClassificationPrediction(**{"class": "cat", "class_id": 1, "confidence": 0.8})
]
mock_response.top = "cat"
mock_response.predictions = predictions
mock_response.confidence = 0.8
mock_response.time = "2023-10-01T12:00:00Z"
return mock_response


def test_build_condensed_response_single_classification(mock_classification_data):
mock_response = mock_classification_data
result = build_condensed_response(mock_response)
assert len(result) == 1
assert "predictions" in result[0]
assert "time" in result[0]


def test_build_condensed_response_multiple_classification(mock_classification_data):
mock_response = mock_classification_data
responses = [mock_response, mock_response]
result = build_condensed_response(responses)
assert len(result) == 2


def test_build_condensed_response_no_predictions_classification(
mock_classification_data,
):
mock_response = mock_classification_data
mock_response.predictions = None
result = build_condensed_response(mock_response)
assert len(result) == 0


@pytest.fixture
def mock_object_detection_data():
mock_request = MagicMock(spec=ObjectDetectionInferenceRequest)
mock_request.id = "test_id"
mock_request.confidence = 0.85
mock_request.dict.return_value = {
"api_key": "test_key",
"confidence": 0.85,
"model_id": "sharks",
"model_type": "object_detection",
}

mock_response = MagicMock(spec=ObjectDetectionInferenceResponse)
mock_response.predictions = [
ObjectDetectionPrediction(
**{
"class_name": "tiger-shark",
"confidence": 0.95,
"x": 0,
"y": 0,
"width": 0,
"height": 0,
"class_confidence": None,
"class_id": 1,
"class": "tiger-shark",
}
),
ObjectDetectionPrediction(
**{
"class_name": "hammerhead",
"confidence": 0.95,
"x": 0,
"y": 0,
"width": 0,
"height": 0,
"class_confidence": None,
"class_id": 2,
"class": "hammerhead",
}
),
ObjectDetectionPrediction(
**{
"class_name": "white-shark",
"confidence": 0.95,
"x": 0,
"y": 0,
"width": 0,
"height": 0,
"class_confidence": None,
"class_id": 3,
"class": "white-shark",
}
),
]
mock_response.time = "2023-10-01T12:00:00Z"

return mock_request, mock_response


def test_to_cachable_inference_item_no_tiny_cache_object_detection(
mock_object_detection_data,
):
mock_request, mock_response = mock_object_detection_data
os.environ["TINY_CACHE"] = "False"
result = to_cachable_inference_item(mock_request, mock_response)
assert result["inference_id"] == mock_request.id
assert result["request"]["api_key"] == mock_request.dict.return_value["api_key"]
assert (
result["response"][0]["predictions"][0]["class"]
== mock_response.predictions[0].class_name
)
assert (
result["response"][0]["predictions"][0]["confidence"]
== mock_response.predictions[0].confidence
)


def test_to_cachable_inference_item_with_tiny_cache_object_detection(
mock_object_detection_data,
):
mock_request, mock_response = mock_object_detection_data
os.environ["TINY_CACHE"] = "True"
result = to_cachable_inference_item(mock_request, mock_response)
assert result["inference_id"] == mock_request.id
assert result["request"]["api_key"] == mock_request.dict.return_value["api_key"]
assert (
result["response"][0]["predictions"][0]["class"]
== mock_response.predictions[0].class_name
)
assert (
result["response"][0]["predictions"][0]["confidence"]
== mock_response.predictions[0].confidence
)


def test_build_condensed_response_no_predictions_object_detection(
mock_object_detection_data,
):
_, mock_response = mock_object_detection_data
mock_response.predictions = None
result = build_condensed_response(mock_response)
assert len(result) == 0


@pytest.fixture
def mock_multilabel_classification_data():
mock_response = MagicMock(spec=MultiLabelClassificationInferenceResponse)
mock_response.predictions = {
"cat": MultiLabelClassificationPrediction(confidence=0.8, class_id=1),
"dog": MultiLabelClassificationPrediction(confidence=0.7, class_id=2),
}
mock_response.time = "2023-10-01T12:00:00Z"
return mock_response


@pytest.fixture
def mock_instance_segmentation_data():
mock_response = MagicMock(spec=InstanceSegmentationInferenceResponse)
mock_response.predictions = [
InstanceSegmentationPrediction(
**{
"class": "person",
"confidence": 0.9,
"class_confidence": None,
"detection_id": "1",
"parent_id": None,
"x": 0,
"y": 0,
"width": 0,
"height": 0,
"points": [Point(x=0, y=0)],
"class_id": 1,
}
)
]
mock_response.time = "2023-10-01T12:00:00Z"
return mock_response


@pytest.fixture
def mock_keypoints_detection_data():
mock_response = MagicMock(spec=KeypointsDetectionInferenceResponse)
mock_response.predictions = [
KeypointsPrediction(
**{
"class": "person",
"confidence": 0.9,
"class_confidence": None,
"detection_id": "1",
"parent_id": None,
"x": 0,
"y": 0,
"width": 0,
"height": 0,
"keypoints": [
Keypoint(
**{
"x": 0,
"y": 0,
"confidence": 0.8,
"class_id": 1,
"class_name": "nose",
}
)
],
"class_id": 1,
}
)
]
mock_response.time = "2023-10-01T12:00:00Z"
return mock_response


def test_build_condensed_response_instance_segmentation(
mock_instance_segmentation_data,
):
mock_response = mock_instance_segmentation_data
result = build_condensed_response(mock_response)
assert len(result) == 1
assert "predictions" in result[0]
assert "time" in result[0]


def test_build_condensed_response_keypoints_detection(mock_keypoints_detection_data):
mock_response = mock_keypoints_detection_data
result = build_condensed_response(mock_response)
assert len(result) == 1
assert "predictions" in result[0]
assert "time" in result[0]


def test_build_condensed_response_object_detection(mock_object_detection_data):
_, mock_response = mock_object_detection_data
result = build_condensed_response(mock_response)
assert len(result) == 1
assert "predictions" in result[0]
assert "time" in result[0]

0 comments on commit c0f4bc7

Please sign in to comment.