Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support order by using the projection columns #1136

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions evadb/binder/binder_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,3 +361,19 @@ def drop_row_id_from_target_list(
continue
filtered_list.append(expr)
return filtered_list


def add_func_expr_outputs_to_binder_context(
func_expr: FunctionExpression, binder_context: StatementBinderContext
):
output_cols = []
for obj, alias in zip(func_expr.output_objs, func_expr.alias.col_names):
col_alias = "{}.{}".format(func_expr.alias.alias_name, alias)
alias_obj = TupleValueExpression(
name=alias,
table_alias=func_expr.alias.alias_name,
col_object=obj,
col_alias=col_alias,
)
output_cols.append(alias_obj)
binder_context.add_derived_table_alias(func_expr.alias.alias_name, output_cols)
18 changes: 5 additions & 13 deletions evadb/binder/statement_binder.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

from evadb.binder.binder_utils import (
BinderError,
add_func_expr_outputs_to_binder_context,
bind_table_info,
check_column_name_is_string,
check_groupby_pattern,
Expand Down Expand Up @@ -199,6 +200,9 @@ def _bind_select_statement(self, node: SelectStatement):
node.target_list = extend_star(self._binder_context)
for expr in node.target_list:
self.bind(expr)
if isinstance(expr, FunctionExpression):
add_func_expr_outputs_to_binder_context(expr, self._binder_context)
xzdandy marked this conversation as resolved.
Show resolved Hide resolved

if node.groupby_clause:
self.bind(node.groupby_clause)
check_table_object_is_groupable(node.from_table)
Expand Down Expand Up @@ -275,19 +279,7 @@ def _bind_tableref(self, node: TableRef):
func_expr = node.table_valued_expr.func_expr
func_expr.alias = node.alias
self.bind(func_expr)
output_cols = []
for obj, alias in zip(func_expr.output_objs, func_expr.alias.col_names):
col_alias = "{}.{}".format(func_expr.alias.alias_name, alias)
alias_obj = TupleValueExpression(
name=alias,
table_alias=func_expr.alias.alias_name,
col_object=obj,
col_alias=col_alias,
)
output_cols.append(alias_obj)
self._binder_context.add_derived_table_alias(
func_expr.alias.alias_name, output_cols
)
add_func_expr_outputs_to_binder_context(func_expr, self._binder_context)
else:
raise BinderError(f"Unsupported node {type(node)}")

Expand Down
22 changes: 12 additions & 10 deletions evadb/optimizer/statement_to_opr_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,8 @@ def visit_select(self, statement: SelectStatement):
statement {SelectStatement} - - [input select statement]
"""

# order of evaluation
# from, where, group by, select, order by, limit, union
table_ref = statement.from_table
if table_ref is not None:
self.visit_table_ref(table_ref)
Expand All @@ -133,22 +135,22 @@ def visit_select(self, statement: SelectStatement):
if statement.groupby_clause is not None:
self._visit_groupby(statement.groupby_clause)

if statement.orderby_list is not None:
self._visit_orderby(statement.orderby_list)

if statement.limit_count is not None:
self._visit_limit(statement.limit_count)

# union
if statement.union_link is not None:
self._visit_union(statement.union_link, statement.union_all)

# Projection operator
select_columns = statement.target_list

if select_columns is not None:
self._visit_projection(select_columns)

if statement.orderby_list is not None:
self._visit_orderby(statement.orderby_list)

if statement.limit_count is not None:
self._visit_limit(statement.limit_count)

# union
if statement.union_link is not None:
self._visit_union(statement.union_link, statement.union_all)

def _visit_sample(self, sample_freq, sample_type):
sample_opr = LogicalSample(sample_freq, sample_type)
sample_opr.append_child(self._plan)
Expand Down
2 changes: 1 addition & 1 deletion test/integration_tests/long/test_model_forecasting.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def test_forecast(self):
execute_query_fetch_all(self.evadb, create_predict_udf)

predict_query = """
SELECT AirForecast(12);
SELECT AirForecast(12) order by y;
"""
result = execute_query_fetch_all(self.evadb, predict_query)
self.assertEqual(len(result), 12)
Expand Down
2 changes: 1 addition & 1 deletion test/integration_tests/long/test_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ def test_should_run_pytorch_and_facenet(self):
execute_query_fetch_all(self.evadb, create_function_query)

select_query = """SELECT FaceDetector(data) FROM MyVideo
WHERE id < 5;"""
WHERE id < 5 order by scores;"""
actual_batch = execute_query_fetch_all(self.evadb, select_query)
self.assertEqual(len(actual_batch), 5)

Expand Down