diff --git a/bigquery/storage/managedwriter/managed_stream.go b/bigquery/storage/managedwriter/managed_stream.go index 9a1f627a825d..0a4d1567252c 100644 --- a/bigquery/storage/managedwriter/managed_stream.go +++ b/bigquery/storage/managedwriter/managed_stream.go @@ -242,10 +242,76 @@ func (ms *ManagedStream) openWithRetry() (storagepb.BigQueryWrite_AppendRowsClie } } -// append handles the details of adding sending an append request on a stream. Appends are sent on a long +// lockingAppend handles a single append attempt. When successful, it returns the number of rows +// in the request for metrics tracking. +func (ms *ManagedStream) lockingAppend(requestCtx context.Context, pw *pendingWrite) (int64, error) { + + // Don't both calling/retrying if this append's context is already expired. + if err := requestCtx.Err(); err != nil { + return 0, err + } + + // critical section: Things that need to happen inside the critical section: + // + // * Getting the stream connection (in case of reconnects) + // * Issuing the append request + // * Adding the pending write to the channel to keep ordering correct on response + ms.mu.Lock() + defer ms.mu.Unlock() + + var arc *storagepb.BigQueryWrite_AppendRowsClient + var ch chan *pendingWrite + var err error + + // If an updated schema is present, we need to reconnect the stream and update the reference + // schema for the stream. + reconnect := false + if pw.newSchema != nil && !proto.Equal(pw.newSchema, ms.schemaDescriptor) { + reconnect = true + ms.schemaDescriptor = proto.Clone(pw.newSchema).(*descriptorpb.DescriptorProto) + } + arc, ch, err = ms.getStream(arc, reconnect) + if err != nil { + return 0, err + } + + // Resolve the special work for the first append on a stream. + var req *storagepb.AppendRowsRequest + ms.streamSetup.Do(func() { + reqCopy := proto.Clone(pw.request).(*storagepb.AppendRowsRequest) + reqCopy.WriteStream = ms.streamSettings.streamID + reqCopy.GetProtoRows().WriterSchema = &storagepb.ProtoSchema{ + ProtoDescriptor: ms.schemaDescriptor, + } + if ms.streamSettings.TraceID != "" { + reqCopy.TraceId = ms.streamSettings.TraceID + } + req = reqCopy + }) + + if req != nil { + // First append in a new connection needs properties like schema and stream name set. + err = (*arc).Send(req) + } else { + // Subsequent requests need no modification. + err = (*arc).Send(pw.request) + } + if err != nil { + return 0, err + } + // Compute numRows, once we pass ownership to the channel the request may be + // cleared. + numRows := int64(len(pw.request.GetProtoRows().Rows.GetSerializedRows())) + ch <- pw + return numRows, nil +} + +// appendWithRetry handles the details of adding sending an append request on a stream. Appends are sent on a long // lived bidirectional network stream, with it's own managed context (ms.ctx). requestCtx is checked // for expiry to enable faster failures, it is not propagated more deeply. -func (ms *ManagedStream) append(requestCtx context.Context, pw *pendingWrite, opts ...gax.CallOption) error { +func (ms *ManagedStream) appendWithRetry(requestCtx context.Context, pw *pendingWrite, opts ...gax.CallOption) error { + + // Resolve retry settings. var settings gax.CallSettings for _, opt := range opts { opt.Resolve(&settings) @@ -255,104 +321,43 @@ func (ms *ManagedStream) append(requestCtx context.Context, pw *pendingWrite, op r = settings.Retry() } - var arc *storagepb.BigQueryWrite_AppendRowsClient - var ch chan *pendingWrite - var err error - for { - // critical section: Things that need to happen inside the critical section: - // - // * Getting the stream connection (in case of reconnects) - // * Issuing the append request - // * Adding the pending write to the channel to keep ordering correct on response - ms.mu.Lock() - - // Don't both calling/retrying if this append's context is already expired. - if err = requestCtx.Err(); err != nil { - return err - } - - // If an updated schema is present, we need to reconnect the stream and update the reference - // schema for the stream. - reconnect := false - if pw.newSchema != nil && !proto.Equal(pw.newSchema, ms.schemaDescriptor) { - reconnect = true - ms.schemaDescriptor = proto.Clone(pw.newSchema).(*descriptorpb.DescriptorProto) - } - arc, ch, err = ms.getStream(arc, reconnect) - if err != nil { - return err - } - - // Resolve the special work for the first append on a stream. - var req *storagepb.AppendRowsRequest - ms.streamSetup.Do(func() { - reqCopy := proto.Clone(pw.request).(*storagepb.AppendRowsRequest) - reqCopy.WriteStream = ms.streamSettings.streamID - reqCopy.GetProtoRows().WriterSchema = &storagepb.ProtoSchema{ - ProtoDescriptor: ms.schemaDescriptor, + numRows, appendErr := ms.lockingAppend(requestCtx, pw) + if appendErr != nil { + // Append yielded an error. Retry by continuing or return. + status := grpcstatus.Convert(appendErr) + if status != nil { + ctx, _ := tag.New(ms.ctx, tag.Insert(keyError, status.Code().String())) + recordStat(ctx, AppendRequestErrors, 1) } - if ms.streamSettings.TraceID != "" { - reqCopy.TraceId = ms.streamSettings.TraceID + bo, shouldRetry := r.Retry(appendErr) + if shouldRetry { + if err := gax.Sleep(ms.ctx, bo); err != nil { + return err + } + continue } - req = reqCopy - }) - - if req != nil { - // First append in a new connection needs properties like schema and stream name set. - err = (*arc).Send(req) - } else { - // Subsequent requests need no modification. - err = (*arc).Send(pw.request) - } - if err == nil { - // Compute numRows, once we pass ownership to the channel the request may be - // cleared. - numRows := int64(len(pw.request.GetProtoRows().Rows.GetSerializedRows())) - ch <- pw - // We've passed ownership of the pending write to the channel. - // It's now responsible for marking the request done, we're done - // with the critical section. + // We've got a non-retriable error, so propagate that up. and mark the write done. + ms.mu.Lock() + ms.err = appendErr + pw.markDone(NoStreamOffset, appendErr, ms.fc) ms.mu.Unlock() - - // Record stats and return. - recordStat(ms.ctx, AppendRequests, 1) - recordStat(ms.ctx, AppendRequestBytes, int64(pw.reqSize)) - recordStat(ms.ctx, AppendRequestRows, numRows) - return nil - } - // Unlock the mutex for error cases. - ms.mu.Unlock() - - // Append yielded an error. Retry by continuing or return. - status := grpcstatus.Convert(err) - if status != nil { - ctx, _ := tag.New(ms.ctx, tag.Insert(keyError, status.Code().String())) - recordStat(ctx, AppendRequestErrors, 1) - } - bo, shouldRetry := r.Retry(err) - if shouldRetry { - if err := gax.Sleep(ms.ctx, bo); err != nil { - return err - } - continue + return appendErr } - // We've got a non-retriable error, so propagate that up. and mark the write done. - ms.mu.Lock() - ms.err = err - pw.markDone(NoStreamOffset, err, ms.fc) - ms.mu.Unlock() - return err + recordStat(ms.ctx, AppendRequests, 1) + recordStat(ms.ctx, AppendRequestBytes, int64(pw.reqSize)) + recordStat(ms.ctx, AppendRequestRows, numRows) + return nil } } // Close closes a managed stream. func (ms *ManagedStream) Close() error { - - var arc *storagepb.BigQueryWrite_AppendRowsClient - // Critical section: get connection, close, mark closed. ms.mu.Lock() + defer ms.mu.Unlock() + + var arc *storagepb.BigQueryWrite_AppendRowsClient arc, ch, err := ms.getStream(arc, false) if err != nil { return err @@ -361,18 +366,22 @@ func (ms *ManagedStream) Close() error { return fmt.Errorf("no stream exists") } err = (*arc).CloseSend() - if err == nil { - close(ch) - } - ms.err = io.EOF - - // Done with the critical section. - ms.mu.Unlock() - // Propagate cancellation. + // Regardless of the outcome of CloseSend(), we're done with this channel. + close(ch) + // Additionally, cancel the underlying context for the stream, we don't allow re-open. if ms.cancel != nil { ms.cancel() + ms.cancel = nil } - return err + + if err != nil { + // For error on CloseSend, save that as the stream error and return. + ms.err = err + return err + } + // For normal operation, mark the stream error as io.EOF and return. + ms.err = io.EOF + return nil } // AppendRows sends the append requests to the service, and returns a single AppendResult for tracking @@ -401,7 +410,7 @@ func (ms *ManagedStream) AppendRows(ctx context.Context, data [][]byte, opts ... var appendErr error go func() { select { - case errCh <- ms.append(ctx, pw): + case errCh <- ms.appendWithRetry(ctx, pw): case <-ctx.Done(): } close(errCh) diff --git a/bigquery/storage/managedwriter/managed_stream_test.go b/bigquery/storage/managedwriter/managed_stream_test.go index f523afbdb035..c36ed0c150e1 100644 --- a/bigquery/storage/managedwriter/managed_stream_test.go +++ b/bigquery/storage/managedwriter/managed_stream_test.go @@ -16,6 +16,7 @@ package managedwriter import ( "context" + "errors" "runtime" "testing" "time" @@ -94,6 +95,7 @@ type testAppendRowsClient struct { requests []*storagepb.AppendRowsRequest sendF func(*storagepb.AppendRowsRequest) error recvF func() (*storagepb.AppendRowsResponse, error) + closeF func() error } func (tarc *testAppendRowsClient) Send(req *storagepb.AppendRowsRequest) error { @@ -104,6 +106,10 @@ func (tarc *testAppendRowsClient) Recv() (*storagepb.AppendRowsResponse, error) return tarc.recvF() } +func (tarc *testAppendRowsClient) CloseSend() error { + return tarc.closeF() +} + // openTestArc handles wiring in a test AppendRowsClient into a managedstream by providing the open function. func openTestArc(testARC *testAppendRowsClient, sendF func(req *storagepb.AppendRowsRequest) error, recvF func() (*storagepb.AppendRowsResponse, error)) func(s string, opts ...gax.CallOption) (storagepb.BigQueryWrite_AppendRowsClient, error) { sF := func(req *storagepb.AppendRowsRequest) error { @@ -123,6 +129,9 @@ func openTestArc(testARC *testAppendRowsClient, sendF func(req *storagepb.Append } testARC.sendF = sF testARC.recvF = rF + testARC.closeF = func() error { + return nil + } return func(s string, opts ...gax.CallOption) (storagepb.BigQueryWrite_AppendRowsClient, error) { testARC.openCount = testARC.openCount + 1 return testARC, nil @@ -291,6 +300,89 @@ func TestManagedStream_AppendWithDeadline(t *testing.T) { } +func TestManagedStream_AppendDeadlocks(t *testing.T) { + // Ensure we don't deadlock by issing two appends. + testCases := []struct { + desc string + openErrors []error + ctx context.Context + respErr error + }{ + { + desc: "no errors", + openErrors: []error{nil, nil}, + ctx: context.Background(), + respErr: nil, + }, + { + desc: "cancelled caller context", + openErrors: []error{nil, nil}, + ctx: func() context.Context { + cctx, cancel := context.WithCancel(context.Background()) + cancel() + return cctx + }(), + respErr: context.Canceled, + }, + { + desc: "expired caller context", + openErrors: []error{nil, nil}, + ctx: func() context.Context { + cctx, cancel := context.WithTimeout(context.Background(), time.Millisecond) + defer cancel() + time.Sleep(2 * time.Millisecond) + return cctx + }(), + respErr: context.DeadlineExceeded, + }, + { + desc: "errored getstream", + openErrors: []error{status.Errorf(codes.ResourceExhausted, "some error"), status.Errorf(codes.ResourceExhausted, "some error")}, + ctx: context.Background(), + respErr: status.Errorf(codes.ResourceExhausted, "some error"), + }, + } + + for _, tc := range testCases { + openF := openTestArc(&testAppendRowsClient{}, nil, nil) + ms := &ManagedStream{ + ctx: context.Background(), + open: func(s string, opts ...gax.CallOption) (storagepb.BigQueryWrite_AppendRowsClient, error) { + if len(tc.openErrors) == 0 { + panic("out of open errors") + } + curErr := tc.openErrors[0] + tc.openErrors = tc.openErrors[1:] + if curErr == nil { + return openF(s, opts...) + } + return nil, curErr + }, + streamSettings: &streamSettings{ + streamID: "foo", + }, + } + + // first append + pw := newPendingWrite([][]byte{[]byte("foo")}) + gotErr := ms.appendWithRetry(tc.ctx, pw) + if !errors.Is(gotErr, tc.respErr) { + t.Errorf("%s first response: got %v, want %v", tc.desc, gotErr, tc.respErr) + } + // second append + pw = newPendingWrite([][]byte{[]byte("bar")}) + gotErr = ms.appendWithRetry(tc.ctx, pw) + if !errors.Is(gotErr, tc.respErr) { + t.Errorf("%s second response: got %v, want %v", tc.desc, gotErr, tc.respErr) + } + + // Issue two closes, to ensure we're not deadlocking there either. + ms.Close() + ms.Close() + } + +} + func TestManagedStream_LeakingGoroutines(t *testing.T) { ctx := context.Background() diff --git a/bigquery/storage/managedwriter/retry.go b/bigquery/storage/managedwriter/retry.go index 1f272a999327..e598a2d806bd 100644 --- a/bigquery/storage/managedwriter/retry.go +++ b/bigquery/storage/managedwriter/retry.go @@ -15,6 +15,8 @@ package managedwriter import ( + "context" + "errors" "time" "github.com/googleapis/gax-go/v2" @@ -31,7 +33,11 @@ func (r *defaultRetryer) Retry(err error) (pause time.Duration, shouldRetry bool // retry predicates in addition to statuscode-based. s, ok := status.FromError(err) if !ok { - // non-status based errors as retryable + // Treat context errors as non-retriable. + if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { + return r.bo.Pause(), false + } + // Any other non-status based errors treated as retryable. return r.bo.Pause(), true } switch s.Code() {