@@ -32,6 +32,7 @@ import (
32
32
"strconv"
33
33
"strings"
34
34
"sync"
35
+ "sync/atomic"
35
36
"testing"
36
37
"time"
37
38
@@ -2424,7 +2425,7 @@ func (s) TestClientHandshakeInfo(t *testing.T) {
2424
2425
TransportCredentials : creds ,
2425
2426
ChannelzParent : channelzSubChannel (t ),
2426
2427
}
2427
- tr , err := NewClientTransport (ctx , context . Background () , addr , copts , func (GoAwayReason ) {})
2428
+ tr , err := NewClientTransport (ctx , ctx , addr , copts , func (GoAwayReason ) {})
2428
2429
if err != nil {
2429
2430
t .Fatalf ("NewClientTransport(): %v" , err )
2430
2431
}
@@ -2465,7 +2466,7 @@ func (s) TestClientHandshakeInfoDialer(t *testing.T) {
2465
2466
Dialer : dialer ,
2466
2467
ChannelzParent : channelzSubChannel (t ),
2467
2468
}
2468
- tr , err := NewClientTransport (ctx , context . Background () , addr , copts , func (GoAwayReason ) {})
2469
+ tr , err := NewClientTransport (ctx , ctx , addr , copts , func (GoAwayReason ) {})
2469
2470
if err != nil {
2470
2471
t .Fatalf ("NewClientTransport(): %v" , err )
2471
2472
}
@@ -2725,7 +2726,7 @@ func (s) TestClientSendsAGoAwayFrame(t *testing.T) {
2725
2726
}
2726
2727
}()
2727
2728
2728
- ct , err := NewClientTransport (ctx , context . Background () , resolver.Address {Addr : lis .Addr ().String ()}, ConnectOptions {}, func (GoAwayReason ) {})
2729
+ ct , err := NewClientTransport (ctx , ctx , resolver.Address {Addr : lis .Addr ().String ()}, ConnectOptions {}, func (GoAwayReason ) {})
2729
2730
if err != nil {
2730
2731
t .Fatalf ("Error while creating client transport: %v" , err )
2731
2732
}
@@ -2746,3 +2747,68 @@ func (s) TestClientSendsAGoAwayFrame(t *testing.T) {
2746
2747
t .Errorf ("Context timed out" )
2747
2748
}
2748
2749
}
2750
+
2751
+ // hangingConn is a net.Conn wrapper for testing, simulating hanging connections
2752
+ // after a GOAWAY frame is sent, of which Write operations pause until explicitly
2753
+ // signaled or a timeout occurs.
2754
+ type hangingConn struct {
2755
+ net.Conn
2756
+ hangConn chan struct {}
2757
+ startHanging * atomic.Bool
2758
+ }
2759
+
2760
+ func (hc * hangingConn ) Write (b []byte ) (n int , err error ) {
2761
+ n , err = hc .Conn .Write (b )
2762
+ if hc .startHanging .Load () {
2763
+ <- hc .hangConn
2764
+ }
2765
+ return n , err
2766
+ }
2767
+
2768
+ // Tests the scenario where a client transport is closed and writing of the
2769
+ // GOAWAY frame as part of the close does not complete because of a network
2770
+ // hang. The test verifies that the client transport is closed without waiting
2771
+ // for too long.
2772
+ func (s ) TestClientCloseReturnsEarlyWhenGoAwayWriteHangs (t * testing.T ) {
2773
+ // Override timer for writing GOAWAY to 0 so that the connection write
2774
+ // always times out. It is equivalent of real network hang when conn
2775
+ // write for goaway doesn't finish in specified deadline
2776
+ origGoAwayLoopyTimeout := goAwayLoopyWriterTimeout
2777
+ goAwayLoopyWriterTimeout = time .Millisecond
2778
+ defer func () {
2779
+ goAwayLoopyWriterTimeout = origGoAwayLoopyTimeout
2780
+ }()
2781
+
2782
+ // Create the server set up.
2783
+ connectCtx , cancel := context .WithTimeout (context .Background (), defaultTestTimeout )
2784
+ defer cancel ()
2785
+ server := setUpServerOnly (t , 0 , & ServerConfig {}, normal )
2786
+ defer server .stop ()
2787
+ addr := resolver.Address {Addr : "localhost:" + server .port }
2788
+ isGreetingDone := & atomic.Bool {}
2789
+ hangConn := make (chan struct {})
2790
+ defer close (hangConn )
2791
+ dialer := func (_ context.Context , addr string ) (net.Conn , error ) {
2792
+ conn , err := net .Dial ("tcp" , addr )
2793
+ if err != nil {
2794
+ return nil , err
2795
+ }
2796
+ return & hangingConn {Conn : conn , hangConn : hangConn , startHanging : isGreetingDone }, nil
2797
+ }
2798
+ copts := ConnectOptions {Dialer : dialer }
2799
+ copts .ChannelzParent = channelzSubChannel (t )
2800
+ // Create client transport with custom dialer
2801
+ ct , connErr := NewClientTransport (connectCtx , context .Background (), addr , copts , func (GoAwayReason ) {})
2802
+ if connErr != nil {
2803
+ t .Fatalf ("failed to create transport: %v" , connErr )
2804
+ }
2805
+
2806
+ ctx , cancel := context .WithTimeout (context .Background (), defaultTestTimeout )
2807
+ defer cancel ()
2808
+ if _ , err := ct .NewStream (ctx , & CallHdr {}); err != nil {
2809
+ t .Fatalf ("Failed to open stream: %v" , err )
2810
+ }
2811
+
2812
+ isGreetingDone .Store (true )
2813
+ ct .Close (errors .New ("manually closed by client" ))
2814
+ }
0 commit comments