From a1b156ddcdb8cc3d1abf53622e3c3d214c136d11 Mon Sep 17 00:00:00 2001 From: Yannick Welsch Date: Sun, 3 Mar 2024 12:18:00 +0100 Subject: [PATCH] Fix segfaults in main example --- CMakeLists.txt | 4 ++-- src/prql_extension.cpp | 16 +++++----------- test/sql/prql.test | 17 +++++++++++++++++ 3 files changed, 24 insertions(+), 13 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index abe0894..7bc9698 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -14,8 +14,8 @@ add_library(${EXTENSION_NAME} STATIC ${EXTENSION_SOURCES}) add_subdirectory(corrosion) -corrosion_import_crate(MANIFEST_PATH prql/prqlc/bindings/prqlc-c/Cargo.toml CRATES prqlc-c - CRATE_TYPES staticlib) +corrosion_import_crate(MANIFEST_PATH prql/prqlc/bindings/prqlc-c/Cargo.toml + CRATES prqlc-c CRATE_TYPES staticlib) set(PARAMETERS "-warnings") build_loadable_extension(${TARGET_NAME} ${PARAMETERS} ${EXTENSION_SOURCES}) diff --git a/src/prql_extension.cpp b/src/prql_extension.cpp index 82d2419..2a8a9a1 100644 --- a/src/prql_extension.cpp +++ b/src/prql_extension.cpp @@ -42,9 +42,10 @@ ParserExtensionParseResult prql_parse(ParserExtensionInfo *, bool failed = false; string sql_query_or_error; { - prqlc::CompileResult compile_result = compile(trimmed_string.c_str(), &options); + prqlc::CompileResult compile_result = + compile(trimmed_string.c_str(), &options); std::stringstream ss; - + for (int i = 0; i < compile_result.messages_len; i++) { prqlc::Message const *e = &compile_result.messages[i]; if (e->kind == prqlc::MessageKind::Error) { @@ -68,7 +69,7 @@ ParserExtensionParseResult prql_parse(ParserExtensionInfo *, sql_query_or_error = ss.str(); prqlc::result_destroy(compile_result); } - + if (failed) { // sql_query_or_error contains error string // TODO: decide when to consider it a PRQL failure vs this parser extension @@ -81,13 +82,6 @@ ParserExtensionParseResult prql_parse(ParserExtensionInfo *, return ParserExtensionParseResult(std::move(sql_query_or_error)); } - // if (sql_query_or_error.find("WITH table_0 AS") != std::string::npos) { - // sql_query_or_error = "WITH table_0 AS (SELECT customer_id, total - 0.8 AS _expr_0, total FROM invoices WHERE invoice_date >= DATE '1970-01-16') SELECT customer_id, AVG(total), COALESCE(SUM(_expr_0), 0) AS sum_income, COUNT(*) AS ct FROM table_0 WHERE _expr_0 > 1 GROUP BY customer_id"; - // sql_query_or_error = "WITH table_0 AS (SELECT customer_id, total FROM invoices WHERE invoice_date < today()) SELECT customer_id, AVG(total) FROM table_0 GROUP BY customer_id"; - // // sql_query_or_error = "WITH table_0 AS (SELECT customer_id, total FROM invoices WHERE invoice_date < today()) SELECT customer_id FROM table_0"; - // } - - // printf("%s\n", sql_query_or_error.c_str()); Parser parser; // TODO Pass (ClientContext.GetParserOptions()); @@ -119,7 +113,7 @@ BoundStatement prql_bind(ClientContext &context, Binder &binder, auto lookup = context.registered_state.find("prql"); if (lookup != context.registered_state.end()) { auto prql_state = (PrqlState *)lookup->second.get(); - auto prql_binder = Binder::CreateBinder(context); + auto prql_binder = Binder::CreateBinder(context, &binder); auto prql_parse_data = dynamic_cast(prql_state->parse_data.get()); auto statement = prql_binder->Bind(*(prql_parse_data->statement)); diff --git a/test/sql/prql.test b/test/sql/prql.test index f67c7d4..054c973 100644 --- a/test/sql/prql.test +++ b/test/sql/prql.test @@ -27,3 +27,20 @@ statement error from t1 | srt j ---- Parser Error + +statement ok +INSTALL httpfs; + +statement ok +LOAD httpfs; + +statement ok +CREATE TABLE invoices AS SELECT * FROM + read_csv_auto('https://raw.githubusercontent.com/PRQL/prql/0.8.0/prql-compiler/tests/integration/data/chinook/invoices.csv'); + +statement ok +CREATE TABLE customers AS SELECT * FROM + read_csv_auto('https://raw.githubusercontent.com/PRQL/prql/0.8.0/prql-compiler/tests/integration/data/chinook/customers.csv'); + +statement ok +from invoices | filter invoice_date >= @1970-01-16 | derive { transaction_fees = 0.8, income = total - transaction_fees } | filter income > 1 | group customer_id ( aggregate { average total, sum_income = sum income, ct = count total, }) | sort {-sum_income} | take 10 | join c=customers (==customer_id) | derive name = f"{c.last_name}, {c.first_name}" | select { c.customer_id, name, sum_income } | derive db_version = s"version()";