Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor(core): Standardized task graph #234

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion datafusion-optd-cli/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ pub async fn main() -> Result<()> {
let args = Args::parse();

tracing_subscriber::fmt()
.with_max_level(tracing::Level::INFO)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TODO: revert this

.with_max_level(tracing::Level::TRACE)
.with_target(false)
.with_ansi(false)
.init();
Expand Down
118 changes: 91 additions & 27 deletions optd-core/src/cascades/optimizer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

use std::collections::{BTreeSet, HashMap, HashSet, VecDeque};
use std::fmt::Display;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;

use anyhow::Result;
Expand All @@ -14,13 +15,14 @@ use super::memo::{ArcMemoPlanNode, GroupInfo, Memo};
use super::tasks::OptimizeGroupTask;
use super::{NaiveMemo, Task};
use crate::cascades::memo::Winner;
use crate::cascades::tasks::get_initial_task;
use crate::cost::CostModel;
use crate::nodes::{
ArcPlanNode, ArcPredNode, NodeType, PlanNodeMeta, PlanNodeMetaMap, PlanNodeOrGroup,
};
use crate::optimizer::Optimizer;
use crate::property::{PropertyBuilder, PropertyBuilderAny};
use crate::rules::Rule;
use crate::rules::{Rule, RuleMatcher};

pub type RuleId = usize;

Expand All @@ -43,11 +45,19 @@ pub struct OptimizerProperties {

pub struct CascadesOptimizer<T: NodeType, M: Memo<T> = NaiveMemo<T>> {
memo: M,
pub(super) tasks: VecDeque<Box<dyn Task<T, M>>>,
/// Stack of tasks that are waiting to be executed
tasks: Vec<Box<dyn Task<T, M>>>,
/// Monotonically increasing counter for task invocations
task_counter: AtomicUsize,
explored_group: HashSet<GroupId>,
explored_expr: HashSet<ExprId>,
fired_rules: HashMap<ExprId, HashSet<RuleId>>,
rules: Arc<[Arc<dyn Rule<T, Self>>]>,
applied_rules: HashMap<ExprId, HashSet<RuleId>>,
/// Transformation rules that may be used while exploring
/// (logical -> logical)
transformation_rules: Arc<[(RuleId, Arc<dyn Rule<T, Self>>)]>,
/// Implementation rules that may be used while optimizing
/// (logical -> physical)
implementation_rules: Arc<[(RuleId, Arc<dyn Rule<T, Self>>)]>,
disabled_rules: HashSet<usize>,
cost: Arc<dyn CostModel<T, M>>,
property_builders: Arc<[Box<dyn PropertyBuilderAny<T>>]>,
Expand Down Expand Up @@ -94,29 +104,52 @@ impl Display for PredId {

impl<T: NodeType> CascadesOptimizer<T, NaiveMemo<T>> {
pub fn new(
rules: Vec<Arc<dyn Rule<T, Self>>>,
transformation_rules: Arc<[Arc<dyn Rule<T, Self>>]>,
implementation_rules: Arc<[Arc<dyn Rule<T, Self>>]>,
cost: Box<dyn CostModel<T, NaiveMemo<T>>>,
property_builders: Vec<Box<dyn PropertyBuilderAny<T>>>,
) -> Self {
Self::new_with_prop(rules, cost, property_builders, Default::default())
Self::new_with_prop(
transformation_rules,
implementation_rules,
cost,
property_builders,
Default::default(),
)
}

pub fn new_with_prop(
rules: Vec<Arc<dyn Rule<T, Self>>>,
transformation_rules: Arc<[Arc<dyn Rule<T, Self>>]>,
implementation_rules: Arc<[Arc<dyn Rule<T, Self>>]>,
cost: Box<dyn CostModel<T, NaiveMemo<T>>>,
property_builders: Vec<Box<dyn PropertyBuilderAny<T>>>,
prop: OptimizerProperties,
) -> Self {
let tasks = VecDeque::new();
let tasks = Vec::new();
// Assign rule IDs
let transformation_rules: Arc<[(RuleId, Arc<dyn Rule<T, Self>>)]> = transformation_rules
.into_iter()
.enumerate()
.map(|(i, r)| (i, r.clone()))
.collect();
let implementation_rules: Arc<[(RuleId, Arc<dyn Rule<T, Self>>)]> = implementation_rules
.into_iter()
.enumerate()
.map(|(i, r)| (i + transformation_rules.len(), r.clone()))
.collect();
debug_assert!(transformation_rules.iter().all(|(_, r)| !r.is_impl_rule()));
debug_assert!(implementation_rules.iter().all(|(_, r)| r.is_impl_rule()));
let property_builders: Arc<[_]> = property_builders.into();
let memo = NaiveMemo::new(property_builders.clone());
Self {
memo,
task_counter: AtomicUsize::new(0),
tasks,
explored_group: HashSet::new(),
explored_expr: HashSet::new(),
fired_rules: HashMap::new(),
rules: rules.into(),
applied_rules: HashMap::new(),
transformation_rules,
implementation_rules,
cost: cost.into(),
ctx: OptimizerContext::default(),
property_builders,
Expand All @@ -128,7 +161,7 @@ impl<T: NodeType> CascadesOptimizer<T, NaiveMemo<T>> {
/// Clear the memo table and all optimizer states.
pub fn step_clear(&mut self) {
self.memo = NaiveMemo::new(self.property_builders.clone());
self.fired_rules.clear();
self.applied_rules.clear();
self.explored_group.clear();
self.explored_expr.clear();
}
Expand All @@ -153,8 +186,12 @@ impl<T: NodeType, M: Memo<T>> CascadesOptimizer<T, M> {
self.cost.clone()
}

pub fn rules(&self) -> Arc<[Arc<dyn Rule<T, Self>>]> {
self.rules.clone()
pub fn transformation_rules(&self) -> Arc<[(RuleId, Arc<dyn Rule<T, Self>>)]> {
self.transformation_rules.clone()
}

pub fn implementation_rules(&self) -> Arc<[(RuleId, Arc<dyn Rule<T, Self>>)]> {
self.implementation_rules.clone()
}

pub fn disable_rule(&mut self, rule_id: usize) {
Expand Down Expand Up @@ -215,7 +252,7 @@ impl<T: NodeType, M: Memo<T>> CascadesOptimizer<T, M> {
/// Optimize a `RelNode`.
pub fn step_optimize_rel(&mut self, root_rel: ArcPlanNode<T>) -> Result<GroupId> {
let (group_id, _) = self.add_new_expr(root_rel);
self.fire_optimize_tasks(group_id)?;
self.fire_optimize_tasks(group_id);
Ok(group_id)
}

Expand Down Expand Up @@ -247,17 +284,30 @@ impl<T: NodeType, M: Memo<T>> CascadesOptimizer<T, M> {
res
}

fn fire_optimize_tasks(&mut self, group_id: GroupId) -> Result<()> {
trace!(event = "fire_optimize_tasks", root_group_id = %group_id);
self.tasks
.push_back(Box::new(OptimizeGroupTask::new(group_id)));
pub fn get_next_task_id(&self) -> usize {
self.task_counter.fetch_add(1, Ordering::AcqRel)
}

pub fn push_task(&mut self, task: Box<dyn Task<T, M>>) {
self.tasks.push(task);
}

fn pop_task(&mut self) -> Option<Box<dyn Task<T, M>>> {
self.tasks.pop()
}

fn fire_optimize_tasks(&mut self, root_group_id: GroupId) {
trace!(event = "fire_optimize_tasks", root_group_id = %root_group_id);
let initial_task_id = self.get_next_task_id();
self.push_task(get_initial_task(initial_task_id, root_group_id));
// get the task from the stack
self.ctx.budget_used = false;
let plan_space_begin = self.memo.estimated_plan_space();
let mut iter = 0;
while let Some(task) = self.tasks.pop_back() {
let new_tasks = task.execute(self)?;
self.tasks.extend(new_tasks);
while let Some(task) = self.pop_task() {
task.execute(self);

// TODO: Iter is wrong
iter += 1;
if !self.ctx.budget_used {
let plan_space = self.memo.estimated_plan_space();
Expand Down Expand Up @@ -286,12 +336,11 @@ impl<T: NodeType, M: Memo<T>> CascadesOptimizer<T, M> {
}
}
}
Ok(())
}

fn optimize_inner(&mut self, root_rel: ArcPlanNode<T>) -> Result<ArcPlanNode<T>> {
let (group_id, _) = self.add_new_expr(root_rel);
self.fire_optimize_tasks(group_id)?;
self.fire_optimize_tasks(group_id);
self.memo.get_best_group_binding(group_id, |_, _, _| {})
}

Expand Down Expand Up @@ -374,15 +423,15 @@ impl<T: NodeType, M: Memo<T>> CascadesOptimizer<T, M> {
self.explored_expr.remove(&expr_id);
}

pub(super) fn is_rule_fired(&self, group_expr_id: ExprId, rule_id: RuleId) -> bool {
self.fired_rules
pub(super) fn is_rule_applied(&self, group_expr_id: ExprId, rule_id: RuleId) -> bool {
self.applied_rules
.get(&group_expr_id)
.map(|rules| rules.contains(&rule_id))
.unwrap_or(false)
}

pub(super) fn mark_rule_fired(&mut self, group_expr_id: ExprId, rule_id: RuleId) {
self.fired_rules
pub(super) fn mark_rule_applied(&mut self, group_expr_id: ExprId, rule_id: RuleId) {
self.applied_rules
.entry(group_expr_id)
.or_default()
.insert(rule_id);
Expand All @@ -406,3 +455,18 @@ impl<T: NodeType, M: Memo<T>> Optimizer<T> for CascadesOptimizer<T, M> {
self.get_property_by_group::<P>(self.resolve_group_id(root_rel), idx)
}
}

pub fn rule_matches_expr<T: NodeType, M: Memo<T>>(
rule: &Arc<dyn Rule<T, CascadesOptimizer<T, M>>>,
expr: &ArcMemoPlanNode<T>,
) -> bool {
let matcher = rule.matcher();
let typ_to_match = &expr.typ;
match matcher {
RuleMatcher::MatchNode { typ, .. } => typ == typ_to_match,
RuleMatcher::MatchDiscriminant {
typ_discriminant, ..
} => *typ_discriminant == std::mem::discriminant(typ_to_match),
_ => panic!("IR should have root node of match"), // TODO: what does this mean? replace text
}
}
23 changes: 17 additions & 6 deletions optd-core/src/cascades/tasks.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,24 +5,35 @@

use anyhow::Result;

use super::{CascadesOptimizer, Memo};
use super::{CascadesOptimizer, GroupId, Memo};
use crate::nodes::NodeType;

mod apply_rule;
mod explore_expr;
mod explore_group;
mod optimize_expression;
mod optimize_expr;
mod optimize_group;
mod optimize_inputs;

pub use apply_rule::ApplyRuleTask;
pub use explore_expr::ExploreExprTask;
pub use explore_group::ExploreGroupTask;
pub use optimize_expression::OptimizeExpressionTask;
pub use optimize_expr::OptimizeExprTask;
pub use optimize_group::OptimizeGroupTask;
pub use optimize_inputs::OptimizeInputsTask;

pub trait Task<T: NodeType, M: Memo<T>>: 'static + Send + Sync {
fn execute(&self, optimizer: &mut CascadesOptimizer<T, M>) -> Result<Vec<Box<dyn Task<T, M>>>>;
fn execute(&self, optimizer: &mut CascadesOptimizer<T, M>);
}

#[allow(dead_code)]
fn describe(&self) -> String;
pub fn get_initial_task<T: NodeType, M: Memo<T>>(
initial_task_id: usize,
root_group_id: GroupId,
) -> Box<dyn Task<T, M>> {
Box::new(OptimizeGroupTask::new(
None,
initial_task_id,
root_group_id,
None,
))
}
Loading
Loading