Skip to content

Commit

Permalink
switch implementation of json_type to use ScalarUDFImpl
Browse files Browse the repository at this point in the history
  • Loading branch information
dadepo committed Apr 20, 2024
1 parent 7a16a53 commit 860b603
Show file tree
Hide file tree
Showing 2 changed files with 82 additions and 60 deletions.
125 changes: 80 additions & 45 deletions src/sqlite/json_udfs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -73,61 +73,96 @@ impl ScalarUDFImpl for Json {
/// function returns the "type" of the element in X that is selected by path P.
/// The "type" returned by json_type() is one of the following SQL
/// text values: 'null', 'true', 'false', 'integer', 'real', 'text', 'array', or 'object'.
/// If the path P in json_type(X,P) selects an element that does not exist in X,
/// If the path P in json_type(X, P) selects an element that does not exist in X,
/// then this function returns NULL.
pub fn json_type(args: &[ArrayRef]) -> Result<ArrayRef> {
if args.is_empty() || args.len() > 2 {
return Err(DataFusionError::Internal(
"wrong number of arguments to function json_type()".to_string(),
));
#[derive(Debug)]
pub struct JsonType {
signature: Signature,
}

impl JsonType {
pub fn new() -> Self {
Self {
signature: Signature::variadic(vec![Utf8], Volatility::Immutable),
}
}
}

let mut string_builder = StringBuilder::with_capacity(args.len(), u8::MAX as usize);
if args.len() == 1 {
//1. Just json and no path
let json_strings = datafusion::common::cast::as_string_array(&args[0])?;
json_strings.iter().try_for_each(|json_string| {
if let Some(json_string) = json_string {
string_builder.append_value(
get_json_string_type(json_string)
.map_err(|err| DataFusionError::Internal(err.to_string()))?,
);
Ok::<(), DataFusionError>(())
} else {
string_builder.append_null();
Ok::<(), DataFusionError>(())
}
})?;
} else {
//2. Json and path
let json_strings = datafusion::common::cast::as_string_array(&args[0])?;
let paths = datafusion::common::cast::as_string_array(&args[1])?;

json_strings
.iter()
.zip(paths.iter())
.try_for_each(|(json_string, path)| {
if let (Some(json_string), Some(path)) = (json_string, path) {
match get_value_at_string(json_string, path) {
Ok(json_at_path) => {
string_builder.append_value(
get_json_type(&json_at_path)
.map_err(|err| DataFusionError::Internal(err.to_string()))?,
);
}
Err(_) => {
string_builder.append_null();
}
}
impl ScalarUDFImpl for JsonType {
fn as_any(&self) -> &dyn std::any::Any {
self
}

fn name(&self) -> &str {
"json_type"
}

fn signature(&self) -> &Signature {
&self.signature
}

fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
Ok(Utf8)
}

fn invoke(&self, args: &[ColumnarValue]) -> Result<ColumnarValue> {
let args = ColumnarValue::values_to_arrays(args)?;
if args.is_empty() || args.len() > 2 {
return Err(DataFusionError::Internal(
"wrong number of arguments to function json_type()".to_string(),
));
}

let mut string_builder = StringBuilder::with_capacity(args.len(), u8::MAX as usize);
if args.len() == 1 {
//1. Just json and no path
let json_strings = datafusion::common::cast::as_string_array(&args[0])?;
json_strings.iter().try_for_each(|json_string| {
if let Some(json_string) = json_string {
string_builder.append_value(
get_json_string_type(json_string)
.map_err(|err| DataFusionError::Internal(err.to_string()))?,
);
Ok::<(), DataFusionError>(())
} else {
string_builder.append_null();
Ok::<(), DataFusionError>(())
}
})?;
}
} else {
//2. Json and path
let json_strings = datafusion::common::cast::as_string_array(&args[0])?;
let paths = datafusion::common::cast::as_string_array(&args[1])?;

json_strings
.iter()
.zip(paths.iter())
.try_for_each(|(json_string, path)| {
if let (Some(json_string), Some(path)) = (json_string, path) {
match get_value_at_string(json_string, path) {
Ok(json_at_path) => {
string_builder.append_value(
get_json_type(&json_at_path).map_err(|err| {
DataFusionError::Internal(err.to_string())
})?,
);
}
Err(_) => {
string_builder.append_null();
}
}
Ok::<(), DataFusionError>(())
} else {
string_builder.append_null();
Ok::<(), DataFusionError>(())
}
})?;
}

Ok(Arc::new(string_builder.finish()) as ArrayRef)
Ok(ColumnarValue::Array(
Arc::new(string_builder.finish()) as ArrayRef
))
}
}

/// The json_valid(X) function return 1 if the argument X is well-formed canonical RFC-7159 JSON
Expand Down
17 changes: 2 additions & 15 deletions src/sqlite/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

mod json_udfs;

use crate::sqlite::json_udfs::{json_type, json_valid, Json};
use crate::sqlite::json_udfs::{json_valid, Json, JsonType};
use datafusion::arrow::datatypes::DataType::{UInt8, Utf8};
use datafusion::error::Result;
use datafusion::logical_expr::{ReturnTypeFunction, ScalarUDF, Signature, Volatility};
Expand All @@ -13,24 +13,11 @@ use std::sync::Arc;

pub fn register_sqlite_udfs(ctx: &SessionContext) -> Result<()> {
ctx.register_udf(ScalarUDF::from(Json::new()));
register_json_type(ctx);
ctx.register_udf(ScalarUDF::from(JsonType::new()));
register_json_valid(ctx);
Ok(())
}

fn register_json_type(ctx: &SessionContext) {
let udf = make_scalar_function(json_type);
let return_type: ReturnTypeFunction = Arc::new(move |_| Ok(Arc::new(Utf8)));
let div_udf = ScalarUDF::new(
"json_type",
&Signature::variadic(vec![Utf8], Volatility::Immutable),
&return_type,
&udf,
);

ctx.register_udf(div_udf);
}

fn register_json_valid(ctx: &SessionContext) {
let udf = make_scalar_function(json_valid);
let return_type: ReturnTypeFunction = Arc::new(move |_| Ok(Arc::new(UInt8)));
Expand Down

0 comments on commit 860b603

Please sign in to comment.