Skip to content

Commit

Permalink
Don't assign job until after it is submitted
Browse files Browse the repository at this point in the history
GitOrigin-RevId: ed7badb8d98cc7bc0fbbd8c44b59b9548cd3d2a8
  • Loading branch information
mikeknep committed Apr 10, 2024
1 parent c9d3982 commit f175bb9
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 4 deletions.
8 changes: 4 additions & 4 deletions src/gretel_trainer/benchmark/gretel/strategy_sdk.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,12 +64,12 @@ def get_generate_time(self) -> Optional[float]:
def train(self) -> None:
model_config = self._format_model_config()
data_source = self.artifact_key or self.dataset.data_source
self.model = self.project.create_model_obj(
_model = self.project.create_model_obj(
model_config=model_config, data_source=data_source
)
# Calling this in lieu of submit_cloud() is supposed to avoid
# artifact upload. Doesn't work for more recent client versions!
self.model.submit(runner_mode=RunnerMode.CLOUD)
self.model = _model.submit(runner_mode=RunnerMode.CLOUD)
job_status = self._await_job(self.model, "training")
if job_status in END_STATES and job_status != Status.COMPLETED:
raise BenchmarkException("Training failed")
Expand All @@ -78,10 +78,10 @@ def generate(self) -> None:
if self.model is None:
raise BenchmarkException("Cannot generate before training")

self.record_handler = self.model.create_record_handler_obj(
_record_handler = self.model.create_record_handler_obj(
params={"num_records": self.dataset.row_count}
)
self.record_handler.submit_cloud()
self.record_handler = _record_handler.submit_cloud()
job_status = self._await_job(self.record_handler, "generation")
if job_status == Status.COMPLETED:
self._download_synthetic_data(self.record_handler)
Expand Down
4 changes: 4 additions & 0 deletions tests/benchmark/test_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,12 +209,14 @@ def test_run_happy_path_gretel_sdk(
status=Status.COMPLETED,
billing_details={"total_time_seconds": 15},
)
record_handler.submit_cloud.return_value = record_handler

model = Mock(
status=Status.COMPLETED,
billing_details={"total_time_seconds": 30},
)
model.create_record_handler_obj.return_value = record_handler
model.submit.return_value = model

evaluate_model = Mock(
status=Status.COMPLETED,
Expand Down Expand Up @@ -268,6 +270,7 @@ def test_sdk_model_failure(working_dir, iris, project):
status=Status.ERROR,
billing_details={"total_time_seconds": 30},
)
model.submit.return_value = model

project.create_model_obj.side_effect = [model]

Expand Down Expand Up @@ -309,6 +312,7 @@ def test_custom_gretel_model_configs_do_not_overwrite_each_other(
status=Status.ERROR,
billing_details={"total_time_seconds": 30},
)
model.submit.return_value = model
project.create_model_obj.return_value = model

pets = create_dataset(df, datatype="tabular", name="pets")
Expand Down

0 comments on commit f175bb9

Please sign in to comment.