Skip to content

Commit

Permalink
Added postgres acosd
Browse files Browse the repository at this point in the history
  • Loading branch information
dadepo committed Dec 10, 2023
1 parent ba991d8 commit af4a93f
Show file tree
Hide file tree
Showing 2 changed files with 101 additions and 1 deletion.
86 changes: 86 additions & 0 deletions src/postgres/math_udfs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,34 @@ use datafusion::arrow::datatypes::DataType;
use datafusion::common::DataFusionError;
use datafusion::error::Result;

/// Inverse cosine, result in radians.
pub fn acosd(args: &[ArrayRef]) -> Result<ArrayRef> {
let values = datafusion::common::cast::as_float64_array(&args[0])?;
let mut float64array_builder = Float64Array::builder(args[0].len());

values.iter().try_for_each(|value| {
if let Some(value) = value {
if value > 1.0 {
return Err(DataFusionError::Internal(
"input is out of range".to_string(),
));
}
let result = value.acos().to_degrees();
if result.fract() < 0.9 {
float64array_builder.append_value(result);
} else {
float64array_builder.append_value(result.ceil());
}
Ok::<(), DataFusionError>(())
} else {
float64array_builder.append_null();
Ok::<(), DataFusionError>(())
}
})?;

Ok(Arc::new(float64array_builder.finish()) as ArrayRef)
}

/// Nearest integer greater than or equal to argument (same as ceil).
pub fn ceiling(args: &[ArrayRef]) -> Result<ArrayRef> {
let values = datafusion::common::cast::as_float64_array(&args[0])?;
Expand Down Expand Up @@ -149,6 +177,63 @@ mod tests {

use super::*;

#[tokio::test]
async fn test_acosd() -> Result<()> {
let ctx = register_udfs_for_test()?;
let df = ctx.sql("select acosd(0.5) as col_result").await?;

let batches = df.clone().collect().await?;

let expected: Vec<&str> = r#"
+------------+
| col_result |
+------------+
| 60.0 |
+------------+"#
.split('\n')
.filter_map(|input| {
if input.is_empty() {
None
} else {
Some(input.trim())
}
})
.collect();
assert_batches_sorted_eq!(expected, &batches);

let df = ctx.sql("select acosd(0.4) as col_result").await?;

let batches = df.clone().collect().await?;

let expected: Vec<&str> = r#"
+-------------------+
| col_result |
+-------------------+
| 66.42182152179817 |
+-------------------+"#
.split('\n')
.filter_map(|input| {
if input.is_empty() {
None
} else {
Some(input.trim())
}
})
.collect();
assert_batches_sorted_eq!(expected, &batches);

let df = ctx.sql("select acosd(1.4) as col_result").await?;

let result = df.clone().collect().await;
assert!(result
.err()
.unwrap()
.to_string()
.contains("input is out of range"));

Ok(())
}

#[tokio::test]
async fn test_ceiling() -> Result<()> {
let ctx = register_udfs_for_test()?;
Expand Down Expand Up @@ -232,6 +317,7 @@ mod tests {
#[tokio::test]
async fn test_erfc() -> Result<()> {
let ctx = register_udfs_for_test()?;

let df = ctx.sql("select index, erfc(uint) as uint, erfc(int) as int, erfc(float) as float from maths_table ORDER BY index ASC").await?;

let batches = df.clone().collect().await?;
Expand Down
16 changes: 15 additions & 1 deletion src/postgres/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

use std::sync::Arc;

use crate::postgres::math_udfs::{ceiling, div, erf, erfc};
use crate::postgres::math_udfs::{acosd, ceiling, div, erf, erfc};
use crate::postgres::network_udfs::{
broadcast, family, host, hostmask, inet_merge, inet_same_family, masklen, netmask, network,
set_masklen,
Expand All @@ -23,13 +23,27 @@ pub fn register_postgres_udfs(ctx: &SessionContext) -> Result<()> {
}

fn register_math_udfs(ctx: &SessionContext) -> Result<()> {
register_acosd(ctx);
register_ceiling(ctx);
register_erf(ctx);
register_erfc(ctx);
register_div(ctx);
Ok(())
}

fn register_acosd(ctx: &SessionContext) {
let acosd_udf = make_scalar_function(acosd);
let return_type: ReturnTypeFunction = Arc::new(move |_| Ok(Arc::new(Float64)));
let acosd_udf = ScalarUDF::new(
"acosd",
&Signature::uniform(1, vec![Float64], Volatility::Immutable),
&return_type,
&acosd_udf,
);

ctx.register_udf(acosd_udf);
}

fn register_ceiling(ctx: &SessionContext) {
let ceiling_udf = make_scalar_function(ceiling);
let return_type: ReturnTypeFunction = Arc::new(move |_| Ok(Arc::new(Float64)));
Expand Down

0 comments on commit af4a93f

Please sign in to comment.