From 0cd7bc7f7505c7d2abee742cb02d149fa916f381 Mon Sep 17 00:00:00 2001 From: Dadepo Aderemi <272535+dadepo@users.noreply.github.com> Date: Sun, 20 Oct 2024 15:35:24 +0400 Subject: [PATCH] Added Postgres sign function --- src/postgres/math_udfs.rs | 147 +++++++++++++++++++++++++++++++++++++- src/postgres/mod.rs | 3 +- supports/postgres.md | 2 +- 3 files changed, 149 insertions(+), 3 deletions(-) diff --git a/src/postgres/math_udfs.rs b/src/postgres/math_udfs.rs index 03d33d2..25ad02f 100644 --- a/src/postgres/math_udfs.rs +++ b/src/postgres/math_udfs.rs @@ -1,6 +1,6 @@ use std::sync::Arc; -use datafusion::arrow::array::{Array, ArrayRef, Float64Array, Int64Array}; +use datafusion::arrow::array::{Array, ArrayRef, Float64Array, Int64Array, Int8Array}; use datafusion::arrow::datatypes::DataType; use datafusion::arrow::datatypes::DataType::{Float64, Int64, UInt64}; @@ -889,6 +889,99 @@ impl ScalarUDFImpl for Mod { } } +#[derive(Debug)] +pub struct Sign { + signature: Signature, +} + +impl Sign { + pub fn new() -> Self { + Self { + signature: Signature::uniform(1, vec![Int64, UInt64, Float64], Volatility::Volatile), + } + } +} + +impl ScalarUDFImpl for Sign { + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn name(&self) -> &str { + "sign" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(DataType::Int8) + } + + fn invoke(&self, args: &[ColumnarValue]) -> Result { + let args = ColumnarValue::values_to_arrays(args)?; + let mut int8array_builder = Int8Array::builder(args[0].len()); + match args.first() { + Some(values) => { + match values.data_type() { + DataType::Int64 | DataType::Int32 | DataType::Int16 | DataType::Int8 => { + let values = datafusion::common::cast::as_int64_array(values)?; + values.iter().for_each(|value| { + match value { + Some(value) => int8array_builder.append_value(value.cmp(&0_i64) as i8), + None => int8array_builder.append_null() + } + }); + }, + DataType::UInt64 | DataType::UInt32 | DataType::UInt16 | DataType::UInt8 => { + let values = datafusion::common::cast::as_uint64_array(values)?; + values.iter().for_each(|value| { + match value { + Some(value) => int8array_builder.append_value(value.cmp(&0_u64) as i8), + None => int8array_builder.append_null() + } + }); + }, + DataType::Float64 | DataType::Float32 | DataType::Float16 => { + let values = datafusion::common::cast::as_float64_array(values)?; + values.iter().for_each(|value| { + match value { + Some(value) => { + if value == 0_f64 { + int8array_builder.append_value(0) + } else if value > 0_f64 { + int8array_builder.append_value(1) + } else { + int8array_builder.append_value(-1) + } + }, + None => int8array_builder.append_null() + } + }); + }, + _ => { + return Err(DataFusionError::Internal( + "No function matches the given name and argument types. You might need to add explicit type casts" + .to_string(), + )) + } + } + } + None => { + return Err(DataFusionError::Internal( + "No function matches the given name and argument types. You might need to add explicit type casts" + .to_string(), + )) + } + } + + Ok(ColumnarValue::Array( + Arc::new(int8array_builder.finish()) as ArrayRef + )) + } +} + #[cfg(feature = "postgres")] #[cfg(test)] mod tests { @@ -1349,6 +1442,58 @@ mod tests { Ok(()) } + #[tokio::test] + async fn test_sign() -> Result<()> { + let ctx = register_udfs_for_test()?; + + let df = ctx.sql("select index, sign(uint) as uint, sign(int) as int, sign(float) as float from maths_table ORDER BY index ASC").await?; + let batches = df.clone().collect().await?; + + let expected: Vec<&str> = r#" ++-------+------+-----+-------+ +| index | uint | int | float | ++-------+------+-----+-------+ +| 1 | 1 | -1 | 1 | +| 2 | 1 | 1 | 1 | +| 3 | | | | ++-------+------+-----+-------+"# + .split('\n') + .filter_map(|input| { + if input.is_empty() { + None + } else { + Some(input.trim()) + } + }) + .collect(); + + assert_batches_sorted_eq!(expected, &batches); + + let df = ctx + .sql("select sign(0) as uint, sign(0.0) as float") + .await?; + let batches = df.clone().collect().await?; + + let expected: Vec<&str> = r#" ++------+-------+ +| uint | float | ++------+-------+ +| 0 | 0 | ++------+-------+"# + .split('\n') + .filter_map(|input| { + if input.is_empty() { + None + } else { + Some(input.trim()) + } + }) + .collect(); + + assert_batches_sorted_eq!(expected, &batches); + Ok(()) + } + fn register_udfs_for_test() -> Result { let ctx = set_up_maths_data_test()?; register_postgres_udfs(&ctx)?; diff --git a/src/postgres/mod.rs b/src/postgres/mod.rs index d7a80e2..fa6863c 100644 --- a/src/postgres/mod.rs +++ b/src/postgres/mod.rs @@ -6,7 +6,7 @@ use datafusion::logical_expr::ScalarUDF; use datafusion::prelude::SessionContext; use crate::postgres::math_udfs::{ - Acosd, Asind, Atand, Ceiling, Cosd, Cotd, Div, Erf, Erfc, Mod, RandomNormal, Sind, Tand, + Acosd, Asind, Atand, Ceiling, Cosd, Cotd, Div, Erf, Erfc, Mod, RandomNormal, Sign, Sind, Tand, }; use crate::postgres::network_udfs::{ Broadcast, Family, Host, HostMask, InetMerge, InetSameFamily, MaskLen, Netmask, Network, @@ -36,6 +36,7 @@ fn register_math_udfs(ctx: &SessionContext) -> Result<()> { ctx.register_udf(ScalarUDF::from(Erfc::new())); ctx.register_udf(ScalarUDF::from(RandomNormal::new())); ctx.register_udf(ScalarUDF::from(Mod::new())); + ctx.register_udf(ScalarUDF::from(Sign::new())); Ok(()) } diff --git a/supports/postgres.md b/supports/postgres.md index 7dffadc..39d68bd 100644 --- a/supports/postgres.md +++ b/supports/postgres.md @@ -30,7 +30,7 @@ https://www.postgresql.org/docs/16/functions-math.html | ❓ | min_scale ( numeric ) → integer | Minimum scale (number of fractional decimal digits) needed to represent the supplied value precisely | min_scale(8.4100) → 2 | | ✅ | mod ( y numeric_type, x numeric_type ) → numeric_type | Remainder of y/x; available for smallint, integer, bigint, and numeric | mod(9, 4) → 1 | | ❓ | scale ( numeric ) → integer | Scale of the argument (the number of decimal digits in the fractional part) | scale(8.4100) → 4 | -| ❓ | sign ( numeric ) → numeric | Sign of the argument (-1, 0, or +1) | sign(-8.4) → -1 | +| ✅ | sign ( numeric ) → numeric | Sign of the argument (-1, 0, or +1) | sign(-8.4) → -1 | | ❓ | trim_scale ( numeric ) → numeric | Reduces the value's scale (number of fractional decimal digits) by removing trailing zeroes | trim_scale(8.4100) → 8.41 | | ❓ | width_bucket ( operand numeric, low numeric, high numeric, count integer ) → integer | Returns the number of the bucket in which operand falls in a histogram having count equal-width buckets spanning the range low to high. Returns 0 or count+1 for an input outside that range. | width_bucket(5.35, 0.024, 10.06, 5) → 3 | | ✅ | random_normal ( [ mean double precision [, stddev double precision ]] ) → double precision | Returns a random value from the normal distribution with the given parameters; mean defaults to 0.0 and stddev defaults to 1.0 | random_normal(0.0, 1.0) → 0.051285419 |