Skip to content

Commit

Permalink
Add benchmark config option generate_num_records
Browse files Browse the repository at this point in the history
GitOrigin-RevId: 4650fdd4414af761b09f6f66861add204f74c195
  • Loading branch information
kboyd committed Aug 20, 2024
1 parent 8d2effa commit 60d7eb8
Show file tree
Hide file tree
Showing 4 changed files with 30 additions and 3 deletions.
22 changes: 22 additions & 0 deletions src/gretel_trainer/benchmark/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,13 +47,35 @@ def __init__(
working_dir: Optional[Union[str, Path]] = None,
additional_report_scores: Optional[list[str]] = None,
n_jobs: int = 5,
generate_num_records: Optional[int] = None,
):
"""Configuration for a benchmark comparison.
Args:
project_display_name: visible name for the Gretel project
containing all artifacts and models for this benchmark run,
uses name based on start time if None
refresh_interval: interval in seconds between refreshes for job
status
trainer: use GretelTrainer for training and generation when True
working_dir: local directory to store benchmark artifacts, if None,
defaults to the project_display_name
additional_report_scores: other scores besides SQS to extract and
show in the results DataFrame, uses abbreviations for scores,
valid values are FCS, PCS, DFS, PPL
n_jobs: max jobs to submit to Gretel in parallel, increase to run
large benchmarks faster, but Gretel also has server-side
limitations, so be considerate to other users
generate_num_records: number of records to generate for evaluation,
use the size of each input dataset if None or 0
"""
self.project_display_name = project_display_name or _default_name()
self.working_dir = Path(working_dir or self.project_display_name)
self.refresh_interval = refresh_interval
self.trainer = trainer
self.additional_report_scores = additional_report_scores or []
self.n_jobs = n_jobs
self.generate_num_records = generate_num_records


class Timer:
Expand Down
2 changes: 1 addition & 1 deletion src/gretel_trainer/benchmark/custom/strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def generate(self) -> None:
self.generate_timer = Timer()
with self.generate_timer:
synthetic_df = self.benchmark_model.generate(
num_records=self.dataset.row_count,
num_records=self.config.generate_num_records or self.dataset.row_count,
)
synthetic_df.to_csv(self._synthetic_data_path, index=False)

Expand Down
5 changes: 4 additions & 1 deletion src/gretel_trainer/benchmark/gretel/strategy_sdk.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,10 @@ def generate(self) -> None:
raise BenchmarkException("Cannot generate before training")

_record_handler = self.model.create_record_handler_obj(
params={"num_records": self.dataset.row_count}
params={
"num_records": self.config.generate_num_records
or self.dataset.row_count
}
)
self.record_handler = _record_handler.submit_cloud()
job_status = self._await_job(self.record_handler, "generation")
Expand Down
4 changes: 3 additions & 1 deletion src/gretel_trainer/benchmark/gretel/strategy_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,9 @@ def generate(self) -> None:

self.generate_timer = Timer()
with self.generate_timer:
synthetic_data = self.trainer.generate(num_records=self.dataset.row_count)
synthetic_data = self.trainer.generate(
num_records=self.config.generate_num_records or self.dataset.row_count
)
synthetic_data.to_csv(self._synthetic_data_path, index=False)

@property
Expand Down

0 comments on commit 60d7eb8

Please sign in to comment.