diff --git a/src/gretel_trainer/benchmark/custom/models.py b/src/gretel_trainer/benchmark/custom/models.py index 36ade1a6..454301b0 100644 --- a/src/gretel_trainer/benchmark/custom/models.py +++ b/src/gretel_trainer/benchmark/custom/models.py @@ -6,8 +6,6 @@ class CustomModel(Protocol): - def train(self, source: Dataset, **kwargs) -> None: - ... + def train(self, source: Dataset, **kwargs) -> None: ... - def generate(self, **kwargs) -> pd.DataFrame: - ... + def generate(self, **kwargs) -> pd.DataFrame: ... diff --git a/src/gretel_trainer/benchmark/executor.py b/src/gretel_trainer/benchmark/executor.py index c3ddd594..942d306e 100644 --- a/src/gretel_trainer/benchmark/executor.py +++ b/src/gretel_trainer/benchmark/executor.py @@ -36,27 +36,20 @@ def cannot_proceed(self) -> bool: class Strategy(Protocol): @property - def dataset(self) -> Dataset: - ... + def dataset(self) -> Dataset: ... @property - def evaluate_ref_data(self) -> str: - ... + def evaluate_ref_data(self) -> str: ... - def runnable(self) -> bool: - ... + def runnable(self) -> bool: ... - def train(self) -> None: - ... + def train(self) -> None: ... - def generate(self) -> None: - ... + def generate(self) -> None: ... - def get_train_time(self) -> Optional[float]: - ... + def get_train_time(self) -> Optional[float]: ... - def get_generate_time(self) -> Optional[float]: - ... + def get_generate_time(self) -> Optional[float]: ... class Executor: diff --git a/src/gretel_trainer/relational/connectors.py b/src/gretel_trainer/relational/connectors.py index 79a01954..2547056b 100644 --- a/src/gretel_trainer/relational/connectors.py +++ b/src/gretel_trainer/relational/connectors.py @@ -6,6 +6,7 @@ which you can then use with the "MultiTable" class to process data with Gretel Transforms, Classify, Synthetics, or a combination of both. """ + from __future__ import annotations import logging diff --git a/src/gretel_trainer/relational/json.py b/src/gretel_trainer/relational/json.py index 2a1edf3f..009c5fc4 100644 --- a/src/gretel_trainer/relational/json.py +++ b/src/gretel_trainer/relational/json.py @@ -179,13 +179,11 @@ def get_foreign_keys( ) -> list: # can't specify element type (ForeignKey) without cyclic dependency ... - def get_table_columns(self, table: str) -> list[str]: - ... + def get_table_columns(self, table: str) -> list[str]: ... def get_invented_table_metadata( self, table: str - ) -> Optional[InventedTableMetadata]: - ... + ) -> Optional[InventedTableMetadata]: ... @dataclass diff --git a/src/gretel_trainer/relational/multi_table.py b/src/gretel_trainer/relational/multi_table.py index 4eb9920f..7edcc09a 100644 --- a/src/gretel_trainer/relational/multi_table.py +++ b/src/gretel_trainer/relational/multi_table.py @@ -3,6 +3,7 @@ take extracted data from a database or data warehouse, and process it with Gretel using Transforms, Classify, and Synthetics. """ + from __future__ import annotations import json @@ -574,9 +575,10 @@ def run_transforms( if isinstance(data_source, pd.DataFrame): data_source.to_csv(transforms_run_path, index=False) else: - with open_artifact(data_source, "rb") as src, open_artifact( - transforms_run_path, "wb" - ) as dest: + with ( + open_artifact(data_source, "rb") as src, + open_artifact(transforms_run_path, "wb") as dest, + ): shutil.copyfileobj(src, dest) transforms_run_paths[table] = transforms_run_path diff --git a/src/gretel_trainer/relational/sdk_extras.py b/src/gretel_trainer/relational/sdk_extras.py index 51b215c4..66fa37c4 100644 --- a/src/gretel_trainer/relational/sdk_extras.py +++ b/src/gretel_trainer/relational/sdk_extras.py @@ -56,9 +56,10 @@ def download_file_artifact( out_path: Union[str, Path], ) -> bool: try: - with gretel_object.get_artifact_handle(artifact_name) as src, open_artifact( - out_path, "wb" - ) as dest: + with ( + gretel_object.get_artifact_handle(artifact_name) as src, + open_artifact(out_path, "wb") as dest, + ): shutil.copyfileobj(src, dest) return True except: diff --git a/src/gretel_trainer/relational/strategies/independent.py b/src/gretel_trainer/relational/strategies/independent.py index 1361fec6..75d8658c 100644 --- a/src/gretel_trainer/relational/strategies/independent.py +++ b/src/gretel_trainer/relational/strategies/independent.py @@ -62,9 +62,10 @@ def prepare_training_data( continue source_path = rel_data.get_table_source(table) - with open_artifact(source_path, "rb") as src, open_artifact( - path, "wb" - ) as dest: + with ( + open_artifact(source_path, "rb") as src, + open_artifact(path, "wb") as dest, + ): pd.DataFrame(columns=use_columns).to_csv(dest, index=False) for chunk in pd.read_csv(src, usecols=use_columns, chunksize=10_000): chunk.to_csv(dest, index=False, mode="a", header=False) diff --git a/src/gretel_trainer/relational/table_evaluation.py b/src/gretel_trainer/relational/table_evaluation.py index 0993577b..1d4b189a 100644 --- a/src/gretel_trainer/relational/table_evaluation.py +++ b/src/gretel_trainer/relational/table_evaluation.py @@ -25,14 +25,12 @@ def is_complete(self) -> bool: @overload def _field_from_json( self, report_json: Optional[dict], entry: str, field: Literal["score"] - ) -> Optional[int]: - ... + ) -> Optional[int]: ... @overload def _field_from_json( self, report_json: Optional[dict], entry: str, field: Literal["grade"] - ) -> Optional[str]: - ... + ) -> Optional[str]: ... def _field_from_json( self, report_json: Optional[dict], entry: str, field: str diff --git a/src/gretel_trainer/relational/task_runner.py b/src/gretel_trainer/relational/task_runner.py index 787e3cdb..b791a929 100644 --- a/src/gretel_trainer/relational/task_runner.py +++ b/src/gretel_trainer/relational/task_runner.py @@ -35,39 +35,28 @@ def maybe_start_job(self, job: Job, table_name: str, action: str) -> None: class Task(Protocol): @property - def ctx(self) -> TaskContext: - ... + def ctx(self) -> TaskContext: ... - def action(self, job: Job) -> str: - ... + def action(self, job: Job) -> str: ... @property - def table_collection(self) -> list[str]: - ... + def table_collection(self) -> list[str]: ... - def more_to_do(self) -> bool: - ... + def more_to_do(self) -> bool: ... - def is_finished(self, table: str) -> bool: - ... + def is_finished(self, table: str) -> bool: ... - def get_job(self, table: str) -> Job: - ... + def get_job(self, table: str) -> Job: ... - def handle_completed(self, table: str, job: Job) -> None: - ... + def handle_completed(self, table: str, job: Job) -> None: ... - def handle_failed(self, table: str, job: Job) -> None: - ... + def handle_failed(self, table: str, job: Job) -> None: ... - def handle_in_progress(self, table: str, job: Job) -> None: - ... + def handle_in_progress(self, table: str, job: Job) -> None: ... - def handle_lost_contact(self, table: str, job: Job) -> None: - ... + def handle_lost_contact(self, table: str, job: Job) -> None: ... - def each_iteration(self) -> None: - ... + def each_iteration(self) -> None: ... def run_task(task: Task, extended_sdk: ExtendedGretelSDK) -> None: diff --git a/src/gretel_trainer/relational/tasks/classify.py b/src/gretel_trainer/relational/tasks/classify.py index 9e394eda..9a473750 100644 --- a/src/gretel_trainer/relational/tasks/classify.py +++ b/src/gretel_trainer/relational/tasks/classify.py @@ -129,8 +129,9 @@ def _write_results(self, job: Job, table: str) -> None: destpath = self.output_handler.filepath_for(filename) - with job.get_artifact_handle(artifact_name) as src, open_artifact( - str(destpath), "wb" - ) as dest: + with ( + job.get_artifact_handle(artifact_name) as src, + open_artifact(str(destpath), "wb") as dest, + ): shutil.copyfileobj(src, dest) self.result_filepaths[table] = destpath diff --git a/src/gretel_trainer/runner.py b/src/gretel_trainer/runner.py index 1b0c6e16..59915974 100644 --- a/src/gretel_trainer/runner.py +++ b/src/gretel_trainer/runner.py @@ -20,6 +20,7 @@ } } """ + from __future__ import annotations import json diff --git a/tests/relational/conftest.py b/tests/relational/conftest.py index 0666f037..6cbbc384 100644 --- a/tests/relational/conftest.py +++ b/tests/relational/conftest.py @@ -61,11 +61,10 @@ def output_handler(tmpdir, project): @pytest.fixture() def project(): - with patch( - "gretel_trainer.relational.multi_table.create_project" - ) as create_project, patch( - "gretel_trainer.relational.multi_table.get_project" - ) as get_project: + with ( + patch("gretel_trainer.relational.multi_table.create_project") as create_project, + patch("gretel_trainer.relational.multi_table.get_project") as get_project, + ): project = Mock() project.name = "name" project.display_name = "display_name" diff --git a/tests/relational/test_ancestral_strategy.py b/tests/relational/test_ancestral_strategy.py index cba5ad69..9c74e951 100644 --- a/tests/relational/test_ancestral_strategy.py +++ b/tests/relational/test_ancestral_strategy.py @@ -20,7 +20,10 @@ def test_preparing_training_data_does_not_mutate_source_data(pets): strategy = AncestralStrategy() - with tempfile.NamedTemporaryFile() as pets_dest, tempfile.NamedTemporaryFile() as humans_dest: + with ( + tempfile.NamedTemporaryFile() as pets_dest, + tempfile.NamedTemporaryFile() as humans_dest, + ): strategy.prepare_training_data( pets, {"pets": pets_dest.name, "humans": humans_dest.name} ) @@ -32,7 +35,10 @@ def test_preparing_training_data_does_not_mutate_source_data(pets): def test_prepare_training_data_subset_of_tables(pets): strategy = AncestralStrategy() - with tempfile.NamedTemporaryFile() as pets_dest, tempfile.NamedTemporaryFile() as humans_dest: + with ( + tempfile.NamedTemporaryFile() as pets_dest, + tempfile.NamedTemporaryFile() as humans_dest, + ): # We aren't synthesizing the "humans" table, so it is not in this list argument... training_data = strategy.prepare_training_data(pets, {"pets": pets_dest.name}) @@ -53,7 +59,10 @@ def test_prepare_training_data_subset_of_tables(pets): def test_prepare_training_data_returns_multigenerational_data(pets): strategy = AncestralStrategy() - with tempfile.NamedTemporaryFile() as pets_dest, tempfile.NamedTemporaryFile() as humans_dest: + with ( + tempfile.NamedTemporaryFile() as pets_dest, + tempfile.NamedTemporaryFile() as humans_dest, + ): training_data = strategy.prepare_training_data( pets, {"pets": pets_dest.name, "humans": humans_dest.name} ) @@ -64,7 +73,10 @@ def test_prepare_training_data_returns_multigenerational_data(pets): def test_prepare_training_data_drops_highly_unique_categorical_ancestor_fields(art): - with tempfile.NamedTemporaryFile() as artists_csv, tempfile.NamedTemporaryFile() as paintings_csv: + with ( + tempfile.NamedTemporaryFile() as artists_csv, + tempfile.NamedTemporaryFile() as paintings_csv, + ): pd.DataFrame( data={ "id": [f"A{i}" for i in range(100)], @@ -84,7 +96,10 @@ def test_prepare_training_data_drops_highly_unique_categorical_ancestor_fields(a strategy = AncestralStrategy() - with tempfile.NamedTemporaryFile() as artists_dest, tempfile.NamedTemporaryFile() as paintings_dest: + with ( + tempfile.NamedTemporaryFile() as artists_dest, + tempfile.NamedTemporaryFile() as paintings_dest, + ): training_data = strategy.prepare_training_data( art, { @@ -111,7 +126,10 @@ def test_prepare_training_data_drops_highly_nan_ancestor_fields(art): else: highly_nan_names.append("some name") - with tempfile.NamedTemporaryFile() as artists_csv, tempfile.NamedTemporaryFile() as paintings_csv: + with ( + tempfile.NamedTemporaryFile() as artists_csv, + tempfile.NamedTemporaryFile() as paintings_csv, + ): pd.DataFrame( data={ "id": [f"A{i}" for i in range(100)], @@ -131,7 +149,10 @@ def test_prepare_training_data_drops_highly_nan_ancestor_fields(art): strategy = AncestralStrategy() - with tempfile.NamedTemporaryFile() as artists_dest, tempfile.NamedTemporaryFile() as paintings_dest: + with ( + tempfile.NamedTemporaryFile() as artists_dest, + tempfile.NamedTemporaryFile() as paintings_dest, + ): training_data = strategy.prepare_training_data( art, { @@ -155,7 +176,10 @@ def test_prepare_training_data_translates_alphanumeric_keys_and_adds_min_max_rec ): strategy = AncestralStrategy() - with tempfile.NamedTemporaryFile() as artists_dest, tempfile.NamedTemporaryFile() as paintings_dest: + with ( + tempfile.NamedTemporaryFile() as artists_dest, + tempfile.NamedTemporaryFile() as paintings_dest, + ): training_data = strategy.prepare_training_data( art, { @@ -191,7 +215,12 @@ def test_prepare_training_data_translates_alphanumeric_keys_and_adds_min_max_rec def test_prepare_training_data_with_composite_keys(tpch): strategy = AncestralStrategy() - with tempfile.NamedTemporaryFile() as supplier_dest, tempfile.NamedTemporaryFile() as part_dest, tempfile.NamedTemporaryFile() as partsupp_dest, tempfile.NamedTemporaryFile() as lineitem_dest: + with ( + tempfile.NamedTemporaryFile() as supplier_dest, + tempfile.NamedTemporaryFile() as part_dest, + tempfile.NamedTemporaryFile() as partsupp_dest, + tempfile.NamedTemporaryFile() as lineitem_dest, + ): training_data = strategy.prepare_training_data( tpch, { diff --git a/tests/relational/test_independent_strategy.py b/tests/relational/test_independent_strategy.py index 18c29d86..8e478eb6 100644 --- a/tests/relational/test_independent_strategy.py +++ b/tests/relational/test_independent_strategy.py @@ -18,7 +18,10 @@ def test_preparing_training_data_does_not_mutate_source_data(pets): strategy = IndependentStrategy() - with tempfile.NamedTemporaryFile() as pets_dest, tempfile.NamedTemporaryFile() as humans_dest: + with ( + tempfile.NamedTemporaryFile() as pets_dest, + tempfile.NamedTemporaryFile() as humans_dest, + ): strategy.prepare_training_data( pets, {"pets": pets_dest.name, "humans": humans_dest.name} ) @@ -30,7 +33,10 @@ def test_preparing_training_data_does_not_mutate_source_data(pets): def test_prepare_training_data_removes_primary_and_foreign_keys(pets): strategy = IndependentStrategy() - with tempfile.NamedTemporaryFile() as pets_dest, tempfile.NamedTemporaryFile() as humans_dest: + with ( + tempfile.NamedTemporaryFile() as pets_dest, + tempfile.NamedTemporaryFile() as humans_dest, + ): training_data = strategy.prepare_training_data( pets, {"pets": pets_dest.name, "humans": humans_dest.name} ) @@ -42,7 +48,10 @@ def test_prepare_training_data_removes_primary_and_foreign_keys(pets): def test_prepare_training_data_subset_of_tables(pets): strategy = IndependentStrategy() - with tempfile.NamedTemporaryFile() as pets_dest, tempfile.NamedTemporaryFile() as humans_dest: + with ( + tempfile.NamedTemporaryFile() as pets_dest, + tempfile.NamedTemporaryFile() as humans_dest, + ): training_data = strategy.prepare_training_data( pets, {"humans": humans_dest.name} ) @@ -53,7 +62,10 @@ def test_prepare_training_data_subset_of_tables(pets): def test_prepare_training_data_join_table(insurance): strategy = IndependentStrategy() - with tempfile.NamedTemporaryFile() as beneficiary_dest, tempfile.NamedTemporaryFile() as policies_dest: + with ( + tempfile.NamedTemporaryFile() as beneficiary_dest, + tempfile.NamedTemporaryFile() as policies_dest, + ): training_data = strategy.prepare_training_data( insurance, {