Skip to content

Commit

Permalink
Added postgres erf function
Browse files Browse the repository at this point in the history
  • Loading branch information
dadepo committed Nov 26, 2023
1 parent 011dd37 commit f815a5d
Show file tree
Hide file tree
Showing 6 changed files with 132 additions and 8 deletions.
5 changes: 3 additions & 2 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ edition = "2021"
anyhow = "1.0.75"
datafusion = "33.0.0"
ipnet = "2.7.2"
libm = "0.2.8"
serde = "1.0.192"
serde_json = { version = "1.0.108", features = ["preserve_order"] }
serde_json_path = "0.6.4"
Expand Down
30 changes: 29 additions & 1 deletion src/common/test_utils.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use datafusion::arrow::array::{StringArray, UInt8Array};
use datafusion::arrow::array::{Float64Array, Int64Array, StringArray, UInt64Array, UInt8Array};
use datafusion::arrow::datatypes::{DataType, Field, Schema};
use datafusion::arrow::record_batch::RecordBatch;
use datafusion::error::Result;
Expand Down Expand Up @@ -43,6 +43,34 @@ pub fn set_up_network_data_test() -> Result<SessionContext> {
// declare a table in memory.
Ok(ctx)
}

pub fn set_up_maths_data_test() -> Result<SessionContext> {
// define a schema.
let schema = Arc::new(Schema::new(vec![
Field::new("index", DataType::UInt8, false),
Field::new("uint", DataType::UInt64, true),
Field::new("int", DataType::Int64, true),
Field::new("float", DataType::Float64, true),
]));

// define data.
let batch = RecordBatch::try_new(
schema,
vec![
Arc::new(UInt8Array::from_iter_values([1, 2, 3])),
Arc::new(UInt64Array::from(vec![Some(2), Some(3), None])),
Arc::new(Int64Array::from(vec![Some(-2), Some(3), None])),
Arc::new(Float64Array::from(vec![Some(1.0), Some(3.3), None])),
],
)?;

// declare a new context
let ctx = SessionContext::new();
ctx.register_batch("maths_table", batch)?;
// declare a table in memory.
Ok(ctx)
}

pub fn set_up_json_data_test() -> Result<SessionContext> {
// define a schema.
let schema = Arc::new(Schema::new(vec![
Expand Down
84 changes: 82 additions & 2 deletions src/postgres/math_udfs.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use std::sync::Arc;

use datafusion::arrow::array::{Array, ArrayRef, Float64Array, Int64Array};
use datafusion::arrow::datatypes::DataType;
use datafusion::common::DataFusionError;
use datafusion::error::Result;

Expand Down Expand Up @@ -36,10 +37,61 @@ pub fn div(args: &[ArrayRef]) -> Result<ArrayRef> {
Ok(Arc::new(int64array_builder.finish()) as ArrayRef)
}

/// Error function
pub fn erf(args: &[ArrayRef]) -> Result<ArrayRef> {
let column_data = &args[0];
let data = column_data.into_data();
let data_type = data.data_type();

let mut float64array_builder = Float64Array::builder(args[0].len());
match data_type {
DataType::Float64 => {
let values = datafusion::common::cast::as_float64_array(&args[0])?;
values.iter().try_for_each(|value| {
if let Some(value) = value {
float64array_builder.append_value(libm::erf(value))
} else {
float64array_builder.append_null();
}
Ok::<(), DataFusionError>(())
})?;
}
DataType::Int64 => {
let values = datafusion::common::cast::as_int64_array(&args[0])?;
values.iter().try_for_each(|value| {
if let Some(value) = value {
float64array_builder.append_value(libm::erf(value as f64))
} else {
float64array_builder.append_null();
}
Ok::<(), DataFusionError>(())
})?;
}
DataType::UInt64 => {
let values = datafusion::common::cast::as_uint64_array(&args[0])?;
values.iter().try_for_each(|value| {
if let Some(value) = value {
float64array_builder.append_value(libm::erf(value as f64))
} else {
float64array_builder.append_null();
}
Ok::<(), DataFusionError>(())
})?;
}
t => {
return Err(DataFusionError::Internal(format!(
"Unsupported type {t} for erf function"
)))
}
};

Ok(Arc::new(float64array_builder.finish()) as ArrayRef)
}

#[cfg(feature = "postgres")]
#[cfg(test)]
mod tests {
use crate::common::test_utils::set_up_network_data_test;
use crate::common::test_utils::set_up_maths_data_test;
use crate::postgres::register_postgres_udfs;
use datafusion::assert_batches_sorted_eq;
use datafusion::prelude::SessionContext;
Expand Down Expand Up @@ -98,8 +150,36 @@ mod tests {
Ok(())
}

#[tokio::test]
async fn test_erf() -> Result<()> {
let ctx = register_udfs_for_test()?;
let df = ctx.sql("select index, erf(uint) as uint, erf(int) as int, erf(float) as float from maths_table ORDER BY index ASC").await?;

let batches = df.clone().collect().await?;

let expected: Vec<&str> = r#"
+-------+--------------------+---------------------+--------------------+
| index | uint | int | float |
+-------+--------------------+---------------------+--------------------+
| 1 | 0.9953222650189527 | -0.9953222650189527 | 0.8427007929497149 |
| 2 | 0.9999779095030014 | 0.9999779095030014 | 0.9999969422902035 |
| 3 | | | |
+-------+--------------------+---------------------+--------------------+"#
.split('\n')
.filter_map(|input| {
if input.is_empty() {
None
} else {
Some(input.trim())
}
})
.collect();
assert_batches_sorted_eq!(expected, &batches);
Ok(())
}

fn register_udfs_for_test() -> Result<SessionContext> {
let ctx = set_up_network_data_test()?;
let ctx = set_up_maths_data_test()?;
register_postgres_udfs(&ctx)?;
Ok(ctx)
}
Expand Down
18 changes: 16 additions & 2 deletions src/postgres/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,12 @@

use std::sync::Arc;

use crate::postgres::math_udfs::{ceiling, div};
use crate::postgres::math_udfs::{ceiling, div, erf};
use crate::postgres::network_udfs::{
broadcast, family, host, hostmask, inet_merge, inet_same_family, masklen, netmask, network,
set_masklen,
};
use datafusion::arrow::datatypes::DataType::{Boolean, Float64, Int64, UInt8, Utf8};
use datafusion::arrow::datatypes::DataType::{Boolean, Float64, Int64, UInt64, UInt8, Utf8};
use datafusion::error::Result;
use datafusion::logical_expr::{ReturnTypeFunction, ScalarUDF, Signature, Volatility};
use datafusion::physical_expr::functions::make_scalar_function;
Expand All @@ -24,6 +24,7 @@ pub fn register_postgres_udfs(ctx: &SessionContext) -> Result<()> {

fn register_math_udfs(ctx: &SessionContext) -> Result<()> {
register_ceiling(ctx);
register_erf(ctx);
register_div(ctx);
Ok(())
}
Expand All @@ -41,6 +42,19 @@ fn register_ceiling(ctx: &SessionContext) {
ctx.register_udf(ceiling_udf);
}

fn register_erf(ctx: &SessionContext) {
let erf_udf = make_scalar_function(erf);
let return_type: ReturnTypeFunction = Arc::new(move |_| Ok(Arc::new(Float64)));
let erf_udf = ScalarUDF::new(
"erf",
&Signature::uniform(1, vec![Int64, UInt64, Float64], Volatility::Immutable),
&return_type,
&erf_udf,
);

ctx.register_udf(erf_udf);
}

fn register_div(ctx: &SessionContext) {
let udf = make_scalar_function(div);
let return_type: ReturnTypeFunction = Arc::new(move |_| Ok(Arc::new(Int64)));
Expand Down
2 changes: 1 addition & 1 deletion supports/postgres.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ https://www.postgresql.org/docs/16/functions-math.html
|-------------|--------------------------------------------------------------------------------------------|-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|-----------------------------------------|
| ✅︎ | ceiling ( numeric ) → numeric | Nearest integer greater than or equal to argument (same as ceil) | ceiling(95.3) → 96 |
| ✅︎ | div ( y numeric, x numeric ) → numeric | Integer quotient of y/x (truncates towards zero) | div(9, 4) → 2 |
| | erf ( double precision ) → double precision | Error function | erf(1.0) → 0.8427007929497149 |
| ✅︎ | erf ( double precision ) → double precision | Error function | erf(1.0) → 0.8427007929497149 |
|| erfc ( double precision ) → double precision | Complementary error function (1 - erf(x), without loss of precision for large inputs) | erfc(1.0) → 0.15729920705028513 |
|| min_scale ( numeric ) → integer | Minimum scale (number of fractional decimal digits) needed to represent the supplied value precisely | min_scale(8.4100) → 2 |
|| mod ( y numeric_type, x numeric_type ) → numeric_type | Remainder of y/x; available for smallint, integer, bigint, and numeric | mod(9, 4) → 1 |
Expand Down

0 comments on commit f815a5d

Please sign in to comment.