diff --git a/src/gretel_trainer/relational/core.py b/src/gretel_trainer/relational/core.py index e480e8ae..087fae91 100644 --- a/src/gretel_trainer/relational/core.py +++ b/src/gretel_trainer/relational/core.py @@ -25,7 +25,6 @@ import networkx import pandas as pd -import smart_open from networkx.algorithms.cycles import simple_cycles from networkx.algorithms.dag import dag_longest_path_length, topological_sort @@ -34,6 +33,7 @@ import gretel_trainer.relational.json as relational_json +from gretel_client.projects.artifact_handlers import open_artifact from gretel_trainer.relational.json import ( IngestResponseT, InventedTableMetadata, @@ -273,7 +273,7 @@ def add_table( preview_df = data.head(PREVIEW_ROW_COUNT) elif isinstance(data, (str, Path)): data_location = self.source_data_handler.resolve_data_location(data) - with smart_open.open(data_location, "rb") as d: + with open_artifact(data_location, "rb") as d: preview_df = pd.read_csv(d, nrows=PREVIEW_ROW_COUNT) columns = list(preview_df.columns) json_cols = relational_json.get_json_columns(preview_df) @@ -293,7 +293,7 @@ def add_table( if isinstance(data, pd.DataFrame): df = data elif isinstance(data, (str, Path)): - with smart_open.open(data, "rb") as d: + with open_artifact(data, "rb") as d: df = pd.read_csv(d) rj_ingest = relational_json.ingest(name, primary_key, df, json_cols) @@ -359,7 +359,7 @@ def _add_single_table( if columns is not None: cols = columns else: - with smart_open.open(source, "rb") as src: + with open_artifact(source, "rb") as src: cols = list(pd.read_csv(src, nrows=1).columns) metadata = TableMetadata( primary_key=primary_key, @@ -762,7 +762,7 @@ def get_table_data( """ source = self.get_table_source(table) usecols = usecols or self.get_table_columns(table) - with smart_open.open(source, "rb") as src: + with open_artifact(source, "rb") as src: return pd.read_csv(src, usecols=usecols) def get_table_columns(self, table: str) -> list[str]: diff --git a/src/gretel_trainer/relational/multi_table.py b/src/gretel_trainer/relational/multi_table.py index ca3e7e05..3ae58864 100644 --- a/src/gretel_trainer/relational/multi_table.py +++ b/src/gretel_trainer/relational/multi_table.py @@ -25,6 +25,7 @@ from gretel_client.config import get_session_config, RunnerMode from gretel_client.projects import create_project, get_project, Project +from gretel_client.projects.artifact_handlers import open_artifact from gretel_client.projects.jobs import ACTIVE_STATES, END_STATES, Status from gretel_client.projects.records import RecordHandler from gretel_trainer.relational.artifacts import ArtifactCollection @@ -651,7 +652,7 @@ def run_transforms( if isinstance(data_source, pd.DataFrame): data_source.to_csv(transforms_run_path, index=False) else: - with smart_open.open(data_source, "rb") as src, smart_open.open( + with open_artifact(data_source, "rb") as src, open_artifact( transforms_run_path, "wb" ) as dest: shutil.copyfileobj(src, dest) @@ -690,7 +691,10 @@ def run_transforms( for table, df in reshaped_tables.items(): filename = f"transformed_{table}.csv" out_path = self._output_handler.filepath_for(filename, subdir=run_subdir) - with smart_open.open(out_path, "wb") as dest: + with open_artifact( + out_path, + "wb", + ) as dest: df.to_csv( dest, index=False, @@ -899,7 +903,7 @@ def generate( synth_csv_path = self._output_handler.filepath_for( f"synth_{table}.csv", subdir=run_subdir ) - with smart_open.open(synth_csv_path, "wb") as dest: + with open_artifact(synth_csv_path, "wb") as dest: synth_df.to_csv( dest, index=False, @@ -1042,7 +1046,7 @@ def create_relational_report(self, run_identifier: str, filepath: str) -> None: now=datetime.utcnow(), run_identifier=run_identifier, ) - with smart_open.open(filepath, "w") as report: + with open_artifact(filepath, "w") as report: html_content = ReportRenderer().render(presenter) report.write(html_content) @@ -1054,8 +1058,8 @@ def _attach_existing_reports(self, run_id: str, table: str) -> None: f"synthetics_cross_table_evaluation_{table}.json", subdir=run_id ) - individual_report_json = json.loads(smart_open.open(individual_path).read()) - cross_table_report_json = json.loads(smart_open.open(cross_table_path).read()) + individual_report_json = json.loads(open_artifact(individual_path).read()) + cross_table_report_json = json.loads(open_artifact(cross_table_path).read()) self._evaluations[table].individual_report_json = individual_report_json self._evaluations[table].cross_table_report_json = cross_table_report_json diff --git a/src/gretel_trainer/relational/sdk_extras.py b/src/gretel_trainer/relational/sdk_extras.py index c9cc63dd..f08767d9 100644 --- a/src/gretel_trainer/relational/sdk_extras.py +++ b/src/gretel_trainer/relational/sdk_extras.py @@ -7,8 +7,8 @@ import pandas as pd import requests -import smart_open +from gretel_client.projects.artifact_handlers import open_artifact from gretel_client.projects.jobs import Job, Status from gretel_client.projects.models import Model from gretel_client.projects.projects import Project @@ -57,9 +57,9 @@ def download_file_artifact( out_path: Union[str, Path], ) -> bool: try: - with gretel_object.get_artifact_handle( - artifact_name - ) as src, smart_open.open(out_path, "wb") as dest: + with gretel_object.get_artifact_handle(artifact_name) as src, open_artifact( + out_path, "wb" + ) as dest: shutil.copyfileobj(src, dest) return True except: @@ -73,7 +73,7 @@ def download_tar_artifact( try: response = requests.get(download_link) if response.status_code == 200: - with smart_open.open(out_path, "wb") as out: + with open_artifact(out_path, "wb") as out: out.write(response.content) except: logger.warning(f"Failed to download `{artifact_name}`") diff --git a/src/gretel_trainer/relational/strategies/ancestral.py b/src/gretel_trainer/relational/strategies/ancestral.py index 7438f3e9..29140e2a 100644 --- a/src/gretel_trainer/relational/strategies/ancestral.py +++ b/src/gretel_trainer/relational/strategies/ancestral.py @@ -3,11 +3,11 @@ from typing import Any, Union import pandas as pd -import smart_open import gretel_trainer.relational.ancestry as ancestry import gretel_trainer.relational.strategies.common as common +from gretel_client.projects.artifact_handlers import open_artifact from gretel_trainer.relational.core import ( GretelModelConfig, MultiTableException, @@ -73,7 +73,7 @@ def prepare_training_data( tableset=altered_tableset, ancestral_seeding=True, ) - with smart_open.open(path, "wb") as dest: + with open_artifact(path, "wb") as dest: data.to_csv(dest, index=False) return table_paths @@ -164,7 +164,7 @@ def get_generation_job( seed_path = output_handler.filepath_for( f"synthetics_seed_{table}.csv", subdir=subdir ) - with smart_open.open(seed_path, "wb") as dest: + with open_artifact(seed_path, "wb") as dest: seed_df.to_csv(dest, index=False) return {"data_source": str(seed_path)} diff --git a/src/gretel_trainer/relational/strategies/independent.py b/src/gretel_trainer/relational/strategies/independent.py index 207a7b99..2820c39d 100644 --- a/src/gretel_trainer/relational/strategies/independent.py +++ b/src/gretel_trainer/relational/strategies/independent.py @@ -4,10 +4,10 @@ from typing import Any import pandas as pd -import smart_open import gretel_trainer.relational.strategies.common as common +from gretel_client.projects.artifact_handlers import open_artifact from gretel_trainer.relational.core import ForeignKey, GretelModelConfig, RelationalData from gretel_trainer.relational.output_handler import OutputHandler @@ -49,7 +49,7 @@ def prepare_training_data( use_columns = [col for col in all_columns if col not in columns_to_drop] source_path = rel_data.get_table_source(table) - with smart_open.open(source_path, "rb") as src, smart_open.open( + with open_artifact(source_path, "rb") as src, open_artifact( path, "wb" ) as dest: pd.DataFrame(columns=use_columns).to_csv(dest, index=False) diff --git a/src/gretel_trainer/relational/tasks/classify.py b/src/gretel_trainer/relational/tasks/classify.py index 72f84e7d..dafe23f0 100644 --- a/src/gretel_trainer/relational/tasks/classify.py +++ b/src/gretel_trainer/relational/tasks/classify.py @@ -1,9 +1,8 @@ import shutil -import smart_open - import gretel_trainer.relational.tasks.common as common +from gretel_client.projects.artifact_handlers import open_artifact from gretel_client.projects.jobs import Job from gretel_client.projects.models import Model from gretel_client.projects.projects import Project @@ -150,7 +149,7 @@ def _write_results(self, job: Job, table: str) -> None: destpath = self.output_handler.filepath_for(filename) - with job.get_artifact_handle(artifact_name) as src, smart_open.open( + with job.get_artifact_handle(artifact_name) as src, open_artifact( str(destpath), "wb" ) as dest: shutil.copyfileobj(src, dest) diff --git a/src/gretel_trainer/relational/tasks/synthetics_evaluate.py b/src/gretel_trainer/relational/tasks/synthetics_evaluate.py index 92fbceb0..34a4ccf7 100644 --- a/src/gretel_trainer/relational/tasks/synthetics_evaluate.py +++ b/src/gretel_trainer/relational/tasks/synthetics_evaluate.py @@ -8,6 +8,7 @@ import gretel_trainer.relational.tasks.common as common +from gretel_client.projects.artifact_handlers import open_artifact from gretel_client.projects.jobs import Job from gretel_client.projects.models import Model from gretel_client.projects.projects import Project @@ -144,7 +145,7 @@ def _read_json_report(model: Model, json_report_filepath: str) -> Optional[dict] also fails, log a warning and give up gracefully. """ try: - return json.loads(smart_open.open(json_report_filepath).read()) + return json.loads(open_artifact(json_report_filepath).read()) except: try: with model.get_artifact_handle("report_json") as report: