Skip to content

Commit

Permalink
fix(df-repr): join assoc rule expr, rm exploding rules (#223)
Browse files Browse the repository at this point in the history
Signed-off-by: Alex Chi <[email protected]>
  • Loading branch information
skyzh committed Nov 7, 2024
1 parent 4fec4eb commit 4bb311e
Show file tree
Hide file tree
Showing 13 changed files with 338 additions and 341 deletions.
7 changes: 1 addition & 6 deletions optd-datafusion-repr/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -121,15 +121,9 @@ impl DatafusionOptimizer {
rule_wrappers.push(RuleWrapper::new_cascades(Arc::new(
rules::JoinCommuteRule::new(),
)));
rule_wrappers.push(RuleWrapper::new_cascades(Arc::new(
rules::InnerCrossJoinRule::new(),
)));
rule_wrappers.push(RuleWrapper::new_cascades(Arc::new(
rules::JoinAssocRule::new(),
)));
rule_wrappers.push(RuleWrapper::new_cascades(Arc::new(
rules::JoinAbsorbFilterRule::new(),
)));
rule_wrappers.push(RuleWrapper::new_cascades(Arc::new(
rules::ProjectionPullUpJoin::new(),
)));
Expand Down Expand Up @@ -186,6 +180,7 @@ impl DatafusionOptimizer {
panic_on_budget: false,
partial_explore_iter: Some(1 << 20),
partial_explore_space: Some(1 << 10),
disable_pruning: false,
},
),
heuristic_optimizer: HeuristicsOptimizer::new_with_rules(
Expand Down
57 changes: 12 additions & 45 deletions optd-datafusion-repr/src/rules/joins.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,52 +8,11 @@ use super::macros::{define_impl_rule, define_rule};
use crate::plan_nodes::{
ArcDfPlanNode, BinOpPred, BinOpType, ColumnRefPred, ConstantPred, ConstantType, DfNodeType,
DfPredType, DfReprPlanNode, DfReprPredNode, JoinType, ListPred, LogOpType,
LogicalEmptyRelation, LogicalFilter, LogicalJoin, LogicalProjection, PhysicalHashJoin, PredExt,
LogicalEmptyRelation, LogicalJoin, LogicalProjection, PhysicalHashJoin, PredExt,
};
use crate::properties::schema::Schema;
use crate::OptimizerExt;

// A cross join B -> A inner join B
define_rule!(
InnerCrossJoinRule,
apply_inner_cross_join,
(Join(JoinType::Cross), left, right)
);

fn apply_inner_cross_join(
_: &impl Optimizer<DfNodeType>,
binding: ArcDfPlanNode,
) -> Vec<PlanNodeOrGroup<DfNodeType>> {
let join = LogicalJoin::from_plan_node(binding).unwrap();
let node = LogicalJoin::new_unchecked(join.left(), join.right(), join.cond(), JoinType::Inner);
vec![node.into_plan_node().into()]
}

// Filter (A inner join B on true) cond -> A inner join B on cond
define_rule!(
JoinAbsorbFilterRule,
apply_join_absorb_filter,
(Filter, (Join(JoinType::Inner), left, right))
);

fn apply_join_absorb_filter(
_: &impl Optimizer<DfNodeType>,
binding: ArcDfPlanNode,
) -> Vec<PlanNodeOrGroup<DfNodeType>> {
let filter = LogicalFilter::from_plan_node(binding).unwrap();
let join = LogicalJoin::from_plan_node(filter.child().unwrap_plan_node()).unwrap();
let join_cond = join.cond();
let filter_cond = filter.cond();
if let Some(constant) = ConstantPred::from_pred_node(join_cond) {
if constant.value().as_bool() {
let node =
LogicalJoin::new_unchecked(join.left(), join.right(), filter_cond, JoinType::Inner);
return vec![node.into_plan_node().into()];
}
}
vec![]
}

// A join B -> B join A
define_rule!(
JoinCommuteRule,
Expand Down Expand Up @@ -112,7 +71,15 @@ fn apply_eliminate_join(
if let DfPredType::Constant(const_type) = cond.typ {
if const_type == ConstantType::Bool {
if let Some(ref data) = cond.data {
if !data.as_bool() {
if data.as_bool() {
let node = LogicalJoin::new_unchecked(
left,
right,
ConstantPred::bool(true).into_pred_node(),
JoinType::Cross,
);
return vec![node.into_plan_node().into()];
} else {
// No need to handle schema here, as all exprs in the same group
// will have same logical properties
let mut left_fields = optimizer.get_schema_of(left.clone()).fields;
Expand Down Expand Up @@ -146,9 +113,9 @@ fn apply_join_assoc(
let join2 = LogicalJoin::from_plan_node(join1.left().unwrap_plan_node()).unwrap();
let a = join2.left();
let b = join2.right();
let cond2 = join2.cond();
let cond1 = join2.cond();
let a_schema = optimizer.get_schema_of(a.clone());
let cond1 = join1.cond();
let cond2 = join1.cond();

let Some(cond2) = cond2.rewrite_column_refs(&mut |idx| {
if idx < a_schema.len() {
Expand Down
2 changes: 1 addition & 1 deletion optd-sqlplannertest/tests/basic/cross_product.planner.sql
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ LogicalProjection { exprs: [ #0, #1 ] }
└── LogicalJoin { join_type: Cross, cond: true }
├── LogicalScan { table: t1 }
└── LogicalScan { table: t2 }
PhysicalNestedLoopJoin { join_type: Inner, cond: true }
PhysicalNestedLoopJoin { join_type: Cross, cond: true }
├── PhysicalScan { table: t1 }
└── PhysicalScan { table: t2 }
0 0
Expand Down
6 changes: 3 additions & 3 deletions optd-sqlplannertest/tests/basic/filter.planner.sql
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ PhysicalFilter
│ └── Eq
│ ├── #0
│ └── #3
└── PhysicalNestedLoopJoin { join_type: Inner, cond: true }
└── PhysicalNestedLoopJoin { join_type: Cross, cond: true }
├── PhysicalScan { table: t1 }
└── PhysicalScan { table: t2 }
0 0 0 200
Expand All @@ -122,7 +122,7 @@ LogicalProjection { exprs: [ #0, #1, #2, #3 ] }
└── LogicalJoin { join_type: Cross, cond: true }
├── LogicalScan { table: t1 }
└── LogicalScan { table: t2 }
PhysicalNestedLoopJoin { join_type: Inner, cond: true }
PhysicalNestedLoopJoin { join_type: Cross, cond: true }
├── PhysicalScan { table: t1 }
└── PhysicalScan { table: t2 }
0 0 0 200
Expand Down Expand Up @@ -254,7 +254,7 @@ LogicalProjection { exprs: [ #0, #1, #2, #3 ] }
│ └── true
├── LogicalScan { table: t1 }
└── LogicalScan { table: t2 }
PhysicalNestedLoopJoin { join_type: Inner, cond: true }
PhysicalNestedLoopJoin { join_type: Cross, cond: true }
├── PhysicalScan { table: t1 }
└── PhysicalScan { table: t2 }
0 0 0 200
Expand Down
30 changes: 20 additions & 10 deletions optd-sqlplannertest/tests/joins/join_enumerate.planner.sql
Original file line number Diff line number Diff line change
Expand Up @@ -24,44 +24,54 @@ select * from t2, t1 where t1v1 = t2v1;
2 202 2 2
*/

-- Test whether the optimizer enumerates all 3-join orders.
-- Test whether the optimizer enumerates all 3-join orders. (It should)
select * from t2, t1, t3 where t1v1 = t2v1 and t1v1 = t3v2;

/*
(Join t2 (Join t1 t3))
(Join t2 (Join t3 t1))
(Join t3 (Join t1 t2))
(Join t3 (Join t2 t1))
(Join (Join t1 t2) t3)
(Join (Join t1 t3) t2)
(Join (Join t2 t1) t3)
(Join (Join t3 t1) t2)
0 200 0 0 0 300
1 201 1 1 1 301
2 202 2 2 2 302
*/

-- Test whether the optimizer enumerates all 3-join orders. (It don't currently)
select * from t2, t1, t3 where t1v1 = t2v1 and t1v2 = t3v2;

/*
(Join t1 (Join t2 t3))
(Join t1 (Join t3 t2))
(Join t2 (Join t1 t3))
(Join t2 (Join t3 t1))
(Join t3 (Join t1 t2))
(Join t3 (Join t2 t1))
(Join (Join t1 t2) t3)
(Join (Join t1 t3) t2)
(Join (Join t2 t1) t3)
(Join (Join t2 t3) t1)
(Join (Join t3 t1) t2)
(Join (Join t3 t2) t1)
0 200 0 0 0 300
1 201 1 1 1 301
2 202 2 2 2 302
*/

-- Test whether the optimizer enumerates all 3-join orders.
-- Test whether the optimizer enumerates all 3-join orders. (It don't currently)
select * from t1, t2, t3 where t1v1 = t2v1 and t1v2 = t3v2;

/*
(Join t1 (Join t2 t3))
(Join t1 (Join t3 t2))
(Join t2 (Join t1 t3))
(Join t2 (Join t3 t1))
(Join t3 (Join t1 t2))
(Join t3 (Join t2 t1))
(Join (Join t1 t2) t3)
(Join (Join t1 t3) t2)
(Join (Join t2 t1) t3)
(Join (Join t2 t3) t1)
(Join (Join t3 t1) t2)
(Join (Join t3 t2) t1)
0 0 0 200 0 300
1 1 1 201 1 301
Expand Down
17 changes: 12 additions & 5 deletions optd-sqlplannertest/tests/joins/join_enumerate.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,17 +11,24 @@
select * from t2, t1 where t1v1 = t2v1;
desc: Test whether the optimizer enumerates all 2-join orders.
tasks:
- explain:logical_join_orders
# well actually pruning doesn't matter b/c join order is logical, but we are now missing join orders with t1 as the outer table
- explain[disable_pruning]:logical_join_orders
- execute
- sql: |
select * from t2, t1, t3 where t1v1 = t2v1 and t1v1 = t3v2;
desc: Test whether the optimizer enumerates all 3-join orders. (It should)
tasks:
- explain[disable_pruning]:logical_join_orders
- execute
- sql: |
select * from t2, t1, t3 where t1v1 = t2v1 and t1v2 = t3v2;
desc: Test whether the optimizer enumerates all 3-join orders.
desc: Test whether the optimizer enumerates all 3-join orders. (It don't currently)
tasks:
- explain:logical_join_orders
- explain[disable_pruning]:logical_join_orders
- execute
- sql: |
select * from t1, t2, t3 where t1v1 = t2v1 and t1v2 = t3v2;
desc: Test whether the optimizer enumerates all 3-join orders.
desc: Test whether the optimizer enumerates all 3-join orders. (It don't currently)
tasks:
- explain:logical_join_orders
- explain[disable_pruning]:logical_join_orders
- execute
92 changes: 92 additions & 0 deletions optd-sqlplannertest/tests/joins/multi-join.planner.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
-- (no id or description)
create table t1(a int, b int);
create table t2(c int, d int);
create table t3(e int, f int);
create table t4(g int, h int);

/*
*/

-- test 3-way join
select * from t1, t2, t3 where a = c AND d = e;

/*
LogicalProjection { exprs: [ #0, #1, #2, #3, #4, #5 ] }
└── LogicalFilter
├── cond:And
│ ├── Eq
│ │ ├── #0
│ │ └── #2
│ └── Eq
│ ├── #3
│ └── #4
└── LogicalJoin { join_type: Cross, cond: true }
├── LogicalJoin { join_type: Cross, cond: true }
│ ├── LogicalScan { table: t1 }
│ └── LogicalScan { table: t2 }
└── LogicalScan { table: t3 }
PhysicalHashJoin { join_type: Inner, left_keys: [ #0 ], right_keys: [ #0 ] }
├── PhysicalScan { table: t1 }
└── PhysicalHashJoin { join_type: Inner, left_keys: [ #1 ], right_keys: [ #0 ] }
├── PhysicalScan { table: t2 }
└── PhysicalScan { table: t3 }
*/

-- test 3-way join
select * from t1, t2, t3 where a = c AND b = e;

/*
LogicalProjection { exprs: [ #0, #1, #2, #3, #4, #5 ] }
└── LogicalFilter
├── cond:And
│ ├── Eq
│ │ ├── #0
│ │ └── #2
│ └── Eq
│ ├── #1
│ └── #4
└── LogicalJoin { join_type: Cross, cond: true }
├── LogicalJoin { join_type: Cross, cond: true }
│ ├── LogicalScan { table: t1 }
│ └── LogicalScan { table: t2 }
└── LogicalScan { table: t3 }
PhysicalHashJoin { join_type: Inner, left_keys: [ #1 ], right_keys: [ #0 ] }
├── PhysicalHashJoin { join_type: Inner, left_keys: [ #0 ], right_keys: [ #0 ] }
│ ├── PhysicalScan { table: t1 }
│ └── PhysicalScan { table: t2 }
└── PhysicalScan { table: t3 }
*/

-- test 4-way join
select * from t1, t2, t3, t4 where a = c AND b = e AND f = g;

/*
LogicalProjection { exprs: [ #0, #1, #2, #3, #4, #5, #6, #7 ] }
└── LogicalFilter
├── cond:And
│ ├── Eq
│ │ ├── #0
│ │ └── #2
│ ├── Eq
│ │ ├── #1
│ │ └── #4
│ └── Eq
│ ├── #5
│ └── #6
└── LogicalJoin { join_type: Cross, cond: true }
├── LogicalJoin { join_type: Cross, cond: true }
│ ├── LogicalJoin { join_type: Cross, cond: true }
│ │ ├── LogicalScan { table: t1 }
│ │ └── LogicalScan { table: t2 }
│ └── LogicalScan { table: t3 }
└── LogicalScan { table: t4 }
PhysicalHashJoin { join_type: Inner, left_keys: [ #1 ], right_keys: [ #0 ] }
├── PhysicalHashJoin { join_type: Inner, left_keys: [ #0 ], right_keys: [ #0 ] }
│ ├── PhysicalScan { table: t1 }
│ └── PhysicalScan { table: t2 }
└── PhysicalHashJoin { join_type: Inner, left_keys: [ #1 ], right_keys: [ #0 ] }
├── PhysicalScan { table: t3 }
└── PhysicalScan { table: t4 }
*/

22 changes: 22 additions & 0 deletions optd-sqlplannertest/tests/joins/multi-join.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
- sql: |
create table t1(a int, b int);
create table t2(c int, d int);
create table t3(e int, f int);
create table t4(g int, h int);
tasks:
- execute
- sql: |
select * from t1, t2, t3 where a = c AND d = e;
desc: test 3-way join
tasks:
- explain:logical_optd,physical_optd
- sql: |
select * from t1, t2, t3 where a = c AND b = e;
desc: test 3-way join
tasks:
- explain:logical_optd,physical_optd
- sql: |
select * from t1, t2, t3, t4 where a = c AND b = e AND f = g;
desc: test 4-way join
tasks:
- explain:logical_optd,physical_optd
32 changes: 16 additions & 16 deletions optd-sqlplannertest/tests/subqueries/subquery_unnesting.planner.sql
Original file line number Diff line number Diff line change
Expand Up @@ -135,27 +135,27 @@ LogicalProjection { exprs: [ #0, #1 ] }
└── LogicalJoin { join_type: Cross, cond: true }
├── LogicalScan { table: t2 }
└── LogicalScan { table: t3 }
PhysicalProjection { exprs: [ #2, #3 ], cost: {compute=9023,io=4000}, stat: {row_cnt=1} }
└── PhysicalHashJoin { join_type: Inner, left_keys: [ #0 ], right_keys: [ #0 ], cost: {compute=9020,io=4000}, stat: {row_cnt=1} }
PhysicalProjection { exprs: [ #2, #3 ], cost: {compute=9021,io=4000}, stat: {row_cnt=1} }
└── PhysicalHashJoin { join_type: Inner, left_keys: [ #0 ], right_keys: [ #0 ], cost: {compute=9018,io=4000}, stat: {row_cnt=1} }
├── PhysicalAgg
│ ├── aggrs:Agg(Sum)
│ │ └── [ Cast { cast_to: Int64, child: #2 } ]
│ ├── groups: [ #1 ]
│ ├── cost: {compute=8018,io=3000}
│ ├── cost: {compute=8016,io=3000}
│ ├── stat: {row_cnt=1}
│ └── PhysicalProjection { exprs: [ #2, #0, #1, #3, #4 ], cost: {compute=8010,io=3000}, stat: {row_cnt=1} }
── PhysicalHashJoin { join_type: Inner, left_keys: [ #1 ], right_keys: [ #0 ], cost: {compute=8004,io=3000}, stat: {row_cnt=1} }
── PhysicalHashJoin { join_type: Inner, left_keys: [ #0 ], right_keys: [ #0 ], cost: {compute=7002,io=2000}, stat: {row_cnt=1} }
├── PhysicalFilter
│ ├── cond:Gt
│ │ ├── #0
│ │ └── 100(i64)
│ ├── cost: {compute=3000,io=1000}
│ ├── stat: {row_cnt=1}
│ └── PhysicalScan { table: t2, cost: {compute=0,io=1000}, stat: {row_cnt=1000} }
└── PhysicalAgg { aggrs: [], groups: [ #0 ], cost: {compute=3000,io=1000}, stat: {row_cnt=1000} }
└── PhysicalScan { table: t1, cost: {compute=0,io=1000}, stat: {row_cnt=1000} }
└── PhysicalScan { table: t3, cost: {compute=0,io=1000}, stat: {row_cnt=1000} }
│ └── PhysicalHashJoin { join_type: Inner, left_keys: [ #2 ], right_keys: [ #0 ], cost: {compute=8008,io=3000}, stat: {row_cnt=1} }
── PhysicalProjection { exprs: [ #2, #0, #1 ], cost: {compute=7006,io=2000}, stat: {row_cnt=1} }
── PhysicalHashJoin { join_type: Inner, left_keys: [ #0 ], right_keys: [ #0 ], cost: {compute=7002,io=2000}, stat: {row_cnt=1} }
├── PhysicalFilter
│ ├── cond:Gt
│ │ ├── #0
│ │ └── 100(i64)
│ ├── cost: {compute=3000,io=1000}
│ ├── stat: {row_cnt=1}
│ └── PhysicalScan { table: t2, cost: {compute=0,io=1000}, stat: {row_cnt=1000} }
└── PhysicalAgg { aggrs: [], groups: [ #0 ], cost: {compute=3000,io=1000}, stat: {row_cnt=1000} }
└── PhysicalScan { table: t1, cost: {compute=0,io=1000}, stat: {row_cnt=1000} }
│ └── PhysicalScan { table: t3, cost: {compute=0,io=1000}, stat: {row_cnt=1000} }
└── PhysicalScan { table: t1, cost: {compute=0,io=1000}, stat: {row_cnt=1000} }
*/

Loading

0 comments on commit 4bb311e

Please sign in to comment.