Skip to content
This repository was archived by the owner on Jan 7, 2025. It is now read-only.

Winner support for orm memo table #250

Draft
wants to merge 2 commits into
base: yuchen/storage-backed-memo
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion optd-core/src/cascades.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@ mod memo;
mod optimizer;
mod tasks;

pub use memo::{ArcMemoPlanNode, Group, GroupInfo, Memo, MemoPlanNode, NaiveMemo, Winner};
pub use memo::{
ArcMemoPlanNode, Group, GroupInfo, Memo, MemoPlanNode, NaiveMemo, Winner, WinnerInfo,
};
pub use optimizer::{
CascadesOptimizer, ExprId, GroupId, OptimizerProperties, PredId, RelNodeContext,
};
Expand Down
1 change: 1 addition & 0 deletions optd-core/src/cascades/memo.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ use std::sync::Arc;

use anyhow::{bail, Context, Result};
use itertools::Itertools;
use serde::{Deserialize, Serialize};
use tracing::trace;

use super::optimizer::{ExprId, GroupId, PredId};
Expand Down
4 changes: 3 additions & 1 deletion optd-core/src/cost.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
// Use of this source code is governed by an MIT-style license that can be found in the LICENSE file or at
// https://opensource.org/licenses/MIT.

use serde::{Deserialize, Serialize};

use crate::cascades::{CascadesOptimizer, Memo, RelNodeContext};
use crate::nodes::{ArcPredNode, NodeType};

Expand All @@ -12,7 +14,7 @@ pub struct Statistics(pub Box<dyn std::any::Any + Send + Sync + 'static>);
/// The cost of an operation. The cost is represented as a vector of double values.
/// For example, it can be represented as `[compute_cost, io_cost]`.
/// A lower value means a better cost.
#[derive(Default, Clone, Debug, PartialOrd, PartialEq)]
#[derive(Default, Clone, Debug, PartialOrd, PartialEq, Serialize, Deserialize)]
pub struct Cost(pub Vec<f64>);

pub trait CostModel<T: NodeType, M: Memo<T>>: 'static + Send + Sync {
Expand Down
88 changes: 71 additions & 17 deletions optd-persistent-memo/src/backend_manager.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
use optd_core::{cascades::WinnerInfo, cost::Cost};
use optd_persistent::{
entities::{
cascades_group, group_winner, logical_children, logical_expression, physical_expression,
predicate, predicate_children, predicate_logical_expression_junction,
plan_cost, predicate, predicate_children, predicate_logical_expression_junction,
predicate_physical_expression_junction,
},
StorageResult,
Expand All @@ -10,17 +11,23 @@ use sea_orm::{
ActiveModelTrait, ColumnTrait, Database, DatabaseConnection, EntityTrait, PaginatorTrait,
QueryFilter, Set,
};

pub struct BackendWinnerInfo {
pub expr_id: i32,
pub cost: f64,
}
use serde::{Deserialize, Serialize};

pub struct BackendGroupInfo {
pub group_exprs: Vec<i32>,
pub winner: Option<BackendWinnerInfo>,
}

#[derive(Clone, Serialize, Deserialize)]
pub struct BackendWinnerInfo {
pub physical_expr_id: i32,
pub total_weighted_cost: f64,
pub operation_weighted_cost: f64,
pub total_cost: Cost,
pub operation_cost: Cost,
// TODO: Statistics
}

#[derive(serde::Serialize, serde::Deserialize)]
pub struct PredicateData {
pub children_ids: Vec<i32>,
Expand All @@ -40,13 +47,20 @@ impl MemoBackendManager {
})
}

async fn get_winner(&self, winner_id: i32) -> StorageResult<Option<group_winner::Model>> {
let winner = group_winner::Entity::find()
.filter(group_winner::Column::Id.eq(winner_id))
.one(&self.db)
.await?;
async fn get_winner(
&self,
winner_id: Option<i32>,
) -> StorageResult<Option<group_winner::Model>> {
if let Some(winner_id) = winner_id {
let winner = group_winner::Entity::find()
.filter(group_winner::Column::Id.eq(winner_id))
.one(&self.db)
.await?;

Ok(winner)
Ok(Some(winner.unwrap()))
} else {
Ok(None)
}
}

pub async fn get_group(&self, group_id: i32) -> StorageResult<BackendGroupInfo> {
Expand All @@ -66,19 +80,59 @@ impl MemoBackendManager {
.map(|child| child.logical_expression_id)
.collect();

let winner = self.get_winner(group.latest_winner.unwrap()).await?;
let winner = self.get_winner(group.latest_winner).await?;

Ok(BackendGroupInfo {
group_exprs: children,
winner: winner.map(|winner| BackendWinnerInfo {
expr_id: winner.physical_expression_id,
cost: 100.0, // TODO
physical_expr_id: winner.physical_expression_id,
total_weighted_cost: 0.0, // TODO: Cost is not saved
operation_weighted_cost: 0.0,
total_cost: Cost(vec![]),
operation_cost: Cost(vec![]),
}),
})
}

pub async fn update_winner(&self) {
todo!()
pub async fn update_winner(
&self,
group_id: i32,
winner: Option<BackendWinnerInfo>,
) -> StorageResult<()> {
// TODO: Do we need garbage collection?
let new_winner = if let Some(winner) = winner {
let cost = plan_cost::ActiveModel {
physical_expression_id: Set(winner.physical_expr_id),
..Default::default()
}; // TODO: how can we store cost??

let cost_res = cost.insert(&self.db).await?;

let winner = group_winner::ActiveModel {
group_id: Set(group_id),
physical_expression_id: Set(winner.physical_expr_id),
cost_id: Set(cost_res.id),
..Default::default()
};
let winner_res = winner.insert(&self.db).await?;

Some(winner_res.id)
} else {
None
};

// update
let mut group_res: cascades_group::ActiveModel = cascades_group::Entity::find()
.filter(cascades_group::Column::Id.eq(group_id))
.one(&self.db)
.await?
.unwrap()
.into();

group_res.latest_winner = Set(new_winner);
group_res.update(&self.db).await?;

return Ok(());
}

pub async fn get_expr_count(&self) -> StorageResult<u64> {
Expand Down
47 changes: 42 additions & 5 deletions optd-persistent-memo/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
mod backend_manager;

use core::panic;
use std::{collections::HashMap, marker::PhantomData, sync::Arc};

use backend_manager::{MemoBackendManager, PredicateData};
use backend_manager::{BackendWinnerInfo, MemoBackendManager, PredicateData};
use futures_lite::future;
use optd_core::{
cascades::{self, ExprId, GroupId, GroupInfo, Memo, MemoPlanNode, PredId},
cascades::{self, ExprId, GroupId, GroupInfo, Memo, MemoPlanNode, PredId, WinnerInfo},
cost::Statistics,
nodes::{
self, ArcPlanNode, NodeType, PlanNodeOrGroup, PredNode, SerializedNodeTag,
SerializedPredTag,
Expand Down Expand Up @@ -316,7 +318,17 @@ impl<T: NodeType> Memo<T> for PersistentMemo<T> {
.map(|expr_id| ExprId((*expr_id).try_into().unwrap()))
.collect(),
info: GroupInfo {
winner: cascades::Winner::Unknown, // TODO
winner: match orm_group.winner {
Some(winner) => cascades::Winner::Full(WinnerInfo {
expr_id: ExprId(self.phys_id_to_expr_id[&winner.physical_expr_id].0),
total_weighted_cost: winner.total_weighted_cost,
operation_weighted_cost: winner.operation_weighted_cost,
total_cost: winner.total_cost,
operation_cost: winner.operation_cost,
statistics: todo!(),
}),
None => cascades::Winner::Unknown,
},
},
properties: Arc::new([]), // TODO
};
Expand All @@ -329,8 +341,33 @@ impl<T: NodeType> Memo<T> for PersistentMemo<T> {
}

fn update_group_info(&mut self, group_id: GroupId, group_info: cascades::GroupInfo) {
// TODO: This might require a bigger redesign
todo!()
// TODO: What of in_progress, is_optimized

match group_info.winner {
cascades::Winner::Unknown => {}
cascades::Winner::Impossible => {
panic!("Impossible winner not supported in persistent memo yet");
}
cascades::Winner::Full(mut info) => {
let phys_expr_id = self
.expr_id_to_log_phys_id
.get(&info.expr_id)
.expect("winner expr id not found in expr->logphys mapping")
.0;

let backend_info = BackendWinnerInfo {
physical_expr_id: phys_expr_id,
total_weighted_cost: info.total_weighted_cost,
operation_weighted_cost: info.operation_weighted_cost,
total_cost: info.total_cost,
operation_cost: info.operation_cost,
// TODO: Statistics
};

let group_id = group_id.0.try_into().unwrap();
future::block_on(self.storage.update_winner(group_id, Some(backend_info))).unwrap();
}
}
}

fn estimated_plan_space(&self) -> usize {
Expand Down