Skip to content

Commit 26d211b

Browse files
committed
dial: add DialContext function
In order to replace timeouts with contexts in `Connect` instance creation (go-tarantool), I need a `DialContext` function. It accepts context, and cancels, if context is canceled by user. Part of tarantool/go-tarantool#136
1 parent b452431 commit 26d211b

File tree

2 files changed

+180
-27
lines changed

2 files changed

+180
-27
lines changed

net.go

Lines changed: 86 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
package openssl
1616

1717
import (
18+
"context"
1819
"errors"
1920
"net"
2021
"time"
@@ -89,8 +90,44 @@ func Dial(network, addr string, ctx *Ctx, flags DialFlags) (*Conn, error) {
8990
// parameters.
9091
func DialTimeout(network, addr string, timeout time.Duration, ctx *Ctx,
9192
flags DialFlags) (*Conn, error) {
92-
d := net.Dialer{Timeout: timeout}
93-
return dialSession(d, network, addr, ctx, flags, nil)
93+
dialContext, cancel := context.WithTimeout(context.Background(), timeout)
94+
defer cancel()
95+
host, conn, err := createConnection(dialContext, network, addr)
96+
if err != nil {
97+
return nil, err
98+
}
99+
ctx, err = prepareCtx(ctx)
100+
if err != nil {
101+
return nil, err
102+
}
103+
client, err := createSession(conn, flags, host, ctx, nil)
104+
if err != nil {
105+
conn.Close()
106+
}
107+
return client, err
108+
}
109+
110+
// DialContext acts like Dial but takes a context for network dial.
111+
//
112+
// The context includes only network dial. It does not include OpenSSL calls.
113+
//
114+
// See func Dial for a description of the network, addr, ctx and flags
115+
// parameters.
116+
func DialContext(context context.Context, network, addr string,
117+
ctx *Ctx, flags DialFlags) (*Conn, error) {
118+
host, conn, err := createConnection(context, network, addr)
119+
if err != nil {
120+
return nil, err
121+
}
122+
ctx, err = prepareCtx(ctx)
123+
if err != nil {
124+
return nil, err
125+
}
126+
client, err := createSession(conn, flags, host, ctx, nil)
127+
if err != nil {
128+
conn.Close()
129+
}
130+
return client, err
94131
}
95132

96133
// DialSession will connect to network/address and then wrap the corresponding
@@ -108,16 +145,22 @@ func DialTimeout(network, addr string, timeout time.Duration, ctx *Ctx,
108145
// can be retrieved from the GetSession method on the Conn.
109146
func DialSession(network, addr string, ctx *Ctx, flags DialFlags,
110147
session []byte) (*Conn, error) {
111-
var d net.Dialer
112-
return dialSession(d, network, addr, ctx, flags, session)
113-
}
114-
115-
func dialSession(d net.Dialer, network, addr string, ctx *Ctx, flags DialFlags,
116-
session []byte) (*Conn, error) {
117-
host, _, err := net.SplitHostPort(addr)
148+
host, conn, err := createConnection(context.Background(), network, addr)
118149
if err != nil {
119150
return nil, err
120151
}
152+
ctx, err = prepareCtx(ctx)
153+
if err != nil {
154+
return nil, err
155+
}
156+
client, err := createSession(conn, flags, host, ctx, session)
157+
if err != nil {
158+
conn.Close()
159+
}
160+
return client, err
161+
}
162+
163+
func prepareCtx(ctx *Ctx) (*Ctx, error) {
121164
if ctx == nil {
122165
var err error
123166
ctx, err = NewCtx()
@@ -126,41 +169,57 @@ func dialSession(d net.Dialer, network, addr string, ctx *Ctx, flags DialFlags,
126169
}
127170
// TODO: use operating system default certificate chain?
128171
}
172+
return ctx, nil
173+
}
129174

130-
c, err := d.Dial(network, addr)
131-
if err != nil {
132-
return nil, err
133-
}
134-
conn, err := Client(c, ctx)
175+
func createConnection(context context.Context, network, addr string) (string, net.Conn, error) {
176+
host, _, err := net.SplitHostPort(addr)
135177
if err != nil {
136-
c.Close()
137-
return nil, err
138-
}
139-
if session != nil {
140-
err := conn.setSession(session)
141-
if err != nil {
142-
c.Close()
143-
return nil, err
144-
}
178+
return "", nil, err
145179
}
180+
181+
dialer := net.Dialer{}
182+
conn, err := dialer.DialContext(context, network, addr)
183+
return host, conn, err
184+
}
185+
186+
func handshake(conn *Conn, host string, flags DialFlags) error {
187+
var err error
146188
if flags&DisableSNI == 0 {
147189
err = conn.SetTlsExtHostName(host)
148190
if err != nil {
149-
conn.Close()
150-
return nil, err
191+
return err
151192
}
152193
}
153194
err = conn.Handshake()
154195
if err != nil {
155-
conn.Close()
156-
return nil, err
196+
return err
157197
}
158198
if flags&InsecureSkipHostVerification == 0 {
159199
err = conn.VerifyHostname(host)
200+
if err != nil {
201+
return err
202+
}
203+
}
204+
return nil
205+
}
206+
207+
func createSession(c net.Conn, flags DialFlags, host string, ctx *Ctx,
208+
session []byte) (*Conn, error) {
209+
conn, err := Client(c, ctx)
210+
if err != nil {
211+
return nil, err
212+
}
213+
if session != nil {
214+
err := conn.setSession(session)
160215
if err != nil {
161216
conn.Close()
162217
return nil, err
163218
}
164219
}
220+
if err := handshake(conn, host, flags); err != nil {
221+
conn.Close()
222+
return nil, err
223+
}
165224
return conn, nil
166225
}

net_test.go

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
package openssl
2+
3+
import (
4+
"context"
5+
"crypto/rand"
6+
"io"
7+
"net"
8+
"sync"
9+
"testing"
10+
"time"
11+
)
12+
13+
var conn net.Conn
14+
15+
func sslConnect(t *testing.T, ssl_listener net.Listener) {
16+
for {
17+
var err error
18+
conn, err = ssl_listener.Accept()
19+
if err != nil {
20+
t.Errorf("failed accept: %s", err)
21+
continue
22+
}
23+
io.Copy(conn, io.LimitReader(rand.Reader, 1024))
24+
break
25+
}
26+
}
27+
28+
func TestDial(t *testing.T) {
29+
ctx := getCtx(t)
30+
if err := ctx.SetCipherList("AES128-SHA"); err != nil {
31+
t.Fatal(err)
32+
}
33+
ssl_listener, err := Listen("tcp", "localhost:0", ctx)
34+
if err != nil {
35+
t.Fatal(err)
36+
}
37+
38+
wg := sync.WaitGroup{}
39+
wg.Add(1)
40+
go func() {
41+
sslConnect(t, ssl_listener)
42+
wg.Done()
43+
}()
44+
45+
client, err := Dial(ssl_listener.Addr().Network(),
46+
ssl_listener.Addr().String(), ctx, InsecureSkipHostVerification)
47+
48+
wg.Wait()
49+
50+
if err != nil {
51+
t.Fatalf("unexpected err: %v", err)
52+
}
53+
if client.is_shutdown {
54+
t.Fatal("client is closed after creation")
55+
}
56+
}
57+
58+
func TestDialTimeout(t *testing.T) {
59+
ctx := getCtx(t)
60+
if err := ctx.SetCipherList("AES128-SHA"); err != nil {
61+
t.Fatal(err)
62+
}
63+
ssl_listener, err := Listen("tcp", "localhost:0", ctx)
64+
if err != nil {
65+
t.Fatal(err)
66+
}
67+
68+
client, err := DialTimeout(ssl_listener.Addr().Network(),
69+
ssl_listener.Addr().String(), time.Nanosecond, ctx, 0)
70+
71+
if client != nil || err == nil {
72+
t.Fatalf("expected error")
73+
}
74+
}
75+
76+
func TestDialContext(t *testing.T) {
77+
ctx := getCtx(t)
78+
if err := ctx.SetCipherList("AES128-SHA"); err != nil {
79+
t.Fatal(err)
80+
}
81+
ssl_listener, err := Listen("tcp", "localhost:0", ctx)
82+
if err != nil {
83+
t.Fatal(err)
84+
}
85+
86+
cancelCtx, cancel := context.WithCancel(context.Background())
87+
cancel()
88+
client, err := DialContext(cancelCtx, ssl_listener.Addr().Network(),
89+
ssl_listener.Addr().String(), ctx, 0)
90+
91+
if client != nil || err == nil {
92+
t.Fatalf("expected error")
93+
}
94+
}

0 commit comments

Comments
 (0)