diff --git a/src/sqlite/json_udfs.rs b/src/sqlite/json_udfs.rs index c25667a..c8150c4 100644 --- a/src/sqlite/json_udfs.rs +++ b/src/sqlite/json_udfs.rs @@ -3,7 +3,7 @@ use std::sync::Arc; use crate::common::{get_json_string_type, get_json_type, get_value_at_string}; use datafusion::arrow::array::{Array, ArrayRef, StringBuilder, UInt8Array}; use datafusion::arrow::datatypes::DataType; -use datafusion::arrow::datatypes::DataType::Utf8; +use datafusion::arrow::datatypes::DataType::{UInt8, Utf8}; use datafusion::common::DataFusionError; use datafusion::error::Result; use datafusion::logical_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; @@ -171,23 +171,58 @@ impl ScalarUDFImpl for JsonType { /// /// Examples: /// -/// json_valid('{"x":35}') → 1 -/// json_valid('{"x":35') → 0 +/// json_valid('{"x": 35}') → 1 +/// json_valid('{"x": 35') → 0 /// json_valid(NULL) → NULL -pub fn json_valid(args: &[ArrayRef]) -> Result { - let json_strings = datafusion::common::cast::as_string_array(&args[0])?; - let mut uint_builder = UInt8Array::builder(json_strings.len()); - - json_strings.iter().for_each(|json_string| { - if let Some(json_string) = json_string { - let json_value: serde_json::error::Result = serde_json::from_str(json_string); - uint_builder.append_value(json_value.is_ok() as u8); - } else { - uint_builder.append_null(); +#[derive(Debug)] +pub struct JsonValid { + signature: Signature, +} + +impl JsonValid { + pub fn new() -> Self { + Self { + signature: Signature::uniform(1, vec![Utf8], Volatility::Immutable), } - }); + } +} + +impl ScalarUDFImpl for JsonValid { + fn as_any(&self) -> &dyn std::any::Any { + self + } - Ok(Arc::new(uint_builder.finish()) as ArrayRef) + fn name(&self) -> &str { + "json_valid" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(UInt8) + } + + fn invoke(&self, args: &[ColumnarValue]) -> Result { + let args = ColumnarValue::values_to_arrays(args)?; + let json_strings = datafusion::common::cast::as_string_array(&args[0])?; + let mut uint_builder = UInt8Array::builder(json_strings.len()); + + json_strings.iter().for_each(|json_string| { + if let Some(json_string) = json_string { + let json_value: serde_json::error::Result = + serde_json::from_str(json_string); + uint_builder.append_value(json_value.is_ok() as u8); + } else { + uint_builder.append_null(); + } + }); + + Ok(ColumnarValue::Array( + Arc::new(uint_builder.finish()) as ArrayRef + )) + } } #[cfg(feature = "sqlite")] diff --git a/src/sqlite/mod.rs b/src/sqlite/mod.rs index 7de36e4..6af74c7 100644 --- a/src/sqlite/mod.rs +++ b/src/sqlite/mod.rs @@ -1,32 +1,17 @@ #![cfg(feature = "sqlite")] #![allow(deprecated)] -mod json_udfs; - -use crate::sqlite::json_udfs::{json_valid, Json, JsonType}; -use datafusion::arrow::datatypes::DataType::{UInt8, Utf8}; use datafusion::error::Result; -use datafusion::logical_expr::{ReturnTypeFunction, ScalarUDF, Signature, Volatility}; -use datafusion::physical_expr::functions::make_scalar_function; +use datafusion::logical_expr::ScalarUDF; use datafusion::prelude::SessionContext; -use std::sync::Arc; + +use crate::sqlite::json_udfs::{Json, JsonType, JsonValid}; + +mod json_udfs; pub fn register_sqlite_udfs(ctx: &SessionContext) -> Result<()> { ctx.register_udf(ScalarUDF::from(Json::new())); ctx.register_udf(ScalarUDF::from(JsonType::new())); - register_json_valid(ctx); + ctx.register_udf(ScalarUDF::from(JsonValid::new())); Ok(()) } - -fn register_json_valid(ctx: &SessionContext) { - let udf = make_scalar_function(json_valid); - let return_type: ReturnTypeFunction = Arc::new(move |_| Ok(Arc::new(UInt8))); - let json_valid_udf = ScalarUDF::new( - "json_valid", - &Signature::uniform(1, vec![Utf8], Volatility::Immutable), - &return_type, - &udf, - ); - - ctx.register_udf(json_valid_udf); -}