From 2c254fee893f1db69b5fcc3d00eff51854ff8d36 Mon Sep 17 00:00:00 2001 From: INADA Naoki Date: Tue, 30 May 2017 13:21:25 +0900 Subject: [PATCH] defer getWarnings() after fetching resultsets. --- connection.go | 1 + driver.go | 4 ++++ errors.go | 1 + infile.go | 3 +++ packets.go | 12 ++++++++++++ 5 files changed, 21 insertions(+) diff --git a/connection.go b/connection.go index cdce3e30f..dc399f62b 100644 --- a/connection.go +++ b/connection.go @@ -31,6 +31,7 @@ type mysqlConn struct { sequence uint8 parseTime bool strict bool + warningCount uint16 } // Handles parameters set in DSN after the connection is established diff --git a/driver.go b/driver.go index e51d98a3c..6a690d133 100644 --- a/driver.go +++ b/driver.go @@ -175,6 +175,10 @@ func handleAuthResult(mc *mysqlConn, oldCipher []byte) error { } _, err = mc.readResultOK() } + + if err == nil && mc.strict && mc.warningCount > 0 { + return mc.getWarnings() + } return err } diff --git a/errors.go b/errors.go index 857854e14..bc6450d10 100644 --- a/errors.go +++ b/errors.go @@ -89,6 +89,7 @@ type MySQLWarning struct { } func (mc *mysqlConn) getWarnings() (err error) { + mc.warningCount = 0 rows, err := mc.Query("SHOW WARNINGS", nil) if err != nil { return diff --git a/infile.go b/infile.go index 4020f9192..6a89cea26 100644 --- a/infile.go +++ b/infile.go @@ -175,6 +175,9 @@ func (mc *mysqlConn) handleInFileRequest(name string) (err error) { // read OK packet if err == nil { _, err = mc.readResultOK() + if err == nil && mc.strict && mc.warningCount > 0 { + err = mc.getWarnings() + } return err } diff --git a/packets.go b/packets.go index 303405a17..88a37f2d8 100644 --- a/packets.go +++ b/packets.go @@ -606,10 +606,12 @@ func (mc *mysqlConn) handleOkPacket(data []byte) error { // warning count [2 bytes] if !mc.strict { + mc.warningCount = 0 return nil } pos := 1 + n + m + 2 + mc.warningCount = binary.LittleEndian.Uint16(data[pos : pos+2]) if binary.LittleEndian.Uint16(data[pos:pos+2]) > 0 { return mc.getWarnings() } @@ -729,7 +731,14 @@ func (rows *textRows) readRow(dest []driver.Value) error { rows.mc.status = readStatus(data[3:]) rows.rs.done = true if !rows.HasNextResultSet() { + mc := rows.mc rows.mc = nil + + if mc.strict && mc.warningCount > 0 { + if err := mc.getWarnings(); err != nil { + return err + } + } } return io.EOF } @@ -1115,6 +1124,9 @@ func (mc *mysqlConn) discardResults() error { } } } + if mc.strict && mc.warningCount > 0 { + return mc.getWarnings() + } return nil }