Skip to content

Commit

Permalink
Added postgres random_normal
Browse files Browse the repository at this point in the history
  • Loading branch information
dadepo committed Jan 13, 2024
1 parent 219a1cf commit f7ed9bd
Show file tree
Hide file tree
Showing 5 changed files with 169 additions and 7 deletions.
12 changes: 12 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ anyhow = "1.0.75"
datafusion = "33.0.0"
ipnet = "2.7.2"
libm = "0.2.8"
rand = "0.8.5"
rand_distr = "0.4.3"
serde = "1.0.192"
serde_json = { version = "1.0.108", features = ["preserve_order"] }
serde_json_path = "0.6.4"
Expand Down
136 changes: 133 additions & 3 deletions src/postgres/math_udfs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@ use datafusion::arrow::array::{Array, ArrayRef, Float64Array, Int64Array};
use datafusion::arrow::datatypes::DataType;
use datafusion::common::DataFusionError;
use datafusion::error::Result;
use rand::distributions::Distribution;
use rand::thread_rng;
use rand_distr::Normal;

/// Inverse cosine, result in degrees.
pub fn acosd(args: &[ArrayRef]) -> Result<ArrayRef> {
Expand Down Expand Up @@ -307,17 +310,124 @@ pub fn erfc(args: &[ArrayRef]) -> Result<ArrayRef> {
}
};

Ok(Arc::new(float64array_builder.finish()) as ArrayRef)
let array = float64array_builder.finish();
Ok(Arc::new(array) as ArrayRef)
}

/// Returns a random value from the normal distribution with the given parameters;
/// mean defaults to 0.0 and stddev defaults to 1.0.
/// Example random_normal(0.0, 1.0) could return 0.051285419
pub fn random_normal(args: &[ArrayRef]) -> Result<ArrayRef> {
if args.len() > 2_usize {
return Err(DataFusionError::Internal(
"No function matches the given name and argument types.".to_string(),
));
}

let args = &args
.iter()
.filter(|arg| !matches!(arg.data_type(), &DataType::Null))
.cloned()
.collect::<Vec<ArrayRef>>()[..];

let means = args.first();
let std_devs = args.get(1);

let float64array = match (means, std_devs) {
(Some(means), Some(std_devs)) => {
let mut float64array_builder = Float64Array::builder(means.len());
let means = datafusion::common::cast::as_float64_array(means)?;
let std_devs = datafusion::common::cast::as_float64_array(std_devs)?;
means
.iter()
.zip(std_devs.iter())
.try_for_each(|(mean, std_dev)| {
if let (Some(mean), Some(std_dev)) = (mean, std_dev) {
let normal = Normal::new(mean, std_dev).map_err(|_| {
DataFusionError::Internal(
"Runtime error: Failed to create normal distribution".to_string(),
)
})?;
let mut rng = thread_rng();
let value = normal.sample(&mut rng);
float64array_builder.append_value(value);
Ok::<(), DataFusionError>(())
} else {
float64array_builder.append_null();
Ok::<(), DataFusionError>(())
}
})?;
float64array_builder.finish()
}
(Some(means), None) => {
let mut float64array_builder = Float64Array::builder(means.len());
let means = datafusion::common::cast::as_float64_array(means)?;
means.iter().try_for_each(|mean| {
if let Some(mean) = mean {
let normal = Normal::new(mean, 1.0_f64).map_err(|_| {
DataFusionError::Internal(
"Runtime error: Failed to create normal distribution".to_string(),
)
})?;
let mut rng = thread_rng();
let value = normal.sample(&mut rng);
float64array_builder.append_value(value);
Ok::<(), DataFusionError>(())
} else {
float64array_builder.append_null();
Ok::<(), DataFusionError>(())
}
})?;
float64array_builder.finish()
}
(None, Some(std_devs)) => {
let mut float64array_builder = Float64Array::builder(std_devs.len());
let std_devs = datafusion::common::cast::as_float64_array(std_devs)?;

std_devs.iter().try_for_each(|std_dev| {
if let Some(std_dev) = std_dev {
let normal = Normal::new(0.0_f64, std_dev).map_err(|_| {
DataFusionError::Internal(
"Runtime error: Failed to create normal distribution".to_string(),
)
})?;
let mut rng = thread_rng();
let value = normal.sample(&mut rng);
float64array_builder.append_value(value);
Ok::<(), DataFusionError>(())
} else {
float64array_builder.append_null();
Ok::<(), DataFusionError>(())
}
})?;
float64array_builder.finish()
}
(None, None) => {
let mut float64array_builder = Float64Array::builder(1);
let normal = Normal::new(0.0_f64, 1.0_f64).map_err(|_| {
DataFusionError::Internal(
"Runtime error: Failed to create normal distribution".to_string(),
)
})?;
let mut rng = thread_rng();
let value = normal.sample(&mut rng);
float64array_builder.append_value(value);
float64array_builder.finish()
}
};

Ok(Arc::new(float64array) as ArrayRef)
}

#[cfg(feature = "postgres")]
#[cfg(test)]
mod tests {
use crate::common::test_utils::set_up_maths_data_test;
use crate::postgres::register_postgres_udfs;
use datafusion::assert_batches_sorted_eq;
use datafusion::prelude::SessionContext;

use crate::common::test_utils::set_up_maths_data_test;
use crate::postgres::register_postgres_udfs;

use super::*;

#[tokio::test]
Expand Down Expand Up @@ -774,6 +884,26 @@ mod tests {
Ok(())
}

#[tokio::test]
async fn test_random_normal() -> Result<()> {
let ctx = register_udfs_for_test()?;

let df = ctx
.sql(
r#"
select random_normal(index) as index,
random_normal(uint) as uint,
random_normal(int) as int,
random_normal(float) as float
from maths_table"#,
)
.await?;

df.clone().show().await?;
// No exception is ok.
Ok(())
}

fn register_udfs_for_test() -> Result<SessionContext> {
let ctx = set_up_maths_data_test()?;
register_postgres_udfs(&ctx)?;
Expand Down
24 changes: 21 additions & 3 deletions src/postgres/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,21 @@
use std::sync::Arc;

use crate::postgres::math_udfs::{
acosd, asind, atand, ceiling, cosd, cotd, div, erf, erfc, sind, tand,
acosd, asind, atand, ceiling, cosd, cotd, div, erf, erfc, random_normal, sind, tand,
};
use crate::postgres::network_udfs::{
broadcast, family, host, hostmask, inet_merge, inet_same_family, masklen, netmask, network,
set_masklen,
};
use datafusion::arrow::datatypes::DataType::{Boolean, Float64, Int64, UInt64, UInt8, Utf8};
use datafusion::error::Result;
use datafusion::logical_expr::{ReturnTypeFunction, ScalarUDF, Signature, Volatility};
use datafusion::logical_expr::TypeSignature::Any;
use datafusion::logical_expr::{
ReturnTypeFunction, ScalarUDF, Signature, TypeSignature, Volatility,
};
use datafusion::physical_expr::functions::make_scalar_function;
use datafusion::prelude::SessionContext;
use TypeSignature::Variadic;

mod math_udfs;
mod network_udfs;
Expand All @@ -33,9 +37,10 @@ fn register_math_udfs(ctx: &SessionContext) -> Result<()> {
register_atand(ctx);
register_tand(ctx);
register_ceiling(ctx);
register_div(ctx);
register_erf(ctx);
register_erfc(ctx);
register_div(ctx);
register_random_normal(ctx);
Ok(())
}

Expand Down Expand Up @@ -182,6 +187,19 @@ fn register_div(ctx: &SessionContext) {
ctx.register_udf(div_udf);
}

fn register_random_normal(ctx: &SessionContext) {
let udf = make_scalar_function(random_normal);
let return_type: ReturnTypeFunction = Arc::new(move |_| Ok(Arc::new(Float64)));
let random_normal_udf = ScalarUDF::new(
"random_normal",
&Signature::one_of(vec![Any(0), Variadic(vec![Float64])], Volatility::Immutable),
&return_type,
&udf,
);

ctx.register_udf(random_normal_udf);
}

fn register_network_udfs(ctx: &SessionContext) -> Result<()> {
register_broadcast(ctx);
register_family(ctx);
Expand Down
2 changes: 1 addition & 1 deletion supports/postgres.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ https://www.postgresql.org/docs/16/functions-math.html
|| sign ( numeric ) → numeric | Sign of the argument (-1, 0, or +1) | sign(-8.4) → -1 |
|| trim_scale ( numeric ) → numeric | Reduces the value's scale (number of fractional decimal digits) by removing trailing zeroes | trim_scale(8.4100) → 8.41 |
|| width_bucket ( operand numeric, low numeric, high numeric, count integer ) → integer | Returns the number of the bucket in which operand falls in a histogram having count equal-width buckets spanning the range low to high. Returns 0 or count+1 for an input outside that range. | width_bucket(5.35, 0.024, 10.06, 5) → 3 |
| | random_normal ( [ mean double precision [, stddev double precision ]] ) → double precision | Returns a random value from the normal distribution with the given parameters; mean defaults to 0.0 and stddev defaults to 1.0 | random_normal(0.0, 1.0) → 0.051285419 |
| | random_normal ( [ mean double precision [, stddev double precision ]] ) → double precision | Returns a random value from the normal distribution with the given parameters; mean defaults to 0.0 and stddev defaults to 1.0 | random_normal(0.0, 1.0) → 0.051285419 |
|| acosd ( double precision ) → double precision | Inverse cosine, result in degrees | acosd(0.5) → 60 |
|| asind ( double precision ) → double precision | Inverse sine, result in degrees | asind(0.5) → 30 |
|| atand ( double precision ) → double precision | Inverse tangent, result in degrees | atand(1) → 45 |
Expand Down

0 comments on commit f7ed9bd

Please sign in to comment.