Skip to content

Commit 35e915e

Browse files
authored
cherry-pick: transport: add timeout for writing GOAWAY on http2Client.Close() #7371 (#7540)
1 parent 63853fd commit 35e915e

File tree

2 files changed

+84
-5
lines changed

2 files changed

+84
-5
lines changed

internal/transport/http2_client.go

+15-2
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,8 @@ import (
5959
// atomically.
6060
var clientConnectionCounter uint64
6161

62+
var goAwayLoopyWriterTimeout = 5 * time.Second
63+
6264
var metadataFromOutgoingContextRaw = internal.FromOutgoingContextRaw.(func(context.Context) (metadata.MD, [][]string, bool))
6365

6466
// http2Client implements the ClientTransport interface with HTTP2.
@@ -983,6 +985,7 @@ func (t *http2Client) closeStream(s *Stream, err error, rst bool, rstCode http2.
983985
// only once on a transport. Once it is called, the transport should not be
984986
// accessed anymore.
985987
func (t *http2Client) Close(err error) {
988+
t.conn.SetWriteDeadline(time.Now().Add(time.Second * 10))
986989
t.mu.Lock()
987990
// Make sure we only close once.
988991
if t.state == closing {
@@ -1006,10 +1009,20 @@ func (t *http2Client) Close(err error) {
10061009
t.kpDormancyCond.Signal()
10071010
}
10081011
t.mu.Unlock()
1012+
10091013
// Per HTTP/2 spec, a GOAWAY frame must be sent before closing the
1010-
// connection. See https://httpwg.org/specs/rfc7540.html#GOAWAY.
1014+
// connection. See https://httpwg.org/specs/rfc7540.html#GOAWAY. It
1015+
// also waits for loopyWriter to be closed with a timer to avoid the
1016+
// long blocking in case the connection is blackholed, i.e. TCP is
1017+
// just stuck.
10111018
t.controlBuf.put(&goAway{code: http2.ErrCodeNo, debugData: []byte("client transport shutdown"), closeConn: err})
1012-
<-t.writerDone
1019+
timer := time.NewTimer(goAwayLoopyWriterTimeout)
1020+
defer timer.Stop()
1021+
select {
1022+
case <-t.writerDone: // success
1023+
case <-timer.C:
1024+
t.logger.Infof("Failed to write a GOAWAY frame as part of connection close after %s. Giving up and closing the transport.", goAwayLoopyWriterTimeout)
1025+
}
10131026
t.cancel()
10141027
t.conn.Close()
10151028
channelz.RemoveEntry(t.channelz.ID)

internal/transport/transport_test.go

+69-3
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ import (
3232
"strconv"
3333
"strings"
3434
"sync"
35+
"sync/atomic"
3536
"testing"
3637
"time"
3738

@@ -2424,7 +2425,7 @@ func (s) TestClientHandshakeInfo(t *testing.T) {
24242425
TransportCredentials: creds,
24252426
ChannelzParent: channelzSubChannel(t),
24262427
}
2427-
tr, err := NewClientTransport(ctx, context.Background(), addr, copts, func(GoAwayReason) {})
2428+
tr, err := NewClientTransport(ctx, ctx, addr, copts, func(GoAwayReason) {})
24282429
if err != nil {
24292430
t.Fatalf("NewClientTransport(): %v", err)
24302431
}
@@ -2465,7 +2466,7 @@ func (s) TestClientHandshakeInfoDialer(t *testing.T) {
24652466
Dialer: dialer,
24662467
ChannelzParent: channelzSubChannel(t),
24672468
}
2468-
tr, err := NewClientTransport(ctx, context.Background(), addr, copts, func(GoAwayReason) {})
2469+
tr, err := NewClientTransport(ctx, ctx, addr, copts, func(GoAwayReason) {})
24692470
if err != nil {
24702471
t.Fatalf("NewClientTransport(): %v", err)
24712472
}
@@ -2725,7 +2726,7 @@ func (s) TestClientSendsAGoAwayFrame(t *testing.T) {
27252726
}
27262727
}()
27272728

2728-
ct, err := NewClientTransport(ctx, context.Background(), resolver.Address{Addr: lis.Addr().String()}, ConnectOptions{}, func(GoAwayReason) {})
2729+
ct, err := NewClientTransport(ctx, ctx, resolver.Address{Addr: lis.Addr().String()}, ConnectOptions{}, func(GoAwayReason) {})
27292730
if err != nil {
27302731
t.Fatalf("Error while creating client transport: %v", err)
27312732
}
@@ -2746,3 +2747,68 @@ func (s) TestClientSendsAGoAwayFrame(t *testing.T) {
27462747
t.Errorf("Context timed out")
27472748
}
27482749
}
2750+
2751+
// hangingConn is a net.Conn wrapper for testing, simulating hanging connections
2752+
// after a GOAWAY frame is sent, of which Write operations pause until explicitly
2753+
// signaled or a timeout occurs.
2754+
type hangingConn struct {
2755+
net.Conn
2756+
hangConn chan struct{}
2757+
startHanging *atomic.Bool
2758+
}
2759+
2760+
func (hc *hangingConn) Write(b []byte) (n int, err error) {
2761+
n, err = hc.Conn.Write(b)
2762+
if hc.startHanging.Load() {
2763+
<-hc.hangConn
2764+
}
2765+
return n, err
2766+
}
2767+
2768+
// Tests the scenario where a client transport is closed and writing of the
2769+
// GOAWAY frame as part of the close does not complete because of a network
2770+
// hang. The test verifies that the client transport is closed without waiting
2771+
// for too long.
2772+
func (s) TestClientCloseReturnsEarlyWhenGoAwayWriteHangs(t *testing.T) {
2773+
// Override timer for writing GOAWAY to 0 so that the connection write
2774+
// always times out. It is equivalent of real network hang when conn
2775+
// write for goaway doesn't finish in specified deadline
2776+
origGoAwayLoopyTimeout := goAwayLoopyWriterTimeout
2777+
goAwayLoopyWriterTimeout = time.Millisecond
2778+
defer func() {
2779+
goAwayLoopyWriterTimeout = origGoAwayLoopyTimeout
2780+
}()
2781+
2782+
// Create the server set up.
2783+
connectCtx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
2784+
defer cancel()
2785+
server := setUpServerOnly(t, 0, &ServerConfig{}, normal)
2786+
defer server.stop()
2787+
addr := resolver.Address{Addr: "localhost:" + server.port}
2788+
isGreetingDone := &atomic.Bool{}
2789+
hangConn := make(chan struct{})
2790+
defer close(hangConn)
2791+
dialer := func(_ context.Context, addr string) (net.Conn, error) {
2792+
conn, err := net.Dial("tcp", addr)
2793+
if err != nil {
2794+
return nil, err
2795+
}
2796+
return &hangingConn{Conn: conn, hangConn: hangConn, startHanging: isGreetingDone}, nil
2797+
}
2798+
copts := ConnectOptions{Dialer: dialer}
2799+
copts.ChannelzParent = channelzSubChannel(t)
2800+
// Create client transport with custom dialer
2801+
ct, connErr := NewClientTransport(connectCtx, context.Background(), addr, copts, func(GoAwayReason) {})
2802+
if connErr != nil {
2803+
t.Fatalf("failed to create transport: %v", connErr)
2804+
}
2805+
2806+
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
2807+
defer cancel()
2808+
if _, err := ct.NewStream(ctx, &CallHdr{}); err != nil {
2809+
t.Fatalf("Failed to open stream: %v", err)
2810+
}
2811+
2812+
isGreetingDone.Store(true)
2813+
ct.Close(errors.New("manually closed by client"))
2814+
}

0 commit comments

Comments
 (0)