From f815a5d5c1fb6c91e6510afceb5d055231e19d7e Mon Sep 17 00:00:00 2001 From: Dadepo Aderemi Date: Sun, 26 Nov 2023 17:57:55 +0400 Subject: [PATCH] Added postgres erf function --- Cargo.lock | 5 ++- Cargo.toml | 1 + src/common/test_utils.rs | 30 +++++++++++++- src/postgres/math_udfs.rs | 84 ++++++++++++++++++++++++++++++++++++++- src/postgres/mod.rs | 18 ++++++++- supports/postgres.md | 2 +- 6 files changed, 132 insertions(+), 8 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 175e928..00d6f94 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -806,6 +806,7 @@ dependencies = [ "anyhow", "datafusion", "ipnet", + "libm", "serde", "serde_json", "serde_json_path", @@ -1283,9 +1284,9 @@ checksum = "f92be4933c13fd498862a9e02a3055f8a8d9c039ce33db97306fd5a6caa7f29b" [[package]] name = "libm" -version = "0.2.7" +version = "0.2.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f7012b1bbb0719e1097c47611d3898568c546d597c2e74d66f6087edd5233ff4" +checksum = "4ec2a862134d2a7d32d7983ddcdd1c4923530833c9f2ea1a44fc5fa473989058" [[package]] name = "linux-raw-sys" diff --git a/Cargo.toml b/Cargo.toml index d636f5b..839e56e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -7,6 +7,7 @@ edition = "2021" anyhow = "1.0.75" datafusion = "33.0.0" ipnet = "2.7.2" +libm = "0.2.8" serde = "1.0.192" serde_json = { version = "1.0.108", features = ["preserve_order"] } serde_json_path = "0.6.4" diff --git a/src/common/test_utils.rs b/src/common/test_utils.rs index 861afe6..6f86579 100644 --- a/src/common/test_utils.rs +++ b/src/common/test_utils.rs @@ -1,4 +1,4 @@ -use datafusion::arrow::array::{StringArray, UInt8Array}; +use datafusion::arrow::array::{Float64Array, Int64Array, StringArray, UInt64Array, UInt8Array}; use datafusion::arrow::datatypes::{DataType, Field, Schema}; use datafusion::arrow::record_batch::RecordBatch; use datafusion::error::Result; @@ -43,6 +43,34 @@ pub fn set_up_network_data_test() -> Result { // declare a table in memory. Ok(ctx) } + +pub fn set_up_maths_data_test() -> Result { + // define a schema. + let schema = Arc::new(Schema::new(vec![ + Field::new("index", DataType::UInt8, false), + Field::new("uint", DataType::UInt64, true), + Field::new("int", DataType::Int64, true), + Field::new("float", DataType::Float64, true), + ])); + + // define data. + let batch = RecordBatch::try_new( + schema, + vec![ + Arc::new(UInt8Array::from_iter_values([1, 2, 3])), + Arc::new(UInt64Array::from(vec![Some(2), Some(3), None])), + Arc::new(Int64Array::from(vec![Some(-2), Some(3), None])), + Arc::new(Float64Array::from(vec![Some(1.0), Some(3.3), None])), + ], + )?; + + // declare a new context + let ctx = SessionContext::new(); + ctx.register_batch("maths_table", batch)?; + // declare a table in memory. + Ok(ctx) +} + pub fn set_up_json_data_test() -> Result { // define a schema. let schema = Arc::new(Schema::new(vec![ diff --git a/src/postgres/math_udfs.rs b/src/postgres/math_udfs.rs index e885298..003cb2d 100644 --- a/src/postgres/math_udfs.rs +++ b/src/postgres/math_udfs.rs @@ -1,6 +1,7 @@ use std::sync::Arc; use datafusion::arrow::array::{Array, ArrayRef, Float64Array, Int64Array}; +use datafusion::arrow::datatypes::DataType; use datafusion::common::DataFusionError; use datafusion::error::Result; @@ -36,10 +37,61 @@ pub fn div(args: &[ArrayRef]) -> Result { Ok(Arc::new(int64array_builder.finish()) as ArrayRef) } +/// Error function +pub fn erf(args: &[ArrayRef]) -> Result { + let column_data = &args[0]; + let data = column_data.into_data(); + let data_type = data.data_type(); + + let mut float64array_builder = Float64Array::builder(args[0].len()); + match data_type { + DataType::Float64 => { + let values = datafusion::common::cast::as_float64_array(&args[0])?; + values.iter().try_for_each(|value| { + if let Some(value) = value { + float64array_builder.append_value(libm::erf(value)) + } else { + float64array_builder.append_null(); + } + Ok::<(), DataFusionError>(()) + })?; + } + DataType::Int64 => { + let values = datafusion::common::cast::as_int64_array(&args[0])?; + values.iter().try_for_each(|value| { + if let Some(value) = value { + float64array_builder.append_value(libm::erf(value as f64)) + } else { + float64array_builder.append_null(); + } + Ok::<(), DataFusionError>(()) + })?; + } + DataType::UInt64 => { + let values = datafusion::common::cast::as_uint64_array(&args[0])?; + values.iter().try_for_each(|value| { + if let Some(value) = value { + float64array_builder.append_value(libm::erf(value as f64)) + } else { + float64array_builder.append_null(); + } + Ok::<(), DataFusionError>(()) + })?; + } + t => { + return Err(DataFusionError::Internal(format!( + "Unsupported type {t} for erf function" + ))) + } + }; + + Ok(Arc::new(float64array_builder.finish()) as ArrayRef) +} + #[cfg(feature = "postgres")] #[cfg(test)] mod tests { - use crate::common::test_utils::set_up_network_data_test; + use crate::common::test_utils::set_up_maths_data_test; use crate::postgres::register_postgres_udfs; use datafusion::assert_batches_sorted_eq; use datafusion::prelude::SessionContext; @@ -98,8 +150,36 @@ mod tests { Ok(()) } + #[tokio::test] + async fn test_erf() -> Result<()> { + let ctx = register_udfs_for_test()?; + let df = ctx.sql("select index, erf(uint) as uint, erf(int) as int, erf(float) as float from maths_table ORDER BY index ASC").await?; + + let batches = df.clone().collect().await?; + + let expected: Vec<&str> = r#" ++-------+--------------------+---------------------+--------------------+ +| index | uint | int | float | ++-------+--------------------+---------------------+--------------------+ +| 1 | 0.9953222650189527 | -0.9953222650189527 | 0.8427007929497149 | +| 2 | 0.9999779095030014 | 0.9999779095030014 | 0.9999969422902035 | +| 3 | | | | ++-------+--------------------+---------------------+--------------------+"# + .split('\n') + .filter_map(|input| { + if input.is_empty() { + None + } else { + Some(input.trim()) + } + }) + .collect(); + assert_batches_sorted_eq!(expected, &batches); + Ok(()) + } + fn register_udfs_for_test() -> Result { - let ctx = set_up_network_data_test()?; + let ctx = set_up_maths_data_test()?; register_postgres_udfs(&ctx)?; Ok(ctx) } diff --git a/src/postgres/mod.rs b/src/postgres/mod.rs index d9cf7a5..baa7133 100644 --- a/src/postgres/mod.rs +++ b/src/postgres/mod.rs @@ -2,12 +2,12 @@ use std::sync::Arc; -use crate::postgres::math_udfs::{ceiling, div}; +use crate::postgres::math_udfs::{ceiling, div, erf}; use crate::postgres::network_udfs::{ broadcast, family, host, hostmask, inet_merge, inet_same_family, masklen, netmask, network, set_masklen, }; -use datafusion::arrow::datatypes::DataType::{Boolean, Float64, Int64, UInt8, Utf8}; +use datafusion::arrow::datatypes::DataType::{Boolean, Float64, Int64, UInt64, UInt8, Utf8}; use datafusion::error::Result; use datafusion::logical_expr::{ReturnTypeFunction, ScalarUDF, Signature, Volatility}; use datafusion::physical_expr::functions::make_scalar_function; @@ -24,6 +24,7 @@ pub fn register_postgres_udfs(ctx: &SessionContext) -> Result<()> { fn register_math_udfs(ctx: &SessionContext) -> Result<()> { register_ceiling(ctx); + register_erf(ctx); register_div(ctx); Ok(()) } @@ -41,6 +42,19 @@ fn register_ceiling(ctx: &SessionContext) { ctx.register_udf(ceiling_udf); } +fn register_erf(ctx: &SessionContext) { + let erf_udf = make_scalar_function(erf); + let return_type: ReturnTypeFunction = Arc::new(move |_| Ok(Arc::new(Float64))); + let erf_udf = ScalarUDF::new( + "erf", + &Signature::uniform(1, vec![Int64, UInt64, Float64], Volatility::Immutable), + &return_type, + &erf_udf, + ); + + ctx.register_udf(erf_udf); +} + fn register_div(ctx: &SessionContext) { let udf = make_scalar_function(div); let return_type: ReturnTypeFunction = Arc::new(move |_| Ok(Arc::new(Int64))); diff --git a/supports/postgres.md b/supports/postgres.md index da775f4..ee1f6bf 100644 --- a/supports/postgres.md +++ b/supports/postgres.md @@ -25,7 +25,7 @@ https://www.postgresql.org/docs/16/functions-math.html |-------------|--------------------------------------------------------------------------------------------|-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|-----------------------------------------| | ✅︎ | ceiling ( numeric ) → numeric | Nearest integer greater than or equal to argument (same as ceil) | ceiling(95.3) → 96 | | ✅︎ | div ( y numeric, x numeric ) → numeric | Integer quotient of y/x (truncates towards zero) | div(9, 4) → 2 | -| ❓ | erf ( double precision ) → double precision | Error function | erf(1.0) → 0.8427007929497149 | +| ✅︎ | erf ( double precision ) → double precision | Error function | erf(1.0) → 0.8427007929497149 | | ❓ | erfc ( double precision ) → double precision | Complementary error function (1 - erf(x), without loss of precision for large inputs) | erfc(1.0) → 0.15729920705028513 | | ❓ | min_scale ( numeric ) → integer | Minimum scale (number of fractional decimal digits) needed to represent the supplied value precisely | min_scale(8.4100) → 2 | | ❓ | mod ( y numeric_type, x numeric_type ) → numeric_type | Remainder of y/x; available for smallint, integer, bigint, and numeric | mod(9, 4) → 1 |