diff --git a/src/gretel_trainer/benchmark/core.py b/src/gretel_trainer/benchmark/core.py index 33d9de30..39a0193a 100644 --- a/src/gretel_trainer/benchmark/core.py +++ b/src/gretel_trainer/benchmark/core.py @@ -47,13 +47,35 @@ def __init__( working_dir: Optional[Union[str, Path]] = None, additional_report_scores: Optional[list[str]] = None, n_jobs: int = 5, + generate_num_records: Optional[int] = None, ): + """Configuration for a benchmark comparison. + + Args: + project_display_name: visible name for the Gretel project + containing all artifacts and models for this benchmark run, + uses name based on start time if None + refresh_interval: interval in seconds between refreshes for job + status + trainer: use GretelTrainer for training and generation when True + working_dir: local directory to store benchmark artifacts, if None, + defaults to the project_display_name + additional_report_scores: other scores besides SQS to extract and + show in the results DataFrame, uses abbreviations for scores, + valid values are FCS, PCS, DFS, PPL + n_jobs: max jobs to submit to Gretel in parallel, increase to run + large benchmarks faster, but Gretel also has server-side + limitations, so be considerate to other users + generate_num_records: number of records to generate for evaluation, + use the size of each input dataset if None or 0 + """ self.project_display_name = project_display_name or _default_name() self.working_dir = Path(working_dir or self.project_display_name) self.refresh_interval = refresh_interval self.trainer = trainer self.additional_report_scores = additional_report_scores or [] self.n_jobs = n_jobs + self.generate_num_records = generate_num_records class Timer: diff --git a/src/gretel_trainer/benchmark/custom/strategy.py b/src/gretel_trainer/benchmark/custom/strategy.py index ac12aa6e..a58f64ab 100644 --- a/src/gretel_trainer/benchmark/custom/strategy.py +++ b/src/gretel_trainer/benchmark/custom/strategy.py @@ -32,7 +32,7 @@ def generate(self) -> None: self.generate_timer = Timer() with self.generate_timer: synthetic_df = self.benchmark_model.generate( - num_records=self.dataset.row_count, + num_records=self.config.generate_num_records or self.dataset.row_count, ) synthetic_df.to_csv(self._synthetic_data_path, index=False) diff --git a/src/gretel_trainer/benchmark/gretel/strategy_sdk.py b/src/gretel_trainer/benchmark/gretel/strategy_sdk.py index ce704011..d8c07a2e 100644 --- a/src/gretel_trainer/benchmark/gretel/strategy_sdk.py +++ b/src/gretel_trainer/benchmark/gretel/strategy_sdk.py @@ -79,7 +79,10 @@ def generate(self) -> None: raise BenchmarkException("Cannot generate before training") _record_handler = self.model.create_record_handler_obj( - params={"num_records": self.dataset.row_count} + params={ + "num_records": self.config.generate_num_records + or self.dataset.row_count + } ) self.record_handler = _record_handler.submit_cloud() job_status = self._await_job(self.record_handler, "generation") diff --git a/src/gretel_trainer/benchmark/gretel/strategy_trainer.py b/src/gretel_trainer/benchmark/gretel/strategy_trainer.py index 9b28711f..0b7f4062 100644 --- a/src/gretel_trainer/benchmark/gretel/strategy_trainer.py +++ b/src/gretel_trainer/benchmark/gretel/strategy_trainer.py @@ -67,7 +67,9 @@ def generate(self) -> None: self.generate_timer = Timer() with self.generate_timer: - synthetic_data = self.trainer.generate(num_records=self.dataset.row_count) + synthetic_data = self.trainer.generate( + num_records=self.config.generate_num_records or self.dataset.row_count + ) synthetic_data.to_csv(self._synthetic_data_path, index=False) @property