Skip to content

Commit

Permalink
switch implementation of json to use ScalarUDFImpl
Browse files Browse the repository at this point in the history
  • Loading branch information
dadepo committed Apr 20, 2024
1 parent ebb1a77 commit 7a16a53
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 33 deletions.
73 changes: 55 additions & 18 deletions src/sqlite/json_udfs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,34 +2,71 @@ use std::sync::Arc;

use crate::common::{get_json_string_type, get_json_type, get_value_at_string};
use datafusion::arrow::array::{Array, ArrayRef, StringBuilder, UInt8Array};
use datafusion::arrow::datatypes::DataType;
use datafusion::arrow::datatypes::DataType::Utf8;
use datafusion::common::DataFusionError;
use datafusion::error::Result;
use datafusion::logical_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility};
use serde_json::Value;

/// The json(X) function verifies that its argument X is a valid JSON string and returns a minified
/// version of that JSON string (with all unnecessary whitespace removed).
/// If X is not a well-formed JSON string, then this routine throws an error.
pub fn json(args: &[ArrayRef]) -> Result<ArrayRef> {
let json_strings = datafusion::common::cast::as_string_array(&args[0])?;
#[derive(Debug)]
pub struct Json {
signature: Signature,
}

let mut string_builder = StringBuilder::with_capacity(json_strings.len(), u8::MAX as usize);
json_strings.iter().try_for_each(|json_string| {
if let Some(json_string) = json_string {
let value: Value = serde_json::from_str(json_string).map_err(|_| {
DataFusionError::Internal("Runtime error: malformed JSON".to_string())
})?;
let pretty_json = serde_json::to_string(&value).map_err(|_| {
DataFusionError::Internal("Runtime error: malformed JSON".to_string())
})?;
string_builder.append_value(pretty_json);
Ok::<(), DataFusionError>(())
} else {
string_builder.append_null();
Ok::<(), DataFusionError>(())
impl Json {
pub fn new() -> Self {
Self {
signature: Signature::uniform(1, vec![Utf8], Volatility::Immutable),
}
})?;
}
}

Ok(Arc::new(string_builder.finish()) as ArrayRef)
impl ScalarUDFImpl for Json {
fn as_any(&self) -> &dyn std::any::Any {
self
}

fn name(&self) -> &str {
"json"
}

fn signature(&self) -> &Signature {
&self.signature
}

fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
Ok(Utf8)
}

fn invoke(&self, args: &[ColumnarValue]) -> Result<ColumnarValue> {
let args = ColumnarValue::values_to_arrays(args)?;
let json_strings = datafusion::common::cast::as_string_array(&args[0])?;

let mut string_builder = StringBuilder::with_capacity(json_strings.len(), u8::MAX as usize);
json_strings.iter().try_for_each(|json_string| {
if let Some(json_string) = json_string {
let value: Value = serde_json::from_str(json_string).map_err(|_| {
DataFusionError::Internal("Runtime error: malformed JSON".to_string())
})?;
let pretty_json = serde_json::to_string(&value).map_err(|_| {
DataFusionError::Internal("Runtime error: malformed JSON".to_string())
})?;
string_builder.append_value(pretty_json);
Ok::<(), DataFusionError>(())
} else {
string_builder.append_null();
Ok::<(), DataFusionError>(())
}
})?;

Ok(ColumnarValue::Array(
Arc::new(string_builder.finish()) as ArrayRef
))
}
}

/// The json_type(X) function returns the "type" of the outermost element of X. The json_type(X,P)
Expand Down
17 changes: 2 additions & 15 deletions src/sqlite/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

mod json_udfs;

use crate::sqlite::json_udfs::{json, json_type, json_valid};
use crate::sqlite::json_udfs::{json_type, json_valid, Json};
use datafusion::arrow::datatypes::DataType::{UInt8, Utf8};
use datafusion::error::Result;
use datafusion::logical_expr::{ReturnTypeFunction, ScalarUDF, Signature, Volatility};
Expand All @@ -12,25 +12,12 @@ use datafusion::prelude::SessionContext;
use std::sync::Arc;

pub fn register_sqlite_udfs(ctx: &SessionContext) -> Result<()> {
register_json(ctx);
ctx.register_udf(ScalarUDF::from(Json::new()));
register_json_type(ctx);
register_json_valid(ctx);
Ok(())
}

fn register_json(ctx: &SessionContext) {
let udf = make_scalar_function(json);
let return_type: ReturnTypeFunction = Arc::new(move |_| Ok(Arc::new(Utf8)));
let div_udf = ScalarUDF::new(
"json",
&Signature::uniform(1, vec![Utf8], Volatility::Immutable),
&return_type,
&udf,
);

ctx.register_udf(div_udf);
}

fn register_json_type(ctx: &SessionContext) {
let udf = make_scalar_function(json_type);
let return_type: ReturnTypeFunction = Arc::new(move |_| Ok(Arc::new(Utf8)));
Expand Down

0 comments on commit 7a16a53

Please sign in to comment.