Skip to content

Commit a46839f

Browse files
committed
dial: add DialContext function
In order to replace timeouts with contexts in `Connect` instance creation (go-tarantool), I need a `DialContext` function. It accepts context, and cancels, if context is canceled by user. Part of tarantool/go-tarantool#136
1 parent b452431 commit a46839f

File tree

2 files changed

+188
-27
lines changed

2 files changed

+188
-27
lines changed

net.go

Lines changed: 94 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
package openssl
1616

1717
import (
18+
"context"
1819
"errors"
1920
"net"
2021
"time"
@@ -89,8 +90,53 @@ func Dial(network, addr string, ctx *Ctx, flags DialFlags) (*Conn, error) {
8990
// parameters.
9091
func DialTimeout(network, addr string, timeout time.Duration, ctx *Ctx,
9192
flags DialFlags) (*Conn, error) {
92-
d := net.Dialer{Timeout: timeout}
93-
return dialSession(d, network, addr, ctx, flags, nil)
93+
host, err := parseHost(addr)
94+
if err != nil {
95+
return nil, err
96+
}
97+
98+
conn, err := net.DialTimeout(network, addr, timeout)
99+
if err != nil {
100+
return nil, err
101+
}
102+
ctx, err = prepareCtx(ctx)
103+
if err != nil {
104+
return nil, err
105+
}
106+
client, err := createSession(conn, flags, host, ctx, nil)
107+
if err != nil {
108+
conn.Close()
109+
}
110+
return client, err
111+
}
112+
113+
// DialContext acts like Dial but takes a context for network dial.
114+
//
115+
// The context includes only network dial. It does not include OpenSSL calls.
116+
//
117+
// See func Dial for a description of the network, addr, ctx and flags
118+
// parameters.
119+
func DialContext(context context.Context, network, addr string,
120+
ctx *Ctx, flags DialFlags) (*Conn, error) {
121+
host, err := parseHost(addr)
122+
if err != nil {
123+
return nil, err
124+
}
125+
126+
dialer := net.Dialer{}
127+
conn, err := dialer.DialContext(context, network, addr)
128+
if err != nil {
129+
return nil, err
130+
}
131+
ctx, err = prepareCtx(ctx)
132+
if err != nil {
133+
return nil, err
134+
}
135+
client, err := createSession(conn, flags, host, ctx, nil)
136+
if err != nil {
137+
conn.Close()
138+
}
139+
return client, err
94140
}
95141

96142
// DialSession will connect to network/address and then wrap the corresponding
@@ -108,59 +154,80 @@ func DialTimeout(network, addr string, timeout time.Duration, ctx *Ctx,
108154
// can be retrieved from the GetSession method on the Conn.
109155
func DialSession(network, addr string, ctx *Ctx, flags DialFlags,
110156
session []byte) (*Conn, error) {
111-
var d net.Dialer
112-
return dialSession(d, network, addr, ctx, flags, session)
113-
}
114-
115-
func dialSession(d net.Dialer, network, addr string, ctx *Ctx, flags DialFlags,
116-
session []byte) (*Conn, error) {
117-
host, _, err := net.SplitHostPort(addr)
157+
host, err := parseHost(addr)
118158
if err != nil {
119159
return nil, err
120160
}
121-
if ctx == nil {
122-
var err error
123-
ctx, err = NewCtx()
124-
if err != nil {
125-
return nil, err
126-
}
127-
// TODO: use operating system default certificate chain?
128-
}
129161

130-
c, err := d.Dial(network, addr)
162+
conn, err := net.Dial(network, addr)
131163
if err != nil {
132164
return nil, err
133165
}
134-
conn, err := Client(c, ctx)
166+
ctx, err = prepareCtx(ctx)
135167
if err != nil {
136-
c.Close()
137168
return nil, err
138169
}
139-
if session != nil {
140-
err := conn.setSession(session)
170+
client, err := createSession(conn, flags, host, ctx, session)
171+
if err != nil {
172+
conn.Close()
173+
}
174+
return client, err
175+
}
176+
177+
func prepareCtx(ctx *Ctx) (*Ctx, error) {
178+
if ctx == nil {
179+
var err error
180+
ctx, err = NewCtx()
141181
if err != nil {
142-
c.Close()
143182
return nil, err
144183
}
184+
// TODO: use operating system default certificate chain?
145185
}
186+
return ctx, nil
187+
}
188+
189+
func parseHost(addr string) (string, error) {
190+
host, _, err := net.SplitHostPort(addr)
191+
return host, err
192+
}
193+
194+
func handshake(conn *Conn, host string, flags DialFlags) error {
195+
var err error
146196
if flags&DisableSNI == 0 {
147197
err = conn.SetTlsExtHostName(host)
148198
if err != nil {
149-
conn.Close()
150-
return nil, err
199+
return err
151200
}
152201
}
153202
err = conn.Handshake()
154203
if err != nil {
155-
conn.Close()
156-
return nil, err
204+
return err
157205
}
158206
if flags&InsecureSkipHostVerification == 0 {
159207
err = conn.VerifyHostname(host)
208+
if err != nil {
209+
return err
210+
}
211+
}
212+
return nil
213+
}
214+
215+
func createSession(c net.Conn, flags DialFlags, host string, ctx *Ctx,
216+
session []byte) (*Conn, error) {
217+
conn, err := Client(c, ctx)
218+
if err != nil {
219+
return nil, err
220+
}
221+
if session != nil {
222+
err := conn.setSession(session)
160223
if err != nil {
161224
conn.Close()
162225
return nil, err
163226
}
164227
}
228+
if err := handshake(conn, host, flags); err != nil {
229+
conn.Close()
230+
return nil, err
231+
}
165232
return conn, nil
166233
}

net_test.go

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
package openssl
2+
3+
import (
4+
"context"
5+
"crypto/rand"
6+
"io"
7+
"net"
8+
"sync"
9+
"testing"
10+
"time"
11+
)
12+
13+
var conn net.Conn
14+
15+
func sslConnect(t *testing.T, ssl_listener net.Listener) {
16+
for {
17+
var err error
18+
conn, err = ssl_listener.Accept()
19+
if err != nil {
20+
t.Errorf("failed accept: %s", err)
21+
continue
22+
}
23+
io.Copy(conn, io.LimitReader(rand.Reader, 1024))
24+
break
25+
}
26+
}
27+
28+
func TestDial(t *testing.T) {
29+
ctx := getCtx(t)
30+
if err := ctx.SetCipherList("AES128-SHA"); err != nil {
31+
t.Fatal(err)
32+
}
33+
ssl_listener, err := Listen("tcp", "localhost:0", ctx)
34+
if err != nil {
35+
t.Fatal(err)
36+
}
37+
38+
wg := sync.WaitGroup{}
39+
wg.Add(1)
40+
go func() {
41+
sslConnect(t, ssl_listener)
42+
wg.Done()
43+
}()
44+
45+
client, err := Dial(ssl_listener.Addr().Network(),
46+
ssl_listener.Addr().String(), ctx, InsecureSkipHostVerification)
47+
48+
wg.Wait()
49+
50+
if err != nil {
51+
t.Fatalf("unexpected err: %v", err)
52+
}
53+
if client.is_shutdown {
54+
t.Fatal("client is closed after creation")
55+
}
56+
}
57+
58+
func TestDialTimeout(t *testing.T) {
59+
ctx := getCtx(t)
60+
if err := ctx.SetCipherList("AES128-SHA"); err != nil {
61+
t.Fatal(err)
62+
}
63+
ssl_listener, err := Listen("tcp", "localhost:0", ctx)
64+
if err != nil {
65+
t.Fatal(err)
66+
}
67+
68+
client, err := DialTimeout(ssl_listener.Addr().Network(),
69+
ssl_listener.Addr().String(), time.Nanosecond, ctx, 0)
70+
71+
if client != nil || err == nil {
72+
t.Fatalf("expected error")
73+
}
74+
}
75+
76+
func TestDialContext(t *testing.T) {
77+
ctx := getCtx(t)
78+
if err := ctx.SetCipherList("AES128-SHA"); err != nil {
79+
t.Fatal(err)
80+
}
81+
ssl_listener, err := Listen("tcp", "localhost:0", ctx)
82+
if err != nil {
83+
t.Fatal(err)
84+
}
85+
86+
cancelCtx, cancel := context.WithCancel(context.Background())
87+
cancel()
88+
client, err := DialContext(cancelCtx, ssl_listener.Addr().Network(),
89+
ssl_listener.Addr().String(), ctx, 0)
90+
91+
if client != nil || err == nil {
92+
t.Fatalf("expected error")
93+
}
94+
}

0 commit comments

Comments
 (0)