From 2dd2a318fa78bacf7c9613ac174896510aecfa1f Mon Sep 17 00:00:00 2001 From: "Alex Chi Z." <4198311+skyzh@users.noreply.github.com> Date: Sun, 17 Nov 2024 23:07:33 -0500 Subject: [PATCH] refactor(core): rm option around cost models (#243) `Option` was introduced during a transition time where we thought cost model can compute the cost solely based on the children cost but it turned out that we need the optimizer for a few derived logical properties. We can drop them now. Signed-off-by: Alex Chi --- docs/src/cost_model.md | 2 +- .../src/cascades/tasks/optimize_inputs.rs | 10 +-- optd-core/src/cost.rs | 16 +++-- optd-datafusion-repr-adv-cost/src/lib.rs | 62 +++++++------------ .../src/cost/adaptive_cost.rs | 14 ++--- optd-datafusion-repr/src/cost/base_cost.rs | 8 +-- .../src/testing/dummy_cost.rs | 8 +-- 7 files changed, 51 insertions(+), 69 deletions(-) diff --git a/docs/src/cost_model.md b/docs/src/cost_model.md index fae44bc8..bc8e39da 100644 --- a/docs/src/cost_model.md +++ b/docs/src/cost_model.md @@ -11,7 +11,7 @@ pub trait CostModel: 'static + Send + Sync { node: &T, data: &Option, children: &[Cost], - context: Option, + context: RelNodeContext, ) -> Cost; } ``` diff --git a/optd-core/src/cascades/tasks/optimize_inputs.rs b/optd-core/src/cascades/tasks/optimize_inputs.rs index b19705bb..6a5da11f 100644 --- a/optd-core/src/cascades/tasks/optimize_inputs.rs +++ b/optd-core/src/cascades/tasks/optimize_inputs.rs @@ -110,12 +110,12 @@ impl OptimizeInputsTask { .iter() .map(|x| x.expect("child winner should always have statistics?")) .collect::>(), - Some(RelNodeContext { + RelNodeContext { group_id, expr_id: self.expr_id, children_group_ids: expr.children.clone(), - }), - Some(optimizer), + }, + optimizer, ); optimizer.update_group_info( group_id, @@ -197,8 +197,8 @@ impl> Task for OptimizeInputsTask { &expr.typ, &preds, &input_statistics_ref, - Some(context.clone()), - Some(optimizer), + context.clone(), + optimizer, ); let total_cost = cost.sum(&operation_cost, &input_cost); diff --git a/optd-core/src/cost.rs b/optd-core/src/cost.rs index f68386d1..fa62f232 100644 --- a/optd-core/src/cost.rs +++ b/optd-core/src/cost.rs @@ -16,25 +16,29 @@ pub struct Statistics(pub Box); pub struct Cost(pub Vec); pub trait CostModel>: 'static + Send + Sync { - /// Compute the cost of a single operation + /// Compute the cost of a single operation. `RelNodeContext` might be + /// optional in the future when we implement physical property enforcers. + /// If we have not decided the winner for a child group yet, the statistics + /// for that group will be `None`. #[allow(clippy::too_many_arguments)] fn compute_operation_cost( &self, node: &T, predicates: &[ArcPredNode], children_stats: &[Option<&Statistics>], - context: Option, - optimizer: Option<&CascadesOptimizer>, + context: RelNodeContext, + optimizer: &CascadesOptimizer, ) -> Cost; - /// Derive the statistics of a single operation + /// Derive the statistics of a single operation. `RelNodeContext` might be + /// optional in the future when we implement physical property enforcers. fn derive_statistics( &self, node: &T, predicates: &[ArcPredNode], children_stats: &[&Statistics], - context: Option, - optimizer: Option<&CascadesOptimizer>, + context: RelNodeContext, + optimizer: &CascadesOptimizer, ) -> Statistics; fn explain_cost(&self, cost: &Cost) -> String; diff --git a/optd-datafusion-repr-adv-cost/src/lib.rs b/optd-datafusion-repr-adv-cost/src/lib.rs index 8d409ccd..60ab35e6 100644 --- a/optd-datafusion-repr-adv-cost/src/lib.rs +++ b/optd-datafusion-repr-adv-cost/src/lib.rs @@ -61,8 +61,8 @@ impl CostModel> for AdvancedCostModel { node: &DfNodeType, predicates: &[ArcDfPredNode], children_stats: &[Option<&Statistics>], - context: Option, - optimizer: Option<&CascadesOptimizer>, + context: RelNodeContext, + optimizer: &CascadesOptimizer, ) -> Cost { self.base_model .compute_operation_cost(node, predicates, children_stats, context, optimizer) @@ -73,11 +73,9 @@ impl CostModel> for AdvancedCostModel { node: &DfNodeType, predicates: &[ArcDfPredNode], children_stats: &[&Statistics], - context: Option, - optimizer: Option<&CascadesOptimizer>, + context: RelNodeContext, + optimizer: &CascadesOptimizer, ) -> Statistics { - let context = context.as_ref(); - let optimizer = optimizer.as_ref(); let row_cnts = children_stats .iter() .map(|child| DfCostModel::row_cnt(child)) @@ -100,12 +98,8 @@ impl CostModel> for AdvancedCostModel { DfCostModel::stat(row_cnt) } DfNodeType::PhysicalFilter => { - let output_schema = optimizer - .unwrap() - .get_schema_of(context.unwrap().group_id.into()); - let output_column_ref = optimizer - .unwrap() - .get_column_ref_of(context.unwrap().group_id.into()); + let output_schema = optimizer.get_schema_of(context.group_id.into()); + let output_column_ref = optimizer.get_column_ref_of(context.group_id.into()); let row_cnt = self.stats.get_filter_row_cnt( row_cnts[0], output_schema, @@ -115,18 +109,12 @@ impl CostModel> for AdvancedCostModel { DfCostModel::stat(row_cnt) } DfNodeType::PhysicalNestedLoopJoin(join_typ) => { - let output_schema = optimizer - .unwrap() - .get_schema_of(context.unwrap().group_id.into()); - let output_column_ref = optimizer - .unwrap() - .get_column_ref_of(context.unwrap().group_id.into()); - let left_column_ref = optimizer - .unwrap() - .get_column_ref_of(context.unwrap().children_group_ids[0].into()); - let right_column_ref = optimizer - .unwrap() - .get_column_ref_of(context.unwrap().children_group_ids[1].into()); + let output_schema = optimizer.get_schema_of(context.group_id.into()); + let output_column_ref = optimizer.get_column_ref_of(context.group_id.into()); + let left_column_ref = + optimizer.get_column_ref_of(context.children_group_ids[0].into()); + let right_column_ref = + optimizer.get_column_ref_of(context.children_group_ids[1].into()); let row_cnt = self.stats.get_nlj_row_cnt( *join_typ, row_cnts[0], @@ -140,18 +128,12 @@ impl CostModel> for AdvancedCostModel { DfCostModel::stat(row_cnt) } DfNodeType::PhysicalHashJoin(join_typ) => { - let output_schema = optimizer - .unwrap() - .get_schema_of(context.unwrap().group_id.into()); - let output_column_ref = optimizer - .unwrap() - .get_column_ref_of(context.unwrap().group_id.into()); - let left_column_ref = optimizer - .unwrap() - .get_column_ref_of(context.unwrap().children_group_ids[0].into()); - let right_column_ref = optimizer - .unwrap() - .get_column_ref_of(context.unwrap().children_group_ids[1].into()); + let output_schema = optimizer.get_schema_of(context.group_id.into()); + let output_column_ref = optimizer.get_column_ref_of(context.group_id.into()); + let left_column_ref = + optimizer.get_column_ref_of(context.children_group_ids[0].into()); + let right_column_ref = + optimizer.get_column_ref_of(context.children_group_ids[1].into()); let row_cnt = self.stats.get_hash_join_row_cnt( *join_typ, row_cnts[0], @@ -166,9 +148,7 @@ impl CostModel> for AdvancedCostModel { DfCostModel::stat(row_cnt) } DfNodeType::PhysicalAgg => { - let output_column_ref = optimizer - .unwrap() - .get_column_ref_of(context.unwrap().group_id.into()); + let output_column_ref = optimizer.get_column_ref_of(context.group_id.into()); let row_cnt = self .stats .get_agg_row_cnt(predicates[1].clone(), output_column_ref); @@ -178,8 +158,8 @@ impl CostModel> for AdvancedCostModel { node, predicates, children_stats, - context.cloned(), - optimizer.copied(), + context, + optimizer, ), } } diff --git a/optd-datafusion-repr/src/cost/adaptive_cost.rs b/optd-datafusion-repr/src/cost/adaptive_cost.rs index 653d48d0..8dfb3ac1 100644 --- a/optd-datafusion-repr/src/cost/adaptive_cost.rs +++ b/optd-datafusion-repr/src/cost/adaptive_cost.rs @@ -28,11 +28,9 @@ pub struct AdaptiveCostModel { } impl AdaptiveCostModel { - fn get_row_cnt(&self, context: &Option) -> f64 { + fn get_row_cnt(&self, context: &RelNodeContext) -> f64 { let guard = self.runtime_row_cnt.lock().unwrap(); - if let Some((runtime_row_cnt, iter)) = - guard.history.get(&context.as_ref().unwrap().group_id) - { + if let Some((runtime_row_cnt, iter)) = guard.history.get(&context.group_id) { if *iter + self.decay >= guard.iter_cnt { return (*runtime_row_cnt).max(1) as f64; } @@ -67,8 +65,8 @@ impl CostModel> for AdaptiveCostModel { node: &DfNodeType, predicates: &[ArcDfPredNode], children: &[Option<&Statistics>], - context: Option, - optimizer: Option<&CascadesOptimizer>, + context: RelNodeContext, + optimizer: &CascadesOptimizer, ) -> Cost { if let DfNodeType::PhysicalScan = node { let row_cnt = self.get_row_cnt(&context); @@ -83,8 +81,8 @@ impl CostModel> for AdaptiveCostModel { node: &DfNodeType, predicates: &[ArcDfPredNode], children: &[&Statistics], - context: Option, - optimizer: Option<&CascadesOptimizer>, + context: RelNodeContext, + optimizer: &CascadesOptimizer, ) -> Statistics { if let DfNodeType::PhysicalScan = node { let row_cnt = self.get_row_cnt(&context); diff --git a/optd-datafusion-repr/src/cost/base_cost.rs b/optd-datafusion-repr/src/cost/base_cost.rs index 9e5dfd88..83607b48 100644 --- a/optd-datafusion-repr/src/cost/base_cost.rs +++ b/optd-datafusion-repr/src/cost/base_cost.rs @@ -89,8 +89,8 @@ impl CostModel> for DfCostModel { node: &DfNodeType, predicates: &[ArcDfPredNode], children: &[&Statistics], - _context: Option, - _optimizer: Option<&CascadesOptimizer>, + _context: RelNodeContext, + _optimizer: &CascadesOptimizer, ) -> Statistics { match node { DfNodeType::PhysicalScan => { @@ -132,8 +132,8 @@ impl CostModel> for DfCostModel { node: &DfNodeType, predicates: &[ArcDfPredNode], children: &[Option<&Statistics>], - _context: Option, - _optimizer: Option<&CascadesOptimizer>, + _context: RelNodeContext, + _optimizer: &CascadesOptimizer, ) -> Cost { let row_cnts = children .iter() diff --git a/optd-datafusion-repr/src/testing/dummy_cost.rs b/optd-datafusion-repr/src/testing/dummy_cost.rs index c5f24bd0..6f860fc4 100644 --- a/optd-datafusion-repr/src/testing/dummy_cost.rs +++ b/optd-datafusion-repr/src/testing/dummy_cost.rs @@ -19,8 +19,8 @@ impl CostModel> for DummyCostModel { _: &DfNodeType, _: &[ArcDfPredNode], _: &[Option<&Statistics>], - _: Option, - _: Option<&CascadesOptimizer>, + _: RelNodeContext, + _: &CascadesOptimizer, ) -> Cost { Cost(vec![1.0]) } @@ -31,8 +31,8 @@ impl CostModel> for DummyCostModel { _: &DfNodeType, _: &[ArcDfPredNode], _: &[&Statistics], - _: Option, - _: Option<&CascadesOptimizer>, + _: RelNodeContext, + _: &CascadesOptimizer, ) -> Statistics { Statistics(Box::new(())) }