Skip to content

Commit

Permalink
switch implementation of acosd to use ScalarUDFImpl
Browse files Browse the repository at this point in the history
  • Loading branch information
dadepo committed Apr 20, 2024
1 parent a465fd2 commit 0efa0ab
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 35 deletions.
72 changes: 53 additions & 19 deletions src/postgres/math_udfs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,27 +14,61 @@ use rand::thread_rng;
use rand_distr::Normal;

/// Inverse cosine, result in degrees.
pub fn acosd(args: &[ArrayRef]) -> Result<ArrayRef> {
let values = datafusion::common::cast::as_float64_array(&args[0])?;
let mut float64array_builder = Float64Array::builder(args[0].len());

values.iter().try_for_each(|value| {
if let Some(value) = value {
if value > 1.0 {
return Err(DataFusionError::Internal(
"input is out of range".to_string(),
));
}
let result = value.acos().to_degrees();
float64array_builder.append_value(result);
Ok::<(), DataFusionError>(())
} else {
float64array_builder.append_null();
Ok::<(), DataFusionError>(())
#[derive(Debug)]
pub struct Acosd {
signature: Signature,
}

impl Acosd {
pub fn new() -> Self {
Self {
signature: Signature::uniform(1, vec![Float64], Volatility::Immutable),
}
})?;
}
}

impl ScalarUDFImpl for Acosd {
fn as_any(&self) -> &dyn std::any::Any {
self
}

fn name(&self) -> &str {
"acosd"
}

fn signature(&self) -> &Signature {
&self.signature
}

fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
Ok(Float64)
}

Ok(Arc::new(float64array_builder.finish()) as ArrayRef)
fn invoke(&self, args: &[ColumnarValue]) -> Result<ColumnarValue> {
let args = ColumnarValue::values_to_arrays(args)?;
let values = datafusion::common::cast::as_float64_array(&args[0])?;
let mut float64array_builder = Float64Array::builder(args[0].len());

values.iter().try_for_each(|value| {
if let Some(value) = value {
if value > 1.0 {
return Err(DataFusionError::Internal(
"input is out of range".to_string(),
));
}
let result = value.acos().to_degrees();
float64array_builder.append_value(result);
Ok::<(), DataFusionError>(())
} else {
float64array_builder.append_null();
Ok::<(), DataFusionError>(())
}
})?;

Ok(ColumnarValue::Array(
Arc::new(float64array_builder.finish()) as ArrayRef,
))
}
}

/// Cosine, argument in degrees.
Expand Down
19 changes: 3 additions & 16 deletions src/postgres/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,14 @@

use std::sync::Arc;

use datafusion::arrow::datatypes::DataType::{Boolean, Float64, Int64, UInt8, Utf8};
use datafusion::arrow::datatypes::DataType::{Boolean, Int64, UInt8, Utf8};
use datafusion::error::Result;
use datafusion::logical_expr::{ReturnTypeFunction, ScalarUDF, Signature, Volatility};
use datafusion::physical_expr::functions::make_scalar_function;
use datafusion::prelude::SessionContext;

use crate::postgres::math_udfs::{
acosd, Asind, Atand, Ceiling, Cosd, Cotd, Div, Erf, Erfc, RandomNormal, Sind, Tand,
Acosd, Asind, Atand, Ceiling, Cosd, Cotd, Div, Erf, Erfc, RandomNormal, Sind, Tand,
};
use crate::postgres::network_udfs::{
broadcast, family, host, hostmask, inet_merge, inet_same_family, masklen, netmask, network,
Expand All @@ -27,7 +27,7 @@ pub fn register_postgres_udfs(ctx: &SessionContext) -> Result<()> {
}

fn register_math_udfs(ctx: &SessionContext) -> Result<()> {
register_acosd(ctx);
ctx.register_udf(ScalarUDF::from(Acosd::new()));
ctx.register_udf(ScalarUDF::from(Cosd::new()));
ctx.register_udf(ScalarUDF::from(Cotd::new()));
ctx.register_udf(ScalarUDF::from(Asind::new()));
Expand All @@ -42,19 +42,6 @@ fn register_math_udfs(ctx: &SessionContext) -> Result<()> {
Ok(())
}

fn register_acosd(ctx: &SessionContext) {
let acosd_udf = make_scalar_function(acosd);
let return_type: ReturnTypeFunction = Arc::new(move |_| Ok(Arc::new(Float64)));
let acosd_udf = ScalarUDF::new(
"acosd",
&Signature::uniform(1, vec![Float64], Volatility::Immutable),
&return_type,
&acosd_udf,
);

ctx.register_udf(acosd_udf);
}

fn register_network_udfs(ctx: &SessionContext) -> Result<()> {
register_broadcast(ctx);
register_family(ctx);
Expand Down

0 comments on commit 0efa0ab

Please sign in to comment.