From 973e80c5b39fcb0fed6e8ccd09af3923d7747261 Mon Sep 17 00:00:00 2001 From: Yuchen Liang Date: Sun, 22 Dec 2024 11:03:29 -0500 Subject: [PATCH] use define_impl_rule_discriminant Signed-off-by: Yuchen Liang --- optd-datafusion-repr/src/lib.rs | 6 +- optd-datafusion-repr/src/rules/joins.rs | 239 +---------------------- optd-datafusion-repr/src/rules/macros.rs | 13 +- 3 files changed, 18 insertions(+), 240 deletions(-) diff --git a/optd-datafusion-repr/src/lib.rs b/optd-datafusion-repr/src/lib.rs index a85bc19c..d7d17eec 100644 --- a/optd-datafusion-repr/src/lib.rs +++ b/optd-datafusion-repr/src/lib.rs @@ -105,9 +105,7 @@ impl DatafusionOptimizer { rule_wrappers.push(Arc::new(rules::FilterInnerJoinTransposeRule::new())); rule_wrappers.push(Arc::new(rules::FilterSortTransposeRule::new())); rule_wrappers.push(Arc::new(rules::FilterAggTransposeRule::new())); - rule_wrappers.push(Arc::new(rules::HashJoinInnerRule::new())); - rule_wrappers.push(Arc::new(rules::HashJoinLeftOuterRule::new())); - rule_wrappers.push(Arc::new(rules::HashJoinLeftMarkRule::new())); + rule_wrappers.push(Arc::new(rules::HashJoinRule::new())); rule_wrappers.push(Arc::new(rules::JoinCommuteRule::new())); rule_wrappers.push(Arc::new(rules::JoinAssocRule::new())); rule_wrappers.push(Arc::new(rules::ProjectionPullUpJoin::new())); @@ -179,7 +177,7 @@ impl DatafusionOptimizer { for rule in rules { rule_wrappers.push(rule); } - rule_wrappers.push(Arc::new(rules::HashJoinInnerRule::new())); + rule_wrappers.push(Arc::new(rules::HashJoinRule::new())); rule_wrappers.insert(0, Arc::new(rules::JoinCommuteRule::new())); rule_wrappers.insert(1, Arc::new(rules::JoinAssocRule::new())); rule_wrappers.insert(2, Arc::new(rules::ProjectionPullUpJoin::new())); diff --git a/optd-datafusion-repr/src/rules/joins.rs b/optd-datafusion-repr/src/rules/joins.rs index a041677c..432c06d3 100644 --- a/optd-datafusion-repr/src/rules/joins.rs +++ b/optd-datafusion-repr/src/rules/joins.rs @@ -9,7 +9,7 @@ use optd_core::nodes::PlanNodeOrGroup; use optd_core::optimizer::Optimizer; use optd_core::rules::{Rule, RuleMatcher}; -use super::macros::{define_impl_rule, define_rule}; +use super::macros::{define_impl_rule_discriminant, define_rule}; use crate::plan_nodes::{ ArcDfPlanNode, BinOpPred, BinOpType, ColumnRefPred, ConstantPred, ConstantType, DfNodeType, DfPredType, DfReprPlanNode, DfReprPredNode, JoinType, ListPred, LogOpType, @@ -140,241 +140,14 @@ fn apply_join_assoc( vec![node.into_plan_node().into()] } -define_impl_rule!( - HashJoinInnerRule, - apply_hash_join_inner, +// Note: this matches all join types despite using `JoinType::Inner` below. +define_impl_rule_discriminant!( + HashJoinRule, + apply_hash_join, (Join(JoinType::Inner), left, right) ); -fn apply_hash_join_inner( - optimizer: &impl Optimizer, - binding: ArcDfPlanNode, -) -> Vec> { - let join = LogicalJoin::from_plan_node(binding).unwrap(); - let cond = join.cond(); - let left = join.left(); - let right = join.right(); - let join_type = join.join_type(); - match cond.typ { - DfPredType::BinOp(BinOpType::Eq) => { - let left_schema = optimizer.get_schema_of(left.clone()); - let op = BinOpPred::from_pred_node(cond.clone()).unwrap(); - let left_expr = op.left_child(); - let right_expr = op.right_child(); - let Some(mut left_expr) = ColumnRefPred::from_pred_node(left_expr) else { - return vec![]; - }; - let Some(mut right_expr) = ColumnRefPred::from_pred_node(right_expr) else { - return vec![]; - }; - let can_convert = if left_expr.index() < left_schema.len() - && right_expr.index() >= left_schema.len() - { - true - } else if right_expr.index() < left_schema.len() - && left_expr.index() >= left_schema.len() - { - (left_expr, right_expr) = (right_expr, left_expr); - true - } else { - false - }; - - if can_convert { - let right_expr = ColumnRefPred::new(right_expr.index() - left_schema.len()); - let node = PhysicalHashJoin::new_unchecked( - left, - right, - ListPred::new(vec![left_expr.into_pred_node()]), - ListPred::new(vec![right_expr.into_pred_node()]), - *join_type, - ); - return vec![node.into_plan_node().into()]; - } - } - DfPredType::LogOp(LogOpType::And) => { - // currently only support consecutive equal queries - let mut is_consecutive_eq = true; - for child in cond.children.clone() { - if let DfPredType::BinOp(BinOpType::Eq) = child.typ { - continue; - } else { - is_consecutive_eq = false; - break; - } - } - if !is_consecutive_eq { - return vec![]; - } - - let left_schema = optimizer.get_schema_of(left.clone()); - let mut left_exprs = vec![]; - let mut right_exprs = vec![]; - for child in &cond.children { - let bin_op = BinOpPred::from_pred_node(child.clone()).unwrap(); - let left_expr = bin_op.left_child(); - let right_expr = bin_op.right_child(); - let Some(mut left_expr) = ColumnRefPred::from_pred_node(left_expr) else { - return vec![]; - }; - let Some(mut right_expr) = ColumnRefPred::from_pred_node(right_expr) else { - return vec![]; - }; - let can_convert = if left_expr.index() < left_schema.len() - && right_expr.index() >= left_schema.len() - { - true - } else if right_expr.index() < left_schema.len() - && left_expr.index() >= left_schema.len() - { - (left_expr, right_expr) = (right_expr, left_expr); - true - } else { - false - }; - if !can_convert { - return vec![]; - } - let right_expr = ColumnRefPred::new(right_expr.index() - left_schema.len()); - right_exprs.push(right_expr.into_pred_node()); - left_exprs.push(left_expr.into_pred_node()); - } - - let node = PhysicalHashJoin::new_unchecked( - left, - right, - ListPred::new(left_exprs), - ListPred::new(right_exprs), - *join_type, - ); - return vec![node.into_plan_node().into()]; - } - _ => {} - } - vec![] -} - -define_impl_rule!( - HashJoinLeftOuterRule, - apply_hash_join_left_outer, - (Join(JoinType::LeftOuter), left, right) -); - -fn apply_hash_join_left_outer( - optimizer: &impl Optimizer, - binding: ArcDfPlanNode, -) -> Vec> { - let join = LogicalJoin::from_plan_node(binding).unwrap(); - let cond = join.cond(); - let left = join.left(); - let right = join.right(); - let join_type = join.join_type(); - match cond.typ { - DfPredType::BinOp(BinOpType::Eq) => { - let left_schema = optimizer.get_schema_of(left.clone()); - let op = BinOpPred::from_pred_node(cond.clone()).unwrap(); - let left_expr = op.left_child(); - let right_expr = op.right_child(); - let Some(mut left_expr) = ColumnRefPred::from_pred_node(left_expr) else { - return vec![]; - }; - let Some(mut right_expr) = ColumnRefPred::from_pred_node(right_expr) else { - return vec![]; - }; - let can_convert = if left_expr.index() < left_schema.len() - && right_expr.index() >= left_schema.len() - { - true - } else if right_expr.index() < left_schema.len() - && left_expr.index() >= left_schema.len() - { - (left_expr, right_expr) = (right_expr, left_expr); - true - } else { - false - }; - - if can_convert { - let right_expr = ColumnRefPred::new(right_expr.index() - left_schema.len()); - let node = PhysicalHashJoin::new_unchecked( - left, - right, - ListPred::new(vec![left_expr.into_pred_node()]), - ListPred::new(vec![right_expr.into_pred_node()]), - *join_type, - ); - return vec![node.into_plan_node().into()]; - } - } - DfPredType::LogOp(LogOpType::And) => { - // currently only support consecutive equal queries - let mut is_consecutive_eq = true; - for child in cond.children.clone() { - if let DfPredType::BinOp(BinOpType::Eq) = child.typ { - continue; - } else { - is_consecutive_eq = false; - break; - } - } - if !is_consecutive_eq { - return vec![]; - } - - let left_schema = optimizer.get_schema_of(left.clone()); - let mut left_exprs = vec![]; - let mut right_exprs = vec![]; - for child in &cond.children { - let bin_op = BinOpPred::from_pred_node(child.clone()).unwrap(); - let left_expr = bin_op.left_child(); - let right_expr = bin_op.right_child(); - let Some(mut left_expr) = ColumnRefPred::from_pred_node(left_expr) else { - return vec![]; - }; - let Some(mut right_expr) = ColumnRefPred::from_pred_node(right_expr) else { - return vec![]; - }; - let can_convert = if left_expr.index() < left_schema.len() - && right_expr.index() >= left_schema.len() - { - true - } else if right_expr.index() < left_schema.len() - && left_expr.index() >= left_schema.len() - { - (left_expr, right_expr) = (right_expr, left_expr); - true - } else { - false - }; - if !can_convert { - return vec![]; - } - let right_expr = ColumnRefPred::new(right_expr.index() - left_schema.len()); - right_exprs.push(right_expr.into_pred_node()); - left_exprs.push(left_expr.into_pred_node()); - } - - let node = PhysicalHashJoin::new_unchecked( - left, - right, - ListPred::new(left_exprs), - ListPred::new(right_exprs), - *join_type, - ); - return vec![node.into_plan_node().into()]; - } - _ => {} - } - vec![] -} - -define_impl_rule!( - HashJoinLeftMarkRule, - apply_hash_join_left_mark, - (Join(JoinType::LeftMark), left, right) -); - -fn apply_hash_join_left_mark( +fn apply_hash_join( optimizer: &impl Optimizer, binding: ArcDfPlanNode, ) -> Vec> { diff --git a/optd-datafusion-repr/src/rules/macros.rs b/optd-datafusion-repr/src/rules/macros.rs index 420e2963..412f83c9 100644 --- a/optd-datafusion-repr/src/rules/macros.rs +++ b/optd-datafusion-repr/src/rules/macros.rs @@ -79,12 +79,19 @@ macro_rules! define_rule_discriminant { }; } -macro_rules! define_impl_rule { +// macro_rules! define_impl_rule { +// ($name:ident, $apply:ident, $($matcher:tt)+) => { +// crate::rules::macros::define_rule_inner! { true, false, $name, $apply, $($matcher)+ } +// }; +// } + +macro_rules! define_impl_rule_discriminant { ($name:ident, $apply:ident, $($matcher:tt)+) => { - crate::rules::macros::define_rule_inner! { true, false, $name, $apply, $($matcher)+ } + crate::rules::macros::define_rule_inner! { true, true, $name, $apply, $($matcher)+ } }; } pub(crate) use { - define_impl_rule, define_matcher, define_rule, define_rule_discriminant, define_rule_inner, + define_impl_rule_discriminant, define_matcher, define_rule, define_rule_discriminant, + define_rule_inner, };