Skip to content

Commit a171c4e

Browse files
authored
Merge pull request #19723 from influxdata/19658/refactor_batcher
chore(write): refactor batcher to reuse code
2 parents 7dcaf5c + e914f66 commit a171c4e

File tree

2 files changed

+78
-102
lines changed

2 files changed

+78
-102
lines changed

write/batcher.go

+18-96
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,21 @@ type Batcher struct {
3535
Service platform.WriteService // Service receives batches flushed from Batcher.
3636
}
3737

38-
// Write reads r in batches and sends to the output.
38+
// Write reads r in batches and writes to a target specified by org and bucket.
3939
func (b *Batcher) Write(ctx context.Context, org, bucket platform.ID, r io.Reader) error {
40+
return b.writeBytes(ctx, r, func(batch []byte) error {
41+
return b.Service.Write(ctx, org, bucket, bytes.NewReader(batch))
42+
})
43+
}
44+
45+
// WriteTo reads r in batches and writes to a target specified by filter.
46+
func (b *Batcher) WriteTo(ctx context.Context, filter platform.BucketFilter, r io.Reader) error {
47+
return b.writeBytes(ctx, r, func(batch []byte) error {
48+
return b.Service.WriteTo(ctx, filter, bytes.NewReader(batch))
49+
})
50+
}
51+
52+
func (b *Batcher) writeBytes(ctx context.Context, r io.Reader, writeFn func(batch []byte) error) error {
4053
ctx, cancel := context.WithCancel(ctx)
4154
defer cancel()
4255

@@ -47,7 +60,7 @@ func (b *Batcher) Write(ctx context.Context, org, bucket platform.ID, r io.Reade
4760
lines := make(chan []byte)
4861

4962
errC := make(chan error, 2)
50-
go b.write(ctx, org, bucket, lines, errC)
63+
go b.write(ctx, writeFn, lines, errC)
5164
go b.read(ctx, r, lines, errC)
5265

5366
// we loop twice to check if both read and write have an error. if read exits
@@ -99,95 +112,7 @@ func (b *Batcher) read(ctx context.Context, r io.Reader, lines chan<- []byte, er
99112
// finishes when the lines channel is closed or context is done.
100113
// if an error occurs while writing data to the write service, the error is send in the
101114
// errC channel and the function returns.
102-
func (b *Batcher) write(ctx context.Context, org, bucket platform.ID, lines <-chan []byte, errC chan<- error) {
103-
flushInterval := b.MaxFlushInterval
104-
if flushInterval == 0 {
105-
flushInterval = DefaultInterval
106-
}
107-
108-
maxBytes := b.MaxFlushBytes
109-
if maxBytes == 0 {
110-
maxBytes = DefaultMaxBytes
111-
}
112-
113-
timer := time.NewTimer(flushInterval)
114-
defer func() { _ = timer.Stop() }()
115-
116-
buf := make([]byte, 0, maxBytes)
117-
r := bytes.NewReader(buf)
118-
119-
var line []byte
120-
var more = true
121-
// if read closes the channel normally, exit the loop
122-
for more {
123-
select {
124-
case line, more = <-lines:
125-
if more {
126-
buf = append(buf, line...)
127-
}
128-
// write if we exceed the max lines OR read routine has finished
129-
if len(buf) >= maxBytes || (!more && len(buf) > 0) {
130-
r.Reset(buf)
131-
timer.Reset(flushInterval)
132-
if err := b.Service.Write(ctx, org, bucket, r); err != nil {
133-
errC <- err
134-
return
135-
}
136-
buf = buf[:0]
137-
}
138-
case <-timer.C:
139-
if len(buf) > 0 {
140-
r.Reset(buf)
141-
timer.Reset(flushInterval)
142-
if err := b.Service.Write(ctx, org, bucket, r); err != nil {
143-
errC <- err
144-
return
145-
}
146-
buf = buf[:0]
147-
}
148-
case <-ctx.Done():
149-
errC <- ctx.Err()
150-
return
151-
}
152-
}
153-
154-
errC <- nil
155-
}
156-
157-
func (b *Batcher) WriteTo(ctx context.Context, filter platform.BucketFilter, r io.Reader) error {
158-
ctx, cancel := context.WithCancel(ctx)
159-
defer cancel()
160-
161-
if b.Service == nil {
162-
return fmt.Errorf("destination write service required")
163-
}
164-
165-
lines := make(chan []byte)
166-
167-
errC := make(chan error, 2)
168-
go b.writeTo(ctx, filter, lines, errC)
169-
go b.read(ctx, r, lines, errC)
170-
171-
// we loop twice to check if both read and write have an error. if read exits
172-
// cleanly, then we still want to wait for write.
173-
for i := 0; i < 2; i++ {
174-
select {
175-
case <-ctx.Done():
176-
return ctx.Err()
177-
case err := <-errC:
178-
// onky if there is any error, exit immediately.
179-
if err != nil {
180-
return err
181-
}
182-
}
183-
}
184-
return nil
185-
}
186-
187-
// finishes when the lines channel is closed or context is done.
188-
// if an error occurs while writing data to the write service, the error is send in the
189-
// errC channel and the function returns.
190-
func (b *Batcher) writeTo(ctx context.Context, filter platform.BucketFilter, lines <-chan []byte, errC chan<- error) {
115+
func (b *Batcher) write(ctx context.Context, writeFn func(batch []byte) error, lines <-chan []byte, errC chan<- error) {
191116
flushInterval := b.MaxFlushInterval
192117
if flushInterval == 0 {
193118
flushInterval = DefaultInterval
@@ -202,7 +127,6 @@ func (b *Batcher) writeTo(ctx context.Context, filter platform.BucketFilter, lin
202127
defer func() { _ = timer.Stop() }()
203128

204129
buf := make([]byte, 0, maxBytes)
205-
r := bytes.NewReader(buf)
206130

207131
var line []byte
208132
var more = true
@@ -215,19 +139,17 @@ func (b *Batcher) writeTo(ctx context.Context, filter platform.BucketFilter, lin
215139
}
216140
// write if we exceed the max lines OR read routine has finished
217141
if len(buf) >= maxBytes || (!more && len(buf) > 0) {
218-
r.Reset(buf)
219142
timer.Reset(flushInterval)
220-
if err := b.Service.WriteTo(ctx, filter, r); err != nil {
143+
if err := writeFn(buf); err != nil {
221144
errC <- err
222145
return
223146
}
224147
buf = buf[:0]
225148
}
226149
case <-timer.C:
227150
if len(buf) > 0 {
228-
r.Reset(buf)
229151
timer.Reset(flushInterval)
230-
if err := b.Service.WriteTo(ctx, filter, r); err != nil {
152+
if err := writeFn(buf); err != nil {
231153
errC <- err
232154
return
233155
}

write/batcher_test.go

+60-6
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package write
22

33
import (
44
"bufio"
5+
"bytes"
56
"context"
67
"fmt"
78
"io"
@@ -295,8 +296,11 @@ func TestBatcher_write(t *testing.T) {
295296
MaxFlushInterval: tt.fields.MaxFlushInterval,
296297
Service: svc,
297298
}
299+
writeFn := func(batch []byte) error {
300+
return svc.Write(ctx, tt.args.org, tt.args.bucket, bytes.NewReader(batch))
301+
}
298302

299-
go b.write(ctx, tt.args.org, tt.args.bucket, tt.args.lines, tt.args.errC)
303+
go b.write(ctx, writeFn, tt.args.lines, tt.args.errC)
300304

301305
if cancel != nil {
302306
cancel()
@@ -325,6 +329,17 @@ func TestBatcher_write(t *testing.T) {
325329
}
326330

327331
func TestBatcher_Write(t *testing.T) {
332+
createReader := func(data string) func() io.Reader {
333+
if data == "error" {
334+
return func() io.Reader {
335+
return &errorReader{}
336+
}
337+
}
338+
return func() io.Reader {
339+
return strings.NewReader(data)
340+
}
341+
}
342+
328343
type fields struct {
329344
MaxFlushBytes int
330345
MaxFlushInterval time.Duration
@@ -333,7 +348,7 @@ func TestBatcher_Write(t *testing.T) {
333348
writeError bool
334349
org platform.ID
335350
bucket platform.ID
336-
r io.Reader
351+
r func() io.Reader
337352
}
338353
tests := []struct {
339354
name string
@@ -351,7 +366,7 @@ func TestBatcher_Write(t *testing.T) {
351366
args: args{
352367
org: platform.ID(1),
353368
bucket: platform.ID(2),
354-
r: strings.NewReader("m1,t1=v1 f1=1"),
369+
r: createReader("m1,t1=v1 f1=1"),
355370
},
356371
want: "m1,t1=v1 f1=1",
357372
wantFlushes: 1,
@@ -364,7 +379,7 @@ func TestBatcher_Write(t *testing.T) {
364379
args: args{
365380
org: platform.ID(1),
366381
bucket: platform.ID(2),
367-
r: strings.NewReader("m1,t1=v1 f1=1\nm2,t2=v2 f2=2\nm3,t3=v3 f3=3"),
382+
r: createReader("m1,t1=v1 f1=1\nm2,t2=v2 f2=2\nm3,t3=v3 f3=3"),
368383
},
369384
want: "m3,t3=v3 f3=3",
370385
wantFlushes: 3,
@@ -375,7 +390,7 @@ func TestBatcher_Write(t *testing.T) {
375390
args: args{
376391
org: platform.ID(1),
377392
bucket: platform.ID(2),
378-
r: &errorReader{},
393+
r: createReader("error"),
379394
},
380395
wantErr: true,
381396
},
@@ -407,7 +422,46 @@ func TestBatcher_Write(t *testing.T) {
407422
}
408423

409424
ctx := context.Background()
410-
if err := b.Write(ctx, tt.args.org, tt.args.bucket, tt.args.r); (err != nil) != tt.wantErr {
425+
if err := b.Write(ctx, tt.args.org, tt.args.bucket, tt.args.r()); (err != nil) != tt.wantErr {
426+
t.Errorf("Batcher.Write() error = %v, wantErr %v", err, tt.wantErr)
427+
}
428+
429+
if gotFlushes != tt.wantFlushes {
430+
t.Errorf("%q. Batcher.Write() flushes %d want %d", tt.name, gotFlushes, tt.wantFlushes)
431+
}
432+
if !cmp.Equal(got, tt.want) {
433+
t.Errorf("%q. Batcher.Write() = -got/+want %s", tt.name, cmp.Diff(got, tt.want))
434+
}
435+
})
436+
// test the same data, but now with WriteTo function
437+
t.Run("WriteTo_"+tt.name, func(t *testing.T) {
438+
// mocking the write service here to either return an error
439+
// or get back all the bytes from the reader.
440+
var (
441+
got string
442+
gotFlushes int
443+
)
444+
svc := &mock.WriteService{
445+
WriteToF: func(ctx context.Context, filter platform.BucketFilter, r io.Reader) error {
446+
if tt.args.writeError {
447+
return fmt.Errorf("error")
448+
}
449+
b, err := ioutil.ReadAll(r)
450+
got = string(b)
451+
gotFlushes++
452+
return err
453+
},
454+
}
455+
456+
b := &Batcher{
457+
MaxFlushBytes: tt.fields.MaxFlushBytes,
458+
MaxFlushInterval: tt.fields.MaxFlushInterval,
459+
Service: svc,
460+
}
461+
462+
ctx := context.Background()
463+
bucketFilter := platform.BucketFilter{ID: &tt.args.bucket, OrganizationID: &tt.args.org}
464+
if err := b.WriteTo(ctx, bucketFilter, tt.args.r()); (err != nil) != tt.wantErr {
411465
t.Errorf("Batcher.Write() error = %v, wantErr %v", err, tt.wantErr)
412466
}
413467

0 commit comments

Comments
 (0)