diff --git a/internal/transport/controlbuf.go b/internal/transport/controlbuf.go index 83c3829826ae..ce8fb90655e8 100644 --- a/internal/transport/controlbuf.go +++ b/internal/transport/controlbuf.go @@ -193,7 +193,7 @@ type goAway struct { code http2.ErrCode debugData []byte headsUp bool - closeConn error // if set, loopyWriter will exit, resulting in conn closure + closeConn error // if set, loopyWriter will exit with this error } func (*goAway) isTransportResponseFrame() bool { return false } @@ -495,21 +495,22 @@ type loopyWriter struct { ssGoAwayHandler func(*goAway) (bool, error) } -func newLoopyWriter(s side, fr *framer, cbuf *controlBuffer, bdpEst *bdpEstimator, conn net.Conn, logger *grpclog.PrefixLogger) *loopyWriter { +func newLoopyWriter(s side, fr *framer, cbuf *controlBuffer, bdpEst *bdpEstimator, conn net.Conn, logger *grpclog.PrefixLogger, goAwayHandler func(*goAway) (bool, error)) *loopyWriter { var buf bytes.Buffer l := &loopyWriter{ - side: s, - cbuf: cbuf, - sendQuota: defaultWindowSize, - oiws: defaultWindowSize, - estdStreams: make(map[uint32]*outStream), - activeStreams: newOutStreamList(), - framer: fr, - hBuf: &buf, - hEnc: hpack.NewEncoder(&buf), - bdpEst: bdpEst, - conn: conn, - logger: logger, + side: s, + cbuf: cbuf, + sendQuota: defaultWindowSize, + oiws: defaultWindowSize, + estdStreams: make(map[uint32]*outStream), + activeStreams: newOutStreamList(), + framer: fr, + hBuf: &buf, + hEnc: hpack.NewEncoder(&buf), + bdpEst: bdpEst, + conn: conn, + logger: logger, + ssGoAwayHandler: goAwayHandler, } return l } diff --git a/internal/transport/http2_client.go b/internal/transport/http2_client.go index deba0c4d9ef4..fe621f991f79 100644 --- a/internal/transport/http2_client.go +++ b/internal/transport/http2_client.go @@ -408,10 +408,10 @@ func newHTTP2Client(connectCtx, ctx context.Context, addr resolver.Address, opts readerErrCh := make(chan error, 1) go t.reader(readerErrCh) defer func() { - if err == nil { - err = <-readerErrCh - } if err != nil { + // writerDone should be closed since the loopy goroutine + // wouldn't have started in the case this function returns an error. + close(t.writerDone) t.Close(err) } }() @@ -458,8 +458,12 @@ func newHTTP2Client(connectCtx, ctx context.Context, addr resolver.Address, opts if err := t.framer.writer.Flush(); err != nil { return nil, err } + // Block until the server preface is received successfully or an error occurs. + if err = <-readerErrCh; err != nil { + return nil, err + } go func() { - t.loopy = newLoopyWriter(clientSide, t.framer, t.controlBuf, t.bdpEst, t.conn, t.logger) + t.loopy = newLoopyWriter(clientSide, t.framer, t.controlBuf, t.bdpEst, t.conn, t.logger, t.outgoingGoAwayHandler) if err := t.loopy.run(); !isIOError(err) { // Immediately close the connection, as the loopy writer returns // when there are no more active streams and we were draining (the @@ -517,6 +521,17 @@ func (t *http2Client) getPeer() *peer.Peer { } } +// OutgoingGoAwayHandler writes a GOAWAY to the connection. Always returns (false, err) as we want the GoAway +// to be the last frame loopy writes to the transport. +func (t *http2Client) outgoingGoAwayHandler(g *goAway) (bool, error) { + t.mu.Lock() + defer t.mu.Unlock() + if err := t.framer.fr.WriteGoAway(t.nextID-2, http2.ErrCodeNo, g.debugData); err != nil { + return false, err + } + return false, g.closeConn +} + func (t *http2Client) createHeaderFields(ctx context.Context, callHdr *CallHdr) ([]hpack.HeaderField, error) { aud := t.createAudience(callHdr) ri := credentials.RequestInfo{ @@ -966,7 +981,7 @@ func (t *http2Client) closeStream(s *Stream, err error, rst bool, rstCode http2. // Close kicks off the shutdown process of the transport. This should be called // only once on a transport. Once it is called, the transport should not be -// accessed any more. +// accessed anymore. func (t *http2Client) Close(err error) { t.mu.Lock() // Make sure we only close once. @@ -991,7 +1006,10 @@ func (t *http2Client) Close(err error) { t.kpDormancyCond.Signal() } t.mu.Unlock() - t.controlBuf.finish() + // Per HTTP/2 spec, a GOAWAY frame must be sent before closing the + // connection. See https://httpwg.org/specs/rfc7540.html#GOAWAY. + t.controlBuf.put(&goAway{code: http2.ErrCodeNo, debugData: []byte("client transport shutdown"), closeConn: err}) + <-t.writerDone t.cancel() t.conn.Close() channelz.RemoveEntry(t.channelz.ID) diff --git a/internal/transport/http2_server.go b/internal/transport/http2_server.go index 1e1585ed210f..8a554ce408bf 100644 --- a/internal/transport/http2_server.go +++ b/internal/transport/http2_server.go @@ -330,8 +330,7 @@ func NewServerTransport(conn net.Conn, config *ServerConfig) (_ ServerTransport, t.handleSettings(sf) go func() { - t.loopy = newLoopyWriter(serverSide, t.framer, t.controlBuf, t.bdpEst, t.conn, t.logger) - t.loopy.ssGoAwayHandler = t.outgoingGoAwayHandler + t.loopy = newLoopyWriter(serverSide, t.framer, t.controlBuf, t.bdpEst, t.conn, t.logger, t.outgoingGoAwayHandler) err := t.loopy.run() close(t.loopyWriterDone) if !isIOError(err) { diff --git a/internal/transport/transport_test.go b/internal/transport/transport_test.go index b0be89210564..6c1ee6b6283c 100644 --- a/internal/transport/transport_test.go +++ b/internal/transport/transport_test.go @@ -2659,3 +2659,94 @@ func TestConnectionError_Unwrap(t *testing.T) { t.Error("ConnectionError does not unwrap") } } + +// Test that in the event of a graceful client transport shutdown, i.e., +// clientTransport.Close(), client sends a goaway to the server with the correct +// error code and debug data. +func (s) TestClientSendsAGoAwayFrame(t *testing.T) { + // Create a server. + lis, err := net.Listen("tcp", "localhost:0") + if err != nil { + t.Fatalf("Error while listening: %v", err) + } + defer lis.Close() + // greetDone is used to notify when server is done greeting the client. + greetDone := make(chan struct{}) + // errorCh verifies that desired GOAWAY not received by server + errorCh := make(chan error) + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + // Launch the server. + go func() { + sconn, err := lis.Accept() + if err != nil { + t.Errorf("Error while accepting: %v", err) + } + defer sconn.Close() + if _, err := io.ReadFull(sconn, make([]byte, len(clientPreface))); err != nil { + t.Errorf("Error while writing settings ack: %v", err) + return + } + sfr := http2.NewFramer(sconn, sconn) + if err := sfr.WriteSettings(); err != nil { + t.Errorf("Error while writing settings %v", err) + return + } + fr, _ := sfr.ReadFrame() + if _, ok := fr.(*http2.SettingsFrame); !ok { + t.Errorf("Expected settings frame, got %v", fr) + } + fr, _ = sfr.ReadFrame() + if fr, ok := fr.(*http2.SettingsFrame); !ok && fr.IsAck() { + t.Errorf("Expected settings ACK frame, got %v", fr) + } + fr, _ = sfr.ReadFrame() + if fr, ok := fr.(*http2.HeadersFrame); !ok && fr.Flags.Has(http2.FlagHeadersEndStream) { + t.Errorf("Expected Headers frame with END_HEADERS frame, got %v", fr) + } + close(greetDone) + + frame, err := sfr.ReadFrame() + if err != nil { + return + } + switch fr := frame.(type) { + case *http2.GoAwayFrame: + // Records that the server successfully received a GOAWAY frame. + goAwayFrame := fr + if goAwayFrame.ErrCode == http2.ErrCodeNo { + t.Logf("Received goAway frame from client") + close(errorCh) + } else { + errorCh <- fmt.Errorf("received unexpected goAway frame: %v", err) + close(errorCh) + } + return + default: + errorCh <- fmt.Errorf("server received a frame other than GOAWAY: %v", err) + close(errorCh) + return + } + }() + + ct, err := NewClientTransport(ctx, context.Background(), resolver.Address{Addr: lis.Addr().String()}, ConnectOptions{}, func(GoAwayReason) {}) + if err != nil { + t.Fatalf("Error while creating client transport: %v", err) + } + _, err = ct.NewStream(ctx, &CallHdr{}) + if err != nil { + t.Fatalf("failed to open stream: %v", err) + } + // Wait until server receives the headers and settings frame as part of greet. + <-greetDone + ct.Close(errors.New("manually closed by client")) + t.Logf("Closed the client connection") + select { + case err := <-errorCh: + if err != nil { + t.Errorf("Error receiving the GOAWAY frame: %v", err) + } + case <-ctx.Done(): + t.Errorf("Context timed out") + } +} diff --git a/test/goaway_test.go b/test/goaway_test.go index 2a8ff0bfcc04..07cd0c915b70 100644 --- a/test/goaway_test.go +++ b/test/goaway_test.go @@ -20,6 +20,7 @@ package test import ( "context" + "fmt" "io" "net" "strings" @@ -761,3 +762,67 @@ func (s) TestTwoGoAwayPingFrames(t *testing.T) { t.Fatalf("Error waiting for graceful shutdown of the server: %v", err) } } + +// TestClientSendsAGoAway tests the scenario where you get a go away ping +// frames from the client during graceful shutdown. +func (s) TestClientSendsAGoAway(t *testing.T) { + lis, err := net.Listen("tcp", "localhost:0") + if err != nil { + t.Fatalf("error listening: %v", err) + } + ctCh := testutils.NewChannel() + go func() { + conn, err := lis.Accept() + if err != nil { + t.Errorf("error in lis.Accept(): %v", err) + } + ct := newClientTester(t, conn) + ctCh.Send(ct) + }() + defer lis.Close() + + cc, err := grpc.Dial(lis.Addr().String(), grpc.WithTransportCredentials(insecure.NewCredentials())) + if err != nil { + t.Fatalf("error dialing: %v", err) + } + + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + + val, err := ctCh.Receive(ctx) + if err != nil { + t.Fatalf("timeout waiting for client transport (should be given after http2 creation)") + } + ct := val.(*clientTester) + goAwayReceived := make(chan struct{}) + errCh := make(chan error) + go func() { + for { + f, err := ct.fr.ReadFrame() + if err != nil { + return + } + switch fr := f.(type) { + case *http2.GoAwayFrame: + fr = f.(*http2.GoAwayFrame) + if fr.ErrCode == http2.ErrCodeNo { + t.Logf("GoAway received from client") + close(goAwayReceived) + } + default: + t.Errorf("server tester received unexpected frame type %T", f) + errCh <- fmt.Errorf("server tester received unexpected frame type %T", f) + close(errCh) + } + } + }() + cc.Close() + defer ct.conn.Close() + select { + case <-goAwayReceived: + case err := <-errCh: + t.Errorf("Error receiving the goAway: %v", err) + case <-ctx.Done(): + t.Errorf("Context timed out") + } +}