diff --git a/src/from_substrait.cpp b/src/from_substrait.cpp index f995233..dfb6a63 100644 --- a/src/from_substrait.cpp +++ b/src/from_substrait.cpp @@ -37,6 +37,8 @@ #include "duckdb/main/relation/projection_relation.hpp" #include "duckdb/main/relation/setop_relation.hpp" +#include + namespace duckdb { const std::unordered_map SubstraitToDuckDB::function_names_remap = { {"modulus", "mod"}, {"std_dev", "stddev"}, {"starts_with", "prefix"}, @@ -492,22 +494,59 @@ shared_ptr SubstraitToDuckDB::TransformFilterOp(const substrait::Rel & return make_shared_ptr(TransformOp(sfilter.input()), TransformExpr(sfilter.condition())); } +const google::protobuf::RepeatedField& GetOutputMapping(const substrait::Rel &sop) { + const substrait::RelCommon* common = nullptr; + switch (sop.rel_type_case()) { + case substrait::Rel::RelTypeCase::kJoin: + common = &sop.join().common(); + break; + case substrait::Rel::RelTypeCase::kProject: + common = &sop.project().common(); + break; + default: + throw InternalException("Unsupported relation type " + to_string(sop.rel_type_case())); + } + if (!common->has_emit()) { + static google::protobuf::RepeatedField empty_mapping; + return empty_mapping; + } + return common->emit().output_mapping(); +} + shared_ptr SubstraitToDuckDB::TransformProjectOp(const substrait::Rel &sop, const google::protobuf::RepeatedPtrField *names) { vector> expressions; RootNameIterator iterator(names); - for (auto &sexpr : sop.project().expressions()) { - expressions.push_back(TransformExpr(sexpr, &iterator)); + auto input_rel = TransformOp(sop.project().input()); + + auto mapping = GetOutputMapping(sop); + auto num_input_columns = input_rel->Columns().size(); + if (mapping.empty()) { + for (int i = 1; i <= num_input_columns; i++) { + expressions.push_back(make_uniq(i)); + } + + for (auto &sexpr : sop.project().expressions()) { + expressions.push_back(TransformExpr(sexpr, &iterator)); + } + } else { + expressions.resize(mapping.size()); + for (size_t i = 0; i < mapping.size(); i++) { + if (mapping[i] < num_input_columns) { + expressions[i] = make_uniq(mapping[i] + 1); + } else { + expressions[i] = TransformExpr(sop.project().expressions(mapping[i] - num_input_columns), &iterator); + } + } } vector mock_aliases; for (size_t i = 0; i < expressions.size(); i++) { mock_aliases.push_back("expr_" + to_string(i)); } - return make_shared_ptr(TransformOp(sop.project().input()), std::move(expressions), - std::move(mock_aliases)); + return make_shared_ptr(input_rel, std::move(expressions), std::move(mock_aliases)); } shared_ptr SubstraitToDuckDB::TransformAggregateOp(const substrait::Rel &sop) { diff --git a/src/include/to_substrait.hpp b/src/include/to_substrait.hpp index 7466395..fdbe03a 100644 --- a/src/include/to_substrait.hpp +++ b/src/include/to_substrait.hpp @@ -73,6 +73,7 @@ class DuckDBToSubstrait { substrait::Rel *TransformInsertTable(LogicalOperator &dop); substrait::Rel *TransformDeleteTable(LogicalOperator &dop); static substrait::Rel *TransformDummyScan(); + substrait::RelCommon *CreateOutputMapping(vector vector); //! Methods to transform different LogicalGet Types (e.g., Table, Parquet) //! To Substrait; void TransformTableScanToSubstrait(LogicalGet &dget, substrait::ReadRel *sget) const; @@ -135,6 +136,7 @@ class DuckDBToSubstrait { static std::string &RemapFunctionName(std::string &function_name); static bool IsExtractFunction(const string &function_name); + //! Creates a Conjunction template substrait::Expression *CreateConjunction(T &source, const FUNC f) { diff --git a/src/to_substrait.cpp b/src/to_substrait.cpp index 6ff0b0a..263916c 100644 --- a/src/to_substrait.cpp +++ b/src/to_substrait.cpp @@ -19,6 +19,8 @@ #include "duckdb/parser/constraints/not_null_constraint.hpp" #include "duckdb/execution/index/art/art_key.hpp" +#include + namespace duckdb { const std::unordered_map DuckDBToSubstrait::function_names_remap = { {"mod", "modulus"}, @@ -854,14 +856,70 @@ substrait::Rel *DuckDBToSubstrait::TransformFilter(LogicalOperator &dop) { return res; } +substrait::RelCommon *DuckDBToSubstrait::CreateOutputMapping(vector vector) { + auto rel_common = new substrait::RelCommon(); + auto output_mapping = rel_common->mutable_emit()->mutable_output_mapping(); + for (auto &col_idx : vector) { + output_mapping->Add(col_idx); + } + return rel_common; +} + substrait::Rel *DuckDBToSubstrait::TransformProjection(LogicalOperator &dop) { auto res = new substrait::Rel(); auto &dproj = dop.Cast(); + + auto child_column_count = dop.children[0]->types.size(); + auto num_passthrough_columns = 0; + auto need_output_mapping = true; + if (child_column_count <= dproj.expressions.size()) { + // check if the projection is just pass through of input columns with no reordering + auto exp_col_idx = 0; + auto is_passthrough = true; + for (auto &dexpr : dproj.expressions) { + if (dexpr->type != ExpressionType::BOUND_REF) { + is_passthrough = false; + break; + } + num_passthrough_columns++; + auto &dref = dexpr.get()->Cast(); + if (dref.index != exp_col_idx) { + is_passthrough = false; + break; + } + exp_col_idx++; + } + if (is_passthrough && child_column_count == exp_col_idx) { + // skip the projection + return TransformOp(*dop.children[0]); + } + if (child_column_count == exp_col_idx) { + // all input columns are projected, no need for output mapping + num_passthrough_columns = child_column_count; + need_output_mapping = false; + } + } + auto sproj = res->mutable_project(); sproj->set_allocated_input(TransformOp(*dop.children[0])); - + auto t_index = 0; + vector output_mapping; for (auto &dexpr : dproj.expressions) { - TransformExpr(*dexpr, *sproj->add_expressions()); + switch (dexpr->type) { + case ExpressionType::BOUND_REF: { + auto &dref = dexpr.get()->Cast(); + output_mapping.push_back(dref.index); + break; + } + default: + TransformExpr(*dexpr.get(), *sproj->add_expressions()); + output_mapping.push_back(child_column_count + t_index); + t_index++; + } + } + if (need_output_mapping) { + auto rel_common = CreateOutputMapping(output_mapping); + sproj->set_allocated_common(rel_common); } return res; } @@ -996,6 +1054,13 @@ substrait::Rel *DuckDBToSubstrait::TransformComparisonJoin(LogicalOperator &dop) } } + auto child_column_count = dop.children[0]->types.size() + dop.children[1]->types.size(); + vector output_mapping; + for (idx_t i = 0; i < projection->expressions_size(); i++) { + output_mapping.push_back(child_column_count + i); + } + auto rel_common = CreateOutputMapping(output_mapping); + projection->set_allocated_common(rel_common); projection->set_allocated_input(res); return proj_rel; } diff --git a/test/c/CMakeLists.txt b/test/c/CMakeLists.txt index 7a01a17..a230c7f 100644 --- a/test/c/CMakeLists.txt +++ b/test/c/CMakeLists.txt @@ -12,7 +12,7 @@ include_directories(../../duckdb/src/include) include_directories(../../duckdb/test/include) include_directories(../../duckdb/third_party/catch) -set(ALL_SOURCES test_substrait_c_api.cpp) +set(ALL_SOURCES test_substrait_c_api.cpp test_substrait_c_utils.cpp test_projection.cpp) add_library_unity(test_substrait OBJECT ${ALL_SOURCES}) diff --git a/test/c/test_projection.cpp b/test/c/test_projection.cpp new file mode 100644 index 0000000..d7141d0 --- /dev/null +++ b/test/c/test_projection.cpp @@ -0,0 +1,136 @@ +#include "catch.hpp" +#include "test_helpers.hpp" +#include "duckdb/main/connection_manager.hpp" +#include "test_substrait_c_utils.hpp" + +#include +#include +#include + +using namespace duckdb; +using namespace std; + +TEST_CASE("Test C Project input columns with Substrait API", "[substrait-api]") { + DuckDB db(nullptr); + Connection con(db); + + REQUIRE_NO_FAIL(con.Query("CREATE TABLE integers(i INTEGER)")); + REQUIRE_NO_FAIL(con.Query("INSERT INTO integers VALUES (10), (20), (30)")); + CreateEmployeeTable(con); + + auto expected_json_str = R"({"relations":[{"root":{"input":{"read":{"baseSchema":{"names":["i"],"struct":{"types":[{"i32":{"nullability":"NULLABILITY_NULLABLE"}}],"nullability":"NULLABILITY_REQUIRED"}},"projection":{"select":{"structItems":[{}]},"maintainSingularStruct":true},"namedTable":{"names":["integers"]}}},"names":["i"]}}],"version":{"minorNumber":53,"producer":"DuckDB"}})"; + auto json_str = con.GetSubstraitJSON("SELECT i FROM integers"); + REQUIRE(json_str == expected_json_str); + auto result = con.FromSubstraitJSON(json_str); + REQUIRE(CHECK_COLUMN(result, 0, {10, 20, 30})); +} + +TEST_CASE("Test C Project 1 input column 1 transformation with Substrait API", "[substrait-api]") { + DuckDB db(nullptr); + Connection con(db); + + REQUIRE_NO_FAIL(con.Query("CREATE TABLE integers(i INTEGER)")); + REQUIRE_NO_FAIL(con.Query("INSERT INTO integers VALUES (10), (20), (30)")); + CreateEmployeeTable(con); + + auto expected_json_str = R"({"extensionUris":[{"extensionUriAnchor":1,"uri":"https://github.com/substrait-io/substrait/blob/main/extensions/functions_arithmetic.yaml"}],"extensions":[{"extensionFunction":{"extensionUriReference":1,"functionAnchor":1,"name":"multiply:i32_i32"}}],"relations":[{"root":{"input":{"project":{"input":{"read":{"baseSchema":{"names":["i"],"struct":{"types":[{"i32":{"nullability":"NULLABILITY_NULLABLE"}}],"nullability":"NULLABILITY_REQUIRED"}},"projection":{"select":{"structItems":[{}]},"maintainSingularStruct":true},"namedTable":{"names":["integers"]}}},"expressions":[{"scalarFunction":{"functionReference":1,"outputType":{"i32":{"nullability":"NULLABILITY_NULLABLE"}},"arguments":[{"value":{"selection":{"directReference":{"structField":{}},"rootReference":{}}}},{"value":{"selection":{"directReference":{"structField":{}},"rootReference":{}}}}]}}]}},"names":["i","isquare"]}}],"version":{"minorNumber":53,"producer":"DuckDB"}})"; + auto json_str = con.GetSubstraitJSON("SELECT i, i *i as isquare FROM integers"); + REQUIRE(json_str == expected_json_str); + auto result = con.FromSubstraitJSON(json_str); + REQUIRE(CHECK_COLUMN(result, 0, {10, 20, 30})); + REQUIRE(CHECK_COLUMN(result, 1, {100, 400, 900})); +} + +TEST_CASE("Test C Project all columns with Substrait API", "[substrait-api]") { + DuckDB db(nullptr); + Connection con(db); + + CreateEmployeeTable(con); + + // This should not have a ProjectRel node + auto json_str = con.GetSubstraitJSON("SELECT * FROM employees"); + auto expected_json_str = R"({"relations":[{"root":{"input":{"read":{"baseSchema":{"names":["employee_id","name","department_id","salary"],"struct":{"types":[{"i32":{"nullability":"NULLABILITY_REQUIRED"}},{"string":{"nullability":"NULLABILITY_NULLABLE"}},{"i32":{"nullability":"NULLABILITY_NULLABLE"}},{"decimal":{"scale":2,"precision":10,"nullability":"NULLABILITY_NULLABLE"}}],"nullability":"NULLABILITY_REQUIRED"}},"projection":{"select":{"structItems":[{},{"field":1},{"field":2},{"field":3}]},"maintainSingularStruct":true},"namedTable":{"names":["employees"]}}},"names":["employee_id","name","department_id","salary"]}}],"version":{"minorNumber":53,"producer":"DuckDB"}})"; + REQUIRE(json_str == expected_json_str); + auto result = con.FromSubstraitJSON(json_str); + REQUIRE(CHECK_COLUMN(result, 0, {1, 2, 3, 4, 5})); + REQUIRE(CHECK_COLUMN(result, 1, {"John Doe", "Jane Smith", "Alice Johnson", "Bob Brown", "Charlie Black"})); + REQUIRE(CHECK_COLUMN(result, 2, {1, 2, 1, 3, 2})); + REQUIRE(CHECK_COLUMN(result, 3, {120000, 80000, 50000, 95000, 60000})); +} + +TEST_CASE("Test C Project two passthrough columns with Substrait API", "[substrait-api]") { + DuckDB db(nullptr); + Connection con(db); + + CreateEmployeeTable(con); + + // This should not have a ProjectRel node + auto json_str = con.GetSubstraitJSON("SELECT name, salary FROM employees"); + auto expected_json_str = R"({"relations":[{"root":{"input":{"read":{"baseSchema":{"names":["employee_id","name","department_id","salary"],"struct":{"types":[{"i32":{"nullability":"NULLABILITY_REQUIRED"}},{"string":{"nullability":"NULLABILITY_NULLABLE"}},{"i32":{"nullability":"NULLABILITY_NULLABLE"}},{"decimal":{"scale":2,"precision":10,"nullability":"NULLABILITY_NULLABLE"}}],"nullability":"NULLABILITY_REQUIRED"}},"projection":{"select":{"structItems":[{"field":1},{"field":3}]},"maintainSingularStruct":true},"namedTable":{"names":["employees"]}}},"names":["name","salary"]}}],"version":{"minorNumber":53,"producer":"DuckDB"}})"; + REQUIRE(json_str == expected_json_str); + auto result = con.FromSubstraitJSON(json_str); + REQUIRE(CHECK_COLUMN(result, 0, {"John Doe", "Jane Smith", "Alice Johnson", "Bob Brown", "Charlie Black"})); + REQUIRE(CHECK_COLUMN(result, 1, {120000, 80000, 50000, 95000, 60000})); +} + +TEST_CASE("Test C Project two passthrough columns with filter", "[substrait-api]") { + DuckDB db(nullptr); + Connection con(db); + + CreateEmployeeTable(con); + + // This should not have a ProjectRel node + auto json_str = con.GetSubstraitJSON("SELECT name, salary FROM employees where department_id = 1"); + auto expected_json_str = R"({"extensionUris":[{"extensionUriAnchor":1,"uri":"https://github.com/substrait-io/substrait/blob/main/extensions/"}],"extensions":[{"extensionFunction":{"extensionUriReference":1,"functionAnchor":1,"name":"equal:i32_i32"}}],"relations":[{"root":{"input":{"read":{"baseSchema":{"names":["employee_id","name","department_id","salary"],"struct":{"types":[{"i32":{"nullability":"NULLABILITY_REQUIRED"}},{"string":{"nullability":"NULLABILITY_NULLABLE"}},{"i32":{"nullability":"NULLABILITY_NULLABLE"}},{"decimal":{"scale":2,"precision":10,"nullability":"NULLABILITY_NULLABLE"}}],"nullability":"NULLABILITY_REQUIRED"}},"filter":{"scalarFunction":{"functionReference":1,"outputType":{"i32":{"nullability":"NULLABILITY_NULLABLE"}},"arguments":[{"value":{"selection":{"directReference":{"structField":{"field":2}},"rootReference":{}}}},{"value":{"literal":{"i32":1}}}]}},"projection":{"select":{"structItems":[{"field":1},{"field":3}]},"maintainSingularStruct":true},"namedTable":{"names":["employees"]}}},"names":["name","salary"]}}],"version":{"minorNumber":53,"producer":"DuckDB"}})"; + REQUIRE(json_str == expected_json_str); + auto result = con.FromSubstraitJSON(json_str); + REQUIRE(CHECK_COLUMN(result, 0, {"John Doe", "Alice Johnson" })); + REQUIRE(CHECK_COLUMN(result, 1, {120000, 50000 })); +} + +TEST_CASE("Test C Project 1 passthrough column, 1 transformation with column elimination", "[substrait-api]") { + DuckDB db(nullptr); + Connection con(db); + + CreateEmployeeTable(con); + + auto json_str = con.GetSubstraitJSON("SELECT name, salary * 1.2 as new_salary FROM employees"); + auto expected_json_str = R"({"extensionUris":[{"extensionUriAnchor":1,"uri":"https://github.com/substrait-io/substrait/blob/main/extensions/functions_arithmetic_decimal.yaml"}],"extensions":[{"extensionFunction":{"extensionUriReference":1,"functionAnchor":1,"name":"multiply:decimal_decimal"}}],"relations":[{"root":{"input":{"project":{"common":{"emit":{"outputMapping":[0,2]}},"input":{"read":{"baseSchema":{"names":["employee_id","name","department_id","salary"],"struct":{"types":[{"i32":{"nullability":"NULLABILITY_REQUIRED"}},{"string":{"nullability":"NULLABILITY_NULLABLE"}},{"i32":{"nullability":"NULLABILITY_NULLABLE"}},{"decimal":{"scale":2,"precision":10,"nullability":"NULLABILITY_NULLABLE"}}],"nullability":"NULLABILITY_REQUIRED"}},"projection":{"select":{"structItems":[{"field":1},{"field":3}]},"maintainSingularStruct":true},"namedTable":{"names":["employees"]}}},"expressions":[{"scalarFunction":{"functionReference":1,"outputType":{"decimal":{"scale":3,"precision":12,"nullability":"NULLABILITY_NULLABLE"}},"arguments":[{"value":{"selection":{"directReference":{"structField":{"field":1}},"rootReference":{}}}},{"value":{"literal":{"decimal":{"value":"DAAAAAAAAAAAAAAAAAAAAA==","precision":12,"scale":1}}}}]}}]}},"names":["name","new_salary"]}}],"version":{"minorNumber":53,"producer":"DuckDB"}})"; + REQUIRE(json_str == expected_json_str); + auto result = con.FromSubstraitJSON(json_str); + REQUIRE(CHECK_COLUMN(result, 0, {"John Doe", "Jane Smith", "Alice Johnson", "Bob Brown", "Charlie Black"})); + REQUIRE(CHECK_COLUMN(result, 1, {144000, 96000, 60000, 114000, 72000})); +} + +TEST_CASE("Test C Project 1 passthrough column and 1 aggregate transformation", "[substrait-api]") { + DuckDB db(nullptr); + Connection con(db); + + CreateEmployeeTable(con); + + auto json_str = con.GetSubstraitJSON("SELECT department_id, AVG(salary) AS avg_salary FROM employees GROUP BY department_id"); + auto expected_json_str = R"({"extensionUris":[{"extensionUriAnchor":1,"uri":"https://github.com/substrait-io/substrait/blob/main/extensions/"}],"extensions":[{"extensionFunction":{"extensionUriReference":1,"functionAnchor":1,"name":"avg:decimal"}}],"relations":[{"root":{"input":{"aggregate":{"input":{"read":{"baseSchema":{"names":["employee_id","name","department_id","salary"],"struct":{"types":[{"i32":{"nullability":"NULLABILITY_REQUIRED"}},{"string":{"nullability":"NULLABILITY_NULLABLE"}},{"i32":{"nullability":"NULLABILITY_NULLABLE"}},{"decimal":{"scale":2,"precision":10,"nullability":"NULLABILITY_NULLABLE"}}],"nullability":"NULLABILITY_REQUIRED"}},"projection":{"select":{"structItems":[{"field":2},{"field":3}]},"maintainSingularStruct":true},"namedTable":{"names":["employees"]}}},"groupings":[{"groupingExpressions":[{"selection":{"directReference":{"structField":{}},"rootReference":{}}}]}],"measures":[{"measure":{"functionReference":1,"outputType":{"fp64":{"nullability":"NULLABILITY_NULLABLE"}},"arguments":[{"value":{"selection":{"directReference":{"structField":{"field":1}},"rootReference":{}}}}]}}]}},"names":["department_id","avg_salary"]}}],"version":{"minorNumber":53,"producer":"DuckDB"}})"; + REQUIRE(json_str == expected_json_str); + auto result = con.FromSubstraitJSON(json_str); + REQUIRE(CHECK_COLUMN(result, 0, {1, 2, 3})); + REQUIRE(CHECK_COLUMN(result, 1, {85000, 70000, 95000})); +} + +TEST_CASE("Test C Project on Join with Substrait API", "[substrait-api]") { + DuckDB db(nullptr); + Connection con(db); + + CreateEmployeeTable(con); + CreateDepartmentsTable(con); + + auto result = ExecuteViaSubstraitJSON(con, + "SELECT e.employee_id, e.name, d.department_name " + "FROM employees e " + "JOIN departments d " + "ON e.department_id = d.department_id" + ); + + REQUIRE(CHECK_COLUMN(result, 0, {1, 2, 3, 4, 5})); + REQUIRE(CHECK_COLUMN(result, 1, {"John Doe", "Jane Smith", "Alice Johnson", "Bob Brown", "Charlie Black"})); + REQUIRE(CHECK_COLUMN(result, 2, {"HR", "Engineering", "HR", "Finance", "Engineering"})); +} \ No newline at end of file diff --git a/test/c/test_substrait_c_api.cpp b/test/c/test_substrait_c_api.cpp index 27ab432..5f61caa 100644 --- a/test/c/test_substrait_c_api.cpp +++ b/test/c/test_substrait_c_api.cpp @@ -1,6 +1,7 @@ #include "catch.hpp" #include "test_helpers.hpp" #include "duckdb/main/connection_manager.hpp" +#include "test_substrait_c_utils.hpp" #include #include @@ -47,52 +48,6 @@ TEST_CASE("Test C Get and To Json-Substrait API", "[substrait-api]") { REQUIRE_THROWS(con.FromSubstraitJSON("this is not valid")); } -duckdb::unique_ptr ExecuteViaSubstrait(Connection &con, const string &sql) { - auto proto = con.GetSubstrait(sql); - return con.FromSubstrait(proto); -} - -duckdb::unique_ptr ExecuteViaSubstraitJSON(Connection &con, const string &sql) { - auto json_str = con.GetSubstraitJSON(sql); - return con.FromSubstraitJSON(json_str); -} - -void CreateEmployeeTable(Connection& con) { - REQUIRE_NO_FAIL(con.Query("CREATE TABLE employees (" - "employee_id INTEGER PRIMARY KEY, " - "name VARCHAR(100), " - "department_id INTEGER, " - "salary DECIMAL(10, 2))")); - - REQUIRE_NO_FAIL(con.Query("INSERT INTO employees VALUES " - "(1, 'John Doe', 1, 120000), " - "(2, 'Jane Smith', 2, 80000), " - "(3, 'Alice Johnson', 1, 50000), " - "(4, 'Bob Brown', 3, 95000), " - "(5, 'Charlie Black', 2, 60000)")); -} - -void CreatePartTimeEmployeeTable(Connection& con) { - REQUIRE_NO_FAIL(con.Query("CREATE TABLE part_time_employees (" - "id INTEGER PRIMARY KEY, " - "name VARCHAR(100), " - "department_id INTEGER, " - "hourly_rate DECIMAL(10, 2))")); - - REQUIRE_NO_FAIL(con.Query("INSERT INTO part_time_employees VALUES " - "(6, 'David White', 1, 30000), " - "(7, 'Eve Green', 2, 40000)")); -} - -void CreateDepartmentsTable(Connection& con) { - REQUIRE_NO_FAIL(con.Query("CREATE TABLE departments (department_id INTEGER PRIMARY KEY, department_name VARCHAR(100))")); - - REQUIRE_NO_FAIL(con.Query("INSERT INTO departments VALUES " - "(1, 'HR'), " - "(2, 'Engineering'), " - "(3, 'Finance')")); -} - TEST_CASE("Test C CTAS Select columns with Substrait API", "[substrait-api]") { DuckDB db(nullptr); Connection con(db); diff --git a/test/c/test_substrait_c_utils.cpp b/test/c/test_substrait_c_utils.cpp new file mode 100644 index 0000000..e3c1501 --- /dev/null +++ b/test/c/test_substrait_c_utils.cpp @@ -0,0 +1,52 @@ +#include "test_helpers.hpp" +#include "test_substrait_c_utils.hpp" + +using namespace duckdb; +using namespace std; + + +duckdb::unique_ptr ExecuteViaSubstrait(Connection &con, const string &sql) { + auto proto = con.GetSubstrait(sql); + return con.FromSubstrait(proto); +} + +duckdb::unique_ptr ExecuteViaSubstraitJSON(Connection &con, const string &sql) { + auto json_str = con.GetSubstraitJSON(sql); + return con.FromSubstraitJSON(json_str); +} + +void CreateEmployeeTable(Connection &con) { + REQUIRE_NO_FAIL(con.Query("CREATE TABLE employees (" + "employee_id INTEGER PRIMARY KEY, " + "name VARCHAR(100), " + "department_id INTEGER, " + "salary DECIMAL(10, 2))")); + + REQUIRE_NO_FAIL(con.Query("INSERT INTO employees VALUES " + "(1, 'John Doe', 1, 120000), " + "(2, 'Jane Smith', 2, 80000), " + "(3, 'Alice Johnson', 1, 50000), " + "(4, 'Bob Brown', 3, 95000), " + "(5, 'Charlie Black', 2, 60000)")); +} + +void CreatePartTimeEmployeeTable(Connection& con) { + REQUIRE_NO_FAIL(con.Query("CREATE TABLE part_time_employees (" + "id INTEGER PRIMARY KEY, " + "name VARCHAR(100), " + "department_id INTEGER, " + "hourly_rate DECIMAL(10, 2))")); + + REQUIRE_NO_FAIL(con.Query("INSERT INTO part_time_employees VALUES " + "(6, 'David White', 1, 30000), " + "(7, 'Eve Green', 2, 40000)")); +} + +void CreateDepartmentsTable(Connection& con) { + REQUIRE_NO_FAIL(con.Query("CREATE TABLE departments (department_id INTEGER PRIMARY KEY, department_name VARCHAR(100))")); + + REQUIRE_NO_FAIL(con.Query("INSERT INTO departments VALUES " + "(1, 'HR'), " + "(2, 'Engineering'), " + "(3, 'Finance')")); +} diff --git a/test/c/test_substrait_c_utils.hpp b/test/c/test_substrait_c_utils.hpp new file mode 100644 index 0000000..52665f4 --- /dev/null +++ b/test/c/test_substrait_c_utils.hpp @@ -0,0 +1,15 @@ +#ifndef TEST_SUBSTRAIT_C_UTILS_HPP +#define TEST_SUBSTRAIT_C_UTILS_HPP + +#include "duckdb.hpp" +#include "duckdb/main/connection_manager.hpp" + +using namespace duckdb; +void CreateEmployeeTable(Connection& con); +void CreatePartTimeEmployeeTable(Connection& con); +void CreateDepartmentsTable(Connection& con); + +duckdb::unique_ptr ExecuteViaSubstraitJSON(Connection &con, const std::string &query); +duckdb::unique_ptr ExecuteViaSubstrait(Connection &con, const std::string &query); + +#endif