Skip to content

Commit

Permalink
fix: disable datafusion optimizers which perform list field name erasure
Browse files Browse the repository at this point in the history
Today the make_array function from Datafusion uses "item" as the list
element's field name. With recent changes in delta-kernel-rs we have
switched to calling it "element" which is more conventional related to
how Apache Parquet handles things

This change introduces a test which still fails with some deep nested
type mismatch error.

        thread 'operations::update::tests::test_update_with_array_that_must_be_coerced' panicked at crates/core/src/operations/update.rs:1143:14:
        called `Result::unwrap()` on an `Err` value: Arrow { source: InvalidArgumentError("arguments need to have the same data type") }
        note: run with `RUST_BACKTRACE=1` environment variable to display a backtrace

Signed-off-by: R. Tyler Croy <[email protected]>
  • Loading branch information
rtyler committed Nov 18, 2024
1 parent c60e29f commit 7f87b64
Show file tree
Hide file tree
Showing 3 changed files with 273 additions and 24 deletions.
171 changes: 156 additions & 15 deletions crates/core/src/delta_datafusion/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,25 +23,166 @@
use std::fmt::{self, Display, Error, Formatter, Write};
use std::sync::Arc;

use arrow_schema::DataType;
use arrow_array::{Array, GenericListArray};
use arrow_schema::{DataType, Field};
use chrono::{DateTime, NaiveDate};
use datafusion::execution::context::SessionState;
use datafusion::execution::session_state::SessionStateBuilder;
use datafusion::execution::FunctionRegistry;
use datafusion::functions_array::make_array::MakeArray;
use datafusion_common::Result as DFResult;
use datafusion_common::{config::ConfigOptions, DFSchema, Result, ScalarValue, TableReference};
use datafusion_expr::expr::InList;
use datafusion_expr::planner::ExprPlanner;
use datafusion_expr::{AggregateUDF, Between, BinaryExpr, Cast, Expr, Like, TableSource};
// Needed for MakeParquetArray
use datafusion_expr::{ColumnarValue, Documentation, ScalarUDF, ScalarUDFImpl, Signature};
use datafusion_functions::core::planner::CoreFunctionPlanner;
use datafusion_sql::planner::{ContextProvider, SqlToRel};
use datafusion_sql::sqlparser::ast::escape_quoted_string;
use datafusion_sql::sqlparser::dialect::GenericDialect;
use datafusion_sql::sqlparser::parser::Parser;
use datafusion_sql::sqlparser::tokenizer::Tokenizer;
use tracing::log::*;

use super::DeltaParserOptions;
use crate::{DeltaResult, DeltaTableError};

/// This struct is like Datafusion's MakeArray but ensures that `element` is used rather than `item
/// as the field name within the list.
#[derive(Debug)]
struct MakeParquetArray {
/// The actual upstream UDF, which we're just totally cheating and using
actual: MakeArray,
/// Aliases for this UDF
aliases: Vec<String>,
}

impl MakeParquetArray {
pub fn new() -> Self {
let actual = MakeArray::default();
let aliases = vec!["make_array".into(), "make_list".into()];
Self { actual, aliases }
}
}

impl ScalarUDFImpl for MakeParquetArray {
fn as_any(&self) -> &dyn std::any::Any {
self
}

fn name(&self) -> &str {
"make_parquet_array"
}

fn signature(&self) -> &Signature {
self.actual.signature()
}

fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
let r_type = match arg_types.len() {
0 => Ok(DataType::List(Arc::new(Field::new(
"element",
DataType::Int32,
true,
)))),
_ => {
// At this point, all the type in array should be coerced to the same one
Ok(DataType::List(Arc::new(Field::new(
"element",
arg_types[0].to_owned(),
true,
))))
}
};
debug!("MakeParquetArray return_type -> {r_type:?}");
r_type
}

fn invoke(&self, args: &[ColumnarValue]) -> Result<ColumnarValue> {
match self.actual.invoke(args)? {
ColumnarValue::Scalar(ScalarValue::List(df_array)) => {
let field = Arc::new(Field::new("element", DataType::Int64, true));
let result = Ok(ColumnarValue::Scalar(ScalarValue::List(Arc::new(
GenericListArray::<i32>::try_new(
field,
df_array.offsets().clone(),
arrow_array::make_array(df_array.values().into_data()),
None,
)?,
))));
debug!("MakeParquetArray;invoke returning: {result:?}");
result
}
others => {
error!("Unexpected response inside MakeParquetArray! {others:?}");
Ok(others)
}
}
}

fn invoke_no_args(&self, number_rows: usize) -> Result<ColumnarValue> {
self.actual.invoke_no_args(number_rows)
}

fn aliases(&self) -> &[String] {
&self.aliases
}

fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
self.actual.coerce_types(arg_types)
}

fn documentation(&self) -> Option<&Documentation> {
self.actual.documentation()
}
}

use datafusion::functions_array::planner::{FieldAccessPlanner, NestedFunctionPlanner};

/// This exists becxause the NestedFunctionPlanner _not_ the UserDefinedFunctionPlanner handles the
/// insertion of "make_array" which is used to turn [100] into List<field=element, values=[100]>
///
/// **screaming intensifies**
#[derive(Debug)]
struct CustomNestedFunctionPlanner {
original: NestedFunctionPlanner,
}

impl Default for CustomNestedFunctionPlanner {
fn default() -> Self {
Self {
original: NestedFunctionPlanner,
}
}
}

use datafusion_expr::planner::{PlannerResult, RawBinaryExpr};
impl ExprPlanner for CustomNestedFunctionPlanner {
fn plan_array_literal(
&self,
exprs: Vec<Expr>,
_schema: &DFSchema,
) -> Result<PlannerResult<Vec<Expr>>> {
let udf = Arc::new(ScalarUDF::from(MakeParquetArray::new()));

Ok(PlannerResult::Planned(udf.call(exprs)))
}
fn plan_binary_op(
&self,
expr: RawBinaryExpr,
schema: &DFSchema,
) -> Result<PlannerResult<RawBinaryExpr>> {
self.original.plan_binary_op(expr, schema)
}
fn plan_make_map(&self, args: Vec<Expr>) -> Result<PlannerResult<Vec<Expr>>> {
self.original.plan_make_map(args)
}
fn plan_any(&self, expr: RawBinaryExpr) -> Result<PlannerResult<RawBinaryExpr>> {
self.original.plan_any(expr)
}
}

pub(crate) struct DeltaContextProvider<'a> {
state: SessionState,
/// Keeping this around just to make use of the 'a lifetime
Expand All @@ -51,22 +192,22 @@ pub(crate) struct DeltaContextProvider<'a> {

impl<'a> DeltaContextProvider<'a> {
fn new(state: &'a SessionState) -> Self {
let planners = state.expr_planners();
// default planners are [CoreFunctionPlanner, NestedFunctionPlanner, FieldAccessPlanner,
// UserDefinedFunctionPlanner]
let planners: Vec<Arc<dyn ExprPlanner>> = vec![
Arc::new(CoreFunctionPlanner::default()),
Arc::new(CustomNestedFunctionPlanner::default()),
Arc::new(FieldAccessPlanner),
Arc::new(datafusion::functions::planner::UserDefinedFunctionPlanner),
];
// Disable the above for testing
//let planners = state.expr_planners();
let new_state = SessionStateBuilder::new_from_existing(state.clone())
.with_expr_planners(planners.clone())
.build();
DeltaContextProvider {
planners,
// Creating a new session state with overridden scalar_functions since
// the get_field() UDF was dropped from the default scalar functions upstream in
// `36660fe10d9c0cdff62e0da0b94bee28422d3419`
state: SessionStateBuilder::new_from_existing(state.clone())
.with_scalar_functions(
state
.scalar_functions()
.values()
.cloned()
.chain(std::iter::once(datafusion::functions::core::get_field()))
.collect(),
)
.build(),
state: new_state,
_original: state,
}
}
Expand Down
34 changes: 28 additions & 6 deletions crates/core/src/operations/cast/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -275,12 +275,12 @@ mod tests {
fn test_merge_arrow_schema_with_nested() {
let left_schema = Arc::new(Schema::new(vec![Field::new(
"f",
DataType::LargeList(Arc::new(Field::new("item", DataType::Utf8, false))),
DataType::LargeList(Arc::new(Field::new("element", DataType::Utf8, false))),
false,
)]));
let right_schema = Arc::new(Schema::new(vec![Field::new(
"f",
DataType::List(Arc::new(Field::new("item", DataType::LargeUtf8, false))),
DataType::List(Arc::new(Field::new("element", DataType::LargeUtf8, false))),
true,
)]));

Expand All @@ -306,7 +306,7 @@ mod tests {

let fields = Fields::from(vec![Field::new_list(
"list_column",
Field::new("item", DataType::Int8, false),
Field::new("element", DataType::Int8, false),
false,
)]);
let target_schema = Arc::new(Schema::new(fields)) as SchemaRef;
Expand All @@ -316,7 +316,7 @@ mod tests {
let schema = result.unwrap().schema();
let field = schema.column_with_name("list_column").unwrap().1;
if let DataType::List(list_item) = field.data_type() {
assert_eq!(list_item.name(), "item");
assert_eq!(list_item.name(), "element");
} else {
panic!("Not a list");
}
Expand All @@ -343,12 +343,34 @@ mod tests {

#[test]
fn test_is_cast_required_with_list() {
let field1 = DataType::List(FieldRef::from(Field::new("item", DataType::Int32, false)));
let field2 = DataType::List(FieldRef::from(Field::new("item", DataType::Int32, false)));
let field1 = DataType::List(FieldRef::from(Field::new(
"element",
DataType::Int32,
false,
)));
let field2 = DataType::List(FieldRef::from(Field::new(
"element",
DataType::Int32,
false,
)));

assert!(!is_cast_required(&field1, &field2));
}

/// Delta has adopted "element" as the default list field name rather than the previously used
/// "item". This lines up more with Apache Parquet but should be handled in casting
#[test]
fn test_is_cast_required_with_old_and_new_list() {
let field1 = DataType::List(FieldRef::from(Field::new(
"element",
DataType::Int32,
false,
)));
let field2 = DataType::List(FieldRef::from(Field::new("item", DataType::Int32, false)));

assert!(is_cast_required(&field1, &field2));
}

#[test]
fn test_is_cast_required_with_smol_int() {
assert!(is_cast_required(&DataType::Int8, &DataType::Int32));
Expand Down
92 changes: 89 additions & 3 deletions crates/core/src/operations/update.rs
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,21 @@ async fn execute(
return Err(DeltaTableError::NotInitializedWithFiles("UPDATE".into()));
}

// NOTE: The optimize_projections rule is being temporarily disabled because it errors with
// our schemas for Lists due to issues discussed
// [here](https://github.com/delta-io/delta-rs/pull/2886#issuecomment-2481550560>
let rules: Vec<Arc<dyn datafusion::optimizer::OptimizerRule + Send + Sync>> = state
.optimizers()
.into_iter()
.filter(|rule| {
rule.name() != "optimize_projections" && rule.name() != "simplify_expressions"
})
.cloned()
.collect();
let state = SessionStateBuilder::from(state)
.with_optimizer_rules(rules)
.build();

let update_planner = DeltaPlanner::<UpdateMetricExtensionPlanner> {
extension_planner: UpdateMetricExtensionPlanner {},
};
Expand Down Expand Up @@ -323,7 +338,6 @@ async fn execute(
enable_pushdown: false,
}),
});

let df_with_predicate_and_metrics = DataFrame::new(state.clone(), plan_with_metrics);

let expressions: Vec<Expr> = df_with_predicate_and_metrics
Expand All @@ -343,6 +357,8 @@ async fn execute(
})
.collect::<DeltaResult<Vec<Expr>>>()?;

//let updated_df = df_with_predicate_and_metrics.clone();
// Disabling the select allows the coerce test to pass, still not sure why
let updated_df = df_with_predicate_and_metrics.select(expressions.clone())?;
let physical_plan = updated_df.clone().create_physical_plan().await?;
let writer_stats_config = WriterStatsConfig::new(
Expand Down Expand Up @@ -1040,11 +1056,81 @@ mod tests {
assert_eq!(table.version(), 1);
// Completed the first creation/write

// Update
use arrow::array::{Int32Builder, ListBuilder};
let mut new_items_builder =
ListBuilder::new(Int32Builder::new()).with_field(arrow_field.clone());
new_items_builder.append_value([Some(100)]);
let new_items = ScalarValue::List(Arc::new(new_items_builder.finish()));

let (table, _metrics) = DeltaOps(table)
.update()
.with_predicate(col("id").eq(lit(1)))
.with_update("items", lit(new_items))
.await
.unwrap();
assert_eq!(table.version(), 2);
}

/// Lists coming in from the Python bindings need to be parsed as SQL expressions by the update
/// and therefore this test emulates their behavior to ensure that the lists are being turned
/// into expressions for the update operation correctly
#[tokio::test]
async fn test_update_with_array_that_must_be_coerced() {
let _ = pretty_env_logger::try_init();
let schema = StructType::new(vec![
StructField::new(
"id".to_string(),
DeltaDataType::Primitive(PrimitiveType::Integer),
true,
),
StructField::new(
"temp".to_string(),
DeltaDataType::Primitive(PrimitiveType::Integer),
true,
),
StructField::new(
"items".to_string(),
DeltaDataType::Array(Box::new(crate::kernel::ArrayType::new(
DeltaDataType::LONG,
false,
))),
true,
),
]);
let arrow_schema: ArrowSchema = (&schema).try_into().unwrap();

// Create the first batch
let arrow_field = Field::new("element", DataType::Int64, false);
let list_array = ListArray::new_null(arrow_field.clone().into(), 2);
let batch = RecordBatch::try_new(
Arc::new(arrow_schema.clone()),
vec![
Arc::new(Int32Array::from(vec![Some(0), Some(1)])),
Arc::new(Int32Array::from(vec![Some(30), Some(31)])),
Arc::new(list_array),
],
)
.expect("Failed to create record batch");
let _ = arrow::util::pretty::print_batches(&[batch.clone()]);

let table = DeltaOps::new_in_memory()
.create()
.with_columns(schema.fields().cloned())
.await
.unwrap();
assert_eq!(table.version(), 0);

let table = DeltaOps(table)
.write(vec![batch])
.await
.expect("Failed to write first batch");
assert_eq!(table.version(), 1);
// Completed the first creation/write

let (table, _metrics) = DeltaOps(table)
.update()
.with_predicate(col("id").eq(lit(1)))
.with_update("items", make_array(vec![lit(100)]))
.with_update("items", "[100]".to_string())
.await
.unwrap();
assert_eq!(table.version(), 2);
Expand Down

0 comments on commit 7f87b64

Please sign in to comment.