From 6165f2d74d7749ea81e70856341a1a0b37e88da5 Mon Sep 17 00:00:00 2001 From: Gretel Team Date: Wed, 22 May 2024 15:38:49 -0500 Subject: [PATCH] Project import generated by Copybara. GitOrigin-RevId: c0020212f05a7787fe084054e6fbcf5b0b9d2f62 --- docs/index.rst | 1 - docs/relational.rst | 39 - notebooks/relational.ipynb | 609 --------- .../table_extraction_with_subsetting.ipynb | 87 -- requirements.txt | 8 - src/gretel_trainer/relational/__init__.py | 24 - src/gretel_trainer/relational/ancestry.py | 222 ---- src/gretel_trainer/relational/backup.py | 165 --- src/gretel_trainer/relational/connectors.py | 211 --- src/gretel_trainer/relational/core.py | 932 ------------- src/gretel_trainer/relational/extractor.py | 618 --------- src/gretel_trainer/relational/json.py | 452 ------- src/gretel_trainer/relational/log.py | 37 - src/gretel_trainer/relational/model_config.py | 184 --- src/gretel_trainer/relational/multi_table.py | 1069 --------------- .../relational/output_handler.py | 216 --- .../relational/report/figures.py | 147 --- .../relational/report/key_highlight.js | 65 - .../relational/report/report.css | 122 -- .../relational/report/report.py | 161 --- .../report/report_privacy_protection.css | 40 - .../report/report_synthetic_quality.css | 172 --- .../relational/report/report_template.html | 170 --- src/gretel_trainer/relational/sdk_extras.py | 99 -- .../relational/strategies/ancestral.py | 360 ----- .../relational/strategies/common.py | 207 --- .../relational/strategies/independent.py | 335 ----- .../relational/table_evaluation.py | 99 -- src/gretel_trainer/relational/task_runner.py | 101 -- .../relational/tasks/classify.py | 137 -- src/gretel_trainer/relational/tasks/common.py | 35 - .../relational/tasks/synthetics_evaluate.py | 148 --- .../relational/tasks/synthetics_run.py | 199 --- .../relational/tasks/synthetics_train.py | 57 - .../relational/tasks/transforms_run.py | 67 - .../relational/tasks/transforms_train.py | 57 - .../relational/workflow_state.py | 31 - tests/relational/conftest.py | 201 --- tests/relational/example_dbs/art.sql | 29 - tests/relational/example_dbs/documents.sql | 48 - tests/relational/example_dbs/ecom.sql | 50 - tests/relational/example_dbs/insurance.sql | 32 - tests/relational/example_dbs/mutagenesis.sql | 23 - tests/relational/example_dbs/pets.sql | 29 - tests/relational/example_dbs/tpch.sql | 72 - tests/relational/example_dbs/trips.sql | 19 - tests/relational/test_ancestral_strategy.py | 697 ---------- tests/relational/test_ancestry.py | 280 ---- tests/relational/test_backup.py | 165 --- tests/relational/test_common_strategy.py | 85 -- tests/relational/test_connectors.py | 37 - tests/relational/test_extractor.py | 138 -- tests/relational/test_independent_strategy.py | 442 ------- tests/relational/test_model_config.py | 240 ---- tests/relational/test_multi_table_restore.py | 362 ----- tests/relational/test_output_handler.py | 21 - tests/relational/test_relational_data.py | 357 ----- .../test_relational_data_with_json.py | 1171 ----------------- tests/relational/test_report.py | 275 ---- tests/relational/test_synthetics_evaluate.py | 43 - tests/relational/test_synthetics_run_task.py | 172 --- tests/relational/test_task_runner.py | 256 ---- tests/relational/test_train_synthetics.py | 233 ---- tests/relational/test_train_transforms.py | 97 -- 64 files changed, 13257 deletions(-) delete mode 100644 docs/relational.rst delete mode 100644 notebooks/relational.ipynb delete mode 100644 notebooks/table_extraction_with_subsetting.ipynb delete mode 100644 src/gretel_trainer/relational/__init__.py delete mode 100644 src/gretel_trainer/relational/ancestry.py delete mode 100644 src/gretel_trainer/relational/backup.py delete mode 100644 src/gretel_trainer/relational/connectors.py delete mode 100644 src/gretel_trainer/relational/core.py delete mode 100644 src/gretel_trainer/relational/extractor.py delete mode 100644 src/gretel_trainer/relational/json.py delete mode 100644 src/gretel_trainer/relational/log.py delete mode 100644 src/gretel_trainer/relational/model_config.py delete mode 100644 src/gretel_trainer/relational/multi_table.py delete mode 100644 src/gretel_trainer/relational/output_handler.py delete mode 100644 src/gretel_trainer/relational/report/figures.py delete mode 100644 src/gretel_trainer/relational/report/key_highlight.js delete mode 100644 src/gretel_trainer/relational/report/report.css delete mode 100644 src/gretel_trainer/relational/report/report.py delete mode 100644 src/gretel_trainer/relational/report/report_privacy_protection.css delete mode 100644 src/gretel_trainer/relational/report/report_synthetic_quality.css delete mode 100644 src/gretel_trainer/relational/report/report_template.html delete mode 100644 src/gretel_trainer/relational/sdk_extras.py delete mode 100644 src/gretel_trainer/relational/strategies/ancestral.py delete mode 100644 src/gretel_trainer/relational/strategies/common.py delete mode 100644 src/gretel_trainer/relational/strategies/independent.py delete mode 100644 src/gretel_trainer/relational/table_evaluation.py delete mode 100644 src/gretel_trainer/relational/task_runner.py delete mode 100644 src/gretel_trainer/relational/tasks/classify.py delete mode 100644 src/gretel_trainer/relational/tasks/common.py delete mode 100644 src/gretel_trainer/relational/tasks/synthetics_evaluate.py delete mode 100644 src/gretel_trainer/relational/tasks/synthetics_run.py delete mode 100644 src/gretel_trainer/relational/tasks/synthetics_train.py delete mode 100644 src/gretel_trainer/relational/tasks/transforms_run.py delete mode 100644 src/gretel_trainer/relational/tasks/transforms_train.py delete mode 100644 src/gretel_trainer/relational/workflow_state.py delete mode 100644 tests/relational/conftest.py delete mode 100644 tests/relational/example_dbs/art.sql delete mode 100644 tests/relational/example_dbs/documents.sql delete mode 100644 tests/relational/example_dbs/ecom.sql delete mode 100644 tests/relational/example_dbs/insurance.sql delete mode 100644 tests/relational/example_dbs/mutagenesis.sql delete mode 100644 tests/relational/example_dbs/pets.sql delete mode 100644 tests/relational/example_dbs/tpch.sql delete mode 100644 tests/relational/example_dbs/trips.sql delete mode 100644 tests/relational/test_ancestral_strategy.py delete mode 100644 tests/relational/test_ancestry.py delete mode 100644 tests/relational/test_backup.py delete mode 100644 tests/relational/test_common_strategy.py delete mode 100644 tests/relational/test_connectors.py delete mode 100644 tests/relational/test_extractor.py delete mode 100644 tests/relational/test_independent_strategy.py delete mode 100644 tests/relational/test_model_config.py delete mode 100644 tests/relational/test_multi_table_restore.py delete mode 100644 tests/relational/test_output_handler.py delete mode 100644 tests/relational/test_relational_data.py delete mode 100644 tests/relational/test_relational_data_with_json.py delete mode 100644 tests/relational/test_report.py delete mode 100644 tests/relational/test_synthetics_evaluate.py delete mode 100644 tests/relational/test_synthetics_run_task.py delete mode 100644 tests/relational/test_task_runner.py delete mode 100644 tests/relational/test_train_synthetics.py delete mode 100644 tests/relational/test_train_transforms.py diff --git a/docs/index.rst b/docs/index.rst index 209840cd..76caee98 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -94,7 +94,6 @@ Modules quickstart.rst trainer.rst models.rst - relational.rst diff --git a/docs/relational.rst b/docs/relational.rst deleted file mode 100644 index 96a4c8e3..00000000 --- a/docs/relational.rst +++ /dev/null @@ -1,39 +0,0 @@ -Relational -========== - -The Gretel Trainer Relational sub-package provides interfaces for processing -multiple tables from relational databases and data warehouses. You may -utilize a `Connector` to automatically connect to supported databases or -provide your own CSVs and define relationships manually. - -Please see https://docs.gretel.ai/reference/relational to get started. - -The primary interfaces are documented below. - -.. contents:: - :local: - :depth: 1 - -Connectors ----------- - -.. automodule:: gretel_trainer.relational.connectors - :members: - -Core ----- - -.. automodule:: gretel_trainer.relational.core - :members: - -Multi Table ------------ - -.. automodule:: gretel_trainer.relational.multi_table - :members: - -Extractor ---------- - -.. automodule:: gretel_trainer.relational.extractor - :members: \ No newline at end of file diff --git a/notebooks/relational.ipynb b/notebooks/relational.ipynb deleted file mode 100644 index 6764387a..00000000 --- a/notebooks/relational.ipynb +++ /dev/null @@ -1,609 +0,0 @@ -{ - "cells": [ - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Gretel Relational\n", - "Synthetics and Transforms for relational data." - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Quickstart" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "from gretel_client import configure_session\n", - "\n", - "configure_session(api_key=\"prompt\", cache=\"yes\", validate=True)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# End-to-end synthetics example\n", - "\n", - "from gretel_trainer.relational import MultiTable, sqlite_conn\n", - "\n", - "\n", - "!curl -o \"ecom_xf.db\" \"https://gretel-blueprints-pub.s3.us-west-2.amazonaws.com/rdb/ecom_xf.db\"\n", - "\n", - "\n", - "connector = sqlite_conn(\"ecom_xf.db\")\n", - "relational_data = connector.extract()\n", - "\n", - "mt = MultiTable(relational_data)\n", - "mt.train_synthetics(config=\"synthetics/amplify\")\n", - "mt.generate()\n", - "\n", - "connector.save(mt.synthetic_output_tables, prefix=\"synthetic_\")" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Detailed walkthrough" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Set up source relational data" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Display the schema of our demo database\n", - "\n", - "from IPython.display import Image\n", - "\n", - "Image(\"https://gretel-blueprints-pub.s3.us-west-2.amazonaws.com/rdb/ecommerce_db.png\", width=600, height=600)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Download the demo database\n", - "\n", - "!curl -o \"ecom_xf.db\" \"https://gretel-blueprints-pub.s3.us-west-2.amazonaws.com/rdb/ecom_xf.db\"" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "The core Python object capturing source relational data and metadata is named `RelationalData`.\n", - "It can be created automatically using a `Connector`, or it can be created manually.\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Connect to SQLite database and extract relational data\n", - "\n", - "from gretel_trainer.relational import sqlite_conn\n", - "\n", - "ecommerce_db_path = \"ecom_xf.db\"\n", - "\n", - "sqlite = sqlite_conn(path=ecommerce_db_path)\n", - "relational_data = sqlite.extract()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Alternatively, manually define relational data\n", - "\n", - "from gretel_trainer.relational import RelationalData\n", - "import pandas as pd\n", - "\n", - "csv_dir = \"/path/to/extracted_csvs\"\n", - "\n", - "tables = [\n", - " (\"events\", \"id\"),\n", - " (\"users\", \"id\"),\n", - " (\"distribution_center\", \"id\"),\n", - " (\"products\", \"id\"),\n", - " (\"inventory_items\", \"id\"),\n", - " (\"order_items\", \"id\"),\n", - "]\n", - "\n", - "foreign_keys = [\n", - " {\n", - " \"table\": \"events\",\n", - " \"constrained_columns\": [\"user_id\"],\n", - " \"referred_table\": \"users\",\n", - " \"referred_columns\": [\"id\"],\n", - " },\n", - " {\n", - " \"table\": \"order_items\",\n", - " \"constrained_columns\": [\"user_id\"],\n", - " \"referred_table\": \"users\",\n", - " \"referred_columns\": [\"id\"],\n", - " },\n", - " {\n", - " \"table\": \"order_items\",\n", - " \"constrained_columns\": [\"inventory_item_id\"],\n", - " \"referred_table\": \"inventory_items\",\n", - " \"referred_columns\": [\"id\"],\n", - " },\n", - " {\n", - " \"table\": \"inventory_items\",\n", - " \"constrained_columns\": [\"product_id\"],\n", - " \"referred_table\": \"products\",\n", - " \"referred_columns\": [\"id\"],\n", - " },\n", - " {\n", - " \"table\": \"inventory_items\",\n", - " \"constrained_columns\": [\"product_distribution_center_id\"],\n", - " \"referred_table\": \"distribution_center\",\n", - " \"referred_columns\": [\"id\"],\n", - " },\n", - " {\n", - " \"table\": \"products\",\n", - " \"constrained_columns\": [\"distribution_center_id\"],\n", - " \"referred_table\": \"distribution_center\",\n", - " \"referred_columns\": [\"id\"],\n", - " },\n", - "]\n", - "\n", - "rel_data = RelationalData()\n", - "\n", - "for table, pk in tables:\n", - " rel_data.add_table(name=table, primary_key=pk, data=pd.read_csv(f\"{csv_dir}/{table}.csv\"))\n", - "\n", - "for foreign_key in foreign_keys:\n", - " rel_data.add_foreign_key_constraint(**foreign_key)" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Regardless of how it was created, a `RelationalData` instance can be modified after creation if necessary. In addition to the methods in the manual example above, you can modify source table data, change primary keys, and remove foreign keys." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# We don't actually want to make these changes; they just serve as examples\n", - "\n", - "# Overwrite source data with a different dataframe\n", - "# rel_data.update_table_data(table=\"users\", data=pd.read_csv(\"alt_users.csv\"))\n", - "\n", - "# Change which column (if any) is designated as the primary key\n", - "# rel_data.set_primary_key(table=\"distribution_center\", primary_key=\"name\")\n", - "# rel_data.set_primary_key(table=\"order_items\", primary_key=None)\n", - "\n", - "# Remove a foreign key relationship\n", - "# rel_data.remove_foreign_key(\"inventory_items.product_distribution_center_id\")" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Operate on the source data" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "The `MultiTable` class is the interface to working with relational data. It requires a `RelationalData` instance. Several other options can be configured; the defaults are shown below as comments." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "from gretel_trainer.relational import MultiTable\n", - "\n", - "multitable = MultiTable(\n", - " relational_data,\n", - " # project_display_name=\"multi-table\",\n", - " # strategy=\"independent\",\n", - " # refresh_interval=60,\n", - ")" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "#### Classify\n", - "\n", - "Run Gretel Classify on all tables to identify PII.\n", - "\n", - "By default, Relational Classify will provide results for the first 100 rows in each table. To process the entire table, set `all_rows=True`." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "import yaml\n", - "\n", - "config_yaml = \"\"\"\n", - "schema_version: \"1.0\"\n", - "name: \"classify-default\"\n", - "models:\n", - " - classify:\n", - " data_source: \"_\"\n", - " labels:\n", - " - person_name\n", - " - credit_card_number\n", - " - phone_number\n", - " - us_social_security_number\n", - " - email_address\n", - " - location\n", - " - acme/*\n", - " \n", - "label_predictors:\n", - " namespace: acme\n", - " regex:\n", - " user_id:\n", - " patterns:\n", - " - score: high\n", - " regex: ^user_[\\d]{5}$\n", - "\"\"\"\n", - "config = yaml.safe_load(config_yaml)\n", - "\n", - "multitable.classify(config)\n", - "\n", - "# Run classify on all rows\n", - "# multitable.classify(config, all_rows=True)" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "#### Transforms\n", - "\n", - "Train Gretel Transforms models by providing a transforms model config. By default this config will be applied to all tables. You can limit the tables being transformed via the optional `only` (tables to include) or `ignore` (tables to exclude) arguments." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "config = \"https://raw.githubusercontent.com/gretelai/gdpr-helpers/main/src/config/transforms_config.yaml\"\n", - "\n", - "multitable.train_transforms(config)\n", - "\n", - "# Optionally limit which tables are trained for transforms via `only` (included) or `ignore` (excluded).\n", - "# Given our example data, the two calls below will lead to the same tables getting trained, just specified different ways.\n", - "#\n", - "# multitable.train_transforms(config, ignore={\"distribution_center\", \"products\"})\n", - "# multitable.train_transforms(config, only={\"users\", \"events\", \"inventory_items\", \"order_items\"})" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Run transforms to get transformed output. Each call to `run_transforms` is assigned (or supplied) a unique identifier; look for the transformed output tables in a subdirectory matching that identifier name in the working directory. An archive file containing all runs' outputs is also uploaded to the Gretel project as a project artifact, visible in the Data Sources tab in the Console.\n", - "\n", - "By default, `run_transforms` operates on the original source data for all tables with successfully completed transforms models.\n", - "\n", - "You can optionally run other data through transforms by passing it in as Pandas DataFrames to the optional `data` argument. In this case, only the provided tables will be transformed (not _all_ tables as in the default, no-`data`-argument case).\n", - "\n", - "If you intend to train synthetic models on the transformed output instead of the original source data, add the argument `in_place=True`. **This will modify the data in the `RelationalData` instance.**" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "multitable.run_transforms()\n", - "\n", - "# Provide a specific identifier for the run (default is `transforms_{timestamp}`)\n", - "# multitable.run_transforms(identifier=\"my-transforms-run\")\n", - "\n", - "# Overwrite source data so that future synthetics actions consider the transformed output as the source\n", - "# multitable.run_transforms(in_place=True)\n", - "\n", - "# Run other data through the trained transforms models\n", - "# multitable.run_transforms(data={\"events\": some_other_events_dataframe})" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Compare original to transformed\n", - "\n", - "print(multitable.relational_data.get_table_data(\"users\").head(5))\n", - "print(multitable.transform_output_tables[\"users\"].head(5))" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "#### Synthetics" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Start by training models for synthetics. By default, a synthetics model will be trained for every table in the `RelationalData`. However, this scope can be reduced to a subset of tables using the optional `only` (tables to include) or `ignore` (tables to exclude) arguments. This can be particularly useful if certain tables contain static reference data that should not be synthesized." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Train synthetic models for all tables\n", - "\n", - "multitable.train_synthetics(config=\"synthetics/amplify\")\n", - "\n", - "# Optionally limit which tables are trained for synthetics via `only` (included) or `ignore` (excluded).\n", - "# Given our example data, the two calls below will lead to the same tables getting trained, just specified different ways.\n", - "#\n", - "# multitable.train_synthetics(config=\"synthetics/amplify\", ignore={\"distribution_center\", \"products\"})\n", - "# multitable.train_synthetics(config=\"synthetics/amplify\", only={\"users\", \"events\", \"inventory_items\", \"order_items\"})" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "When training is complete, you'll find a number of artifacts in your working directory, including the CSVs on which models were trained (`synthetics_train_{table}.csv`) and evaluation reports (`synthetics_[type]_evaluation_{table}.[html|json]`).\n", - "\n", - "You can also view some evaluation metrics at this point. (We'll expand upon them after generating synthetic data.)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "multitable.evaluations" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Each synthetic data generation run is assigned (or supplied) a unique identifier. Look for a subdirectory with this identifier name in the working directory to find all synthetic outputs, including data and reports. An archive file containing all runs' outputs is also uploaded to the Gretel project as a project artifact, visible in the Data Sources tab in the Console.\n", - "\n", - "When you generate synthetic data, you can optionally change the amount of data to generate via `record_size_ratio`." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Generate synthetic data\n", - "\n", - "multitable.generate()\n", - "\n", - "# Provide a specific identifier for the run (default is `synthetics_{timestamp}`)\n", - "# multitable.generate(identifier=\"my-synthetics-run\")\n", - "\n", - "# Generate twice as much synthetic data\n", - "# multitable.generate(record_size_ratio=2.0)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Compare original to synthetic data\n", - "\n", - "print(multitable.relational_data.get_table_data(\"users\").head(5))\n", - "print(multitable.synthetic_output_tables[\"users\"].head(5))" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "If we take another look at our evaluations, we'll see additional metrics are available." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "multitable.evaluations" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "We've also automatically generated a full relational report summarizing and explaining all this information. Look for `relational_report.html` in the generate run subdirectory." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "import IPython\n", - "from smart_open import open\n", - "\n", - "report_path = str(multitable._working_dir / multitable._synthetics_run.identifier / \"relational_report.html\")\n", - "\n", - "IPython.display.HTML(data=open(report_path).read())" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "The synthetic data is automatically written to the working directory in CSV format as `synth_{table}.csv`. You can optionally use a `Connector` to write the synthetic data to a database. (If you're writing back to the same database as your source, pass a `prefix: str` argument to the `save` method to avoid overwriting your source tables!)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Write output data to a new SQLite database\n", - "\n", - "from gretel_trainer.relational import sqlite_conn\n", - "\n", - "synthetic_db_path = \"out.db\"\n", - "\n", - "synthetic_db_conn = sqlite_conn(synthetic_db_path)\n", - "synthetic_db_conn.save(multitable.synthetic_output_tables)" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Postgres demo via Docker" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Start up a postgres container with docker\n", - "\n", - "!docker run --rm -d --name multitable_pgdemo -e POSTGRES_PASSWORD=password -p 5432:5432 postgres" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Write synthetic tables to the Postgres db\n", - "\n", - "from gretel_trainer.relational import postgres_conn\n", - "\n", - "out_db = postgres_conn(\"postgres\", \"password\", \"localhost\", 5432)\n", - "out_db.save(multitable.synthetic_output_tables)\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Inspect the postgres database\n", - "\n", - "!docker exec multitable_pgdemo psql -U postgres -c \"\\dt\"\n", - "!docker exec multitable_pgdemo psql -U postgres -c \"select * from users limit 5;\"" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Tear down the docker container\n", - "\n", - "!docker stop multitable_pgdemo" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3.9.10 64-bit ('3.9.10')", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.9.10" - }, - "orig_nbformat": 4, - "vscode": { - "interpreter": { - "hash": "c8726cf33f00e2373738d19e8a73b26d03723d6c732c72211354be2991192c77" - } - } - }, - "nbformat": 4, - "nbformat_minor": 2 -} diff --git a/notebooks/table_extraction_with_subsetting.ipynb b/notebooks/table_extraction_with_subsetting.ipynb deleted file mode 100644 index 0f67d9d2..00000000 --- a/notebooks/table_extraction_with_subsetting.ipynb +++ /dev/null @@ -1,87 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": null, - "id": "a2900f42", - "metadata": {}, - "outputs": [], - "source": [ - "from pathlib import Path\n", - "\n", - "from gretel_trainer.relational.connectors import Connector\n", - "from gretel_trainer.relational.extractor import ExtractorConfig, TableExtractor\n", - "from sqlalchemy import create_engine\n", - "\n", - "# Where we are gonna drop these tables!\n", - "storage_dir = Path(\"table-data\")\n", - "storage_dir.mkdir(exist_ok=True)\n", - "\n", - "!curl -o \"ecom.db\" \"https://gretel-blueprints-pub.s3.us-west-2.amazonaws.com/rdb/ecom_xf.db\"" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "cf21bac7", - "metadata": { - "scrolled": false - }, - "outputs": [], - "source": [ - "connector = Connector.from_conn_str(\"sqlite:///ecom.db\")\n", - "\n", - "# Change this if you want more/less subsetting. You can set it to 0 and only the headers\n", - "# of the tables will be extracted. The default mode is -1 (get all the tables)\n", - "#\n", - "# If you set to a value between 0..1 (exclusive), then that will be the rough\n", - "# percentage of rows that are sampled. So a value of .5 will sample roughly half.\n", - "\n", - "config = ExtractorConfig(\n", - " target_row_count=100,\n", - ")\n", - "\n", - "extractor = TableExtractor(connector=connector, config=config, storage_dir=storage_dir)\n", - "extractor.sample_tables()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "9088f8a0", - "metadata": {}, - "outputs": [], - "source": [ - "# All the tables are on disk in the `storage_dir`, you can load one\n", - "# back in as a DF just based on the table name.\n", - "\n", - "import random\n", - "import pandas as pd\n", - "\n", - "df = extractor.get_table_df(random.choice(extractor.table_order))\n", - "print(df.head())" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3 (ipykernel)", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.9.10" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} diff --git a/requirements.txt b/requirements.txt index 24e9f1fc..7731cb78 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,15 +1,7 @@ boto3~=1.20 -dask[dataframe]==2023.5.1 gretel-client>=0.17.7 -jinja2~=3.1 -networkx~=3.0 -numpy~=1.20 pandas~=1.5 -plotly~=5.11 pydantic~=1.9 requests~=2.25 -scikit-learn~=1.0 smart_open[s3]~=5.2 -sqlalchemy~=1.4 typing_extensions~=4.7 -unflatten==0.1.1 diff --git a/src/gretel_trainer/relational/__init__.py b/src/gretel_trainer/relational/__init__.py deleted file mode 100644 index 979b25bb..00000000 --- a/src/gretel_trainer/relational/__init__.py +++ /dev/null @@ -1,24 +0,0 @@ -import logging - -import gretel_trainer.relational.log - -from gretel_trainer.relational.connectors import ( - Connector, - mariadb_conn, - mysql_conn, - postgres_conn, - snowflake_conn, - sqlite_conn, -) -from gretel_trainer.relational.core import RelationalData -from gretel_trainer.relational.extractor import ExtractorConfig -from gretel_trainer.relational.log import set_log_level -from gretel_trainer.relational.multi_table import MultiTable - -logger = logging.getLogger(__name__) - -logger.warn( - "Relational Trainer is deprecated, and will be removed in the next Trainer release. " - "To transform and synthesize relational data, use Gretel Workflows. " - "Visit the docs to learn more: https://docs.gretel.ai/create-synthetic-data/workflows-and-connectors" -) diff --git a/src/gretel_trainer/relational/ancestry.py b/src/gretel_trainer/relational/ancestry.py deleted file mode 100644 index db034b30..00000000 --- a/src/gretel_trainer/relational/ancestry.py +++ /dev/null @@ -1,222 +0,0 @@ -import re - -from typing import Optional - -import pandas as pd - -from gretel_trainer.relational.core import ForeignKey, RelationalData - -_START_LINEAGE = "self" -_GEN_DELIMITER = "." -_COL_DELIMITER = "+" -_END_LINEAGE = "|" - - -def get_multigenerational_primary_key( - rel_data: RelationalData, table: str -) -> list[str]: - "Returns the provided table's primary key with the ancestral lineage prefix appended" - return [ - f"{_START_LINEAGE}{_END_LINEAGE}{pk}" for pk in rel_data.get_primary_key(table) - ] - - -def get_all_key_columns(rel_data: RelationalData, table: str) -> list[str]: - tableset = _empty_data_tableset(rel_data) - return list( - get_table_data_with_ancestors(rel_data, table, tableset, keys_only=True).columns - ) - - -def get_ancestral_foreign_key_maps( - rel_data: RelationalData, table: str -) -> list[tuple[str, str]]: - """ - Returns a list of two-element tuples where the first element is a foreign key column - with ancestral lineage prefix, and the second element is the ancestral-lineage-prefixed - referred column. This function ultimately provides a list of which columns are duplicates - in a fully-joined ancestral table (i.e. `get_table_data_with_ancestors`) (only between - the provided table and its direct parents, not between parents and grandparents). - - For example: given an events table with foreign key `events.user_id` => `users.id`, - this method returns: [("self|user_id", "self.user_id|id")] - """ - - def _ancestral_fk_map(fk: ForeignKey) -> list[tuple[str, str]]: - maps = [] - fk_lineage = _COL_DELIMITER.join(fk.columns) - - for i in range(len(fk.columns)): - fk_col = fk.columns[i] - ref_col = fk.parent_columns[i] - - maps.append( - ( - f"{_START_LINEAGE}{_END_LINEAGE}{fk_col}", - f"{_START_LINEAGE}{_GEN_DELIMITER}{fk_lineage}{_END_LINEAGE}{ref_col}", - ) - ) - - return maps - - return [ - fkmap - for fk in rel_data.get_foreign_keys(table) - for fkmap in _ancestral_fk_map(fk) - ] - - -def _empty_data_tableset(rel_data: RelationalData) -> dict[str, pd.DataFrame]: - return { - table: pd.DataFrame(columns=list(rel_data.get_table_columns(table))) - for table in rel_data.list_all_tables() - } - - -def get_seed_safe_multigenerational_columns( - rel_data: RelationalData, -) -> dict[str, list[str]]: - """ - Returns a dict with Scope.MODELABLE table names as keys and lists of columns to use - for conditional seeding as values. By using a tableset of empty dataframes, this provides - a significantly faster / less resource-intensive way to get just the column names - from the results of `get_table_data_with_ancestors` for all tables. - """ - tableset = _empty_data_tableset(rel_data) - return { - table: list( - get_table_data_with_ancestors( - rel_data, table, tableset, ancestral_seeding=True - ).columns - ) - for table in rel_data.list_all_tables() - } - - -def get_table_data_with_ancestors( - rel_data: RelationalData, - table: str, - tableset: Optional[dict[str, pd.DataFrame]] = None, - ancestral_seeding: bool = False, - keys_only: bool = False, -) -> pd.DataFrame: - """ - Returns a data frame with all ancestral data joined to each record. - Column names are modified to the format `LINAGE|COLUMN_NAME`. - Lineage begins with `self` for the supplied `table`, and as older - generations are joined, the foreign keys to those generations are appended, - separated by periods. - - If `tableset` is provided, use it in place of the source data in `self.graph`. - - If `ancestral_seeding` is True, the returned dataframe only includes columns - that can be used as conditional seeds. - - If `keys_only` is True, the returned dataframe only includes columns that are primary - or foreign keys. - """ - if tableset is not None: - df = tableset[table] - else: - df = rel_data.get_table_data(table) - - if keys_only: - df = df[rel_data.get_all_key_columns(table)] - - lineage = _START_LINEAGE - df = df.add_prefix(f"{_START_LINEAGE}{_END_LINEAGE}") - return _join_parents( - rel_data, df, table, lineage, tableset, ancestral_seeding, keys_only - ) - - -def _join_parents( - rel_data: RelationalData, - df: pd.DataFrame, - table: str, - lineage: str, - tableset: Optional[dict[str, pd.DataFrame]], - ancestral_seeding: bool, - keys_only: bool, -) -> pd.DataFrame: - for foreign_key in rel_data.get_foreign_keys(table): - fk_lineage = _COL_DELIMITER.join(foreign_key.columns) - next_lineage = f"{lineage}{_GEN_DELIMITER}{fk_lineage}" - - parent_table_name = foreign_key.parent_table_name - - if ancestral_seeding: - usecols = list(rel_data.get_safe_ancestral_seed_columns(parent_table_name)) - elif keys_only: - usecols = rel_data.get_all_key_columns(parent_table_name) - else: - usecols = rel_data.get_table_columns(parent_table_name) - - if tableset is not None: - parent_data = tableset[parent_table_name][list(usecols)] - else: - parent_data = rel_data.get_table_data(parent_table_name, usecols=usecols) - - df = df.merge( - parent_data.add_prefix(f"{next_lineage}{_END_LINEAGE}"), - how="left", - left_on=[f"{lineage}{_END_LINEAGE}{col}" for col in foreign_key.columns], - right_on=[ - f"{next_lineage}{_END_LINEAGE}{parent_col}" - for parent_col in foreign_key.parent_columns - ], - ) - - df = _join_parents( - rel_data, - df, - parent_table_name, - next_lineage, - tableset, - ancestral_seeding, - keys_only, - ) - return df - - -def is_ancestral_column(column: str) -> bool: - """ - Returns True if the provided column name corresponds to an elder-generation ancestor. - """ - regex_string = rf"\{_GEN_DELIMITER}[^\{_END_LINEAGE}]+\{_END_LINEAGE}" - regex = re.compile(regex_string) - return bool(regex.search(column)) - - -def drop_ancestral_data(df: pd.DataFrame) -> pd.DataFrame: - """ - Drops ancestral columns from the given dataframe and removes the lineage prefix - from the remaining columns, restoring them to their original source names. - """ - root_columns = [ - col for col in df.columns if col.startswith(f"{_START_LINEAGE}{_END_LINEAGE}") - ] - mapper = { - col: col.removeprefix(f"{_START_LINEAGE}{_END_LINEAGE}") for col in root_columns - } - return df[root_columns].rename(columns=mapper) - - -def prepend_foreign_key_lineage(df: pd.DataFrame, fk_cols: list[str]) -> pd.DataFrame: - """ - Given a multigenerational dataframe, renames all columns such that the provided - foreign key columns act as the lineage from some child table to the provided data. - The resulting column names are elder-generation ancestral column names from the - perspective of a child table that relates to that parent via the provided foreign key. - """ - fk_lineage = _COL_DELIMITER.join(fk_cols) - - def _adjust(col: str) -> str: - return col.replace( - _START_LINEAGE, - f"{_START_LINEAGE}{_GEN_DELIMITER}{fk_lineage}", - 1, - ) - - mapper = {col: _adjust(col) for col in df.columns} - return df.rename(columns=mapper) diff --git a/src/gretel_trainer/relational/backup.py b/src/gretel_trainer/relational/backup.py deleted file mode 100644 index 81533d54..00000000 --- a/src/gretel_trainer/relational/backup.py +++ /dev/null @@ -1,165 +0,0 @@ -from __future__ import annotations - -from dataclasses import asdict, dataclass -from typing import Any, Optional, Union - -from gretel_trainer.relational.core import ForeignKey, RelationalData, Scope -from gretel_trainer.relational.json import InventedTableMetadata, ProducerMetadata - - -@dataclass -class BackupRelationalDataTable: - primary_key: list[str] - invented_table_metadata: Optional[dict[str, Any]] = None - producer_metadata: Optional[dict[str, Any]] = None - - -@dataclass -class BackupForeignKey: - table: str - constrained_columns: list[str] - referred_table: str - referred_columns: list[str] - - @classmethod - def from_fk(cls, fk: ForeignKey) -> BackupForeignKey: - return BackupForeignKey( - table=fk.table_name, - constrained_columns=fk.columns, - referred_table=fk.parent_table_name, - referred_columns=fk.parent_columns, - ) - - -@dataclass -class BackupRelationalData: - tables: dict[str, BackupRelationalDataTable] - foreign_keys: list[BackupForeignKey] - - @classmethod - def from_relational_data(cls, rel_data: RelationalData) -> BackupRelationalData: - tables = {} - foreign_keys = [] - for table in rel_data.list_all_tables(Scope.ALL): - tables[table] = BackupRelationalDataTable( - primary_key=rel_data.get_primary_key(table), - invented_table_metadata=_optionally_as_dict( - rel_data.get_invented_table_metadata(table) - ), - producer_metadata=_optionally_as_dict( - rel_data.get_producer_metadata(table) - ), - ) - - # Producer tables delegate their foreign keys to root invented tables. - # We exclude producers here to avoid adding duplicate foreign keys. - if not rel_data.is_producer_of_invented_tables(table): - foreign_keys.extend( - [ - BackupForeignKey.from_fk(key) - for key in rel_data.get_foreign_keys(table) - ] - ) - return BackupRelationalData(tables=tables, foreign_keys=foreign_keys) - - -def _optionally_as_dict( - metadata: Optional[Union[InventedTableMetadata, ProducerMetadata]] -) -> Optional[dict[str, Any]]: - if metadata is None: - return None - - return asdict(metadata) - - -@dataclass -class BackupClassify: - model_ids: dict[str, str] - - -@dataclass -class BackupTransformsTrain: - model_ids: dict[str, str] - lost_contact: list[str] - - -@dataclass -class BackupSyntheticsTrain: - model_ids: dict[str, str] - lost_contact: list[str] - - -@dataclass -class BackupGenerate: - identifier: str - preserved: list[str] - record_size_ratio: float - record_handler_ids: dict[str, str] - lost_contact: list[str] - - -@dataclass -class Backup: - project_name: str - strategy: str - refresh_interval: int - source_archive: Optional[str] - relational_data: BackupRelationalData - classify: Optional[BackupClassify] = None - transforms_train: Optional[BackupTransformsTrain] = None - synthetics_train: Optional[BackupSyntheticsTrain] = None - generate: Optional[BackupGenerate] = None - - @property - def as_dict(self): - return asdict(self) - - @classmethod - def from_dict(cls, b: dict[str, Any]): - relational_data = b["relational_data"] - brd = BackupRelationalData( - tables={ - k: BackupRelationalDataTable(**v) - for k, v in relational_data.get("tables", {}).items() - }, - foreign_keys=[ - BackupForeignKey( - table=fk["table"], - constrained_columns=fk["constrained_columns"], - referred_table=fk["referred_table"], - referred_columns=fk["referred_columns"], - ) - for fk in relational_data.get("foreign_keys", []) - ], - ) - - # source_archive previously was stored under artifact_collection - source_archive = b.get( - "source_archive", b.get("artifact_collection", {}).get("source_archive") - ) - - backup = Backup( - project_name=b["project_name"], - strategy=b["strategy"], - refresh_interval=b["refresh_interval"], - source_archive=source_archive, - relational_data=brd, - ) - - classify = b.get("classify") - if classify is not None: - backup.classify = BackupClassify(**classify) - - transforms_train = b.get("transforms_train") - if transforms_train is not None: - backup.transforms_train = BackupTransformsTrain(**transforms_train) - - synthetics_train = b.get("synthetics_train") - if synthetics_train is not None: - backup.synthetics_train = BackupSyntheticsTrain(**synthetics_train) - - generate = b.get("generate") - if generate is not None: - backup.generate = BackupGenerate(**generate) - - return backup diff --git a/src/gretel_trainer/relational/connectors.py b/src/gretel_trainer/relational/connectors.py deleted file mode 100644 index 2547056b..00000000 --- a/src/gretel_trainer/relational/connectors.py +++ /dev/null @@ -1,211 +0,0 @@ -""" -This module provides the "Connector" class which allows for reading from -and writing to databases and data warehouses. This class can handle -metadata and table extraction automatically. When this is done with -the "Connector.extract" method, a "RelationalData" instance is provided -which you can then use with the "MultiTable" class to process data with -Gretel Transforms, Classify, Synthetics, or a combination of both. -""" - -from __future__ import annotations - -import logging - -from pathlib import Path -from typing import Optional - -import pandas as pd - -from sqlalchemy import create_engine, inspect -from sqlalchemy.engine.base import Engine -from sqlalchemy.exc import OperationalError -from sqlalchemy.sql import text - -from gretel_trainer.relational.core import ( - DEFAULT_RELATIONAL_SOURCE_DIR, - MultiTableException, - RelationalData, -) -from gretel_trainer.relational.extractor import ExtractorConfig, TableExtractor - -logger = logging.getLogger(__name__) - - -class Connector: - """ - Wraps connections to relational databases and backups. - - Args: - engine (sqlalchemy.engine.base.Engine): A SQLAlchemy engine configured - to connect to some database. A variety of helper functions exist to - assist with creating engines for some popular databases, but these - should not be considered exhaustive. You may need to install - additional dialect/adapter packages via pip, such as psycopg2 for - connecting to postgres. - - For more detail, see the SQLAlchemy docs: - https://docs.sqlalchemy.org/en/20/core/engines.html - """ - - def __init__(self, engine: Engine): - self.engine = engine - logger.info("Connecting to database") - try: - self.engine.connect() - except OperationalError as e: - logger.error(f"{e}, {e.__cause__}") - raise e - logger.info("Successfully connected to db") - - @classmethod - def from_conn_str(cls, conn_str: str) -> Connector: - """ - Alternate constructor that creates a Connector instance - directly from a connection string. - - Args: - conn_str: A full connection string for the target database. - """ - engine = create_engine(conn_str) - return cls(engine) - - def extract( - self, - only: Optional[set[str]] = None, - ignore: Optional[set[str]] = None, - schema: Optional[str] = None, - config: Optional[ExtractorConfig] = None, - storage_dir: str = DEFAULT_RELATIONAL_SOURCE_DIR, - ) -> RelationalData: - """ - Extracts table data and relationships from the database. Optional args include: - - Args: - only: Only extract these table names, cannot be used with `ignore` - ignore: Skip extracting these table names, cannot be used with `only` - schema: An optional schema name that is passed through to SQLAlchemy, may only - be used with certain dialects. - config: An optional extraction config. This config can be used to only include - specific tables, ignore specific tables, and configure subsetting. Please - see the `ExtractorConfig` docs for more details. - storage_dir: The output directory where extracted data is stored. - """ - if only is not None and ignore is not None: - raise MultiTableException("Cannot specify both `only` and `ignore`.") - - if config is None: - config = ExtractorConfig( - only=only, ignore=ignore, schema=schema # pyright: ignore - ) - - storage_dir_path = Path(storage_dir) - storage_dir_path.mkdir(parents=True, exist_ok=True) - - extractor = TableExtractor( - config=config, connector=self, storage_dir=storage_dir_path - ) - extractor.sample_tables(schema=schema) - - # We ensure to re-create RelationalData after extraction so - # we can account for any embedded JSON. This also loads - # each table as a DF in the object which is currently - # the expected behavior for later operations. - extractor._relational_data = extractor._create_rel_data( - extracted_tables=extractor.table_order - ) - - return extractor.relational_data - - def save( - self, - output_tables: dict[str, pd.DataFrame], - prefix: str = "", - source_relational_data: Optional[RelationalData] = None, - ) -> None: - # Default to unsorted if source_relational_data isn't provided - ordered_source_tables = output_tables.keys() - if source_relational_data is not None: - # Depencencies can exist between tables due to foreign key constraints - # This topologically sorted dependency graph allows for deleting/writing tables in the proper order - ordered_source_tables = ( - source_relational_data.list_tables_parents_before_children() - ) - # Tables may not be in output_tables if they failed to train/generate - table_collection = [ - tbl for tbl in ordered_source_tables if tbl in output_tables - ] - - logger.info("\n--- Executing SQL insertion for output tables ---") - logger.info("Emptying output tables if they already exist.") - - # Traverse dependency graph from leaf up to parent - for source_table_name in reversed(table_collection): - # Check that this table was not dropped with an "only" param - # and check if table exists in destination db. - if inspect(self.engine).has_table(f"{prefix}{source_table_name}"): - logger.info( - f"Table: {prefix}{source_table_name} - Exists in destination db. Deleting all records from table." - ) - # Not all SQL DBs support table truncation with FK constraints (ie. MSSQL) - self.engine.execute( - text(f"delete from {prefix}{source_table_name}").execution_options( - autocommit=True - ) - ) - else: - logger.info( - f"Table: {prefix}{source_table_name} - Does not exist in destination db. Skipping record deletion." - ) - - logger.info("\nWriting data to destination.") - - # Traverse dependency graph from parent down to leaf - for source_table_name in table_collection: - logger.info(f"Table: {prefix}{source_table_name} - Writing data.") - data = output_tables[source_table_name] - data.to_sql( - f"{prefix}{source_table_name}", - con=self.engine, - if_exists="append", - index=False, - ) - - -def sqlite_conn(path: str) -> Connector: - engine = create_engine(f"sqlite:///{path}") - return Connector(engine) - - -def postgres_conn( - *, user: str, password: str, host: str, port: int, database: str -) -> Connector: - conn_string = f"postgresql://{user}:{password}@{host}:{port}/{database}" - engine = create_engine(conn_string) - return Connector(engine) - - -def mysql_conn(*, user: str, password: str, host: str, port: int, database: str): - conn_string = f"mysql://{user}:{password}@{host}:{port}/{database}" - engine = create_engine(conn_string) - return Connector(engine) - - -def mariadb_conn(*, user: str, password: str, host: str, port: int, database: str): - conn_string = f"mariadb://{user}:{password}@{host}:{port}/{database}" - engine = create_engine(conn_string) - return Connector(engine) - - -def snowflake_conn( - *, - user: str, - password: str, - account_identifier: str, - database: str, - schema: str, - warehouse: str, - role: str, -) -> Connector: - conn_string = f"snowflake://{user}:{password}@{account_identifier}/{database}/{schema}?warehouse={warehouse}&role={role}" - engine = create_engine(conn_string) - return Connector(engine) diff --git a/src/gretel_trainer/relational/core.py b/src/gretel_trainer/relational/core.py deleted file mode 100644 index e9a0422d..00000000 --- a/src/gretel_trainer/relational/core.py +++ /dev/null @@ -1,932 +0,0 @@ -""" -This module exposes the "RelationalData" class to users, which allows the processing -of relational databases and data warehouses with Gretel.ai. - -When using a "Connector" or a "TableExtractor" instance to automatically connect -to a database, a "RelationalData" instance will be created for you that contains -all of the learned metadata. - -If you are processing relational tables manually, with your own CSVs, you -will need to create a "RelationalData" instance and populate it yourself. - -Please see the specific docs for the "RelationalData" class on how to do this. -""" - -from __future__ import annotations - -import logging -import shutil -import tempfile - -from dataclasses import dataclass, replace -from enum import Enum -from pathlib import Path -from typing import Any, Optional, Protocol, Union - -import networkx -import pandas as pd - -from networkx.algorithms.cycles import simple_cycles -from networkx.algorithms.dag import dag_longest_path_length, topological_sort -from networkx.classes.function import number_of_edges -from pandas.api.types import is_string_dtype - -import gretel_trainer.relational.json as relational_json - -from gretel_client.projects.artifact_handlers import open_artifact -from gretel_trainer.relational.json import ( - IngestResponseT, - InventedTableMetadata, - ProducerMetadata, -) - -logger = logging.getLogger(__name__) - -DEFAULT_RELATIONAL_SOURCE_DIR = "relational_source" -PREVIEW_ROW_COUNT = 5 - - -class MultiTableException(Exception): - pass - - -GretelModelConfig = Union[str, Path, dict] - - -@dataclass -class ForeignKey: - table_name: str - columns: list[str] - parent_table_name: str - parent_columns: list[str] - - def is_composite(self) -> bool: - return len(self.columns) > 1 - - -UserFriendlyDataT = Union[pd.DataFrame, str, Path] -UserFriendlyPrimaryKeyT = Optional[Union[str, list[str]]] - - -class Scope(str, Enum): - """ - Various non-mutually-exclusive sets of tables known to the system - """ - - ALL = "all" - """ - Every known table (all user-supplied tables, all invented tables) - """ - - PUBLIC = "public" - """ - Includes all user-supplied tables, omits invented tables - """ - - MODELABLE = "modelable" - """ - Includes flat source tables and all invented tables, omits source tables that led to invented tables - """ - - EVALUATABLE = "evaluatable" - """ - A subset of MODELABLE that additionally omits invented child tables (but includes invented root tables) - """ - - INVENTED = "invented" - """ - Includes all tables invented from un-modelable user source tables - """ - - -@dataclass -class TableMetadata: - primary_key: list[str] - source: str - columns: list[str] - invented_table_metadata: Optional[InventedTableMetadata] = None - producer_metadata: Optional[ProducerMetadata] = None - safe_ancestral_seed_columns: Optional[set[str]] = None - - -@dataclass -class _RemovedTableMetadata: - source: str - primary_key: list[str] - fks_to_parents: list[ForeignKey] - fks_from_children: list[ForeignKey] - - -class SourceDataHandler(Protocol): - def resolve_data_location(self, data: Union[str, Path]) -> str: - """ - Returns a string handle that can be used with smart_open to read source data. - """ - ... - - def put_source(self, name: str, data: UserFriendlyDataT) -> str: - """ - Ensures source data exists in a preferred, accessible internal location. - Returns a string handle to the data that can be used with smart_open. - """ - ... - - def put_invented_table_source(self, name: str, data: pd.DataFrame) -> str: - """ - Ensures invented table data exists in a preferred, accessible internal location. - Returns a string handle to the data that can be used with smart_open. - """ - ... - - -class SDKSourceDataHandler: - def __init__(self, directory: Union[str, Path]): - self.dir = Path(directory) - self.dir.mkdir(parents=True, exist_ok=True) - - def resolve_data_location(self, data: Union[str, Path]) -> str: - """ - Returns the provided data location as a string. May be a relative path. - """ - return str(data) - - def put_source(self, name: str, data: UserFriendlyDataT) -> str: - """ - Writes (or copies) the provided data to a CSV file in the local working directory. - """ - source_path = self.dir / f"{name}.csv" - if isinstance(data, pd.DataFrame): - data.to_csv(source_path, index=False) - elif isinstance(data, (str, Path)): - shutil.copyfile(data, source_path) - - return str(source_path) - - def put_invented_table_source(self, name: str, data: pd.DataFrame) -> str: - """ - Writes invented table data to the local working directory. - """ - return self.put_source(name, data) - - -class RelationalData: - """ - Stores information about multiple tables and their relationships. When - using this object you could create it without any arguments and rely - on the instance methods for adding tables and key relationships. - - Example:: - - rel_data = RelationalData() - rel_data.add_table(...) - rel_data.add_table(...) - rel_data.add_foreign_key_constraint(...) - - See the specific method docstrings for details on each method. - """ - - def __init__( - self, - directory: Optional[Union[str, Path]] = None, - source_data_handler: Optional[SourceDataHandler] = None, - ): - self.source_data_handler = source_data_handler or SDKSourceDataHandler( - directory=Path(directory or DEFAULT_RELATIONAL_SOURCE_DIR) - ) - self.graph = networkx.DiGraph() - - @property - def is_empty(self) -> bool: - """ - Return a bool to indicate if the `RelationalData` contains - any table information. - """ - return not self.graph.number_of_nodes() > 0 - - @property - def foreign_key_cycles(self) -> list[list[str]]: - """ - Returns lists of tables that have cyclic foreign key relationships. - """ - return list(simple_cycles(self.graph)) - - def restore(self, tableset: dict[str, pd.DataFrame]) -> dict[str, pd.DataFrame]: - """Restores a given tableset (presumably output from some MultiTable workflow, - i.e. transforms or synthetics) to its original shape (specifically, "re-nests" - any JSON that had been expanded out. - - Users should rely on MultiTable calling this internally when appropriate and not - need to do so themselves. - """ - restored = {} - discarded = set() - - # Restore any invented tables to nested-JSON format - producers = { - table: pmeta - for table in self.list_all_tables(Scope.ALL) - if (pmeta := self.get_producer_metadata(table)) is not None - } - for table_name, producer_metadata in producers.items(): - tables = { - table: data - for table, data in tableset.items() - if table in producer_metadata.table_names - } - data = relational_json.restore( - tables=tables, - rel_data=self, - root_table_name=producer_metadata.invented_root_table_name, - original_columns=self.get_table_columns(table_name), - table_name_mappings=producer_metadata.table_name_mappings, - original_table_name=table_name, - ) - if data is not None: - restored[table_name] = data - discarded.update(producer_metadata.table_names) - - # Add remaining tables - for table, data in tableset.items(): - if table not in discarded: - restored[table] = data - - return restored - - def add_table( - self, - *, - name: str, - primary_key: UserFriendlyPrimaryKeyT, - data: UserFriendlyDataT, - ) -> None: - """ - Add a table. The primary key can be None (if one is not defined on the table), - a string column name (most common), or a list of multiple string column names (composite key). - - This call MAY result in multiple tables getting "registered," specifically if - the table includes nested JSON data. - """ - primary_key = self._format_key_column(primary_key) - - # Preview data to get list of columns and determine if there is any JSON - if isinstance(data, pd.DataFrame): - preview_df = data.head(PREVIEW_ROW_COUNT) - elif isinstance(data, (str, Path)): - data_location = self.source_data_handler.resolve_data_location(data) - with open_artifact(data_location, "rb") as d: - preview_df = pd.read_csv(d, nrows=PREVIEW_ROW_COUNT) - columns = list(preview_df.columns) - json_cols = relational_json.get_json_columns(preview_df) - - # Write/copy source to preferred internal location - data_to_write = data - if isinstance(data, pd.DataFrame): - data_to_write = relational_json.jsonencode(data, json_cols) - source = self.source_data_handler.put_source(name, data_to_write) - - # If we found JSON in preview above, run JSON ingestion on full data - rj_ingest = None - if len(json_cols) > 0: - logger.info( - f"Detected JSON data in table `{name}`. Running JSON normalization." - ) - if isinstance(data, pd.DataFrame): - df = data - elif isinstance(data, (str, Path)): - with open_artifact(data, "rb") as d: - df = pd.read_csv(d) - rj_ingest = relational_json.ingest(name, primary_key, df, json_cols) - - # Add the table(s) - if rj_ingest is not None: - self._add_producer_and_invented_tables( - name, primary_key, source, columns, rj_ingest - ) - else: - self._add_single_table( - name=name, - primary_key=primary_key, - source=source, - columns=columns, - ) - - def _add_producer_and_invented_tables( - self, - table: str, - primary_key: list[str], - source: str, - columns: list[str], - rj_ingest: IngestResponseT, - ) -> None: - commands, producer_metadata = rj_ingest - tables, foreign_keys = commands - - # Add the producer table - self._add_single_table( - name=table, - primary_key=primary_key, - source=source, - columns=columns, - producer_metadata=producer_metadata, - ) - - # Add the invented tables - for tbl in tables: - name = tbl["name"] - df = tbl["data"] - tbl_source = self.source_data_handler.put_invented_table_source(name, df) - self._add_single_table( - name=name, - primary_key=tbl["primary_key"], - source=tbl_source, - columns=list(df.columns), - invented_table_metadata=tbl["invented_table_metadata"], - ) - for foreign_key in foreign_keys: - self.add_foreign_key_constraint(**foreign_key) - - def _add_single_table( - self, - *, - name: str, - primary_key: UserFriendlyPrimaryKeyT, - source: str, - columns: Optional[list[str]] = None, - invented_table_metadata: Optional[InventedTableMetadata] = None, - producer_metadata: Optional[ProducerMetadata] = None, - ) -> None: - primary_key = self._format_key_column(primary_key) - if columns is not None: - cols = columns - else: - with open_artifact(source, "rb") as src: - cols = list(pd.read_csv(src, nrows=1).columns) - metadata = TableMetadata( - primary_key=primary_key, - source=source, - columns=cols, - invented_table_metadata=invented_table_metadata, - producer_metadata=producer_metadata, - ) - self.graph.add_node(name, metadata=metadata) - - def _get_table_metadata(self, table: str) -> TableMetadata: - try: - return self.graph.nodes[table]["metadata"] - except KeyError: - raise MultiTableException(f"Unrecognized table: `{table}`") - - def set_primary_key( - self, *, table: str, primary_key: UserFriendlyPrimaryKeyT - ) -> None: - """ - (Re)set the primary key on an existing table. - If the table does not yet exist in the instance's collection, add it via `add_table`. - """ - if table not in self.list_all_tables(Scope.ALL): - raise MultiTableException(f"Unrecognized table name: `{table}`") - - primary_key = self._format_key_column(primary_key) - - known_columns = self.get_table_columns(table) - for col in primary_key: - if col not in known_columns: - raise MultiTableException(f"Unrecognized column name: `{col}`") - - # Prevent interfering with manually invented tables - if self._is_invented(table): - raise MultiTableException("Cannot change primary key on invented tables") - - # If `table` is a producer of invented tables, we redo JSON ingestion - # to ensure primary keys are set properly on invented tables - elif self.is_producer_of_invented_tables(table): - source = self.get_table_source(table) - with tempfile.TemporaryDirectory() as tmpdir: - tmpfile = Path(tmpdir) / f"{table}.csv" - shutil.move(source, tmpfile) - removal_metadata = self._remove_producer(table) - self.add_table(name=table, primary_key=primary_key, data=tmpfile) - self._restore_fks_in_both_directions(table, removal_metadata) - - # At this point we are working with a "normal" table - else: - self._get_table_metadata(table).primary_key = primary_key - self._clear_safe_ancestral_seed_columns(table) - - def _restore_fks_in_both_directions( - self, table: str, removal_metadata: _RemovedTableMetadata - ) -> None: - for fk in removal_metadata.fks_to_parents: - self.add_foreign_key_constraint( - table=table, - constrained_columns=fk.columns, - referred_table=fk.parent_table_name, - referred_columns=fk.parent_columns, - ) - - for fk in removal_metadata.fks_from_children: - self.add_foreign_key_constraint( - table=fk.table_name, - constrained_columns=fk.columns, - referred_table=table, - referred_columns=fk.parent_columns, - ) - - def _get_user_defined_fks_to_table(self, table: str) -> list[ForeignKey]: - return [ - fk - for child in self.graph.predecessors(table) - for fk in self.get_foreign_keys(child) - if fk.parent_table_name == table and not self._is_invented(fk.table_name) - ] - - def _remove_producer(self, table: str) -> _RemovedTableMetadata: - """ - Removes the producer table and all its invented tables from the graph - (which in turn removes all edges (foreign keys) to/from other tables). - - Returns a _RemovedTableMetadata object for restoring metadata in broader "update" contexts. - """ - table_metadata = self._get_table_metadata(table) - producer_metadata = table_metadata.producer_metadata - - if producer_metadata is None: - raise MultiTableException( - "Cannot remove invented tables from non-producer table" - ) - - removal_metadata = _RemovedTableMetadata( - source=table_metadata.source, - primary_key=table_metadata.primary_key, - fks_to_parents=self.get_foreign_keys(table), - fks_from_children=self._get_user_defined_fks_to_table( - self._get_fk_delegate_table(table) - ), - ) - - for invented_table_name in producer_metadata.table_names: - if invented_table_name in self.graph.nodes: - self._remove_node(invented_table_name) - self._remove_node(table) - - return removal_metadata - - def _remove_node(self, table: str) -> None: - self.graph.remove_node(table) - - def _format_key_column(self, key: Optional[Union[str, list[str]]]) -> list[str]: - if key is None: - return [] - elif isinstance(key, str): - return [key] - else: - return key - - def add_foreign_key_constraint( - self, - *, - table: str, - constrained_columns: list[str], - referred_table: str, - referred_columns: list[str], - ) -> None: - """ - Add a foreign key relationship between two tables. - - Args: - table: The table name that contains the foreign key. - constrained_columns: The column name(s) defining a relationship to the `referred_table` (the parent table). - referred_table: The table name that the foreign key in `table` refers to (the parent table). - referred_columns: The column name(s) in the parent table that the `constrained_columns` point to. - """ - known_tables = self.list_all_tables(Scope.ALL) - - abort = False - if table not in known_tables: - logger.warning(f"Unrecognized table name: `{table}`") - abort = True - if referred_table not in known_tables: - logger.warning(f"Unrecognized table name: `{referred_table}`") - abort = True - - if abort: - raise MultiTableException("Unrecognized table(s) in foreign key arguments") - - if len(constrained_columns) != len(referred_columns): - logger.warning( - "Constrained and referred columns must be of the same length" - ) - raise MultiTableException( - "Invalid column constraints in foreign key arguments" - ) - - table_all_columns = self.get_table_columns(table) - for col in constrained_columns: - if col not in table_all_columns: - logger.warning( - f"Constrained column `{col}` does not exist on table `{table}`" - ) - abort = True - referred_table_all_columns = self.get_table_columns(referred_table) - for col in referred_columns: - if col not in referred_table_all_columns: - logger.warning( - f"Referred column `{col}` does not exist on table `{referred_table}`" - ) - abort = True - - if abort: - raise MultiTableException("Unrecognized column(s) in foreign key arguments") - - fk_delegate_table = self._get_fk_delegate_table(table) - fk_delegate_referred_table = self._get_fk_delegate_table(referred_table) - - self.graph.add_edge(fk_delegate_table, fk_delegate_referred_table) - edge = self.graph.edges[fk_delegate_table, fk_delegate_referred_table] - via = edge.get("via", []) - via.append( - ForeignKey( - table_name=fk_delegate_table, - columns=constrained_columns, - parent_table_name=fk_delegate_referred_table, - parent_columns=referred_columns, - ) - ) - edge["via"] = via - self._clear_safe_ancestral_seed_columns(fk_delegate_table) - self._clear_safe_ancestral_seed_columns(table) - - def remove_foreign_key_constraint( - self, table: str, constrained_columns: list[str] - ) -> None: - """ - Remove an existing foreign key. - """ - if table not in self.list_all_tables(Scope.ALL): - raise MultiTableException(f"Unrecognized table name: `{table}`") - - key_to_remove = None - for fk in self.get_foreign_keys(table): - if fk.columns == constrained_columns: - key_to_remove = fk - - if key_to_remove is None: - raise MultiTableException( - f"`{table} does not have a foreign key with constrained columns {constrained_columns}`" - ) - - fk_delegate_table = self._get_fk_delegate_table(table) - - edge = self.graph.edges[fk_delegate_table, key_to_remove.parent_table_name] - via = edge.get("via") - via.remove(key_to_remove) - if len(via) == 0: - self.graph.remove_edge(fk_delegate_table, key_to_remove.parent_table_name) - else: - edge["via"] = via - self._clear_safe_ancestral_seed_columns(fk_delegate_table) - self._clear_safe_ancestral_seed_columns(table) - - def update_table_data(self, table: str, data: UserFriendlyDataT) -> None: - """ - Set a DataFrame as the table data for a given table name. - """ - if self._is_invented(table): - raise MultiTableException("Cannot modify invented tables' data") - elif self.is_producer_of_invented_tables(table): - removal_metadata = self._remove_producer(table) - else: - removal_metadata = _RemovedTableMetadata( - source="", # we don't care about the old data - primary_key=self.get_primary_key(table), - fks_to_parents=self.get_foreign_keys(table), - fks_from_children=self._get_user_defined_fks_to_table(table), - ) - self._remove_node(table) - - self.add_table(name=table, primary_key=removal_metadata.primary_key, data=data) - self._restore_fks_in_both_directions(table, removal_metadata) - - def list_all_tables(self, scope: Scope = Scope.MODELABLE) -> list[str]: - """ - Returns a list of table names belonging to the provided Scope. - See "Scope" enum documentation for details. - By default, returns tables that can be submitted as jobs to Gretel - (i.e. that are MODELABLE). - """ - graph_nodes = list(self.graph.nodes) - - producer_tables = [ - t for t in graph_nodes if self.is_producer_of_invented_tables(t) - ] - - modelable_tables = [] - evaluatable_tables = [] - invented_tables: list[str] = [] - - for n in graph_nodes: - meta = self._get_table_metadata(n) - if (invented_meta := meta.invented_table_metadata) is not None: - invented_tables.append(n) - if invented_meta.invented_root_table_name == n: - evaluatable_tables.append(n) - if not invented_meta.empty: - modelable_tables.append(n) - else: - if n not in producer_tables: - modelable_tables.append(n) - evaluatable_tables.append(n) - - if scope == Scope.MODELABLE: - return modelable_tables - elif scope == Scope.EVALUATABLE: - return evaluatable_tables - elif scope == Scope.INVENTED: - return invented_tables - elif scope == Scope.ALL: - return graph_nodes - elif scope == Scope.PUBLIC: - return [t for t in graph_nodes if t not in invented_tables] - - def _is_invented(self, table: str) -> bool: - return self.get_invented_table_metadata(table) is not None - - def is_producer_of_invented_tables(self, table: str) -> bool: - return self.get_producer_metadata(table) is not None - - def get_modelable_table_names(self, table: str) -> list[str]: - """ - Returns a list of MODELABLE table names connected to the provided table. - If the provided table is the source of invented tables, returns the modelable invented tables created from it. - If the provided table is itself modelable, returns that table name back. - Otherwise returns an empty list. - """ - try: - table_metadata = self._get_table_metadata(table) - except MultiTableException: - return [] - - if (pmeta := table_metadata.producer_metadata) is not None: - return [ - t - for t in pmeta.table_names - if t in self.list_all_tables(Scope.MODELABLE) - ] - elif table in self.list_all_tables(Scope.MODELABLE): - return [table] - else: - return [] - - def get_public_name(self, table: str) -> Optional[str]: - if (imeta := self.get_invented_table_metadata(table)) is not None: - return imeta.original_table_name - - return table - - def get_invented_table_metadata( - self, table: str - ) -> Optional[InventedTableMetadata]: - return self._get_table_metadata(table).invented_table_metadata - - def get_producer_metadata(self, table: str) -> Optional[ProducerMetadata]: - return self._get_table_metadata(table).producer_metadata - - def get_parents(self, table: str) -> list[str]: - """ - Given a table name, return the table names that are referred to - by the foreign keys in this table. - """ - return list(self.graph.successors(table)) - - def get_ancestors(self, table: str) -> list[str]: - """ - Same as `get_parents` except recursively keep adding - parent tables until there are no more. - """ - - def _add_parents(ancestors, table): - parents = self.get_parents(table) - if len(parents) > 0: - ancestors.update(parents) - for parent in parents: - _add_parents(ancestors, parent) - - ancestors = set() - _add_parents(ancestors, table) - - return list(ancestors) - - def get_descendants(self, table: str) -> list[str]: - """ - Given a table name, recursively return all tables that - carry foreign keys that reference the primary key in this table - and all subsequent tables that are discovered. - """ - - def _add_children(descendants, table): - children = list(self.graph.predecessors(table)) - if len(children) > 0: - descendants.update(children) - for child in children: - _add_children(descendants, child) - - descendants = set() - _add_children(descendants, table) - - return list(descendants) - - def list_tables_parents_before_children(self) -> list[str]: - """ - Returns a list of all tables with the guarantee that a parent table - appears before any of its children. No other guarantees about order - are made, e.g. the following (and others) are all valid outputs: - [p1, p2, c1, c2] or [p2, c2, p1, c1] or [p2, p1, c1, c2] etc. - """ - return list(reversed(list(topological_sort(self.graph)))) - - def get_primary_key(self, table: str) -> list[str]: - """ - Return the list of columns defining the primary key for a table. - It may be a single column or multiple columns (composite key). - """ - return self._get_table_metadata(table).primary_key - - def get_table_source(self, table: str) -> str: - return self._get_table_metadata(table).source - - def get_table_data( - self, table: str, usecols: Optional[list[str]] = None - ) -> pd.DataFrame: - """ - Return the table contents for a given table name as a DataFrame. - """ - source = self.get_table_source(table) - usecols = usecols or self.get_table_columns(table) - with open_artifact(source, "rb") as src: - return pd.read_csv(src, usecols=usecols) - - def get_table_columns(self, table: str) -> list[str]: - """ - Return the column names for a provided table name. - """ - return self._get_table_metadata(table).columns - - def get_table_row_count(self, table: str) -> int: - """ - Return the number of rows in the table. - """ - source = self.get_table_source(table) - with open_artifact(source, "rb") as src: - return sum(1 for line in src) - 1 - - def get_safe_ancestral_seed_columns(self, table: str) -> set[str]: - safe_columns = self._get_table_metadata(table).safe_ancestral_seed_columns - if safe_columns is None: - safe_columns = self._set_safe_ancestral_seed_columns(table) - return safe_columns - - def _set_safe_ancestral_seed_columns(self, table: str) -> set[str]: - cols = set() - - # Key columns are always kept - cols.update(self.get_primary_key(table)) - for fk in self.get_foreign_keys(table): - cols.update(fk.columns) - - data = self.get_table_data(table) - for col in self.get_table_columns(table): - if col in cols: - continue - if _ok_for_train_and_seed(col, data): - cols.add(col) - - self._get_table_metadata(table).safe_ancestral_seed_columns = cols - return cols - - def _clear_safe_ancestral_seed_columns(self, table: str) -> None: - self._get_table_metadata(table).safe_ancestral_seed_columns = None - - def _get_fk_delegate_table(self, table: str) -> str: - if (pmeta := self.get_producer_metadata(table)) is not None: - return pmeta.invented_root_table_name - - return table - - def get_foreign_keys( - self, table: str, rename_invented_tables: bool = False - ) -> list[ForeignKey]: - def _rename_invented(fk: ForeignKey) -> ForeignKey: - table_name = self.get_public_name(fk.table_name) - parent_table_name = self.get_public_name(fk.parent_table_name) - return replace( - fk, table_name=table_name, parent_table_name=parent_table_name - ) - - table = self._get_fk_delegate_table(table) - foreign_keys = [] - for parent in self.get_parents(table): - fks = self.graph.edges[table, parent]["via"] - foreign_keys.extend(fks) - - if rename_invented_tables: - return [_rename_invented(fk) for fk in foreign_keys] - else: - return foreign_keys - - def get_all_key_columns(self, table: str) -> list[str]: - all_key_cols = [] - all_key_cols.extend(self.get_primary_key(table)) - for fk in self.get_foreign_keys(table): - all_key_cols.extend(fk.columns) - return sorted(list(set(all_key_cols))) - - def any_table_relationships(self) -> bool: - return number_of_edges(self.graph) > 0 - - def debug_summary(self) -> dict[str, Any]: - if len(self.foreign_key_cycles) > 0: - max_depth = "indeterminate (cycles in foreign keys)" - else: - max_depth = dag_longest_path_length(self.graph) - public_table_count = len(self.list_all_tables(Scope.PUBLIC)) - invented_table_count = len(self.list_all_tables(Scope.INVENTED)) - - all_tables = self.list_all_tables(Scope.ALL) - total_foreign_key_count = 0 - tables = {} - for table in all_tables: - this_table_foreign_key_count = 0 - foreign_keys = [] - for key in self.get_foreign_keys(table): - total_foreign_key_count = total_foreign_key_count + 1 - this_table_foreign_key_count = this_table_foreign_key_count + 1 - foreign_keys.append( - { - "columns": key.columns, - "parent_table_name": key.parent_table_name, - "parent_columns": key.parent_columns, - } - ) - table_metadata = { - "column_count": len(self.get_table_columns(table)), - "primary_key": self.get_primary_key(table), - "foreign_key_count": this_table_foreign_key_count, - "foreign_keys": foreign_keys, - "is_invented_table": self._is_invented(table), - } - if (producer_metadata := self.get_producer_metadata(table)) is not None: - table_metadata["invented_table_details"] = { - "table_type": "producer", - "json_to_table_mappings": producer_metadata.table_name_mappings, - } - elif ( - invented_table_metadata := self.get_invented_table_metadata(table) - ) is not None: - table_metadata["invented_table_details"] = { - "table_type": "invented", - "json_breadcrumb_path": invented_table_metadata.json_breadcrumb_path, - } - tables[table] = table_metadata - - return { - "foreign_key_count": total_foreign_key_count, - "max_depth": max_depth, - "tables": tables, - "public_table_count": public_table_count, - "invented_table_count": invented_table_count, - } - - -def _ok_for_train_and_seed(col: str, df: pd.DataFrame) -> bool: - if _is_highly_nan(col, df): - return False - - if _is_highly_unique_categorical(col, df): - return False - - return True - - -def _is_highly_nan(col: str, df: pd.DataFrame) -> bool: - total = len(df) - if total == 0: - return False - - missing = df[col].isnull().sum() - missing_perc = missing / total - return missing_perc > 0.2 - - -def _is_highly_unique_categorical(col: str, df: pd.DataFrame) -> bool: - return is_string_dtype(df[col]) and _percent_unique(col, df) >= 0.7 - - -def _percent_unique(col: str, df: pd.DataFrame) -> float: - col_no_nan = df[col].dropna() - total = len(col_no_nan) - distinct = col_no_nan.nunique() - - if total == 0: - return 0.0 - else: - return distinct / total diff --git a/src/gretel_trainer/relational/extractor.py b/src/gretel_trainer/relational/extractor.py deleted file mode 100644 index 1b04aef2..00000000 --- a/src/gretel_trainer/relational/extractor.py +++ /dev/null @@ -1,618 +0,0 @@ -""" -Extract database or data warehouse SQL tables to flat files with optional subsetting. -""" - -from __future__ import annotations - -import logging - -from contextlib import nullcontext -from dataclasses import asdict, dataclass -from enum import Enum -from pathlib import Path -from threading import Lock -from typing import Iterator, Optional, TYPE_CHECKING - -import dask.dataframe as dd -import numpy as np -import pandas as pd - -from sqlalchemy import func, inspect, MetaData, select, Table, tuple_ - -from gretel_trainer.relational.core import RelationalData - -if TYPE_CHECKING: - from sqlalchemy.engine import Engine - - from gretel_trainer.relational.connectors import Connector - -logger = logging.getLogger(__name__) - - -class SampleMode(str, Enum): - RANDOM = "random" - CONTIGUOUS = "contiguous" - - -@dataclass -class ExtractorConfig: - """ - Configuration class for extracting tables from a remote database. An instance - of this class should be passed as a param to the "TableExtractor" constructor. - """ - - target_row_count: float = -1.0 - """ - The target number of rows (or ratio of rows) to sample. This will be used as the sample target for "leaf" tables, - or tables that do not have any references to their primary keys. If this number is >= 1 then - that number of rows will be used, if the value is between 0..1 then it is considered to be a percetange - of the total number of rows. A 0 value will just extract headers and -1 will extract entire tables. - - The default value, -1, implies that full tables should be extracted. - """ - - sample_mode: SampleMode = SampleMode.CONTIGUOUS - """ - The method to sample records from tables that do not contain - any primary keys that are referenced by other tables. We call these - "leaf" tables because in a graph representation they do not - have any children. - - The default mode is to sample contiguously based on how the - specific database/data warehouse supports it. This essentially - does a 'SELECT * FROM table LIMIT ' based on the provided - `target_row_count`. - - If using "random" sampling, the extractor will attempt to select - leaf table rows randomly, however different dialects - have different support for this. If the "random" sampling fails, - the extractor will fall back to the "contiguous" method. - """ - - only: Optional[set[str]] = None - """ - Only extract these tables. Cannot be used with `ignore.` - """ - - ignore: Optional[set[str]] = None - """ - Ignore these tables during extraction. Cannot be used with `only.` - """ - - schema: Optional[None] = None - """ - Limit scope to a specific schema, this is a pass-through param to SQLAlchemy. It is not - supported by all dialects - """ - - def __post_init__(self): - errors = [] - - if self.sample_mode not in (SampleMode.RANDOM, SampleMode.CONTIGUOUS): - raise ValueError("Invalid `sample_mode`") - - if self.target_row_count < -1: - errors.append("The `target_row_count` must be -1 or higher") - - if self.ignore is not None and self.only is not None: - errors.append("Cannot specify both `only` and `ignore` together") - - if self.sample_mode not in ("random", "contiguous"): - errors.append("`sample_mode` must be one of 'random', 'contiguous'") - - if errors: - raise ValueError(f"The following errors occured: {', '.join(errors)}") - - @property - def entire_table(self) -> bool: - """ - Returns True if the config is set to extract entire tables - from the remote database. - """ - return self.target_row_count == -1 - - @property - def empty_table(self) -> bool: - """ - Returns True if the config is set to only extract column names. - """ - return self.target_row_count == 0 - - def _should_skip_table(self, table_name: str) -> bool: - if self.only and table_name not in self.only: - return True - - if self.ignore and table_name in self.ignore: - return True - - return False - - -def _determine_sample_size(config: ExtractorConfig, total_row_count: int) -> int: - """ - Given the actual total row count of a table, determine how - many rows we should sample from it. - """ - if config.target_row_count >= 1: - return int(config.target_row_count) - - if config.entire_table: - return total_row_count - - return int(total_row_count * config.target_row_count) - - -@dataclass -class _TableSession: - table: Table - engine: Engine - - @property - def total_row_count(self) -> int: - with self.engine.connect() as conn: - query = select(func.count()).select_from(self.table) - count = conn.execute(query).scalar() - return 0 if count is None else int(count) - - @property - def total_column_count(self) -> int: - return len(self.columns) - - @property - def columns(self) -> list[str]: - return [column.name for column in self.table.columns] - - -@dataclass -class _PKValues: - """ - Contains information that is needed to sample rows from a parent table - where we need the foreign key values of the table's children so we - can extract only those rows from the parent table. - """ - - table_name: str - values_ddf: dd.DataFrame # pyright: ignore - column_names: list[str] - - -@dataclass -class TableMetadata: - """ - Contains information about an extracted table. - """ - - original_row_count: int - sampled_row_count: int - column_count: int - - def dict(self) -> dict[str, int]: - return asdict(self) - - -def _stream_df_to_path( - df_iter: Iterator[pd.DataFrame], path: Path, lock: Optional[Lock] = None -) -> int: - """ - Stream the contents of a DF to disk, this function only does appending - """ - if lock is None: - lock_ = nullcontext() - else: - lock_ = lock - - row_count = 0 - - for df in df_iter: - with lock_: - df.to_csv(path, mode="a", index=False, header=False) - row_count += len(df) - - return row_count - - -class TableExtractorError(Exception): - pass - - -class TableExtractor: - _connector: Connector - _config: ExtractorConfig - _storage_dir: Path - _relational_data: RelationalData - _chunk_size: int - - table_order: list[str] - - def __init__( - self, - *, - config: ExtractorConfig, - connector: Connector, - storage_dir: Path, - ): - self._connector = connector - self._config = config - - if not storage_dir.is_dir(): - raise ValueError("The `storage_dir` must be a directory!") - - self._storage_dir = storage_dir - - self._relational_data = RelationalData(directory=self._storage_dir) - self.table_order = [] - self._chunk_size = 50_000 - - def _get_table_session( - self, table_name: str, schema: Optional[str] = None - ) -> _TableSession: - metadata = MetaData() - metadata.reflect(only=[table_name], bind=self._connector.engine, schema=schema) - if schema: - # In cases where a schema exists, it can be prepended to the table name - schema_prepended_table_name = f"{schema}.{table_name}" - if ( - table_name not in metadata.tables - and schema_prepended_table_name in metadata.tables - ): - table_name = schema_prepended_table_name - table = metadata.tables[table_name] - return _TableSession(table=table, engine=self._connector.engine) - - def _create_rel_data( - self, extracted_tables: Optional[list[str]] = None - ) -> RelationalData: - """ - Internal helper method. This can be used to construct a `RelationalData` - object that either contains just the table headers and FK/PK relationships - or create an instance that has loaded DataFrames. - - You may need to use this in order to build up a "fresh" `RelationalData` object - _after_ tables have already been sampled. Especially if you need to consider - any embedded JSON data that is used to create additional PK/FK mappings - that are invented. - - If any table names are provided in the `extracted_tables` list, then those tables - will be loaded as DFs and added as the data to nodes. - - NOTE: If `extracted_tables` are provided, then these tables must have already been - extracted! - """ - if extracted_tables is None: - extracted_tables = [] - - rel_data = RelationalData(directory=self._storage_dir) - inspector = inspect(self._connector.engine) - foreign_keys: list[tuple[str, dict]] = [] - - for table_name in inspector.get_table_names(schema=self._config.schema): - if self._config._should_skip_table(table_name): - continue - - logger.debug(f"Extracting source schema data from `{table_name}`") - - if table_name not in extracted_tables: - df = pd.DataFrame( - columns=[col["name"] for col in inspector.get_columns(table_name)] - ) - else: - df = self.get_table_df(table_name) - - primary_key = inspector.get_pk_constraint(table_name)["constrained_columns"] - for fk in inspector.get_foreign_keys(table_name): - if self._config._should_skip_table(fk["referred_table"]): - continue - foreign_keys.append((table_name, fk)) - - rel_data.add_table(name=table_name, primary_key=primary_key, data=df) - - for foreign_key in foreign_keys: - table, fk = foreign_key - rel_data.add_foreign_key_constraint( - table=table, - constrained_columns=fk["constrained_columns"], - referred_table=fk["referred_table"], - referred_columns=fk["referred_columns"], - ) - - return rel_data - - def _extract_schema(self) -> None: - # This will initially only populate RelationalData with empty - # DataFrames which are only used for building up the right order - # to extract tables for subsetting purposes. There will be no - # actual table contents stored on the Graph. This means that - # after this runs the `RelationalData` object will not have - # any relationships that may exist from embedded JSON. - - self._relational_data = self._create_rel_data() - - # Set the table processing order for extraction - self.table_order = list( - reversed(self._relational_data.list_tables_parents_before_children()) - ) - - def _table_path(self, table_name: str) -> Path: - return self._storage_dir / f"{table_name}.csv" - - def get_table_df(self, table_name: str) -> pd.DataFrame: - """ - Return a sampled table as a DataFrame. This assumes tables have - already been sampled and are stored on disk. - - Args: - table_name: The name of the table to fetch as a DataFrame. - """ - table_path = self._table_path(table_name) - if not table_path.is_file(): - raise ValueError(f"The table name: `{table_name}` does not exist.") - - return pd.read_csv(table_path) - - def _load_table_pk_values( - self, table_name: str, child_table_names: list[str] - ) -> _PKValues: - """ - Given a table name, extract all of the values of the primary key of the - table as they exist in already sampled tables of this table's children. - - In otherwords, iterate all the children of this table and extract the - values of the foreign keys that reference this table. The values of - the FKs will represent all the PK values for this table that - should be extracted as a subset. - - This function assumes the children table already sampled and stored - based on the required table ordering needed to completed subsetting.1 - """ - values_ddf = None - parent_column_names: list[str] = [] - pk_set = set(self._relational_data.get_primary_key(table_name)) - logger.debug( - f"Extracting primary key values for sampling from table '{table_name}'" - ) - - for child_table_name in child_table_names: - child_fks = self._relational_data.get_foreign_keys(child_table_name) - for fk in child_fks: - if fk.parent_table_name == table_name and pk_set == set( - fk.parent_columns - ): - if not parent_column_names: - parent_column_names = fk.parent_columns - logger.debug( - f"Found primary key values for table '{table_name}' in child table '{child_table_name}'" - ) - - # NOTE: When we extract the FK values from the child tables, we store them under the PK - # column names for the current parent table we are processing. - rename_map = dict(zip(fk.columns, fk.parent_columns)) - - # NOTE: The child tables MUST have already been extracted! - child_table_path = self._table_path(child_table_name) - - tmp_ddf = dd.read_csv( # pyright: ignore - str(child_table_path), usecols=fk.columns - ) - tmp_ddf = tmp_ddf.rename(columns=rename_map) - if values_ddf is None: - values_ddf = tmp_ddf - else: - values_ddf = dd.concat([values_ddf, tmp_ddf]) # pyright: ignore - - # Dropping the duplicates *only* works consistently - # when operating on a specific subset of columns using the [] - # notation. Using the "subset=" kwarg does not work, and neither - # does operating on the entire DDF. - if parent_column_names and values_ddf is not None: - values_ddf = values_ddf[ # pyright: ignore - parent_column_names - ].drop_duplicates() # pyright: ignore - else: - raise TableExtractorError( - f"Could not extract primary key values needed to sample from table `{table_name}`" - ) - - return _PKValues( - table_name=table_name, - values_ddf=values_ddf, # pyright: ignore - column_names=parent_column_names, - ) - - def _sample_pk_values(self, table_path: Path, pk_values: _PKValues) -> int: - """ - Given a DDF of PK values for a table, we query for those rows and start - streaming them to the target path. This assumes the target file already - exists with column names and we will be appending to that file. - """ - row_count = 0 - - lock = Lock() - - def handle_partition(df: pd.DataFrame, lock: Lock): - # This runs in another thread so we have to re-create our table session info - table_session = self._get_table_session(pk_values.table_name) - nonlocal row_count - - chunk_size = 15_000 # limit how many checks go into a WHERE clause - - for _, chunk_df in df.groupby(np.arange(len(df)) // chunk_size): - values_list = chunk_df.to_records(index=False).tolist() - query = table_session.table.select().where( - tuple_( - *[table_session.table.c[col] for col in pk_values.column_names] - ).in_(values_list) - ) - - with table_session.engine.connect() as conn: - df_iter = pd.read_sql_query(query, conn, chunksize=self._chunk_size) - write_count = _stream_df_to_path(df_iter, table_path, lock=lock) - row_count += write_count - - logger.debug( - f"Sampling primary key values for parent table '{pk_values.table_name}'" - ) - - # By providing the "meta" kwarg, this prevents - # Dask from running the map function ("handle_partition") on - # dummy data in an attempt to infer the metdata (which we don't - # need for the purposes of making the SQL queries). When this - # dummy partition is mapped, it was using the values in the - # partition to make additional SQL queries which can have - # unintended side effects. See the "map_partition" docs - # for more details if interested. - pk_values.values_ddf.map_partitions( - handle_partition, lock, meta=(None, "object") - ).compute() - - return row_count - - def _flat_sample( - self, table_path: Path, table_session: _TableSession - ) -> TableMetadata: - sample_row_count = _determine_sample_size( - self._config, table_session.total_row_count - ) - - logger.debug( - f"Sampling {sample_row_count} rows from table '{table_session.table.name}'" - ) - - df_iter = iter([pd.DataFrame()]) - - with table_session.engine.connect() as conn: - contiguous_query = select(table_session.table).limit(sample_row_count) - if self._config.sample_mode == SampleMode.RANDOM: - random_success = False - random_errs = [] - # Different dialects will use different random functions - # so we just try them until one works. If none work, - # we fall back to contiguous mode - for rand_func in (func.random(), func.rand()): - random_query = ( - select(table_session.table) - .order_by(rand_func) - .limit(sample_row_count) - ) - try: - df_iter = pd.read_sql_query( - random_query, conn, chunksize=self._chunk_size - ) - except Exception as err: - random_errs.append(str(err)) - else: - random_success = True - break - - if not random_success: - logger.info( - f"Could not sample randomly, received the following errors: {', '.join(random_errs)}. Will fall back to contiguous mode." - ) - - df_iter = pd.read_sql_query( - contiguous_query, conn, chunksize=self._chunk_size - ) - - else: - df_iter = pd.read_sql_query( - contiguous_query, conn, chunksize=self._chunk_size - ) - - sampled_row_count = _stream_df_to_path(df_iter, table_path) - - return TableMetadata( - original_row_count=table_session.total_row_count, - sampled_row_count=sampled_row_count, - column_count=table_session.total_column_count, - ) - - def _sample_table( - self, - table_name: str, - child_tables: Optional[list[str]] = None, - schema: Optional[str] = None, - ) -> TableMetadata: - if self._relational_data.is_empty: - self._extract_schema() - - table_path = self._table_path(table_name) - table_session = self._get_table_session(table_name, schema=schema) - engine = self._connector.engine - - # First we'll create our table file on disk and bootstrap - # it with just the column names - df = pd.DataFrame(columns=table_session.columns) - df.to_csv(table_path, index=False) - - # If we aren't sampling any rows, we're done! - if self._config.empty_table: - return TableMetadata( - original_row_count=table_session.total_row_count, - sampled_row_count=0, - column_count=table_session.total_column_count, - ) - - # If we are sampling the entire table, we can just short circuit here - # and start streaing data into the file - if self._config.entire_table: - logger.debug(f"Extracting entire table: {table_name}") - with engine.connect() as conn: - df_iter = pd.read_sql_table( - table_name, conn, chunksize=self._chunk_size, schema=schema - ) - sampled_count = _stream_df_to_path(df_iter, table_path) - - return TableMetadata( - original_row_count=table_session.total_row_count, - sampled_row_count=sampled_count, - column_count=table_session.total_column_count, - ) - - # If this is a leaf table, determine how many rows to sample and - # run the query and start streaming the results - if not child_tables: - return self._flat_sample(table_path, table_session) - - # Child nodes exist at this point. - - # At this point, we are at a parent table, first we build a DDF that contains - # all of the PK values that we will sample from this parent table. - # These PK values are the set union of all the FK values of this - # parent table's child tables - pk_values = self._load_table_pk_values(table_name, child_tables) - sampled_row_count = self._sample_pk_values(table_path, pk_values) - - return TableMetadata( - original_row_count=table_session.total_row_count, - sampled_row_count=sampled_row_count, - column_count=table_session.total_column_count, - ) - - def sample_tables(self, schema: Optional[str] = None) -> dict[str, TableMetadata]: - """ - Extract database tables according to the `ExtractorConfig.` Tables will be stored in the - configured storage directory that is configured on the `ExtractorConfig` object. - """ - if self._relational_data.is_empty: - self._extract_schema() - - table_data = {} - for table_name in self.table_order: - child_tables = self._relational_data.get_descendants(table_name) - meta = self._sample_table( - table_name, child_tables=child_tables, schema=schema - ) - table_data[table_name] = meta - - return table_data - - @property - def relational_data(self) -> RelationalData: - """ - Return the "RelationalData" instance that was created - during table extraction. - """ - if self._relational_data.is_empty: - raise TableExtractorError( - "Cannot return `RelationalData`, `sample_tables()` must be run first." - ) - return self._relational_data diff --git a/src/gretel_trainer/relational/json.py b/src/gretel_trainer/relational/json.py deleted file mode 100644 index 009c5fc4..00000000 --- a/src/gretel_trainer/relational/json.py +++ /dev/null @@ -1,452 +0,0 @@ -from __future__ import annotations - -import logging -import re - -from dataclasses import dataclass -from json import dumps, JSONDecodeError, loads -from typing import Any, Optional, Protocol, Union -from uuid import uuid4 - -import numpy as np -import pandas as pd - -from unflatten import unflatten - -logger = logging.getLogger(__name__) - -# JSON dict to multi-column and list to multi-table - -FIELD_SEPARATOR = ">" -TABLE_SEPARATOR = "^" -ID_SUFFIX = "~id" -ORDER_COLUMN = "array~order" -CONTENT_COLUMN = "content" -PRIMARY_KEY_COLUMN = "~PRIMARY_KEY_ID~" - - -def load_json(obj: Any) -> Union[dict, list]: - if isinstance(obj, (dict, list)): - return obj - else: - return loads(obj) - - -def is_json(obj: Any, json_type=(dict, list)) -> bool: - try: - obj = load_json(obj) - except (ValueError, TypeError, JSONDecodeError): - return False - else: - return isinstance(obj, json_type) - - -def is_dict(obj: Any) -> bool: - return is_json(obj, dict) - - -def is_list(obj: Any) -> bool: - return isinstance(obj, np.ndarray) or is_json(obj, list) - - -def pandas_json_normalize(series: pd.Series) -> pd.DataFrame: - return pd.json_normalize(series.apply(load_json).to_list(), sep=FIELD_SEPARATOR) - - -def nulls_to_empty_dicts(df: pd.DataFrame) -> pd.DataFrame: - return df.applymap(lambda x: {} if pd.isnull(x) else x) - - -def nulls_to_empty_lists(series: pd.Series) -> pd.Series: - return series.apply(lambda x: x if isinstance(x, list) or not pd.isnull(x) else []) - - -def _normalize_json( - nested_dfs: list[tuple[str, pd.DataFrame]], - flat_dfs: list[tuple[str, pd.DataFrame]], - columns: Optional[list[str]] = None, -) -> list[tuple[str, pd.DataFrame]]: - if not nested_dfs: - return flat_dfs - name, df = nested_dfs.pop() - cols_to_scan = columns or [col for col in df.columns if df.dtypes[col] == "object"] - dict_cols = [col for col in cols_to_scan if df[col].dropna().apply(is_dict).all()] - if dict_cols: - df[dict_cols] = nulls_to_empty_dicts(df[dict_cols]) - for col in dict_cols: - new_cols = pandas_json_normalize(df[col]).add_prefix(col + FIELD_SEPARATOR) - df = pd.concat([df, new_cols], axis="columns") - df = df.drop(columns=new_cols.columns[new_cols.isnull().all()]) - nested_dfs.append((name, df.drop(columns=dict_cols))) - else: - list_cols = [ - col for col in cols_to_scan if df[col].dropna().apply(is_list).all() - ] - if list_cols: - for col in list_cols: - new_table = df[col].explode().dropna().rename(CONTENT_COLUMN).to_frame() - new_table[ORDER_COLUMN] = new_table.groupby(level=0).cumcount() - nested_dfs.append( - ( - name + TABLE_SEPARATOR + col, - new_table.reset_index(names=name + ID_SUFFIX), - ) - ) - nested_dfs.append((name, df.drop(columns=list_cols))) - else: - flat_dfs.append((name, df)) - return _normalize_json(nested_dfs, flat_dfs) - - -# Multi-table and multi-column back to single-table with JSON - - -def get_id_columns(df: pd.DataFrame) -> list[str]: - return [col for col in df.columns if col.endswith(ID_SUFFIX)] - - -def get_parent_table_name_from_child_id_column(id_column_name: str) -> str: - return id_column_name[: -len(ID_SUFFIX)] - - -def get_parent_column_name_from_child_table_name(table_name: str) -> str: - return table_name.split(TABLE_SEPARATOR)[-1] - - -def _is_invented_child_table(table: str, rel_data: _RelationalData) -> bool: - imeta = rel_data.get_invented_table_metadata(table) - return imeta is not None and imeta.invented_root_table_name != table - - -def generate_unique_table_name(s: str): - sanitized_str = "-".join(re.findall(r"[a-zA-Z_0-9]+", s)) - # Generate unique suffix to prevent collisions - unique_suffix = make_suffix() - # Max length for a table/filename is 128 chars - return f"{sanitized_str[:80]}_invented_{unique_suffix}" - - -def make_suffix(): - return uuid4().hex - - -def jsonencode(df: pd.DataFrame, cols: list[str]) -> pd.DataFrame: - """ - Returns a dataframe with the specified columns transformed such that their JSON-like - values can be written to CSV and later re-read back to Pandas from CSV. - """ - # Save memory and return the *original dataframe* (not a copy!) if no columns to transform - if len(cols) == 0: - return df - - def _jsonencode(x): - if isinstance(x, str): - return x - elif isinstance(x, (dict, list)): - return dumps(x) - - copy = df.copy() - for col in cols: - copy[col] = copy[col].map(_jsonencode) - - return copy - - -def jsondecode(df: pd.DataFrame, cols: list[str]) -> pd.DataFrame: - """ - Returns a dataframe with the specified columns parsed from JSON to Python objects. - """ - # Save memory and return the *original dataframe* (not a copy!) if no columns to transform - if len(cols) == 0: - return df - - def _jsondecode(obj): - try: - return loads(obj) - except (ValueError, TypeError, JSONDecodeError): - return obj - - copy = df.copy() - for col in cols: - copy[col] = copy[col].map(_jsondecode) - - return copy - - -class _RelationalData(Protocol): - def get_foreign_keys( - self, table: str - ) -> list: # can't specify element type (ForeignKey) without cyclic dependency - ... - - def get_table_columns(self, table: str) -> list[str]: ... - - def get_invented_table_metadata( - self, table: str - ) -> Optional[InventedTableMetadata]: ... - - -@dataclass -class InventedTableMetadata: - invented_root_table_name: str - original_table_name: str - json_breadcrumb_path: str - empty: bool - - -@dataclass -class ProducerMetadata: - invented_root_table_name: str - table_name_mappings: dict[str, str] - - @property - def table_names(self) -> list[str]: - return list(self.table_name_mappings.values()) - - -def ingest( - table_name: str, - primary_key: list[str], - df: pd.DataFrame, - json_columns: list[str], -) -> Optional[IngestResponseT]: - json_decoded = jsondecode(df, json_columns) - tables = _normalize_json([(table_name, json_decoded)], [], json_columns) - - # If we created additional tables (from JSON lists) or added columns (from JSON dicts) - if len(tables) > 1 or len(tables[0][1].columns) > len(df.columns): - # Map json breadcrumbs to uniquely generated table name - mappings = {name: generate_unique_table_name(table_name) for name, _ in tables} - logger.info(f"Transformed JSON into {len(mappings)} tables for modeling.") - logger.debug(f"Invented table names: {list(mappings.values())}") - commands = _generate_commands( - tables=tables, - table_name_mappings=mappings, - original_table_name=table_name, - original_primary_key=primary_key, - ) - producer_metadata = ProducerMetadata( - invented_root_table_name=mappings[table_name], - table_name_mappings=mappings, - ) - return (commands, producer_metadata) - - -def _generate_commands( - tables: list[tuple[str, pd.DataFrame]], - table_name_mappings: dict[str, str], - original_table_name: str, - original_primary_key: list[str], -) -> CommandsT: - """ - Returns lists of keyword arguments designed to be passed to a - RelationalData instance's _add_single_table and add_foreign_key methods - """ - root_table_name = table_name_mappings[original_table_name] - - _add_single_table = [] - add_foreign_key = [] - - for table_breadcrumb_name, table_df in tables: - table_name = table_name_mappings[table_breadcrumb_name] - if table_name == root_table_name: - table_pk = original_primary_key + [PRIMARY_KEY_COLUMN] - else: - table_pk = [PRIMARY_KEY_COLUMN] - table_df.index.rename(PRIMARY_KEY_COLUMN, inplace=True) - table_df.reset_index(inplace=True) - metadata = InventedTableMetadata( - invented_root_table_name=root_table_name, - original_table_name=original_table_name, - json_breadcrumb_path=table_breadcrumb_name, - empty=table_df.empty, - ) - _add_single_table.append( - { - "name": table_name, - "primary_key": table_pk, - "data": table_df, - "invented_table_metadata": metadata, - } - ) - - for table_breadcrumb_name, table_df in tables: - table_name = table_name_mappings[table_breadcrumb_name] - for column in get_id_columns(table_df): - referred_table = table_name_mappings[ - get_parent_table_name_from_child_id_column(column) - ] - add_foreign_key.append( - { - "table": table_name, - "constrained_columns": [column], - "referred_table": referred_table, - "referred_columns": [PRIMARY_KEY_COLUMN], - } - ) - return (_add_single_table, add_foreign_key) - - -def restore( - tables: dict[str, pd.DataFrame], - rel_data: _RelationalData, - root_table_name: str, - original_columns: list[str], - table_name_mappings: dict[str, str], - original_table_name: str, -) -> Optional[pd.DataFrame]: - # If the root invented table is not present, we are completely out of luck - # (Missing invented child tables can be replaced with empty lists so we at least provide _something_) - if root_table_name not in tables: - logger.warning( - f"Cannot restore nested JSON data: root invented table `{root_table_name}` is missing from output tables." - ) - return None - - return _denormalize_json( - tables, rel_data, table_name_mappings, original_table_name - )[original_columns] - - -def _denormalize_json( - tables: dict[str, pd.DataFrame], - rel_data: _RelationalData, - table_name_mappings: dict[str, str], - original_table_name: str, -) -> pd.DataFrame: - # Throughout this function, "provenance name" refers to the json-breadcrumb-style name (e.g. foo^bar>baz) - # and "node name" refers to the table name as it appears in the RelationalData graph (e.g. foo_invented_{uuid}) - - # The provided `table_name_mappings` argument (from producer metadata) maps from provenance name to node name. - # `inverse_table_name_mappings` inverts that mapping (so: node name keys, provenance name values). - # `table_dict` replaces the keys (node names) in `tables` with corresponding provenance names - table_node_names = list(table_name_mappings.values()) - inverse_table_name_mappings = {v: k for k, v in table_name_mappings.items()} - table_dict = {inverse_table_name_mappings[k]: v for k, v in tables.items()} - - get_table = lambda node_name: _get_table_or_empty_fallback( - rel_data, inverse_table_name_mappings, table_dict, node_name - ) - - for table_name in list(reversed(table_node_names)): - table_df = get_table(table_name) - table_provenance_name = inverse_table_name_mappings[table_name] - - if table_df.empty and _is_invented_child_table(table_name, rel_data): - # Mutate the parent dataframe by adding a column with empty lists - - # We know invented child tables have exactly one foreign key - fk = rel_data.get_foreign_keys(table_name)[0] - parent_node_name = fk.parent_table_name - parent_provenance_name = inverse_table_name_mappings[parent_node_name] - col_name = get_parent_column_name_from_child_table_name( - table_provenance_name - ) - - parent_df = get_table(parent_node_name) - - # Add the column of empty lists and "save" the modified df back to the dict for later use - parent_df[col_name] = [[] for _ in range(len(parent_df))] - table_dict[parent_provenance_name] = parent_df - else: - # First, make a version of `table_df` with column names altered to send to `unflatten` - # See: https://github.com/dairiki/unflatten/#synopsis - col_names = [col for col in table_df.columns if FIELD_SEPARATOR in col] - new_col_names = [col.replace(FIELD_SEPARATOR, ".") for col in col_names] - flat_df = table_df[col_names].rename( - columns=dict(zip(col_names, new_col_names)) - ) - - # Turn this dataframe into a dict to send to `unflatten`... - flat_dict = { - k: { - k1: v1 - for k1, v1 in v.items() - if v1 is not np.nan and v1 is not None - } - for k, v in flat_df.to_dict("index").items() - } - # ...and after unflattening, go back to a dataframe - dict_df = nulls_to_empty_dicts( - pd.DataFrame.from_dict( - {k: unflatten(v) for k, v in flat_dict.items()}, orient="index" - ) - ) - - # Join the flattened dict_df onto table_df. - # You might think we could drop the original non-nested columns at this point, - # but under certain JSON shapes (e.g. nested lists of objects) we still need them present. - nested_df = table_df.join(dict_df) - if _is_invented_child_table(table_name, rel_data): - # we know there is exactly one foreign key on invented child tables... - fk = rel_data.get_foreign_keys(table_name)[0] - # ...with exactly one column - fk_col = fk.columns[0] - parent_node_name = fk.parent_table_name - parent_provenance_name = inverse_table_name_mappings[parent_node_name] - nested_df = ( - nested_df.sort_values(ORDER_COLUMN) - .groupby(fk_col)[CONTENT_COLUMN] - .agg(list) - ) - col_name = get_parent_column_name_from_child_table_name( - table_provenance_name - ) - parent_df = get_table(parent_node_name) - parent_df = parent_df.join(nested_df.rename(col_name)) - parent_df[col_name] = nulls_to_empty_lists(parent_df[col_name]) - table_dict[parent_provenance_name] = parent_df - table_dict[table_provenance_name] = nested_df - return table_dict[original_table_name] - - -def _get_table_or_empty_fallback( - rel_data: _RelationalData, - inverse_table_name_mappings: dict[str, str], - table_dict: dict[str, pd.DataFrame], - node_name: str, -) -> pd.DataFrame: - """ - Helper function used inside _denormalize_json to safely retrieve a table dataframe. - If `node_name` is not present in `table_dict` (e.g. because the Gretel Job failed), - return an empty dataframe with the expected columns. - """ - empty_fallback = pd.DataFrame(columns=rel_data.get_table_columns(node_name)) - provenance_name = inverse_table_name_mappings[node_name] - return table_dict.get(provenance_name, empty_fallback) - - -def get_json_columns(df: pd.DataFrame) -> list[str]: - """ - Samples non-null values from all columns and returns those that contain - valid JSON dicts or lists. - - Raises an error if *all* columns are lists, as that is not currently supported. - """ - object_cols = { - col: data - for col in df.columns - if df.dtypes[col] == "object" and len(data := df[col].dropna()) > 0 - } - - if len(object_cols) == 0: - return [] - - list_cols = [ - col for col, series in object_cols.items() if series.apply(is_list).all() - ] - - if len(list_cols) == len(df.columns): - raise ValueError("Cannot accept tables with JSON lists in all columns") - - dict_cols = [ - col - for col, series in object_cols.items() - if col not in list_cols and series.apply(is_dict).all() - ] - - return dict_cols + list_cols - - -CommandsT = tuple[list[dict], list[dict]] -IngestResponseT = tuple[CommandsT, ProducerMetadata] diff --git a/src/gretel_trainer/relational/log.py b/src/gretel_trainer/relational/log.py deleted file mode 100644 index 75e9a53d..00000000 --- a/src/gretel_trainer/relational/log.py +++ /dev/null @@ -1,37 +0,0 @@ -import logging - -from contextlib import contextmanager - -RELATIONAL = "gretel_trainer.relational" - -log_format = "%(levelname)s - %(asctime)s - %(message)s" -time_format = "%Y-%m-%d %H:%M:%S" -formatter = logging.Formatter(log_format, time_format) -handler = logging.StreamHandler() -handler.setFormatter(formatter) - -logger = logging.getLogger(RELATIONAL) -logger.handlers.clear() -logger.addHandler(handler) -logger.setLevel("INFO") - -# Clear out any existing root handlers -# (This prevents duplicate log output in Colab) -for root_handler in logging.root.handlers: - logging.root.removeHandler(root_handler) - - -def set_log_level(level: str): - logger = logging.getLogger(RELATIONAL) - logger.setLevel(level) - - -@contextmanager -def silent_logs(): - logger = logging.getLogger(RELATIONAL) - current_level = logger.getEffectiveLevel() - logger.setLevel("CRITICAL") - try: - yield - finally: - logger.setLevel(current_level) diff --git a/src/gretel_trainer/relational/model_config.py b/src/gretel_trainer/relational/model_config.py deleted file mode 100644 index ad73e322..00000000 --- a/src/gretel_trainer/relational/model_config.py +++ /dev/null @@ -1,184 +0,0 @@ -from copy import deepcopy -from typing import Any, Optional - -from gretel_client.projects.exceptions import ModelConfigError -from gretel_client.projects.models import read_model_config -from gretel_trainer.relational.core import ( - GretelModelConfig, - MultiTableException, - RelationalData, -) - -TRANSFORM_MODEL_KEYS = ["transform", "transforms", "transform_v2"] - - -def get_model_key(config_dict: dict[str, Any]) -> Optional[str]: - try: - models = config_dict["models"] - assert isinstance(models, list) - assert isinstance(models[0], dict) - return list(models[0])[0] - except (AssertionError, IndexError, KeyError): - return None - - -def ingest(config: GretelModelConfig) -> dict[str, Any]: - try: - return read_model_config(deepcopy(config)) - except ModelConfigError as e: - raise MultiTableException("Invalid config") from e - - -def _model_name(workflow: str, table: str) -> str: - ok_table_name = table.replace("--", "__") - return f"{workflow}-{ok_table_name}" - - -def make_classify_config(table: str, config: GretelModelConfig) -> dict[str, Any]: - tailored_config = ingest(config) - tailored_config["name"] = _model_name("classify", table) - return tailored_config - - -def make_evaluate_config(table: str, sqs_type: str) -> dict[str, Any]: - tailored_config = ingest("evaluate/default") - tailored_config["name"] = f"evaluate-{_model_name(sqs_type, table)}" - return tailored_config - - -def make_synthetics_config(table: str, config: GretelModelConfig) -> dict[str, Any]: - tailored_config = ingest(config) - tailored_config["name"] = _model_name("synthetics", table) - return tailored_config - - -def make_transform_config( - rel_data: RelationalData, table: str, config: GretelModelConfig -) -> dict[str, Any]: - tailored_config = ingest(config) - tailored_config["name"] = _model_name("transforms", table) - - model_key, model = next(iter(tailored_config["models"][0].items())) - - # Ensure we have a transform config - if model_key not in TRANSFORM_MODEL_KEYS: - raise MultiTableException("Invalid transform config") - - # Tv2 configs pass through unaltered (except for name, above) - if model_key == "transform_v2": - return tailored_config - - # We add a passthrough policy to Tv1 configs to avoid transforming PK/FK columns - key_columns = rel_data.get_all_key_columns(table) - if len(key_columns) > 0: - policies = model["policies"] - passthrough_policy = _passthrough_policy(key_columns) - adjusted_policies = [passthrough_policy] + policies - - tailored_config["models"][0][model_key]["policies"] = adjusted_policies - - return tailored_config - - -def _passthrough_policy(columns: list[str]) -> dict[str, Any]: - return { - "name": "ignore-keys", - "rules": [ - { - "name": "ignore-key-columns", - "conditions": {"field_name": columns}, - "transforms": [ - { - "type": "passthrough", - } - ], - } - ], - } - - -def assemble_configs( - rel_data: RelationalData, - config: GretelModelConfig, - table_specific_configs: Optional[dict[str, GretelModelConfig]], - only: Optional[set[str]], - ignore: Optional[set[str]], -) -> dict[str, dict]: - only, ignore = _expand_only_and_ignore(rel_data, only, ignore) - - tables_in_scope = [ - table - for table in rel_data.list_all_tables() - if not _skip_table(table, only, ignore) - ] - - # Standardize type of all provided models - config_dict = ingest(config) - table_specific_config_dicts = { - table: ingest(conf) for table, conf in (table_specific_configs or {}).items() - } - - # Translate any JSON-source tables in table_specific_configs to invented tables - all_table_specific_config_dicts = {} - for table, conf in table_specific_config_dicts.items(): - m_names = rel_data.get_modelable_table_names(table) - if len(m_names) == 0: - raise MultiTableException(f"Unrecognized table name: `{table}`") - for m_name in m_names: - all_table_specific_config_dicts[m_name] = table_specific_config_dicts.get( - m_name, conf - ) - - # Ensure compatibility between only/ignore and table_specific_configs - omitted_tables_with_overrides_specified = [] - for table in all_table_specific_config_dicts: - if _skip_table(table, only, ignore): - omitted_tables_with_overrides_specified.append(table) - if len(omitted_tables_with_overrides_specified) > 0: - raise MultiTableException( - f"Cannot provide configs for tables that have been omitted from synthetics training: " - f"{omitted_tables_with_overrides_specified}" - ) - - return { - table: all_table_specific_config_dicts.get(table, config_dict) - for table in tables_in_scope - } - - -def _expand_only_and_ignore( - rel_data: RelationalData, only: Optional[set[str]], ignore: Optional[set[str]] -) -> tuple[Optional[set[str]], Optional[set[str]]]: - """ - Accepts the `only` and `ignore` parameter values as provided by the user and: - - ensures both are not set (must provide one or the other, or neither) - - translates any JSON-source tables to the invented tables - """ - if only is not None and ignore is not None: - raise MultiTableException("Cannot specify both `only` and `ignore`.") - - modelable_tables = set() - for table in only or ignore or {}: - m_names = rel_data.get_modelable_table_names(table) - if len(m_names) == 0: - raise MultiTableException(f"Unrecognized table name: `{table}`") - modelable_tables.update(m_names) - - if only is None: - return (None, modelable_tables) - elif ignore is None: - return (modelable_tables, None) - else: - return (None, None) - - -def _skip_table( - table: str, only: Optional[set[str]], ignore: Optional[set[str]] -) -> bool: - skip = False - if only is not None and table not in only: - skip = True - if ignore is not None and table in ignore: - skip = True - - return skip diff --git a/src/gretel_trainer/relational/multi_table.py b/src/gretel_trainer/relational/multi_table.py deleted file mode 100644 index 7edcc09a..00000000 --- a/src/gretel_trainer/relational/multi_table.py +++ /dev/null @@ -1,1069 +0,0 @@ -""" -This module provides the "MultiTable" class to users. This allows you to -take extracted data from a database or data warehouse, and process it -with Gretel using Transforms, Classify, and Synthetics. -""" - -from __future__ import annotations - -import json -import logging -import shutil - -from collections import defaultdict -from datetime import datetime -from pathlib import Path -from typing import Any, cast, Optional, Union - -import pandas as pd - -import gretel_trainer.relational.ancestry as ancestry - -from gretel_client.config import add_session_context, ClientConfig, RunnerMode -from gretel_client.projects import create_project, get_project, Project -from gretel_client.projects.artifact_handlers import open_artifact -from gretel_client.projects.jobs import ACTIVE_STATES, END_STATES, Status -from gretel_client.projects.records import RecordHandler -from gretel_trainer.relational.backup import ( - Backup, - BackupClassify, - BackupGenerate, - BackupRelationalData, - BackupSyntheticsTrain, - BackupTransformsTrain, -) -from gretel_trainer.relational.core import ( - GretelModelConfig, - MultiTableException, - RelationalData, - Scope, - UserFriendlyDataT, -) -from gretel_trainer.relational.json import InventedTableMetadata, ProducerMetadata -from gretel_trainer.relational.model_config import ( - assemble_configs, - get_model_key, - make_classify_config, - make_evaluate_config, - make_synthetics_config, - make_transform_config, -) -from gretel_trainer.relational.output_handler import OutputHandler, SDKOutputHandler -from gretel_trainer.relational.report.report import ReportPresenter, ReportRenderer -from gretel_trainer.relational.sdk_extras import ExtendedGretelSDK -from gretel_trainer.relational.strategies.ancestral import AncestralStrategy -from gretel_trainer.relational.strategies.independent import IndependentStrategy -from gretel_trainer.relational.table_evaluation import TableEvaluation -from gretel_trainer.relational.task_runner import run_task, TaskContext -from gretel_trainer.relational.tasks.classify import ClassifyTask -from gretel_trainer.relational.tasks.synthetics_evaluate import SyntheticsEvaluateTask -from gretel_trainer.relational.tasks.synthetics_run import SyntheticsRunTask -from gretel_trainer.relational.tasks.synthetics_train import SyntheticsTrainTask -from gretel_trainer.relational.tasks.transforms_run import TransformsRunTask -from gretel_trainer.relational.tasks.transforms_train import TransformsTrainTask -from gretel_trainer.relational.workflow_state import ( - Classify, - SyntheticsRun, - SyntheticsTrain, - TransformsTrain, -) -from gretel_trainer.version import get_trainer_version - -RELATIONAL_SESSION_METADATA = {"trainer_relational": get_trainer_version()} - -logger = logging.getLogger(__name__) - - -class MultiTable: - """ - Relational data support for the Trainer SDK - - Args: - relational_data (RelationalData): Core data structure representing the source tables and their relationships. - strategy (str, optional): The strategy to use for synthetics. Supports "independent" (default) and "ancestral". - project_display_name (str, optional): Display name in the console for a new Gretel project holding models and artifacts. Defaults to "multi-table". Conflicts with `project`. - project (Project, optional): Existing project to use for models and artifacts. Conflicts with `project_display_name`. - refresh_interval (int, optional): Frequency in seconds to poll Gretel Cloud for job statuses. Must be at least 30. Defaults to 60 (1m). - backup (Backup, optional): Should not be supplied manually; instead use the `restore` classmethod. - """ - - def __init__( - self, - relational_data: RelationalData, - *, - strategy: str = "independent", - project_display_name: Optional[str] = None, - project: Optional[Project] = None, - refresh_interval: Optional[int] = None, - backup: Optional[Backup] = None, - output_handler: Optional[OutputHandler] = None, - session: Optional[ClientConfig] = None, - ): - if project_display_name is not None and project is not None: - raise MultiTableException( - "Cannot set both `project_display_name` and `project`. " - "Set `project_display_name` to create a new project with that display name, " - "or set `project` to run in an existing project." - ) - - if len(cycles := relational_data.foreign_key_cycles) > 0: - logger.warning( - f"Detected cyclic foreign key relationships in schema: {cycles}. " - "Support for cyclic table dependencies is limited. " - "You may need to remove some foreign keys to ensure no cycles exist." - ) - - self._session = _configure_session(project, session) - self._strategy = _validate_strategy(strategy) - self._set_refresh_interval(refresh_interval) - self.relational_data = relational_data - self._extended_sdk = ExtendedGretelSDK(hybrid=self._hybrid) - self._latest_backup: Optional[Backup] = None - self._classify = Classify() - self._transforms_train = TransformsTrain() - self.transform_output_tables: dict[str, pd.DataFrame] = {} - self._synthetics_train = SyntheticsTrain() - self._synthetics_run: Optional[SyntheticsRun] = None - self.synthetic_output_tables: dict[str, pd.DataFrame] = {} - self._evaluations = defaultdict(lambda: TableEvaluation()) - - if backup is None: - self._complete_fresh_init(project_display_name, project, output_handler) - else: - # The current restore-from-backup implementation is hyper-specific to direct SDK usage. - # We do not need to pass the Optional[OutputHandler] `output_handler` here because we know it - # will be None; instead we create an SDKOutputHandler in that method and (for better or worse) - # access its private `_working_dir` attribute (without pyright complaining about that attribute - # not existing on the OutputHandler protocol). - # In the future, other clients restoring state should implement their own `_complete_init_from...` - # method using their own client-specific, not-None implementation of OutputHandler. - self._complete_init_from_backup(backup) - - def _complete_fresh_init( - self, - project_display_name: Optional[str], - project: Optional[Project], - output_handler: Optional[OutputHandler], - ) -> None: - if project is None: - self._project = create_project( - display_name=project_display_name or "multi-table", - session=self._session, - ) - logger.info( - f"Created project `{self._project.display_name}` with unique name `{self._project.name}`." - ) - else: - self._project = project.with_session(self._session) - - self._set_output_handler(output_handler, None) - self._output_handler.save_debug_summary(self.relational_data.debug_summary()) - self._upload_sources_to_project() - - def _set_output_handler( - self, output_handler: Optional[OutputHandler], source_archive: Optional[str] - ) -> None: - self._output_handler = output_handler or SDKOutputHandler( - workdir=self._project.name, - project=self._project, - hybrid=self._hybrid, - source_archive=source_archive, - ) - - def _complete_init_from_backup(self, backup: Backup) -> None: - # Raises GretelProjectEror if not found - self._project = get_project(name=backup.project_name, session=self._session) - logger.info( - f"Connected to existing project `{self._project.display_name}` with unique name `{self._project.name}`." - ) - self._set_output_handler(None, backup.source_archive) - # We currently only support restoring from backup via the SDK, so we know the concrete type of the output handler - # (and set it here so pyright doesn't complain about us peeking in to a private attribute). - self._output_handler = cast(SDKOutputHandler, self._output_handler) - - # RelationalData - source_archive_path = self._output_handler.filepath_for("source_tables.tar.gz") - source_archive_id = self._output_handler.get_source_archive() - if source_archive_id is not None: - self._extended_sdk.download_file_artifact( - gretel_object=self._project, - artifact_name=source_archive_id, - out_path=source_archive_path, - ) - if not Path(source_archive_path).exists(): - raise MultiTableException( - "Cannot restore from backup: source archive is missing." - ) - shutil.unpack_archive( - filename=source_archive_path, - extract_dir=self._output_handler._working_dir, - format="gztar", - ) - for table_name, table_backup in backup.relational_data.tables.items(): - source_data = self._output_handler.filepath_for(f"{table_name}.csv") - invented_table_metadata = None - producer_metadata = None - if (imeta := table_backup.invented_table_metadata) is not None: - invented_table_metadata = InventedTableMetadata(**imeta) - if (pmeta := table_backup.producer_metadata) is not None: - producer_metadata = ProducerMetadata(**pmeta) - self.relational_data._add_single_table( - name=table_name, - primary_key=table_backup.primary_key, - source=source_data, - invented_table_metadata=invented_table_metadata, - producer_metadata=producer_metadata, - ) - for fk_backup in backup.relational_data.foreign_keys: - self.relational_data.add_foreign_key_constraint( - table=fk_backup.table, - constrained_columns=fk_backup.constrained_columns, - referred_table=fk_backup.referred_table, - referred_columns=fk_backup.referred_columns, - ) - - # Classify - backup_classify = backup.classify - if backup_classify is None: - logger.info("No classify data found in backup.") - else: - logger.info("Restoring classify models") - self._classify.models = { - table: self._project.get_model(model_id) - for table, model_id in backup_classify.model_ids.items() - } - - # Transforms Train - backup_transforms_train = backup.transforms_train - if backup_transforms_train is None: - logger.info("No transforms training data found in backup.") - else: - logger.info("Restoring transforms models") - self._transforms_train.models = { - table: self._project.get_model(model_id) - for table, model_id in backup_transforms_train.model_ids.items() - } - - # Synthetics Train - backup_synthetics_train = backup.synthetics_train - if backup_synthetics_train is None: - logger.info("No synthetics training data found in backup.") - return None - - logger.info("Restoring synthetics models") - - self._synthetics_train.lost_contact = backup_synthetics_train.lost_contact - self._synthetics_train.models = { - table: self._project.get_model(model_id) - for table, model_id in backup_synthetics_train.model_ids.items() - } - - still_in_progress = [ - table - for table, model in self._synthetics_train.models.items() - if model.status in ACTIVE_STATES - ] - if len(still_in_progress) > 0: - logger.warning( - f"Cannot restore at this time: model training still in progress for tables `{still_in_progress}`. " - "Please wait for training to finish, and re-attempt restoring from backup once all models have completed training. " - f"You can view training progress in the Console here: {self._project.get_console_url()}" - ) - raise MultiTableException( - "Cannot restore while model training is actively in progress. Wait for training to finish and try again." - ) - - training_failed = [ - table - for table, model in self._synthetics_train.models.items() - if model.status in END_STATES and model.status != Status.COMPLETED - ] - if len(training_failed) > 0: - logger.info( - f"Training failed for tables: {training_failed}. You may try retraining them with modified data by calling `retrain_tables`." - ) - return None - - # Synthetics Generate - backup_generate = backup.generate - if backup_generate is None: - logger.info( - "No synthetic generation jobs had been started in previous instance." - ) - return None - - record_handlers = { - table: self._synthetics_train.models[table].get_record_handler(rh_id) - for table, rh_id in backup_generate.record_handler_ids.items() - } - self._synthetics_run = SyntheticsRun( - identifier=backup_generate.identifier, - record_size_ratio=backup_generate.record_size_ratio, - preserved=backup_generate.preserved, - lost_contact=backup_generate.lost_contact, - record_handlers=record_handlers, - ) - - artifact_keys = [artifact["key"] for artifact in self._project.artifacts] - latest_run_id = self._synthetics_run.identifier - if any(latest_run_id in artifact_key for artifact_key in artifact_keys): - logger.info( - f"Results of most recent run `{latest_run_id}` can be reviewed by downloading the output archive from the project; visit {self._project.get_console_url()}" - ) - else: - logger.info( - f"At time of last backup, generation run `{latest_run_id}` was still in progress. " - "You can attempt to resume that generate job via `generate(resume=True)`, or restart generation from scratch via a regular call to `generate`." - ) - - @classmethod - def restore( - cls, - backup_file: str, - session: Optional[ClientConfig] = None, - ) -> MultiTable: - """ - Create a `MultiTable` instance from a backup file. - """ - logger.info(f"Restoring from backup file `{backup_file}`.") - with open(backup_file, "r") as b: - backup = Backup.from_dict(json.load(b)) - - return MultiTable( - relational_data=RelationalData(directory=backup.project_name), - strategy=backup.strategy, - refresh_interval=backup.refresh_interval, - backup=backup, - session=session, - ) - - def _backup(self) -> None: - backup = self._build_backup() - # exit early if nothing has changed since last backup - if backup == self._latest_backup: - return None - - self._output_handler.save_backup(backup) - - self._latest_backup = backup - - def _build_backup(self) -> Backup: - backup = Backup( - project_name=self._project.name, - strategy=self._strategy.name, - refresh_interval=self._refresh_interval, - source_archive=self._output_handler.get_source_archive(), - relational_data=BackupRelationalData.from_relational_data( - self.relational_data - ), - ) - - # Classify - if len(self._classify.models) > 0: - backup.classify = BackupClassify( - model_ids={ - table: model.model_id - for table, model in self._classify.models.items() - } - ) - - # Transforms Train - if len(self._transforms_train.models) > 0: - backup.transforms_train = BackupTransformsTrain( - model_ids={ - table: model.model_id - for table, model in self._transforms_train.models.items() - }, - lost_contact=self._transforms_train.lost_contact, - ) - - # Synthetics Train - if len(self._synthetics_train.models) > 0: - backup.synthetics_train = BackupSyntheticsTrain( - model_ids={ - table: model.model_id - for table, model in self._synthetics_train.models.items() - }, - lost_contact=self._synthetics_train.lost_contact, - ) - - # Generate - if self._synthetics_run is not None: - backup.generate = BackupGenerate( - identifier=self._synthetics_run.identifier, - record_size_ratio=self._synthetics_run.record_size_ratio, - preserved=self._synthetics_run.preserved, - lost_contact=self._synthetics_run.lost_contact, - record_handler_ids={ - table: rh.record_id - for table, rh in self._synthetics_run.record_handlers.items() - }, - ) - - return backup - - @property - def _hybrid(self) -> bool: - return self._session.default_runner == RunnerMode.HYBRID - - @property - def evaluations(self) -> dict[str, TableEvaluation]: - evaluations = defaultdict(lambda: TableEvaluation()) - - for table, evaluation in self._evaluations.items(): - if (public_name := self.relational_data.get_public_name(table)) is not None: - evaluations[public_name] = evaluation - - return evaluations - - def _set_refresh_interval(self, interval: Optional[int]) -> None: - if interval is None: - self._refresh_interval = 60 - else: - if interval < 30: - logger.warning( - "Refresh interval must be at least 30 seconds. Setting to 30." - ) - self._refresh_interval = 30 - else: - self._refresh_interval = interval - - def classify(self, config: GretelModelConfig, all_rows: bool = False) -> None: - classify_data_sources = {} - for table in self.relational_data.list_all_tables(): - classify_config = make_classify_config(table, config) - data_source = str(self.relational_data.get_table_source(table)) - classify_data_sources[table] = data_source - - # Create model if necessary - if self._classify.models.get(table) is not None: - continue - - model = self._project.create_model_obj( - model_config=classify_config, data_source=data_source - ) - self._classify.models[table] = model - - self._backup() - - task = ClassifyTask( - classify=self._classify, - data_sources=classify_data_sources, - all_rows=all_rows, - ctx=self._new_task_context(), - output_handler=self._output_handler, - ) - run_task(task, self._extended_sdk) - - self._output_handler.save_classify_outputs(task.result_filepaths) - - def _setup_transforms_train_state(self, configs: dict[str, dict]) -> None: - for table, config in configs.items(): - model = self._project.create_model_obj( - model_config=make_transform_config(self.relational_data, table, config), - data_source=str(self.relational_data.get_table_source(table)), - ) - self._transforms_train.models[table] = model - - self._backup() - - def transform_v2( - self, - config: GretelModelConfig, - *, - table_specific_configs: Optional[dict[str, GretelModelConfig]] = None, - only: Optional[set[str]] = None, - ignore: Optional[set[str]] = None, - encode_keys: bool = False, - in_place: bool = False, - identifier: Optional[str] = None, - ) -> None: - configs = assemble_configs( - self.relational_data, config, table_specific_configs, only, ignore - ) - _validate_all_transform_v2_configs(configs) - - self._setup_transforms_train_state(configs) - task = TransformsTrainTask( - transforms_train=self._transforms_train, - ctx=self._new_task_context(), - ) - run_task(task, self._extended_sdk) - - output_tables = {} - for table, model in self._transforms_train.models.items(): - if table in task.completed: - with model.get_artifact_handle("data_preview") as data_preview: - output_tables[table] = pd.read_csv(data_preview) - - self._post_process_transformed_tables( - output_tables=output_tables, - identifier=identifier or f"transforms_{_timestamp()}", - encode_keys=encode_keys, - in_place=in_place, - ) - - def train_transforms( - self, - config: GretelModelConfig, - *, - table_specific_configs: Optional[dict[str, GretelModelConfig]] = None, - only: Optional[set[str]] = None, - ignore: Optional[set[str]] = None, - ) -> None: - configs = assemble_configs( - self.relational_data, config, table_specific_configs, only, ignore - ) - self._setup_transforms_train_state(configs) - task = TransformsTrainTask( - transforms_train=self._transforms_train, - ctx=self._new_task_context(), - ) - run_task(task, self._extended_sdk) - - def run_transforms( - self, - identifier: Optional[str] = None, - in_place: bool = False, - data: Optional[dict[str, UserFriendlyDataT]] = None, - encode_keys: bool = False, - ) -> None: - """ - Run pre-trained Gretel Transform models on Relational table data: - - Args: - identifier: Unique string identifying a specific call to this method. Defaults to `transforms_` + current timestamp - in_place: If True, overwrites source data in all locations - (internal Python state, local working directory, project artifact archive). - Used for transforms->synthetics workflows. - data: If supplied, runs only the supplied data through the corresponding transforms models. - Otherwise runs source data through all existing transforms models. - encode_keys: If set, primary and foreign keys will be replaced with label encoded variants. This can add - an additional level of privacy at the cost of referential integrity between transformed and - original data. - """ - if encode_keys and len(self.relational_data.foreign_key_cycles) > 0: - raise MultiTableException( - "Cannot encode keys when schema includes cyclic foreign key relationships." - ) - - if data is not None: - unrunnable_tables = [ - table - for table in data - if not _table_trained_successfully(self._transforms_train, table) - ] - if len(unrunnable_tables) > 0: - raise MultiTableException( - f"Cannot run transforms on provided data without successfully trained models for {unrunnable_tables}" - ) - - identifier = identifier or f"transforms_{_timestamp()}" - logger.info(f"Starting transforms run `{identifier}`") - transforms_run_paths = {} - - data_sources = data or { - table: self.relational_data.get_table_source(table) - for table in self._transforms_train.models - if _table_trained_successfully(self._transforms_train, table) - } - - for table, data_source in data_sources.items(): - transforms_run_path = self._output_handler.filepath_for( - f"transforms_input_{table}.csv" - ) - if isinstance(data_source, pd.DataFrame): - data_source.to_csv(transforms_run_path, index=False) - else: - with ( - open_artifact(data_source, "rb") as src, - open_artifact(transforms_run_path, "wb") as dest, - ): - shutil.copyfileobj(src, dest) - transforms_run_paths[table] = transforms_run_path - - transforms_record_handlers: dict[str, RecordHandler] = {} - - for table_name, transforms_run_path in transforms_run_paths.items(): - model = self._transforms_train.models[table_name] - record_handler = model.create_record_handler_obj( - data_source=transforms_run_path - ) - transforms_record_handlers[table_name] = record_handler - - task = TransformsRunTask( - record_handlers=transforms_record_handlers, - ctx=self._new_task_context(), - ) - run_task(task, self._extended_sdk) - - self._post_process_transformed_tables( - output_tables=task.output_tables, - identifier=identifier, - encode_keys=encode_keys, - in_place=in_place, - ) - - def _post_process_transformed_tables( - self, - output_tables: dict[str, pd.DataFrame], - identifier: str, - encode_keys: bool, - in_place: bool, - ) -> None: - if encode_keys: - output_tables = self._strategy.label_encode_keys( - self.relational_data, output_tables - ) - - if in_place: - for table_name, transformed_table in output_tables.items(): - self.relational_data.update_table_data(table_name, transformed_table) - self._upload_sources_to_project() - - reshaped_tables = self.relational_data.restore(output_tables) - - run_subdir = self._output_handler.make_subdirectory(identifier) - final_output_filepaths = {} - for table, df in reshaped_tables.items(): - filename = f"transformed_{table}.csv" - out_path = self._output_handler.filepath_for(filename, subdir=run_subdir) - with open_artifact( - out_path, - "wb", - ) as dest: - df.to_csv( - dest, - index=False, - columns=self.relational_data.get_table_columns(table), - ) - final_output_filepaths[table] = out_path - - self._output_handler.save_transforms_outputs(final_output_filepaths, run_subdir) - - self._backup() - self.transform_output_tables = reshaped_tables - - def _train_synthetics_models(self, configs: dict[str, dict[str, Any]]) -> None: - """ - Uses the configured strategy to prepare training data sources for each table, - exported to the working directory. Creates a model for each table and submits - it for training. Upon completion, downloads the evaluation reports for each - table to the working directory. - """ - training_paths = { - table: self._output_handler.filepath_for(f"synthetics_train_{table}.csv") - for table in configs - } - - training_paths = self._strategy.prepare_training_data( - self.relational_data, training_paths - ) - - for table_name, config in configs.items(): - if table_name not in training_paths: - logger.info(f"Bypassing model training for table `{table_name}`") - self._synthetics_train.bypass.append(table_name) - continue - - synthetics_config = make_synthetics_config(table_name, config) - model = self._project.create_model_obj( - model_config=synthetics_config, - data_source=training_paths[table_name], - ) - self._synthetics_train.models[table_name] = model - - self._backup() - - task = SyntheticsTrainTask( - synthetics_train=self._synthetics_train, - ctx=self._new_task_context(), - ) - run_task(task, self._extended_sdk) - - def train_synthetics( - self, - *, - config: Optional[GretelModelConfig] = None, - table_specific_configs: Optional[dict[str, GretelModelConfig]] = None, - only: Optional[set[str]] = None, - ignore: Optional[set[str]] = None, - ) -> None: - """ - Train synthetic data models for the tables in the tableset, - optionally scoped by either `only` or `ignore`. - """ - if len(self.relational_data.foreign_key_cycles) > 0: - raise MultiTableException( - "Cyclic foreign key relationships are not supported by relational synthetics." - ) - - if config is None: - config = self._strategy.default_config - - configs = assemble_configs( - self.relational_data, config, table_specific_configs, only, ignore - ) - - # validate table scope (preserved tables) against the strategy - excluded_tables = [ - table - for table in self.relational_data.list_all_tables() - if table not in configs - ] - self._strategy.validate_preserved_tables(excluded_tables, self.relational_data) - - # validate all provided model configs are supported by the strategy - for conf in configs.values(): - self._validate_synthetics_config(conf) - - self._train_synthetics_models(configs) - - def retrain_tables(self, tables: dict[str, UserFriendlyDataT]) -> None: - """ - Provide updated table data and retrain. This method overwrites the table data in the - `RelationalData` instance. It should be used when initial training fails and source data - needs to be altered, but progress on other tables can be left as-is. - """ - # The strategy determines the full set of tables that need to be retrained based on those provided. - tables_to_retrain = self._strategy.tables_to_retrain( - list(tables.keys()), self.relational_data - ) - - # Grab the configs from the about-to-be-replaced models. If any can't be found, - # we have to abort because we don't know what model config to use with the new data. - configs = {} - for table in tables_to_retrain: - if (old_model := self._synthetics_train.models.get(table)) is None: - raise MultiTableException( - f"Could not find an existing model for table `{table}`. You may need to rerun all training via `train_synthetics`." - ) - else: - configs[table] = old_model.model_config - - # Orphan the old models - for table in tables_to_retrain: - del self._synthetics_train.models[table] - - # Update the source table data. - for table_name, table_data in tables.items(): - self.relational_data.update_table_data(table_name, table_data) - self._upload_sources_to_project() - - # Train new synthetics models for the subset of tables - self._train_synthetics_models(configs) - - def _upload_sources_to_project(self) -> None: - self._output_handler.save_sources(self.relational_data) - self._backup() - - def generate( - self, - record_size_ratio: float = 1.0, - preserve_tables: Optional[list[str]] = None, - identifier: Optional[str] = None, - resume: bool = False, - ) -> None: - """ - Sample synthetic data from trained models. - Tables that did not train successfully will be omitted from the output dictionary. - Tables listed in `preserve_tables` *may differ* from source tables in foreign key columns, to ensure - joining to parent tables (which may have been synthesized) continues to work properly. - - Args: - record_size_ratio (float, optional): Ratio to upsample real world data size with. Defaults to 1. - preserve_tables (list[str], optional): List of tables to skip sampling and leave (mostly) identical to source. - identifier (str, optional): Unique string identifying a specific call to this method. Defaults to `synthetics_` + current timestamp. - resume (bool, optional): Set to True when restoring from a backup to complete a previous, interrupted run. - """ - if resume: - if identifier is not None: - logger.warning( - "Cannot set identifier when resuming previous generation. Ignoring." - ) - if record_size_ratio is not None: - logger.warning( - "Cannot set record_size_ratio when resuming previous generation. Ignoring." - ) - if preserve_tables is not None: - logger.warning( - "Cannot set preserve_tables when resuming previous generation. Ignoring." - ) - if self._synthetics_run is None: - raise MultiTableException( - "Cannot resume a synthetics generation run without existing run information." - ) - logger.info(f"Resuming synthetics run `{self._synthetics_run.identifier}`") - else: - preserve_tables = preserve_tables or [] - preserve_tables.extend( - [ - table - for table in self.relational_data.list_all_tables() - if table not in self._synthetics_train.models - and table not in self._synthetics_train.bypass - ] - ) - self._strategy.validate_preserved_tables( - preserve_tables, self.relational_data - ) - - identifier = identifier or f"synthetics_{_timestamp()}" - - self._synthetics_run = SyntheticsRun( - identifier=identifier, - record_size_ratio=record_size_ratio, - preserved=preserve_tables, - record_handlers={}, - lost_contact=[], - ) - logger.info(f"Starting synthetics run `{self._synthetics_run.identifier}`") - - run_subdir = self._output_handler.make_subdirectory( - self._synthetics_run.identifier - ) - - task = SyntheticsRunTask( - synthetics_run=self._synthetics_run, - synthetics_train=self._synthetics_train, - subdir=run_subdir, - output_handler=self._output_handler, - ctx=self._new_task_context(), - rel_data=self.relational_data, - strategy=self._strategy, - ) - run_task(task, self._extended_sdk) - - output_tables = self._strategy.post_process_synthetic_results( - synth_tables=task.output_tables, - preserved=self._synthetics_run.preserved, - rel_data=self.relational_data, - record_size_ratio=self._synthetics_run.record_size_ratio, - ) - - reshaped_tables = self.relational_data.restore(output_tables) - - synthetic_table_filepaths = {} - for table, synth_df in reshaped_tables.items(): - synth_csv_path = self._output_handler.filepath_for( - f"synth_{table}.csv", subdir=run_subdir - ) - with open_artifact(synth_csv_path, "wb") as dest: - synth_df.to_csv( - dest, - index=False, - columns=self.relational_data.get_table_columns(table), - ) - synthetic_table_filepaths[table] = synth_csv_path - - individual_evaluate_models = {} - cross_table_evaluate_models = {} - for table, synth_df in output_tables.items(): - if table in self._synthetics_run.preserved: - continue - - if table not in self._synthetics_train.models: - continue - - if table not in self.relational_data.list_all_tables(Scope.EVALUATABLE): - continue - - # Create an evaluate model for individual SQS - individual_data = self._get_individual_evaluate_data( - table=table, - synthetic_tables=output_tables, - ) - individual_sqs_job = self._project.create_model_obj( - model_config=make_evaluate_config(table, "individual"), - data_source=individual_data["synthetic"], - ref_data=individual_data["source"], - ) - individual_evaluate_models[table] = individual_sqs_job - - # Create an evaluate model for cross-table SQS (if we can/should) - cross_table_data = self._get_cross_table_evaluate_data( - table=table, - synthetic_tables=output_tables, - ) - if cross_table_data is not None: - cross_table_sqs_job = self._project.create_model_obj( - model_config=make_evaluate_config(table, "cross_table"), - data_source=cross_table_data["synthetic"], - ref_data=cross_table_data["source"], - ) - cross_table_evaluate_models[table] = cross_table_sqs_job - - synthetics_evaluate_task = SyntheticsEvaluateTask( - individual_evaluate_models=individual_evaluate_models, - cross_table_evaluate_models=cross_table_evaluate_models, - subdir=run_subdir, - output_handler=self._output_handler, - evaluations=self._evaluations, - ctx=self._new_task_context(), - ) - run_task(synthetics_evaluate_task, self._extended_sdk) - - relational_report_filepath = None - if self.relational_data.any_table_relationships(): - logger.info("Creating relational report") - relational_report_filepath = self._output_handler.filepath_for( - "relational_report.html", subdir=run_subdir - ) - self.create_relational_report( - run_identifier=self._synthetics_run.identifier, - filepath=relational_report_filepath, - ) - - self._output_handler.save_synthetics_outputs( - tables=synthetic_table_filepaths, - table_reports=synthetics_evaluate_task.report_filepaths, - relational_report=relational_report_filepath, - run_subdir=run_subdir, - ) - self.synthetic_output_tables = reshaped_tables - self._backup() - - def _get_individual_evaluate_data( - self, table: str, synthetic_tables: dict[str, pd.DataFrame] - ) -> dict[str, pd.DataFrame]: - """ - Returns a dictionary containing source and synthetic versions of a table, - to be used in an Evaluate job. - - Removes all key columns to avoid artificially deflating the score - (key types may not match, and key values carry no semantic meaning). - """ - all_cols = self.relational_data.get_table_columns(table) - key_cols = self.relational_data.get_all_key_columns(table) - use_cols = [c for c in all_cols if c not in key_cols] - - return { - "source": self.relational_data.get_table_data(table, usecols=use_cols), - "synthetic": synthetic_tables[table].drop(columns=key_cols), - } - - def _get_cross_table_evaluate_data( - self, table: str, synthetic_tables: dict[str, pd.DataFrame] - ) -> Optional[dict[str, pd.DataFrame]]: - """ - Returns a dictionary containing source and synthetic versions of a table - with ancestral data attached, to be used in an Evaluate job for cross-table SQS. - - Removes all key columns to avoid artificially deflating the score - (key types may not match, and key values carry no semantic meaning). - - Returns None if a cross-table SQS job cannot or should not be performed. - """ - # Exit early if table does not have parents (no need for cross-table evaluation) - if len(self.relational_data.get_parents(table)) == 0: - return None - - # Exit early if we can't create synthetic cross-table data - # (e.g. parent data missing due to job failure) - missing_ancestors = [ - ancestor - for ancestor in self.relational_data.get_ancestors(table) - if ancestor not in synthetic_tables - ] - if len(missing_ancestors) > 0: - logger.info( - f"Cannot run cross_table evaluations for `{table}` because no synthetic data exists for ancestor tables {missing_ancestors}." - ) - return None - - source_data = ancestry.get_table_data_with_ancestors( - self.relational_data, table - ) - synthetic_data = ancestry.get_table_data_with_ancestors( - self.relational_data, table, synthetic_tables - ) - key_cols = ancestry.get_all_key_columns(self.relational_data, table) - return { - "source": source_data.drop(columns=key_cols), - "synthetic": synthetic_data.drop(columns=key_cols), - } - - def create_relational_report(self, run_identifier: str, filepath: str) -> None: - presenter = ReportPresenter( - rel_data=self.relational_data, - evaluations=self.evaluations, - now=datetime.utcnow(), - run_identifier=run_identifier, - ) - with open_artifact(filepath, "w") as report: - html_content = ReportRenderer().render(presenter) - report.write(html_content) - - def _new_task_context(self) -> TaskContext: - return TaskContext( - in_flight_jobs=0, - refresh_interval=self._refresh_interval, - project=self._project, - extended_sdk=self._extended_sdk, - backup=self._backup, - ) - - def _validate_synthetics_config(self, config_dict: dict[str, Any]) -> None: - """ - Validates that the provided config (in dict form) - is supported by the configured synthetics strategy - """ - if (model_key := get_model_key(config_dict)) is None: - raise MultiTableException("Invalid config") - else: - supported_models = self._strategy.supported_model_keys - if model_key not in supported_models: - raise MultiTableException( - f"Invalid gretel model requested: {model_key}. " - f"The selected strategy supports: {supported_models}." - ) - - -def _configure_session( - project: Optional[Project], session: Optional[ClientConfig] -) -> ClientConfig: - if session is None and project is not None: - session = project.session - - return add_session_context( - session=session, client_metrics=RELATIONAL_SESSION_METADATA - ) - - -def _validate_strategy(strategy: str) -> Union[IndependentStrategy, AncestralStrategy]: - strategy = strategy.lower() - - if strategy == "independent": - return IndependentStrategy() - elif strategy == "ancestral": - return AncestralStrategy() - else: - msg = f"Unrecognized strategy requested: {strategy}. Supported strategies are `independent` and `ancestral`." - logger.warning(msg) - raise MultiTableException(msg) - - -def _validate_all_transform_v2_configs(configs: dict[str, dict]) -> None: - invalid = 0 - for table, config in configs.items(): - model_key = get_model_key(config) - if model_key is None or model_key != "transform_v2": - logger.warning( - f"Invalid model config for `{table}`, expected `transform_v2` model type, got `{model_key}`" - ) - invalid += 1 - - if invalid > 0: - raise MultiTableException(f"Received {invalid} invalid model configs.") - - -def _table_trained_successfully(train_state: TransformsTrain, table: str) -> bool: - model = train_state.models.get(table) - if model is None: - return False - else: - return model.status == Status.COMPLETED - - -def _timestamp() -> str: - return datetime.now().strftime("%Y%m%d%H%M%S") diff --git a/src/gretel_trainer/relational/output_handler.py b/src/gretel_trainer/relational/output_handler.py deleted file mode 100644 index 941c47e5..00000000 --- a/src/gretel_trainer/relational/output_handler.py +++ /dev/null @@ -1,216 +0,0 @@ -import json -import logging -import shutil -import tempfile - -from pathlib import Path -from typing import Optional, Protocol - -from gretel_client.projects import Project -from gretel_trainer.relational.backup import Backup -from gretel_trainer.relational.core import RelationalData - -logger = logging.getLogger(__name__) - - -class OutputHandler(Protocol): - def filepath_for(self, filename: str, subdir: Optional[str] = None) -> str: - """ - Returns a string handle that can be used with smart_open to write data to an internal location. - """ - ... - - def make_subdirectory(self, name: str) -> str: - """ - Returns a string that can be passed to `filepath_for` to support output organization. - """ - ... - - def get_source_archive(self) -> Optional[str]: - """ - Returns an ID for a source archive artifact if one exists. - """ - - def save_sources(self, relational_data: RelationalData) -> None: - """ - Callback when source data is finalized for an implementation to persist source data. - """ - ... - - def save_backup(self, backup: Backup) -> None: - """ - Callback at several notable moments of execution for an implementation to persist backup data. - """ - ... - - def save_debug_summary(self, content: dict) -> None: - """ - Callback when initial state is set up to persist debug information. - """ - ... - - def save_classify_outputs(self, filepaths: dict[str, str]) -> None: - """ - Callback when classify completes to persist classify output data. - """ - ... - - def save_transforms_outputs( - self, filepaths: dict[str, str], run_subdir: str - ) -> None: - """ - Callback when transform completes to persist transform output data. - """ - ... - - def save_synthetics_outputs( - self, - tables: dict[str, str], - table_reports: dict[str, dict[str, dict[str, str]]], - relational_report: Optional[str], - run_subdir: str, - ) -> None: - """ - Callback when synthetics completes to persist synthetic output data. - """ - ... - - -class SDKOutputHandler: - def __init__( - self, - workdir: str, - project: Project, - hybrid: bool, - source_archive: Optional[str], - ): - self._project = project - self._hybrid = hybrid - self._working_dir = _mkdir(workdir) - self._source_archive = source_archive - - def filepath_for(self, filename: str, subdir: Optional[str] = None) -> str: - """ - Returns a path inside the working directory for the provided file. - """ - if subdir is not None: - return str(self._working_dir / subdir / filename) - else: - return str(self._working_dir / filename) - - def make_subdirectory(self, name: str) -> str: - """ - Creates a subdirectory in the working dir with name, - and returns just the name (or "stem") back. - """ - _mkdir(str(self._working_dir / name)) - return name - - def get_source_archive(self) -> Optional[str]: - return self._source_archive - - def save_sources(self, relational_data: RelationalData) -> None: - """ - Creates an archive of all tables' source files and uploads it as a project artifact. - """ - source_data_dir = relational_data.source_data_handler.dir # type:ignore - latest = self._archive_and_upload( - archive_name=str(self._working_dir / "source_tables.tar.gz"), - dir_to_archive=source_data_dir, - ) - - # Delete the older version if present (Cloud-only, deletes not supported in Hybrid). - if (not self._hybrid) and (existing := self._source_archive) is not None: - self._project.delete_artifact(existing) - - self._source_archive = latest - - def save_backup(self, backup: Backup) -> None: - """ - Writes backup data as JSON to the working directory, - uploads the file as a project artifact, and deletes any stale backups. - """ - backup_path = self._working_dir / "_gretel_backup.json" - with open(backup_path, "w") as bak: - json.dump(backup.as_dict, bak) - - # Exit early if hybrid, because this should be a "singleton" project artifact - # and we cannot delete hybrid project artifacts. - if self._hybrid: - return None - - # The backup file does not use the ArtifactCollection because the AC's data - # (artifact keys) is included in the full backup data, so we would end up - # "chasing our own tail", for example: - # - create backup data with AC.backup_key=1, write to file - # - upload backup file => new backup_key returned from API => AC.backup_key=2 - # Backup data would always be stale and we'd write more backups than we need - # (we skip uploading backup data if we detect no changes from the latest snapshot). - latest = self._project.upload_artifact(str(backup_path)) - for artifact in self._project.artifacts: - key = artifact["key"] - if key != latest and key.endswith("__gretel_backup.json"): - self._project.delete_artifact(key) - - def save_debug_summary(self, content: dict) -> None: - """ - Writes the debug summary content as JSON to the working directory. - """ - debug_summary_path = self._working_dir / "_gretel_debug_summary.json" - with open(debug_summary_path, "w") as dbg: - json.dump(content, dbg) - - def save_classify_outputs(self, filepaths: dict[str, str]) -> None: - """ - Creates an archive file of the provided classify output files and uploads it - as a project artifact. - """ - with tempfile.TemporaryDirectory() as tmpdir: - for item in filepaths.values(): - shutil.copy(item, tmpdir) - - self._archive_and_upload( - archive_name=str(self._working_dir / "classify_outputs.tar.gz"), - dir_to_archive=Path(tmpdir), - ) - - def save_transforms_outputs( - self, filepaths: dict[str, str], run_subdir: str - ) -> None: - """ - Archives the run subdirectory and uploads it to the project. - """ - self._archive_and_upload_run_outputs(run_subdir) - - def save_synthetics_outputs( - self, - tables: dict[str, str], - table_reports: dict[str, dict[str, dict[str, str]]], - relational_report: Optional[str], - run_subdir: str, - ) -> None: - """ - Archives the run subdirectory and uploads it to the project. - """ - self._archive_and_upload_run_outputs(run_subdir) - - def _archive_and_upload_run_outputs(self, run_subdir: str) -> None: - root_dir = self._working_dir / run_subdir - self._archive_and_upload( - archive_name=str(root_dir), - dir_to_archive=root_dir, - ) - - def _archive_and_upload(self, archive_name: str, dir_to_archive: Path) -> str: - archive_location = shutil.make_archive( - base_name=archive_name.removesuffix(".tar.gz"), - format="gztar", - root_dir=dir_to_archive, - ) - return self._project.upload_artifact(archive_location) - - -def _mkdir(name: str) -> Path: - d = Path(name) - d.mkdir(parents=True, exist_ok=True) - return d diff --git a/src/gretel_trainer/relational/report/figures.py b/src/gretel_trainer/relational/report/figures.py deleted file mode 100644 index 46695824..00000000 --- a/src/gretel_trainer/relational/report/figures.py +++ /dev/null @@ -1,147 +0,0 @@ -import math - -from typing import Optional - -import plotly.graph_objects as go - -_GRETEL_PALETTE = ["#A051FA", "#18E7AA"] - -SCORE_VALUES = [ - {"label": "Very poor", "color": "rgb(229, 60, 26)"}, - {"label": "Poor", "color": "rgb(229, 128, 26)"}, - {"label": "Average", "color": "rgb(229, 161, 26)"}, - {"label": "Good", "color": "rgb(183, 210, 45)"}, - {"label": "Excellent", "color": "rgb(72, 210, 45)"}, -] - -PRIVACY_LEVEL_VALUES = [ - {"label": "Poor", "color": "rgb(203, 210, 252)"}, - {"label": "Normal", "color": "rgb(160, 171, 245)"}, - {"label": "Good", "color": "rgb(124, 135, 233)"}, - {"label": "Very Good", "color": "rgb(83, 81, 222)"}, - {"label": "Excellent", "color": "rgb(59, 46, 208)"}, -] - - -def _generate_pointer_path(score: int): - """ - Helper to generate an svg path for the needle in the gauge and needle chart. The path is a triangle, - basically a tall skinny pyramid with the base at the center of the circle and the apex at the score - on the outer ring. - - Args: - score: Integer score in [0,100]. Pointer path will point at this value on the gauge. - - Returns: - A string containing the raw svg path. It does NOT return any tags. - - """ - theta = score * (282 - 34) / 100 - 34 - rads = math.radians(theta) - - radius = 0.45 - size = 0.025 - - x1 = -1 * radius * math.cos(rads) + 0.5 - y1 = radius * math.sin(rads) + 0.5 - return f""" - M {x1} {y1} - L {-1 * size * math.cos(math.radians(theta - 90)) + 0.5} - {size * math.sin(math.radians(theta - 90)) + 0.5} - L {-1 * size * math.cos(math.radians(theta + 90)) + 0.5} - {size * math.sin(math.radians(theta + 90)) + 0.5} - Z""" - - -def gauge_and_needle_chart( - score: Optional[int], - display_score: bool = True, - marker_colors: Optional[list[str]] = None, -) -> go.Figure: - """ - The 'fancy' gauge and needle chart to go with the overall score of the report. Has colored segments for - each grade range and an svg pointer supplied by _generate_pointer_path - - Args: - score: Integer score in [0,100]. Pointer path will point at this value on the gauge. - - Returns: - A plotly.graph_objects.Figure - - """ - if score is None: - fig = go.Figure( - layout=go.Layout( - annotations=[ - go.layout.Annotation( - text="N/A", - font=dict(color="rgba(174, 95, 5, 1)", size=18), - showarrow=False, - xref="paper", - yref="paper", - x=0.5, - y=0.5, - ) - ] - ) - ) - marker_colors = ["rgb(220, 220, 220)", "rgba(255, 255, 255, 0)"] - pie_values = [70, 30] - else: - if not marker_colors: - marker_colors = [s["color"] for s in SCORE_VALUES] - if marker_colors[-1] != "rgba(255, 255, 255, 0)": - marker_colors.append("rgba(255, 255, 255, 0)") - pie_values = [70 // (len(marker_colors) - 1)] * (len(marker_colors) - 1) - pie_values.append(30) - fig = go.Figure() - - fig.update_layout( - autosize=False, - showlegend=False, - xaxis=dict(visible=False), - yaxis=dict(visible=False), - height=180, - width=180, - margin=dict(l=0, r=0, t=0, b=0), - paper_bgcolor="rgba(0,0,0,0)", - hovermode=False, - modebar=None, - ) - fig.add_trace( - go.Pie( - name="gauge", - values=pie_values, - marker=dict( - colors=marker_colors, - line=dict(width=4, color="#fafafa"), - ), - hole=0.75, - direction="clockwise", - sort=False, - rotation=234, - showlegend=False, - hoverinfo="none", - textinfo="none", - textposition="outside", - ) - ) - - if score is not None: - if display_score: - fig.add_trace( - go.Indicator( - mode="number", value=score, domain=dict(x=[0, 1], y=[0.28, 0.45]) - ) - ) - fig.add_shape( - type="circle", fillcolor="black", x0=0.475, x1=0.525, y0=0.475, y1=0.525 - ) - fig.add_shape( - type="path", - fillcolor="black", - line=dict(width=0), - path=_generate_pointer_path(score), - ) - - return fig diff --git a/src/gretel_trainer/relational/report/key_highlight.js b/src/gretel_trainer/relational/report/key_highlight.js deleted file mode 100644 index b7bafda6..00000000 --- a/src/gretel_trainer/relational/report/key_highlight.js +++ /dev/null @@ -1,65 +0,0 @@ -function mouseOverPk(event) { - event.preventDefault(); - let { currentTarget } = event; - let target = currentTarget; - - target.classList.add("label__highlighted"); - target.parentNode.parentNode.classList.add("row__highlighted"); - - document.querySelectorAll(`[data-fk-from=${target.id}]`).forEach( - fk => { - fk.classList.add("label__highlighted"); - fk.parentNode.parentNode.classList.add("row__highlighted"); - } - ) -} - -function mouseOutPk(event) { - event.preventDefault(); - let { currentTarget } = event; - let target = currentTarget; - - target.classList.remove("label__highlighted"); - target.parentNode.parentNode.classList.remove("row__highlighted"); - - document.querySelectorAll(`[data-fk-from=${target.id}]`).forEach( - fk => { - fk.classList.remove("label__highlighted"); - fk.parentNode.parentNode.classList.remove("row__highlighted"); - } - ) -} - -function mouseOverFk(event) { - event.preventDefault(); - let { currentTarget } = event; - let target = currentTarget; - - target.classList.add("label__highlighted"); - target.parentNode.parentNode.classList.add("row__highlighted"); - - let pk = document.getElementById(target.getAttribute('data-fk-from')); - pk.classList.add("label__highlighted"); - pk.parentNode.parentNode.classList.add("row__highlighted"); -} - -function mouseOutFk(event) { - event.preventDefault(); - let { currentTarget } = event; - let target = currentTarget; - - target.classList.remove("label__highlighted"); - target.parentNode.parentNode.classList.remove("row__highlighted"); - - let pk = document.getElementById(target.getAttribute('data-fk-from')); - pk.classList.remove("label__highlighted"); - pk.parentNode.parentNode.classList.remove("row__highlighted"); -} - -const pks = document.querySelectorAll('[data-pk]'); -pks.forEach(pk => { pk.onmouseover = mouseOverPk }); -pks.forEach(pk => { pk.onmouseout = mouseOutPk }); - -const fks = document.querySelectorAll('[data-fk-from]'); -fks.forEach(fk => { fk.onmouseover = mouseOverFk }); -fks.forEach(fk => { fk.onmouseout = mouseOutFk }); diff --git a/src/gretel_trainer/relational/report/report.css b/src/gretel_trainer/relational/report/report.css deleted file mode 100644 index b2c8140d..00000000 --- a/src/gretel_trainer/relational/report/report.css +++ /dev/null @@ -1,122 +0,0 @@ - html, - body { - font-size: 16px; - color: #25212B; - margin: 0; - padding: 0; - background-color: #fafafa; - line-height: 1.5; - font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", Roboto, - Helvetica, Arial, sans-serif, "Apple Color Emoji", "Segoe UI Emoji", - "Segoe UI Symbol"; - } - a { - color: #0c052e; - } - a:hover { - color: #3c1ae5; - } - .wrapper { - margin: auto; - max-width: 75ch; - max-width: 960px; - /* border-left: 1px solid #E5E5E6; - border-right: 1px solid #E5E5E6; */ - } - section { - background-color: #fff; - padding: 2rem; - margin: 2rem 0; - box-shadow: 0 4px 12px rgba(0, 0, 0, 0.2); - } - section.break { - page-break-after: always; - } - figure { - text-align: center; - /* border: 2px solid #180a5c; */ - width: 100%; - margin: 0 0 2rem 0; - } - figcaption { - text-align: center; - color: #646369; - font-size: 0.875rem; - margin: 0.5rem 0; - } - header { - padding: 2rem 0 4rem 0; - margin-bottom: -4rem; - color: #fff; - background-color: #0c0c0d; - } - h1 { - color: #25212B; - font-size: 3rem; - font-weight: 500; - line-height: 1.2; - margin-bottom: 0.5em; - } - h2 { - color: #25212B; - margin: 0; - page-break-after: avoid; - } - .note { - background-color: #faecd1; - } - /* TABLE */ - table { - width: 100%; - border-collapse: collapse; - } - th { - background-color: #fafafa; - text-align: left; - border: none; - } - th, - td { - padding: 0.25rem 0.5rem; - border-bottom: 1px solid #e5e5e6; - border-top: 1px solid #e5e5e6; - border-right: 1px solid #e5e5e6; - } - th:last-of-type, - td:last-of-type { - border-right: none; - } - th.tnum, - td.tnum { - font-feature-settings: "tnum"; - text-align: right; - } - .highlight { - background-color: #fafafa; - font-family: monospace; - font-weight: normal; - } - footer { - color: #646369; - } - @media print { - html, - body { - font-size: 12px; - background-color: transparent; - } - header { - color: #000; - background-color: transparent; - } - section { - box-shadow: none; - border: 1px solid #e5e5e6; - } - a { - text-decoration: none; - } - a:not([href*="#"])::after { - content: " (" attr(href) ")"; - } - } diff --git a/src/gretel_trainer/relational/report/report.py b/src/gretel_trainer/relational/report/report.py deleted file mode 100644 index ac2f188f..00000000 --- a/src/gretel_trainer/relational/report/report.py +++ /dev/null @@ -1,161 +0,0 @@ -from __future__ import annotations - -import datetime - -from dataclasses import dataclass -from functools import cached_property -from math import ceil -from pathlib import Path -from typing import Optional - -import plotly.graph_objects as go - -from jinja2 import Environment, FileSystemLoader - -from gretel_trainer.relational.core import ForeignKey, RelationalData, Scope -from gretel_trainer.relational.report.figures import ( - gauge_and_needle_chart, - PRIVACY_LEVEL_VALUES, -) -from gretel_trainer.relational.table_evaluation import TableEvaluation - -_TEMPLATE_DIR = str(Path(__file__).parent) -_TEMPLATE_FILE = "report_template.html" - - -class ReportRenderer: - def __init__(self): - file_loader = FileSystemLoader(_TEMPLATE_DIR) - env = Environment(loader=file_loader) - self.template = env.get_template(_TEMPLATE_FILE) - - def render(self, presenter: ReportPresenter) -> str: - return self.template.render(presenter=presenter) - - -@dataclass -class ReportTableData: - table: str - pk: list[str] - fks: list[ForeignKey] - - -@dataclass -class ReportPresenter: - rel_data: RelationalData - now: datetime.datetime - run_identifier: str - evaluations: dict[str, TableEvaluation] - - @property - def generated_at(self) -> str: - return self.now.strftime("%Y-%m-%d") - - @property - def copyright_year(self) -> str: - return self.now.strftime("%Y") - - @cached_property - def composite_sqs_score_and_grade(self) -> tuple[Optional[int], str]: - # Add up all the non-None SQS scores and track how many there are. - _total_score = 0 - _num_scores = 0 - for eval in self.evaluations.values(): - if eval.individual_sqs is not None: - _total_score += eval.individual_sqs - _num_scores += 1 - if eval.cross_table_sqs is not None: - _total_score += eval.cross_table_sqs - _num_scores += 1 - - # Take the average. - if _num_scores > 0: - score = ceil(_total_score / _num_scores) - # Or tell the user the bad news. - else: - score = None - return (score, self.sqs_score_to_grade(score)) - - @property - def composite_sqs_label(self) -> str: - _formatted_grade = self.css_label_format(self.composite_sqs_score_and_grade[1]) - return f"label__{_formatted_grade}" - - @property - def composite_sqs_figure(self) -> go.Figure: - score, grade = self.composite_sqs_score_and_grade - return gauge_and_needle_chart(score) - - @cached_property - def composite_ppl_score_and_grade(self) -> tuple[Optional[int], str]: - # Collect all the non-None PPLs, individual and cross-table. - scores = [ - eval.individual_ppl - for eval in self.evaluations.values() - if eval.individual_ppl is not None - ] - scores += [ - eval.cross_table_ppl - for eval in self.evaluations.values() - if eval.cross_table_ppl is not None - ] - # Take the min, our "weakest link" - if len(scores) > 0: - score = min(scores) - GRADES = ["Normal", "Good", "Very Good", "Excellent"] - if 0 <= score < 0.5: - return (score, GRADES[0]) - if 0.5 <= score < 2.5: - return (score, GRADES[1]) - if 2.5 <= score < 4.5: - return (score, GRADES[2]) - return (score, GRADES[3]) - # Or tell the user the bad news. - else: - GRADE_UNAVAILABLE = "Unavailable" - return (None, GRADE_UNAVAILABLE) - - @property - def composite_ppl_label(self) -> str: - _formatted_grade = self.css_label_format(self.composite_ppl_score_and_grade[1]) - return f"label__privacy__{_formatted_grade}" - - @property - def composite_ppl_figure(self) -> go.Figure: - score, grade = self.composite_ppl_score_and_grade - ppl_score_map = {0: 30, 1: 46, 2: 54, 3: 66, 4: 74, 5: 86, 6: 94} - if score is None: - ppl_score = 30 - else: - ppl_score = ppl_score_map.get(score, 30) - return gauge_and_needle_chart( - ppl_score, - display_score=False, - marker_colors=[s["color"] for s in PRIVACY_LEVEL_VALUES], - ) - - @property - def report_table_data(self) -> list[ReportTableData]: - table_data = [] - for table in self.rel_data.list_all_tables(Scope.PUBLIC): - pk = self.rel_data.get_primary_key(table) - fks = self.rel_data.get_foreign_keys(table, rename_invented_tables=True) - table_data.append(ReportTableData(table=table, pk=pk, fks=fks)) - - # Sort tables alphabetically because that's nice. - table_data = sorted(table_data, key=lambda x: x.table) - return table_data - - # Helper, making it "just a method" so it is easily accessible in jinja template. - def sqs_score_to_grade(self, score: Optional[int]) -> str: - GRADES = ["Very Poor", "Poor", "Moderate", "Good", "Excellent"] - GRADE_UNAVAILABLE = "Unavailable" - if score is None: - return GRADE_UNAVAILABLE - idx = score // 20 - # Constrain to [0,4] - idx = max(0, min(4, idx)) - return GRADES[idx] - - def css_label_format(self, grade: str) -> str: - return grade.lower().replace(" ", "-") diff --git a/src/gretel_trainer/relational/report/report_privacy_protection.css b/src/gretel_trainer/relational/report/report_privacy_protection.css deleted file mode 100644 index 5b151bb5..00000000 --- a/src/gretel_trainer/relational/report/report_privacy_protection.css +++ /dev/null @@ -1,40 +0,0 @@ -:root { - --color-privacy-poor: #CBD2FC; - --color-privacy-normal: #A0ABF5; - --color-privacy-good: #7C87E9; - --color-privacy-very-good: #5351DE; - --color-privacy-excellent: #3B2ED0; - --color-white: #FFFFFF; - --color-privacy-text: #144BB8; - --color-privacy-text-not-excellent: #0C052E; -} - -.grid-pps { - display: grid; - grid-template-columns: repeat(4, 1fr); -} - -.label__privacy__poor { - background-color: var(--color-privacy-poor); - color: var(--color-privacy-text-not-excellent); -} -.label__privacy__normal { - background-color: var(--color-privacy-normal); - color: var(--color-privacy-text-not-excellent); -} -.label__privacy__good { - background-color: var(--color-privacy-good); - color: var(--color-privacy-text-not-excellent); -} -.label__privacy__very-good { - background-color: var(--color-privacy-good); - color: var(--color-privacy-text-not-excellent); -} -.label__privacy__excellent { - background-color: var(--color-privacy-excellent); - color: var(--color-white); -} - -.privacy-level { - color: var(--color-privacy-text); -} diff --git a/src/gretel_trainer/relational/report/report_synthetic_quality.css b/src/gretel_trainer/relational/report/report_synthetic_quality.css deleted file mode 100644 index dad843a0..00000000 --- a/src/gretel_trainer/relational/report/report_synthetic_quality.css +++ /dev/null @@ -1,172 +0,0 @@ -:root { - --color-unavailable: #DCDCDC; - --color-v-poor: #F5B1A3; - --color-poor: #F5CCA3; - --color-moderate: #F5DAA3; - --color-good: #E9F1C0; - --color-excellent: #C8F1C0; - --color-highlighted: #FDE68A; - --color-highlighted-row: #FFF9E0; - --color-figure-v-poor: #E53C1A; - --color-figure-poor: #E5801A; - --color-figure-moderate: #E5A11A; - --color-figure-good: #B7D22D; - --color-figure-excellent: #48D22D; - --color-neutral-50: #FAFAFA; - --color-neutral-100: #E5E5E6; - --color-neutral-400: #97969C; - --color-plum: #C18DFC; - --gap-disclosure: 0.5rem; -} - -body { - font-family: "Helvetica Neue", Helvetica, Arial, sans-serif; - min-height: 100vh; - margin: 0; -} -body::before { - content: " "; - display: block; - width: 100%; - height: 0.5rem; - background-color: var(--color-plum); -} -.section { - width: 100%; - max-width: 70rem; - margin: 4rem auto; -} -.header__container { - display: grid; - grid-template-columns: 1fr 1fr 1fr; -} -.score-container { - display: flex; - align-items: center; - flex-direction: column; -} -.score-container-text { - font-size: 1.3125rem; - text-align: center; - float: right; -} -.score { - width: 180px; - height: 180px; -} -.ring { - margin: 1rem; - width: 90px; - height: 90px; - display: inline-block; -} -.grid { - display: grid; - grid-template-columns: repeat(3, 1fr); -} -.grid__cell { - display: flex; - flex-direction: column; - align-items: center; - border-bottom: 1px solid var(--color-neutral-100); - border-right: 1px solid var(--color-neutral-100); -} -.grid__cell[aria-expanded="true"] { - background-color: var(--color-neutral-50); -} -.grid__cell:last-of-type { - border-right: none; -} -.grid__cell__row { - display: flex; - align-items: center; - width: 95%; -} - -.label { - display: inline-block; - font-size: 0.9375rem; - line-height: 1; - padding: 0.25rem 0.75rem; - border-radius: 1.25rem; -} -.label__unavailable { - background-color: var(--color-unavailable); -} -.label__very-poor { - background-color: var(--color-v-poor); -} -.label__poor { - background-color: var(--color-poor); -} -.label__moderate { - background-color: var(--color-moderate); -} -.label__good { - background-color: var(--color-good); -} -.label__excellent { - background-color: var(--color-excellent); -} - -.label__highlighted { - background-color: var(--color-highlighted); -} -.row__highlighted { - background-color: var(--color-highlighted-row); -} - -.figure__very-poor { - background-color: var(--color-figure-v-poor); -} -.figure__poor { - background-color: var(--color-figure-poor); -} -.figure__moderate { - background-color: var(--color-figure-moderate); -} -.figure__good { - background-color: var(--color-figure-good); -} -.figure__excellent { - background-color: var(--color-figure-excellent); -} - -.sqs-table-score { - padding-left: 1rem; - padding-right: 1rem; -} -.sqs-table-link { - float: right; - padding-right: 1rem; -} - -/* Utils */ -.h3 { - font-size: 1.3125rem; -} -.m-0\.25 { - margin: 0.25rem 0; -} -.mb-0\.5 { - margin-bottom: 0.5rem; -} -.mb-1 { - margin-bottom: 1rem; -} -.mt-0 { - margin-top: 0; -} -.mt-0\.5 { - margin-top: 0.5rem; -} -.mt-1 { - margin-top: 1rem; -} -.mt--2 { - margin-top: -2rem; -} -.py-2 { - padding-top: 2rem; - padding-bottom: 2rem; -} diff --git a/src/gretel_trainer/relational/report/report_template.html b/src/gretel_trainer/relational/report/report_template.html deleted file mode 100644 index 92832e10..00000000 --- a/src/gretel_trainer/relational/report/report_template.html +++ /dev/null @@ -1,170 +0,0 @@ - - - - - - Gretel Relational Synthetic Report - - - - - -
- -
-
- -
-

Gretel Relational
Synthetic Report

- - Generated {{presenter.generated_at}} -
- Generation ID {{ presenter.run_identifier }} -
-
- -
- - {{ presenter.composite_sqs_score_and_grade[1] }} - -
-
{{ presenter.composite_sqs_figure.to_html(config={'displayModeBar': False}, full_html=False, include_plotlyjs=False) }}
-
-
- Composite
Synthetic Data Quality Score
-
-
- -
- - {{ presenter.composite_ppl_score_and_grade[1] }} - -
-
{{ presenter.composite_ppl_figure.to_html(config={'displayModeBar': False}, full_html=False, include_plotlyjs=False) }}
-
-
- Composite
Privacy Protection Level
-
-
- -
-
- -
-

Table Relationships

- -

- The primary and foreign keys for each table in the synthesized database and their - relationships are displayed below. -

- - - - - - - - {% for table_data in presenter.report_table_data %} - - - - - - {% endfor %} -
Table NamePrimary KeyForeign Keys
{{ table_data.table }} - {% for pk in table_data.pk %} - {{ pk }} - {% endfor %} - - {% for fk in table_data.fks %} - {% for i in range(fk.columns | length) %} - {{ fk.columns[i] }}  - {% endfor %} - {% endfor %} -
-
- -
-

Synthetic Data Quality Results

- -

For each table, individual and cross-table Gretel Synthetic Reports are generated, which - include the Synthetic Data - Quality Score (SQS). The individual Synthetic Report evaluates the statistical accuracy of the - individual synthetic table compared to the real world data that it is - based on. This provides insight into the accuracy of the synthetic output of the stand-alone - table. The individual SQS does not take into account statistical correlations of data across - related tables. The - cross-table Synthetic Report evaluates the statistical accuracy of the synthetic data of a - table with consideration to the - correlations between data across related tables. The cross-table SQS provides insight into - the accuracy of the - table in the context of the database as a whole. More information about the Gretel Synthetic - Report and Synthetic Data Quality Score is - available here.

- - Synthetic Data Quality Scores - -

For each table, individual and cross-table synthetic data quality scores (SQS) are computed and displayed below.

- - - - - - - - {% for table, evaluation in presenter.evaluations.items() %} - {% set individual_grade = presenter.sqs_score_to_grade(evaluation.individual_sqs) %} - {% set cross_table_grade = presenter.sqs_score_to_grade(evaluation.cross_table_sqs) %} - - - - - - {% endfor %} -
Table NameIndividual SQSCross-table SQS
{{ table }} - {{ evaluation.individual_sqs }} - - {{ individual_grade }} - - - Report - - - {{ evaluation.cross_table_sqs }} - - {{ cross_table_grade }} - - - Report - -
- -

The Synthetic Data Quality Score is an estimate of how well the generated synthetic - data maintains the same statistical properties as the original dataset. In this - sense, the Synthetic Data Quality Score can be viewed as a utility score or a - confidence score as to whether scientific conclusions drawn from the synthetic - dataset would be the same if one were to have used the original dataset instead. - If you do not require statistical symmetry, as might be the case in a testing or - demo environment, a lower score may be just as acceptable.

-

If your Synthetic Data Quality Score isn't as high as you'd like it to be, - read here - for a multitude of ideas for improving your model.

- -
- -
-

- Copyright © {{ presenter.copyright_year }} Gretel Labs, Inc. All rights reserved. -

-
- -
- - - diff --git a/src/gretel_trainer/relational/sdk_extras.py b/src/gretel_trainer/relational/sdk_extras.py deleted file mode 100644 index 66fa37c4..00000000 --- a/src/gretel_trainer/relational/sdk_extras.py +++ /dev/null @@ -1,99 +0,0 @@ -import logging -import shutil - -from pathlib import Path -from typing import Optional, Union - -import pandas as pd - -from gretel_client.projects.artifact_handlers import open_artifact -from gretel_client.projects.exceptions import MaxConcurrentJobsException -from gretel_client.projects.jobs import Job, Status -from gretel_client.projects.models import Model -from gretel_client.projects.projects import Project -from gretel_client.projects.records import RecordHandler -from gretel_trainer.relational.core import MultiTableException - -logger = logging.getLogger(__name__) - -MAX_IN_FLIGHT_JOBS = 10 - - -class ExtendedGretelSDK: - def __init__(self, hybrid: bool): - self._hybrid = hybrid - - def get_job_id(self, job: Job) -> Optional[str]: - if isinstance(job, Model): - return job.model_id - elif isinstance(job, RecordHandler): - return job.record_id - else: - raise MultiTableException("Unexpected job object") - - def delete_data_sources(self, project: Project, job: Job) -> None: - if not self._hybrid: - if job.data_source is not None: - project.delete_artifact(job.data_source) - for ref_data in job.ref_data.values: - project.delete_artifact(ref_data) - - def cautiously_refresh_status( - self, job: Job, key: str, refresh_attempts: dict[str, int] - ) -> Status: - try: - job.refresh() - refresh_attempts[key] = 0 - except: - refresh_attempts[key] = refresh_attempts[key] + 1 - - return job.status - - def download_file_artifact( - self, - gretel_object: Union[Project, Model], - artifact_name: str, - out_path: Union[str, Path], - ) -> bool: - try: - with ( - gretel_object.get_artifact_handle(artifact_name) as src, - open_artifact(out_path, "wb") as dest, - ): - shutil.copyfileobj(src, dest) - return True - except: - logger.warning(f"Failed to download `{artifact_name}`") - return False - - def get_record_handler_data(self, record_handler: RecordHandler) -> pd.DataFrame: - with record_handler.get_artifact_handle("data") as data: - return pd.read_csv(data) - - def start_job_if_possible( - self, - job: Job, - table_name: str, - action: str, - project: Project, - in_flight_jobs: int, - ) -> int: - if in_flight_jobs < MAX_IN_FLIGHT_JOBS: - self._log_start(table_name, action) - try: - job.submit() - return 1 - except MaxConcurrentJobsException: - self._log_waiting(table_name, action) - return 0 - else: - self._log_waiting(table_name, action) - return 0 - - def _log_start(self, table_name: str, action: str) -> None: - logger.info(f"Starting {action} for `{table_name}`.") - - def _log_waiting(self, table_name: str, action: str) -> None: - logger.info( - f"Maximum concurrent jobs reached. Deferring start of `{table_name}` {action}." - ) diff --git a/src/gretel_trainer/relational/strategies/ancestral.py b/src/gretel_trainer/relational/strategies/ancestral.py deleted file mode 100644 index 29140e2a..00000000 --- a/src/gretel_trainer/relational/strategies/ancestral.py +++ /dev/null @@ -1,360 +0,0 @@ -import logging - -from typing import Any, Union - -import pandas as pd - -import gretel_trainer.relational.ancestry as ancestry -import gretel_trainer.relational.strategies.common as common - -from gretel_client.projects.artifact_handlers import open_artifact -from gretel_trainer.relational.core import ( - GretelModelConfig, - MultiTableException, - RelationalData, -) -from gretel_trainer.relational.output_handler import OutputHandler - -logger = logging.getLogger(__name__) - - -class AncestralStrategy: - @property - def name(self) -> str: - return "ancestral" - - @property - def supported_model_keys(self) -> list[str]: - return ["amplify"] - - @property - def default_config(self) -> GretelModelConfig: - return "synthetics/amplify" - - def label_encode_keys( - self, rel_data: RelationalData, tables: dict[str, pd.DataFrame] - ) -> dict[str, pd.DataFrame]: - return common.label_encode_keys(rel_data, tables) - - def prepare_training_data( - self, rel_data: RelationalData, table_paths: dict[str, str] - ) -> dict[str, str]: - """ - Writes tables' training data to provided paths. - Training data has: - - all safe-for-seed ancestor fields added - - columns in multigenerational format - - all keys translated to contiguous integers - - artificial min/max seed records added - """ - all_tables = rel_data.list_all_tables() - omitted_tables = [t for t in all_tables if t not in table_paths] - altered_tableset = {} - - # Create a new table set identical to source data - for table_name in all_tables: - altered_tableset[table_name] = rel_data.get_table_data(table_name).copy() - - # Translate all keys to a contiguous list of integers - altered_tableset = common.label_encode_keys( - rel_data, altered_tableset, omit=omitted_tables - ) - - # Add artificial rows to support seeding - altered_tableset = _add_artifical_rows_for_seeding( - rel_data, altered_tableset, omitted_tables - ) - - # Collect all data in multigenerational format - for table, path in table_paths.items(): - data = ancestry.get_table_data_with_ancestors( - rel_data=rel_data, - table=table, - tableset=altered_tableset, - ancestral_seeding=True, - ) - with open_artifact(path, "wb") as dest: - data.to_csv(dest, index=False) - - return table_paths - - def tables_to_retrain( - self, tables: list[str], rel_data: RelationalData - ) -> list[str]: - """ - Given a set of tables requested to retrain, returns those tables with all their - descendants, because those descendant tables were trained with data from their - parents appended. - """ - retrain = set(tables) - for table in tables: - retrain.update(rel_data.get_descendants(table)) - return list(retrain) - - def validate_preserved_tables( - self, tables: list[str], rel_data: RelationalData - ) -> None: - """ - Ensures that for every table marked as preserved, all its ancestors are also preserved. - """ - for table in tables: - for parent in rel_data.get_parents(table): - if parent not in tables: - raise MultiTableException( - f"Cannot preserve table {table} without also preserving parent {parent}." - ) - - def get_preserved_data(self, table: str, rel_data: RelationalData) -> pd.DataFrame: - """ - Returns preserved source data in multigenerational format for synthetic children - to reference during generation post-processing. - """ - return ancestry.get_table_data_with_ancestors(rel_data, table) - - def ready_to_generate( - self, - rel_data: RelationalData, - in_progress: list[str], - finished: list[str], - ) -> list[str]: - """ - Tables with no parents are immediately ready for generation. - Tables with parents are ready once their parents are finished. - All tables are no longer considered ready once they are at least in progress. - """ - ready = [] - - for table in rel_data.list_all_tables(): - if table in in_progress or table in finished: - continue - - parents = rel_data.get_parents(table) - if len(parents) == 0: - ready.append(table) - elif all([parent in finished for parent in parents]): - ready.append(table) - - return ready - - def get_generation_job( - self, - table: str, - rel_data: RelationalData, - record_size_ratio: float, - output_tables: dict[str, pd.DataFrame], - subdir: str, - output_handler: OutputHandler, - ) -> dict[str, Any]: - """ - Returns kwargs for creating a record handler job via the Gretel SDK. - - If the table does not have any parents, job requests an output - record count based on the initial table data size and the record size ratio. - - If the table does have parents, builds a seed dataframe to use in generation. - """ - source_data_size = len(rel_data.get_table_data(table)) - synth_size = int(source_data_size * record_size_ratio) - if len(rel_data.get_parents(table)) == 0: - return {"params": {"num_records": synth_size}} - else: - seed_df = self._build_seed_data_for_table( - table, output_tables, rel_data, synth_size - ) - seed_path = output_handler.filepath_for( - f"synthetics_seed_{table}.csv", subdir=subdir - ) - with open_artifact(seed_path, "wb") as dest: - seed_df.to_csv(dest, index=False) - return {"data_source": str(seed_path)} - - def _build_seed_data_for_table( - self, - table: str, - output_tables: dict[str, pd.DataFrame], - rel_data: RelationalData, - synth_size: int, - ) -> pd.DataFrame: - column_legend = ancestry.get_seed_safe_multigenerational_columns(rel_data) - seed_df = pd.DataFrame() - - source_data = rel_data.get_table_data(table) - for fk in rel_data.get_foreign_keys(table): - parent_table_data = output_tables[fk.parent_table_name] - parent_table_data = ancestry.prepend_foreign_key_lineage( - parent_table_data, fk.columns - ) - - # Get FK frequencies - freqs = common.get_frequencies(source_data, fk.columns) - freqs = sorted(freqs, reverse=True) - f = 0 - - # Make a list of parent_table indicies matching FK frequencies - parent_indices = range(len(parent_table_data)) - p = 0 - parent_indices_to_use_as_fks = [] - while len(parent_indices_to_use_as_fks) < synth_size: - parent_index_to_use = parent_indices[p] - for _ in range(freqs[f]): - parent_indices_to_use_as_fks.append(parent_index_to_use) - p = _safe_inc(p, parent_indices) - f = _safe_inc(f, freqs) - - # Turn list into a DF and merge the parent table data - tmp_column_name = "tmp_parent_merge" - this_fk_seed_df = pd.DataFrame( - data={tmp_column_name: parent_indices_to_use_as_fks} - ) - this_fk_seed_df = this_fk_seed_df.merge( - parent_table_data, - how="left", - left_on=tmp_column_name, - right_index=True, - ) - - # Drop any columns that weren't used in training, as well as the temporary merge column - columns_to_drop = [ - col - for col in this_fk_seed_df.columns - if col not in column_legend[table] - ] - columns_to_drop.append(tmp_column_name) - this_fk_seed_df = this_fk_seed_df.drop(columns=columns_to_drop) - - seed_df = pd.concat( - [ - seed_df.reset_index(drop=True), - this_fk_seed_df.reset_index(drop=True), - ], - axis=1, - ) - - return seed_df - - def tables_to_skip_when_failed( - self, table: str, rel_data: RelationalData - ) -> list[str]: - return rel_data.get_descendants(table) - - def post_process_individual_synthetic_result( - self, - table_name: str, - rel_data: RelationalData, - synthetic_table: pd.DataFrame, - record_size_ratio: float, - ) -> pd.DataFrame: - """ - Replaces primary key values with a new, contiguous set of values. - Replaces synthesized foreign keys with seed primary keys. - """ - processed = synthetic_table - - multigenerational_primary_key = ancestry.get_multigenerational_primary_key( - rel_data, table_name - ) - - if len(multigenerational_primary_key) == 0: - pass - elif len(multigenerational_primary_key) == 1: - processed[multigenerational_primary_key[0]] = [ - i for i in range(len(synthetic_table)) - ] - else: - synthetic_pk_columns = common.make_composite_pks( - table_name=table_name, - rel_data=rel_data, - primary_key=multigenerational_primary_key, - synth_row_count=len(synthetic_table), - ) - - # make_composite_pks may not have created as many unique keys as we have - # synthetic rows, so we truncate the table to avoid inserting NaN PKs. - processed = pd.concat( - [ - pd.DataFrame.from_records(synthetic_pk_columns), - processed.drop(multigenerational_primary_key, axis="columns").head( - len(synthetic_pk_columns) - ), - ], - axis=1, - ) - - for fk_map in ancestry.get_ancestral_foreign_key_maps(rel_data, table_name): - fk_col, parent_pk_col = fk_map - processed[fk_col] = processed[parent_pk_col] - - return processed - - def post_process_synthetic_results( - self, - synth_tables: dict[str, pd.DataFrame], - preserved: list[str], - rel_data: RelationalData, - record_size_ratio: float, - ) -> dict[str, pd.DataFrame]: - """ - Restores tables from multigenerational to original shape - """ - return { - table_name: ancestry.drop_ancestral_data(df) - for table_name, df in synth_tables.items() - } - - -def _add_artifical_rows_for_seeding( - rel_data: RelationalData, tables: dict[str, pd.DataFrame], omitted: list[str] -) -> dict[str, pd.DataFrame]: - # On each table, add an artifical row with the max possible PK value - # unless the table is omitted from synthetics - max_pk_values = {} - for table_name, data in tables.items(): - if table_name in omitted: - continue - max_pk_values[table_name] = len(data) * 50 - - random_record = tables[table_name].sample().copy() - for pk_col in rel_data.get_primary_key(table_name): - random_record[pk_col] = max_pk_values[table_name] - tables[table_name] = pd.concat([data, random_record]).reset_index(drop=True) - - # On each table with foreign keys, add two more artificial rows containing the min and max FK values - for table_name, data in tables.items(): - foreign_keys = rel_data.get_foreign_keys(table_name) - if len(foreign_keys) == 0: - continue - - # Skip if the parent table is omitted and is the only parent - if len(foreign_keys) == 1 and foreign_keys[0].parent_table_name in omitted: - continue - - two_records = tables[table_name].sample(2) - min_fk_record = two_records.head(1).copy() - max_fk_record = two_records.tail(1).copy() - - # By default, just auto-increment the primary key - for pk_col in rel_data.get_primary_key(table_name): - min_fk_record[pk_col] = max_pk_values[table_name] + 1 - max_fk_record[pk_col] = max_pk_values[table_name] + 2 - - # This can potentially overwrite the auto-incremented primary keys above in the case of composite keys - for foreign_key in foreign_keys: - # Treat FK columns to omitted parents as normal columns - if foreign_key.parent_table_name in omitted: - continue - for fk_col in foreign_key.columns: - min_fk_record[fk_col] = 0 - max_fk_record[fk_col] = max_pk_values[foreign_key.parent_table_name] - - tables[table_name] = pd.concat( - [data, min_fk_record, max_fk_record] - ).reset_index(drop=True) - - return tables - - -def _safe_inc(i: int, col: Union[list, range]) -> int: - i = i + 1 - if i == len(col): - i = 0 - return i diff --git a/src/gretel_trainer/relational/strategies/common.py b/src/gretel_trainer/relational/strategies/common.py deleted file mode 100644 index 0253e03c..00000000 --- a/src/gretel_trainer/relational/strategies/common.py +++ /dev/null @@ -1,207 +0,0 @@ -import logging -import random - -from dataclasses import dataclass -from typing import Optional - -import pandas as pd - -from sklearn import preprocessing - -from gretel_trainer.relational.core import MultiTableException, RelationalData - -logger = logging.getLogger(__name__) - - -def label_encode_keys( - rel_data: RelationalData, - tables: dict[str, pd.DataFrame], - omit: Optional[list[str]] = None, -) -> dict[str, pd.DataFrame]: - """ - Crawls tables for all key columns (primary and foreign). For each PK (and FK columns referencing it), - runs all values through a LabelEncoder and updates tables' columns to use LE-transformed values. - """ - omit = omit or [] - for table_name in rel_data.list_tables_parents_before_children(): - if table_name in omit: - continue - - df = tables.get(table_name) - if df is None: - continue - - for primary_key_column in rel_data.get_primary_key(table_name): - # Get a set of the tables and columns in `tables` referencing this PK - fk_references: set[tuple[str, str]] = set() - for descendant in rel_data.get_descendants(table_name): - if tables.get(descendant) is None: - continue - fks = rel_data.get_foreign_keys(descendant) - for fk in fks: - if fk.parent_table_name != table_name: - continue - - for i in range(len(fk.columns)): - if fk.parent_columns[i] == primary_key_column: - fk_references.add((descendant, fk.columns[i])) - - # Collect column values from PK and FK columns into a set - source_values = set() - source_values.update(df[primary_key_column].to_list()) - for fk_ref in fk_references: - fk_tbl, fk_col = fk_ref - fk_df = tables.get(fk_tbl) - if fk_df is None: - continue - source_values.update(fk_df[fk_col].to_list()) - - # Fit a label encoder on all values - le = preprocessing.LabelEncoder() - le.fit(list(source_values)) - - # Update PK and FK columns using the label encoder - df[primary_key_column] = le.transform(df[primary_key_column]) - - for fk_ref in fk_references: - fk_tbl, fk_col = fk_ref - fk_df = tables.get(fk_tbl) - if fk_df is None: - continue - fk_df[fk_col] = le.transform(fk_df[fk_col]) - - return tables - - -def make_composite_pks( - table_name: str, - rel_data: RelationalData, - primary_key: list[str], - synth_row_count: int, -) -> list[dict]: - # Given the randomness involved in this process, it is possible for this function to generate - # fewer unique composite keys than desired to completely fill the dataframe (i.e. the length - # of the tuple values in the dictionary may be < synth_row_count). It is the client's - # responsibility to check for this and drop synthetic records if necessary to fit. - table_data = rel_data.get_table_data(table_name) - original_primary_key = rel_data.get_primary_key(table_name) - - pk_component_frequencies = { - col: get_frequencies(table_data, [col]) for col in original_primary_key - } - - # each key in new_cols is a column name, and each value is a list of - # column values. The values are a contiguous list of integers, with - # each integer value appearing 1-N times to match the frequencies of - # (original source) values' appearances in the source data. - new_cols: dict[str, list] = {} - for i, col in enumerate(primary_key): - freqs = pk_component_frequencies[original_primary_key[i]] - next_freq = 0 - next_key = 0 - new_col_values = [] - - while len(new_col_values) < synth_row_count: - for i in range(freqs[next_freq]): - new_col_values.append(next_key) - next_key += 1 - next_freq += 1 - if next_freq == len(freqs): - next_freq = 0 - - # A large frequency may have added more values than we need, - # so trim to synth_row_count - new_cols[col] = new_col_values[0:synth_row_count] - - # Shuffle for realism - for col_name, col_values in new_cols.items(): - random.shuffle(col_values) - - # Zip the individual columns into a list of records. - # Each element in the list is a composite key dict. - composite_keys: list[dict] = [] - for i in range(synth_row_count): - comp_key = {} - for col_name, col_values in new_cols.items(): - comp_key[col_name] = col_values[i] - composite_keys.append(comp_key) - - # The zip above may not have produced unique composite key dicts. - # Using the most unique column (to give us the most options), try - # changing a value to "resolve" candidate composite keys to unique combinations. - cant_resolve = 0 - seen: set[str] = set() - final_synthetic_composite_keys: list[dict] = [] - most_unique_column = _get_most_unique_column(primary_key, pk_component_frequencies) - - for i in range(synth_row_count): - y = i + 1 - if y == len(composite_keys): - y = 0 - - comp_key = composite_keys[i] - - while str(comp_key) in seen and y != i: - last_val = new_cols[most_unique_column][y] - y += 1 - if y == len(composite_keys): - y = 0 - comp_key[most_unique_column] = last_val - if str(comp_key) in seen: - cant_resolve += 1 - else: - final_synthetic_composite_keys.append(comp_key) - seen.add(str(comp_key)) - - return final_synthetic_composite_keys - - -def _get_most_unique_column(pk: list[str], col_freqs: dict[str, list]) -> str: - most_unique = None - max_length = 0 - for col, freqs in col_freqs.items(): - if len(freqs) > max_length: - most_unique = col - - if most_unique is None: - raise MultiTableException( - f"Failed to identify most unique column from column frequencies: {col_freqs}" - ) - - # The keys in col_freqs are always the source column names from the original primary key. - # Meanwhile, `pk` could be either the same (independent strategy) or in multigenerational - # format (ancestral strategy). We need to return the column name in the format matching - # the rest of the synthetic data undergoing post-processing. - idx = list(col_freqs.keys()).index(most_unique) - return pk[idx] - - -def get_frequencies(table_data: pd.DataFrame, cols: list[str]) -> list[int]: - return list(table_data.groupby(cols).size().reset_index()[0]) - - -# Frequency metadata for a list of columns (typically a foreign key). -# -# Example: pd.DataFrame(data={ -# "col_1": ["a", "a", "b"], -# "col_2": [100, 100, None], -# }) -# -# null_percentages: [0.0, 0.33333] -# (col_1 has no null values; col_2 is 1/3 null) -# not_null_frequencies: [2] -# (["a", 100] occurs twice; there are no other non-null values) -@dataclass -class FrequencyData: - null_percentages: list[float] - not_null_frequencies: list[int] - - @classmethod - def for_columns(cls, table_data: pd.DataFrame, cols: list[str]): - null_percentages = (table_data[cols].isnull().sum() / len(table_data)).tolist() - not_null_frequencies = table_data.groupby(cols).size().tolist() - - return FrequencyData( - null_percentages=null_percentages, - not_null_frequencies=not_null_frequencies, - ) diff --git a/src/gretel_trainer/relational/strategies/independent.py b/src/gretel_trainer/relational/strategies/independent.py deleted file mode 100644 index 75d8658c..00000000 --- a/src/gretel_trainer/relational/strategies/independent.py +++ /dev/null @@ -1,335 +0,0 @@ -import logging -import random - -from typing import Any - -import pandas as pd - -import gretel_trainer.relational.strategies.common as common - -from gretel_client.projects.artifact_handlers import open_artifact -from gretel_trainer.relational.core import ForeignKey, GretelModelConfig, RelationalData -from gretel_trainer.relational.output_handler import OutputHandler - -logger = logging.getLogger(__name__) - - -class IndependentStrategy: - @property - def name(self) -> str: - return "independent" - - @property - def supported_model_keys(self) -> list[str]: - return ["amplify", "actgan", "synthetics", "tabular_dp"] - - @property - def default_config(self) -> GretelModelConfig: - return "synthetics/tabular-actgan" - - def label_encode_keys( - self, rel_data: RelationalData, tables: dict[str, pd.DataFrame] - ) -> dict[str, pd.DataFrame]: - return common.label_encode_keys(rel_data, tables) - - def prepare_training_data( - self, rel_data: RelationalData, table_paths: dict[str, str] - ) -> dict[str, str]: - """ - Writes tables' training data to provided paths. - Training data has primary and foreign key columns removed. - """ - prepared_tables = {} - - for table, path in table_paths.items(): - columns_to_drop = set() - columns_to_drop.update(rel_data.get_primary_key(table)) - for foreign_key in rel_data.get_foreign_keys(table): - columns_to_drop.update(foreign_key.columns) - - all_columns = rel_data.get_table_columns(table) - use_columns = [col for col in all_columns if col not in columns_to_drop] - - # It's possible for *all columns* on a table to be part of a PK or FK, - # leaving no columns to send to a model for training. We omit such tables - # from the returned dictionary, indicating to MultiTable that it should - # "bypass" training and running a model for that table and instead leave - # it alone until post-processing (synthesizing key columns). - if len(use_columns) == 0: - logger.info( - f"All columns in table `{table}` are associated with key constraints" - ) - continue - - source_path = rel_data.get_table_source(table) - with ( - open_artifact(source_path, "rb") as src, - open_artifact(path, "wb") as dest, - ): - pd.DataFrame(columns=use_columns).to_csv(dest, index=False) - for chunk in pd.read_csv(src, usecols=use_columns, chunksize=10_000): - chunk.to_csv(dest, index=False, mode="a", header=False) - prepared_tables[table] = path - - return prepared_tables - - def tables_to_retrain( - self, tables: list[str], rel_data: RelationalData - ) -> list[str]: - """ - Returns the provided tables requested to retrain, unaltered. - """ - return tables - - def validate_preserved_tables( - self, tables: list[str], rel_data: RelationalData - ) -> None: - """ - No-op. Under this strategy, any collection of tables can be preserved. - """ - pass - - def get_preserved_data(self, table: str, rel_data: RelationalData) -> pd.DataFrame: - """ - Returns preserved source data for synthetic children - to reference during generation post-processing. - """ - return rel_data.get_table_data(table) - - def ready_to_generate( - self, - rel_data: RelationalData, - in_progress: list[str], - finished: list[str], - ) -> list[str]: - """ - All tables are immediately ready for generation. Once they are - at least in progress, they are no longer ready. - """ - return [ - table - for table in rel_data.list_all_tables() - if table not in in_progress and table not in finished - ] - - def get_generation_job( - self, - table: str, - rel_data: RelationalData, - record_size_ratio: float, - output_tables: dict[str, pd.DataFrame], - subdir: str, - output_handler: OutputHandler, - ) -> dict[str, Any]: - """ - Returns kwargs for a record handler job requesting an output record - count based on the initial table data size and the record size ratio. - """ - source_data_size = len(rel_data.get_table_data(table)) - synth_size = int(source_data_size * record_size_ratio) - return {"params": {"num_records": synth_size}} - - def tables_to_skip_when_failed( - self, table: str, rel_data: RelationalData - ) -> list[str]: - return [] - - def post_process_individual_synthetic_result( - self, - table_name: str, - rel_data: RelationalData, - synthetic_table: pd.DataFrame, - record_size_ratio: float, - ) -> pd.DataFrame: - """ - No-op. This strategy does not apply any changes to individual table results upon record handler completion. - All post-processing is performed on the output tables collectively when they are all finished. - """ - return synthetic_table - - def post_process_synthetic_results( - self, - synth_tables: dict[str, pd.DataFrame], - preserved: list[str], - rel_data: RelationalData, - record_size_ratio: float, - ) -> dict[str, pd.DataFrame]: - "Synthesizes primary and foreign keys" - synth_tables = _synthesize_primary_keys( - synth_tables, preserved, rel_data, record_size_ratio - ) - synth_tables = _synthesize_foreign_keys(synth_tables, rel_data) - return synth_tables - - -def _synthesize_primary_keys( - synth_tables: dict[str, pd.DataFrame], - preserved: list[str], - rel_data: RelationalData, - record_size_ratio: float, -) -> dict[str, pd.DataFrame]: - """ - Alters primary key columns on all tables *except* preserved. - Assumes the primary key column is of type integer. - """ - processed = {} - for table_name, synth_data in synth_tables.items(): - processed[table_name] = synth_data.copy() - if table_name in preserved: - continue - - primary_key = rel_data.get_primary_key(table_name) - synth_row_count = len(synth_data) - - if len(primary_key) == 0: - continue - elif len(primary_key) == 1: - processed[table_name][primary_key[0]] = [i for i in range(synth_row_count)] - else: - synthetic_pk_columns = common.make_composite_pks( - table_name=table_name, - rel_data=rel_data, - primary_key=primary_key, - synth_row_count=synth_row_count, - ) - - # make_composite_pks may not have created as many unique keys as we have - # synthetic rows, so we truncate the table to avoid inserting NaN PKs. - processed[table_name] = pd.concat( - [ - processed[table_name].head(len(synthetic_pk_columns)), - pd.DataFrame.from_records(synthetic_pk_columns), - ], - axis=1, - ) - - return processed - - -def _synthesize_foreign_keys( - synth_tables: dict[str, pd.DataFrame], rel_data: RelationalData -) -> dict[str, pd.DataFrame]: - """ - Alters foreign key columns on all tables (*including* those flagged as not to - be synthesized to ensure joining to a synthesized parent table continues to work) - by replacing foreign key column values with valid values from the parent table column - being referenced. - """ - processed = {} - for table_name in rel_data.list_tables_parents_before_children(): - out_df = synth_tables.get(table_name) - if out_df is None: - continue - for foreign_key in rel_data.get_foreign_keys(table_name): - # We pull the parent from `processed` instead of from `synth_tables` because "this" table - # (`table_name` / `out_df`) may have a FK pointing to a parent column that _is itself_ a FK to some third table. - # We want to ensure the synthetic values we're using to populate "this" table's FK column are - # the final output values we've produced for its parent table. - # We are synthesizing foreign keys in parent->child order, so the parent table - # should(*) already exist in the processed dict with its final synthetic values... - parent_synth_table = processed.get(foreign_key.parent_table_name) - if parent_synth_table is None: - # (*)...BUT the parent table generation job may have failed and therefore not be present in either `processed` or `synth_tables`. - # The synthetic data for "this" table may still be useful, but we do not have valid/any synthetic - # values from the parent to set in "this" table's foreign key column. Instead of introducing dangling - # pointers, set the entire column to None. - synth_parent_values = [None] * len(foreign_key.parent_columns) - else: - synth_parent_values = parent_synth_table[ - foreign_key.parent_columns - ].values.tolist() - - original_table_data = rel_data.get_table_data(table_name) - fk_frequency_data = common.FrequencyData.for_columns( - original_table_data, foreign_key.columns - ) - - new_fk_values = _collect_fk_values( - synth_parent_values, fk_frequency_data, len(out_df), foreign_key - ) - - out_df[foreign_key.columns] = new_fk_values - - processed[table_name] = out_df - - return processed - - -def _collect_fk_values( - values: list, - freq_data: common.FrequencyData, - total: int, - foreign_key: ForeignKey, -) -> list: - # Support for and restrictions on null values in composite foreign keys varies - # across database dialects. The simplest and safest thing to do here is - # exclusively produce composite foreign key values that contain no NULLs. - if foreign_key.is_composite(): - return _collect_values(values, freq_data.not_null_frequencies, total, []) - - # Here, the foreign key is a single column. Start by adding an appropriate - # amount of NULL values. - num_nulls = round(freq_data.null_percentages[0] * total) - new_values = [(None,)] * num_nulls - - # Dedupe repeated values and discard None if present. - # 1. Duplicates should not exist, because foreign keys are required to reference - # columns with unique values. If the referred column is a PRIMARY KEY, there - # *will not* be duplicates given how we synthesize primary keys. However, the - # foreign key could be referencing a non-PK column with a UNIQUE constraint; - # in that case, our model will not have known about the UNIQUE consraint and - # may have produced non-unique values. - # 2. Nulls are already accounted for above; we don't want to synthesize more. - def _unique_not_null_values(values: list) -> list: - unique_values = {tuple(v) for v in values} - unique_values.discard((None,)) - vals = list(unique_values) - random.shuffle(vals) - return vals - - # Collect final output values by adding non-null values to `new_values` - # (which has the requisite number of nulls already). - return _collect_values( - _unique_not_null_values(values), - freq_data.not_null_frequencies, - total, - new_values, - ) - - -def _collect_values( - values: list, - frequencies: list[int], - total: int, - new_values: list, -) -> list: - freqs = sorted(frequencies) - - # Loop through frequencies in ascending order, - # adding "that many" of the next valid value - # to the output collection - v = 0 - f = 0 - while len(new_values) < total: - fk_value = values[v] - - for _ in range(freqs[f]): - new_values.append(fk_value) - - v = _safe_inc(v, values) - f = _safe_inc(f, freqs) - - # trim potential excess - new_values = new_values[0:total] - - # shuffle for realism - random.shuffle(new_values) - - return new_values - - -def _safe_inc(i: int, col: list) -> int: - i = i + 1 - if i == len(col): - i = 0 - return i diff --git a/src/gretel_trainer/relational/table_evaluation.py b/src/gretel_trainer/relational/table_evaluation.py deleted file mode 100644 index 1d4b189a..00000000 --- a/src/gretel_trainer/relational/table_evaluation.py +++ /dev/null @@ -1,99 +0,0 @@ -import json - -from dataclasses import dataclass, field -from typing import Literal, Optional, overload, Union - -_SQS = "synthetic_data_quality_score" -_PPL = "privacy_protection_level" -_SCORE = "score" -_GRADE = "grade" - - -@dataclass -class TableEvaluation: - cross_table_report_json: Optional[dict] = field(default=None, repr=False) - individual_report_json: Optional[dict] = field(default=None, repr=False) - - def is_complete(self) -> bool: - return ( - self.cross_table_report_json is not None - and self.cross_table_sqs is not None - and self.individual_report_json is not None - and self.individual_sqs is not None - ) - - @overload - def _field_from_json( - self, report_json: Optional[dict], entry: str, field: Literal["score"] - ) -> Optional[int]: ... - - @overload - def _field_from_json( - self, report_json: Optional[dict], entry: str, field: Literal["grade"] - ) -> Optional[str]: ... - - def _field_from_json( - self, report_json: Optional[dict], entry: str, field: str - ) -> Optional[Union[int, str]]: - if report_json is None: - return None - else: - return report_json.get(entry, {}).get(field) - - @property - def cross_table_sqs(self) -> Optional[int]: - return self._field_from_json(self.cross_table_report_json, _SQS, _SCORE) - - @property - def cross_table_sqs_grade(self) -> Optional[str]: - return self._field_from_json(self.cross_table_report_json, _SQS, _GRADE) - - @property - def cross_table_ppl(self) -> Optional[int]: - return self._field_from_json(self.cross_table_report_json, _PPL, _SCORE) - - @property - def cross_table_ppl_grade(self) -> Optional[str]: - return self._field_from_json(self.cross_table_report_json, _PPL, _GRADE) - - @property - def individual_sqs(self) -> Optional[int]: - return self._field_from_json(self.individual_report_json, _SQS, _SCORE) - - @property - def individual_sqs_grade(self) -> Optional[str]: - return self._field_from_json(self.individual_report_json, _SQS, _GRADE) - - @property - def individual_ppl(self) -> Optional[int]: - return self._field_from_json(self.individual_report_json, _PPL, _SCORE) - - @property - def individual_ppl_grade(self) -> Optional[str]: - return self._field_from_json(self.individual_report_json, _PPL, _GRADE) - - def __repr__(self) -> str: - d = {} - if self.cross_table_report_json is not None: - d["cross_table"] = { - "sqs": { - "score": self.cross_table_sqs, - "grade": self.cross_table_sqs_grade, - }, - "ppl": { - "score": self.cross_table_ppl, - "grade": self.cross_table_ppl_grade, - }, - } - if self.individual_report_json is not None: - d["individual"] = { - "sqs": { - "score": self.individual_sqs, - "grade": self.individual_sqs_grade, - }, - "ppl": { - "score": self.individual_ppl, - "grade": self.individual_ppl_grade, - }, - } - return json.dumps(d) diff --git a/src/gretel_trainer/relational/task_runner.py b/src/gretel_trainer/relational/task_runner.py deleted file mode 100644 index b791a929..00000000 --- a/src/gretel_trainer/relational/task_runner.py +++ /dev/null @@ -1,101 +0,0 @@ -import logging - -from collections import defaultdict -from dataclasses import dataclass -from typing import Callable, Protocol - -import gretel_trainer.relational.tasks.common as common - -from gretel_client.projects.jobs import END_STATES, Job, Status -from gretel_client.projects.projects import Project -from gretel_trainer.relational.sdk_extras import ExtendedGretelSDK - -MAX_REFRESH_ATTEMPTS = 3 - -logger = logging.getLogger(__name__) - - -@dataclass -class TaskContext: - in_flight_jobs: int - refresh_interval: int - project: Project - extended_sdk: ExtendedGretelSDK - backup: Callable[[], None] - - def maybe_start_job(self, job: Job, table_name: str, action: str) -> None: - self.in_flight_jobs += self.extended_sdk.start_job_if_possible( - job=job, - table_name=table_name, - action=action, - project=self.project, - in_flight_jobs=self.in_flight_jobs, - ) - - -class Task(Protocol): - @property - def ctx(self) -> TaskContext: ... - - def action(self, job: Job) -> str: ... - - @property - def table_collection(self) -> list[str]: ... - - def more_to_do(self) -> bool: ... - - def is_finished(self, table: str) -> bool: ... - - def get_job(self, table: str) -> Job: ... - - def handle_completed(self, table: str, job: Job) -> None: ... - - def handle_failed(self, table: str, job: Job) -> None: ... - - def handle_in_progress(self, table: str, job: Job) -> None: ... - - def handle_lost_contact(self, table: str, job: Job) -> None: ... - - def each_iteration(self) -> None: ... - - -def run_task(task: Task, extended_sdk: ExtendedGretelSDK) -> None: - refresh_attempts: dict[str, int] = defaultdict(int) - first_pass = True - - while task.more_to_do(): - if first_pass: - first_pass = False - else: - common.wait(task.ctx.refresh_interval) - - for table_name in task.table_collection: - if task.is_finished(table_name): - continue - - job = task.get_job(table_name) - if extended_sdk.get_job_id(job) is None: - task.ctx.maybe_start_job( - job=job, table_name=table_name, action=task.action(job) - ) - continue - - status = extended_sdk.cautiously_refresh_status( - job, table_name, refresh_attempts - ) - - if refresh_attempts[table_name] >= MAX_REFRESH_ATTEMPTS: - task.ctx.in_flight_jobs -= 1 - task.handle_lost_contact(table_name, job) - continue - - if status == Status.COMPLETED: - task.ctx.in_flight_jobs -= 1 - task.handle_completed(table_name, job) - elif status in END_STATES: - task.ctx.in_flight_jobs -= 1 - task.handle_failed(table_name, job) - else: - task.handle_in_progress(table_name, job) - - task.each_iteration() diff --git a/src/gretel_trainer/relational/tasks/classify.py b/src/gretel_trainer/relational/tasks/classify.py deleted file mode 100644 index 9a473750..00000000 --- a/src/gretel_trainer/relational/tasks/classify.py +++ /dev/null @@ -1,137 +0,0 @@ -import shutil - -import gretel_trainer.relational.tasks.common as common - -from gretel_client.projects.artifact_handlers import open_artifact -from gretel_client.projects.jobs import Job -from gretel_client.projects.models import Model -from gretel_client.projects.records import RecordHandler -from gretel_trainer.relational.output_handler import OutputHandler -from gretel_trainer.relational.task_runner import TaskContext -from gretel_trainer.relational.workflow_state import Classify - - -class ClassifyTask: - def __init__( - self, - classify: Classify, - data_sources: dict[str, str], - all_rows: bool, - ctx: TaskContext, - output_handler: OutputHandler, - ): - self.classify = classify - self.data_sources = data_sources - self.all_rows = all_rows - self.ctx = ctx - self.output_handler = output_handler - self.classify_record_handlers: dict[str, RecordHandler] = {} - self.completed_models = [] - self.failed_models = [] - self.completed_record_handlers = [] - self.failed_record_handlers = [] - self.result_filepaths: dict[str, str] = {} - - def action(self, job: Job) -> str: - if self.all_rows: - if isinstance(job, Model): - return "classify training" - else: - return "classification (all rows)" - else: - return "classification" - - @property - def table_collection(self) -> list[str]: - return list(self.classify.models.keys()) - - def more_to_do(self) -> bool: - total_tables = len(self.classify.models) - any_unfinished_models = len(self._finished_models) < total_tables - any_unfinished_record_handlers = ( - len(self._finished_record_handlers) < total_tables - ) - - if self.all_rows: - return any_unfinished_models or any_unfinished_record_handlers - else: - return any_unfinished_models - - @property - def _finished_models(self) -> list[str]: - return self.completed_models + self.failed_models - - @property - def _finished_record_handlers(self) -> list[str]: - return self.completed_record_handlers + self.failed_record_handlers - - def is_finished(self, table: str) -> bool: - if self.all_rows: - return ( - table in self._finished_models - and table in self._finished_record_handlers - ) - else: - return table in self._finished_models - - def get_job(self, table: str) -> Job: - record_handler = self.classify_record_handlers.get(table) - model = self.classify.models.get(table) - return record_handler or model - - def handle_completed(self, table: str, job: Job) -> None: - if isinstance(job, Model): - self.completed_models.append(table) - common.log_success(table, self.action(job)) - if self.all_rows: - record_handler = job.create_record_handler_obj( - data_source=self.data_sources[table] - ) - self.classify_record_handlers[table] = record_handler - self.ctx.maybe_start_job( - job=record_handler, table_name=table, action=self.action(job) - ) - elif isinstance(job, RecordHandler): - self.completed_record_handlers.append(table) - common.log_success(table, self.action(job)) - self._write_results(job=job, table=table) - common.cleanup(sdk=self.ctx.extended_sdk, project=self.ctx.project, job=job) - - def handle_failed(self, table: str, job: Job) -> None: - if isinstance(job, Model): - self.failed_models.append(table) - elif isinstance(job, RecordHandler): - self.failed_record_handlers.append(table) - common.log_failed(table, self.action(job)) - common.cleanup(sdk=self.ctx.extended_sdk, project=self.ctx.project, job=job) - - def handle_lost_contact(self, table: str, job: Job) -> None: - if isinstance(job, Model): - self.failed_models.append(table) - elif isinstance(job, RecordHandler): - self.failed_record_handlers.append(table) - common.log_lost_contact(table) - common.cleanup(sdk=self.ctx.extended_sdk, project=self.ctx.project, job=job) - - def handle_in_progress(self, table: str, job: Job) -> None: - common.log_in_progress(table, job.status, self.action(job)) - - def each_iteration(self) -> None: - self.ctx.backup() - - def _write_results(self, job: Job, table: str) -> None: - if isinstance(job, Model): - filename = f"classify_{table}.gz" - artifact_name = "data_preview" - else: - filename = f"classify_all_rows_{table}.gz" - artifact_name = "data" - - destpath = self.output_handler.filepath_for(filename) - - with ( - job.get_artifact_handle(artifact_name) as src, - open_artifact(str(destpath), "wb") as dest, - ): - shutil.copyfileobj(src, dest) - self.result_filepaths[table] = destpath diff --git a/src/gretel_trainer/relational/tasks/common.py b/src/gretel_trainer/relational/tasks/common.py deleted file mode 100644 index 4dc61007..00000000 --- a/src/gretel_trainer/relational/tasks/common.py +++ /dev/null @@ -1,35 +0,0 @@ -import logging -import time - -from gretel_client.projects.jobs import Job, Status -from gretel_client.projects.projects import Project -from gretel_trainer.relational.sdk_extras import ExtendedGretelSDK - -logger = logging.getLogger(__name__) - - -def wait(seconds: int) -> None: - logger.info(f"Next status check in {seconds} seconds.") - time.sleep(seconds) - - -def log_in_progress(table_name: str, status: Status, action: str) -> None: - logger.info( - f"{action.capitalize()} job for `{table_name}` still in progress (status: {status})." - ) - - -def log_success(table_name: str, action: str) -> None: - logger.info(f"{action.capitalize()} successfully completed for `{table_name}`.") - - -def log_failed(table_name: str, action: str) -> None: - logger.info(f"{action.capitalize()} failed for `{table_name}`.") - - -def log_lost_contact(table_name: str) -> None: - logger.warning(f"Lost contact with job for `{table_name}`.") - - -def cleanup(sdk: ExtendedGretelSDK, project: Project, job: Job) -> None: - sdk.delete_data_sources(project, job) diff --git a/src/gretel_trainer/relational/tasks/synthetics_evaluate.py b/src/gretel_trainer/relational/tasks/synthetics_evaluate.py deleted file mode 100644 index 8f30b376..00000000 --- a/src/gretel_trainer/relational/tasks/synthetics_evaluate.py +++ /dev/null @@ -1,148 +0,0 @@ -import json -import logging - -from collections import defaultdict -from typing import Optional - -import gretel_trainer.relational.tasks.common as common - -from gretel_client.projects.artifact_handlers import open_artifact -from gretel_client.projects.jobs import Job -from gretel_client.projects.models import Model -from gretel_trainer.relational.output_handler import OutputHandler -from gretel_trainer.relational.table_evaluation import TableEvaluation -from gretel_trainer.relational.task_runner import TaskContext - -logger = logging.getLogger(__name__) - -ACTION = "synthetic data evaluation" - - -class SyntheticsEvaluateTask: - def __init__( - self, - individual_evaluate_models: dict[str, Model], - cross_table_evaluate_models: dict[str, Model], - subdir: str, - output_handler: OutputHandler, - evaluations: dict[str, TableEvaluation], - ctx: TaskContext, - ): - self.jobs = {} - for table, model in individual_evaluate_models.items(): - self.jobs[f"individual-{table}"] = model - for table, model in cross_table_evaluate_models.items(): - self.jobs[f"cross_table-{table}"] = model - self.subdir = subdir - self.output_handler = output_handler - self.evaluations = evaluations - self.completed = [] - self.failed = [] - self.ctx = ctx - # Nested dict organizing by table > sqs_type > file_type, e.g. - # { - # "users": { - # "individual": { - # "json": "/path/to/report.json", - # "html": "/path/to/report.html", - # }, - # "cross_table": { - # "json": "/path/to/report.json", - # "html": "/path/to/report.html", - # }, - # }, - # } - self.report_filepaths: dict[str, dict[str, dict[str, str]]] = defaultdict( - lambda: defaultdict(dict) - ) - - def action(self, job: Job) -> str: - return ACTION - - @property - def table_collection(self) -> list[str]: - return list(self.jobs.keys()) - - def more_to_do(self) -> bool: - return len(self.completed + self.failed) < len(self.jobs) - - def is_finished(self, table: str) -> bool: - return table in (self.completed + self.failed) - - def get_job(self, table: str) -> Job: - return self.jobs[table] - - def handle_completed(self, table: str, job: Job) -> None: - self.completed.append(table) - common.log_success(table, self.action(job)) - - model = self.get_job(table) - sqs_type, table_name = table.split("-", 1) - - filename_stem = _filename_stem(sqs_type, table_name) - - # JSON - json_filepath = self.output_handler.filepath_for( - f"{filename_stem}.json", subdir=self.subdir - ) - json_ok = self.ctx.extended_sdk.download_file_artifact( - model, "report_json", json_filepath - ) - if json_ok: - self.report_filepaths[table_name][sqs_type]["json"] = json_filepath - # Set json data on local evaluations object for use in report - json_data = _read_json_report(model, json_filepath) - if sqs_type == "individual": - self.evaluations[table_name].individual_report_json = json_data - else: - self.evaluations[table_name].cross_table_report_json = json_data - - # HTML - html_filepath = self.output_handler.filepath_for( - f"{filename_stem}.html", subdir=self.subdir - ) - html_ok = self.ctx.extended_sdk.download_file_artifact( - model, "report", html_filepath - ) - if html_ok: - self.report_filepaths[table_name][sqs_type]["html"] = html_filepath - - common.cleanup(sdk=self.ctx.extended_sdk, project=self.ctx.project, job=job) - - def handle_failed(self, table: str, job: Job) -> None: - self.failed.append(table) - common.log_failed(table, self.action(job)) - common.cleanup(sdk=self.ctx.extended_sdk, project=self.ctx.project, job=job) - - def handle_lost_contact(self, table: str, job: Job) -> None: - self.failed.append(table) - common.log_lost_contact(table) - common.cleanup(sdk=self.ctx.extended_sdk, project=self.ctx.project, job=job) - - def handle_in_progress(self, table: str, job: Job) -> None: - common.log_in_progress(table, job.status, self.action(job)) - - def each_iteration(self) -> None: - pass - - -def _read_json_report(model: Model, json_report_filepath: str) -> Optional[dict]: - """ - Reads the JSON report data in to a dictionary to be appended to the MultiTable - evaluations property. First try reading the file we just downloaded to the run - directory. If that fails, try reading the data remotely from the model. If that - also fails, log a warning and give up gracefully. - """ - try: - return json.loads(open_artifact(json_report_filepath).read()) - except: - try: - with model.get_artifact_handle("report_json") as report: - return json.loads(report.read()) - except: - logger.warning("Failed to fetch model evaluation report JSON.") - return None - - -def _filename_stem(sqs_type: str, table_name: str) -> str: - return f"synthetics_{sqs_type}_evaluation_{table_name}" diff --git a/src/gretel_trainer/relational/tasks/synthetics_run.py b/src/gretel_trainer/relational/tasks/synthetics_run.py deleted file mode 100644 index 42de5897..00000000 --- a/src/gretel_trainer/relational/tasks/synthetics_run.py +++ /dev/null @@ -1,199 +0,0 @@ -import logging - -from typing import Optional, Union - -import pandas as pd - -import gretel_trainer.relational.tasks.common as common - -from gretel_client.projects.jobs import ACTIVE_STATES, Job, Status -from gretel_client.projects.records import RecordHandler -from gretel_trainer.relational.core import RelationalData -from gretel_trainer.relational.output_handler import OutputHandler -from gretel_trainer.relational.strategies.ancestral import AncestralStrategy -from gretel_trainer.relational.strategies.independent import IndependentStrategy -from gretel_trainer.relational.task_runner import TaskContext -from gretel_trainer.relational.workflow_state import SyntheticsRun, SyntheticsTrain - -logger = logging.getLogger(__name__) - -ACTION = "synthetic data generation" - - -class SyntheticsRunTask: - def __init__( - self, - synthetics_run: SyntheticsRun, - synthetics_train: SyntheticsTrain, - subdir: str, - output_handler: OutputHandler, - rel_data: RelationalData, - strategy: Union[IndependentStrategy, AncestralStrategy], - ctx: TaskContext, - ): - self.synthetics_run = synthetics_run - self.synthetics_train = synthetics_train - self.subdir = subdir - self.output_handler = output_handler - self.rel_data = rel_data - self.strategy = strategy - self.ctx = ctx - self.working_tables = self._setup_working_tables() - - def _setup_working_tables(self) -> dict[str, Optional[pd.DataFrame]]: - working_tables = {} - all_tables = self.rel_data.list_all_tables() - - for table in all_tables: - if table in self.synthetics_train.bypass: - source_row_count = self.rel_data.get_table_row_count(table) - out_row_count = int( - source_row_count * self.synthetics_run.record_size_ratio - ) - working_tables[table] = pd.DataFrame(index=range(out_row_count)) - continue - - model = self.synthetics_train.models.get(table) - - # Table was either omitted from training or marked as to-be-preserved during generation - if model is None or table in self.synthetics_run.preserved: - working_tables[table] = self.strategy.get_preserved_data( - table, self.rel_data - ) - continue - - # Table was included in training, but failed at that step - if model.status != Status.COMPLETED: - working_tables[table] = None - continue - - return working_tables - - @property - def output_tables(self) -> dict[str, pd.DataFrame]: - return { - table: data - for table, data in self.working_tables.items() - if data is not None - } - - def action(self, job: Job) -> str: - return ACTION - - @property - def table_collection(self) -> list[str]: - return list(self.synthetics_run.record_handlers.keys()) - - def more_to_do(self) -> bool: - return len(self.working_tables) < len(self._all_tables) - - def is_finished(self, table: str) -> bool: - return table in self.working_tables - - def get_job(self, table: str) -> Job: - return self.synthetics_run.record_handlers[table] - - def handle_completed(self, table: str, job: Job) -> None: - record_handler_data = self.ctx.extended_sdk.get_record_handler_data(job) - post_processed_data = self.strategy.post_process_individual_synthetic_result( - table_name=table, - rel_data=self.rel_data, - synthetic_table=record_handler_data, - record_size_ratio=self.synthetics_run.record_size_ratio, - ) - self.working_tables[table] = post_processed_data - common.log_success(table, self.action(job)) - common.cleanup(sdk=self.ctx.extended_sdk, project=self.ctx.project, job=job) - - def handle_failed(self, table: str, job: Job) -> None: - self.working_tables[table] = None - self._fail_table(table) - common.log_failed(table, self.action(job)) - common.cleanup(sdk=self.ctx.extended_sdk, project=self.ctx.project, job=job) - self.ctx.backup() - - def handle_lost_contact(self, table: str, job: Job) -> None: - self.synthetics_run.lost_contact.append(table) - self._fail_table(table) - common.log_lost_contact(table) - common.cleanup(sdk=self.ctx.extended_sdk, project=self.ctx.project, job=job) - self.ctx.backup() - - def handle_in_progress(self, table: str, job: Job) -> None: - common.log_in_progress(table, job.status, self.action(job)) - - def each_iteration(self) -> None: - # Determine if we can start any more jobs - in_progress_tables = [ - table - for table in self._all_tables - if _table_is_in_progress(self.synthetics_run.record_handlers, table) - ] - finished_tables = [table for table in self.working_tables] - - ready_tables = self.strategy.ready_to_generate( - self.rel_data, in_progress_tables, finished_tables - ) - - for table_name in ready_tables: - # Any record handlers we created but deferred submitting will continue to register as "ready" until they are actually submitted and become "in progress". - # This check prevents repeatedly incurring the cost of fetching the job details (and logging duplicatively) while the job is deferred. - if self.synthetics_run.record_handlers.get(table_name) is not None: - continue - - present_working_tables = { - table: data - for table, data in self.working_tables.items() - if data is not None - } - - table_job = self.strategy.get_generation_job( - table_name, - self.rel_data, - self.synthetics_run.record_size_ratio, - present_working_tables, - self.subdir, - self.output_handler, - ) - model = self.synthetics_train.models[table_name] - record_handler = model.create_record_handler_obj(**table_job) - self.synthetics_run.record_handlers[table_name] = record_handler - # Attempt starting the record handler right away. If it can't start right at this moment, - # the regular task runner check will handle starting it when possible. - self.ctx.maybe_start_job( - job=record_handler, - table_name=table_name, - action=self.action(record_handler), - ) - - self.ctx.backup() - - @property - def _all_tables(self) -> list[str]: - return self.rel_data.list_all_tables() - - def _fail_table(self, table: str) -> None: - self.working_tables[table] = None - for other_table in self.strategy.tables_to_skip_when_failed( - table, self.rel_data - ): - _log_skipping(skip=other_table, failed_parent=table) - self.working_tables[other_table] = None - - -def _log_skipping(skip: str, failed_parent: str) -> None: - logger.info( - f"Skipping synthetic data generation for `{skip}` because it depends on `{failed_parent}`" - ) - - -def _table_is_in_progress( - record_handlers: dict[str, RecordHandler], table: str -) -> bool: - in_progress = False - - record_handler = record_handlers.get(table) - if record_handler is not None and record_handler.record_id is not None: - in_progress = record_handler.status in ACTIVE_STATES - - return in_progress diff --git a/src/gretel_trainer/relational/tasks/synthetics_train.py b/src/gretel_trainer/relational/tasks/synthetics_train.py deleted file mode 100644 index 95e64437..00000000 --- a/src/gretel_trainer/relational/tasks/synthetics_train.py +++ /dev/null @@ -1,57 +0,0 @@ -import gretel_trainer.relational.tasks.common as common - -from gretel_client.projects.jobs import Job -from gretel_trainer.relational.task_runner import TaskContext -from gretel_trainer.relational.workflow_state import SyntheticsTrain - -ACTION = "synthetics model training" - - -class SyntheticsTrainTask: - def __init__( - self, - synthetics_train: SyntheticsTrain, - ctx: TaskContext, - ): - self.synthetics_train = synthetics_train - self.ctx = ctx - self.completed = [] - self.failed = [] - - def action(self, job: Job) -> str: - return ACTION - - @property - def table_collection(self) -> list[str]: - return list(self.synthetics_train.models.keys()) - - def more_to_do(self) -> bool: - return len(self.completed + self.failed) < len(self.synthetics_train.models) - - def is_finished(self, table: str) -> bool: - return table in (self.completed + self.failed) - - def get_job(self, table: str) -> Job: - return self.synthetics_train.models[table] - - def handle_completed(self, table: str, job: Job) -> None: - self.completed.append(table) - common.log_success(table, self.action(job)) - common.cleanup(sdk=self.ctx.extended_sdk, project=self.ctx.project, job=job) - - def handle_failed(self, table: str, job: Job) -> None: - self.failed.append(table) - common.log_failed(table, self.action(job)) - common.cleanup(sdk=self.ctx.extended_sdk, project=self.ctx.project, job=job) - - def handle_lost_contact(self, table: str, job: Job) -> None: - self.synthetics_train.lost_contact.append(table) - self.failed.append(table) - common.log_lost_contact(table) - common.cleanup(sdk=self.ctx.extended_sdk, project=self.ctx.project, job=job) - - def handle_in_progress(self, table: str, job: Job) -> None: - common.log_in_progress(table, job.status, self.action(job)) - - def each_iteration(self) -> None: - self.ctx.backup() diff --git a/src/gretel_trainer/relational/tasks/transforms_run.py b/src/gretel_trainer/relational/tasks/transforms_run.py deleted file mode 100644 index b70545a6..00000000 --- a/src/gretel_trainer/relational/tasks/transforms_run.py +++ /dev/null @@ -1,67 +0,0 @@ -from typing import Optional - -import pandas as pd - -import gretel_trainer.relational.tasks.common as common - -from gretel_client.projects.jobs import Job -from gretel_client.projects.records import RecordHandler -from gretel_trainer.relational.task_runner import TaskContext - -ACTION = "transforms run" - - -class TransformsRunTask: - def __init__( - self, - record_handlers: dict[str, RecordHandler], - ctx: TaskContext, - ): - self.record_handlers = record_handlers - self.ctx = ctx - self.working_tables: dict[str, Optional[pd.DataFrame]] = {} - - @property - def output_tables(self) -> dict[str, pd.DataFrame]: - return { - table: data - for table, data in self.working_tables.items() - if data is not None - } - - def action(self, job: Job) -> str: - return ACTION - - @property - def table_collection(self) -> list[str]: - return list(self.record_handlers.keys()) - - def more_to_do(self) -> bool: - return len(self.working_tables) < len(self.record_handlers) - - def is_finished(self, table: str) -> bool: - return table in self.working_tables - - def get_job(self, table: str) -> Job: - return self.record_handlers[table] - - def handle_completed(self, table: str, job: Job) -> None: - self.working_tables[table] = self.ctx.extended_sdk.get_record_handler_data(job) - common.log_success(table, self.action(job)) - common.cleanup(sdk=self.ctx.extended_sdk, project=self.ctx.project, job=job) - - def handle_failed(self, table: str, job: Job) -> None: - self.working_tables[table] = None - common.log_failed(table, self.action(job)) - common.cleanup(sdk=self.ctx.extended_sdk, project=self.ctx.project, job=job) - - def handle_lost_contact(self, table: str, job: Job) -> None: - self.working_tables[table] = None - common.log_lost_contact(table) - common.cleanup(sdk=self.ctx.extended_sdk, project=self.ctx.project, job=job) - - def handle_in_progress(self, table: str, job: Job) -> None: - common.log_in_progress(table, job.status, self.action(job)) - - def each_iteration(self) -> None: - pass diff --git a/src/gretel_trainer/relational/tasks/transforms_train.py b/src/gretel_trainer/relational/tasks/transforms_train.py deleted file mode 100644 index 90ae3f53..00000000 --- a/src/gretel_trainer/relational/tasks/transforms_train.py +++ /dev/null @@ -1,57 +0,0 @@ -import gretel_trainer.relational.tasks.common as common - -from gretel_client.projects.jobs import Job -from gretel_trainer.relational.task_runner import TaskContext -from gretel_trainer.relational.workflow_state import TransformsTrain - -ACTION = "transforms model training" - - -class TransformsTrainTask: - def __init__( - self, - transforms_train: TransformsTrain, - ctx: TaskContext, - ): - self.transforms_train = transforms_train - self.ctx = ctx - self.completed = [] - self.failed = [] - - def action(self, job: Job) -> str: - return ACTION - - @property - def table_collection(self) -> list[str]: - return list(self.transforms_train.models.keys()) - - def more_to_do(self) -> bool: - return len(self.completed + self.failed) < len(self.transforms_train.models) - - def is_finished(self, table: str) -> bool: - return table in (self.completed + self.failed) - - def get_job(self, table: str) -> Job: - return self.transforms_train.models[table] - - def handle_completed(self, table: str, job: Job) -> None: - self.completed.append(table) - common.log_success(table, self.action(job)) - common.cleanup(sdk=self.ctx.extended_sdk, project=self.ctx.project, job=job) - - def handle_failed(self, table: str, job: Job) -> None: - self.failed.append(table) - common.log_failed(table, self.action(job)) - common.cleanup(sdk=self.ctx.extended_sdk, project=self.ctx.project, job=job) - - def handle_lost_contact(self, table: str, job: Job) -> None: - self.transforms_train.lost_contact.append(table) - self.failed.append(table) - common.log_lost_contact(table) - common.cleanup(sdk=self.ctx.extended_sdk, project=self.ctx.project, job=job) - - def handle_in_progress(self, table: str, job: Job) -> None: - common.log_in_progress(table, job.status, self.action(job)) - - def each_iteration(self) -> None: - self.ctx.backup() diff --git a/src/gretel_trainer/relational/workflow_state.py b/src/gretel_trainer/relational/workflow_state.py deleted file mode 100644 index ff1004f2..00000000 --- a/src/gretel_trainer/relational/workflow_state.py +++ /dev/null @@ -1,31 +0,0 @@ -from dataclasses import dataclass, field - -from gretel_client.projects.models import Model -from gretel_client.projects.records import RecordHandler - - -@dataclass -class Classify: - models: dict[str, Model] = field(default_factory=dict) - - -@dataclass -class TransformsTrain: - models: dict[str, Model] = field(default_factory=dict) - lost_contact: list[str] = field(default_factory=list) - - -@dataclass -class SyntheticsTrain: - models: dict[str, Model] = field(default_factory=dict) - lost_contact: list[str] = field(default_factory=list) - bypass: list[str] = field(default_factory=list) - - -@dataclass -class SyntheticsRun: - identifier: str - record_size_ratio: float - preserved: list[str] - record_handlers: dict[str, RecordHandler] - lost_contact: list[str] diff --git a/tests/relational/conftest.py b/tests/relational/conftest.py deleted file mode 100644 index 6cbbc384..00000000 --- a/tests/relational/conftest.py +++ /dev/null @@ -1,201 +0,0 @@ -import itertools -import sqlite3 -import tempfile - -from pathlib import Path -from typing import Callable, Generator -from unittest.mock import Mock, patch - -import pandas as pd -import pytest - -from sqlalchemy import create_engine - -from gretel_trainer.relational.connectors import Connector -from gretel_trainer.relational.core import RelationalData -from gretel_trainer.relational.output_handler import SDKOutputHandler - -EXAMPLE_DBS = Path(__file__).parent.resolve() / "example_dbs" - - -@pytest.fixture(autouse=True) -def static_suffix(request): - if "no_mock_suffix" in request.keywords: - yield - return - with patch("gretel_trainer.relational.json.make_suffix") as make_suffix: - # Each call to make_suffix must be unique or there will be table collisions - make_suffix.side_effect = itertools.count(start=1) - yield make_suffix - - -@pytest.fixture() -def get_invented_table_suffix() -> Callable[[int], str]: - return _get_invented_table_suffix - - -def _get_invented_table_suffix(make_suffix_execution_number: int): - return f"invented_{str(make_suffix_execution_number)}" - - -@pytest.fixture -def invented_tables(get_invented_table_suffix) -> dict[str, str]: - return { - "purchases_root": f"purchases_{get_invented_table_suffix(1)}", - "purchases_data_years": f"purchases_{get_invented_table_suffix(2)}", - "bball_root": f"bball_{get_invented_table_suffix(1)}", - "bball_suspensions": f"bball_{get_invented_table_suffix(2)}", - "bball_teams": f"bball_{get_invented_table_suffix(3)}", - } - - -@pytest.fixture() -def output_handler(tmpdir, project): - return SDKOutputHandler( - workdir=tmpdir, - project=project, - hybrid=False, - source_archive=None, - ) - - -@pytest.fixture() -def project(): - with ( - patch("gretel_trainer.relational.multi_table.create_project") as create_project, - patch("gretel_trainer.relational.multi_table.get_project") as get_project, - ): - project = Mock() - project.name = "name" - project.display_name = "display_name" - - create_project.return_value = project - get_project.return_value = project - - yield project - - -def _rel_data_connector(name) -> Connector: - con = sqlite3.connect(f"file:{name}?mode=memory&cache=shared", uri=True) - cur = con.cursor() - with open(EXAMPLE_DBS / f"{name}.sql") as f: - cur.executescript(f.read()) - return Connector( - create_engine(f"sqlite:///file:{name}?mode=memory&cache=shared&uri=true") - ) - - -@pytest.fixture() -def example_dbs(): - return EXAMPLE_DBS - - -@pytest.fixture() -def tmpdir(): - with tempfile.TemporaryDirectory() as tmpdir: - yield tmpdir - - -@pytest.fixture() -def tmpfile(): - with tempfile.NamedTemporaryFile() as tmpfile: - yield tmpfile - - -@pytest.fixture() -def pets(tmpdir) -> Generator[RelationalData, None, None]: - yield _rel_data_connector("pets").extract(storage_dir=tmpdir) - - -@pytest.fixture() -def ecom(tmpdir) -> Generator[RelationalData, None, None]: - yield _rel_data_connector("ecom").extract(storage_dir=tmpdir) - - -@pytest.fixture() -def mutagenesis(tmpdir) -> Generator[RelationalData, None, None]: - yield _rel_data_connector("mutagenesis").extract(storage_dir=tmpdir) - - -@pytest.fixture() -def tpch(tmpdir) -> Generator[RelationalData, None, None]: - yield _rel_data_connector("tpch").extract(storage_dir=tmpdir) - - -@pytest.fixture() -def art(tmpdir) -> Generator[RelationalData, None, None]: - yield _rel_data_connector("art").extract(storage_dir=tmpdir) - - -@pytest.fixture() -def insurance(tmpdir) -> Generator[RelationalData, None, None]: - yield _rel_data_connector("insurance").extract(storage_dir=tmpdir) - - -@pytest.fixture() -def documents(tmpdir) -> Generator[RelationalData, None, None]: - yield _rel_data_connector("documents").extract(storage_dir=tmpdir) - - -@pytest.fixture() -def trips(tmpdir) -> Generator[RelationalData, None, None]: - with tempfile.NamedTemporaryFile() as tmpfile: - data = pd.DataFrame( - data={ - "id": list(range(100)), - "purpose": ["work"] * 100, - "vehicle_type_id": [1] * 60 + [2] * 30 + [3] * 5 + [4] * 5, - } - ) - data.to_csv(tmpfile.name, index=False) - rel_data = _rel_data_connector("trips").extract(storage_dir=tmpdir) - rel_data.update_table_data(table="trips", data=tmpfile.name) - yield rel_data - - -# These two NBA fixtures need their own temporary directories instead of -# using the tmpdir fixture because otherwise they stomp on each other -@pytest.fixture() -def source_nba(): - with tempfile.TemporaryDirectory() as tmpdir: - yield _setup_nba(tmpdir, synthetic=False) - - -@pytest.fixture() -def synthetic_nba(): - with tempfile.TemporaryDirectory() as tmpdir: - yield _setup_nba(tmpdir, synthetic=True) - - -def _setup_nba(directory: str, synthetic: bool): - if synthetic: - states = ["PA", "FL"] - cities = ["Philadelphia", "Miami"] - teams = ["Sixers", "Heat"] - else: - states = ["CA", "TN"] - cities = ["Los Angeles", "Memphis"] - teams = ["Lakers", "Grizzlies"] - - states = pd.DataFrame(data={"name": states, "id": [1, 2]}) - cities = pd.DataFrame(data={"name": cities, "id": [1, 2], "state_id": [1, 2]}) - teams = pd.DataFrame(data={"name": teams, "id": [1, 2], "city_id": [1, 2]}) - - rel_data = RelationalData(directory=directory) - rel_data.add_table(name="states", primary_key="id", data=states) - rel_data.add_table(name="cities", primary_key="id", data=cities) - rel_data.add_table(name="teams", primary_key="id", data=teams) - rel_data.add_foreign_key_constraint( - table="teams", - constrained_columns=["city_id"], - referred_table="cities", - referred_columns=["id"], - ) - rel_data.add_foreign_key_constraint( - table="cities", - constrained_columns=["state_id"], - referred_table="states", - referred_columns=["id"], - ) - - return rel_data, states, cities, teams diff --git a/tests/relational/example_dbs/art.sql b/tests/relational/example_dbs/art.sql deleted file mode 100644 index 0adb6b8e..00000000 --- a/tests/relational/example_dbs/art.sql +++ /dev/null @@ -1,29 +0,0 @@ -create table if not exists artists ( - id text primary key, - name text not null -); -delete from artists; - -create table if not exists paintings ( - id text primary key, - name text not null, - artist_id text not null, - -- - foreign key (artist_id) references artists (id) -); -delete from paintings; - -insert into artists (id, name) values - ("A001", "Wassily Kandinsky"), - ("A002", "Pablo Picasso"), - ("A003", "Vincent van Gogh"), - ("A004", "Leonardo da Vinci"); - -insert into paintings (id, artist_id, name) values - ("P001", "A004", "Mona Lisa"), - ("P002", "A004", "The Last Supper"), - ("P004", "A002", "Guernica"), - ("P005", "A002", "The Old Guitarist"), - ("P006", "A003", "Starry Night"), - ("P007", "A003", "Bedroom in Arles"), - ("P008", "A003", "Irises"); diff --git a/tests/relational/example_dbs/documents.sql b/tests/relational/example_dbs/documents.sql deleted file mode 100644 index ebc861bf..00000000 --- a/tests/relational/example_dbs/documents.sql +++ /dev/null @@ -1,48 +0,0 @@ -create table if not exists users ( - id integer primary key, - name text not null -); -delete from users; - -create table if not exists purchases ( - id integer primary key, - user_id integer not null, - data text not null, - -- - foreign key (user_id) references users (id) -); -delete from purchases; - -create table if not exists payments ( - id integer primary key, - purchase_id integer not null, - amount integer not null, - -- - foreign key (purchase_id) references purchases (id) -); -delete from payments; - -insert into users (id, name) values - (1, "Andy"), - (2, "Bob"), - (3, "Charlie"), - (4, "David"); - -insert into purchases (id, user_id, data) values - (1, 1, '{"item": "pen", "cost": 100, "details": {"color": "red"}, "years": [2023]}'), - (2, 2, '{"item": "paint", "cost": 100, "details": {"color": "red"}, "years": [2023, 2022]}'), - (3, 2, '{"item": "ink", "cost": 100, "details": {"color": "red"}, "years": [2020, 2019]}'), - (4, 3, '{"item": "pen", "cost": 100, "details": {"color": "blue"}, "years": []}'), - (5, 3, '{"item": "paint", "cost": 100, "details": {"color": "blue"}, "years": [2021]}'), - (6, 3, '{"item": "ink", "cost": 100, "details": {"color": "blue"}, "years": []}'); - -insert into payments (id, purchase_id, amount) values - (1, 1, 42), - (2, 1, 42), - (3, 2, 42), - (4, 2, 42), - (5, 2, 42), - (6, 3, 42), - (7, 4, 42), - (8, 4, 42), - (9, 5, 42); diff --git a/tests/relational/example_dbs/ecom.sql b/tests/relational/example_dbs/ecom.sql deleted file mode 100644 index de5b4166..00000000 --- a/tests/relational/example_dbs/ecom.sql +++ /dev/null @@ -1,50 +0,0 @@ -create table if not exists users ( - id integer primary key, - first_name text not null, - last_name text not null -); - -create table if not exists events ( - id integer primary key, - browser text not null, - traffic_source text not null, - user_id text not null, - -- - foreign key (user_id) references users (id) -); - -create table if not exists distribution_center ( - id integer primary key, - name text not null -); - -create table if not exists products ( - id integer primary key, - name text not null, - brand text not null, - distribution_center_id integer not null, - -- - foreign key (distribution_center_id) references distribution_center (id) -); - -create table if not exists inventory_items ( - id integer primary key, - sold_at text not null, - cost text not null, - product_id integer not null, - product_distribution_center_id integer not null, - -- - foreign key (product_id) references products (id), - foreign key (product_distribution_center_id) references distribution_center (id) -); - -create table if not exists order_items ( - id integer primary key, - sale_price text not null, - status text not null, - user_id integer not null, - inventory_item_id integer not null, - -- - foreign key (user_id) references users (id), - foreign key (inventory_item_id) references inventory_items (id) -); diff --git a/tests/relational/example_dbs/insurance.sql b/tests/relational/example_dbs/insurance.sql deleted file mode 100644 index c00049b7..00000000 --- a/tests/relational/example_dbs/insurance.sql +++ /dev/null @@ -1,32 +0,0 @@ -create table if not exists beneficiary ( - id integer primary key, - name text not null -); - -create table if not exists insurance_policies ( - id integer primary key, - primary_beneficiary integer not null, - secondary_beneficiary integer not null, - -- - foreign key (primary_beneficiary) references beneficiary (id), - foreign key (secondary_beneficiary) references beneficiary (id) -); - -insert into beneficiary (name) values - ("John Doe"), - ("Jane Smith"), - ("Michael Johnson"), - ("Emily Brown"), - ("William Wilson"); - -insert into insurance_policies (primary_beneficiary, secondary_beneficiary) values - (1, 2), - (2, 3), - (3, 4), - (4, 5), - (5, 1), - (1, 3), - (2, 4), - (3, 5), - (4, 1), - (5, 2); diff --git a/tests/relational/example_dbs/mutagenesis.sql b/tests/relational/example_dbs/mutagenesis.sql deleted file mode 100644 index d9427c54..00000000 --- a/tests/relational/example_dbs/mutagenesis.sql +++ /dev/null @@ -1,23 +0,0 @@ -create table if not exists molecule ( - molecule_id text primary key, - mutagenic text not null -); - -create table if not exists atom ( - atom_id integer primary key, - element text not null, - charge real not null, - molecule_id text not null, - -- - foreign key (molecule_id) references molecule (molecule_id) -); - -create table if not exists bond ( - type text not null, - atom1_id integer not null, - atom2_id integer not null, - -- - primary key (atom1_id, atom2_id), - foreign key (atom1_id) references atom (atom_id), - foreign key (atom2_id) references atom (atom_id) -); diff --git a/tests/relational/example_dbs/pets.sql b/tests/relational/example_dbs/pets.sql deleted file mode 100644 index 5c97d726..00000000 --- a/tests/relational/example_dbs/pets.sql +++ /dev/null @@ -1,29 +0,0 @@ -create table if not exists humans ( - id integer primary key, - name text not null, - city text not null -); -delete from humans; - -create table if not exists pets ( - id integer primary key, - name text not null, - age integer not null, - human_id integer not null, - foreign key (human_id) references humans (id) -); -delete from pets; - -insert into humans (name, city) values - ("John", "Liverpool"), - ("Paul", "Liverpool"), - ("George", "Liverpool"), - ("Ringo", "Liverpool"), - ("Billy", "Houston"); - -insert into pets (human_id, name, age) values - (1, "Lennon", 6), - (2, "McCartney", 14), - (3, "Harrison", 8), - (4, "Starr", 7), - (5, "Preston", 2); diff --git a/tests/relational/example_dbs/tpch.sql b/tests/relational/example_dbs/tpch.sql deleted file mode 100644 index 21175ee3..00000000 --- a/tests/relational/example_dbs/tpch.sql +++ /dev/null @@ -1,72 +0,0 @@ -create table if not exists supplier ( - s_suppkey integer primary key, - s_name text not null -); -delete from supplier; - -create table if not exists part ( - p_partkey integer primary key, - p_name text not null -); -delete from part; - -create table if not exists partsupp ( - ps_partkey integer not null, - ps_suppkey integer not null, - ps_availqty integer not null, - -- - primary key (ps_partkey, ps_suppkey), - foreign key (ps_partkey) references part (p_partkey), - foreign key (ps_suppkey) references supplier (s_suppkey) -); -delete from partsupp; - -create table if not exists lineitem ( - l_partkey integer not null, - l_suppkey integer not null, - l_quantity integer not null, - -- - primary key (l_partkey, l_suppkey), - foreign key (l_partkey, l_suppkey) references partsupp (ps_partkey, ps_suppkey) -); -delete from lineitem; - -insert into supplier (s_suppkey, s_name) values - (1, "SupplierA"), - (2, "SupplierB"), - (3, "SupplierC"), - (4, "SupplierD"); - -insert into part (p_partkey, p_name) values - (1, "burlywood plum powder puff mint"), - (2, "hot spring dodger dim light"), - (3, "dark slate grey steel misty"), - (4, "cream turquoise dark thistle light"); - -insert into partsupp (ps_partkey, ps_suppkey, ps_availqty) values - (1, 3, 103), - (1, 2, 102), - (1, 4, 104), - (2, 1, 201), - (2, 2, 202), - (2, 3, 203), - (3, 1, 301), - (3, 3, 303), - (3, 4, 304), - (4, 1, 401), - (4, 4, 404), - (4, 2, 402); - -insert into lineitem (l_partkey, l_suppkey, l_quantity) values - (1, 3, 13), - (1, 2, 12), - (1, 4, 14), - (2, 1, 21), - (2, 2, 22), - (2, 3, 23), - (3, 1, 31), - (3, 3, 33), - (3, 4, 34), - (4, 1, 41), - (4, 4, 44), - (4, 2, 42); diff --git a/tests/relational/example_dbs/trips.sql b/tests/relational/example_dbs/trips.sql deleted file mode 100644 index 004a7f7e..00000000 --- a/tests/relational/example_dbs/trips.sql +++ /dev/null @@ -1,19 +0,0 @@ -create table if not exists vehicle_types ( - id integer primary key, - name text not null -); -delete from vehicle_types; - -create table if not exists trips ( - id integer primary key, - purpose text not null, - vehicle_type_id integer not null, - -- - foreign key (vehicle_type_id) references vehicle_types (id) -); -delete from vehicle_types; - -insert into vehicle_types (name) values - ("car"), ("train"), ("bike"), ("plane"); - --- trip data inserted via python fixture diff --git a/tests/relational/test_ancestral_strategy.py b/tests/relational/test_ancestral_strategy.py deleted file mode 100644 index 9c74e951..00000000 --- a/tests/relational/test_ancestral_strategy.py +++ /dev/null @@ -1,697 +0,0 @@ -import os -import tempfile - -from unittest.mock import patch - -import pandas as pd -import pandas.testing as pdtest -import pytest - -import gretel_trainer.relational.ancestry as ancestry - -from gretel_trainer.relational.core import MultiTableException -from gretel_trainer.relational.strategies.ancestral import AncestralStrategy - - -def test_preparing_training_data_does_not_mutate_source_data(pets): - original_tables = { - table: pets.get_table_data(table).copy() for table in pets.list_all_tables() - } - - strategy = AncestralStrategy() - - with ( - tempfile.NamedTemporaryFile() as pets_dest, - tempfile.NamedTemporaryFile() as humans_dest, - ): - strategy.prepare_training_data( - pets, {"pets": pets_dest.name, "humans": humans_dest.name} - ) - - for table in pets.list_all_tables(): - pdtest.assert_frame_equal(original_tables[table], pets.get_table_data(table)) - - -def test_prepare_training_data_subset_of_tables(pets): - strategy = AncestralStrategy() - - with ( - tempfile.NamedTemporaryFile() as pets_dest, - tempfile.NamedTemporaryFile() as humans_dest, - ): - # We aren't synthesizing the "humans" table, so it is not in this list argument... - training_data = strategy.prepare_training_data(pets, {"pets": pets_dest.name}) - - train_pets = pd.read_csv(training_data["pets"]) - - # ...nor do we create training data for it - assert not train_pets.empty - assert os.stat(humans_dest.name).st_size == 0 - - # Since the humans table is omitted from synthetics, we leave the FK values alone; specifically: - # - they are not label-encoded (which would effectively zero-index them) - # - we do not add artificial min/max values - assert set(train_pets["self|human_id"].values) == {1, 2, 3, 4, 5} - # We do add the artificial max PK row, though, since this table is being synthesized - assert len(train_pets) == 6 - - -def test_prepare_training_data_returns_multigenerational_data(pets): - strategy = AncestralStrategy() - - with ( - tempfile.NamedTemporaryFile() as pets_dest, - tempfile.NamedTemporaryFile() as humans_dest, - ): - training_data = strategy.prepare_training_data( - pets, {"pets": pets_dest.name, "humans": humans_dest.name} - ) - train_pets = pd.read_csv(training_data["pets"]) - - for expected_column in ["self|id", "self|name", "self.human_id|id"]: - assert expected_column in train_pets - - -def test_prepare_training_data_drops_highly_unique_categorical_ancestor_fields(art): - with ( - tempfile.NamedTemporaryFile() as artists_csv, - tempfile.NamedTemporaryFile() as paintings_csv, - ): - pd.DataFrame( - data={ - "id": [f"A{i}" for i in range(100)], - "name": [f"artist_name_{i}" for i in range(100)], - } - ).to_csv(artists_csv.name, index=False) - art.update_table_data(table="artists", data=artists_csv.name) - - pd.DataFrame( - data={ - "id": [f"P{i}" for i in range(100)], - "artist_id": [f"A{i}" for i in range(100)], - "name": [f"painting_name_{i}" for i in range(100)], - } - ).to_csv(paintings_csv.name, index=False) - art.update_table_data(table="paintings", data=paintings_csv.name) - - strategy = AncestralStrategy() - - with ( - tempfile.NamedTemporaryFile() as artists_dest, - tempfile.NamedTemporaryFile() as paintings_dest, - ): - training_data = strategy.prepare_training_data( - art, - { - "artists": artists_dest.name, - "paintings": paintings_dest.name, - }, - ) - train_paintings = pd.read_csv(training_data["paintings"]) - - # Does not contain `self.artist_id|name` because it is highly unique categorical - assert set(train_paintings.columns) == { - "self|id", - "self|name", - "self|artist_id", - "self.artist_id|id", - } - - -def test_prepare_training_data_drops_highly_nan_ancestor_fields(art): - highly_nan_names = [] - for i in range(100): - if i > 70: - highly_nan_names.append(None) - else: - highly_nan_names.append("some name") - - with ( - tempfile.NamedTemporaryFile() as artists_csv, - tempfile.NamedTemporaryFile() as paintings_csv, - ): - pd.DataFrame( - data={ - "id": [f"A{i}" for i in range(100)], - "name": highly_nan_names, - } - ).to_csv(artists_csv.name, index=False) - art.update_table_data(table="artists", data=artists_csv.name) - - pd.DataFrame( - data={ - "id": [f"P{i}" for i in range(100)], - "artist_id": [f"A{i}" for i in range(100)], - "name": [str(i) for i in range(100)], - } - ).to_csv(paintings_csv.name, index=False) - art.update_table_data(table="paintings", data=paintings_csv.name) - - strategy = AncestralStrategy() - - with ( - tempfile.NamedTemporaryFile() as artists_dest, - tempfile.NamedTemporaryFile() as paintings_dest, - ): - training_data = strategy.prepare_training_data( - art, - { - "artists": artists_dest.name, - "paintings": paintings_dest.name, - }, - ) - train_paintings = pd.read_csv(training_data["paintings"]) - - # Does not contain `self.artist_id|name` because it is highly NaN - assert set(train_paintings.columns) == { - "self|id", - "self|name", - "self|artist_id", - "self.artist_id|id", - } - - -def test_prepare_training_data_translates_alphanumeric_keys_and_adds_min_max_records( - art, -): - strategy = AncestralStrategy() - - with ( - tempfile.NamedTemporaryFile() as artists_dest, - tempfile.NamedTemporaryFile() as paintings_dest, - ): - training_data = strategy.prepare_training_data( - art, - { - "artists": artists_dest.name, - "paintings": paintings_dest.name, - }, - ) - train_artists = pd.read_csv(training_data["artists"]) - train_paintings = pd.read_csv(training_data["paintings"]) - - # Artists, a parent table, should have 1 additional row - assert len(train_artists) == len(art.get_table_data("artists")) + 1 - # The last record has the artifical max PK - assert train_artists["self|id"].to_list() == [0, 1, 2, 3, 200] - # We do not assert on the value of "self|name" because the artificial max PK record is - # randomly sampled from source and so the exact value is not deterministic - - # Paintings, as a child table, should have 3 additional rows - # - artificial max PK - # - artificial min FKs - # - artificial max FKs - assert len(train_paintings) == len(art.get_table_data("paintings")) + 3 - - last_three = train_paintings.tail(3) - last_two = last_three.tail(2) - - # PKs are max, +1, +2 - assert last_three["self|id"].to_list() == [350, 351, 352] - # FKs on last two rows (artifical FKs) are min, max - assert last_two["self|artist_id"].to_list() == [0, 200] - assert last_two["self.artist_id|id"].to_list() == [0, 200] - - -def test_prepare_training_data_with_composite_keys(tpch): - strategy = AncestralStrategy() - with ( - tempfile.NamedTemporaryFile() as supplier_dest, - tempfile.NamedTemporaryFile() as part_dest, - tempfile.NamedTemporaryFile() as partsupp_dest, - tempfile.NamedTemporaryFile() as lineitem_dest, - ): - training_data = strategy.prepare_training_data( - tpch, - { - "supplier": supplier_dest.name, - "part": part_dest.name, - "partsupp": partsupp_dest.name, - "lineitem": lineitem_dest.name, - }, - ) - - train_partsupp = pd.read_csv(training_data["partsupp"]) - train_lineitem = pd.read_csv(training_data["lineitem"]) - - l_max = len(tpch.get_table_data("lineitem")) * 50 - ps_max = len(tpch.get_table_data("partsupp")) * 50 - p_max = len(tpch.get_table_data("part")) * 50 - s_max = len(tpch.get_table_data("supplier")) * 50 - - # partsupp table, composite PK - assert set(train_partsupp.columns) == { - "self|ps_partkey", - "self|ps_suppkey", - "self|ps_availqty", - "self.ps_partkey|p_partkey", - "self.ps_suppkey|s_suppkey", - } - assert len(train_partsupp) == len(tpch.get_table_data("partsupp")) + 3 - last_three_partsupp_keys = train_partsupp.tail(3).reset_index()[ - ["self|ps_partkey", "self|ps_suppkey"] - ] - pdtest.assert_frame_equal( - last_three_partsupp_keys, - pd.DataFrame( - data={ - "self|ps_partkey": [ps_max, 0, p_max], - "self|ps_suppkey": [ps_max, 0, s_max], - } - ), - ) - - # lineitem table, composite FK to partsupp - assert set(train_lineitem.columns) == { - "self|l_partkey", - "self|l_suppkey", - "self|l_quantity", - "self.l_partkey+l_suppkey|ps_partkey", - "self.l_partkey+l_suppkey|ps_suppkey", - "self.l_partkey+l_suppkey|ps_availqty", - "self.l_partkey+l_suppkey.ps_partkey|p_partkey", - "self.l_partkey+l_suppkey.ps_suppkey|s_suppkey", - } - assert len(train_lineitem) == len(tpch.get_table_data("lineitem")) + 3 - last_three_lineitem_keys = train_lineitem.tail(3).reset_index()[ - ["self|l_partkey", "self|l_suppkey"] - ] - pdtest.assert_frame_equal( - last_three_lineitem_keys, - pd.DataFrame( - data={ - "self|l_partkey": [l_max, 0, ps_max], - "self|l_suppkey": [l_max, 0, ps_max], - } - ), - ) - - -def test_retraining_a_set_of_tables_forces_retraining_descendants_as_well(ecom): - strategy = AncestralStrategy() - assert set(strategy.tables_to_retrain(["users"], ecom)) == { - "users", - "events", - "order_items", - } - assert set(strategy.tables_to_retrain(["products"], ecom)) == { - "products", - "inventory_items", - "order_items", - } - assert set(strategy.tables_to_retrain(["users", "products"], ecom)) == { - "users", - "events", - "products", - "inventory_items", - "order_items", - } - - -def test_preserve_tables(ecom): - strategy = AncestralStrategy() - - with pytest.raises(MultiTableException): - # Need to also preserve parent users - strategy.validate_preserved_tables(["events"], ecom) - - with pytest.raises(MultiTableException): - # Need to also preserve parent products - strategy.validate_preserved_tables( - ["distribution_center", "inventory_items"], ecom - ) - - assert strategy.validate_preserved_tables(["users", "events"], ecom) is None - assert ( - strategy.validate_preserved_tables( - ["distribution_center", "products", "inventory_items"], ecom - ) - is None - ) - - -def test_table_generation_readiness(ecom): - strategy = AncestralStrategy() - - # To start, "eldest generation" tables (those with no parents / outbound foreign keys) are ready - assert set(strategy.ready_to_generate(ecom, [], [])) == { - "users", - "distribution_center", - } - - # Once a table has been started, it is no longer ready - assert set(strategy.ready_to_generate(ecom, ["users"], [])) == { - "distribution_center" - } - - # It's possible to be in a state where work is happening but nothing is ready - assert ( - set(strategy.ready_to_generate(ecom, ["users", "distribution_center"], [])) - == set() - ) - - # `events` was only blocked by `users`; once the latter completes, the former is ready, - # regardless of the state of the unrelated `distribution_center` table - assert set( - strategy.ready_to_generate(ecom, ["distribution_center"], ["users"]) - ) == {"events"} - - # Similarly, the completion of `distribution_center` unblocks `products`, - # regardless of progress on `events` - assert set( - strategy.ready_to_generate(ecom, [], ["users", "distribution_center"]) - ) == {"events", "products"} - - # Remaining tables become ready as their parents complete - assert set( - strategy.ready_to_generate( - ecom, [], ["users", "distribution_center", "events", "products"] - ) - ) == {"inventory_items"} - - # As above, being in progress is not enough! Work is happening but nothing new is ready - assert ( - set( - strategy.ready_to_generate( - ecom, - ["inventory_items"], - ["users", "distribution_center", "events", "products"], - ) - ) - == set() - ) - - assert set( - strategy.ready_to_generate( - ecom, - [], - ["users", "distribution_center", "events", "products", "inventory_items"], - ) - ) == {"order_items"} - - assert ( - set( - strategy.ready_to_generate( - ecom, - ["order_items"], - [ - "users", - "distribution_center", - "events", - "products", - "inventory_items", - ], - ) - ) - == set() - ) - assert ( - set( - strategy.ready_to_generate( - ecom, - [], - [ - "users", - "distribution_center", - "events", - "products", - "inventory_items", - "order_items", - ], - ) - ) - == set() - ) - - -def test_generation_job(pets, output_handler): - def _num_objects_in_run_dir(): - return len(os.listdir(output_handler._working_dir / "run-id")) - - strategy = AncestralStrategy() - output_handler.make_subdirectory("run-id") - - # Table with no ancestors - parent_table_job = strategy.get_generation_job( - "humans", pets, 2.0, {}, "run-id", output_handler - ) - assert _num_objects_in_run_dir() == 0 - assert parent_table_job == {"params": {"num_records": 10}} - - # Table with ancestors - synthetic_humans = pd.DataFrame( - data={ - "self|name": [ - "Miles Davis", - "Wayne Shorter", - "Herbie Hancock", - "Ron Carter", - "Tony Williams", - ], - "self|city": [ - "New York", - "New York", - "New York", - "New York", - "Los Angeles", - ], - "self|id": [1, 2, 3, 4, 5], - } - ) - output_tables = {"humans": synthetic_humans} - child_table_job = strategy.get_generation_job( - "pets", pets, 2.0, output_tables, "run-id", output_handler - ) - - assert _num_objects_in_run_dir() == 1 - assert set(child_table_job.keys()) == {"data_source"} - child_table_seed_df = pd.read_csv(child_table_job["data_source"]) - - # `self.human_id|name` should not be present in seed because it was - # excluded from training data (highly-unique categorical field) - expected_seed_df = pd.DataFrame( - data={ - "self.human_id|city": [ - "New York", - "New York", - "New York", - "New York", - "Los Angeles", - "New York", - "New York", - "New York", - "New York", - "Los Angeles", - ], - "self.human_id|id": [1, 2, 3, 4, 5, 1, 2, 3, 4, 5], - } - ) - - pdtest.assert_frame_equal(child_table_seed_df, expected_seed_df) - - # sanity check: assert we did not mutate the originally-supplied synthetic tables - assert set(output_tables["humans"].columns) == {"self|name", "self|city", "self|id"} - - -def test_generation_job_seeds_go_back_multiple_generations( - source_nba, synthetic_nba, output_handler -): - source_nba = source_nba[0] - synthetic_nba = synthetic_nba[0] - output_tables = { - "cities": ancestry.get_table_data_with_ancestors(synthetic_nba, "cities"), - "states": ancestry.get_table_data_with_ancestors(synthetic_nba, "states"), - } - - strategy = AncestralStrategy() - output_handler.make_subdirectory("run-id") - - job = strategy.get_generation_job( - "teams", - source_nba, - 1.0, - output_tables, - "run-id", - output_handler, - ) - seed_df = pd.read_csv(job["data_source"]) - - expected_seed_df_columns = { - "self.city_id|id", - # "self.city_id|name", # highly unique categorical - "self.city_id|state_id", - "self.city_id.state_id|id", - # "self.city_id.state_id|name", # highly unique categorical - } - - assert set(seed_df.columns) == expected_seed_df_columns - - -def test_post_processing_individual_synthetic_result(ecom): - strategy = AncestralStrategy() - synth_events = pd.DataFrame( - data={ - "self|id": [100, 101, 102, 103, 104], - "self|user_id": [200, 201, 202, 203, 204], - "self.user_id|id": [10, 11, 12, 13, 14], - } - ) - - processed_events = strategy.post_process_individual_synthetic_result( - "events", ecom, synth_events, 1 - ) - - expected_post_processing = pd.DataFrame( - data={ - "self|id": [0, 1, 2, 3, 4], - "self|user_id": [10, 11, 12, 13, 14], - "self.user_id|id": [10, 11, 12, 13, 14], - } - ) - - pdtest.assert_frame_equal(expected_post_processing, processed_events) - - -def test_post_processing_individual_synthetic_result_composite_keys(tpch): - strategy = AncestralStrategy() - synth_lineitem = pd.DataFrame( - data={ - "self|l_partkey": [10, 20, 30, 40] * 3, - "self|l_suppkey": [10, 20, 30, 40] * 3, - "self|l_quantity": [42, 42, 42, 42] * 3, - "self.l_partkey+l_suppkey|ps_partkey": [2, 3, 4, 5] * 3, - "self.l_partkey+l_suppkey|ps_suppkey": [6, 7, 8, 9] * 3, - "self.l_partkey+l_suppkey|ps_availqty": [80, 80, 80, 80] * 3, - "self.l_partkey+l_suppkey.ps_partkey|p_partkey": [2, 3, 4, 5] * 3, - "self.l_partkey+l_suppkey.ps_partkey|p_name": ["a", "b", "c", "d"] * 3, - "self.l_partkey+l_suppkey.ps_suppkey|s_suppkey": [6, 7, 8, 9] * 3, - "self.l_partkey+l_suppkey.ps_suppkey|s_name": ["e", "f", "g", "h"] * 3, - } - ) - - processed_lineitem = strategy.post_process_individual_synthetic_result( - "lineitem", tpch, synth_lineitem, 1 - ) - - expected_post_processing = pd.DataFrame( - data={ - "self|l_partkey": [2, 3, 4, 5] * 3, - "self|l_suppkey": [6, 7, 8, 9] * 3, - "self|l_quantity": [42, 42, 42, 42] * 3, - "self.l_partkey+l_suppkey|ps_partkey": [2, 3, 4, 5] * 3, - "self.l_partkey+l_suppkey|ps_suppkey": [6, 7, 8, 9] * 3, - "self.l_partkey+l_suppkey|ps_availqty": [80, 80, 80, 80] * 3, - "self.l_partkey+l_suppkey.ps_partkey|p_partkey": [2, 3, 4, 5] * 3, - "self.l_partkey+l_suppkey.ps_partkey|p_name": ["a", "b", "c", "d"] * 3, - "self.l_partkey+l_suppkey.ps_suppkey|s_suppkey": [6, 7, 8, 9] * 3, - "self.l_partkey+l_suppkey.ps_suppkey|s_name": ["e", "f", "g", "h"] * 3, - } - ) - - pdtest.assert_frame_equal(expected_post_processing, processed_lineitem) - - -def test_post_processing_individual_composite_too_few_keys_created(tpch): - strategy = AncestralStrategy() - synth_lineitem = pd.DataFrame( - data={ - "self|l_partkey": [10, 20, 30, 40] * 3, - "self|l_suppkey": [10, 20, 30, 40] * 3, - "self|l_quantity": [42, 42, 42, 42] * 3, - "self.l_partkey+l_suppkey|ps_partkey": [2, 3, 4, 5] * 3, - "self.l_partkey+l_suppkey|ps_suppkey": [6, 7, 8, 9] * 3, - "self.l_partkey+l_suppkey|ps_availqty": [80, 80, 80, 80] * 3, - "self.l_partkey+l_suppkey.ps_partkey|p_partkey": [2, 3, 4, 5] * 3, - "self.l_partkey+l_suppkey.ps_partkey|p_name": ["a", "b", "c", "d"] * 3, - "self.l_partkey+l_suppkey.ps_suppkey|s_suppkey": [6, 7, 8, 9] * 3, - "self.l_partkey+l_suppkey.ps_suppkey|s_name": ["e", "f", "g", "h"] * 3, - } - ) - - # Given inherent randomness, make_composite_pks can fail to produce enough - # unique composite keys to fill the entire synthetic dataframe. In such situations, - # the client drops records from the raw record handler output and the resulting - # synthetic table only has as many records as keys produced. - with patch( - "gretel_trainer.relational.strategies.ancestral.common.make_composite_pks" - ) as make_keys: - make_keys.return_value = [ - {"self|l_partkey": 55, "self|l_suppkey": 55}, - {"self|l_partkey": 66, "self|l_suppkey": 66}, - {"self|l_partkey": 77, "self|l_suppkey": 77}, - {"self|l_partkey": 88, "self|l_suppkey": 88}, - ] - processed_lineitem = strategy.post_process_individual_synthetic_result( - "lineitem", tpch, synth_lineitem, 1 - ) - - # make_composite_pks only created 4 unique keys, so the table is truncated. - # The values (2-5 and 6-9) come from the subsequent foreign key step. - expected_post_processing = pd.DataFrame( - data={ - "self|l_partkey": [2, 3, 4, 5], - "self|l_suppkey": [6, 7, 8, 9], - "self|l_quantity": [42, 42, 42, 42], - "self.l_partkey+l_suppkey|ps_partkey": [2, 3, 4, 5], - "self.l_partkey+l_suppkey|ps_suppkey": [6, 7, 8, 9], - "self.l_partkey+l_suppkey|ps_availqty": [80, 80, 80, 80], - "self.l_partkey+l_suppkey.ps_partkey|p_partkey": [2, 3, 4, 5], - "self.l_partkey+l_suppkey.ps_partkey|p_name": ["a", "b", "c", "d"], - "self.l_partkey+l_suppkey.ps_suppkey|s_suppkey": [6, 7, 8, 9], - "self.l_partkey+l_suppkey.ps_suppkey|s_name": ["e", "f", "g", "h"], - } - ) - - pdtest.assert_frame_equal(expected_post_processing, processed_lineitem) - - -def test_post_process_synthetic_results(ecom): - strategy = AncestralStrategy() - out_events = pd.DataFrame( - data={ - "self|id": [0, 1, 2], - "self|browser": ["chrome", "safari", "brave"], - "self|traffic_source": ["mobile", "mobile", "mobile"], - "self|user_id": [0, 1, 2], - "self.user_id|id": [0, 1, 2], - "self.user_id|first_name": ["a", "b", "c"], - "self.user_id|last_name": ["A", "B", "C"], - "self.user_id|ssn": ["111", "222", "333"], - } - ) - out_users = pd.DataFrame( - data={ - "self|id": [0, 1, 2], - "self|first_name": ["a", "b", "c"], - "self|last_name": ["A", "B", "C"], - "self|ssn": ["111", "222", "333"], - } - ) - output_tables = { - "events": out_events, - "users": out_users, - } - - processed_tables = strategy.post_process_synthetic_results( - output_tables, [], ecom, 1 - ) - - expected_events = pd.DataFrame( - data={ - "id": [0, 1, 2], - "browser": ["chrome", "safari", "brave"], - "traffic_source": ["mobile", "mobile", "mobile"], - "user_id": [0, 1, 2], - } - ) - expected_users = pd.DataFrame( - data={ - "id": [0, 1, 2], - "first_name": ["a", "b", "c"], - "last_name": ["A", "B", "C"], - "ssn": ["111", "222", "333"], - } - ) - - pdtest.assert_frame_equal(expected_events, processed_tables["events"]) - pdtest.assert_frame_equal(expected_users, processed_tables["users"]) diff --git a/tests/relational/test_ancestry.py b/tests/relational/test_ancestry.py deleted file mode 100644 index f2116d4b..00000000 --- a/tests/relational/test_ancestry.py +++ /dev/null @@ -1,280 +0,0 @@ -import gretel_trainer.relational.ancestry as ancestry - - -def test_ecom_add_and_remove_ancestor_data(ecom): - users_with_ancestors = ancestry.get_table_data_with_ancestors(ecom, "users") - assert set(users_with_ancestors.columns) == { - "self|id", - "self|first_name", - "self|last_name", - } - restored_users = ancestry.drop_ancestral_data(users_with_ancestors) - assert set(restored_users.columns) == {"id", "first_name", "last_name"} - - events_with_ancestors = ancestry.get_table_data_with_ancestors(ecom, "events") - assert set(events_with_ancestors.columns) == { - "self|id", - "self|browser", - "self|traffic_source", - "self|user_id", - "self.user_id|id", - "self.user_id|first_name", - "self.user_id|last_name", - } - restored_events = ancestry.drop_ancestral_data(events_with_ancestors) - assert set(restored_events.columns) == { - "id", - "browser", - "traffic_source", - "user_id", - } - - inventory_items_with_ancestors = ancestry.get_table_data_with_ancestors( - ecom, "inventory_items" - ) - assert set(inventory_items_with_ancestors.columns) == { - "self|id", - "self|sold_at", - "self|cost", - "self|product_id", - "self|product_distribution_center_id", - "self.product_id|id", - "self.product_id|name", - "self.product_id|brand", - "self.product_id|distribution_center_id", - "self.product_distribution_center_id|id", - "self.product_distribution_center_id|name", - "self.product_id.distribution_center_id|id", - "self.product_id.distribution_center_id|name", - } - restored_inventory_items = ancestry.drop_ancestral_data( - inventory_items_with_ancestors - ) - assert set(restored_inventory_items.columns) == { - "id", - "sold_at", - "cost", - "product_id", - "product_distribution_center_id", - } - - -def test_mutagenesis_add_and_remove_ancestor_data(mutagenesis): - bond_with_ancestors = ancestry.get_table_data_with_ancestors(mutagenesis, "bond") - assert set(bond_with_ancestors.columns) == { - "self|type", - "self|atom1_id", - "self|atom2_id", - "self.atom1_id|atom_id", - "self.atom1_id|element", - "self.atom1_id|charge", - "self.atom1_id|molecule_id", - "self.atom2_id|atom_id", - "self.atom2_id|element", - "self.atom2_id|charge", - "self.atom2_id|molecule_id", - "self.atom1_id.molecule_id|molecule_id", - "self.atom1_id.molecule_id|mutagenic", - "self.atom2_id.molecule_id|molecule_id", - "self.atom2_id.molecule_id|mutagenic", - } - restored_bond = ancestry.drop_ancestral_data(bond_with_ancestors) - assert set(restored_bond.columns) == {"type", "atom1_id", "atom2_id"} - - -def test_tpch_add_and_remove_ancestor_data(tpch): - lineitem_with_ancestors = ancestry.get_table_data_with_ancestors(tpch, "lineitem") - assert set(lineitem_with_ancestors.columns) == { - "self|l_partkey", - "self|l_suppkey", - "self|l_quantity", - "self.l_partkey+l_suppkey|ps_partkey", - "self.l_partkey+l_suppkey|ps_suppkey", - "self.l_partkey+l_suppkey|ps_availqty", - "self.l_partkey+l_suppkey.ps_partkey|p_partkey", - "self.l_partkey+l_suppkey.ps_partkey|p_name", - "self.l_partkey+l_suppkey.ps_suppkey|s_suppkey", - "self.l_partkey+l_suppkey.ps_suppkey|s_name", - } - restored_lineitem = ancestry.drop_ancestral_data(lineitem_with_ancestors) - assert set(restored_lineitem.columns) == {"l_partkey", "l_suppkey", "l_quantity"} - - -def test_ancestral_data_from_different_tablesets(source_nba, synthetic_nba): - source_nba = source_nba[0] - _, custom_states, custom_cities, custom_teams = synthetic_nba - - # By default, get data from source - source_teams_with_ancestors = ancestry.get_table_data_with_ancestors( - source_nba, "teams" - ) - assert set(source_teams_with_ancestors["self|name"]) == {"Lakers", "Grizzlies"} - - custom_tableset = { - "states": custom_states, - "cities": custom_cities, - "teams": custom_teams, - } - - # Optionally provide a different tableset - custom_teams_with_ancestors = ancestry.get_table_data_with_ancestors( - source_nba, "teams", custom_tableset - ) - assert set(custom_teams_with_ancestors["self|name"]) == {"Sixers", "Heat"} - - -def test_whether_column_is_ancestral(mutagenesis): - assert ancestry.is_ancestral_column("self|atom1_id") is False - assert ancestry.is_ancestral_column("self.atom1_id|atom1_id") - assert ancestry.is_ancestral_column("self.atom1_id.molecule_id|atom1_id") - - -def test_primary_key_in_multigenerational_format(mutagenesis): - assert ancestry.get_multigenerational_primary_key(mutagenesis, "bond") == [ - "self|atom1_id", - "self|atom2_id", - ] - assert ancestry.get_multigenerational_primary_key(mutagenesis, "atom") == [ - "self|atom_id" - ] - - -def test_get_all_key_columns(ecom, mutagenesis): - assert set(ancestry.get_all_key_columns(ecom, "distribution_center")) == {"self|id"} - assert set(ancestry.get_all_key_columns(ecom, "events")) == { - "self|id", - "self|user_id", - "self.user_id|id", - } - assert set(ancestry.get_all_key_columns(ecom, "inventory_items")) == { - "self|id", - "self|product_id", - "self|product_distribution_center_id", - "self.product_id|id", - "self.product_id|distribution_center_id", - "self.product_distribution_center_id|id", - "self.product_id.distribution_center_id|id", - } - - assert set(ancestry.get_all_key_columns(mutagenesis, "molecule")) == { - "self|molecule_id" - } - assert set(ancestry.get_all_key_columns(mutagenesis, "atom")) == { - "self|atom_id", - "self|molecule_id", - "self.molecule_id|molecule_id", - } - assert set(ancestry.get_all_key_columns(mutagenesis, "bond")) == { - "self|atom1_id", - "self|atom2_id", - "self.atom1_id|atom_id", - "self.atom1_id|molecule_id", - "self.atom1_id.molecule_id|molecule_id", - "self.atom2_id|atom_id", - "self.atom2_id|molecule_id", - "self.atom2_id.molecule_id|molecule_id", - } - - -def test_ancestral_foreign_key_maps(ecom): - events_afk_maps = ancestry.get_ancestral_foreign_key_maps(ecom, "events") - assert events_afk_maps == [("self|user_id", "self.user_id|id")] - - inventory_items_afk_maps = ancestry.get_ancestral_foreign_key_maps( - ecom, "inventory_items" - ) - assert ("self|product_id", "self.product_id|id") in inventory_items_afk_maps - assert ( - "self|product_distribution_center_id", - "self.product_distribution_center_id|id", - ) in inventory_items_afk_maps - - -def test_ancestral_foreign_key_maps_composite(tpch): - lineitem_afk_maps = ancestry.get_ancestral_foreign_key_maps(tpch, "lineitem") - assert ( - "self|l_partkey", - "self.l_partkey+l_suppkey|ps_partkey", - ) in lineitem_afk_maps - assert ( - "self|l_suppkey", - "self.l_partkey+l_suppkey|ps_suppkey", - ) in lineitem_afk_maps - - -def test_prepend_foreign_key_lineage(ecom): - multigen_inventory_items = ancestry.get_table_data_with_ancestors( - ecom, "inventory_items" - ) - order_items_parent_data = ancestry.prepend_foreign_key_lineage( - multigen_inventory_items, ["inventory_item_id"] - ) - assert set(order_items_parent_data.columns) == { - "self.inventory_item_id|id", - "self.inventory_item_id|sold_at", - "self.inventory_item_id|cost", - "self.inventory_item_id|product_id", - "self.inventory_item_id|product_distribution_center_id", - "self.inventory_item_id.product_id|id", - "self.inventory_item_id.product_id|name", - "self.inventory_item_id.product_id|brand", - "self.inventory_item_id.product_id|distribution_center_id", - "self.inventory_item_id.product_distribution_center_id|id", - "self.inventory_item_id.product_distribution_center_id|name", - "self.inventory_item_id.product_id.distribution_center_id|id", - "self.inventory_item_id.product_id.distribution_center_id|name", - } - - -def test_get_seed_safe_multigenerational_columns_1(pets): - table_cols = ancestry.get_seed_safe_multigenerational_columns(pets) - - expected = { - "humans": {"self|id", "self|name", "self|city"}, - "pets": { - "self|id", - "self|name", - "self|age", - "self|human_id", - "self.human_id|id", - # "self.human_id|name", # highly unique categorical - "self.human_id|city", - }, - } - - assert set(table_cols.keys()) == set(expected.keys()) - for table, expected_cols in expected.items(): - assert set(table_cols[table]) == expected_cols - - -def test_get_seed_safe_multigenerational_columns_2(source_nba): - source_nba = source_nba[0] - table_cols = ancestry.get_seed_safe_multigenerational_columns(source_nba) - - expected = { - "teams": { - "self|name", - "self|id", - "self|city_id", - "self.city_id|id", - "self.city_id|state_id", - # "self.city_id|name", # highly unique categorical - "self.city_id.state_id|id", - # "self.city_id.state_id|name", # highly unique categorical - }, - "cities": { - "self|id", - "self|state_id", - "self|name", - "self.state_id|id", - # "self.state_id|name", # highly unique categorical - }, - "states": { - "self|id", - "self|name", - }, - } - - assert set(table_cols.keys()) == set(expected.keys()) - for table, expected_cols in expected.items(): - assert set(table_cols[table]) == expected_cols diff --git a/tests/relational/test_backup.py b/tests/relational/test_backup.py deleted file mode 100644 index 2f8c11fa..00000000 --- a/tests/relational/test_backup.py +++ /dev/null @@ -1,165 +0,0 @@ -import json - -from gretel_trainer.relational.backup import ( - Backup, - BackupClassify, - BackupForeignKey, - BackupGenerate, - BackupRelationalData, - BackupRelationalDataTable, - BackupSyntheticsTrain, - BackupTransformsTrain, -) - - -def test_backup_relational_data(trips): - expected = BackupRelationalData( - tables={ - "vehicle_types": BackupRelationalDataTable( - primary_key=["id"], - ), - "trips": BackupRelationalDataTable( - primary_key=["id"], - ), - }, - foreign_keys=[ - BackupForeignKey( - table="trips", - constrained_columns=["vehicle_type_id"], - referred_table="vehicle_types", - referred_columns=["id"], - ) - ], - ) - - assert BackupRelationalData.from_relational_data(trips) == expected - - -def test_backup_relational_data_with_json(documents, get_invented_table_suffix): - purchases_root_invented_table = f"purchases_{get_invented_table_suffix(1)}" - purchases_data_years_invented_table = f"purchases_{get_invented_table_suffix(2)}" - - expected = BackupRelationalData( - tables={ - "users": BackupRelationalDataTable(primary_key=["id"]), - "purchases": BackupRelationalDataTable( - primary_key=["id"], - producer_metadata={ - "invented_root_table_name": purchases_root_invented_table, - "table_name_mappings": { - "purchases": purchases_root_invented_table, - "purchases^data>years": purchases_data_years_invented_table, - }, - }, - ), - purchases_root_invented_table: BackupRelationalDataTable( - primary_key=["id", "~PRIMARY_KEY_ID~"], - invented_table_metadata={ - "invented_root_table_name": purchases_root_invented_table, - "original_table_name": "purchases", - "json_breadcrumb_path": "purchases", - "empty": False, - }, - ), - purchases_data_years_invented_table: BackupRelationalDataTable( - primary_key=["~PRIMARY_KEY_ID~"], - invented_table_metadata={ - "invented_root_table_name": purchases_root_invented_table, - "original_table_name": "purchases", - "json_breadcrumb_path": "purchases^data>years", - "empty": False, - }, - ), - "payments": BackupRelationalDataTable(primary_key=["id"]), - }, - foreign_keys=[ - BackupForeignKey( - table="payments", - constrained_columns=["purchase_id"], - referred_table=purchases_root_invented_table, - referred_columns=["id"], - ), - BackupForeignKey( - table=purchases_root_invented_table, - constrained_columns=["user_id"], - referred_table="users", - referred_columns=["id"], - ), - BackupForeignKey( - table=purchases_data_years_invented_table, - constrained_columns=["purchases~id"], - referred_table=purchases_root_invented_table, - referred_columns=["~PRIMARY_KEY_ID~"], - ), - ], - ) - - assert BackupRelationalData.from_relational_data(documents) == expected - - -def test_backup(): - backup_relational = BackupRelationalData( - tables={ - "customer": BackupRelationalDataTable( - primary_key=["id"], - ), - "address": BackupRelationalDataTable( - primary_key=[], - ), - }, - foreign_keys=[ - BackupForeignKey( - table="address", - constrained_columns=["customer_id"], - referred_table="customer", - referred_columns=["id"], - ) - ], - ) - backup_classify = BackupClassify( - model_ids={ - "customer": "aaabbbccc", - "address": "dddeeefff", - }, - ) - backup_transforms_train = BackupTransformsTrain( - model_ids={ - "customer": "222333444", - "address": "888777666", - }, - lost_contact=[], - ) - backup_synthetics_train = BackupSyntheticsTrain( - model_ids={ - "customer": "1234567890", - "address": "0987654321", - }, - lost_contact=[], - ) - backup_generate = BackupGenerate( - identifier="run-id", - preserved=[], - record_size_ratio=1.0, - lost_contact=[], - record_handler_ids={ - "customer": "555444666", - "address": "333111222", - }, - ) - source_archive = "gretel_abc_source_tables.tar.gz" - backup = Backup( - project_name="my-project", - strategy="independent", - refresh_interval=120, - source_archive=source_archive, - relational_data=backup_relational, - classify=backup_classify, - transforms_train=backup_transforms_train, - synthetics_train=backup_synthetics_train, - generate=backup_generate, - ) - - j = json.dumps(backup.as_dict) - rehydrated = Backup.from_dict(json.loads(j)) - - assert rehydrated == backup diff --git a/tests/relational/test_common_strategy.py b/tests/relational/test_common_strategy.py deleted file mode 100644 index 86051d89..00000000 --- a/tests/relational/test_common_strategy.py +++ /dev/null @@ -1,85 +0,0 @@ -import pandas as pd - -import gretel_trainer.relational.strategies.common as common - -from gretel_trainer.relational.core import RelationalData - - -def test_composite_pk_columns(tmpdir): - df = pd.DataFrame( - data={ - "letter": ["a", "a", "a", "a", "b", "b", "b", "b"], - "number": [1, 2, 3, 4, 1, 2, 3, 4], - } - ) - rel_data = RelationalData(directory=tmpdir) - rel_data.add_table( - name="table", - primary_key=["letter", "number"], - data=df, - ) - - result = common.make_composite_pks( - table_name="table", - rel_data=rel_data, - primary_key=["letter", "number"], - synth_row_count=8, - ) - - # Label-encoding turns the keys into zero-indexed contiguous integers. - # It is absolutely required that all composite keys returned are unique. - # We also ideally recreate the original data frequencies (in this case, - # two unique letters and four unique numbers). - expected_keys = [ - {"letter": 0, "number": 0}, - {"letter": 0, "number": 1}, - {"letter": 0, "number": 2}, - {"letter": 0, "number": 3}, - {"letter": 1, "number": 0}, - {"letter": 1, "number": 1}, - {"letter": 1, "number": 2}, - {"letter": 1, "number": 3}, - ] - - for expected_key in expected_keys: - assert expected_key in result - - -def test_composite_pk_columns_2(tmpdir): - df = pd.DataFrame( - data={ - "letter": ["a", "a", "a", "a", "b", "b", "b", "b"], - "number": [1, 2, 3, 4, 5, 6, 7, 8], - } - ) - rel_data = RelationalData(directory=tmpdir) - rel_data.add_table( - name="table", - primary_key=["letter", "number"], - data=df, - ) - - result = common.make_composite_pks( - table_name="table", - rel_data=rel_data, - primary_key=["letter", "number"], - synth_row_count=8, - ) - - # We create as many keys as we need - assert len(result) == 8 - - # Each combination is unique - assert len(set([str(composite_key) for composite_key in result])) == 8 - - # In this case, there are more potential unique combinations than there are synthetic rows, - # so we can't say for sure what the exact composite values will be. However, we do expect - # the original frequencies to be maintained. - synthetic_letters = [key["letter"] for key in result] - assert len(synthetic_letters) == 8 - assert set(synthetic_letters) == {0, 1} - assert len([x for x in synthetic_letters if x != 0]) == 4 - - synthetic_numbers = [key["number"] for key in result] - assert len(synthetic_numbers) == 8 - assert set(synthetic_numbers) == {0, 1, 2, 3, 4, 5, 6, 7} diff --git a/tests/relational/test_connectors.py b/tests/relational/test_connectors.py deleted file mode 100644 index 7f582e19..00000000 --- a/tests/relational/test_connectors.py +++ /dev/null @@ -1,37 +0,0 @@ -import sqlite3 -import tempfile - -import pytest - -from gretel_trainer.relational.connectors import sqlite_conn -from gretel_trainer.relational.core import MultiTableException, Scope - - -def test_extract_subsets_of_relational_data(example_dbs, tmpdir): - with tempfile.NamedTemporaryFile() as f: - con = sqlite3.connect(f.name) - cur = con.cursor() - with open(example_dbs / "ecom.sql") as sql_script: - cur.executescript(sql_script.read()) - - connector = sqlite_conn(f.name) - - with pytest.raises(MultiTableException): - connector.extract(only={"users"}, ignore={"events"}, storage_dir=tmpdir) - - only = connector.extract( - only={"users", "events", "products"}, storage_dir=tmpdir - ) - ignore = connector.extract( - ignore={"distribution_center", "order_items", "inventory_items"}, - storage_dir=tmpdir, - ) - - expected_tables = {"users", "events", "products"} - assert set(only.list_all_tables(Scope.ALL)) == expected_tables - assert set(ignore.list_all_tables(Scope.ALL)) == expected_tables - - # `products` has a foreign key to `distribution_center` in the source, but because the - # latter table was not extracted, the relationship is not recognized - assert only.get_parents("products") == [] - assert ignore.get_parents("products") == [] diff --git a/tests/relational/test_extractor.py b/tests/relational/test_extractor.py deleted file mode 100644 index 761fb88f..00000000 --- a/tests/relational/test_extractor.py +++ /dev/null @@ -1,138 +0,0 @@ -import sqlite3 -import tempfile - -from pathlib import Path -from typing import Iterable - -import pytest - -from gretel_trainer.relational.connectors import Connector, sqlite_conn -from gretel_trainer.relational.extractor import ( - _determine_sample_size, - ExtractorConfig, - TableExtractor, -) - - -def test_subset_config(): - # Can't have a target row count < -1 - with pytest.raises(ValueError): - ExtractorConfig(target_row_count=-2) - - # Concrete row count - config = ExtractorConfig(target_row_count=100) - assert _determine_sample_size(config, 200) == 100 - - # Ratio - config = ExtractorConfig(target_row_count=0.5) - assert _determine_sample_size(config, 100) == 50 - - # Entire table - config = ExtractorConfig() - assert config.entire_table - assert _determine_sample_size(config, 101) == 101 - - # Empty table - config = ExtractorConfig(target_row_count=0) - assert not config.entire_table - assert config.empty_table - assert _determine_sample_size(config, 101) == 0 - - # Can't have both only and ignore - with pytest.raises(ValueError): - ExtractorConfig(ignore={"foo"}, only={"bar"}) - - -@pytest.fixture -def connector_ecom(example_dbs) -> Iterable[Connector]: - with tempfile.NamedTemporaryFile() as f: - con = sqlite3.connect(f.name) - cur = con.cursor() - with open(example_dbs / "ecom.sql") as sql_script: - cur.executescript(sql_script.read()) - - connector = sqlite_conn(f.name) - yield connector - - -@pytest.fixture -def connector_art(example_dbs) -> Iterable[Connector]: - with tempfile.NamedTemporaryFile() as f: - con = sqlite3.connect(f.name) - cur = con.cursor() - with open(example_dbs / "art.sql") as sql_script: - cur.executescript(sql_script.read()) - - connector = sqlite_conn(f.name) - yield connector - - -def test_extract_schema(connector_ecom: Connector, tmpdir): - config = ExtractorConfig() - extractor = TableExtractor( - config=config, connector=connector_ecom, storage_dir=Path(tmpdir) - ) - extractor._extract_schema() - assert extractor.table_order == [ - "events", - "order_items", - "users", - "inventory_items", - "products", - "distribution_center", - ] - - -def test_table_session(connector_art, tmpdir): - config = ExtractorConfig() - extractor = TableExtractor( - config=config, connector=connector_art, storage_dir=Path(tmpdir) - ) - table_session = extractor._get_table_session("paintings") - assert table_session.total_row_count == 7 - assert set(table_session.columns) == {"id", "artist_id", "name"} - assert table_session.total_column_count == 3 - - -@pytest.mark.parametrize("target,expect", [(-1, 7), (3, 3), (0, 0)]) -def test_sample_table(target, expect, connector_art, tmpdir): - config = ExtractorConfig(target_row_count=target) - extractor = TableExtractor( - config=config, connector=connector_art, storage_dir=Path(tmpdir) - ) - extractor._chunk_size = 1 - meta = extractor._sample_table("paintings") - assert meta.original_row_count == 7 - assert meta.sampled_row_count == expect - assert meta.column_count == 3 - df = extractor.get_table_df("paintings") - assert len(df) == expect - - # Now we can sample from an intermediate table - meta = extractor._sample_table("artists", child_tables=["paintings"]) - assert meta.original_row_count == 4 - assert ( - 0 <= meta.sampled_row_count <= 4 - ) # could vary based on what other FKs were selected - df = extractor.get_table_df("artists") - assert 0 <= len(df) <= 4 - - # A001 should never be sampled, sorry Wassily - assert "A001" not in df["id"] - - -@pytest.mark.parametrize("sample_mode", ["random", "contiguous"]) -def test_sample_tables(sample_mode, connector_art, tmpdir): - config = ExtractorConfig(target_row_count=0.5, sample_mode=sample_mode) - extractor = TableExtractor( - config=config, - connector=connector_art, - storage_dir=Path(tmpdir), - ) - meta = extractor.sample_tables() - paintings = meta["paintings"] - assert paintings.sampled_row_count == 3 - - # artists will vary between 1 and 3 - artists = meta["artists"] - assert 1 <= artists.sampled_row_count <= 3 diff --git a/tests/relational/test_independent_strategy.py b/tests/relational/test_independent_strategy.py deleted file mode 100644 index 8e478eb6..00000000 --- a/tests/relational/test_independent_strategy.py +++ /dev/null @@ -1,442 +0,0 @@ -import os -import tempfile - -from collections import defaultdict -from unittest.mock import patch - -import pandas as pd -import pandas.testing as pdtest - -from gretel_trainer.relational.core import RelationalData -from gretel_trainer.relational.strategies.independent import IndependentStrategy - - -def test_preparing_training_data_does_not_mutate_source_data(pets): - original_tables = { - table: pets.get_table_data(table).copy() for table in pets.list_all_tables() - } - - strategy = IndependentStrategy() - - with ( - tempfile.NamedTemporaryFile() as pets_dest, - tempfile.NamedTemporaryFile() as humans_dest, - ): - strategy.prepare_training_data( - pets, {"pets": pets_dest.name, "humans": humans_dest.name} - ) - - for table in pets.list_all_tables(): - pdtest.assert_frame_equal(original_tables[table], pets.get_table_data(table)) - - -def test_prepare_training_data_removes_primary_and_foreign_keys(pets): - strategy = IndependentStrategy() - - with ( - tempfile.NamedTemporaryFile() as pets_dest, - tempfile.NamedTemporaryFile() as humans_dest, - ): - training_data = strategy.prepare_training_data( - pets, {"pets": pets_dest.name, "humans": humans_dest.name} - ) - train_pets = pd.read_csv(training_data["pets"]) - - assert set(train_pets.columns) == {"name", "age"} - - -def test_prepare_training_data_subset_of_tables(pets): - strategy = IndependentStrategy() - - with ( - tempfile.NamedTemporaryFile() as pets_dest, - tempfile.NamedTemporaryFile() as humans_dest, - ): - training_data = strategy.prepare_training_data( - pets, {"humans": humans_dest.name} - ) - assert not pd.read_csv(training_data["humans"]).empty - assert os.stat(pets_dest.name).st_size == 0 - - -def test_prepare_training_data_join_table(insurance): - strategy = IndependentStrategy() - - with ( - tempfile.NamedTemporaryFile() as beneficiary_dest, - tempfile.NamedTemporaryFile() as policies_dest, - ): - training_data = strategy.prepare_training_data( - insurance, - { - "beneficiary": beneficiary_dest.name, - "insurance_policies": policies_dest.name, - }, - ) - assert set(training_data.keys()) == {"beneficiary"} - assert not pd.read_csv(training_data["beneficiary"]).empty - assert os.stat(policies_dest.name).st_size == 0 - - -def test_retraining_a_set_of_tables_only_retrains_those_tables(ecom): - strategy = IndependentStrategy() - assert set(strategy.tables_to_retrain(["users"], ecom)) == {"users"} - assert set(strategy.tables_to_retrain(["users", "events"], ecom)) == { - "users", - "events", - } - assert set(strategy.tables_to_retrain(["products"], ecom)) == {"products"} - - -def test_table_generation_readiness(ecom): - strategy = IndependentStrategy() - - # All tables are immediately ready for generation - assert set(strategy.ready_to_generate(ecom, [], [])) == { - "users", - "events", - "distribution_center", - "products", - "inventory_items", - "order_items", - } - - # Tables that are in progress or finished are no longer ready - assert set(strategy.ready_to_generate(ecom, ["users"], ["events"])) == { - "distribution_center", - "products", - "inventory_items", - "order_items", - } - - -def test_generation_job_requests_num_records(pets, output_handler): - strategy = IndependentStrategy() - job = strategy.get_generation_job("pets", pets, 2.0, {}, "run-id", output_handler) - - assert job == {"params": {"num_records": 10}} - - -def test_post_processing_one_to_one(pets): - strategy = IndependentStrategy() - - # Models train on data with PKs and FKs removed, - # so those fields won't be present in raw results - raw_synth_tables = { - "humans": pd.DataFrame( - data={ - "name": ["Michael", "Dominique", "Dirk"], - "city": ["Chicago", "Atlanta", "Dallas"], - } - ), - "pets": pd.DataFrame( - data={ - "name": ["Bull", "Hawk", "Maverick"], - "age": [6, 0, 1], - } - ), - } - - # Normally we shuffle synthesized keys for realism, but for deterministic testing we sort instead - with patch("random.shuffle", wraps=sorted): - processed = strategy.post_process_synthetic_results( - raw_synth_tables, [], pets, 1 - ) - - # Fields from the raw results do not change - pdtest.assert_frame_equal( - processed["humans"], - pd.DataFrame( - data={ - "name": ["Michael", "Dominique", "Dirk"], - "city": ["Chicago", "Atlanta", "Dallas"], - "id": [0, 1, 2], - } - ), - ) - pdtest.assert_frame_equal( - processed["pets"], - pd.DataFrame( - data={ - "name": ["Bull", "Hawk", "Maverick"], - "age": [6, 0, 1], - "id": [0, 1, 2], - "human_id": [0, 1, 2], - } - ), - ) - - -def test_post_processing_foreign_keys_with_skewed_frequencies_and_different_size_tables( - trips, -): - strategy = IndependentStrategy() - - # Simulate a record_size_ratio of 1.5 - raw_synth_tables = { - "vehicle_types": pd.DataFrame( - data={"name": ["car", "train", "plane", "bus", "walk", "bike"]} - ), - "trips": pd.DataFrame(data={"purpose": ["w"] * 150}), - } - - processed = strategy.post_process_synthetic_results( - raw_synth_tables, [], trips, 1.5 - ) - processed_trips = processed["trips"] - - fk_values = set(processed["trips"]["vehicle_type_id"]) - assert fk_values == {0, 1, 2, 3, 4, 5} - - fk_value_counts = defaultdict(int) - for _, row in processed_trips.iterrows(): - fk_value = row["vehicle_type_id"] - fk_value_counts[fk_value] = fk_value_counts[fk_value] + 1 - - fk_value_counts = sorted(list(fk_value_counts.values())) - - assert fk_value_counts == [5, 5, 15, 30, 35, 60] - - -# In this scenario, a table (shipping_notifications) has a FK (customer_id) pointing to -# a column that is itself a FK but *not* a PK (orders.customer_id). -# (No, this is not a "perfectly normalized" schema, but it can happen in the wild.) -# We need to ensure tables have FKs synthesized in parent->child order to avoid blowing up -# due to missing columns. -def test_post_processing_fks_to_non_pks(tmpdir): - rel_data = RelationalData(directory=tmpdir) - - rel_data.add_table( - name="customers", - primary_key="id", - data=pd.DataFrame(data={"id": [1, 2], "name": ["Xavier", "Yesenia"]}), - ) - rel_data.add_table( - name="orders", - primary_key="id", - data=pd.DataFrame( - data={ - "id": [1, 2], - "customer_id": [1, 2], - "total": [42, 43], - } - ), - ) - rel_data.add_table( - name="shipping_notifications", - primary_key="id", - data=pd.DataFrame( - data={ - "id": [1, 2], - "order_id": [1, 2], - "customer_id": [1, 2], - "service": ["FedEx", "USPS"], - } - ), - ) - - # Add FKs. The third one is the critical one for this test. - rel_data.add_foreign_key_constraint( - table="orders", - constrained_columns=["customer_id"], - referred_table="customers", - referred_columns=["id"], - ) - rel_data.add_foreign_key_constraint( - table="shipping_notifications", - constrained_columns=["order_id"], - referred_table="orders", - referred_columns=["id"], - ) - rel_data.add_foreign_key_constraint( - table="shipping_notifications", - constrained_columns=["customer_id"], - referred_table="orders", - referred_columns=["customer_id"], - ) - - strategy = IndependentStrategy() - - # This dict is deliberately ordered child->parent for this unit test. - # Were it not for logic in the strategy (processing tables in parent->child order), - # this setup would cause an exception. - raw_synth_tables = { - "shipping_notifications": pd.DataFrame(data={"service": ["FedEx", "USPS"]}), - "orders": pd.DataFrame(data={"total": [55, 56]}), - "customers": pd.DataFrame(data={"name": ["Alice", "Bob"]}), - } - - processed = strategy.post_process_synthetic_results( - raw_synth_tables, [], rel_data, 1 - ) - - for table in rel_data.list_all_tables(): - assert set(processed[table].columns) == set(rel_data.get_table_columns(table)) - - -def test_post_processing_null_foreign_key(tmpdir): - rel_data = RelationalData(directory=tmpdir) - - rel_data.add_table( - name="customers", - primary_key="id", - data=pd.DataFrame(data={"id": [1, 2], "name": ["Xavier", "Yesenia"]}), - ) - rel_data.add_table( - name="events", - primary_key="id", - data=pd.DataFrame( - data={ - "id": [1, 2], - "customer_id": [1, None], - "total": [42, 43], - } - ), - ) - rel_data.add_foreign_key_constraint( - table="events", - constrained_columns=["customer_id"], - referred_table="customers", - referred_columns=["id"], - ) - - strategy = IndependentStrategy() - - raw_synth_tables = { - "events": pd.DataFrame(data={"total": [55, 56, 57, 58]}), - "customers": pd.DataFrame( - data={"name": ["Alice", "Bob", "Christina", "David"]} - ), - } - - # Patch shuffle for deterministic testing, but don't swap in `sorted` - # because that function doesn't cooperate with `None` (raises TypeError) - with patch("random.shuffle", wraps=lambda x: x): - processed = strategy.post_process_synthetic_results( - raw_synth_tables, [], rel_data, 2 - ) - - # Given 50% of source FKs are null and record_size_ratio=2, - # we expect 2/4 customer_ids to be null - pdtest.assert_frame_equal( - processed["events"], - pd.DataFrame( - data={ - "total": [55, 56, 57, 58], - "id": [0, 1, 2, 3], - "customer_id": [None, None, 0, 1], - } - ), - ) - - -def test_post_processing_null_composite_foreign_key(tmpdir): - rel_data = RelationalData(directory=tmpdir) - - rel_data.add_table( - name="customers", - primary_key="id", - data=pd.DataFrame( - data={ - "id": [1, 2], - "first": ["Albert", "Betsy"], - "last": ["Anderson", "Bond"], - } - ), - ) - rel_data.add_table( - name="events", - primary_key="id", - data=pd.DataFrame( - data={ - "id": [1, 2, 3, 4, 5], - "customer_first": ["Albert", "Betsy", None, "Betsy", None], - "customer_last": ["Anderson", "Bond", None, None, "Bond"], - "total": [42, 43, 44, 45, 46], - } - ), - ) - rel_data.add_foreign_key_constraint( - table="events", - constrained_columns=["customer_first", "customer_last"], - referred_table="customers", - referred_columns=["first", "last"], - ) - - strategy = IndependentStrategy() - - raw_synth_tables = { - "events": pd.DataFrame(data={"total": [55, 56, 57, 58, 59]}), - "customers": pd.DataFrame( - data={ - "first": ["Herbert", "Isabella", "Jack", "Kevin", "Louise"], - "last": ["Hoover", "Irvin", "Johnson", "Knight", "Lane"], - } - ), - } - - # Patch shuffle for deterministic testing - with patch("random.shuffle", wraps=sorted): - processed = strategy.post_process_synthetic_results( - raw_synth_tables, [], rel_data, 1 - ) - - # We do not create composite foreign key values with nulls, - # even if some existed in the source data. - pdtest.assert_frame_equal( - processed["events"], - pd.DataFrame( - data={ - "total": [55, 56, 57, 58, 59], - "id": [0, 1, 2, 3, 4], - "customer_first": ["Herbert", "Isabella", "Jack", "Kevin", "Louise"], - "customer_last": ["Hoover", "Irvin", "Johnson", "Knight", "Lane"], - } - ), - ) - - -def test_post_processing_with_bypass_table(insurance): - strategy = IndependentStrategy() - - raw_synth_tables = { - "beneficiary": pd.DataFrame( - data={ - "name": ["Adam", "Beth", "Chris", "Demi", "Eric"], - } - ), - "insurance_policies": pd.DataFrame(index=range(40)), - } - - processed = strategy.post_process_synthetic_results( - raw_synth_tables, [], insurance, 1 - ) - - beneficiary_ids = [0, 1, 2, 3, 4] - pdtest.assert_frame_equal( - processed["beneficiary"], - pd.DataFrame( - data={ - "name": ["Adam", "Beth", "Chris", "Demi", "Eric"], - "id": beneficiary_ids, - } - ), - ) - assert set(processed["insurance_policies"].columns) == { - "id", - "primary_beneficiary", - "secondary_beneficiary", - } - assert list(processed["insurance_policies"]["id"].values) == list(range(40)) - assert all( - [ - v in beneficiary_ids - for v in processed["insurance_policies"]["primary_beneficiary"].values - ] - ) - assert all( - [ - v in beneficiary_ids - for v in processed["insurance_policies"]["secondary_beneficiary"].values - ] - ) diff --git a/tests/relational/test_model_config.py b/tests/relational/test_model_config.py deleted file mode 100644 index 749960d2..00000000 --- a/tests/relational/test_model_config.py +++ /dev/null @@ -1,240 +0,0 @@ -import pytest - -from gretel_client.projects.models import read_model_config -from gretel_trainer.relational.core import MultiTableException -from gretel_trainer.relational.model_config import ( - assemble_configs, - get_model_key, - make_evaluate_config, - make_synthetics_config, - make_transform_config, -) - - -def test_get_model_key(): - # Returns the model key when config is valid (at least as far as model key) - assert get_model_key({"models": [{"amplify": {}}]}) == "amplify" - - # Returns None when given an invalid config - assert get_model_key({"foo": "bar"}) is None - assert get_model_key({"models": "wrong type"}) is None - assert get_model_key({"models": {"wrong": "type"}}) is None - assert get_model_key({"models": []}) is None - assert get_model_key({"models": ["wrong type"]}) is None - - -def test_evaluate_config_prepends_evaluation_type(): - config = make_evaluate_config("users", "individual") - assert config["name"] == "evaluate-individual-users" - - -def test_synthetics_config_prepends_workflow(): - config = make_synthetics_config("users", "synthetics/amplify") - assert config["name"] == "synthetics-users" - - -def test_synthetics_config_handles_noncompliant_table_names(): - config = make_synthetics_config("hello--world", "synthetics/amplify") - assert config["name"] == "synthetics-hello__world" - - -def test_transform_requires_valid_config(mutagenesis): - with pytest.raises(MultiTableException): - make_transform_config(mutagenesis, "atom", "synthetics/amplify") - - -def test_transform_v2_config_is_unaltered(mutagenesis): - tv2_config = { - "schema_version": "1.0", - "name": "original-name", - "models": [{"transform_v2": {"some": "Tv2 config"}}], - } - config = make_transform_config(mutagenesis, "atom", tv2_config) - assert config["name"] == "transforms-atom" - assert config["schema_version"] == tv2_config["schema_version"] - assert config["models"] == tv2_config["models"] - - -def test_transforms_config_prepends_workflow(mutagenesis): - config = make_transform_config(mutagenesis, "atom", "transform/default") - assert config["name"] == "transforms-atom" - - -def test_transforms_config_adds_passthrough_policy(mutagenesis): - def get_policies(config): - # default blueprint uses `transforms` model key - return config["models"][0]["transforms"]["policies"] - - original = read_model_config("transform/default") - original_policies = get_policies(original) - - xform_config = make_transform_config(mutagenesis, "atom", "transform/default") - xform_config_policies = get_policies(xform_config) - - assert len(xform_config_policies) == len(original_policies) + 1 - assert xform_config_policies[1:] == original_policies - assert xform_config_policies[0] == { - "name": "ignore-keys", - "rules": [ - { - "name": "ignore-key-columns", - "conditions": {"field_name": ["atom_id", "molecule_id"]}, - "transforms": [ - { - "type": "passthrough", - } - ], - } - ], - } - - -_ACTGAN_CONFIG = {"models": [{"actgan": {}}]} -_LSTM_CONFIG = {"models": [{"synthetics": {}}]} -_TABULAR_DP_CONFIG = {"models": [{"tabular_dp": {}}]} - - -def test_assemble_configs(ecom): - # Apply a config to all tables - configs = assemble_configs( - rel_data=ecom, - config=_ACTGAN_CONFIG, - table_specific_configs=None, - only=None, - ignore=None, - ) - assert len(configs) == len(ecom.list_all_tables()) - assert all([config == _ACTGAN_CONFIG for config in configs.values()]) - - # Limit scope to a subset of tables - configs = assemble_configs( - rel_data=ecom, - config=_ACTGAN_CONFIG, - only={"events", "users"}, - table_specific_configs=None, - ignore=None, - ) - assert len(configs) == 2 - - # Exclude a table - configs = assemble_configs( - rel_data=ecom, - config=_ACTGAN_CONFIG, - ignore={"events"}, - table_specific_configs=None, - only=None, - ) - assert len(configs) == len(ecom.list_all_tables()) - 1 - - # Cannot specify both only and ignore - with pytest.raises(MultiTableException): - assemble_configs( - rel_data=ecom, - config=_ACTGAN_CONFIG, - only={"users"}, - ignore={"events"}, - table_specific_configs=None, - ) - - # Provide table-specific configs - configs = assemble_configs( - rel_data=ecom, - config=_ACTGAN_CONFIG, - table_specific_configs={"events": _LSTM_CONFIG}, - only=None, - ignore=None, - ) - assert configs["events"] == _LSTM_CONFIG - assert all( - [ - config == _ACTGAN_CONFIG - for table, config in configs.items() - if table != "events" - ] - ) - - # Ensure no conflicts between table-specific configs and scope - with pytest.raises(MultiTableException): - assemble_configs( - rel_data=ecom, - config=_ACTGAN_CONFIG, - table_specific_configs={"events": _LSTM_CONFIG}, - ignore={"events"}, - only=None, - ) - with pytest.raises(MultiTableException): - assemble_configs( - rel_data=ecom, - config=_ACTGAN_CONFIG, - table_specific_configs={"events": _LSTM_CONFIG}, - only={"users"}, - ignore=None, - ) - - -def test_assemble_configs_json(documents, invented_tables): - # If table_specific_configs includes a producer table, we apply it to all invented tables - configs = assemble_configs( - rel_data=documents, - config=_ACTGAN_CONFIG, - table_specific_configs={"purchases": _LSTM_CONFIG}, - only=None, - ignore=None, - ) - assert configs == { - "users": _ACTGAN_CONFIG, - "payments": _ACTGAN_CONFIG, - invented_tables["purchases_root"]: _LSTM_CONFIG, - invented_tables["purchases_data_years"]: _LSTM_CONFIG, - } - - # If table_specific_configs includes a producer table AND an invented table, - # the more specific config takes precedence. - configs = assemble_configs( - rel_data=documents, - config=_ACTGAN_CONFIG, - table_specific_configs={ - "purchases": _LSTM_CONFIG, - invented_tables["purchases_data_years"]: _TABULAR_DP_CONFIG, - }, - only=None, - ignore=None, - ) - assert configs == { - "users": _ACTGAN_CONFIG, - "payments": _ACTGAN_CONFIG, - invented_tables["purchases_root"]: _LSTM_CONFIG, - invented_tables["purchases_data_years"]: _TABULAR_DP_CONFIG, - } - - # Ensure no conflicts between (invented) table-specific configs and scope - with pytest.raises(MultiTableException): - assemble_configs( - rel_data=documents, - config=_ACTGAN_CONFIG, - table_specific_configs={ - "purchases": _LSTM_CONFIG, - }, - ignore={"purchases"}, - only=None, - ) - with pytest.raises(MultiTableException): - assemble_configs( - rel_data=documents, - config=_ACTGAN_CONFIG, - table_specific_configs={ - "purchases": _LSTM_CONFIG, - }, - ignore={invented_tables["purchases_root"]}, - only=None, - ) - with pytest.raises(MultiTableException): - assemble_configs( - rel_data=documents, - config=_ACTGAN_CONFIG, - table_specific_configs={ - invented_tables["purchases_root"]: _LSTM_CONFIG, - }, - ignore={"purchases"}, - only=None, - ) diff --git a/tests/relational/test_multi_table_restore.py b/tests/relational/test_multi_table_restore.py deleted file mode 100644 index fade442b..00000000 --- a/tests/relational/test_multi_table_restore.py +++ /dev/null @@ -1,362 +0,0 @@ -import json -import os -import shutil -import tempfile - -from pathlib import Path -from typing import Optional -from unittest.mock import Mock - -import pytest -import smart_open - -import gretel_trainer.relational.backup as b - -from gretel_trainer.relational.core import MultiTableException, RelationalData -from gretel_trainer.relational.multi_table import MultiTable, SyntheticsRun - -SOURCE_ARCHIVE_ARTIFACT_ID = "gretel_abc_source_tables.tar.gz" -SOURCE_ARCHIVE_LOCAL_FILENAME = "source_tables.tar.gz" - - -def make_backup( - rel_data: RelationalData, - source_archive: Optional[str], - transforms_models: dict[str, Mock] = {}, - synthetics_models: dict[str, Mock] = {}, - synthetics_record_handlers: dict[str, Mock] = {}, -) -> b.Backup: - backup = b.Backup( - project_name="project_name", - strategy="independent", - refresh_interval=60, - source_archive=source_archive, - relational_data=b.BackupRelationalData.from_relational_data(rel_data), - ) - if len(transforms_models) > 0: - backup.transforms_train = b.BackupTransformsTrain( - model_ids={ - table: mock.model_id for table, mock in transforms_models.items() - }, - lost_contact=[], - ) - if len(synthetics_models) > 0: - backup.synthetics_train = b.BackupSyntheticsTrain( - model_ids={ - table: mock.model_id for table, mock in synthetics_models.items() - }, - lost_contact=[], - ) - if len(synthetics_record_handlers) > 0: - backup.generate = b.BackupGenerate( - identifier="run-id", - preserved=[], - record_size_ratio=1.0, - lost_contact=[], - record_handler_ids={ - table: mock.record_id - for table, mock in synthetics_record_handlers.items() - }, - ) - return backup - - -def write_backup(backup: b.Backup, out_dir: Path) -> str: - dest = out_dir / "_gretel_backup.json" - with open(dest, "w") as b: - json.dump(backup.as_dict, b) - return str(dest) - - -def create_backup( - rel_data: RelationalData, - working_dir: Path, - synthetics_models: dict[str, Mock] = {}, - synthetics_record_handlers: dict[str, Mock] = {}, - transforms_models: dict[str, Mock] = {}, -) -> str: - backup = make_backup( - rel_data, - SOURCE_ARCHIVE_ARTIFACT_ID, - transforms_models, - synthetics_models, - synthetics_record_handlers, - ) - - # Clear everything (i.e. original RelationalData source files) from the working directory so that - # we start the restore process with the dir containing just the backup file and nothing else. - shutil.rmtree(working_dir) - working_dir.mkdir(parents=True, exist_ok=True) - return write_backup(backup, working_dir) - - -def make_mock_get_artifact_handle(setup_path: Path): - def get_artifact_handle(artifact_id): - if artifact_id == SOURCE_ARCHIVE_ARTIFACT_ID: - return smart_open.open(setup_path / SOURCE_ARCHIVE_LOCAL_FILENAME, "rb") - else: - raise ValueError(f"unexpected artifact_id: {artifact_id}") - - return get_artifact_handle - - -def make_mock_get_model(models: dict[str, Mock]): - def get_model(model_id): - return models[model_id] - - return get_model - - -def make_mock_model( - name: str, - status: str, - setup_path: Path, - record_handler: Optional[Mock] = None, -) -> Mock: - model = Mock() - model.status = status - model.model_id = name - model.get_artifact_handle = make_mock_get_artifact_handle(setup_path) - model.get_record_handler.return_value = record_handler - return model - - -def make_mock_record_handler(name: str, status: str) -> Mock: - record_handler = Mock() - record_handler.status = status - record_handler.record_id = name - return record_handler - - -# Creates a source_archive.tar.gz in the temporary setup path (standing in for Gretel Cloud) -def create_standin_source_archive_artifact( - rel_data: RelationalData, setup_path: Path -) -> None: - shutil.make_archive( - base_name=str(setup_path / SOURCE_ARCHIVE_LOCAL_FILENAME).removesuffix( - ".tar.gz" - ), - format="gztar", - root_dir=rel_data.source_data_handler.dir, # type: ignore - ) - - -def configure_mocks( - project: Mock, - setup_path: Path, - working_path: Path, - models: dict[str, Mock] = {}, -) -> None: - # The working directory is always named after the project. In these tests we need the name to - # match the working path that we've configured our other mock artifacts and handlers to use. - project.name = str(working_path) - project.get_artifact_handle = make_mock_get_artifact_handle(setup_path) - project.get_model = make_mock_get_model(models) - project.artifacts = [] - - -@pytest.fixture(autouse=True) -def working_dir(output_handler): - return output_handler._working_dir - - -@pytest.fixture(autouse=True) -def testsetup_dir(): - with tempfile.TemporaryDirectory() as testsetup_dir: - yield Path(testsetup_dir) - - -def test_restore_initial_state(project, pets, working_dir, testsetup_dir): - create_standin_source_archive_artifact(pets, testsetup_dir) - configure_mocks(project, testsetup_dir, working_dir) - backup_file = create_backup(pets, working_dir) - - # Restore a MultiTable instance, starting with only the backup file present in working_dir - assert os.listdir(working_dir) == ["_gretel_backup.json"] - mt = MultiTable.restore(backup_file) - - # Backup + Source archive + (2) Source CSVs - assert len(os.listdir(working_dir)) == 4 - - # RelationalData is restored - assert os.path.exists(working_dir / "humans.csv") - assert os.path.exists(working_dir / "pets.csv") - assert mt.relational_data.debug_summary() == pets.debug_summary() - - -def test_restore_transforms(project, pets, working_dir, testsetup_dir): - transforms_models = { - "humans": make_mock_model( - name="humans", - status="completed", - setup_path=testsetup_dir, - ), - "pets": make_mock_model( - name="pets", - status="completed", - setup_path=testsetup_dir, - ), - } - - create_standin_source_archive_artifact(pets, testsetup_dir) - configure_mocks( - project, - testsetup_dir, - working_dir, - transforms_models, - ) - backup_file = create_backup(pets, working_dir, transforms_models=transforms_models) - - mt = MultiTable.restore(backup_file) - - # Transforms state is restored - assert len(mt._transforms_train.models) == 2 - assert len(mt._transforms_train.lost_contact) == 0 - - -def test_restore_synthetics_training_still_in_progress( - project, pets, working_dir, testsetup_dir -): - synthetics_models = { - "humans": make_mock_model( - name="humans", - status="active", - setup_path=testsetup_dir, - ), - "pets": make_mock_model( - name="pets", - status="pending", - setup_path=testsetup_dir, - ), - } - - create_standin_source_archive_artifact(pets, testsetup_dir) - configure_mocks( - project, - testsetup_dir, - working_dir, - synthetics_models, - ) - backup_file = create_backup(pets, working_dir, synthetics_models=synthetics_models) - - with pytest.raises(MultiTableException): - MultiTable.restore(backup_file) - - -def test_restore_training_complete(project, pets, working_dir, testsetup_dir): - synthetics_models = { - "humans": make_mock_model( - name="humans", - status="completed", - setup_path=testsetup_dir, - ), - "pets": make_mock_model( - name="pets", - status="completed", - setup_path=testsetup_dir, - ), - } - - create_standin_source_archive_artifact(pets, testsetup_dir) - configure_mocks( - project, - testsetup_dir, - working_dir, - synthetics_models, - ) - backup_file = create_backup( - pets, - working_dir, - synthetics_models=synthetics_models, - ) - - mt = MultiTable.restore(backup_file) - - # Training state is restored - assert len(mt._synthetics_train.models) == 2 - - -def test_restore_training_one_failed(project, pets, working_dir, testsetup_dir): - synthetics_models = { - "humans": make_mock_model( - name="humans", - status="error", - setup_path=testsetup_dir, - ), - "pets": make_mock_model( - name="pets", - status="completed", - setup_path=testsetup_dir, - ), - } - - create_standin_source_archive_artifact(pets, testsetup_dir) - configure_mocks( - project, - testsetup_dir, - working_dir, - synthetics_models, - ) - backup_file = create_backup( - pets, - working_dir, - synthetics_models=synthetics_models, - ) - - mt = MultiTable.restore(backup_file) - - # Training state is restored - assert len(mt._synthetics_train.models) == 2 - - -def test_restore_generate_completed(project, pets, working_dir, testsetup_dir): - synthetics_record_handlers = { - "humans": make_mock_record_handler(name="humans", status="completed"), - "pets": make_mock_record_handler(name="pets", status="completed"), - } - - synthetics_models = { - "humans": make_mock_model( - name="humans", - status="completed", - setup_path=testsetup_dir, - record_handler=synthetics_record_handlers["humans"], - ), - "pets": make_mock_model( - name="pets", - status="completed", - setup_path=testsetup_dir, - record_handler=synthetics_record_handlers["pets"], - ), - } - - create_standin_source_archive_artifact(pets, testsetup_dir) - configure_mocks( - project, - testsetup_dir, - working_dir, - synthetics_models, - ) - backup_file = create_backup( - pets, - working_dir, - synthetics_models=synthetics_models, - synthetics_record_handlers=synthetics_record_handlers, - ) - - mt = MultiTable.restore(backup_file) - - # Generate task state is restored - assert mt._synthetics_run == SyntheticsRun( - identifier="run-id", - preserved=[], - record_size_ratio=1.0, - lost_contact=[], - record_handlers=synthetics_record_handlers, - ) - # but note we don't (re)set synthetic_output_tables or evaluations - assert len(mt.synthetic_output_tables) == 0 - assert mt.evaluations["humans"].individual_sqs is None - assert mt.evaluations["humans"].cross_table_sqs is None - assert mt.evaluations["pets"].individual_sqs is None - assert mt.evaluations["pets"].cross_table_sqs is None diff --git a/tests/relational/test_output_handler.py b/tests/relational/test_output_handler.py deleted file mode 100644 index 164e5c2d..00000000 --- a/tests/relational/test_output_handler.py +++ /dev/null @@ -1,21 +0,0 @@ -def test_uploads_path_to_project_and_stores_artifact_key(output_handler, pets): - project = output_handler._project - project.upload_artifact.return_value = "artifact_key" - - output_handler.save_sources(pets) - - project.upload_artifact.assert_called_once() - assert output_handler.get_source_archive() == "artifact_key" - - -def test_overwrites_project_artifacts(output_handler, pets): - output_handler._source_archive = "first_key" - - project = output_handler._project - project.upload_artifact.return_value = "second_key" - - output_handler.save_sources(pets) - - project.upload_artifact.assert_called_once() - project.delete_artifact.assert_called_once_with("first_key") - assert output_handler.get_source_archive() == "second_key" diff --git a/tests/relational/test_relational_data.py b/tests/relational/test_relational_data.py deleted file mode 100644 index b1a94e69..00000000 --- a/tests/relational/test_relational_data.py +++ /dev/null @@ -1,357 +0,0 @@ -import pandas as pd -import pytest - -from gretel_trainer.relational.core import MultiTableException - - -def test_ecommerce_relational_data(ecom): - assert ecom.get_parents("users") == [] - assert ecom.get_parents("events") == ["users"] - assert set(ecom.get_parents("inventory_items")) == { - "products", - "distribution_center", - } - - # get_parents goes back one generation, - # get_ancestors goes back all generations - assert set(ecom.get_parents("order_items")) == { - "users", - "inventory_items", - } - assert set(ecom.get_ancestors("order_items")) == { - "users", - "inventory_items", - "products", - "distribution_center", - } - - # get_descendants goes forward all generations - assert set(ecom.get_descendants("products")) == { - "inventory_items", - "order_items", - } - - -def test_mutagenesis_relational_data(mutagenesis): - assert mutagenesis.get_parents("bond") == ["atom"] - assert mutagenesis.get_parents("atom") == ["molecule"] - - assert mutagenesis.get_primary_key("bond") == ["atom1_id", "atom2_id"] - assert mutagenesis.get_primary_key("atom") == ["atom_id"] - - assert set(mutagenesis.get_all_key_columns("bond")) == {"atom1_id", "atom2_id"} - assert set(mutagenesis.get_all_key_columns("atom")) == {"atom_id", "molecule_id"} - - -def test_row_count(art): - assert art.get_table_row_count("artists") == 4 - assert art.get_table_row_count("paintings") == 7 - - -def test_column_metadata(pets, tmpfile): - assert pets.get_table_columns("humans") == ["id", "name", "city"] - - # Name is a highly unique categorical field, so is excluded - assert pets.get_safe_ancestral_seed_columns("humans") == {"id", "city"} - - # Update the table data such that: - # - id is highly unique categorical (_ to force string instead of int), but still the PK - # - name is no longer highly unique - # - city is highly NaN - pd.DataFrame( - data={ - "id": ["1_", "2_", "3_"], - "name": ["n", "n", "n"], - "city": [None, None, "Chicago"], - } - ).to_csv(tmpfile.name, index=False) - pets.update_table_data( - "humans", - tmpfile.name, - ) - - assert pets.get_safe_ancestral_seed_columns("humans") == {"id", "name"} - - # Resetting the primary key refreshes the cache state - # In this case, since id is no longer the PK and is highly unique, it is excluded - pets.set_primary_key(table="humans", primary_key=None) - assert pets.get_safe_ancestral_seed_columns("humans") == {"name"} - - # Reset back to normal - pets.set_primary_key(table="humans", primary_key="id") - - # Setting a column as a foreign key ensures it is included - pets.add_foreign_key_constraint( - table="humans", - constrained_columns=["city"], - referred_table="pets", - referred_columns=["id"], - ) - assert pets.get_safe_ancestral_seed_columns("humans") == {"id", "name", "city"} - - # Removing a foreign key refreshes the cache state - pets.remove_foreign_key_constraint("humans", ["city"]) - assert pets.get_safe_ancestral_seed_columns("humans") == {"id", "name"} - - -def test_adding_and_removing_foreign_keys(pets): - # pets has a foreign key defined out of the box. - # First lets successfully remove it and re-add it. - assert len(pets.get_foreign_keys("pets")) == 1 - - pets.remove_foreign_key_constraint(table="pets", constrained_columns=["human_id"]) - assert len(pets.get_foreign_keys("pets")) == 0 - - pets.add_foreign_key_constraint( - table="pets", - constrained_columns=["human_id"], - referred_table="humans", - referred_columns=["id"], - ) - assert len(pets.get_foreign_keys("pets")) == 1 - - # Now we'll make some assertions about our defense - - # Cannot add to an unrecognized table - with pytest.raises(MultiTableException): - pets.add_foreign_key_constraint( - table="unrecognized", - constrained_columns=["user_id"], - referred_table="humans", - referred_columns=["id"], - ) - - # Cannot add to an unrecognized referred table - with pytest.raises(MultiTableException): - pets.add_foreign_key_constraint( - table="pets", - constrained_columns=["human_id"], - referred_table="unrecognized", - referred_columns=["id"], - ) - - # Cannot add unrecognized columns - with pytest.raises(MultiTableException): - pets.add_foreign_key_constraint( - table="pets", - constrained_columns=["human_id"], - referred_table="humans", - referred_columns=["unrecognized"], - ) - with pytest.raises(MultiTableException): - pets.add_foreign_key_constraint( - table="pets", - constrained_columns=["unrecognized"], - referred_table="humans", - referred_columns=["id"], - ) - - # Cannot remove from unrecognized table - with pytest.raises(MultiTableException): - pets.remove_foreign_key_constraint( - table="unrecognized", constrained_columns=["id"] - ) - - # Cannot remove a non-existent key - with pytest.raises(MultiTableException): - pets.remove_foreign_key_constraint(table="pets", constrained_columns=["id"]) - - -def test_set_primary_key(ecom): - assert ecom.get_primary_key("users") == ["id"] - - ecom.set_primary_key(table="users", primary_key=None) - assert ecom.get_primary_key("users") == [] - - ecom.set_primary_key(table="users", primary_key=["first_name", "last_name"]) - assert ecom.get_primary_key("users") == ["first_name", "last_name"] - - ecom.set_primary_key(table="users", primary_key="id") - assert ecom.get_primary_key("users") == ["id"] - - # Can't set primary key on an unknown table - with pytest.raises(MultiTableException): - ecom.set_primary_key(table="not_a_table", primary_key="id") - - # Can't set primary key to a non-existent column - with pytest.raises(MultiTableException): - ecom.set_primary_key(table="users", primary_key="not_a_column") - - -def test_get_subset_of_data(pets): - normal_length = len(pets.get_table_data("humans")) - subset = pets.get_table_data("humans", ["name", "city"]) - assert list(subset.columns) == ["name", "city"] - assert len(subset) == normal_length - - -def test_list_tables_parents_before_children(ecom): - def in_order(col, t1, t2): - return col.index(t1) < col.index(t2) - - tables = ecom.list_tables_parents_before_children() - assert in_order(tables, "users", "events") - assert in_order(tables, "distribution_center", "products") - assert in_order(tables, "distribution_center", "inventory_items") - assert in_order(tables, "products", "inventory_items") - assert in_order(tables, "inventory_items", "order_items") - assert in_order(tables, "users", "order_items") - - -def test_detect_cycles(ecom): - assert ecom.foreign_key_cycles == [] - - ecom.add_foreign_key_constraint( - table="users", - constrained_columns=["first_name"], - referred_table="users", - referred_columns=["last_name"], - ) - ecom.debug_summary() - - assert ecom.foreign_key_cycles == [["users"]] - assert "indeterminate" in ecom.debug_summary()["max_depth"] - - ecom.add_foreign_key_constraint( - table="users", - constrained_columns=["first_name"], - referred_table="events", - referred_columns=["user_id"], - ) - - sorted_cycles = sorted([sorted(cycle) for cycle in ecom.foreign_key_cycles]) - assert sorted_cycles == [["events", "users"], ["users"]] - - -def test_debug_summary(ecom, mutagenesis): - assert ecom.debug_summary() == { - "foreign_key_count": 6, - "max_depth": 3, - "public_table_count": 6, - "invented_table_count": 0, - "tables": { - "users": { - "column_count": 3, - "primary_key": ["id"], - "foreign_key_count": 0, - "foreign_keys": [], - "is_invented_table": False, - }, - "events": { - "column_count": 4, - "primary_key": ["id"], - "foreign_key_count": 1, - "foreign_keys": [ - { - "columns": ["user_id"], - "parent_table_name": "users", - "parent_columns": ["id"], - } - ], - "is_invented_table": False, - }, - "distribution_center": { - "column_count": 2, - "primary_key": ["id"], - "foreign_key_count": 0, - "foreign_keys": [], - "is_invented_table": False, - }, - "products": { - "column_count": 4, - "primary_key": ["id"], - "foreign_key_count": 1, - "foreign_keys": [ - { - "columns": ["distribution_center_id"], - "parent_table_name": "distribution_center", - "parent_columns": ["id"], - } - ], - "is_invented_table": False, - }, - "inventory_items": { - "column_count": 5, - "primary_key": ["id"], - "foreign_key_count": 2, - "foreign_keys": [ - { - "columns": ["product_id"], - "parent_table_name": "products", - "parent_columns": ["id"], - }, - { - "columns": ["product_distribution_center_id"], - "parent_table_name": "distribution_center", - "parent_columns": ["id"], - }, - ], - "is_invented_table": False, - }, - "order_items": { - "column_count": 5, - "primary_key": ["id"], - "foreign_key_count": 2, - "foreign_keys": [ - { - "columns": ["user_id"], - "parent_table_name": "users", - "parent_columns": ["id"], - }, - { - "columns": ["inventory_item_id"], - "parent_table_name": "inventory_items", - "parent_columns": ["id"], - }, - ], - "is_invented_table": False, - }, - }, - } - - assert mutagenesis.debug_summary() == { - "foreign_key_count": 3, - "max_depth": 2, - "public_table_count": 3, - "invented_table_count": 0, - "tables": { - "bond": { - "column_count": 3, - "primary_key": ["atom1_id", "atom2_id"], - "foreign_key_count": 2, - "foreign_keys": [ - { - "columns": ["atom1_id"], - "parent_table_name": "atom", - "parent_columns": ["atom_id"], - }, - { - "columns": ["atom2_id"], - "parent_table_name": "atom", - "parent_columns": ["atom_id"], - }, - ], - "is_invented_table": False, - }, - "atom": { - "column_count": 4, - "primary_key": ["atom_id"], - "foreign_key_count": 1, - "foreign_keys": [ - { - "columns": ["molecule_id"], - "parent_table_name": "molecule", - "parent_columns": ["molecule_id"], - } - ], - "is_invented_table": False, - }, - "molecule": { - "column_count": 2, - "primary_key": ["molecule_id"], - "foreign_key_count": 0, - "foreign_keys": [], - "is_invented_table": False, - }, - }, - } diff --git a/tests/relational/test_relational_data_with_json.py b/tests/relational/test_relational_data_with_json.py deleted file mode 100644 index 491ff86d..00000000 --- a/tests/relational/test_relational_data_with_json.py +++ /dev/null @@ -1,1171 +0,0 @@ -import itertools -import re -import tempfile - -import pandas as pd -import pandas.testing as pdtest -import pytest - -from gretel_trainer.relational.core import ForeignKey, RelationalData, Scope -from gretel_trainer.relational.json import generate_unique_table_name, get_json_columns - - -@pytest.fixture -def bball(tmpdir): - bball_jsonl = """ - {"name": "LeBron James", "age": 38, "draft": {"year": 2003}, "teams": ["Cavaliers", "Heat", "Lakers"], "suspensions": []} - {"name": "Steph Curry", "age": 35, "draft": {"year": 2009, "college": "Davidson"}, "teams": ["Warriors"], "suspensions": []} - """ - bball_df = pd.read_json(bball_jsonl, lines=True) - - rel_data = RelationalData(directory=tmpdir) - rel_data.add_table(name="bball", primary_key=None, data=bball_df) - - return rel_data - - -@pytest.fixture -def deeply_nested(tmpdir): - jsonl = """ - {"hello1":"world1","level1_long_property_name":[{"hello2":"world2","level2_long_property_name":[{"hello3":"world3","level3_long_property_name":[{"hello4":"world4","level4_long_property_name":[{"hello5":"world5","level5_long_property_name":[{"hello6":"world6","level6_long_property_name":[{"hello7":"world7","level7_long_property_name":[{"hello8":"world8","level8_long_property_name":[{"hello9":"world9"}]}]}]}]}]}]}]}]} - """ - df = pd.read_json(jsonl, lines=True) - - rel_data = RelationalData(directory=tmpdir) - rel_data.add_table(name="deeply_nested", primary_key=None, data=df) - - return rel_data - - -def test_list_json_cols(documents, bball): - assert get_json_columns(documents.get_table_data("users")) == [] - assert get_json_columns(documents.get_table_data("purchases")) == ["data"] - - assert set(get_json_columns(bball.get_table_data("bball"))) == { - "draft", - "teams", - "suspensions", - } - - -def test_json_columns_produce_invented_flattened_tables(documents, invented_tables): - pdtest.assert_frame_equal( - documents.get_table_data(invented_tables["purchases_root"]), - pd.DataFrame( - data={ - "~PRIMARY_KEY_ID~": [0, 1, 2, 3, 4, 5], - "id": [1, 2, 3, 4, 5, 6], - "user_id": [1, 2, 2, 3, 3, 3], - "data>item": ["pen", "paint", "ink", "pen", "paint", "ink"], - "data>cost": [100, 100, 100, 100, 100, 100], - "data>details>color": ["red", "red", "red", "blue", "blue", "blue"], - } - ), - check_like=True, - ) - - pdtest.assert_frame_equal( - documents.get_table_data(invented_tables["purchases_data_years"]), - pd.DataFrame( - data={ - "content": [2023, 2023, 2022, 2020, 2019, 2021], - "array~order": [0, 0, 1, 0, 1, 0], - "~PRIMARY_KEY_ID~": [0, 1, 2, 3, 4, 5], - "purchases~id": [0, 1, 1, 2, 2, 4], - } - ), - check_like=True, - check_dtype=False, # Without this, test fails asserting dtype mismatch in `content` field (object vs. int) - ) - - assert documents.get_foreign_keys(invented_tables["purchases_data_years"]) == [ - ForeignKey( - table_name=invented_tables["purchases_data_years"], - columns=["purchases~id"], - parent_table_name=invented_tables["purchases_root"], - parent_columns=["~PRIMARY_KEY_ID~"], - ) - ] - - -def test_list_tables_accepts_various_scopes(documents, invented_tables): - # PUBLIC reflects the user's source - assert set(documents.list_all_tables(Scope.PUBLIC)) == { - "users", - "purchases", - "payments", - } - - # MODELABLE replaces any source tables containing JSON with the invented tables - assert set(documents.list_all_tables(Scope.MODELABLE)) == { - "users", - "payments", - invented_tables["purchases_root"], - invented_tables["purchases_data_years"], - } - - # EVALUATABLE is similar to MODELABLE, but omits invented child tables—we only evaluate the root invented table - assert set(documents.list_all_tables(Scope.EVALUATABLE)) == { - "users", - "payments", - invented_tables["purchases_root"], - } - - # INVENTED returns only tables invented from source tables with JSON - assert set(documents.list_all_tables(Scope.INVENTED)) == { - invented_tables["purchases_root"], - invented_tables["purchases_data_years"], - } - - # ALL returns every table name, including both source-with-JSON tables and those invented from such tables - assert set(documents.list_all_tables(Scope.ALL)) == { - "users", - "purchases", - "payments", - invented_tables["purchases_root"], - invented_tables["purchases_data_years"], - } - - # Default scope is MODELABLE - assert set(documents.list_all_tables()) == set( - documents.list_all_tables(Scope.MODELABLE) - ) - - -def test_get_modelable_table_names(documents, invented_tables): - # Given a source-with-JSON name, returns the tables invented from that source - assert set(documents.get_modelable_table_names("purchases")) == { - invented_tables["purchases_root"], - invented_tables["purchases_data_years"], - } - - # Invented tables are modelable - assert documents.get_modelable_table_names(invented_tables["purchases_root"]) == [ - invented_tables["purchases_root"] - ] - assert documents.get_modelable_table_names( - invented_tables["purchases_data_years"] - ) == [invented_tables["purchases_data_years"]] - - # Unknown tables return empty list - assert documents.get_modelable_table_names("nonsense") == [] - - -def test_get_modelable_names_ignores_empty_mapped_tables(bball, invented_tables): - # The `suspensions` column in the source data contained empty lists for all records. - # The normalization process transforms that into a standalone, empty table. - # We need to hold onto that table name to support denormalizing back to the original - # source data shape. It is therefore present when listing ALL tables... - assert set(bball.list_all_tables(Scope.ALL)) == { - "bball", - invented_tables["bball_root"], - invented_tables["bball_suspensions"], - invented_tables["bball_teams"], - } - - # ...and the producer metadata is aware of it... - assert set(bball.get_producer_metadata("bball").table_names) == { - invented_tables["bball_root"], - invented_tables["bball_suspensions"], - invented_tables["bball_teams"], - } - - # ...BUT most clients only care about invented tables that can be modeled - # (i.e. that contain data), so the empty table does not appear in these contexts: - assert set(bball.get_modelable_table_names("bball")) == { - invented_tables["bball_root"], - invented_tables["bball_teams"], - } - assert set(bball.list_all_tables()) == { - invented_tables["bball_root"], - invented_tables["bball_teams"], - } - - -def test_invented_json_column_names_documents(documents, invented_tables): - # The root invented table adds columns for dictionary properties lifted from nested JSON objects - assert documents.get_table_columns(invented_tables["purchases_root"]) == [ - "~PRIMARY_KEY_ID~", - "id", - "user_id", - "data>item", - "data>cost", - "data>details>color", - ] - - # JSON lists lead to invented child tables. These tables store the original content, - # a new primary key, a foreign key back to the parent, and the original array index - assert documents.get_table_columns(invented_tables["purchases_data_years"]) == [ - "~PRIMARY_KEY_ID~", - "purchases~id", - "content", - "array~order", - ] - - -def test_invented_json_column_names_bball(bball, invented_tables): - # If the source table does not have a primary key defined, one is created on the root invented table - assert bball.get_table_columns(invented_tables["bball_root"]) == [ - "~PRIMARY_KEY_ID~", - "name", - "age", - "draft>year", - "draft>college", - ] - - -def test_set_some_primary_key_to_none(static_suffix, documents, invented_tables): - # The producer table has a single column primary key, - # so the root invented table has a composite key that includes the source PK and an invented column - assert documents.get_primary_key("purchases") == ["id"] - assert documents.get_primary_key(invented_tables["purchases_root"]) == [ - "id", - "~PRIMARY_KEY_ID~", - ] - - # Setting an existing primary key to None puts us in the correct state - assert len(documents.list_all_tables(Scope.ALL)) == 5 - original_payments_fks = documents.get_foreign_keys("payments") - - # Reset the make_suffix iterator back to original count since set_primary_key will call it again - # once for each invented table. - static_suffix.side_effect = itertools.count(start=1) - - # Setting the primary key causes json invented tables to be dropped and reingested - documents.set_primary_key(table="purchases", primary_key=None) - assert len(documents.list_all_tables(Scope.ALL)) == 5 - assert documents.get_primary_key("purchases") == [] - assert documents.get_primary_key(invented_tables["purchases_root"]) == [ - "~PRIMARY_KEY_ID~" - ] - assert documents.get_foreign_keys(invented_tables["purchases_data_years"]) == [ - ForeignKey( - table_name=invented_tables["purchases_data_years"], - columns=["purchases~id"], - parent_table_name=invented_tables["purchases_root"], - parent_columns=["~PRIMARY_KEY_ID~"], - ) - ] - assert documents.get_foreign_keys("payments") == original_payments_fks - - -def test_set_none_primary_key_to_some_value(static_suffix, bball, invented_tables): - # The producer table has no primary key, - # so the root invented table has a single invented key column - assert bball.get_primary_key("bball") == [] - assert bball.get_primary_key(invented_tables["bball_root"]) == ["~PRIMARY_KEY_ID~"] - - # Setting a None primary key to some column puts us in the correct state - assert len(bball.list_all_tables(Scope.ALL)) == 4 - - # Reset the make_suffix iterator back to original count since set_primary_key will call it again - # once for each invented table. - static_suffix.side_effect = itertools.count(start=1) - - bball.set_primary_key(table="bball", primary_key="name") - assert len(bball.list_all_tables(Scope.ALL)) == 4 - assert bball.get_primary_key("bball") == ["name"] - assert bball.get_primary_key(invented_tables["bball_root"]) == [ - "name", - "~PRIMARY_KEY_ID~", - ] - assert bball.get_foreign_keys(invented_tables["bball_suspensions"]) == [ - ForeignKey( - table_name=invented_tables["bball_suspensions"], - columns=["bball~id"], - parent_table_name=invented_tables["bball_root"], - parent_columns=["~PRIMARY_KEY_ID~"], - ) - ] - - -def test_foreign_keys(documents, invented_tables): - # Foreign keys from the source-with-JSON table are present on the root invented table - assert documents.get_foreign_keys("purchases") == documents.get_foreign_keys( - invented_tables["purchases_root"] - ) - - # The root invented table name is used in the ForeignKey - assert documents.get_foreign_keys("purchases") == [ - ForeignKey( - table_name=invented_tables["purchases_root"], - columns=["user_id"], - parent_table_name="users", - parent_columns=["id"], - ) - ] - - # Invented children point to invented parents - assert documents.get_foreign_keys(invented_tables["purchases_data_years"]) == [ - ForeignKey( - table_name=invented_tables["purchases_data_years"], - columns=["purchases~id"], - parent_table_name=invented_tables["purchases_root"], - parent_columns=["~PRIMARY_KEY_ID~"], - ) - ] - - # Source children of the source-with-JSON table point to the root invented table - assert documents.get_foreign_keys("payments") == [ - ForeignKey( - table_name="payments", - columns=["purchase_id"], - parent_table_name=invented_tables["purchases_root"], - parent_columns=["id"], - ) - ] - - # You can request public/user-supplied names instead of the default invented table names - assert documents.get_foreign_keys("payments", rename_invented_tables=True) == [ - ForeignKey( - table_name="payments", - columns=["purchase_id"], - parent_table_name="purchases", - parent_columns=["id"], - ) - ] - assert documents.get_foreign_keys("purchases", rename_invented_tables=True) == [ - ForeignKey( - table_name="purchases", - columns=["user_id"], - parent_table_name="users", - parent_columns=["id"], - ) - ] - - # Removing a foreign key from the source-with-JSON table updates the root invented table - documents.remove_foreign_key_constraint( - table="purchases", constrained_columns=["user_id"] - ) - assert documents.get_foreign_keys("purchases") == [] - assert documents.get_foreign_keys(invented_tables["purchases_root"]) == [] - - -def test_update_data_with_existing_json_to_new_json( - static_suffix, documents, invented_tables -): - new_purchases_jsonl = """ - {"id": 1, "user_id": 1, "data": {"item": "watercolor", "cost": 200, "details": {"color": "aquamarine"}, "years": [1999]}} - {"id": 2, "user_id": 2, "data": {"item": "watercolor", "cost": 200, "details": {"color": "aquamarine"}, "years": [1999]}} - {"id": 3, "user_id": 2, "data": {"item": "watercolor", "cost": 200, "details": {"color": "aquamarine"}, "years": [1999]}} - {"id": 4, "user_id": 3, "data": {"item": "charcoal", "cost": 200, "details": {"color": "aquamarine"}, "years": [1998]}} - {"id": 5, "user_id": 3, "data": {"item": "charcoal", "cost": 200, "details": {"color": "aquamarine"}, "years": [1998]}} - {"id": 6, "user_id": 3, "data": {"item": "charcoal", "cost": 200, "details": {"color": "aquamarine"}, "years": [1998]}} - """ - new_purchases_df = pd.read_json(new_purchases_jsonl, lines=True) - - # Reset the make_suffix iterator back to original count since make_suffix will be called again - # once for each invented table. - static_suffix.side_effect = itertools.count(start=1) - - documents.update_table_data("purchases", data=new_purchases_df) - - assert len(documents.list_all_tables(Scope.ALL)) == 5 - assert len(documents.list_all_tables(Scope.MODELABLE)) == 4 - - expected = { - invented_tables["purchases_root"]: pd.DataFrame( - data={ - "~PRIMARY_KEY_ID~": [0, 1, 2, 3, 4, 5], - "id": [1, 2, 3, 4, 5, 6], - "user_id": [1, 2, 2, 3, 3, 3], - "data>item": [ - "watercolor", - "watercolor", - "watercolor", - "charcoal", - "charcoal", - "charcoal", - ], - "data>cost": [200, 200, 200, 200, 200, 200], - "data>details>color": [ - "aquamarine", - "aquamarine", - "aquamarine", - "aquamarine", - "aquamarine", - "aquamarine", - ], - } - ), - invented_tables["purchases_data_years"]: pd.DataFrame( - data={ - "content": [1999, 1999, 1999, 1998, 1998, 1998], - "array~order": [0, 0, 0, 0, 0, 0], - "~PRIMARY_KEY_ID~": [0, 1, 2, 3, 4, 5], - "purchases~id": [0, 1, 2, 3, 4, 5], - } - ), - } - - pdtest.assert_frame_equal( - documents.get_table_data(invented_tables["purchases_root"]), - expected[invented_tables["purchases_root"]], - check_like=True, - ) - - pdtest.assert_frame_equal( - documents.get_table_data(invented_tables["purchases_data_years"]), - expected[invented_tables["purchases_data_years"]], - check_like=True, - check_dtype=False, # Without this, test fails asserting dtype mismatch in `content` field (object vs. int) - ) - - # User-supplied child table FK still exists - assert documents.get_foreign_keys("payments") == [ - ForeignKey( - table_name="payments", - columns=["purchase_id"], - parent_table_name=invented_tables["purchases_root"], - parent_columns=["id"], - ) - ] - - -def test_update_data_existing_json_to_no_json(documents): - new_purchases_df = pd.DataFrame( - data={ - "id": [1, 2, 3, 4, 5, 6], - "user_id": [1, 2, 2, 3, 3, 3], - "data": ["pen", "paint", "ink", "pen", "paint", "ink"], - } - ) - - documents.update_table_data("purchases", data=new_purchases_df) - - assert len(documents.list_all_tables(Scope.ALL)) == 3 - - pdtest.assert_frame_equal( - documents.get_table_data("purchases"), - new_purchases_df, - check_like=True, - ) - - assert documents.get_foreign_keys("payments") == [ - ForeignKey( - table_name="payments", - columns=["purchase_id"], - parent_table_name="purchases", - parent_columns=["id"], - ) - ] - - -def test_update_data_existing_flat_to_json(static_suffix, documents, invented_tables): - # Build up a RelationalData instance that basically mirrors documents, - # but purchases is flat to start and thus there are no RelationalJson instances - flat_purchases_df = pd.DataFrame( - data={ - "id": [1, 2, 3, 4, 5, 6], - "user_id": [1, 2, 2, 3, 3, 3], - "data": ["pen", "paint", "ink", "pen", "paint", "ink"], - } - ) - with tempfile.TemporaryDirectory() as tmpdir: - rel_data = RelationalData(directory=tmpdir) - rel_data.add_table( - name="users", primary_key="id", data=documents.get_table_data("users") - ) - rel_data.add_table(name="purchases", primary_key="id", data=flat_purchases_df) - rel_data.add_table( - name="payments", primary_key="id", data=documents.get_table_data("payments") - ) - rel_data.add_foreign_key_constraint( - table="purchases", - constrained_columns=["user_id"], - referred_table="users", - referred_columns=["id"], - ) - rel_data.add_foreign_key_constraint( - table="payments", - constrained_columns=["purchase_id"], - referred_table="purchases", - referred_columns=["id"], - ) - assert len(rel_data.list_all_tables(Scope.ALL)) == 3 - assert len(rel_data.list_all_tables(Scope.MODELABLE)) == 3 - - # Reset the make_suffix iterator back to original count since make_suffix will be called again - # once for each invented table. - static_suffix.side_effect = itertools.count(start=1) - rel_data.update_table_data("purchases", documents.get_table_data("purchases")) - - assert set(rel_data.list_all_tables(Scope.ALL)) == { - "users", - "purchases", - invented_tables["purchases_root"], - invented_tables["purchases_data_years"], - "payments", - } - # the original purchases table is no longer flat, nor (therefore) MODELABLE - assert set(rel_data.list_all_tables(Scope.MODELABLE)) == { - "users", - invented_tables["purchases_root"], - invented_tables["purchases_data_years"], - "payments", - } - assert rel_data.get_foreign_keys("payments") == [ - ForeignKey( - table_name="payments", - columns=["purchase_id"], - parent_table_name=invented_tables[ - "purchases_root" - ], # The foreign key now points to the root invented table - parent_columns=["id"], - ) - ] - - -# Simulates output tables from MultiTable transforms or synthetics, which will only include the MODELABLE tables -@pytest.fixture() -def mt_output_tables(invented_tables): - return { - "users": pd.DataFrame( - data={ - "id": [1, 2, 3], - "name": ["Rob", "Sam", "Tim"], - } - ), - "payments": pd.DataFrame( - data={ - "id": [1, 2, 3, 4], - "amount": [10, 10, 10, 10], - "purchase_id": [1, 2, 3, 4], - } - ), - invented_tables["purchases_root"]: pd.DataFrame( - data={ - "~PRIMARY_KEY_ID~": [0, 1, 2, 3], - "id": [1, 2, 3, 4], - "user_id": [1, 1, 2, 3], - "data>item": ["pen", "paint", "ink", "ink"], - "data>cost": [18, 19, 20, 21], - "data>details>color": ["blue", "yellow", "pink", "orange"], - } - ), - invented_tables["purchases_data_years"]: pd.DataFrame( - data={ - "content": [2000, 2001, 2002, 2003, 2004, 2005, 2006, 2007], - "~PRIMARY_KEY_ID~": [0, 1, 2, 3, 4, 5, 6, 7], - "purchases~id": [0, 0, 0, 1, 2, 2, 3, 3], - "array~order": [0, 1, 2, 0, 0, 1, 0, 1], - } - ), - } - - -def test_restoring_output_tables_to_original_shape(documents, mt_output_tables): - restored_tables = documents.restore(mt_output_tables) - - # We expect our restored tables to match the PUBLIC tables - assert len(restored_tables) == 3 - expected = { - "users": mt_output_tables["users"], - "payments": mt_output_tables["payments"], - "purchases": pd.DataFrame( - data={ - "id": [1, 2, 3, 4], - "user_id": [1, 1, 2, 3], - "data": [ - { - "item": "pen", - "cost": 18, - "details": {"color": "blue"}, - "years": [2000, 2001, 2002], - }, - { - "item": "paint", - "cost": 19, - "details": {"color": "yellow"}, - "years": [2003], - }, - { - "item": "ink", - "cost": 20, - "details": {"color": "pink"}, - "years": [2004, 2005], - }, - { - "item": "ink", - "cost": 21, - "details": {"color": "orange"}, - "years": [2006, 2007], - }, - ], - } - ), - } - - for t, df in restored_tables.items(): - pdtest.assert_frame_equal(df, expected[t]) - - -def test_restore_with_incomplete_tableset(documents, mt_output_tables, invented_tables): - without_invented_root = { - k: v - for k, v in mt_output_tables.items() - if k != invented_tables["purchases_root"] - } - - without_invented_child = { - k: v - for k, v in mt_output_tables.items() - if k != invented_tables["purchases_data_years"] - } - - restored_without_invented_root = documents.restore(without_invented_root) - restored_without_invented_child = documents.restore(without_invented_child) - - # non-JSON-related tables are fine/unaffected - pdtest.assert_frame_equal( - restored_without_invented_child["users"], mt_output_tables["users"] - ) - pdtest.assert_frame_equal( - restored_without_invented_child["payments"], mt_output_tables["payments"] - ) - pdtest.assert_frame_equal( - restored_without_invented_root["users"], mt_output_tables["users"] - ) - pdtest.assert_frame_equal( - restored_without_invented_root["payments"], mt_output_tables["payments"] - ) - - # If the invented root is missing, the table is omitted from the result dict entirely - assert "purchases" not in restored_without_invented_root - - # If an invented child is missing, we restore the shape but populate the list column with empty lists - pdtest.assert_frame_equal( - restored_without_invented_child["purchases"], - pd.DataFrame( - data={ - "id": [1, 2, 3, 4], - "user_id": [1, 1, 2, 3], - "data": [ - { - "item": "pen", - "cost": 18, - "details": {"color": "blue"}, - "years": [], - }, - { - "item": "paint", - "cost": 19, - "details": {"color": "yellow"}, - "years": [], - }, - { - "item": "ink", - "cost": 20, - "details": {"color": "pink"}, - "years": [], - }, - { - "item": "ink", - "cost": 21, - "details": {"color": "orange"}, - "years": [], - }, - ], - } - ), - ) - - -def test_restore_with_empty_tables(bball, invented_tables): - synthetic_bball_output_tables = { - invented_tables["bball_root"]: pd.DataFrame( - data={ - "name": ["Jimmy Butler"], - "age": [33], - "draft>year": [2011], - "draft>college": ["Marquette"], - "~PRIMARY_KEY_ID~": [0], - } - ), - invented_tables["bball_teams"]: pd.DataFrame( - data={ - "content": ["Bulls", "Timberwolves", "Sixers", "Heat"], - "array~order": [0, 1, 2, 3], - "~PRIMARY_KEY_ID~": [0, 1, 2, 3], - "bball~id": [0, 0, 0, 0], - } - ), - } - - restored_tables = bball.restore(synthetic_bball_output_tables) - jimmy = restored_tables["bball"].iloc[0] - - assert jimmy["name"] == "Jimmy Butler" - assert jimmy["age"] == 33 - assert jimmy["draft"] == {"year": 2011, "college": "Marquette"} - assert jimmy["teams"] == ["Bulls", "Timberwolves", "Sixers", "Heat"] - assert jimmy["suspensions"] == [] - - -@pytest.fixture -def nested_lists_of_objects(tmpdir): - json = """ -{ - "Records": [ - { - "userAgent": "hello", - "responseElements": { - "accountAttributes": [ - { - "attributeName": "duration", - "attributeValues": [ - { - "attributeValue": "45" - } - ] - } - ] - } - } - ] -} -""" - json_df = pd.read_json(json) - - rel_data = RelationalData(directory=tmpdir) - rel_data.add_table(name="demo", primary_key=None, data=json_df) - - return rel_data - - -def test_nested_lists_of_objects(nested_lists_of_objects): - output_tables = { - "demo_invented_1": pd.DataFrame( - data={ - "~PRIMARY_KEY_ID~": [0, 1], - "Records>userAgent": ["abc", "def"], - } - ), - "demo_invented_2": pd.DataFrame( - data={ - "~PRIMARY_KEY_ID~": [0, 1], - "demo~id": [0, 1], - "array~order": [0, 0], - "content>attributeName": ["duration", "duration"], - } - ), - "demo_invented_3": pd.DataFrame( - data={ - "~PRIMARY_KEY_ID~": [0, 1], - "demo^Records>responseElements>accountAttributes~id": [0, 1], - "array~order": [0, 0], - "content>attributeValue": ["42", "43"], - } - ), - } - - restored = nested_lists_of_objects.restore(output_tables) - - pdtest.assert_frame_equal( - restored["demo"], - pd.DataFrame( - data={ - "Records": [ - { - "userAgent": "abc", - "responseElements": { - "accountAttributes": [ - { - "attributeName": "duration", - "attributeValues": [{"attributeValue": "42"}], - } - ] - }, - }, - { - "userAgent": "def", - "responseElements": { - "accountAttributes": [ - { - "attributeName": "duration", - "attributeValues": [{"attributeValue": "43"}], - } - ] - }, - }, - ] - } - ), - ) - - -# TODO: This test documents current behavior, but ideally we'd improve our handling of this scenario -# to retain more synthetic data from "deeper" levels that trained and ran successfully. -def test_handles_missing_interior_invented_tables(nested_lists_of_objects): - # Same setup as the test above except we omit demo_invented_2 - # (simulating that table's model/rh erroring out) - output_tables = { - "demo_invented_1": pd.DataFrame( - data={ - "~PRIMARY_KEY_ID~": [0, 1], - "Records>userAgent": ["abc", "def"], - } - ), - # Since demo_invented_2 is missing, Independent strategy post-processing - # will set all foreign key values on this table to None. - "demo_invented_3": pd.DataFrame( - data={ - "~PRIMARY_KEY_ID~": [0, 1], - "demo^Records>responseElements>accountAttributes~id": [None, None], - "array~order": [0, 0], - "content>attributeValue": ["42", "43"], - } - ), - } - - restored = nested_lists_of_objects.restore(output_tables) - - pdtest.assert_frame_equal( - restored["demo"], - pd.DataFrame( - data={ - "Records": [ - { - "userAgent": "abc", - "responseElements": { - "accountAttributes": [ - # We lost the table at this level, and with it everything below. - ] - }, - }, - { - "userAgent": "def", - "responseElements": {"accountAttributes": []}, - }, - ] - } - ), - ) - - -def test_flatten_and_restore_all_sorts_of_json(tmpdir, get_invented_table_suffix): - json = """ -[ - { - "a": 1, - "b": {"bb": 1}, - "c": {"cc": {"ccc": 1}}, - "d": [1, 2, 3], - "e": [ - {"ee": 1}, - {"ee": 2} - ], - "f": [ - { - "ff": [ - {"fff": 1}, - {"fff": 2} - ] - } - ], - } -] -""" - demo_root_invented_table = f"demo_{get_invented_table_suffix(1)}" - demo_invented_f_table = f"demo_{get_invented_table_suffix(2)}" - demo_invented_f_content_ff_table = f"demo_{get_invented_table_suffix(3)}" - demo_invented_e_table = f"demo_{get_invented_table_suffix(4)}" - demo_invented_d_table = f"demo_{get_invented_table_suffix(5)}" - - json_df = pd.read_json(json, orient="records") - rel_data = RelationalData(directory=tmpdir) - rel_data.add_table(name="demo", primary_key=None, data=json_df) - - assert set(rel_data.list_all_tables(Scope.ALL)) == { - "demo", - demo_root_invented_table, - demo_invented_f_table, - demo_invented_f_content_ff_table, - demo_invented_e_table, - demo_invented_d_table, - } - - assert rel_data.get_table_columns(demo_root_invented_table) == [ - "~PRIMARY_KEY_ID~", - "a", - "b>bb", - "c>cc>ccc", - ] - assert rel_data.get_table_columns(demo_invented_d_table) == [ - "~PRIMARY_KEY_ID~", - "demo~id", - "content", - "array~order", - ] - assert rel_data.get_table_columns(demo_invented_e_table) == [ - "~PRIMARY_KEY_ID~", - "demo~id", - "array~order", - "content>ee", - ] - assert rel_data.get_table_columns(demo_invented_f_table) == [ - "~PRIMARY_KEY_ID~", - "demo~id", - "array~order", - ] - assert rel_data.get_table_columns(demo_invented_f_content_ff_table) == [ - "~PRIMARY_KEY_ID~", - "demo^f~id", - "array~order", - "content>fff", - ] - - output_tables = { - demo_root_invented_table: pd.DataFrame( - data={ - "a": [1, 2], - "b>bb": [3, 4], - "c>cc>ccc": [5, 6], - "~PRIMARY_KEY_ID~": [0, 1], - } - ), - demo_invented_d_table: pd.DataFrame( - data={ - "content": [10, 11, 12, 13], - "~PRIMARY_KEY_ID~": [0, 1, 2, 3], - "demo~id": [0, 0, 0, 1], - "array~order": [0, 1, 2, 0], - } - ), - demo_invented_e_table: pd.DataFrame( - data={ - "content>ee": [100, 200, 300], - "~PRIMARY_KEY_ID~": [0, 1, 2], - "demo~id": [0, 1, 1], - "array~order": [0, 0, 1], - } - ), - demo_invented_f_table: pd.DataFrame( - data={"~PRIMARY_KEY_ID~": [0, 1], "demo~id": [0, 1], "array~order": [0, 0]} - ), - demo_invented_f_content_ff_table: pd.DataFrame( - data={ - "content>fff": [10, 11, 12], - "~PRIMARY_KEY_ID~": [0, 1, 2], - "demo^f~id": [0, 0, 0], - "array~order": [0, 1, 2], - } - ), - } - - restored = rel_data.restore(output_tables) - - expected = pd.DataFrame( - data={ - "a": [1, 2], - "b": [{"bb": 3}, {"bb": 4}], - "c": [{"cc": {"ccc": 5}}, {"cc": {"ccc": 6}}], - "d": [[10, 11, 12], [13]], - "e": [[{"ee": 100}], [{"ee": 200}, {"ee": 300}]], - "f": [[{"ff": [{"fff": 10}, {"fff": 11}, {"fff": 12}]}], [{"ff": []}]], - } - ) - - assert len(restored) == 1 - pdtest.assert_frame_equal(restored["demo"], expected) - - -def test_only_lists_edge_case(tmpdir): - # Smallest reproduction: a dataframe with just one row and one column, and the value is a list - list_df = pd.DataFrame(data={"l": [[1, 2, 3, 4]]}) - rel_data = RelationalData(directory=tmpdir) - - # Since there are no flat fields on the source, the invented root table would be empty. - # The root table is what we use for evaluation, so we bail. - with pytest.raises(ValueError): - rel_data.add_table(name="list", primary_key=None, data=list_df) - - assert rel_data.list_all_tables(Scope.ALL) == [] - - -def test_lists_of_lists(tmpdir, get_invented_table_suffix): - # Enough flat data in the source to create a root invented table. - # Upping the complexity by making the special value a list of lists, - # but not to fear: we can handle this correctly. - lol_df = pd.DataFrame(data={"a": [1], "l": [[[1, 2], [3, 4]]]}) - rel_data = RelationalData(directory=tmpdir) - rel_data.add_table(name="lol", primary_key=None, data=lol_df) - - lol_invented_root_table = f"lol_{get_invented_table_suffix(1)}" - lol_invented_l_table = f"lol_{get_invented_table_suffix(2)}" - lol_invented_l_content_table = f"lol_{get_invented_table_suffix(3)}" - - assert set(rel_data.list_all_tables(Scope.ALL)) == { - "lol", - lol_invented_root_table, - lol_invented_l_table, - lol_invented_l_content_table, - } - - output = { - lol_invented_root_table: pd.DataFrame( - data={"a": [1, 2], "~PRIMARY_KEY_ID~": [0, 1]} - ), - lol_invented_l_table: pd.DataFrame( - data={"~PRIMARY_KEY_ID~": [0, 1], "lol~id": [0, 0], "array~order": [0, 1]} - ), - lol_invented_l_content_table: pd.DataFrame( - data={ - "content": [10, 20, 30, 40], - "~PRIMARY_KEY_ID~": [0, 1, 2, 3], - "lol^l~id": [0, 0, 1, 1], - "array~order": [0, 1, 0, 1], - } - ), - } - restored = rel_data.restore(output) - - assert len(restored) == 1 - pdtest.assert_frame_equal( - restored["lol"], - pd.DataFrame( - data={ - "a": [1, 2], - "l": [[[10, 20], [30, 40]], []], - } - ), - ) - - -def test_mix_of_dict_and_list_cols(tmpdir, get_invented_table_suffix): - df = pd.DataFrame( - data={ - "id": [1, 2], - "dcol": [{"language": "english"}, {"language": "spanish"}], - "lcol": [["a", "b"], ["c", "d"]], - } - ) - mix_invented_root_table = f"mix_{get_invented_table_suffix(1)}" - mix_invented_lcol_table = f"mix_{get_invented_table_suffix(2)}" - - rel_data = RelationalData(directory=tmpdir) - rel_data.add_table(name="mix", primary_key=None, data=df) - assert set(rel_data.list_all_tables()) == { - mix_invented_root_table, - mix_invented_lcol_table, - } - assert rel_data.get_table_columns(mix_invented_root_table) == [ - "~PRIMARY_KEY_ID~", - "id", - "dcol>language", - ] - assert rel_data.get_table_columns(mix_invented_lcol_table) == [ - "~PRIMARY_KEY_ID~", - "mix~id", - "content", - "array~order", - ] - - -def test_all_tables_are_present_in_debug_summary(documents, invented_tables): - assert documents.debug_summary() == { - "foreign_key_count": 4, - "max_depth": 2, - "public_table_count": 3, - "invented_table_count": 2, - "tables": { - "users": { - "column_count": 2, - "primary_key": ["id"], - "foreign_key_count": 0, - "foreign_keys": [], - "is_invented_table": False, - }, - "payments": { - "column_count": 3, - "primary_key": ["id"], - "foreign_key_count": 1, - "foreign_keys": [ - { - "columns": ["purchase_id"], - "parent_table_name": invented_tables["purchases_root"], - "parent_columns": ["id"], - } - ], - "is_invented_table": False, - }, - "purchases": { - "column_count": 3, - "primary_key": ["id"], - "foreign_key_count": 1, - "foreign_keys": [ - { - "columns": ["user_id"], - "parent_table_name": "users", - "parent_columns": ["id"], - } - ], - "is_invented_table": False, - "invented_table_details": { - "table_type": "producer", - "json_to_table_mappings": { - "purchases": invented_tables["purchases_root"], - "purchases^data>years": invented_tables["purchases_data_years"], - }, - }, - }, - invented_tables["purchases_root"]: { - "column_count": 6, - "primary_key": ["id", "~PRIMARY_KEY_ID~"], - "foreign_key_count": 1, - "foreign_keys": [ - { - "columns": ["user_id"], - "parent_table_name": "users", - "parent_columns": ["id"], - } - ], - "is_invented_table": True, - "invented_table_details": { - "table_type": "invented", - "json_breadcrumb_path": "purchases", - }, - }, - invented_tables["purchases_data_years"]: { - "column_count": 4, - "primary_key": ["~PRIMARY_KEY_ID~"], - "foreign_key_count": 1, - "foreign_keys": [ - { - "columns": ["purchases~id"], - "parent_table_name": invented_tables["purchases_root"], - "parent_columns": ["~PRIMARY_KEY_ID~"], - } - ], - "is_invented_table": True, - "invented_table_details": { - "table_type": "invented", - "json_breadcrumb_path": "purchases^data>years", - }, - }, - }, - } - - -@pytest.mark.no_mock_suffix -def test_invented_table_names_contain_uuid(documents: RelationalData): - regex = re.compile(r"purchases_invented_[a-fA-F0-9]{32}") - tables = documents.list_all_tables(Scope.INVENTED) - assert len(tables) == 2 - assert regex.match(tables[0]) - assert regex.match(tables[1]) - - -@pytest.mark.no_mock_suffix -def test_generate_unique_table_name_truncates_length(): - table_name_128_chars = "loremipsumdolorsitametconsecteturadipiscingelitseddoeiusmodtemporincididuntutlaboreetdoloremagnaaliquautenimadminimveniamquisnos" - result = generate_unique_table_name(table_name_128_chars) - assert len(result) < 128 - - -@pytest.mark.no_mock_suffix -def test_deeply_nested_json_truncates_length(deeply_nested): - tables = deeply_nested.list_all_tables(Scope.ALL) - assert len(tables) == 10 - for table in tables: - assert len(table) < 128 diff --git a/tests/relational/test_report.py b/tests/relational/test_report.py deleted file mode 100644 index 7d6c3c28..00000000 --- a/tests/relational/test_report.py +++ /dev/null @@ -1,275 +0,0 @@ -from datetime import datetime - -from lxml import html - -from gretel_trainer.relational.core import Scope -from gretel_trainer.relational.report.report import ReportPresenter, ReportRenderer -from gretel_trainer.relational.table_evaluation import TableEvaluation - - -def _evals_from_rel_data(rel_data): - d = { - "synthetic_data_quality_score": {"score": 90, "grade": "Excellent"}, - "privacy_protection_level": {"score": 2, "grade": "Good"}, - } - evals = {} - for table in rel_data.list_all_tables(Scope.PUBLIC): - eval = TableEvaluation(cross_table_report_json=d, individual_report_json=d) - evals[table] = eval - return evals - - -def test_ecommerce_relational_data_report(ecom): - # Fake these - evaluations = _evals_from_rel_data(ecom) - - presenter = ReportPresenter( - rel_data=ecom, - evaluations=evaluations, - now=datetime.utcnow(), - run_identifier="run_identifier", - ) - - html_content = ReportRenderer().render(presenter) - - # DEV ONLY if you want to save a local copy to look at - # with open("report.html", 'w') as f: - # f.write(html_content) - - tree = html.fromstring(html_content) - - # Top level scores - assert ( - len( - tree.xpath( - '//div[contains(@class, "test-report-main-score")]' - + '//div[contains(@class, "score-container")]' - ) - ) - == 2 - ) - # SQS score label and bottom text - assert ( - tree.xpath( - '//div[contains(@class, "test-report-main-score")]' - + '//div[contains(@class, "score-container")]' - + '//span[contains(@class, "label")]' - )[0].text.strip() - == "Excellent" - ) - assert ( - tree.xpath( - '//div[contains(@class, "test-report-main-score")]' - + '//div[contains(@class, "score-container")]' - + '//span[contains(@class, "score-container-text")]' - )[0].text - == "Composite" #
cuts off the rest - ) - # PPL score label and bottom text - assert ( - tree.xpath( - '//div[contains(@class, "test-report-main-score")]' - + '//div[contains(@class, "score-container")]' - + '//span[contains(@class, "label")]' - )[1].text.strip() - == "Good" - ) - assert ( - tree.xpath( - '//div[contains(@class, "test-report-main-score")]' - + '//div[contains(@class, "score-container")]' - + '//span[contains(@class, "score-container-text")]' - )[0].text.strip() - == "Composite" #
cuts off the rest - ) - - # Table relationships - assert ( - len( - tree.xpath( - '//section[contains(@class, "test-table-relationships")]' + "//tr" - ) - ) - == 7 # Header plus six tables - ) - relations_data_rows = tree.xpath( - '//section[contains(@class, "test-table-relationships")]' + "//tr" - )[1:] - # First row, Table name td, bold tag wrapping table name - assert ( - relations_data_rows[0].getchildren()[0].getchildren()[0].text - == "distribution_center" - ) - # pk column/td, each is a span, unpack then text - pks = [row.getchildren()[1].getchildren()[0].text for row in relations_data_rows] - for pk in pks: - assert pk == "id" - # First row has no fk's - assert len(relations_data_rows[0].getchildren()[2].getchildren()) == 0 - # Third row has two fk's - assert len(relations_data_rows[2].getchildren()[2].getchildren()) == 2 - - # SQS score table - assert ( - len(tree.xpath('//section[contains(@class, "test-sqs-results")]' + "//tr")) - == 7 # Header plus six tables again - ) - assert ( - len( - tree.xpath( - '//section[contains(@class, "test-sqs-results")]' - + "//tr" - + '//span[contains(@class, "sqs-table-score")]' - ) - ) - == 12 # Six tables, each has two numeric scores - ) - assert ( - len( - tree.xpath( - '//section[contains(@class, "test-sqs-results")]' - + "//tr" - + '//span[contains(@class, "label")]' - ) - ) - == 12 # Six tables, each has two grade labels - ) - assert ( - len( - tree.xpath( - '//section[contains(@class, "test-sqs-results")]' - + "//tr" - + '//span[contains(@class, "sqs-table-link")]' - ) - ) - == 12 # Six tables, each has two linked reports - ) - # Check the first report link - assert ( - tree.xpath( - '//section[contains(@class, "test-sqs-results")]' - + "//tr" - + '//span[contains(@class, "sqs-table-link")]' - + "/a/@href" - )[0] - == "synthetics_individual_evaluation_distribution_center.html" - ) - - -def test_mutagenesis_relational_data_report(mutagenesis): - # Fake these - evaluations = _evals_from_rel_data(mutagenesis) - - presenter = ReportPresenter( - rel_data=mutagenesis, - evaluations=evaluations, - now=datetime.utcnow(), - run_identifier="run_identifier", - ) - - html_content = ReportRenderer().render(presenter) - - # DEV ONLY if you want to save a local copy to look at - # with open("report.html", 'w') as f: - # f.write(html_content) - - tree = html.fromstring(html_content) - - # Two scores at top - assert ( - len( - tree.xpath( - '//div[contains(@class, "test-report-main-score")]' - + '//div[contains(@class, "score-container")]' - ) - ) - == 2 - ) - - # Table relationships - assert ( - len( - tree.xpath( - '//section[contains(@class, "test-table-relationships")]' + "//tr" - ) - ) - == 4 # Header plus three tables - ) - - # SQS score table - assert ( - len(tree.xpath('//section[contains(@class, "test-sqs-results")]' + "//tr")) - == 4 # Header plus three tables again - ) - assert ( - len( - tree.xpath( - '//section[contains(@class, "test-sqs-results")]' - + "//tr" - + '//span[contains(@class, "sqs-table-score")]' - ) - ) - == 6 # Three tables, each has two numeric scores - ) - assert ( - len( - tree.xpath( - '//section[contains(@class, "test-sqs-results")]' - + "//tr" - + '//span[contains(@class, "label")]' - ) - ) - == 6 # Three tables, each has two grade labels - ) - assert ( - len( - tree.xpath( - '//section[contains(@class, "test-sqs-results")]' - + "//tr" - + '//span[contains(@class, "sqs-table-link")]' - ) - ) - == 6 # Three tables, each has two linked reports - ) - # Check the first report link - assert ( - tree.xpath( - '//section[contains(@class, "test-sqs-results")]' - + "//tr" - + '//span[contains(@class, "sqs-table-link")]' - + "/a/@href" - )[0] - == "synthetics_individual_evaluation_atom.html" - ) - - -def test_source_data_including_json(documents): - # Fake these - evaluations = _evals_from_rel_data(documents) - - presenter = ReportPresenter( - rel_data=documents, - evaluations=evaluations, - now=datetime.utcnow(), - run_identifier="run_identifier", - ) - - html_content = ReportRenderer().render(presenter) - - # DEV ONLY if you want to save a local copy to look at - # with open("report.html", 'w') as f: - # f.write(html_content) - - tree = html.fromstring(html_content) - - relations_data_rows = tree.xpath( - '//section[contains(@class, "test-table-relationships")]' + "//tr" - )[1:] - - # Ensure public names, not invented table names, are displayed - table_names = [ - # Row, Table name td, bold tag wrapping table name - row.getchildren()[0].getchildren()[0].text - for row in relations_data_rows - ] - assert table_names == ["payments", "purchases", "users"] diff --git a/tests/relational/test_synthetics_evaluate.py b/tests/relational/test_synthetics_evaluate.py deleted file mode 100644 index fde5ef04..00000000 --- a/tests/relational/test_synthetics_evaluate.py +++ /dev/null @@ -1,43 +0,0 @@ -import json - -from collections import defaultdict -from unittest.mock import Mock - -from gretel_trainer.relational.table_evaluation import TableEvaluation -from gretel_trainer.relational.task_runner import TaskContext -from gretel_trainer.relational.tasks.synthetics_evaluate import SyntheticsEvaluateTask - - -def test_sets_json_data_on_evaluations(output_handler, project): - json_result = {"sqs": 99} - - def mock_download_file_artifact(gretel_object, artifact_name, out_path): - if artifact_name == "report_json": - with open(out_path, "w") as out: - json.dump(json_result, out) - - ind_users_model = Mock() - evaluations = defaultdict(lambda: TableEvaluation()) - ext_sdk = Mock() - ext_sdk.download_file_artifact.side_effect = mock_download_file_artifact - context = TaskContext( - in_flight_jobs=0, - refresh_interval=0, - project=project, - extended_sdk=ext_sdk, - backup=lambda: None, - ) - - task = SyntheticsEvaluateTask( - individual_evaluate_models={"users": ind_users_model}, - cross_table_evaluate_models={}, - subdir="run", - output_handler=output_handler, - evaluations=evaluations, - ctx=context, - ) - - output_handler.make_subdirectory("run") - task.handle_completed(table="individual-users", job=ind_users_model) - - assert evaluations["users"].individual_report_json == json_result diff --git a/tests/relational/test_synthetics_run_task.py b/tests/relational/test_synthetics_run_task.py deleted file mode 100644 index 6b7b54b7..00000000 --- a/tests/relational/test_synthetics_run_task.py +++ /dev/null @@ -1,172 +0,0 @@ -from typing import Optional -from unittest.mock import Mock, patch - -import pandas as pd -import pandas.testing as pdtest - -from gretel_client.projects.jobs import Status -from gretel_trainer.relational.core import RelationalData -from gretel_trainer.relational.output_handler import OutputHandler -from gretel_trainer.relational.sdk_extras import ExtendedGretelSDK, MAX_IN_FLIGHT_JOBS -from gretel_trainer.relational.strategies.ancestral import AncestralStrategy -from gretel_trainer.relational.task_runner import TaskContext -from gretel_trainer.relational.tasks.synthetics_run import SyntheticsRunTask -from gretel_trainer.relational.workflow_state import SyntheticsRun, SyntheticsTrain - - -class MockStrategy(AncestralStrategy): - def post_process_individual_synthetic_result( - self, table_name, rel_data, synthetic_table, record_size_ratio - ): - return synthetic_table.head(1) - - -def make_task( - rel_data: RelationalData, - output_handler: OutputHandler, - preserved: Optional[list[str]] = None, - failed: Optional[list[str]] = None, - omitted: Optional[list[str]] = None, -) -> SyntheticsRunTask: - def _status_for_table(table: str, failed: list[str]) -> Status: - if table in failed: - return Status.ERROR - else: - return Status.COMPLETED - - context = TaskContext( - in_flight_jobs=0, - refresh_interval=0, - project=Mock(), - extended_sdk=ExtendedGretelSDK(hybrid=False), - backup=lambda: None, - ) - return SyntheticsRunTask( - synthetics_run=SyntheticsRun( - identifier="generate", - record_handlers={}, - lost_contact=[], - preserved=preserved or [], - record_size_ratio=1.0, - ), - synthetics_train=SyntheticsTrain( - models={ - table: Mock( - create_record_handler=Mock(), - status=_status_for_table(table, failed or []), - ) - for table in rel_data.list_all_tables() - if table not in (omitted or []) - }, - ), - output_handler=output_handler, - subdir="run-identifier", - ctx=context, - strategy=MockStrategy(), - rel_data=rel_data, - ) - - -def test_ignores_preserved_tables(pets, output_handler): - task = make_task(pets, output_handler, preserved=["pets"]) - - # Source data is used - assert task.working_tables["pets"] is not None - assert "pets" in task.output_tables - task.each_iteration() - assert "pets" not in task.synthetics_run.record_handlers - - -def test_ignores_tables_that_were_omitted_from_training(pets, output_handler): - task = make_task(pets, output_handler, omitted=["pets"]) - - # Source data is used - assert task.working_tables["pets"] is not None - assert "pets" in task.output_tables - task.each_iteration() - assert "pets" not in task.synthetics_run.record_handlers - - -def test_ignores_tables_that_failed_during_training(pets, output_handler): - task = make_task(pets, output_handler, failed=["pets"]) - - # We set tables that failed to explicit None - assert task.working_tables["pets"] is None - assert "pets" not in task.output_tables - task.each_iteration() - assert "pets" not in task.synthetics_run.record_handlers - - -def test_runs_post_processing_when_table_completes(pets, output_handler): - task = make_task(pets, output_handler) - - raw_df = pd.DataFrame(data={"col1": [1, 2], "col2": [3, 4]}) - - with patch( - "gretel_trainer.relational.sdk_extras.ExtendedGretelSDK.get_record_handler_data" - ) as get_rh_data: - get_rh_data.return_value = raw_df - task.handle_completed("table", Mock(ref_data=Mock(values=[]))) - - post_processed = task.working_tables["table"] - assert post_processed is not None - pdtest.assert_frame_equal(post_processed, raw_df.head(1)) - - -def test_starts_jobs_for_ready_tables(pets, output_handler): - task = make_task(pets, output_handler) - - assert len(task.synthetics_run.record_handlers) == 0 - - task.each_iteration() - - assert len(task.synthetics_run.record_handlers) == 1 - assert "humans" in task.synthetics_run.record_handlers - task.synthetics_train.models[ - "humans" - ].create_record_handler_obj.assert_called_once() - task.synthetics_run.record_handlers["humans"].submit.assert_called_once() - - -def test_defers_job_submission_if_max_jobs(pets, output_handler): - task = make_task(pets, output_handler) - - assert len(task.synthetics_run.record_handlers) == 0 - - humans_model = task.synthetics_train.models["humans"] - - # If we already have the max number of jobs in flight... - task.ctx.in_flight_jobs = MAX_IN_FLIGHT_JOBS - - task.each_iteration() - - # ...the record handler is created, but not submitted - assert len(task.synthetics_run.record_handlers) == 1 - assert "humans" in task.synthetics_run.record_handlers - humans_model.create_record_handler_obj.assert_called_once() - humans_record_handler = task.synthetics_run.record_handlers["humans"] - humans_record_handler.submit.assert_not_called() - - # Subsequent passes through the task loop will neither submit the job, - # nor recreate a new record handler instance. - humans_model.reset_mock() - task.ctx.maybe_start_job( - job=humans_record_handler, - table_name="humans", - action=task.action(humans_record_handler), - ) - task.each_iteration() - humans_model.create_record_handler_obj.assert_not_called() - humans_record_handler.submit.assert_not_called() - - # Once there is room again for more jobs... - task.ctx.in_flight_jobs = 0 - - # ...the next pass through submits the record handler since there is now room for another job. - task.ctx.maybe_start_job( - job=humans_record_handler, - table_name="humans", - action=task.action(humans_record_handler), - ) - humans_record_handler.submit.assert_called_once() - assert task.ctx.in_flight_jobs == 1 diff --git a/tests/relational/test_task_runner.py b/tests/relational/test_task_runner.py deleted file mode 100644 index 42679795..00000000 --- a/tests/relational/test_task_runner.py +++ /dev/null @@ -1,256 +0,0 @@ -from typing import Optional -from unittest.mock import Mock, patch - -import pytest - -from gretel_client.projects.exceptions import MaxConcurrentJobsException -from gretel_client.projects.jobs import Job, Status -from gretel_trainer.relational.sdk_extras import ExtendedGretelSDK -from gretel_trainer.relational.task_runner import run_task, TaskContext - - -class MockTask: - def __init__(self, project, models): - self.project = project - self.models = models - self.iteration_count = 0 - self.completed = [] - self.failed = [] - self.lost_contact = [] - self.ctx = TaskContext( - in_flight_jobs=0, - refresh_interval=0, - project=project, - extended_sdk=ExtendedGretelSDK(hybrid=False), - backup=lambda: None, - ) - - def action(self, job: Job) -> str: - return "mock task" - - @property - def table_collection(self) -> list[str]: - return list(self.models.keys()) - - def more_to_do(self) -> bool: - return len(self.completed + self.failed + self.lost_contact) < len(self.models) - - def is_finished(self, table: str) -> bool: - return table in (self.completed + self.failed + self.lost_contact) - - def get_job(self, table: str) -> Job: - return self.models[table] - - def handle_completed(self, table: str, job: Job) -> None: - self.completed.append(table) - - def handle_failed(self, table: str, job: Job) -> None: - self.failed.append(table) - - def handle_lost_contact(self, table: str, job: Job) -> None: - self.lost_contact.append(table) - - def handle_in_progress(self, table: str, job: Job) -> None: - pass - - def each_iteration(self) -> None: - self.iteration_count += 1 - - -class MockModel: - def __init__(self, statuses: list[Optional[str]], fail_n_times: int = 0): - self.identifier = None - - self._statuses = statuses - self.status = None - - self._fail_n_times = fail_n_times - self._fail_count = 0 - - def submit(self): - if self._fail_count < self._fail_n_times: - self._fail_count += 1 - raise MaxConcurrentJobsException() - self.identifier = "identifier" - - def refresh(self): - next_status = self._statuses.pop(0) - if next_status is None: - raise Exception() - self.status = next_status - - -@pytest.fixture(autouse=True) -def mock_extended_sdk(): - def _get_job_id(mock_model): - return mock_model.identifier - - extended_sdk = ExtendedGretelSDK(hybrid=False) - extended_sdk.get_job_id = _get_job_id # type:ignore - return extended_sdk - - -def test_one_successful_model(mock_extended_sdk): - models = { - "table": MockModel(statuses=[Status.COMPLETED]), - } - - task = MockTask( - project=Mock(), - models=models, - ) - run_task(task, mock_extended_sdk) - - assert task.iteration_count == 2 - assert task.completed == ["table"] - assert task.failed == [] - - -def test_one_failed_model(mock_extended_sdk): - models = { - "table": MockModel(statuses=[Status.ERROR]), - } - - task = MockTask( - project=Mock(), - models=models, - ) - run_task(task, mock_extended_sdk) - - assert task.iteration_count == 2 - assert task.completed == [] - assert task.failed == ["table"] - - -def test_model_taking_awhile(mock_extended_sdk): - models = { - "table": MockModel(statuses=[Status.ACTIVE, Status.ACTIVE, Status.COMPLETED]), - } - - task = MockTask( - project=Mock(), - models=models, - ) - run_task(task, mock_extended_sdk) - - assert task.iteration_count == 4 - assert task.completed == ["table"] - assert task.failed == [] - - -def test_lose_contact_with_model(mock_extended_sdk): - # By only setting one status, subsequent calls to `refresh` will throw - # an IndexError (as a stand-in for SDK refresh errors) - models = { - "table": MockModel(statuses=[Status.ACTIVE]), - } - - task = MockTask( - project=Mock(), - models=models, - ) - run_task(task, mock_extended_sdk) - - # Bail after refresh fails MAX_REFRESH_ATTEMPTS times - # (first iteration creates the job, +4 refresh failures) - assert task.iteration_count == 5 - assert task.completed == [] - assert task.failed == [] - assert task.lost_contact == ["table"] - - -def test_refresh_status_can_tolerate_blips(mock_extended_sdk): - models = { - "table": MockModel( - statuses=[Status.ACTIVE, None, Status.ACTIVE, Status.COMPLETED] - ), - } - - task = MockTask( - project=Mock(), - models=models, - ) - run_task(task, mock_extended_sdk) - - # 1. Create - # 2. Active - # 3. Blip - # 4. Active - # 5. Completed - assert task.iteration_count == 5 - assert task.completed == ["table"] - assert task.failed == [] - assert task.lost_contact == [] - - -def test_defers_submission_if_max_jobs_in_flight(mock_extended_sdk): - model_1 = MockModel(statuses=[Status.ACTIVE, Status.COMPLETED]) - model_2 = MockModel(statuses=[Status.ACTIVE, Status.COMPLETED]) - - models = {"t1": model_1, "t2": model_2} - - task = MockTask( - project=Mock(), - models=models, - ) - with patch("gretel_trainer.relational.sdk_extras.MAX_IN_FLIGHT_JOBS", 1): - run_task(task, mock_extended_sdk) - - # 1: Started, Deferred - # 2: Active, Deferred - # 3: Completed, Started - # 4: Completed, Active - # 5: Completed, Completed - assert task.iteration_count == 5 - assert task.completed == ["t1", "t2"] - - -def test_defers_submission_if_max_jobs_in_created_state(mock_extended_sdk): - # In this test, we're not running into our client-side max jobs limit; - # rather, the API is not allowing us to submit due to too many jobs in created state. - # The second model fails to be submitted 5 times (e.g. due to other unrelated jobs) - # before getting submitted successfully - model_1 = MockModel(statuses=[Status.ACTIVE, Status.COMPLETED]) - model_2 = MockModel(statuses=[Status.ACTIVE, Status.COMPLETED], fail_n_times=5) - - models = {"t1": model_1, "t2": model_2} - - task = MockTask( - project=Mock(), - models=models, - ) - run_task(task, mock_extended_sdk) - - # 1: Started, Deferred - # 2: Active, Deferred - # 3: Completed, Deferred - # 4: Completed, Deferred - # 5: Completed, Deferred - # 6: Completed, Started - # 7: Completed, Active - # 8: Completed, Completed - assert task.iteration_count == 8 - assert task.completed == ["t1", "t2"] - - -def test_several_models(mock_extended_sdk): - completed_model = MockModel(statuses=[Status.COMPLETED]) - error_model = MockModel(statuses=[Status.ERROR]) - cancelled_model = MockModel(statuses=[Status.CANCELLED]) - lost_model = MockModel(statuses=[Status.LOST]) - - models = { - "completed": completed_model, - "error": error_model, - "cancelled": cancelled_model, - "lost": lost_model, - } - - task = MockTask( - project=Mock(), - models=models, - ) - run_task(task, mock_extended_sdk) - - assert task.completed == ["completed"] - assert set(task.failed) == {"error", "cancelled", "lost"} diff --git a/tests/relational/test_train_synthetics.py b/tests/relational/test_train_synthetics.py deleted file mode 100644 index 1b8a53ff..00000000 --- a/tests/relational/test_train_synthetics.py +++ /dev/null @@ -1,233 +0,0 @@ -import tempfile - -from unittest.mock import ANY, patch - -import pytest - -from gretel_trainer.relational.core import MultiTableException -from gretel_trainer.relational.multi_table import MultiTable - - -# The assertions in this file are concerned with setting up the synthetics train -# workflow state properly, and stop short of kicking off the task. -@pytest.fixture(autouse=True) -def run_task(): - with patch("gretel_trainer.relational.multi_table.run_task"): - yield - - -@pytest.fixture(autouse=True) -def backup(): - with patch.object(MultiTable, "_backup", return_value=None): - yield - - -@pytest.fixture() -def tmpdir(project): - with tempfile.TemporaryDirectory() as tmpdir: - project.name = tmpdir - yield tmpdir - - -class ModelConfigMatcher: - def __init__(self, model_key: str): - self.model_key = model_key - - def __eq__(self, other): - return list(other["models"][0])[0] == self.model_key - - -def test_train_synthetics_strategy_specific_default_configs(pets, tmpdir, project): - mt = MultiTable(pets, strategy="independent", project_display_name=tmpdir) - mt.train_synthetics() - project.create_model_obj.assert_called_with( - model_config=ModelConfigMatcher("actgan"), - data_source=f"{tmpdir}/synthetics_train_pets.csv", - ) - - mt = MultiTable(pets, strategy="ancestral", project_display_name=tmpdir) - mt.train_synthetics() - project.create_model_obj.assert_called_with( - model_config=ModelConfigMatcher("amplify"), - data_source=f"{tmpdir}/synthetics_train_pets.csv", - ) - - -def test_train_synthetics_defaults_to_training_all_tables(ecom, tmpdir): - mt = MultiTable(ecom, project_display_name=tmpdir) - mt.train_synthetics(config="synthetics/amplify") - - assert set(mt._synthetics_train.models.keys()) == set(ecom.list_all_tables()) - - -def test_train_synthetics_only_includes_specified_tables(ecom, tmpdir, project): - mt = MultiTable(ecom, project_display_name=tmpdir) - mt.train_synthetics(config="synthetics/amplify", only={"users"}) - - assert set(mt._synthetics_train.models.keys()) == {"users"} - project.create_model_obj.assert_called_with( - model_config=ANY, # a tailored synthetics config, in dict form - data_source=f"{tmpdir}/synthetics_train_users.csv", - ) - - -def test_train_synthetics_ignore_excludes_specified_tables(ecom, tmpdir): - mt = MultiTable(ecom, project_display_name=tmpdir) - mt.train_synthetics( - config="synthetics/amplify", ignore={"distribution_center", "products"} - ) - - assert set(mt._synthetics_train.models.keys()) == { - "events", - "users", - "order_items", - "inventory_items", - } - - -def test_train_synthetics_exits_early_if_unrecognized_tables(ecom, tmpdir, project): - mt = MultiTable(ecom, project_display_name=tmpdir) - with pytest.raises(MultiTableException): - mt.train_synthetics(config="synthetics/amplify", ignore={"nonsense"}) - - assert len(mt._synthetics_train.models) == 0 - project.create_model_obj.assert_not_called() - - -def test_train_synthetics_custom_configs_per_table(ecom, tmpdir, project): - mock_actgan_config = {"models": [{"actgan": {}}]} - mock_tabdp_config = {"models": [{"tabular_dp": {}}]} - - mt = MultiTable(ecom, project_display_name=tmpdir) - - # We provide an actgan config to use for tables PLUS a tabular-dp config for one specific table. - mt.train_synthetics( - config=mock_actgan_config, table_specific_configs={"events": mock_tabdp_config} - ) - - # The tabular-dp config is used for the singularly called-out table... - project.create_model_obj.assert_any_call( - model_config={"name": "synthetics-events", **mock_tabdp_config}, - data_source=f"{tmpdir}/synthetics_train_events.csv", - ) - - # ...and the actgan config is used for all the rest. - project.create_model_obj.assert_any_call( - model_config={"name": "synthetics-users", **mock_actgan_config}, - data_source=f"{tmpdir}/synthetics_train_users.csv", - ) - - -def test_train_synthetics_validates_against_configured_strategy(pets, tmpdir): - # Independent strategy - mt_independent = MultiTable( - pets, project_display_name=tmpdir, strategy="independent" - ) - - mt_independent.train_synthetics(config="synthetics/tabular-lstm") - mt_independent.train_synthetics(config="synthetics/tabular-actgan") - mt_independent.train_synthetics(config="synthetics/amplify") - mt_independent.train_synthetics(config="synthetics/tabular-differential-privacy") - with pytest.raises(MultiTableException): - mt_independent.train_synthetics(config="synthetics/time-series") - - # Ancestral strategy - mt_ancestral = MultiTable(pets, project_display_name=tmpdir, strategy="ancestral") - - mt_ancestral.train_synthetics(config="synthetics/amplify") - with pytest.raises(MultiTableException): - mt_ancestral.train_synthetics(config="synthetics/tabular-lstm") - with pytest.raises(MultiTableException): - mt_ancestral.train_synthetics(config="synthetics/tabular-actgan") - with pytest.raises(MultiTableException): - mt_ancestral.train_synthetics(config="synthetics/tabular-differential-privacy") - with pytest.raises(MultiTableException): - mt_ancestral.train_synthetics(config="synthetics/time-series") - - -def test_train_synthetics_errors(ecom, tmpdir): - actgan_config = {"models": [{"actgan": {}}]} - mt = MultiTable(ecom, project_display_name=tmpdir) - - # Invalid config - with pytest.raises(MultiTableException): - mt.train_synthetics(config="nonsense") - - # Unrecognized table - with pytest.raises(MultiTableException): - mt.train_synthetics( - config="synthetics/amplify", - table_specific_configs={"not-a-table": actgan_config}, - ) - - # Config provided for omitted table - with pytest.raises(MultiTableException): - mt.train_synthetics( - config="synthetics/amplify", - ignore={"users"}, - table_specific_configs={"users": actgan_config}, - ) - - # Config for unsupported model - mt = MultiTable(ecom, project_display_name=tmpdir, strategy="ancestral") - with pytest.raises(MultiTableException): - mt.train_synthetics(config=actgan_config) - - # Table config for unsupported model - mt = MultiTable(ecom, project_display_name=tmpdir, strategy="ancestral") - with pytest.raises(MultiTableException): - mt.train_synthetics( - config="synthetics/amplify", table_specific_configs={"users": actgan_config} - ) - - -def test_train_synthetics_multiple_calls_additive(ecom, tmpdir): - mt = MultiTable(ecom, project_display_name=tmpdir) - mt.train_synthetics(config="synthetics/amplify", only={"products"}) - mt.train_synthetics(config="synthetics/amplify", only={"users"}) - - # We do not lose the first table model - assert set(mt._synthetics_train.models.keys()) == {"products", "users"} - - -def test_train_synthetics_models_for_dbs_with_invented_tables( - documents, tmpdir, get_invented_table_suffix -): - mt = MultiTable(documents, project_display_name=tmpdir) - mt.train_synthetics(config="synthetics/amplify") - - purchases_root_invented_table = f"purchases_{get_invented_table_suffix(1)}" - purchases_data_years_invented_table = f"purchases_{get_invented_table_suffix(2)}" - - assert set(mt._synthetics_train.models.keys()) == { - "users", - "payments", - purchases_root_invented_table, - purchases_data_years_invented_table, - } - - -def test_train_synthetics_table_filters_cascade_to_invented_tables(documents, tmpdir): - # When a user provides the ("public") name of a table that contained JSON and led - # to the creation of invented tables, we recognize that as implicitly applying to - # all the tables internally created from that source table. - mt = MultiTable(documents, project_display_name=tmpdir) - mt.train_synthetics(config="synthetics/amplify", ignore={"purchases"}) - - assert set(mt._synthetics_train.models.keys()) == {"users", "payments"} - - -def test_train_synthetics_multiple_calls_overwrite(ecom, tmpdir, project): - project.create_model_obj.return_value = "m1" - - mt = MultiTable(ecom, project_display_name=tmpdir) - mt.train_synthetics(config="synthetics/amplify", only={"products"}) - - assert mt._synthetics_train.models["products"] == "m1" - - project.reset_mock() - project.create_model_obj.return_value = "m2" - - # calling a second time will create a new model for the table that overwrites the original - mt.train_synthetics(config="synthetics/amplify", only={"products"}) - assert mt._synthetics_train.models["products"] == "m2" diff --git a/tests/relational/test_train_transforms.py b/tests/relational/test_train_transforms.py deleted file mode 100644 index 65895c90..00000000 --- a/tests/relational/test_train_transforms.py +++ /dev/null @@ -1,97 +0,0 @@ -import tempfile - -from unittest.mock import ANY, patch - -import pytest - -from gretel_trainer.relational.core import MultiTableException -from gretel_trainer.relational.multi_table import MultiTable - - -# The assertions in this file are concerned with setting up the transforms train -# workflow state properly, and stop short of kicking off the task. -@pytest.fixture(autouse=True) -def run_task(): - with patch("gretel_trainer.relational.multi_table.run_task"): - yield - - -@pytest.fixture(autouse=True) -def backup(): - with patch.object(MultiTable, "_backup", return_value=None): - yield - - -@pytest.fixture() -def tmpdir(project): - with tempfile.TemporaryDirectory() as tmpdir: - project.name = tmpdir - yield tmpdir - - -def test_train_transforms_defaults_to_transforming_all_tables(ecom, tmpdir): - mt = MultiTable(ecom, project_display_name=tmpdir) - mt.train_transforms("transform/default") - transforms_train = mt._transforms_train - - assert set(transforms_train.models.keys()) == set(ecom.list_all_tables()) - - -def test_train_transforms_only_includes_specified_tables(ecom, tmpdir, project): - mt = MultiTable(ecom, project_display_name=tmpdir) - mt.train_transforms("transform/default", only={"users"}) - transforms_train = mt._transforms_train - - assert set(transforms_train.models.keys()) == {"users"} - project.create_model_obj.assert_called_with( - model_config=ANY, # a tailored transforms config, in dict form - data_source=f"{tmpdir}/users.csv", - ) - - -def test_train_transforms_ignore_excludes_specified_tables(ecom, tmpdir): - mt = MultiTable(ecom, project_display_name=tmpdir) - mt.train_transforms("transform/default", ignore={"distribution_center", "products"}) - transforms_train = mt._transforms_train - - assert set(transforms_train.models.keys()) == { - "events", - "users", - "order_items", - "inventory_items", - } - - -def test_train_transforms_exits_early_if_unrecognized_tables(ecom, tmpdir, project): - mt = MultiTable(ecom, project_display_name=tmpdir) - with pytest.raises(MultiTableException): - mt.train_transforms("transform/default", ignore={"nonsense"}) - transforms_train = mt._transforms_train - - assert len(transforms_train.models) == 0 - project.create_model_obj.assert_not_called() - - -def test_train_transforms_multiple_calls_additive(ecom, tmpdir): - mt = MultiTable(ecom, project_display_name=tmpdir) - mt.train_transforms("transform/default", only={"products"}) - mt.train_transforms("transform/default", only={"users"}) - - # We do not lose the first table model - assert set(mt._transforms_train.models.keys()) == {"products", "users"} - - -def test_train_transforms_multiple_calls_overwrite(ecom, tmpdir, project): - project.create_model_obj.return_value = "m1" - - mt = MultiTable(ecom, project_display_name=tmpdir) - mt.train_transforms("transform/default", only={"products"}) - - assert mt._transforms_train.models["products"] == "m1" - - project.reset_mock() - project.create_model_obj.return_value = "m2" - - # calling a second time will create a new model for the table that overwrites the original - mt.train_transforms("transform/default", only={"products"}) - assert mt._transforms_train.models["products"] == "m2"