Skip to content

Commit

Permalink
Fix iterator support for replicate.run() (#383)
Browse files Browse the repository at this point in the history
Prior to 1.0.0 `replicate.run()` would return an iterator for cog models
that output a type of `Iterator[Any]`. This would poll the
`predictions.get` endpoint for the in progress prediction and yield any
new output.

When implementing the new file interface we introduced two bugs:

1. The iterator didn't convert URLs returned by the model into
`FileOutput` types making it inconsistent with the non-iterator
interface. This is controlled by the `use_file_outputs` argument.
2. The iterator was returned without checking if we are using the new
blocking API introduced by default and controlled by the `wait`
argument.

This commit fixes these two issues, consistently applying the
`transform_output` function to the output of the iterator as well as
returning the polling iterator (`prediciton.output_iterator`) if the
blocking API has not successfully returned a completed prediction.

The tests have been updated to exercise both of these code paths.
  • Loading branch information
aron authored Oct 28, 2024
1 parent 23bd903 commit de717a0
Show file tree
Hide file tree
Showing 2 changed files with 606 additions and 228 deletions.
77 changes: 54 additions & 23 deletions replicate/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
from replicate.exceptions import ModelError
from replicate.helpers import transform_output
from replicate.model import Model
from replicate.prediction import Prediction
from replicate.schema import make_schema_backwards_compatible
from replicate.version import Version, Versions

Expand Down Expand Up @@ -59,15 +58,36 @@ def run(
if not version and (owner and name and version_id):
version = Versions(client, model=(owner, name)).get(version_id)

if version and (iterator := _make_output_iterator(version, prediction)):
return iterator
# Currently the "Prefer: wait" interface will return a prediction with a status
# of "processing" rather than a terminal state because it returns before the
# prediction has been fully processed. If request exceeds the wait time, even if
# it is actually processing, the prediction will be in a "starting" state.
#
# We should fix this in the blocking API itself. Predictions that are done should
# be in a terminal state and predictions that are processing should be in state
# "processing".
in_terminal_state = is_blocking and prediction.status != "starting"
if not in_terminal_state:
# Return a "polling" iterator if the model has an output iterator array type.
if version and _has_output_iterator_array_type(version):
return (
transform_output(chunk, client)
for chunk in prediction.output_iterator()
)

if not (is_blocking and prediction.status != "starting"):
prediction.wait()

if prediction.status == "failed":
raise ModelError(prediction)

# Return an iterator for the completed prediction when needed.
if (
version
and _has_output_iterator_array_type(version)
and prediction.output is not None
):
return (transform_output(chunk, client) for chunk in prediction.output)

if use_file_output:
return transform_output(prediction.output, client)

Expand Down Expand Up @@ -108,15 +128,39 @@ async def async_run(
if not version and (owner and name and version_id):
version = await Versions(client, model=(owner, name)).async_get(version_id)

if version and (iterator := _make_async_output_iterator(version, prediction)):
return iterator
# Currently the "Prefer: wait" interface will return a prediction with a status
# of "processing" rather than a terminal state because it returns before the
# prediction has been fully processed. If request exceeds the wait time, even if
# it is actually processing, the prediction will be in a "starting" state.
#
# We should fix this in the blocking API itself. Predictions that are done should
# be in a terminal state and predictions that are processing should be in state
# "processing".
in_terminal_state = is_blocking and prediction.status != "starting"
if not in_terminal_state:
# Return a "polling" iterator if the model has an output iterator array type.
if version and _has_output_iterator_array_type(version):
return (
transform_output(chunk, client)
async for chunk in prediction.async_output_iterator()
)

if not (is_blocking and prediction.status != "starting"):
await prediction.async_wait()

if prediction.status == "failed":
raise ModelError(prediction)

# Return an iterator for completed output if the model has an output iterator array type.
if (
version
and _has_output_iterator_array_type(version)
and prediction.output is not None
):
return (
transform_output(chunk, client)
async for chunk in _make_async_iterator(prediction.output)
)

if use_file_output:
return transform_output(prediction.output, client)

Expand All @@ -133,22 +177,9 @@ def _has_output_iterator_array_type(version: Version) -> bool:
)


def _make_output_iterator(
version: Version, prediction: Prediction
) -> Optional[Iterator[Any]]:
if _has_output_iterator_array_type(version):
return prediction.output_iterator()

return None


def _make_async_output_iterator(
version: Version, prediction: Prediction
) -> Optional[AsyncIterator[Any]]:
if _has_output_iterator_array_type(version):
return prediction.async_output_iterator()

return None
async def _make_async_iterator(list: list) -> AsyncIterator:
for item in list:
yield item


__all__: List = []
Loading

0 comments on commit de717a0

Please sign in to comment.