Skip to content

Commit

Permalink
Benchmark snapshots
Browse files Browse the repository at this point in the history
GitOrigin-RevId: 82ee45b13432ec2950fa2436ba651f941a5f6ef8
  • Loading branch information
mikeknep committed Mar 15, 2024
1 parent 69b78ab commit e110d64
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 5 deletions.
7 changes: 6 additions & 1 deletion src/gretel_trainer/benchmark/executor.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import logging

from enum import Enum
from typing import Optional, Protocol
from typing import Callable, Optional, Protocol

from gretel_client.projects.models import Model
from gretel_client.projects.projects import Project
Expand Down Expand Up @@ -66,11 +66,13 @@ def __init__(
run_identifier: str,
evaluate_project: Project,
config: BenchmarkConfig,
snapshot: Callable[[], None],
):
self.strategy = strategy
self.run_identifier = run_identifier
self.evaluate_project = evaluate_project
self.config = config
self.snapshot = snapshot

self.status = Status.NotStarted
self.exception: Optional[Exception] = None
Expand All @@ -81,10 +83,13 @@ def run(self) -> None:
self._maybe_skip()
if self.status.can_proceed:
self._train()
self.snapshot()
if self.status.can_proceed:
self._generate()
self.snapshot()
if self.status.can_proceed:
self._evaluate()
self.snapshot()

def get_report_score(self, key: str) -> Optional[int]:
if self.evaluate_report_json is None:
Expand Down
8 changes: 8 additions & 0 deletions src/gretel_trainer/benchmark/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,11 +209,15 @@ def _setup_gretel_run(
trainer_project_index=trainer_project_index,
artifact_key=artifact_key,
)
snapshot_dest = str(
self._config.working_dir / f"{run_key.identifier}_result_data.csv"
)
executor = Executor(
strategy=strategy,
run_identifier=run_identifier,
evaluate_project=self._project,
config=self._config,
snapshot=lambda: self.export_results(snapshot_dest),
)
self._gretel_executors[run_key] = executor

Expand All @@ -231,11 +235,15 @@ def _setup_custom_run(
config=self._config,
artifact_key=artifact_key,
)
snapshot_dest = str(
self._config.working_dir / f"{run_key.identifier}_result_data.csv"
)
executor = Executor(
strategy=strategy,
run_identifier=run_identifier,
evaluate_project=self._project,
config=self._config,
snapshot=lambda: self.export_results(snapshot_dest),
)
self._custom_executors[run_key] = executor

Expand Down
13 changes: 9 additions & 4 deletions tests/benchmark/test_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,7 @@
compare,
create_dataset,
Datatype,
GretelGPTX,
GretelLSTM,
launch,
)
from gretel_trainer.benchmark.core import Dataset
from gretel_trainer.benchmark.gretel.models import GretelModel
Expand Down Expand Up @@ -247,16 +245,23 @@ def test_run_happy_path_gretel_sdk(
assert result["Generate time (sec)"] == 15
assert result["Total time (sec)"] == 45

# The synthetic data is written to the working directory
working_dir_contents = os.listdir(working_dir)
assert len(working_dir_contents) == 1
assert len(working_dir_contents) == 2

# The synthetic data is written to the working directory
filename = f"synth_{model_name}-iris.csv"
assert filename in working_dir_contents
df = pd.read_csv(f"{working_dir}/{filename}")
pdtest.assert_frame_equal(
df, pd.DataFrame(data={"synthetic": [1, 2], "data": [3, 4]})
)

# Snapshot results are written to the working directory
filename = f"{model_name}-iris_result_data.csv"
assert filename in working_dir_contents
df = pd.read_csv(f"{working_dir}/{filename}")
pdtest.assert_frame_equal(df, session.results)


def test_sdk_model_failure(working_dir, iris, project):
model = Mock(
Expand Down

0 comments on commit e110d64

Please sign in to comment.