From 1bac0e9d03fbb3119d3653683805177b2c5d8030 Mon Sep 17 00:00:00 2001 From: Reinier Schoof Date: Thu, 10 Feb 2022 13:47:16 +0100 Subject: [PATCH 1/5] removed unused arguments --- server/handshake_resp.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/server/handshake_resp.go b/server/handshake_resp.go index 3c4c320a5..b1c999a98 100644 --- a/server/handshake_resp.go +++ b/server/handshake_resp.go @@ -30,7 +30,7 @@ func (c *Conn) readHandshakeResponse() error { pos = c.readPluginName(data, pos) - cont, err := c.handleAuthMatch(authData, pos) + cont, err := c.handleAuthMatch() if err != nil { return err } @@ -191,7 +191,7 @@ func (c *Conn) handlePublicKeyRetrieval(authData []byte) (bool, error) { return true, nil } -func (c *Conn) handleAuthMatch(authData []byte, pos int) (bool, error) { +func (c *Conn) handleAuthMatch() (bool, error) { // if the client responds the handshake with a different auth method, the server will send the AuthSwitchRequest packet // to the client to ask the client to switch. From d22c8cfb9968af22ebdf82c73845ca4aa65b5182 Mon Sep 17 00:00:00 2001 From: Reinier Schoof Date: Thu, 10 Feb 2022 13:47:47 +0100 Subject: [PATCH 2/5] auth plugin name is NUL-terminated --- server/handshake_resp.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/server/handshake_resp.go b/server/handshake_resp.go index b1c999a98..446fe0099 100644 --- a/server/handshake_resp.go +++ b/server/handshake_resp.go @@ -130,7 +130,7 @@ func (c *Conn) readDb(data []byte, pos int) (int, error) { func (c *Conn) readPluginName(data []byte, pos int) int { if c.capability&CLIENT_PLUGIN_AUTH != 0 { c.authPluginName = string(data[pos : pos+bytes.IndexByte(data[pos:], 0x00)]) - pos += len(c.authPluginName) + pos += len(c.authPluginName) + 1 } else { // The method used is Native Authentication if both CLIENT_PROTOCOL_41 and CLIENT_SECURE_CONNECTION are set, // but CLIENT_PLUGIN_AUTH is not set, so we fallback to 'mysql_native_password' From 01495ed2326ca069359711549e6290af37c11824 Mon Sep 17 00:00:00 2001 From: Reinier Schoof Date: Thu, 10 Feb 2022 14:17:15 +0100 Subject: [PATCH 3/5] properly read connection attributes --- server/handshake_resp.go | 16 ++++++++++++++-- server/handshake_resp_test.go | 2 +- 2 files changed, 15 insertions(+), 3 deletions(-) diff --git a/server/handshake_resp.go b/server/handshake_resp.go index 446fe0099..8fc4ff7b7 100644 --- a/server/handshake_resp.go +++ b/server/handshake_resp.go @@ -207,17 +207,29 @@ func (c *Conn) handleAuthMatch() (bool, error) { } func (c *Conn) readAttributes(data []byte, pos int) (int, error) { - attrs := make(map[string]string) + // read length of attribute data + attrLen, isNull, skip := LengthEncodedInt(data[pos:]) + pos += skip + if isNull { + return pos, nil + } + + if len(data) < pos+int(attrLen) { + return pos, errors.New("corrupt attributes data") + } + i := 0 + attrs := make(map[string]string) var key string + // read until end of data or NUL for atrribute key/values for { str, isNull, strLen, err := LengthEncodedString(data[pos:]) - if err != nil { return -1, err } + // end of data if isNull { break } diff --git a/server/handshake_resp_test.go b/server/handshake_resp_test.go index 5724422fb..b4232e612 100644 --- a/server/handshake_resp_test.go +++ b/server/handshake_resp_test.go @@ -75,7 +75,7 @@ func TestReadAttributes(t *testing.T) { 0x6d, 0x06, 0x78, 0x38, 0x36, 0x5f, 0x36, 0x34, 0x03, 0x66, 0x6f, 0x6f, 0x03, 0x62, 0x61, 0x72, } - pos := 85 + pos := 84 c := &Conn{} From aea4caa2164179f75067b38039464d67bca6d0ae Mon Sep 17 00:00:00 2001 From: Reinier Schoof Date: Fri, 11 Feb 2022 10:15:11 +0100 Subject: [PATCH 4/5] added test for readPluginName --- server/handshake_resp_test.go | 70 +++++++++++++++++++++++++++++++++++ 1 file changed, 70 insertions(+) diff --git a/server/handshake_resp_test.go b/server/handshake_resp_test.go index b4232e612..e7780641d 100644 --- a/server/handshake_resp_test.go +++ b/server/handshake_resp_test.go @@ -53,6 +53,76 @@ func TestDecodeFirstPart(t *testing.T) { } } +func TestReadPluginName(t *testing.T) { + // example data from + // https://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::HandshakeResponse41 + mysqlNativePassword := []byte{ + 0x54, 0x00, 0x00, 0x01, 0x8d, 0xa6, 0x0f, 0x00, 0x00, 0x00, 0x00, + 0x01, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x70, 0x61, 0x6d, 0x00, 0x14, 0xab, 0x09, 0xee, + 0xf6, 0xbc, 0xb1, 0x32, 0x3e, 0x61, 0x14, 0x38, 0x65, 0xc0, 0x99, + 0x1d, 0x95, 0x7d, 0x75, 0xd4, 0x47, 0x74, 0x65, 0x73, 0x74, 0x00, + 0x6d, 0x79, 0x73, 0x71, 0x6c, 0x5f, 0x6e, 0x61, 0x74, 0x69, 0x76, + 0x65, 0x5f, 0x70, 0x61, 0x73, 0x73, 0x77, 0x6f, 0x72, 0x64, 0x00, + } + + // altered example data so it has different auth plugin + otherPlugin := []byte{ + 0x54, 0x00, 0x00, 0x01, 0x8d, 0xa6, 0x0f, 0x00, 0x00, 0x00, 0x00, + 0x01, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x70, 0x61, 0x6d, 0x00, 0x14, 0xab, 0x09, 0xee, + 0xf6, 0xbc, 0xb1, 0x32, 0x3e, 0x61, 0x14, 0x38, 0x65, 0xc0, 0x99, + 0x1d, 0x95, 0x7d, 0x75, 0xd4, 0x47, 0x74, 0x65, 0x73, 0x74, 0x00, + 0x66, 0x6f, 0x6f, 0x62, 0x61, 0x72, 0x00, + } + + t.Run("mysql_native_password from plugin name", func(t *testing.T) { + c := &Conn{} + c.SetCapability(mysql.CLIENT_PLUGIN_AUTH) + pos := 66 + + pos = c.readPluginName(mysqlNativePassword, pos) + if pos != 88 { + t.Fatalf("unexpected pos, got %d", pos) + } + + if c.authPluginName != "mysql_native_password" { + t.Fatalf("unexpected plugin name, got %s", c.authPluginName) + } + }) + + t.Run("other plugin", func(t *testing.T) { + c := &Conn{} + c.SetCapability(mysql.CLIENT_PLUGIN_AUTH) + pos := 66 + + pos = c.readPluginName(otherPlugin, pos) + if pos != 73 { + t.Fatalf("unexpected pos, got %d", pos) + } + + if c.authPluginName != "foobar" { + t.Fatalf("unexpected plugin name, got %s", c.authPluginName) + } + }) + + t.Run("mysql_native_password as default", func(t *testing.T) { + c := &Conn{} + pos := 123 // can be anything + + pos = c.readPluginName(mysqlNativePassword, pos) + if pos != 123 { + t.Fatalf("unexpected pos, got %d", pos) + } + + if c.authPluginName != mysql.AUTH_NATIVE_PASSWORD { + t.Fatalf("unexpected plugin name, got %s", c.authPluginName) + } + }) +} + func TestReadAttributes(t *testing.T) { var err error // example data from From 0a0d6f26dcca9f89e2fc4af17db04070d9bfa777 Mon Sep 17 00:00:00 2001 From: Reinier Schoof Date: Fri, 11 Feb 2022 10:41:50 +0100 Subject: [PATCH 5/5] added test for readDB --- go.mod | 1 + go.sum | 6 +- mocks/Handler.go | 161 ++++++++++++++++++++++++++++++++++ server/handshake_resp_test.go | 52 ++++++++++- 4 files changed, 216 insertions(+), 4 deletions(-) create mode 100644 mocks/Handler.go diff --git a/go.mod b/go.mod index 97c99cd36..49c3ed7bc 100644 --- a/go.mod +++ b/go.mod @@ -14,6 +14,7 @@ require ( github.com/shopspring/decimal v0.0.0-20180709203117-cd690d0c9e24 github.com/siddontang/go v0.0.0-20180604090527-bdc77568d726 github.com/siddontang/go-log v0.0.0-20180807004314-8d05993dda07 + github.com/stretchr/testify v1.7.0 go.uber.org/multierr v1.6.0 // indirect go.uber.org/zap v1.16.0 // indirect golang.org/x/text v0.3.6 // indirect diff --git a/go.sum b/go.sum index c943996c7..fd01e57e1 100644 --- a/go.sum +++ b/go.sum @@ -51,10 +51,12 @@ github.com/siddontang/go v0.0.0-20180604090527-bdc77568d726 h1:xT+JlYxNGqyT+XcU8 github.com/siddontang/go v0.0.0-20180604090527-bdc77568d726/go.mod h1:3yhqj7WBBfRhbBlzyOC3gUxftwsU0u8gqevxwIHQpMw= github.com/siddontang/go-log v0.0.0-20180807004314-8d05993dda07 h1:oI+RNwuC9jF2g2lP0u0cVEEZrc/AYBCuFdvwrLWM/6Q= github.com/siddontang/go-log v0.0.0-20180807004314-8d05993dda07/go.mod h1:yFdBgwXP24JziuRl2NMUahT7nGLNOKi1SIiFxMttVD4= +github.com/stretchr/objx v0.1.0 h1:4G4v2dO3VZwixGIRoQ5Lfboy6nUhCyYzaqnIAPPhYs4= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= -github.com/stretchr/testify v1.4.0 h1:2E4SXV/wtOkTonXsotYi4li6zVWxYlZuYNCXe9XRJyk= github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= +github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5CcY= +github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= go.uber.org/atomic v1.3.2/go.mod h1:gD2HeocX3+yG+ygLZcrzQJaqmWj9AIm7n08wl/qW/PE= go.uber.org/atomic v1.6.0/go.mod h1:sABNBOSYdrvTF6hTgEIbc7YasKWGhgEQZyfxyTvoXHQ= @@ -112,5 +114,7 @@ gopkg.in/natefinch/lumberjack.v2 v2.0.0 h1:1Lc07Kr7qY4U2YPouBjpCLxpiyxIVoxqXgkXL gopkg.in/natefinch/lumberjack.v2 v2.0.0/go.mod h1:l0ndWWf7gzL7RNwBG7wST/UCcT4T24xpD6X8LsfU/+k= gopkg.in/yaml.v2 v2.2.2 h1:ZCJp+EgiOT7lHqUV2J862kp8Qj64Jo6az82+3Td9dZw= gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c h1:dUUwHk2QECo/6vqA44rthZ8ie2QXMNeKRTHCNY2nXvo= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= honnef.co/go/tools v0.0.1-2019.2.3 h1:3JgtbtFHMiCmsznwGVTUWbgGov+pVqnlf1dEJTNAXeM= honnef.co/go/tools v0.0.1-2019.2.3/go.mod h1:a3bituU0lyd329TUQxRnasdCoJDkEUEAqEt0JzvZhAg= diff --git a/mocks/Handler.go b/mocks/Handler.go new file mode 100644 index 000000000..a871e84fe --- /dev/null +++ b/mocks/Handler.go @@ -0,0 +1,161 @@ +// Code generated by mockery v2.9.4. DO NOT EDIT. + +package mocks + +import ( + mysql "github.com/go-mysql-org/go-mysql/mysql" + mock "github.com/stretchr/testify/mock" +) + +// Handler is an autogenerated mock type for the Handler type +type Handler struct { + mock.Mock +} + +// HandleFieldList provides a mock function with given fields: table, fieldWildcard +func (_m *Handler) HandleFieldList(table string, fieldWildcard string) ([]*mysql.Field, error) { + ret := _m.Called(table, fieldWildcard) + + var r0 []*mysql.Field + if rf, ok := ret.Get(0).(func(string, string) []*mysql.Field); ok { + r0 = rf(table, fieldWildcard) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]*mysql.Field) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(string, string) error); ok { + r1 = rf(table, fieldWildcard) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// HandleOtherCommand provides a mock function with given fields: cmd, data +func (_m *Handler) HandleOtherCommand(cmd byte, data []byte) error { + ret := _m.Called(cmd, data) + + var r0 error + if rf, ok := ret.Get(0).(func(byte, []byte) error); ok { + r0 = rf(cmd, data) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// HandleQuery provides a mock function with given fields: query +func (_m *Handler) HandleQuery(query string) (*mysql.Result, error) { + ret := _m.Called(query) + + var r0 *mysql.Result + if rf, ok := ret.Get(0).(func(string) *mysql.Result); ok { + r0 = rf(query) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*mysql.Result) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(string) error); ok { + r1 = rf(query) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// HandleStmtClose provides a mock function with given fields: context +func (_m *Handler) HandleStmtClose(context interface{}) error { + ret := _m.Called(context) + + var r0 error + if rf, ok := ret.Get(0).(func(interface{}) error); ok { + r0 = rf(context) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// HandleStmtExecute provides a mock function with given fields: context, query, args +func (_m *Handler) HandleStmtExecute(context interface{}, query string, args []interface{}) (*mysql.Result, error) { + ret := _m.Called(context, query, args) + + var r0 *mysql.Result + if rf, ok := ret.Get(0).(func(interface{}, string, []interface{}) *mysql.Result); ok { + r0 = rf(context, query, args) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*mysql.Result) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(interface{}, string, []interface{}) error); ok { + r1 = rf(context, query, args) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// HandleStmtPrepare provides a mock function with given fields: query +func (_m *Handler) HandleStmtPrepare(query string) (int, int, interface{}, error) { + ret := _m.Called(query) + + var r0 int + if rf, ok := ret.Get(0).(func(string) int); ok { + r0 = rf(query) + } else { + r0 = ret.Get(0).(int) + } + + var r1 int + if rf, ok := ret.Get(1).(func(string) int); ok { + r1 = rf(query) + } else { + r1 = ret.Get(1).(int) + } + + var r2 interface{} + if rf, ok := ret.Get(2).(func(string) interface{}); ok { + r2 = rf(query) + } else { + if ret.Get(2) != nil { + r2 = ret.Get(2).(interface{}) + } + } + + var r3 error + if rf, ok := ret.Get(3).(func(string) error); ok { + r3 = rf(query) + } else { + r3 = ret.Error(3) + } + + return r0, r1, r2, r3 +} + +// UseDB provides a mock function with given fields: dbName +func (_m *Handler) UseDB(dbName string) error { + ret := _m.Called(dbName) + + var r0 error + if rf, ok := ret.Get(0).(func(string) error); ok { + r0 = rf(dbName) + } else { + r0 = ret.Error(0) + } + + return r0 +} diff --git a/server/handshake_resp_test.go b/server/handshake_resp_test.go index e7780641d..1cb9cd237 100644 --- a/server/handshake_resp_test.go +++ b/server/handshake_resp_test.go @@ -4,7 +4,9 @@ import ( "bytes" "testing" + "github.com/go-mysql-org/go-mysql/mocks" "github.com/go-mysql-org/go-mysql/mysql" + "github.com/stretchr/testify/mock" ) func TestReadAuthData(t *testing.T) { @@ -53,6 +55,50 @@ func TestDecodeFirstPart(t *testing.T) { } } +func TestReadDB(t *testing.T) { + handler := &mocks.Handler{} + c := &Conn{ + h: handler, + } + c.SetCapability(mysql.CLIENT_CONNECT_WITH_DB) + var dbName string + + // when handler's UseDB is called, copy dbName to local variable + handler.On("UseDB", mock.IsType("")).Return(nil).Once().RunFn = func(args mock.Arguments) { + dbName = args[0].(string) + } + + // example data from + // https://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::HandshakeResponse41 + data := []byte{ + 0x54, 0x00, 0x00, 0x01, 0x8d, 0xa6, 0x0f, 0x00, 0x00, 0x00, 0x00, 0x01, + 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x70, 0x61, 0x6d, 0x00, 0x14, 0xab, 0x09, 0xee, 0xf6, 0xbc, 0xb1, 0x32, + 0x3e, 0x61, 0x14, 0x38, 0x65, 0xc0, 0x99, 0x1d, 0x95, 0x7d, 0x75, 0xd4, + 0x47, 0x74, 0x65, 0x73, 0x74, 0x00, 0x6d, 0x79, 0x73, 0x71, 0x6c, 0x5f, + 0x6e, 0x61, 0x74, 0x69, 0x76, 0x65, 0x5f, 0x70, 0x61, 0x73, 0x73, 0x77, + 0x6f, 0x72, 0x64, 0x00, + } + pos := 61 + + var err error + pos, err = c.readDb(data, pos) + if err != nil { + t.Fatalf("unexpected error: %s", err.Error()) + } + + if pos != 66 { // 61 + len("test") + 1 + t.Fatalf("unexpected pos, got %d", pos) + } + + if dbName != "test" { + t.Fatalf("unexpected db, got %s", dbName) + } + + handler.AssertExpectations(t) +} + func TestReadPluginName(t *testing.T) { // example data from // https://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::HandshakeResponse41 @@ -84,7 +130,7 @@ func TestReadPluginName(t *testing.T) { pos := 66 pos = c.readPluginName(mysqlNativePassword, pos) - if pos != 88 { + if pos != 88 { // 66 + len("mysql_native_password") + 1 t.Fatalf("unexpected pos, got %d", pos) } @@ -99,7 +145,7 @@ func TestReadPluginName(t *testing.T) { pos := 66 pos = c.readPluginName(otherPlugin, pos) - if pos != 73 { + if pos != 73 { // 66 + len("foobar") + 1 t.Fatalf("unexpected pos, got %d", pos) } @@ -113,7 +159,7 @@ func TestReadPluginName(t *testing.T) { pos := 123 // can be anything pos = c.readPluginName(mysqlNativePassword, pos) - if pos != 123 { + if pos != 123 { // capability not set, so same as initial pos t.Fatalf("unexpected pos, got %d", pos) }