From 0b4a059f9406a00b760f4762b10bcdc7f8954183 Mon Sep 17 00:00:00 2001 From: Dadepo Aderemi Date: Sat, 20 Apr 2024 17:51:23 +0400 Subject: [PATCH] switch implementation of hostmask to use ScalarUDFImpl --- src/postgres/mod.rs | 17 ++------- src/postgres/network_udfs.rs | 70 +++++++++++++++++++++++++++--------- 2 files changed, 55 insertions(+), 32 deletions(-) diff --git a/src/postgres/mod.rs b/src/postgres/mod.rs index cce6455..231709c 100644 --- a/src/postgres/mod.rs +++ b/src/postgres/mod.rs @@ -13,7 +13,7 @@ use crate::postgres::math_udfs::{ Acosd, Asind, Atand, Ceiling, Cosd, Cotd, Div, Erf, Erfc, RandomNormal, Sind, Tand, }; use crate::postgres::network_udfs::{ - broadcast, family, host, hostmask, InetMerge, InetSameFamily, MaskLen, Netmask, Network, + broadcast, family, host, HostMask, InetMerge, InetSameFamily, MaskLen, Netmask, Network, SetMaskLen, }; @@ -46,7 +46,7 @@ fn register_network_udfs(ctx: &SessionContext) -> Result<()> { register_broadcast(ctx); register_family(ctx); register_host(ctx); - register_hostmask(ctx); + ctx.register_udf(ScalarUDF::from(HostMask::new())); ctx.register_udf(ScalarUDF::from(InetSameFamily::new())); ctx.register_udf(ScalarUDF::from(InetMerge::new())); ctx.register_udf(ScalarUDF::from(MaskLen::new())); @@ -94,16 +94,3 @@ fn register_host(ctx: &SessionContext) { ctx.register_udf(host_udf); } - -fn register_hostmask(ctx: &SessionContext) { - let hostmask_udf = make_scalar_function(hostmask); - let return_type: ReturnTypeFunction = Arc::new(move |_| Ok(Arc::new(Utf8))); - let hostmask_udf = ScalarUDF::new( - "hostmask", - &Signature::uniform(1, vec![Utf8], Volatility::Immutable), - &return_type, - &hostmask_udf, - ); - - ctx.register_udf(hostmask_udf); -} diff --git a/src/postgres/network_udfs.rs b/src/postgres/network_udfs.rs index 0697274..1e8b2f9 100644 --- a/src/postgres/network_udfs.rs +++ b/src/postgres/network_udfs.rs @@ -85,25 +85,61 @@ pub fn family(args: &[ArrayRef]) -> Result { /// Constructs host mask for network. /// Returns NULL for columns with NULL values. -pub fn hostmask(args: &[ArrayRef]) -> Result { - let mut string_builder = StringBuilder::with_capacity(args[0].len(), u8::MAX as usize); - let ip_string = datafusion::common::cast::as_string_array(&args[0])?; - ip_string.iter().try_for_each(|ip_string| { - if let Some(ip_string) = ip_string { - let hostmask = IpNet::from_str(ip_string) - .map_err(|e| { - DataFusionError::Internal(format!("Parsing {ip_string} failed with error {e}")) - })? - .hostmask(); - string_builder.append_value(hostmask.to_string()); - Ok::<(), DataFusionError>(()) - } else { - string_builder.append_null(); - Ok::<(), DataFusionError>(()) +#[derive(Debug)] +pub struct HostMask { + signature: Signature, +} + +impl HostMask { + pub fn new() -> Self { + Self { + signature: Signature::uniform(1, vec![Utf8], Volatility::Immutable), } - })?; + } +} - Ok(Arc::new(string_builder.finish()) as ArrayRef) +impl ScalarUDFImpl for HostMask { + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn name(&self) -> &str { + "hostmask" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(Utf8) + } + + fn invoke(&self, args: &[ColumnarValue]) -> Result { + let args = ColumnarValue::values_to_arrays(args)?; + let mut string_builder = StringBuilder::with_capacity(args[0].len(), u8::MAX as usize); + let ip_string = datafusion::common::cast::as_string_array(&args[0])?; + ip_string.iter().try_for_each(|ip_string| { + if let Some(ip_string) = ip_string { + let hostmask = IpNet::from_str(ip_string) + .map_err(|e| { + DataFusionError::Internal(format!( + "Parsing {ip_string} failed with error {e}" + )) + })? + .hostmask(); + string_builder.append_value(hostmask.to_string()); + Ok::<(), DataFusionError>(()) + } else { + string_builder.append_null(); + Ok::<(), DataFusionError>(()) + } + })?; + + Ok(ColumnarValue::Array( + Arc::new(string_builder.finish()) as ArrayRef + )) + } } /// Checks if IP address is from the same family.