From 6b18d28e927a1ea06295dd09ee50ef0ad0d1d68e Mon Sep 17 00:00:00 2001 From: Dadepo Aderemi Date: Sat, 20 Apr 2024 17:22:21 +0400 Subject: [PATCH] switch implementation of netmask to use ScalarUDFImpl --- src/postgres/mod.rs | 17 ++------- src/postgres/network_udfs.rs | 70 +++++++++++++++++++++++++++--------- 2 files changed, 55 insertions(+), 32 deletions(-) diff --git a/src/postgres/mod.rs b/src/postgres/mod.rs index 3adce3d..a6b1067 100644 --- a/src/postgres/mod.rs +++ b/src/postgres/mod.rs @@ -13,7 +13,7 @@ use crate::postgres::math_udfs::{ Acosd, Asind, Atand, Ceiling, Cosd, Cotd, Div, Erf, Erfc, RandomNormal, Sind, Tand, }; use crate::postgres::network_udfs::{ - broadcast, family, host, hostmask, inet_merge, inet_same_family, masklen, netmask, Masklen, + broadcast, family, host, hostmask, inet_merge, inet_same_family, masklen, Masklen, Netmask, Network, }; @@ -50,7 +50,7 @@ fn register_network_udfs(ctx: &SessionContext) -> Result<()> { register_inet_same_family(ctx); register_inet_merge(ctx); register_masklen(ctx); - register_netmask(ctx); + ctx.register_udf(ScalarUDF::from(Netmask::new())); ctx.register_udf(ScalarUDF::from(Network::new())); ctx.register_udf(ScalarUDF::from(Masklen::new())); Ok(()) @@ -146,16 +146,3 @@ fn register_masklen(ctx: &SessionContext) { ctx.register_udf(masklen_udf); } - -fn register_netmask(ctx: &SessionContext) { - let netmask_udf = make_scalar_function(netmask); - let return_type: ReturnTypeFunction = Arc::new(move |_| Ok(Arc::new(Utf8))); - let netmask_udf = ScalarUDF::new( - "netmask", - &Signature::uniform(1, vec![Utf8], Volatility::Immutable), - &return_type, - &netmask_udf, - ); - - ctx.register_udf(netmask_udf); -} diff --git a/src/postgres/network_udfs.rs b/src/postgres/network_udfs.rs index 4b7ade6..7a5f5db 100644 --- a/src/postgres/network_udfs.rs +++ b/src/postgres/network_udfs.rs @@ -278,25 +278,61 @@ pub fn masklen(args: &[ArrayRef]) -> Result { /// Constructs netmask for network. /// Returns NULL for columns with NULL values. -pub fn netmask(args: &[ArrayRef]) -> Result { - let mut string_builder = StringBuilder::with_capacity(args[0].len(), u8::MAX as usize); - let ip_string = datafusion::common::cast::as_string_array(&args[0])?; - ip_string.iter().try_for_each(|ip_string| { - if let Some(ip_string) = ip_string { - let netmask = IpNet::from_str(ip_string) - .map_err(|e| { - DataFusionError::Internal(format!("Parsing {ip_string} failed with error {e}")) - })? - .netmask(); - string_builder.append_value(netmask.to_string()); - Ok::<(), DataFusionError>(()) - } else { - string_builder.append_null(); - Ok::<(), DataFusionError>(()) +#[derive(Debug)] +pub struct Netmask { + signature: Signature, +} + +impl Netmask { + pub fn new() -> Self { + Self { + signature: Signature::uniform(1, vec![Utf8], Volatility::Immutable), } - })?; + } +} - Ok(Arc::new(string_builder.finish()) as ArrayRef) +impl ScalarUDFImpl for Netmask { + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn name(&self) -> &str { + "netmask" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(Utf8) + } + + fn invoke(&self, args: &[ColumnarValue]) -> Result { + let args = ColumnarValue::values_to_arrays(args)?; + let mut string_builder = StringBuilder::with_capacity(args[0].len(), u8::MAX as usize); + let ip_string = datafusion::common::cast::as_string_array(&args[0])?; + ip_string.iter().try_for_each(|ip_string| { + if let Some(ip_string) = ip_string { + let netmask = IpNet::from_str(ip_string) + .map_err(|e| { + DataFusionError::Internal(format!( + "Parsing {ip_string} failed with error {e}" + )) + })? + .netmask(); + string_builder.append_value(netmask.to_string()); + Ok::<(), DataFusionError>(()) + } else { + string_builder.append_null(); + Ok::<(), DataFusionError>(()) + } + })?; + + Ok(ColumnarValue::Array( + Arc::new(string_builder.finish()) as ArrayRef + )) + } } /// Extracts network part of address.