diff --git a/crates/core/src/delta_datafusion/expr.rs b/crates/core/src/delta_datafusion/expr.rs index eb542d98dd..587df851da 100644 --- a/crates/core/src/delta_datafusion/expr.rs +++ b/crates/core/src/delta_datafusion/expr.rs @@ -23,25 +23,166 @@ use std::fmt::{self, Display, Error, Formatter, Write}; use std::sync::Arc; -use arrow_schema::DataType; +use arrow_array::{Array, GenericListArray}; +use arrow_schema::{DataType, Field}; use chrono::{DateTime, NaiveDate}; use datafusion::execution::context::SessionState; use datafusion::execution::session_state::SessionStateBuilder; use datafusion::execution::FunctionRegistry; +use datafusion::functions_array::make_array::MakeArray; use datafusion_common::Result as DFResult; use datafusion_common::{config::ConfigOptions, DFSchema, Result, ScalarValue, TableReference}; use datafusion_expr::expr::InList; use datafusion_expr::planner::ExprPlanner; use datafusion_expr::{AggregateUDF, Between, BinaryExpr, Cast, Expr, Like, TableSource}; +// Needed for MakeParquetArray +use datafusion_expr::{ColumnarValue, Documentation, ScalarUDF, ScalarUDFImpl, Signature}; +use datafusion_functions::core::planner::CoreFunctionPlanner; use datafusion_sql::planner::{ContextProvider, SqlToRel}; use datafusion_sql::sqlparser::ast::escape_quoted_string; use datafusion_sql::sqlparser::dialect::GenericDialect; use datafusion_sql::sqlparser::parser::Parser; use datafusion_sql::sqlparser::tokenizer::Tokenizer; +use tracing::log::*; use super::DeltaParserOptions; use crate::{DeltaResult, DeltaTableError}; +/// This struct is like Datafusion's MakeArray but ensures that `element` is used rather than `item +/// as the field name within the list. +#[derive(Debug)] +struct MakeParquetArray { + /// The actual upstream UDF, which we're just totally cheating and using + actual: MakeArray, + /// Aliases for this UDF + aliases: Vec, +} + +impl MakeParquetArray { + pub fn new() -> Self { + let actual = MakeArray::default(); + let aliases = vec!["make_array".into(), "make_list".into()]; + Self { actual, aliases } + } +} + +impl ScalarUDFImpl for MakeParquetArray { + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn name(&self) -> &str { + "make_parquet_array" + } + + fn signature(&self) -> &Signature { + self.actual.signature() + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + let r_type = match arg_types.len() { + 0 => Ok(DataType::List(Arc::new(Field::new( + "element", + DataType::Int32, + true, + )))), + _ => { + // At this point, all the type in array should be coerced to the same one + Ok(DataType::List(Arc::new(Field::new( + "element", + arg_types[0].to_owned(), + true, + )))) + } + }; + debug!("MakeParquetArray return_type -> {r_type:?}"); + r_type + } + + fn invoke(&self, args: &[ColumnarValue]) -> Result { + match self.actual.invoke(args)? { + ColumnarValue::Scalar(ScalarValue::List(df_array)) => { + let field = Arc::new(Field::new("element", DataType::Int64, true)); + let result = Ok(ColumnarValue::Scalar(ScalarValue::List(Arc::new( + GenericListArray::::try_new( + field, + df_array.offsets().clone(), + arrow_array::make_array(df_array.values().into_data()), + None, + )?, + )))); + debug!("MakeParquetArray;invoke returning: {result:?}"); + result + } + others => { + error!("Unexpected response inside MakeParquetArray! {others:?}"); + Ok(others) + } + } + } + + fn invoke_no_args(&self, number_rows: usize) -> Result { + self.actual.invoke_no_args(number_rows) + } + + fn aliases(&self) -> &[String] { + &self.aliases + } + + fn coerce_types(&self, arg_types: &[DataType]) -> Result> { + self.actual.coerce_types(arg_types) + } + + fn documentation(&self) -> Option<&Documentation> { + self.actual.documentation() + } +} + +use datafusion::functions_array::planner::{FieldAccessPlanner, NestedFunctionPlanner}; + +/// This exists becxause the NestedFunctionPlanner _not_ the UserDefinedFunctionPlanner handles the +/// insertion of "make_array" which is used to turn [100] into List +/// +/// **screaming intensifies** +#[derive(Debug)] +struct CustomNestedFunctionPlanner { + original: NestedFunctionPlanner, +} + +impl Default for CustomNestedFunctionPlanner { + fn default() -> Self { + Self { + original: NestedFunctionPlanner, + } + } +} + +use datafusion_expr::planner::{PlannerResult, RawBinaryExpr}; +impl ExprPlanner for CustomNestedFunctionPlanner { + fn plan_array_literal( + &self, + exprs: Vec, + _schema: &DFSchema, + ) -> Result>> { + let udf = Arc::new(ScalarUDF::from(MakeParquetArray::new())); + + Ok(PlannerResult::Planned(udf.call(exprs))) + } + fn plan_binary_op( + &self, + expr: RawBinaryExpr, + schema: &DFSchema, + ) -> Result> { + self.original.plan_binary_op(expr, schema) + } + fn plan_make_map(&self, args: Vec) -> Result>> { + self.original.plan_make_map(args) + } + fn plan_any(&self, expr: RawBinaryExpr) -> Result> { + self.original.plan_any(expr) + } +} + pub(crate) struct DeltaContextProvider<'a> { state: SessionState, /// Keeping this around just to make use of the 'a lifetime @@ -51,22 +192,22 @@ pub(crate) struct DeltaContextProvider<'a> { impl<'a> DeltaContextProvider<'a> { fn new(state: &'a SessionState) -> Self { - let planners = state.expr_planners(); + // default planners are [CoreFunctionPlanner, NestedFunctionPlanner, FieldAccessPlanner, + // UserDefinedFunctionPlanner] + let planners: Vec> = vec![ + Arc::new(CoreFunctionPlanner::default()), + Arc::new(CustomNestedFunctionPlanner::default()), + Arc::new(FieldAccessPlanner), + Arc::new(datafusion::functions::planner::UserDefinedFunctionPlanner), + ]; + // Disable the above for testing + //let planners = state.expr_planners(); + let new_state = SessionStateBuilder::new_from_existing(state.clone()) + .with_expr_planners(planners.clone()) + .build(); DeltaContextProvider { planners, - // Creating a new session state with overridden scalar_functions since - // the get_field() UDF was dropped from the default scalar functions upstream in - // `36660fe10d9c0cdff62e0da0b94bee28422d3419` - state: SessionStateBuilder::new_from_existing(state.clone()) - .with_scalar_functions( - state - .scalar_functions() - .values() - .cloned() - .chain(std::iter::once(datafusion::functions::core::get_field())) - .collect(), - ) - .build(), + state: new_state, _original: state, } } diff --git a/crates/core/src/operations/cast/mod.rs b/crates/core/src/operations/cast/mod.rs index 278cb2bbfa..a358515194 100644 --- a/crates/core/src/operations/cast/mod.rs +++ b/crates/core/src/operations/cast/mod.rs @@ -275,12 +275,12 @@ mod tests { fn test_merge_arrow_schema_with_nested() { let left_schema = Arc::new(Schema::new(vec![Field::new( "f", - DataType::LargeList(Arc::new(Field::new("item", DataType::Utf8, false))), + DataType::LargeList(Arc::new(Field::new("element", DataType::Utf8, false))), false, )])); let right_schema = Arc::new(Schema::new(vec![Field::new( "f", - DataType::List(Arc::new(Field::new("item", DataType::LargeUtf8, false))), + DataType::List(Arc::new(Field::new("element", DataType::LargeUtf8, false))), true, )])); @@ -306,7 +306,7 @@ mod tests { let fields = Fields::from(vec![Field::new_list( "list_column", - Field::new("item", DataType::Int8, false), + Field::new("element", DataType::Int8, false), false, )]); let target_schema = Arc::new(Schema::new(fields)) as SchemaRef; @@ -316,7 +316,7 @@ mod tests { let schema = result.unwrap().schema(); let field = schema.column_with_name("list_column").unwrap().1; if let DataType::List(list_item) = field.data_type() { - assert_eq!(list_item.name(), "item"); + assert_eq!(list_item.name(), "element"); } else { panic!("Not a list"); } @@ -343,12 +343,34 @@ mod tests { #[test] fn test_is_cast_required_with_list() { - let field1 = DataType::List(FieldRef::from(Field::new("item", DataType::Int32, false))); - let field2 = DataType::List(FieldRef::from(Field::new("item", DataType::Int32, false))); + let field1 = DataType::List(FieldRef::from(Field::new( + "element", + DataType::Int32, + false, + ))); + let field2 = DataType::List(FieldRef::from(Field::new( + "element", + DataType::Int32, + false, + ))); assert!(!is_cast_required(&field1, &field2)); } + /// Delta has adopted "element" as the default list field name rather than the previously used + /// "item". This lines up more with Apache Parquet but should be handled in casting + #[test] + fn test_is_cast_required_with_old_and_new_list() { + let field1 = DataType::List(FieldRef::from(Field::new( + "element", + DataType::Int32, + false, + ))); + let field2 = DataType::List(FieldRef::from(Field::new("item", DataType::Int32, false))); + + assert!(is_cast_required(&field1, &field2)); + } + #[test] fn test_is_cast_required_with_smol_int() { assert!(is_cast_required(&DataType::Int8, &DataType::Int32)); diff --git a/crates/core/src/operations/update.rs b/crates/core/src/operations/update.rs index 3cd9e8b80c..55f1ab4b30 100644 --- a/crates/core/src/operations/update.rs +++ b/crates/core/src/operations/update.rs @@ -242,6 +242,21 @@ async fn execute( return Err(DeltaTableError::NotInitializedWithFiles("UPDATE".into())); } + // NOTE: The optimize_projections rule is being temporarily disabled because it errors with + // our schemas for Lists due to issues discussed + // [here](https://github.com/delta-io/delta-rs/pull/2886#issuecomment-2481550560> + let rules: Vec> = state + .optimizers() + .into_iter() + .filter(|rule| { + rule.name() != "optimize_projections" && rule.name() != "simplify_expressions" + }) + .cloned() + .collect(); + let state = SessionStateBuilder::from(state) + .with_optimizer_rules(rules) + .build(); + let update_planner = DeltaPlanner:: { extension_planner: UpdateMetricExtensionPlanner {}, }; @@ -323,7 +338,6 @@ async fn execute( enable_pushdown: false, }), }); - let df_with_predicate_and_metrics = DataFrame::new(state.clone(), plan_with_metrics); let expressions: Vec = df_with_predicate_and_metrics @@ -343,6 +357,8 @@ async fn execute( }) .collect::>>()?; + //let updated_df = df_with_predicate_and_metrics.clone(); + // Disabling the select allows the coerce test to pass, still not sure why let updated_df = df_with_predicate_and_metrics.select(expressions.clone())?; let physical_plan = updated_df.clone().create_physical_plan().await?; let writer_stats_config = WriterStatsConfig::new( @@ -1040,11 +1056,81 @@ mod tests { assert_eq!(table.version(), 1); // Completed the first creation/write - // Update + use arrow::array::{Int32Builder, ListBuilder}; + let mut new_items_builder = + ListBuilder::new(Int32Builder::new()).with_field(arrow_field.clone()); + new_items_builder.append_value([Some(100)]); + let new_items = ScalarValue::List(Arc::new(new_items_builder.finish())); + + let (table, _metrics) = DeltaOps(table) + .update() + .with_predicate(col("id").eq(lit(1))) + .with_update("items", lit(new_items)) + .await + .unwrap(); + assert_eq!(table.version(), 2); + } + + /// Lists coming in from the Python bindings need to be parsed as SQL expressions by the update + /// and therefore this test emulates their behavior to ensure that the lists are being turned + /// into expressions for the update operation correctly + #[tokio::test] + async fn test_update_with_array_that_must_be_coerced() { + let _ = pretty_env_logger::try_init(); + let schema = StructType::new(vec![ + StructField::new( + "id".to_string(), + DeltaDataType::Primitive(PrimitiveType::Integer), + true, + ), + StructField::new( + "temp".to_string(), + DeltaDataType::Primitive(PrimitiveType::Integer), + true, + ), + StructField::new( + "items".to_string(), + DeltaDataType::Array(Box::new(crate::kernel::ArrayType::new( + DeltaDataType::LONG, + false, + ))), + true, + ), + ]); + let arrow_schema: ArrowSchema = (&schema).try_into().unwrap(); + + // Create the first batch + let arrow_field = Field::new("element", DataType::Int64, false); + let list_array = ListArray::new_null(arrow_field.clone().into(), 2); + let batch = RecordBatch::try_new( + Arc::new(arrow_schema.clone()), + vec![ + Arc::new(Int32Array::from(vec![Some(0), Some(1)])), + Arc::new(Int32Array::from(vec![Some(30), Some(31)])), + Arc::new(list_array), + ], + ) + .expect("Failed to create record batch"); + let _ = arrow::util::pretty::print_batches(&[batch.clone()]); + + let table = DeltaOps::new_in_memory() + .create() + .with_columns(schema.fields().cloned()) + .await + .unwrap(); + assert_eq!(table.version(), 0); + + let table = DeltaOps(table) + .write(vec![batch]) + .await + .expect("Failed to write first batch"); + assert_eq!(table.version(), 1); + // Completed the first creation/write + let (table, _metrics) = DeltaOps(table) .update() .with_predicate(col("id").eq(lit(1))) - .with_update("items", make_array(vec![lit(100)])) + .with_update("items", "[100]".to_string()) .await .unwrap(); assert_eq!(table.version(), 2);