diff --git a/src/postgres/math_udfs.rs b/src/postgres/math_udfs.rs index d3bfdaf..03d33d2 100644 --- a/src/postgres/math_udfs.rs +++ b/src/postgres/math_udfs.rs @@ -830,6 +830,65 @@ impl ScalarUDFImpl for RandomNormal { } } +#[derive(Debug)] +pub struct Mod { + signature: Signature, +} + +impl Mod { + pub fn new() -> Self { + Self { + signature: Signature::variadic(vec![Int64, UInt64], Volatility::Volatile), + } + } +} + +/// Remainder of y/x +impl ScalarUDFImpl for Mod { + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn name(&self) -> &str { + "mod" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(Int64) + } + + fn invoke(&self, args: &[ColumnarValue]) -> Result { + if args.len() > 2_usize { + return Err(DataFusionError::Internal( + "No function matches the given name and argument types.".to_string(), + )); + } + + let args = ColumnarValue::values_to_arrays(args)?; + let first_values = datafusion::common::cast::as_int64_array(&args[0])?; + let second_values = datafusion::common::cast::as_int64_array(&args[1])?; + + let mut int64array_builder = Int64Array::builder(args[0].len()); + + first_values + .iter() + .flatten() + .zip(second_values.iter().flatten()) + .try_for_each(|(first, second)| { + int64array_builder.append_value(first % second); + Ok::<(), DataFusionError>(()) + })?; + + Ok(ColumnarValue::Array( + Arc::new(int64array_builder.finish()) as ArrayRef + )) + } +} + #[cfg(feature = "postgres")] #[cfg(test)] mod tests { @@ -1274,6 +1333,22 @@ mod tests { Ok(()) } + #[tokio::test] + async fn test_mod() -> Result<()> { + let ctx = register_udfs_for_test()?; + let df = ctx.sql("select mod(9, 4) as col_result").await?; + + let batches = df.clone().collect().await?; + + let columns = &batches.first().unwrap().column(0); + let result = as_int64_array(columns)?; + let result = result.value(0); + + assert_eq!(result, 1_i64); + + Ok(()) + } + fn register_udfs_for_test() -> Result { let ctx = set_up_maths_data_test()?; register_postgres_udfs(&ctx)?; diff --git a/src/postgres/mod.rs b/src/postgres/mod.rs index f73187d..d7a80e2 100644 --- a/src/postgres/mod.rs +++ b/src/postgres/mod.rs @@ -6,7 +6,7 @@ use datafusion::logical_expr::ScalarUDF; use datafusion::prelude::SessionContext; use crate::postgres::math_udfs::{ - Acosd, Asind, Atand, Ceiling, Cosd, Cotd, Div, Erf, Erfc, RandomNormal, Sind, Tand, + Acosd, Asind, Atand, Ceiling, Cosd, Cotd, Div, Erf, Erfc, Mod, RandomNormal, Sind, Tand, }; use crate::postgres::network_udfs::{ Broadcast, Family, Host, HostMask, InetMerge, InetSameFamily, MaskLen, Netmask, Network, @@ -35,6 +35,7 @@ fn register_math_udfs(ctx: &SessionContext) -> Result<()> { ctx.register_udf(ScalarUDF::from(Erf::new())); ctx.register_udf(ScalarUDF::from(Erfc::new())); ctx.register_udf(ScalarUDF::from(RandomNormal::new())); + ctx.register_udf(ScalarUDF::from(Mod::new())); Ok(()) }