15
15
package openssl
16
16
17
17
import (
18
+ "context"
18
19
"errors"
19
20
"net"
20
21
"time"
@@ -89,8 +90,53 @@ func Dial(network, addr string, ctx *Ctx, flags DialFlags) (*Conn, error) {
89
90
// parameters.
90
91
func DialTimeout (network , addr string , timeout time.Duration , ctx * Ctx ,
91
92
flags DialFlags ) (* Conn , error ) {
92
- d := net.Dialer {Timeout : timeout }
93
- return dialSession (d , network , addr , ctx , flags , nil )
93
+ host , err := parseHost (addr )
94
+ if err != nil {
95
+ return nil , err
96
+ }
97
+
98
+ conn , err := net .DialTimeout (network , addr , timeout )
99
+ if err != nil {
100
+ return nil , err
101
+ }
102
+ ctx , err = prepareCtx (ctx )
103
+ if err != nil {
104
+ return nil , err
105
+ }
106
+ client , err := createSession (conn , flags , host , ctx , nil )
107
+ if err != nil {
108
+ conn .Close ()
109
+ }
110
+ return client , err
111
+ }
112
+
113
+ // DialContext acts like Dial but takes a context for network dial.
114
+ //
115
+ // The context includes only network dial. It does not include OpenSSL calls.
116
+ //
117
+ // See func Dial for a description of the network, addr, ctx and flags
118
+ // parameters.
119
+ func DialContext (context context.Context , network , addr string ,
120
+ ctx * Ctx , flags DialFlags ) (* Conn , error ) {
121
+ host , err := parseHost (addr )
122
+ if err != nil {
123
+ return nil , err
124
+ }
125
+
126
+ dialer := net.Dialer {}
127
+ conn , err := dialer .DialContext (context , network , addr )
128
+ if err != nil {
129
+ return nil , err
130
+ }
131
+ ctx , err = prepareCtx (ctx )
132
+ if err != nil {
133
+ return nil , err
134
+ }
135
+ client , err := createSession (conn , flags , host , ctx , nil )
136
+ if err != nil {
137
+ conn .Close ()
138
+ }
139
+ return client , err
94
140
}
95
141
96
142
// DialSession will connect to network/address and then wrap the corresponding
@@ -108,59 +154,80 @@ func DialTimeout(network, addr string, timeout time.Duration, ctx *Ctx,
108
154
// can be retrieved from the GetSession method on the Conn.
109
155
func DialSession (network , addr string , ctx * Ctx , flags DialFlags ,
110
156
session []byte ) (* Conn , error ) {
111
- var d net.Dialer
112
- return dialSession (d , network , addr , ctx , flags , session )
113
- }
114
-
115
- func dialSession (d net.Dialer , network , addr string , ctx * Ctx , flags DialFlags ,
116
- session []byte ) (* Conn , error ) {
117
- host , _ , err := net .SplitHostPort (addr )
157
+ host , err := parseHost (addr )
118
158
if err != nil {
119
159
return nil , err
120
160
}
121
- if ctx == nil {
122
- var err error
123
- ctx , err = NewCtx ()
124
- if err != nil {
125
- return nil , err
126
- }
127
- // TODO: use operating system default certificate chain?
128
- }
129
161
130
- c , err := d .Dial (network , addr )
162
+ conn , err := net .Dial (network , addr )
131
163
if err != nil {
132
164
return nil , err
133
165
}
134
- conn , err := Client ( c , ctx )
166
+ ctx , err = prepareCtx ( ctx )
135
167
if err != nil {
136
- c .Close ()
137
168
return nil , err
138
169
}
139
- if session != nil {
140
- err := conn .setSession (session )
170
+ client , err := createSession (conn , flags , host , ctx , session )
171
+ if err != nil {
172
+ conn .Close ()
173
+ }
174
+ return client , err
175
+ }
176
+
177
+ func prepareCtx (ctx * Ctx ) (* Ctx , error ) {
178
+ if ctx == nil {
179
+ var err error
180
+ ctx , err = NewCtx ()
141
181
if err != nil {
142
- c .Close ()
143
182
return nil , err
144
183
}
184
+ // TODO: use operating system default certificate chain?
145
185
}
186
+ return ctx , nil
187
+ }
188
+
189
+ func parseHost (addr string ) (string , error ) {
190
+ host , _ , err := net .SplitHostPort (addr )
191
+ return host , err
192
+ }
193
+
194
+ func handshake (conn * Conn , host string , flags DialFlags ) error {
195
+ var err error
146
196
if flags & DisableSNI == 0 {
147
197
err = conn .SetTlsExtHostName (host )
148
198
if err != nil {
149
- conn .Close ()
150
- return nil , err
199
+ return err
151
200
}
152
201
}
153
202
err = conn .Handshake ()
154
203
if err != nil {
155
- conn .Close ()
156
- return nil , err
204
+ return err
157
205
}
158
206
if flags & InsecureSkipHostVerification == 0 {
159
207
err = conn .VerifyHostname (host )
208
+ if err != nil {
209
+ return err
210
+ }
211
+ }
212
+ return nil
213
+ }
214
+
215
+ func createSession (c net.Conn , flags DialFlags , host string , ctx * Ctx ,
216
+ session []byte ) (* Conn , error ) {
217
+ conn , err := Client (c , ctx )
218
+ if err != nil {
219
+ return nil , err
220
+ }
221
+ if session != nil {
222
+ err := conn .setSession (session )
160
223
if err != nil {
161
224
conn .Close ()
162
225
return nil , err
163
226
}
164
227
}
228
+ if err := handshake (conn , host , flags ); err != nil {
229
+ conn .Close ()
230
+ return nil , err
231
+ }
165
232
return conn , nil
166
233
}
0 commit comments