Skip to content

Commit

Permalink
Support for mod
Browse files Browse the repository at this point in the history
  • Loading branch information
dadepo committed Aug 24, 2024
1 parent 905fde4 commit 57eb0b3
Show file tree
Hide file tree
Showing 2 changed files with 77 additions and 1 deletion.
75 changes: 75 additions & 0 deletions src/postgres/math_udfs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -830,6 +830,65 @@ impl ScalarUDFImpl for RandomNormal {
}
}

#[derive(Debug)]
pub struct Mod {
signature: Signature,
}

impl Mod {
pub fn new() -> Self {
Self {
signature: Signature::variadic(vec![Int64, UInt64], Volatility::Volatile),
}
}
}

/// Remainder of y/x
impl ScalarUDFImpl for Mod {
fn as_any(&self) -> &dyn std::any::Any {
self
}

fn name(&self) -> &str {
"mod"
}

fn signature(&self) -> &Signature {
&self.signature
}

fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
Ok(Int64)
}

fn invoke(&self, args: &[ColumnarValue]) -> Result<ColumnarValue> {
if args.len() > 2_usize {
return Err(DataFusionError::Internal(
"No function matches the given name and argument types.".to_string(),
));
}

let args = ColumnarValue::values_to_arrays(args)?;
let first_values = datafusion::common::cast::as_int64_array(&args[0])?;
let second_values = datafusion::common::cast::as_int64_array(&args[1])?;

let mut int64array_builder = Int64Array::builder(args[0].len());

first_values
.iter()
.flatten()
.zip(second_values.iter().flatten())
.try_for_each(|(first, second)| {
int64array_builder.append_value(first % second);
Ok::<(), DataFusionError>(())
})?;

Ok(ColumnarValue::Array(
Arc::new(int64array_builder.finish()) as ArrayRef
))
}
}

#[cfg(feature = "postgres")]
#[cfg(test)]
mod tests {
Expand Down Expand Up @@ -1274,6 +1333,22 @@ mod tests {
Ok(())
}

#[tokio::test]
async fn test_mod() -> Result<()> {
let ctx = register_udfs_for_test()?;
let df = ctx.sql("select mod(9, 4) as col_result").await?;

let batches = df.clone().collect().await?;

let columns = &batches.first().unwrap().column(0);
let result = as_int64_array(columns)?;
let result = result.value(0);

assert_eq!(result, 1_i64);

Ok(())
}

fn register_udfs_for_test() -> Result<SessionContext> {
let ctx = set_up_maths_data_test()?;
register_postgres_udfs(&ctx)?;
Expand Down
3 changes: 2 additions & 1 deletion src/postgres/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use datafusion::logical_expr::ScalarUDF;
use datafusion::prelude::SessionContext;

use crate::postgres::math_udfs::{
Acosd, Asind, Atand, Ceiling, Cosd, Cotd, Div, Erf, Erfc, RandomNormal, Sind, Tand,
Acosd, Asind, Atand, Ceiling, Cosd, Cotd, Div, Erf, Erfc, Mod, RandomNormal, Sind, Tand,
};
use crate::postgres::network_udfs::{
Broadcast, Family, Host, HostMask, InetMerge, InetSameFamily, MaskLen, Netmask, Network,
Expand Down Expand Up @@ -35,6 +35,7 @@ fn register_math_udfs(ctx: &SessionContext) -> Result<()> {
ctx.register_udf(ScalarUDF::from(Erf::new()));
ctx.register_udf(ScalarUDF::from(Erfc::new()));
ctx.register_udf(ScalarUDF::from(RandomNormal::new()));
ctx.register_udf(ScalarUDF::from(Mod::new()));
Ok(())
}

Expand Down

0 comments on commit 57eb0b3

Please sign in to comment.