Skip to content

Commit

Permalink
chore: tests
Browse files Browse the repository at this point in the history
Signed-off-by: Ion Koutsouris <[email protected]>
  • Loading branch information
ion-elgreco authored and rtyler committed Jan 15, 2025
1 parent 5ba1e64 commit b7f75dd
Show file tree
Hide file tree
Showing 3 changed files with 267 additions and 4 deletions.
44 changes: 43 additions & 1 deletion crates/core/src/kernel/models/schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,6 @@ impl StructTypeExt for StructType {
};
}
}
dbg!(generated_cols.clone());
Ok(generated_cols)
}

Expand Down Expand Up @@ -183,6 +182,49 @@ mod tests {
use serde_json;
use serde_json::json;

#[test]
fn test_get_generated_columns() {
let schema: StructType = serde_json::from_value(json!(
{
"type":"struct",
"fields":[
{"name":"id","type":"integer","nullable":true,"metadata":{}},
{"name":"gc","type":"integer","nullable":true,"metadata":{}}]
}
))
.unwrap();
let cols = schema.get_generated_columns().unwrap();
assert_eq!(cols.len(), 0);

let schema: StructType = serde_json::from_value(json!(
{
"type":"struct",
"fields":[
{"name":"id","type":"integer","nullable":true,"metadata":{}},
{"name":"gc","type":"integer","nullable":true,"metadata":{"delta.generationExpression":"\"5\""}}]
}
)).unwrap();
let cols = schema.get_generated_columns().unwrap();
assert_eq!(cols.len(), 1);
assert_eq!(cols[0].data_type, DataType::INTEGER);
assert_eq!(
cols[0].validation_expr,
"gc = 5 OR (gc IS NULL AND 5 IS NULL)"
);

let schema: StructType = serde_json::from_value(json!(
{
"type":"struct",
"fields":[
{"name":"id","type":"integer","nullable":true,"metadata":{}},
{"name":"gc","type":"integer","nullable":true,"metadata":{"delta.generationExpression":"\"5\""}},
{"name":"id2","type":"integer","nullable":true,"metadata":{"delta.generationExpression":"\"id * 10\""}},]
}
)).unwrap();
let cols = schema.get_generated_columns().unwrap();
assert_eq!(cols.len(), 2);
}

#[test]
fn test_get_invariants() {
let schema: StructType = serde_json::from_value(json!({
Expand Down
7 changes: 4 additions & 3 deletions crates/core/src/operations/merge/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -761,11 +761,12 @@ async fn execute(
for generated_col in generated_cols {
let col_name = generated_col.get_name();

if !df
if df
.clone()
.schema()
.field_names()
.contains(&col_name.to_string())
.field_with_unqualified_name(&col_name.to_string())
.is_err()
// implies it doesn't exist
{
debug!(
"Adding missing generated column {} in source as placeholder",
Expand Down
220 changes: 220 additions & 0 deletions python/tests/test_generated_columns.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,220 @@
import pyarrow as pa
import pytest

from deltalake import DeltaTable, Field, Schema, write_deltalake
from deltalake.exceptions import DeltaError, SchemaMismatchError
from deltalake.schema import PrimitiveType


@pytest.fixture
def gc_schema() -> Schema:
return Schema(
[
Field(name="id", type=PrimitiveType("integer")),
Field(
name="gc",
type=PrimitiveType("integer"),
metadata={"delta.generationExpression": "'5'"},
),
]
)


@pytest.fixture
def valid_gc_data() -> pa.Table:
id_col = pa.field("id", pa.int32())
gc = pa.field("gc", pa.int32()).with_metadata({"delta.generationExpression": "10"})
data = pa.Table.from_pydict(
{"id": [1, 2], "gc": [10, 10]}, schema=pa.schema([id_col, gc])
)
return data


@pytest.fixture
def data_without_gc() -> pa.Table:
id_col = pa.field("id", pa.int32())
data = pa.Table.from_pydict({"id": [1, 2]}, schema=pa.schema([id_col]))
return data


@pytest.fixture
def invalid_gc_data() -> pa.Table:
id_col = pa.field("id", pa.int32())
gc = pa.field("gc", pa.int32()).with_metadata({"delta.generationExpression": "10"})
data = pa.Table.from_pydict(
{"id": [1, 2], "gc": [5, 10]}, schema=pa.schema([id_col, gc])
)
return data


@pytest.fixture
def table_with_gc(tmp_path, gc_schema) -> DeltaTable:
dt = DeltaTable.create(
tmp_path,
schema=gc_schema,
)
return dt


def test_create_table_with_generated_columns(tmp_path, gc_schema: Schema):
dt = DeltaTable.create(
tmp_path,
schema=gc_schema,
)
protocol = dt.protocol()
assert protocol.min_writer_version == 4

dt = DeltaTable.create(
tmp_path,
schema=gc_schema,
mode="overwrite",
configuration={"delta.minWriterVersion": "7"},
)
protocol = dt.protocol()

assert dt.version() == 1
assert protocol.writer_features is not None
assert "generatedColumns" in protocol.writer_features


def test_write_with_gc(tmp_path, valid_gc_data):
write_deltalake(tmp_path, mode="append", data=valid_gc_data)
dt = DeltaTable(tmp_path)

assert dt.protocol().min_writer_version == 4
assert dt.to_pyarrow_table() == valid_gc_data


def test_write_with_gc_higher_writer_version(tmp_path, valid_gc_data):
write_deltalake(
tmp_path,
mode="append",
data=valid_gc_data,
configuration={"delta.minWriterVersion": "7"},
)
dt = DeltaTable(tmp_path)
protocol = dt.protocol()
assert protocol.min_writer_version == 7
assert protocol.writer_features is not None
assert "generatedColumns" in protocol.writer_features
assert dt.to_pyarrow_table() == valid_gc_data


def test_write_with_invalid_gc(tmp_path, invalid_gc_data):
import re

with pytest.raises(
DeltaError,
match=re.escape(
'Invariant violations: ["Check or Invariant (gc = 10 OR (gc IS NULL AND 10 IS NULL)) violated by value in row: [5]"]'
),
):
write_deltalake(tmp_path, mode="append", data=invalid_gc_data)


def test_write_with_invalid_gc_to_table(table_with_gc, invalid_gc_data):
import re

with pytest.raises(
DeltaError,
match=re.escape(
"Invariant violations: [\"Check or Invariant (gc = '5' OR (gc IS NULL AND '5' IS NULL)) violated by value in row: [10]\"]"
),
):
write_deltalake(table_with_gc, mode="append", data=invalid_gc_data)


def test_write_to_table_generating_data(table_with_gc: DeltaTable):
id_col = pa.field("id", pa.int32())
data = pa.Table.from_pydict({"id": [1, 2]}, schema=pa.schema([id_col]))
write_deltalake(table_with_gc, mode="append", data=data)

id_col = pa.field("id", pa.int32())
gc = pa.field("gc", pa.int32())
expected_data = pa.Table.from_pydict(
{"id": [1, 2], "gc": [5, 5]}, schema=pa.schema([id_col, gc])
)

assert table_with_gc.version() == 1
assert table_with_gc.to_pyarrow_table() == expected_data


def test_raise_when_gc_passed_during_schema_evolution(
tmp_path, data_without_gc, valid_gc_data
):
write_deltalake(
tmp_path,
mode="append",
data=data_without_gc,
)
dt = DeltaTable(tmp_path)
assert dt.protocol().min_writer_version == 2

with pytest.raises(
SchemaMismatchError,
match="Schema evolved fields cannot have generated expressions. Recreate the table to achieve this.",
):
write_deltalake(
dt,
mode="append",
data=valid_gc_data,
schema_mode="merge",
)


def test_raise_when_gc_passed_during_adding_new_columns(tmp_path, data_without_gc):
write_deltalake(
tmp_path,
mode="append",
data=data_without_gc,
)
dt = DeltaTable(tmp_path)
assert dt.protocol().min_writer_version == 2

with pytest.raises(DeltaError, match="New columns cannot be a generated column"):
dt.alter.add_columns(
fields=[
Field(
name="gc",
type=PrimitiveType("integer"),
metadata={"delta.generationExpression": "'5'"},
)
]
)


def test_merge_with_gc(table_with_gc: DeltaTable, data_without_gc):
(
table_with_gc.merge(
data_without_gc, predicate="s.id = t.id", source_alias="s", target_alias="t"
)
.when_not_matched_insert_all()
.execute()
)
id_col = pa.field("id", pa.int32())
gc = pa.field("gc", pa.int32())
expected_data = pa.Table.from_pydict(
{"id": [1, 2], "gc": [5, 5]}, schema=pa.schema([id_col, gc])
)
assert table_with_gc.to_pyarrow_table() == expected_data


def test_merge_with_gc_invalid(table_with_gc: DeltaTable, invalid_gc_data):
import re

with pytest.raises(
DeltaError,
match=re.escape(
"Invariant violations: [\"Check or Invariant (gc = '5' OR (gc IS NULL AND '5' IS NULL)) violated by value in row: [10]\"]"
),
):
(
table_with_gc.merge(
invalid_gc_data,
predicate="s.id = t.id",
source_alias="s",
target_alias="t",
)
.when_not_matched_insert_all()
.execute()
)

0 comments on commit b7f75dd

Please sign in to comment.