Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[HWORKS-937] Explicit model provenance - hsml optimization - init serving/batching #236

Merged
merged 3 commits into from
May 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 19 additions & 5 deletions python/hsml/core/serving_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,9 @@ def get_by_id(self, id: int):
str(id),
]
deployment_json = _client._send_request("GET", path_params)
return deployment.Deployment.from_response_json(deployment_json)
deployment_instance = deployment.Deployment.from_response_json(deployment_json)
deployment_instance.model_registry_id = _client._project_id
return deployment_instance

def get(self, name: str):
"""Get the metadata of a deployment with a certain name.
Expand All @@ -71,7 +73,9 @@ def get(self, name: str):
deployment_json = _client._send_request(
"GET", path_params, query_params=query_params
)
return deployment.Deployment.from_response_json(deployment_json)
deployment_instance = deployment.Deployment.from_response_json(deployment_json)
deployment_instance.model_registry_id = _client._project_id
return deployment_instance

def get_all(self, model_name: str = None, status: str = None):
"""Get the metadata of all deployments.
Expand All @@ -89,7 +93,12 @@ def get_all(self, model_name: str = None, status: str = None):
deployments_json = _client._send_request(
"GET", path_params, query_params=query_params
)
return deployment.Deployment.from_response_json(deployments_json)
deployment_instances = deployment.Deployment.from_response_json(
deployments_json
)
for deployment_instance in deployment_instances:
deployment_instance.model_registry_id = _client._project_id
return deployment_instances

def get_inference_endpoints(self):
"""Get inference endpoints.
Expand Down Expand Up @@ -119,14 +128,16 @@ def put(self, deployment_instance):
if deployment_instance.artifact_version == ARTIFACT_VERSION.CREATE:
deployment_instance.artifact_version = -1

return deployment_instance.update_from_response_json(
deployment_instance = deployment_instance.update_from_response_json(
_client._send_request(
"PUT",
path_params,
headers=headers,
data=deployment_instance.json(),
)
)
deployment_instance.model_registry_id = _client._project_id
return deployment_instance

def post(self, deployment_instance, action: str):
"""Perform an action on the deployment
Expand Down Expand Up @@ -195,7 +206,10 @@ def reset_changes(self, deployment_instance):
deployment_json = _client._send_request(
"GET", path_params, query_params=query_params
)
return deployment_instance.update_from_response_json(deployment_json)
deployment_aux = deployment_instance.update_from_response_json(deployment_json)
# TODO: remove when model_registry_id is added properly to deployments in backend
deployment_aux.model_registry_id = _client._project_id
return deployment_aux

def send_inference_request(
self,
Expand Down
19 changes: 18 additions & 1 deletion python/hsml/deployment.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from hsml.client.exceptions import ModelServingException
from hsml.client.istio.utils.infer_type import InferInput
from hsml.constants import DEPLOYABLE_COMPONENT, PREDICTOR_STATE
from hsml.core import serving_api
from hsml.core import model_api, serving_api
from hsml.engine import serving_engine
from hsml.inference_batcher import InferenceBatcher
from hsml.inference_logger import InferenceLogger
Expand Down Expand Up @@ -59,7 +59,9 @@ def __init__(

self._serving_api = serving_api.ServingApi()
self._serving_engine = serving_engine.ServingEngine()
self._model_api = model_api.ModelApi()
self._grpc_channel = None
self._model_registry_id = None

def save(self, await_update: Optional[int] = 60):
"""Persist this deployment including the predictor and metadata to Model Serving.
Expand Down Expand Up @@ -203,6 +205,12 @@ def predict(

return self._serving_engine.predict(self, data, inputs)

def get_model(self):
"""Retrieve the metadata object for the model being used by this deployment"""
return self._model_api.get(
javierdlrm marked this conversation as resolved.
Show resolved Hide resolved
self.model_name, self.model_version, self.model_registry_id
)

def download_artifact(self):
"""Download the model artifact served by the deployment"""

Expand Down Expand Up @@ -425,6 +433,15 @@ def transformer(self):
def transformer(self, transformer: Transformer):
self._predictor.transformer = transformer

@property
def model_registry_id(self):
"""Model Registry Id of the deployment."""
return self._model_registry_id

@model_registry_id.setter
def model_registry_id(self, model_registry_id: int):
self._model_registry_id = model_registry_id

@property
def created_at(self):
"""Created at date of the predictor."""
Expand Down
25 changes: 23 additions & 2 deletions python/hsml/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
#

import json
import logging
import os
import warnings
from typing import Any, Dict, Optional, Union

Expand All @@ -31,6 +33,9 @@
from hsml.transformer import Transformer


_logger = logging.getLogger(__name__)


class Model:
"""Metadata object representing a model in the Model Registry."""

Expand Down Expand Up @@ -280,7 +285,7 @@ def get_url(self):
)
return util.get_hostname_replaced_url(sub_path=path)

def get_feature_view(self):
def get_feature_view(self, init: bool = True, online: Optional[bool] = None):
"""Get the parent feature view of this model, based on explicit provenance.
Only accessible, usable feature view objects are returned. Otherwise an Exception is raised.
For more details, call the base method - get_feature_view_provenance
Expand All @@ -291,7 +296,23 @@ def get_feature_view(self):
`Exception` in case the backend fails to retrieve the tags.
"""
fv_prov = self.get_feature_view_provenance()
return explicit_provenance.Links.get_one_accessible_parent(fv_prov)
fv = explicit_provenance.Links.get_one_accessible_parent(fv_prov)
if fv is None:
return None
if init:
td_prov = self.get_training_dataset_provenance()
td = explicit_provenance.Links.get_one_accessible_parent(td_prov)
is_deployment = "DEPLOYMENT_NAME" in os.environ
if online or is_deployment:
_logger.info(
o-alex marked this conversation as resolved.
Show resolved Hide resolved
"Initializing for batch and online retrieval of feature vectors"
+ (" - within a deployment" if is_deployment else "")
)
fv.init_serving(training_dataset_version=td.version)
o-alex marked this conversation as resolved.
Show resolved Hide resolved
elif online is False:
_logger.info("Initializing for batch retrieval of feature vectors")
fv.init_batch_scoring(training_dataset_version=td.version)
return fv

def get_feature_view_provenance(self):
"""Get the parent feature view of this model, based on explicit provenance.
Expand Down
5 changes: 4 additions & 1 deletion python/hsml/model_serving.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# limitations under the License.
#

import os
from typing import Optional, Union

from hsml import util
Expand Down Expand Up @@ -66,7 +67,7 @@ def get_deployment_by_id(self, id: int):

return self._serving_api.get_by_id(id)

def get_deployment(self, name: str):
def get_deployment(self, name: str = None):
"""Get a deployment by name from Model Serving.

!!! example
Expand All @@ -88,6 +89,8 @@ def get_deployment(self, name: str):
`RestAPIError`: If unable to retrieve deployment from model serving.
"""

if name is None and ("DEPLOYMENT_NAME" in os.environ):
name = os.environ["DEPLOYMENT_NAME"]
return self._serving_api.get(name)

def get_deployments(self, model: Model = None, status: str = None):
Expand Down
70 changes: 70 additions & 0 deletions python/tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,12 @@
#

import copy
import os

import humps
from hsml import model
from hsml.constants import MODEL
from hsml.core import explicit_provenance


class TestModel:
Expand Down Expand Up @@ -398,3 +400,71 @@ def assert_model(self, mocker, m, m_json, model_framework):
mock_read_file.assert_called_once_with(
model_instance=m, resource=m_json["environment"]
)

def test_get_feature_view(self, mocker):
mock_fv = mocker.Mock()
links = explicit_provenance.Links(accessible=[mock_fv])
mock_fv_provenance = mocker.patch(
"hsml.model.Model.get_feature_view_provenance", return_value=links
)
mock_td_provenance = mocker.patch(
"hsml.model.Model.get_training_dataset_provenance", return_value=links
)
mocker.patch("os.environ", return_value={})
m = model.Model(1, "test")
m.get_feature_view()
mock_fv_provenance.assert_called_once()
mock_td_provenance.assert_called_once()
assert not mock_fv.init_serving.called
assert not mock_fv.init_batch_scoring.called

def test_get_feature_view_online(self, mocker):
mock_fv = mocker.Mock()
links = explicit_provenance.Links(accessible=[mock_fv])
mock_fv_provenance = mocker.patch(
"hsml.model.Model.get_feature_view_provenance", return_value=links
)
mock_td_provenance = mocker.patch(
"hsml.model.Model.get_training_dataset_provenance", return_value=links
)
mocker.patch("os.environ", return_value={})
m = model.Model(1, "test")
m.get_feature_view(online=True)
mock_fv_provenance.assert_called_once()
mock_td_provenance.assert_called_once()
assert mock_fv.init_serving.called
assert not mock_fv.init_batch_scoring.called

def test_get_feature_view_batch(self, mocker):
mock_fv = mocker.Mock()
links = explicit_provenance.Links(accessible=[mock_fv])
mock_fv_provenance = mocker.patch(
"hsml.model.Model.get_feature_view_provenance", return_value=links
)
mock_td_provenance = mocker.patch(
"hsml.model.Model.get_training_dataset_provenance", return_value=links
)
mocker.patch("os.environ", return_value={})
m = model.Model(1, "test")
m.get_feature_view(online=False)
mock_fv_provenance.assert_called_once()
mock_td_provenance.assert_called_once()
assert not mock_fv.init_serving.called
assert mock_fv.init_batch_scoring.called

def test_get_feature_view_deployment(self, mocker):
mock_fv = mocker.Mock()
links = explicit_provenance.Links(accessible=[mock_fv])
mock_fv_provenance = mocker.patch(
"hsml.model.Model.get_feature_view_provenance", return_value=links
)
mock_td_provenance = mocker.patch(
"hsml.model.Model.get_training_dataset_provenance", return_value=links
)
mocker.patch.dict(os.environ, {"DEPLOYMENT_NAME": "test"})
m = model.Model(1, "test")
m.get_feature_view()
mock_fv_provenance.assert_called_once()
mock_td_provenance.assert_called_once()
assert mock_fv.init_serving.called
assert not mock_fv.init_batch_scoring.called