Skip to content

Commit

Permalink
Switch to using file outputs and blocking api by default
Browse files Browse the repository at this point in the history
  • Loading branch information
aron committed Oct 4, 2024
1 parent 08ee31a commit a0a06fe
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 11 deletions.
10 changes: 6 additions & 4 deletions replicate/prediction.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,11 +395,13 @@ class CreatePredictionParams(TypedDict):

wait: NotRequired[Union[int, bool]]
"""
Wait until the prediction is completed before returning.
Block until the prediction is completed before returning.
If `True`, wait a predetermined number of seconds until the prediction
is completed before returning.
If an `int`, wait for the specified number of seconds.
If `True`, keep the request open for up to 60 seconds, falling back to
polling until the prediction is completed.
If an `int`, same as True but hold the request for a specified number of
seconds (between 1 and 60).
If `False`, poll for the prediction status until completed.
"""

file_encoding_strategy: NotRequired[FileEncodingStrategy]
Expand Down
9 changes: 6 additions & 3 deletions replicate/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,14 +29,17 @@ def run(
client: "Client",
ref: Union["Model", "Version", "ModelVersionIdentifier", str],
input: Optional[Dict[str, Any]] = None,
use_file_output: Optional[bool] = None,
use_file_output: Optional[bool] = True,
**params: Unpack["Predictions.CreatePredictionParams"],
) -> Union[Any, Iterator[Any]]: # noqa: ANN401
"""
Run a model and wait for its output.
"""

is_blocking = "wait" in params
if "wait" not in params:
params["wait"] = True
is_blocking = params["wait"] != False

version, owner, name, version_id = identifier._resolve(ref)

if version_id is not None:
Expand Down Expand Up @@ -74,7 +77,7 @@ async def async_run(
client: "Client",
ref: Union["Model", "Version", "ModelVersionIdentifier", str],
input: Optional[Dict[str, Any]] = None,
use_file_output: Optional[bool] = None,
use_file_output: Optional[bool] = True,
**params: Unpack["Predictions.CreatePredictionParams"],
) -> Union[Any, AsyncIterator[Any]]: # noqa: ANN401
"""
Expand Down
8 changes: 4 additions & 4 deletions tests/test_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ def prediction_with_status(status: str) -> dict:
router.route(method="POST", path="/predictions").mock(
return_value=httpx.Response(
201,
json=prediction_with_status("processing"),
json=prediction_with_status("starting"),
)
)
router.route(method="GET", path="/predictions/p1").mock(
Expand Down Expand Up @@ -212,7 +212,7 @@ def prediction_with_status(status: str) -> dict:
router.route(method="POST", path="/predictions").mock(
return_value=httpx.Response(
201,
json=prediction_with_status("processing"),
json=prediction_with_status("starting"),
)
)
router.route(method="GET", path="/predictions/p1").mock(
Expand Down Expand Up @@ -454,7 +454,7 @@ def prediction_with_status(
router.route(method="POST", path="/predictions").mock(
return_value=httpx.Response(
201,
json=prediction_with_status("processing"),
json=prediction_with_status("starting"),
)
)
router.route(method="GET", path="/predictions/p1").mock(
Expand Down Expand Up @@ -541,7 +541,7 @@ def prediction_with_status(
router.route(method="POST", path="/predictions").mock(
return_value=httpx.Response(
201,
json=prediction_with_status("processing"),
json=prediction_with_status("starting"),
)
)
router.route(method="GET", path="/predictions/p1").mock(
Expand Down

0 comments on commit a0a06fe

Please sign in to comment.