Skip to content

Commit 5427a8d

Browse files
author
dvilaverde
committed
allow setting the collation in auth handshake
1 parent 7c31dc4 commit 5427a8d

File tree

3 files changed

+47
-2
lines changed

3 files changed

+47
-2
lines changed

client/auth.go

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ import (
55
"crypto/tls"
66
"encoding/binary"
77
"fmt"
8+
"github.com/pingcap/tidb/pkg/parser/charset"
89

910
. "github.com/go-mysql-org/go-mysql/mysql"
1011
"github.com/go-mysql-org/go-mysql/packet"
@@ -269,7 +270,16 @@ func (c *Conn) writeAuthHandshake() error {
269270

270271
// Charset [1 byte]
271272
// use default collation id 33 here, is utf-8
272-
data[12] = DEFAULT_COLLATION_ID
273+
collationName := c.collation
274+
if len(collationName) == 0 {
275+
collationName = DEFAULT_COLLATION_NAME
276+
}
277+
collation, err := charset.GetCollationByName(collationName)
278+
if err != nil {
279+
return fmt.Errorf("invalid collation name %s", collationName)
280+
}
281+
282+
data[12] = byte(collation.ID)
273283

274284
// SSL Connection Request Packet
275285
// http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::SSLRequest

client/client_test.go

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,11 @@ func TestClientSuite(t *testing.T) {
3131
func (s *clientTestSuite) SetupSuite() {
3232
var err error
3333
addr := fmt.Sprintf("%s:%s", *test_util.MysqlHost, s.port)
34-
s.c, err = Connect(addr, *testUser, *testPassword, "")
34+
s.c, err = Connect(addr, *testUser, *testPassword, "", func(conn *Conn) {
35+
// test the collation logic, but this is essentially a no-op since
36+
// the collation set is the default value
37+
_ = conn.SetCollation(mysql.DEFAULT_COLLATION_NAME)
38+
})
3539
require.NoError(s.T(), err)
3640

3741
var result *mysql.Result
@@ -228,6 +232,21 @@ func (s *clientTestSuite) TestConn_SetCharset() {
228232
require.NoError(s.T(), err)
229233
}
230234

235+
func (s *clientTestSuite) TestConn_SetCollationAfterConnect() {
236+
err := s.c.SetCollation("latin1_swedish_ci")
237+
require.Error(s.T(), err)
238+
}
239+
240+
func (s *clientTestSuite) TestConn_SetCollation() {
241+
addr := fmt.Sprintf("%s:%s", *test_util.MysqlHost, s.port)
242+
_, err := Connect(addr, *testUser, *testPassword, "", func(conn *Conn) {
243+
// test the collation logic
244+
_ = conn.SetCollation("invalid_collation")
245+
})
246+
247+
require.Error(s.T(), err)
248+
}
249+
231250
func (s *clientTestSuite) testStmt_DropTable() {
232251
str := `drop table if exists mixer_test_stmt`
233252

client/conn.go

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,8 @@ type Conn struct {
3737
status uint16
3838

3939
charset string
40+
// sets the collation to be set on the auth handshake, this does not issue a 'set names' command
41+
collation string
4042

4143
salt []byte
4244
authPluginName string
@@ -357,6 +359,20 @@ func (c *Conn) SetCharset(charset string) error {
357359
}
358360
}
359361

362+
func (c *Conn) SetCollation(collation string) error {
363+
if c.status == 0 {
364+
c.collation = collation
365+
} else {
366+
return errors.Trace(errors.Errorf("cannot set collation after connection is established"))
367+
}
368+
369+
return nil
370+
}
371+
372+
func (c *Conn) GetCollation() string {
373+
return c.collation
374+
}
375+
360376
func (c *Conn) FieldList(table string, wildcard string) ([]*Field, error) {
361377
if err := c.writeCommandStrStr(COM_FIELD_LIST, table, wildcard); err != nil {
362378
return nil, errors.Trace(err)

0 commit comments

Comments
 (0)