Skip to content

Commit 6d81418

Browse files
committed
Fix recursive flatten
The fix is covered by recursive flatten test case in array.slt
1 parent 6903259 commit 6d81418

File tree

3 files changed

+36
-2
lines changed

3 files changed

+36
-2
lines changed

datafusion/expr-common/src/signature.rs

+6
Original file line numberDiff line numberDiff line change
@@ -175,6 +175,9 @@ pub enum ArrayFunctionSignature {
175175
/// The function takes a single argument that must be a List/LargeList/FixedSizeList
176176
/// or something that can be coerced to one of those types.
177177
Array,
178+
/// A function takes a single argument that must be a List/LargeList/FixedSizeList
179+
/// which gets coerced to List, with element type recursively coerced to List too if it is list-like.
180+
RecursiveArray,
178181
/// Specialized Signature for MapArray
179182
/// The function takes a single argument that must be a MapArray
180183
MapArray,
@@ -198,6 +201,9 @@ impl std::fmt::Display for ArrayFunctionSignature {
198201
ArrayFunctionSignature::Array => {
199202
write!(f, "array")
200203
}
204+
ArrayFunctionSignature::RecursiveArray => {
205+
write!(f, "recursive_array")
206+
}
201207
ArrayFunctionSignature::MapArray => {
202208
write!(f, "map_array")
203209
}

datafusion/expr/src/type_coercion/functions.rs

+21
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ use arrow::{
2121
compute::can_cast_types,
2222
datatypes::{DataType, TimeUnit},
2323
};
24+
use datafusion_common::utils::coerced_fixed_size_list_to_list;
2425
use datafusion_common::{
2526
exec_err, internal_datafusion_err, internal_err, plan_err,
2627
types::{LogicalType, NativeType},
@@ -414,6 +415,7 @@ fn get_valid_types(
414415
_ => Ok(vec![vec![]]),
415416
}
416417
}
418+
417419
fn array(array_type: &DataType) -> Option<DataType> {
418420
match array_type {
419421
DataType::List(_) => Some(array_type.clone()),
@@ -424,6 +426,18 @@ fn get_valid_types(
424426
}
425427
}
426428

429+
fn recursive_array(array_type: &DataType) -> Option<DataType> {
430+
match array_type {
431+
DataType::List(_)
432+
| DataType::LargeList(_)
433+
| DataType::FixedSizeList(_, _) => {
434+
let array_type = coerced_fixed_size_list_to_list(array_type);
435+
Some(array_type)
436+
}
437+
_ => None,
438+
}
439+
}
440+
427441
fn function_length_check(length: usize, expected_length: usize) -> Result<()> {
428442
if length < 1 {
429443
return plan_err!(
@@ -651,6 +665,13 @@ fn get_valid_types(
651665
array(&current_types[0])
652666
.map_or_else(|| vec![vec![]], |array_type| vec![vec![array_type]])
653667
}
668+
ArrayFunctionSignature::RecursiveArray => {
669+
if current_types.len() != 1 {
670+
return Ok(vec![vec![]]);
671+
}
672+
recursive_array(&current_types[0])
673+
.map_or_else(|| vec![vec![]], |array_type| vec![vec![array_type]])
674+
}
654675
ArrayFunctionSignature::MapArray => {
655676
if current_types.len() != 1 {
656677
return Ok(vec![vec![]]);

datafusion/functions-nested/src/flatten.rs

+9-2
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,8 @@ use datafusion_common::cast::{
2828
use datafusion_common::{exec_err, Result};
2929
use datafusion_expr::scalar_doc_sections::DOC_SECTION_ARRAY;
3030
use datafusion_expr::{
31-
ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility,
31+
ArrayFunctionSignature, ColumnarValue, Documentation, ScalarUDFImpl, Signature,
32+
TypeSignature, Volatility,
3233
};
3334
use std::any::Any;
3435
use std::sync::{Arc, OnceLock};
@@ -56,7 +57,13 @@ impl Default for Flatten {
5657
impl Flatten {
5758
pub fn new() -> Self {
5859
Self {
59-
signature: Signature::array(Volatility::Immutable),
60+
signature: Signature {
61+
// TODO (https://github.com/apache/datafusion/issues/13757) flatten should be single-step, not recursive
62+
type_signature: TypeSignature::ArraySignature(
63+
ArrayFunctionSignature::RecursiveArray,
64+
),
65+
volatility: Volatility::Immutable,
66+
},
6067
aliases: vec![],
6168
}
6269
}

0 commit comments

Comments
 (0)