15
15
package openssl
16
16
17
17
import (
18
+ "context"
18
19
"errors"
19
20
"net"
20
21
"time"
@@ -89,8 +90,44 @@ func Dial(network, addr string, ctx *Ctx, flags DialFlags) (*Conn, error) {
89
90
// parameters.
90
91
func DialTimeout (network , addr string , timeout time.Duration , ctx * Ctx ,
91
92
flags DialFlags ) (* Conn , error ) {
92
- d := net.Dialer {Timeout : timeout }
93
- return dialSession (d , network , addr , ctx , flags , nil )
93
+ dialContext , cancel := context .WithTimeout (context .Background (), timeout )
94
+ defer cancel ()
95
+ host , conn , err := createConnection (dialContext , network , addr )
96
+ if err != nil {
97
+ return nil , err
98
+ }
99
+ ctx , err = prepareCtx (ctx )
100
+ if err != nil {
101
+ return nil , err
102
+ }
103
+ client , err := createSession (conn , flags , host , ctx , nil )
104
+ if err != nil {
105
+ conn .Close ()
106
+ }
107
+ return client , err
108
+ }
109
+
110
+ // DialContext acts like Dial but takes a context for network dial.
111
+ //
112
+ // The context includes only network dial. It does not include OpenSSL calls.
113
+ //
114
+ // See func Dial for a description of the network, addr, ctx and flags
115
+ // parameters.
116
+ func DialContext (context context.Context , network , addr string ,
117
+ ctx * Ctx , flags DialFlags ) (* Conn , error ) {
118
+ host , conn , err := createConnection (context , network , addr )
119
+ if err != nil {
120
+ return nil , err
121
+ }
122
+ ctx , err = prepareCtx (ctx )
123
+ if err != nil {
124
+ return nil , err
125
+ }
126
+ client , err := createSession (conn , flags , host , ctx , nil )
127
+ if err != nil {
128
+ conn .Close ()
129
+ }
130
+ return client , err
94
131
}
95
132
96
133
// DialSession will connect to network/address and then wrap the corresponding
@@ -108,16 +145,22 @@ func DialTimeout(network, addr string, timeout time.Duration, ctx *Ctx,
108
145
// can be retrieved from the GetSession method on the Conn.
109
146
func DialSession (network , addr string , ctx * Ctx , flags DialFlags ,
110
147
session []byte ) (* Conn , error ) {
111
- var d net.Dialer
112
- return dialSession (d , network , addr , ctx , flags , session )
113
- }
114
-
115
- func dialSession (d net.Dialer , network , addr string , ctx * Ctx , flags DialFlags ,
116
- session []byte ) (* Conn , error ) {
117
- host , _ , err := net .SplitHostPort (addr )
148
+ host , conn , err := createConnection (context .Background (), network , addr )
118
149
if err != nil {
119
150
return nil , err
120
151
}
152
+ ctx , err = prepareCtx (ctx )
153
+ if err != nil {
154
+ return nil , err
155
+ }
156
+ client , err := createSession (conn , flags , host , ctx , session )
157
+ if err != nil {
158
+ conn .Close ()
159
+ }
160
+ return client , err
161
+ }
162
+
163
+ func prepareCtx (ctx * Ctx ) (* Ctx , error ) {
121
164
if ctx == nil {
122
165
var err error
123
166
ctx , err = NewCtx ()
@@ -126,41 +169,57 @@ func dialSession(d net.Dialer, network, addr string, ctx *Ctx, flags DialFlags,
126
169
}
127
170
// TODO: use operating system default certificate chain?
128
171
}
172
+ return ctx , nil
173
+ }
129
174
130
- c , err := d .Dial (network , addr )
131
- if err != nil {
132
- return nil , err
133
- }
134
- conn , err := Client (c , ctx )
175
+ func createConnection (context context.Context , network , addr string ) (string , net.Conn , error ) {
176
+ host , _ , err := net .SplitHostPort (addr )
135
177
if err != nil {
136
- c .Close ()
137
- return nil , err
138
- }
139
- if session != nil {
140
- err := conn .setSession (session )
141
- if err != nil {
142
- c .Close ()
143
- return nil , err
144
- }
178
+ return "" , nil , err
145
179
}
180
+
181
+ dialer := net.Dialer {}
182
+ conn , err := dialer .DialContext (context , network , addr )
183
+ return host , conn , err
184
+ }
185
+
186
+ func handshake (conn * Conn , host string , flags DialFlags ) error {
187
+ var err error
146
188
if flags & DisableSNI == 0 {
147
189
err = conn .SetTlsExtHostName (host )
148
190
if err != nil {
149
- conn .Close ()
150
- return nil , err
191
+ return err
151
192
}
152
193
}
153
194
err = conn .Handshake ()
154
195
if err != nil {
155
- conn .Close ()
156
- return nil , err
196
+ return err
157
197
}
158
198
if flags & InsecureSkipHostVerification == 0 {
159
199
err = conn .VerifyHostname (host )
200
+ if err != nil {
201
+ return err
202
+ }
203
+ }
204
+ return nil
205
+ }
206
+
207
+ func createSession (c net.Conn , flags DialFlags , host string , ctx * Ctx ,
208
+ session []byte ) (* Conn , error ) {
209
+ conn , err := Client (c , ctx )
210
+ if err != nil {
211
+ return nil , err
212
+ }
213
+ if session != nil {
214
+ err := conn .setSession (session )
160
215
if err != nil {
161
216
conn .Close ()
162
217
return nil , err
163
218
}
164
219
}
220
+ if err := handshake (conn , host , flags ); err != nil {
221
+ conn .Close ()
222
+ return nil , err
223
+ }
165
224
return conn , nil
166
225
}
0 commit comments