Skip to content

Commit c6185aa

Browse files
benburkertbradfitz
authored andcommitted
crypto/tls: add CloseWrite method to Conn
The CloseWrite method sends a close_notify alert record to the other side of the connection. This record indicates that the sender has finished sending on the connection. Unlike the Close method, the sender may still read from the connection until it recieves a close_notify record (or the underlying connection is closed). This is analogous to a TCP half-close. Updates #8579 Change-Id: I9c6bc193efcb25cc187f7735ee07170afa7fdde3 Reviewed-on: https://go-review.googlesource.com/25159 Reviewed-by: Brad Fitzpatrick <[email protected]> Run-TryBot: Brad Fitzpatrick <[email protected]> TryBot-Result: Gobot Gobot <[email protected]>
1 parent b97b753 commit c6185aa

File tree

2 files changed

+128
-2
lines changed

2 files changed

+128
-2
lines changed

src/crypto/tls/conn.go

+42-2
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,13 @@ type Conn struct {
6464
// the first transmitted Finished message is the tls-unique
6565
// channel-binding value.
6666
clientFinishedIsFirst bool
67+
68+
// closeNotifyErr is any error from sending the alertCloseNotify record.
69+
closeNotifyErr error
70+
// closeNotifySent is true if the Conn attempted to send an
71+
// alertCloseNotify record.
72+
closeNotifySent bool
73+
6774
// clientFinished and serverFinished contain the Finished message sent
6875
// by the client or server in the most recent handshake. This is
6976
// retained to support the renegotiation extension and tls-unique
@@ -992,7 +999,10 @@ func (c *Conn) readHandshake() (interface{}, error) {
992999
return m, nil
9931000
}
9941001

995-
var errClosed = errors.New("tls: use of closed connection")
1002+
var (
1003+
errClosed = errors.New("tls: use of closed connection")
1004+
errShutdown = errors.New("tls: protocol is shutdown")
1005+
)
9961006

9971007
// Write writes data to the connection.
9981008
func (c *Conn) Write(b []byte) (int, error) {
@@ -1023,6 +1033,10 @@ func (c *Conn) Write(b []byte) (int, error) {
10231033
return 0, alertInternalError
10241034
}
10251035

1036+
if c.closeNotifySent {
1037+
return 0, errShutdown
1038+
}
1039+
10261040
// SSL 3.0 and TLS 1.0 are susceptible to a chosen-plaintext
10271041
// attack when using block mode ciphers due to predictable IVs.
10281042
// This can be prevented by splitting each Application Data
@@ -1186,7 +1200,7 @@ func (c *Conn) Close() error {
11861200
c.handshakeMutex.Lock()
11871201
defer c.handshakeMutex.Unlock()
11881202
if c.handshakeComplete {
1189-
alertErr = c.sendAlert(alertCloseNotify)
1203+
alertErr = c.closeNotify()
11901204
}
11911205

11921206
if err := c.conn.Close(); err != nil {
@@ -1195,6 +1209,32 @@ func (c *Conn) Close() error {
11951209
return alertErr
11961210
}
11971211

1212+
var errEarlyCloseWrite = errors.New("tls: CloseWrite called before handshake complete")
1213+
1214+
// CloseWrite shuts down the writing side of the connection. It should only be
1215+
// called once the handshake has completed and does not call CloseWrite on the
1216+
// underlying connection. Most callers should just use Close.
1217+
func (c *Conn) CloseWrite() error {
1218+
c.handshakeMutex.Lock()
1219+
defer c.handshakeMutex.Unlock()
1220+
if !c.handshakeComplete {
1221+
return errEarlyCloseWrite
1222+
}
1223+
1224+
return c.closeNotify()
1225+
}
1226+
1227+
func (c *Conn) closeNotify() error {
1228+
c.out.Lock()
1229+
defer c.out.Unlock()
1230+
1231+
if !c.closeNotifySent {
1232+
c.closeNotifyErr = c.sendAlertLocked(alertCloseNotify)
1233+
c.closeNotifySent = true
1234+
}
1235+
return c.closeNotifyErr
1236+
}
1237+
11981238
// Handshake runs the client or server handshake
11991239
// protocol if it has not yet been run.
12001240
// Most uses of this package need not call Handshake

src/crypto/tls/tls_test.go

+86
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ import (
1111
"fmt"
1212
"internal/testenv"
1313
"io"
14+
"io/ioutil"
1415
"math"
1516
"math/rand"
1617
"net"
@@ -458,6 +459,91 @@ func TestConnCloseBreakingWrite(t *testing.T) {
458459
}
459460
}
460461

462+
func TestConnCloseWrite(t *testing.T) {
463+
ln := newLocalListener(t)
464+
defer ln.Close()
465+
466+
go func() {
467+
sconn, err := ln.Accept()
468+
if err != nil {
469+
t.Fatal(err)
470+
}
471+
472+
serverConfig := testConfig.Clone()
473+
srv := Server(sconn, serverConfig)
474+
if err := srv.Handshake(); err != nil {
475+
t.Fatalf("handshake: %v", err)
476+
}
477+
defer srv.Close()
478+
479+
data, err := ioutil.ReadAll(srv)
480+
if err != nil {
481+
t.Fatal(err)
482+
}
483+
if len(data) > 0 {
484+
t.Errorf("Read data = %q; want nothing", data)
485+
}
486+
487+
if err = srv.CloseWrite(); err != nil {
488+
t.Errorf("server CloseWrite: %v", err)
489+
}
490+
}()
491+
492+
clientConfig := testConfig.Clone()
493+
conn, err := Dial("tcp", ln.Addr().String(), clientConfig)
494+
if err != nil {
495+
t.Fatal(err)
496+
}
497+
if err = conn.Handshake(); err != nil {
498+
t.Fatal(err)
499+
}
500+
defer conn.Close()
501+
502+
if err = conn.CloseWrite(); err != nil {
503+
t.Errorf("client CloseWrite: %v", err)
504+
}
505+
506+
if _, err := conn.Write([]byte{0}); err != errShutdown {
507+
t.Errorf("CloseWrite error = %v; want errShutdown", err)
508+
}
509+
510+
data, err := ioutil.ReadAll(conn)
511+
if err != nil {
512+
t.Fatal(err)
513+
}
514+
if len(data) > 0 {
515+
t.Errorf("Read data = %q; want nothing", data)
516+
}
517+
518+
// test CloseWrite called before handshake finished
519+
520+
ln2 := newLocalListener(t)
521+
defer ln2.Close()
522+
523+
go func() {
524+
sconn, err := ln2.Accept()
525+
if err != nil {
526+
t.Fatal(err)
527+
}
528+
529+
serverConfig := testConfig.Clone()
530+
srv := Server(sconn, serverConfig)
531+
532+
srv.Handshake()
533+
srv.Close()
534+
}()
535+
536+
netConn, err := net.Dial("tcp", ln2.Addr().String())
537+
if err != nil {
538+
t.Fatal(err)
539+
}
540+
conn = Client(netConn, clientConfig)
541+
542+
if err = conn.CloseWrite(); err != errEarlyCloseWrite {
543+
t.Errorf("CloseWrite error = %v; want errEarlyCloseWrite", err)
544+
}
545+
}
546+
461547
func TestClone(t *testing.T) {
462548
var c1 Config
463549
v := reflect.ValueOf(&c1).Elem()

0 commit comments

Comments
 (0)