Skip to content

Commit e88f12e

Browse files
authored
server: prohibit more than MaxConcurrentStreams handlers from running at once (#6703) (#6704)
1 parent be7919c commit e88f12e

File tree

5 files changed

+210
-45
lines changed

5 files changed

+210
-45
lines changed

Diff for: benchmark/primitives/primitives_test.go

+39
Original file line numberDiff line numberDiff line change
@@ -425,3 +425,42 @@ func BenchmarkRLockUnlock(b *testing.B) {
425425
}
426426
})
427427
}
428+
429+
type ifNop interface {
430+
nop()
431+
}
432+
433+
type alwaysNop struct{}
434+
435+
func (alwaysNop) nop() {}
436+
437+
type concreteNop struct {
438+
isNop atomic.Bool
439+
i int
440+
}
441+
442+
func (c *concreteNop) nop() {
443+
if c.isNop.Load() {
444+
return
445+
}
446+
c.i++
447+
}
448+
449+
func BenchmarkInterfaceNop(b *testing.B) {
450+
n := ifNop(alwaysNop{})
451+
b.RunParallel(func(pb *testing.PB) {
452+
for pb.Next() {
453+
n.nop()
454+
}
455+
})
456+
}
457+
458+
func BenchmarkConcreteNop(b *testing.B) {
459+
n := &concreteNop{}
460+
n.isNop.Store(true)
461+
b.RunParallel(func(pb *testing.PB) {
462+
for pb.Next() {
463+
n.nop()
464+
}
465+
})
466+
}

Diff for: internal/transport/http2_server.go

+3-8
Original file line numberDiff line numberDiff line change
@@ -171,15 +171,10 @@ func NewServerTransport(conn net.Conn, config *ServerConfig) (_ ServerTransport,
171171
ID: http2.SettingMaxFrameSize,
172172
Val: http2MaxFrameLen,
173173
}}
174-
// TODO(zhaoq): Have a better way to signal "no limit" because 0 is
175-
// permitted in the HTTP2 spec.
176-
maxStreams := config.MaxStreams
177-
if maxStreams == 0 {
178-
maxStreams = math.MaxUint32
179-
} else {
174+
if config.MaxStreams != math.MaxUint32 {
180175
isettings = append(isettings, http2.Setting{
181176
ID: http2.SettingMaxConcurrentStreams,
182-
Val: maxStreams,
177+
Val: config.MaxStreams,
183178
})
184179
}
185180
dynamicWindow := true
@@ -258,7 +253,7 @@ func NewServerTransport(conn net.Conn, config *ServerConfig) (_ ServerTransport,
258253
framer: framer,
259254
readerDone: make(chan struct{}),
260255
writerDone: make(chan struct{}),
261-
maxStreams: maxStreams,
256+
maxStreams: config.MaxStreams,
262257
inTapHandle: config.InTapHandle,
263258
fc: &trInFlow{limit: uint32(icwz)},
264259
state: reachable,

Diff for: internal/transport/transport_test.go

+19-16
Original file line numberDiff line numberDiff line change
@@ -337,6 +337,9 @@ func (s *server) start(t *testing.T, port int, serverConfig *ServerConfig, ht hT
337337
return
338338
}
339339
rawConn := conn
340+
if serverConfig.MaxStreams == 0 {
341+
serverConfig.MaxStreams = math.MaxUint32
342+
}
340343
transport, err := NewServerTransport(conn, serverConfig)
341344
if err != nil {
342345
return
@@ -425,8 +428,8 @@ func setUpServerOnly(t *testing.T, port int, sc *ServerConfig, ht hType) *server
425428
return server
426429
}
427430

428-
func setUp(t *testing.T, port int, maxStreams uint32, ht hType) (*server, *http2Client, func()) {
429-
return setUpWithOptions(t, port, &ServerConfig{MaxStreams: maxStreams}, ht, ConnectOptions{})
431+
func setUp(t *testing.T, port int, ht hType) (*server, *http2Client, func()) {
432+
return setUpWithOptions(t, port, &ServerConfig{}, ht, ConnectOptions{})
430433
}
431434

432435
func setUpWithOptions(t *testing.T, port int, sc *ServerConfig, ht hType, copts ConnectOptions) (*server, *http2Client, func()) {
@@ -521,7 +524,7 @@ func (s) TestInflightStreamClosing(t *testing.T) {
521524

522525
// Tests that when streamID > MaxStreamId, the current client transport drains.
523526
func (s) TestClientTransportDrainsAfterStreamIDExhausted(t *testing.T) {
524-
server, ct, cancel := setUp(t, 0, math.MaxUint32, normal)
527+
server, ct, cancel := setUp(t, 0, normal)
525528
defer cancel()
526529
defer server.stop()
527530
callHdr := &CallHdr{
@@ -566,7 +569,7 @@ func (s) TestClientTransportDrainsAfterStreamIDExhausted(t *testing.T) {
566569
}
567570

568571
func (s) TestClientSendAndReceive(t *testing.T) {
569-
server, ct, cancel := setUp(t, 0, math.MaxUint32, normal)
572+
server, ct, cancel := setUp(t, 0, normal)
570573
defer cancel()
571574
callHdr := &CallHdr{
572575
Host: "localhost",
@@ -606,7 +609,7 @@ func (s) TestClientSendAndReceive(t *testing.T) {
606609
}
607610

608611
func (s) TestClientErrorNotify(t *testing.T) {
609-
server, ct, cancel := setUp(t, 0, math.MaxUint32, normal)
612+
server, ct, cancel := setUp(t, 0, normal)
610613
defer cancel()
611614
go server.stop()
612615
// ct.reader should detect the error and activate ct.Error().
@@ -640,7 +643,7 @@ func performOneRPC(ct ClientTransport) {
640643
}
641644

642645
func (s) TestClientMix(t *testing.T) {
643-
s, ct, cancel := setUp(t, 0, math.MaxUint32, normal)
646+
s, ct, cancel := setUp(t, 0, normal)
644647
defer cancel()
645648
time.AfterFunc(time.Second, s.stop)
646649
go func(ct ClientTransport) {
@@ -654,7 +657,7 @@ func (s) TestClientMix(t *testing.T) {
654657
}
655658

656659
func (s) TestLargeMessage(t *testing.T) {
657-
server, ct, cancel := setUp(t, 0, math.MaxUint32, normal)
660+
server, ct, cancel := setUp(t, 0, normal)
658661
defer cancel()
659662
callHdr := &CallHdr{
660663
Host: "localhost",
@@ -789,7 +792,7 @@ func (s) TestLargeMessageWithDelayRead(t *testing.T) {
789792
// proceed until they complete naturally, while not allowing creation of new
790793
// streams during this window.
791794
func (s) TestGracefulClose(t *testing.T) {
792-
server, ct, cancel := setUp(t, 0, math.MaxUint32, pingpong)
795+
server, ct, cancel := setUp(t, 0, pingpong)
793796
defer cancel()
794797
defer func() {
795798
// Stop the server's listener to make the server's goroutines terminate
@@ -855,7 +858,7 @@ func (s) TestGracefulClose(t *testing.T) {
855858
}
856859

857860
func (s) TestLargeMessageSuspension(t *testing.T) {
858-
server, ct, cancel := setUp(t, 0, math.MaxUint32, suspended)
861+
server, ct, cancel := setUp(t, 0, suspended)
859862
defer cancel()
860863
callHdr := &CallHdr{
861864
Host: "localhost",
@@ -963,7 +966,7 @@ func (s) TestMaxStreams(t *testing.T) {
963966
}
964967

965968
func (s) TestServerContextCanceledOnClosedConnection(t *testing.T) {
966-
server, ct, cancel := setUp(t, 0, math.MaxUint32, suspended)
969+
server, ct, cancel := setUp(t, 0, suspended)
967970
defer cancel()
968971
callHdr := &CallHdr{
969972
Host: "localhost",
@@ -1435,7 +1438,7 @@ func (s) TestClientWithMisbehavedServer(t *testing.T) {
14351438
var encodingTestStatus = status.New(codes.Internal, "\n")
14361439

14371440
func (s) TestEncodingRequiredStatus(t *testing.T) {
1438-
server, ct, cancel := setUp(t, 0, math.MaxUint32, encodingRequiredStatus)
1441+
server, ct, cancel := setUp(t, 0, encodingRequiredStatus)
14391442
defer cancel()
14401443
callHdr := &CallHdr{
14411444
Host: "localhost",
@@ -1463,7 +1466,7 @@ func (s) TestEncodingRequiredStatus(t *testing.T) {
14631466
}
14641467

14651468
func (s) TestInvalidHeaderField(t *testing.T) {
1466-
server, ct, cancel := setUp(t, 0, math.MaxUint32, invalidHeaderField)
1469+
server, ct, cancel := setUp(t, 0, invalidHeaderField)
14671470
defer cancel()
14681471
callHdr := &CallHdr{
14691472
Host: "localhost",
@@ -1485,7 +1488,7 @@ func (s) TestInvalidHeaderField(t *testing.T) {
14851488
}
14861489

14871490
func (s) TestHeaderChanClosedAfterReceivingAnInvalidHeader(t *testing.T) {
1488-
server, ct, cancel := setUp(t, 0, math.MaxUint32, invalidHeaderField)
1491+
server, ct, cancel := setUp(t, 0, invalidHeaderField)
14891492
defer cancel()
14901493
defer server.stop()
14911494
defer ct.Close(fmt.Errorf("closed manually by test"))
@@ -2153,7 +2156,7 @@ func (s) TestPingPong1MB(t *testing.T) {
21532156

21542157
// This is a stress-test of flow control logic.
21552158
func runPingPongTest(t *testing.T, msgSize int) {
2156-
server, client, cancel := setUp(t, 0, 0, pingpong)
2159+
server, client, cancel := setUp(t, 0, pingpong)
21572160
defer cancel()
21582161
defer server.stop()
21592162
defer client.Close(fmt.Errorf("closed manually by test"))
@@ -2235,7 +2238,7 @@ func (s) TestHeaderTblSize(t *testing.T) {
22352238
}
22362239
}()
22372240

2238-
server, ct, cancel := setUp(t, 0, math.MaxUint32, normal)
2241+
server, ct, cancel := setUp(t, 0, normal)
22392242
defer cancel()
22402243
defer ct.Close(fmt.Errorf("closed manually by test"))
22412244
defer server.stop()
@@ -2594,7 +2597,7 @@ func TestConnectionError_Unwrap(t *testing.T) {
25942597

25952598
func (s) TestPeerSetInServerContext(t *testing.T) {
25962599
// create client and server transports.
2597-
server, client, cancel := setUp(t, 0, math.MaxUint32, normal)
2600+
server, client, cancel := setUp(t, 0, normal)
25982601
defer cancel()
25992602
defer server.stop()
26002603
defer client.Close(fmt.Errorf("closed manually by test"))

Diff for: server.go

+50-21
Original file line numberDiff line numberDiff line change
@@ -115,12 +115,6 @@ type serviceInfo struct {
115115
mdata any
116116
}
117117

118-
type serverWorkerData struct {
119-
st transport.ServerTransport
120-
wg *sync.WaitGroup
121-
stream *transport.Stream
122-
}
123-
124118
// Server is a gRPC server to serve RPC requests.
125119
type Server struct {
126120
opts serverOptions
@@ -145,7 +139,7 @@ type Server struct {
145139
channelzID *channelz.Identifier
146140
czData *channelzData
147141

148-
serverWorkerChannel chan *serverWorkerData
142+
serverWorkerChannel chan func()
149143
}
150144

151145
type serverOptions struct {
@@ -179,6 +173,7 @@ type serverOptions struct {
179173
}
180174

181175
var defaultServerOptions = serverOptions{
176+
maxConcurrentStreams: math.MaxUint32,
182177
maxReceiveMessageSize: defaultServerMaxReceiveMessageSize,
183178
maxSendMessageSize: defaultServerMaxSendMessageSize,
184179
connectionTimeout: 120 * time.Second,
@@ -404,6 +399,9 @@ func MaxSendMsgSize(m int) ServerOption {
404399
// MaxConcurrentStreams returns a ServerOption that will apply a limit on the number
405400
// of concurrent streams to each ServerTransport.
406401
func MaxConcurrentStreams(n uint32) ServerOption {
402+
if n == 0 {
403+
n = math.MaxUint32
404+
}
407405
return newFuncServerOption(func(o *serverOptions) {
408406
o.maxConcurrentStreams = n
409407
})
@@ -605,24 +603,19 @@ const serverWorkerResetThreshold = 1 << 16
605603
// [1] https://github.com/golang/go/issues/18138
606604
func (s *Server) serverWorker() {
607605
for completed := 0; completed < serverWorkerResetThreshold; completed++ {
608-
data, ok := <-s.serverWorkerChannel
606+
f, ok := <-s.serverWorkerChannel
609607
if !ok {
610608
return
611609
}
612-
s.handleSingleStream(data)
610+
f()
613611
}
614612
go s.serverWorker()
615613
}
616614

617-
func (s *Server) handleSingleStream(data *serverWorkerData) {
618-
defer data.wg.Done()
619-
s.handleStream(data.st, data.stream)
620-
}
621-
622615
// initServerWorkers creates worker goroutines and a channel to process incoming
623616
// connections to reduce the time spent overall on runtime.morestack.
624617
func (s *Server) initServerWorkers() {
625-
s.serverWorkerChannel = make(chan *serverWorkerData)
618+
s.serverWorkerChannel = make(chan func())
626619
for i := uint32(0); i < s.opts.numServerWorkers; i++ {
627620
go s.serverWorker()
628621
}
@@ -982,21 +975,26 @@ func (s *Server) serveStreams(st transport.ServerTransport) {
982975
defer st.Close(errors.New("finished serving streams for the server transport"))
983976
var wg sync.WaitGroup
984977

978+
streamQuota := newHandlerQuota(s.opts.maxConcurrentStreams)
985979
st.HandleStreams(func(stream *transport.Stream) {
986980
wg.Add(1)
981+
982+
streamQuota.acquire()
983+
f := func() {
984+
defer streamQuota.release()
985+
defer wg.Done()
986+
s.handleStream(st, stream)
987+
}
988+
987989
if s.opts.numServerWorkers > 0 {
988-
data := &serverWorkerData{st: st, wg: &wg, stream: stream}
989990
select {
990-
case s.serverWorkerChannel <- data:
991+
case s.serverWorkerChannel <- f:
991992
return
992993
default:
993994
// If all stream workers are busy, fallback to the default code path.
994995
}
995996
}
996-
go func() {
997-
defer wg.Done()
998-
s.handleStream(st, stream)
999-
}()
997+
go f()
1000998
})
1001999
wg.Wait()
10021000
}
@@ -2077,3 +2075,34 @@ func validateSendCompressor(name, clientCompressors string) error {
20772075
}
20782076
return fmt.Errorf("client does not support compressor %q", name)
20792077
}
2078+
2079+
// atomicSemaphore implements a blocking, counting semaphore. acquire should be
2080+
// called synchronously; release may be called asynchronously.
2081+
type atomicSemaphore struct {
2082+
n atomic.Int64
2083+
wait chan struct{}
2084+
}
2085+
2086+
func (q *atomicSemaphore) acquire() {
2087+
if q.n.Add(-1) < 0 {
2088+
// We ran out of quota. Block until a release happens.
2089+
<-q.wait
2090+
}
2091+
}
2092+
2093+
func (q *atomicSemaphore) release() {
2094+
// N.B. the "<= 0" check below should allow for this to work with multiple
2095+
// concurrent calls to acquire, but also note that with synchronous calls to
2096+
// acquire, as our system does, n will never be less than -1. There are
2097+
// fairness issues (queuing) to consider if this was to be generalized.
2098+
if q.n.Add(1) <= 0 {
2099+
// An acquire was waiting on us. Unblock it.
2100+
q.wait <- struct{}{}
2101+
}
2102+
}
2103+
2104+
func newHandlerQuota(n uint32) *atomicSemaphore {
2105+
a := &atomicSemaphore{wait: make(chan struct{}, 1)}
2106+
a.n.Store(int64(n))
2107+
return a
2108+
}

0 commit comments

Comments
 (0)