Skip to content

Commit

Permalink
[SNOW-980540] Fix for sql simplifier that incorrectly pushed filter d…
Browse files Browse the repository at this point in the history
…own during flattern with window function (#1183)

* fix sql simplifier

* fix error

* update change log

* fix change log
  • Loading branch information
sfc-gh-yzou authored Dec 15, 2023
1 parent d27ebcd commit 5f208fd
Show file tree
Hide file tree
Showing 5 changed files with 34 additions and 6 deletions.
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
# Release History

## 1.12.0 (TBD)

### Bug Fixes
- Fixed sql simplifier for filter with window function columns in select.

## 1.11.1 (2023-12-07)

### Bug Fixes
Expand Down
6 changes: 2 additions & 4 deletions src/snowflake/snowpark/_internal/analyzer/select_statement.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,6 @@ def __init__(
self.pre_actions: Optional[List["Query"]] = None
self.post_actions: Optional[List["Query"]] = None
self.flatten_disabled: bool = False
self.has_data_generator_exp: bool = False
self._column_states: Optional[ColumnStateDict] = None
self._snowflake_plan: Optional[SnowflakePlan] = None
self.expr_to_alias = {}
Expand Down Expand Up @@ -670,7 +669,6 @@ def select(self, cols: List[Expression]) -> "SelectStatement":
new.from_ = self.from_.to_subqueryable()
new.pre_actions = new.from_.pre_actions
new.post_actions = new.from_.post_actions
new.has_data_generator_exp = has_data_generator_exp(cols)
else:
new = SelectStatement(
projection=cols, from_=self.to_subqueryable(), analyzer=self.analyzer
Expand All @@ -691,7 +689,7 @@ def filter(self, col: Expression) -> "SelectStatement":
and can_clause_dependent_columns_flatten(
derive_dependent_columns(col), self.column_states
)
and not self.has_data_generator_exp
and not has_data_generator_exp(self.projection)
)
if can_be_flattened:
new = copy(self)
Expand All @@ -713,7 +711,7 @@ def sort(self, cols: List[Expression]) -> "SelectStatement":
and can_clause_dependent_columns_flatten(
derive_dependent_columns(*cols), self.column_states
)
and not self.has_data_generator_exp
and not has_data_generator_exp(self.projection)
)
if can_be_flattened:
new = copy(self)
Expand Down
5 changes: 4 additions & 1 deletion src/snowflake/snowpark/mock/_plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,10 @@
Sample,
)
from snowflake.snowpark._internal.type_utils import infer_type
from snowflake.snowpark._internal.utils import generate_random_alphanumeric, parse_table_name
from snowflake.snowpark._internal.utils import (
generate_random_alphanumeric,
parse_table_name,
)
from snowflake.snowpark.column import Column
from snowflake.snowpark.exceptions import SnowparkSQLException
from snowflake.snowpark.mock._functions import _MOCK_FUNCTION_IMPLEMENTATION_MAP
Expand Down
2 changes: 1 addition & 1 deletion tests/integ/test_pandas_to_df.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

import pytest

from snowflake.snowpark.types import LongType, TimestampTimeZone, TimestampType
from snowflake.snowpark.types import TimestampTimeZone, TimestampType

try:
from pandas import DataFrame as PandasDF, to_datetime
Expand Down
22 changes: 22 additions & 0 deletions tests/integ/test_simplifier_suite.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@
from snowflake.snowpark.functions import (
avg,
col,
iff,
lit,
min as min_,
sql_expr,
sum as sum_,
table_function,
Expand Down Expand Up @@ -1211,3 +1213,23 @@ def test_select_after_orderby(session, operation, simplified_query, execute_sql)
assert operation(df2).queries["queries"][0] == simplified_query
if execute_sql:
Utils.check_answer(operation(df1), operation(df2))


def test_window_with_filter(session):
session.sql_simplifier_enabled = False
df1 = session.create_dataframe([[0], [1]], schema=["A"])

session.sql_simplifier_enabled = True
df2 = session.create_dataframe([[0], [1]], schema=["A"])

df1 = (
df1.with_column("B", iff(df1.A == 0, 10, 11))
.with_column("C", min_("B").over())
.filter(df1.A == 1)
)
df2 = (
df2.with_column("B", iff(df2.A == 0, 10, 11))
.with_column("C", min_("B").over())
.filter(df2.A == 1)
)
Utils.check_answer(df1, df2, sort=False)

0 comments on commit 5f208fd

Please sign in to comment.