diff --git a/datafusion/expr/src/type_coercion/functions.rs b/datafusion/expr/src/type_coercion/functions.rs index d30d202df050..bacc12c022ee 100644 --- a/datafusion/expr/src/type_coercion/functions.rs +++ b/datafusion/expr/src/type_coercion/functions.rs @@ -729,6 +729,8 @@ fn coerced_from<'a>( { Some(type_into.clone()) } + // Added this new case to handle Date32 + Int64 + (Date32, Int64) | (Int64, Date32) => Some(Date32), _ => None, } } @@ -924,4 +926,16 @@ mod tests { Some(type_into.clone()) ); } + + #[test] + fn test_date32_int64_coercion() { + assert_eq!( + coerced_from(&DataType::Date32, &DataType::Int64), + Some(DataType::Date32) + ); + assert_eq!( + coerced_from(&DataType::Int64, &DataType::Date32), + Some(DataType::Date32) + ); + } } diff --git a/datafusion/physical-expr/src/expressions/binary.rs b/datafusion/physical-expr/src/expressions/binary.rs index 08c133d7193a..33b49ba949eb 100644 --- a/datafusion/physical-expr/src/expressions/binary.rs +++ b/datafusion/physical-expr/src/expressions/binary.rs @@ -645,6 +645,112 @@ impl BinaryExpr { } } } + + fn evaluate_date32_int64_addition( + &self, + left: &ArrayRef, + right: &ArrayRef, + ) -> Result { + let date_array = left.as_any().downcast_ref::().unwrap(); + let int_array = right.as_any().downcast_ref::().unwrap(); + let result = date_array + .iter() + .zip(int_array.iter()) + .map(|(date, days)| { + date.and_then(|d| days.map(|days| d.checked_add_days(days as i32).unwrap_or(d))) + }) + .collect::(); + Ok(Arc::new(result) as ArrayRef) + } + + pub fn evaluate(&self, batch: &RecordBatch) -> Result { + let left = self.left.evaluate(batch)?; + let right = self.right.evaluate(batch)?; + + match (left.data_type(), right.data_type()) { + (DataType::Date32, DataType::Int64) => { + let result = self.evaluate_date32_int64_addition( + left.into_array(batch.num_rows())?, + right.into_array(batch.num_rows())?, + )?; + Ok(ColumnarValue::Array(result)) + } + (DataType::Int64, DataType::Date32) => { + let result = self.evaluate_date32_int64_addition( + right.into_array(batch.num_rows())?, + left.into_array(batch.num_rows())?, + )?; + Ok(ColumnarValue::Array(result)) + } + _ => { + let left_data_type = left.data_type(); + let right_data_type = right.data_type(); + + let schema = batch.schema(); + let input_schema = schema.as_ref(); + + if left_data_type.is_nested() { + if right_data_type != left_data_type { + return internal_err!("type mismatch"); + } + return apply_cmp_for_nested(self.op, &left, &right); + } + + match self.op { + Operator::Plus if self.fail_on_overflow => return apply(&left, &right, add), + Operator::Plus => return apply(&left, &right, add_wrapping), + Operator::Minus if self.fail_on_overflow => return apply(&left, &right, sub), + Operator::Minus => return apply(&left, &right, sub_wrapping), + Operator::Multiply if self.fail_on_overflow => return apply(&left, &right, mul), + Operator::Multiply => return apply(&left, &right, mul_wrapping), + Operator::Divide => return apply(&left, &right, div), + Operator::Modulo => return apply(&left, &right, rem), + Operator::Eq => return apply_cmp(&left, &right, eq), + Operator::NotEq => return apply_cmp(&left, &right, neq), + Operator::Lt => return apply_cmp(&left, &right, lt), + Operator::Gt => return apply_cmp(&left, &right, gt), + Operator::LtEq => return apply_cmp(&left, &right, lt_eq), + Operator::GtEq => return apply_cmp(&left, &right, gt_eq), + Operator::IsDistinctFrom => return apply_cmp(&left, &right, distinct), + Operator::IsNotDistinctFrom => return apply_cmp(&left, &right, not_distinct), + Operator::LikeMatch => return apply_cmp(&left, &right, like), + Operator::ILikeMatch => return apply_cmp(&left, &right, ilike), + Operator::NotLikeMatch => return apply_cmp(&left, &right, nlike), + Operator::NotILikeMatch => return apply_cmp(&left, &right, nilike), + _ => {} + } + + let result_type = self.data_type(input_schema)?; + + // Attempt to use special kernels if one input is scalar and the other is an array + let scalar_result = match (&left, &right) { + (ColumnarValue::Array(array), ColumnarValue::Scalar(scalar)) => { + // if left is array and right is literal(not NULL) - use scalar operations + if scalar.is_null() { + None + } else { + self.evaluate_array_scalar(array, scalar.clone())?.map(|r| { + r.and_then(|a| to_result_type_array(&self.op, a, &result_type)) + }) + } + } + (_, _) => None, // default to array implementation + }; + + if let Some(result) = scalar_result { + return result.map(ColumnarValue::Array); + } + + // if both arrays or both literals - extract arrays and continue execution + let (left, right) = ( + left.into_array(batch.num_rows())?, + right.into_array(batch.num_rows())?, + ); + self.evaluate_with_resolved_args(left, &left_data_type, right, &right_data_type) + .map(ColumnarValue::Array) + } + } + } } fn concat_elements(left: Arc, right: Arc) -> Result { @@ -913,6 +1019,18 @@ mod tests { DataType::Boolean, [true, false], ); + test_coercion!( + Date32Array, + DataType::Date32, + vec![0, 31], // 1970-01-01, 1970-02-01 + Int64Array, + DataType::Int64, + vec![1, 365], + Operator::Plus, + Date32Array, + DataType::Date32, + [1, 396], // 1970-01-02, 1971-02-01 + ); test_coercion!( StringArray, DataType::Utf8, @@ -4226,4 +4344,29 @@ mod tests { .contains("Overflow happened on: 2147483647 * 2")); Ok(()) } + #[test] + fn test_date32_int64_addition() -> Result<()> { + let schema = Schema::new(vec![ + Field::new("date", DataType::Date32, true), + Field::new("days", DataType::Int64, true), + ]); + let date_array = Arc::new(Date32Array::from(vec![Some(0), Some(31), None])); + let days_array = Arc::new(Int64Array::from(vec![Some(1), Some(365), None])); + + let expr = BinaryExpr::new( + Arc::new(Column::new("date", 0)), + Operator::Plus, + Arc::new(Column::new("days", 1)), + ); + + let result = expr.evaluate(&RecordBatch::try_new( + Arc::new(schema), + vec![date_array, days_array], + )?)?; + + let expected = Arc::new(Date32Array::from(vec![Some(1), Some(396), None])); + assert_eq!(result.as_ref(), expected.as_ref()); + + Ok(()) + } }