From ee6bd09e9179b6dd32452a3e1ef66f470c314d63 Mon Sep 17 00:00:00 2001 From: Dadepo Aderemi Date: Sun, 3 Dec 2023 05:35:24 +0400 Subject: [PATCH] Added postgres erfc --- .github/workflows/rust.yml | 2 +- src/postgres/math_udfs.rs | 82 ++++++++++++++++++++++++++++++++++++++ src/postgres/mod.rs | 16 +++++++- supports/postgres.md | 2 +- 4 files changed, 99 insertions(+), 3 deletions(-) diff --git a/.github/workflows/rust.yml b/.github/workflows/rust.yml index 0bce216..0b075d4 100644 --- a/.github/workflows/rust.yml +++ b/.github/workflows/rust.yml @@ -113,4 +113,4 @@ jobs: - name: Build run: cargo build --verbose - name: Run tests - run: cargo test --features "postgres" --verbose + run: cargo test --all-features --verbose diff --git a/src/postgres/math_udfs.rs b/src/postgres/math_udfs.rs index 003cb2d..046cbd6 100644 --- a/src/postgres/math_udfs.rs +++ b/src/postgres/math_udfs.rs @@ -88,6 +88,57 @@ pub fn erf(args: &[ArrayRef]) -> Result { Ok(Arc::new(float64array_builder.finish()) as ArrayRef) } +/// Complementary error function +pub fn erfc(args: &[ArrayRef]) -> Result { + let column_data = &args[0]; + let data = column_data.into_data(); + let data_type = data.data_type(); + + let mut float64array_builder = Float64Array::builder(args[0].len()); + match data_type { + DataType::Float64 => { + let values = datafusion::common::cast::as_float64_array(&args[0])?; + values.iter().try_for_each(|value| { + if let Some(value) = value { + float64array_builder.append_value(libm::erfc(value)) + } else { + float64array_builder.append_null(); + } + Ok::<(), DataFusionError>(()) + })?; + } + DataType::Int64 => { + let values = datafusion::common::cast::as_int64_array(&args[0])?; + values.iter().try_for_each(|value| { + if let Some(value) = value { + float64array_builder.append_value(libm::erfc(value as f64)) + } else { + float64array_builder.append_null(); + } + Ok::<(), DataFusionError>(()) + })?; + } + DataType::UInt64 => { + let values = datafusion::common::cast::as_uint64_array(&args[0])?; + values.iter().try_for_each(|value| { + if let Some(value) = value { + float64array_builder.append_value(libm::erfc(value as f64)) + } else { + float64array_builder.append_null(); + } + Ok::<(), DataFusionError>(()) + })?; + } + t => { + return Err(DataFusionError::Internal(format!( + "Unsupported type {t} for erf function" + ))) + } + }; + + Ok(Arc::new(float64array_builder.finish()) as ArrayRef) +} + #[cfg(feature = "postgres")] #[cfg(test)] mod tests { @@ -178,6 +229,37 @@ mod tests { Ok(()) } + #[tokio::test] + async fn test_erfc() -> Result<()> { + let ctx = register_udfs_for_test()?; + let df = ctx.sql("select index, uint as uint, int as int, float as float from maths_table ORDER BY index ASC").await?; + df.show().await?; + + let df = ctx.sql("select index, erfc(uint) as uint, erfc(int) as int, erfc(float) as float from maths_table ORDER BY index ASC").await?; + + let batches = df.clone().collect().await?; + + let expected: Vec<&str> = r#" ++-------+-------------------------+-------------------------+----------------------+ +| index | uint | int | float | ++-------+-------------------------+-------------------------+----------------------+ +| 1 | 0.004677734981047266 | 1.9953222650189528 | 0.15729920705028513 | +| 2 | 0.000022090496998585438 | 0.000022090496998585438 | 3.057709796438165e-6 | +| 3 | | | | ++-------+-------------------------+-------------------------+----------------------+"# + .split('\n') + .filter_map(|input| { + if input.is_empty() { + None + } else { + Some(input.trim()) + } + }) + .collect(); + assert_batches_sorted_eq!(expected, &batches); + Ok(()) + } + fn register_udfs_for_test() -> Result { let ctx = set_up_maths_data_test()?; register_postgres_udfs(&ctx)?; diff --git a/src/postgres/mod.rs b/src/postgres/mod.rs index baa7133..572fcde 100644 --- a/src/postgres/mod.rs +++ b/src/postgres/mod.rs @@ -2,7 +2,7 @@ use std::sync::Arc; -use crate::postgres::math_udfs::{ceiling, div, erf}; +use crate::postgres::math_udfs::{ceiling, div, erf, erfc}; use crate::postgres::network_udfs::{ broadcast, family, host, hostmask, inet_merge, inet_same_family, masklen, netmask, network, set_masklen, @@ -25,6 +25,7 @@ pub fn register_postgres_udfs(ctx: &SessionContext) -> Result<()> { fn register_math_udfs(ctx: &SessionContext) -> Result<()> { register_ceiling(ctx); register_erf(ctx); + register_erfc(ctx); register_div(ctx); Ok(()) } @@ -55,6 +56,19 @@ fn register_erf(ctx: &SessionContext) { ctx.register_udf(erf_udf); } +fn register_erfc(ctx: &SessionContext) { + let erfc_udf = make_scalar_function(erfc); + let return_type: ReturnTypeFunction = Arc::new(move |_| Ok(Arc::new(Float64))); + let erfc_udf = ScalarUDF::new( + "erfc", + &Signature::uniform(1, vec![Int64, UInt64, Float64], Volatility::Immutable), + &return_type, + &erfc_udf, + ); + + ctx.register_udf(erfc_udf); +} + fn register_div(ctx: &SessionContext) { let udf = make_scalar_function(div); let return_type: ReturnTypeFunction = Arc::new(move |_| Ok(Arc::new(Int64))); diff --git a/supports/postgres.md b/supports/postgres.md index ee1f6bf..25f7c4d 100644 --- a/supports/postgres.md +++ b/supports/postgres.md @@ -26,7 +26,7 @@ https://www.postgresql.org/docs/16/functions-math.html | ✅︎ | ceiling ( numeric ) → numeric | Nearest integer greater than or equal to argument (same as ceil) | ceiling(95.3) → 96 | | ✅︎ | div ( y numeric, x numeric ) → numeric | Integer quotient of y/x (truncates towards zero) | div(9, 4) → 2 | | ✅︎ | erf ( double precision ) → double precision | Error function | erf(1.0) → 0.8427007929497149 | -| ❓ | erfc ( double precision ) → double precision | Complementary error function (1 - erf(x), without loss of precision for large inputs) | erfc(1.0) → 0.15729920705028513 | +| ✅ | erfc ( double precision ) → double precision | Complementary error function (1 - erf(x), without loss of precision for large inputs) | erfc(1.0) → 0.15729920705028513 | | ❓ | min_scale ( numeric ) → integer | Minimum scale (number of fractional decimal digits) needed to represent the supplied value precisely | min_scale(8.4100) → 2 | | ❓ | mod ( y numeric_type, x numeric_type ) → numeric_type | Remainder of y/x; available for smallint, integer, bigint, and numeric | mod(9, 4) → 1 | | ❓ | scale ( numeric ) → integer | Scale of the argument (the number of decimal digits in the fractional part) | scale(8.4100) → 4 |