From 905fde4d79bfb63e6d3ede337280cb4232ad8632 Mon Sep 17 00:00:00 2001 From: Dadepo Aderemi Date: Sat, 20 Apr 2024 21:31:49 +0400 Subject: [PATCH] Added support for json_array_length --- src/common/mod.rs | 4 + src/sqlite/json_udfs.rs | 264 +++++++++++++++++++++++++++++++++++++++- src/sqlite/mod.rs | 3 +- supports/sqlite.md | 2 +- 4 files changed, 266 insertions(+), 7 deletions(-) diff --git a/src/common/mod.rs b/src/common/mod.rs index 6590627..8f536a8 100644 --- a/src/common/mod.rs +++ b/src/common/mod.rs @@ -67,5 +67,9 @@ mod test { let value = json!({"foo": ["bar", "baz"]}); let result = get_value_at(value, "$.foo").unwrap(); assert_eq!(result, json!(["bar", "baz"])); + + let value = json!({"foo": ["bar", "baz"]}); + let result = get_value_at(value, "$").unwrap(); + assert_eq!(result, json!({"foo": ["bar", "baz"]})); } } diff --git a/src/sqlite/json_udfs.rs b/src/sqlite/json_udfs.rs index c8150c4..2d3f066 100644 --- a/src/sqlite/json_udfs.rs +++ b/src/sqlite/json_udfs.rs @@ -1,14 +1,16 @@ use std::sync::Arc; -use crate::common::{get_json_string_type, get_json_type, get_value_at_string}; -use datafusion::arrow::array::{Array, ArrayRef, StringBuilder, UInt8Array}; +use datafusion::arrow::array::{Array, ArrayRef, StringBuilder, UInt64Array, UInt8Array}; use datafusion::arrow::datatypes::DataType; -use datafusion::arrow::datatypes::DataType::{UInt8, Utf8}; +use datafusion::arrow::datatypes::DataType::{UInt64, UInt8, Utf8}; use datafusion::common::DataFusionError; use datafusion::error::Result; +use datafusion::logical_expr::TypeSignature::Uniform; use datafusion::logical_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; use serde_json::Value; +use crate::common::{get_json_string_type, get_json_type, get_value_at_string}; + /// The json(X) function verifies that its argument X is a valid JSON string and returns a minified /// version of that JSON string (with all unnecessary whitespace removed). /// If X is not a well-formed JSON string, then this routine throws an error. @@ -225,14 +227,115 @@ impl ScalarUDFImpl for JsonValid { } } +/// The json_array_length(X) function returns the number of elements in the JSON array X, or 0 if X is some kind of JSON value other than an array. +/// The json_array_length(X,P) locates the array at path P within X and returns the length of that array, or 0 if path P locates an element in X that is not a JSON array, +/// and NULL if path P does not locate any element of X. Errors are thrown if either X is not well-formed JSON or if P is not a well-formed path. +#[derive(Debug)] +pub struct JsonArrayLength { + signature: Signature, +} + +impl JsonArrayLength { + pub fn new() -> Self { + Self { + signature: Signature::one_of( + vec![Uniform(1, vec![Utf8]), Uniform(2, vec![Utf8])], + Volatility::Volatile, + ), + } + } +} + +impl ScalarUDFImpl for JsonArrayLength { + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn name(&self) -> &str { + "json_array_length" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(UInt64) + } + + fn invoke(&self, args: &[ColumnarValue]) -> Result { + let args = ColumnarValue::values_to_arrays(args)?; + let mut iter = args.iter(); + let json_strings = iter.next().ok_or(DataFusionError::Execution( + "First input not set".to_string(), + ))?; + let json_strings = datafusion::common::cast::as_string_array(&json_strings)?; + let paths = iter.next(); + let mut uint_builder = UInt64Array::builder(json_strings.len()); + + match paths { + None => { + json_strings.iter().try_for_each(|json_string| { + if let Some(json_string) = json_string { + if let Ok(value) = get_value_at_string(json_string, "$") { + if let Some(value_array) = value.as_array() { + uint_builder.append_value(value_array.len() as u64); + } else { + uint_builder.append_value(0u64); + } + } else { + uint_builder.append_null(); + } + Ok::<(), DataFusionError>(()) + } else { + uint_builder.append_null(); + Ok::<(), DataFusionError>(()) + } + })?; + } + Some(paths) => { + let paths = datafusion::common::cast::as_string_array(&paths)?; + json_strings + .iter() + .zip(paths.iter()) + .try_for_each(|(json_string, path)| { + if let (Some(json_string), Some(path)) = (json_string, path) { + match get_value_at_string(json_string, path) { + Ok(json_at_path) => { + if let Some(value_array) = json_at_path.as_array() { + uint_builder.append_value(value_array.len() as u64); + } else { + uint_builder.append_value(0u64); + } + } + Err(_) => { + uint_builder.append_null(); + } + } + Ok::<(), DataFusionError>(()) + } else { + uint_builder.append_null(); + Ok::<(), DataFusionError>(()) + } + })?; + } + } + + Ok(ColumnarValue::Array( + Arc::new(uint_builder.finish()) as ArrayRef + )) + } +} + #[cfg(feature = "sqlite")] #[cfg(test)] mod tests { - use crate::common::test_utils::set_up_json_data_test; - use crate::sqlite::register_sqlite_udfs; use datafusion::assert_batches_sorted_eq; use datafusion::prelude::SessionContext; + use crate::common::test_utils::set_up_json_data_test; + use crate::sqlite::register_sqlite_udfs; + use super::*; #[tokio::test] @@ -670,6 +773,157 @@ mod tests { Ok(()) } + #[tokio::test] + async fn test_json_array_length() -> Result<()> { + let ctx = register_udfs_for_test()?; + + let df = ctx + .sql(r#"select json_array_length('[1,2,3,4]') as len"#) + .await?; + + let batches = df.clone().collect().await?; + + let expected: Vec<&str> = r#" ++-----+ +| len | ++-----+ +| 4 | ++-----+"# + .split('\n') + .filter_map(|input| { + if input.is_empty() { + None + } else { + Some(input.trim()) + } + }) + .collect(); + + assert_batches_sorted_eq!(expected, &batches); + + let df = ctx + .sql(r#"select json_array_length('[1,2,3,4]', '$') as len"#) + .await?; + + let batches = df.clone().collect().await?; + + let expected: Vec<&str> = r#" ++-----+ +| len | ++-----+ +| 4 | ++-----+"# + .split('\n') + .filter_map(|input| { + if input.is_empty() { + None + } else { + Some(input.trim()) + } + }) + .collect(); + + assert_batches_sorted_eq!(expected, &batches); + + let df = ctx + .sql(r#"select json_array_length('[1,2,3,4]', '$[2]') as len"#) + .await?; + + let batches = df.clone().collect().await?; + + let expected: Vec<&str> = r#" ++-----+ +| len | ++-----+ +| 0 | ++-----+"# + .split('\n') + .filter_map(|input| { + if input.is_empty() { + None + } else { + Some(input.trim()) + } + }) + .collect(); + + assert_batches_sorted_eq!(expected, &batches); + + let df = ctx + .sql(r#"select json_array_length('{"one":[1,2,3]}') as len"#) + .await?; + + let batches = df.clone().collect().await?; + + let expected: Vec<&str> = r#" ++-----+ +| len | ++-----+ +| 0 | ++-----+"# + .split('\n') + .filter_map(|input| { + if input.is_empty() { + None + } else { + Some(input.trim()) + } + }) + .collect(); + + assert_batches_sorted_eq!(expected, &batches); + + let df = ctx + .sql(r#"select json_array_length('{"one":[1,2,3]}', '$.one') as len"#) + .await?; + + let batches = df.clone().collect().await?; + + let expected: Vec<&str> = r#" ++-----+ +| len | ++-----+ +| 3 | ++-----+"# + .split('\n') + .filter_map(|input| { + if input.is_empty() { + None + } else { + Some(input.trim()) + } + }) + .collect(); + + assert_batches_sorted_eq!(expected, &batches); + + let df = ctx + .sql(r#"select json_array_length('{"one":[1,2,3]}', '$.two') as len"#) + .await?; + + let batches = df.clone().collect().await?; + + let expected: Vec<&str> = r#" ++-----+ +| len | ++-----+ +| | ++-----+"# + .split('\n') + .filter_map(|input| { + if input.is_empty() { + None + } else { + Some(input.trim()) + } + }) + .collect(); + + assert_batches_sorted_eq!(expected, &batches); + + Ok(()) + } + fn register_udfs_for_test() -> Result { let ctx = set_up_json_data_test()?; register_sqlite_udfs(&ctx)?; diff --git a/src/sqlite/mod.rs b/src/sqlite/mod.rs index 6af74c7..3e94542 100644 --- a/src/sqlite/mod.rs +++ b/src/sqlite/mod.rs @@ -5,7 +5,7 @@ use datafusion::error::Result; use datafusion::logical_expr::ScalarUDF; use datafusion::prelude::SessionContext; -use crate::sqlite::json_udfs::{Json, JsonType, JsonValid}; +use crate::sqlite::json_udfs::{Json, JsonArrayLength, JsonType, JsonValid}; mod json_udfs; @@ -13,5 +13,6 @@ pub fn register_sqlite_udfs(ctx: &SessionContext) -> Result<()> { ctx.register_udf(ScalarUDF::from(Json::new())); ctx.register_udf(ScalarUDF::from(JsonType::new())); ctx.register_udf(ScalarUDF::from(JsonValid::new())); + ctx.register_udf(ScalarUDF::from(JsonArrayLength::new())); Ok(()) } diff --git a/supports/sqlite.md b/supports/sqlite.md index 5b2cbfa..dbd75be 100644 --- a/supports/sqlite.md +++ b/supports/sqlite.md @@ -7,7 +7,7 @@ |-------------|---------------------| | ✅︎ | json | | ✅︎ | json_valid | -| 🚧︎︎ | json_array_length | +| ✅︎ | json_array_length | | 🚧︎ | json_error_position | | 🚧︎ | json_extract | | 🚧︎︎ | json_insert |