Skip to content

Commit

Permalink
Merge branch 'main' into feature/model-registry-onboarding
Browse files Browse the repository at this point in the history
  • Loading branch information
lugi0 authored Dec 17, 2024
2 parents aef064f + 452e555 commit 5396178
Show file tree
Hide file tree
Showing 4 changed files with 374 additions and 56 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ repos:
- id: detect-secrets

- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.8.1
rev: v0.8.3
hooks:
- id: ruff
- id: ruff-format
Expand Down
14 changes: 13 additions & 1 deletion tests/trustyai/drift/test_drift.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import pytest

from tests.trustyai.drift.utils import send_inference_requests_and_verify_trustyai_service
from tests.trustyai.drift.utils import send_inference_requests_and_verify_trustyai_service, verify_metric_request


@pytest.mark.parametrize(
Expand All @@ -17,6 +17,7 @@ class TestDriftMetrics:
Verifies all the basic operations with a drift metric (meanshift) available in TrustyAI, using PVC storage.
1. Send data to the model (gaussian_credit_model) and verify that TrustyAI registers the observations.
2. Send metric request (meanshift) and verify the response.
"""

def test_send_inference_request_and_verify_trustyai_service(
Expand All @@ -36,3 +37,14 @@ def test_send_inference_request_and_verify_trustyai_service(
)

# TODO: Add rest of operations in upcoming PRs (upload data directly to Trusty, send metric request, schedule period metric calculation, delete metric request).

def test_drift_metric_meanshift(
self, admin_client, openshift_token, trustyai_service_with_pvc_storage, gaussian_credit_model
):
verify_metric_request(
client=admin_client,
trustyai_service=trustyai_service_with_pvc_storage,
token=openshift_token,
metric_name="meanshift",
json_data={"modelId": gaussian_credit_model.name, "referenceTag": "TRAINING"},
)
61 changes: 61 additions & 0 deletions tests/trustyai/drift/utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import http
import json
import os
from typing import Any, Dict, List, Optional
Expand All @@ -21,6 +22,10 @@
TIMEOUT_30SEC: int = 30


class MetricValidationError(Exception):
pass


class TrustyAIServiceRequestHandler:
"""
Class to encapsulate the behaviors associated to the different TrustyAIService requests we make in the tests
Expand Down Expand Up @@ -54,6 +59,14 @@ def _send_request(
def get_model_metadata(self) -> Any:
return self._send_request(endpoint="/info", method="GET")

def send_drift_request(
self,
metric_name: str,
json: Optional[Dict[str, Any]] = None,
) -> Any:
LOGGER.info(f"Sending request for drift metric: {metric_name}")
return self._send_request(endpoint=f"/metrics/drift/{metric_name}", method="POST", json=json)


# TODO: Refactor code to be under utilities.inference_utils.Inference
def send_inference_request(
Expand Down Expand Up @@ -238,3 +251,51 @@ def _check_pods_ready_with_env() -> bool:
for sample in samples:
if sample:
return


def verify_metric_request(
client: DynamicClient, trustyai_service: TrustyAIService, token: str, metric_name: str, json_data: Any
) -> None:
"""
Sends a metric request to a TrustyAIService and validates the response.
Args:
client (DynamicClient): The client instance for interacting with the cluster.
trustyai_service (TrustyAIService): The TrustyAI service instance to interact with.
token (str): Authentication token for the service.
metric_name (str): Name of the metric to request.
json_data (Any): JSON payload for the metric request.
Raise:
AssertionError if some of the response fields does not have the expected value.
"""

response = TrustyAIServiceRequestHandler(token=token, service=trustyai_service, client=client).send_drift_request(
metric_name=metric_name, json=json_data
)
LOGGER.info(msg=f"TrustyAI metric request response: {json.dumps(json.loads(response.text), indent=2)}")
response_data = json.loads(response.text)

errors = []

if response.status_code != http.HTTPStatus.OK:
errors.append(f"Unexpected status code: {response.status_code}")
if response_data["timestamp"] == "":
errors.append("Timestamp is empty")
if response_data["type"] != "metric":
errors.append("Incorrect type")
if response_data["value"] == "":
errors.append("Value is empty")
if not isinstance(response_data["value"], float):
errors.append("Value must be a float")
if response_data["specificDefinition"] == "":
errors.append("Specific definition is empty")
if response_data["name"] != metric_name:
errors.append(f"Wrong name: {response_data['name']}, expected: {metric_name}")
if response_data["id"] == "":
errors.append("ID is empty")
if response_data["thresholds"] == "":
errors.append("Thresholds are empty")

if errors:
raise MetricValidationError("\n".join(errors))
Loading

0 comments on commit 5396178

Please sign in to comment.