Skip to content

Commit

Permalink
issue apache#12342 partially solved
Browse files Browse the repository at this point in the history
  • Loading branch information
ashrafshaik09 committed Sep 6, 2024
1 parent 9d819e1 commit e9d152e
Show file tree
Hide file tree
Showing 2 changed files with 157 additions and 0 deletions.
14 changes: 14 additions & 0 deletions datafusion/expr/src/type_coercion/functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -729,6 +729,8 @@ fn coerced_from<'a>(
{
Some(type_into.clone())
}
// Added this new case to handle Date32 + Int64
(Date32, Int64) | (Int64, Date32) => Some(Date32),
_ => None,
}
}
Expand Down Expand Up @@ -924,4 +926,16 @@ mod tests {
Some(type_into.clone())
);
}

#[test]
fn test_date32_int64_coercion() {
assert_eq!(
coerced_from(&DataType::Date32, &DataType::Int64),
Some(DataType::Date32)
);
assert_eq!(
coerced_from(&DataType::Int64, &DataType::Date32),
Some(DataType::Date32)
);
}
}
143 changes: 143 additions & 0 deletions datafusion/physical-expr/src/expressions/binary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -645,6 +645,112 @@ impl BinaryExpr {
}
}
}

fn evaluate_date32_int64_addition(
&self,
left: &ArrayRef,
right: &ArrayRef,
) -> Result<ArrayRef> {
let date_array = left.as_any().downcast_ref::<Date32Array>().unwrap();
let int_array = right.as_any().downcast_ref::<Int64Array>().unwrap();
let result = date_array
.iter()
.zip(int_array.iter())
.map(|(date, days)| {
date.and_then(|d| days.map(|days| d.checked_add_days(days as i32).unwrap_or(d)))
})
.collect::<Date32Array>();
Ok(Arc::new(result) as ArrayRef)
}

pub fn evaluate(&self, batch: &RecordBatch) -> Result<ColumnarValue> {
let left = self.left.evaluate(batch)?;
let right = self.right.evaluate(batch)?;

match (left.data_type(), right.data_type()) {
(DataType::Date32, DataType::Int64) => {
let result = self.evaluate_date32_int64_addition(
left.into_array(batch.num_rows())?,
right.into_array(batch.num_rows())?,
)?;
Ok(ColumnarValue::Array(result))
}
(DataType::Int64, DataType::Date32) => {
let result = self.evaluate_date32_int64_addition(
right.into_array(batch.num_rows())?,
left.into_array(batch.num_rows())?,
)?;
Ok(ColumnarValue::Array(result))
}
_ => {
let left_data_type = left.data_type();
let right_data_type = right.data_type();

let schema = batch.schema();
let input_schema = schema.as_ref();

if left_data_type.is_nested() {
if right_data_type != left_data_type {
return internal_err!("type mismatch");
}
return apply_cmp_for_nested(self.op, &left, &right);
}

match self.op {
Operator::Plus if self.fail_on_overflow => return apply(&left, &right, add),
Operator::Plus => return apply(&left, &right, add_wrapping),
Operator::Minus if self.fail_on_overflow => return apply(&left, &right, sub),
Operator::Minus => return apply(&left, &right, sub_wrapping),
Operator::Multiply if self.fail_on_overflow => return apply(&left, &right, mul),
Operator::Multiply => return apply(&left, &right, mul_wrapping),
Operator::Divide => return apply(&left, &right, div),
Operator::Modulo => return apply(&left, &right, rem),
Operator::Eq => return apply_cmp(&left, &right, eq),
Operator::NotEq => return apply_cmp(&left, &right, neq),
Operator::Lt => return apply_cmp(&left, &right, lt),
Operator::Gt => return apply_cmp(&left, &right, gt),
Operator::LtEq => return apply_cmp(&left, &right, lt_eq),
Operator::GtEq => return apply_cmp(&left, &right, gt_eq),
Operator::IsDistinctFrom => return apply_cmp(&left, &right, distinct),
Operator::IsNotDistinctFrom => return apply_cmp(&left, &right, not_distinct),
Operator::LikeMatch => return apply_cmp(&left, &right, like),
Operator::ILikeMatch => return apply_cmp(&left, &right, ilike),
Operator::NotLikeMatch => return apply_cmp(&left, &right, nlike),
Operator::NotILikeMatch => return apply_cmp(&left, &right, nilike),
_ => {}
}

let result_type = self.data_type(input_schema)?;

// Attempt to use special kernels if one input is scalar and the other is an array
let scalar_result = match (&left, &right) {
(ColumnarValue::Array(array), ColumnarValue::Scalar(scalar)) => {
// if left is array and right is literal(not NULL) - use scalar operations
if scalar.is_null() {
None
} else {
self.evaluate_array_scalar(array, scalar.clone())?.map(|r| {
r.and_then(|a| to_result_type_array(&self.op, a, &result_type))
})
}
}
(_, _) => None, // default to array implementation
};

if let Some(result) = scalar_result {
return result.map(ColumnarValue::Array);
}

// if both arrays or both literals - extract arrays and continue execution
let (left, right) = (
left.into_array(batch.num_rows())?,
right.into_array(batch.num_rows())?,
);
self.evaluate_with_resolved_args(left, &left_data_type, right, &right_data_type)
.map(ColumnarValue::Array)
}
}
}
}

fn concat_elements(left: Arc<dyn Array>, right: Arc<dyn Array>) -> Result<ArrayRef> {
Expand Down Expand Up @@ -913,6 +1019,18 @@ mod tests {
DataType::Boolean,
[true, false],
);
test_coercion!(
Date32Array,
DataType::Date32,
vec![0, 31], // 1970-01-01, 1970-02-01
Int64Array,
DataType::Int64,
vec![1, 365],
Operator::Plus,
Date32Array,
DataType::Date32,
[1, 396], // 1970-01-02, 1971-02-01
);
test_coercion!(
StringArray,
DataType::Utf8,
Expand Down Expand Up @@ -4226,4 +4344,29 @@ mod tests {
.contains("Overflow happened on: 2147483647 * 2"));
Ok(())
}
#[test]
fn test_date32_int64_addition() -> Result<()> {
let schema = Schema::new(vec![
Field::new("date", DataType::Date32, true),
Field::new("days", DataType::Int64, true),
]);
let date_array = Arc::new(Date32Array::from(vec![Some(0), Some(31), None]));
let days_array = Arc::new(Int64Array::from(vec![Some(1), Some(365), None]));

let expr = BinaryExpr::new(
Arc::new(Column::new("date", 0)),
Operator::Plus,
Arc::new(Column::new("days", 1)),
);

let result = expr.evaluate(&RecordBatch::try_new(
Arc::new(schema),
vec![date_array, days_array],
)?)?;

let expected = Arc::new(Date32Array::from(vec![Some(1), Some(396), None]));
assert_eq!(result.as_ref(), expected.as_ref());

Ok(())
}
}

0 comments on commit e9d152e

Please sign in to comment.