Skip to content

Commit

Permalink
Add progress property to Prediction
Browse files Browse the repository at this point in the history
Signed-off-by: Mattt Zmuda <[email protected]>
  • Loading branch information
mattt committed Sep 17, 2023
1 parent 3b9cd7a commit 3fa37dc
Show file tree
Hide file tree
Showing 2 changed files with 95 additions and 0 deletions.
32 changes: 32 additions & 0 deletions replicate/prediction.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import re
import time
from dataclasses import dataclass
from typing import Any, Dict, Iterator, List, Optional

from replicate.base_model import BaseModel
Expand Down Expand Up @@ -56,6 +58,36 @@ class Prediction(BaseModel):
- `cancel`: A URL to cancel the prediction.
"""

@dataclass
class Progress:
percentage: float
"""The percentage of the prediction that has completed."""

current: int
"""The number of items that have been processed."""

total: int
"""The total number of items to process."""

@property
def progress(self) -> Optional[Progress]:
if self.logs is None or self.logs == "":
return None

pattern = r"^\s*(?P<percentage>\d+)%\s*\|.+?\|\s*(?P<current>\d+)\/(?P<total>\d+)"
re_compiled = re.compile(pattern)

lines = self.logs.split("\n")
for i in reversed(range(len(lines))):
line = lines[i].strip()
if re_compiled.match(line):
matches = re_compiled.findall(line)
if len(matches) == 1:
percentage, current, total = map(int, matches[0])
return Prediction.Progress(percentage / 100.0, current, total)
return None


def wait(self) -> None:
"""
Wait for prediction to finish.
Expand Down
63 changes: 63 additions & 0 deletions tests/test_prediction.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import responses
from responses import matchers

from replicate.prediction import Prediction

from .factories import create_client, create_version


Expand Down Expand Up @@ -214,3 +216,64 @@ def test_async_timings():
assert prediction.completed_at == "2022-04-26T20:02:27.648305Z"
assert prediction.output == "hello world"
assert prediction.metrics["predict_time"] == 1.2345

def test_prediction_progress():
client = create_client()
version = create_version(client)
prediction = Prediction(
id="ufawqhfynnddngldkgtslldrkq",
version=version,
status="starting"
)

lines = [
"Using seed: 12345",
"0%| | 0/5 [00:00<?, ?it/s]",
"20%|██ | 1/5 [00:00<00:01, 21.38it/s]",
"40%|████▍ | 2/5 [00:01<00:01, 22.46it/s]",
"60%|████▍ | 3/5 [00:01<00:01, 22.46it/s]",
"80%|████████ | 4/5 [00:01<00:00, 22.86it/s]",
"100%|██████████| 5/5 [00:02<00:00, 22.26it/s]",
]
logs = ""

for i, line in enumerate(lines):
logs += "\n" + line
prediction.logs = logs

progress = prediction.progress

if i == 0:
prediction.status = "processing"
assert progress is None
elif i == 1:
assert progress is not None
assert progress.current == 0
assert progress.total == 5
assert progress.percentage == 0.0
elif i == 2:
assert progress is not None
assert progress.current == 1
assert progress.total == 5
assert progress.percentage == 0.2
elif i == 3:
assert progress is not None
assert progress.current == 2
assert progress.total == 5
assert progress.percentage == 0.4
elif i == 4:
assert progress is not None
assert progress.current == 3
assert progress.total == 5
assert progress.percentage == 0.6
elif i == 5:
assert progress is not None
assert progress.current == 4
assert progress.total == 5
assert progress.percentage == 0.8
elif i == 6:
assert progress is not None
prediction.status = "succeeded"
assert progress.current == 5
assert progress.total == 5
assert progress.percentage == 1.0

0 comments on commit 3fa37dc

Please sign in to comment.