Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix 28/sql3 #33

Merged
merged 17 commits into from
May 18, 2024
75 changes: 43 additions & 32 deletions database/sql3/db.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,25 +9,22 @@ import (
"os"
"path/filepath"
"runtime"
"strings"

"github.com/maragudk/migrate"
_ "github.com/mattn/go-sqlite3"
)

const (
driverName = "sqlite3"
// Pragma is a set of default commands used to modify the operation of the SQLite.
//
// - JOURNAL MODE = WAL
// - BUSY TIMEOUT = 5000
// - SYNCHRONOUS = NORMAL
// - CACHE SIZE = 1000000000
// - FOREIGN KEYS = TRUE
// - TXLOCK = IMMEDIATE
// - TEMP STORE = MEMORY
// - MMAP SIZE = 3000000000
PRAGMA = "_journal_mod=wal&_busy_timeout=5000&_synchronous=normal&_cache_size=1000000000&_foreign_keys=true&_txlock=immediate&_temp_store=memory&_mmap_size=3000000000"
)
type Executor interface {
Exec(ctx context.Context, query string, args ...any) (sql.Result, error)
}

type Querier interface {
Query(ctx context.Context, query string, args ...any) (*sql.Rows, error)
QueryRow(ctx context.Context, query string, args ...any) *sql.Row
}

const driverName = "sqlite3"

// DB
type DB struct {
Expand Down Expand Up @@ -87,7 +84,7 @@ func Open(filename string) (*DB, error) {
if err := os.MkdirAll(filepath.Dir(filename), 0755); err != nil {
return nil, fmt.Errorf("create directory for database files: %w", err)
}
dsn := fmt.Sprintf("file:%s?%s", filename, PRAGMA)
dsn := fmt.Sprintf("file:%s?%s", filename, pragma())
wc, err := sql.Open(driverName, dsn)
if err != nil {
return nil, fmt.Errorf("open write pool: %w", err)
Expand All @@ -103,35 +100,25 @@ func Open(filename string) (*DB, error) {

// Tx
type Tx struct {
tx *sql.Tx
ctx context.Context
tx *sql.Tx
}

// Commit the transaction.
func (tx *Tx) Commit() error { return tx.tx.Commit() }

// Exec executes a query without returning any rows. The args are for any placeholder parameters in the query.
func (tx *Tx) Exec(query string, args ...any) (sql.Result, error) {
if tx.ctx == nil {
tx.ctx = context.Background()
}
return tx.tx.ExecContext(tx.ctx, query, args...)
func (tx *Tx) Exec(ctx context.Context, query string, args ...any) (sql.Result, error) {
return tx.tx.ExecContext(ctx, query, args...)
}

// Query executes a query that returns rows, typically a SELECT. The args are for any placeholder parameters in the query.
func (tx *Tx) Query(query string, args ...any) (*sql.Rows, error) {
if tx.ctx == nil {
tx.ctx = context.Background()
}
return tx.tx.QueryContext(tx.ctx, query, args...)
func (tx *Tx) Query(ctx context.Context, query string, args ...any) (*sql.Rows, error) {
return tx.tx.QueryContext(ctx, query, args...)
}

// QueryRow executes a query that is expected to return at most one row.
func (tx *Tx) QueryRow(query string, args ...any) *sql.Row {
if tx.ctx == nil {
tx.ctx = context.Background()
}
return tx.tx.QueryRowContext(tx.ctx, query, args...)
func (tx *Tx) QueryRow(ctx context.Context, query string, args ...any) *sql.Row {
return tx.tx.QueryRowContext(ctx, query, args...)
}

// Rollback the transaction.
Expand Down Expand Up @@ -174,3 +161,27 @@ func NewFS(fsys fs.FS, dir string) (*FS, error) {
fsys, err := fs.Sub(fsys, dir)
return &FS{fsys: fsys}, err
}

func pragma() string {
v := map[string]string{
"_journal_mode": "wal",
"_busy_timeout": "5000",
"_synchronous": "normal",
"_cache_size": "1000000000",
"_foreign_keys": "true",
"_txlock": "immediate",
"_temp_store": "memory",
"_mmap_size": "3000000000",
}

var buf strings.Builder
for k, val := range v {
if buf.Len() > 0 {
buf.WriteByte('&')
}
buf.WriteString(k)
buf.WriteByte('=')
buf.WriteString(val)
}
return buf.String()
}
19 changes: 17 additions & 2 deletions database/sql3/db_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ package sql3_test
import (
"context"
"embed"
"os"
"path/filepath"
"testing"

"github.com/google/uuid"
Expand All @@ -18,12 +20,16 @@ func Test_Up(t *testing.T) {
t.Run("OK", func(t *testing.T) {
is := is.NewRelaxed(t)

_, err := sqlFS.Up(context.TODO(), t.TempDir()+"/test.db")
_, err := sqlFS.Up(context.TODO(), testFilename(t, "test.db"))
is.NoErr(err) // (sql3.FS).Up
})
}

func Test_DB_Tx(t *testing.T) {
if testing.Short() {
t.Skip("this is a long test")
}

t.Run("OK", testRoundTrip(func(db *DB) {
is := is.NewRelaxed(t)

Expand All @@ -34,7 +40,7 @@ func Test_DB_Tx(t *testing.T) {

for i := range 5_000_000 {
rid := uuid.Must(uuid.NewV7())
_, err = tx.Exec(`insert into tests (id, counter) values (?, ?)`, rid, i)
_, err = tx.Exec(context.TODO(), `insert into tests (id, counter) values (?, ?)`, rid, i)
is.NoErr(err) // (sql3.Tx).Exec

if i%500_000 == 0 {
Expand Down Expand Up @@ -64,3 +70,12 @@ func testRoundTrip(f func(*DB)) func(*testing.T) {
f(db)
}
}

func testFilename(t testing.TB, filename string) string {
t.Helper()
if os.Getenv("DEBUG") != "1" {
return filepath.Join(t.TempDir(), filename)
}
_ = os.Remove(filename)
return filepath.Join(filename)
}
Loading