Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[!] add support for pgx.Batch, closes #152 #179

Closed
Closed
134 changes: 134 additions & 0 deletions batch.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
package pgxmock

import (
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgconn"
)

// BatchQuery helps to create batches for testing
type BatchQuery struct {
sql string
rewrittenSQL string
args []interface{}
}

// NewBatchQuery creates new query that will be queued in batch.
// Function accepts sql string and optionally arguments
func NewBatchQuery(sql string, args ...interface{}) *BatchQuery {
return &BatchQuery{
sql: sql,
args: args,
}
}

// WithRewrittenSQL will match given expected expression to a rewritten SQL statement by
// a pgx.QueryRewriter argument
func (be *BatchQuery) WithRewrittenSQL(sql string) *BatchQuery {
be.rewrittenSQL = sql
return be
}

// Batch is a batch mock that helps to create batches for testing
type Batch struct {
elements []*BatchQuery
}

// AddBatchQueries adds any number of BatchQuery to Batch
// that is used for mocking pgx.SendBatch()
func (b *Batch) AddBatchQueries(bes ...*BatchQuery) *Batch {
b.elements = append(b.elements, bes...)
return b
}

// NewBatch creates a structure that helps to combine multiple queries
// that will be used in batches. This function should be used in .ExpectSendBatch()
func NewBatch() *Batch {
return &Batch{}
}

// batchResults is a subsidiary structure for mocking BatchResults interface response
type batchResults struct {
br *BatchResults
ex *ExpectedBatch
}

// Query is a mock for Query() function in pgx.BatchResults interface
func (b *batchResults) Query() (pgx.Rows, error) {
if b.br.queryErr != nil {
return nil, b.br.queryErr
}
rs := b.br.rows.Kind()
rs.Next()
return rs, nil
}

// Exec is a mock for Exec() function in pgx.BatchResults interface
func (b *batchResults) Exec() (pgconn.CommandTag, error) {
if b.br.execErr != nil {
return pgconn.CommandTag{}, b.br.execErr
}
return b.br.commandTag, nil
}

// QueryRow is a mock for QueryRow() function in pgx.BatchResults interface
func (b *batchResults) QueryRow() pgx.Row {
rs := b.br.rows.Kind()
rs.Next()
return rs
}

// Close is a mock for Close() function in pgx.BatchResults interface
func (b *batchResults) Close() error {
b.ex.batchWasClosed = true
return b.br.closeErr
}

// BatchResults is a subsidiary structure for mocking SendBatch() function
// response. There is an option to mock returned Rows, errors and commandTag
type BatchResults struct {
commandTag pgconn.CommandTag
rows *Rows
queryErr error
execErr error
closeErr error
}

// NewBatchResults returns a mock response for SendBatch() function
func NewBatchResults() *BatchResults {
return &BatchResults{}
}

// QueryError sets the error that will be returned by Query() function
// called using pgx.BatchResults interface
func (b *BatchResults) QueryError(err error) *BatchResults {
b.queryErr = err
return b
}

// ExecError sets the error that will be returned by Exec() function
// called using pgx.BatchResults interface
func (b *BatchResults) ExecError(err error) *BatchResults {
b.execErr = err
return b
}

// CloseError sets the error that will be returned by Close() function
// called using pgx.BatchResults interface
func (b *BatchResults) CloseError(err error) *BatchResults {
b.closeErr = err
return b
}

// WillReturnRows allows to return mocked Rows by Query() and QueryRow()
// functions in pgx.BatchResults interface
func (b *BatchResults) WillReturnRows(rows *Rows) *BatchResults {
b.rows = rows
return b
}

// AddCommandTag allows to add pgconn.CommandTag to batchResults struct
// that will be returned in Exec() function
func (b *BatchResults) AddCommandTag(ct pgconn.CommandTag) *BatchResults {
b.commandTag = ct
return b
}
206 changes: 206 additions & 0 deletions batch_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,206 @@
package pgxmock

import (
"context"
"fmt"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgconn"
"github.com/jackc/pgx/v5/pgtype"
"github.com/stretchr/testify/assert"
"testing"
)

func TestBatchClosed(t *testing.T) {
t.Parallel()
a := assert.New(t)
mock, err := NewConn()
if err != nil {
t.Fatalf("an error '%s' was not expected when opening a stub database connection", err)
}
defer mock.Close(context.Background())

expectedBatch := mock.NewBatch().
AddBatchQueries(
NewBatchQuery("SELECT *", 1),
NewBatchQuery("SELECT *"),
)

batchResultsMock := NewBatchResults()

batch := new(pgx.Batch)
batch.Queue("SELECT * FROM TABLE", 1)
batch.Queue("SELECT * FROM TABLE")

mock.ExpectSendBatch(expectedBatch).
WillReturnBatchResults(batchResultsMock)

br := mock.SendBatch(context.Background(), batch)
a.NotNil(br)
a.NoError(br.Close())

a.NoError(mock.ExpectationsWereMet())
}

func TestBatchWithRewrittenSQL(t *testing.T) {
t.Parallel()
mock, err := NewConn(QueryMatcherOption(QueryMatcherEqual))
a := assert.New(t)
a.NoError(err)
defer mock.Close(context.Background())

u := user{name: "John", email: pgtype.Text{String: "[email protected]", Valid: true}}

expectedBatch := mock.NewBatch().
AddBatchQueries(
//first batch query is correct
NewBatchQuery("INSERT", &u).
WithRewrittenSQL("INSERT INTO users (username, email) VALUES ($1, $2) RETURNING id"),
//second batch query is not correct
NewBatchQuery("INSERT INTO users(username, password) VALUES (@user, @password)", pgx.NamedArgs{"user": "John", "password": "strong"}).
WithRewrittenSQL("INSERT INTO users(username, password) VALUES ($1)"),
)
batchResultsMock := NewBatchResults()

mock.ExpectSendBatch(expectedBatch).
WillReturnBatchResults(batchResultsMock).
BatchResultsWillBeClosed()

batch := new(pgx.Batch)
batch.Queue("INSERT", &u)
batch.Queue("INSERT INTO users(username) VALUES (@user)", pgx.NamedArgs{"user": "John", "password": "strong"})

br := mock.SendBatch(context.Background(), batch)
a.Nil(br)
a.Error(mock.ExpectationsWereMet())
}

func TestBatchQuery(t *testing.T) {
t.Parallel()
a := assert.New(t)
mock, err := NewConn()
if err != nil {
t.Fatalf("an error '%s' was not expected when opening a stub database connection", err)
}
defer mock.Close(context.Background())

expectedBatch := mock.NewBatch().
AddBatchQueries(
NewBatchQuery("SELECT *", 1),
NewBatchQuery("SELECT *"),
)

rows := NewRows([]string{"id", "name", "email"}).
AddRow("some-id-1", "some-name-1", "some-email-1").
AddRow("some-id-2", "some-name-2", "some-email-2")

batchResultsMock := NewBatchResults().WillReturnRows(rows).AddCommandTag(pgconn.NewCommandTag("SELECT 2"))

batch := new(pgx.Batch)
batch.Queue("SELECT * FROM TABLE", 1)
batch.Queue("SELECT * FROM TABLE")

mock.ExpectSendBatch(expectedBatch).
WillReturnBatchResults(batchResultsMock)

br := mock.SendBatch(context.Background(), batch)
a.NotNil(br)
r, err := br.Query()
a.NoError(err)

//assert rows are returned correctly
var id, name, email string
err = r.Scan(&id, &name, &email)
a.NoError(err)
a.Equal("some-id-1", id)
a.Equal("some-name-1", name)
a.Equal("some-email-1", email)

a.True(r.Next())
a.NoError(mock.ExpectationsWereMet())
}

func TestBatchErrors(t *testing.T) {
t.Parallel()
a := assert.New(t)
mock, err := NewConn()
if err != nil {
t.Fatalf("an error '%s' was not expected when opening a stub database connection", err)
}
defer mock.Close(context.Background())

expectedBatch := mock.NewBatch().
AddBatchQueries(
NewBatchQuery("SELECT *", 1),
NewBatchQuery("SELECT *"),
)

batchResultsMock := NewBatchResults().
QueryError(fmt.Errorf("query returned error")).
ExecError(fmt.Errorf("exec returned error")).
CloseError(fmt.Errorf("close returned error"))

batch := new(pgx.Batch)
batch.Queue("SELECT * FROM TABLE", 1)
batch.Queue("SELECT * FROM TABLE")

mock.ExpectSendBatch(expectedBatch).
WillReturnBatchResults(batchResultsMock)

br := mock.SendBatch(context.Background(), batch)
a.NotNil(br)

_, err = br.Query()
a.Error(err)

_, err = br.Exec()
a.Error(err)

err = br.Close()
a.Error(err)

a.NoError(mock.ExpectationsWereMet())
}

func TestBatchQueryRow(t *testing.T) {
t.Parallel()
a := assert.New(t)
mock, err := NewConn()
if err != nil {
t.Fatalf("an error '%s' was not expected when opening a stub database connection", err)
}
defer mock.Close(context.Background())

expectedBatch := mock.NewBatch().
AddBatchQueries(
NewBatchQuery("SELECT *", 1),
NewBatchQuery("SELECT *"),
)

rows := NewRows([]string{"id", "name", "email"}).
AddRow("some-id-1", "some-name-1", "some-email-1").
AddRow("some-id-2", "some-name-2", "some-email-2")

batchResultsMock := NewBatchResults().WillReturnRows(rows)

batch := new(pgx.Batch)
batch.Queue("SELECT * FROM TABLE", 1)
batch.Queue("SELECT * FROM TABLE")

mock.ExpectSendBatch(expectedBatch).
WillReturnBatchResults(batchResultsMock)

br := mock.SendBatch(context.Background(), batch)
a.NotNil(br)

r := br.QueryRow()

//assert rows are returned correctly
var id, name, email string
err = r.Scan(&id, &name, &email)
a.NoError(err)
a.Equal("some-id-1", id)
a.Equal("some-name-1", name)
a.Equal("some-email-1", email)

a.NoError(mock.ExpectationsWereMet())
}
46 changes: 46 additions & 0 deletions expectations.go
Original file line number Diff line number Diff line change
Expand Up @@ -441,3 +441,49 @@ func (e *ExpectedRollback) String() string {
}
return msg
}

// ExpectedBatch is used to manage pgx.SendBatch, pgx.Tx.SendBatch expectations.
// Returned by pgxmock.ExpectBatch
type ExpectedBatch struct {
commonExpectation
expectedBatch []queryBasedExpectation
expectedBatchResponse pgx.BatchResults
batchMustBeClosed bool
batchWasClosed bool
}

// String returns string representation
func (e *ExpectedBatch) String() string {
msg := "ExpectedBatch => expecting call to SendBatch():\n"
for _, b := range e.expectedBatch {
msg += "\texpecting query that:\n"
msg += fmt.Sprintf("\t\t- matches sql: '%s'\n", b.expectSQL)

if len(b.args) == 0 {
msg += "\t\t- is without arguments\n"
} else {
msg += "\t\t- is with arguments:\n"
for i, arg := range b.args {
msg += fmt.Sprintf("\t\t\t%d - %+v\n", i, arg)
}
}
}

if e.expectedBatch != nil {
msg += fmt.Sprintf("%v\n", e.expectedBatch)
}
return msg + e.commonExpectation.String()
}

// WillReturnBatchResults arranges for an expected SendBatch() to return given batch results
func (e *ExpectedBatch) WillReturnBatchResults(br *BatchResults) *ExpectedBatch {
e.expectedBatchResponse = &batchResults{br: br, ex: e}
return e
}

// BatchResultsWillBeClosed indicates that batchResults has to be closed.
// batchMustBeClosed will be checked in pgxmock.ExpectationsWereMet()
func (e *ExpectedBatch) BatchResultsWillBeClosed() *ExpectedBatch {
e.batchMustBeClosed = true
return e
}
Loading
Loading