Skip to content

Commit 0bc8145

Browse files
committed
allow returning ErrBadConn on compression
1 parent 25cf587 commit 0bc8145

File tree

2 files changed

+10
-11
lines changed

2 files changed

+10
-11
lines changed

compress.go

+9-7
Original file line numberDiff line numberDiff line change
@@ -199,8 +199,10 @@ func (c *compIO) writePackets(packets []byte) (int, error) {
199199
}
200200
}
201201

202-
if err := c.writeCompressedPacket(buf.Bytes(), uncompressedLen); err != nil {
203-
return 0, err
202+
if n, err := c.writeCompressedPacket(buf.Bytes(), uncompressedLen); err != nil {
203+
// To allow returning ErrBadConn when sending really 0 bytes, we sum
204+
// up compressed bytes that is returned by underlying Write().
205+
return totalBytes - len(packets) + n, err
204206
}
205207
dataLen -= payloadLen
206208
packets = packets[payloadLen:]
@@ -211,7 +213,7 @@ func (c *compIO) writePackets(packets []byte) (int, error) {
211213

212214
// writeCompressedPacket writes a compressed packet with header.
213215
// data should start with 7 size space for header followed by payload.
214-
func (c *compIO) writeCompressedPacket(data []byte, uncompressedLen int) error {
216+
func (c *compIO) writeCompressedPacket(data []byte, uncompressedLen int) (int, error) {
215217
mc := c.mc
216218
comprLength := len(data) - 7
217219
if debugTrace {
@@ -225,11 +227,11 @@ func (c *compIO) writeCompressedPacket(data []byte, uncompressedLen int) error {
225227
data[3] = mc.compressSequence
226228
putUint24(data[4:7], uncompressedLen)
227229

228-
if _, err := mc.writeWithTimeout(data); err != nil {
229-
mc.log("writing compressed packet:", err)
230-
return err
230+
if n, err := mc.writeWithTimeout(data); err != nil {
231+
// mc.log("writing compressed packet:", err)
232+
return n, err
231233
}
232234

233235
mc.compressSequence++
234-
return nil
236+
return n, nil
235237
}

connection_test.go

+1-4
Original file line numberDiff line numberDiff line change
@@ -159,11 +159,9 @@ func TestCleanCancel(t *testing.T) {
159159

160160
func TestPingMarkBadConnection(t *testing.T) {
161161
nc := badConnection{err: errors.New("boom")}
162-
163-
buf := newBuffer()
164162
mc := &mysqlConn{
165163
netConn: nc,
166-
buf: buf,
164+
buf: newBuffer(),
167165
maxAllowedPacket: defaultMaxAllowedPacket,
168166
closech: make(chan struct{}),
169167
cfg: NewConfig(),
@@ -178,7 +176,6 @@ func TestPingMarkBadConnection(t *testing.T) {
178176

179177
func TestPingErrInvalidConn(t *testing.T) {
180178
nc := badConnection{err: errors.New("failed to write"), n: 10}
181-
182179
mc := &mysqlConn{
183180
netConn: nc,
184181
buf: newBuffer(),

0 commit comments

Comments
 (0)