-
Notifications
You must be signed in to change notification settings - Fork 50
/
producer.go
320 lines (289 loc) · 8.09 KB
/
producer.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
// Amazon kinesis producer
// A KPL-like batch producer for Amazon Kinesis built on top of the official Go AWS SDK
// and using the same aggregation format that KPL use.
//
// Note: this project start as a fork of `tj/go-kinesis`. if you are not intersting in the
// KPL aggregation logic, you probably want to check it out.
package producer
import (
"context"
"crypto/md5"
"errors"
"fmt"
"sync"
"time"
"github.com/aws/aws-sdk-go-v2/aws"
k "github.com/aws/aws-sdk-go-v2/service/kinesis"
ktypes "github.com/aws/aws-sdk-go-v2/service/kinesis/types"
"github.com/jpillora/backoff"
)
// Errors
var (
ErrStoppedProducer = errors.New("Unable to Put record. Producer is already stopped")
ErrIllegalPartitionKey = errors.New("Invalid parition key. Length must be at least 1 and at most 256")
ErrRecordSizeExceeded = errors.New("Data must be less than or equal to 1MB in size")
)
// Producer batches records.
type Producer struct {
sync.RWMutex
*Config
aggregator *Aggregator
semaphore semaphore
records chan *ktypes.PutRecordsRequestEntry
failure chan *FailureRecord
done chan struct{}
// Current state of the Producer
// notify set to true after calling to `NotifyFailures`
notify bool
// stopped set to true after `Stop`ing the Producer.
// This will prevent from user to `Put` any new data.
stopped bool
}
// New creates new producer with the given config.
func New(config *Config) *Producer {
config.defaults()
return &Producer{
Config: config,
done: make(chan struct{}),
records: make(chan *ktypes.PutRecordsRequestEntry, config.BacklogCount),
semaphore: make(chan struct{}, config.MaxConnections),
aggregator: new(Aggregator),
}
}
// Put `data` using `partitionKey` asynchronously. This method is thread-safe.
//
// Under the covers, the Producer will automatically re-attempt puts in case of
// transient errors.
// When unrecoverable error has detected(e.g: trying to put to in a stream that
// doesn't exist), the message will returned by the Producer.
// Add a listener with `Producer.NotifyFailures` to handle undeliverable messages.
func (p *Producer) Put(data []byte, partitionKey string) error {
p.RLock()
stopped := p.stopped
p.RUnlock()
if stopped {
return ErrStoppedProducer
}
if len(data) > maxRecordSize {
return ErrRecordSizeExceeded
}
if l := len(partitionKey); l < 1 || l > 256 {
return ErrIllegalPartitionKey
}
nbytes := len(data) + len([]byte(partitionKey))
// if the record size is bigger than aggregation size
// handle it as a simple kinesis record
if nbytes > p.AggregateBatchSize {
p.records <- &ktypes.PutRecordsRequestEntry{
Data: data,
PartitionKey: &partitionKey,
}
} else {
p.Lock()
needToDrain := nbytes+p.aggregator.Size()+md5.Size+len(magicNumber)+partitionKeyIndexSize > maxRecordSize || p.aggregator.Count() >= p.AggregateBatchCount
var (
record *ktypes.PutRecordsRequestEntry
err error
)
if needToDrain {
if record, err = p.aggregator.Drain(); err != nil {
p.Logger.Error("drain aggregator", err)
}
}
p.aggregator.Put(data, partitionKey)
p.Unlock()
// release the lock and then pipe the record to the records channel
// we did it, because the "send" operation blocks when the backlog is full
// and this can cause deadlock(when we never release the lock)
if needToDrain && record != nil {
p.records <- record
}
}
return nil
}
// Failure record type
type FailureRecord struct {
error
Data []byte
PartitionKey string
}
// NotifyFailures registers and return listener to handle undeliverable messages.
// The incoming struct has a copy of the Data and the PartitionKey along with some
// error information about why the publishing failed.
func (p *Producer) NotifyFailures() <-chan *FailureRecord {
p.Lock()
defer p.Unlock()
if !p.notify {
p.notify = true
p.failure = make(chan *FailureRecord, p.BacklogCount)
}
return p.failure
}
// Start the producer
func (p *Producer) Start() {
p.Logger.Info("starting producer", LogValue{"stream", p.StreamName})
go p.loop()
}
// Stop the producer gracefully. Flushes any in-flight data.
func (p *Producer) Stop() {
p.Lock()
p.stopped = true
p.Unlock()
p.Logger.Info("stopping producer", LogValue{"backlog", len(p.records)})
// drain
if record, ok := p.drainIfNeed(); ok {
p.records <- record
}
p.done <- struct{}{}
close(p.records)
// wait
<-p.done
p.semaphore.wait()
// close the failures channel if we notify
p.RLock()
if p.notify {
close(p.failure)
}
p.RUnlock()
p.Logger.Info("stopped producer")
}
// loop and flush at the configured interval, or when the buffer is exceeded.
func (p *Producer) loop() {
size := 0
drain := false
buf := make([]ktypes.PutRecordsRequestEntry, 0, p.BatchCount)
tick := time.NewTicker(p.FlushInterval)
flush := func(msg string) {
p.semaphore.acquire()
go p.flush(buf, msg)
buf = nil
size = 0
}
bufAppend := func(record *ktypes.PutRecordsRequestEntry) {
// the record size limit applies to the total size of the
// partition key and data blob.
rsize := len(record.Data) + len([]byte(*record.PartitionKey))
if size+rsize > p.BatchSize {
flush("batch size")
}
size += rsize
buf = append(buf, *record)
if len(buf) >= p.BatchCount {
flush("batch length")
}
}
defer tick.Stop()
defer close(p.done)
for {
select {
case record, ok := <-p.records:
if drain && !ok {
if size > 0 {
flush("drain")
}
p.Logger.Info("backlog drained")
return
}
bufAppend(record)
case <-tick.C:
if record, ok := p.drainIfNeed(); ok {
bufAppend(record)
}
// if the buffer is still containing records
if size > 0 {
flush("interval")
}
case <-p.done:
drain = true
}
}
}
func (p *Producer) drainIfNeed() (*ktypes.PutRecordsRequestEntry, bool) {
p.RLock()
needToDrain := p.aggregator.Size() > 0
p.RUnlock()
if needToDrain {
p.Lock()
record, err := p.aggregator.Drain()
p.Unlock()
if err != nil {
p.Logger.Error("drain aggregator", err)
} else {
return record, true
}
}
return nil, false
}
// flush records and retry failures if necessary.
// for example: when we get "ProvisionedThroughputExceededException"
func (p *Producer) flush(records []ktypes.PutRecordsRequestEntry, reason string) {
b := &backoff.Backoff{
Jitter: true,
}
defer p.semaphore.release()
for {
p.Logger.Info("flushing records", LogValue{"reason", reason}, LogValue{"records", len(records)})
out, err := p.Client.PutRecords(context.Background(), &k.PutRecordsInput{
StreamName: aws.String(p.StreamName),
Records: records,
})
if err != nil {
p.Logger.Error("flush", err)
p.RLock()
notify := p.notify
p.RUnlock()
if notify {
p.dispatchFailures(records, err)
}
return
}
if p.Verbose {
for i, r := range out.Records {
values := make([]LogValue, 2)
if r.ErrorCode != nil {
values[0] = LogValue{"ErrorCode", *r.ErrorCode}
values[1] = LogValue{"ErrorMessage", *r.ErrorMessage}
} else {
values[0] = LogValue{"ShardId", *r.ShardId}
values[1] = LogValue{"SequenceNumber", *r.SequenceNumber}
}
p.Logger.Info(fmt.Sprintf("Result[%d]", i), values...)
}
}
failed := *out.FailedRecordCount
if failed == 0 {
return
}
duration := b.Duration()
p.Logger.Info(
"put failures",
LogValue{"failures", failed},
LogValue{"backoff", duration.String()},
)
time.Sleep(duration)
// change the logging state for the next itertion
reason = "retry"
records = failures(records, out.Records)
}
}
// dispatchFailures gets batch of records, extract them, and push them
// into the failure channel
func (p *Producer) dispatchFailures(records []ktypes.PutRecordsRequestEntry, err error) {
for _, r := range records {
if isAggregated(&r) {
p.dispatchFailures(extractRecords(&r), err)
} else {
p.failure <- &FailureRecord{err, r.Data, *r.PartitionKey}
}
}
}
// failures returns the failed records as indicated in the response.
func failures(records []ktypes.PutRecordsRequestEntry,
response []ktypes.PutRecordsResultEntry) (out []ktypes.PutRecordsRequestEntry) {
for i, record := range response {
if record.ErrorCode != nil {
out = append(out, records[i])
}
}
return
}