Skip to content

Commit

Permalink
Added postgres erfc
Browse files Browse the repository at this point in the history
  • Loading branch information
dadepo committed Dec 3, 2023
1 parent f815a5d commit ee6bd09
Show file tree
Hide file tree
Showing 4 changed files with 99 additions and 3 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/rust.yml
Original file line number Diff line number Diff line change
Expand Up @@ -113,4 +113,4 @@ jobs:
- name: Build
run: cargo build --verbose
- name: Run tests
run: cargo test --features "postgres" --verbose
run: cargo test --all-features --verbose
82 changes: 82 additions & 0 deletions src/postgres/math_udfs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,57 @@ pub fn erf(args: &[ArrayRef]) -> Result<ArrayRef> {
Ok(Arc::new(float64array_builder.finish()) as ArrayRef)
}

/// Complementary error function
pub fn erfc(args: &[ArrayRef]) -> Result<ArrayRef> {
let column_data = &args[0];
let data = column_data.into_data();
let data_type = data.data_type();

let mut float64array_builder = Float64Array::builder(args[0].len());
match data_type {
DataType::Float64 => {
let values = datafusion::common::cast::as_float64_array(&args[0])?;
values.iter().try_for_each(|value| {
if let Some(value) = value {
float64array_builder.append_value(libm::erfc(value))
} else {
float64array_builder.append_null();
}
Ok::<(), DataFusionError>(())
})?;
}
DataType::Int64 => {
let values = datafusion::common::cast::as_int64_array(&args[0])?;
values.iter().try_for_each(|value| {
if let Some(value) = value {
float64array_builder.append_value(libm::erfc(value as f64))
} else {
float64array_builder.append_null();
}
Ok::<(), DataFusionError>(())
})?;
}
DataType::UInt64 => {
let values = datafusion::common::cast::as_uint64_array(&args[0])?;
values.iter().try_for_each(|value| {
if let Some(value) = value {
float64array_builder.append_value(libm::erfc(value as f64))
} else {
float64array_builder.append_null();
}
Ok::<(), DataFusionError>(())
})?;
}
t => {
return Err(DataFusionError::Internal(format!(
"Unsupported type {t} for erf function"
)))
}
};

Ok(Arc::new(float64array_builder.finish()) as ArrayRef)
}

#[cfg(feature = "postgres")]
#[cfg(test)]
mod tests {
Expand Down Expand Up @@ -178,6 +229,37 @@ mod tests {
Ok(())
}

#[tokio::test]
async fn test_erfc() -> Result<()> {
let ctx = register_udfs_for_test()?;
let df = ctx.sql("select index, uint as uint, int as int, float as float from maths_table ORDER BY index ASC").await?;
df.show().await?;

let df = ctx.sql("select index, erfc(uint) as uint, erfc(int) as int, erfc(float) as float from maths_table ORDER BY index ASC").await?;

let batches = df.clone().collect().await?;

let expected: Vec<&str> = r#"
+-------+-------------------------+-------------------------+----------------------+
| index | uint | int | float |
+-------+-------------------------+-------------------------+----------------------+
| 1 | 0.004677734981047266 | 1.9953222650189528 | 0.15729920705028513 |
| 2 | 0.000022090496998585438 | 0.000022090496998585438 | 3.057709796438165e-6 |
| 3 | | | |
+-------+-------------------------+-------------------------+----------------------+"#
.split('\n')
.filter_map(|input| {
if input.is_empty() {
None
} else {
Some(input.trim())
}
})
.collect();
assert_batches_sorted_eq!(expected, &batches);
Ok(())
}

fn register_udfs_for_test() -> Result<SessionContext> {
let ctx = set_up_maths_data_test()?;
register_postgres_udfs(&ctx)?;
Expand Down
16 changes: 15 additions & 1 deletion src/postgres/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

use std::sync::Arc;

use crate::postgres::math_udfs::{ceiling, div, erf};
use crate::postgres::math_udfs::{ceiling, div, erf, erfc};
use crate::postgres::network_udfs::{
broadcast, family, host, hostmask, inet_merge, inet_same_family, masklen, netmask, network,
set_masklen,
Expand All @@ -25,6 +25,7 @@ pub fn register_postgres_udfs(ctx: &SessionContext) -> Result<()> {
fn register_math_udfs(ctx: &SessionContext) -> Result<()> {
register_ceiling(ctx);
register_erf(ctx);
register_erfc(ctx);
register_div(ctx);
Ok(())
}
Expand Down Expand Up @@ -55,6 +56,19 @@ fn register_erf(ctx: &SessionContext) {
ctx.register_udf(erf_udf);
}

fn register_erfc(ctx: &SessionContext) {
let erfc_udf = make_scalar_function(erfc);
let return_type: ReturnTypeFunction = Arc::new(move |_| Ok(Arc::new(Float64)));
let erfc_udf = ScalarUDF::new(
"erfc",
&Signature::uniform(1, vec![Int64, UInt64, Float64], Volatility::Immutable),
&return_type,
&erfc_udf,
);

ctx.register_udf(erfc_udf);
}

fn register_div(ctx: &SessionContext) {
let udf = make_scalar_function(div);
let return_type: ReturnTypeFunction = Arc::new(move |_| Ok(Arc::new(Int64)));
Expand Down
2 changes: 1 addition & 1 deletion supports/postgres.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ https://www.postgresql.org/docs/16/functions-math.html
| ✅︎ | ceiling ( numeric ) → numeric | Nearest integer greater than or equal to argument (same as ceil) | ceiling(95.3) → 96 |
| ✅︎ | div ( y numeric, x numeric ) → numeric | Integer quotient of y/x (truncates towards zero) | div(9, 4) → 2 |
| ✅︎ | erf ( double precision ) → double precision | Error function | erf(1.0) → 0.8427007929497149 |
| | erfc ( double precision ) → double precision | Complementary error function (1 - erf(x), without loss of precision for large inputs) | erfc(1.0) → 0.15729920705028513 |
| | erfc ( double precision ) → double precision | Complementary error function (1 - erf(x), without loss of precision for large inputs) | erfc(1.0) → 0.15729920705028513 |
|| min_scale ( numeric ) → integer | Minimum scale (number of fractional decimal digits) needed to represent the supplied value precisely | min_scale(8.4100) → 2 |
|| mod ( y numeric_type, x numeric_type ) → numeric_type | Remainder of y/x; available for smallint, integer, bigint, and numeric | mod(9, 4) → 1 |
|| scale ( numeric ) → integer | Scale of the argument (the number of decimal digits in the fractional part) | scale(8.4100) → 4 |
Expand Down

0 comments on commit ee6bd09

Please sign in to comment.