Skip to content

Commit

Permalink
Use mark join in decorrelate subqueries
Browse files Browse the repository at this point in the history
This fixes a correctness issue in the current approach.
  • Loading branch information
eejbyfeldt committed Oct 27, 2024
1 parent 9213260 commit 6ddce2e
Show file tree
Hide file tree
Showing 3 changed files with 65 additions and 82 deletions.
50 changes: 15 additions & 35 deletions datafusion/optimizer/src/decorrelate_predicate_subquery.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@

//! [`DecorrelatePredicateSubquery`] converts `IN`/`EXISTS` subquery predicates to `SEMI`/`ANTI` joins
use std::collections::BTreeSet;
use std::iter;
use std::ops::Deref;
use std::sync::Arc;

Expand All @@ -34,11 +33,10 @@ use datafusion_expr::expr_rewriter::create_col_from_scalar_expr;
use datafusion_expr::logical_plan::{JoinType, Subquery};
use datafusion_expr::utils::{conjunction, split_conjunction_owned};
use datafusion_expr::{
exists, in_subquery, lit, not, not_exists, not_in_subquery, BinaryExpr, Expr, Filter,
exists, in_subquery, not, not_exists, not_in_subquery, BinaryExpr, Expr, Filter,
LogicalPlan, LogicalPlanBuilder, Operator,
};

use itertools::chain;
use log::debug;

/// Optimizer rule for rewriting predicate(IN/EXISTS) subquery to left semi/anti joins
Expand Down Expand Up @@ -138,17 +136,14 @@ fn rewrite_inner_subqueries(
Expr::Exists(Exists {
subquery: Subquery { subquery, .. },
negated,
}) => {
match existence_join(&cur_input, Arc::clone(&subquery), None, negated, alias)?
{
Some((plan, exists_expr)) => {
cur_input = plan;
Ok(Transformed::yes(exists_expr))
}
None if negated => Ok(Transformed::no(not_exists(subquery))),
None => Ok(Transformed::no(exists(subquery))),
}) => match mark_join(&cur_input, Arc::clone(&subquery), None, negated, alias)? {
Some((plan, exists_expr)) => {
cur_input = plan;
Ok(Transformed::yes(exists_expr))
}
}
None if negated => Ok(Transformed::no(not_exists(subquery))),
None => Ok(Transformed::no(exists(subquery))),
},
Expr::InSubquery(InSubquery {
expr,
subquery: Subquery { subquery, .. },
Expand All @@ -159,7 +154,7 @@ fn rewrite_inner_subqueries(
.map_or(plan_err!("single expression required."), |output_expr| {
Ok(Expr::eq(*expr.clone(), output_expr))
})?;
match existence_join(
match mark_join(
&cur_input,
Arc::clone(&subquery),
Some(in_predicate),
Expand Down Expand Up @@ -283,10 +278,6 @@ fn build_join_top(
build_join(left, subquery, in_predicate_opt, join_type, subquery_alias)
}

/// Existence join is emulated by adding a non-nullable column to the subquery and using a left join
/// and checking if the column is null or not. If native support is added for Existence/Mark then
/// we should use that instead.
///
/// This is used to handle the case when the subquery is embedded in a more complex boolean
/// expression like and OR. For example
///
Expand All @@ -296,37 +287,26 @@ fn build_join_top(
///
/// ```text
/// Projection: t1.id
/// Filter: t1.id < 0 OR __correlated_sq_1.__exists IS NOT NULL
/// Filter: t1.id < 0 OR __correlated_sq_1.mark
/// Left Join: Filter: t1.id = __correlated_sq_1.id
/// TableScan: t1
/// SubqueryAlias: __correlated_sq_1
/// Projection: t2.id, true as __exists
/// Projection: t2.id
/// TableScan: t2
fn existence_join(
fn mark_join(
left: &LogicalPlan,
subquery: Arc<LogicalPlan>,
in_predicate_opt: Option<Expr>,
negated: bool,
alias_generator: &Arc<AliasGenerator>,
) -> Result<Option<(LogicalPlan, Expr)>> {
// Add non nullable column to emulate existence join
let always_true_expr = lit(true).alias("__exists");
let cols = chain(
subquery.schema().columns().into_iter().map(Expr::Column),
iter::once(always_true_expr),
);
let subquery = LogicalPlanBuilder::from(subquery).project(cols)?.build()?;
let alias = alias_generator.next("__correlated_sq");

let exists_col = Expr::Column(Column::new(Some(alias.clone()), "__exists"));
let exists_expr = if negated {
exists_col.is_null()
} else {
exists_col.is_not_null()
};
let exists_col = Expr::Column(Column::new(Some(alias.clone()), "mark"));
let exists_expr = if negated { !exists_col } else { exists_col };

Ok(
build_join(left, &subquery, in_predicate_opt, JoinType::Left, alias)?
build_join(left, &subquery, in_predicate_opt, JoinType::LeftMark, alias)?
.map(|plan| (plan, exists_expr)),
)
}
Expand Down
79 changes: 41 additions & 38 deletions datafusion/sqllogictest/test_files/subquery.slt
Original file line number Diff line number Diff line change
Expand Up @@ -1056,13 +1056,11 @@ where t1.t1_id > 40 or t1.t1_id in (select t2.t2_id from t2 where t1.t1_int > 0)
----
logical_plan
01)Projection: t1.t1_id, t1.t1_name, t1.t1_int
02)--Filter: t1.t1_id > Int32(40) OR __correlated_sq_1.__exists IS NOT NULL
03)----Projection: t1.t1_id, t1.t1_name, t1.t1_int, __correlated_sq_1.__exists
04)------Left Join: t1.t1_id = __correlated_sq_1.t2_id Filter: t1.t1_int > Int32(0)
05)--------TableScan: t1 projection=[t1_id, t1_name, t1_int]
06)--------SubqueryAlias: __correlated_sq_1
07)----------Projection: t2.t2_id, Boolean(true) AS __exists
08)------------TableScan: t2 projection=[t2_id]
02)--Filter: t1.t1_id > Int32(40) OR __correlated_sq_1.mark
03)----LeftMark Join: t1.t1_id = __correlated_sq_1.t2_id Filter: t1.t1_int > Int32(0)
04)------TableScan: t1 projection=[t1_id, t1_name, t1_int]
05)------SubqueryAlias: __correlated_sq_1
06)--------TableScan: t2 projection=[t2_id]

query ITI rowsort
select t1.t1_id,
Expand All @@ -1085,13 +1083,12 @@ where t1.t1_id = 11 or t1.t1_id + 12 not in (select t2.t2_id + 1 from t2 where t
----
logical_plan
01)Projection: t1.t1_id, t1.t1_name, t1.t1_int
02)--Filter: t1.t1_id = Int32(11) OR __correlated_sq_1.__exists IS NULL
03)----Projection: t1.t1_id, t1.t1_name, t1.t1_int, __correlated_sq_1.__exists
04)------Left Join: CAST(t1.t1_id AS Int64) + Int64(12) = __correlated_sq_1.t2.t2_id + Int64(1) Filter: t1.t1_int > Int32(0)
05)--------TableScan: t1 projection=[t1_id, t1_name, t1_int]
06)--------SubqueryAlias: __correlated_sq_1
07)----------Projection: CAST(t2.t2_id AS Int64) + Int64(1), Boolean(true) AS __exists
08)------------TableScan: t2 projection=[t2_id]
02)--Filter: t1.t1_id = Int32(11) OR NOT __correlated_sq_1.mark
03)----LeftMark Join: CAST(t1.t1_id AS Int64) + Int64(12) = __correlated_sq_1.t2.t2_id + Int64(1) Filter: t1.t1_int > Int32(0)
04)------TableScan: t1 projection=[t1_id, t1_name, t1_int]
05)------SubqueryAlias: __correlated_sq_1
06)--------Projection: CAST(t2.t2_id AS Int64) + Int64(1)
07)----------TableScan: t2 projection=[t2_id]

query ITI rowsort
select t1.t1_id,
Expand All @@ -1113,13 +1110,11 @@ where t1.t1_id > 40 or exists (select * from t2 where t1.t1_id = t2.t2_id)
----
logical_plan
01)Projection: t1.t1_id, t1.t1_name, t1.t1_int
02)--Filter: t1.t1_id > Int32(40) OR __correlated_sq_1.__exists IS NOT NULL
03)----Projection: t1.t1_id, t1.t1_name, t1.t1_int, __correlated_sq_1.__exists
04)------Left Join: t1.t1_id = __correlated_sq_1.t2_id
05)--------TableScan: t1 projection=[t1_id, t1_name, t1_int]
06)--------SubqueryAlias: __correlated_sq_1
07)----------Projection: t2.t2_id, Boolean(true) AS __exists
08)------------TableScan: t2 projection=[t2_id]
02)--Filter: t1.t1_id > Int32(40) OR __correlated_sq_1.mark
03)----LeftMark Join: t1.t1_id = __correlated_sq_1.t2_id
04)------TableScan: t1 projection=[t1_id, t1_name, t1_int]
05)------SubqueryAlias: __correlated_sq_1
06)--------TableScan: t2 projection=[t2_id]

query ITI rowsort
select t1.t1_id,
Expand All @@ -1142,13 +1137,11 @@ where t1.t1_id > 40 or not exists (select * from t2 where t1.t1_id = t2.t2_id)
----
logical_plan
01)Projection: t1.t1_id, t1.t1_name, t1.t1_int
02)--Filter: t1.t1_id > Int32(40) OR __correlated_sq_1.__exists IS NULL
03)----Projection: t1.t1_id, t1.t1_name, t1.t1_int, __correlated_sq_1.__exists
04)------Left Join: t1.t1_id = __correlated_sq_1.t2_id
05)--------TableScan: t1 projection=[t1_id, t1_name, t1_int]
06)--------SubqueryAlias: __correlated_sq_1
07)----------Projection: t2.t2_id, Boolean(true) AS __exists
08)------------TableScan: t2 projection=[t2_id]
02)--Filter: t1.t1_id > Int32(40) OR NOT __correlated_sq_1.mark
03)----LeftMark Join: t1.t1_id = __correlated_sq_1.t2_id
04)------TableScan: t1 projection=[t1_id, t1_name, t1_int]
05)------SubqueryAlias: __correlated_sq_1
06)--------TableScan: t2 projection=[t2_id]

query ITI rowsort
select t1.t1_id,
Expand All @@ -1170,16 +1163,14 @@ where t1.t1_id in (select t3.t3_id from t3) and (t1.t1_id > 40 or t1.t1_id in (s
----
logical_plan
01)Projection: t1.t1_id, t1.t1_name, t1.t1_int
02)--Filter: t1.t1_id > Int32(40) OR __correlated_sq_2.__exists IS NOT NULL
03)----Projection: t1.t1_id, t1.t1_name, t1.t1_int, __correlated_sq_2.__exists
04)------Left Join: t1.t1_id = __correlated_sq_2.t2_id Filter: t1.t1_int > Int32(0)
05)--------LeftSemi Join: t1.t1_id = __correlated_sq_1.t3_id
06)----------TableScan: t1 projection=[t1_id, t1_name, t1_int]
07)----------SubqueryAlias: __correlated_sq_1
08)------------TableScan: t3 projection=[t3_id]
09)--------SubqueryAlias: __correlated_sq_2
10)----------Projection: t2.t2_id, Boolean(true) AS __exists
11)------------TableScan: t2 projection=[t2_id]
02)--Filter: t1.t1_id > Int32(40) OR __correlated_sq_2.mark
03)----LeftMark Join: t1.t1_id = __correlated_sq_2.t2_id Filter: t1.t1_int > Int32(0)
04)------LeftSemi Join: t1.t1_id = __correlated_sq_1.t3_id
05)--------TableScan: t1 projection=[t1_id, t1_name, t1_int]
06)--------SubqueryAlias: __correlated_sq_1
07)----------TableScan: t3 projection=[t3_id]
08)------SubqueryAlias: __correlated_sq_2
09)--------TableScan: t2 projection=[t2_id]

query ITI rowsort
select t1.t1_id,
Expand All @@ -1192,6 +1183,18 @@ where t1.t1_id in (select t3.t3_id from t3) and (t1.t1_id > 40 or t1.t1_id in (s
22 b 2
44 d 4

# Handle duplicate values in exists query
query ITI rowsort
select t1.t1_id,
t1.t1_name,
t1.t1_int
from t1
where t1.t1_id > 40 or exists (select * from t2 cross join t3 where t1.t1_id = t2.t2_id)
----
11 a 1
22 b 2
44 d 4

# Nested subqueries
query ITI rowsort
select t1.t1_id,
Expand Down
18 changes: 9 additions & 9 deletions datafusion/substrait/tests/cases/roundtrip_logical_plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -473,15 +473,15 @@ async fn roundtrip_inlist_5() -> Result<()> {
// on roundtrip there is an additional projection during TableScan which includes all column of the table,
// using assert_expected_plan here as a workaround
assert_expected_plan(
"SELECT a, f FROM data WHERE (f IN ('a', 'b', 'c') OR a in (SELECT data2.a FROM data2 WHERE f IN ('b', 'c', 'd')))",
"Projection: data.a, data.f\
\n Filter: data.f = Utf8(\"a\") OR data.f = Utf8(\"b\") OR data.f = Utf8(\"c\") OR Boolean(true) IS NOT NULL\
\n Projection: data.a, data.f, Boolean(true)\
\n Left Join: data.a = data2.a\
\n TableScan: data projection=[a, f]\
\n Projection: data2.a, Boolean(true)\
\n Filter: data2.f = Utf8(\"b\") OR data2.f = Utf8(\"c\") OR data2.f = Utf8(\"d\")\
\n TableScan: data2 projection=[a, f], partial_filters=[data2.f = Utf8(\"b\") OR data2.f = Utf8(\"c\") OR data2.f = Utf8(\"d\")]",
"SELECT a, f FROM data WHERE (f IN ('a', 'b', 'c') OR a in (SELECT data2.a FROM data2 WHERE f IN ('b', 'c', 'd')))",

"Projection: data.a, data.f\
\n Filter: data.f = Utf8(\"a\") OR data.f = Utf8(\"b\") OR data.f = Utf8(\"c\") OR data2.mark\
\n LeftMark Join: data.a = data2.a\
\n TableScan: data projection=[a, f]\
\n Projection: data2.a\
\n Filter: data2.f = Utf8(\"b\") OR data2.f = Utf8(\"c\") OR data2.f = Utf8(\"d\")\
\n TableScan: data2 projection=[a, f], partial_filters=[data2.f = Utf8(\"b\") OR data2.f = Utf8(\"c\") OR data2.f = Utf8(\"d\")]",
true).await
}

Expand Down

0 comments on commit 6ddce2e

Please sign in to comment.