Skip to content

Commit

Permalink
Add Deployment model and deployments collection property on Client
Browse files Browse the repository at this point in the history
Signed-off-by: Mattt Zmuda <[email protected]>
  • Loading branch information
mattt committed Sep 11, 2023
1 parent 2b1da58 commit b16235f
Show file tree
Hide file tree
Showing 4 changed files with 193 additions and 0 deletions.
1 change: 1 addition & 0 deletions replicate/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@
models = default_client.models
predictions = default_client.predictions
trainings = default_client.trainings
deployments = default_client.deployments
5 changes: 5 additions & 0 deletions replicate/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from requests.cookies import RequestsCookieJar

from replicate.__about__ import __version__
from replicate.deployment import DeploymentCollection
from replicate.exceptions import ModelError, ReplicateError
from replicate.model import ModelCollection
from replicate.prediction import PredictionCollection
Expand Down Expand Up @@ -113,6 +114,10 @@ def predictions(self) -> PredictionCollection:
def trainings(self) -> TrainingCollection:
return TrainingCollection(client=self)

@property
def deployments(self) -> DeploymentCollection:
return DeploymentCollection(client=self)

def run(self, model_version: str, **kwargs) -> Union[Any, Iterator[Any]]:
"""
Run a model and wait for its output.
Expand Down
140 changes: 140 additions & 0 deletions replicate/deployment.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union

from replicate.base_model import BaseModel
from replicate.collection import Collection
from replicate.files import upload_file
from replicate.json import encode_json
from replicate.prediction import Prediction

if TYPE_CHECKING:
from replicate.client import Client


class Deployment(BaseModel):
"""
A deployment of a model hosted on Replicate.
"""

username: str
"""
The name of the user or organization that owns the deployment.
"""

name: str
"""
The name of the deployment.
"""

@property
def predictions(self) -> "DeploymentPredictionCollection":
"""
Get the predictions for this deployment.
"""

return DeploymentPredictionCollection(client=self._client, deployment=self)


class DeploymentCollection(Collection):
model = Deployment

def list(self) -> List[Deployment]:
raise NotImplementedError()

def get(self, name: str) -> Deployment:
"""
Get a deployment by name.
Args:
name: The name of the deployment, in the format `owner/model-name`.
Returns:
The model.
"""

# TODO: fetch model from server
# TODO: support permanent IDs
username, name = name.split("/")
return self.prepare_model({"username": username, "name": name})

def create(self, **kwargs) -> Deployment:
raise NotImplementedError()

def prepare_model(self, attrs: Union[Deployment, Dict]) -> Deployment:
if isinstance(attrs, BaseModel):
attrs.id = f"{attrs.username}/{attrs.name}"
elif isinstance(attrs, dict):
attrs["id"] = f"{attrs['username']}/{attrs['name']}"
return super().prepare_model(attrs)


class DeploymentPredictionCollection(Collection):
model = Prediction

def __init__(self, client: "Client", deployment: Deployment) -> None:
super().__init__(client=client)
self._deployment = deployment

def list(self) -> List[Prediction]:
raise NotImplementedError()

def get(self, id: str) -> Prediction:
"""
Get a prediction by ID.
Args:
id: The ID of the prediction.
Returns:
Prediction: The prediction object.
"""

resp = self._client._request("GET", f"/v1/predictions/{id}")
obj = resp.json()
# HACK: resolve this? make it lazy somehow?
del obj["version"]
return self.prepare_model(obj)

def create( # type: ignore
self,
input: Dict[str, Any],
webhook: Optional[str] = None,
webhook_completed: Optional[str] = None,
webhook_events_filter: Optional[List[str]] = None,
*,
stream: Optional[bool] = None,
**kwargs,
) -> Prediction:
"""
Create a new prediction with the deployment.
Args:
input: The input data for the prediction.
webhook: The URL to receive a POST request with prediction updates.
webhook_completed: The URL to receive a POST request when the prediction is completed.
webhook_events_filter: List of events to trigger webhooks.
stream: Set to True to enable streaming of prediction output.
Returns:
Prediction: The created prediction object.
"""

input = encode_json(input, upload_file=upload_file)
body = {
"input": input,
}
if webhook is not None:
body["webhook"] = webhook
if webhook_completed is not None:
body["webhook_completed"] = webhook_completed
if webhook_events_filter is not None:
body["webhook_events_filter"] = webhook_events_filter
if stream is True:
body["stream"] = "true"

resp = self._client._request(
"POST",
f"/v1/deployments/{self._deployment.username}/{self._deployment.name}/predictions",
json=body,
)
obj = resp.json()
obj["deployment"] = self._deployment
del obj["version"]
return self.prepare_model(obj)
47 changes: 47 additions & 0 deletions tests/test_deployment.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import responses
from responses import matchers

from replicate.client import Client


@responses.activate
def test_deployment_predictions_create():
client = Client(api_token="abc123")

deployment = client.deployments.get("test/model")

rsp = responses.post(
"https://api.replicate.com/v1/deployments/test/model/predictions",
match=[
matchers.json_params_matcher(
{
"input": {"text": "world"},
"webhook": "https://example.com/webhook",
"webhook_events_filter": ["completed"],
}
),
],
json={
"id": "p1",
"version": "v1",
"urls": {
"get": "https://api.replicate.com/v1/predictions/p1",
"cancel": "https://api.replicate.com/v1/predictions/p1/cancel",
},
"created_at": "2022-04-26T20:00:40.658234Z",
"source": "api",
"status": "processing",
"input": {"text": "hello"},
"output": None,
"error": None,
"logs": "",
},
)

deployment.predictions.create(
input={"text": "world"},
webhook="https://example.com/webhook",
webhook_events_filter=["completed"],
)

assert rsp.call_count == 1

0 comments on commit b16235f

Please sign in to comment.