Skip to content
This repository was archived by the owner on Jan 7, 2025. It is now read-only.

refactor(core): Standardized task graph #234

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion datafusion-optd-cli/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ pub async fn main() -> Result<()> {
let args = Args::parse();

tracing_subscriber::fmt()
.with_max_level(tracing::Level::INFO)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TODO: revert this

.with_max_level(tracing::Level::TRACE)
.with_target(false)
.with_ansi(false)
.init();
Expand Down
118 changes: 91 additions & 27 deletions optd-core/src/cascades/optimizer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

use std::collections::{BTreeSet, HashMap, HashSet, VecDeque};
use std::fmt::Display;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;

use anyhow::Result;
Expand All @@ -14,13 +15,14 @@ use super::memo::{ArcMemoPlanNode, GroupInfo, Memo};
use super::tasks::OptimizeGroupTask;
use super::{NaiveMemo, Task};
use crate::cascades::memo::Winner;
use crate::cascades::tasks::get_initial_task;
use crate::cost::CostModel;
use crate::nodes::{
ArcPlanNode, ArcPredNode, NodeType, PlanNodeMeta, PlanNodeMetaMap, PlanNodeOrGroup,
};
use crate::optimizer::Optimizer;
use crate::property::{PropertyBuilder, PropertyBuilderAny};
use crate::rules::Rule;
use crate::rules::{Rule, RuleMatcher};

pub type RuleId = usize;

Expand All @@ -43,11 +45,19 @@ pub struct OptimizerProperties {

pub struct CascadesOptimizer<T: NodeType, M: Memo<T> = NaiveMemo<T>> {
memo: M,
pub(super) tasks: VecDeque<Box<dyn Task<T, M>>>,
/// Stack of tasks that are waiting to be executed
tasks: Vec<Box<dyn Task<T, M>>>,
/// Monotonically increasing counter for task invocations
task_counter: AtomicUsize,
explored_group: HashSet<GroupId>,
explored_expr: HashSet<ExprId>,
fired_rules: HashMap<ExprId, HashSet<RuleId>>,
rules: Arc<[Arc<dyn Rule<T, Self>>]>,
applied_rules: HashMap<ExprId, HashSet<RuleId>>,
/// Transformation rules that may be used while exploring
/// (logical -> logical)
transformation_rules: Arc<[(RuleId, Arc<dyn Rule<T, Self>>)]>,
/// Implementation rules that may be used while optimizing
/// (logical -> physical)
implementation_rules: Arc<[(RuleId, Arc<dyn Rule<T, Self>>)]>,
disabled_rules: HashSet<usize>,
cost: Arc<dyn CostModel<T, M>>,
property_builders: Arc<[Box<dyn PropertyBuilderAny<T>>]>,
Expand Down Expand Up @@ -94,29 +104,52 @@ impl Display for PredId {

impl<T: NodeType> CascadesOptimizer<T, NaiveMemo<T>> {
pub fn new(
rules: Vec<Arc<dyn Rule<T, Self>>>,
transformation_rules: Arc<[Arc<dyn Rule<T, Self>>]>,
implementation_rules: Arc<[Arc<dyn Rule<T, Self>>]>,
cost: Box<dyn CostModel<T, NaiveMemo<T>>>,
property_builders: Vec<Box<dyn PropertyBuilderAny<T>>>,
) -> Self {
Self::new_with_prop(rules, cost, property_builders, Default::default())
Self::new_with_prop(
transformation_rules,
implementation_rules,
cost,
property_builders,
Default::default(),
)
}

pub fn new_with_prop(
rules: Vec<Arc<dyn Rule<T, Self>>>,
transformation_rules: Arc<[Arc<dyn Rule<T, Self>>]>,
implementation_rules: Arc<[Arc<dyn Rule<T, Self>>]>,
cost: Box<dyn CostModel<T, NaiveMemo<T>>>,
property_builders: Vec<Box<dyn PropertyBuilderAny<T>>>,
prop: OptimizerProperties,
) -> Self {
let tasks = VecDeque::new();
let tasks = Vec::new();
// Assign rule IDs
let transformation_rules: Arc<[(RuleId, Arc<dyn Rule<T, Self>>)]> = transformation_rules
.into_iter()
.enumerate()
.map(|(i, r)| (i, r.clone()))
.collect();
let implementation_rules: Arc<[(RuleId, Arc<dyn Rule<T, Self>>)]> = implementation_rules
.into_iter()
.enumerate()
.map(|(i, r)| (i + transformation_rules.len(), r.clone()))
.collect();
debug_assert!(transformation_rules.iter().all(|(_, r)| !r.is_impl_rule()));
debug_assert!(implementation_rules.iter().all(|(_, r)| r.is_impl_rule()));
let property_builders: Arc<[_]> = property_builders.into();
let memo = NaiveMemo::new(property_builders.clone());
Self {
memo,
task_counter: AtomicUsize::new(0),
tasks,
explored_group: HashSet::new(),
explored_expr: HashSet::new(),
fired_rules: HashMap::new(),
rules: rules.into(),
applied_rules: HashMap::new(),
transformation_rules,
implementation_rules,
cost: cost.into(),
ctx: OptimizerContext::default(),
property_builders,
Expand All @@ -128,7 +161,7 @@ impl<T: NodeType> CascadesOptimizer<T, NaiveMemo<T>> {
/// Clear the memo table and all optimizer states.
pub fn step_clear(&mut self) {
self.memo = NaiveMemo::new(self.property_builders.clone());
self.fired_rules.clear();
self.applied_rules.clear();
self.explored_group.clear();
self.explored_expr.clear();
}
Expand All @@ -153,8 +186,12 @@ impl<T: NodeType, M: Memo<T>> CascadesOptimizer<T, M> {
self.cost.clone()
}

pub fn rules(&self) -> Arc<[Arc<dyn Rule<T, Self>>]> {
self.rules.clone()
pub fn transformation_rules(&self) -> Arc<[(RuleId, Arc<dyn Rule<T, Self>>)]> {
self.transformation_rules.clone()
}

pub fn implementation_rules(&self) -> Arc<[(RuleId, Arc<dyn Rule<T, Self>>)]> {
self.implementation_rules.clone()
}

pub fn disable_rule(&mut self, rule_id: usize) {
Expand Down Expand Up @@ -215,7 +252,7 @@ impl<T: NodeType, M: Memo<T>> CascadesOptimizer<T, M> {
/// Optimize a `RelNode`.
pub fn step_optimize_rel(&mut self, root_rel: ArcPlanNode<T>) -> Result<GroupId> {
let (group_id, _) = self.add_new_expr(root_rel);
self.fire_optimize_tasks(group_id)?;
self.fire_optimize_tasks(group_id);
Ok(group_id)
}

Expand Down Expand Up @@ -247,17 +284,30 @@ impl<T: NodeType, M: Memo<T>> CascadesOptimizer<T, M> {
res
}

fn fire_optimize_tasks(&mut self, group_id: GroupId) -> Result<()> {
trace!(event = "fire_optimize_tasks", root_group_id = %group_id);
self.tasks
.push_back(Box::new(OptimizeGroupTask::new(group_id)));
pub fn get_next_task_id(&self) -> usize {
self.task_counter.fetch_add(1, Ordering::AcqRel)
}

pub fn push_task(&mut self, task: Box<dyn Task<T, M>>) {
self.tasks.push(task);
}

fn pop_task(&mut self) -> Option<Box<dyn Task<T, M>>> {
self.tasks.pop()
}

fn fire_optimize_tasks(&mut self, root_group_id: GroupId) {
trace!(event = "fire_optimize_tasks", root_group_id = %root_group_id);
let initial_task_id = self.get_next_task_id();
self.push_task(get_initial_task(initial_task_id, root_group_id));
// get the task from the stack
self.ctx.budget_used = false;
let plan_space_begin = self.memo.estimated_plan_space();
let mut iter = 0;
while let Some(task) = self.tasks.pop_back() {
let new_tasks = task.execute(self)?;
self.tasks.extend(new_tasks);
while let Some(task) = self.pop_task() {
task.execute(self);

// TODO: Iter is wrong
iter += 1;
if !self.ctx.budget_used {
let plan_space = self.memo.estimated_plan_space();
Expand Down Expand Up @@ -286,12 +336,11 @@ impl<T: NodeType, M: Memo<T>> CascadesOptimizer<T, M> {
}
}
}
Ok(())
}

fn optimize_inner(&mut self, root_rel: ArcPlanNode<T>) -> Result<ArcPlanNode<T>> {
let (group_id, _) = self.add_new_expr(root_rel);
self.fire_optimize_tasks(group_id)?;
self.fire_optimize_tasks(group_id);
self.memo.get_best_group_binding(group_id, |_, _, _| {})
}

Expand Down Expand Up @@ -374,15 +423,15 @@ impl<T: NodeType, M: Memo<T>> CascadesOptimizer<T, M> {
self.explored_expr.remove(&expr_id);
}

pub(super) fn is_rule_fired(&self, group_expr_id: ExprId, rule_id: RuleId) -> bool {
self.fired_rules
pub(super) fn is_rule_applied(&self, group_expr_id: ExprId, rule_id: RuleId) -> bool {
self.applied_rules
.get(&group_expr_id)
.map(|rules| rules.contains(&rule_id))
.unwrap_or(false)
}

pub(super) fn mark_rule_fired(&mut self, group_expr_id: ExprId, rule_id: RuleId) {
self.fired_rules
pub(super) fn mark_rule_applied(&mut self, group_expr_id: ExprId, rule_id: RuleId) {
self.applied_rules
.entry(group_expr_id)
.or_default()
.insert(rule_id);
Expand All @@ -406,3 +455,18 @@ impl<T: NodeType, M: Memo<T>> Optimizer<T> for CascadesOptimizer<T, M> {
self.get_property_by_group::<P>(self.resolve_group_id(root_rel), idx)
}
}

pub fn rule_matches_expr<T: NodeType, M: Memo<T>>(
rule: &Arc<dyn Rule<T, CascadesOptimizer<T, M>>>,
expr: &ArcMemoPlanNode<T>,
) -> bool {
let matcher = rule.matcher();
let typ_to_match = &expr.typ;
match matcher {
RuleMatcher::MatchNode { typ, .. } => typ == typ_to_match,
RuleMatcher::MatchDiscriminant {
typ_discriminant, ..
} => *typ_discriminant == std::mem::discriminant(typ_to_match),
_ => panic!("IR should have root node of match"), // TODO: what does this mean? replace text
}
}
23 changes: 17 additions & 6 deletions optd-core/src/cascades/tasks.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,24 +5,35 @@

use anyhow::Result;

use super::{CascadesOptimizer, Memo};
use super::{CascadesOptimizer, GroupId, Memo};
use crate::nodes::NodeType;

mod apply_rule;
mod explore_expr;
mod explore_group;
mod optimize_expression;
mod optimize_expr;
mod optimize_group;
mod optimize_inputs;

pub use apply_rule::ApplyRuleTask;
pub use explore_expr::ExploreExprTask;
pub use explore_group::ExploreGroupTask;
pub use optimize_expression::OptimizeExpressionTask;
pub use optimize_expr::OptimizeExprTask;
pub use optimize_group::OptimizeGroupTask;
pub use optimize_inputs::OptimizeInputsTask;

pub trait Task<T: NodeType, M: Memo<T>>: 'static + Send + Sync {
fn execute(&self, optimizer: &mut CascadesOptimizer<T, M>) -> Result<Vec<Box<dyn Task<T, M>>>>;
fn execute(&self, optimizer: &mut CascadesOptimizer<T, M>);
}

#[allow(dead_code)]
fn describe(&self) -> String;
pub fn get_initial_task<T: NodeType, M: Memo<T>>(
initial_task_id: usize,
root_group_id: GroupId,
) -> Box<dyn Task<T, M>> {
Box::new(OptimizeGroupTask::new(
None,
initial_task_id,
root_group_id,
None,
))
}
Loading
Loading