Skip to content

Commit 9292448

Browse files
committed
api: create AuthDialer and ProtocolDialer
To disable SSL by default we want to transfer `OpenSslDialer` to the go-openssl repository. In order to do so, we need to minimize the amount of copy-paste of the private functions. `AuthDialer` is created as a dialer-wrapper, that calls authentication methods. `ProtoDialer` is created to check expected `ProtocolInfo` with the actual (in the created connection). Part of #301
1 parent 6ba01ff commit 9292448

File tree

5 files changed

+315
-69
lines changed

5 files changed

+315
-69
lines changed

CHANGELOG.md

+2
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,8 @@ Versioning](http://semver.org/spec/v2.0.0.html) except to the first release.
3737
the response (#237)
3838
- Ability to mock connections for tests (#237). Added new types `MockDoer`,
3939
`MockRequest` to `test_helpers`.
40+
- `AuthDialer` and `ProtocolDialer` types for creating a dialer with
41+
authentication and `ProtocolInfo` check (#301)
4042

4143
### Changed
4244

connection.go

+3-1
Original file line numberDiff line numberDiff line change
@@ -440,7 +440,9 @@ func (conn *Connection) dial(ctx context.Context) error {
440440
}
441441

442442
conn.addr = c.Addr()
443-
conn.Greeting.Version = c.Greeting().Version
443+
connGreeting := c.Greeting()
444+
conn.Greeting.Version = connGreeting.Version
445+
conn.Greeting.Salt = connGreeting.Salt
444446
conn.serverProtocolInfo = c.ProtocolInfo()
445447

446448
spaceAndIndexNamesSupported :=

dial.go

+183-68
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ const bufSize = 128 * 1024
2121
// Greeting is a message sent by Tarantool on connect.
2222
type Greeting struct {
2323
Version string
24+
Salt string
2425
}
2526

2627
// writeFlusher is the interface that groups the basic Write and Flush methods.
@@ -80,21 +81,48 @@ type tntConn struct {
8081

8182
// rawDial does basic dial operations:
8283
// reads greeting, identifies a protocol and validates it.
83-
func rawDial(conn *tntConn, requiredProto ProtocolInfo) (string, error) {
84+
func rawDial(conn *tntConn) error {
8485
version, salt, err := readGreeting(conn.reader)
8586
if err != nil {
86-
return "", fmt.Errorf("failed to read greeting: %w", err)
87+
return fmt.Errorf("failed to read greeting: %w", err)
8788
}
8889
conn.greeting.Version = version
90+
conn.greeting.Salt = salt
8991

9092
if conn.protocol, err = identify(conn.writer, conn.reader); err != nil {
91-
return "", fmt.Errorf("failed to identify: %w", err)
93+
return fmt.Errorf("failed to identify: %w", err)
9294
}
9395

94-
if err = checkProtocolInfo(requiredProto, conn.protocol); err != nil {
95-
return "", fmt.Errorf("invalid server protocol: %w", err)
96+
return nil
97+
}
98+
99+
type netDialer struct {
100+
address string
101+
}
102+
103+
func (d netDialer) Dial(ctx context.Context, opts DialOpts) (Conn, error) {
104+
var err error
105+
conn := new(tntConn)
106+
107+
network, address := parseAddress(d.address)
108+
dialer := net.Dialer{}
109+
conn.net, err = dialer.DialContext(ctx, network, address)
110+
if err != nil {
111+
return nil, fmt.Errorf("failed to dial: %w", err)
112+
}
113+
114+
dc := &deadlineIO{to: opts.IoTimeout, c: conn.net}
115+
conn.reader = bufio.NewReaderSize(dc, bufSize)
116+
conn.writer = bufio.NewWriterSize(dc, bufSize)
117+
118+
err = rawDial(conn)
119+
if err != nil {
120+
conn.net.Close()
121+
return nil, err
96122
}
97-
return salt, err
123+
124+
conn.protocol.Auth = ChapSha1Auth
125+
return conn, nil
98126
}
99127

100128
// NetDialer is a basic Dialer implementation.
@@ -121,12 +149,45 @@ type NetDialer struct {
121149

122150
// Dial makes NetDialer satisfy the Dialer interface.
123151
func (d NetDialer) Dial(ctx context.Context, opts DialOpts) (Conn, error) {
152+
dialer := AuthDialer{
153+
Dialer: ProtocolDialer{
154+
Dialer: netDialer{
155+
address: d.Address,
156+
},
157+
RequiredProtocolInfo: d.RequiredProtocolInfo,
158+
},
159+
Auth: ChapSha1Auth,
160+
Username: d.User,
161+
Password: d.Password,
162+
}
163+
164+
return dialer.Dial(ctx, opts)
165+
}
166+
167+
type openSslDialer struct {
168+
address string
169+
auth Auth
170+
sslKeyFile string
171+
sslCertFile string
172+
sslCaFile string
173+
sslCiphers string
174+
sslPassword string
175+
sslPasswordFile string
176+
}
177+
178+
func (d openSslDialer) Dial(ctx context.Context, opts DialOpts) (Conn, error) {
124179
var err error
125180
conn := new(tntConn)
126181

127-
network, address := parseAddress(d.Address)
128-
dialer := net.Dialer{}
129-
conn.net, err = dialer.DialContext(ctx, network, address)
182+
network, address := parseAddress(d.address)
183+
conn.net, err = sslDialContext(ctx, network, address, sslOpts{
184+
KeyFile: d.sslKeyFile,
185+
CertFile: d.sslCertFile,
186+
CaFile: d.sslCaFile,
187+
Ciphers: d.sslCiphers,
188+
Password: d.sslPassword,
189+
PasswordFile: d.sslPasswordFile,
190+
})
130191
if err != nil {
131192
return nil, fmt.Errorf("failed to dial: %w", err)
132193
}
@@ -135,20 +196,18 @@ func (d NetDialer) Dial(ctx context.Context, opts DialOpts) (Conn, error) {
135196
conn.reader = bufio.NewReaderSize(dc, bufSize)
136197
conn.writer = bufio.NewWriterSize(dc, bufSize)
137198

138-
salt, err := rawDial(conn, d.RequiredProtocolInfo)
199+
err = rawDial(conn)
139200
if err != nil {
140201
conn.net.Close()
141202
return nil, err
142203
}
143204

144-
if d.User == "" {
145-
return conn, nil
146-
}
147-
148-
conn.protocol.Auth = ChapSha1Auth
149-
if err = authenticate(conn, ChapSha1Auth, d.User, d.Password, salt); err != nil {
150-
conn.net.Close()
151-
return nil, fmt.Errorf("failed to authenticate: %w", err)
205+
if d.auth == AutoAuth {
206+
if conn.protocol.Auth == AutoAuth {
207+
conn.protocol.Auth = ChapSha1Auth
208+
}
209+
} else {
210+
conn.protocol.Auth = d.auth
152211
}
153212

154213
return conn, nil
@@ -206,51 +265,26 @@ type OpenSslDialer struct {
206265

207266
// Dial makes OpenSslDialer satisfy the Dialer interface.
208267
func (d OpenSslDialer) Dial(ctx context.Context, opts DialOpts) (Conn, error) {
209-
var err error
210-
conn := new(tntConn)
211-
212-
network, address := parseAddress(d.Address)
213-
conn.net, err = sslDialContext(ctx, network, address, sslOpts{
214-
KeyFile: d.SslKeyFile,
215-
CertFile: d.SslCertFile,
216-
CaFile: d.SslCaFile,
217-
Ciphers: d.SslCiphers,
218-
Password: d.SslPassword,
219-
PasswordFile: d.SslPasswordFile,
220-
})
221-
if err != nil {
222-
return nil, fmt.Errorf("failed to dial: %w", err)
223-
}
224-
225-
dc := &deadlineIO{to: opts.IoTimeout, c: conn.net}
226-
conn.reader = bufio.NewReaderSize(dc, bufSize)
227-
conn.writer = bufio.NewWriterSize(dc, bufSize)
228-
229-
salt, err := rawDial(conn, d.RequiredProtocolInfo)
230-
if err != nil {
231-
conn.net.Close()
232-
return nil, err
233-
}
234-
235-
if d.User == "" {
236-
return conn, nil
237-
}
238-
239-
if d.Auth == AutoAuth {
240-
if conn.protocol.Auth != AutoAuth {
241-
d.Auth = conn.protocol.Auth
242-
} else {
243-
d.Auth = ChapSha1Auth
244-
}
245-
}
246-
conn.protocol.Auth = d.Auth
247-
248-
if err = authenticate(conn, d.Auth, d.User, d.Password, salt); err != nil {
249-
conn.net.Close()
250-
return nil, fmt.Errorf("failed to authenticate: %w", err)
251-
}
252-
253-
return conn, nil
268+
dialer := AuthDialer{
269+
Dialer: ProtocolDialer{
270+
Dialer: openSslDialer{
271+
address: d.Address,
272+
auth: d.Auth,
273+
sslKeyFile: d.SslKeyFile,
274+
sslCertFile: d.SslCertFile,
275+
sslCaFile: d.SslCaFile,
276+
sslCiphers: d.SslCiphers,
277+
sslPassword: d.SslPassword,
278+
sslPasswordFile: d.SslPasswordFile,
279+
},
280+
RequiredProtocolInfo: d.RequiredProtocolInfo,
281+
},
282+
Auth: d.Auth,
283+
Username: d.User,
284+
Password: d.Password,
285+
}
286+
287+
return dialer.Dial(ctx, opts)
254288
}
255289

256290
// FdDialer allows to use an existing socket fd for connection.
@@ -263,6 +297,10 @@ type FdDialer struct {
263297
RequiredProtocolInfo ProtocolInfo
264298
}
265299

300+
type fdDialer struct {
301+
fd uintptr
302+
}
303+
266304
type fdAddr struct {
267305
Fd uintptr
268306
}
@@ -284,22 +322,21 @@ func (c *fdConn) RemoteAddr() net.Addr {
284322
return c.Addr
285323
}
286324

287-
// Dial makes FdDialer satisfy the Dialer interface.
288-
func (d FdDialer) Dial(ctx context.Context, opts DialOpts) (Conn, error) {
289-
file := os.NewFile(d.Fd, "")
325+
func (d fdDialer) Dial(ctx context.Context, opts DialOpts) (Conn, error) {
326+
file := os.NewFile(d.fd, "")
290327
c, err := net.FileConn(file)
291328
if err != nil {
292329
return nil, fmt.Errorf("failed to dial: %w", err)
293330
}
294331

295332
conn := new(tntConn)
296-
conn.net = &fdConn{Conn: c, Addr: fdAddr{Fd: d.Fd}}
333+
conn.net = &fdConn{Conn: c, Addr: fdAddr{Fd: d.fd}}
297334

298335
dc := &deadlineIO{to: opts.IoTimeout, c: conn.net}
299336
conn.reader = bufio.NewReaderSize(dc, bufSize)
300337
conn.writer = bufio.NewWriterSize(dc, bufSize)
301338

302-
_, err = rawDial(conn, d.RequiredProtocolInfo)
339+
err = rawDial(conn)
303340
if err != nil {
304341
conn.net.Close()
305342
return nil, err
@@ -308,6 +345,84 @@ func (d FdDialer) Dial(ctx context.Context, opts DialOpts) (Conn, error) {
308345
return conn, nil
309346
}
310347

348+
// Dial makes FdDialer satisfy the Dialer interface.
349+
func (d FdDialer) Dial(ctx context.Context, opts DialOpts) (Conn, error) {
350+
dialer := ProtocolDialer{
351+
Dialer: fdDialer{
352+
fd: d.Fd,
353+
},
354+
RequiredProtocolInfo: d.RequiredProtocolInfo,
355+
}
356+
357+
return dialer.Dial(ctx, opts)
358+
}
359+
360+
type AuthDialer struct {
361+
// Dialer is a base dialer.
362+
Dialer Dialer
363+
// Authentication options.
364+
Auth Auth
365+
Username string
366+
Password string
367+
}
368+
369+
// Dial makes AuthDialer satisfy the Dialer interface.
370+
func (d AuthDialer) Dial(ctx context.Context, opts DialOpts) (Conn, error) {
371+
conn, err := d.Dialer.Dial(ctx, opts)
372+
if err != nil {
373+
return conn, err
374+
}
375+
greeting := conn.Greeting()
376+
if greeting.Salt == "" {
377+
conn.Close()
378+
return nil, fmt.Errorf("failed to authenticate: " +
379+
"an invalid connection without salt")
380+
}
381+
382+
if d.Username == "" {
383+
return conn, nil
384+
}
385+
386+
protocolAuth := conn.ProtocolInfo().Auth
387+
if d.Auth == AutoAuth {
388+
if protocolAuth != AutoAuth {
389+
d.Auth = protocolAuth
390+
} else {
391+
d.Auth = ChapSha1Auth
392+
}
393+
}
394+
395+
if err := authenticate(conn, d.Auth, d.Username, d.Password,
396+
conn.Greeting().Salt); err != nil {
397+
conn.Close()
398+
return nil, fmt.Errorf("failed to authenticate: %w", err)
399+
}
400+
return conn, nil
401+
}
402+
403+
type ProtocolDialer struct {
404+
// Dialer is a base dialer.
405+
Dialer Dialer
406+
// RequiredProtocol contains minimal protocol version and
407+
// list of protocol features that should be supported by
408+
// Tarantool server. By default, there are no restrictions.
409+
RequiredProtocolInfo ProtocolInfo
410+
}
411+
412+
// Dial makes ProtocolDialer satisfy the Dialer interface.
413+
func (d ProtocolDialer) Dial(ctx context.Context, opts DialOpts) (Conn, error) {
414+
conn, err := d.Dialer.Dial(ctx, opts)
415+
if err != nil {
416+
return conn, err
417+
}
418+
err = checkProtocolInfo(d.RequiredProtocolInfo, conn.ProtocolInfo())
419+
if err != nil {
420+
conn.Close()
421+
return nil, fmt.Errorf("invalid server protocol: %w", err)
422+
}
423+
return conn, nil
424+
}
425+
311426
// Addr makes tntConn satisfy the Conn interface.
312427
func (c *tntConn) Addr() net.Addr {
313428
return c.net.RemoteAddr()

0 commit comments

Comments
 (0)