Skip to content

Commit

Permalink
Added support for json_array_length
Browse files Browse the repository at this point in the history
  • Loading branch information
dadepo committed Apr 20, 2024
1 parent 660ccca commit 905fde4
Show file tree
Hide file tree
Showing 4 changed files with 266 additions and 7 deletions.
4 changes: 4 additions & 0 deletions src/common/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -67,5 +67,9 @@ mod test {
let value = json!({"foo": ["bar", "baz"]});
let result = get_value_at(value, "$.foo").unwrap();
assert_eq!(result, json!(["bar", "baz"]));

let value = json!({"foo": ["bar", "baz"]});
let result = get_value_at(value, "$").unwrap();
assert_eq!(result, json!({"foo": ["bar", "baz"]}));
}
}
264 changes: 259 additions & 5 deletions src/sqlite/json_udfs.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
use std::sync::Arc;

use crate::common::{get_json_string_type, get_json_type, get_value_at_string};
use datafusion::arrow::array::{Array, ArrayRef, StringBuilder, UInt8Array};
use datafusion::arrow::array::{Array, ArrayRef, StringBuilder, UInt64Array, UInt8Array};
use datafusion::arrow::datatypes::DataType;
use datafusion::arrow::datatypes::DataType::{UInt8, Utf8};
use datafusion::arrow::datatypes::DataType::{UInt64, UInt8, Utf8};
use datafusion::common::DataFusionError;
use datafusion::error::Result;
use datafusion::logical_expr::TypeSignature::Uniform;
use datafusion::logical_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility};
use serde_json::Value;

use crate::common::{get_json_string_type, get_json_type, get_value_at_string};

/// The json(X) function verifies that its argument X is a valid JSON string and returns a minified
/// version of that JSON string (with all unnecessary whitespace removed).
/// If X is not a well-formed JSON string, then this routine throws an error.
Expand Down Expand Up @@ -225,14 +227,115 @@ impl ScalarUDFImpl for JsonValid {
}
}

/// The json_array_length(X) function returns the number of elements in the JSON array X, or 0 if X is some kind of JSON value other than an array.
/// The json_array_length(X,P) locates the array at path P within X and returns the length of that array, or 0 if path P locates an element in X that is not a JSON array,
/// and NULL if path P does not locate any element of X. Errors are thrown if either X is not well-formed JSON or if P is not a well-formed path.
#[derive(Debug)]
pub struct JsonArrayLength {
signature: Signature,
}

impl JsonArrayLength {
pub fn new() -> Self {
Self {
signature: Signature::one_of(
vec![Uniform(1, vec![Utf8]), Uniform(2, vec![Utf8])],
Volatility::Volatile,
),
}
}
}

impl ScalarUDFImpl for JsonArrayLength {
fn as_any(&self) -> &dyn std::any::Any {
self
}

fn name(&self) -> &str {
"json_array_length"
}

fn signature(&self) -> &Signature {
&self.signature
}

fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
Ok(UInt64)
}

fn invoke(&self, args: &[ColumnarValue]) -> Result<ColumnarValue> {
let args = ColumnarValue::values_to_arrays(args)?;
let mut iter = args.iter();
let json_strings = iter.next().ok_or(DataFusionError::Execution(
"First input not set".to_string(),
))?;
let json_strings = datafusion::common::cast::as_string_array(&json_strings)?;
let paths = iter.next();
let mut uint_builder = UInt64Array::builder(json_strings.len());

match paths {
None => {
json_strings.iter().try_for_each(|json_string| {
if let Some(json_string) = json_string {
if let Ok(value) = get_value_at_string(json_string, "$") {
if let Some(value_array) = value.as_array() {
uint_builder.append_value(value_array.len() as u64);
} else {
uint_builder.append_value(0u64);
}
} else {
uint_builder.append_null();
}
Ok::<(), DataFusionError>(())
} else {
uint_builder.append_null();
Ok::<(), DataFusionError>(())
}
})?;
}
Some(paths) => {
let paths = datafusion::common::cast::as_string_array(&paths)?;
json_strings
.iter()
.zip(paths.iter())
.try_for_each(|(json_string, path)| {
if let (Some(json_string), Some(path)) = (json_string, path) {
match get_value_at_string(json_string, path) {
Ok(json_at_path) => {
if let Some(value_array) = json_at_path.as_array() {
uint_builder.append_value(value_array.len() as u64);
} else {
uint_builder.append_value(0u64);
}
}
Err(_) => {
uint_builder.append_null();
}
}
Ok::<(), DataFusionError>(())
} else {
uint_builder.append_null();
Ok::<(), DataFusionError>(())
}
})?;
}
}

Ok(ColumnarValue::Array(
Arc::new(uint_builder.finish()) as ArrayRef
))
}
}

#[cfg(feature = "sqlite")]
#[cfg(test)]
mod tests {
use crate::common::test_utils::set_up_json_data_test;
use crate::sqlite::register_sqlite_udfs;
use datafusion::assert_batches_sorted_eq;
use datafusion::prelude::SessionContext;

use crate::common::test_utils::set_up_json_data_test;
use crate::sqlite::register_sqlite_udfs;

use super::*;

#[tokio::test]
Expand Down Expand Up @@ -670,6 +773,157 @@ mod tests {
Ok(())
}

#[tokio::test]
async fn test_json_array_length() -> Result<()> {
let ctx = register_udfs_for_test()?;

let df = ctx
.sql(r#"select json_array_length('[1,2,3,4]') as len"#)
.await?;

let batches = df.clone().collect().await?;

let expected: Vec<&str> = r#"
+-----+
| len |
+-----+
| 4 |
+-----+"#
.split('\n')
.filter_map(|input| {
if input.is_empty() {
None
} else {
Some(input.trim())
}
})
.collect();

assert_batches_sorted_eq!(expected, &batches);

let df = ctx
.sql(r#"select json_array_length('[1,2,3,4]', '$') as len"#)
.await?;

let batches = df.clone().collect().await?;

let expected: Vec<&str> = r#"
+-----+
| len |
+-----+
| 4 |
+-----+"#
.split('\n')
.filter_map(|input| {
if input.is_empty() {
None
} else {
Some(input.trim())
}
})
.collect();

assert_batches_sorted_eq!(expected, &batches);

let df = ctx
.sql(r#"select json_array_length('[1,2,3,4]', '$[2]') as len"#)
.await?;

let batches = df.clone().collect().await?;

let expected: Vec<&str> = r#"
+-----+
| len |
+-----+
| 0 |
+-----+"#
.split('\n')
.filter_map(|input| {
if input.is_empty() {
None
} else {
Some(input.trim())
}
})
.collect();

assert_batches_sorted_eq!(expected, &batches);

let df = ctx
.sql(r#"select json_array_length('{"one":[1,2,3]}') as len"#)
.await?;

let batches = df.clone().collect().await?;

let expected: Vec<&str> = r#"
+-----+
| len |
+-----+
| 0 |
+-----+"#
.split('\n')
.filter_map(|input| {
if input.is_empty() {
None
} else {
Some(input.trim())
}
})
.collect();

assert_batches_sorted_eq!(expected, &batches);

let df = ctx
.sql(r#"select json_array_length('{"one":[1,2,3]}', '$.one') as len"#)
.await?;

let batches = df.clone().collect().await?;

let expected: Vec<&str> = r#"
+-----+
| len |
+-----+
| 3 |
+-----+"#
.split('\n')
.filter_map(|input| {
if input.is_empty() {
None
} else {
Some(input.trim())
}
})
.collect();

assert_batches_sorted_eq!(expected, &batches);

let df = ctx
.sql(r#"select json_array_length('{"one":[1,2,3]}', '$.two') as len"#)
.await?;

let batches = df.clone().collect().await?;

let expected: Vec<&str> = r#"
+-----+
| len |
+-----+
| |
+-----+"#
.split('\n')
.filter_map(|input| {
if input.is_empty() {
None
} else {
Some(input.trim())
}
})
.collect();

assert_batches_sorted_eq!(expected, &batches);

Ok(())
}

fn register_udfs_for_test() -> Result<SessionContext> {
let ctx = set_up_json_data_test()?;
register_sqlite_udfs(&ctx)?;
Expand Down
3 changes: 2 additions & 1 deletion src/sqlite/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,14 @@ use datafusion::error::Result;
use datafusion::logical_expr::ScalarUDF;
use datafusion::prelude::SessionContext;

use crate::sqlite::json_udfs::{Json, JsonType, JsonValid};
use crate::sqlite::json_udfs::{Json, JsonArrayLength, JsonType, JsonValid};

mod json_udfs;

pub fn register_sqlite_udfs(ctx: &SessionContext) -> Result<()> {
ctx.register_udf(ScalarUDF::from(Json::new()));
ctx.register_udf(ScalarUDF::from(JsonType::new()));
ctx.register_udf(ScalarUDF::from(JsonValid::new()));
ctx.register_udf(ScalarUDF::from(JsonArrayLength::new()));
Ok(())
}
2 changes: 1 addition & 1 deletion supports/sqlite.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
|-------------|---------------------|
| ✅︎ | json |
| ✅︎ | json_valid |
| 🚧︎︎ | json_array_length |
| ✅︎ | json_array_length |
| 🚧︎ | json_error_position |
| 🚧︎ | json_extract |
| 🚧︎︎ | json_insert |
Expand Down

0 comments on commit 905fde4

Please sign in to comment.