From 66d45867eb7854cef96ae0bfb8b10b7437356385 Mon Sep 17 00:00:00 2001 From: Andrei Goncear Date: Mon, 19 Aug 2024 12:15:06 +0000 Subject: [PATCH] implemented QueryRow & Exec methods of sql driver interface --- chdb/driver/driver.go | 104 ++++++++++++++++++++++++++++++++++--- chdb/driver/driver_test.go | 90 ++++++++++++++++++++++++++++++++ 2 files changed, 187 insertions(+), 7 deletions(-) diff --git a/chdb/driver/driver.go b/chdb/driver/driver.go index ec1d1fb..22f7a57 100644 --- a/chdb/driver/driver.go +++ b/chdb/driver/driver.go @@ -79,6 +79,60 @@ func init() { sql.Register("chdb", Driver{}) } +// Row is the result of calling [DB.QueryRow] to select a single row. +type singleRow struct { + // One of these two will be non-nil: + err error // deferred error for easy chaining + rows driver.Rows +} + +// Scan copies the columns from the matched row into the values +// pointed at by dest. See the documentation on [Rows.Scan] for details. +// If more than one row matches the query, +// Scan uses the first row and discards the rest. If no row matches +// the query, Scan returns [ErrNoRows]. +func (r *singleRow) Scan(dest ...any) error { + if r.err != nil { + return r.err + } + vals := make([]driver.Value, 0) + for _, v := range dest { + vals = append(vals, v) + } + err := r.rows.Next(vals) + if err != nil { + return err + } + // Make sure the query can be processed to completion with no errors. + return r.rows.Close() +} + +// Err provides a way for wrapping packages to check for +// query errors without calling [Row.Scan]. +// Err returns the error, if any, that was encountered while running the query. +// If this error is not nil, this error will also be returned from [Row.Scan]. +func (r *singleRow) Err() error { + return r.err +} + +type execResult struct { + err error +} + +func (e *execResult) LastInsertId() (int64, error) { + if e.err != nil { + return 0, e.err + } + return -1, fmt.Errorf("does not support LastInsertId") + +} +func (e *execResult) RowsAffected() (int64, error) { + if e.err != nil { + return 0, e.err + } + return -1, fmt.Errorf("does not support RowsAffected") +} + type queryHandle func(string, ...string) (*chdbstable.LocalResult, error) type connector struct { @@ -192,6 +246,18 @@ type conn struct { QueryFun queryHandle } +func prepareValues(values []driver.Value) []driver.NamedValue { + namedValues := make([]driver.NamedValue, len(values)) + for i, value := range values { + namedValues[i] = driver.NamedValue{ + // nb: Name field is optional + Ordinal: i, + Value: value, + } + } + return namedValues +} + func (c *conn) Close() error { return nil } @@ -204,15 +270,39 @@ func (c *conn) SetupQueryFun() { } func (c *conn) Query(query string, values []driver.Value) (driver.Rows, error) { - namedValues := make([]driver.NamedValue, len(values)) - for i, value := range values { - namedValues[i] = driver.NamedValue{ - // nb: Name field is optional - Ordinal: i, - Value: value, + return c.QueryContext(context.Background(), query, prepareValues(values)) +} + +func (c *conn) QueryRow(query string, values []driver.Value) *singleRow { + return c.QueryRowContext(context.Background(), query, values) +} + +func (c *conn) Exec(query string, values []driver.Value) (sql.Result, error) { + return c.ExecContext(context.Background(), query, prepareValues(values)) +} + +func (c *conn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) { + _, err := c.QueryContext(ctx, query, args) + if err != nil && err.Error() != "result is nil" { + return nil, err + } + return &execResult{ + err: nil, + }, nil +} + +func (c *conn) QueryRowContext(ctx context.Context, query string, values []driver.Value) *singleRow { + + v, err := c.QueryContext(ctx, query, prepareValues(values)) + if err != nil { + return &singleRow{ + err: err, + rows: nil, } } - return c.QueryContext(context.Background(), query, namedValues) + return &singleRow{ + rows: v, + } } func (c *conn) compileArguments(query string, args []driver.NamedValue) (string, error) { diff --git a/chdb/driver/driver_test.go b/chdb/driver/driver_test.go index fb6b4ae..d16b951 100644 --- a/chdb/driver/driver_test.go +++ b/chdb/driver/driver_test.go @@ -167,3 +167,93 @@ func TestDbWithSession(t *testing.T) { count++ } } + +func TestQueryRow(t *testing.T) { + sessionDir, err := os.MkdirTemp("", "unittest-sessiondata") + if err != nil { + t.Fatalf("create temp directory fail, err: %s", err) + } + defer os.RemoveAll(sessionDir) + session, err := chdb.NewSession(sessionDir) + if err != nil { + t.Fatalf("new session fail, err: %s", err) + } + defer session.Cleanup() + + session.Query("USE testdb; INSERT INTO testtable VALUES (1), (2), (3);") + + ret, err := session.Query("SELECT * FROM testtable;") + if err != nil { + t.Fatalf("Query fail, err: %s", err) + } + if string(ret.Buf()) != "1\n2\n3\n" { + t.Errorf("Query result should be 1\n2\n3\n, got %s", string(ret.Buf())) + } + db, err := sql.Open("chdb", fmt.Sprintf("session=%s", sessionDir)) + if err != nil { + t.Fatalf("open db fail, err: %s", err) + } + if db.Ping() != nil { + t.Fatalf("ping db fail, err: %s", err) + } + rows := db.QueryRow("select * from testtable;") + + var bar = 0 + var count = 1 + err = rows.Scan(&bar) + if err != nil { + t.Fatalf("scan fail, err: %s", err) + } + if bar != count { + t.Fatalf("result is not match, want: %d actual: %d", count, bar) + } + err2 := rows.Scan(&bar) + if err2 == nil { + t.Fatalf("QueryRow method should return only one item") + } + +} + +func TestExec(t *testing.T) { + sessionDir, err := os.MkdirTemp("", "unittest-sessiondata") + if err != nil { + t.Fatalf("create temp directory fail, err: %s", err) + } + defer os.RemoveAll(sessionDir) + session, err := chdb.NewSession(sessionDir) + if err != nil { + t.Fatalf("new session fail, err: %s", err) + } + defer session.Cleanup() + session.Query("CREATE DATABASE IF NOT EXISTS testdb; " + + "CREATE TABLE IF NOT EXISTS testdb.testtable (id UInt32) ENGINE = MergeTree() ORDER BY id;") + + db, err := sql.Open("chdb", fmt.Sprintf("session=%s", sessionDir)) + if err != nil { + t.Fatalf("open db fail, err: %s", err) + } + if db.Ping() != nil { + t.Fatalf("ping db fail, err: %s", err) + } + + _, err = db.Exec("INSERT INTO testdb.testtable VALUES (1), (2), (3);") + if err != nil { + t.Fatalf("exec failed, err: %s", err) + } + rows := db.QueryRow("select * from testdb.testtable;") + + var bar = 0 + var count = 1 + err = rows.Scan(&bar) + if err != nil { + t.Fatalf("scan fail, err: %s", err) + } + if bar != count { + t.Fatalf("result is not match, want: %d actual: %d", count, bar) + } + err2 := rows.Scan(&bar) + if err2 == nil { + t.Fatalf("QueryRow method should return only one item") + } + +}