Skip to content

Commit

Permalink
switch implementation of masklen to use ScalarUDFImpl
Browse files Browse the repository at this point in the history
  • Loading branch information
dadepo committed Apr 20, 2024
1 parent ce23cf2 commit d927c56
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 34 deletions.
19 changes: 3 additions & 16 deletions src/postgres/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@ use crate::postgres::math_udfs::{
Acosd, Asind, Atand, Ceiling, Cosd, Cotd, Div, Erf, Erfc, RandomNormal, Sind, Tand,
};
use crate::postgres::network_udfs::{
broadcast, family, host, hostmask, inet_merge, inet_same_family, masklen, SetMasklen, Netmask,
Network,
broadcast, family, host, hostmask, inet_merge, inet_same_family, MaskLen, Netmask, Network,
SetMasklen,
};

mod math_udfs;
Expand Down Expand Up @@ -49,7 +49,7 @@ fn register_network_udfs(ctx: &SessionContext) -> Result<()> {
register_hostmask(ctx);
register_inet_same_family(ctx);
register_inet_merge(ctx);
register_masklen(ctx);
ctx.register_udf(ScalarUDF::from(MaskLen::new()));
ctx.register_udf(ScalarUDF::from(Netmask::new()));
ctx.register_udf(ScalarUDF::from(Network::new()));
ctx.register_udf(ScalarUDF::from(SetMasklen::new()));
Expand Down Expand Up @@ -133,16 +133,3 @@ fn register_inet_merge(ctx: &SessionContext) {

ctx.register_udf(inet_merge_udf);
}

fn register_masklen(ctx: &SessionContext) {
let masklen_udf = make_scalar_function(masklen);
let return_type: ReturnTypeFunction = Arc::new(move |_| Ok(Arc::new(UInt8)));
let masklen_udf = ScalarUDF::new(
"masklen",
&Signature::uniform(1, vec![Utf8], Volatility::Immutable),
&return_type,
&masklen_udf,
);

ctx.register_udf(masklen_udf);
}
72 changes: 54 additions & 18 deletions src/postgres/network_udfs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use std::sync::Arc;

use datafusion::arrow::array::{Array, ArrayRef, BooleanArray, StringBuilder, UInt8Array};
use datafusion::arrow::datatypes::DataType;
use datafusion::arrow::datatypes::DataType::{Int64, Utf8};
use datafusion::arrow::datatypes::DataType::{Int64, UInt8, Utf8};
use datafusion::common::DataFusionError;
use datafusion::error::Result;
use datafusion::logical_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility};
Expand Down Expand Up @@ -255,25 +255,61 @@ pub fn inet_same_family(args: &[ArrayRef]) -> Result<ArrayRef> {

/// Extracts netmask length.
/// Returns NULL for columns with NULL values.
pub fn masklen(args: &[ArrayRef]) -> Result<ArrayRef> {
let mut int8array = UInt8Array::builder(args[0].len());
let ip_string = datafusion::common::cast::as_string_array(&args[0])?;
ip_string.iter().try_for_each(|ip_string| {
if let Some(ip_string) = ip_string {
let prefix_len = IpNet::from_str(ip_string)
.map_err(|e| {
DataFusionError::Internal(format!("Parsing {ip_string} failed with error {e}"))
})?
.prefix_len();
int8array.append_value(prefix_len);
Ok::<(), DataFusionError>(())
} else {
int8array.append_null();
Ok::<(), DataFusionError>(())
#[derive(Debug)]
pub struct MaskLen {
signature: Signature,
}

impl MaskLen {
pub fn new() -> Self {
Self {
signature: Signature::uniform(1, vec![Utf8], Volatility::Immutable),
}
})?;
}
}

Ok(Arc::new(int8array.finish()) as ArrayRef)
impl ScalarUDFImpl for MaskLen {
fn as_any(&self) -> &dyn std::any::Any {
self
}

fn name(&self) -> &str {
"masklen"
}

fn signature(&self) -> &Signature {
&self.signature
}

fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
Ok(UInt8)
}

fn invoke(&self, args: &[ColumnarValue]) -> Result<ColumnarValue> {
let args = ColumnarValue::values_to_arrays(args)?;
let mut int8array = UInt8Array::builder(args[0].len());
let ip_string = datafusion::common::cast::as_string_array(&args[0])?;
ip_string.iter().try_for_each(|ip_string| {
if let Some(ip_string) = ip_string {
let prefix_len = IpNet::from_str(ip_string)
.map_err(|e| {
DataFusionError::Internal(format!(
"Parsing {ip_string} failed with error {e}"
))
})?
.prefix_len();
int8array.append_value(prefix_len);
Ok::<(), DataFusionError>(())
} else {
int8array.append_null();
Ok::<(), DataFusionError>(())
}
})?;

Ok(ColumnarValue::Array(
Arc::new(int8array.finish()) as ArrayRef
))
}
}

/// Constructs netmask for network.
Expand Down

0 comments on commit d927c56

Please sign in to comment.