Skip to content

Commit

Permalink
Fix: Defer creation of the deploable table for a new forward-only mod…
Browse files Browse the repository at this point in the history
…el (#3657)
  • Loading branch information
izeigerman authored Jan 20, 2025
1 parent 90cd883 commit 017da33
Show file tree
Hide file tree
Showing 5 changed files with 232 additions and 14 deletions.
63 changes: 55 additions & 8 deletions sqlmesh/core/snapshot/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -762,6 +762,8 @@ def _create_snapshot(
and adapter.SUPPORTS_CLONING
# managed models cannot have their schema mutated because theyre based on queries, so clone + alter wont work
and not snapshot.is_managed
# If the deployable table is missing we can't clone it
and True not in deployability_flags
):
target_table_name = snapshot.table_name(is_deployable=False)
tmp_table_name = f"{target_table_name}__schema_migration_source"
Expand Down Expand Up @@ -797,6 +799,19 @@ def _create_snapshot(
else:
dry_run = len(deployability_flags) == 1
for is_table_deployable in deployability_flags:
if (
is_table_deployable
and snapshot.model.forward_only
and not is_snapshot_representative
):
logger.info(
"Skipping creation of the deployable table '%s' for the forward-only model %s. "
"The table will be created when the snapshot is deployed to production",
snapshot.table_name(is_deployable=is_table_deployable),
snapshot.snapshot_id,
)
continue

evaluation_strategy.create(
table_name=snapshot.table_name(is_deployable=is_table_deployable),
model=snapshot.model,
Expand Down Expand Up @@ -829,15 +844,47 @@ def _migrate_snapshot(
if not needs_migration:
return

tmp_table_name = snapshot.table_name(is_deployable=False)
evaluation_strategy = _evaluation_strategy(snapshot, adapter)

target_table_name = snapshot.table_name()
_evaluation_strategy(snapshot, adapter).migrate(
target_table_name=target_table_name,
source_table_name=tmp_table_name,
snapshot=snapshot,
snapshots=parent_snapshots_by_name(snapshot, snapshots),
allow_destructive_snapshots=allow_destructive_snapshots,
)
if adapter.table_exists(target_table_name):
tmp_table_name = snapshot.table_name(is_deployable=False)
logger.info(
"Migrating table schema from '%s' to '%s'",
tmp_table_name,
target_table_name,
)
evaluation_strategy.migrate(
target_table_name=target_table_name,
source_table_name=tmp_table_name,
snapshot=snapshot,
snapshots=parent_snapshots_by_name(snapshot, snapshots),
allow_destructive_snapshots=allow_destructive_snapshots,
)
else:
logger.info(
"Creating table '%s' for the snapshot of the forward-only model %s",
target_table_name,
snapshot.snapshot_id,
)
render_kwargs: t.Dict[str, t.Any] = dict(
engine_adapter=adapter,
snapshots=parent_snapshots_by_name(snapshot, snapshots),
runtime_stage=RuntimeStage.CREATING,
deployability_index=DeployabilityIndex.all_deployable(),
)
with adapter.transaction(), adapter.session(snapshot.model.session_properties):
adapter.execute(snapshot.model.render_pre_statements(**render_kwargs))
evaluation_strategy.create(
table_name=target_table_name,
model=snapshot.model,
is_table_deployable=True,
render_kwargs=render_kwargs,
is_snapshot_deployable=True,
is_snapshot_representative=True,
dry_run=False,
)
adapter.execute(snapshot.model.render_post_statements(**render_kwargs))

def _promote_snapshot(
self,
Expand Down
8 changes: 6 additions & 2 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
SnapshotChangeCategory,
SnapshotDataVersion,
SnapshotFingerprint,
DeployabilityIndex,
)
from sqlmesh.utils import random_id
from sqlmesh.utils.date import TimeLike, to_date
Expand Down Expand Up @@ -246,9 +247,12 @@ def push_plan(context: Context, plan: Plan) -> None:
context.create_scheduler,
context.default_catalog,
)
plan_evaluator._push(plan.to_evaluatable(), plan.snapshots)
deployability_index = DeployabilityIndex.create(context.snapshots.values())
plan_evaluator._push(plan.to_evaluatable(), plan.snapshots, deployability_index)
promotion_result = plan_evaluator._promote(plan.to_evaluatable(), plan.snapshots)
plan_evaluator._update_views(plan.to_evaluatable(), plan.snapshots, promotion_result)
plan_evaluator._update_views(
plan.to_evaluatable(), plan.snapshots, promotion_result, deployability_index
)


@pytest.fixture()
Expand Down
4 changes: 0 additions & 4 deletions tests/core/engine_adapter/integration/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -1476,10 +1476,6 @@ def test_sushi(ctx: TestContext, tmp_path_factory: pytest.TempPathFactory):
"table": "Sushi customer data",
"column": {"customer_id": "customer_id uniquely identifies customers"},
},
"marketing": {
"table": "Sushi marketing data",
"column": {"customer_id": "customer_id uniquely identifies customers \\"},
},
"orders": {
"table": "Table of sushi orders.",
},
Expand Down
18 changes: 18 additions & 0 deletions tests/core/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -788,6 +788,24 @@ def test_forward_only_parent_created_in_dev_child_created_in_prod(
context.apply(plan)


@time_machine.travel("2023-01-08 00:00:00 UTC")
def test_new_forward_only_model(init_and_plan_context: t.Callable):
context, _ = init_and_plan_context("examples/sushi")

context.plan("dev", skip_tests=True, no_prompts=True, auto_apply=True)

snapshot = context.get_snapshot("sushi.marketing")

# The deployable table should not exist yet
assert not context.engine_adapter.table_exists(snapshot.table_name())
assert context.engine_adapter.table_exists(snapshot.table_name(is_deployable=False))

context.plan("prod", skip_tests=True, no_prompts=True, auto_apply=True)

assert context.engine_adapter.table_exists(snapshot.table_name())
assert context.engine_adapter.table_exists(snapshot.table_name(is_deployable=False))


@time_machine.travel("2023-01-08 15:00:00 UTC")
def test_plan_set_choice_is_reflected_in_missing_intervals(init_and_plan_context: t.Callable):
context, plan = init_and_plan_context("examples/sushi")
Expand Down
153 changes: 153 additions & 0 deletions tests/core/test_snapshot_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -765,6 +765,54 @@ def test_create_only_dev_table_exists(mocker: MockerFixture, adapter_mock, make_
)


def test_create_new_forward_only_model(mocker: MockerFixture, adapter_mock, make_snapshot):
model = load_sql_based_model(
parse( # type: ignore
"""
MODEL (
name test_schema.test_model,
kind INCREMENTAL_BY_TIME_RANGE (
time_column ds,
forward_only true,
)
);
SELECT a::int, '2024-01-01' as ds FROM tbl;
"""
),
)

snapshot = make_snapshot(model)
snapshot.categorize_as(SnapshotChangeCategory.BREAKING)

adapter_mock.get_data_objects.return_value = []
adapter_mock.table_exists.return_value = False
evaluator = SnapshotEvaluator(adapter_mock)

evaluator.create([snapshot], {}, deployability_index=DeployabilityIndex.none_deployable())
adapter_mock.create_schema.assert_called_once_with(to_schema("sqlmesh__test_schema"))
# Only non-deployable table should be created
adapter_mock.create_table.assert_called_once_with(
f"sqlmesh__test_schema.test_schema__test_model__{snapshot.temp_version_get_or_generate()}__temp",
columns_to_types={"a": exp.DataType.build("int"), "ds": exp.DataType.build("varchar")},
table_format=None,
storage_format=None,
partitioned_by=model.partitioned_by,
partition_interval_unit=model.partition_interval_unit,
clustered_by=[],
table_properties={},
table_description=None,
column_descriptions=None,
)
adapter_mock.get_data_objects.assert_called_once_with(
schema_("sqlmesh__test_schema"),
{
f"test_schema__test_model__{snapshot.version}",
f"test_schema__test_model__{snapshot.temp_version_get_or_generate()}__temp",
},
)


@pytest.mark.parametrize(
"deployability_index, snapshot_category, deployability_flags",
[
Expand Down Expand Up @@ -1122,6 +1170,7 @@ def columns(table_name):
}

adapter.columns = columns # type: ignore
adapter.table_exists = lambda _: True # type: ignore

evaluator = SnapshotEvaluator(adapter)

Expand All @@ -1148,6 +1197,42 @@ def columns(table_name):
)


def test_migrate_missing_table(mocker: MockerFixture, make_snapshot):
connection_mock = mocker.NonCallableMock()
cursor_mock = mocker.Mock()
connection_mock.cursor.return_value = cursor_mock
adapter = EngineAdapter(lambda: connection_mock, "")

adapter.table_exists = lambda _: False # type: ignore

evaluator = SnapshotEvaluator(adapter)

model = SqlModel(
name="test_schema.test_model",
kind=IncrementalByTimeRangeKind(
time_column="a", on_destructive_change=OnDestructiveChange.ALLOW
),
storage_format="parquet",
query=parse_one("SELECT c, a FROM tbl WHERE ds BETWEEN @start_ds and @end_ds"),
pre_statements=[parse_one("CREATE TABLE pre (a INT)")],
post_statements=[parse_one("DROP TABLE pre")],
)
snapshot = make_snapshot(model, version="1")
snapshot.change_category = SnapshotChangeCategory.FORWARD_ONLY

evaluator.migrate([snapshot], {})

cursor_mock.execute.assert_has_calls(
[
call('CREATE TABLE "pre" ("a" INT)'),
call(
'CREATE TABLE IF NOT EXISTS "sqlmesh__test_schema"."test_schema__test_model__1" AS SELECT "c" AS "c", "a" AS "a" FROM "tbl" AS "tbl" WHERE "ds" BETWEEN \'1970-01-01\' AND \'1970-01-01\' AND FALSE LIMIT 0'
),
call('DROP TABLE "pre"'),
]
)


@pytest.mark.parametrize(
"change_category",
[SnapshotChangeCategory.FORWARD_ONLY, SnapshotChangeCategory.INDIRECT_NON_BREAKING],
Expand Down Expand Up @@ -1386,6 +1471,14 @@ def test_create_clone_in_dev(mocker: MockerFixture, adapter_mock, make_snapshot)
snapshot.categorize_as(SnapshotChangeCategory.FORWARD_ONLY)
snapshot.previous_versions = snapshot.all_versions

adapter_mock.get_data_objects.return_value = [
DataObject(
name=f"test_schema__test_model__{snapshot.version}",
schema="sqlmesh__test_schema",
type=DataObjectType.TABLE,
),
]

evaluator.create([snapshot], {})

adapter_mock.create_table.assert_called_once_with(
Expand Down Expand Up @@ -1419,6 +1512,50 @@ def test_create_clone_in_dev(mocker: MockerFixture, adapter_mock, make_snapshot)
)


def test_create_clone_in_dev_missing_table(mocker: MockerFixture, adapter_mock, make_snapshot):
adapter_mock.SUPPORTS_CLONING = True
adapter_mock.get_alter_expressions.return_value = []
evaluator = SnapshotEvaluator(adapter_mock)

model = load_sql_based_model(
parse( # type: ignore
"""
MODEL (
name test_schema.test_model,
kind INCREMENTAL_BY_TIME_RANGE (
time_column ds,
forward_only true,
)
);
SELECT 1::INT as a, ds::DATE FROM a;
"""
),
)

snapshot = make_snapshot(model)
snapshot.categorize_as(SnapshotChangeCategory.FORWARD_ONLY)
snapshot.previous_versions = snapshot.all_versions

evaluator.create([snapshot], {}, deployability_index=DeployabilityIndex.none_deployable())

adapter_mock.create_table.assert_called_once_with(
f"sqlmesh__test_schema.test_schema__test_model__{snapshot.temp_version_get_or_generate()}__temp",
columns_to_types={"a": exp.DataType.build("int"), "ds": exp.DataType.build("date")},
table_format=None,
storage_format=None,
partitioned_by=[exp.to_column("ds", quoted=True)],
partition_interval_unit=IntervalUnit.DAY,
clustered_by=[],
table_properties={},
table_description=None,
column_descriptions=None,
)

adapter_mock.clone_table.assert_not_called()
adapter_mock.alter_table.assert_not_called()


def test_drop_clone_in_dev_when_migration_fails(mocker: MockerFixture, adapter_mock, make_snapshot):
adapter_mock.SUPPORTS_CLONING = True
adapter_mock.get_alter_expressions.return_value = []
Expand All @@ -1445,6 +1582,14 @@ def test_drop_clone_in_dev_when_migration_fails(mocker: MockerFixture, adapter_m
snapshot.categorize_as(SnapshotChangeCategory.FORWARD_ONLY)
snapshot.previous_versions = snapshot.all_versions

adapter_mock.get_data_objects.return_value = [
DataObject(
name=f"test_schema__test_model__{snapshot.version}",
schema="sqlmesh__test_schema",
type=DataObjectType.TABLE,
),
]

evaluator.create([snapshot], {})

adapter_mock.clone_table.assert_called_once_with(
Expand Down Expand Up @@ -1494,6 +1639,14 @@ def test_create_clone_in_dev_self_referencing(mocker: MockerFixture, adapter_moc
snapshot.categorize_as(SnapshotChangeCategory.FORWARD_ONLY)
snapshot.previous_versions = snapshot.all_versions

adapter_mock.get_data_objects.return_value = [
DataObject(
name=f"test_schema__test_model__{snapshot.version}",
schema="sqlmesh__test_schema",
type=DataObjectType.TABLE,
),
]

evaluator.create([snapshot], {})

adapter_mock.create_table.assert_called_once_with(
Expand Down

0 comments on commit 017da33

Please sign in to comment.