Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Preserve constant values across union operations #13805

Draft
wants to merge 18 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 20 additions & 2 deletions datafusion/physical-expr/src/equivalence/class.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ use std::fmt::Display;
use std::sync::Arc;

use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode};
use datafusion_common::JoinType;
use datafusion_common::{JoinType, ScalarValue};
use datafusion_physical_expr_common::physical_expr::format_physical_expr_list;

use indexmap::{IndexMap, IndexSet};
Expand Down Expand Up @@ -62,11 +62,15 @@ pub struct ConstExpr {
/// Does the constant have the same value across all partitions? See
/// struct docs for more details
across_partitions: bool,
/// The value of the constant expression
value: Option<ScalarValue>,
}

impl PartialEq for ConstExpr {
fn eq(&self, other: &Self) -> bool {
self.across_partitions == other.across_partitions && self.expr.eq(&other.expr)
self.across_partitions == other.across_partitions
&& self.expr.eq(&other.expr)
&& self.value == other.value
}
}

Expand All @@ -80,9 +84,15 @@ impl ConstExpr {
expr,
// By default, assume constant expressions are not same across partitions.
across_partitions: false,
value: None,
}
}

pub fn with_value(mut self, value: ScalarValue) -> Self {
self.value = Some(value);
self
}

/// Set the `across_partitions` flag
///
/// See struct docs for more details
Expand All @@ -106,6 +116,10 @@ impl ConstExpr {
self.expr
}

pub fn value(&self) -> Option<&ScalarValue> {
self.value.as_ref()
}

pub fn map<F>(&self, f: F) -> Option<Self>
where
F: Fn(&Arc<dyn PhysicalExpr>) -> Option<Arc<dyn PhysicalExpr>>,
Expand All @@ -114,6 +128,7 @@ impl ConstExpr {
maybe_expr.map(|expr| Self {
expr,
across_partitions: self.across_partitions,
value: self.value.clone(),
})
}

Expand Down Expand Up @@ -152,6 +167,9 @@ impl Display for ConstExpr {
if self.across_partitions {
write!(f, "(across_partitions)")?;
}
if let Some(value) = self.value.as_ref() {
write!(f, "({})", value)?;
}
Ok(())
}
}
Expand Down
185 changes: 136 additions & 49 deletions datafusion/physical-expr/src/equivalence/properties.rs
Original file line number Diff line number Diff line change
Expand Up @@ -257,14 +257,32 @@ impl EquivalenceProperties {
if self.is_expr_constant(left) {
// Left expression is constant, add right as constant
if !const_exprs_contains(&self.constants, right) {
self.constants
.push(ConstExpr::from(right).with_across_partitions(true));
// Try to get value from left constant expression
let value = left
.as_any()
.downcast_ref::<Literal>()
.map(|lit| lit.value().clone());

let mut const_expr = ConstExpr::from(right).with_across_partitions(true);
if let Some(val) = value {
const_expr = const_expr.with_value(val);
}
self.constants.push(const_expr);
}
} else if self.is_expr_constant(right) {
// Right expression is constant, add left as constant
if !const_exprs_contains(&self.constants, left) {
self.constants
.push(ConstExpr::from(left).with_across_partitions(true));
// Try to get value from right constant expression
let value = right
.as_any()
.downcast_ref::<Literal>()
.map(|lit| lit.value().clone());

let mut const_expr = ConstExpr::from(left).with_across_partitions(true);
if let Some(val) = value {
const_expr = const_expr.with_value(val);
}
self.constants.push(const_expr);
}
}

Expand Down Expand Up @@ -293,30 +311,33 @@ impl EquivalenceProperties {
mut self,
constants: impl IntoIterator<Item = ConstExpr>,
) -> Self {
let (const_exprs, across_partition_flags): (
Vec<Arc<dyn PhysicalExpr>>,
Vec<bool>,
) = constants
let normalized_constants = constants
.into_iter()
.map(|const_expr| {
let across_partitions = const_expr.across_partitions();
let expr = const_expr.owned_expr();
(expr, across_partitions)
.filter_map(|c| {
let across_partitions = c.across_partitions();
let value = c.value().cloned();
let expr = c.owned_expr();
let normalized_expr = self.eq_group.normalize_expr(expr);

if const_exprs_contains(&self.constants, &normalized_expr) {
return None;
}

let mut const_expr = ConstExpr::from(normalized_expr)
.with_across_partitions(across_partitions);

if let Some(value) = value {
const_expr = const_expr.with_value(value);
}

Some(const_expr)
})
.unzip();
for (expr, across_partitions) in self
.eq_group
.normalize_exprs(const_exprs)
.into_iter()
.zip(across_partition_flags)
{
if !const_exprs_contains(&self.constants, &expr) {
let const_expr =
ConstExpr::from(expr).with_across_partitions(across_partitions);
self.constants.push(const_expr);
}
}
.collect::<Vec<_>>();

// Add all new normalized constants
self.constants.extend(normalized_constants);

// Discover any new orderings based on the constants
for ordering in self.normalized_oeq_class().iter() {
if let Err(e) = self.discover_new_orderings(&ordering[0].expr) {
log::debug!("error discovering new orderings: {e}");
Expand Down Expand Up @@ -875,19 +896,39 @@ impl EquivalenceProperties {
.constants
.iter()
.flat_map(|const_expr| {
const_expr.map(|expr| self.eq_group.project_expr(mapping, expr))
const_expr
.map(|expr| self.eq_group.project_expr(mapping, expr))
.map(|projected_expr| {
let mut new_const_expr = projected_expr
.with_across_partitions(const_expr.across_partitions());
if let Some(value) = const_expr.value() {
new_const_expr = new_const_expr.with_value(value.clone());
}
new_const_expr
})
})
.collect::<Vec<_>>();

// Add projection expressions that are known to be constant:
for (source, target) in mapping.iter() {
if self.is_expr_constant(source)
&& !const_exprs_contains(&projected_constants, target)
{
let across_partitions = self.is_expr_constant_accross_partitions(source);
// Try to get value from source constant expression
let value = self
.constants
.iter()
.find(|c| c.expr().eq(source))
.and_then(|c| c.value().cloned());

// Expression evaluates to single value
projected_constants.push(
ConstExpr::from(target).with_across_partitions(across_partitions),
);
let mut const_expr =
ConstExpr::from(target).with_across_partitions(across_partitions);
if let Some(val) = value {
const_expr = const_expr.with_value(val);
}
projected_constants.push(const_expr);
}
}
projected_constants
Expand Down Expand Up @@ -1099,9 +1140,14 @@ impl EquivalenceProperties {
.into_iter()
.map(|const_expr| {
let across_partitions = const_expr.across_partitions();
let value = const_expr.value().cloned();
let new_const_expr = with_new_schema(const_expr.owned_expr(), &schema)?;
Ok(ConstExpr::new(new_const_expr)
.with_across_partitions(across_partitions))
let mut new_const_expr = ConstExpr::new(new_const_expr)
.with_across_partitions(across_partitions);
if let Some(value) = value {
new_const_expr = new_const_expr.with_value(value.clone());
}
Ok(new_const_expr)
})
.collect::<Result<Vec<_>>>()?;

Expand Down Expand Up @@ -1852,7 +1898,7 @@ impl Hash for ExprWrapper {
/// *all* output partitions, that is the same as being true for all *input*
/// partitions
fn calculate_union_binary(
mut lhs: EquivalenceProperties,
lhs: EquivalenceProperties,
mut rhs: EquivalenceProperties,
) -> Result<EquivalenceProperties> {
// Harmonize the schema of the rhs with the schema of the lhs (which is the accumulator schema):
Expand All @@ -1861,26 +1907,32 @@ fn calculate_union_binary(
}

// First, calculate valid constants for the union. An expression is constant
// at the output of the union if it is constant in both sides.
let constants: Vec<_> = lhs
// at the output of the union if it is constant in both sides with matching values.
let constants = lhs
.constants()
.iter()
.filter(|const_expr| const_exprs_contains(rhs.constants(), const_expr.expr()))
.map(|const_expr| {
// TODO: When both sides have a constant column, and the actual
// constant value is the same, then the output properties could
// reflect the constant is valid across all partitions. However we
// don't track the actual value that the ConstExpr takes on, so we
// can't determine that yet
ConstExpr::new(Arc::clone(const_expr.expr())).with_across_partitions(false)
})
.collect();
.filter_map(|lhs_const| {
// Find matching constant expression in RHS
rhs.constants()
.iter()
.find(|rhs_const| rhs_const.expr().eq(lhs_const.expr()))
.map(|rhs_const| {
let mut const_expr = ConstExpr::new(Arc::clone(lhs_const.expr()));

// remove any constants that are shared in both outputs (avoid double counting them)
for c in &constants {
lhs = lhs.remove_constant(c);
rhs = rhs.remove_constant(c);
}
// If both sides have matching constant values, preserve the value and set across_partitions=true
if let (Some(lhs_val), Some(rhs_val)) =
(lhs_const.value(), rhs_const.value())
{
if lhs_val == rhs_val {
const_expr = const_expr
.with_across_partitions(true)
.with_value(lhs_val.clone());
}
}
const_expr
})
})
.collect::<Vec<_>>();

// Next, calculate valid orderings for the union by searching for prefixes
// in both sides.
Expand Down Expand Up @@ -2146,6 +2198,7 @@ mod tests {

use arrow::datatypes::{DataType, Field, Schema};
use arrow_schema::{Fields, TimeUnit};
use datafusion_common::ScalarValue;
use datafusion_expr::Operator;

#[test]
Expand Down Expand Up @@ -3684,4 +3737,38 @@ mod tests {

sort_expr
}

#[test]
fn test_union_constant_value_preservation() -> Result<()> {
let schema = Arc::new(Schema::new(vec![
Field::new("a", DataType::Int32, true),
Field::new("b", DataType::Int32, true),
]));

let col_a = col("a", &schema)?;
let literal_10 = ScalarValue::Int32(Some(10));

// Create first input with a=10
let const_expr1 =
ConstExpr::new(Arc::clone(&col_a)).with_value(literal_10.clone());
let input1 = EquivalenceProperties::new(Arc::clone(&schema))
.with_constants(vec![const_expr1]);

// Create second input with a=10
let const_expr2 =
ConstExpr::new(Arc::clone(&col_a)).with_value(literal_10.clone());
let input2 = EquivalenceProperties::new(Arc::clone(&schema))
.with_constants(vec![const_expr2]);

// Calculate union properties
let union_props = calculate_union(vec![input1, input2], schema)?;

// Verify column 'a' remains constant with value 10
let const_a = &union_props.constants()[0];
assert!(const_a.expr().eq(&col_a));
assert!(const_a.across_partitions());
assert_eq!(const_a.value(), Some(&literal_10));

Ok(())
}
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a way to crate an end to end .slt test that shows this behavior?

For example, a EXPLAIN PLAN where a Sort is optimized away after the constant value is propagated through the union?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good idea! I have one in my mind. Let me add it

Copy link
Contributor

@berkaysynnada berkaysynnada Dec 17, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey @alamb, I tried it but after thinking more, we actually need one more step in planner to experience an end-to-end difference. Now we have the knowledge, but we are not using it. 2 possible optimizations are which come to my mind now:
Let's assume we have:

# Constant value tracking across union
query TT
explain
SELECT * FROM(
(
    SELECT * FROM aggregate_test_100 WHERE c1='a'
)
UNION ALL
(
    SELECT * FROM aggregate_test_100 WHERE c1='a'
))
ORDER BY c1
----
+   physical_plan
+   01)SortPreservingMergeExec: [c1@0 ASC NULLS LAST]
+   02)--UnionExec
+   03)----CoalesceBatchesExec: target_batch_size=2
+   04)------FilterExec: c1@0 = a
+   05)--------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1
+   06)----------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c1, c2, c3, c4, c5, c6, c7, c8, c9, c10, c11, c12, c13], has_header=true
+   07)----CoalesceBatchesExec: target_batch_size=2
+   08)------FilterExec: c1@0 = a
+   09)--------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1
+   10)----------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c1, c2, c3, c4, c5, c6, c7, c8, c9, c10, c11, c12, c13], has_header=true
  1. At the top of the plan, we see an SPM. However, it can have a CoalescePartitionsExec instead. That would improve the performance for sure.
  2. For the same query without an order by but with another outer filter, we will see another filter. However, we can actually remove that. This is another optimization, but can be observed pretty rarely rather than 1st one.

2nd one could be not really realistic, but the first one could be implemented without much effort with a few changes in replace_with_order_preserving_variants scope.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you take a look at the first check @gokselk? It should take a few line changes in plan_with_order_preserving_variants() function. It should first look the order requirements, and if they are matched, then it would try to convert CoalescePartitionExec to SortPreservingMergeExec. But before that conversion, you can check across_partitions flag of the input constants, and if it is true, you can left the CoalescePartitionsExec as is.

Copy link
Contributor Author

@gokselk gokselk Dec 17, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you take a look at the first check @gokselk? It should take a few line changes in plan_with_order_preserving_variants() function. It should first look the order requirements, and if they are matched, then it would try to convert CoalescePartitionExec to SortPreservingMergeExec. But before that conversion, you can check across_partitions flag of the input constants, and if it is true, you can left the CoalescePartitionsExec as is.

I've made changes to FilterExec for value extraction and added an initial SLT file. The query now shows CoalescePartitionExec in the output, so I think your suggested changes to plan_with_order_preserving_variants() might not be needed anymore. However, I'd appreciate your review to confirm this.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It appears that I broke some ORDER BY queries in my recent commits. I will investigate this further.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To add more context, some tests are failing non-deterministically, which is why I didn't notice it beforehand.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure what the actual situation is w.r.t. those tests, but I'd advise to take a look at whether they were underspecified in the first place (i.e. the query itself may not be specifying a concrete output ordering, which could make the test flaky).

Do failing queries have top level ORDER BY clauses? If so, it is probably a bug that was introduced. Otherwise, maybe they were flaky in the first place.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I noticed that some parts of the code don't assign values to ConstExpr when they could. I'll add these assignments and check if this resolves the problem.

39 changes: 31 additions & 8 deletions datafusion/physical-plan/src/filter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ use datafusion_common::{
use datafusion_execution::TaskContext;
use datafusion_expr::Operator;
use datafusion_physical_expr::equivalence::ProjectionMapping;
use datafusion_physical_expr::expressions::BinaryExpr;
use datafusion_physical_expr::expressions::{BinaryExpr, Literal};
use datafusion_physical_expr::intervals::utils::check_support;
use datafusion_physical_expr::utils::collect_columns;
use datafusion_physical_expr::{
Expand Down Expand Up @@ -218,13 +218,29 @@ impl FilterExec {
if binary.op() == &Operator::Eq {
// Filter evaluates to single value for all partitions
if input_eqs.is_expr_constant(binary.left()) {
res_constants.push(
ConstExpr::from(binary.right()).with_across_partitions(true),
)
// When left side is constant, extract value from right side if it's a literal
let (expr, lit) = (
binary.right(),
binary.right().as_any().downcast_ref::<Literal>(),
);
let mut const_expr =
ConstExpr::from(expr).with_across_partitions(true);
if let Some(lit) = lit {
const_expr = const_expr.with_value(lit.value().clone());
}
res_constants.push(const_expr);
} else if input_eqs.is_expr_constant(binary.right()) {
res_constants.push(
ConstExpr::from(binary.left()).with_across_partitions(true),
)
// When right side is constant, extract value from left side if it's a literal
let (expr, lit) = (
binary.left(),
binary.left().as_any().downcast_ref::<Literal>(),
);
let mut const_expr =
ConstExpr::from(expr).with_across_partitions(true);
if let Some(lit) = lit {
const_expr = const_expr.with_value(lit.value().clone());
}
res_constants.push(const_expr);
}
}
}
Expand Down Expand Up @@ -252,8 +268,15 @@ impl FilterExec {
.into_iter()
.filter(|column| stats.column_statistics[column.index()].is_singleton())
.map(|column| {
let value = stats.column_statistics[column.index()]
.min_value
.get_value();
let expr = Arc::new(column) as _;
ConstExpr::new(expr).with_across_partitions(true)
let mut const_expr = ConstExpr::new(expr).with_across_partitions(true);
if let Some(value) = value {
const_expr = const_expr.with_value(value.clone());
}
const_expr
});
// This is for statistics
eq_properties = eq_properties.with_constants(constants);
Expand Down
Loading
Loading