Skip to content

Commit

Permalink
switch implementation of host to use ScalarUDFImpl
Browse files Browse the repository at this point in the history
  • Loading branch information
dadepo committed Apr 20, 2024
1 parent 0b4a059 commit 6653d96
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 38 deletions.
17 changes: 2 additions & 15 deletions src/postgres/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ use crate::postgres::math_udfs::{
Acosd, Asind, Atand, Ceiling, Cosd, Cotd, Div, Erf, Erfc, RandomNormal, Sind, Tand,
};
use crate::postgres::network_udfs::{
broadcast, family, host, HostMask, InetMerge, InetSameFamily, MaskLen, Netmask, Network,
broadcast, family, Host, HostMask, InetMerge, InetSameFamily, MaskLen, Netmask, Network,
SetMaskLen,
};

Expand Down Expand Up @@ -45,7 +45,7 @@ fn register_math_udfs(ctx: &SessionContext) -> Result<()> {
fn register_network_udfs(ctx: &SessionContext) -> Result<()> {
register_broadcast(ctx);
register_family(ctx);
register_host(ctx);
ctx.register_udf(ScalarUDF::from(Host::new()));
ctx.register_udf(ScalarUDF::from(HostMask::new()));
ctx.register_udf(ScalarUDF::from(InetSameFamily::new()));
ctx.register_udf(ScalarUDF::from(InetMerge::new()));
Expand Down Expand Up @@ -81,16 +81,3 @@ fn register_family(ctx: &SessionContext) {

ctx.register_udf(family_udf);
}

fn register_host(ctx: &SessionContext) {
let host_udf = make_scalar_function(host);
let return_type: ReturnTypeFunction = Arc::new(move |_| Ok(Arc::new(Utf8)));
let host_udf = ScalarUDF::new(
"host",
&Signature::uniform(1, vec![Utf8], Volatility::Immutable),
&return_type,
&host_udf,
);

ctx.register_udf(host_udf);
}
82 changes: 59 additions & 23 deletions src/postgres/network_udfs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,29 +33,6 @@ pub fn broadcast(args: &[ArrayRef]) -> Result<ArrayRef> {
Ok(Arc::new(string_builder.finish()) as ArrayRef)
}

/// Gives the host address for network.
/// Returns NULL for columns with NULL values.
pub fn host(args: &[ArrayRef]) -> Result<ArrayRef> {
let mut string_builder = StringBuilder::with_capacity(args[0].len(), u8::MAX as usize);
let ip_string = datafusion::common::cast::as_string_array(&args[0])?;
ip_string.iter().try_for_each(|ip_string| {
if let Some(ip_string) = ip_string {
let host_address = IpNet::from_str(ip_string)
.map_err(|e| {
DataFusionError::Internal(format!("Parsing {ip_string} failed with error {e}"))
})?
.network();
string_builder.append_value(host_address.to_string());
Ok::<(), DataFusionError>(())
} else {
string_builder.append_null();
Ok::<(), DataFusionError>(())
}
})?;

Ok(Arc::new(string_builder.finish()) as ArrayRef)
}

/// Returns the address's family: 4 for IPv4, 6 for IPv6.
/// Returns NULL for columns with NULL values.
pub fn family(args: &[ArrayRef]) -> Result<ArrayRef> {
Expand Down Expand Up @@ -83,6 +60,65 @@ pub fn family(args: &[ArrayRef]) -> Result<ArrayRef> {
Ok(Arc::new(int8array.finish()) as ArrayRef)
}

/// Gives the host address for the network.
/// Returns NULL for columns with NULL values.
#[derive(Debug)]
pub struct Host {
signature: Signature,
}

impl Host {
pub fn new() -> Self {
Self {
signature: Signature::uniform(1, vec![Utf8], Volatility::Immutable),
}
}
}

impl ScalarUDFImpl for Host {
fn as_any(&self) -> &dyn std::any::Any {
self
}

fn name(&self) -> &str {
"host"
}

fn signature(&self) -> &Signature {
&self.signature
}

fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
Ok(Utf8)
}

fn invoke(&self, args: &[ColumnarValue]) -> Result<ColumnarValue> {
let args = ColumnarValue::values_to_arrays(args)?;
let mut string_builder = StringBuilder::with_capacity(args[0].len(), u8::MAX as usize);
let ip_string = datafusion::common::cast::as_string_array(&args[0])?;
ip_string.iter().try_for_each(|ip_string| {
if let Some(ip_string) = ip_string {
let host_address = IpNet::from_str(ip_string)
.map_err(|e| {
DataFusionError::Internal(format!(
"Parsing {ip_string} failed with error {e}"
))
})?
.network();
string_builder.append_value(host_address.to_string());
Ok::<(), DataFusionError>(())
} else {
string_builder.append_null();
Ok::<(), DataFusionError>(())
}
})?;

Ok(ColumnarValue::Array(
Arc::new(string_builder.finish()) as ArrayRef
))
}
}

/// Constructs host mask for network.
/// Returns NULL for columns with NULL values.
#[derive(Debug)]
Expand Down

0 comments on commit 6653d96

Please sign in to comment.