Skip to content

Commit 9d854a6

Browse files
committed
refactor: introduce Invariant levels, and make explicit how the post-optimization checker should be run
1 parent e71ef9f commit 9d854a6

File tree

2 files changed

+157
-24
lines changed

2 files changed

+157
-24
lines changed

datafusion/core/src/physical_planner.rs

+145-22
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,7 @@ use datafusion_expr::{
8383
use datafusion_physical_expr::aggregate::{AggregateExprBuilder, AggregateFunctionExpr};
8484
use datafusion_physical_expr::expressions::Literal;
8585
use datafusion_physical_expr::LexOrdering;
86+
use datafusion_physical_plan::execution_plan::InvariantLevel;
8687
use datafusion_physical_plan::placeholder_row::PlaceholderRowExec;
8788
use datafusion_physical_plan::unnest::ListUnnest;
8889
use datafusion_sql::utils::window_expr_common_partition_keys;
@@ -1875,6 +1876,10 @@ impl DefaultPhysicalPlanner {
18751876
displayable(plan.as_ref()).indent(true)
18761877
);
18771878

1879+
// This runs once before any optimization,
1880+
// to verify that the plan fulfills the base requirements.
1881+
InvariantChecker(InvariantLevel::Always).check(&plan)?;
1882+
18781883
let mut new_plan = Arc::clone(&plan);
18791884
for optimizer in optimizers {
18801885
let before_schema = new_plan.schema();
@@ -1884,9 +1889,9 @@ impl DefaultPhysicalPlanner {
18841889
DataFusionError::Context(optimizer.name().to_string(), Box::new(e))
18851890
})?;
18861891

1887-
// confirm optimizer change did not violate invariants
1888-
let mut validator = InvariantChecker::new(optimizer);
1889-
validator.check(&new_plan, before_schema)?;
1892+
// This only checks the schema in release build, and performs additional checks in debug mode.
1893+
OptimizationInvariantChecker::new(optimizer)
1894+
.check(&new_plan, before_schema)?;
18901895

18911896
trace!(
18921897
"Optimized physical plan by {}:\n{}\n",
@@ -1895,6 +1900,11 @@ impl DefaultPhysicalPlanner {
18951900
);
18961901
observer(new_plan.as_ref(), optimizer.as_ref())
18971902
}
1903+
1904+
// This runs once after all optimizer runs are complete,
1905+
// to verify that the plan is executable.
1906+
InvariantChecker(InvariantLevel::Executable).check(&new_plan)?;
1907+
18981908
debug!(
18991909
"Optimized physical plan:\n{}\n",
19001910
displayable(new_plan.as_ref()).indent(false)
@@ -2002,22 +2012,21 @@ fn tuple_err<T, R>(value: (Result<T>, Result<R>)) -> Result<(T, R)> {
20022012
}
20032013
}
20042014

2005-
/// Confirms that a given [`PhysicalOptimizerRule`] run
2006-
/// did not violate the [`ExecutionPlan`] invariants.
2007-
struct InvariantChecker<'a> {
2015+
struct OptimizationInvariantChecker<'a> {
20082016
rule: &'a Arc<dyn PhysicalOptimizerRule + Send + Sync>,
20092017
}
20102018

2011-
impl<'a> InvariantChecker<'a> {
2012-
/// Create an [`InvariantChecker`].
2019+
impl<'a> OptimizationInvariantChecker<'a> {
2020+
/// Create an [`OptimizationInvariantChecker`] that performs checking per tule.
20132021
pub fn new(rule: &'a Arc<dyn PhysicalOptimizerRule + Send + Sync>) -> Self {
20142022
Self { rule }
20152023
}
20162024

20172025
/// Checks that the plan change is permitted, returning an Error if not.
20182026
///
2027+
/// Conditionally performs schema checks per [PhysicalOptimizerRule::schema_check].
20192028
/// In debug mode, this recursively walks the entire physical plan
2020-
/// and performs [`ExecutionPlan::check_node_invariants`].
2029+
/// and performs [`ExecutionPlan::check_invariants`].
20212030
pub fn check(
20222031
&mut self,
20232032
plan: &Arc<dyn ExecutionPlan>,
@@ -2032,19 +2041,48 @@ impl<'a> InvariantChecker<'a> {
20322041
)?
20332042
}
20342043

2035-
// check invariants per ExecutionPlan extension
2044+
// check invariants per each ExecutionPlan node
20362045
#[cfg(debug_assertions)]
20372046
plan.visit(self)?;
20382047

20392048
Ok(())
20402049
}
20412050
}
20422051

2043-
impl<'n> TreeNodeVisitor<'n> for InvariantChecker<'_> {
2052+
impl<'n> TreeNodeVisitor<'n> for OptimizationInvariantChecker<'_> {
20442053
type Node = Arc<dyn ExecutionPlan>;
20452054

20462055
fn f_down(&mut self, node: &'n Self::Node) -> Result<TreeNodeRecursion> {
2047-
node.check_node_invariants().map_err(|e| e.context(format!("Invariant for ExecutionPlan node '{}' failed for PhysicalOptimizer rule '{}'", node.name(), self.rule.name())))?;
2056+
// Checks for the more permissive `InvariantLevel::Always`.
2057+
// Plans are not guarenteed to be executable after each physical optimizer run.
2058+
node.check_invariants(InvariantLevel::Always).map_err(|e| e.context(format!("Invariant for ExecutionPlan node '{}' failed for PhysicalOptimizer rule '{}'", node.name(), self.rule.name())))?;
2059+
Ok(TreeNodeRecursion::Continue)
2060+
}
2061+
}
2062+
2063+
/// Check [`ExecutionPlan`] invariants per [`InvariantLevel`].
2064+
struct InvariantChecker(InvariantLevel);
2065+
2066+
impl InvariantChecker {
2067+
/// Checks that the plan is executable, returning an Error if not.
2068+
pub fn check(&mut self, plan: &Arc<dyn ExecutionPlan>) -> Result<()> {
2069+
// check invariants per each ExecutionPlan node
2070+
plan.visit(self)?;
2071+
2072+
Ok(())
2073+
}
2074+
}
2075+
2076+
impl<'n> TreeNodeVisitor<'n> for InvariantChecker {
2077+
type Node = Arc<dyn ExecutionPlan>;
2078+
2079+
fn f_down(&mut self, node: &'n Self::Node) -> Result<TreeNodeRecursion> {
2080+
node.check_invariants(self.0).map_err(|e| {
2081+
e.context(format!(
2082+
"Invariant for ExecutionPlan node '{}' failed",
2083+
node.name()
2084+
))
2085+
})?;
20482086
Ok(TreeNodeRecursion::Continue)
20492087
}
20502088
}
@@ -2864,15 +2902,18 @@ digraph {
28642902
}
28652903
}
28662904

2867-
/// Extension Node which fails invariant checks
2905+
/// Extension Node which fails the [`OptimizationInvariantChecker`].
28682906
#[derive(Debug)]
28692907
struct InvariantFailsExtensionNode;
28702908
impl ExecutionPlan for InvariantFailsExtensionNode {
28712909
fn name(&self) -> &str {
28722910
"InvariantFailsExtensionNode"
28732911
}
2874-
fn check_node_invariants(&self) -> Result<()> {
2875-
plan_err!("extension node failed it's user-defined invariant check")
2912+
fn check_invariants(&self, check: InvariantLevel) -> Result<()> {
2913+
match check {
2914+
InvariantLevel::Always => plan_err!("extension node failed it's user-defined always-invariant check"),
2915+
InvariantLevel::Executable => panic!("the OptimizationInvariantChecker should not be checking for executableness"),
2916+
}
28762917
}
28772918
fn schema(&self) -> SchemaRef {
28782919
Arc::new(Schema::empty())
@@ -2926,7 +2967,7 @@ digraph {
29262967
}
29272968

29282969
#[test]
2929-
fn test_invariant_checker() -> Result<()> {
2970+
fn test_optimization_invariant_checker() -> Result<()> {
29302971
let rule: Arc<dyn PhysicalOptimizerRule + Send + Sync> =
29312972
Arc::new(OptimizerRuleWithSchemaCheck);
29322973

@@ -2940,37 +2981,119 @@ digraph {
29402981

29412982
// Test: check should pass with same schema
29422983
let equal_schema = ok_plan.schema();
2943-
InvariantChecker::new(&rule).check(&ok_plan, equal_schema)?;
2984+
OptimizationInvariantChecker::new(&rule).check(&ok_plan, equal_schema)?;
29442985

29452986
// Test: should fail with schema changed
29462987
let different_schema =
29472988
Arc::new(Schema::new(vec![Field::new("a", DataType::Boolean, false)]));
2948-
let expected_err = InvariantChecker::new(&rule)
2989+
let expected_err = OptimizationInvariantChecker::new(&rule)
29492990
.check(&ok_plan, different_schema)
29502991
.unwrap_err();
29512992
assert!(expected_err.to_string().contains("PhysicalOptimizer rule 'OptimizerRuleWithSchemaCheck' failed, due to generate a different schema"));
29522993

29532994
// Test: should fail when extension node fails it's own invariant check
29542995
let failing_node: Arc<dyn ExecutionPlan> = Arc::new(InvariantFailsExtensionNode);
2955-
let expected_err = InvariantChecker::new(&rule)
2996+
let expected_err = OptimizationInvariantChecker::new(&rule)
29562997
.check(&failing_node, ok_plan.schema())
29572998
.unwrap_err();
29582999
assert!(expected_err
29593000
.to_string()
2960-
.contains("extension node failed it's user-defined invariant check"));
3001+
.contains("extension node failed it's user-defined always-invariant check"));
29613002

29623003
// Test: should fail when descendent extension node fails
29633004
let failing_node: Arc<dyn ExecutionPlan> = Arc::new(InvariantFailsExtensionNode);
29643005
let invalid_plan = ok_node.with_new_children(vec![
29653006
Arc::clone(&child).with_new_children(vec![Arc::clone(&failing_node)])?,
29663007
Arc::clone(&child),
29673008
])?;
2968-
let expected_err = InvariantChecker::new(&rule)
3009+
let expected_err = OptimizationInvariantChecker::new(&rule)
29693010
.check(&invalid_plan, ok_plan.schema())
29703011
.unwrap_err();
29713012
assert!(expected_err
29723013
.to_string()
2973-
.contains("extension node failed it's user-defined invariant check"));
3014+
.contains("extension node failed it's user-defined always-invariant check"));
3015+
3016+
Ok(())
3017+
}
3018+
3019+
/// Extension Node which fails the [`InvariantChecker`]
3020+
/// if, and only if, [`InvariantLevel::Executable`]
3021+
#[derive(Debug)]
3022+
struct ExecutableInvariantFails;
3023+
impl ExecutionPlan for ExecutableInvariantFails {
3024+
fn name(&self) -> &str {
3025+
"ExecutableInvariantFails"
3026+
}
3027+
fn check_invariants(&self, check: InvariantLevel) -> Result<()> {
3028+
match check {
3029+
InvariantLevel::Always => Ok(()),
3030+
InvariantLevel::Executable => plan_err!(
3031+
"extension node failed it's user-defined executable-invariant check"
3032+
),
3033+
}
3034+
}
3035+
fn schema(&self) -> SchemaRef {
3036+
Arc::new(Schema::empty())
3037+
}
3038+
fn with_new_children(
3039+
self: Arc<Self>,
3040+
_children: Vec<Arc<dyn ExecutionPlan>>,
3041+
) -> Result<Arc<dyn ExecutionPlan>> {
3042+
unimplemented!()
3043+
}
3044+
fn as_any(&self) -> &dyn Any {
3045+
unimplemented!()
3046+
}
3047+
fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
3048+
vec![]
3049+
}
3050+
fn properties(&self) -> &PlanProperties {
3051+
unimplemented!()
3052+
}
3053+
fn execute(
3054+
&self,
3055+
_partition: usize,
3056+
_context: Arc<TaskContext>,
3057+
) -> Result<SendableRecordBatchStream> {
3058+
unimplemented!()
3059+
}
3060+
}
3061+
impl DisplayAs for ExecutableInvariantFails {
3062+
fn fmt_as(&self, _t: DisplayFormatType, f: &mut fmt::Formatter) -> fmt::Result {
3063+
write!(f, "{}", self.name())
3064+
}
3065+
}
3066+
3067+
#[test]
3068+
fn test_invariant_checker_levels() -> Result<()> {
3069+
// plan that passes the always-invariant, but fails the executable check
3070+
let plan: Arc<dyn ExecutionPlan> = Arc::new(ExecutableInvariantFails);
3071+
3072+
// Test: check should pass with less stringent Always check
3073+
InvariantChecker(InvariantLevel::Always).check(&plan)?;
3074+
3075+
// Test: should fail the executable check
3076+
let expected_err = InvariantChecker(InvariantLevel::Executable)
3077+
.check(&plan)
3078+
.unwrap_err();
3079+
assert!(expected_err.to_string().contains(
3080+
"extension node failed it's user-defined executable-invariant check"
3081+
));
3082+
3083+
// Test: should fail when descendent extension node fails
3084+
let failing_node: Arc<dyn ExecutionPlan> = Arc::new(ExecutableInvariantFails);
3085+
let ok_node: Arc<dyn ExecutionPlan> = Arc::new(OkExtensionNode(vec![]));
3086+
let child = Arc::clone(&ok_node);
3087+
let plan = ok_node.with_new_children(vec![
3088+
Arc::clone(&child).with_new_children(vec![Arc::clone(&failing_node)])?,
3089+
Arc::clone(&child),
3090+
])?;
3091+
let expected_err = InvariantChecker(InvariantLevel::Executable)
3092+
.check(&plan)
3093+
.unwrap_err();
3094+
assert!(expected_err.to_string().contains(
3095+
"extension node failed it's user-defined executable-invariant check"
3096+
));
29743097

29753098
Ok(())
29763099
}

datafusion/physical-plan/src/execution_plan.rs

+12-2
Original file line numberDiff line numberDiff line change
@@ -115,8 +115,7 @@ pub trait ExecutionPlan: Debug + DisplayAs + Send + Sync {
115115
///
116116
/// A default set of invariants is provided in the default implementation.
117117
/// Extension nodes can provide their own invariants.
118-
fn check_node_invariants(&self) -> Result<()> {
119-
// TODO
118+
fn check_invariants(&self, _check: InvariantLevel) -> Result<()> {
120119
Ok(())
121120
}
122121

@@ -434,6 +433,17 @@ pub trait ExecutionPlan: Debug + DisplayAs + Send + Sync {
434433
}
435434
}
436435

436+
#[derive(Clone, Copy)]
437+
pub enum InvariantLevel {
438+
/// Invariants that are always true for the [`ExecutionPlan`] node
439+
/// such as the number of expected children.
440+
Always,
441+
/// Invariants that must hold true for the [`ExecutionPlan`] node
442+
/// to be "executable", such as ordering and/or distribution requirements
443+
/// being fulfilled.
444+
Executable,
445+
}
446+
437447
/// Extension trait provides an easy API to fetch various properties of
438448
/// [`ExecutionPlan`] objects based on [`ExecutionPlan::properties`].
439449
pub trait ExecutionPlanProperties {

0 commit comments

Comments
 (0)