Skip to content

Commit

Permalink
Handle alias when parsing sql(parse_sql_expr) (apache#12939)
Browse files Browse the repository at this point in the history
* fix: Fix parse_sql_expr not handling alias

* cargo fmt

* fix parse_sql_expr example(remove alias)

* add testing

* add SUM udaf to TestContextProvider and modify test_sql_to_expr_with_alias for function

* revert change on example `parse_sql_expr`
  • Loading branch information
Eason0729 authored Dec 11, 2024
1 parent 6196ff2 commit 93b3d9c
Show file tree
Hide file tree
Showing 4 changed files with 82 additions and 18 deletions.
10 changes: 5 additions & 5 deletions datafusion-examples/examples/parse_sql_expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -121,11 +121,11 @@ async fn query_parquet_demo() -> Result<()> {

assert_batches_eq!(
&[
"+------------+----------------------+",
"| double_col | sum(?table?.int_col) |",
"+------------+----------------------+",
"| 10.1 | 4 |",
"+------------+----------------------+",
"+------------+-------------+",
"| double_col | sum_int_col |",
"+------------+-------------+",
"| 10.1 | 4 |",
"+------------+-------------+",
],
&result
);
Expand Down
21 changes: 16 additions & 5 deletions datafusion/core/src/execution/session_state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ use datafusion_sql::planner::{ContextProvider, ParserOptions, PlannerContext, Sq
use itertools::Itertools;
use log::{debug, info};
use object_store::ObjectStore;
use sqlparser::ast::Expr as SQLExpr;
use sqlparser::ast::{Expr as SQLExpr, ExprWithAlias as SQLExprWithAlias};
use sqlparser::dialect::dialect_from_str;
use std::any::Any;
use std::collections::hash_map::Entry;
Expand Down Expand Up @@ -500,11 +500,22 @@ impl SessionState {
sql: &str,
dialect: &str,
) -> datafusion_common::Result<SQLExpr> {
self.sql_to_expr_with_alias(sql, dialect).map(|x| x.expr)
}

/// parse a sql string into a sqlparser-rs AST [`SQLExprWithAlias`].
///
/// See [`Self::create_logical_expr`] for parsing sql to [`Expr`].
pub fn sql_to_expr_with_alias(
&self,
sql: &str,
dialect: &str,
) -> datafusion_common::Result<SQLExprWithAlias> {
let dialect = dialect_from_str(dialect).ok_or_else(|| {
plan_datafusion_err!(
"Unsupported SQL dialect: {dialect}. Available dialects: \
Generic, MySQL, PostgreSQL, Hive, SQLite, Snowflake, Redshift, \
MsSQL, ClickHouse, BigQuery, Ansi."
Generic, MySQL, PostgreSQL, Hive, SQLite, Snowflake, Redshift, \
MsSQL, ClickHouse, BigQuery, Ansi."
)
})?;

Expand Down Expand Up @@ -603,15 +614,15 @@ impl SessionState {
) -> datafusion_common::Result<Expr> {
let dialect = self.config.options().sql_parser.dialect.as_str();

let sql_expr = self.sql_to_expr(sql, dialect)?;
let sql_expr = self.sql_to_expr_with_alias(sql, dialect)?;

let provider = SessionContextProvider {
state: self,
tables: HashMap::new(),
};

let query = SqlToRel::new_with_options(&provider, self.get_parser_options());
query.sql_to_expr(sql_expr, df_schema, &mut PlannerContext::new())
query.sql_to_expr_with_alias(sql_expr, df_schema, &mut PlannerContext::new())
}

/// Returns the [`Analyzer`] for this session
Expand Down
60 changes: 56 additions & 4 deletions datafusion/sql/src/expr/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@ use datafusion_expr::planner::{
use recursive::recursive;
use sqlparser::ast::{
BinaryOperator, CastFormat, CastKind, DataType as SQLDataType, DictionaryField,
Expr as SQLExpr, MapEntry, StructField, Subscript, TrimWhereField, Value,
Expr as SQLExpr, ExprWithAlias as SQLExprWithAlias, MapEntry, StructField, Subscript,
TrimWhereField, Value,
};

use datafusion_common::{
Expand All @@ -50,6 +51,19 @@ mod unary_op;
mod value;

impl<S: ContextProvider> SqlToRel<'_, S> {
pub(crate) fn sql_expr_to_logical_expr_with_alias(
&self,
sql: SQLExprWithAlias,
schema: &DFSchema,
planner_context: &mut PlannerContext,
) -> Result<Expr> {
let mut expr =
self.sql_expr_to_logical_expr(sql.expr, schema, planner_context)?;
if let Some(alias) = sql.alias {
expr = expr.alias(alias.value);
}
Ok(expr)
}
pub(crate) fn sql_expr_to_logical_expr(
&self,
sql: SQLExpr,
Expand Down Expand Up @@ -131,6 +145,20 @@ impl<S: ContextProvider> SqlToRel<'_, S> {
)))
}

pub fn sql_to_expr_with_alias(
&self,
sql: SQLExprWithAlias,
schema: &DFSchema,
planner_context: &mut PlannerContext,
) -> Result<Expr> {
let mut expr =
self.sql_expr_to_logical_expr_with_alias(sql, schema, planner_context)?;
expr = self.rewrite_partial_qualifier(expr, schema);
self.validate_schema_satisfies_exprs(schema, &[expr.clone()])?;
let (expr, _) = expr.infer_placeholder_types(schema)?;
Ok(expr)
}

/// Generate a relational expression from a SQL expression
pub fn sql_to_expr(
&self,
Expand Down Expand Up @@ -1091,8 +1119,11 @@ mod tests {
None
}

fn get_aggregate_meta(&self, _name: &str) -> Option<Arc<AggregateUDF>> {
None
fn get_aggregate_meta(&self, name: &str) -> Option<Arc<AggregateUDF>> {
match name {
"sum" => Some(datafusion_functions_aggregate::sum::sum_udaf()),
_ => None,
}
}

fn get_variable_type(&self, _variable_names: &[String]) -> Option<DataType> {
Expand All @@ -1112,7 +1143,7 @@ mod tests {
}

fn udaf_names(&self) -> Vec<String> {
Vec::new()
vec!["sum".to_string()]
}

fn udwf_names(&self) -> Vec<String> {
Expand Down Expand Up @@ -1167,4 +1198,25 @@ mod tests {
test_stack_overflow!(2048);
test_stack_overflow!(4096);
test_stack_overflow!(8192);
#[test]
fn test_sql_to_expr_with_alias() {
let schema = DFSchema::empty();
let mut planner_context = PlannerContext::default();

let expr_str = "SUM(int_col) as sum_int_col";

let dialect = GenericDialect {};
let mut parser = Parser::new(&dialect).try_with_sql(expr_str).unwrap();
// from sqlparser
let sql_expr = parser.parse_expr_with_alias().unwrap();

let context_provider = TestContextProvider::new();
let sql_to_rel = SqlToRel::new(&context_provider);

let expr = sql_to_rel
.sql_expr_to_logical_expr_with_alias(sql_expr, &schema, &mut planner_context)
.unwrap();

assert!(matches!(expr, Expr::Alias(_)));
}
}
9 changes: 5 additions & 4 deletions datafusion/sql/src/parser.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,10 @@
use std::collections::VecDeque;
use std::fmt;

use sqlparser::ast::ExprWithAlias;
use sqlparser::{
ast::{
ColumnDef, ColumnOptionDef, Expr, ObjectName, OrderByExpr, Query,
ColumnDef, ColumnOptionDef, ObjectName, OrderByExpr, Query,
Statement as SQLStatement, TableConstraint, Value,
},
dialect::{keywords::Keyword, Dialect, GenericDialect},
Expand Down Expand Up @@ -328,7 +329,7 @@ impl<'a> DFParser<'a> {
pub fn parse_sql_into_expr_with_dialect(
sql: &str,
dialect: &dyn Dialect,
) -> Result<Expr, ParserError> {
) -> Result<ExprWithAlias, ParserError> {
let mut parser = DFParser::new_with_dialect(sql, dialect)?;
parser.parse_expr()
}
Expand Down Expand Up @@ -377,7 +378,7 @@ impl<'a> DFParser<'a> {
}
}

pub fn parse_expr(&mut self) -> Result<Expr, ParserError> {
pub fn parse_expr(&mut self) -> Result<ExprWithAlias, ParserError> {
if let Token::Word(w) = self.parser.peek_token().token {
match w.keyword {
Keyword::CREATE | Keyword::COPY | Keyword::EXPLAIN => {
Expand All @@ -387,7 +388,7 @@ impl<'a> DFParser<'a> {
}
}

self.parser.parse_expr()
self.parser.parse_expr_with_alias()
}

/// Parse a SQL `COPY TO` statement
Expand Down

0 comments on commit 93b3d9c

Please sign in to comment.