From 860b603673678ead3ef82c7977d8808122aaa8fc Mon Sep 17 00:00:00 2001 From: Dadepo Aderemi Date: Sat, 20 Apr 2024 18:17:59 +0400 Subject: [PATCH] switch implementation of json_type to use ScalarUDFImpl --- src/sqlite/json_udfs.rs | 125 +++++++++++++++++++++++++--------------- src/sqlite/mod.rs | 17 +----- 2 files changed, 82 insertions(+), 60 deletions(-) diff --git a/src/sqlite/json_udfs.rs b/src/sqlite/json_udfs.rs index a15ac4b..c25667a 100644 --- a/src/sqlite/json_udfs.rs +++ b/src/sqlite/json_udfs.rs @@ -73,61 +73,96 @@ impl ScalarUDFImpl for Json { /// function returns the "type" of the element in X that is selected by path P. /// The "type" returned by json_type() is one of the following SQL /// text values: 'null', 'true', 'false', 'integer', 'real', 'text', 'array', or 'object'. -/// If the path P in json_type(X,P) selects an element that does not exist in X, +/// If the path P in json_type(X, P) selects an element that does not exist in X, /// then this function returns NULL. -pub fn json_type(args: &[ArrayRef]) -> Result { - if args.is_empty() || args.len() > 2 { - return Err(DataFusionError::Internal( - "wrong number of arguments to function json_type()".to_string(), - )); +#[derive(Debug)] +pub struct JsonType { + signature: Signature, +} + +impl JsonType { + pub fn new() -> Self { + Self { + signature: Signature::variadic(vec![Utf8], Volatility::Immutable), + } } +} - let mut string_builder = StringBuilder::with_capacity(args.len(), u8::MAX as usize); - if args.len() == 1 { - //1. Just json and no path - let json_strings = datafusion::common::cast::as_string_array(&args[0])?; - json_strings.iter().try_for_each(|json_string| { - if let Some(json_string) = json_string { - string_builder.append_value( - get_json_string_type(json_string) - .map_err(|err| DataFusionError::Internal(err.to_string()))?, - ); - Ok::<(), DataFusionError>(()) - } else { - string_builder.append_null(); - Ok::<(), DataFusionError>(()) - } - })?; - } else { - //2. Json and path - let json_strings = datafusion::common::cast::as_string_array(&args[0])?; - let paths = datafusion::common::cast::as_string_array(&args[1])?; - - json_strings - .iter() - .zip(paths.iter()) - .try_for_each(|(json_string, path)| { - if let (Some(json_string), Some(path)) = (json_string, path) { - match get_value_at_string(json_string, path) { - Ok(json_at_path) => { - string_builder.append_value( - get_json_type(&json_at_path) - .map_err(|err| DataFusionError::Internal(err.to_string()))?, - ); - } - Err(_) => { - string_builder.append_null(); - } - } +impl ScalarUDFImpl for JsonType { + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn name(&self) -> &str { + "json_type" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(Utf8) + } + + fn invoke(&self, args: &[ColumnarValue]) -> Result { + let args = ColumnarValue::values_to_arrays(args)?; + if args.is_empty() || args.len() > 2 { + return Err(DataFusionError::Internal( + "wrong number of arguments to function json_type()".to_string(), + )); + } + + let mut string_builder = StringBuilder::with_capacity(args.len(), u8::MAX as usize); + if args.len() == 1 { + //1. Just json and no path + let json_strings = datafusion::common::cast::as_string_array(&args[0])?; + json_strings.iter().try_for_each(|json_string| { + if let Some(json_string) = json_string { + string_builder.append_value( + get_json_string_type(json_string) + .map_err(|err| DataFusionError::Internal(err.to_string()))?, + ); Ok::<(), DataFusionError>(()) } else { string_builder.append_null(); Ok::<(), DataFusionError>(()) } })?; - } + } else { + //2. Json and path + let json_strings = datafusion::common::cast::as_string_array(&args[0])?; + let paths = datafusion::common::cast::as_string_array(&args[1])?; + + json_strings + .iter() + .zip(paths.iter()) + .try_for_each(|(json_string, path)| { + if let (Some(json_string), Some(path)) = (json_string, path) { + match get_value_at_string(json_string, path) { + Ok(json_at_path) => { + string_builder.append_value( + get_json_type(&json_at_path).map_err(|err| { + DataFusionError::Internal(err.to_string()) + })?, + ); + } + Err(_) => { + string_builder.append_null(); + } + } + Ok::<(), DataFusionError>(()) + } else { + string_builder.append_null(); + Ok::<(), DataFusionError>(()) + } + })?; + } - Ok(Arc::new(string_builder.finish()) as ArrayRef) + Ok(ColumnarValue::Array( + Arc::new(string_builder.finish()) as ArrayRef + )) + } } /// The json_valid(X) function return 1 if the argument X is well-formed canonical RFC-7159 JSON diff --git a/src/sqlite/mod.rs b/src/sqlite/mod.rs index 6c841aa..7de36e4 100644 --- a/src/sqlite/mod.rs +++ b/src/sqlite/mod.rs @@ -3,7 +3,7 @@ mod json_udfs; -use crate::sqlite::json_udfs::{json_type, json_valid, Json}; +use crate::sqlite::json_udfs::{json_valid, Json, JsonType}; use datafusion::arrow::datatypes::DataType::{UInt8, Utf8}; use datafusion::error::Result; use datafusion::logical_expr::{ReturnTypeFunction, ScalarUDF, Signature, Volatility}; @@ -13,24 +13,11 @@ use std::sync::Arc; pub fn register_sqlite_udfs(ctx: &SessionContext) -> Result<()> { ctx.register_udf(ScalarUDF::from(Json::new())); - register_json_type(ctx); + ctx.register_udf(ScalarUDF::from(JsonType::new())); register_json_valid(ctx); Ok(()) } -fn register_json_type(ctx: &SessionContext) { - let udf = make_scalar_function(json_type); - let return_type: ReturnTypeFunction = Arc::new(move |_| Ok(Arc::new(Utf8))); - let div_udf = ScalarUDF::new( - "json_type", - &Signature::variadic(vec![Utf8], Volatility::Immutable), - &return_type, - &udf, - ); - - ctx.register_udf(div_udf); -} - fn register_json_valid(ctx: &SessionContext) { let udf = make_scalar_function(json_valid); let return_type: ReturnTypeFunction = Arc::new(move |_| Ok(Arc::new(UInt8)));