Skip to content

Commit

Permalink
Added Postgres sign function
Browse files Browse the repository at this point in the history
  • Loading branch information
dadepo committed Oct 20, 2024
1 parent 05a023a commit 0cd7bc7
Show file tree
Hide file tree
Showing 3 changed files with 149 additions and 3 deletions.
147 changes: 146 additions & 1 deletion src/postgres/math_udfs.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use std::sync::Arc;

use datafusion::arrow::array::{Array, ArrayRef, Float64Array, Int64Array};
use datafusion::arrow::array::{Array, ArrayRef, Float64Array, Int64Array, Int8Array};
use datafusion::arrow::datatypes::DataType;
use datafusion::arrow::datatypes::DataType::{Float64, Int64, UInt64};

Expand Down Expand Up @@ -889,6 +889,99 @@ impl ScalarUDFImpl for Mod {
}
}

#[derive(Debug)]
pub struct Sign {
signature: Signature,
}

impl Sign {
pub fn new() -> Self {
Self {
signature: Signature::uniform(1, vec![Int64, UInt64, Float64], Volatility::Volatile),
}
}
}

impl ScalarUDFImpl for Sign {
fn as_any(&self) -> &dyn std::any::Any {
self
}

fn name(&self) -> &str {
"sign"
}

fn signature(&self) -> &Signature {
&self.signature
}

fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
Ok(DataType::Int8)
}

fn invoke(&self, args: &[ColumnarValue]) -> Result<ColumnarValue> {
let args = ColumnarValue::values_to_arrays(args)?;
let mut int8array_builder = Int8Array::builder(args[0].len());
match args.first() {
Some(values) => {
match values.data_type() {
DataType::Int64 | DataType::Int32 | DataType::Int16 | DataType::Int8 => {
let values = datafusion::common::cast::as_int64_array(values)?;
values.iter().for_each(|value| {
match value {
Some(value) => int8array_builder.append_value(value.cmp(&0_i64) as i8),
None => int8array_builder.append_null()
}
});
},
DataType::UInt64 | DataType::UInt32 | DataType::UInt16 | DataType::UInt8 => {
let values = datafusion::common::cast::as_uint64_array(values)?;
values.iter().for_each(|value| {
match value {
Some(value) => int8array_builder.append_value(value.cmp(&0_u64) as i8),
None => int8array_builder.append_null()
}
});
},
DataType::Float64 | DataType::Float32 | DataType::Float16 => {
let values = datafusion::common::cast::as_float64_array(values)?;
values.iter().for_each(|value| {
match value {
Some(value) => {
if value == 0_f64 {
int8array_builder.append_value(0)
} else if value > 0_f64 {
int8array_builder.append_value(1)
} else {
int8array_builder.append_value(-1)
}
},
None => int8array_builder.append_null()
}
});
},
_ => {
return Err(DataFusionError::Internal(
"No function matches the given name and argument types. You might need to add explicit type casts"
.to_string(),
))
}
}
}
None => {
return Err(DataFusionError::Internal(
"No function matches the given name and argument types. You might need to add explicit type casts"
.to_string(),
))
}
}

Ok(ColumnarValue::Array(
Arc::new(int8array_builder.finish()) as ArrayRef
))
}
}

#[cfg(feature = "postgres")]
#[cfg(test)]
mod tests {
Expand Down Expand Up @@ -1349,6 +1442,58 @@ mod tests {
Ok(())
}

#[tokio::test]
async fn test_sign() -> Result<()> {
let ctx = register_udfs_for_test()?;

let df = ctx.sql("select index, sign(uint) as uint, sign(int) as int, sign(float) as float from maths_table ORDER BY index ASC").await?;
let batches = df.clone().collect().await?;

let expected: Vec<&str> = r#"
+-------+------+-----+-------+
| index | uint | int | float |
+-------+------+-----+-------+
| 1 | 1 | -1 | 1 |
| 2 | 1 | 1 | 1 |
| 3 | | | |
+-------+------+-----+-------+"#
.split('\n')
.filter_map(|input| {
if input.is_empty() {
None
} else {
Some(input.trim())
}
})
.collect();

assert_batches_sorted_eq!(expected, &batches);

let df = ctx
.sql("select sign(0) as uint, sign(0.0) as float")
.await?;
let batches = df.clone().collect().await?;

let expected: Vec<&str> = r#"
+------+-------+
| uint | float |
+------+-------+
| 0 | 0 |
+------+-------+"#
.split('\n')
.filter_map(|input| {
if input.is_empty() {
None
} else {
Some(input.trim())
}
})
.collect();

assert_batches_sorted_eq!(expected, &batches);
Ok(())
}

fn register_udfs_for_test() -> Result<SessionContext> {
let ctx = set_up_maths_data_test()?;
register_postgres_udfs(&ctx)?;
Expand Down
3 changes: 2 additions & 1 deletion src/postgres/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use datafusion::logical_expr::ScalarUDF;
use datafusion::prelude::SessionContext;

use crate::postgres::math_udfs::{
Acosd, Asind, Atand, Ceiling, Cosd, Cotd, Div, Erf, Erfc, Mod, RandomNormal, Sind, Tand,
Acosd, Asind, Atand, Ceiling, Cosd, Cotd, Div, Erf, Erfc, Mod, RandomNormal, Sign, Sind, Tand,
};
use crate::postgres::network_udfs::{
Broadcast, Family, Host, HostMask, InetMerge, InetSameFamily, MaskLen, Netmask, Network,
Expand Down Expand Up @@ -36,6 +36,7 @@ fn register_math_udfs(ctx: &SessionContext) -> Result<()> {
ctx.register_udf(ScalarUDF::from(Erfc::new()));
ctx.register_udf(ScalarUDF::from(RandomNormal::new()));
ctx.register_udf(ScalarUDF::from(Mod::new()));
ctx.register_udf(ScalarUDF::from(Sign::new()));
Ok(())
}

Expand Down
2 changes: 1 addition & 1 deletion supports/postgres.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ https://www.postgresql.org/docs/16/functions-math.html
|| min_scale ( numeric ) → integer | Minimum scale (number of fractional decimal digits) needed to represent the supplied value precisely | min_scale(8.4100) → 2 |
|| mod ( y numeric_type, x numeric_type ) → numeric_type | Remainder of y/x; available for smallint, integer, bigint, and numeric | mod(9, 4) → 1 |
|| scale ( numeric ) → integer | Scale of the argument (the number of decimal digits in the fractional part) | scale(8.4100) → 4 |
| | sign ( numeric ) → numeric | Sign of the argument (-1, 0, or +1) | sign(-8.4) → -1 |
| | sign ( numeric ) → numeric | Sign of the argument (-1, 0, or +1) | sign(-8.4) → -1 |
|| trim_scale ( numeric ) → numeric | Reduces the value's scale (number of fractional decimal digits) by removing trailing zeroes | trim_scale(8.4100) → 8.41 |
|| width_bucket ( operand numeric, low numeric, high numeric, count integer ) → integer | Returns the number of the bucket in which operand falls in a histogram having count equal-width buckets spanning the range low to high. Returns 0 or count+1 for an input outside that range. | width_bucket(5.35, 0.024, 10.06, 5) → 3 |
|| random_normal ( [ mean double precision [, stddev double precision ]] ) → double precision | Returns a random value from the normal distribution with the given parameters; mean defaults to 0.0 and stddev defaults to 1.0 | random_normal(0.0, 1.0) → 0.051285419 |
Expand Down

0 comments on commit 0cd7bc7

Please sign in to comment.