@@ -21,6 +21,7 @@ const bufSize = 128 * 1024
21
21
// Greeting is a message sent by Tarantool on connect.
22
22
type Greeting struct {
23
23
Version string
24
+ Salt string
24
25
}
25
26
26
27
// writeFlusher is the interface that groups the basic Write and Flush methods.
@@ -80,21 +81,48 @@ type tntConn struct {
80
81
81
82
// rawDial does basic dial operations:
82
83
// reads greeting, identifies a protocol and validates it.
83
- func rawDial (conn * tntConn , requiredProto ProtocolInfo ) ( string , error ) {
84
+ func rawDial (conn * tntConn ) error {
84
85
version , salt , err := readGreeting (conn .reader )
85
86
if err != nil {
86
- return "" , fmt .Errorf ("failed to read greeting: %w" , err )
87
+ return fmt .Errorf ("failed to read greeting: %w" , err )
87
88
}
88
89
conn .greeting .Version = version
90
+ conn .greeting .Salt = salt
89
91
90
92
if conn .protocol , err = identify (conn .writer , conn .reader ); err != nil {
91
- return "" , fmt .Errorf ("failed to identify: %w" , err )
93
+ return fmt .Errorf ("failed to identify: %w" , err )
92
94
}
93
95
94
- if err = checkProtocolInfo (requiredProto , conn .protocol ); err != nil {
95
- return "" , fmt .Errorf ("invalid server protocol: %w" , err )
96
+ return nil
97
+ }
98
+
99
+ type netDialer struct {
100
+ address string
101
+ }
102
+
103
+ func (d netDialer ) Dial (ctx context.Context , opts DialOpts ) (Conn , error ) {
104
+ var err error
105
+ conn := new (tntConn )
106
+
107
+ network , address := parseAddress (d .address )
108
+ dialer := net.Dialer {}
109
+ conn .net , err = dialer .DialContext (ctx , network , address )
110
+ if err != nil {
111
+ return nil , fmt .Errorf ("failed to dial: %w" , err )
112
+ }
113
+
114
+ dc := & deadlineIO {to : opts .IoTimeout , c : conn .net }
115
+ conn .reader = bufio .NewReaderSize (dc , bufSize )
116
+ conn .writer = bufio .NewWriterSize (dc , bufSize )
117
+
118
+ err = rawDial (conn )
119
+ if err != nil {
120
+ conn .net .Close ()
121
+ return nil , err
96
122
}
97
- return salt , err
123
+
124
+ conn .protocol .Auth = ChapSha1Auth
125
+ return conn , nil
98
126
}
99
127
100
128
// NetDialer is a basic Dialer implementation.
@@ -121,12 +149,45 @@ type NetDialer struct {
121
149
122
150
// Dial makes NetDialer satisfy the Dialer interface.
123
151
func (d NetDialer ) Dial (ctx context.Context , opts DialOpts ) (Conn , error ) {
152
+ dialer := AuthDialer {
153
+ Dialer : ProtocolDialer {
154
+ Dialer : netDialer {
155
+ address : d .Address ,
156
+ },
157
+ RequiredProtocolInfo : d .RequiredProtocolInfo ,
158
+ },
159
+ Auth : ChapSha1Auth ,
160
+ Username : d .User ,
161
+ Password : d .Password ,
162
+ }
163
+
164
+ return dialer .Dial (ctx , opts )
165
+ }
166
+
167
+ type openSslDialer struct {
168
+ address string
169
+ auth Auth
170
+ sslKeyFile string
171
+ sslCertFile string
172
+ sslCaFile string
173
+ sslCiphers string
174
+ sslPassword string
175
+ sslPasswordFile string
176
+ }
177
+
178
+ func (d openSslDialer ) Dial (ctx context.Context , opts DialOpts ) (Conn , error ) {
124
179
var err error
125
180
conn := new (tntConn )
126
181
127
- network , address := parseAddress (d .Address )
128
- dialer := net.Dialer {}
129
- conn .net , err = dialer .DialContext (ctx , network , address )
182
+ network , address := parseAddress (d .address )
183
+ conn .net , err = sslDialContext (ctx , network , address , sslOpts {
184
+ KeyFile : d .sslKeyFile ,
185
+ CertFile : d .sslCertFile ,
186
+ CaFile : d .sslCaFile ,
187
+ Ciphers : d .sslCiphers ,
188
+ Password : d .sslPassword ,
189
+ PasswordFile : d .sslPasswordFile ,
190
+ })
130
191
if err != nil {
131
192
return nil , fmt .Errorf ("failed to dial: %w" , err )
132
193
}
@@ -135,20 +196,18 @@ func (d NetDialer) Dial(ctx context.Context, opts DialOpts) (Conn, error) {
135
196
conn .reader = bufio .NewReaderSize (dc , bufSize )
136
197
conn .writer = bufio .NewWriterSize (dc , bufSize )
137
198
138
- salt , err : = rawDial (conn , d . RequiredProtocolInfo )
199
+ err = rawDial (conn )
139
200
if err != nil {
140
201
conn .net .Close ()
141
202
return nil , err
142
203
}
143
204
144
- if d .User == "" {
145
- return conn , nil
146
- }
147
-
148
- conn .protocol .Auth = ChapSha1Auth
149
- if err = authenticate (conn , ChapSha1Auth , d .User , d .Password , salt ); err != nil {
150
- conn .net .Close ()
151
- return nil , fmt .Errorf ("failed to authenticate: %w" , err )
205
+ if d .auth == AutoAuth {
206
+ if conn .protocol .Auth == AutoAuth {
207
+ conn .protocol .Auth = ChapSha1Auth
208
+ }
209
+ } else {
210
+ conn .protocol .Auth = d .auth
152
211
}
153
212
154
213
return conn , nil
@@ -206,51 +265,26 @@ type OpenSslDialer struct {
206
265
207
266
// Dial makes OpenSslDialer satisfy the Dialer interface.
208
267
func (d OpenSslDialer ) Dial (ctx context.Context , opts DialOpts ) (Conn , error ) {
209
- var err error
210
- conn := new (tntConn )
211
-
212
- network , address := parseAddress (d .Address )
213
- conn .net , err = sslDialContext (ctx , network , address , sslOpts {
214
- KeyFile : d .SslKeyFile ,
215
- CertFile : d .SslCertFile ,
216
- CaFile : d .SslCaFile ,
217
- Ciphers : d .SslCiphers ,
218
- Password : d .SslPassword ,
219
- PasswordFile : d .SslPasswordFile ,
220
- })
221
- if err != nil {
222
- return nil , fmt .Errorf ("failed to dial: %w" , err )
223
- }
224
-
225
- dc := & deadlineIO {to : opts .IoTimeout , c : conn .net }
226
- conn .reader = bufio .NewReaderSize (dc , bufSize )
227
- conn .writer = bufio .NewWriterSize (dc , bufSize )
228
-
229
- salt , err := rawDial (conn , d .RequiredProtocolInfo )
230
- if err != nil {
231
- conn .net .Close ()
232
- return nil , err
233
- }
234
-
235
- if d .User == "" {
236
- return conn , nil
237
- }
238
-
239
- if d .Auth == AutoAuth {
240
- if conn .protocol .Auth != AutoAuth {
241
- d .Auth = conn .protocol .Auth
242
- } else {
243
- d .Auth = ChapSha1Auth
244
- }
245
- }
246
- conn .protocol .Auth = d .Auth
247
-
248
- if err = authenticate (conn , d .Auth , d .User , d .Password , salt ); err != nil {
249
- conn .net .Close ()
250
- return nil , fmt .Errorf ("failed to authenticate: %w" , err )
251
- }
252
-
253
- return conn , nil
268
+ dialer := AuthDialer {
269
+ Dialer : ProtocolDialer {
270
+ Dialer : openSslDialer {
271
+ address : d .Address ,
272
+ auth : d .Auth ,
273
+ sslKeyFile : d .SslKeyFile ,
274
+ sslCertFile : d .SslCertFile ,
275
+ sslCaFile : d .SslCaFile ,
276
+ sslCiphers : d .SslCiphers ,
277
+ sslPassword : d .SslPassword ,
278
+ sslPasswordFile : d .SslPasswordFile ,
279
+ },
280
+ RequiredProtocolInfo : d .RequiredProtocolInfo ,
281
+ },
282
+ Auth : d .Auth ,
283
+ Username : d .User ,
284
+ Password : d .Password ,
285
+ }
286
+
287
+ return dialer .Dial (ctx , opts )
254
288
}
255
289
256
290
// FdDialer allows to use an existing socket fd for connection.
@@ -263,6 +297,10 @@ type FdDialer struct {
263
297
RequiredProtocolInfo ProtocolInfo
264
298
}
265
299
300
+ type fdDialer struct {
301
+ fd uintptr
302
+ }
303
+
266
304
type fdAddr struct {
267
305
Fd uintptr
268
306
}
@@ -284,22 +322,21 @@ func (c *fdConn) RemoteAddr() net.Addr {
284
322
return c .Addr
285
323
}
286
324
287
- // Dial makes FdDialer satisfy the Dialer interface.
288
- func (d FdDialer ) Dial (ctx context.Context , opts DialOpts ) (Conn , error ) {
289
- file := os .NewFile (d .Fd , "" )
325
+ func (d fdDialer ) Dial (ctx context.Context , opts DialOpts ) (Conn , error ) {
326
+ file := os .NewFile (d .fd , "" )
290
327
c , err := net .FileConn (file )
291
328
if err != nil {
292
329
return nil , fmt .Errorf ("failed to dial: %w" , err )
293
330
}
294
331
295
332
conn := new (tntConn )
296
- conn .net = & fdConn {Conn : c , Addr : fdAddr {Fd : d .Fd }}
333
+ conn .net = & fdConn {Conn : c , Addr : fdAddr {Fd : d .fd }}
297
334
298
335
dc := & deadlineIO {to : opts .IoTimeout , c : conn .net }
299
336
conn .reader = bufio .NewReaderSize (dc , bufSize )
300
337
conn .writer = bufio .NewWriterSize (dc , bufSize )
301
338
302
- _ , err = rawDial (conn , d . RequiredProtocolInfo )
339
+ err = rawDial (conn )
303
340
if err != nil {
304
341
conn .net .Close ()
305
342
return nil , err
@@ -308,6 +345,84 @@ func (d FdDialer) Dial(ctx context.Context, opts DialOpts) (Conn, error) {
308
345
return conn , nil
309
346
}
310
347
348
+ // Dial makes FdDialer satisfy the Dialer interface.
349
+ func (d FdDialer ) Dial (ctx context.Context , opts DialOpts ) (Conn , error ) {
350
+ dialer := ProtocolDialer {
351
+ Dialer : fdDialer {
352
+ fd : d .Fd ,
353
+ },
354
+ RequiredProtocolInfo : d .RequiredProtocolInfo ,
355
+ }
356
+
357
+ return dialer .Dial (ctx , opts )
358
+ }
359
+
360
+ type AuthDialer struct {
361
+ // Dialer is a base dialer.
362
+ Dialer Dialer
363
+ // Authentication options.
364
+ Auth Auth
365
+ Username string
366
+ Password string
367
+ }
368
+
369
+ // Dial makes AuthDialer satisfy the Dialer interface.
370
+ func (d AuthDialer ) Dial (ctx context.Context , opts DialOpts ) (Conn , error ) {
371
+ conn , err := d .Dialer .Dial (ctx , opts )
372
+ if err != nil {
373
+ return conn , err
374
+ }
375
+ greeting := conn .Greeting ()
376
+ if greeting .Salt == "" {
377
+ conn .Close ()
378
+ return nil , fmt .Errorf ("failed to authenticate: " +
379
+ "an invalid connection without salt" )
380
+ }
381
+
382
+ if d .Username == "" {
383
+ return conn , nil
384
+ }
385
+
386
+ protocolAuth := conn .ProtocolInfo ().Auth
387
+ if d .Auth == AutoAuth {
388
+ if protocolAuth != AutoAuth {
389
+ d .Auth = protocolAuth
390
+ } else {
391
+ d .Auth = ChapSha1Auth
392
+ }
393
+ }
394
+
395
+ if err := authenticate (conn , d .Auth , d .Username , d .Password ,
396
+ conn .Greeting ().Salt ); err != nil {
397
+ conn .Close ()
398
+ return nil , fmt .Errorf ("failed to authenticate: %w" , err )
399
+ }
400
+ return conn , nil
401
+ }
402
+
403
+ type ProtocolDialer struct {
404
+ // Dialer is a base dialer.
405
+ Dialer Dialer
406
+ // RequiredProtocol contains minimal protocol version and
407
+ // list of protocol features that should be supported by
408
+ // Tarantool server. By default, there are no restrictions.
409
+ RequiredProtocolInfo ProtocolInfo
410
+ }
411
+
412
+ // Dial makes ProtocolDialer satisfy the Dialer interface.
413
+ func (d ProtocolDialer ) Dial (ctx context.Context , opts DialOpts ) (Conn , error ) {
414
+ conn , err := d .Dialer .Dial (ctx , opts )
415
+ if err != nil {
416
+ return conn , err
417
+ }
418
+ err = checkProtocolInfo (d .RequiredProtocolInfo , conn .ProtocolInfo ())
419
+ if err != nil {
420
+ conn .Close ()
421
+ return nil , fmt .Errorf ("invalid server protocol: %w" , err )
422
+ }
423
+ return conn , nil
424
+ }
425
+
311
426
// Addr makes tntConn satisfy the Conn interface.
312
427
func (c * tntConn ) Addr () net.Addr {
313
428
return c .net .RemoteAddr ()
0 commit comments