From 82e0e38839fe9488a4deb668e514f21177afb9f1 Mon Sep 17 00:00:00 2001 From: Connor Tsui Date: Thu, 24 Apr 2025 12:11:45 -0400 Subject: [PATCH 1/9] add in-memory memo table Implements the in-memory memo table for the optimizeer. Co-authored-by: Sarvesh Tandon Co-authored-by: Connor Tsui --- optd/src/core/memo/memory.rs | 1120 ++++++++++++++++++++++++++++++ optd/src/core/memo/merge_repr.rs | 14 +- optd/src/core/memo/mod.rs | 165 +++-- 3 files changed, 1248 insertions(+), 51 deletions(-) create mode 100644 optd/src/core/memo/memory.rs diff --git a/optd/src/core/memo/memory.rs b/optd/src/core/memo/memory.rs new file mode 100644 index 00000000..9b1ad972 --- /dev/null +++ b/optd/src/core/memo/memory.rs @@ -0,0 +1,1120 @@ +use std::collections::hash_map::Entry; +use std::collections::{HashMap, HashSet, VecDeque}; + +use async_recursion::async_recursion; + +use crate::cir::Child; + +use super::Memoize; +use super::merge_repr::Representative; +use super::*; + +/// An in-memory implementation of the memo table. +#[derive(Default)] +pub struct MemoryMemo { + /// Group id to state. + groups: HashMap, + + /// Logical expression id to node. + logical_exprs: HashMap, + /// Logical expression node to id. + logical_expr_node_to_id_index: HashMap, + /// A mapping from logical expression id to group id. + logical_expr_group_index: HashMap, + + /// Dependent logical expression ids for each group id. + /// This is used to quickly find all the logical expressions that have a child equal to the group id, which is the key. + /// Dependent here does not mean the dependency stuff that we have in the memo table + group_dependent_logical_exprs: HashMap>, + + /// Physical expression id to node. + physical_exprs: HashMap)>, + /// Physical expression node to id. + physical_expr_node_to_id_index: HashMap, + + /// Dependent physical expression ids for each goal id. + /// This is used to quickly find all the physical expressions that have a child equal to the goal id, which is the key. + /// Dependent here does not mean the dependency stuff that we have in the memo table + goal_dependent_physical_exprs: HashMap>, + + /// Dependent physical expression ids for each physical expression id. + /// This is used to quickly find all the physical expressions that have a child equal to the physical expression id, which is the key. + physical_expr_dependent_physical_exprs: + HashMap>, + + /// Goal id to state. + goals: HashMap, + /// Goal node to id. + goal_node_to_id_index: HashMap, + + /// A mapping from goal member to the set of goal ids that depend on it. + member_subscribers: HashMap>, + + /// best optimized physical expression for each goal id. + best_optimized_physical_expr_index: HashMap, + + /// The shared next unique id to be used for goals, groups, logical expressions, and physical expressions. + next_shared_id: i64, + + repr_group: Representative, + repr_goal: Representative, + repr_logical_expr: Representative, + repr_physical_expr: Representative, + + transform_dependency: HashMap>, + implement_dependency: + HashMap>, + cost_dependency: HashMap, +} + +struct RuleDependency { + group_ids: HashSet, + status: Status, +} + +impl RuleDependency { + fn new(status: Status) -> Self { + let group_ids = HashSet::new(); + Self { group_ids, status } + } +} + +struct CostDependency { + goal_ids: HashSet, + status: Status, +} + +impl CostDependency { + fn new(status: Status) -> Self { + let goal_ids = HashSet::new(); + Self { goal_ids, status } + } +} + +/// State of a group in the memo structure. +struct GroupState { + /// The logical properties of the group, might be `None` if it hasn't been derived yet. + properties: Option, + logical_exprs: HashSet, + goals: HashSet, +} + +impl GroupState { + fn new(logical_expr_id: LogicalExpressionId) -> Self { + let mut logical_exprs = HashSet::new(); + logical_exprs.insert(logical_expr_id); + Self { + properties: None, + logical_exprs, + goals: HashSet::new(), + } + } +} + +struct GoalState { + /// The set of members that are part of this goal. + goal: Goal, + members: HashSet, +} + +impl GoalState { + fn new(goal: Goal) -> Self { + Self { + goal, + members: HashSet::new(), + } + } +} + +impl Memoize for MemoryMemo { + async fn merge_groups( + &mut self, + group_id_1: GroupId, + group_id_2: GroupId, + ) -> MemoizeResult> { + self.merge_groups_helper(group_id_1, group_id_2).await + } + + async fn get_logical_properties( + &self, + group_id: GroupId, + ) -> MemoizeResult> { + let group_id = self.find_repr_group(group_id).await?; + let group = self + .groups + .get(&group_id) + .ok_or(MemoizeError::GroupNotFound(group_id))?; + + Ok(group.properties.clone()) + } + + async fn set_logical_properties( + &mut self, + group_id: GroupId, + props: LogicalProperties, + ) -> MemoizeResult<()> { + let group_id = self.find_repr_group(group_id).await?; + let group = self + .groups + .get_mut(&group_id) + .ok_or(MemoizeError::GroupNotFound(group_id))?; + + group.properties = Some(props); + Ok(()) + } + + async fn get_all_logical_exprs( + &self, + group_id: GroupId, + ) -> MemoizeResult> { + let group_id = self.find_repr_group(group_id).await?; + let group = self + .groups + .get(&group_id) + .ok_or(MemoizeError::GroupNotFound(group_id))?; + + Ok(group.logical_exprs.iter().cloned().collect()) + } + + async fn get_any_logical_expr(&self, group_id: GroupId) -> MemoizeResult { + let group_id = self.find_repr_group(group_id).await?; + let group = self + .groups + .get(&group_id) + .ok_or(MemoizeError::GroupNotFound(group_id))?; + + group + .logical_exprs + .iter() + .next() + .cloned() + .ok_or(MemoizeError::NoLogicalExprInGroup(group_id)) + } + + async fn find_logical_expr_group( + &self, + logical_expr_id: LogicalExpressionId, + ) -> MemoizeResult> { + let logical_expr_id = self.find_repr_logical_expr(logical_expr_id).await?; + let maybe_group_id = self.logical_expr_group_index.get(&logical_expr_id).cloned(); + Ok(maybe_group_id) + } + + async fn create_group( + &mut self, + logical_expr_id: LogicalExpressionId, + ) -> MemoizeResult { + let group_id = self.next_group_id(); + let group = GroupState::new(logical_expr_id); + self.groups.insert(group_id, group); + self.logical_expr_group_index + .insert(logical_expr_id, group_id); + + Ok(group_id) + } + + async fn get_best_optimized_physical_expr( + &self, + goal_id: GoalId, + ) -> MemoizeResult> { + let goal_id = self.find_repr_goal(goal_id).await?; + let maybe_best_costed = self + .best_optimized_physical_expr_index + .get(&goal_id) + .cloned(); + Ok(maybe_best_costed) + } + + async fn get_all_goal_members(&self, goal_id: GoalId) -> MemoizeResult> { + let goal_id = self.find_repr_goal(goal_id).await?; + let goal_state = self.goals.get(&goal_id).unwrap(); + Ok(goal_state.members.iter().cloned().collect()) + } + + async fn add_goal_member( + &mut self, + goal_id: GoalId, + member: GoalMemberId, + ) -> MemoizeResult> { + let goal_id = self.find_repr_goal(goal_id).await?; + let member = self.find_repr_goal_member(member).await?; + let goal_state = self.goals.get_mut(&goal_id).unwrap(); + + let is_new = goal_state.members.insert(member); + if is_new { + // Create a new subscriber for the member (initialize the set if it doesn't exist). + self.member_subscribers + .entry(member) + .or_default() + .insert(goal_id); + + let new_member_cost = match member { + GoalMemberId::PhysicalExpressionId(physical_expr_id) => self + .get_physical_expr_cost(physical_expr_id) + .await? + .map(|c| (physical_expr_id, c)), + GoalMemberId::GoalId(member_goal_id) => { + self.get_best_optimized_physical_expr(member_goal_id) + .await? + } + }; + + let mut subscribers = VecDeque::new(); + subscribers.push_back(goal_id); + + let Some((physical_expr_id, cost)) = new_member_cost else { + return Ok(None); + }; + let mut subscribers = VecDeque::new(); + subscribers.push_back(goal_id); + let mut result = ForwardResult::new(physical_expr_id, cost); + // propagate the new cost to all subscribers. + self.propagate_new_member_cost(subscribers, &mut result) + .await?; + if result.goals_forwarded.is_empty() { + // No goals were forwarded, so we can return None. + Ok(None) + } else { + // Some goals were forwarded, so we return the result. + Ok(Some(result)) + } + } else { + Ok(None) + } + } + + async fn get_physical_expr_cost( + &self, + physical_expr_id: PhysicalExpressionId, + ) -> MemoizeResult> { + let physical_expr_id = self.find_repr_physical_expr(physical_expr_id).await?; + let (_, maybe_cost) = self + .physical_exprs + .get(&physical_expr_id) + .ok_or(MemoizeError::PhysicalExprNotFound(physical_expr_id))?; + Ok(*maybe_cost) + } + + async fn update_physical_expr_cost( + &mut self, + physical_expr_id: PhysicalExpressionId, + new_cost: Cost, + ) -> MemoizeResult> { + let physical_expr_id = self.find_repr_physical_expr(physical_expr_id).await?; + let (_, cost_mut) = self + .physical_exprs + .get_mut(&physical_expr_id) + .ok_or(MemoizeError::PhysicalExprNotFound(physical_expr_id))?; + let is_better = cost_mut + .replace(new_cost) + .map(|old_cost| new_cost < old_cost) + .unwrap_or(true); + + if is_better { + let mut subscribers = VecDeque::new(); + // keep propagating the new cost to all subscribers. + if let Some(subscriber_goal_ids) = self + .member_subscribers + .get(&GoalMemberId::PhysicalExpressionId(physical_expr_id)) + .map(|goals| goals.iter().cloned()) + { + subscribers.extend(subscriber_goal_ids); + } + + let mut result = ForwardResult::new(physical_expr_id, new_cost); + // propagate the new cost to all subscribers. + self.propagate_new_member_cost(subscribers, &mut result) + .await?; + if result.goals_forwarded.is_empty() { + // No goals were forwarded, so we can return None. + Ok(None) + } else { + // Some goals were forwarded, so we return the result. + Ok(Some(result)) + } + } else { + Ok(None) + } + } + + async fn get_transformation_status( + &self, + logical_expr_id: LogicalExpressionId, + rule: &TransformationRule, + ) -> MemoizeResult { + let logical_expr_id = self.find_repr_logical_expr(logical_expr_id).await?; + let status = self + .transform_dependency + .get(&logical_expr_id) + .and_then(|status_map| status_map.get(rule)) + .map(|dep| dep.status.clone()) + .unwrap_or(Status::Dirty); + Ok(status) + } + + async fn set_transformation_clean( + &mut self, + logical_expr_id: LogicalExpressionId, + rule: &TransformationRule, + ) -> MemoizeResult<()> { + let logical_expr_id = self.find_repr_logical_expr(logical_expr_id).await?; + let status_map = self + .transform_dependency + .entry(logical_expr_id) + .or_default(); + match status_map.entry(rule.clone()) { + Entry::Occupied(occupied_entry) => { + let dep = occupied_entry.into_mut(); + dep.status = Status::Clean; + } + Entry::Vacant(vacant) => { + vacant.insert(RuleDependency::new(Status::Clean)); + } + } + Ok(()) + } + + async fn get_implementation_status( + &self, + logical_expr_id: LogicalExpressionId, + goal_id: GoalId, + rule: &ImplementationRule, + ) -> MemoizeResult { + let logical_expr_id = self.find_repr_logical_expr(logical_expr_id).await?; + let goal_id = self.find_repr_goal(goal_id).await?; + let status = self + .implement_dependency + .get(&logical_expr_id) + .and_then(|status_map| status_map.get(&(goal_id, rule.clone()))) + .map(|dep| dep.status.clone()) + .unwrap_or(Status::Dirty); + Ok(status) + } + + async fn set_implementation_clean( + &mut self, + logical_expr_id: LogicalExpressionId, + goal_id: GoalId, + rule: &ImplementationRule, + ) -> MemoizeResult<()> { + let logical_expr_id = self.find_repr_logical_expr(logical_expr_id).await?; + let status_map = self + .implement_dependency + .entry(logical_expr_id) + .or_default(); + match status_map.entry((goal_id, rule.clone())) { + Entry::Occupied(occupied_entry) => { + let dep = occupied_entry.into_mut(); + dep.status = Status::Clean; + } + Entry::Vacant(vacant) => { + vacant.insert(RuleDependency::new(Status::Clean)); + } + } + Ok(()) + } + + async fn get_cost_status( + &self, + physical_expr_id: PhysicalExpressionId, + ) -> MemoizeResult { + let physical_expr_id = self.find_repr_physical_expr(physical_expr_id).await?; + let status = self + .cost_dependency + .get(&physical_expr_id) + .map(|dep| dep.status.clone()) + .unwrap_or(Status::Dirty); + Ok(status) + } + + async fn set_cost_clean( + &mut self, + physical_expr_id: PhysicalExpressionId, + ) -> MemoizeResult<()> { + let physical_expr_id = self.find_repr_physical_expr(physical_expr_id).await?; + + let entry = self.cost_dependency.entry(physical_expr_id); + + match entry { + Entry::Occupied(occupied) => { + let dep = occupied.into_mut(); + dep.status = Status::Clean; + } + Entry::Vacant(vacant) => { + vacant.insert(CostDependency::new(Status::Clean)); + } + } + + Ok(()) + } + + async fn add_transformation_dependency( + &mut self, + logical_expr_id: LogicalExpressionId, + rule: &TransformationRule, + group_id: GroupId, + ) -> MemoizeResult<()> { + let logical_expr_id = self.find_repr_logical_expr(logical_expr_id).await?; + let group_id = self.find_repr_group(group_id).await?; + let status_map = self + .transform_dependency + .entry(logical_expr_id) + .or_default(); + + match status_map.entry(rule.clone()) { + Entry::Occupied(occupied_entry) => { + let dep = occupied_entry.into_mut(); + dep.group_ids.insert(group_id); + } + Entry::Vacant(vacant) => { + let mut dep = RuleDependency::new(Status::Dirty); + dep.group_ids.insert(group_id); + vacant.insert(dep); + } + } + + Ok(()) + } + + async fn add_implementation_dependency( + &mut self, + logical_expr_id: LogicalExpressionId, + goal_id: GoalId, + rule: &ImplementationRule, + group_id: GroupId, + ) -> MemoizeResult<()> { + let logical_expr_id = self.find_repr_logical_expr(logical_expr_id).await?; + let group_id = self.find_repr_group(group_id).await?; + let goal_id = self.find_repr_goal(goal_id).await?; + + let status_map = self + .implement_dependency + .entry(logical_expr_id) + .or_default(); + + match status_map.entry((goal_id, rule.clone())) { + Entry::Occupied(occupied) => { + let dep = occupied.into_mut(); + dep.group_ids.insert(group_id); + } + Entry::Vacant(vacant) => { + let mut dep = RuleDependency::new(Status::Dirty); + dep.group_ids.insert(group_id); + vacant.insert(dep); + } + } + + Ok(()) + } + + async fn add_cost_dependency( + &mut self, + physical_expr_id: PhysicalExpressionId, + goal_id: GoalId, + ) -> MemoizeResult<()> { + let physical_expr_id = self.find_repr_physical_expr(physical_expr_id).await?; + let goal_id = self.find_repr_goal(goal_id).await?; + + match self.cost_dependency.entry(physical_expr_id) { + Entry::Occupied(occupied) => { + let dep = occupied.into_mut(); + dep.goal_ids.insert(goal_id); + } + Entry::Vacant(vacant) => { + let mut dep = CostDependency::new(Status::Dirty); + dep.goal_ids.insert(goal_id); + vacant.insert(dep); + } + } + + Ok(()) + } + + async fn get_goal_id(&mut self, goal: &Goal) -> MemoizeResult { + if let Some(goal_id) = self.goal_node_to_id_index.get(goal).cloned() { + return Ok(goal_id); + } + let goal_id = self.next_goal_id(); + self.goal_node_to_id_index.insert(goal.clone(), goal_id); + self.goals.insert(goal_id, GoalState::new(goal.clone())); + + let Goal(group_id, _) = goal; + self.groups.get_mut(group_id).unwrap().goals.insert(goal_id); + Ok(goal_id) + } + + async fn materialize_goal(&self, goal_id: GoalId) -> MemoizeResult { + let state = self + .goals + .get(&goal_id) + .ok_or(MemoizeError::GoalNotFound(goal_id))?; + + Ok(state.goal.clone()) + } + + async fn get_logical_expr_id( + &mut self, + logical_expr: &LogicalExpression, + ) -> MemoizeResult { + if let Some(logical_expr_id) = self + .logical_expr_node_to_id_index + .get(logical_expr) + .cloned() + { + return Ok(logical_expr_id); + } + let logical_expr_id = self.next_logical_expr_id(); + self.logical_expr_node_to_id_index + .insert(logical_expr.clone(), logical_expr_id); + self.logical_exprs + .insert(logical_expr_id, logical_expr.clone()); + + for child in logical_expr.children.iter() { + match child { + Child::Singleton(group_id) => { + self.group_dependent_logical_exprs + .entry(group_id.clone()) + .or_default() + .insert(logical_expr_id); + } + Child::VarLength(group_ids) => { + for group_id in group_ids.iter() { + self.group_dependent_logical_exprs + .entry(group_id.clone()) + .or_default() + .insert(logical_expr_id); + } + } + } + } + Ok(logical_expr_id) + } + + async fn materialize_logical_expr( + &self, + logical_expr_id: LogicalExpressionId, + ) -> MemoizeResult { + let logical_expr_id = self.find_repr_logical_expr(logical_expr_id).await?; + let logical_expr = self + .logical_exprs + .get(&logical_expr_id) + .ok_or(MemoizeError::LogicalExprNotFound(logical_expr_id))?; + Ok(logical_expr.clone()) + } + + async fn get_physical_expr_id( + &mut self, + physical_expr: &PhysicalExpression, + ) -> MemoizeResult { + if let Some(physical_expr_id) = self + .physical_expr_node_to_id_index + .get(physical_expr) + .cloned() + { + return Ok(physical_expr_id); + } + let physical_expr_id = self.next_physical_expr_id(); + self.physical_expr_node_to_id_index + .insert(physical_expr.clone(), physical_expr_id); + self.physical_exprs + .insert(physical_expr_id, (physical_expr.clone(), None)); + + for child in physical_expr.children.iter() { + match child { + Child::Singleton(goal_member_id) => { + if let GoalMemberId::GoalId(goal_id) = goal_member_id { + self.goal_dependent_physical_exprs + .entry(goal_id.clone()) + .or_default() + .insert(physical_expr_id); + } + } + Child::VarLength(goal_member_ids) => { + for goal_member_id in goal_member_ids.iter() { + match goal_member_id { + GoalMemberId::GoalId(goal_id) => { + self.goal_dependent_physical_exprs + .entry(goal_id.clone()) + .or_default() + .insert(physical_expr_id); + } + GoalMemberId::PhysicalExpressionId(child_physical_expr_id) => { + self.physical_expr_dependent_physical_exprs + .entry(child_physical_expr_id.clone()) + .or_default() + .insert(physical_expr_id); + } + } + } + } + } + } + Ok(physical_expr_id) + } + + async fn materialize_physical_expr( + &self, + physical_expr_id: PhysicalExpressionId, + ) -> MemoizeResult { + let physical_expr_id = self.find_repr_physical_expr(physical_expr_id).await?; + let (physical_expr, _) = self + .physical_exprs + .get(&physical_expr_id) + .ok_or(MemoizeError::PhysicalExprNotFound(physical_expr_id))?; + Ok(physical_expr.clone()) + } + + async fn find_repr_group(&self, group_id: GroupId) -> MemoizeResult { + let repr_group_id = self.repr_group.find(&group_id); + Ok(repr_group_id) + } + + async fn find_repr_goal(&self, goal_id: GoalId) -> MemoizeResult { + let repr_goal_id = self.repr_goal.find(&goal_id); + Ok(repr_goal_id) + } + + async fn find_repr_logical_expr( + &self, + logical_expr_id: LogicalExpressionId, + ) -> MemoizeResult { + let repr_expr_id = self.repr_logical_expr.find(&logical_expr_id); + Ok(repr_expr_id) + } + + async fn find_repr_physical_expr( + &self, + physical_expr_id: PhysicalExpressionId, + ) -> MemoizeResult { + let repr_expr_id = self.repr_physical_expr.find(&physical_expr_id); + Ok(repr_expr_id) + } +} + +impl MemoryMemo { + /// Creates a new logical expression with the same children but with the children being the representative group ids. + async fn create_repr_logical_expr( + &mut self, + logical_expr: LogicalExpression, + ) -> MemoizeResult { + let mut repr_logical_expr = logical_expr.clone(); + let mut new_children = Vec::new(); + + for child in repr_logical_expr.children.iter() { + match child { + Child::Singleton(group_id) => { + let repr_group_id = self.find_repr_group(group_id.clone()).await?; + new_children.push(Child::Singleton(repr_group_id)); + } + Child::VarLength(group_ids) => { + let new_group_ids = group_ids + .iter() + .map(|group_id| { + let group_id = group_id.clone(); + let self_ref = &self; + // TODO(Sarvesh): this is a hack to get the repr group id, i'm sure there's a better way to do this. + async move { self_ref.find_repr_group(group_id).await } + }) + .collect::>(); + + let new_group_ids = futures::future::join_all(new_group_ids) + .await + .into_iter() + .collect::, _>>()?; + + new_children.push(Child::VarLength(new_group_ids)); + } + } + } + repr_logical_expr.children = new_children; + Ok(repr_logical_expr) + } + + /// Creates a new physical expression with the same children but with the children being the representative group ids. + async fn create_repr_physical_expr( + &mut self, + physical_expr: PhysicalExpression, + ) -> MemoizeResult { + let mut repr_physical_expr = physical_expr.clone(); + let mut new_children = Vec::new(); + + for child in repr_physical_expr.children.iter() { + match child { + Child::Singleton(goal_member_id) => { + if let GoalMemberId::GoalId(goal_id) = goal_member_id { + let repr_goal_id = self.find_repr_goal(goal_id.clone()).await?; + new_children.push(Child::Singleton(GoalMemberId::GoalId(repr_goal_id))); + } else { + new_children.push(Child::Singleton(goal_member_id.clone())); + } + } + Child::VarLength(goal_member_ids) => { + let mut new_goal_member_ids = Vec::new(); + for goal_member_id in goal_member_ids.iter() { + match goal_member_id { + GoalMemberId::GoalId(goal_id) => { + let repr_goal_id = self.find_repr_goal(goal_id.clone()).await?; + new_goal_member_ids.push(GoalMemberId::GoalId(repr_goal_id)); + } + GoalMemberId::PhysicalExpressionId(physical_expr_id) => { + let repr_physical_expr_id = self + .find_repr_physical_expr(physical_expr_id.clone()) + .await?; + new_goal_member_ids.push(GoalMemberId::PhysicalExpressionId( + repr_physical_expr_id, + )); + } + } + } + new_children.push(Child::VarLength(new_goal_member_ids)); + } + } + } + repr_physical_expr.children = new_children; + Ok(repr_physical_expr) + } + + /// Recursively merges physical expressions. + #[async_recursion] + async fn merge_physical_exprs( + &mut self, + physical_expr_id: PhysicalExpressionId, + ) -> MemoizeResult> { + let (physical_expr, cost) = self.physical_exprs.get(&physical_expr_id).unwrap(); + let repr_physical_expr = self + .create_repr_physical_expr(physical_expr.clone()) + .await?; + let repr_physical_expr_id = self.get_physical_expr_id(&repr_physical_expr).await?; + + // merge the physical exprs + self.repr_physical_expr + .merge(&physical_expr_id, &repr_physical_expr_id); + + let mut stale_physical_exprs = HashSet::new(); + stale_physical_exprs.insert(physical_expr_id); + + let mut results = Vec::new(); + results.push(MergePhysicalExprResult { + repr_physical_expr: repr_physical_expr_id, + stale_physical_exprs: stale_physical_exprs, + }); + + let dependent_physical_exprs = self + .physical_expr_dependent_physical_exprs + .get(&physical_expr_id); + if let Some(dependent_physical_exprs) = dependent_physical_exprs { + let dependent_physical_exprs = + dependent_physical_exprs.iter().cloned().collect::>(); + for dependent_physical_expr_id in dependent_physical_exprs { + // TODO(Sarvesh): handle async recursion + let merge_physical_expr_result = self + .merge_physical_exprs(dependent_physical_expr_id.clone()) + .await?; + results.extend(merge_physical_expr_result); + } + } + + Ok(results) + } + + /// Merges two goals into a single goal. + async fn merge_goals_helper( + &mut self, + goal_id1: GoalId, + goal_id2: GoalId, + ) -> MemoizeResult<(MergeGoalResult, Vec)> { + let goal_2 = self.goals.remove(&goal_id2).unwrap(); + let goal_1 = self.goals.get(&goal_id1).unwrap(); + let goal_1_props = &goal_1.goal.1; + let goal_2_props = &goal_2.goal.1; + self.repr_goal.merge(&goal_id2, &goal_id1); + + let mut merged_goal_result = MergeGoalResult { + merged_goals: HashMap::new(), + best_expr: None, + new_repr_goal_id: goal_id1, + }; + + let best_expr_goal1 = self.get_best_optimized_physical_expr(goal_id1).await?; + let best_expr_goal2 = self.get_best_optimized_physical_expr(goal_id2).await?; + + let best_expr = match (best_expr_goal1, best_expr_goal2) { + (Some(best_expr_goal1), Some(best_expr_goal2)) => { + Some(if best_expr_goal1.1 < best_expr_goal2.1 { + best_expr_goal1 + } else { + best_expr_goal2 + }) + } + (Some(best_expr_goal1), None) => Some(best_expr_goal1), + (None, Some(best_expr_goal2)) => Some(best_expr_goal2), + (None, None) => None, + }; + + if let Some(best_expr) = best_expr { + merged_goal_result.best_expr = Some(best_expr); + } + + let mut merged_goal_info_1 = MergedGoalInfo { + goal_id: goal_id1.clone(), + members: goal_1.members.iter().cloned().collect(), + seen_best_expr_before_merge: { + if let Some(best_expr_goal1) = best_expr_goal1 { + if let Some(best_expr_goal2) = best_expr_goal2 { + // goal 1 and goal 2 both had expr, return true if goal 1's is better or equal to goal 2's + best_expr_goal1.1 <= best_expr_goal2.1 + } else { + // goal 1 had a best expr before merge but goal 2 didn't + true + } + } else { + // neither goal had a best expr before merge + false + } + }, + }; + + let mut merged_goal_info_2 = MergedGoalInfo { + goal_id: goal_id2.clone(), + members: goal_2.members.iter().cloned().collect(), + seen_best_expr_before_merge: { + if let Some(best_expr_goal2) = best_expr_goal2 { + if let Some(best_expr_goal1) = best_expr_goal1 { + // goal 2 and goal 1 both had expr, return true if goal 2's is better or equal to goal 1's + best_expr_goal2.1 <= best_expr_goal1.1 + } else { + // goal 2 had a best expr before merge but goal 1 didn't + true + } + } else { + // neither goal had a best expr before merge + false + } + }, + }; + + merged_goal_result + .merged_goals + .insert(goal_id1.clone(), merged_goal_info_1); + merged_goal_result + .merged_goals + .insert(goal_id2.clone(), merged_goal_info_2); + + // Now, we need to update all the physical exprs that depend on goal 2 to now depend on goal 1. + let goal_2_dependent_physical_exprs = self + .goal_dependent_physical_exprs + .get(&goal_id2) + .unwrap() + .iter() + .cloned() + .collect::>(); + + let mut results = Vec::new(); + for physical_expr_id in goal_2_dependent_physical_exprs { + let merge_physical_expr_result = self.merge_physical_exprs(physical_expr_id).await?; + results.extend(merge_physical_expr_result); + } + + Ok((merged_goal_result, results)) + } + + #[async_recursion] + async fn merge_groups_helper( + &mut self, + group_id_1: GroupId, + group_id_2: GroupId, + ) -> MemoizeResult> { + // our strategy is to always merge group 2 into group 1. + let group_id_1 = self.find_repr_group(group_id_1).await?; + let group_id_2 = self.find_repr_group(group_id_2).await?; + + if group_id_1 == group_id_2 { + return Ok(None); + } + let mut result = MergeResult::default(); + + let group_2_state = self.groups.remove(&group_id_2).unwrap(); + let group_2_exprs = group_2_state.logical_exprs.iter().cloned().collect(); + + let group1_state = self.groups.get_mut(&group_id_1).unwrap(); + let group1_exprs = group1_state.logical_exprs.iter().cloned().collect(); + + for logical_expr_id in group_2_state.logical_exprs { + // Update the logical expression to point to the new group id. + let old_group_id = self + .logical_expr_group_index + .insert(logical_expr_id, group_id_1); + assert!(old_group_id.is_some()); + group1_state.logical_exprs.insert(logical_expr_id); + } + let mut merge_group_result = MergeGroupResult::new(group_id_1); + merge_group_result + .merged_groups + .insert(group_id_1, group1_exprs); + merge_group_result + .merged_groups + .insert(group_id_2, group_2_exprs); + + self.repr_group.merge(&group_id_2, &group_id_1); + + result.group_merges.push(merge_group_result); + + // So now, we have to find out all the goals that belong to both groups but contain the same properties. + + let group_1_goals = group1_state.goals.iter().cloned().collect::>(); + let group_2_goals = group_2_state.goals.iter().cloned().collect::>(); + + for goal_id1 in group_1_goals.iter() { + for goal_id2 in group_2_goals.iter() { + let goal_1 = self.goals.get(&goal_id1).unwrap(); + let goal_2 = self.goals.get(&goal_id2).unwrap(); + let goal_1_props = &goal_1.goal.1; + let goal_2_props = &goal_2.goal.1; + if goal_1_props == goal_2_props { + let (merged_goal_result, merge_physical_expr_results) = self + .merge_goals_helper(goal_id1.clone(), goal_id2.clone()) + .await?; + result.goal_merges.push(merged_goal_result); + result + .physical_expr_merges + .extend(merge_physical_expr_results); + } + } + } + + // Let's check for cascading merges now. + let logical_expr_with_group_2_as_child = self + .group_dependent_logical_exprs + .get(&group_id_2) + .unwrap() + .clone(); + + for logical_expr_id in logical_expr_with_group_2_as_child.iter() { + let logical_expr = self.logical_exprs.get(logical_expr_id).unwrap(); + let repr_logical_expr = self.create_repr_logical_expr(logical_expr.clone()).await?; + let repr_logical_expr_id = self.get_logical_expr_id(&repr_logical_expr).await?; + // merge the logical exprs + self.repr_logical_expr + .merge(&logical_expr_id, &repr_logical_expr_id); + + let parent_group_id = self.logical_expr_group_index.get(logical_expr_id).unwrap(); + let parent_group_state = self.groups.get_mut(parent_group_id).unwrap(); + // We remove the stale logical expr from the parent group. + parent_group_state.logical_exprs.remove(logical_expr_id); + + // is the repr logical expr already part of a group? + if let Some(repr_parent_group_id) = + self.logical_expr_group_index.get(&repr_logical_expr_id) + { + // the repr logical expr is part of a group, so + let parent_group_id = self.logical_expr_group_index.get(logical_expr_id).unwrap(); + if repr_parent_group_id != parent_group_id { + // we have another merge to do + // do a cascading merge between repr_parent_group_id and parent_group_id + let merge_result = self + .merge_groups_helper(repr_parent_group_id.clone(), parent_group_id.clone()) + .await?; + // merge the cascading merge result with the current result. + if let Some(merge_result) = merge_result { + result.group_merges.extend(merge_result.group_merges); + result + .physical_expr_merges + .extend(merge_result.physical_expr_merges); + result.goal_merges.extend(merge_result.goal_merges); + } + } + } else { + // the repr logical expr is not part of a group, so we add it to the parent group. + // We add the new repr logical expr to the parent group. + parent_group_state + .logical_exprs + .insert(repr_logical_expr_id); + // we update the index + self.logical_expr_group_index + .insert(repr_logical_expr_id, parent_group_id.clone()); + } + } + + Ok(Some(result)) + } + + /// Generates a new group id. + fn next_group_id(&mut self) -> GroupId { + let group_id = GroupId(self.next_shared_id); + self.next_shared_id += 1; + group_id + } + + /// Generates a new physical expression id. + fn next_physical_expr_id(&mut self) -> PhysicalExpressionId { + let physical_expr_id = PhysicalExpressionId(self.next_shared_id); + self.next_shared_id += 1; + physical_expr_id + } + + /// Generates a new logical expression id. + fn next_logical_expr_id(&mut self) -> LogicalExpressionId { + let logical_expr_id = LogicalExpressionId(self.next_shared_id); + self.next_shared_id += 1; + logical_expr_id + } + + /// Generates a new goal id. + fn next_goal_id(&mut self) -> GoalId { + let goal_id = GoalId(self.next_shared_id); + self.next_shared_id += 1; + goal_id + } + + /// Propagates the new costed member physical expression to all subscribers. + async fn propagate_new_member_cost( + &mut self, + mut subscribers: VecDeque, + result: &mut ForwardResult, + ) -> MemoizeResult<()> { + while let Some(goal_id) = subscribers.pop_front() { + let current_best = self.get_best_optimized_physical_expr(goal_id).await?; + + let is_better = current_best + .map(|(_, cost)| result.best_cost < cost) + .unwrap_or(true); + + if is_better { + // Update the best cost for the goal. + self.best_optimized_physical_expr_index + .insert(goal_id, (result.physical_expr_id, result.best_cost)); + + result.goals_forwarded.insert(goal_id); + + // keep propagating the new cost to all subscribers. + if let Some(subscriber_goal_ids) = self + .member_subscribers + .get(&GoalMemberId::GoalId(goal_id)) + .map(|goals| goals.iter().cloned().collect::>()) + { + for subscriber_goal_id in subscriber_goal_ids { + subscribers.push_back(subscriber_goal_id); + } + } + } + } + + Ok(()) + } + + /// Find the representative of a goal member. + /// + /// This reduces down to finding representative physical expr or goal id. + async fn find_repr_goal_member(&self, member: GoalMemberId) -> MemoizeResult { + match member { + GoalMemberId::PhysicalExpressionId(physical_expr_id) => { + let physical_expr_id = self.find_repr_physical_expr(physical_expr_id).await?; + Ok(GoalMemberId::PhysicalExpressionId(physical_expr_id)) + } + GoalMemberId::GoalId(goal_id) => { + let goal_id = self.find_repr_goal(goal_id).await?; + Ok(GoalMemberId::GoalId(goal_id)) + } + } + } +} diff --git a/optd/src/core/memo/merge_repr.rs b/optd/src/core/memo/merge_repr.rs index e551a5ed..90648ce2 100644 --- a/optd/src/core/memo/merge_repr.rs +++ b/optd/src/core/memo/merge_repr.rs @@ -1,3 +1,5 @@ +#![allow(dead_code)] + use std::collections::HashMap; use std::hash::Hash; @@ -9,12 +11,18 @@ pub struct Representative { parents: HashMap, } +impl Default for Representative { + fn default() -> Self { + Representative { + parents: HashMap::new(), + } + } +} + impl Representative { /// Creates a new empty Representative pub(super) fn new() -> Self { - Self { - parents: HashMap::new(), - } + Self::default() } /// Finds the representative of the set containing `x` diff --git a/optd/src/core/memo/mod.rs b/optd/src/core/memo/mod.rs index 192bd170..8440f1c2 100644 --- a/optd/src/core/memo/mod.rs +++ b/optd/src/core/memo/mod.rs @@ -1,14 +1,37 @@ -use crate::core::{ - cir::{ - Cost, Goal, GoalId, GoalMemberId, GroupId, ImplementationRule, LogicalExpression, - LogicalExpressionId, LogicalProperties, PhysicalExpression, PhysicalExpressionId, - TransformationRule, - }, - error::Error, +#[cfg(test)] +pub mod memory; +mod merge_repr; + +use std::collections::{HashMap, HashSet}; + +use async_recursion::async_recursion; + +use crate::cir::{ + Cost, Goal, GoalId, GoalMemberId, GroupId, ImplementationRule, LogicalExpression, + LogicalExpressionId, LogicalProperties, PhysicalExpression, PhysicalExpressionId, + TransformationRule, }; /// Type alias for results returned by Memoize trait methods -pub type MemoizeResult = Result; +pub type MemoizeResult = Result; + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum MemoizeError { + /// Error indicating that a group ID was not found in the memo. + GroupNotFound(GroupId), + + /// Error indicating that a goal ID was not found in the memo. + GoalNotFound(GoalId), + + /// Error indicating that a logical expression ID was not found in the memo. + LogicalExprNotFound(LogicalExpressionId), + + /// Error indicating that a physical expression ID was not found in the memo. + PhysicalExprNotFound(PhysicalExpressionId), + + /// Error indicating that there is no logical expression in the group. + NoLogicalExprInGroup(GroupId), +} /// Status of a rule application or costing operation #[derive(Debug, Clone, Copy, PartialEq, Eq)] @@ -20,24 +43,27 @@ pub enum Status { Clean, } -/// Information about a merged group, including its ID and expressions -#[derive(Debug)] -pub struct MergedGroupInfo { - /// ID of the merged group - pub group_id: GroupId, - - /// All logical expressions in this group - pub expressions: Vec, -} - /// Result of merging two groups. #[derive(Debug)] pub struct MergeGroupResult { - /// Groups that were merged along with their expressions. - pub merged_groups: Vec, - /// ID of the new representative group id. pub new_repr_group_id: GroupId, + /// Groups that were merged along with their expressions. + pub merged_groups: HashMap>, +} + +impl MergeGroupResult { + /// Creates a new MergeGroupResult instance. + /// + /// # Parameters + /// * `merged_groups` - Groups that were merged along with their expressions. + /// * `new_repr_group_id` - ID of the new representative group id. + pub fn new(new_repr_group_id: GroupId) -> Self { + Self { + new_repr_group_id, + merged_groups: HashMap::new(), + } + } } /// Information about a merged goal, including its ID and expressions @@ -46,8 +72,8 @@ pub struct MergedGoalInfo { /// ID of the merged goal pub goal_id: GoalId, - /// The best costed expression for this goal, if any - pub best_expr: Option<(PhysicalExpressionId, Cost)>, + /// Whether this goal contained the best costed expression before merging. + pub seen_best_expr_before_merge: bool, /// All members in this goal, which can be physical expressions or references to other goals pub members: Vec, @@ -57,14 +83,27 @@ pub struct MergedGoalInfo { #[derive(Debug)] pub struct MergeGoalResult { /// Goals that were merged along with their potential best costed expression. - pub merged_goals: Vec, + pub merged_goals: HashMap, + + /// The best costed expression for all merged goals combined. + pub best_expr: Option<(PhysicalExpressionId, Cost)>, /// ID of the new representative goal id. pub new_repr_goal_id: GoalId, } -/// Results of merge operations with newly dirtied expressions. +/// Result of merging two cost expressions. #[derive(Debug)] +pub struct MergePhysicalExprResult { + /// The new representative physical expression id. + pub repr_physical_expr: PhysicalExpressionId, + + /// Physical expressions that were stale + pub stale_physical_exprs: HashSet, +} + +/// Results of merge operations with newly dirtied expressions. +#[derive(Debug, Default)] pub struct MergeResult { /// Group merge results. pub group_merges: Vec, @@ -72,14 +111,32 @@ pub struct MergeResult { /// Goal merge results. pub goal_merges: Vec, - /// Transformations that were marked as dirty and need new application. - pub dirty_transformations: Vec<(LogicalExpressionId, TransformationRule)>, + /// Physical expression merge results. + pub physical_expr_merges: Vec, + // /// Transformations that were marked as dirty and need new application. + // pub dirty_transformations: Vec<(LogicalExpressionId, TransformationRule)>, - /// Implementations that were marked as dirty and need new application. - pub dirty_implementations: Vec<(LogicalExpressionId, GoalId, ImplementationRule)>, + // /// Implementations that were marked as dirty and need new application. + // pub dirty_implementations: Vec<(LogicalExpressionId, GoalId, ImplementationRule)>, + + // /// Costings that were marked as dirty and need recomputation. + // pub dirty_costings: Vec, +} - /// Costings that were marked as dirty and need recomputation. - pub dirty_costings: Vec, +pub struct ForwardResult { + pub physical_expr_id: PhysicalExpressionId, + pub best_cost: Cost, + pub goals_forwarded: HashSet, +} + +impl ForwardResult { + pub fn new(physical_expr_id: PhysicalExpressionId, best_cost: Cost) -> Self { + Self { + physical_expr_id, + best_cost, + goals_forwarded: HashSet::new(), + } + } } /// Core interface for memo-based query optimization. @@ -101,7 +158,24 @@ pub trait Memoize: Send + Sync + 'static { /// /// # Returns /// The properties associated with the group or an error if not found. - async fn get_logical_properties(&self, group_id: GroupId) -> MemoizeResult; + async fn get_logical_properties( + &self, + group_id: GroupId, + ) -> MemoizeResult>; + + /// Sets logical properties for a group ID. + /// + /// # Parameters + /// * `group_id` - ID of the group to set properties for. + /// * `props` - The logical properties to associate with the group. + /// + /// # Returns + /// A result indicating success or failure of the operation. + async fn set_logical_properties( + &mut self, + group_id: GroupId, + props: LogicalProperties, + ) -> MemoizeResult<()>; /// Gets all logical expression IDs in a group (only representative IDs). /// @@ -115,6 +189,9 @@ pub trait Memoize: Send + Sync + 'static { group_id: GroupId, ) -> MemoizeResult>; + /// Gets any logical expression ID in a group. + async fn get_any_logical_expr(&self, group_id: GroupId) -> MemoizeResult; + /// Finds group containing a logical expression ID, if it exists. /// /// # Parameters @@ -138,7 +215,6 @@ pub trait Memoize: Send + Sync + 'static { async fn create_group( &mut self, logical_expr_id: LogicalExpressionId, - props: &LogicalProperties, ) -> MemoizeResult; /// Merges groups 1 and 2, unifying them under a common representative. @@ -156,7 +232,7 @@ pub trait Memoize: Send + Sync + 'static { &mut self, group_id_1: GroupId, group_id_2: GroupId, - ) -> MemoizeResult; + ) -> MemoizeResult>; // // Physical expression and goal operations. @@ -175,18 +251,6 @@ pub trait Memoize: Send + Sync + 'static { goal_id: GoalId, ) -> MemoizeResult>; - /// Gets all physical expression IDs in a goal (only representative IDs). - /// - /// # Parameters - /// * `goal_id` - ID of the goal to retrieve expressions from. - /// - /// # Returns - /// A vector of physical expression IDs in the specified goal. - async fn get_all_physical_exprs( - &self, - goal_id: GoalId, - ) -> MemoizeResult>; - /// Gets all members of a goal, which can be physical expressions or other goals. /// /// # Parameters @@ -208,7 +272,7 @@ pub trait Memoize: Send + Sync + 'static { &mut self, goal_id: GoalId, member: GoalMemberId, - ) -> MemoizeResult; + ) -> MemoizeResult>; /// Updates the cost of a physical expression ID. /// @@ -222,7 +286,12 @@ pub trait Memoize: Send + Sync + 'static { &mut self, physical_expr_id: PhysicalExpressionId, new_cost: Cost, - ) -> MemoizeResult; + ) -> MemoizeResult>; + + async fn get_physical_expr_cost( + &self, + physical_expr_id: PhysicalExpressionId, + ) -> MemoizeResult>; // // Rule and costing status operations. From e8edf3a819c334b0a9c5148109b340ff8dc287ae Mon Sep 17 00:00:00 2001 From: Connor Tsui Date: Thu, 24 Apr 2025 14:30:32 -0400 Subject: [PATCH 2/9] fix clippy warnings --- optd/src/core/memo/memory.rs | 79 ++++++++++++++++-------------------- optd/src/core/memo/mod.rs | 14 ++----- 2 files changed, 38 insertions(+), 55 deletions(-) diff --git a/optd/src/core/memo/memory.rs b/optd/src/core/memo/memory.rs index 9b1ad972..ee4b2310 100644 --- a/optd/src/core/memo/memory.rs +++ b/optd/src/core/memo/memory.rs @@ -1,13 +1,9 @@ -use std::collections::hash_map::Entry; -use std::collections::{HashMap, HashSet, VecDeque}; - -use async_recursion::async_recursion; - -use crate::cir::Child; - -use super::Memoize; use super::merge_repr::Representative; use super::*; +use crate::core::cir::Child; +use async_recursion::async_recursion; +use std::collections::hash_map::Entry; +use std::collections::{HashMap, HashSet, VecDeque}; /// An in-memory implementation of the memo table. #[derive(Default)] @@ -347,7 +343,7 @@ impl Memoize for MemoryMemo { .transform_dependency .get(&logical_expr_id) .and_then(|status_map| status_map.get(rule)) - .map(|dep| dep.status.clone()) + .map(|dep| dep.status) .unwrap_or(Status::Dirty); Ok(status) } @@ -386,7 +382,7 @@ impl Memoize for MemoryMemo { .implement_dependency .get(&logical_expr_id) .and_then(|status_map| status_map.get(&(goal_id, rule.clone()))) - .map(|dep| dep.status.clone()) + .map(|dep| dep.status) .unwrap_or(Status::Dirty); Ok(status) } @@ -422,7 +418,7 @@ impl Memoize for MemoryMemo { let status = self .cost_dependency .get(&physical_expr_id) - .map(|dep| dep.status.clone()) + .map(|dep| dep.status) .unwrap_or(Status::Dirty); Ok(status) } @@ -573,14 +569,14 @@ impl Memoize for MemoryMemo { match child { Child::Singleton(group_id) => { self.group_dependent_logical_exprs - .entry(group_id.clone()) + .entry(*group_id) .or_default() .insert(logical_expr_id); } Child::VarLength(group_ids) => { for group_id in group_ids.iter() { self.group_dependent_logical_exprs - .entry(group_id.clone()) + .entry(*group_id) .or_default() .insert(logical_expr_id); } @@ -624,7 +620,7 @@ impl Memoize for MemoryMemo { Child::Singleton(goal_member_id) => { if let GoalMemberId::GoalId(goal_id) = goal_member_id { self.goal_dependent_physical_exprs - .entry(goal_id.clone()) + .entry(*goal_id) .or_default() .insert(physical_expr_id); } @@ -634,13 +630,13 @@ impl Memoize for MemoryMemo { match goal_member_id { GoalMemberId::GoalId(goal_id) => { self.goal_dependent_physical_exprs - .entry(goal_id.clone()) + .entry(*goal_id) .or_default() .insert(physical_expr_id); } GoalMemberId::PhysicalExpressionId(child_physical_expr_id) => { self.physical_expr_dependent_physical_exprs - .entry(child_physical_expr_id.clone()) + .entry(*child_physical_expr_id) .or_default() .insert(physical_expr_id); } @@ -703,17 +699,16 @@ impl MemoryMemo { for child in repr_logical_expr.children.iter() { match child { Child::Singleton(group_id) => { - let repr_group_id = self.find_repr_group(group_id.clone()).await?; + let repr_group_id = self.find_repr_group(*group_id).await?; new_children.push(Child::Singleton(repr_group_id)); } Child::VarLength(group_ids) => { let new_group_ids = group_ids .iter() .map(|group_id| { - let group_id = group_id.clone(); let self_ref = &self; // TODO(Sarvesh): this is a hack to get the repr group id, i'm sure there's a better way to do this. - async move { self_ref.find_repr_group(group_id).await } + async move { self_ref.find_repr_group(*group_id).await } }) .collect::>(); @@ -742,10 +737,10 @@ impl MemoryMemo { match child { Child::Singleton(goal_member_id) => { if let GoalMemberId::GoalId(goal_id) = goal_member_id { - let repr_goal_id = self.find_repr_goal(goal_id.clone()).await?; + let repr_goal_id = self.find_repr_goal(*goal_id).await?; new_children.push(Child::Singleton(GoalMemberId::GoalId(repr_goal_id))); } else { - new_children.push(Child::Singleton(goal_member_id.clone())); + new_children.push(Child::Singleton(*goal_member_id)); } } Child::VarLength(goal_member_ids) => { @@ -753,13 +748,12 @@ impl MemoryMemo { for goal_member_id in goal_member_ids.iter() { match goal_member_id { GoalMemberId::GoalId(goal_id) => { - let repr_goal_id = self.find_repr_goal(goal_id.clone()).await?; + let repr_goal_id = self.find_repr_goal(*goal_id).await?; new_goal_member_ids.push(GoalMemberId::GoalId(repr_goal_id)); } GoalMemberId::PhysicalExpressionId(physical_expr_id) => { - let repr_physical_expr_id = self - .find_repr_physical_expr(physical_expr_id.clone()) - .await?; + let repr_physical_expr_id = + self.find_repr_physical_expr(*physical_expr_id).await?; new_goal_member_ids.push(GoalMemberId::PhysicalExpressionId( repr_physical_expr_id, )); @@ -780,7 +774,7 @@ impl MemoryMemo { &mut self, physical_expr_id: PhysicalExpressionId, ) -> MemoizeResult> { - let (physical_expr, cost) = self.physical_exprs.get(&physical_expr_id).unwrap(); + let (physical_expr, _cost) = self.physical_exprs.get(&physical_expr_id).unwrap(); let repr_physical_expr = self .create_repr_physical_expr(physical_expr.clone()) .await?; @@ -796,7 +790,7 @@ impl MemoryMemo { let mut results = Vec::new(); results.push(MergePhysicalExprResult { repr_physical_expr: repr_physical_expr_id, - stale_physical_exprs: stale_physical_exprs, + stale_physical_exprs, }); let dependent_physical_exprs = self @@ -808,7 +802,7 @@ impl MemoryMemo { for dependent_physical_expr_id in dependent_physical_exprs { // TODO(Sarvesh): handle async recursion let merge_physical_expr_result = self - .merge_physical_exprs(dependent_physical_expr_id.clone()) + .merge_physical_exprs(dependent_physical_expr_id) .await?; results.extend(merge_physical_expr_result); } @@ -825,8 +819,6 @@ impl MemoryMemo { ) -> MemoizeResult<(MergeGoalResult, Vec)> { let goal_2 = self.goals.remove(&goal_id2).unwrap(); let goal_1 = self.goals.get(&goal_id1).unwrap(); - let goal_1_props = &goal_1.goal.1; - let goal_2_props = &goal_2.goal.1; self.repr_goal.merge(&goal_id2, &goal_id1); let mut merged_goal_result = MergeGoalResult { @@ -855,8 +847,8 @@ impl MemoryMemo { merged_goal_result.best_expr = Some(best_expr); } - let mut merged_goal_info_1 = MergedGoalInfo { - goal_id: goal_id1.clone(), + let merged_goal_info_1 = MergedGoalInfo { + goal_id: goal_id1, members: goal_1.members.iter().cloned().collect(), seen_best_expr_before_merge: { if let Some(best_expr_goal1) = best_expr_goal1 { @@ -874,8 +866,8 @@ impl MemoryMemo { }, }; - let mut merged_goal_info_2 = MergedGoalInfo { - goal_id: goal_id2.clone(), + let merged_goal_info_2 = MergedGoalInfo { + goal_id: goal_id2, members: goal_2.members.iter().cloned().collect(), seen_best_expr_before_merge: { if let Some(best_expr_goal2) = best_expr_goal2 { @@ -895,10 +887,10 @@ impl MemoryMemo { merged_goal_result .merged_goals - .insert(goal_id1.clone(), merged_goal_info_1); + .insert(goal_id1, merged_goal_info_1); merged_goal_result .merged_goals - .insert(goal_id2.clone(), merged_goal_info_2); + .insert(goal_id2, merged_goal_info_2); // Now, we need to update all the physical exprs that depend on goal 2 to now depend on goal 1. let goal_2_dependent_physical_exprs = self @@ -966,14 +958,13 @@ impl MemoryMemo { for goal_id1 in group_1_goals.iter() { for goal_id2 in group_2_goals.iter() { - let goal_1 = self.goals.get(&goal_id1).unwrap(); - let goal_2 = self.goals.get(&goal_id2).unwrap(); + let goal_1 = self.goals.get(goal_id1).unwrap(); + let goal_2 = self.goals.get(goal_id2).unwrap(); let goal_1_props = &goal_1.goal.1; let goal_2_props = &goal_2.goal.1; if goal_1_props == goal_2_props { - let (merged_goal_result, merge_physical_expr_results) = self - .merge_goals_helper(goal_id1.clone(), goal_id2.clone()) - .await?; + let (merged_goal_result, merge_physical_expr_results) = + self.merge_goals_helper(*goal_id1, *goal_id2).await?; result.goal_merges.push(merged_goal_result); result .physical_expr_merges @@ -995,7 +986,7 @@ impl MemoryMemo { let repr_logical_expr_id = self.get_logical_expr_id(&repr_logical_expr).await?; // merge the logical exprs self.repr_logical_expr - .merge(&logical_expr_id, &repr_logical_expr_id); + .merge(logical_expr_id, &repr_logical_expr_id); let parent_group_id = self.logical_expr_group_index.get(logical_expr_id).unwrap(); let parent_group_state = self.groups.get_mut(parent_group_id).unwrap(); @@ -1012,7 +1003,7 @@ impl MemoryMemo { // we have another merge to do // do a cascading merge between repr_parent_group_id and parent_group_id let merge_result = self - .merge_groups_helper(repr_parent_group_id.clone(), parent_group_id.clone()) + .merge_groups_helper(*repr_parent_group_id, *parent_group_id) .await?; // merge the cascading merge result with the current result. if let Some(merge_result) = merge_result { @@ -1031,7 +1022,7 @@ impl MemoryMemo { .insert(repr_logical_expr_id); // we update the index self.logical_expr_group_index - .insert(repr_logical_expr_id, parent_group_id.clone()); + .insert(repr_logical_expr_id, *parent_group_id); } } diff --git a/optd/src/core/memo/mod.rs b/optd/src/core/memo/mod.rs index 8440f1c2..3dff5fbb 100644 --- a/optd/src/core/memo/mod.rs +++ b/optd/src/core/memo/mod.rs @@ -1,16 +1,8 @@ -#[cfg(test)] -pub mod memory; -mod merge_repr; - +use crate::core::cir::*; use std::collections::{HashMap, HashSet}; -use async_recursion::async_recursion; - -use crate::cir::{ - Cost, Goal, GoalId, GoalMemberId, GroupId, ImplementationRule, LogicalExpression, - LogicalExpressionId, LogicalProperties, PhysicalExpression, PhysicalExpressionId, - TransformationRule, -}; +mod memory; +mod merge_repr; /// Type alias for results returned by Memoize trait methods pub type MemoizeResult = Result; From 249f0ee1742d43b819036b76cb03083c2d72131a Mon Sep 17 00:00:00 2001 From: Connor Tsui Date: Tue, 29 Apr 2025 18:40:39 -0400 Subject: [PATCH 3/9] refactor monolith trait into subtraits --- optd/src/core/memo/memory.rs | 342 ++++++++++++++++++----------------- optd/src/core/memo/mod.rs | 272 ++++++++++++++-------------- 2 files changed, 311 insertions(+), 303 deletions(-) diff --git a/optd/src/core/memo/memory.rs b/optd/src/core/memo/memory.rs index ee4b2310..4e63d10f 100644 --- a/optd/src/core/memo/memory.rs +++ b/optd/src/core/memo/memory.rs @@ -122,15 +122,9 @@ impl GoalState { } } -impl Memoize for MemoryMemo { - async fn merge_groups( - &mut self, - group_id_1: GroupId, - group_id_2: GroupId, - ) -> MemoizeResult> { - self.merge_groups_helper(group_id_1, group_id_2).await - } +impl OptimizerState for MemoryMemo {} +impl Memo for MemoryMemo { async fn get_logical_properties( &self, group_id: GroupId, @@ -209,6 +203,14 @@ impl Memoize for MemoryMemo { Ok(group_id) } + async fn merge_groups( + &mut self, + group_id_1: GroupId, + group_id_2: GroupId, + ) -> MemoizeResult> { + self.merge_groups_helper(group_id_1, group_id_2).await + } + async fn get_best_optimized_physical_expr( &self, goal_id: GoalId, @@ -333,6 +335,170 @@ impl Memoize for MemoryMemo { } } + async fn find_repr_group(&self, group_id: GroupId) -> MemoizeResult { + let repr_group_id = self.repr_group.find(&group_id); + Ok(repr_group_id) + } + + async fn find_repr_goal(&self, goal_id: GoalId) -> MemoizeResult { + let repr_goal_id = self.repr_goal.find(&goal_id); + Ok(repr_goal_id) + } + + async fn find_repr_logical_expr( + &self, + logical_expr_id: LogicalExpressionId, + ) -> MemoizeResult { + let repr_expr_id = self.repr_logical_expr.find(&logical_expr_id); + Ok(repr_expr_id) + } + + async fn find_repr_physical_expr( + &self, + physical_expr_id: PhysicalExpressionId, + ) -> MemoizeResult { + let repr_expr_id = self.repr_physical_expr.find(&physical_expr_id); + Ok(repr_expr_id) + } +} + +impl Materialize for MemoryMemo { + async fn get_goal_id(&mut self, goal: &Goal) -> MemoizeResult { + if let Some(goal_id) = self.goal_node_to_id_index.get(goal).cloned() { + return Ok(goal_id); + } + let goal_id = self.next_goal_id(); + self.goal_node_to_id_index.insert(goal.clone(), goal_id); + self.goals.insert(goal_id, GoalState::new(goal.clone())); + + let Goal(group_id, _) = goal; + self.groups.get_mut(group_id).unwrap().goals.insert(goal_id); + Ok(goal_id) + } + + async fn materialize_goal(&self, goal_id: GoalId) -> MemoizeResult { + let state = self + .goals + .get(&goal_id) + .ok_or(MemoizeError::GoalNotFound(goal_id))?; + + Ok(state.goal.clone()) + } + + async fn get_logical_expr_id( + &mut self, + logical_expr: &LogicalExpression, + ) -> MemoizeResult { + if let Some(logical_expr_id) = self + .logical_expr_node_to_id_index + .get(logical_expr) + .cloned() + { + return Ok(logical_expr_id); + } + let logical_expr_id = self.next_logical_expr_id(); + self.logical_expr_node_to_id_index + .insert(logical_expr.clone(), logical_expr_id); + self.logical_exprs + .insert(logical_expr_id, logical_expr.clone()); + + for child in logical_expr.children.iter() { + match child { + Child::Singleton(group_id) => { + self.group_dependent_logical_exprs + .entry(*group_id) + .or_default() + .insert(logical_expr_id); + } + Child::VarLength(group_ids) => { + for group_id in group_ids.iter() { + self.group_dependent_logical_exprs + .entry(*group_id) + .or_default() + .insert(logical_expr_id); + } + } + } + } + Ok(logical_expr_id) + } + + async fn materialize_logical_expr( + &self, + logical_expr_id: LogicalExpressionId, + ) -> MemoizeResult { + let logical_expr_id = self.find_repr_logical_expr(logical_expr_id).await?; + let logical_expr = self + .logical_exprs + .get(&logical_expr_id) + .ok_or(MemoizeError::LogicalExprNotFound(logical_expr_id))?; + Ok(logical_expr.clone()) + } + + async fn get_physical_expr_id( + &mut self, + physical_expr: &PhysicalExpression, + ) -> MemoizeResult { + if let Some(physical_expr_id) = self + .physical_expr_node_to_id_index + .get(physical_expr) + .cloned() + { + return Ok(physical_expr_id); + } + let physical_expr_id = self.next_physical_expr_id(); + self.physical_expr_node_to_id_index + .insert(physical_expr.clone(), physical_expr_id); + self.physical_exprs + .insert(physical_expr_id, (physical_expr.clone(), None)); + + for child in physical_expr.children.iter() { + match child { + Child::Singleton(goal_member_id) => { + if let GoalMemberId::GoalId(goal_id) = goal_member_id { + self.goal_dependent_physical_exprs + .entry(*goal_id) + .or_default() + .insert(physical_expr_id); + } + } + Child::VarLength(goal_member_ids) => { + for goal_member_id in goal_member_ids.iter() { + match goal_member_id { + GoalMemberId::GoalId(goal_id) => { + self.goal_dependent_physical_exprs + .entry(*goal_id) + .or_default() + .insert(physical_expr_id); + } + GoalMemberId::PhysicalExpressionId(child_physical_expr_id) => { + self.physical_expr_dependent_physical_exprs + .entry(*child_physical_expr_id) + .or_default() + .insert(physical_expr_id); + } + } + } + } + } + } + Ok(physical_expr_id) + } + + async fn materialize_physical_expr( + &self, + physical_expr_id: PhysicalExpressionId, + ) -> MemoizeResult { + let physical_expr_id = self.find_repr_physical_expr(physical_expr_id).await?; + let (physical_expr, _) = self + .physical_exprs + .get(&physical_expr_id) + .ok_or(MemoizeError::PhysicalExprNotFound(physical_expr_id))?; + Ok(physical_expr.clone()) + } +} + +impl TaskState for MemoryMemo { async fn get_transformation_status( &self, logical_expr_id: LogicalExpressionId, @@ -525,166 +691,6 @@ impl Memoize for MemoryMemo { Ok(()) } - - async fn get_goal_id(&mut self, goal: &Goal) -> MemoizeResult { - if let Some(goal_id) = self.goal_node_to_id_index.get(goal).cloned() { - return Ok(goal_id); - } - let goal_id = self.next_goal_id(); - self.goal_node_to_id_index.insert(goal.clone(), goal_id); - self.goals.insert(goal_id, GoalState::new(goal.clone())); - - let Goal(group_id, _) = goal; - self.groups.get_mut(group_id).unwrap().goals.insert(goal_id); - Ok(goal_id) - } - - async fn materialize_goal(&self, goal_id: GoalId) -> MemoizeResult { - let state = self - .goals - .get(&goal_id) - .ok_or(MemoizeError::GoalNotFound(goal_id))?; - - Ok(state.goal.clone()) - } - - async fn get_logical_expr_id( - &mut self, - logical_expr: &LogicalExpression, - ) -> MemoizeResult { - if let Some(logical_expr_id) = self - .logical_expr_node_to_id_index - .get(logical_expr) - .cloned() - { - return Ok(logical_expr_id); - } - let logical_expr_id = self.next_logical_expr_id(); - self.logical_expr_node_to_id_index - .insert(logical_expr.clone(), logical_expr_id); - self.logical_exprs - .insert(logical_expr_id, logical_expr.clone()); - - for child in logical_expr.children.iter() { - match child { - Child::Singleton(group_id) => { - self.group_dependent_logical_exprs - .entry(*group_id) - .or_default() - .insert(logical_expr_id); - } - Child::VarLength(group_ids) => { - for group_id in group_ids.iter() { - self.group_dependent_logical_exprs - .entry(*group_id) - .or_default() - .insert(logical_expr_id); - } - } - } - } - Ok(logical_expr_id) - } - - async fn materialize_logical_expr( - &self, - logical_expr_id: LogicalExpressionId, - ) -> MemoizeResult { - let logical_expr_id = self.find_repr_logical_expr(logical_expr_id).await?; - let logical_expr = self - .logical_exprs - .get(&logical_expr_id) - .ok_or(MemoizeError::LogicalExprNotFound(logical_expr_id))?; - Ok(logical_expr.clone()) - } - - async fn get_physical_expr_id( - &mut self, - physical_expr: &PhysicalExpression, - ) -> MemoizeResult { - if let Some(physical_expr_id) = self - .physical_expr_node_to_id_index - .get(physical_expr) - .cloned() - { - return Ok(physical_expr_id); - } - let physical_expr_id = self.next_physical_expr_id(); - self.physical_expr_node_to_id_index - .insert(physical_expr.clone(), physical_expr_id); - self.physical_exprs - .insert(physical_expr_id, (physical_expr.clone(), None)); - - for child in physical_expr.children.iter() { - match child { - Child::Singleton(goal_member_id) => { - if let GoalMemberId::GoalId(goal_id) = goal_member_id { - self.goal_dependent_physical_exprs - .entry(*goal_id) - .or_default() - .insert(physical_expr_id); - } - } - Child::VarLength(goal_member_ids) => { - for goal_member_id in goal_member_ids.iter() { - match goal_member_id { - GoalMemberId::GoalId(goal_id) => { - self.goal_dependent_physical_exprs - .entry(*goal_id) - .or_default() - .insert(physical_expr_id); - } - GoalMemberId::PhysicalExpressionId(child_physical_expr_id) => { - self.physical_expr_dependent_physical_exprs - .entry(*child_physical_expr_id) - .or_default() - .insert(physical_expr_id); - } - } - } - } - } - } - Ok(physical_expr_id) - } - - async fn materialize_physical_expr( - &self, - physical_expr_id: PhysicalExpressionId, - ) -> MemoizeResult { - let physical_expr_id = self.find_repr_physical_expr(physical_expr_id).await?; - let (physical_expr, _) = self - .physical_exprs - .get(&physical_expr_id) - .ok_or(MemoizeError::PhysicalExprNotFound(physical_expr_id))?; - Ok(physical_expr.clone()) - } - - async fn find_repr_group(&self, group_id: GroupId) -> MemoizeResult { - let repr_group_id = self.repr_group.find(&group_id); - Ok(repr_group_id) - } - - async fn find_repr_goal(&self, goal_id: GoalId) -> MemoizeResult { - let repr_goal_id = self.repr_goal.find(&goal_id); - Ok(repr_goal_id) - } - - async fn find_repr_logical_expr( - &self, - logical_expr_id: LogicalExpressionId, - ) -> MemoizeResult { - let repr_expr_id = self.repr_logical_expr.find(&logical_expr_id); - Ok(repr_expr_id) - } - - async fn find_repr_physical_expr( - &self, - physical_expr_id: PhysicalExpressionId, - ) -> MemoizeResult { - let repr_expr_id = self.repr_physical_expr.find(&physical_expr_id); - Ok(repr_expr_id) - } } impl MemoryMemo { diff --git a/optd/src/core/memo/mod.rs b/optd/src/core/memo/mod.rs index 3dff5fbb..bde8fb2b 100644 --- a/optd/src/core/memo/mod.rs +++ b/optd/src/core/memo/mod.rs @@ -131,18 +131,16 @@ impl ForwardResult { } } -/// Core interface for memo-based query optimization. -/// -/// This trait defines the operations needed to store, retrieve, and manipulate -/// the memo data structure that powers the dynamic programming approach to -/// query optimization. The memo stores logical and physical expressions by their IDs, -/// manages expression properties, and tracks optimization status. +pub trait OptimizerState: Memo + Materialize + TaskState {} + +// +// Logical expression and group operations. +// +// +// Physical expression and goal operations. +// #[trait_variant::make(Send)] -pub trait Memoize: Send + Sync + 'static { - // - // Logical expression and group operations. - // - +pub trait Memo { /// Retrieves logical properties for a group ID. /// /// # Parameters @@ -226,10 +224,6 @@ pub trait Memoize: Send + Sync + 'static { group_id_2: GroupId, ) -> MemoizeResult>; - // - // Physical expression and goal operations. - // - /// Gets the best optimized physical expression ID for a goal ID. /// /// # Parameters @@ -285,6 +279,134 @@ pub trait Memoize: Send + Sync + 'static { physical_expr_id: PhysicalExpressionId, ) -> MemoizeResult>; + /// Finds the representative group ID for a given group ID. + /// + /// # Parameters + /// * `group_id` - The group ID to find the representative for. + /// + /// # Returns + /// The representative group ID (which may be the same as the input if + /// it's already the representative). + async fn find_repr_group(&self, group_id: GroupId) -> MemoizeResult; + + /// Finds the representative goal ID for a given goal ID. + /// + /// # Parameters + /// * `goal_id` - The goal ID to find the representative for. + /// + /// # Returns + /// The representative goal ID (which may be the same as the input if + /// it's already the representative). + async fn find_repr_goal(&self, goal_id: GoalId) -> MemoizeResult; + + /// Finds the representative logical expression ID for a given logical expression ID. + /// + /// # Parameters + /// * `logical_expr_id` - The logical expression ID to find the representative for. + /// + /// # Returns + /// The representative logical expression ID (which may be the same as the input if + /// it's already the representative). + async fn find_repr_logical_expr( + &self, + logical_expr_id: LogicalExpressionId, + ) -> MemoizeResult; + + /// Finds the representative physical expression ID for a given physical expression ID. + /// + /// # Parameters + /// * `physical_expr_id` - The physical expression ID to find the representative for. + /// + /// # Returns + /// The representative physical expression ID (which may be the same as the input if + /// it's already the representative). + async fn find_repr_physical_expr( + &self, + physical_expr_id: PhysicalExpressionId, + ) -> MemoizeResult; +} + +#[trait_variant::make(Send)] +pub trait Materialize { + // + // ID conversion and materialization operations. + // + + /// Gets or creates a goal ID for a given goal. + /// + /// # Parameters + /// * `goal` - The goal to get or create an ID for. + /// + /// # Returns + /// The ID of the goal. + async fn get_goal_id(&mut self, goal: &Goal) -> MemoizeResult; + + /// Materializes a goal from its ID. + /// + /// # Parameters + /// * `goal_id` - ID of the goal to materialize. + /// + /// # Returns + /// The materialized goal. + async fn materialize_goal(&self, goal_id: GoalId) -> MemoizeResult; + + /// Gets or creates a logical expression ID for a given logical expression. + /// + /// # Parameters + /// * `logical_expr` - The logical expression to get or create an ID for. + /// + /// # Returns + /// The ID of the logical expression. + async fn get_logical_expr_id( + &mut self, + logical_expr: &LogicalExpression, + ) -> MemoizeResult; + + /// Materializes a logical expression from its ID. + /// + /// # Parameters + /// * `logical_expr_id` - ID of the logical expression to materialize. + /// + /// # Returns + /// The materialized logical expression. + async fn materialize_logical_expr( + &self, + logical_expr_id: LogicalExpressionId, + ) -> MemoizeResult; + + /// Gets or creates a physical expression ID for a given physical expression. + /// + /// # Parameters + /// * `physical_expr` - The physical expression to get or create an ID for. + /// + /// # Returns + /// The ID of the physical expression. + async fn get_physical_expr_id( + &mut self, + physical_expr: &PhysicalExpression, + ) -> MemoizeResult; + + /// Materializes a physical expression from its ID. + /// + /// # Parameters + /// * `physical_expr_id` - ID of the physical expression to materialize. + /// + /// # Returns + /// The materialized physical expression. + async fn materialize_physical_expr( + &self, + physical_expr_id: PhysicalExpressionId, + ) -> MemoizeResult; +} + +/// Core interface for memo-based query optimization. +/// +/// This trait defines the operations needed to store, retrieve, and manipulate +/// the memo data structure that powers the dynamic programming approach to +/// query optimization. The memo stores logical and physical expressions by their IDs, +/// manages expression properties, and tracks optimization status. +#[trait_variant::make(Send)] +pub trait TaskState { // // Rule and costing status operations. // @@ -413,124 +535,4 @@ pub trait Memoize: Send + Sync + 'static { physical_expr_id: PhysicalExpressionId, goal_id: GoalId, ) -> MemoizeResult<()>; - - // - // ID conversion and materialization operations. - // - - /// Gets or creates a goal ID for a given goal. - /// - /// # Parameters - /// * `goal` - The goal to get or create an ID for. - /// - /// # Returns - /// The ID of the goal. - async fn get_goal_id(&mut self, goal: &Goal) -> MemoizeResult; - - /// Materializes a goal from its ID. - /// - /// # Parameters - /// * `goal_id` - ID of the goal to materialize. - /// - /// # Returns - /// The materialized goal. - async fn materialize_goal(&self, goal_id: GoalId) -> MemoizeResult; - - /// Gets or creates a logical expression ID for a given logical expression. - /// - /// # Parameters - /// * `logical_expr` - The logical expression to get or create an ID for. - /// - /// # Returns - /// The ID of the logical expression. - async fn get_logical_expr_id( - &mut self, - logical_expr: &LogicalExpression, - ) -> MemoizeResult; - - /// Materializes a logical expression from its ID. - /// - /// # Parameters - /// * `logical_expr_id` - ID of the logical expression to materialize. - /// - /// # Returns - /// The materialized logical expression. - async fn materialize_logical_expr( - &self, - logical_expr_id: LogicalExpressionId, - ) -> MemoizeResult; - - /// Gets or creates a physical expression ID for a given physical expression. - /// - /// # Parameters - /// * `physical_expr` - The physical expression to get or create an ID for. - /// - /// # Returns - /// The ID of the physical expression. - async fn get_physical_expr_id( - &mut self, - physical_expr: &PhysicalExpression, - ) -> MemoizeResult; - - /// Materializes a physical expression from its ID. - /// - /// # Parameters - /// * `physical_expr_id` - ID of the physical expression to materialize. - /// - /// # Returns - /// The materialized physical expression. - async fn materialize_physical_expr( - &self, - physical_expr_id: PhysicalExpressionId, - ) -> MemoizeResult; - - // - // Representative ID operations. - // - - /// Finds the representative group ID for a given group ID. - /// - /// # Parameters - /// * `group_id` - The group ID to find the representative for. - /// - /// # Returns - /// The representative group ID (which may be the same as the input if - /// it's already the representative). - async fn find_repr_group(&self, group_id: GroupId) -> MemoizeResult; - - /// Finds the representative goal ID for a given goal ID. - /// - /// # Parameters - /// * `goal_id` - The goal ID to find the representative for. - /// - /// # Returns - /// The representative goal ID (which may be the same as the input if - /// it's already the representative). - async fn find_repr_goal(&self, goal_id: GoalId) -> MemoizeResult; - - /// Finds the representative logical expression ID for a given logical expression ID. - /// - /// # Parameters - /// * `logical_expr_id` - The logical expression ID to find the representative for. - /// - /// # Returns - /// The representative logical expression ID (which may be the same as the input if - /// it's already the representative). - async fn find_repr_logical_expr( - &self, - logical_expr_id: LogicalExpressionId, - ) -> MemoizeResult; - - /// Finds the representative physical expression ID for a given physical expression ID. - /// - /// # Parameters - /// * `physical_expr_id` - The physical expression ID to find the representative for. - /// - /// # Returns - /// The representative physical expression ID (which may be the same as the input if - /// it's already the representative). - async fn find_repr_physical_expr( - &self, - physical_expr_id: PhysicalExpressionId, - ) -> MemoizeResult; } From 864c4dabd05bc7809fe59bfcfa046cbff7407cab Mon Sep 17 00:00:00 2001 From: Connor Tsui Date: Wed, 30 Apr 2025 10:01:15 -0400 Subject: [PATCH 4/9] refactor memo module --- optd/src/core/memo/error.rs | 22 + optd/src/core/memo/memory.rs | 16 +- optd/src/core/memo/mod.rs | 545 +----------------- optd/src/core/memo/traits.rs | 408 +++++++++++++ optd/src/core/memo/types.rs | 108 ++++ .../memo/{merge_repr.rs => union_find.rs} | 30 +- 6 files changed, 571 insertions(+), 558 deletions(-) create mode 100644 optd/src/core/memo/error.rs create mode 100644 optd/src/core/memo/traits.rs create mode 100644 optd/src/core/memo/types.rs rename optd/src/core/memo/{merge_repr.rs => union_find.rs} (89%) diff --git a/optd/src/core/memo/error.rs b/optd/src/core/memo/error.rs new file mode 100644 index 00000000..92d55953 --- /dev/null +++ b/optd/src/core/memo/error.rs @@ -0,0 +1,22 @@ +use crate::core::cir::*; + +/// Type alias for results returned by Memoize trait methods +pub type MemoizeResult = Result; + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum MemoizeError { + /// Error indicating that a group ID was not found in the memo. + GroupNotFound(GroupId), + + /// Error indicating that a goal ID was not found in the memo. + GoalNotFound(GoalId), + + /// Error indicating that a logical expression ID was not found in the memo. + LogicalExprNotFound(LogicalExpressionId), + + /// Error indicating that a physical expression ID was not found in the memo. + PhysicalExprNotFound(PhysicalExpressionId), + + /// Error indicating that there is no logical expression in the group. + NoLogicalExprInGroup(GroupId), +} diff --git a/optd/src/core/memo/memory.rs b/optd/src/core/memo/memory.rs index 4e63d10f..a620a805 100644 --- a/optd/src/core/memo/memory.rs +++ b/optd/src/core/memo/memory.rs @@ -1,9 +1,7 @@ -use super::merge_repr::Representative; -use super::*; -use crate::core::cir::Child; +use super::{union_find::UnionFind, *}; +use crate::core::cir::*; use async_recursion::async_recursion; -use std::collections::hash_map::Entry; -use std::collections::{HashMap, HashSet, VecDeque}; +use std::collections::{HashMap, HashSet, VecDeque, hash_map::Entry}; /// An in-memory implementation of the memo table. #[derive(Default)] @@ -52,10 +50,10 @@ pub struct MemoryMemo { /// The shared next unique id to be used for goals, groups, logical expressions, and physical expressions. next_shared_id: i64, - repr_group: Representative, - repr_goal: Representative, - repr_logical_expr: Representative, - repr_physical_expr: Representative, + repr_group: UnionFind, + repr_goal: UnionFind, + repr_logical_expr: UnionFind, + repr_physical_expr: UnionFind, transform_dependency: HashMap>, implement_dependency: diff --git a/optd/src/core/memo/mod.rs b/optd/src/core/memo/mod.rs index bde8fb2b..d4bf7d7b 100644 --- a/optd/src/core/memo/mod.rs +++ b/optd/src/core/memo/mod.rs @@ -1,538 +1,13 @@ -use crate::core::cir::*; -use std::collections::{HashMap, HashSet}; +mod error; +mod traits; +mod types; -mod memory; -mod merge_repr; - -/// Type alias for results returned by Memoize trait methods -pub type MemoizeResult = Result; - -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -pub enum MemoizeError { - /// Error indicating that a group ID was not found in the memo. - GroupNotFound(GroupId), - - /// Error indicating that a goal ID was not found in the memo. - GoalNotFound(GoalId), - - /// Error indicating that a logical expression ID was not found in the memo. - LogicalExprNotFound(LogicalExpressionId), - - /// Error indicating that a physical expression ID was not found in the memo. - PhysicalExprNotFound(PhysicalExpressionId), - - /// Error indicating that there is no logical expression in the group. - NoLogicalExprInGroup(GroupId), -} - -/// Status of a rule application or costing operation -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -pub enum Status { - /// There exist ongoing jobs that may generate more expressions or costs from this expression. - Dirty, - - /// Expression is fully explored or costed with no pending jobs that could add anything new. - Clean, -} - -/// Result of merging two groups. -#[derive(Debug)] -pub struct MergeGroupResult { - /// ID of the new representative group id. - pub new_repr_group_id: GroupId, - /// Groups that were merged along with their expressions. - pub merged_groups: HashMap>, -} - -impl MergeGroupResult { - /// Creates a new MergeGroupResult instance. - /// - /// # Parameters - /// * `merged_groups` - Groups that were merged along with their expressions. - /// * `new_repr_group_id` - ID of the new representative group id. - pub fn new(new_repr_group_id: GroupId) -> Self { - Self { - new_repr_group_id, - merged_groups: HashMap::new(), - } - } -} - -/// Information about a merged goal, including its ID and expressions -#[derive(Debug)] -pub struct MergedGoalInfo { - /// ID of the merged goal - pub goal_id: GoalId, - - /// Whether this goal contained the best costed expression before merging. - pub seen_best_expr_before_merge: bool, - - /// All members in this goal, which can be physical expressions or references to other goals - pub members: Vec, -} - -/// Result of merging two goals. -#[derive(Debug)] -pub struct MergeGoalResult { - /// Goals that were merged along with their potential best costed expression. - pub merged_goals: HashMap, - - /// The best costed expression for all merged goals combined. - pub best_expr: Option<(PhysicalExpressionId, Cost)>, - - /// ID of the new representative goal id. - pub new_repr_goal_id: GoalId, -} - -/// Result of merging two cost expressions. -#[derive(Debug)] -pub struct MergePhysicalExprResult { - /// The new representative physical expression id. - pub repr_physical_expr: PhysicalExpressionId, - - /// Physical expressions that were stale - pub stale_physical_exprs: HashSet, -} - -/// Results of merge operations with newly dirtied expressions. -#[derive(Debug, Default)] -pub struct MergeResult { - /// Group merge results. - pub group_merges: Vec, - - /// Goal merge results. - pub goal_merges: Vec, - - /// Physical expression merge results. - pub physical_expr_merges: Vec, - // /// Transformations that were marked as dirty and need new application. - // pub dirty_transformations: Vec<(LogicalExpressionId, TransformationRule)>, - - // /// Implementations that were marked as dirty and need new application. - // pub dirty_implementations: Vec<(LogicalExpressionId, GoalId, ImplementationRule)>, - - // /// Costings that were marked as dirty and need recomputation. - // pub dirty_costings: Vec, -} - -pub struct ForwardResult { - pub physical_expr_id: PhysicalExpressionId, - pub best_cost: Cost, - pub goals_forwarded: HashSet, -} - -impl ForwardResult { - pub fn new(physical_expr_id: PhysicalExpressionId, best_cost: Cost) -> Self { - Self { - physical_expr_id, - best_cost, - goals_forwarded: HashSet::new(), - } - } -} - -pub trait OptimizerState: Memo + Materialize + TaskState {} - -// -// Logical expression and group operations. -// -// -// Physical expression and goal operations. -// -#[trait_variant::make(Send)] -pub trait Memo { - /// Retrieves logical properties for a group ID. - /// - /// # Parameters - /// * `group_id` - ID of the group to retrieve properties for. - /// - /// # Returns - /// The properties associated with the group or an error if not found. - async fn get_logical_properties( - &self, - group_id: GroupId, - ) -> MemoizeResult>; - - /// Sets logical properties for a group ID. - /// - /// # Parameters - /// * `group_id` - ID of the group to set properties for. - /// * `props` - The logical properties to associate with the group. - /// - /// # Returns - /// A result indicating success or failure of the operation. - async fn set_logical_properties( - &mut self, - group_id: GroupId, - props: LogicalProperties, - ) -> MemoizeResult<()>; - - /// Gets all logical expression IDs in a group (only representative IDs). - /// - /// # Parameters - /// * `group_id` - ID of the group to retrieve expressions from. - /// - /// # Returns - /// A vector of logical expression IDs in the specified group. - async fn get_all_logical_exprs( - &self, - group_id: GroupId, - ) -> MemoizeResult>; +pub use error::*; +pub use traits::*; +pub use types::*; - /// Gets any logical expression ID in a group. - async fn get_any_logical_expr(&self, group_id: GroupId) -> MemoizeResult; +/// A generic implementation of the Union-Find algorithm. +mod union_find; - /// Finds group containing a logical expression ID, if it exists. - /// - /// # Parameters - /// * `logical_expr_id` - ID of the logical expression to find. - /// - /// # Returns - /// The group ID if the expression exists, None otherwise. - async fn find_logical_expr_group( - &self, - logical_expr_id: LogicalExpressionId, - ) -> MemoizeResult>; - - /// Creates a new group with a logical expression ID and properties. - /// - /// # Parameters - /// * `logical_expr_id` - ID of the logical expression to add to the group. - /// * `props` - Logical properties for the group. - /// - /// # Returns - /// The ID of the newly created group. - async fn create_group( - &mut self, - logical_expr_id: LogicalExpressionId, - ) -> MemoizeResult; - - /// Merges groups 1 and 2, unifying them under a common representative. - /// - /// May trigger cascading merges of parent groups & goals. - /// - /// # Parameters - /// * `group_id_1` - ID of the first group to merge. - /// * `group_id_2` - ID of the second group to merge. - /// - /// # Returns - /// Merge results for all affected entities including newly dirtied - /// transformations, implementations and costings. - async fn merge_groups( - &mut self, - group_id_1: GroupId, - group_id_2: GroupId, - ) -> MemoizeResult>; - - /// Gets the best optimized physical expression ID for a goal ID. - /// - /// # Parameters - /// * `goal_id` - ID of the goal to retrieve the best expression for. - /// - /// # Returns - /// The ID of the lowest-cost physical implementation found so far for the goal, - /// along with its cost. Returns None if no optimized expression exists. - async fn get_best_optimized_physical_expr( - &self, - goal_id: GoalId, - ) -> MemoizeResult>; - - /// Gets all members of a goal, which can be physical expressions or other goals. - /// - /// # Parameters - /// * `goal_id` - ID of the goal to retrieve members from. - /// - /// # Returns - /// A vector of goal members, each being either a physical expression ID or another goal ID. - async fn get_all_goal_members(&self, goal_id: GoalId) -> MemoizeResult>; - - /// Adds a member to a goal. - /// - /// # Parameters - /// * `goal_id` - ID of the goal to add the member to. - /// * `member` - The member to add, either a physical expression ID or another goal ID. - /// - /// # Returns - /// True if the member was added to the goal, or false if it already existed. - async fn add_goal_member( - &mut self, - goal_id: GoalId, - member: GoalMemberId, - ) -> MemoizeResult>; - - /// Updates the cost of a physical expression ID. - /// - /// # Parameters - /// * `physical_expr_id` - ID of the physical expression to update. - /// * `new_cost` - New cost to assign to the physical expression. - /// - /// # Returns - /// Whether the cost of the expression has improved. - async fn update_physical_expr_cost( - &mut self, - physical_expr_id: PhysicalExpressionId, - new_cost: Cost, - ) -> MemoizeResult>; - - async fn get_physical_expr_cost( - &self, - physical_expr_id: PhysicalExpressionId, - ) -> MemoizeResult>; - - /// Finds the representative group ID for a given group ID. - /// - /// # Parameters - /// * `group_id` - The group ID to find the representative for. - /// - /// # Returns - /// The representative group ID (which may be the same as the input if - /// it's already the representative). - async fn find_repr_group(&self, group_id: GroupId) -> MemoizeResult; - - /// Finds the representative goal ID for a given goal ID. - /// - /// # Parameters - /// * `goal_id` - The goal ID to find the representative for. - /// - /// # Returns - /// The representative goal ID (which may be the same as the input if - /// it's already the representative). - async fn find_repr_goal(&self, goal_id: GoalId) -> MemoizeResult; - - /// Finds the representative logical expression ID for a given logical expression ID. - /// - /// # Parameters - /// * `logical_expr_id` - The logical expression ID to find the representative for. - /// - /// # Returns - /// The representative logical expression ID (which may be the same as the input if - /// it's already the representative). - async fn find_repr_logical_expr( - &self, - logical_expr_id: LogicalExpressionId, - ) -> MemoizeResult; - - /// Finds the representative physical expression ID for a given physical expression ID. - /// - /// # Parameters - /// * `physical_expr_id` - The physical expression ID to find the representative for. - /// - /// # Returns - /// The representative physical expression ID (which may be the same as the input if - /// it's already the representative). - async fn find_repr_physical_expr( - &self, - physical_expr_id: PhysicalExpressionId, - ) -> MemoizeResult; -} - -#[trait_variant::make(Send)] -pub trait Materialize { - // - // ID conversion and materialization operations. - // - - /// Gets or creates a goal ID for a given goal. - /// - /// # Parameters - /// * `goal` - The goal to get or create an ID for. - /// - /// # Returns - /// The ID of the goal. - async fn get_goal_id(&mut self, goal: &Goal) -> MemoizeResult; - - /// Materializes a goal from its ID. - /// - /// # Parameters - /// * `goal_id` - ID of the goal to materialize. - /// - /// # Returns - /// The materialized goal. - async fn materialize_goal(&self, goal_id: GoalId) -> MemoizeResult; - - /// Gets or creates a logical expression ID for a given logical expression. - /// - /// # Parameters - /// * `logical_expr` - The logical expression to get or create an ID for. - /// - /// # Returns - /// The ID of the logical expression. - async fn get_logical_expr_id( - &mut self, - logical_expr: &LogicalExpression, - ) -> MemoizeResult; - - /// Materializes a logical expression from its ID. - /// - /// # Parameters - /// * `logical_expr_id` - ID of the logical expression to materialize. - /// - /// # Returns - /// The materialized logical expression. - async fn materialize_logical_expr( - &self, - logical_expr_id: LogicalExpressionId, - ) -> MemoizeResult; - - /// Gets or creates a physical expression ID for a given physical expression. - /// - /// # Parameters - /// * `physical_expr` - The physical expression to get or create an ID for. - /// - /// # Returns - /// The ID of the physical expression. - async fn get_physical_expr_id( - &mut self, - physical_expr: &PhysicalExpression, - ) -> MemoizeResult; - - /// Materializes a physical expression from its ID. - /// - /// # Parameters - /// * `physical_expr_id` - ID of the physical expression to materialize. - /// - /// # Returns - /// The materialized physical expression. - async fn materialize_physical_expr( - &self, - physical_expr_id: PhysicalExpressionId, - ) -> MemoizeResult; -} - -/// Core interface for memo-based query optimization. -/// -/// This trait defines the operations needed to store, retrieve, and manipulate -/// the memo data structure that powers the dynamic programming approach to -/// query optimization. The memo stores logical and physical expressions by their IDs, -/// manages expression properties, and tracks optimization status. -#[trait_variant::make(Send)] -pub trait TaskState { - // - // Rule and costing status operations. - // - - /// Checks the status of applying a transformation rule on a logical expression ID. - /// - /// # Parameters - /// * `logical_expr_id` - ID of the logical expression to check. - /// * `rule` - Transformation rule to check status for. - /// - /// # Returns - /// `Status::Dirty` if there are ongoing events that may affect the transformation, - /// `Status::Clean` if the transformation does not need to be re-evaluated. - async fn get_transformation_status( - &self, - logical_expr_id: LogicalExpressionId, - rule: &TransformationRule, - ) -> MemoizeResult; - - /// Sets the status of a transformation rule as clean on a logical expression ID. - /// - /// # Parameters - /// * `logical_expr_id` - ID of the logical expression to update. - /// * `rule` - Transformation rule to set status for. - async fn set_transformation_clean( - &mut self, - logical_expr_id: LogicalExpressionId, - rule: &TransformationRule, - ) -> MemoizeResult<()>; - - /// Checks the status of applying an implementation rule on a logical expression ID and goal ID. - /// - /// # Parameters - /// * `logical_expr_id` - ID of the logical expression to check. - /// * `goal_id` - ID of the goal to check against. - /// * `rule` - Implementation rule to check status for. - /// - /// # Returns - /// `Status::Dirty` if there are ongoing events that may affect the implementation, - /// `Status::Clean` if the implementation does not need to be re-evaluated. - async fn get_implementation_status( - &self, - logical_expr_id: LogicalExpressionId, - goal_id: GoalId, - rule: &ImplementationRule, - ) -> MemoizeResult; - - /// Sets the status of an implementation rule as clean on a logical expression ID and goal ID. - /// - /// # Parameters - /// * `logical_expr_id` - ID of the logical expression to update. - /// * `goal_id` - ID of the goal to update against. - /// * `rule` - Implementation rule to set status for. - async fn set_implementation_clean( - &mut self, - logical_expr_id: LogicalExpressionId, - goal_id: GoalId, - rule: &ImplementationRule, - ) -> MemoizeResult<()>; - - /// Checks the status of costing a physical expression ID. - /// - /// # Parameters - /// * `physical_expr_id` - ID of the physical expression to check. - /// - /// # Returns - /// `Status::Dirty` if there are ongoing events that may affect the costing, - /// `Status::Clean` if the costing does not need to be re-evaluated. - async fn get_cost_status( - &self, - physical_expr_id: PhysicalExpressionId, - ) -> MemoizeResult; - - /// Sets the status of costing a physical expression ID as clean. - /// - /// # Parameters - /// * `physical_expr_id` - ID of the physical expression to update. - async fn set_cost_clean(&mut self, physical_expr_id: PhysicalExpressionId) - -> MemoizeResult<()>; - - /// Adds a dependency between a transformation rule application and a group. - /// - /// This registers that the application of the transformation rule on the logical expression - /// depends on the group. When the group changes, the transformation status should be set to dirty. - /// - /// # Parameters - /// * `logical_expr_id` - ID of the logical expression the rule is applied to. - /// * `rule` - Transformation rule that depends on the group. - /// * `group_id` - ID of the group that the transformation depends on. - async fn add_transformation_dependency( - &mut self, - logical_expr_id: LogicalExpressionId, - rule: &TransformationRule, - group_id: GroupId, - ) -> MemoizeResult<()>; - - /// Adds a dependency between an implementation rule application and a group. - /// - /// This registers that the application of the implementation rule on the logical expression - /// for a specific goal depends on the group. When the group changes, the implementation status - /// should be set to dirty. - /// - /// # Parameters - /// * `logical_expr_id` - ID of the logical expression the rule is applied to. - /// * `goal_id` - ID of the goal the implementation targets. - /// * `rule` - Implementation rule that depends on the group. - /// * `group_id` - ID of the group that the implementation depends on. - async fn add_implementation_dependency( - &mut self, - logical_expr_id: LogicalExpressionId, - goal_id: GoalId, - rule: &ImplementationRule, - group_id: GroupId, - ) -> MemoizeResult<()>; - - /// Adds a dependency between costing a physical expression and a goal. - /// - /// This registers that the costing of the physical expression depends on the goal. - /// When the goal changes, the costing status should be set to dirty. - /// - /// # Parameters - /// * `physical_expr_id` - ID of the physical expression to cost. - /// * `goal_id` - ID of the goal that the costing depends on. - async fn add_cost_dependency( - &mut self, - physical_expr_id: PhysicalExpressionId, - goal_id: GoalId, - ) -> MemoizeResult<()>; -} +/// In-memory implementation of the optimizer state (including the memo table). +mod memory; diff --git a/optd/src/core/memo/traits.rs b/optd/src/core/memo/traits.rs new file mode 100644 index 00000000..a316c7cd --- /dev/null +++ b/optd/src/core/memo/traits.rs @@ -0,0 +1,408 @@ +use super::{ForwardResult, MemoizeResult, MergeResult, Status}; +use crate::core::cir::*; + +pub trait OptimizerState: Memo + Materialize + TaskState {} + +// +// Logical expression and group operations. +// +// +// Physical expression and goal operations. +// +#[trait_variant::make(Send)] +pub trait Memo { + /// Retrieves logical properties for a group ID. + /// + /// # Parameters + /// * `group_id` - ID of the group to retrieve properties for. + /// + /// # Returns + /// The properties associated with the group or an error if not found. + async fn get_logical_properties( + &self, + group_id: GroupId, + ) -> MemoizeResult>; + + /// Sets logical properties for a group ID. + /// + /// # Parameters + /// * `group_id` - ID of the group to set properties for. + /// * `props` - The logical properties to associate with the group. + /// + /// # Returns + /// A result indicating success or failure of the operation. + async fn set_logical_properties( + &mut self, + group_id: GroupId, + props: LogicalProperties, + ) -> MemoizeResult<()>; + + /// Gets all logical expression IDs in a group (only representative IDs). + /// + /// # Parameters + /// * `group_id` - ID of the group to retrieve expressions from. + /// + /// # Returns + /// A vector of logical expression IDs in the specified group. + async fn get_all_logical_exprs( + &self, + group_id: GroupId, + ) -> MemoizeResult>; + + /// Gets any logical expression ID in a group. + async fn get_any_logical_expr(&self, group_id: GroupId) -> MemoizeResult; + + /// Finds group containing a logical expression ID, if it exists. + /// + /// # Parameters + /// * `logical_expr_id` - ID of the logical expression to find. + /// + /// # Returns + /// The group ID if the expression exists, None otherwise. + async fn find_logical_expr_group( + &self, + logical_expr_id: LogicalExpressionId, + ) -> MemoizeResult>; + + /// Creates a new group with a logical expression ID and properties. + /// + /// # Parameters + /// * `logical_expr_id` - ID of the logical expression to add to the group. + /// * `props` - Logical properties for the group. + /// + /// # Returns + /// The ID of the newly created group. + async fn create_group( + &mut self, + logical_expr_id: LogicalExpressionId, + ) -> MemoizeResult; + + /// Merges groups 1 and 2, unifying them under a common representative. + /// + /// May trigger cascading merges of parent groups & goals. + /// + /// # Parameters + /// * `group_id_1` - ID of the first group to merge. + /// * `group_id_2` - ID of the second group to merge. + /// + /// # Returns + /// Merge results for all affected entities including newly dirtied + /// transformations, implementations and costings. + async fn merge_groups( + &mut self, + group_id_1: GroupId, + group_id_2: GroupId, + ) -> MemoizeResult>; + + /// Gets the best optimized physical expression ID for a goal ID. + /// + /// # Parameters + /// * `goal_id` - ID of the goal to retrieve the best expression for. + /// + /// # Returns + /// The ID of the lowest-cost physical implementation found so far for the goal, + /// along with its cost. Returns None if no optimized expression exists. + async fn get_best_optimized_physical_expr( + &self, + goal_id: GoalId, + ) -> MemoizeResult>; + + /// Gets all members of a goal, which can be physical expressions or other goals. + /// + /// # Parameters + /// * `goal_id` - ID of the goal to retrieve members from. + /// + /// # Returns + /// A vector of goal members, each being either a physical expression ID or another goal ID. + async fn get_all_goal_members(&self, goal_id: GoalId) -> MemoizeResult>; + + /// Adds a member to a goal. + /// + /// # Parameters + /// * `goal_id` - ID of the goal to add the member to. + /// * `member` - The member to add, either a physical expression ID or another goal ID. + /// + /// # Returns + /// True if the member was added to the goal, or false if it already existed. + async fn add_goal_member( + &mut self, + goal_id: GoalId, + member: GoalMemberId, + ) -> MemoizeResult>; + + /// Updates the cost of a physical expression ID. + /// + /// # Parameters + /// * `physical_expr_id` - ID of the physical expression to update. + /// * `new_cost` - New cost to assign to the physical expression. + /// + /// # Returns + /// Whether the cost of the expression has improved. + async fn update_physical_expr_cost( + &mut self, + physical_expr_id: PhysicalExpressionId, + new_cost: Cost, + ) -> MemoizeResult>; + + async fn get_physical_expr_cost( + &self, + physical_expr_id: PhysicalExpressionId, + ) -> MemoizeResult>; + + /// Finds the representative group ID for a given group ID. + /// + /// # Parameters + /// * `group_id` - The group ID to find the representative for. + /// + /// # Returns + /// The representative group ID (which may be the same as the input if + /// it's already the representative). + async fn find_repr_group(&self, group_id: GroupId) -> MemoizeResult; + + /// Finds the representative goal ID for a given goal ID. + /// + /// # Parameters + /// * `goal_id` - The goal ID to find the representative for. + /// + /// # Returns + /// The representative goal ID (which may be the same as the input if + /// it's already the representative). + async fn find_repr_goal(&self, goal_id: GoalId) -> MemoizeResult; + + /// Finds the representative logical expression ID for a given logical expression ID. + /// + /// # Parameters + /// * `logical_expr_id` - The logical expression ID to find the representative for. + /// + /// # Returns + /// The representative logical expression ID (which may be the same as the input if + /// it's already the representative). + async fn find_repr_logical_expr( + &self, + logical_expr_id: LogicalExpressionId, + ) -> MemoizeResult; + + /// Finds the representative physical expression ID for a given physical expression ID. + /// + /// # Parameters + /// * `physical_expr_id` - The physical expression ID to find the representative for. + /// + /// # Returns + /// The representative physical expression ID (which may be the same as the input if + /// it's already the representative). + async fn find_repr_physical_expr( + &self, + physical_expr_id: PhysicalExpressionId, + ) -> MemoizeResult; +} + +#[trait_variant::make(Send)] +pub trait Materialize { + // + // ID conversion and materialization operations. + // + + /// Gets or creates a goal ID for a given goal. + /// + /// # Parameters + /// * `goal` - The goal to get or create an ID for. + /// + /// # Returns + /// The ID of the goal. + async fn get_goal_id(&mut self, goal: &Goal) -> MemoizeResult; + + /// Materializes a goal from its ID. + /// + /// # Parameters + /// * `goal_id` - ID of the goal to materialize. + /// + /// # Returns + /// The materialized goal. + async fn materialize_goal(&self, goal_id: GoalId) -> MemoizeResult; + + /// Gets or creates a logical expression ID for a given logical expression. + /// + /// # Parameters + /// * `logical_expr` - The logical expression to get or create an ID for. + /// + /// # Returns + /// The ID of the logical expression. + async fn get_logical_expr_id( + &mut self, + logical_expr: &LogicalExpression, + ) -> MemoizeResult; + + /// Materializes a logical expression from its ID. + /// + /// # Parameters + /// * `logical_expr_id` - ID of the logical expression to materialize. + /// + /// # Returns + /// The materialized logical expression. + async fn materialize_logical_expr( + &self, + logical_expr_id: LogicalExpressionId, + ) -> MemoizeResult; + + /// Gets or creates a physical expression ID for a given physical expression. + /// + /// # Parameters + /// * `physical_expr` - The physical expression to get or create an ID for. + /// + /// # Returns + /// The ID of the physical expression. + async fn get_physical_expr_id( + &mut self, + physical_expr: &PhysicalExpression, + ) -> MemoizeResult; + + /// Materializes a physical expression from its ID. + /// + /// # Parameters + /// * `physical_expr_id` - ID of the physical expression to materialize. + /// + /// # Returns + /// The materialized physical expression. + async fn materialize_physical_expr( + &self, + physical_expr_id: PhysicalExpressionId, + ) -> MemoizeResult; +} + +/// Core interface for memo-based query optimization. +/// +/// This trait defines the operations needed to store, retrieve, and manipulate +/// the memo data structure that powers the dynamic programming approach to +/// query optimization. The memo stores logical and physical expressions by their IDs, +/// manages expression properties, and tracks optimization status. +#[trait_variant::make(Send)] +pub trait TaskState { + // + // Rule and costing status operations. + // + + /// Checks the status of applying a transformation rule on a logical expression ID. + /// + /// # Parameters + /// * `logical_expr_id` - ID of the logical expression to check. + /// * `rule` - Transformation rule to check status for. + /// + /// # Returns + /// `Status::Dirty` if there are ongoing events that may affect the transformation, + /// `Status::Clean` if the transformation does not need to be re-evaluated. + async fn get_transformation_status( + &self, + logical_expr_id: LogicalExpressionId, + rule: &TransformationRule, + ) -> MemoizeResult; + + /// Sets the status of a transformation rule as clean on a logical expression ID. + /// + /// # Parameters + /// * `logical_expr_id` - ID of the logical expression to update. + /// * `rule` - Transformation rule to set status for. + async fn set_transformation_clean( + &mut self, + logical_expr_id: LogicalExpressionId, + rule: &TransformationRule, + ) -> MemoizeResult<()>; + + /// Checks the status of applying an implementation rule on a logical expression ID and goal ID. + /// + /// # Parameters + /// * `logical_expr_id` - ID of the logical expression to check. + /// * `goal_id` - ID of the goal to check against. + /// * `rule` - Implementation rule to check status for. + /// + /// # Returns + /// `Status::Dirty` if there are ongoing events that may affect the implementation, + /// `Status::Clean` if the implementation does not need to be re-evaluated. + async fn get_implementation_status( + &self, + logical_expr_id: LogicalExpressionId, + goal_id: GoalId, + rule: &ImplementationRule, + ) -> MemoizeResult; + + /// Sets the status of an implementation rule as clean on a logical expression ID and goal ID. + /// + /// # Parameters + /// * `logical_expr_id` - ID of the logical expression to update. + /// * `goal_id` - ID of the goal to update against. + /// * `rule` - Implementation rule to set status for. + async fn set_implementation_clean( + &mut self, + logical_expr_id: LogicalExpressionId, + goal_id: GoalId, + rule: &ImplementationRule, + ) -> MemoizeResult<()>; + + /// Checks the status of costing a physical expression ID. + /// + /// # Parameters + /// * `physical_expr_id` - ID of the physical expression to check. + /// + /// # Returns + /// `Status::Dirty` if there are ongoing events that may affect the costing, + /// `Status::Clean` if the costing does not need to be re-evaluated. + async fn get_cost_status( + &self, + physical_expr_id: PhysicalExpressionId, + ) -> MemoizeResult; + + /// Sets the status of costing a physical expression ID as clean. + /// + /// # Parameters + /// * `physical_expr_id` - ID of the physical expression to update. + async fn set_cost_clean(&mut self, physical_expr_id: PhysicalExpressionId) + -> MemoizeResult<()>; + + /// Adds a dependency between a transformation rule application and a group. + /// + /// This registers that the application of the transformation rule on the logical expression + /// depends on the group. When the group changes, the transformation status should be set to dirty. + /// + /// # Parameters + /// * `logical_expr_id` - ID of the logical expression the rule is applied to. + /// * `rule` - Transformation rule that depends on the group. + /// * `group_id` - ID of the group that the transformation depends on. + async fn add_transformation_dependency( + &mut self, + logical_expr_id: LogicalExpressionId, + rule: &TransformationRule, + group_id: GroupId, + ) -> MemoizeResult<()>; + + /// Adds a dependency between an implementation rule application and a group. + /// + /// This registers that the application of the implementation rule on the logical expression + /// for a specific goal depends on the group. When the group changes, the implementation status + /// should be set to dirty. + /// + /// # Parameters + /// * `logical_expr_id` - ID of the logical expression the rule is applied to. + /// * `goal_id` - ID of the goal the implementation targets. + /// * `rule` - Implementation rule that depends on the group. + /// * `group_id` - ID of the group that the implementation depends on. + async fn add_implementation_dependency( + &mut self, + logical_expr_id: LogicalExpressionId, + goal_id: GoalId, + rule: &ImplementationRule, + group_id: GroupId, + ) -> MemoizeResult<()>; + + /// Adds a dependency between costing a physical expression and a goal. + /// + /// This registers that the costing of the physical expression depends on the goal. + /// When the goal changes, the costing status should be set to dirty. + /// + /// # Parameters + /// * `physical_expr_id` - ID of the physical expression to cost. + /// * `goal_id` - ID of the goal that the costing depends on. + async fn add_cost_dependency( + &mut self, + physical_expr_id: PhysicalExpressionId, + goal_id: GoalId, + ) -> MemoizeResult<()>; +} diff --git a/optd/src/core/memo/types.rs b/optd/src/core/memo/types.rs new file mode 100644 index 00000000..7967fdab --- /dev/null +++ b/optd/src/core/memo/types.rs @@ -0,0 +1,108 @@ +use crate::core::cir::*; +use std::collections::{HashMap, HashSet}; + +/// Status of a rule application or costing operation +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum Status { + /// There exist ongoing jobs that may generate more expressions or costs from this expression. + Dirty, + + /// Expression is fully explored or costed with no pending jobs that could add anything new. + Clean, +} + +/// Result of merging two groups. +#[derive(Debug)] +pub struct MergeGroupResult { + /// ID of the new representative group id. + pub new_repr_group_id: GroupId, + /// Groups that were merged along with their expressions. + pub merged_groups: HashMap>, +} + +impl MergeGroupResult { + /// Creates a new MergeGroupResult instance. + /// + /// # Parameters + /// * `merged_groups` - Groups that were merged along with their expressions. + /// * `new_repr_group_id` - ID of the new representative group id. + pub fn new(new_repr_group_id: GroupId) -> Self { + Self { + new_repr_group_id, + merged_groups: HashMap::new(), + } + } +} + +/// Information about a merged goal, including its ID and expressions +#[derive(Debug)] +pub struct MergedGoalInfo { + /// ID of the merged goal + pub goal_id: GoalId, + + /// Whether this goal contained the best costed expression before merging. + pub seen_best_expr_before_merge: bool, + + /// All members in this goal, which can be physical expressions or references to other goals + pub members: Vec, +} + +/// Result of merging two goals. +#[derive(Debug)] +pub struct MergeGoalResult { + /// Goals that were merged along with their potential best costed expression. + pub merged_goals: HashMap, + + /// The best costed expression for all merged goals combined. + pub best_expr: Option<(PhysicalExpressionId, Cost)>, + + /// ID of the new representative goal id. + pub new_repr_goal_id: GoalId, +} + +/// Result of merging two cost expressions. +#[derive(Debug)] +pub struct MergePhysicalExprResult { + /// The new representative physical expression id. + pub repr_physical_expr: PhysicalExpressionId, + + /// Physical expressions that were stale + pub stale_physical_exprs: HashSet, +} + +/// Results of merge operations with newly dirtied expressions. +#[derive(Debug, Default)] +pub struct MergeResult { + /// Group merge results. + pub group_merges: Vec, + + /// Goal merge results. + pub goal_merges: Vec, + + /// Physical expression merge results. + pub physical_expr_merges: Vec, + // /// Transformations that were marked as dirty and need new application. + // pub dirty_transformations: Vec<(LogicalExpressionId, TransformationRule)>, + + // /// Implementations that were marked as dirty and need new application. + // pub dirty_implementations: Vec<(LogicalExpressionId, GoalId, ImplementationRule)>, + + // /// Costings that were marked as dirty and need recomputation. + // pub dirty_costings: Vec, +} + +pub struct ForwardResult { + pub physical_expr_id: PhysicalExpressionId, + pub best_cost: Cost, + pub goals_forwarded: HashSet, +} + +impl ForwardResult { + pub fn new(physical_expr_id: PhysicalExpressionId, best_cost: Cost) -> Self { + Self { + physical_expr_id, + best_cost, + goals_forwarded: HashSet::new(), + } + } +} diff --git a/optd/src/core/memo/merge_repr.rs b/optd/src/core/memo/union_find.rs similarity index 89% rename from optd/src/core/memo/merge_repr.rs rename to optd/src/core/memo/union_find.rs index 90648ce2..86d93800 100644 --- a/optd/src/core/memo/merge_repr.rs +++ b/optd/src/core/memo/union_find.rs @@ -1,3 +1,5 @@ +//! TODO(connor) replace this with a proper third-party crate. + #![allow(dead_code)] use std::collections::HashMap; @@ -7,19 +9,19 @@ use std::hash::Hash; /// caused by merges /// /// Implements union-find with path compression for O(α(n)) amortized time complexity -pub struct Representative { +pub struct UnionFind { parents: HashMap, } -impl Default for Representative { +impl Default for UnionFind { fn default() -> Self { - Representative { + UnionFind { parents: HashMap::new(), } } } -impl Representative { +impl UnionFind { /// Creates a new empty Representative pub(super) fn new() -> Self { Self::default() @@ -83,20 +85,20 @@ mod tests { #[test] fn test_find_nonexistent() { - let repr = Representative::::new(); + let repr = UnionFind::::new(); assert_eq!(repr.find(&42), 42); } #[test] fn test_find_self() { - let mut repr = Representative::::new(); + let mut repr = UnionFind::::new(); repr.parents.insert(42, 42); assert_eq!(repr.find(&42), 42); } #[test] fn test_find_without_compression() { - let mut repr = Representative::::new(); + let mut repr = UnionFind::::new(); repr.parents.insert(1, 2); repr.parents.insert(2, 3); repr.parents.insert(3, 4); @@ -112,7 +114,7 @@ mod tests { #[test] fn test_merge_with_compression() { - let mut repr = Representative::::new(); + let mut repr = UnionFind::::new(); repr.parents.insert(1, 2); repr.parents.insert(2, 3); repr.parents.insert(3, 4); @@ -128,7 +130,7 @@ mod tests { #[test] fn test_merge_basic() { - let mut repr = Representative::::new(); + let mut repr = UnionFind::::new(); let result = repr.merge(&1, &2); assert_eq!(result, 2); assert_eq!(repr.find(&1), 2); @@ -136,7 +138,7 @@ mod tests { #[test] fn test_merge_existing() { - let mut repr = Representative::::new(); + let mut repr = UnionFind::::new(); repr.parents.insert(1, 1); repr.parents.insert(2, 2); @@ -147,7 +149,7 @@ mod tests { #[test] fn test_merge_already_merged() { - let mut repr = Representative::::new(); + let mut repr = UnionFind::::new(); repr.parents.insert(1, 2); repr.parents.insert(2, 2); @@ -158,7 +160,7 @@ mod tests { #[test] fn test_merge_chains() { - let mut repr = Representative::::new(); + let mut repr = UnionFind::::new(); // Create chain 1->2->3 repr.merge(&1, &2); @@ -183,7 +185,7 @@ mod tests { #[test] fn test_merge_with_string_keys() { - let mut repr = Representative::::new(); + let mut repr = UnionFind::::new(); let result = repr.merge(&"old".to_string(), &"new".to_string()); assert_eq!(result, "new"); @@ -192,7 +194,7 @@ mod tests { #[test] fn test_complex_merges() { - let mut repr = Representative::::new(); + let mut repr = UnionFind::::new(); // First set of merges repr.merge(&1, &2); From dcb4823c876715147508af43ef992af1e3ae9783 Mon Sep 17 00:00:00 2001 From: Connor Tsui Date: Wed, 30 Apr 2025 10:06:46 -0400 Subject: [PATCH 5/9] rename the error type to OptimizeStateError --- optd/src/core/memo/error.rs | 21 ++++---- optd/src/core/memo/memory.rs | 96 ++++++++++++++++++------------------ optd/src/core/memo/mod.rs | 6 +++ optd/src/core/memo/traits.rs | 64 ++++++++++++------------ 4 files changed, 98 insertions(+), 89 deletions(-) diff --git a/optd/src/core/memo/error.rs b/optd/src/core/memo/error.rs index 92d55953..6453faaf 100644 --- a/optd/src/core/memo/error.rs +++ b/optd/src/core/memo/error.rs @@ -1,22 +1,25 @@ use crate::core::cir::*; -/// Type alias for results returned by Memoize trait methods -pub type MemoizeResult = Result; +/// A type alias for results returned by the different memo table trait methods. +/// +/// See the private `traits.rs` module for more information (note that the traits are re-exported). +pub type OptimizeStateResult = Result; -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -pub enum MemoizeError { - /// Error indicating that a group ID was not found in the memo. +/// The possible kinds of errors that optimize state operations can run into. +#[derive(Debug, Clone, Copy)] +pub enum OptimizeStateError { + /// A [`GroupId`] does not exist in the memo. GroupNotFound(GroupId), - /// Error indicating that a goal ID was not found in the memo. + /// A [`GoalId`] does not exist in the memo. GoalNotFound(GoalId), - /// Error indicating that a logical expression ID was not found in the memo. + /// A [`LogicalExpressionId`] does not exist in the memo. LogicalExprNotFound(LogicalExpressionId), - /// Error indicating that a physical expression ID was not found in the memo. + /// A [`PhysicalExpressionId`] does not exist in the memo. PhysicalExprNotFound(PhysicalExpressionId), - /// Error indicating that there is no logical expression in the group. + /// A group does not contain any logical expressions. NoLogicalExprInGroup(GroupId), } diff --git a/optd/src/core/memo/memory.rs b/optd/src/core/memo/memory.rs index a620a805..754488d5 100644 --- a/optd/src/core/memo/memory.rs +++ b/optd/src/core/memo/memory.rs @@ -126,12 +126,12 @@ impl Memo for MemoryMemo { async fn get_logical_properties( &self, group_id: GroupId, - ) -> MemoizeResult> { + ) -> OptimizeStateResult> { let group_id = self.find_repr_group(group_id).await?; let group = self .groups .get(&group_id) - .ok_or(MemoizeError::GroupNotFound(group_id))?; + .ok_or(OptimizeStateError::GroupNotFound(group_id))?; Ok(group.properties.clone()) } @@ -140,12 +140,12 @@ impl Memo for MemoryMemo { &mut self, group_id: GroupId, props: LogicalProperties, - ) -> MemoizeResult<()> { + ) -> OptimizeStateResult<()> { let group_id = self.find_repr_group(group_id).await?; let group = self .groups .get_mut(&group_id) - .ok_or(MemoizeError::GroupNotFound(group_id))?; + .ok_or(OptimizeStateError::GroupNotFound(group_id))?; group.properties = Some(props); Ok(()) @@ -154,35 +154,35 @@ impl Memo for MemoryMemo { async fn get_all_logical_exprs( &self, group_id: GroupId, - ) -> MemoizeResult> { + ) -> OptimizeStateResult> { let group_id = self.find_repr_group(group_id).await?; let group = self .groups .get(&group_id) - .ok_or(MemoizeError::GroupNotFound(group_id))?; + .ok_or(OptimizeStateError::GroupNotFound(group_id))?; Ok(group.logical_exprs.iter().cloned().collect()) } - async fn get_any_logical_expr(&self, group_id: GroupId) -> MemoizeResult { + async fn get_any_logical_expr(&self, group_id: GroupId) -> OptimizeStateResult { let group_id = self.find_repr_group(group_id).await?; let group = self .groups .get(&group_id) - .ok_or(MemoizeError::GroupNotFound(group_id))?; + .ok_or(OptimizeStateError::GroupNotFound(group_id))?; group .logical_exprs .iter() .next() .cloned() - .ok_or(MemoizeError::NoLogicalExprInGroup(group_id)) + .ok_or(OptimizeStateError::NoLogicalExprInGroup(group_id)) } async fn find_logical_expr_group( &self, logical_expr_id: LogicalExpressionId, - ) -> MemoizeResult> { + ) -> OptimizeStateResult> { let logical_expr_id = self.find_repr_logical_expr(logical_expr_id).await?; let maybe_group_id = self.logical_expr_group_index.get(&logical_expr_id).cloned(); Ok(maybe_group_id) @@ -191,7 +191,7 @@ impl Memo for MemoryMemo { async fn create_group( &mut self, logical_expr_id: LogicalExpressionId, - ) -> MemoizeResult { + ) -> OptimizeStateResult { let group_id = self.next_group_id(); let group = GroupState::new(logical_expr_id); self.groups.insert(group_id, group); @@ -205,14 +205,14 @@ impl Memo for MemoryMemo { &mut self, group_id_1: GroupId, group_id_2: GroupId, - ) -> MemoizeResult> { + ) -> OptimizeStateResult> { self.merge_groups_helper(group_id_1, group_id_2).await } async fn get_best_optimized_physical_expr( &self, goal_id: GoalId, - ) -> MemoizeResult> { + ) -> OptimizeStateResult> { let goal_id = self.find_repr_goal(goal_id).await?; let maybe_best_costed = self .best_optimized_physical_expr_index @@ -221,7 +221,7 @@ impl Memo for MemoryMemo { Ok(maybe_best_costed) } - async fn get_all_goal_members(&self, goal_id: GoalId) -> MemoizeResult> { + async fn get_all_goal_members(&self, goal_id: GoalId) -> OptimizeStateResult> { let goal_id = self.find_repr_goal(goal_id).await?; let goal_state = self.goals.get(&goal_id).unwrap(); Ok(goal_state.members.iter().cloned().collect()) @@ -231,7 +231,7 @@ impl Memo for MemoryMemo { &mut self, goal_id: GoalId, member: GoalMemberId, - ) -> MemoizeResult> { + ) -> OptimizeStateResult> { let goal_id = self.find_repr_goal(goal_id).await?; let member = self.find_repr_goal_member(member).await?; let goal_state = self.goals.get_mut(&goal_id).unwrap(); @@ -282,12 +282,12 @@ impl Memo for MemoryMemo { async fn get_physical_expr_cost( &self, physical_expr_id: PhysicalExpressionId, - ) -> MemoizeResult> { + ) -> OptimizeStateResult> { let physical_expr_id = self.find_repr_physical_expr(physical_expr_id).await?; let (_, maybe_cost) = self .physical_exprs .get(&physical_expr_id) - .ok_or(MemoizeError::PhysicalExprNotFound(physical_expr_id))?; + .ok_or(OptimizeStateError::PhysicalExprNotFound(physical_expr_id))?; Ok(*maybe_cost) } @@ -295,12 +295,12 @@ impl Memo for MemoryMemo { &mut self, physical_expr_id: PhysicalExpressionId, new_cost: Cost, - ) -> MemoizeResult> { + ) -> OptimizeStateResult> { let physical_expr_id = self.find_repr_physical_expr(physical_expr_id).await?; let (_, cost_mut) = self .physical_exprs .get_mut(&physical_expr_id) - .ok_or(MemoizeError::PhysicalExprNotFound(physical_expr_id))?; + .ok_or(OptimizeStateError::PhysicalExprNotFound(physical_expr_id))?; let is_better = cost_mut .replace(new_cost) .map(|old_cost| new_cost < old_cost) @@ -333,12 +333,12 @@ impl Memo for MemoryMemo { } } - async fn find_repr_group(&self, group_id: GroupId) -> MemoizeResult { + async fn find_repr_group(&self, group_id: GroupId) -> OptimizeStateResult { let repr_group_id = self.repr_group.find(&group_id); Ok(repr_group_id) } - async fn find_repr_goal(&self, goal_id: GoalId) -> MemoizeResult { + async fn find_repr_goal(&self, goal_id: GoalId) -> OptimizeStateResult { let repr_goal_id = self.repr_goal.find(&goal_id); Ok(repr_goal_id) } @@ -346,7 +346,7 @@ impl Memo for MemoryMemo { async fn find_repr_logical_expr( &self, logical_expr_id: LogicalExpressionId, - ) -> MemoizeResult { + ) -> OptimizeStateResult { let repr_expr_id = self.repr_logical_expr.find(&logical_expr_id); Ok(repr_expr_id) } @@ -354,14 +354,14 @@ impl Memo for MemoryMemo { async fn find_repr_physical_expr( &self, physical_expr_id: PhysicalExpressionId, - ) -> MemoizeResult { + ) -> OptimizeStateResult { let repr_expr_id = self.repr_physical_expr.find(&physical_expr_id); Ok(repr_expr_id) } } impl Materialize for MemoryMemo { - async fn get_goal_id(&mut self, goal: &Goal) -> MemoizeResult { + async fn get_goal_id(&mut self, goal: &Goal) -> OptimizeStateResult { if let Some(goal_id) = self.goal_node_to_id_index.get(goal).cloned() { return Ok(goal_id); } @@ -374,11 +374,11 @@ impl Materialize for MemoryMemo { Ok(goal_id) } - async fn materialize_goal(&self, goal_id: GoalId) -> MemoizeResult { + async fn materialize_goal(&self, goal_id: GoalId) -> OptimizeStateResult { let state = self .goals .get(&goal_id) - .ok_or(MemoizeError::GoalNotFound(goal_id))?; + .ok_or(OptimizeStateError::GoalNotFound(goal_id))?; Ok(state.goal.clone()) } @@ -386,7 +386,7 @@ impl Materialize for MemoryMemo { async fn get_logical_expr_id( &mut self, logical_expr: &LogicalExpression, - ) -> MemoizeResult { + ) -> OptimizeStateResult { if let Some(logical_expr_id) = self .logical_expr_node_to_id_index .get(logical_expr) @@ -424,19 +424,19 @@ impl Materialize for MemoryMemo { async fn materialize_logical_expr( &self, logical_expr_id: LogicalExpressionId, - ) -> MemoizeResult { + ) -> OptimizeStateResult { let logical_expr_id = self.find_repr_logical_expr(logical_expr_id).await?; let logical_expr = self .logical_exprs .get(&logical_expr_id) - .ok_or(MemoizeError::LogicalExprNotFound(logical_expr_id))?; + .ok_or(OptimizeStateError::LogicalExprNotFound(logical_expr_id))?; Ok(logical_expr.clone()) } async fn get_physical_expr_id( &mut self, physical_expr: &PhysicalExpression, - ) -> MemoizeResult { + ) -> OptimizeStateResult { if let Some(physical_expr_id) = self .physical_expr_node_to_id_index .get(physical_expr) @@ -486,12 +486,12 @@ impl Materialize for MemoryMemo { async fn materialize_physical_expr( &self, physical_expr_id: PhysicalExpressionId, - ) -> MemoizeResult { + ) -> OptimizeStateResult { let physical_expr_id = self.find_repr_physical_expr(physical_expr_id).await?; let (physical_expr, _) = self .physical_exprs .get(&physical_expr_id) - .ok_or(MemoizeError::PhysicalExprNotFound(physical_expr_id))?; + .ok_or(OptimizeStateError::PhysicalExprNotFound(physical_expr_id))?; Ok(physical_expr.clone()) } } @@ -501,7 +501,7 @@ impl TaskState for MemoryMemo { &self, logical_expr_id: LogicalExpressionId, rule: &TransformationRule, - ) -> MemoizeResult { + ) -> OptimizeStateResult { let logical_expr_id = self.find_repr_logical_expr(logical_expr_id).await?; let status = self .transform_dependency @@ -516,7 +516,7 @@ impl TaskState for MemoryMemo { &mut self, logical_expr_id: LogicalExpressionId, rule: &TransformationRule, - ) -> MemoizeResult<()> { + ) -> OptimizeStateResult<()> { let logical_expr_id = self.find_repr_logical_expr(logical_expr_id).await?; let status_map = self .transform_dependency @@ -539,7 +539,7 @@ impl TaskState for MemoryMemo { logical_expr_id: LogicalExpressionId, goal_id: GoalId, rule: &ImplementationRule, - ) -> MemoizeResult { + ) -> OptimizeStateResult { let logical_expr_id = self.find_repr_logical_expr(logical_expr_id).await?; let goal_id = self.find_repr_goal(goal_id).await?; let status = self @@ -556,7 +556,7 @@ impl TaskState for MemoryMemo { logical_expr_id: LogicalExpressionId, goal_id: GoalId, rule: &ImplementationRule, - ) -> MemoizeResult<()> { + ) -> OptimizeStateResult<()> { let logical_expr_id = self.find_repr_logical_expr(logical_expr_id).await?; let status_map = self .implement_dependency @@ -577,7 +577,7 @@ impl TaskState for MemoryMemo { async fn get_cost_status( &self, physical_expr_id: PhysicalExpressionId, - ) -> MemoizeResult { + ) -> OptimizeStateResult { let physical_expr_id = self.find_repr_physical_expr(physical_expr_id).await?; let status = self .cost_dependency @@ -590,7 +590,7 @@ impl TaskState for MemoryMemo { async fn set_cost_clean( &mut self, physical_expr_id: PhysicalExpressionId, - ) -> MemoizeResult<()> { + ) -> OptimizeStateResult<()> { let physical_expr_id = self.find_repr_physical_expr(physical_expr_id).await?; let entry = self.cost_dependency.entry(physical_expr_id); @@ -613,7 +613,7 @@ impl TaskState for MemoryMemo { logical_expr_id: LogicalExpressionId, rule: &TransformationRule, group_id: GroupId, - ) -> MemoizeResult<()> { + ) -> OptimizeStateResult<()> { let logical_expr_id = self.find_repr_logical_expr(logical_expr_id).await?; let group_id = self.find_repr_group(group_id).await?; let status_map = self @@ -642,7 +642,7 @@ impl TaskState for MemoryMemo { goal_id: GoalId, rule: &ImplementationRule, group_id: GroupId, - ) -> MemoizeResult<()> { + ) -> OptimizeStateResult<()> { let logical_expr_id = self.find_repr_logical_expr(logical_expr_id).await?; let group_id = self.find_repr_group(group_id).await?; let goal_id = self.find_repr_goal(goal_id).await?; @@ -671,7 +671,7 @@ impl TaskState for MemoryMemo { &mut self, physical_expr_id: PhysicalExpressionId, goal_id: GoalId, - ) -> MemoizeResult<()> { + ) -> OptimizeStateResult<()> { let physical_expr_id = self.find_repr_physical_expr(physical_expr_id).await?; let goal_id = self.find_repr_goal(goal_id).await?; @@ -696,7 +696,7 @@ impl MemoryMemo { async fn create_repr_logical_expr( &mut self, logical_expr: LogicalExpression, - ) -> MemoizeResult { + ) -> OptimizeStateResult { let mut repr_logical_expr = logical_expr.clone(); let mut new_children = Vec::new(); @@ -733,7 +733,7 @@ impl MemoryMemo { async fn create_repr_physical_expr( &mut self, physical_expr: PhysicalExpression, - ) -> MemoizeResult { + ) -> OptimizeStateResult { let mut repr_physical_expr = physical_expr.clone(); let mut new_children = Vec::new(); @@ -777,7 +777,7 @@ impl MemoryMemo { async fn merge_physical_exprs( &mut self, physical_expr_id: PhysicalExpressionId, - ) -> MemoizeResult> { + ) -> OptimizeStateResult> { let (physical_expr, _cost) = self.physical_exprs.get(&physical_expr_id).unwrap(); let repr_physical_expr = self .create_repr_physical_expr(physical_expr.clone()) @@ -820,7 +820,7 @@ impl MemoryMemo { &mut self, goal_id1: GoalId, goal_id2: GoalId, - ) -> MemoizeResult<(MergeGoalResult, Vec)> { + ) -> OptimizeStateResult<(MergeGoalResult, Vec)> { let goal_2 = self.goals.remove(&goal_id2).unwrap(); let goal_1 = self.goals.get(&goal_id1).unwrap(); self.repr_goal.merge(&goal_id2, &goal_id1); @@ -919,7 +919,7 @@ impl MemoryMemo { &mut self, group_id_1: GroupId, group_id_2: GroupId, - ) -> MemoizeResult> { + ) -> OptimizeStateResult> { // our strategy is to always merge group 2 into group 1. let group_id_1 = self.find_repr_group(group_id_1).await?; let group_id_2 = self.find_repr_group(group_id_2).await?; @@ -1066,7 +1066,7 @@ impl MemoryMemo { &mut self, mut subscribers: VecDeque, result: &mut ForwardResult, - ) -> MemoizeResult<()> { + ) -> OptimizeStateResult<()> { while let Some(goal_id) = subscribers.pop_front() { let current_best = self.get_best_optimized_physical_expr(goal_id).await?; @@ -1100,7 +1100,7 @@ impl MemoryMemo { /// Find the representative of a goal member. /// /// This reduces down to finding representative physical expr or goal id. - async fn find_repr_goal_member(&self, member: GoalMemberId) -> MemoizeResult { + async fn find_repr_goal_member(&self, member: GoalMemberId) -> OptimizeStateResult { match member { GoalMemberId::PhysicalExpressionId(physical_expr_id) => { let physical_expr_id = self.find_repr_physical_expr(physical_expr_id).await?; diff --git a/optd/src/core/memo/mod.rs b/optd/src/core/memo/mod.rs index d4bf7d7b..d097923e 100644 --- a/optd/src/core/memo/mod.rs +++ b/optd/src/core/memo/mod.rs @@ -1,3 +1,9 @@ +//! Definitions and implementations of components related to optimizer state, which we refer to +//! generally as the memo table. +//! +//! TODO(connor): Explain the distinction between the memo table and the other things that the +//! optimizer needs to store / remember (task graph state as well). + mod error; mod traits; mod types; diff --git a/optd/src/core/memo/traits.rs b/optd/src/core/memo/traits.rs index a316c7cd..d2de2cd1 100644 --- a/optd/src/core/memo/traits.rs +++ b/optd/src/core/memo/traits.rs @@ -1,4 +1,4 @@ -use super::{ForwardResult, MemoizeResult, MergeResult, Status}; +use super::{ForwardResult, OptimizeStateResult, MergeResult, Status}; use crate::core::cir::*; pub trait OptimizerState: Memo + Materialize + TaskState {} @@ -21,7 +21,7 @@ pub trait Memo { async fn get_logical_properties( &self, group_id: GroupId, - ) -> MemoizeResult>; + ) -> OptimizeStateResult>; /// Sets logical properties for a group ID. /// @@ -35,7 +35,7 @@ pub trait Memo { &mut self, group_id: GroupId, props: LogicalProperties, - ) -> MemoizeResult<()>; + ) -> OptimizeStateResult<()>; /// Gets all logical expression IDs in a group (only representative IDs). /// @@ -47,10 +47,10 @@ pub trait Memo { async fn get_all_logical_exprs( &self, group_id: GroupId, - ) -> MemoizeResult>; + ) -> OptimizeStateResult>; /// Gets any logical expression ID in a group. - async fn get_any_logical_expr(&self, group_id: GroupId) -> MemoizeResult; + async fn get_any_logical_expr(&self, group_id: GroupId) -> OptimizeStateResult; /// Finds group containing a logical expression ID, if it exists. /// @@ -62,7 +62,7 @@ pub trait Memo { async fn find_logical_expr_group( &self, logical_expr_id: LogicalExpressionId, - ) -> MemoizeResult>; + ) -> OptimizeStateResult>; /// Creates a new group with a logical expression ID and properties. /// @@ -75,7 +75,7 @@ pub trait Memo { async fn create_group( &mut self, logical_expr_id: LogicalExpressionId, - ) -> MemoizeResult; + ) -> OptimizeStateResult; /// Merges groups 1 and 2, unifying them under a common representative. /// @@ -92,7 +92,7 @@ pub trait Memo { &mut self, group_id_1: GroupId, group_id_2: GroupId, - ) -> MemoizeResult>; + ) -> OptimizeStateResult>; /// Gets the best optimized physical expression ID for a goal ID. /// @@ -105,7 +105,7 @@ pub trait Memo { async fn get_best_optimized_physical_expr( &self, goal_id: GoalId, - ) -> MemoizeResult>; + ) -> OptimizeStateResult>; /// Gets all members of a goal, which can be physical expressions or other goals. /// @@ -114,7 +114,7 @@ pub trait Memo { /// /// # Returns /// A vector of goal members, each being either a physical expression ID or another goal ID. - async fn get_all_goal_members(&self, goal_id: GoalId) -> MemoizeResult>; + async fn get_all_goal_members(&self, goal_id: GoalId) -> OptimizeStateResult>; /// Adds a member to a goal. /// @@ -128,7 +128,7 @@ pub trait Memo { &mut self, goal_id: GoalId, member: GoalMemberId, - ) -> MemoizeResult>; + ) -> OptimizeStateResult>; /// Updates the cost of a physical expression ID. /// @@ -142,12 +142,12 @@ pub trait Memo { &mut self, physical_expr_id: PhysicalExpressionId, new_cost: Cost, - ) -> MemoizeResult>; + ) -> OptimizeStateResult>; async fn get_physical_expr_cost( &self, physical_expr_id: PhysicalExpressionId, - ) -> MemoizeResult>; + ) -> OptimizeStateResult>; /// Finds the representative group ID for a given group ID. /// @@ -157,7 +157,7 @@ pub trait Memo { /// # Returns /// The representative group ID (which may be the same as the input if /// it's already the representative). - async fn find_repr_group(&self, group_id: GroupId) -> MemoizeResult; + async fn find_repr_group(&self, group_id: GroupId) -> OptimizeStateResult; /// Finds the representative goal ID for a given goal ID. /// @@ -167,7 +167,7 @@ pub trait Memo { /// # Returns /// The representative goal ID (which may be the same as the input if /// it's already the representative). - async fn find_repr_goal(&self, goal_id: GoalId) -> MemoizeResult; + async fn find_repr_goal(&self, goal_id: GoalId) -> OptimizeStateResult; /// Finds the representative logical expression ID for a given logical expression ID. /// @@ -180,7 +180,7 @@ pub trait Memo { async fn find_repr_logical_expr( &self, logical_expr_id: LogicalExpressionId, - ) -> MemoizeResult; + ) -> OptimizeStateResult; /// Finds the representative physical expression ID for a given physical expression ID. /// @@ -193,7 +193,7 @@ pub trait Memo { async fn find_repr_physical_expr( &self, physical_expr_id: PhysicalExpressionId, - ) -> MemoizeResult; + ) -> OptimizeStateResult; } #[trait_variant::make(Send)] @@ -209,7 +209,7 @@ pub trait Materialize { /// /// # Returns /// The ID of the goal. - async fn get_goal_id(&mut self, goal: &Goal) -> MemoizeResult; + async fn get_goal_id(&mut self, goal: &Goal) -> OptimizeStateResult; /// Materializes a goal from its ID. /// @@ -218,7 +218,7 @@ pub trait Materialize { /// /// # Returns /// The materialized goal. - async fn materialize_goal(&self, goal_id: GoalId) -> MemoizeResult; + async fn materialize_goal(&self, goal_id: GoalId) -> OptimizeStateResult; /// Gets or creates a logical expression ID for a given logical expression. /// @@ -230,7 +230,7 @@ pub trait Materialize { async fn get_logical_expr_id( &mut self, logical_expr: &LogicalExpression, - ) -> MemoizeResult; + ) -> OptimizeStateResult; /// Materializes a logical expression from its ID. /// @@ -242,7 +242,7 @@ pub trait Materialize { async fn materialize_logical_expr( &self, logical_expr_id: LogicalExpressionId, - ) -> MemoizeResult; + ) -> OptimizeStateResult; /// Gets or creates a physical expression ID for a given physical expression. /// @@ -254,7 +254,7 @@ pub trait Materialize { async fn get_physical_expr_id( &mut self, physical_expr: &PhysicalExpression, - ) -> MemoizeResult; + ) -> OptimizeStateResult; /// Materializes a physical expression from its ID. /// @@ -266,7 +266,7 @@ pub trait Materialize { async fn materialize_physical_expr( &self, physical_expr_id: PhysicalExpressionId, - ) -> MemoizeResult; + ) -> OptimizeStateResult; } /// Core interface for memo-based query optimization. @@ -294,7 +294,7 @@ pub trait TaskState { &self, logical_expr_id: LogicalExpressionId, rule: &TransformationRule, - ) -> MemoizeResult; + ) -> OptimizeStateResult; /// Sets the status of a transformation rule as clean on a logical expression ID. /// @@ -305,7 +305,7 @@ pub trait TaskState { &mut self, logical_expr_id: LogicalExpressionId, rule: &TransformationRule, - ) -> MemoizeResult<()>; + ) -> OptimizeStateResult<()>; /// Checks the status of applying an implementation rule on a logical expression ID and goal ID. /// @@ -322,7 +322,7 @@ pub trait TaskState { logical_expr_id: LogicalExpressionId, goal_id: GoalId, rule: &ImplementationRule, - ) -> MemoizeResult; + ) -> OptimizeStateResult; /// Sets the status of an implementation rule as clean on a logical expression ID and goal ID. /// @@ -335,7 +335,7 @@ pub trait TaskState { logical_expr_id: LogicalExpressionId, goal_id: GoalId, rule: &ImplementationRule, - ) -> MemoizeResult<()>; + ) -> OptimizeStateResult<()>; /// Checks the status of costing a physical expression ID. /// @@ -348,14 +348,14 @@ pub trait TaskState { async fn get_cost_status( &self, physical_expr_id: PhysicalExpressionId, - ) -> MemoizeResult; + ) -> OptimizeStateResult; /// Sets the status of costing a physical expression ID as clean. /// /// # Parameters /// * `physical_expr_id` - ID of the physical expression to update. async fn set_cost_clean(&mut self, physical_expr_id: PhysicalExpressionId) - -> MemoizeResult<()>; + -> OptimizeStateResult<()>; /// Adds a dependency between a transformation rule application and a group. /// @@ -371,7 +371,7 @@ pub trait TaskState { logical_expr_id: LogicalExpressionId, rule: &TransformationRule, group_id: GroupId, - ) -> MemoizeResult<()>; + ) -> OptimizeStateResult<()>; /// Adds a dependency between an implementation rule application and a group. /// @@ -390,7 +390,7 @@ pub trait TaskState { goal_id: GoalId, rule: &ImplementationRule, group_id: GroupId, - ) -> MemoizeResult<()>; + ) -> OptimizeStateResult<()>; /// Adds a dependency between costing a physical expression and a goal. /// @@ -404,5 +404,5 @@ pub trait TaskState { &mut self, physical_expr_id: PhysicalExpressionId, goal_id: GoalId, - ) -> MemoizeResult<()>; + ) -> OptimizeStateResult<()>; } From 5d8e94421184c2214e65619fd98fdcf442685dbd Mon Sep 17 00:00:00 2001 From: Connor Tsui Date: Wed, 30 Apr 2025 10:19:12 -0400 Subject: [PATCH 6/9] rename traits and remove constructors --- optd/src/core/memo/error.rs | 2 +- optd/src/core/memo/memory.rs | 86 ++++++++++++++++++++++-------------- optd/src/core/memo/mod.rs | 2 +- optd/src/core/memo/traits.rs | 26 ++++++----- optd/src/core/memo/types.rs | 39 +++++----------- 5 files changed, 84 insertions(+), 71 deletions(-) diff --git a/optd/src/core/memo/error.rs b/optd/src/core/memo/error.rs index 6453faaf..2e388f63 100644 --- a/optd/src/core/memo/error.rs +++ b/optd/src/core/memo/error.rs @@ -1,7 +1,7 @@ use crate::core::cir::*; /// A type alias for results returned by the different memo table trait methods. -/// +/// /// See the private `traits.rs` module for more information (note that the traits are re-exported). pub type OptimizeStateResult = Result; diff --git a/optd/src/core/memo/memory.rs b/optd/src/core/memo/memory.rs index 754488d5..4c7355bb 100644 --- a/optd/src/core/memo/memory.rs +++ b/optd/src/core/memo/memory.rs @@ -63,11 +63,11 @@ pub struct MemoryMemo { struct RuleDependency { group_ids: HashSet, - status: Status, + status: TaskStatus, } impl RuleDependency { - fn new(status: Status) -> Self { + fn new(status: TaskStatus) -> Self { let group_ids = HashSet::new(); Self { group_ids, status } } @@ -75,11 +75,11 @@ impl RuleDependency { struct CostDependency { goal_ids: HashSet, - status: Status, + status: TaskStatus, } impl CostDependency { - fn new(status: Status) -> Self { + fn new(status: TaskStatus) -> Self { let goal_ids = HashSet::new(); Self { goal_ids, status } } @@ -164,7 +164,10 @@ impl Memo for MemoryMemo { Ok(group.logical_exprs.iter().cloned().collect()) } - async fn get_any_logical_expr(&self, group_id: GroupId) -> OptimizeStateResult { + async fn get_any_logical_expr( + &self, + group_id: GroupId, + ) -> OptimizeStateResult { let group_id = self.find_repr_group(group_id).await?; let group = self .groups @@ -221,7 +224,10 @@ impl Memo for MemoryMemo { Ok(maybe_best_costed) } - async fn get_all_goal_members(&self, goal_id: GoalId) -> OptimizeStateResult> { + async fn get_all_goal_members( + &self, + goal_id: GoalId, + ) -> OptimizeStateResult> { let goal_id = self.find_repr_goal(goal_id).await?; let goal_state = self.goals.get(&goal_id).unwrap(); Ok(goal_state.members.iter().cloned().collect()) @@ -263,10 +269,16 @@ impl Memo for MemoryMemo { }; let mut subscribers = VecDeque::new(); subscribers.push_back(goal_id); - let mut result = ForwardResult::new(physical_expr_id, cost); + + let mut result = ForwardResult { + physical_expr_id, + best_cost: cost, + goals_forwarded: HashSet::new(), + }; // propagate the new cost to all subscribers. self.propagate_new_member_cost(subscribers, &mut result) .await?; + if result.goals_forwarded.is_empty() { // No goals were forwarded, so we can return None. Ok(None) @@ -317,7 +329,12 @@ impl Memo for MemoryMemo { subscribers.extend(subscriber_goal_ids); } - let mut result = ForwardResult::new(physical_expr_id, new_cost); + let mut result = ForwardResult { + physical_expr_id, + best_cost: new_cost, + goals_forwarded: HashSet::new(), + }; + // propagate the new cost to all subscribers. self.propagate_new_member_cost(subscribers, &mut result) .await?; @@ -496,19 +513,19 @@ impl Materialize for MemoryMemo { } } -impl TaskState for MemoryMemo { +impl TaskGraphState for MemoryMemo { async fn get_transformation_status( &self, logical_expr_id: LogicalExpressionId, rule: &TransformationRule, - ) -> OptimizeStateResult { + ) -> OptimizeStateResult { let logical_expr_id = self.find_repr_logical_expr(logical_expr_id).await?; let status = self .transform_dependency .get(&logical_expr_id) .and_then(|status_map| status_map.get(rule)) .map(|dep| dep.status) - .unwrap_or(Status::Dirty); + .unwrap_or(TaskStatus::Dirty); Ok(status) } @@ -525,10 +542,10 @@ impl TaskState for MemoryMemo { match status_map.entry(rule.clone()) { Entry::Occupied(occupied_entry) => { let dep = occupied_entry.into_mut(); - dep.status = Status::Clean; + dep.status = TaskStatus::Clean; } Entry::Vacant(vacant) => { - vacant.insert(RuleDependency::new(Status::Clean)); + vacant.insert(RuleDependency::new(TaskStatus::Clean)); } } Ok(()) @@ -539,7 +556,7 @@ impl TaskState for MemoryMemo { logical_expr_id: LogicalExpressionId, goal_id: GoalId, rule: &ImplementationRule, - ) -> OptimizeStateResult { + ) -> OptimizeStateResult { let logical_expr_id = self.find_repr_logical_expr(logical_expr_id).await?; let goal_id = self.find_repr_goal(goal_id).await?; let status = self @@ -547,7 +564,7 @@ impl TaskState for MemoryMemo { .get(&logical_expr_id) .and_then(|status_map| status_map.get(&(goal_id, rule.clone()))) .map(|dep| dep.status) - .unwrap_or(Status::Dirty); + .unwrap_or(TaskStatus::Dirty); Ok(status) } @@ -565,10 +582,10 @@ impl TaskState for MemoryMemo { match status_map.entry((goal_id, rule.clone())) { Entry::Occupied(occupied_entry) => { let dep = occupied_entry.into_mut(); - dep.status = Status::Clean; + dep.status = TaskStatus::Clean; } Entry::Vacant(vacant) => { - vacant.insert(RuleDependency::new(Status::Clean)); + vacant.insert(RuleDependency::new(TaskStatus::Clean)); } } Ok(()) @@ -577,13 +594,13 @@ impl TaskState for MemoryMemo { async fn get_cost_status( &self, physical_expr_id: PhysicalExpressionId, - ) -> OptimizeStateResult { + ) -> OptimizeStateResult { let physical_expr_id = self.find_repr_physical_expr(physical_expr_id).await?; let status = self .cost_dependency .get(&physical_expr_id) .map(|dep| dep.status) - .unwrap_or(Status::Dirty); + .unwrap_or(TaskStatus::Dirty); Ok(status) } @@ -598,10 +615,10 @@ impl TaskState for MemoryMemo { match entry { Entry::Occupied(occupied) => { let dep = occupied.into_mut(); - dep.status = Status::Clean; + dep.status = TaskStatus::Clean; } Entry::Vacant(vacant) => { - vacant.insert(CostDependency::new(Status::Clean)); + vacant.insert(CostDependency::new(TaskStatus::Clean)); } } @@ -627,7 +644,7 @@ impl TaskState for MemoryMemo { dep.group_ids.insert(group_id); } Entry::Vacant(vacant) => { - let mut dep = RuleDependency::new(Status::Dirty); + let mut dep = RuleDependency::new(TaskStatus::Dirty); dep.group_ids.insert(group_id); vacant.insert(dep); } @@ -658,7 +675,7 @@ impl TaskState for MemoryMemo { dep.group_ids.insert(group_id); } Entry::Vacant(vacant) => { - let mut dep = RuleDependency::new(Status::Dirty); + let mut dep = RuleDependency::new(TaskStatus::Dirty); dep.group_ids.insert(group_id); vacant.insert(dep); } @@ -681,7 +698,7 @@ impl TaskState for MemoryMemo { dep.goal_ids.insert(goal_id); } Entry::Vacant(vacant) => { - let mut dep = CostDependency::new(Status::Dirty); + let mut dep = CostDependency::new(TaskStatus::Dirty); dep.goal_ids.insert(goal_id); vacant.insert(dep); } @@ -943,13 +960,15 @@ impl MemoryMemo { assert!(old_group_id.is_some()); group1_state.logical_exprs.insert(logical_expr_id); } - let mut merge_group_result = MergeGroupResult::new(group_id_1); - merge_group_result - .merged_groups - .insert(group_id_1, group1_exprs); - merge_group_result - .merged_groups - .insert(group_id_2, group_2_exprs); + + let mut merged_groups = HashMap::with_capacity(2); + merged_groups.insert(group_id_1, group1_exprs); + merged_groups.insert(group_id_2, group_2_exprs); + + let merge_group_result = MergeGroupResult { + new_repr_group_id: group_id_1, + merged_groups, + }; self.repr_group.merge(&group_id_2, &group_id_1); @@ -1100,7 +1119,10 @@ impl MemoryMemo { /// Find the representative of a goal member. /// /// This reduces down to finding representative physical expr or goal id. - async fn find_repr_goal_member(&self, member: GoalMemberId) -> OptimizeStateResult { + async fn find_repr_goal_member( + &self, + member: GoalMemberId, + ) -> OptimizeStateResult { match member { GoalMemberId::PhysicalExpressionId(physical_expr_id) => { let physical_expr_id = self.find_repr_physical_expr(physical_expr_id).await?; diff --git a/optd/src/core/memo/mod.rs b/optd/src/core/memo/mod.rs index d097923e..7f33ba56 100644 --- a/optd/src/core/memo/mod.rs +++ b/optd/src/core/memo/mod.rs @@ -1,6 +1,6 @@ //! Definitions and implementations of components related to optimizer state, which we refer to //! generally as the memo table. -//! +//! //! TODO(connor): Explain the distinction between the memo table and the other things that the //! optimizer needs to store / remember (task graph state as well). diff --git a/optd/src/core/memo/traits.rs b/optd/src/core/memo/traits.rs index d2de2cd1..e445c318 100644 --- a/optd/src/core/memo/traits.rs +++ b/optd/src/core/memo/traits.rs @@ -1,7 +1,7 @@ -use super::{ForwardResult, OptimizeStateResult, MergeResult, Status}; +use super::{ForwardResult, MergeResult, OptimizeStateResult, TaskStatus}; use crate::core::cir::*; -pub trait OptimizerState: Memo + Materialize + TaskState {} +pub trait OptimizerState: Memo + Materialize + TaskGraphState {} // // Logical expression and group operations. @@ -50,7 +50,10 @@ pub trait Memo { ) -> OptimizeStateResult>; /// Gets any logical expression ID in a group. - async fn get_any_logical_expr(&self, group_id: GroupId) -> OptimizeStateResult; + async fn get_any_logical_expr( + &self, + group_id: GroupId, + ) -> OptimizeStateResult; /// Finds group containing a logical expression ID, if it exists. /// @@ -114,7 +117,8 @@ pub trait Memo { /// /// # Returns /// A vector of goal members, each being either a physical expression ID or another goal ID. - async fn get_all_goal_members(&self, goal_id: GoalId) -> OptimizeStateResult>; + async fn get_all_goal_members(&self, goal_id: GoalId) + -> OptimizeStateResult>; /// Adds a member to a goal. /// @@ -276,7 +280,7 @@ pub trait Materialize { /// query optimization. The memo stores logical and physical expressions by their IDs, /// manages expression properties, and tracks optimization status. #[trait_variant::make(Send)] -pub trait TaskState { +pub trait TaskGraphState { // // Rule and costing status operations. // @@ -294,7 +298,7 @@ pub trait TaskState { &self, logical_expr_id: LogicalExpressionId, rule: &TransformationRule, - ) -> OptimizeStateResult; + ) -> OptimizeStateResult; /// Sets the status of a transformation rule as clean on a logical expression ID. /// @@ -322,7 +326,7 @@ pub trait TaskState { logical_expr_id: LogicalExpressionId, goal_id: GoalId, rule: &ImplementationRule, - ) -> OptimizeStateResult; + ) -> OptimizeStateResult; /// Sets the status of an implementation rule as clean on a logical expression ID and goal ID. /// @@ -348,14 +352,16 @@ pub trait TaskState { async fn get_cost_status( &self, physical_expr_id: PhysicalExpressionId, - ) -> OptimizeStateResult; + ) -> OptimizeStateResult; /// Sets the status of costing a physical expression ID as clean. /// /// # Parameters /// * `physical_expr_id` - ID of the physical expression to update. - async fn set_cost_clean(&mut self, physical_expr_id: PhysicalExpressionId) - -> OptimizeStateResult<()>; + async fn set_cost_clean( + &mut self, + physical_expr_id: PhysicalExpressionId, + ) -> OptimizeStateResult<()>; /// Adds a dependency between a transformation rule application and a group. /// diff --git a/optd/src/core/memo/types.rs b/optd/src/core/memo/types.rs index 7967fdab..559feee5 100644 --- a/optd/src/core/memo/types.rs +++ b/optd/src/core/memo/types.rs @@ -1,17 +1,16 @@ use crate::core::cir::*; use std::collections::{HashMap, HashSet}; -/// Status of a rule application or costing operation +/// The status of rule application or costing operation in the task graph. #[derive(Debug, Clone, Copy, PartialEq, Eq)] -pub enum Status { +pub enum TaskStatus { /// There exist ongoing jobs that may generate more expressions or costs from this expression. Dirty, - /// Expression is fully explored or costed with no pending jobs that could add anything new. Clean, } -/// Result of merging two groups. +/// The result of merging two groups. #[derive(Debug)] pub struct MergeGroupResult { /// ID of the new representative group id. @@ -20,20 +19,6 @@ pub struct MergeGroupResult { pub merged_groups: HashMap>, } -impl MergeGroupResult { - /// Creates a new MergeGroupResult instance. - /// - /// # Parameters - /// * `merged_groups` - Groups that were merged along with their expressions. - /// * `new_repr_group_id` - ID of the new representative group id. - pub fn new(new_repr_group_id: GroupId) -> Self { - Self { - new_repr_group_id, - merged_groups: HashMap::new(), - } - } -} - /// Information about a merged goal, including its ID and expressions #[derive(Debug)] pub struct MergedGoalInfo { @@ -97,12 +82,12 @@ pub struct ForwardResult { pub goals_forwarded: HashSet, } -impl ForwardResult { - pub fn new(physical_expr_id: PhysicalExpressionId, best_cost: Cost) -> Self { - Self { - physical_expr_id, - best_cost, - goals_forwarded: HashSet::new(), - } - } -} +// impl ForwardResult { +// pub fn new(physical_expr_id: PhysicalExpressionId, best_cost: Cost) -> Self { +// Self { +// physical_expr_id, +// best_cost, +// goals_forwarded: HashSet::new(), +// } +// } +// } From 03718012e4e7a17a34be92cafdf9b50bbb1686b5 Mon Sep 17 00:00:00 2001 From: Connor Tsui Date: Wed, 30 Apr 2025 10:21:24 -0400 Subject: [PATCH 7/9] clean up mod.rs --- optd/src/core/memo/mod.rs | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/optd/src/core/memo/mod.rs b/optd/src/core/memo/mod.rs index 7f33ba56..d64b96f1 100644 --- a/optd/src/core/memo/mod.rs +++ b/optd/src/core/memo/mod.rs @@ -4,12 +4,16 @@ //! TODO(connor): Explain the distinction between the memo table and the other things that the //! optimizer needs to store / remember (task graph state as well). +/// Error and Result defintions. mod error; -mod traits; -mod types; - pub use error::*; + +/// Trait definitions. +mod traits; pub use traits::*; + +/// Type definitions. +mod types; pub use types::*; /// A generic implementation of the Union-Find algorithm. From 133ebeeb4d09860733712fd2f016e1c6061128cc Mon Sep 17 00:00:00 2001 From: Connor Tsui Date: Wed, 30 Apr 2025 12:01:43 -0400 Subject: [PATCH 8/9] first attempt refactor draft --- optd/src/core/memo/memory.rs | 1 + optd/src/core/memo/mod.rs | 4 +- optd/src/core/memo/traits.rs | 201 +++++++++++++++-------------------- 3 files changed, 91 insertions(+), 115 deletions(-) diff --git a/optd/src/core/memo/memory.rs b/optd/src/core/memo/memory.rs index 4c7355bb..f36667c0 100644 --- a/optd/src/core/memo/memory.rs +++ b/optd/src/core/memo/memory.rs @@ -105,6 +105,7 @@ impl GroupState { } } +/// TODO(connor): Why are the members not part of a goal normally? struct GoalState { /// The set of members that are part of this goal. goal: Goal, diff --git a/optd/src/core/memo/mod.rs b/optd/src/core/memo/mod.rs index d64b96f1..235383dd 100644 --- a/optd/src/core/memo/mod.rs +++ b/optd/src/core/memo/mod.rs @@ -19,5 +19,5 @@ pub use types::*; /// A generic implementation of the Union-Find algorithm. mod union_find; -/// In-memory implementation of the optimizer state (including the memo table). -mod memory; +// In-memory implementation of the optimizer state (including the memo table). +// mod memory; diff --git a/optd/src/core/memo/traits.rs b/optd/src/core/memo/traits.rs index e445c318..9f5ec3b3 100644 --- a/optd/src/core/memo/traits.rs +++ b/optd/src/core/memo/traits.rs @@ -1,101 +1,122 @@ use super::{ForwardResult, MergeResult, OptimizeStateResult, TaskStatus}; use crate::core::cir::*; +/// The main interface for tracking state needed by the optimizer. This includes the memo table and +/// state needed for the task graph. pub trait OptimizerState: Memo + Materialize + TaskGraphState {} -// -// Logical expression and group operations. -// -// -// Physical expression and goal operations. -// +/// The interface for a `Group` of logical expressions. +/// +/// Implementors of this trait should be able to track the logical expressions belonging to this +/// `Group` via [`LogicalExpressionId`], as well as the derived [`LogicalProperties`] and related +/// [`Goal`]s via [`GoalId`]s. +pub trait Group { + /// Creates a new `Group` from a new [`LogicalExpressionId`]. + fn new_from_logical_expression(id: LogicalExpressionId) -> Self; + + /// Retrieves an iterator of [`LogicalExpressionId`] contained in the `Group`. + fn logical_expressions(&self) -> impl Iterator; + + /// Checks if the `Group` contains a logical expression by ID. + fn contains_logical_expression(&self, id: LogicalExpressionId) -> bool; + + /// Adds a logical expression to a `Group`. + fn add_logical_expression(&mut self, id: LogicalExpressionId); + + /// Removes a logical expression to a `Group`. + fn remove_logical_expression(&mut self, id: LogicalExpressionId); + + /// Retrieves the logical properties of a `Group`. + fn logical_properties(&self) -> Option; + + /// Replaces the logical properties for a `Group`. + fn replace_logical_properties(&mut self, props: LogicalProperties) + -> Option; + + /// The IDs of the [`Goal`]s that are dependent on this `Group`. + fn goals(&self) -> impl Iterator; + + /// Add a related [`GoalId`] to this `Group`. + fn add_goal(&mut self, goal_id: GoalId); +} + +/// The interface for an optimizer memoization (memo) table. +/// +/// This trait mainly describes operations related to groups, goals, logical and physical +/// expressions, and finding representative nodes of the union-find substructures. #[trait_variant::make(Send)] pub trait Memo { - /// Retrieves logical properties for a group ID. - /// - /// # Parameters - /// * `group_id` - ID of the group to retrieve properties for. - /// - /// # Returns - /// The properties associated with the group or an error if not found. - async fn get_logical_properties( - &self, - group_id: GroupId, - ) -> OptimizeStateResult>; + /// The associated type needed for managing `Group` data. + type GroupState: Group; - /// Sets logical properties for a group ID. - /// - /// # Parameters - /// * `group_id` - ID of the group to set properties for. - /// * `props` - The logical properties to associate with the group. + /// Retrives the `GroupState` data given the group's ID. + async fn get_group(&self, group_id: GroupId) -> &Self::GroupState; + + /// Mutably retrives the `GroupState` data given the group's ID. + async fn get_group_mut(&mut self, group_id: GroupId) -> &mut Self::GroupState; + + /// Finds the representative group of a given group. The representative is usually tracked via a + /// Union-Find data structure. /// - /// # Returns - /// A result indicating success or failure of the operation. - async fn set_logical_properties( - &mut self, - group_id: GroupId, - props: LogicalProperties, - ) -> OptimizeStateResult<()>; + /// If the input group is already the representative, then the returned [`GroupId`] is equal to + /// the input [`GroupId`]. + async fn find_repr_group(&self, group_id: GroupId) -> GroupId; - /// Gets all logical expression IDs in a group (only representative IDs). + /// Finds the representative goal of a given goal. The representative is usually tracked via a + /// Union-Find data structure. /// - /// # Parameters - /// * `group_id` - ID of the group to retrieve expressions from. + /// If the input goal is already the representative, then the returned [`GoalId`] is equal to + /// the input [`GoalId`]. + async fn find_repr_goal(&self, goal_id: GoalId) -> GoalId; + + /// Finds the representative logical expression of a given expression. The representative is + /// usually tracked via a Union-Find data structure. /// - /// # Returns - /// A vector of logical expression IDs in the specified group. - async fn get_all_logical_exprs( + /// If the input expression is already the representative, then the returned + /// [`LogicalExpressionId`] is equal to the input [`LogicalExpressionId`]. + async fn find_repr_logical_expr( &self, - group_id: GroupId, - ) -> OptimizeStateResult>; + logical_expr_id: LogicalExpressionId, + ) -> LogicalExpressionId; - /// Gets any logical expression ID in a group. - async fn get_any_logical_expr( + /// Finds the representative physical expression of a given expression. The representative is + /// usually tracked via a Union-Find data structure. + /// + /// If the input expression is already the representative, then the returned + /// [`PhysicalExpressionId`] is equal to the input [`PhysicalExpressionId`]. + async fn find_repr_physical_expr( &self, - group_id: GroupId, - ) -> OptimizeStateResult; + physical_expr_id: PhysicalExpressionId, + ) -> PhysicalExpressionId; - /// Finds group containing a logical expression ID, if it exists. - /// - /// # Parameters - /// * `logical_expr_id` - ID of the logical expression to find. + /// Finds the ID of the representative group containing the given logical expression ID. /// - /// # Returns - /// The group ID if the expression exists, None otherwise. - async fn find_logical_expr_group( + /// If there is no `Group` that contains the input logical expression ID, this returns `None`. + async fn find_group_of_logical_expression( &self, logical_expr_id: LogicalExpressionId, - ) -> OptimizeStateResult>; + ) -> Option; - /// Creates a new group with a logical expression ID and properties. + /// Creates a new group given a new [`LogicalExpressionId`]. /// - /// # Parameters - /// * `logical_expr_id` - ID of the logical expression to add to the group. - /// * `props` - Logical properties for the group. - /// - /// # Returns - /// The ID of the newly created group. - async fn create_group( - &mut self, - logical_expr_id: LogicalExpressionId, - ) -> OptimizeStateResult; + /// Returns The ID of the newly created group. + async fn create_group(&mut self, logical_expr_id: LogicalExpressionId) -> GroupId; - /// Merges groups 1 and 2, unifying them under a common representative. + /// Merges two groups, unifying them under a common representative group. /// - /// May trigger cascading merges of parent groups & goals. - /// - /// # Parameters - /// * `group_id_1` - ID of the first group to merge. - /// * `group_id_2` - ID of the second group to merge. + /// This function can trigger cascading (recursive) merges of parent groups & goals. /// + /// TODO(connor): Clean up /// # Returns /// Merge results for all affected entities including newly dirtied /// transformations, implementations and costings. + /// + /// Should panic if the groups are equal (instead of returning an option) async fn merge_groups( &mut self, group_id_1: GroupId, group_id_2: GroupId, - ) -> OptimizeStateResult>; + ) -> OptimizeStateResult; /// Gets the best optimized physical expression ID for a goal ID. /// @@ -152,52 +173,6 @@ pub trait Memo { &self, physical_expr_id: PhysicalExpressionId, ) -> OptimizeStateResult>; - - /// Finds the representative group ID for a given group ID. - /// - /// # Parameters - /// * `group_id` - The group ID to find the representative for. - /// - /// # Returns - /// The representative group ID (which may be the same as the input if - /// it's already the representative). - async fn find_repr_group(&self, group_id: GroupId) -> OptimizeStateResult; - - /// Finds the representative goal ID for a given goal ID. - /// - /// # Parameters - /// * `goal_id` - The goal ID to find the representative for. - /// - /// # Returns - /// The representative goal ID (which may be the same as the input if - /// it's already the representative). - async fn find_repr_goal(&self, goal_id: GoalId) -> OptimizeStateResult; - - /// Finds the representative logical expression ID for a given logical expression ID. - /// - /// # Parameters - /// * `logical_expr_id` - The logical expression ID to find the representative for. - /// - /// # Returns - /// The representative logical expression ID (which may be the same as the input if - /// it's already the representative). - async fn find_repr_logical_expr( - &self, - logical_expr_id: LogicalExpressionId, - ) -> OptimizeStateResult; - - /// Finds the representative physical expression ID for a given physical expression ID. - /// - /// # Parameters - /// * `physical_expr_id` - The physical expression ID to find the representative for. - /// - /// # Returns - /// The representative physical expression ID (which may be the same as the input if - /// it's already the representative). - async fn find_repr_physical_expr( - &self, - physical_expr_id: PhysicalExpressionId, - ) -> OptimizeStateResult; } #[trait_variant::make(Send)] From e692a944712ddcf3d5c2a86f278ed2241f19b805 Mon Sep 17 00:00:00 2001 From: Connor Tsui Date: Wed, 30 Apr 2025 12:27:12 -0400 Subject: [PATCH 9/9] some more refactors --- optd/src/core/bridge/from_cir.rs | 4 +-- optd/src/core/bridge/into_cir.rs | 7 +++- optd/src/core/cir/goal.rs | 13 +++++-- optd/src/core/cir/group.rs | 13 +++++++ optd/src/core/memo/traits.rs | 60 ++++++-------------------------- 5 files changed, 42 insertions(+), 55 deletions(-) diff --git a/optd/src/core/bridge/from_cir.rs b/optd/src/core/bridge/from_cir.rs index 6c84327f..aca962a7 100644 --- a/optd/src/core/bridge/from_cir.rs +++ b/optd/src/core/bridge/from_cir.rs @@ -79,8 +79,8 @@ pub(crate) fn physical_properties_to_value(properties: &PhysicalProperties) -> V /// Converts a CIR [`Goal`] to a HIR [`Goal`](hir::Goal). fn cir_goal_to_hir(goal: &Goal) -> hir::Goal { - let group_id = cir_group_id_to_hir(&goal.0); - let properties = physical_properties_to_value(&goal.1); + let group_id = cir_group_id_to_hir(&goal.group_id); + let properties = physical_properties_to_value(&goal.properties); hir::Goal { group_id, diff --git a/optd/src/core/bridge/into_cir.rs b/optd/src/core/bridge/into_cir.rs index 7bd5c956..858afa3f 100644 --- a/optd/src/core/bridge/into_cir.rs +++ b/optd/src/core/bridge/into_cir.rs @@ -6,6 +6,7 @@ use Child::*; use CoreData::*; use Literal::*; use Materializable::*; +use std::collections::HashSet; use std::sync::Arc; /// Converts a [`Value`] into a [`PartialLogicalPlan`]. @@ -99,7 +100,11 @@ pub(crate) fn hir_group_id_to_cir(hir_group_id: &hir::GroupId) -> GroupId { pub(crate) fn hir_goal_to_cir(hir_goal: &hir::Goal) -> Goal { let group_id = hir_group_id_to_cir(&hir_goal.group_id); let properties = value_to_physical_properties(&hir_goal.properties); - Goal(group_id, properties) + Goal { + group_id, + properties, + members: HashSet::new(), + } } /// Converts a [`Value`] into a fully materialized [`LogicalPlan`]. diff --git a/optd/src/core/cir/goal.rs b/optd/src/core/cir/goal.rs index 1873d89a..79981165 100644 --- a/optd/src/core/cir/goal.rs +++ b/optd/src/core/cir/goal.rs @@ -1,4 +1,6 @@ use super::{PhysicalExpressionId, group::GroupId, properties::PhysicalProperties}; +use std::collections::HashSet; +use std::hash::Hash; /// A physical optimization goal, consisting of a group to optimize and the required physical /// properties. @@ -7,8 +9,15 @@ use super::{PhysicalExpressionId, group::GroupId, properties::PhysicalProperties /// that satisfies specific physical property requirements (like sort order, distribution, etc.). /// /// Goals can be thought of as the physical counterpart to logical [`GroupId`]s. -#[derive(Debug, Clone, PartialEq, Eq, Hash)] -pub struct Goal(pub GroupId, pub PhysicalProperties); +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct Goal { + /// The [`GroupId`] of the group this `Goal` is based on. + pub group_id: GroupId, + /// The physical properties of this `Goal`. + pub properties: PhysicalProperties, + /// The set of members that are part of this goal. + pub members: HashSet, +} /// Represents a member of a goal, which can be either a physical expression /// or a reference to another goal diff --git a/optd/src/core/cir/group.rs b/optd/src/core/cir/group.rs index f1d277e0..8068c13a 100644 --- a/optd/src/core/cir/group.rs +++ b/optd/src/core/cir/group.rs @@ -1,3 +1,6 @@ +use super::{GoalId, LogicalExpressionId, LogicalProperties}; +use std::collections::HashSet; + /// A unique identifier for a group in the memo structure. /// /// A group in a represents a set of logically equivalent expressions. All expressions in a group @@ -10,3 +13,13 @@ /// expressions. optd instead represents physical expressions via [`Goal`]s. #[derive(Debug, Clone, Copy, PartialEq, Hash, Eq)] pub struct GroupId(pub i64); + +/// The representation of a `Group` of logical expressions in the memo table. +pub struct Group { + /// The logical expression belonging to this `Group`, tracked via [`LogicalExpressionId`]. + pub logical_exprs: HashSet, + /// The logical properties of the group, might be `None` if it hasn't been derived yet. + pub properties: Option, + /// The `Goal`s that are dependent on this `Group`. + pub goals: HashSet, +} diff --git a/optd/src/core/memo/traits.rs b/optd/src/core/memo/traits.rs index 9f5ec3b3..37b82a46 100644 --- a/optd/src/core/memo/traits.rs +++ b/optd/src/core/memo/traits.rs @@ -5,55 +5,23 @@ use crate::core::cir::*; /// state needed for the task graph. pub trait OptimizerState: Memo + Materialize + TaskGraphState {} -/// The interface for a `Group` of logical expressions. -/// -/// Implementors of this trait should be able to track the logical expressions belonging to this -/// `Group` via [`LogicalExpressionId`], as well as the derived [`LogicalProperties`] and related -/// [`Goal`]s via [`GoalId`]s. -pub trait Group { - /// Creates a new `Group` from a new [`LogicalExpressionId`]. - fn new_from_logical_expression(id: LogicalExpressionId) -> Self; - - /// Retrieves an iterator of [`LogicalExpressionId`] contained in the `Group`. - fn logical_expressions(&self) -> impl Iterator; - - /// Checks if the `Group` contains a logical expression by ID. - fn contains_logical_expression(&self, id: LogicalExpressionId) -> bool; - - /// Adds a logical expression to a `Group`. - fn add_logical_expression(&mut self, id: LogicalExpressionId); - - /// Removes a logical expression to a `Group`. - fn remove_logical_expression(&mut self, id: LogicalExpressionId); - - /// Retrieves the logical properties of a `Group`. - fn logical_properties(&self) -> Option; - - /// Replaces the logical properties for a `Group`. - fn replace_logical_properties(&mut self, props: LogicalProperties) - -> Option; - - /// The IDs of the [`Goal`]s that are dependent on this `Group`. - fn goals(&self) -> impl Iterator; - - /// Add a related [`GoalId`] to this `Group`. - fn add_goal(&mut self, goal_id: GoalId); -} - /// The interface for an optimizer memoization (memo) table. /// /// This trait mainly describes operations related to groups, goals, logical and physical /// expressions, and finding representative nodes of the union-find substructures. #[trait_variant::make(Send)] pub trait Memo { - /// The associated type needed for managing `Group` data. - type GroupState: Group; - /// Retrives the `GroupState` data given the group's ID. - async fn get_group(&self, group_id: GroupId) -> &Self::GroupState; + async fn get_group(&self, group_id: GroupId) -> &Group; /// Mutably retrives the `GroupState` data given the group's ID. - async fn get_group_mut(&mut self, group_id: GroupId) -> &mut Self::GroupState; + async fn get_group_mut(&mut self, group_id: GroupId) -> &mut Group; + + /// Retrives the `GroupState` data given the group's ID. + async fn get_goal(&self, goal_id: GoalId) -> &Goal; + + /// Mutably retrives the `GroupState` data given the goal's ID. + async fn get_goal_mut(&mut self, goal_id: GoalId) -> &mut Goal; /// Finds the representative group of a given group. The representative is usually tracked via a /// Union-Find data structure. @@ -131,16 +99,6 @@ pub trait Memo { goal_id: GoalId, ) -> OptimizeStateResult>; - /// Gets all members of a goal, which can be physical expressions or other goals. - /// - /// # Parameters - /// * `goal_id` - ID of the goal to retrieve members from. - /// - /// # Returns - /// A vector of goal members, each being either a physical expression ID or another goal ID. - async fn get_all_goal_members(&self, goal_id: GoalId) - -> OptimizeStateResult>; - /// Adds a member to a goal. /// /// # Parameters @@ -149,6 +107,8 @@ pub trait Memo { /// /// # Returns /// True if the member was added to the goal, or false if it already existed. + /// + /// TODO(connor): This is clearly doing much more than adding a goal member async fn add_goal_member( &mut self, goal_id: GoalId,