Skip to content

Commit

Permalink
implement parquet rows to read data from clickhouse in parquet format
Browse files Browse the repository at this point in the history
  • Loading branch information
agoncear-mwb committed Jul 18, 2024
1 parent ecf0a25 commit dd038d2
Show file tree
Hide file tree
Showing 8 changed files with 729 additions and 223 deletions.
175 changes: 175 additions & 0 deletions chdb/driver/arrow.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,175 @@
package chdbdriver

import (
"database/sql/driver"
"fmt"
"reflect"
"time"

"github.com/apache/arrow/go/v15/arrow"
"github.com/apache/arrow/go/v15/arrow/array"
"github.com/apache/arrow/go/v15/arrow/decimal128"
"github.com/apache/arrow/go/v15/arrow/decimal256"
"github.com/apache/arrow/go/v15/arrow/ipc"
"github.com/chdb-io/chdb-go/chdbstable"
)

type arrowRows struct {
localResult *chdbstable.LocalResult
reader *ipc.FileReader
curRecord arrow.Record
curRow int64
}

func (r *arrowRows) Columns() (out []string) {
sch := r.reader.Schema()
for i := 0; i < sch.NumFields(); i++ {
out = append(out, sch.Field(i).Name)
}
return
}

func (r *arrowRows) Close() error {
if r.curRecord != nil {
r.curRecord = nil
}
// ignore reader close
_ = r.reader.Close()
r.reader = nil
r.localResult = nil
return nil
}

func (r *arrowRows) Next(dest []driver.Value) error {
if r.curRecord != nil && r.curRow == r.curRecord.NumRows() {
r.curRecord = nil
}
for r.curRecord == nil {
record, err := r.reader.Read()
if err != nil {
return err
}
if record.NumRows() == 0 {
continue
}
r.curRecord = record
r.curRow = 0
}

for i, col := range r.curRecord.Columns() {
if col.IsNull(int(r.curRow)) {
dest[i] = nil
continue
}
switch col := col.(type) {
case *array.Boolean:
dest[i] = col.Value(int(r.curRow))
case *array.Int8:
dest[i] = col.Value(int(r.curRow))
case *array.Uint8:
dest[i] = col.Value(int(r.curRow))
case *array.Int16:
dest[i] = col.Value(int(r.curRow))
case *array.Uint16:
dest[i] = col.Value(int(r.curRow))
case *array.Int32:
dest[i] = col.Value(int(r.curRow))
case *array.Uint32:
dest[i] = col.Value(int(r.curRow))
case *array.Int64:
dest[i] = col.Value(int(r.curRow))
case *array.Uint64:
dest[i] = col.Value(int(r.curRow))
case *array.Float32:
dest[i] = col.Value(int(r.curRow))
case *array.Float64:
dest[i] = col.Value(int(r.curRow))
case *array.String:
dest[i] = col.Value(int(r.curRow))
case *array.LargeString:
dest[i] = col.Value(int(r.curRow))
case *array.Binary:
dest[i] = col.Value(int(r.curRow))
case *array.LargeBinary:
dest[i] = col.Value(int(r.curRow))
case *array.Date32:
dest[i] = col.Value(int(r.curRow)).ToTime()
case *array.Date64:
dest[i] = col.Value(int(r.curRow)).ToTime()
case *array.Time32:
dest[i] = col.Value(int(r.curRow)).ToTime(col.DataType().(*arrow.Time32Type).Unit)
case *array.Time64:
dest[i] = col.Value(int(r.curRow)).ToTime(col.DataType().(*arrow.Time64Type).Unit)
case *array.Timestamp:
dest[i] = col.Value(int(r.curRow)).ToTime(col.DataType().(*arrow.TimestampType).Unit)
case *array.Decimal128:
dest[i] = col.Value(int(r.curRow))
case *array.Decimal256:
dest[i] = col.Value(int(r.curRow))
default:
return fmt.Errorf(
"not yet implemented populating from columns of type " + col.DataType().String(),
)
}
}

r.curRow++
return nil
}

func (r *arrowRows) ColumnTypeDatabaseTypeName(index int) string {
return r.reader.Schema().Field(index).Type.String()
}

func (r *arrowRows) ColumnTypeNullable(index int) (nullable, ok bool) {
return r.reader.Schema().Field(index).Nullable, true
}

func (r *arrowRows) ColumnTypePrecisionScale(index int) (precision, scale int64, ok bool) {
typ := r.reader.Schema().Field(index).Type
switch dt := typ.(type) {
case *arrow.Decimal128Type:
return int64(dt.Precision), int64(dt.Scale), true
case *arrow.Decimal256Type:
return int64(dt.Precision), int64(dt.Scale), true
}
return 0, 0, false
}

func (r *arrowRows) ColumnTypeScanType(index int) reflect.Type {
switch r.reader.Schema().Field(index).Type.ID() {
case arrow.BOOL:
return reflect.TypeOf(false)
case arrow.INT8:
return reflect.TypeOf(int8(0))
case arrow.UINT8:
return reflect.TypeOf(uint8(0))
case arrow.INT16:
return reflect.TypeOf(int16(0))
case arrow.UINT16:
return reflect.TypeOf(uint16(0))
case arrow.INT32:
return reflect.TypeOf(int32(0))
case arrow.UINT32:
return reflect.TypeOf(uint32(0))
case arrow.INT64:
return reflect.TypeOf(int64(0))
case arrow.UINT64:
return reflect.TypeOf(uint64(0))
case arrow.FLOAT32:
return reflect.TypeOf(float32(0))
case arrow.FLOAT64:
return reflect.TypeOf(float64(0))
case arrow.DECIMAL128:
return reflect.TypeOf(decimal128.Num{})
case arrow.DECIMAL256:
return reflect.TypeOf(decimal256.Num{})
case arrow.BINARY:
return reflect.TypeOf([]byte{})
case arrow.STRING:
return reflect.TypeOf(string(""))
case arrow.TIME32, arrow.TIME64, arrow.DATE32, arrow.DATE64, arrow.TIMESTAMP:
return reflect.TypeOf(time.Time{})
}
return nil
}
106 changes: 106 additions & 0 deletions chdb/driver/arrow_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
package chdbdriver

import (
"database/sql"
"fmt"
"os"
"testing"

"github.com/chdb-io/chdb-go/chdb"
)

func TestDbWithArrow(t *testing.T) {

db, err := sql.Open("chdb", fmt.Sprintf("driverType=%s", "ARROW"))
if err != nil {
t.Errorf("open db fail, err:%s", err)
}
if db.Ping() != nil {
t.Errorf("ping db fail")
}
rows, err := db.Query(`SELECT 1,'abc'`)
if err != nil {
t.Errorf("run Query fail, err:%s", err)
}
cols, err := rows.Columns()
if err != nil {
t.Errorf("get result columns fail, err: %s", err)
}
if len(cols) != 2 {
t.Errorf("select result columns length should be 2")
}
var (
bar int
foo string
)
defer rows.Close()
for rows.Next() {
err := rows.Scan(&bar, &foo)
if err != nil {
t.Errorf("scan fail, err: %s", err)
}
if bar != 1 {
t.Errorf("expected error")
}
if foo != "abc" {
t.Errorf("expected error")
}
}
}

func TestDBWithArrowSession(t *testing.T) {
sessionDir, err := os.MkdirTemp("", "unittest-sessiondata")
if err != nil {
t.Fatalf("create temp directory fail, err: %s", err)
}
defer os.RemoveAll(sessionDir)
session, err := chdb.NewSession(sessionDir)
if err != nil {
t.Fatalf("new session fail, err: %s", err)
}
defer session.Cleanup()

session.Query("CREATE DATABASE IF NOT EXISTS testdb; " +
"CREATE TABLE IF NOT EXISTS testdb.testtable (id UInt32) ENGINE = MergeTree() ORDER BY id;")

session.Query("INSERT INTO testdb.testtable VALUES (1), (2), (3);")

ret, err := session.Query("SELECT * FROM testdb.testtable;")
if err != nil {
t.Fatalf("Query fail, err: %s", err)
}
if string(ret.Buf()) != "1\n2\n3\n" {
t.Errorf("Query result should be 1\n2\n3\n, got %s", string(ret.Buf()))
}
db, err := sql.Open("chdb", fmt.Sprintf("session=%s;driverType=%s", sessionDir, "ARROW"))
if err != nil {
t.Fatalf("open db fail, err: %s", err)
}
if db.Ping() != nil {
t.Fatalf("ping db fail, err: %s", err)
}
rows, err := db.Query("select * from testdb.testtable;")
if err != nil {
t.Fatalf("exec create function fail, err: %s", err)
}
defer rows.Close()
cols, err := rows.Columns()
if err != nil {
t.Fatalf("get result columns fail, err: %s", err)
}
if len(cols) != 1 {
t.Fatalf("result columns length shoule be 3, actual: %d", len(cols))
}
var bar = 0
var count = 1
for rows.Next() {
err = rows.Scan(&bar)
if err != nil {
t.Fatalf("scan fail, err: %s", err)
}
if bar != count {
t.Fatalf("result is not match, want: %d actual: %d", count, bar)
}
count++
}
}
Loading

0 comments on commit dd038d2

Please sign in to comment.