Skip to content

Commit

Permalink
use define_impl_rule_discriminant
Browse files Browse the repository at this point in the history
Signed-off-by: Yuchen Liang <[email protected]>
  • Loading branch information
yliang412 committed Dec 22, 2024
1 parent fa66727 commit 973e80c
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 240 deletions.
6 changes: 2 additions & 4 deletions optd-datafusion-repr/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -105,9 +105,7 @@ impl DatafusionOptimizer {
rule_wrappers.push(Arc::new(rules::FilterInnerJoinTransposeRule::new()));
rule_wrappers.push(Arc::new(rules::FilterSortTransposeRule::new()));
rule_wrappers.push(Arc::new(rules::FilterAggTransposeRule::new()));
rule_wrappers.push(Arc::new(rules::HashJoinInnerRule::new()));
rule_wrappers.push(Arc::new(rules::HashJoinLeftOuterRule::new()));
rule_wrappers.push(Arc::new(rules::HashJoinLeftMarkRule::new()));
rule_wrappers.push(Arc::new(rules::HashJoinRule::new()));
rule_wrappers.push(Arc::new(rules::JoinCommuteRule::new()));
rule_wrappers.push(Arc::new(rules::JoinAssocRule::new()));
rule_wrappers.push(Arc::new(rules::ProjectionPullUpJoin::new()));
Expand Down Expand Up @@ -179,7 +177,7 @@ impl DatafusionOptimizer {
for rule in rules {
rule_wrappers.push(rule);
}
rule_wrappers.push(Arc::new(rules::HashJoinInnerRule::new()));
rule_wrappers.push(Arc::new(rules::HashJoinRule::new()));
rule_wrappers.insert(0, Arc::new(rules::JoinCommuteRule::new()));
rule_wrappers.insert(1, Arc::new(rules::JoinAssocRule::new()));
rule_wrappers.insert(2, Arc::new(rules::ProjectionPullUpJoin::new()));
Expand Down
239 changes: 6 additions & 233 deletions optd-datafusion-repr/src/rules/joins.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use optd_core::nodes::PlanNodeOrGroup;
use optd_core::optimizer::Optimizer;
use optd_core::rules::{Rule, RuleMatcher};

use super::macros::{define_impl_rule, define_rule};
use super::macros::{define_impl_rule_discriminant, define_rule};
use crate::plan_nodes::{
ArcDfPlanNode, BinOpPred, BinOpType, ColumnRefPred, ConstantPred, ConstantType, DfNodeType,
DfPredType, DfReprPlanNode, DfReprPredNode, JoinType, ListPred, LogOpType,
Expand Down Expand Up @@ -140,241 +140,14 @@ fn apply_join_assoc(
vec![node.into_plan_node().into()]
}

define_impl_rule!(
HashJoinInnerRule,
apply_hash_join_inner,
// Note: this matches all join types despite using `JoinType::Inner` below.
define_impl_rule_discriminant!(
HashJoinRule,
apply_hash_join,
(Join(JoinType::Inner), left, right)
);

fn apply_hash_join_inner(
optimizer: &impl Optimizer<DfNodeType>,
binding: ArcDfPlanNode,
) -> Vec<PlanNodeOrGroup<DfNodeType>> {
let join = LogicalJoin::from_plan_node(binding).unwrap();
let cond = join.cond();
let left = join.left();
let right = join.right();
let join_type = join.join_type();
match cond.typ {
DfPredType::BinOp(BinOpType::Eq) => {
let left_schema = optimizer.get_schema_of(left.clone());
let op = BinOpPred::from_pred_node(cond.clone()).unwrap();
let left_expr = op.left_child();
let right_expr = op.right_child();
let Some(mut left_expr) = ColumnRefPred::from_pred_node(left_expr) else {
return vec![];
};
let Some(mut right_expr) = ColumnRefPred::from_pred_node(right_expr) else {
return vec![];
};
let can_convert = if left_expr.index() < left_schema.len()
&& right_expr.index() >= left_schema.len()
{
true
} else if right_expr.index() < left_schema.len()
&& left_expr.index() >= left_schema.len()
{
(left_expr, right_expr) = (right_expr, left_expr);
true
} else {
false
};

if can_convert {
let right_expr = ColumnRefPred::new(right_expr.index() - left_schema.len());
let node = PhysicalHashJoin::new_unchecked(
left,
right,
ListPred::new(vec![left_expr.into_pred_node()]),
ListPred::new(vec![right_expr.into_pred_node()]),
*join_type,
);
return vec![node.into_plan_node().into()];
}
}
DfPredType::LogOp(LogOpType::And) => {
// currently only support consecutive equal queries
let mut is_consecutive_eq = true;
for child in cond.children.clone() {
if let DfPredType::BinOp(BinOpType::Eq) = child.typ {
continue;
} else {
is_consecutive_eq = false;
break;
}
}
if !is_consecutive_eq {
return vec![];
}

let left_schema = optimizer.get_schema_of(left.clone());
let mut left_exprs = vec![];
let mut right_exprs = vec![];
for child in &cond.children {
let bin_op = BinOpPred::from_pred_node(child.clone()).unwrap();
let left_expr = bin_op.left_child();
let right_expr = bin_op.right_child();
let Some(mut left_expr) = ColumnRefPred::from_pred_node(left_expr) else {
return vec![];
};
let Some(mut right_expr) = ColumnRefPred::from_pred_node(right_expr) else {
return vec![];
};
let can_convert = if left_expr.index() < left_schema.len()
&& right_expr.index() >= left_schema.len()
{
true
} else if right_expr.index() < left_schema.len()
&& left_expr.index() >= left_schema.len()
{
(left_expr, right_expr) = (right_expr, left_expr);
true
} else {
false
};
if !can_convert {
return vec![];
}
let right_expr = ColumnRefPred::new(right_expr.index() - left_schema.len());
right_exprs.push(right_expr.into_pred_node());
left_exprs.push(left_expr.into_pred_node());
}

let node = PhysicalHashJoin::new_unchecked(
left,
right,
ListPred::new(left_exprs),
ListPred::new(right_exprs),
*join_type,
);
return vec![node.into_plan_node().into()];
}
_ => {}
}
vec![]
}

define_impl_rule!(
HashJoinLeftOuterRule,
apply_hash_join_left_outer,
(Join(JoinType::LeftOuter), left, right)
);

fn apply_hash_join_left_outer(
optimizer: &impl Optimizer<DfNodeType>,
binding: ArcDfPlanNode,
) -> Vec<PlanNodeOrGroup<DfNodeType>> {
let join = LogicalJoin::from_plan_node(binding).unwrap();
let cond = join.cond();
let left = join.left();
let right = join.right();
let join_type = join.join_type();
match cond.typ {
DfPredType::BinOp(BinOpType::Eq) => {
let left_schema = optimizer.get_schema_of(left.clone());
let op = BinOpPred::from_pred_node(cond.clone()).unwrap();
let left_expr = op.left_child();
let right_expr = op.right_child();
let Some(mut left_expr) = ColumnRefPred::from_pred_node(left_expr) else {
return vec![];
};
let Some(mut right_expr) = ColumnRefPred::from_pred_node(right_expr) else {
return vec![];
};
let can_convert = if left_expr.index() < left_schema.len()
&& right_expr.index() >= left_schema.len()
{
true
} else if right_expr.index() < left_schema.len()
&& left_expr.index() >= left_schema.len()
{
(left_expr, right_expr) = (right_expr, left_expr);
true
} else {
false
};

if can_convert {
let right_expr = ColumnRefPred::new(right_expr.index() - left_schema.len());
let node = PhysicalHashJoin::new_unchecked(
left,
right,
ListPred::new(vec![left_expr.into_pred_node()]),
ListPred::new(vec![right_expr.into_pred_node()]),
*join_type,
);
return vec![node.into_plan_node().into()];
}
}
DfPredType::LogOp(LogOpType::And) => {
// currently only support consecutive equal queries
let mut is_consecutive_eq = true;
for child in cond.children.clone() {
if let DfPredType::BinOp(BinOpType::Eq) = child.typ {
continue;
} else {
is_consecutive_eq = false;
break;
}
}
if !is_consecutive_eq {
return vec![];
}

let left_schema = optimizer.get_schema_of(left.clone());
let mut left_exprs = vec![];
let mut right_exprs = vec![];
for child in &cond.children {
let bin_op = BinOpPred::from_pred_node(child.clone()).unwrap();
let left_expr = bin_op.left_child();
let right_expr = bin_op.right_child();
let Some(mut left_expr) = ColumnRefPred::from_pred_node(left_expr) else {
return vec![];
};
let Some(mut right_expr) = ColumnRefPred::from_pred_node(right_expr) else {
return vec![];
};
let can_convert = if left_expr.index() < left_schema.len()
&& right_expr.index() >= left_schema.len()
{
true
} else if right_expr.index() < left_schema.len()
&& left_expr.index() >= left_schema.len()
{
(left_expr, right_expr) = (right_expr, left_expr);
true
} else {
false
};
if !can_convert {
return vec![];
}
let right_expr = ColumnRefPred::new(right_expr.index() - left_schema.len());
right_exprs.push(right_expr.into_pred_node());
left_exprs.push(left_expr.into_pred_node());
}

let node = PhysicalHashJoin::new_unchecked(
left,
right,
ListPred::new(left_exprs),
ListPred::new(right_exprs),
*join_type,
);
return vec![node.into_plan_node().into()];
}
_ => {}
}
vec![]
}

define_impl_rule!(
HashJoinLeftMarkRule,
apply_hash_join_left_mark,
(Join(JoinType::LeftMark), left, right)
);

fn apply_hash_join_left_mark(
fn apply_hash_join(
optimizer: &impl Optimizer<DfNodeType>,
binding: ArcDfPlanNode,
) -> Vec<PlanNodeOrGroup<DfNodeType>> {
Expand Down
13 changes: 10 additions & 3 deletions optd-datafusion-repr/src/rules/macros.rs
Original file line number Diff line number Diff line change
Expand Up @@ -79,12 +79,19 @@ macro_rules! define_rule_discriminant {
};
}

macro_rules! define_impl_rule {
// macro_rules! define_impl_rule {
// ($name:ident, $apply:ident, $($matcher:tt)+) => {
// crate::rules::macros::define_rule_inner! { true, false, $name, $apply, $($matcher)+ }
// };
// }

macro_rules! define_impl_rule_discriminant {
($name:ident, $apply:ident, $($matcher:tt)+) => {
crate::rules::macros::define_rule_inner! { true, false, $name, $apply, $($matcher)+ }
crate::rules::macros::define_rule_inner! { true, true, $name, $apply, $($matcher)+ }
};
}

pub(crate) use {
define_impl_rule, define_matcher, define_rule, define_rule_discriminant, define_rule_inner,
define_impl_rule_discriminant, define_matcher, define_rule, define_rule_discriminant,
define_rule_inner,
};

0 comments on commit 973e80c

Please sign in to comment.