diff --git a/optd-core/src/cascades.rs b/optd-core/src/cascades.rs index 8ee12e4a..c28af5b5 100644 --- a/optd-core/src/cascades.rs +++ b/optd-core/src/cascades.rs @@ -9,7 +9,9 @@ mod memo; mod optimizer; mod tasks; -pub use memo::{ArcMemoPlanNode, Group, GroupInfo, Memo, MemoPlanNode, NaiveMemo, Winner}; +pub use memo::{ + ArcMemoPlanNode, Group, GroupInfo, Memo, MemoPlanNode, NaiveMemo, Winner, WinnerInfo, +}; pub use optimizer::{ CascadesOptimizer, ExprId, GroupId, OptimizerProperties, PredId, RelNodeContext, }; diff --git a/optd-core/src/cascades/memo.rs b/optd-core/src/cascades/memo.rs index a4a0e13f..5a10bc59 100644 --- a/optd-core/src/cascades/memo.rs +++ b/optd-core/src/cascades/memo.rs @@ -9,6 +9,7 @@ use std::sync::Arc; use anyhow::{bail, Context, Result}; use itertools::Itertools; +use serde::{Deserialize, Serialize}; use tracing::trace; use super::optimizer::{ExprId, GroupId, PredId}; diff --git a/optd-core/src/cost.rs b/optd-core/src/cost.rs index f68386d1..1226fc1f 100644 --- a/optd-core/src/cost.rs +++ b/optd-core/src/cost.rs @@ -3,6 +3,8 @@ // Use of this source code is governed by an MIT-style license that can be found in the LICENSE file or at // https://opensource.org/licenses/MIT. +use serde::{Deserialize, Serialize}; + use crate::cascades::{CascadesOptimizer, Memo, RelNodeContext}; use crate::nodes::{ArcPredNode, NodeType}; @@ -12,7 +14,7 @@ pub struct Statistics(pub Box); /// The cost of an operation. The cost is represented as a vector of double values. /// For example, it can be represented as `[compute_cost, io_cost]`. /// A lower value means a better cost. -#[derive(Default, Clone, Debug, PartialOrd, PartialEq)] +#[derive(Default, Clone, Debug, PartialOrd, PartialEq, Serialize, Deserialize)] pub struct Cost(pub Vec); pub trait CostModel>: 'static + Send + Sync { diff --git a/optd-persistent-memo/src/backend_manager.rs b/optd-persistent-memo/src/backend_manager.rs index c2de3ddb..5ac82595 100644 --- a/optd-persistent-memo/src/backend_manager.rs +++ b/optd-persistent-memo/src/backend_manager.rs @@ -1,7 +1,8 @@ +use optd_core::{cascades::WinnerInfo, cost::Cost}; use optd_persistent::{ entities::{ cascades_group, group_winner, logical_children, logical_expression, physical_expression, - predicate, predicate_children, predicate_logical_expression_junction, + plan_cost, predicate, predicate_children, predicate_logical_expression_junction, predicate_physical_expression_junction, }, StorageResult, @@ -10,17 +11,23 @@ use sea_orm::{ ActiveModelTrait, ColumnTrait, Database, DatabaseConnection, EntityTrait, PaginatorTrait, QueryFilter, Set, }; - -pub struct BackendWinnerInfo { - pub expr_id: i32, - pub cost: f64, -} +use serde::{Deserialize, Serialize}; pub struct BackendGroupInfo { pub group_exprs: Vec, pub winner: Option, } +#[derive(Clone, Serialize, Deserialize)] +pub struct BackendWinnerInfo { + pub physical_expr_id: i32, + pub total_weighted_cost: f64, + pub operation_weighted_cost: f64, + pub total_cost: Cost, + pub operation_cost: Cost, + // TODO: Statistics +} + #[derive(serde::Serialize, serde::Deserialize)] pub struct PredicateData { pub children_ids: Vec, @@ -40,13 +47,20 @@ impl MemoBackendManager { }) } - async fn get_winner(&self, winner_id: i32) -> StorageResult> { - let winner = group_winner::Entity::find() - .filter(group_winner::Column::Id.eq(winner_id)) - .one(&self.db) - .await?; + async fn get_winner( + &self, + winner_id: Option, + ) -> StorageResult> { + if let Some(winner_id) = winner_id { + let winner = group_winner::Entity::find() + .filter(group_winner::Column::Id.eq(winner_id)) + .one(&self.db) + .await?; - Ok(winner) + Ok(Some(winner.unwrap())) + } else { + Ok(None) + } } pub async fn get_group(&self, group_id: i32) -> StorageResult { @@ -66,19 +80,59 @@ impl MemoBackendManager { .map(|child| child.logical_expression_id) .collect(); - let winner = self.get_winner(group.latest_winner.unwrap()).await?; + let winner = self.get_winner(group.latest_winner).await?; Ok(BackendGroupInfo { group_exprs: children, winner: winner.map(|winner| BackendWinnerInfo { - expr_id: winner.physical_expression_id, - cost: 100.0, // TODO + physical_expr_id: winner.physical_expression_id, + total_weighted_cost: 0.0, // TODO: Cost is not saved + operation_weighted_cost: 0.0, + total_cost: Cost(vec![]), + operation_cost: Cost(vec![]), }), }) } - pub async fn update_winner(&self) { - todo!() + pub async fn update_winner( + &self, + group_id: i32, + winner: Option, + ) -> StorageResult<()> { + // TODO: Do we need garbage collection? + let new_winner = if let Some(winner) = winner { + let cost = plan_cost::ActiveModel { + physical_expression_id: Set(winner.physical_expr_id), + ..Default::default() + }; // TODO: how can we store cost?? + + let cost_res = cost.insert(&self.db).await?; + + let winner = group_winner::ActiveModel { + group_id: Set(group_id), + physical_expression_id: Set(winner.physical_expr_id), + cost_id: Set(cost_res.id), + ..Default::default() + }; + let winner_res = winner.insert(&self.db).await?; + + Some(winner_res.id) + } else { + None + }; + + // update + let mut group_res: cascades_group::ActiveModel = cascades_group::Entity::find() + .filter(cascades_group::Column::Id.eq(group_id)) + .one(&self.db) + .await? + .unwrap() + .into(); + + group_res.latest_winner = Set(new_winner); + group_res.update(&self.db).await?; + + return Ok(()); } pub async fn get_expr_count(&self) -> StorageResult { diff --git a/optd-persistent-memo/src/lib.rs b/optd-persistent-memo/src/lib.rs index 641ad456..bc6afe49 100644 --- a/optd-persistent-memo/src/lib.rs +++ b/optd-persistent-memo/src/lib.rs @@ -1,11 +1,13 @@ mod backend_manager; +use core::panic; use std::{collections::HashMap, marker::PhantomData, sync::Arc}; -use backend_manager::{MemoBackendManager, PredicateData}; +use backend_manager::{BackendWinnerInfo, MemoBackendManager, PredicateData}; use futures_lite::future; use optd_core::{ - cascades::{self, ExprId, GroupId, GroupInfo, Memo, MemoPlanNode, PredId}, + cascades::{self, ExprId, GroupId, GroupInfo, Memo, MemoPlanNode, PredId, WinnerInfo}, + cost::Statistics, nodes::{ self, ArcPlanNode, NodeType, PlanNodeOrGroup, PredNode, SerializedNodeTag, SerializedPredTag, @@ -316,7 +318,17 @@ impl Memo for PersistentMemo { .map(|expr_id| ExprId((*expr_id).try_into().unwrap())) .collect(), info: GroupInfo { - winner: cascades::Winner::Unknown, // TODO + winner: match orm_group.winner { + Some(winner) => cascades::Winner::Full(WinnerInfo { + expr_id: ExprId(self.phys_id_to_expr_id[&winner.physical_expr_id].0), + total_weighted_cost: winner.total_weighted_cost, + operation_weighted_cost: winner.operation_weighted_cost, + total_cost: winner.total_cost, + operation_cost: winner.operation_cost, + statistics: todo!(), + }), + None => cascades::Winner::Unknown, + }, }, properties: Arc::new([]), // TODO }; @@ -329,8 +341,33 @@ impl Memo for PersistentMemo { } fn update_group_info(&mut self, group_id: GroupId, group_info: cascades::GroupInfo) { - // TODO: This might require a bigger redesign - todo!() + // TODO: What of in_progress, is_optimized + + match group_info.winner { + cascades::Winner::Unknown => {} + cascades::Winner::Impossible => { + panic!("Impossible winner not supported in persistent memo yet"); + } + cascades::Winner::Full(mut info) => { + let phys_expr_id = self + .expr_id_to_log_phys_id + .get(&info.expr_id) + .expect("winner expr id not found in expr->logphys mapping") + .0; + + let backend_info = BackendWinnerInfo { + physical_expr_id: phys_expr_id, + total_weighted_cost: info.total_weighted_cost, + operation_weighted_cost: info.operation_weighted_cost, + total_cost: info.total_cost, + operation_cost: info.operation_cost, + // TODO: Statistics + }; + + let group_id = group_id.0.try_into().unwrap(); + future::block_on(self.storage.update_winner(group_id, Some(backend_info))).unwrap(); + } + } } fn estimated_plan_space(&self) -> usize {