-
Notifications
You must be signed in to change notification settings - Fork 0
/
batch.go
71 lines (58 loc) · 1.78 KB
/
batch.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
package pgutil
import (
"context"
"fmt"
)
type BatchInserter struct {
db DB
numColumns int
maxBatchSize int
maxCapacity int
queryBuilder *batchQueryBuilder
returningScanner ScanFunc
values []any
}
const maxNumPostgresParameters = 65535
func NewBatchInserter(db DB, tableName string, columnNames []string, configs ...BatchInserterConfigFunc) *BatchInserter {
var (
options = getBatchInserterOptions(configs)
numColumns = len(columnNames)
maxBatchSize = int(maxNumPostgresParameters/numColumns) * numColumns
maxCapacity = maxBatchSize + numColumns
queryBuilder = newBatchQueryBuilder(tableName, columnNames, options.onConflictClause, options.returningClause)
returningScanner = options.returningScanner
)
return &BatchInserter{
db: db,
numColumns: numColumns,
maxBatchSize: maxBatchSize,
maxCapacity: maxCapacity,
queryBuilder: queryBuilder,
returningScanner: returningScanner,
values: make([]any, 0, maxCapacity),
}
}
func (i *BatchInserter) Insert(ctx context.Context, values ...any) error {
if len(values) != i.numColumns {
return fmt.Errorf("received %d values for %d columns", len(values), i.numColumns)
}
i.values = append(i.values, values...)
if len(i.values) >= i.maxBatchSize {
return i.Flush(ctx)
}
return nil
}
func (i *BatchInserter) Flush(ctx context.Context) error {
if len(i.values) == 0 {
return nil
}
n := i.maxBatchSize
if len(i.values) < i.maxBatchSize {
n = len(i.values)
}
batch := i.values[:n]
i.values = append(make([]any, 0, i.maxCapacity), i.values[n:]...)
batchSize := len(batch)
query := i.queryBuilder.build(batchSize)
return NewRowScanner(i.returningScanner)(i.db.Query(ctx, RawQuery(query, batch...)))
}