Skip to content

Commit 37881ae

Browse files
committed
Exec() now provides access to the last inserted ID and number of affected rows
for all statements, not just the last one. This is useful to execute batches of statements such as UPDATE with minimal roundtrips. The approach taken is to track last insert id and affected rows using []int64 instead of a int64. Both are set in `mysqlResult`, and a new `mysql.Result` interface makes them accessible to callers calling `Exec()` via `sql.Conn.Raw`. For example: ``` conn.Raw(func(conn interface{}) error { ex := conn.(driver.Execer) res, err := ex.Exec(` UPDATE point SET x = 1 WHERE y = 2; UPDATE point SET x = 2 WHERE y = 3; `, nil) // Both slices have 2 elements. log.Print(res.(mysql.Result).AllRowsAffected()) log.Print(res.(mysql.Result).AllLastInsertIds()) }) ```
1 parent 6cf3092 commit 37881ae

File tree

6 files changed

+113
-16
lines changed

6 files changed

+113
-16
lines changed

README.md

+16
Original file line numberDiff line numberDiff line change
@@ -288,6 +288,22 @@ Allow multiple statements in one query. While this allows batch queries, it also
288288

289289
When `multiStatements` is used, `?` parameters must only be used in the first statement.
290290

291+
It's possible to access the last inserted ID and number of affected rows for multiple statements by using `sql.Conn.Raw()` and the `mysql.Result`. For example:
292+
293+
```
294+
conn, _ := db.Conn(ctx)
295+
conn.Raw(func(conn interface{}) error {
296+
ex := conn.(driver.Execer)
297+
res, err := ex.Exec(`
298+
UPDATE point SET x = 1 WHERE y = 2;
299+
UPDATE point SET x = 2 WHERE y = 3;
300+
`, nil)
301+
// Both slices have 2 elements.
302+
log.Print(res.(mysql.Result).AllRowsAffected())
303+
log.Print(res.(mysql.Result).AllLastInsertIds())
304+
})
305+
```
306+
291307
##### `parseTime`
292308

293309
```

connection.go

+6-6
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,8 @@ type mysqlConn struct {
2424
buf buffer
2525
netConn net.Conn
2626
rawConn net.Conn // underlying connection when netConn is TLS connection.
27-
affectedRows uint64
28-
insertId uint64
27+
affectedRows []int64
28+
insertIds []int64
2929
cfg *Config
3030
maxAllowedPacket int
3131
maxWriteSize int
@@ -310,14 +310,14 @@ func (mc *mysqlConn) Exec(query string, args []driver.Value) (driver.Result, err
310310
}
311311
query = prepared
312312
}
313-
mc.affectedRows = 0
314-
mc.insertId = 0
313+
mc.affectedRows = nil
314+
mc.insertIds = nil
315315

316316
err := mc.exec(query)
317317
if err == nil {
318318
return &mysqlResult{
319-
affectedRows: int64(mc.affectedRows),
320-
insertId: int64(mc.insertId),
319+
affectedRows: mc.affectedRows,
320+
insertIds: mc.insertIds,
321321
}, err
322322
}
323323
return nil, mc.markBadConn(err)

driver_test.go

+36
Original file line numberDiff line numberDiff line change
@@ -2379,6 +2379,42 @@ func TestMultiResultSetNoSelect(t *testing.T) {
23792379
})
23802380
}
23812381

2382+
func TestExecMultipleResults(t *testing.T) {
2383+
ctx := context.Background()
2384+
runTestsWithMultiStatement(t, dsn, func(dbt *DBTest) {
2385+
dbt.mustExec(`
2386+
CREATE TABLE test (
2387+
id INT NOT NULL AUTO_INCREMENT,
2388+
value VARCHAR(255),
2389+
PRIMARY KEY (id)
2390+
)`)
2391+
conn, err := dbt.db.Conn(ctx)
2392+
if err != nil {
2393+
t.Fatalf("failed to connect: %v", err)
2394+
}
2395+
conn.Raw(func(conn interface{}) error {
2396+
ex := conn.(driver.Execer)
2397+
res, err := ex.Exec(`
2398+
INSERT INTO test (value) VALUES ('a'), ('b');
2399+
INSERT INTO test (value) VALUES ('c'), ('d'), ('e');
2400+
`, nil)
2401+
if err != nil {
2402+
t.Fatalf("insert statements failed: %v", err)
2403+
}
2404+
mres := res.(Result)
2405+
if got, want := mres.AllRowsAffected(), []int64{2, 3}; !reflect.DeepEqual(got, want) {
2406+
t.Errorf("bad AllRowsAffected: got %v, want=%v", got, want)
2407+
}
2408+
// For INSERTs containing multiple rows, LAST_INSERT_ID() returns the
2409+
// first inserted ID, not the last.
2410+
if got, want := mres.AllLastInsertIds(), []int64{1, 3}; !reflect.DeepEqual(got, want) {
2411+
t.Errorf("bad AllLastInsertIds: got %v, want %v", got, want)
2412+
}
2413+
return nil
2414+
})
2415+
})
2416+
}
2417+
23822418
// tests if rows are set in a proper state if some results were ignored before
23832419
// calling rows.NextResultSet.
23842420
func TestSkipResults(t *testing.T) {

packets.go

+18-2
Original file line numberDiff line numberDiff line change
@@ -535,6 +535,10 @@ func (mc *mysqlConn) readResultOK() error {
535535
// Result Set Header Packet
536536
// http://dev.mysql.com/doc/internals/en/com-query-response.html#packet-ProtocolText::Resultset
537537
func (mc *mysqlConn) readResultSetHeaderPacket() (int, error) {
538+
// handleOkPacket replaces both values; other cases leave the values unchanged.
539+
mc.affectedRows = append(mc.affectedRows, 0)
540+
mc.insertIds = append(mc.insertIds, 0)
541+
538542
data, err := mc.readPacket()
539543
if err == nil {
540544
switch data[0] {
@@ -611,14 +615,24 @@ func readStatus(b []byte) statusFlag {
611615
// http://dev.mysql.com/doc/internals/en/generic-response-packets.html#packet-OK_Packet
612616
func (mc *mysqlConn) handleOkPacket(data []byte) error {
613617
var n, m int
618+
var affectedRows, insertID uint64
614619

615620
// 0x00 [1 byte]
616621

617622
// Affected rows [Length Coded Binary]
618-
mc.affectedRows, _, n = readLengthEncodedInteger(data[1:])
623+
affectedRows, _, n = readLengthEncodedInteger(data[1:])
619624

620625
// Insert id [Length Coded Binary]
621-
mc.insertId, _, m = readLengthEncodedInteger(data[1+n:])
626+
insertID, _, m = readLengthEncodedInteger(data[1+n:])
627+
628+
// Update for the current statement result (only used by
629+
// readResultSetHeaderPacket).
630+
if len(mc.affectedRows) > 0 {
631+
mc.affectedRows[len(mc.affectedRows)-1] = int64(affectedRows)
632+
}
633+
if len(mc.insertIds) > 0 {
634+
mc.insertIds[len(mc.insertIds)-1] = int64(insertID)
635+
}
622636

623637
// server_status [2 bytes]
624638
mc.status = readStatus(data[1+n+m : 1+n+m+2])
@@ -1149,6 +1163,8 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error {
11491163
return mc.writePacket(data)
11501164
}
11511165

1166+
// For each remaining resultset in the stream, discards its rows and updates
1167+
// mc.affectedRows and mc.insertIds.
11521168
func (mc *mysqlConn) discardResults() error {
11531169
for mc.status&statusMoreResultsExists != 0 {
11541170
resLen, err := mc.readResultSetHeaderPacket()

result.go

+33-4
Original file line numberDiff line numberDiff line change
@@ -8,15 +8,44 @@
88

99
package mysql
1010

11+
import "database/sql/driver"
12+
13+
// Result exposes data not available through *connection.Result.
14+
//
15+
// This is accessible by executing statements using sql.Conn.Raw() and
16+
// downcasting the returned result:
17+
//
18+
// res, err := rawConn.Exec(...)
19+
// res.(mysql.Result).AllRowsAffected()
20+
//
21+
type Result interface {
22+
driver.Result
23+
// AllRowsAffected returns a slice containing the affected rows for each
24+
// executed statement.
25+
AllRowsAffected() []int64
26+
// AllLastInsertIds returns a slice containing the last inserted ID for each
27+
// executed statement.
28+
AllLastInsertIds() []int64
29+
}
30+
1131
type mysqlResult struct {
12-
affectedRows int64
13-
insertId int64
32+
// One entry in both slices is created for every executed statement result.
33+
affectedRows []int64
34+
insertIds []int64
1435
}
1536

1637
func (res *mysqlResult) LastInsertId() (int64, error) {
17-
return res.insertId, nil
38+
return res.insertIds[len(res.insertIds)-1], nil
1839
}
1940

2041
func (res *mysqlResult) RowsAffected() (int64, error) {
21-
return res.affectedRows, nil
42+
return res.affectedRows[len(res.affectedRows)-1], nil
43+
}
44+
45+
func (res *mysqlResult) AllLastInsertIds() []int64 {
46+
return append([]int64{}, res.insertIds...) // defensive copy
47+
}
48+
49+
func (res *mysqlResult) AllRowsAffected() []int64 {
50+
return append([]int64{}, res.affectedRows...) // defensive copy
2251
}

statement.go

+4-4
Original file line numberDiff line numberDiff line change
@@ -62,8 +62,8 @@ func (stmt *mysqlStmt) Exec(args []driver.Value) (driver.Result, error) {
6262

6363
mc := stmt.mc
6464

65-
mc.affectedRows = 0
66-
mc.insertId = 0
65+
mc.affectedRows = nil
66+
mc.insertIds = nil
6767

6868
// Read Result
6969
resLen, err := mc.readResultSetHeaderPacket()
@@ -88,8 +88,8 @@ func (stmt *mysqlStmt) Exec(args []driver.Value) (driver.Result, error) {
8888
}
8989

9090
return &mysqlResult{
91-
affectedRows: int64(mc.affectedRows),
92-
insertId: int64(mc.insertId),
91+
affectedRows: mc.affectedRows,
92+
insertIds: mc.insertIds,
9393
}, nil
9494
}
9595

0 commit comments

Comments
 (0)