Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(optimizer): consider impure expression in predicate push down of Project #9133

Merged
merged 17 commits into from
Apr 18, 2023
3 changes: 0 additions & 3 deletions proto/expr.proto
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,6 @@ option java_package = "com.risingwave.proto";
option optimize_for = SPEED;

message ExprNode {
// a "pure function" will be defined as having `1 < expr_node as i32 <= 600`.
// Please modify this definition if adding a pure function that does not belong
// to this range.
enum Type {
UNSPECIFIED = 0;
INPUT_REF = 1;
Expand Down
5 changes: 0 additions & 5 deletions src/frontend/src/expr/function_call.rs
Original file line number Diff line number Diff line change
Expand Up @@ -267,11 +267,6 @@ impl FunctionCall {
self.func_type
}

/// Refer to [`ExprType`] for details.
pub fn is_pure(&self) -> bool {
0 < self.func_type as i32 && self.func_type as i32 <= 600
}

/// Get a reference to the function call's inputs.
pub fn inputs(&self) -> &[ExprImpl] {
self.inputs.as_ref()
Expand Down
11 changes: 11 additions & 0 deletions src/frontend/src/expr/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ mod function_call;
mod input_ref;
mod literal;
mod parameter;
mod pure;
mod subquery;
mod table_function;
mod user_defined_function;
Expand All @@ -53,6 +54,7 @@ pub use function_call::{is_row_function, FunctionCall, FunctionCallDisplay};
pub use input_ref::{input_ref_to_column_indices, InputRef, InputRefDisplay};
pub use literal::Literal;
pub use parameter::Parameter;
pub use pure::*;
pub use risingwave_pb::expr::expr_node::Type as ExprType;
pub use session_timezone::SessionTimezone;
pub use subquery::{Subquery, SubqueryKind};
Expand Down Expand Up @@ -170,6 +172,15 @@ impl ExprImpl {
visitor.into()
}

/// Check if the expression has no side effects and output is deterministic
pub fn is_pure(&self) -> bool {
is_pure(self)
}

pub fn is_impure(&self) -> bool {
is_impure(self)
}

/// Count `Now`s in the expression.
pub fn count_nows(&self) -> usize {
let mut visitor = CountNow::default();
Expand Down
163 changes: 163 additions & 0 deletions src/frontend/src/expr/pure.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,163 @@
// Copyright 2023 RisingWave Labs
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

use risingwave_pb::expr::expr_node;

use super::{ExprImpl, ExprVisitor};
struct ImpureAnalyzer {}

impl ExprVisitor<bool> for ImpureAnalyzer {
fn merge(a: bool, b: bool) -> bool {
// the expr will be impure if any of its input is impure
a || b
}

fn visit_user_defined_function(&mut self, _func_call: &super::UserDefinedFunction) -> bool {
true
}

fn visit_function_call(&mut self, func_call: &super::FunctionCall) -> bool {
match func_call.get_expr_type() {
expr_node::Type::Unspecified
| expr_node::Type::InputRef
| expr_node::Type::ConstantValue
| expr_node::Type::Add
| expr_node::Type::Subtract
| expr_node::Type::Multiply
| expr_node::Type::Divide
| expr_node::Type::Modulus
| expr_node::Type::Equal
| expr_node::Type::NotEqual
| expr_node::Type::LessThan
| expr_node::Type::LessThanOrEqual
| expr_node::Type::GreaterThan
| expr_node::Type::GreaterThanOrEqual
| expr_node::Type::And
| expr_node::Type::Or
| expr_node::Type::Not
| expr_node::Type::In
| expr_node::Type::Some
| expr_node::Type::All
| expr_node::Type::BitwiseAnd
| expr_node::Type::BitwiseOr
| expr_node::Type::BitwiseXor
| expr_node::Type::BitwiseNot
| expr_node::Type::BitwiseShiftLeft
| expr_node::Type::BitwiseShiftRight
| expr_node::Type::Extract
| expr_node::Type::DatePart
| expr_node::Type::TumbleStart
| expr_node::Type::ToTimestamp
| expr_node::Type::AtTimeZone
| expr_node::Type::DateTrunc
| expr_node::Type::ToTimestamp1
| expr_node::Type::CastWithTimeZone
| expr_node::Type::Cast
| expr_node::Type::Substr
| expr_node::Type::Length
| expr_node::Type::Like
| expr_node::Type::Upper
| expr_node::Type::Lower
| expr_node::Type::Trim
| expr_node::Type::Replace
| expr_node::Type::Position
| expr_node::Type::Ltrim
| expr_node::Type::Rtrim
| expr_node::Type::Case
| expr_node::Type::RoundDigit
| expr_node::Type::Round
| expr_node::Type::Ascii
| expr_node::Type::Translate
| expr_node::Type::Coalesce
| expr_node::Type::ConcatWs
| expr_node::Type::Abs
| expr_node::Type::SplitPart
| expr_node::Type::Ceil
| expr_node::Type::Floor
| expr_node::Type::ToChar
| expr_node::Type::Md5
| expr_node::Type::CharLength
| expr_node::Type::Repeat
| expr_node::Type::ConcatOp
| expr_node::Type::BoolOut
| expr_node::Type::OctetLength
| expr_node::Type::BitLength
| expr_node::Type::Overlay
| expr_node::Type::RegexpMatch
| expr_node::Type::Pow
| expr_node::Type::Exp
| expr_node::Type::Chr
| expr_node::Type::StartsWith
| expr_node::Type::Initcap
| expr_node::Type::Lpad
| expr_node::Type::Rpad
| expr_node::Type::Reverse
| expr_node::Type::Strpos
| expr_node::Type::ToAscii
| expr_node::Type::ToHex
| expr_node::Type::QuoteIdent
| expr_node::Type::Sin
| expr_node::Type::Cos
| expr_node::Type::Tan
| expr_node::Type::Cot
| expr_node::Type::Asin
| expr_node::Type::Acos
| expr_node::Type::Atan
| expr_node::Type::Atan2
| expr_node::Type::Sqrt
| expr_node::Type::Degrees
| expr_node::Type::Radians
| expr_node::Type::IsTrue
| expr_node::Type::IsNotTrue
| expr_node::Type::IsFalse
| expr_node::Type::IsNotFalse
| expr_node::Type::IsNull
| expr_node::Type::IsNotNull
| expr_node::Type::IsDistinctFrom
| expr_node::Type::IsNotDistinctFrom
| expr_node::Type::Neg
| expr_node::Type::Field
| expr_node::Type::Array
| expr_node::Type::ArrayAccess
| expr_node::Type::Row
| expr_node::Type::ArrayToString
| expr_node::Type::ArrayCat
| expr_node::Type::ArrayAppend
| expr_node::Type::ArrayPrepend
| expr_node::Type::FormatType
| expr_node::Type::ArrayDistinct
| expr_node::Type::ArrayLength
| expr_node::Type::Cardinality
| expr_node::Type::ArrayRemove
| expr_node::Type::HexToInt256
| expr_node::Type::JsonbAccessInner
| expr_node::Type::JsonbAccessStr
| expr_node::Type::JsonbTypeof
| expr_node::Type::JsonbArrayLength
| expr_node::Type::Pi => false,
expr_node::Type::Vnode
| expr_node::Type::Now
| expr_node::Type::Proctime
| expr_node::Type::Udf => true,
}
}
}

pub fn is_pure(expr: &ExprImpl) -> bool {
!is_impure(expr)
}
pub fn is_impure(expr: &ExprImpl) -> bool {
let mut a = ImpureAnalyzer {};
a.visit_expr(expr)
}
13 changes: 1 addition & 12 deletions src/frontend/src/optimizer/plan_node/logical_join.rs
Original file line number Diff line number Diff line change
Expand Up @@ -649,17 +649,6 @@ impl ExprRewritable for LogicalJoin {
}
}

fn is_pure_fn_except_for_input_ref(expr: &ExprImpl) -> bool {
match expr {
ExprImpl::Literal(_) => true,
ExprImpl::FunctionCall(inner) => {
inner.is_pure() && inner.inputs().iter().all(is_pure_fn_except_for_input_ref)
}
ExprImpl::InputRef(_) => true,
st1page marked this conversation as resolved.
Show resolved Hide resolved
_ => false,
}
}

/// We are trying to derive a predicate to apply to the other side of a join if all
/// the `InputRef`s in the predicate are eq condition columns, and can hence be substituted
/// with the corresponding eq condition columns of the other side.
Expand All @@ -683,7 +672,7 @@ fn derive_predicate_from_eq_condition(
col_num: usize,
expr_is_left: bool,
) -> Option<ExprImpl> {
if !is_pure_fn_except_for_input_ref(expr) {
if !expr.is_pure() {
return None;
}
let eq_indices = if expr_is_left {
Expand Down
14 changes: 12 additions & 2 deletions src/frontend/src/optimizer/plan_node/logical_project.rs
Original file line number Diff line number Diff line change
Expand Up @@ -210,9 +210,19 @@ impl PredicatePushdown for LogicalProject {
let mut subst = Substitute {
mapping: self.exprs().clone(),
};
let predicate = predicate.rewrite_expr(&mut subst);

gen_filter_and_pushdown(self, Condition::true_cond(), predicate, ctx)
let impure_mask = {
let mut impure_mask = FixedBitSet::with_capacity(self.exprs().len());
for (i, e) in self.exprs().iter().enumerate() {
impure_mask.set(i, e.is_impure())
}
impure_mask
};
// (with impure input, with pure input)
let (remained_cond, pushed_cond) = predicate.split_disjoint(&impure_mask);
let pushed_cond = pushed_cond.rewrite_expr(&mut subst);

gen_filter_and_pushdown(self, remained_cond, pushed_cond, ctx)
}
}

Expand Down