Skip to content

Commit

Permalink
Fix signature of create methods (#181)
Browse files Browse the repository at this point in the history
Provides backwards-compatible migration to typed kwargs following [PEP
692](https://peps.python.org/pep-0692/).

---------

Signed-off-by: Mattt Zmuda <[email protected]>
  • Loading branch information
mattt authored Nov 2, 2023
1 parent 39d6bc9 commit e531cff
Show file tree
Hide file tree
Showing 7 changed files with 195 additions and 46 deletions.
4 changes: 3 additions & 1 deletion replicate/collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,9 @@ def get(self, key: str) -> Model: # pylint: disable=missing-function-docstring
pass

@abc.abstractmethod
def create(self, **kwargs) -> Model: # pylint: disable=missing-function-docstring
def create( # pylint: disable=missing-function-docstring
self, *args, **kwargs
) -> Model:
pass

def prepare_model(self, attrs: Union[Model, Dict]) -> Model:
Expand Down
59 changes: 43 additions & 16 deletions replicate/deployment.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union, overload

from typing_extensions import Unpack

from replicate.base_model import BaseModel
from replicate.collection import Collection
from replicate.files import upload_file
from replicate.json import encode_json
from replicate.prediction import Prediction
from replicate.prediction import Prediction, PredictionCollection

if TYPE_CHECKING:
from replicate.client import Client
Expand Down Expand Up @@ -65,7 +67,11 @@ def get(self, name: str) -> Deployment:
username, name = name.split("/")
return self.prepare_model({"username": username, "name": name})

def create(self, **kwargs) -> Deployment:
def create(
self,
*args,
**kwargs,
) -> Deployment:
"""
Create a deployment.
Expand Down Expand Up @@ -114,15 +120,34 @@ def get(self, id: str) -> Prediction:
del obj["version"]
return self.prepare_model(obj)

def create( # type: ignore
@overload
def create( # pylint: disable=arguments-differ disable=too-many-arguments
self,
input: Dict[str, Any],
*,
webhook: Optional[str] = None,
webhook_completed: Optional[str] = None,
webhook_events_filter: Optional[List[str]] = None,
stream: Optional[bool] = None,
) -> Prediction:
...

@overload
def create( # pylint: disable=arguments-differ disable=too-many-arguments
self,
*,
input: Dict[str, Any],
webhook: Optional[str] = None,
webhook_completed: Optional[str] = None,
webhook_events_filter: Optional[List[str]] = None,
stream: Optional[bool] = None,
**kwargs,
) -> Prediction:
...

def create(
self,
*args,
**kwargs: Unpack[PredictionCollection.CreateParams], # type: ignore[misc]
) -> Prediction:
"""
Create a new prediction with the deployment.
Expand All @@ -138,18 +163,20 @@ def create( # type: ignore
Prediction: The created prediction object.
"""

input = encode_json(input, upload_file=upload_file)
body: Dict[str, Any] = {
"input": input,
input = args[0] if len(args) > 0 else kwargs.get("input")
if input is None:
raise ValueError(
"An input must be provided as a positional or keyword argument."
)

body = {
"input": encode_json(input, upload_file=upload_file),
}
if webhook is not None:
body["webhook"] = webhook
if webhook_completed is not None:
body["webhook_completed"] = webhook_completed
if webhook_events_filter is not None:
body["webhook_events_filter"] = webhook_events_filter
if stream is True:
body["stream"] = True

for key in ["webhook", "webhook_completed", "webhook_events_filter", "stream"]:
value = kwargs.get(key)
if value is not None:
body[key] = value

resp = self._client._request(
"POST",
Expand Down
6 changes: 5 additions & 1 deletion replicate/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,11 @@ def get(self, key: str) -> Model:
resp = self._client._request("GET", f"/v1/models/{key}")
return self.prepare_model(resp.json())

def create(self, **kwargs) -> Model:
def create(
self,
*args,
**kwargs,
) -> Model:
"""
Create a model.
Expand Down
69 changes: 55 additions & 14 deletions replicate/prediction.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import re
import time
from dataclasses import dataclass
from typing import Any, Dict, Iterator, List, Optional, Union
from typing import Any, Dict, Iterator, List, Optional, TypedDict, Union, overload

from typing_extensions import Unpack

from replicate.base_model import BaseModel
from replicate.collection import Collection
Expand Down Expand Up @@ -137,6 +139,16 @@ class PredictionCollection(Collection):
Namespace for operations related to predictions.
"""

class CreateParams(TypedDict):
"""Parameters for creating a prediction."""

version: Union[Version, str]
input: Dict[str, Any]
webhook: Optional[str]
webhook_completed: Optional[str]
webhook_events_filter: Optional[List[str]]
stream: Optional[bool]

model = Prediction

def list(self) -> List[Prediction]:
Expand Down Expand Up @@ -171,16 +183,36 @@ def get(self, id: str) -> Prediction:
del obj["version"]
return self.prepare_model(obj)

def create( # type: ignore
@overload
def create( # pylint: disable=arguments-differ disable=too-many-arguments
self,
version: Union[Version, str],
input: Dict[str, Any],
*,
webhook: Optional[str] = None,
webhook_completed: Optional[str] = None,
webhook_events_filter: Optional[List[str]] = None,
stream: Optional[bool] = None,
) -> Prediction:
...

@overload
def create( # pylint: disable=arguments-differ disable=too-many-arguments
self,
*,
version: Union[Version, str],
input: Dict[str, Any],
webhook: Optional[str] = None,
webhook_completed: Optional[str] = None,
webhook_events_filter: Optional[List[str]] = None,
stream: Optional[bool] = None,
**kwargs,
) -> Prediction:
...

def create(
self,
*args,
**kwargs: Unpack[CreateParams], # type: ignore[misc]
) -> Prediction:
"""
Create a new prediction for the specified model version.
Expand All @@ -197,19 +229,28 @@ def create( # type: ignore
Prediction: The created prediction object.
"""

input = encode_json(input, upload_file=upload_file)
body: Dict[str, Any] = {
# Support positional arguments for backwards compatibility
version = args[0] if args else kwargs.get("version")
if version is None:
raise ValueError(
"A version identifier must be provided as a positional or keyword argument."
)

input = args[1] if len(args) > 1 else kwargs.get("input")
if input is None:
raise ValueError(
"An input must be provided as a positional or keyword argument."
)

body = {
"version": version if isinstance(version, str) else version.id,
"input": input,
"input": encode_json(input, upload_file=upload_file),
}
if webhook is not None:
body["webhook"] = webhook
if webhook_completed is not None:
body["webhook_completed"] = webhook_completed
if webhook_events_filter is not None:
body["webhook_events_filter"] = webhook_events_filter
if stream is True:
body["stream"] = True

for key in ["webhook", "webhook_completed", "webhook_events_filter", "stream"]:
value = kwargs.get(key)
if value is not None:
body[key] = value

resp = self._client._request(
"POST",
Expand Down
77 changes: 66 additions & 11 deletions replicate/training.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import re
from typing import Any, Dict, List, Optional
from typing import Any, Dict, List, Optional, TypedDict, Union

from typing_extensions import NotRequired, Unpack, overload

from replicate.base_model import BaseModel
from replicate.collection import Collection
Expand Down Expand Up @@ -68,6 +70,16 @@ class TrainingCollection(Collection):

model = Training

class CreateParams(TypedDict):
"""Parameters for creating a prediction."""

version: Union[Version, str]
destination: str
input: Dict[str, Any]
webhook: NotRequired[str]
webhook_completed: NotRequired[str]
webhook_events_filter: NotRequired[List[str]]

def list(self) -> List[Training]:
"""
List your trainings.
Expand Down Expand Up @@ -103,14 +115,36 @@ def get(self, id: str) -> Training:
del obj["version"]
return self.prepare_model(obj)

def create( # type: ignore
@overload
def create( # pylint: disable=arguments-differ disable=too-many-arguments
self,
version: Union[Version, str],
input: Dict[str, Any],
destination: str,
*,
webhook: Optional[str] = None,
webhook_completed: Optional[str] = None,
webhook_events_filter: Optional[List[str]] = None,
) -> Training:
...

@overload
def create( # pylint: disable=arguments-differ disable=too-many-arguments
self,
version: str,
*,
version: Union[Version, str],
input: Dict[str, Any],
destination: str,
webhook: Optional[str] = None,
webhook_completed: Optional[str] = None,
webhook_events_filter: Optional[List[str]] = None,
**kwargs,
) -> Training:
...

def create(
self,
*args,
**kwargs: Unpack[CreateParams], # type: ignore[misc]
) -> Training:
"""
Create a new training using the specified model version as a base.
Expand All @@ -120,24 +154,45 @@ def create( # type: ignore
input: The input to the training.
destination: The desired model to push to in the format `{owner}/{model_name}`. This should be an existing model owned by the user or organization making the API request.
webhook: The URL to send a POST request to when the training is completed. Defaults to None.
webhook_completed: The URL to receive a POST request when the prediction is completed.
webhook_events_filter: The events to send to the webhook. Defaults to None.
Returns:
The training object.
"""

input = encode_json(input, upload_file=upload_file)
# Support positional arguments for backwards compatibility
version = args[0] if args else kwargs.get("version")
if version is None:
raise ValueError(
"A version identifier must be provided as a positional or keyword argument."
)

destination = args[1] if len(args) > 1 else kwargs.get("destination")
if destination is None:
raise ValueError(
"A destination must be provided as a positional or keyword argument."
)

input = args[2] if len(args) > 2 else kwargs.get("input")
if input is None:
raise ValueError(
"An input must be provided as a positional or keyword argument."
)

body = {
"input": input,
"input": encode_json(input, upload_file=upload_file),
"destination": destination,
}
if webhook is not None:
body["webhook"] = webhook
if webhook_events_filter is not None:
body["webhook_events_filter"] = webhook_events_filter

for key in ["webhook", "webhook_completed", "webhook_events_filter"]:
value = kwargs.get(key)
if value is not None:
body[key] = value

# Split version in format "username/model_name:version_id"
match = re.match(
r"^(?P<username>[^/]+)/(?P<model_name>[^:]+):(?P<version_id>.+)$", version
r"^(?P<username>[^/]+)/(?P<model_name>[^:]+):(?P<version_id>.+)$",
version.id if isinstance(version, Version) else version,
)
if not match:
raise ReplicateException(
Expand Down
6 changes: 5 additions & 1 deletion replicate/version.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,11 @@ def get(self, id: str) -> Version:
)
return self.prepare_model(resp.json())

def create(self, **kwargs) -> Version:
def create(
self,
*args,
**kwargs,
) -> Version:
"""
Create a model version.
Expand Down
Loading

0 comments on commit e531cff

Please sign in to comment.