Skip to content

Commit

Permalink
Updated test for sqlite json_valid
Browse files Browse the repository at this point in the history
  • Loading branch information
dadepo committed Nov 18, 2023
1 parent e415e56 commit e2f4ab5
Show file tree
Hide file tree
Showing 5 changed files with 70 additions and 29 deletions.
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
[workspace]
resolver = "2"

members = [
"common",
Expand Down
31 changes: 30 additions & 1 deletion common/src/test_utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use datafusion::error::Result;
use datafusion::prelude::SessionContext;
use std::sync::Arc;

pub fn set_up_test_datafusion() -> Result<SessionContext> {
pub fn set_up_network_data_test() -> Result<SessionContext> {
// define a schema.
let schema = Arc::new(Schema::new(vec![
Field::new("index", DataType::UInt8, false),
Expand Down Expand Up @@ -43,3 +43,32 @@ pub fn set_up_test_datafusion() -> Result<SessionContext> {
// declare a table in memory.
Ok(ctx)
}
pub fn set_up_json_data_test() -> Result<SessionContext> {
// define a schema.
let schema = Arc::new(Schema::new(vec![
Field::new("index", DataType::UInt8, false),
Field::new("json_data", DataType::Utf8, true),
]));

// define data.
let batch = RecordBatch::try_new(
schema,
vec![
Arc::new(UInt8Array::from_iter_values([1, 2])),
Arc::new(StringArray::from(vec![
Some(r#" { "this" : "is", "a": [ "test" ] } "#),
// Some("172.16.0.0/20"),
// Some("10.0.0.0/16"),
// Some("2001:0db8::/32"),
// Some("2001:db8:abcd::/48"),
None,
])),
],
)?;

// declare a new context
let ctx = SessionContext::new();
ctx.register_batch("json_table", batch)?;
// declare a table in memory.
Ok(ctx)
}
4 changes: 2 additions & 2 deletions df_extras_postgres/src/math_udfs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ pub fn div(args: &[ArrayRef]) -> Result<ArrayRef> {
#[cfg(feature = "postgres")]
#[cfg(test)]
mod tests {
use common::test_utils::set_up_test_datafusion;
use common::test_utils::set_up_network_data_test;
use datafusion::assert_batches_sorted_eq;
use datafusion::prelude::SessionContext;

Expand Down Expand Up @@ -100,7 +100,7 @@ mod tests {
}

fn register_udfs_for_test() -> Result<SessionContext> {
let ctx = set_up_test_datafusion()?;
let ctx = set_up_network_data_test()?;
register_udfs(&ctx)?;
Ok(ctx)
}
Expand Down
4 changes: 2 additions & 2 deletions df_extras_postgres/src/network_udfs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -423,7 +423,7 @@ mod tests {
use datafusion::assert_batches_sorted_eq;
use datafusion::prelude::SessionContext;

use common::test_utils::set_up_test_datafusion;
use common::test_utils::set_up_network_data_test;

use crate::register_udfs;

Expand Down Expand Up @@ -922,7 +922,7 @@ mod tests {
}

fn register_udfs_for_test() -> Result<SessionContext> {
let ctx = set_up_test_datafusion()?;
let ctx = set_up_network_data_test()?;
register_udfs(&ctx)?;
Ok(ctx)
}
Expand Down
59 changes: 35 additions & 24 deletions df_extras_sqlite/src/json_udfs.rs
Original file line number Diff line number Diff line change
@@ -1,22 +1,28 @@
use std::sync::Arc;

use datafusion::arrow::array::{Array, ArrayRef, StringBuilder, UInt8Array};
use datafusion::common::DataFusionError;
use datafusion::error::Result;
use serde_json::Value;
use std::sync::Arc;

pub fn json(args: &[ArrayRef]) -> Result<ArrayRef> {
let json_strings = datafusion::common::cast::as_string_array(&args[0])?;

let mut string_builder = StringBuilder::with_capacity(json_strings.len(), u8::MAX as usize);
json_strings.iter().flatten().try_for_each(|json_string| {
let value: Value = serde_json::from_str(json_string).map_err(|e| {
DataFusionError::Internal(format!("Parsing {json_string} failed with error {e}"))
})?;
let pretty_json = serde_json::to_string(&value).map_err(|e| {
DataFusionError::Internal(format!("Parsing {json_string} failed with error {e}"))
})?;
string_builder.append_value(pretty_json);
Ok::<(), DataFusionError>(())
json_strings.iter().try_for_each(|json_string| {
if let Some(json_string) = json_string {
let value: Value = serde_json::from_str(json_string).map_err(|e| {
DataFusionError::Internal(format!("Parsing {json_string} failed with error {e}"))
})?;
let pretty_json = serde_json::to_string(&value).map_err(|e| {
DataFusionError::Internal(format!("Parsing {json_string} failed with error {e}"))
})?;
string_builder.append_value(pretty_json);
Ok::<(), DataFusionError>(())
} else {
string_builder.append_null();
Ok::<(), DataFusionError>(())
}
})?;

Ok(Arc::new(string_builder.finish()) as ArrayRef)
Expand All @@ -41,10 +47,11 @@ pub fn json_valid(args: &[ArrayRef]) -> Result<ArrayRef> {
#[cfg(feature = "sqlite")]
#[cfg(test)]
mod tests {
use common::test_utils::set_up_test_datafusion;
use datafusion::assert_batches_sorted_eq;
use datafusion::prelude::SessionContext;

use common::test_utils::set_up_json_data_test;

use crate::register_udfs;

use super::*;
Expand All @@ -53,17 +60,20 @@ mod tests {
async fn test_json() -> Result<()> {
let ctx = register_udfs_for_test()?;
let df = ctx
.sql(r#"select json(' { "this" : "is", "a": [ "test" ] } ') as col_result"#)
.sql(
r#"select index, json(json_data) as col_result FROM json_table ORDER BY index ASC"#,
)
.await?;

let batches = df.clone().collect().await?;

let expected: Vec<&str> = r#"
+----------------------------+
| col_result |
+----------------------------+
| {"this":"is","a":["test"]} |
+----------------------------+"#
+-------+----------------------------+
| index | col_result |
+-------+----------------------------+
| 1 | {"this":"is","a":["test"]} |
| 2 | |
+-------+----------------------------+"#
.split('\n')
.filter_map(|input| {
if input.is_empty() {
Expand All @@ -80,16 +90,17 @@ mod tests {
#[tokio::test]
async fn test_json_valid() -> Result<()> {
let ctx = register_udfs_for_test()?;
let df = ctx.sql(r#"select json_valid(null) as col_result"#).await?;
let df = ctx.sql(r#"select index, json_valid(json_data) as col_result FROM json_table ORDER BY index ASC"#).await?;

let batches = df.clone().collect().await?;

let expected: Vec<&str> = r#"
+------------+
| col_result |
+------------+
| |
+------------+"#
+-------+------------+
| index | col_result |
+-------+------------+
| 1 | 1 |
| 2 | |
+-------+------------+"#
.split('\n')
.filter_map(|input| {
if input.is_empty() {
Expand All @@ -104,7 +115,7 @@ mod tests {
}

fn register_udfs_for_test() -> Result<SessionContext> {
let ctx = set_up_test_datafusion()?;
let ctx = set_up_json_data_test()?;
register_udfs(&ctx)?;
Ok(ctx)
}
Expand Down

0 comments on commit e2f4ab5

Please sign in to comment.