Skip to content

Commit 3922b93

Browse files
exchange protocol info on dial
1 parent 0081afb commit 3922b93

File tree

2 files changed

+51
-36
lines changed

2 files changed

+51
-36
lines changed

connection.go

+51-28
Original file line numberDiff line numberDiff line change
@@ -393,13 +393,6 @@ func Connect(addr string, opts Opts) (conn *Connection, err error) {
393393
}
394394
}
395395

396-
if err = conn.loadProtocolInfo(); err != nil {
397-
conn.mutex.Lock()
398-
defer conn.mutex.Unlock()
399-
conn.closeConnection(err, true)
400-
return nil, err
401-
}
402-
403396
return conn, err
404397
}
405398

@@ -511,6 +504,13 @@ func (conn *Connection) dial() (err error) {
511504
conn.Greeting.Version = bytes.NewBuffer(greeting[:64]).String()
512505
conn.Greeting.auth = bytes.NewBuffer(greeting[64:108]).String()
513506

507+
// IPROTO_ID requests can be processed without authentication.
508+
// https://www.tarantool.io/en/doc/latest/dev_guide/internals/iproto/requests/#iproto-id
509+
if err = conn.loadProtocolInfo(w, r); err != nil {
510+
connection.Close()
511+
return err
512+
}
513+
514514
// Auth
515515
if opts.User != "" {
516516
scr, err := scramble(conn.Greeting.auth, opts.Pass)
@@ -590,43 +590,64 @@ func pack(h *smallWBuf, enc *encoder, reqid uint32,
590590
return
591591
}
592592

593-
func (conn *Connection) writeAuthRequest(w *bufio.Writer, scramble []byte) (err error) {
593+
func (conn *Connection) writeRequestRaw(w *bufio.Writer, req Request,
594+
reqName string) (err error) {
594595
var packet smallWBuf
595-
req := newAuthRequest(conn.opts.User, string(scramble))
596596
err = pack(&packet, newEncoder(&packet), 0, req, ignoreStreamId, conn.Schema)
597597

598598
if err != nil {
599-
return errors.New("auth: pack error " + err.Error())
599+
return errors.New(reqName + ": pack error " + err.Error())
600600
}
601601
if err := write(w, packet.b); err != nil {
602-
return errors.New("auth: write error " + err.Error())
602+
return errors.New(reqName + ": write error " + err.Error())
603603
}
604604
if err = w.Flush(); err != nil {
605-
return errors.New("auth: flush error " + err.Error())
605+
return errors.New(reqName + ": flush error " + err.Error())
606606
}
607607
return
608608
}
609609

610-
func (conn *Connection) readAuthResponse(r io.Reader) (err error) {
610+
func (conn *Connection) writeAuthRequest(w *bufio.Writer, scramble []byte) (err error) {
611+
req := newAuthRequest(conn.opts.User, string(scramble))
612+
return conn.writeRequestRaw(w, req, "auth")
613+
}
614+
615+
func (conn *Connection) writeProtocolInfoRequest(w *bufio.Writer, version ProtocolVersion,
616+
features []ProtocolFeature) (err error) {
617+
req := newProtocolInfoRequest(version, features)
618+
return conn.writeRequestRaw(w, req, "iproto id")
619+
}
620+
621+
func (conn *Connection) readResponseRaw(r io.Reader,
622+
reqName string) (resp Response, err error) {
611623
respBytes, err := conn.read(r)
612624
if err != nil {
613-
return errors.New("auth: read error " + err.Error())
625+
return resp, errors.New(reqName + ": read error " + err.Error())
614626
}
615-
resp := Response{buf: smallBuf{b: respBytes}}
627+
resp = Response{buf: smallBuf{b: respBytes}}
616628
err = resp.decodeHeader(conn.dec)
617629
if err != nil {
618-
return errors.New("auth: decode response header error " + err.Error())
630+
return resp, errors.New(reqName + ": decode response header error " + err.Error())
619631
}
620632
err = resp.decodeBody()
621633
if err != nil {
622634
switch err.(type) {
623635
case Error:
624-
return err
636+
return resp, err
625637
default:
626-
return errors.New("auth: decode response body error " + err.Error())
638+
return resp, errors.New(reqName + ": decode response body error " + err.Error())
627639
}
628640
}
629-
return
641+
return resp, nil
642+
}
643+
644+
func (conn *Connection) readAuthResponse(r io.Reader) (err error) {
645+
_, err = conn.readResponseRaw(r, "auth")
646+
return err
647+
}
648+
649+
func (conn *Connection) readProtocolInfoResponse(r io.Reader) (resp Response, err error) {
650+
return conn.readResponseRaw(r, "iproto id")
630651
}
631652

632653
func (conn *Connection) createConnection(reconnect bool) (err error) {
@@ -696,10 +717,6 @@ func (conn *Connection) reconnect(neterr error, c net.Conn) {
696717
conn.closeConnection(neterr, false)
697718
if err := conn.createConnection(true); err != nil {
698719
conn.closeConnection(err, true)
699-
} else {
700-
if err = conn.loadProtocolInfo(); err != nil {
701-
conn.closeConnection(err, true)
702-
}
703720
}
704721
}
705722
} else {
@@ -1180,15 +1197,21 @@ func (conn *Connection) NewStream() (*Stream, error) {
11801197
// loadProtocolInfo sends info about client protocol,
11811198
// receives info about server protocol in response
11821199
// and store in in connection serverProtocolInfo.
1183-
func (conn *Connection) loadProtocolInfo() error {
1200+
func (conn *Connection) loadProtocolInfo(w *bufio.Writer, r *bufio.Reader) error {
11841201
var ok bool
1202+
var resp Response
1203+
var err error
1204+
1205+
err = conn.writeProtocolInfoRequest(w, ClientProtocolVersion, ClientProtocolFeatures)
1206+
if err != nil {
1207+
return err
1208+
}
11851209

1186-
resp, err := conn.exchangeProtocolInfo(
1187-
ClientProtocolVersion,
1188-
ClientProtocolFeatures)
1210+
resp, err = conn.readProtocolInfoResponse(r)
11891211

11901212
if err != nil {
1191-
if resp.Code == ErrUnknownRequestType {
1213+
tarantoolError, ok := err.(Error)
1214+
if ok && tarantoolError.Code == ErrUnknownRequestType {
11921215
// IPROTO_ID requests are not supported by server.
11931216
conn.serverProtocolInfo = protocolInfo{
11941217
version: ProtocolVersionUnsupported,

request.go

-8
Original file line numberDiff line numberDiff line change
@@ -1124,14 +1124,6 @@ func newProtocolInfoRequest(protocolVersion ProtocolVersion,
11241124
return req
11251125
}
11261126

1127-
// exchangeProtocolInfo sends info about client protocol
1128-
// and receives info about server protocol in response.
1129-
func (conn *Connection) exchangeProtocolInfo(version ProtocolVersion,
1130-
features []ProtocolFeature) (resp *Response, err error) {
1131-
req := newProtocolInfoRequest(version, features)
1132-
return conn.Do(req).Get()
1133-
}
1134-
11351127
// Body fills an encoder with the protocol version request body.
11361128
func (req *protocolInfoRequest) Body(res SchemaResolver, enc *encoder) error {
11371129
return req.fillProtocolInfoRequest(enc)

0 commit comments

Comments
 (0)