Skip to content

Commit 397e2f5

Browse files
Exec() now provides access to status of multiple statements. (#1309)
It now reports the last inserted ID and affected row count for all statements, not just the last one. This is useful to execute batches of statements such as UPDATE with minimal roundtrips. Co-authored-by: Inada Naoki <[email protected]>
1 parent f43effa commit 397e2f5

9 files changed

+259
-50
lines changed

README.md

+16
Original file line numberDiff line numberDiff line change
@@ -305,6 +305,22 @@ Allow multiple statements in one query. This can be used to bach multiple querie
305305

306306
When `multiStatements` is used, `?` parameters must only be used in the first statement. [interpolateParams](#interpolateparams) can be used to avoid this limitation unless prepared statement is used explicitly.
307307

308+
It's possible to access the last inserted ID and number of affected rows for multiple statements by using `sql.Conn.Raw()` and the `mysql.Result`. For example:
309+
310+
```go
311+
conn, _ := db.Conn(ctx)
312+
conn.Raw(func(conn interface{}) error {
313+
ex := conn.(driver.Execer)
314+
res, err := ex.Exec(`
315+
UPDATE point SET x = 1 WHERE y = 2;
316+
UPDATE point SET x = 2 WHERE y = 3;
317+
`, nil)
318+
// Both slices have 2 elements.
319+
log.Print(res.(mysql.Result).AllRowsAffected())
320+
log.Print(res.(mysql.Result).AllLastInsertIds())
321+
})
322+
```
323+
308324
##### `parseTime`
309325

310326
```

auth.go

+3-3
Original file line numberDiff line numberDiff line change
@@ -346,7 +346,7 @@ func (mc *mysqlConn) handleAuthResult(oldAuthData []byte, plugin string) error {
346346
case 1:
347347
switch authData[0] {
348348
case cachingSha2PasswordFastAuthSuccess:
349-
if err = mc.readResultOK(); err == nil {
349+
if err = mc.resultUnchanged().readResultOK(); err == nil {
350350
return nil // auth successful
351351
}
352352

@@ -397,7 +397,7 @@ func (mc *mysqlConn) handleAuthResult(oldAuthData []byte, plugin string) error {
397397
return err
398398
}
399399
}
400-
return mc.readResultOK()
400+
return mc.resultUnchanged().readResultOK()
401401

402402
default:
403403
return ErrMalformPkt
@@ -426,7 +426,7 @@ func (mc *mysqlConn) handleAuthResult(oldAuthData []byte, plugin string) error {
426426
if err != nil {
427427
return err
428428
}
429-
return mc.readResultOK()
429+
return mc.resultUnchanged().readResultOK()
430430
}
431431

432432
default:

connection.go

+15-14
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,8 @@ import (
2323
type mysqlConn struct {
2424
buf buffer
2525
netConn net.Conn
26-
rawConn net.Conn // underlying connection when netConn is TLS connection.
27-
affectedRows uint64
28-
insertId uint64
26+
rawConn net.Conn // underlying connection when netConn is TLS connection.
27+
result mysqlResult // managed by clearResult() and handleOkPacket().
2928
cfg *Config
3029
connector *connector
3130
maxAllowedPacket int
@@ -155,6 +154,7 @@ func (mc *mysqlConn) cleanup() {
155154
if err := mc.netConn.Close(); err != nil {
156155
mc.cfg.Logger.Print(err)
157156
}
157+
mc.clearResult()
158158
}
159159

160160
func (mc *mysqlConn) error() error {
@@ -316,28 +316,25 @@ func (mc *mysqlConn) Exec(query string, args []driver.Value) (driver.Result, err
316316
}
317317
query = prepared
318318
}
319-
mc.affectedRows = 0
320-
mc.insertId = 0
321319

322320
err := mc.exec(query)
323321
if err == nil {
324-
return &mysqlResult{
325-
affectedRows: int64(mc.affectedRows),
326-
insertId: int64(mc.insertId),
327-
}, err
322+
copied := mc.result
323+
return &copied, err
328324
}
329325
return nil, mc.markBadConn(err)
330326
}
331327

332328
// Internal function to execute commands
333329
func (mc *mysqlConn) exec(query string) error {
330+
handleOk := mc.clearResult()
334331
// Send command
335332
if err := mc.writeCommandPacketStr(comQuery, query); err != nil {
336333
return mc.markBadConn(err)
337334
}
338335

339336
// Read Result
340-
resLen, err := mc.readResultSetHeaderPacket()
337+
resLen, err := handleOk.readResultSetHeaderPacket()
341338
if err != nil {
342339
return err
343340
}
@@ -354,14 +351,16 @@ func (mc *mysqlConn) exec(query string) error {
354351
}
355352
}
356353

357-
return mc.discardResults()
354+
return handleOk.discardResults()
358355
}
359356

360357
func (mc *mysqlConn) Query(query string, args []driver.Value) (driver.Rows, error) {
361358
return mc.query(query, args)
362359
}
363360

364361
func (mc *mysqlConn) query(query string, args []driver.Value) (*textRows, error) {
362+
handleOk := mc.clearResult()
363+
365364
if mc.closed.Load() {
366365
mc.cfg.Logger.Print(ErrInvalidConn)
367366
return nil, driver.ErrBadConn
@@ -382,7 +381,7 @@ func (mc *mysqlConn) query(query string, args []driver.Value) (*textRows, error)
382381
if err == nil {
383382
// Read Result
384383
var resLen int
385-
resLen, err = mc.readResultSetHeaderPacket()
384+
resLen, err = handleOk.readResultSetHeaderPacket()
386385
if err == nil {
387386
rows := new(textRows)
388387
rows.mc = mc
@@ -410,12 +409,13 @@ func (mc *mysqlConn) query(query string, args []driver.Value) (*textRows, error)
410409
// The returned byte slice is only valid until the next read
411410
func (mc *mysqlConn) getSystemVar(name string) ([]byte, error) {
412411
// Send command
412+
handleOk := mc.clearResult()
413413
if err := mc.writeCommandPacketStr(comQuery, "SELECT @@"+name); err != nil {
414414
return nil, err
415415
}
416416

417417
// Read Result
418-
resLen, err := mc.readResultSetHeaderPacket()
418+
resLen, err := handleOk.readResultSetHeaderPacket()
419419
if err == nil {
420420
rows := new(textRows)
421421
rows.mc = mc
@@ -466,11 +466,12 @@ func (mc *mysqlConn) Ping(ctx context.Context) (err error) {
466466
}
467467
defer mc.finish()
468468

469+
handleOk := mc.clearResult()
469470
if err = mc.writeCommandPacket(comPing); err != nil {
470471
return mc.markBadConn(err)
471472
}
472473

473-
return mc.readResultOK()
474+
return handleOk.readResultOK()
474475
}
475476

476477
// BeginTx implements driver.ConnBeginTx interface

driver_test.go

+112
Original file line numberDiff line numberDiff line change
@@ -2154,11 +2154,51 @@ func TestRejectReadOnly(t *testing.T) {
21542154
}
21552155

21562156
func TestPing(t *testing.T) {
2157+
ctx := context.Background()
21572158
runTests(t, dsn, func(dbt *DBTest) {
21582159
if err := dbt.db.Ping(); err != nil {
21592160
dbt.fail("Ping", "Ping", err)
21602161
}
21612162
})
2163+
2164+
runTests(t, dsn, func(dbt *DBTest) {
2165+
conn, err := dbt.db.Conn(ctx)
2166+
if err != nil {
2167+
dbt.fail("db", "Conn", err)
2168+
}
2169+
2170+
// Check that affectedRows and insertIds are cleared after each call.
2171+
conn.Raw(func(conn interface{}) error {
2172+
c := conn.(*mysqlConn)
2173+
2174+
// Issue a query that sets affectedRows and insertIds.
2175+
q, err := c.Query(`SELECT 1`, nil)
2176+
if err != nil {
2177+
dbt.fail("Conn", "Query", err)
2178+
}
2179+
if got, want := c.result.affectedRows, []int64{0}; !reflect.DeepEqual(got, want) {
2180+
dbt.Fatalf("bad affectedRows: got %v, want=%v", got, want)
2181+
}
2182+
if got, want := c.result.insertIds, []int64{0}; !reflect.DeepEqual(got, want) {
2183+
dbt.Fatalf("bad insertIds: got %v, want=%v", got, want)
2184+
}
2185+
q.Close()
2186+
2187+
// Verify that Ping() clears both fields.
2188+
for i := 0; i < 2; i++ {
2189+
if err := c.Ping(ctx); err != nil {
2190+
dbt.fail("Pinger", "Ping", err)
2191+
}
2192+
if got, want := c.result.affectedRows, []int64(nil); !reflect.DeepEqual(got, want) {
2193+
t.Errorf("bad affectedRows: got %v, want=%v", got, want)
2194+
}
2195+
if got, want := c.result.insertIds, []int64(nil); !reflect.DeepEqual(got, want) {
2196+
t.Errorf("bad affectedRows: got %v, want=%v", got, want)
2197+
}
2198+
}
2199+
return nil
2200+
})
2201+
})
21622202
}
21632203

21642204
// See Issue #799
@@ -2378,6 +2418,42 @@ func TestMultiResultSetNoSelect(t *testing.T) {
23782418
})
23792419
}
23802420

2421+
func TestExecMultipleResults(t *testing.T) {
2422+
ctx := context.Background()
2423+
runTestsWithMultiStatement(t, dsn, func(dbt *DBTest) {
2424+
dbt.mustExec(`
2425+
CREATE TABLE test (
2426+
id INT NOT NULL AUTO_INCREMENT,
2427+
value VARCHAR(255),
2428+
PRIMARY KEY (id)
2429+
)`)
2430+
conn, err := dbt.db.Conn(ctx)
2431+
if err != nil {
2432+
t.Fatalf("failed to connect: %v", err)
2433+
}
2434+
conn.Raw(func(conn interface{}) error {
2435+
ex := conn.(driver.Execer)
2436+
res, err := ex.Exec(`
2437+
INSERT INTO test (value) VALUES ('a'), ('b');
2438+
INSERT INTO test (value) VALUES ('c'), ('d'), ('e');
2439+
`, nil)
2440+
if err != nil {
2441+
t.Fatalf("insert statements failed: %v", err)
2442+
}
2443+
mres := res.(Result)
2444+
if got, want := mres.AllRowsAffected(), []int64{2, 3}; !reflect.DeepEqual(got, want) {
2445+
t.Errorf("bad AllRowsAffected: got %v, want=%v", got, want)
2446+
}
2447+
// For INSERTs containing multiple rows, LAST_INSERT_ID() returns the
2448+
// first inserted ID, not the last.
2449+
if got, want := mres.AllLastInsertIds(), []int64{1, 3}; !reflect.DeepEqual(got, want) {
2450+
t.Errorf("bad AllLastInsertIds: got %v, want %v", got, want)
2451+
}
2452+
return nil
2453+
})
2454+
})
2455+
}
2456+
23812457
// tests if rows are set in a proper state if some results were ignored before
23822458
// calling rows.NextResultSet.
23832459
func TestSkipResults(t *testing.T) {
@@ -2399,6 +2475,42 @@ func TestSkipResults(t *testing.T) {
23992475
})
24002476
}
24012477

2478+
func TestQueryMultipleResults(t *testing.T) {
2479+
ctx := context.Background()
2480+
runTestsWithMultiStatement(t, dsn, func(dbt *DBTest) {
2481+
dbt.mustExec(`
2482+
CREATE TABLE test (
2483+
id INT NOT NULL AUTO_INCREMENT,
2484+
value VARCHAR(255),
2485+
PRIMARY KEY (id)
2486+
)`)
2487+
conn, err := dbt.db.Conn(ctx)
2488+
if err != nil {
2489+
t.Fatalf("failed to connect: %v", err)
2490+
}
2491+
conn.Raw(func(conn interface{}) error {
2492+
qr := conn.(driver.Queryer)
2493+
2494+
c := conn.(*mysqlConn)
2495+
2496+
// Demonstrate that repeated queries reset the affectedRows
2497+
for i := 0; i < 2; i++ {
2498+
_, err := qr.Query(`
2499+
INSERT INTO test (value) VALUES ('a'), ('b');
2500+
INSERT INTO test (value) VALUES ('c'), ('d'), ('e');
2501+
`, nil)
2502+
if err != nil {
2503+
t.Fatalf("insert statements failed: %v", err)
2504+
}
2505+
if got, want := c.result.affectedRows, []int64{2, 3}; !reflect.DeepEqual(got, want) {
2506+
t.Errorf("bad affectedRows: got %v, want=%v", got, want)
2507+
}
2508+
}
2509+
return nil
2510+
})
2511+
})
2512+
}
2513+
24022514
func TestPingContext(t *testing.T) {
24032515
runTests(t, dsn, func(dbt *DBTest) {
24042516
ctx, cancel := context.WithCancel(context.Background())

infile.go

+4-4
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ func deferredClose(err *error, closer io.Closer) {
9393

9494
const defaultPacketSize = 16 * 1024 // 16KB is small enough for disk readahead and large enough for TCP
9595

96-
func (mc *mysqlConn) handleInFileRequest(name string) (err error) {
96+
func (mc *okHandler) handleInFileRequest(name string) (err error) {
9797
var rdr io.Reader
9898
var data []byte
9999
packetSize := defaultPacketSize
@@ -154,7 +154,7 @@ func (mc *mysqlConn) handleInFileRequest(name string) (err error) {
154154
for err == nil {
155155
n, err = rdr.Read(data[4:])
156156
if n > 0 {
157-
if ioErr := mc.writePacket(data[:4+n]); ioErr != nil {
157+
if ioErr := mc.conn().writePacket(data[:4+n]); ioErr != nil {
158158
return ioErr
159159
}
160160
}
@@ -168,7 +168,7 @@ func (mc *mysqlConn) handleInFileRequest(name string) (err error) {
168168
if data == nil {
169169
data = make([]byte, 4)
170170
}
171-
if ioErr := mc.writePacket(data[:4]); ioErr != nil {
171+
if ioErr := mc.conn().writePacket(data[:4]); ioErr != nil {
172172
return ioErr
173173
}
174174

@@ -177,6 +177,6 @@ func (mc *mysqlConn) handleInFileRequest(name string) (err error) {
177177
return mc.readResultOK()
178178
}
179179

180-
mc.readPacket()
180+
mc.conn().readPacket()
181181
return err
182182
}

0 commit comments

Comments
 (0)