diff --git a/datafusion/optimizer/src/decorrelate_predicate_subquery.rs b/datafusion/optimizer/src/decorrelate_predicate_subquery.rs index cdffa8c645ea5..e84fd828a649b 100644 --- a/datafusion/optimizer/src/decorrelate_predicate_subquery.rs +++ b/datafusion/optimizer/src/decorrelate_predicate_subquery.rs @@ -17,7 +17,6 @@ //! [`DecorrelatePredicateSubquery`] converts `IN`/`EXISTS` subquery predicates to `SEMI`/`ANTI` joins use std::collections::BTreeSet; -use std::iter; use std::ops::Deref; use std::sync::Arc; @@ -34,11 +33,10 @@ use datafusion_expr::expr_rewriter::create_col_from_scalar_expr; use datafusion_expr::logical_plan::{JoinType, Subquery}; use datafusion_expr::utils::{conjunction, split_conjunction_owned}; use datafusion_expr::{ - exists, in_subquery, lit, not, not_exists, not_in_subquery, BinaryExpr, Expr, Filter, + exists, in_subquery, not, not_exists, not_in_subquery, BinaryExpr, Expr, Filter, LogicalPlan, LogicalPlanBuilder, Operator, }; -use itertools::chain; use log::debug; /// Optimizer rule for rewriting predicate(IN/EXISTS) subquery to left semi/anti joins @@ -138,17 +136,14 @@ fn rewrite_inner_subqueries( Expr::Exists(Exists { subquery: Subquery { subquery, .. }, negated, - }) => { - match existence_join(&cur_input, Arc::clone(&subquery), None, negated, alias)? - { - Some((plan, exists_expr)) => { - cur_input = plan; - Ok(Transformed::yes(exists_expr)) - } - None if negated => Ok(Transformed::no(not_exists(subquery))), - None => Ok(Transformed::no(exists(subquery))), + }) => match mark_join(&cur_input, Arc::clone(&subquery), None, negated, alias)? { + Some((plan, exists_expr)) => { + cur_input = plan; + Ok(Transformed::yes(exists_expr)) } - } + None if negated => Ok(Transformed::no(not_exists(subquery))), + None => Ok(Transformed::no(exists(subquery))), + }, Expr::InSubquery(InSubquery { expr, subquery: Subquery { subquery, .. }, @@ -159,7 +154,7 @@ fn rewrite_inner_subqueries( .map_or(plan_err!("single expression required."), |output_expr| { Ok(Expr::eq(*expr.clone(), output_expr)) })?; - match existence_join( + match mark_join( &cur_input, Arc::clone(&subquery), Some(in_predicate), @@ -283,10 +278,6 @@ fn build_join_top( build_join(left, subquery, in_predicate_opt, join_type, subquery_alias) } -/// Existence join is emulated by adding a non-nullable column to the subquery and using a left join -/// and checking if the column is null or not. If native support is added for Existence/Mark then -/// we should use that instead. -/// /// This is used to handle the case when the subquery is embedded in a more complex boolean /// expression like and OR. For example /// @@ -296,37 +287,26 @@ fn build_join_top( /// /// ```text /// Projection: t1.id -/// Filter: t1.id < 0 OR __correlated_sq_1.__exists IS NOT NULL +/// Filter: t1.id < 0 OR __correlated_sq_1.mark /// Left Join: Filter: t1.id = __correlated_sq_1.id /// TableScan: t1 /// SubqueryAlias: __correlated_sq_1 -/// Projection: t2.id, true as __exists +/// Projection: t2.id /// TableScan: t2 -fn existence_join( +fn mark_join( left: &LogicalPlan, subquery: Arc, in_predicate_opt: Option, negated: bool, alias_generator: &Arc, ) -> Result> { - // Add non nullable column to emulate existence join - let always_true_expr = lit(true).alias("__exists"); - let cols = chain( - subquery.schema().columns().into_iter().map(Expr::Column), - iter::once(always_true_expr), - ); - let subquery = LogicalPlanBuilder::from(subquery).project(cols)?.build()?; let alias = alias_generator.next("__correlated_sq"); - let exists_col = Expr::Column(Column::new(Some(alias.clone()), "__exists")); - let exists_expr = if negated { - exists_col.is_null() - } else { - exists_col.is_not_null() - }; + let exists_col = Expr::Column(Column::new(Some(alias.clone()), "mark")); + let exists_expr = if negated { !exists_col } else { exists_col }; Ok( - build_join(left, &subquery, in_predicate_opt, JoinType::Left, alias)? + build_join(left, &subquery, in_predicate_opt, JoinType::LeftMark, alias)? .map(|plan| (plan, exists_expr)), ) } diff --git a/datafusion/sqllogictest/test_files/subquery.slt b/datafusion/sqllogictest/test_files/subquery.slt index 26b5d8b952f6c..f5dcf903e7ff0 100644 --- a/datafusion/sqllogictest/test_files/subquery.slt +++ b/datafusion/sqllogictest/test_files/subquery.slt @@ -1056,13 +1056,11 @@ where t1.t1_id > 40 or t1.t1_id in (select t2.t2_id from t2 where t1.t1_int > 0) ---- logical_plan 01)Projection: t1.t1_id, t1.t1_name, t1.t1_int -02)--Filter: t1.t1_id > Int32(40) OR __correlated_sq_1.__exists IS NOT NULL -03)----Projection: t1.t1_id, t1.t1_name, t1.t1_int, __correlated_sq_1.__exists -04)------Left Join: t1.t1_id = __correlated_sq_1.t2_id Filter: t1.t1_int > Int32(0) -05)--------TableScan: t1 projection=[t1_id, t1_name, t1_int] -06)--------SubqueryAlias: __correlated_sq_1 -07)----------Projection: t2.t2_id, Boolean(true) AS __exists -08)------------TableScan: t2 projection=[t2_id] +02)--Filter: t1.t1_id > Int32(40) OR __correlated_sq_1.mark +03)----LeftMark Join: t1.t1_id = __correlated_sq_1.t2_id Filter: t1.t1_int > Int32(0) +04)------TableScan: t1 projection=[t1_id, t1_name, t1_int] +05)------SubqueryAlias: __correlated_sq_1 +06)--------TableScan: t2 projection=[t2_id] query ITI rowsort select t1.t1_id, @@ -1085,13 +1083,12 @@ where t1.t1_id = 11 or t1.t1_id + 12 not in (select t2.t2_id + 1 from t2 where t ---- logical_plan 01)Projection: t1.t1_id, t1.t1_name, t1.t1_int -02)--Filter: t1.t1_id = Int32(11) OR __correlated_sq_1.__exists IS NULL -03)----Projection: t1.t1_id, t1.t1_name, t1.t1_int, __correlated_sq_1.__exists -04)------Left Join: CAST(t1.t1_id AS Int64) + Int64(12) = __correlated_sq_1.t2.t2_id + Int64(1) Filter: t1.t1_int > Int32(0) -05)--------TableScan: t1 projection=[t1_id, t1_name, t1_int] -06)--------SubqueryAlias: __correlated_sq_1 -07)----------Projection: CAST(t2.t2_id AS Int64) + Int64(1), Boolean(true) AS __exists -08)------------TableScan: t2 projection=[t2_id] +02)--Filter: t1.t1_id = Int32(11) OR NOT __correlated_sq_1.mark +03)----LeftMark Join: CAST(t1.t1_id AS Int64) + Int64(12) = __correlated_sq_1.t2.t2_id + Int64(1) Filter: t1.t1_int > Int32(0) +04)------TableScan: t1 projection=[t1_id, t1_name, t1_int] +05)------SubqueryAlias: __correlated_sq_1 +06)--------Projection: CAST(t2.t2_id AS Int64) + Int64(1) +07)----------TableScan: t2 projection=[t2_id] query ITI rowsort select t1.t1_id, @@ -1113,13 +1110,11 @@ where t1.t1_id > 40 or exists (select * from t2 where t1.t1_id = t2.t2_id) ---- logical_plan 01)Projection: t1.t1_id, t1.t1_name, t1.t1_int -02)--Filter: t1.t1_id > Int32(40) OR __correlated_sq_1.__exists IS NOT NULL -03)----Projection: t1.t1_id, t1.t1_name, t1.t1_int, __correlated_sq_1.__exists -04)------Left Join: t1.t1_id = __correlated_sq_1.t2_id -05)--------TableScan: t1 projection=[t1_id, t1_name, t1_int] -06)--------SubqueryAlias: __correlated_sq_1 -07)----------Projection: t2.t2_id, Boolean(true) AS __exists -08)------------TableScan: t2 projection=[t2_id] +02)--Filter: t1.t1_id > Int32(40) OR __correlated_sq_1.mark +03)----LeftMark Join: t1.t1_id = __correlated_sq_1.t2_id +04)------TableScan: t1 projection=[t1_id, t1_name, t1_int] +05)------SubqueryAlias: __correlated_sq_1 +06)--------TableScan: t2 projection=[t2_id] query ITI rowsort select t1.t1_id, @@ -1142,13 +1137,11 @@ where t1.t1_id > 40 or not exists (select * from t2 where t1.t1_id = t2.t2_id) ---- logical_plan 01)Projection: t1.t1_id, t1.t1_name, t1.t1_int -02)--Filter: t1.t1_id > Int32(40) OR __correlated_sq_1.__exists IS NULL -03)----Projection: t1.t1_id, t1.t1_name, t1.t1_int, __correlated_sq_1.__exists -04)------Left Join: t1.t1_id = __correlated_sq_1.t2_id -05)--------TableScan: t1 projection=[t1_id, t1_name, t1_int] -06)--------SubqueryAlias: __correlated_sq_1 -07)----------Projection: t2.t2_id, Boolean(true) AS __exists -08)------------TableScan: t2 projection=[t2_id] +02)--Filter: t1.t1_id > Int32(40) OR NOT __correlated_sq_1.mark +03)----LeftMark Join: t1.t1_id = __correlated_sq_1.t2_id +04)------TableScan: t1 projection=[t1_id, t1_name, t1_int] +05)------SubqueryAlias: __correlated_sq_1 +06)--------TableScan: t2 projection=[t2_id] query ITI rowsort select t1.t1_id, @@ -1170,16 +1163,14 @@ where t1.t1_id in (select t3.t3_id from t3) and (t1.t1_id > 40 or t1.t1_id in (s ---- logical_plan 01)Projection: t1.t1_id, t1.t1_name, t1.t1_int -02)--Filter: t1.t1_id > Int32(40) OR __correlated_sq_2.__exists IS NOT NULL -03)----Projection: t1.t1_id, t1.t1_name, t1.t1_int, __correlated_sq_2.__exists -04)------Left Join: t1.t1_id = __correlated_sq_2.t2_id Filter: t1.t1_int > Int32(0) -05)--------LeftSemi Join: t1.t1_id = __correlated_sq_1.t3_id -06)----------TableScan: t1 projection=[t1_id, t1_name, t1_int] -07)----------SubqueryAlias: __correlated_sq_1 -08)------------TableScan: t3 projection=[t3_id] -09)--------SubqueryAlias: __correlated_sq_2 -10)----------Projection: t2.t2_id, Boolean(true) AS __exists -11)------------TableScan: t2 projection=[t2_id] +02)--Filter: t1.t1_id > Int32(40) OR __correlated_sq_2.mark +03)----LeftMark Join: t1.t1_id = __correlated_sq_2.t2_id Filter: t1.t1_int > Int32(0) +04)------LeftSemi Join: t1.t1_id = __correlated_sq_1.t3_id +05)--------TableScan: t1 projection=[t1_id, t1_name, t1_int] +06)--------SubqueryAlias: __correlated_sq_1 +07)----------TableScan: t3 projection=[t3_id] +08)------SubqueryAlias: __correlated_sq_2 +09)--------TableScan: t2 projection=[t2_id] query ITI rowsort select t1.t1_id, @@ -1192,6 +1183,18 @@ where t1.t1_id in (select t3.t3_id from t3) and (t1.t1_id > 40 or t1.t1_id in (s 22 b 2 44 d 4 +# Handle duplicate values in exists query +query ITI rowsort +select t1.t1_id, + t1.t1_name, + t1.t1_int +from t1 +where t1.t1_id > 40 or exists (select * from t2 cross join t3 where t1.t1_id = t2.t2_id) +---- +11 a 1 +22 b 2 +44 d 4 + # Nested subqueries query ITI rowsort select t1.t1_id, diff --git a/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs b/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs index 06a047b108bd3..1f0157ce50097 100644 --- a/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs @@ -473,15 +473,15 @@ async fn roundtrip_inlist_5() -> Result<()> { // on roundtrip there is an additional projection during TableScan which includes all column of the table, // using assert_expected_plan here as a workaround assert_expected_plan( - "SELECT a, f FROM data WHERE (f IN ('a', 'b', 'c') OR a in (SELECT data2.a FROM data2 WHERE f IN ('b', 'c', 'd')))", - "Projection: data.a, data.f\ - \n Filter: data.f = Utf8(\"a\") OR data.f = Utf8(\"b\") OR data.f = Utf8(\"c\") OR Boolean(true) IS NOT NULL\ - \n Projection: data.a, data.f, Boolean(true)\ - \n Left Join: data.a = data2.a\ - \n TableScan: data projection=[a, f]\ - \n Projection: data2.a, Boolean(true)\ - \n Filter: data2.f = Utf8(\"b\") OR data2.f = Utf8(\"c\") OR data2.f = Utf8(\"d\")\ - \n TableScan: data2 projection=[a, f], partial_filters=[data2.f = Utf8(\"b\") OR data2.f = Utf8(\"c\") OR data2.f = Utf8(\"d\")]", + "SELECT a, f FROM data WHERE (f IN ('a', 'b', 'c') OR a in (SELECT data2.a FROM data2 WHERE f IN ('b', 'c', 'd')))", + + "Projection: data.a, data.f\ + \n Filter: data.f = Utf8(\"a\") OR data.f = Utf8(\"b\") OR data.f = Utf8(\"c\") OR data2.mark\ + \n LeftMark Join: data.a = data2.a\ + \n TableScan: data projection=[a, f]\ + \n Projection: data2.a\ + \n Filter: data2.f = Utf8(\"b\") OR data2.f = Utf8(\"c\") OR data2.f = Utf8(\"d\")\ + \n TableScan: data2 projection=[a, f], partial_filters=[data2.f = Utf8(\"b\") OR data2.f = Utf8(\"c\") OR data2.f = Utf8(\"d\")]", true).await }