Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support left-outer and left-mark hash join impl rules #274

Open
wants to merge 12 commits into
base: main
Choose a base branch
from
2 changes: 2 additions & 0 deletions optd-datafusion-bridge/src/from_optd.rs
Original file line number Diff line number Diff line change
Expand Up @@ -529,6 +529,8 @@ impl OptdPlanContext<'_> {
let right_exec = self.conv_from_optd_plan_node(node.right(), meta).await?;
let join_type = match node.join_type() {
JoinType::Inner => datafusion::logical_expr::JoinType::Inner,
JoinType::LeftOuter => datafusion::logical_expr::JoinType::Left,
JoinType::LeftMark => datafusion::logical_expr::JoinType::LeftMark,
_ => unimplemented!(),
};
let left_exprs = node.left_keys().to_vec();
Expand Down
8 changes: 6 additions & 2 deletions optd-datafusion-repr/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,11 @@ impl DatafusionOptimizer {
rule_wrappers.push(Arc::new(rules::FilterInnerJoinTransposeRule::new()));
rule_wrappers.push(Arc::new(rules::FilterSortTransposeRule::new()));
rule_wrappers.push(Arc::new(rules::FilterAggTransposeRule::new()));
rule_wrappers.push(Arc::new(rules::HashJoinRule::new()));
rule_wrappers.push(Arc::new(rules::JoinInnerSplitFilterRule::new()));
rule_wrappers.push(Arc::new(rules::JoinLeftOuterSplitFilterRule::new()));
rule_wrappers.push(Arc::new(rules::HashJoinInnerRule::new()));
rule_wrappers.push(Arc::new(rules::HashJoinLeftOuterRule::new()));
rule_wrappers.push(Arc::new(rules::HashJoinLeftMarkRule::new()));
rule_wrappers.push(Arc::new(rules::JoinCommuteRule::new()));
rule_wrappers.push(Arc::new(rules::JoinAssocRule::new()));
rule_wrappers.push(Arc::new(rules::ProjectionPullUpJoin::new()));
Expand Down Expand Up @@ -178,7 +182,7 @@ impl DatafusionOptimizer {
for rule in rules {
rule_wrappers.push(rule);
}
rule_wrappers.push(Arc::new(rules::HashJoinRule::new()));
rule_wrappers.push(Arc::new(rules::HashJoinInnerRule::new()));
rule_wrappers.insert(0, Arc::new(rules::JoinCommuteRule::new()));
rule_wrappers.insert(1, Arc::new(rules::JoinAssocRule::new()));
rule_wrappers.insert(2, Arc::new(rules::ProjectionPullUpJoin::new()));
Expand Down
42 changes: 18 additions & 24 deletions optd-datafusion-repr/src/rules/filter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,41 +38,35 @@ pub(crate) fn simplify_log_expr(log_expr: ArcDfPredNode, changed: &mut bool) ->
if let DfPredType::Constant(ConstantType::Bool) = new_child.typ {
let data = ConstantPred::from_pred_node(new_child).unwrap().value();
*changed = true;
// TrueExpr
if data.as_bool() {
if op == LogOpType::And {
// skip True in And
continue;
}
if op == LogOpType::Or {

match (data.as_bool(), op) {
(true, LogOpType::Or) => {
// replace whole exprList with True
return ConstantPred::bool(true).into_pred_node();
}
unreachable!("no other type in logOp");
}
// FalseExpr
if op == LogOpType::And {
// replace whole exprList with False
return ConstantPred::bool(false).into_pred_node();
}
if op == LogOpType::Or {
// skip False in Or
continue;
(false, LogOpType::And) => {
// replace whole exprList with False
return ConstantPred::bool(false).into_pred_node();
}
_ => {
// skip True in `And`, and False in `Or`
continue;
}
}
unreachable!("no other type in logOp");
} else if !new_children_set.contains(&new_child) {
new_children_set.insert(new_child.clone());
new_children.push(new_child);
}
}
if new_children.is_empty() {
if op == LogOpType::And {
return ConstantPred::bool(true).into_pred_node();
}
if op == LogOpType::Or {
return ConstantPred::bool(false).into_pred_node();
match op {
LogOpType::And => {
return ConstantPred::bool(true).into_pred_node();
}
LogOpType::Or => {
return ConstantPred::bool(false).into_pred_node();
}
}
unreachable!("no other type in logOp");
}
if new_children.len() == 1 {
*changed = true;
Expand Down
174 changes: 174 additions & 0 deletions optd-datafusion-repr/src/rules/filter_pushdown.rs
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,87 @@ fn apply_filter_merge(
vec![new_filter.into_plan_node().into()]
}

// Rule to split predicates in a join condition into those that can be pushed down as filters.
define_rule!(
JoinInnerSplitFilterRule,
apply_join_split_filter,
(Join(JoinType::Inner), child_a, child_b)
);

define_rule!(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this rule is correct. You cannot move the outer join condition into a filter in some cases.

Consider select * from a left join b on a.x = b.y and b.z = 1. The result is different from select * from a left join b on a.x = b.y where b.z = 1. Assume left table is x=1, right table is y=1,z=2, the correct result is 1, NULL, NULL, versus the rule will produce zero rows.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ahh, I realized that this is a filter pushdown, then it might be correct; I will do a review later :)

JoinLeftOuterSplitFilterRule,
apply_join_split_filter,
(Join(JoinType::LeftOuter), child_a, child_b)
);

fn apply_join_split_filter(
optimizer: &impl Optimizer<DfNodeType>,
binding: ArcDfPlanNode,
) -> Vec<PlanNodeOrGroup<DfNodeType>> {
let join = LogicalJoin::from_plan_node(binding.clone()).unwrap();
let left_child = join.left();
let right_child = join.right();
let join_cond = join.cond();
let join_typ = join.join_type();

let left_schema_size = optimizer.get_schema_of(left_child.clone()).len();
let right_schema_size = optimizer.get_schema_of(right_child.clone()).len();

// Conditions that only involve the left relation.
let mut left_conds = vec![];
// Conditions that only involve the right relation.
let mut right_conds = vec![];
// Conditions that involve both relations.
let mut keep_conds = vec![];

let categorization_fn = |expr: ArcDfPredNode, children: &[ArcDfPredNode]| {
let location = determine_join_cond_dep(children, left_schema_size, right_schema_size);
match location {
JoinCondDependency::Left => left_conds.push(expr),
JoinCondDependency::Right => right_conds.push(
expr.rewrite_column_refs(|idx| {
Some(LogicalJoin::map_through_join(
idx,
left_schema_size,
right_schema_size,
))
})
.unwrap(),
),
JoinCondDependency::Both | JoinCondDependency::None => {
// JoinCondDependency::None could happy if there are no column refs in the predicate.
// e.g. true for CrossJoin.
keep_conds.push(expr);
}
}
};
categorize_conds(categorization_fn, join_cond);

let new_left = if !left_conds.is_empty() {
let new_filter_node =
LogicalFilter::new_unchecked(left_child, and_expr_list_to_expr(left_conds));
PlanNodeOrGroup::PlanNode(new_filter_node.into_plan_node())
} else {
left_child
};

let new_right = if !right_conds.is_empty() {
let new_filter_node =
LogicalFilter::new_unchecked(right_child, and_expr_list_to_expr(right_conds));
PlanNodeOrGroup::PlanNode(new_filter_node.into_plan_node())
} else {
right_child
};

let new_join = LogicalJoin::new_unchecked(
new_left,
new_right,
and_expr_list_to_expr(keep_conds),
*join_typ,
);

vec![new_join.into_plan_node().into()]
}
define_rule!(
FilterInnerJoinTransposeRule,
apply_filter_inner_join_transpose,
Expand Down Expand Up @@ -369,6 +450,8 @@ fn apply_filter_agg_transpose(
mod tests {
use std::sync::Arc;

use optd_core::nodes::Value;

use super::*;
use crate::plan_nodes::{BinOpPred, BinOpType, ConstantPred, LogicalScan};
use crate::testing::new_test_optimizer;
Expand Down Expand Up @@ -442,6 +525,97 @@ mod tests {
assert_eq!(col_4.value().as_i32(), 1);
}

#[test]
fn join_split_filter() {
let mut test_optimizer = new_test_optimizer(Arc::new(JoinLeftOuterSplitFilterRule::new()));

let scan1 = LogicalScan::new("customer".into());

let scan2 = LogicalScan::new("orders".into());

let join_cond = LogOpPred::new(
LogOpType::And,
vec![
BinOpPred::new(
// This one should be pushed to the left child
ColumnRefPred::new(0).into_pred_node(),
ConstantPred::int32(5).into_pred_node(),
BinOpType::Eq,
)
.into_pred_node(),
BinOpPred::new(
// This one should be pushed to the right child
ColumnRefPred::new(11).into_pred_node(),
ConstantPred::int32(6).into_pred_node(),
BinOpType::Eq,
)
.into_pred_node(),
BinOpPred::new(
// This one stays in the join condition.
ColumnRefPred::new(2).into_pred_node(),
ColumnRefPred::new(8).into_pred_node(),
BinOpType::Eq,
)
.into_pred_node(),
// This one stays in the join condition.
ConstantPred::bool(true).into_pred_node(),
],
);

let join = LogicalJoin::new(
scan1.into_plan_node(),
scan2.into_plan_node(),
join_cond.into_pred_node(),
super::JoinType::LeftOuter,
);

let plan = test_optimizer.optimize(join.into_plan_node()).unwrap();
let join = LogicalJoin::from_plan_node(plan.clone()).unwrap();

assert_eq!(join.join_type(), &JoinType::LeftOuter);

{
// Examine join conditions.
let join_conds = LogOpPred::from_pred_node(join.cond()).unwrap();
assert!(matches!(join_conds.op_type(), LogOpType::And));
assert_eq!(join_conds.children().len(), 2);
let bin_op_with_both_ref =
BinOpPred::from_pred_node(join_conds.children()[0].clone()).unwrap();
assert!(matches!(bin_op_with_both_ref.op_type(), BinOpType::Eq));
let col_2 = ColumnRefPred::from_pred_node(bin_op_with_both_ref.left_child()).unwrap();
let col_8 = ColumnRefPred::from_pred_node(bin_op_with_both_ref.right_child()).unwrap();
assert_eq!(col_2.index(), 2);
assert_eq!(col_8.index(), 8);
let constant_true =
ConstantPred::from_pred_node(join_conds.children()[1].clone()).unwrap();
assert_eq!(constant_true.value(), Value::Bool(true));
}

{
// Examine left child filter + condition
let filter_left =
LogicalFilter::from_plan_node(join.left().unwrap_plan_node()).unwrap();
let bin_op = BinOpPred::from_pred_node(filter_left.cond()).unwrap();
assert!(matches!(bin_op.op_type(), BinOpType::Eq));
let col = ColumnRefPred::from_pred_node(bin_op.left_child()).unwrap();
let constant = ConstantPred::from_pred_node(bin_op.right_child()).unwrap();
assert_eq!(col.index(), 0);
assert_eq!(constant.value().as_i32(), 5);
}

{
// Examine right child filter + condition
let filter_right =
LogicalFilter::from_plan_node(join.right().unwrap_plan_node()).unwrap();
let bin_op = BinOpPred::from_pred_node(filter_right.cond()).unwrap();
assert!(matches!(bin_op.op_type(), BinOpType::Eq));
let col = ColumnRefPred::from_pred_node(bin_op.left_child()).unwrap();
let constant = ConstantPred::from_pred_node(bin_op.right_child()).unwrap();
assert_eq!(col.index(), 3);
assert_eq!(constant.value().as_i32(), 6);
}
}

#[test]
fn push_past_join_conjunction() {
// Test pushing a complex filter past a join, where one clause can
Expand Down
19 changes: 16 additions & 3 deletions optd-datafusion-repr/src/rules/joins.rs
Original file line number Diff line number Diff line change
Expand Up @@ -141,11 +141,23 @@ fn apply_join_assoc(
}

define_impl_rule!(
HashJoinRule,
HashJoinInnerRule,
apply_hash_join,
(Join(JoinType::Inner), left, right)
);

define_impl_rule!(
HashJoinLeftOuterRule,
apply_hash_join,
(Join(JoinType::LeftOuter), left, right)
);

define_impl_rule!(
HashJoinLeftMarkRule,
apply_hash_join,
(Join(JoinType::LeftMark), left, right)
);

fn apply_hash_join(
optimizer: &impl Optimizer<DfNodeType>,
binding: ArcDfPlanNode,
Expand All @@ -154,6 +166,7 @@ fn apply_hash_join(
let cond = join.cond();
let left = join.left();
let right = join.right();
let join_type = join.join_type();
match cond.typ {
DfPredType::BinOp(BinOpType::Eq) => {
let left_schema = optimizer.get_schema_of(left.clone());
Expand Down Expand Up @@ -186,7 +199,7 @@ fn apply_hash_join(
right,
ListPred::new(vec![left_expr.into_pred_node()]),
ListPred::new(vec![right_expr.into_pred_node()]),
JoinType::Inner,
*join_type,
);
return vec![node.into_plan_node().into()];
}
Expand Down Expand Up @@ -244,7 +257,7 @@ fn apply_hash_join(
right,
ListPred::new(left_exprs),
ListPred::new(right_exprs),
JoinType::Inner,
*join_type,
);
return vec![node.into_plan_node().into()];
}
Expand Down
Loading
Loading