Skip to content

Commit 2e231fa

Browse files
committed
Fix mysqlConn.{affectedRows,insertedIds} growing on each call to Query(), QueryContext() and Ping().
1 parent 241fd91 commit 2e231fa

File tree

7 files changed

+170
-46
lines changed

7 files changed

+170
-46
lines changed

auth.go

+3-3
Original file line numberDiff line numberDiff line change
@@ -347,7 +347,7 @@ func (mc *mysqlConn) handleAuthResult(oldAuthData []byte, plugin string) error {
347347
case 1:
348348
switch authData[0] {
349349
case cachingSha2PasswordFastAuthSuccess:
350-
if err = mc.readResultOK(); err == nil {
350+
if err = mc.readResultOK(resultUnchanged); err == nil {
351351
return nil // auth successful
352352
}
353353

@@ -391,7 +391,7 @@ func (mc *mysqlConn) handleAuthResult(oldAuthData []byte, plugin string) error {
391391
return err
392392
}
393393
}
394-
return mc.readResultOK()
394+
return mc.readResultOK(resultUnchanged)
395395

396396
default:
397397
return ErrMalformPkt
@@ -416,7 +416,7 @@ func (mc *mysqlConn) handleAuthResult(oldAuthData []byte, plugin string) error {
416416
if err != nil {
417417
return err
418418
}
419-
return mc.readResultOK()
419+
return mc.readResultOK(resultUnchanged)
420420
}
421421

422422
default:

connection.go

+52-14
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,8 @@ import (
2323
type mysqlConn struct {
2424
buf buffer
2525
netConn net.Conn
26-
rawConn net.Conn // underlying connection when netConn is TLS connection.
27-
affectedRows []int64
28-
insertIds []int64
26+
rawConn net.Conn // underlying connection when netConn is TLS connection.
27+
result mysqlResult // managed by clearResult() and handleOkPacket().
2928
cfg *Config
3029
maxAllowedPacket int
3130
maxWriteSize int
@@ -45,6 +44,43 @@ type mysqlConn struct {
4544
closed atomicBool // set when conn is closed, before closech is closed
4645
}
4746

47+
// To correctly manage mysqlConn.result (updated by handleOkPacket()), we need
48+
// to ensure all callpaths have either:
49+
//
50+
// 1. cleared it using clearResult() before sending the command, or
51+
// 2. don't need to (eg. in call paths which are accumulating resultsets).
52+
//
53+
// handleOkPacket() takes an argument of this type to ensure exhaustively that
54+
// all callpaths manage this state correctly.
55+
type resultState int
56+
57+
const (
58+
// mysqlConn.result was cleared (ie. a new command or query is being run.)
59+
//
60+
// This value is obtained by calling mysqlConn.clearResult().
61+
resultCleared resultState = iota + 1
62+
// mysqlConn.result was unchanged (ie. additional resultsets are being
63+
// fetched, or the fields did not need to be cleared.)
64+
resultUnchanged
65+
)
66+
67+
// clearResult clears the connection's stored affectedRows and insertIds
68+
// fields.
69+
//
70+
// Ref: https://dev.mysql.com/doc/internals/en/packet-OK_Packet.html
71+
//
72+
// It returns a resultCleared status, to be passed directly (or
73+
// indirectly) to handleOkPacket().
74+
//
75+
// All call paths ending in handleOkPacket() must either:
76+
//
77+
// 1. call clearResult(), and pass its result to handleOkPacket().
78+
// 2. pass resultUnchanged to handleOkPacket().
79+
func (mc *mysqlConn) clearResult() resultState {
80+
mc.result = mysqlResult{}
81+
return resultCleared
82+
}
83+
4884
// Handles parameters set in DSN after the connection is established
4985
func (mc *mysqlConn) handleParams() (err error) {
5086
var cmdSet strings.Builder
@@ -124,6 +160,7 @@ func (mc *mysqlConn) begin(readOnly bool) (driver.Tx, error) {
124160
func (mc *mysqlConn) Close() (err error) {
125161
// Makes Close idempotent
126162
if !mc.closed.IsSet() {
163+
mc.clearResult()
127164
err = mc.writeCommandPacket(comQuit)
128165
}
129166

@@ -310,28 +347,25 @@ func (mc *mysqlConn) Exec(query string, args []driver.Value) (driver.Result, err
310347
}
311348
query = prepared
312349
}
313-
mc.affectedRows = nil
314-
mc.insertIds = nil
315350

316351
err := mc.exec(query)
317352
if err == nil {
318-
return &mysqlResult{
319-
affectedRows: mc.affectedRows,
320-
insertIds: mc.insertIds,
321-
}, err
353+
copied := mc.result
354+
return &copied, err
322355
}
323356
return nil, mc.markBadConn(err)
324357
}
325358

326359
// Internal function to execute commands
327360
func (mc *mysqlConn) exec(query string) error {
361+
resultCleared := mc.clearResult()
328362
// Send command
329363
if err := mc.writeCommandPacketStr(comQuery, query); err != nil {
330364
return mc.markBadConn(err)
331365
}
332366

333367
// Read Result
334-
resLen, err := mc.readResultSetHeaderPacket()
368+
resLen, err := mc.readResultSetHeaderPacket(resultCleared)
335369
if err != nil {
336370
return err
337371
}
@@ -348,14 +382,16 @@ func (mc *mysqlConn) exec(query string) error {
348382
}
349383
}
350384

351-
return mc.discardResults()
385+
return mc.discardResults(resultUnchanged)
352386
}
353387

354388
func (mc *mysqlConn) Query(query string, args []driver.Value) (driver.Rows, error) {
355389
return mc.query(query, args)
356390
}
357391

358392
func (mc *mysqlConn) query(query string, args []driver.Value) (*textRows, error) {
393+
resultCleared := mc.clearResult()
394+
359395
if mc.closed.IsSet() {
360396
errLog.Print(ErrInvalidConn)
361397
return nil, driver.ErrBadConn
@@ -376,7 +412,7 @@ func (mc *mysqlConn) query(query string, args []driver.Value) (*textRows, error)
376412
if err == nil {
377413
// Read Result
378414
var resLen int
379-
resLen, err = mc.readResultSetHeaderPacket()
415+
resLen, err = mc.readResultSetHeaderPacket(resultCleared)
380416
if err == nil {
381417
rows := new(textRows)
382418
rows.mc = mc
@@ -404,12 +440,13 @@ func (mc *mysqlConn) query(query string, args []driver.Value) (*textRows, error)
404440
// The returned byte slice is only valid until the next read
405441
func (mc *mysqlConn) getSystemVar(name string) ([]byte, error) {
406442
// Send command
443+
resultCleared := mc.clearResult()
407444
if err := mc.writeCommandPacketStr(comQuery, "SELECT @@"+name); err != nil {
408445
return nil, err
409446
}
410447

411448
// Read Result
412-
resLen, err := mc.readResultSetHeaderPacket()
449+
resLen, err := mc.readResultSetHeaderPacket(resultCleared)
413450
if err == nil {
414451
rows := new(textRows)
415452
rows.mc = mc
@@ -460,11 +497,12 @@ func (mc *mysqlConn) Ping(ctx context.Context) (err error) {
460497
}
461498
defer mc.finish()
462499

500+
resultCleared := mc.clearResult()
463501
if err = mc.writeCommandPacket(comPing); err != nil {
464502
return mc.markBadConn(err)
465503
}
466504

467-
return mc.readResultOK()
505+
return mc.readResultOK(resultCleared)
468506
}
469507

470508
// BeginTx implements driver.ConnBeginTx interface

driver_test.go

+76
Original file line numberDiff line numberDiff line change
@@ -2155,11 +2155,51 @@ func TestRejectReadOnly(t *testing.T) {
21552155
}
21562156

21572157
func TestPing(t *testing.T) {
2158+
ctx := context.Background()
21582159
runTests(t, dsn, func(dbt *DBTest) {
21592160
if err := dbt.db.Ping(); err != nil {
21602161
dbt.fail("Ping", "Ping", err)
21612162
}
21622163
})
2164+
2165+
runTests(t, dsn, func(dbt *DBTest) {
2166+
conn, err := dbt.db.Conn(ctx)
2167+
if err != nil {
2168+
dbt.fail("db", "Conn", err)
2169+
}
2170+
2171+
// Check that affectedRows and insertIds are cleared after each call.
2172+
conn.Raw(func(conn interface{}) error {
2173+
c := conn.(*mysqlConn)
2174+
2175+
// Issue a query that sets affectedRows and insertIds.
2176+
q, err := c.Query(`SELECT 1`, nil)
2177+
if err != nil {
2178+
dbt.fail("Conn", "Query", err)
2179+
}
2180+
if got, want := c.result.affectedRows, []int64{0}; !reflect.DeepEqual(got, want) {
2181+
dbt.Fatalf("bad affectedRows: got %v, want=%v", got, want)
2182+
}
2183+
if got, want := c.result.insertIds, []int64{0}; !reflect.DeepEqual(got, want) {
2184+
dbt.Fatalf("bad insertIds: got %v, want=%v", got, want)
2185+
}
2186+
q.Close()
2187+
2188+
// Verify that Ping() clears both fields.
2189+
for i := 0; i < 2; i++ {
2190+
if err := c.Ping(ctx); err != nil {
2191+
dbt.fail("Pinger", "Ping", err)
2192+
}
2193+
if got, want := c.result.affectedRows, []int64(nil); !reflect.DeepEqual(got, want) {
2194+
t.Errorf("bad affectedRows: got %v, want=%v", got, want)
2195+
}
2196+
if got, want := c.result.insertIds, []int64(nil); !reflect.DeepEqual(got, want) {
2197+
t.Errorf("bad affectedRows: got %v, want=%v", got, want)
2198+
}
2199+
}
2200+
return nil
2201+
})
2202+
})
21632203
}
21642204

21652205
// See Issue #799
@@ -2436,6 +2476,42 @@ func TestSkipResults(t *testing.T) {
24362476
})
24372477
}
24382478

2479+
func TestQueryMultipleResults(t *testing.T) {
2480+
ctx := context.Background()
2481+
runTestsWithMultiStatement(t, dsn, func(dbt *DBTest) {
2482+
dbt.mustExec(`
2483+
CREATE TABLE test (
2484+
id INT NOT NULL AUTO_INCREMENT,
2485+
value VARCHAR(255),
2486+
PRIMARY KEY (id)
2487+
)`)
2488+
conn, err := dbt.db.Conn(ctx)
2489+
if err != nil {
2490+
t.Fatalf("failed to connect: %v", err)
2491+
}
2492+
conn.Raw(func(conn interface{}) error {
2493+
qr := conn.(driver.Queryer)
2494+
2495+
c := conn.(*mysqlConn)
2496+
2497+
// Demonstrate that repeated queries reset the affectedRows
2498+
for i := 0; i < 2; i++ {
2499+
_, err := qr.Query(`
2500+
INSERT INTO test (value) VALUES ('a'), ('b');
2501+
INSERT INTO test (value) VALUES ('c'), ('d'), ('e');
2502+
`, nil)
2503+
if err != nil {
2504+
t.Fatalf("insert statements failed: %v", err)
2505+
}
2506+
if got, want := c.result.affectedRows, []int64{2, 3}; !reflect.DeepEqual(got, want) {
2507+
t.Errorf("bad affectedRows: got %v, want=%v", got, want)
2508+
}
2509+
}
2510+
return nil
2511+
})
2512+
})
2513+
}
2514+
24392515
func TestPingContext(t *testing.T) {
24402516
runTests(t, dsn, func(dbt *DBTest) {
24412517
ctx, cancel := context.WithCancel(context.Background())

infile.go

+2-2
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ func deferredClose(err *error, closer io.Closer) {
9393
}
9494
}
9595

96-
func (mc *mysqlConn) handleInFileRequest(name string) (err error) {
96+
func (mc *mysqlConn) handleInFileRequest(name string, resultState resultState) (err error) {
9797
var rdr io.Reader
9898
var data []byte
9999
packetSize := 16 * 1024 // 16KB is small enough for disk readahead and large enough for TCP
@@ -174,7 +174,7 @@ func (mc *mysqlConn) handleInFileRequest(name string) (err error) {
174174

175175
// read OK packet
176176
if err == nil {
177-
return mc.readResultOK()
177+
return mc.readResultOK(resultState)
178178
}
179179

180180
mc.readPacket()

packets.go

+25-15
Original file line numberDiff line numberDiff line change
@@ -496,7 +496,9 @@ func (mc *mysqlConn) readAuthResult() ([]byte, string, error) {
496496
switch data[0] {
497497

498498
case iOK:
499-
return nil, "", mc.handleOkPacket(data)
499+
// resultUnchanged, since auth happens before any queries or
500+
// commands have been executed.
501+
return nil, "", mc.handleOkPacket(data, resultUnchanged)
500502

501503
case iAuthMoreData:
502504
return data[1:], "", err
@@ -520,37 +522,37 @@ func (mc *mysqlConn) readAuthResult() ([]byte, string, error) {
520522
}
521523

522524
// Returns error if Packet is not an 'Result OK'-Packet
523-
func (mc *mysqlConn) readResultOK() error {
525+
func (mc *mysqlConn) readResultOK(resultState resultState) error {
524526
data, err := mc.readPacket()
525527
if err != nil {
526528
return err
527529
}
528530

529531
if data[0] == iOK {
530-
return mc.handleOkPacket(data)
532+
return mc.handleOkPacket(data, resultState)
531533
}
532534
return mc.handleErrorPacket(data)
533535
}
534536

535537
// Result Set Header Packet
536538
// http://dev.mysql.com/doc/internals/en/com-query-response.html#packet-ProtocolText::Resultset
537-
func (mc *mysqlConn) readResultSetHeaderPacket() (int, error) {
539+
func (mc *mysqlConn) readResultSetHeaderPacket(resultState resultState) (int, error) {
538540
// handleOkPacket replaces both values; other cases leave the values unchanged.
539-
mc.affectedRows = append(mc.affectedRows, 0)
540-
mc.insertIds = append(mc.insertIds, 0)
541+
mc.result.affectedRows = append(mc.result.affectedRows, 0)
542+
mc.result.insertIds = append(mc.result.insertIds, 0)
541543

542544
data, err := mc.readPacket()
543545
if err == nil {
544546
switch data[0] {
545547

546548
case iOK:
547-
return 0, mc.handleOkPacket(data)
549+
return 0, mc.handleOkPacket(data, resultState)
548550

549551
case iERR:
550552
return 0, mc.handleErrorPacket(data)
551553

552554
case iLocalInFile:
553-
return 0, mc.handleInFileRequest(string(data[1:]))
555+
return 0, mc.handleInFileRequest(string(data[1:]), resultState)
554556
}
555557

556558
// column count
@@ -613,7 +615,11 @@ func readStatus(b []byte) statusFlag {
613615

614616
// Ok Packet
615617
// http://dev.mysql.com/doc/internals/en/generic-response-packets.html#packet-OK_Packet
616-
func (mc *mysqlConn) handleOkPacket(data []byte) error {
618+
//
619+
// The resultState argument ensures that the caller has either cleared the
620+
// affectedRows and insertIds fields (by calling clearResult()) before
621+
// the call, or intentionally left them unchanged.
622+
func (mc *mysqlConn) handleOkPacket(data []byte, resultState resultState) error {
617623
var n, m int
618624
var affectedRows, insertId uint64
619625

@@ -627,11 +633,11 @@ func (mc *mysqlConn) handleOkPacket(data []byte) error {
627633

628634
// Update for the current statement result (only used by
629635
// readResultSetHeaderPacket).
630-
if len(mc.affectedRows) > 0 {
631-
mc.affectedRows[len(mc.affectedRows)-1] = int64(affectedRows)
636+
if len(mc.result.affectedRows) > 0 {
637+
mc.result.affectedRows[len(mc.result.affectedRows)-1] = int64(affectedRows)
632638
}
633-
if len(mc.insertIds) > 0 {
634-
mc.insertIds[len(mc.insertIds)-1] = int64(insertId)
639+
if len(mc.result.insertIds) > 0 {
640+
mc.result.insertIds[len(mc.result.insertIds)-1] = int64(insertId)
635641
}
636642

637643
// server_status [2 bytes]
@@ -1165,9 +1171,13 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error {
11651171

11661172
// For each remaining resultset in the stream, discards its rows and updates
11671173
// mc.affectedRows and mc.insertIds.
1168-
func (mc *mysqlConn) discardResults() error {
1174+
//
1175+
// The resultState argument ensures that the caller has either reset the
1176+
// affectedRows and insertIds counters before the call by calling
1177+
// resetStoredOKPackets(), or intentionally left them unchanged.
1178+
func (mc *mysqlConn) discardResults(resultState resultState) error {
11691179
for mc.status&statusMoreResultsExists != 0 {
1170-
resLen, err := mc.readResultSetHeaderPacket()
1180+
resLen, err := mc.readResultSetHeaderPacket(resultState)
11711181
if err != nil {
11721182
return err
11731183
}

0 commit comments

Comments
 (0)