Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP fix: project rel to and from substrait to include pass through columns #133

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions src/from_substrait.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@
interval_t interval {};
interval.months = 0;
interval.days = literal.interval_day_to_second().days();
interval.micros = literal.interval_day_to_second().microseconds();

Check warning on line 165 in src/from_substrait.cpp

View workflow job for this annotation

GitHub Actions / Build extension binaries / MacOS (osx_amd64, x86_64, x64-osx)

'microseconds' is deprecated [-Wdeprecated-declarations]

Check warning on line 165 in src/from_substrait.cpp

View workflow job for this annotation

GitHub Actions / Build extension binaries / MacOS (osx_amd64, x86_64, x64-osx)

'microseconds' is deprecated [-Wdeprecated-declarations]

Check warning on line 165 in src/from_substrait.cpp

View workflow job for this annotation

GitHub Actions / Build extension binaries / MacOS (osx_arm64, arm64, arm64-osx)

'microseconds' is deprecated [-Wdeprecated-declarations]

Check warning on line 165 in src/from_substrait.cpp

View workflow job for this annotation

GitHub Actions / Build extension binaries / MacOS (osx_arm64, arm64, arm64-osx)

'microseconds' is deprecated [-Wdeprecated-declarations]
return Value::INTERVAL(interval);
}
default:
Expand Down Expand Up @@ -498,6 +498,12 @@
vector<unique_ptr<ParsedExpression>> expressions;
RootNameIterator iterator(names);

auto input_rel = TransformOp(sop.project().input());
auto num_input_columns = input_rel->Columns().size();
for (int i = 1; i <= num_input_columns; i++) {
expressions.push_back(make_uniq<PositionalReferenceExpression>(i));
}

for (auto &sexpr : sop.project().expressions()) {
expressions.push_back(TransformExpr(sexpr, &iterator));
}
Expand All @@ -506,8 +512,7 @@
for (size_t i = 0; i < expressions.size(); i++) {
mock_aliases.push_back("expr_" + to_string(i));
}
return make_shared_ptr<ProjectionRelation>(TransformOp(sop.project().input()), std::move(expressions),
std::move(mock_aliases));
return make_shared_ptr<ProjectionRelation>(input_rel, std::move(expressions), std::move(mock_aliases));
}

shared_ptr<Relation> SubstraitToDuckDB::TransformAggregateOp(const substrait::Rel &sop) {
Expand All @@ -515,7 +520,7 @@

if (sop.aggregate().groupings_size() > 0) {
for (auto &sgrp : sop.aggregate().groupings()) {
for (auto &sgrpexpr : sgrp.grouping_expressions()) {

Check warning on line 523 in src/from_substrait.cpp

View workflow job for this annotation

GitHub Actions / Build extension binaries / MacOS (osx_amd64, x86_64, x64-osx)

'grouping_expressions' is deprecated [-Wdeprecated-declarations]

Check warning on line 523 in src/from_substrait.cpp

View workflow job for this annotation

GitHub Actions / Build extension binaries / MacOS (osx_arm64, arm64, arm64-osx)

'grouping_expressions' is deprecated [-Wdeprecated-declarations]
groups.push_back(TransformExpr(sgrpexpr));
expressions.push_back(TransformExpr(sgrpexpr));
}
Expand Down Expand Up @@ -615,8 +620,8 @@
scan = rel->Alias(name);
} else if (sget.has_virtual_table()) {
// We need to handle a virtual table as a LogicalExpressionGet
if (!sget.virtual_table().values().empty()) {

Check warning on line 623 in src/from_substrait.cpp

View workflow job for this annotation

GitHub Actions / Build extension binaries / MacOS (osx_amd64, x86_64, x64-osx)

'values' is deprecated [-Wdeprecated-declarations]

Check warning on line 623 in src/from_substrait.cpp

View workflow job for this annotation

GitHub Actions / Build extension binaries / MacOS (osx_arm64, arm64, arm64-osx)

'values' is deprecated [-Wdeprecated-declarations]
auto literal_values = sget.virtual_table().values();

Check warning on line 624 in src/from_substrait.cpp

View workflow job for this annotation

GitHub Actions / Build extension binaries / MacOS (osx_amd64, x86_64, x64-osx)

'values' is deprecated [-Wdeprecated-declarations]

Check warning on line 624 in src/from_substrait.cpp

View workflow job for this annotation

GitHub Actions / Build extension binaries / MacOS (osx_arm64, arm64, arm64-osx)

'values' is deprecated [-Wdeprecated-declarations]
vector<vector<Value>> expression_rows;
for (auto &row : literal_values) {
auto values = row.fields();
Expand Down
41 changes: 38 additions & 3 deletions src/to_substrait.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@
} else {
auto interval_day = make_uniq<substrait::Expression_Literal_IntervalDayToSecond>();
interval_day->set_days(dval.GetValue<interval_t>().days);
interval_day->set_microseconds(static_cast<int32_t>(dval.GetValue<interval_t>().micros));

Check warning on line 204 in src/to_substrait.cpp

View workflow job for this annotation

GitHub Actions / Build extension binaries / MacOS (osx_amd64, x86_64, x64-osx)

'set_microseconds' is deprecated [-Wdeprecated-declarations]

Check warning on line 204 in src/to_substrait.cpp

View workflow job for this annotation

GitHub Actions / Build extension binaries / MacOS (osx_arm64, arm64, arm64-osx)

'set_microseconds' is deprecated [-Wdeprecated-declarations]
sval.set_allocated_interval_day_to_second(interval_day.release());
}
}
Expand Down Expand Up @@ -857,11 +857,46 @@
substrait::Rel *DuckDBToSubstrait::TransformProjection(LogicalOperator &dop) {
auto res = new substrait::Rel();
auto &dproj = dop.Cast<LogicalProjection>();

auto child_column_count = dop.children[0]->types.size();
auto num_passthrough_columns = 0;
auto need_output_mapping = true;
if (child_column_count <= dproj.expressions.size()) {
// check if the projection is just pass through of input columns with no reordering
auto exp_col_idx = 0;
auto is_passthrough = true;
for (auto &dexpr : dproj.expressions) {
if (dexpr->type != ExpressionType::BOUND_REF) {
is_passthrough = false;
break;
}
auto &dref = dexpr.get()->Cast<BoundReferenceExpression>();
if (dref.index != exp_col_idx) {
is_passthrough = false;
break;
}
exp_col_idx++;
}
if (is_passthrough && child_column_count == exp_col_idx) {
// skip the projection
return TransformOp(*dop.children[0]);
}
if (child_column_count == exp_col_idx) {
// all input columns are projected, no need for output mapping
num_passthrough_columns = child_column_count;
need_output_mapping = false;
}
}

// num_passthrough_columns = 0; // TODO remove this
// add remaining columns as expressions. These are other than the input columns in same order
auto sproj = res->mutable_project();
sproj->set_allocated_input(TransformOp(*dop.children[0]));

for (auto &dexpr : dproj.expressions) {
TransformExpr(*dexpr, *sproj->add_expressions());
for (int i = num_passthrough_columns; i < dproj.expressions.size(); i++) {
TransformExpr(*dproj.expressions[i], *sproj->add_expressions());
}
if (need_output_mapping) {
// TODO
}
return res;
}
Expand Down Expand Up @@ -1012,7 +1047,7 @@
// TODO push projection or push substrait to allow expressions here
throw NotImplementedException("No expressions in groupings yet");
}
TransformExpr(*dgrp, *sgrp->add_grouping_expressions());

Check warning on line 1050 in src/to_substrait.cpp

View workflow job for this annotation

GitHub Actions / Build extension binaries / MacOS (osx_amd64, x86_64, x64-osx)

'add_grouping_expressions' is deprecated [-Wdeprecated-declarations]

Check warning on line 1050 in src/to_substrait.cpp

View workflow job for this annotation

GitHub Actions / Build extension binaries / MacOS (osx_arm64, arm64, arm64-osx)

'add_grouping_expressions' is deprecated [-Wdeprecated-declarations]
}
for (auto &dmeas : daggr.expressions) {
auto smeas = saggr->add_measures()->mutable_measure();
Expand Down Expand Up @@ -1280,7 +1315,7 @@
auto virtual_table = sget->mutable_virtual_table();

// Add a dummy value to emit one row
auto dummy_value = virtual_table->add_values();

Check warning on line 1318 in src/to_substrait.cpp

View workflow job for this annotation

GitHub Actions / Build extension binaries / MacOS (osx_amd64, x86_64, x64-osx)

'add_values' is deprecated [-Wdeprecated-declarations]

Check warning on line 1318 in src/to_substrait.cpp

View workflow job for this annotation

GitHub Actions / Build extension binaries / MacOS (osx_arm64, arm64, arm64-osx)

'add_values' is deprecated [-Wdeprecated-declarations]
dummy_value->add_fields()->set_i32(42);
return get_rel;
}
Expand Down
2 changes: 1 addition & 1 deletion test/c/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ include_directories(../../duckdb/src/include)
include_directories(../../duckdb/test/include)
include_directories(../../duckdb/third_party/catch)

set(ALL_SOURCES test_substrait_c_api.cpp)
set(ALL_SOURCES test_substrait_c_api.cpp test_substrait_c_utils.cpp test_projection.cpp)


add_library_unity(test_substrait OBJECT ${ALL_SOURCES})
Expand Down
55 changes: 55 additions & 0 deletions test/c/test_projection.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
#include "catch.hpp"
#include "test_helpers.hpp"
#include "duckdb/main/connection_manager.hpp"
#include "test_substrait_c_utils.hpp"

#include <chrono>
#include <thread>
#include <iostream>

using namespace duckdb;
using namespace std;

TEST_CASE("Test C Project input columns with Substrait API", "[substrait-api]") {
DuckDB db(nullptr);
Connection con(db);

REQUIRE_NO_FAIL(con.Query("CREATE TABLE integers(i INTEGER)"));
REQUIRE_NO_FAIL(con.Query("INSERT INTO integers VALUES (10), (20), (30)"));
CreateEmployeeTable(con);

auto expected_json_str = R"({"relations":[{"root":{"input":{"read":{"baseSchema":{"names":["i"],"struct":{"types":[{"i32":{"nullability":"NULLABILITY_NULLABLE"}}],"nullability":"NULLABILITY_REQUIRED"}},"projection":{"select":{"structItems":[{}]},"maintainSingularStruct":true},"namedTable":{"names":["integers"]}}},"names":["i"]}}],"version":{"minorNumber":53,"producer":"DuckDB"}})";
auto json_str = con.GetSubstraitJSON("SELECT i FROM integers");
REQUIRE(json_str == expected_json_str);
auto result = con.FromSubstraitJSON(json_str);
REQUIRE(CHECK_COLUMN(result, 0, {10, 20, 30}));
}

TEST_CASE("Test C Project 1 input column 1 transformation with Substrait API", "[substrait-api]") {
DuckDB db(nullptr);
Connection con(db);

REQUIRE_NO_FAIL(con.Query("CREATE TABLE integers(i INTEGER)"));
REQUIRE_NO_FAIL(con.Query("INSERT INTO integers VALUES (10), (20), (30)"));
CreateEmployeeTable(con);

auto expected_json_str = R"({"extensionUris":[{"extensionUriAnchor":1,"uri":"https://github.com/substrait-io/substrait/blob/main/extensions/functions_arithmetic.yaml"}],"extensions":[{"extensionFunction":{"extensionUriReference":1,"functionAnchor":1,"name":"multiply:i32_i32"}}],"relations":[{"root":{"input":{"project":{"input":{"read":{"baseSchema":{"names":["i"],"struct":{"types":[{"i32":{"nullability":"NULLABILITY_NULLABLE"}}],"nullability":"NULLABILITY_REQUIRED"}},"projection":{"select":{"structItems":[{}]},"maintainSingularStruct":true},"namedTable":{"names":["integers"]}}},"expressions":[{"scalarFunction":{"functionReference":1,"outputType":{"i32":{"nullability":"NULLABILITY_NULLABLE"}},"arguments":[{"value":{"selection":{"directReference":{"structField":{}},"rootReference":{}}}},{"value":{"selection":{"directReference":{"structField":{}},"rootReference":{}}}}]}}]}},"names":["i","isquare"]}}],"version":{"minorNumber":53,"producer":"DuckDB"}})";
auto json_str = con.GetSubstraitJSON("SELECT i, i *i as isquare FROM integers");
REQUIRE(json_str == expected_json_str);
auto result = con.FromSubstraitJSON(json_str);
REQUIRE(CHECK_COLUMN(result, 0, {10, 20, 30}));
REQUIRE(CHECK_COLUMN(result, 1, {100, 400, 900}));
}

TEST_CASE("Test C Project two columns with Substrait API", "[substrait-api]") {
DuckDB db(nullptr);
Connection con(db);

CreateEmployeeTable(con);

// TODO should this have any projection?
auto json_str = con.GetSubstraitJSON("SELECT name, salary FROM employees");
// REQUIRE(json_str == expected_json_str);
auto result = con.FromSubstraitJSON(json_str);
REQUIRE(CHECK_COLUMN(result, 0, {"John Doe", "Jane Smith", "Alice Johnson", "Bob Brown", "Charlie Black"}));
}
26 changes: 1 addition & 25 deletions test/c/test_substrait_c_api.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#include "catch.hpp"
#include "test_helpers.hpp"
#include "duckdb/main/connection_manager.hpp"
#include "test_substrait_c_utils.hpp"

#include <chrono>
#include <thread>
Expand Down Expand Up @@ -47,31 +48,6 @@ TEST_CASE("Test C Get and To Json-Substrait API", "[substrait-api]") {
REQUIRE_THROWS(con.FromSubstraitJSON("this is not valid"));
}

duckdb::unique_ptr<QueryResult> ExecuteViaSubstrait(Connection &con, const string &sql) {
auto proto = con.GetSubstrait(sql);
return con.FromSubstrait(proto);
}

duckdb::unique_ptr<QueryResult> ExecuteViaSubstraitJSON(Connection &con, const string &sql) {
auto json_str = con.GetSubstraitJSON(sql);
return con.FromSubstraitJSON(json_str);
}

void CreateEmployeeTable(Connection& con) {
REQUIRE_NO_FAIL(con.Query("CREATE TABLE employees ("
"employee_id INTEGER PRIMARY KEY, "
"name VARCHAR(100), "
"department_id INTEGER, "
"salary DECIMAL(10, 2))"));

REQUIRE_NO_FAIL(con.Query("INSERT INTO employees VALUES "
"(1, 'John Doe', 1, 120000), "
"(2, 'Jane Smith', 2, 80000), "
"(3, 'Alice Johnson', 1, 50000), "
"(4, 'Bob Brown', 3, 95000), "
"(5, 'Charlie Black', 2, 60000)"));
}

void CreatePartTimeEmployeeTable(Connection& con) {
REQUIRE_NO_FAIL(con.Query("CREATE TABLE part_time_employees ("
"id INTEGER PRIMARY KEY, "
Expand Down
31 changes: 31 additions & 0 deletions test/c/test_substrait_c_utils.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
#include "test_helpers.hpp"
#include "test_substrait_c_utils.hpp"

using namespace duckdb;
using namespace std;


duckdb::unique_ptr<QueryResult> ExecuteViaSubstrait(Connection &con, const string &sql) {
auto proto = con.GetSubstrait(sql);
return con.FromSubstrait(proto);
}

duckdb::unique_ptr<QueryResult> ExecuteViaSubstraitJSON(Connection &con, const string &sql) {
auto json_str = con.GetSubstraitJSON(sql);
return con.FromSubstraitJSON(json_str);
}

void CreateEmployeeTable(Connection &con) {
REQUIRE_NO_FAIL(con.Query("CREATE TABLE employees ("
"employee_id INTEGER PRIMARY KEY, "
"name VARCHAR(100), "
"department_id INTEGER, "
"salary DECIMAL(10, 2))"));

REQUIRE_NO_FAIL(con.Query("INSERT INTO employees VALUES "
"(1, 'John Doe', 1, 120000), "
"(2, 'Jane Smith', 2, 80000), "
"(3, 'Alice Johnson', 1, 50000), "
"(4, 'Bob Brown', 3, 95000), "
"(5, 'Charlie Black', 2, 60000)"));
}
13 changes: 13 additions & 0 deletions test/c/test_substrait_c_utils.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
#ifndef TEST_SUBSTRAIT_C_UTILS_HPP
#define TEST_SUBSTRAIT_C_UTILS_HPP

#include "duckdb.hpp"
#include "duckdb/main/connection_manager.hpp"

using namespace duckdb;
void CreateEmployeeTable(Connection& con);

duckdb::unique_ptr<QueryResult> ExecuteViaSubstraitJSON(Connection &con, const std::string &query);
duckdb::unique_ptr<QueryResult> ExecuteViaSubstrait(Connection &con, const std::string &query);

#endif
Loading