Skip to content

Commit

Permalink
Filter/Agg transpose bugfix (#252)
Browse files Browse the repository at this point in the history
The check for whether or not a filter expression could be pushed beyond
an agg node was incorrect.

It was checking if the column was in the group by columns (checking the
equality of the numbers), when it should be checking based on indices if
we are only referring to columns that are *emitted* from the agg node as
group by columns.

For example, if we see:
```
Filter #1 > 100
    Agg { groups: [#1], agg: Sum() }
```

We should *not* push down because `#1` refers to the sum column. In the
current main branch, it is pushed down because it sees that `#1` equals
a column in the `groups` field. It should be checking that every column
is `< groups.len()` instead.
  • Loading branch information
jurplel authored Dec 6, 2024
1 parent 2dd2a31 commit 5f26d36
Show file tree
Hide file tree
Showing 4 changed files with 60 additions and 39 deletions.
5 changes: 4 additions & 1 deletion optd-datafusion-repr/src/rules/filter_pushdown.rs
Original file line number Diff line number Diff line change
Expand Up @@ -346,7 +346,10 @@ fn apply_filter_agg_transpose(
let mut group_by_cols_only = true;
for child in children {
if let Some(col_ref) = ColumnRefPred::from_pred_node(child.clone()) {
if !group_cols.contains(&col_ref.index()) {
// The agg schema is (group columns) + (expr columns),
// so if the column ref is < group_cols.len(), it is
// a group column.
if col_ref.index() >= group_cols.len() {
group_by_cols_only = false;
break;
}
Expand Down
11 changes: 11 additions & 0 deletions optd-sqllogictest/slt/_basic_tables.slt.part
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
statement ok
create table t1(v1 int, v2 int);

statement ok
create table t2(v3 int, v4 int);

statement ok
insert into t1 values (1, 100), (2, 200), (2, 250), (3, 300), (3, 300);

statement ok
insert into t2 values (2, 200), (2, 250), (3, 300);
9 changes: 9 additions & 0 deletions optd-sqllogictest/slt/unnest-dup.slt
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
include _basic_tables.slt.part

query
select * from t1 where (select sum(v4) from t2 where v3 = v1) > 100;
----
2 200
2 250
3 300
3 300
74 changes: 36 additions & 38 deletions optd-sqlplannertest/tests/subqueries/subquery_unnesting.planner.sql
Original file line number Diff line number Diff line change
Expand Up @@ -53,25 +53,24 @@ LogicalProjection { exprs: [ #0, #1 ] }
├── LogicalAgg { exprs: [], groups: [ #0 ] }
│ └── LogicalScan { table: t1 }
└── LogicalScan { table: t2 }
PhysicalProjection { exprs: [ #2, #3 ], cost: {compute=8019,io=3000}, stat: {row_cnt=1} }
└── PhysicalHashJoin { join_type: Inner, left_keys: [ #0 ], right_keys: [ #0 ], cost: {compute=8016,io=3000}, stat: {row_cnt=1} }
├── PhysicalAgg
│ ├── aggrs:Agg(Sum)
│ │ ── [ Cast { cast_to: Int64, child: #2 } ]
├── groups: [ #1 ]
│ ├── cost: {compute=7014,io=2000}
PhysicalProjection { exprs: [ #2, #3 ], cost: {compute=18005,io=3000}, stat: {row_cnt=1} }
└── PhysicalHashJoin { join_type: Inner, left_keys: [ #0 ], right_keys: [ #0 ], cost: {compute=18002,io=3000}, stat: {row_cnt=1} }
├── PhysicalFilter
│ ├── cond:Gt
│ │ ── #1
│ └── 100(i64)
│ ├── cost: {compute=17000,io=2000}
│ ├── stat: {row_cnt=1}
│ └── PhysicalProjection { exprs: [ #2, #0, #1 ], cost: {compute=7006,io=2000}, stat: {row_cnt=1} }
│ └── PhysicalHashJoin { join_type: Inner, left_keys: [ #0 ], right_keys: [ #0 ], cost: {compute=7002,io=2000}, stat: {row_cnt=1} }
│ ├── PhysicalFilter
│ │ ├── cond:Gt
│ │ │ ├── #0
│ │ │ └── 100(i64)
│ │ ├── cost: {compute=3000,io=1000}
│ │ ├── stat: {row_cnt=1}
│ │ └── PhysicalScan { table: t2, cost: {compute=0,io=1000}, stat: {row_cnt=1000} }
│ └── PhysicalAgg { aggrs: [], groups: [ #0 ], cost: {compute=3000,io=1000}, stat: {row_cnt=1000} }
│ └── PhysicalScan { table: t1, cost: {compute=0,io=1000}, stat: {row_cnt=1000} }
│ └── PhysicalAgg
│ ├── aggrs:Agg(Sum)
│ │ └── [ Cast { cast_to: Int64, child: #2 } ]
│ ├── groups: [ #1 ]
│ ├── cost: {compute=14000,io=2000}
│ ├── stat: {row_cnt=1000}
│ └── PhysicalHashJoin { join_type: Inner, left_keys: [ #0 ], right_keys: [ #0 ], cost: {compute=6000,io=2000}, stat: {row_cnt=1000} }
│ ├── PhysicalAgg { aggrs: [], groups: [ #0 ], cost: {compute=3000,io=1000}, stat: {row_cnt=1000} }
│ │ └── PhysicalScan { table: t1, cost: {compute=0,io=1000}, stat: {row_cnt=1000} }
│ └── PhysicalScan { table: t2, cost: {compute=0,io=1000}, stat: {row_cnt=1000} }
└── PhysicalScan { table: t1, cost: {compute=0,io=1000}, stat: {row_cnt=1000} }
*/

Expand Down Expand Up @@ -135,27 +134,26 @@ LogicalProjection { exprs: [ #0, #1 ] }
└── LogicalJoin { join_type: Cross, cond: true }
├── LogicalScan { table: t2 }
└── LogicalScan { table: t3 }
PhysicalProjection { exprs: [ #2, #3 ], cost: {compute=9021,io=4000}, stat: {row_cnt=1} }
└── PhysicalHashJoin { join_type: Inner, left_keys: [ #0 ], right_keys: [ #0 ], cost: {compute=9018,io=4000}, stat: {row_cnt=1} }
├── PhysicalAgg
│ ├── aggrs:Agg(Sum)
│ │ ── [ Cast { cast_to: Int64, child: #2 } ]
├── groups: [ #1 ]
│ ├── cost: {compute=8016,io=3000}
PhysicalProjection { exprs: [ #2, #3 ], cost: {compute=21005,io=4000}, stat: {row_cnt=1} }
└── PhysicalHashJoin { join_type: Inner, left_keys: [ #0 ], right_keys: [ #0 ], cost: {compute=21002,io=4000}, stat: {row_cnt=1} }
├── PhysicalFilter
│ ├── cond:Gt
│ │ ── #1
│ └── 100(i64)
│ ├── cost: {compute=20000,io=3000}
│ ├── stat: {row_cnt=1}
│ └── PhysicalHashJoin { join_type: Inner, left_keys: [ #2 ], right_keys: [ #0 ], cost: {compute=8008,io=3000}, stat: {row_cnt=1} }
│ ├── PhysicalProjection { exprs: [ #2, #0, #1 ], cost: {compute=7006,io=2000}, stat: {row_cnt=1} }
│ │ └── PhysicalHashJoin { join_type: Inner, left_keys: [ #0 ], right_keys: [ #0 ], cost: {compute=7002,io=2000}, stat: {row_cnt=1} }
│ │ ├── PhysicalFilter
│ │ │ ├── cond:Gt
│ │ │ │ ├── #0
│ │ │ │ └── 100(i64)
│ │ │ ├── cost: {compute=3000,io=1000}
│ │ │ ├── stat: {row_cnt=1}
│ │ │ └── PhysicalScan { table: t2, cost: {compute=0,io=1000}, stat: {row_cnt=1000} }
│ │ └── PhysicalAgg { aggrs: [], groups: [ #0 ], cost: {compute=3000,io=1000}, stat: {row_cnt=1000} }
│ │ └── PhysicalScan { table: t1, cost: {compute=0,io=1000}, stat: {row_cnt=1000} }
│ └── PhysicalScan { table: t3, cost: {compute=0,io=1000}, stat: {row_cnt=1000} }
│ └── PhysicalAgg
│ ├── aggrs:Agg(Sum)
│ │ └── [ Cast { cast_to: Int64, child: #2 } ]
│ ├── groups: [ #1 ]
│ ├── cost: {compute=17000,io=3000}
│ ├── stat: {row_cnt=1000}
│ └── PhysicalHashJoin { join_type: Inner, left_keys: [ #2 ], right_keys: [ #0 ], cost: {compute=9000,io=3000}, stat: {row_cnt=1000} }
│ ├── PhysicalHashJoin { join_type: Inner, left_keys: [ #0 ], right_keys: [ #0 ], cost: {compute=6000,io=2000}, stat: {row_cnt=1000} }
│ │ ├── PhysicalAgg { aggrs: [], groups: [ #0 ], cost: {compute=3000,io=1000}, stat: {row_cnt=1000} }
│ │ │ └── PhysicalScan { table: t1, cost: {compute=0,io=1000}, stat: {row_cnt=1000} }
│ │ └── PhysicalScan { table: t2, cost: {compute=0,io=1000}, stat: {row_cnt=1000} }
│ └── PhysicalScan { table: t3, cost: {compute=0,io=1000}, stat: {row_cnt=1000} }
└── PhysicalScan { table: t1, cost: {compute=0,io=1000}, stat: {row_cnt=1000} }
*/

0 comments on commit 5f26d36

Please sign in to comment.