From b16235fcda29ed343fc0ab93538daff7ac6e9711 Mon Sep 17 00:00:00 2001 From: Mattt Zmuda Date: Thu, 3 Aug 2023 09:43:22 -0700 Subject: [PATCH] Add Deployment model and deployments collection property on Client Signed-off-by: Mattt Zmuda --- replicate/__init__.py | 1 + replicate/client.py | 5 ++ replicate/deployment.py | 140 +++++++++++++++++++++++++++++++++++++++ tests/test_deployment.py | 47 +++++++++++++ 4 files changed, 193 insertions(+) create mode 100644 replicate/deployment.py create mode 100644 tests/test_deployment.py diff --git a/replicate/__init__.py b/replicate/__init__.py index 0d35ad3b..c8aaeb01 100644 --- a/replicate/__init__.py +++ b/replicate/__init__.py @@ -5,3 +5,4 @@ models = default_client.models predictions = default_client.predictions trainings = default_client.trainings +deployments = default_client.deployments diff --git a/replicate/client.py b/replicate/client.py index 91a2bf07..e78296a4 100644 --- a/replicate/client.py +++ b/replicate/client.py @@ -8,6 +8,7 @@ from requests.cookies import RequestsCookieJar from replicate.__about__ import __version__ +from replicate.deployment import DeploymentCollection from replicate.exceptions import ModelError, ReplicateError from replicate.model import ModelCollection from replicate.prediction import PredictionCollection @@ -113,6 +114,10 @@ def predictions(self) -> PredictionCollection: def trainings(self) -> TrainingCollection: return TrainingCollection(client=self) + @property + def deployments(self) -> DeploymentCollection: + return DeploymentCollection(client=self) + def run(self, model_version: str, **kwargs) -> Union[Any, Iterator[Any]]: """ Run a model and wait for its output. diff --git a/replicate/deployment.py b/replicate/deployment.py new file mode 100644 index 00000000..9cab3a7b --- /dev/null +++ b/replicate/deployment.py @@ -0,0 +1,140 @@ +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union + +from replicate.base_model import BaseModel +from replicate.collection import Collection +from replicate.files import upload_file +from replicate.json import encode_json +from replicate.prediction import Prediction + +if TYPE_CHECKING: + from replicate.client import Client + + +class Deployment(BaseModel): + """ + A deployment of a model hosted on Replicate. + """ + + username: str + """ + The name of the user or organization that owns the deployment. + """ + + name: str + """ + The name of the deployment. + """ + + @property + def predictions(self) -> "DeploymentPredictionCollection": + """ + Get the predictions for this deployment. + """ + + return DeploymentPredictionCollection(client=self._client, deployment=self) + + +class DeploymentCollection(Collection): + model = Deployment + + def list(self) -> List[Deployment]: + raise NotImplementedError() + + def get(self, name: str) -> Deployment: + """ + Get a deployment by name. + + Args: + name: The name of the deployment, in the format `owner/model-name`. + Returns: + The model. + """ + + # TODO: fetch model from server + # TODO: support permanent IDs + username, name = name.split("/") + return self.prepare_model({"username": username, "name": name}) + + def create(self, **kwargs) -> Deployment: + raise NotImplementedError() + + def prepare_model(self, attrs: Union[Deployment, Dict]) -> Deployment: + if isinstance(attrs, BaseModel): + attrs.id = f"{attrs.username}/{attrs.name}" + elif isinstance(attrs, dict): + attrs["id"] = f"{attrs['username']}/{attrs['name']}" + return super().prepare_model(attrs) + + +class DeploymentPredictionCollection(Collection): + model = Prediction + + def __init__(self, client: "Client", deployment: Deployment) -> None: + super().__init__(client=client) + self._deployment = deployment + + def list(self) -> List[Prediction]: + raise NotImplementedError() + + def get(self, id: str) -> Prediction: + """ + Get a prediction by ID. + + Args: + id: The ID of the prediction. + Returns: + Prediction: The prediction object. + """ + + resp = self._client._request("GET", f"/v1/predictions/{id}") + obj = resp.json() + # HACK: resolve this? make it lazy somehow? + del obj["version"] + return self.prepare_model(obj) + + def create( # type: ignore + self, + input: Dict[str, Any], + webhook: Optional[str] = None, + webhook_completed: Optional[str] = None, + webhook_events_filter: Optional[List[str]] = None, + *, + stream: Optional[bool] = None, + **kwargs, + ) -> Prediction: + """ + Create a new prediction with the deployment. + + Args: + input: The input data for the prediction. + webhook: The URL to receive a POST request with prediction updates. + webhook_completed: The URL to receive a POST request when the prediction is completed. + webhook_events_filter: List of events to trigger webhooks. + stream: Set to True to enable streaming of prediction output. + + Returns: + Prediction: The created prediction object. + """ + + input = encode_json(input, upload_file=upload_file) + body = { + "input": input, + } + if webhook is not None: + body["webhook"] = webhook + if webhook_completed is not None: + body["webhook_completed"] = webhook_completed + if webhook_events_filter is not None: + body["webhook_events_filter"] = webhook_events_filter + if stream is True: + body["stream"] = "true" + + resp = self._client._request( + "POST", + f"/v1/deployments/{self._deployment.username}/{self._deployment.name}/predictions", + json=body, + ) + obj = resp.json() + obj["deployment"] = self._deployment + del obj["version"] + return self.prepare_model(obj) diff --git a/tests/test_deployment.py b/tests/test_deployment.py new file mode 100644 index 00000000..8ec63a77 --- /dev/null +++ b/tests/test_deployment.py @@ -0,0 +1,47 @@ +import responses +from responses import matchers + +from replicate.client import Client + + +@responses.activate +def test_deployment_predictions_create(): + client = Client(api_token="abc123") + + deployment = client.deployments.get("test/model") + + rsp = responses.post( + "https://api.replicate.com/v1/deployments/test/model/predictions", + match=[ + matchers.json_params_matcher( + { + "input": {"text": "world"}, + "webhook": "https://example.com/webhook", + "webhook_events_filter": ["completed"], + } + ), + ], + json={ + "id": "p1", + "version": "v1", + "urls": { + "get": "https://api.replicate.com/v1/predictions/p1", + "cancel": "https://api.replicate.com/v1/predictions/p1/cancel", + }, + "created_at": "2022-04-26T20:00:40.658234Z", + "source": "api", + "status": "processing", + "input": {"text": "hello"}, + "output": None, + "error": None, + "logs": "", + }, + ) + + deployment.predictions.create( + input={"text": "world"}, + webhook="https://example.com/webhook", + webhook_events_filter=["completed"], + ) + + assert rsp.call_count == 1