Skip to content

Commit c40c9ba

Browse files
authoredOct 10, 2023
server: prohibit more than MaxConcurrentStreams handlers from running at once (#6703) (#6705)
1 parent dd9270d commit c40c9ba

File tree

5 files changed

+210
-45
lines changed

5 files changed

+210
-45
lines changed
 

‎benchmark/primitives/primitives_test.go

+39
Original file line numberDiff line numberDiff line change
@@ -425,3 +425,42 @@ func BenchmarkRLockUnlock(b *testing.B) {
425425
}
426426
})
427427
}
428+
429+
type ifNop interface {
430+
nop()
431+
}
432+
433+
type alwaysNop struct{}
434+
435+
func (alwaysNop) nop() {}
436+
437+
type concreteNop struct {
438+
isNop atomic.Bool
439+
i int
440+
}
441+
442+
func (c *concreteNop) nop() {
443+
if c.isNop.Load() {
444+
return
445+
}
446+
c.i++
447+
}
448+
449+
func BenchmarkInterfaceNop(b *testing.B) {
450+
n := ifNop(alwaysNop{})
451+
b.RunParallel(func(pb *testing.PB) {
452+
for pb.Next() {
453+
n.nop()
454+
}
455+
})
456+
}
457+
458+
func BenchmarkConcreteNop(b *testing.B) {
459+
n := &concreteNop{}
460+
n.isNop.Store(true)
461+
b.RunParallel(func(pb *testing.PB) {
462+
for pb.Next() {
463+
n.nop()
464+
}
465+
})
466+
}

‎internal/transport/http2_server.go

+3-8
Original file line numberDiff line numberDiff line change
@@ -171,15 +171,10 @@ func NewServerTransport(conn net.Conn, config *ServerConfig) (_ ServerTransport,
171171
ID: http2.SettingMaxFrameSize,
172172
Val: http2MaxFrameLen,
173173
}}
174-
// TODO(zhaoq): Have a better way to signal "no limit" because 0 is
175-
// permitted in the HTTP2 spec.
176-
maxStreams := config.MaxStreams
177-
if maxStreams == 0 {
178-
maxStreams = math.MaxUint32
179-
} else {
174+
if config.MaxStreams != math.MaxUint32 {
180175
isettings = append(isettings, http2.Setting{
181176
ID: http2.SettingMaxConcurrentStreams,
182-
Val: maxStreams,
177+
Val: config.MaxStreams,
183178
})
184179
}
185180
dynamicWindow := true
@@ -258,7 +253,7 @@ func NewServerTransport(conn net.Conn, config *ServerConfig) (_ ServerTransport,
258253
framer: framer,
259254
readerDone: make(chan struct{}),
260255
writerDone: make(chan struct{}),
261-
maxStreams: maxStreams,
256+
maxStreams: config.MaxStreams,
262257
inTapHandle: config.InTapHandle,
263258
fc: &trInFlow{limit: uint32(icwz)},
264259
state: reachable,

‎internal/transport/transport_test.go

+19-16
Original file line numberDiff line numberDiff line change
@@ -337,6 +337,9 @@ func (s *server) start(t *testing.T, port int, serverConfig *ServerConfig, ht hT
337337
return
338338
}
339339
rawConn := conn
340+
if serverConfig.MaxStreams == 0 {
341+
serverConfig.MaxStreams = math.MaxUint32
342+
}
340343
transport, err := NewServerTransport(conn, serverConfig)
341344
if err != nil {
342345
return
@@ -443,8 +446,8 @@ func setUpServerOnly(t *testing.T, port int, sc *ServerConfig, ht hType) *server
443446
return server
444447
}
445448

446-
func setUp(t *testing.T, port int, maxStreams uint32, ht hType) (*server, *http2Client, func()) {
447-
return setUpWithOptions(t, port, &ServerConfig{MaxStreams: maxStreams}, ht, ConnectOptions{})
449+
func setUp(t *testing.T, port int, ht hType) (*server, *http2Client, func()) {
450+
return setUpWithOptions(t, port, &ServerConfig{}, ht, ConnectOptions{})
448451
}
449452

450453
func setUpWithOptions(t *testing.T, port int, sc *ServerConfig, ht hType, copts ConnectOptions) (*server, *http2Client, func()) {
@@ -539,7 +542,7 @@ func (s) TestInflightStreamClosing(t *testing.T) {
539542

540543
// Tests that when streamID > MaxStreamId, the current client transport drains.
541544
func (s) TestClientTransportDrainsAfterStreamIDExhausted(t *testing.T) {
542-
server, ct, cancel := setUp(t, 0, math.MaxUint32, normal)
545+
server, ct, cancel := setUp(t, 0, normal)
543546
defer cancel()
544547
defer server.stop()
545548
callHdr := &CallHdr{
@@ -584,7 +587,7 @@ func (s) TestClientTransportDrainsAfterStreamIDExhausted(t *testing.T) {
584587
}
585588

586589
func (s) TestClientSendAndReceive(t *testing.T) {
587-
server, ct, cancel := setUp(t, 0, math.MaxUint32, normal)
590+
server, ct, cancel := setUp(t, 0, normal)
588591
defer cancel()
589592
callHdr := &CallHdr{
590593
Host: "localhost",
@@ -624,7 +627,7 @@ func (s) TestClientSendAndReceive(t *testing.T) {
624627
}
625628

626629
func (s) TestClientErrorNotify(t *testing.T) {
627-
server, ct, cancel := setUp(t, 0, math.MaxUint32, normal)
630+
server, ct, cancel := setUp(t, 0, normal)
628631
defer cancel()
629632
go server.stop()
630633
// ct.reader should detect the error and activate ct.Error().
@@ -658,7 +661,7 @@ func performOneRPC(ct ClientTransport) {
658661
}
659662

660663
func (s) TestClientMix(t *testing.T) {
661-
s, ct, cancel := setUp(t, 0, math.MaxUint32, normal)
664+
s, ct, cancel := setUp(t, 0, normal)
662665
defer cancel()
663666
time.AfterFunc(time.Second, s.stop)
664667
go func(ct ClientTransport) {
@@ -672,7 +675,7 @@ func (s) TestClientMix(t *testing.T) {
672675
}
673676

674677
func (s) TestLargeMessage(t *testing.T) {
675-
server, ct, cancel := setUp(t, 0, math.MaxUint32, normal)
678+
server, ct, cancel := setUp(t, 0, normal)
676679
defer cancel()
677680
callHdr := &CallHdr{
678681
Host: "localhost",
@@ -807,7 +810,7 @@ func (s) TestLargeMessageWithDelayRead(t *testing.T) {
807810
// proceed until they complete naturally, while not allowing creation of new
808811
// streams during this window.
809812
func (s) TestGracefulClose(t *testing.T) {
810-
server, ct, cancel := setUp(t, 0, math.MaxUint32, pingpong)
813+
server, ct, cancel := setUp(t, 0, pingpong)
811814
defer cancel()
812815
defer func() {
813816
// Stop the server's listener to make the server's goroutines terminate
@@ -873,7 +876,7 @@ func (s) TestGracefulClose(t *testing.T) {
873876
}
874877

875878
func (s) TestLargeMessageSuspension(t *testing.T) {
876-
server, ct, cancel := setUp(t, 0, math.MaxUint32, suspended)
879+
server, ct, cancel := setUp(t, 0, suspended)
877880
defer cancel()
878881
callHdr := &CallHdr{
879882
Host: "localhost",
@@ -981,7 +984,7 @@ func (s) TestMaxStreams(t *testing.T) {
981984
}
982985

983986
func (s) TestServerContextCanceledOnClosedConnection(t *testing.T) {
984-
server, ct, cancel := setUp(t, 0, math.MaxUint32, suspended)
987+
server, ct, cancel := setUp(t, 0, suspended)
985988
defer cancel()
986989
callHdr := &CallHdr{
987990
Host: "localhost",
@@ -1453,7 +1456,7 @@ func (s) TestClientWithMisbehavedServer(t *testing.T) {
14531456
var encodingTestStatus = status.New(codes.Internal, "\n")
14541457

14551458
func (s) TestEncodingRequiredStatus(t *testing.T) {
1456-
server, ct, cancel := setUp(t, 0, math.MaxUint32, encodingRequiredStatus)
1459+
server, ct, cancel := setUp(t, 0, encodingRequiredStatus)
14571460
defer cancel()
14581461
callHdr := &CallHdr{
14591462
Host: "localhost",
@@ -1481,7 +1484,7 @@ func (s) TestEncodingRequiredStatus(t *testing.T) {
14811484
}
14821485

14831486
func (s) TestInvalidHeaderField(t *testing.T) {
1484-
server, ct, cancel := setUp(t, 0, math.MaxUint32, invalidHeaderField)
1487+
server, ct, cancel := setUp(t, 0, invalidHeaderField)
14851488
defer cancel()
14861489
callHdr := &CallHdr{
14871490
Host: "localhost",
@@ -1503,7 +1506,7 @@ func (s) TestInvalidHeaderField(t *testing.T) {
15031506
}
15041507

15051508
func (s) TestHeaderChanClosedAfterReceivingAnInvalidHeader(t *testing.T) {
1506-
server, ct, cancel := setUp(t, 0, math.MaxUint32, invalidHeaderField)
1509+
server, ct, cancel := setUp(t, 0, invalidHeaderField)
15071510
defer cancel()
15081511
defer server.stop()
15091512
defer ct.Close(fmt.Errorf("closed manually by test"))
@@ -2171,7 +2174,7 @@ func (s) TestPingPong1MB(t *testing.T) {
21712174

21722175
// This is a stress-test of flow control logic.
21732176
func runPingPongTest(t *testing.T, msgSize int) {
2174-
server, client, cancel := setUp(t, 0, 0, pingpong)
2177+
server, client, cancel := setUp(t, 0, pingpong)
21752178
defer cancel()
21762179
defer server.stop()
21772180
defer client.Close(fmt.Errorf("closed manually by test"))
@@ -2253,7 +2256,7 @@ func (s) TestHeaderTblSize(t *testing.T) {
22532256
}
22542257
}()
22552258

2256-
server, ct, cancel := setUp(t, 0, math.MaxUint32, normal)
2259+
server, ct, cancel := setUp(t, 0, normal)
22572260
defer cancel()
22582261
defer ct.Close(fmt.Errorf("closed manually by test"))
22592262
defer server.stop()
@@ -2612,7 +2615,7 @@ func TestConnectionError_Unwrap(t *testing.T) {
26122615

26132616
func (s) TestPeerSetInServerContext(t *testing.T) {
26142617
// create client and server transports.
2615-
server, client, cancel := setUp(t, 0, math.MaxUint32, normal)
2618+
server, client, cancel := setUp(t, 0, normal)
26162619
defer cancel()
26172620
defer server.stop()
26182621
defer client.Close(fmt.Errorf("closed manually by test"))

‎server.go

+50-21
Original file line numberDiff line numberDiff line change
@@ -115,12 +115,6 @@ type serviceInfo struct {
115115
mdata any
116116
}
117117

118-
type serverWorkerData struct {
119-
st transport.ServerTransport
120-
wg *sync.WaitGroup
121-
stream *transport.Stream
122-
}
123-
124118
// Server is a gRPC server to serve RPC requests.
125119
type Server struct {
126120
opts serverOptions
@@ -145,7 +139,7 @@ type Server struct {
145139
channelzID *channelz.Identifier
146140
czData *channelzData
147141

148-
serverWorkerChannel chan *serverWorkerData
142+
serverWorkerChannel chan func()
149143
}
150144

151145
type serverOptions struct {
@@ -179,6 +173,7 @@ type serverOptions struct {
179173
}
180174

181175
var defaultServerOptions = serverOptions{
176+
maxConcurrentStreams: math.MaxUint32,
182177
maxReceiveMessageSize: defaultServerMaxReceiveMessageSize,
183178
maxSendMessageSize: defaultServerMaxSendMessageSize,
184179
connectionTimeout: 120 * time.Second,
@@ -404,6 +399,9 @@ func MaxSendMsgSize(m int) ServerOption {
404399
// MaxConcurrentStreams returns a ServerOption that will apply a limit on the number
405400
// of concurrent streams to each ServerTransport.
406401
func MaxConcurrentStreams(n uint32) ServerOption {
402+
if n == 0 {
403+
n = math.MaxUint32
404+
}
407405
return newFuncServerOption(func(o *serverOptions) {
408406
o.maxConcurrentStreams = n
409407
})
@@ -605,24 +603,19 @@ const serverWorkerResetThreshold = 1 << 16
605603
// [1] https://github.com/golang/go/issues/18138
606604
func (s *Server) serverWorker() {
607605
for completed := 0; completed < serverWorkerResetThreshold; completed++ {
608-
data, ok := <-s.serverWorkerChannel
606+
f, ok := <-s.serverWorkerChannel
609607
if !ok {
610608
return
611609
}
612-
s.handleSingleStream(data)
610+
f()
613611
}
614612
go s.serverWorker()
615613
}
616614

617-
func (s *Server) handleSingleStream(data *serverWorkerData) {
618-
defer data.wg.Done()
619-
s.handleStream(data.st, data.stream, s.traceInfo(data.st, data.stream))
620-
}
621-
622615
// initServerWorkers creates worker goroutines and a channel to process incoming
623616
// connections to reduce the time spent overall on runtime.morestack.
624617
func (s *Server) initServerWorkers() {
625-
s.serverWorkerChannel = make(chan *serverWorkerData)
618+
s.serverWorkerChannel = make(chan func())
626619
for i := uint32(0); i < s.opts.numServerWorkers; i++ {
627620
go s.serverWorker()
628621
}
@@ -982,21 +975,26 @@ func (s *Server) serveStreams(st transport.ServerTransport) {
982975
defer st.Close(errors.New("finished serving streams for the server transport"))
983976
var wg sync.WaitGroup
984977

978+
streamQuota := newHandlerQuota(s.opts.maxConcurrentStreams)
985979
st.HandleStreams(func(stream *transport.Stream) {
986980
wg.Add(1)
981+
982+
streamQuota.acquire()
983+
f := func() {
984+
defer streamQuota.release()
985+
defer wg.Done()
986+
s.handleStream(st, stream, s.traceInfo(st, stream))
987+
}
988+
987989
if s.opts.numServerWorkers > 0 {
988-
data := &serverWorkerData{st: st, wg: &wg, stream: stream}
989990
select {
990-
case s.serverWorkerChannel <- data:
991+
case s.serverWorkerChannel <- f:
991992
return
992993
default:
993994
// If all stream workers are busy, fallback to the default code path.
994995
}
995996
}
996-
go func() {
997-
defer wg.Done()
998-
s.handleStream(st, stream, s.traceInfo(st, stream))
999-
}()
997+
go f()
1000998
}, func(ctx context.Context, method string) context.Context {
1001999
if !EnableTracing {
10021000
return ctx
@@ -2091,3 +2089,34 @@ func validateSendCompressor(name, clientCompressors string) error {
20912089
}
20922090
return fmt.Errorf("client does not support compressor %q", name)
20932091
}
2092+
2093+
// atomicSemaphore implements a blocking, counting semaphore. acquire should be
2094+
// called synchronously; release may be called asynchronously.
2095+
type atomicSemaphore struct {
2096+
n atomic.Int64
2097+
wait chan struct{}
2098+
}
2099+
2100+
func (q *atomicSemaphore) acquire() {
2101+
if q.n.Add(-1) < 0 {
2102+
// We ran out of quota. Block until a release happens.
2103+
<-q.wait
2104+
}
2105+
}
2106+
2107+
func (q *atomicSemaphore) release() {
2108+
// N.B. the "<= 0" check below should allow for this to work with multiple
2109+
// concurrent calls to acquire, but also note that with synchronous calls to
2110+
// acquire, as our system does, n will never be less than -1. There are
2111+
// fairness issues (queuing) to consider if this was to be generalized.
2112+
if q.n.Add(1) <= 0 {
2113+
// An acquire was waiting on us. Unblock it.
2114+
q.wait <- struct{}{}
2115+
}
2116+
}
2117+
2118+
func newHandlerQuota(n uint32) *atomicSemaphore {
2119+
a := &atomicSemaphore{wait: make(chan struct{}, 1)}
2120+
a.n.Store(int64(n))
2121+
return a
2122+
}

‎server_ext_test.go

+99
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
/*
2+
*
3+
* Copyright 2023 gRPC authors.
4+
*
5+
* Licensed under the Apache License, Version 2.0 (the "License");
6+
* you may not use this file except in compliance with the License.
7+
* You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*
17+
*/
18+
19+
package grpc_test
20+
21+
import (
22+
"context"
23+
"io"
24+
"testing"
25+
"time"
26+
27+
"google.golang.org/grpc"
28+
"google.golang.org/grpc/internal/grpcsync"
29+
"google.golang.org/grpc/internal/stubserver"
30+
31+
testgrpc "google.golang.org/grpc/interop/grpc_testing"
32+
)
33+
34+
// TestServer_MaxHandlers ensures that no more than MaxConcurrentStreams server
35+
// handlers are active at one time.
36+
func (s) TestServer_MaxHandlers(t *testing.T) {
37+
started := make(chan struct{})
38+
blockCalls := grpcsync.NewEvent()
39+
40+
// This stub server does not properly respect the stream context, so it will
41+
// not exit when the context is canceled.
42+
ss := stubserver.StubServer{
43+
FullDuplexCallF: func(stream testgrpc.TestService_FullDuplexCallServer) error {
44+
started <- struct{}{}
45+
<-blockCalls.Done()
46+
return nil
47+
},
48+
}
49+
if err := ss.Start([]grpc.ServerOption{grpc.MaxConcurrentStreams(1)}); err != nil {
50+
t.Fatal("Error starting server:", err)
51+
}
52+
defer ss.Stop()
53+
54+
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
55+
defer cancel()
56+
57+
// Start one RPC to the server.
58+
ctx1, cancel1 := context.WithCancel(ctx)
59+
_, err := ss.Client.FullDuplexCall(ctx1)
60+
if err != nil {
61+
t.Fatal("Error staring call:", err)
62+
}
63+
64+
// Wait for the handler to be invoked.
65+
select {
66+
case <-started:
67+
case <-ctx.Done():
68+
t.Fatalf("Timed out waiting for RPC to start on server.")
69+
}
70+
71+
// Cancel it on the client. The server handler will still be running.
72+
cancel1()
73+
74+
ctx2, cancel2 := context.WithCancel(ctx)
75+
defer cancel2()
76+
s, err := ss.Client.FullDuplexCall(ctx2)
77+
if err != nil {
78+
t.Fatal("Error staring call:", err)
79+
}
80+
81+
// After 100ms, allow the first call to unblock. That should allow the
82+
// second RPC to run and finish.
83+
select {
84+
case <-started:
85+
blockCalls.Fire()
86+
t.Fatalf("RPC started unexpectedly.")
87+
case <-time.After(100 * time.Millisecond):
88+
blockCalls.Fire()
89+
}
90+
91+
select {
92+
case <-started:
93+
case <-ctx.Done():
94+
t.Fatalf("Timed out waiting for second RPC to start on server.")
95+
}
96+
if _, err := s.Recv(); err != io.EOF {
97+
t.Fatal("Received unexpected RPC error:", err)
98+
}
99+
}

0 commit comments

Comments
 (0)
Please sign in to comment.