Skip to content

Commit

Permalink
switch implementation of set_masklen to use ScalarUDFImpl
Browse files Browse the repository at this point in the history
  • Loading branch information
dadepo committed Apr 20, 2024
1 parent 0efa0ab commit 9837e6b
Show file tree
Hide file tree
Showing 2 changed files with 94 additions and 70 deletions.
19 changes: 3 additions & 16 deletions src/postgres/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

use std::sync::Arc;

use datafusion::arrow::datatypes::DataType::{Boolean, Int64, UInt8, Utf8};
use datafusion::arrow::datatypes::DataType::{Boolean, UInt8, Utf8};
use datafusion::error::Result;
use datafusion::logical_expr::{ReturnTypeFunction, ScalarUDF, Signature, Volatility};
use datafusion::physical_expr::functions::make_scalar_function;
Expand All @@ -14,7 +14,7 @@ use crate::postgres::math_udfs::{
};
use crate::postgres::network_udfs::{
broadcast, family, host, hostmask, inet_merge, inet_same_family, masklen, netmask, network,
set_masklen,
Masklen,
};

mod math_udfs;
Expand Down Expand Up @@ -52,7 +52,7 @@ fn register_network_udfs(ctx: &SessionContext) -> Result<()> {
register_masklen(ctx);
register_netmask(ctx);
register_network(ctx);
register_set_masklen(ctx);
ctx.register_udf(ScalarUDF::from(Masklen::new()));
Ok(())
}

Expand Down Expand Up @@ -172,16 +172,3 @@ fn register_network(ctx: &SessionContext) {

ctx.register_udf(network_udf);
}

fn register_set_masklen(ctx: &SessionContext) {
let set_masklen_udf = make_scalar_function(set_masklen);
let return_type: ReturnTypeFunction = Arc::new(move |_| Ok(Arc::new(Utf8)));
let set_masklen_udf = ScalarUDF::new(
"set_masklen",
&Signature::exact(vec![Utf8, Int64], Volatility::Immutable),
&return_type,
&set_masklen_udf,
);

ctx.register_udf(set_masklen_udf);
}
145 changes: 91 additions & 54 deletions src/postgres/network_udfs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,14 @@ use std::str::FromStr;
use std::sync::Arc;

use datafusion::arrow::array::{Array, ArrayRef, BooleanArray, StringBuilder, UInt8Array};
use datafusion::arrow::datatypes::DataType;
use datafusion::arrow::datatypes::DataType::{Int64, Utf8};
use datafusion::common::DataFusionError;
use datafusion::error::Result;
use datafusion::logical_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility};
use ipnet::{IpNet, Ipv4Net, Ipv6Net};

/// Gives the broadcast address for network.
/// Gives the broadcast address for the network.
/// Returns NULL for columns with NULL values.
pub fn broadcast(args: &[ArrayRef]) -> Result<ArrayRef> {
let mut string_builder = StringBuilder::with_capacity(args[0].len(), u8::MAX as usize);
Expand Down Expand Up @@ -323,75 +326,109 @@ pub fn network(args: &[ArrayRef]) -> Result<ArrayRef> {
/// If input is IP, The address part does not change.
/// If the input is a CIDR, Address bits to the right of the new netmask are set to zero.
/// Returns NULL if any of the columns contain NULL values.
pub fn set_masklen(args: &[ArrayRef]) -> Result<ArrayRef> {
let mut string_builder = StringBuilder::with_capacity(args[0].len(), u8::MAX as usize);
let cidr_strings = datafusion::common::cast::as_string_array(&args[0])?;
let prefix_lengths = datafusion::common::cast::as_int64_array(&args[1])?;
#[derive(Debug)]
pub struct Masklen {
signature: Signature,
}

if cidr_strings.len() != prefix_lengths.len() {
return Err(DataFusionError::Internal(
"Cidr count do not match prefix length count".to_string(),
));
impl Masklen {
pub fn new() -> Self {
Self {
signature: Signature::exact(vec![Utf8, Int64], Volatility::Immutable),
}
}
}

for i in 0..cidr_strings.len() {
let input_string = cidr_strings.value(i);
let prefix: u8 = prefix_lengths.value(i) as u8;
impl ScalarUDFImpl for Masklen {
fn as_any(&self) -> &dyn std::any::Any {
self
}

if input_string.is_empty() || prefix == 0 {
string_builder.append_null();
continue;
fn name(&self) -> &str {
"set_masklen"
}

fn signature(&self) -> &Signature {
&self.signature
}

fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
Ok(Utf8)
}

fn invoke(&self, args: &[ColumnarValue]) -> Result<ColumnarValue> {
let args = ColumnarValue::values_to_arrays(args)?;
let mut string_builder = StringBuilder::with_capacity(args[0].len(), u8::MAX as usize);
let cidr_strings = datafusion::common::cast::as_string_array(&args[0])?;
let prefix_lengths = datafusion::common::cast::as_int64_array(&args[1])?;

if cidr_strings.len() != prefix_lengths.len() {
return Err(DataFusionError::Internal(
"Cidr count do not match prefix length count".to_string(),
));
}

let is_cidr = input_string.contains('/');
for i in 0..cidr_strings.len() {
let input_string = cidr_strings.value(i);
let prefix: u8 = prefix_lengths.value(i) as u8;

let addr = if is_cidr {
IpNet::from_str(input_string)
.map_err(|e| {
if input_string.is_empty() || prefix == 0 {
string_builder.append_null();
continue;
}

let is_cidr = input_string.contains('/');

let addr = if is_cidr {
IpNet::from_str(input_string)
.map_err(|e| {
DataFusionError::Internal(format!(
"Parsing {input_string} into CIDR failed with error {e}"
))
})?
.network()
} else {
IpAddr::from_str(input_string).map_err(|e| {
DataFusionError::Internal(format!(
"Parsing {input_string} into CIDR failed with error {e}"
"Parsing {input_string} into IP address failed with error {e}"
))
})?
.network()
} else {
IpAddr::from_str(input_string).map_err(|e| {
DataFusionError::Internal(format!(
"Parsing {input_string} into IP address failed with error {e}"
))
})?
};

match addr {
IpAddr::V4(_) => {
if prefix > 32 {
return Err(DataFusionError::Internal(format!(
"ERROR: invalid mask length: {prefix}"
)));
};

match addr {
IpAddr::V4(_) => {
if prefix > 32 {
return Err(DataFusionError::Internal(format!(
"ERROR: invalid mask length: {prefix}"
)));
}
}
}
IpAddr::V6(_) => {
if prefix > 128 {
return Err(DataFusionError::Internal(format!(
"ERROR: invalid mask length: {prefix}"
)));
IpAddr::V6(_) => {
if prefix > 128 {
return Err(DataFusionError::Internal(format!(
"ERROR: invalid mask length: {prefix}"
)));
}
}
}
};
};

let mut new_cidr = IpNet::new(addr, prefix).map_err(|e| {
DataFusionError::Internal(format!(
"Creating CIDR from {addr} and prefix {prefix} failed with error {e}"
))
})?;
let mut new_cidr = IpNet::new(addr, prefix).map_err(|e| {
DataFusionError::Internal(format!(
"Creating CIDR from {addr} and prefix {prefix} failed with error {e}"
))
})?;

if is_cidr {
new_cidr = new_cidr.trunc();
};
if is_cidr {
new_cidr = new_cidr.trunc();
};

string_builder.append_value(new_cidr.to_string());
}
string_builder.append_value(new_cidr.to_string());
}

Ok(Arc::new(string_builder.finish()) as ArrayRef)
Ok(ColumnarValue::Array(
Arc::new(string_builder.finish()) as ArrayRef
))
}
}

fn bit_in_common(l: &[u8], r: &[u8], n: usize) -> usize {
Expand Down

0 comments on commit 9837e6b

Please sign in to comment.