Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Require all zero argument UDFs use Signature::Nullary, improve error messages #13871

Draft
wants to merge 9 commits into
base: main
Choose a base branch
from
Original file line number Diff line number Diff line change
Expand Up @@ -270,7 +270,7 @@ async fn scalar_udf_zero_params() -> Result<()> {

let get_100_udf = Simple0ArgsScalarUDF {
name: "get_100".to_string(),
signature: Signature::exact(vec![], Volatility::Immutable),
signature: Signature::nullary(Volatility::Immutable),
return_type: DataType::Int32,
};

Expand Down Expand Up @@ -1119,6 +1119,61 @@ async fn create_scalar_function_from_sql_statement() -> Result<()> {
Ok(())
}

#[tokio::test]
async fn test_valid_zero_argument_signatures() {
let signatures = vec![Signature::nullary(Volatility::Immutable)];
for signature in signatures {
let ctx = SessionContext::new();
let udf = ScalarFunctionWrapper {
name: "good_signature".to_string(),
expr: lit(1),
signature,
return_type: DataType::Int32,
};
ctx.register_udf(ScalarUDF::from(udf));
let results = ctx
.sql("select good_signature()")
.await
.unwrap()
.collect()
.await
.unwrap();
let expected = [
"+------------------+",
"| good_signature() |",
"+------------------+",
"| 1 |",
"+------------------+",
];
assert_batches_eq!(expected, &results);
}
}

#[tokio::test]
async fn test_invalid_zero_argument_signatures() {
let signatures = vec![
Signature::variadic(vec![], Volatility::Immutable),
Signature::variadic_any(Volatility::Immutable),
Signature::uniform(0, vec![], Volatility::Immutable),
Signature::coercible(vec![], Volatility::Immutable),
Signature::comparable(0, Volatility::Immutable),
Signature::any(0, Volatility::Immutable),
Signature::nullary(Volatility::Immutable),
];
for signature in signatures {
let ctx = SessionContext::new();
let udf = ScalarFunctionWrapper {
name: "bad_signature".to_string(),
expr: lit(1),
signature,
return_type: DataType::Int32,
};
ctx.register_udf(ScalarUDF::from(udf));
let results = ctx.sql("select bad_signature()").await.unwrap_err();
assert_contains!(results.to_string(), "Error during planning: Error during planning: bad_signature does not support zero arguments");
}
}

/// Saves whatever is passed to it as a scalar function
#[derive(Debug, Default)]
struct RecordingFunctionFactory {
Expand Down
43 changes: 0 additions & 43 deletions datafusion/expr-common/src/signature.rs
Original file line number Diff line number Diff line change
Expand Up @@ -342,7 +342,6 @@ impl TypeSignature {
/// Check whether 0 input argument is valid for given `TypeSignature`
pub fn supports_zero_argument(&self) -> bool {
match &self {
TypeSignature::Exact(vec) => vec.is_empty(),
TypeSignature::Nullary => true,
TypeSignature::OneOf(types) => types
.iter()
Expand Down Expand Up @@ -613,48 +612,6 @@ mod tests {

use super::*;

#[test]
fn supports_zero_argument_tests() {
// Testing `TypeSignature`s which supports 0 arg
let positive_cases = vec![
TypeSignature::Exact(vec![]),
TypeSignature::OneOf(vec![
TypeSignature::Exact(vec![DataType::Int8]),
TypeSignature::Nullary,
TypeSignature::Uniform(1, vec![DataType::Int8]),
]),
TypeSignature::Nullary,
];

for case in positive_cases {
assert!(
case.supports_zero_argument(),
"Expected {:?} to support zero arguments",
case
);
}

// Testing `TypeSignature`s which doesn't support 0 arg
let negative_cases = vec![
TypeSignature::Exact(vec![DataType::Utf8]),
TypeSignature::Uniform(1, vec![DataType::Float64]),
TypeSignature::Any(1),
TypeSignature::VariadicAny,
TypeSignature::OneOf(vec![
TypeSignature::Exact(vec![DataType::Int8]),
TypeSignature::Uniform(1, vec![DataType::Int8]),
]),
];

for case in negative_cases {
assert!(
!case.supports_zero_argument(),
"Expected {:?} not to support zero arguments",
case
);
}
}

#[test]
fn type_signature_partial_ord() {
// Test validates that partial ord is defined for TypeSignature and Signature.
Expand Down
103 changes: 29 additions & 74 deletions datafusion/expr/src/type_coercion/functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ pub fn data_types_with_scalar_udf(
if signature.type_signature.supports_zero_argument() {
return Ok(vec![]);
} else {
return plan_err!("{} does not support zero arguments.", func.name());
return plan_err!("{} does not support zero arguments. Please add TypeSignature::Nullary to your function's signature", func.name());
}
}

Expand Down Expand Up @@ -88,34 +88,27 @@ pub fn data_types_with_aggregate_udf(
current_types: &[DataType],
func: &AggregateUDF,
) -> Result<Vec<DataType>> {
let signature = func.signature();
let type_signature = &func.signature().type_signature;

if current_types.is_empty() {
if signature.type_signature.supports_zero_argument() {
if type_signature.supports_zero_argument() {
return Ok(vec![]);
} else {
return plan_err!("{} does not support zero arguments.", func.name());
}
}

let valid_types = get_valid_types_with_aggregate_udf(
&signature.type_signature,
current_types,
func,
)?;
let valid_types =
get_valid_types_with_aggregate_udf(type_signature, current_types, func)?;

if valid_types
.iter()
.any(|data_type| data_type == current_types)
{
return Ok(current_types.to_vec());
}

try_coerce_types(
func.name(),
valid_types,
current_types,
&signature.type_signature,
)
try_coerce_types(func.name(), valid_types, current_types, type_signature)
}

/// Performs type coercion for window function arguments.
Expand All @@ -129,31 +122,27 @@ pub fn data_types_with_window_udf(
current_types: &[DataType],
func: &WindowUDF,
) -> Result<Vec<DataType>> {
let signature = func.signature();
let type_signature = &func.signature().type_signature;

if current_types.is_empty() {
if signature.type_signature.supports_zero_argument() {
if type_signature.supports_zero_argument() {
return Ok(vec![]);
} else {
return plan_err!("{} does not support zero arguments.", func.name());
}
}

let valid_types =
get_valid_types_with_window_udf(&signature.type_signature, current_types, func)?;
get_valid_types_with_window_udf(type_signature, current_types, func)?;

if valid_types
.iter()
.any(|data_type| data_type == current_types)
{
return Ok(current_types.to_vec());
}

try_coerce_types(
func.name(),
valid_types,
current_types,
&signature.type_signature,
)
try_coerce_types(func.name(), valid_types, current_types, type_signature)
}

/// Performs type coercion for function arguments.
Expand All @@ -168,31 +157,28 @@ pub fn data_types(
current_types: &[DataType],
signature: &Signature,
) -> Result<Vec<DataType>> {
let type_signature = &signature.type_signature;

if current_types.is_empty() {
if signature.type_signature.supports_zero_argument() {
if type_signature.supports_zero_argument() {
return Ok(vec![]);
} else {
return plan_err!(
"signature {:?} does not support zero arguments.",
&signature.type_signature
"{} does not support zero arguments.",
function_name.as_ref()
);
}
}

let valid_types = get_valid_types(&signature.type_signature, current_types)?;
let valid_types = get_valid_types(type_signature, current_types)?;
if valid_types
.iter()
.any(|data_type| data_type == current_types)
{
return Ok(current_types.to_vec());
}

try_coerce_types(
function_name,
valid_types,
current_types,
&signature.type_signature,
)
try_coerce_types(function_name, valid_types, current_types, type_signature)
}

fn is_well_supported_signature(type_signature: &TypeSignature) -> bool {
Expand Down Expand Up @@ -335,6 +321,7 @@ fn get_valid_types_with_window_udf(
}

/// Returns a Vec of all possible valid argument types for the given signature.
/// Empty argument is checked by the caller so no need to re-check here.
fn get_valid_types(
signature: &TypeSignature,
current_types: &[DataType],
Expand Down Expand Up @@ -441,12 +428,6 @@ fn get_valid_types(
}

fn function_length_check(length: usize, expected_length: usize) -> Result<()> {
if length < 1 {
return plan_err!(
"The signature expected at least one argument but received {expected_length}"
);
}

if length != expected_length {
return plan_err!(
"The signature expected {length} arguments but received {expected_length}"
Expand Down Expand Up @@ -645,27 +626,16 @@ fn get_valid_types(

vec![new_types]
}
TypeSignature::Uniform(number, valid_types) => {
if *number == 0 {
return plan_err!("The function expected at least one argument");
}

valid_types
.iter()
.map(|valid_type| (0..*number).map(|_| valid_type.clone()).collect())
.collect()
}
TypeSignature::Uniform(number, valid_types) => valid_types
.iter()
.map(|valid_type| (0..*number).map(|_| valid_type.clone()).collect())
.collect(),
TypeSignature::UserDefined => {
return internal_err!(
"User-defined signature should be handled by function-specific coerce_types."
)
}
TypeSignature::VariadicAny => {
if current_types.is_empty() {
return plan_err!(
"The function expected at least one argument but received 0"
);
}
vec![current_types.to_vec()]
}
TypeSignature::Exact(valid_types) => vec![valid_types.clone()],
Expand Down Expand Up @@ -716,28 +686,13 @@ fn get_valid_types(
}
},
TypeSignature::Nullary => {
if !current_types.is_empty() {
return plan_err!(
"The function expected zero argument but received {}",
current_types.len()
);
}
vec![vec![]]
return plan_err!(
"Nullary expects zero arguments, but received {}",
current_types.len()
);
}
TypeSignature::Any(number) => {
if current_types.is_empty() {
return plan_err!(
"The function expected at least one argument but received 0"
);
}

if current_types.len() != *number {
return plan_err!(
"The function expected {} arguments but received {}",
number,
current_types.len()
);
}
function_length_check(current_types.len(), *number)?;
vec![(0..*number).map(|i| current_types[i].clone()).collect()]
}
TypeSignature::OneOf(types) => types
Expand Down
2 changes: 1 addition & 1 deletion datafusion/functions/src/core/version.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ impl Default for VersionFunc {
impl VersionFunc {
pub fn new() -> Self {
Self {
signature: Signature::exact(vec![], Volatility::Immutable),
signature: Signature::nullary(Volatility::Immutable),
}
}
}
Expand Down
2 changes: 1 addition & 1 deletion datafusion/functions/src/string/uuid.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ impl Default for UuidFunc {
impl UuidFunc {
pub fn new() -> Self {
Self {
signature: Signature::exact(vec![], Volatility::Volatile),
signature: Signature::nullary(Volatility::Volatile),
}
}
}
Expand Down
2 changes: 1 addition & 1 deletion datafusion/optimizer/src/common_subexpr_eliminate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1683,7 +1683,7 @@ mod test {
impl RandomStub {
fn new() -> Self {
Self {
signature: Signature::exact(vec![], Volatility::Volatile),
signature: Signature::nullary(Volatility::Volatile),
}
}
}
Expand Down
Loading
Loading