Skip to content

Commit

Permalink
Fix segfaults in main example
Browse files Browse the repository at this point in the history
  • Loading branch information
ywelsch committed Mar 3, 2024
1 parent a79ef2c commit a1b156d
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 13 deletions.
4 changes: 2 additions & 2 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@ add_library(${EXTENSION_NAME} STATIC ${EXTENSION_SOURCES})

add_subdirectory(corrosion)

corrosion_import_crate(MANIFEST_PATH prql/prqlc/bindings/prqlc-c/Cargo.toml CRATES prqlc-c
CRATE_TYPES staticlib)
corrosion_import_crate(MANIFEST_PATH prql/prqlc/bindings/prqlc-c/Cargo.toml
CRATES prqlc-c CRATE_TYPES staticlib)

set(PARAMETERS "-warnings")
build_loadable_extension(${TARGET_NAME} ${PARAMETERS} ${EXTENSION_SOURCES})
Expand Down
16 changes: 5 additions & 11 deletions src/prql_extension.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,10 @@ ParserExtensionParseResult prql_parse(ParserExtensionInfo *,
bool failed = false;
string sql_query_or_error;
{
prqlc::CompileResult compile_result = compile(trimmed_string.c_str(), &options);
prqlc::CompileResult compile_result =
compile(trimmed_string.c_str(), &options);
std::stringstream ss;

for (int i = 0; i < compile_result.messages_len; i++) {
prqlc::Message const *e = &compile_result.messages[i];
if (e->kind == prqlc::MessageKind::Error) {
Expand All @@ -68,7 +69,7 @@ ParserExtensionParseResult prql_parse(ParserExtensionInfo *,
sql_query_or_error = ss.str();
prqlc::result_destroy(compile_result);
}

if (failed) {
// sql_query_or_error contains error string
// TODO: decide when to consider it a PRQL failure vs this parser extension
Expand All @@ -81,13 +82,6 @@ ParserExtensionParseResult prql_parse(ParserExtensionInfo *,
return ParserExtensionParseResult(std::move(sql_query_or_error));
}

// if (sql_query_or_error.find("WITH table_0 AS") != std::string::npos) {
// sql_query_or_error = "WITH table_0 AS (SELECT customer_id, total - 0.8 AS _expr_0, total FROM invoices WHERE invoice_date >= DATE '1970-01-16') SELECT customer_id, AVG(total), COALESCE(SUM(_expr_0), 0) AS sum_income, COUNT(*) AS ct FROM table_0 WHERE _expr_0 > 1 GROUP BY customer_id";
// sql_query_or_error = "WITH table_0 AS (SELECT customer_id, total FROM invoices WHERE invoice_date < today()) SELECT customer_id, AVG(total) FROM table_0 GROUP BY customer_id";
// // sql_query_or_error = "WITH table_0 AS (SELECT customer_id, total FROM invoices WHERE invoice_date < today()) SELECT customer_id FROM table_0";
// }


// printf("%s\n", sql_query_or_error.c_str());

Parser parser; // TODO Pass (ClientContext.GetParserOptions());
Expand Down Expand Up @@ -119,7 +113,7 @@ BoundStatement prql_bind(ClientContext &context, Binder &binder,
auto lookup = context.registered_state.find("prql");
if (lookup != context.registered_state.end()) {
auto prql_state = (PrqlState *)lookup->second.get();
auto prql_binder = Binder::CreateBinder(context);
auto prql_binder = Binder::CreateBinder(context, &binder);
auto prql_parse_data =
dynamic_cast<PrqlParseData *>(prql_state->parse_data.get());
auto statement = prql_binder->Bind(*(prql_parse_data->statement));
Expand Down
17 changes: 17 additions & 0 deletions test/sql/prql.test
Original file line number Diff line number Diff line change
Expand Up @@ -27,3 +27,20 @@ statement error
from t1 | srt j
----
Parser Error

statement ok
INSTALL httpfs;

statement ok
LOAD httpfs;

statement ok
CREATE TABLE invoices AS SELECT * FROM
read_csv_auto('https://raw.githubusercontent.com/PRQL/prql/0.8.0/prql-compiler/tests/integration/data/chinook/invoices.csv');

statement ok
CREATE TABLE customers AS SELECT * FROM
read_csv_auto('https://raw.githubusercontent.com/PRQL/prql/0.8.0/prql-compiler/tests/integration/data/chinook/customers.csv');

statement ok
from invoices | filter invoice_date >= @1970-01-16 | derive { transaction_fees = 0.8, income = total - transaction_fees } | filter income > 1 | group customer_id ( aggregate { average total, sum_income = sum income, ct = count total, }) | sort {-sum_income} | take 10 | join c=customers (==customer_id) | derive name = f"{c.last_name}, {c.first_name}" | select { c.customer_id, name, sum_income } | derive db_version = s"version()";

0 comments on commit a1b156d

Please sign in to comment.