Skip to content

Commit

Permalink
Add azure and google storage deps + wrap smart_open calls for relational
Browse files Browse the repository at this point in the history
* Add azure and google storage deps

* Update relational SDK to use transport params

* Wrap smart_open calls to insert the transport_params
as necessary for Azure

GitOrigin-RevId: e13e70c9308c6a27099c998f82f4e3267463ef95
  • Loading branch information
mckornfield authored and tylersbray committed Jan 5, 2024
1 parent 35684de commit 3cddc7c
Show file tree
Hide file tree
Showing 7 changed files with 29 additions and 25 deletions.
10 changes: 5 additions & 5 deletions src/gretel_trainer/relational/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@

import networkx
import pandas as pd
import smart_open

from networkx.algorithms.cycles import simple_cycles
from networkx.algorithms.dag import dag_longest_path_length, topological_sort
Expand All @@ -34,6 +33,7 @@

import gretel_trainer.relational.json as relational_json

from gretel_client.projects.artifact_handlers import open_artifact
from gretel_trainer.relational.json import (
IngestResponseT,
InventedTableMetadata,
Expand Down Expand Up @@ -273,7 +273,7 @@ def add_table(
preview_df = data.head(PREVIEW_ROW_COUNT)
elif isinstance(data, (str, Path)):
data_location = self.source_data_handler.resolve_data_location(data)
with smart_open.open(data_location, "rb") as d:
with open_artifact(data_location, "rb") as d:
preview_df = pd.read_csv(d, nrows=PREVIEW_ROW_COUNT)
columns = list(preview_df.columns)
json_cols = relational_json.get_json_columns(preview_df)
Expand All @@ -293,7 +293,7 @@ def add_table(
if isinstance(data, pd.DataFrame):
df = data
elif isinstance(data, (str, Path)):
with smart_open.open(data, "rb") as d:
with open_artifact(data, "rb") as d:
df = pd.read_csv(d)
rj_ingest = relational_json.ingest(name, primary_key, df, json_cols)

Expand Down Expand Up @@ -359,7 +359,7 @@ def _add_single_table(
if columns is not None:
cols = columns
else:
with smart_open.open(source, "rb") as src:
with open_artifact(source, "rb") as src:
cols = list(pd.read_csv(src, nrows=1).columns)
metadata = TableMetadata(
primary_key=primary_key,
Expand Down Expand Up @@ -762,7 +762,7 @@ def get_table_data(
"""
source = self.get_table_source(table)
usecols = usecols or self.get_table_columns(table)
with smart_open.open(source, "rb") as src:
with open_artifact(source, "rb") as src:
return pd.read_csv(src, usecols=usecols)

def get_table_columns(self, table: str) -> list[str]:
Expand Down
16 changes: 10 additions & 6 deletions src/gretel_trainer/relational/multi_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@

from gretel_client.config import get_session_config, RunnerMode
from gretel_client.projects import create_project, get_project, Project
from gretel_client.projects.artifact_handlers import open_artifact
from gretel_client.projects.jobs import ACTIVE_STATES, END_STATES, Status
from gretel_client.projects.records import RecordHandler
from gretel_trainer.relational.artifacts import ArtifactCollection
Expand Down Expand Up @@ -651,7 +652,7 @@ def run_transforms(
if isinstance(data_source, pd.DataFrame):
data_source.to_csv(transforms_run_path, index=False)
else:
with smart_open.open(data_source, "rb") as src, smart_open.open(
with open_artifact(data_source, "rb") as src, open_artifact(
transforms_run_path, "wb"
) as dest:
shutil.copyfileobj(src, dest)
Expand Down Expand Up @@ -690,7 +691,10 @@ def run_transforms(
for table, df in reshaped_tables.items():
filename = f"transformed_{table}.csv"
out_path = self._output_handler.filepath_for(filename, subdir=run_subdir)
with smart_open.open(out_path, "wb") as dest:
with open_artifact(
out_path,
"wb",
) as dest:
df.to_csv(
dest,
index=False,
Expand Down Expand Up @@ -899,7 +903,7 @@ def generate(
synth_csv_path = self._output_handler.filepath_for(
f"synth_{table}.csv", subdir=run_subdir
)
with smart_open.open(synth_csv_path, "wb") as dest:
with open_artifact(synth_csv_path, "wb") as dest:
synth_df.to_csv(
dest,
index=False,
Expand Down Expand Up @@ -1042,7 +1046,7 @@ def create_relational_report(self, run_identifier: str, filepath: str) -> None:
now=datetime.utcnow(),
run_identifier=run_identifier,
)
with smart_open.open(filepath, "w") as report:
with open_artifact(filepath, "w") as report:
html_content = ReportRenderer().render(presenter)
report.write(html_content)

Expand All @@ -1054,8 +1058,8 @@ def _attach_existing_reports(self, run_id: str, table: str) -> None:
f"synthetics_cross_table_evaluation_{table}.json", subdir=run_id
)

individual_report_json = json.loads(smart_open.open(individual_path).read())
cross_table_report_json = json.loads(smart_open.open(cross_table_path).read())
individual_report_json = json.loads(open_artifact(individual_path).read())
cross_table_report_json = json.loads(open_artifact(cross_table_path).read())

self._evaluations[table].individual_report_json = individual_report_json
self._evaluations[table].cross_table_report_json = cross_table_report_json
Expand Down
10 changes: 5 additions & 5 deletions src/gretel_trainer/relational/sdk_extras.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@

import pandas as pd
import requests
import smart_open

from gretel_client.projects.artifact_handlers import open_artifact
from gretel_client.projects.jobs import Job, Status
from gretel_client.projects.models import Model
from gretel_client.projects.projects import Project
Expand Down Expand Up @@ -57,9 +57,9 @@ def download_file_artifact(
out_path: Union[str, Path],
) -> bool:
try:
with gretel_object.get_artifact_handle(
artifact_name
) as src, smart_open.open(out_path, "wb") as dest:
with gretel_object.get_artifact_handle(artifact_name) as src, open_artifact(
out_path, "wb"
) as dest:
shutil.copyfileobj(src, dest)
return True
except:
Expand All @@ -73,7 +73,7 @@ def download_tar_artifact(
try:
response = requests.get(download_link)
if response.status_code == 200:
with smart_open.open(out_path, "wb") as out:
with open_artifact(out_path, "wb") as out:
out.write(response.content)
except:
logger.warning(f"Failed to download `{artifact_name}`")
Expand Down
6 changes: 3 additions & 3 deletions src/gretel_trainer/relational/strategies/ancestral.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@
from typing import Any, Union

import pandas as pd
import smart_open

import gretel_trainer.relational.ancestry as ancestry
import gretel_trainer.relational.strategies.common as common

from gretel_client.projects.artifact_handlers import open_artifact
from gretel_trainer.relational.core import (
GretelModelConfig,
MultiTableException,
Expand Down Expand Up @@ -73,7 +73,7 @@ def prepare_training_data(
tableset=altered_tableset,
ancestral_seeding=True,
)
with smart_open.open(path, "wb") as dest:
with open_artifact(path, "wb") as dest:
data.to_csv(dest, index=False)

return table_paths
Expand Down Expand Up @@ -164,7 +164,7 @@ def get_generation_job(
seed_path = output_handler.filepath_for(
f"synthetics_seed_{table}.csv", subdir=subdir
)
with smart_open.open(seed_path, "wb") as dest:
with open_artifact(seed_path, "wb") as dest:
seed_df.to_csv(dest, index=False)
return {"data_source": str(seed_path)}

Expand Down
4 changes: 2 additions & 2 deletions src/gretel_trainer/relational/strategies/independent.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@
from typing import Any

import pandas as pd
import smart_open

import gretel_trainer.relational.strategies.common as common

from gretel_client.projects.artifact_handlers import open_artifact
from gretel_trainer.relational.core import ForeignKey, GretelModelConfig, RelationalData
from gretel_trainer.relational.output_handler import OutputHandler

Expand Down Expand Up @@ -49,7 +49,7 @@ def prepare_training_data(
use_columns = [col for col in all_columns if col not in columns_to_drop]

source_path = rel_data.get_table_source(table)
with smart_open.open(source_path, "rb") as src, smart_open.open(
with open_artifact(source_path, "rb") as src, open_artifact(
path, "wb"
) as dest:
pd.DataFrame(columns=use_columns).to_csv(dest, index=False)
Expand Down
5 changes: 2 additions & 3 deletions src/gretel_trainer/relational/tasks/classify.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
import shutil

import smart_open

import gretel_trainer.relational.tasks.common as common

from gretel_client.projects.artifact_handlers import open_artifact
from gretel_client.projects.jobs import Job
from gretel_client.projects.models import Model
from gretel_client.projects.projects import Project
Expand Down Expand Up @@ -150,7 +149,7 @@ def _write_results(self, job: Job, table: str) -> None:

destpath = self.output_handler.filepath_for(filename)

with job.get_artifact_handle(artifact_name) as src, smart_open.open(
with job.get_artifact_handle(artifact_name) as src, open_artifact(
str(destpath), "wb"
) as dest:
shutil.copyfileobj(src, dest)
Expand Down
3 changes: 2 additions & 1 deletion src/gretel_trainer/relational/tasks/synthetics_evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

import gretel_trainer.relational.tasks.common as common

from gretel_client.projects.artifact_handlers import open_artifact
from gretel_client.projects.jobs import Job
from gretel_client.projects.models import Model
from gretel_client.projects.projects import Project
Expand Down Expand Up @@ -144,7 +145,7 @@ def _read_json_report(model: Model, json_report_filepath: str) -> Optional[dict]
also fails, log a warning and give up gracefully.
"""
try:
return json.loads(smart_open.open(json_report_filepath).read())
return json.loads(open_artifact(json_report_filepath).read())
except:
try:
with model.get_artifact_handle("report_json") as report:
Expand Down

0 comments on commit 3cddc7c

Please sign in to comment.