Skip to content

Commit

Permalink
feat: support type expr
Browse files Browse the repository at this point in the history
  • Loading branch information
Gun9niR committed Feb 17, 2024
1 parent 214848c commit 0f119f3
Show file tree
Hide file tree
Showing 8 changed files with 98 additions and 65 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

14 changes: 6 additions & 8 deletions optd-datafusion-bridge/src/from_optd.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ use optd_datafusion_repr::{
PhysicalSort, PlanNode, SortOrderExpr, SortOrderType,
},
properties::schema::Schema as OptdSchema,
PhysicalCollector, Value,
PhysicalCollector,
};

use crate::{physical_collector::CollectorExec, OptdPlanContext};
Expand Down Expand Up @@ -250,14 +250,12 @@ impl OptdPlanContext<'_> {
OptRelNodeTyp::Cast => {
let expr = CastExpr::from_rel_node(expr.into_rel_node()).unwrap();
let child = Self::conv_from_optd_expr(expr.child(), context)?;
let data_type = match expr.cast_to() {
Value::Bool(_) => DataType::Boolean,
Value::Decimal128(_) => DataType::Decimal128(15, 2), /* TODO: AVOID HARD CODE PRECISION */
Value::Date32(_) => DataType::Date32,
other => unimplemented!("{}", other),
};
Ok(Arc::new(
datafusion::physical_plan::expressions::CastExpr::new(child, data_type, None),
datafusion::physical_plan::expressions::CastExpr::new(
child,
expr.cast_to(),
None,
),
))
}
OptRelNodeTyp::Like => {
Expand Down
30 changes: 6 additions & 24 deletions optd-datafusion-bridge/src/into_optd.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,11 @@ use datafusion::{
};
use datafusion_expr::Expr as DFExpr;
use optd_core::rel_node::RelNode;
use optd_datafusion_repr::{
plan_nodes::{
BetweenExpr, BinOpExpr, BinOpType, CastExpr, ColumnRefExpr, ConstantExpr, Expr, ExprList,
FuncExpr, FuncType, JoinType, LikeExpr, LogOpExpr, LogOpType, LogicalAgg,
LogicalEmptyRelation, LogicalFilter, LogicalJoin, LogicalLimit, LogicalProjection,
LogicalScan, LogicalSort, OptRelNode, OptRelNodeRef, OptRelNodeTyp, PlanNode,
SortOrderExpr, SortOrderType,
},
Value,
use optd_datafusion_repr::plan_nodes::{
BetweenExpr, BinOpExpr, BinOpType, CastExpr, ColumnRefExpr, ConstantExpr, Expr, ExprList,
FuncExpr, FuncType, JoinType, LikeExpr, LogOpExpr, LogOpType, LogicalAgg, LogicalEmptyRelation,
LogicalFilter, LogicalJoin, LogicalLimit, LogicalProjection, LogicalScan, LogicalSort,
OptRelNode, OptRelNodeRef, OptRelNodeTyp, PlanNode, SortOrderExpr, SortOrderType,
};

use crate::OptdPlanContext;
Expand Down Expand Up @@ -170,21 +166,7 @@ impl OptdPlanContext<'_> {
}
Expr::Cast(x) => {
let expr = self.conv_into_optd_expr(x.expr.as_ref(), context)?;
let data_type = x.data_type.clone();
let val = match data_type {
arrow_schema::DataType::Int8 => Value::Int8(0),
arrow_schema::DataType::Int16 => Value::Int16(0),
arrow_schema::DataType::Int32 => Value::Int32(0),
arrow_schema::DataType::Int64 => Value::Int64(0),
arrow_schema::DataType::UInt8 => Value::UInt8(0),
arrow_schema::DataType::UInt16 => Value::UInt16(0),
arrow_schema::DataType::UInt32 => Value::UInt32(0),
arrow_schema::DataType::UInt64 => Value::UInt64(0),
arrow_schema::DataType::Date32 => Value::Date32(0),
arrow_schema::DataType::Decimal128(_, _) => Value::Decimal128(0),
other => unimplemented!("unimplemented datatype {:?}", other),
};
Ok(CastExpr::new(expr, val).into_expr())
Ok(CastExpr::new(expr, x.data_type.clone()).into_expr())
}
Expr::Like(x) => {
let expr = self.conv_into_optd_expr(x.expr.as_ref(), context)?;
Expand Down
1 change: 1 addition & 0 deletions optd-datafusion-repr/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ edition = "2021"

[dependencies]
anyhow = "1"
arrow-schema = "47.0.0"
num-traits = "0.2"
num-derive = "0.2"
tracing = "0.1"
Expand Down
10 changes: 8 additions & 2 deletions optd-datafusion-repr/src/plan_nodes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ mod sort;

use std::sync::Arc;

use arrow_schema::DataType;
use optd_core::{
cascades::{CascadesOptimizer, GroupId},
rel_node::{RelNode, RelNodeRef, RelNodeTyp},
Expand All @@ -24,8 +25,8 @@ pub use apply::{ApplyType, LogicalApply};
pub use empty_relation::{LogicalEmptyRelation, PhysicalEmptyRelation};
pub use expr::{
BetweenExpr, BinOpExpr, BinOpType, CastExpr, ColumnRefExpr, ConstantExpr, ConstantType,
ExprList, FuncExpr, FuncType, LikeExpr, LogOpExpr, LogOpType, SortOrderExpr, SortOrderType,
UnOpExpr, UnOpType,
DataTypeExpr, ExprList, FuncExpr, FuncType, LikeExpr, LogOpExpr, LogOpType, SortOrderExpr,
SortOrderType, UnOpExpr, UnOpType,
};
pub use filter::{LogicalFilter, PhysicalFilter};
pub use join::{JoinType, LogicalJoin, PhysicalHashJoin, PhysicalNestedLoopJoin};
Expand Down Expand Up @@ -77,6 +78,7 @@ pub enum OptRelNodeTyp {
Between,
Cast,
Like,
DataType(DataType),
}

impl OptRelNodeTyp {
Expand Down Expand Up @@ -118,6 +120,7 @@ impl OptRelNodeTyp {
| Self::Between
| Self::Cast
| Self::Like
| Self::DataType(_)
)
}
}
Expand Down Expand Up @@ -387,6 +390,9 @@ pub fn explain(rel_node: OptRelNodeRef) -> Pretty<'static> {
OptRelNodeTyp::Like => LikeExpr::from_rel_node(rel_node)
.unwrap()
.dispatch_explain(),
OptRelNodeTyp::DataType(_) => DataTypeExpr::from_rel_node(rel_node)
.unwrap()
.dispatch_explain(),
}
}

Expand Down
62 changes: 53 additions & 9 deletions optd-datafusion-repr/src/plan_nodes/expr.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use std::{fmt::Display, sync::Arc};

use arrow_schema::DataType;
use itertools::Itertools;
use pretty_xmlish::Pretty;

Expand Down Expand Up @@ -636,19 +637,60 @@ impl OptRelNode for BetweenExpr {
}
}

#[derive(Clone, Debug)]
pub struct DataTypeExpr(pub Expr);

impl DataTypeExpr {
pub fn new(typ: DataType) -> Self {
DataTypeExpr(Expr(
RelNode {
typ: OptRelNodeTyp::DataType(typ),
children: vec![],
data: None,
}
.into(),
))
}

pub fn data_type(&self) -> DataType {
if let OptRelNodeTyp::DataType(data_type) = self.0.typ() {
data_type
} else {
panic!("not a data type")
}
}
}

impl OptRelNode for DataTypeExpr {
fn into_rel_node(self) -> OptRelNodeRef {
self.0.into_rel_node()
}

fn from_rel_node(rel_node: OptRelNodeRef) -> Option<Self> {
if !matches!(rel_node.typ, OptRelNodeTyp::DataType(_)) {
return None;
}
Expr::from_rel_node(rel_node).map(Self)
}

fn dispatch_explain(&self) -> Pretty<'static> {
Pretty::display(&self.data_type().to_string())
}
}

#[derive(Clone, Debug)]
pub struct CastExpr(pub Expr);

impl CastExpr {
pub fn new(
expr: Expr,
cast_to: Value, /* TODO: have a `type` relnode for representing type */
) -> Self {
pub fn new(expr: Expr, cast_to: DataType) -> Self {
CastExpr(Expr(
RelNode {
typ: OptRelNodeTyp::Cast,
children: vec![expr.into_rel_node()],
data: Some(cast_to),
children: vec![
expr.into_rel_node(),
DataTypeExpr::new(cast_to).into_rel_node(),
],
data: None,
}
.into(),
))
Expand All @@ -658,8 +700,10 @@ impl CastExpr {
Expr(self.0.child(0))
}

pub fn cast_to(&self) -> Value {
self.0 .0.data.clone().unwrap()
pub fn cast_to(&self) -> DataType {
DataTypeExpr::from_rel_node(self.0.child(1))
.unwrap()
.data_type()
}
}

Expand All @@ -679,7 +723,7 @@ impl OptRelNode for CastExpr {
Pretty::simple_record(
"Cast",
vec![
("cast_to", format!("{:?}", self.cast_to()).into()),
("cast_to", format!("{}", self.cast_to()).into()),
("expr", self.child().explain()),
],
vec![],
Expand Down
1 change: 1 addition & 0 deletions optd-datafusion-repr/src/properties/column_ref.rs
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ impl PropertyBuilder<OptRelNodeTyp> for ColumnRefPropertyBuilder {
OptRelNodeTyp::Constant(_)
| OptRelNodeTyp::Func(_)
| OptRelNodeTyp::BinOp(_)
| OptRelNodeTyp::DataType(_)
| OptRelNodeTyp::Between
| OptRelNodeTyp::EmptyRelation => {
vec![ColumnRef::Derived]
Expand Down
44 changes: 22 additions & 22 deletions optd-sqlplannertest/tests/tpch.planner.sql
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ LogicalSort
│ └── Mul
│ ├── #22
│ └── Sub
│ ├── Cast { cast_to: Decimal128(0), expr: 1 }
│ ├── Cast { cast_to: Decimal128(20, 0), expr: 1 }
│ └── #23
├── groups: [ #41 ]
└── LogicalFilter
Expand Down Expand Up @@ -159,10 +159,10 @@ LogicalSort
│ │ │ └── "Asia"
│ │ └── Geq
│ │ ├── #12
│ │ └── Cast { cast_to: Date32(0), expr: "2023-01-01" }
│ │ └── Cast { cast_to: Date32, expr: "2023-01-01" }
│ └── Lt
│ ├── #12
│ └── Cast { cast_to: Date32(0), expr: "2024-01-01" }
│ └── Cast { cast_to: Date32, expr: "2024-01-01" }
└── LogicalJoin { join_type: Cross, cond: true }
├── LogicalJoin { join_type: Cross, cond: true }
│ ├── LogicalJoin { join_type: Cross, cond: true }
Expand All @@ -183,7 +183,7 @@ PhysicalSort
│ └── Mul
│ ├── #22
│ └── Sub
│ ├── Cast { cast_to: Decimal128(0), expr: 1 }
│ ├── Cast { cast_to: Decimal128(20, 0), expr: 1 }
│ └── #23
├── groups: [ #41 ]
└── PhysicalFilter
Expand Down Expand Up @@ -218,10 +218,10 @@ PhysicalSort
│ │ │ └── "Asia"
│ │ └── Geq
│ │ ├── #12
│ │ └── Cast { cast_to: Date32(0), expr: "2023-01-01" }
│ │ └── Cast { cast_to: Date32, expr: "2023-01-01" }
│ └── Lt
│ ├── #12
│ └── Cast { cast_to: Date32(0), expr: "2024-01-01" }
│ └── Cast { cast_to: Date32, expr: "2024-01-01" }
└── PhysicalNestedLoopJoin { join_type: Cross, cond: true }
├── PhysicalNestedLoopJoin { join_type: Cross, cond: true }
│ ├── PhysicalNestedLoopJoin { join_type: Cross, cond: true }
Expand Down Expand Up @@ -260,14 +260,14 @@ LogicalProjection { exprs: [ #0 ] }
│ │ ├── And
│ │ │ ├── Geq
│ │ │ │ ├── #10
│ │ │ │ └── Cast { cast_to: Date32(0), expr: "2023-01-01" }
│ │ │ │ └── Cast { cast_to: Date32, expr: "2023-01-01" }
│ │ │ └── Lt
│ │ │ ├── #10
│ │ │ └── Cast { cast_to: Date32(0), expr: "2024-01-01" }
│ │ └── Between { expr: Cast { cast_to: Decimal128(0), expr: #6 }, lower: Cast { cast_to: Decimal128(0), expr: 0.05 }, upper: Cast { cast_to: Decimal128(0), expr: 0.07 } }
│ │ │ └── Cast { cast_to: Date32, expr: "2024-01-01" }
│ │ └── Between { expr: Cast { cast_to: Decimal128(30, 15), expr: #6 }, lower: Cast { cast_to: Decimal128(30, 15), expr: 0.05 }, upper: Cast { cast_to: Decimal128(30, 15), expr: 0.07 } }
│ └── Lt
│ ├── Cast { cast_to: Decimal128(0), expr: #4 }
│ └── Cast { cast_to: Decimal128(0), expr: 24 }
│ ├── Cast { cast_to: Decimal128(22, 2), expr: #4 }
│ └── Cast { cast_to: Decimal128(22, 2), expr: 24 }
└── LogicalScan { table: lineitem }
PhysicalProjection { exprs: [ #0 ] }
└── PhysicalAgg
Expand All @@ -282,14 +282,14 @@ PhysicalProjection { exprs: [ #0 ] }
│ │ ├── And
│ │ │ ├── Geq
│ │ │ │ ├── #10
│ │ │ │ └── Cast { cast_to: Date32(0), expr: "2023-01-01" }
│ │ │ │ └── Cast { cast_to: Date32, expr: "2023-01-01" }
│ │ │ └── Lt
│ │ │ ├── #10
│ │ │ └── Cast { cast_to: Date32(0), expr: "2024-01-01" }
│ │ └── Between { expr: Cast { cast_to: Decimal128(0), expr: #6 }, lower: Cast { cast_to: Decimal128(0), expr: 0.05 }, upper: Cast { cast_to: Decimal128(0), expr: 0.07 } }
│ │ │ └── Cast { cast_to: Date32, expr: "2024-01-01" }
│ │ └── Between { expr: Cast { cast_to: Decimal128(30, 15), expr: #6 }, lower: Cast { cast_to: Decimal128(30, 15), expr: 0.05 }, upper: Cast { cast_to: Decimal128(30, 15), expr: 0.07 } }
│ └── Lt
│ ├── Cast { cast_to: Decimal128(0), expr: #4 }
│ └── Cast { cast_to: Decimal128(0), expr: 24 }
│ ├── Cast { cast_to: Decimal128(22, 2), expr: #4 }
│ └── Cast { cast_to: Decimal128(22, 2), expr: 24 }
└── PhysicalScan { table: lineitem }
*/

Expand Down Expand Up @@ -351,7 +351,7 @@ LogicalSort
│ │ │ ├── #2
│ │ │ └── "IRAQ"
│ │ ├── #1
│ │ └── Cast { cast_to: Decimal128(0), expr: 0 }
│ │ └── Cast { cast_to: Decimal128(38, 4), expr: 0 }
│ └── Agg(Sum)
│ └── [ #1 ]
├── groups: [ #0 ]
Expand All @@ -362,7 +362,7 @@ LogicalSort
│ ├── Mul
│ │ ├── #21
│ │ └── Sub
│ │ ├── Cast { cast_to: Decimal128(0), expr: 1 }
│ │ ├── Cast { cast_to: Decimal128(20, 0), expr: 1 }
│ │ └── #22
│ └── #54
└── LogicalFilter
Expand Down Expand Up @@ -399,7 +399,7 @@ LogicalSort
│ │ │ └── Eq
│ │ │ ├── #12
│ │ │ └── #53
│ │ └── Between { expr: #36, lower: Cast { cast_to: Date32(0), expr: "1995-01-01" }, upper: Cast { cast_to: Date32(0), expr: "1996-12-31" } }
│ │ └── Between { expr: #36, lower: Cast { cast_to: Date32, expr: "1995-01-01" }, upper: Cast { cast_to: Date32, expr: "1996-12-31" } }
│ └── Eq
│ ├── #4
│ └── "ECONOMY ANODIZED STEEL"
Expand Down Expand Up @@ -436,7 +436,7 @@ PhysicalSort
│ │ │ ├── #2
│ │ │ └── "IRAQ"
│ │ ├── #1
│ │ └── Cast { cast_to: Decimal128(0), expr: 0 }
│ │ └── Cast { cast_to: Decimal128(38, 4), expr: 0 }
│ └── Agg(Sum)
│ └── [ #1 ]
├── groups: [ #0 ]
Expand All @@ -447,7 +447,7 @@ PhysicalSort
│ ├── Mul
│ │ ├── #21
│ │ └── Sub
│ │ ├── Cast { cast_to: Decimal128(0), expr: 1 }
│ │ ├── Cast { cast_to: Decimal128(20, 0), expr: 1 }
│ │ └── #22
│ └── #54
└── PhysicalFilter
Expand Down Expand Up @@ -484,7 +484,7 @@ PhysicalSort
│ │ │ └── Eq
│ │ │ ├── #12
│ │ │ └── #53
│ │ └── Between { expr: #36, lower: Cast { cast_to: Date32(0), expr: "1995-01-01" }, upper: Cast { cast_to: Date32(0), expr: "1996-12-31" } }
│ │ └── Between { expr: #36, lower: Cast { cast_to: Date32, expr: "1995-01-01" }, upper: Cast { cast_to: Date32, expr: "1996-12-31" } }
│ └── Eq
│ ├── #4
│ └── "ECONOMY ANODIZED STEEL"
Expand Down

0 comments on commit 0f119f3

Please sign in to comment.