Skip to content

Commit 3e559a8

Browse files
author
Brigitte Lamarche
committed
saving work with SimpleReader present
1 parent f339392 commit 3e559a8

8 files changed

+99
-48
lines changed

benchmark_test.go

+3-2
Original file line numberDiff line numberDiff line change
@@ -231,9 +231,10 @@ func BenchmarkInterpolation(b *testing.B) {
231231
},
232232
maxAllowedPacket: maxPacketSize,
233233
maxWriteSize: maxPacketSize - 1,
234-
buf: newBuffer(nil),
235234
}
236-
mc.reader = &mc.buf
235+
236+
buf := newBuffer(nil)
237+
mc.reader = newSimpleReader(&buf)
237238

238239
args := []driver.Value{
239240
int64(42424242),

buffer.go

+10-4
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ const defaultBufSize = 4096
2121
// In other words, we can't write and read simultaneously on the same connection.
2222
// The buffer is similar to bufio.Reader / Writer but zero-copy-ish
2323
// Also highly optimized for this particular use case.
24-
type buffer struct {
24+
type buffer struct { //PROBLEM: figure this all out better
2525
buf []byte
2626
nc net.Conn
2727
idx int
@@ -49,7 +49,7 @@ func (b *buffer) fill(need int) error {
4949
// grow buffer if necessary
5050
// TODO: let the buffer shrink again at some point
5151
// Maybe keep the org buf slice and swap back?
52-
if need > len(b.buf) {
52+
if need > len(b.buf) { //look up what len and cap mean again!
5353
// Round up to the next multiple of the default size
5454
newBuf := make([]byte, ((need/defaultBufSize)+1)*defaultBufSize)
5555
copy(newBuf, b.buf)
@@ -92,6 +92,10 @@ func (b *buffer) fill(need int) error {
9292
// returns next N bytes from buffer.
9393
// The returned slice is only guaranteed to be valid until the next read
9494
func (b *buffer) readNext(need int) ([]byte, error) {
95+
if need == -1 {
96+
return b.takeCompleteBuffer()
97+
}
98+
9599
if b.length < need {
96100
// refill
97101
if err := b.fill(need); err != nil {
@@ -110,7 +114,7 @@ func (b *buffer) readNext(need int) ([]byte, error) {
110114
// Otherwise a bigger buffer is made.
111115
// Only one buffer (total) can be used at a time.
112116
func (b *buffer) takeBuffer(length int) []byte {
113-
if b.length > 0 {
117+
if b.length > 0 { //assume its empty
114118
return nil
115119
}
116120

@@ -126,15 +130,17 @@ func (b *buffer) takeBuffer(length int) []byte {
126130
return make([]byte, length)
127131
}
128132

133+
/*
129134
// shortcut which can be used if the requested buffer is guaranteed to be
130135
// smaller than defaultBufSize
131136
// Only one buffer (total) can be used at a time.
132137
func (b *buffer) takeSmallBuffer(length int) []byte {
133-
if b.length == 0 {
138+
if b.length == 0 { //assume its empty
134139
return b.buf[:length]
135140
}
136141
return nil
137142
}
143+
*/
138144

139145
// takeCompleteBuffer returns the complete existing buffer.
140146
// This can be used if the necessary buffer size is unknown.

compress.go

+17-2
Original file line numberDiff line numberDiff line change
@@ -11,26 +11,37 @@ const (
1111
)
1212

1313
type compressedReader struct {
14-
buf packetReader
14+
buf *buffer //packetReader
1515
bytesBuf []byte
1616
mc *mysqlConn
1717
zr io.ReadCloser
1818
}
1919

20+
21+
type simpleReader struct {
22+
buf *buffer //packetReader
23+
}
24+
2025
type compressedWriter struct {
2126
connWriter io.Writer
2227
mc *mysqlConn
2328
zw *zlib.Writer
2429
}
2530

26-
func newCompressedReader(buf packetReader, mc *mysqlConn) *compressedReader {
31+
func newCompressedReader(buf *buffer, mc *mysqlConn) *compressedReader {
2732
return &compressedReader{
2833
buf: buf,
2934
bytesBuf: make([]byte, 0),
3035
mc: mc,
3136
}
3237
}
3338

39+
func newSimpleReader(buf *buffer) *simpleReader {
40+
return &simpleReader{
41+
buf: buf,
42+
}
43+
}
44+
3445
func newCompressedWriter(connWriter io.Writer, mc *mysqlConn) *compressedWriter {
3546
return &compressedWriter{
3647
connWriter: connWriter,
@@ -52,6 +63,10 @@ func (cr *compressedReader) readNext(need int) ([]byte, error) {
5263
return data, nil
5364
}
5465

66+
func (sr *simpleReader) readNext(need int) ([]byte, error) {
67+
return sr.buf.readNext(need)
68+
}
69+
5570
func (cr *compressedReader) uncompressPacket() error {
5671
header, err := cr.buf.readNext(7) // size of compressed header
5772

connection.go

+17-2
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@ type mysqlContext interface {
2828
}
2929

3030
type mysqlConn struct {
31-
buf buffer
3231
netConn net.Conn
3332
reader packetReader
3433
writer io.Writer
@@ -57,6 +56,18 @@ type packetReader interface {
5756
readNext(need int) ([]byte, error)
5857
}
5958

59+
/*
60+
type packetReadCloser interface{
61+
Read(n int) ([]byte, error)
62+
Close() error // PROBLEM: is there a way to do this?
63+
}
64+
65+
type packetWriteCloser interface{
66+
Write([]byte) (int, error)
67+
Close() error
68+
}
69+
*/
70+
6071
// Handles parameters set in DSN after the connection is established
6172
func (mc *mysqlConn) handleParams() (err error) {
6273
for param, val := range mc.cfg.Params {
@@ -197,7 +208,11 @@ func (mc *mysqlConn) interpolateParams(query string, args []driver.Value) (strin
197208
return "", driver.ErrSkip
198209
}
199210

200-
buf := mc.buf.takeCompleteBuffer()
211+
//https://stackoverflow.com/questions/29684609/how-to-check-if-an-object-has-a-particular-method
212+
213+
//reader has &buf which is a packetreader but also always a buffer
214+
buf, _ := mc.reader.readNext(-1) //PROBLEM uncompressed so this works, what if compressed
215+
201216
if buf == nil {
202217
// can not take the buffer. Something must be wrong with the connection
203218
errLog.Print(ErrBusyBuffer)

connection_test.go

+6-6
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,13 @@ import (
1515

1616
func TestInterpolateParams(t *testing.T) {
1717
mc := &mysqlConn{
18-
buf: newBuffer(nil),
1918
maxAllowedPacket: maxPacketSize,
2019
cfg: &Config{
2120
InterpolateParams: true,
2221
},
2322
}
24-
mc.reader = &mc.buf
23+
buf := newBuffer(nil)
24+
mc.reader = newSimpleReader(&buf)
2525

2626
q, err := mc.interpolateParams("SELECT ?+?", []driver.Value{int64(42), "gopher"})
2727
if err != nil {
@@ -36,13 +36,13 @@ func TestInterpolateParams(t *testing.T) {
3636

3737
func TestInterpolateParamsTooManyPlaceholders(t *testing.T) {
3838
mc := &mysqlConn{
39-
buf: newBuffer(nil),
4039
maxAllowedPacket: maxPacketSize,
4140
cfg: &Config{
4241
InterpolateParams: true,
4342
},
4443
}
45-
mc.reader = &mc.buf
44+
buf := newBuffer(nil)
45+
mc.reader = newSimpleReader(&buf)
4646

4747
q, err := mc.interpolateParams("SELECT ?+?", []driver.Value{int64(42)})
4848
if err != driver.ErrSkip {
@@ -54,14 +54,14 @@ func TestInterpolateParamsTooManyPlaceholders(t *testing.T) {
5454
// https://github.com/go-sql-driver/mysql/pull/490
5555
func TestInterpolateParamsPlaceholderInString(t *testing.T) {
5656
mc := &mysqlConn{
57-
buf: newBuffer(nil),
5857
maxAllowedPacket: maxPacketSize,
5958
cfg: &Config{
6059
InterpolateParams: true,
6160
},
6261
}
6362

64-
mc.reader = &mc.buf
63+
buf := newBuffer(nil)
64+
mc.reader = newSimpleReader(&buf)
6565

6666
q, err := mc.interpolateParams("SELECT 'abc?xyz',?", []driver.Value{int64(42)})
6767
// When InterpolateParams support string literal, this should return `"SELECT 'abc?xyz', 42`

driver.go

+7-6
Original file line numberDiff line numberDiff line change
@@ -91,15 +91,16 @@ func (d MySQLDriver) Open(dsn string) (driver.Conn, error) {
9191
s.startWatcher()
9292
}
9393

94-
mc.buf = newBuffer(mc.netConn)
94+
buf := newBuffer(mc.netConn)
95+
96+
// Set I/O timeouts
97+
buf.timeout = mc.cfg.ReadTimeout
98+
mc.writeTimeout = mc.cfg.WriteTimeout
9599

96100
// packet reader and writer in handshake are never compressed
97-
mc.reader = &mc.buf
101+
mc.reader = newSimpleReader(&buf)
98102
mc.writer = mc.netConn
99103

100-
// Set I/O timeouts
101-
mc.buf.timeout = mc.cfg.ReadTimeout
102-
mc.writeTimeout = mc.cfg.WriteTimeout
103104

104105
// Reading Handshake Initialization Packet
105106
cipher, err := mc.readInitPacket()
@@ -124,7 +125,7 @@ func (d MySQLDriver) Open(dsn string) (driver.Conn, error) {
124125
}
125126

126127
if mc.cfg.Compress {
127-
mc.reader = newCompressedReader(&mc.buf, mc)
128+
mc.reader = newCompressedReader(&buf, mc)
128129
mc.writer = newCompressedWriter(mc.writer, mc)
129130
}
130131

packets.go

+25-13
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ func (mc *mysqlConn) readPacket() ([]byte, error) {
2828
var prevData []byte
2929
for {
3030
// read packet header
31-
data, err := mc.reader.readNext(4)
31+
data, err := mc.reader.readNext(4)
3232
if err != nil {
3333
if cerr := mc.canceled.Value(); cerr != nil {
3434
return nil, cerr
@@ -64,7 +64,7 @@ func (mc *mysqlConn) readPacket() ([]byte, error) {
6464
}
6565

6666
// read packet body [pktLen bytes]
67-
data, err = mc.reader.readNext(pktLen)
67+
data, err = mc.reader.readNext(pktLen)
6868
if err != nil {
6969
if cerr := mc.canceled.Value(); cerr != nil {
7070
return nil, cerr
@@ -283,7 +283,8 @@ func (mc *mysqlConn) writeAuthPacket(cipher []byte) error {
283283
}
284284

285285
// Calculate packet length and get buffer with that size
286-
data := mc.buf.takeSmallBuffer(pktLen + 4)
286+
data, _ := mc.reader.readNext(pktLen + 4)
287+
287288
if data == nil {
288289
// can not take the buffer. Something must be wrong with the connection
289290
errLog.Print(ErrBusyBuffer)
@@ -326,8 +327,12 @@ func (mc *mysqlConn) writeAuthPacket(cipher []byte) error {
326327
return err
327328
}
328329
mc.netConn = tlsConn
329-
mc.buf.nc = tlsConn
330+
nc := tlsConn
330331

332+
// make newBuffer with tls conn, clean slate bc handshake
333+
newBuf := newBuffer(nc)
334+
mc.reader = newSimpleReader(&newBuf)
335+
331336
mc.writer = mc.netConn
332337
}
333338

@@ -373,7 +378,8 @@ func (mc *mysqlConn) writeOldAuthPacket(cipher []byte) error {
373378

374379
// Calculate the packet length and add a tailing 0
375380
pktLen := len(scrambleBuff) + 1
376-
data := mc.buf.takeSmallBuffer(4 + pktLen)
381+
data, _ := mc.reader.readNext(4 + pktLen)
382+
377383
if data == nil {
378384
// can not take the buffer. Something must be wrong with the connection
379385
errLog.Print(ErrBusyBuffer)
@@ -392,7 +398,8 @@ func (mc *mysqlConn) writeOldAuthPacket(cipher []byte) error {
392398
func (mc *mysqlConn) writeClearAuthPacket() error {
393399
// Calculate the packet length and add a tailing 0
394400
pktLen := len(mc.cfg.Passwd) + 1
395-
data := mc.buf.takeSmallBuffer(4 + pktLen)
401+
data, _ := mc.reader.readNext(4 + pktLen)
402+
396403
if data == nil {
397404
// can not take the buffer. Something must be wrong with the connection
398405
errLog.Print(ErrBusyBuffer)
@@ -415,7 +422,8 @@ func (mc *mysqlConn) writeNativeAuthPacket(cipher []byte) error {
415422

416423
// Calculate the packet length and add a tailing 0
417424
pktLen := len(scrambleBuff)
418-
data := mc.buf.takeSmallBuffer(4 + pktLen)
425+
data, _ := mc.reader.readNext(4 + pktLen)
426+
419427
if data == nil {
420428
// can not take the buffer. Something must be wrong with the connection
421429
errLog.Print(ErrBusyBuffer)
@@ -437,7 +445,8 @@ func (mc *mysqlConn) writeCommandPacket(command byte) error {
437445
mc.sequence = 0
438446
mc.compressionSequence = 0
439447

440-
data := mc.buf.takeSmallBuffer(4 + 1)
448+
data, _ := mc.reader.readNext(4+1)
449+
441450
if data == nil {
442451
// can not take the buffer. Something must be wrong with the connection
443452
errLog.Print(ErrBusyBuffer)
@@ -457,7 +466,8 @@ func (mc *mysqlConn) writeCommandPacketStr(command byte, arg string) error {
457466
mc.compressionSequence = 0
458467

459468
pktLen := 1 + len(arg)
460-
data := mc.buf.takeBuffer(pktLen + 4)
469+
data, _ := mc.reader.readNext(pktLen + 4)
470+
461471
if data == nil {
462472
// can not take the buffer. Something must be wrong with the connection
463473
errLog.Print(ErrBusyBuffer)
@@ -479,7 +489,8 @@ func (mc *mysqlConn) writeCommandPacketUint32(command byte, arg uint32) error {
479489
mc.sequence = 0
480490
mc.compressionSequence = 0
481491

482-
data := mc.buf.takeSmallBuffer(4 + 1 + 4)
492+
data, _ := mc.reader.readNext(4 + 1 + 4)
493+
483494
if data == nil {
484495
// can not take the buffer. Something must be wrong with the connection
485496
errLog.Print(ErrBusyBuffer)
@@ -946,9 +957,9 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error {
946957
var data []byte
947958

948959
if len(args) == 0 {
949-
data = mc.buf.takeBuffer(minPktLen)
960+
data, _ = mc.reader.readNext(minPktLen)
950961
} else {
951-
data = mc.buf.takeCompleteBuffer()
962+
data, _ = mc.reader.readNext(-1) //how does this work out with compressed?
952963
}
953964
if data == nil {
954965
// can not take the buffer. Something must be wrong with the connection
@@ -1127,7 +1138,8 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error {
11271138
// In that case we must build the data packet with the new values buffer
11281139
if valuesCap != cap(paramValues) {
11291140
data = append(data[:pos], paramValues...)
1130-
mc.buf.buf = data
1141+
readerBuffer := mc.reader.getBuffer() //PROBLEM: what the fuck is going on here??
1142+
readerBuffer.buf = data
11311143
}
11321144

11331145
pos += len(paramValues)

0 commit comments

Comments
 (0)