@@ -5,6 +5,7 @@ package tarantool
5
5
import (
6
6
"bufio"
7
7
"bytes"
8
+ "context"
8
9
"errors"
9
10
"fmt"
10
11
"io"
@@ -125,8 +126,9 @@ type Connection struct {
125
126
c net.Conn
126
127
mutex sync.Mutex
127
128
// Schema contains schema loaded on connection.
128
- Schema * Schema
129
- requestId uint32
129
+ Schema * Schema
130
+ requestId uint32
131
+ contextRequestId uint32
130
132
// Greeting contains first message sent by Tarantool.
131
133
Greeting * Greeting
132
134
@@ -143,16 +145,57 @@ type Connection struct {
143
145
144
146
var _ = Connector (& Connection {}) // Check compatibility with connector interface.
145
147
148
+ type futureList struct {
149
+ first * Future
150
+ last * * Future
151
+ }
152
+
153
+ func (list * futureList ) findFuture (reqid uint32 , fetch bool ) * Future {
154
+ root := & list .first
155
+ for {
156
+ fut := * root
157
+ if fut == nil {
158
+ return nil
159
+ }
160
+ if fut .requestId == reqid {
161
+ if fetch {
162
+ * root = fut .next
163
+ if fut .next == nil {
164
+ list .last = root
165
+ } else {
166
+ fut .next = nil
167
+ }
168
+ }
169
+ return fut
170
+ }
171
+ root = & fut .next
172
+ }
173
+ }
174
+
175
+ func (list * futureList ) addFuture (fut * Future ) {
176
+ * list .last = fut
177
+ list .last = & fut .next
178
+ }
179
+
180
+ func (list * futureList ) clear (err error , conn * Connection ) {
181
+ fut := list .first
182
+ list .first = nil
183
+ list .last = & list .first
184
+ for fut != nil {
185
+ fut .SetError (err )
186
+ conn .markDone (fut )
187
+ fut , fut .next = fut .next , nil
188
+ }
189
+ }
190
+
146
191
type connShard struct {
147
- rmut sync.Mutex
148
- requests [requestsMap ]struct {
149
- first * Future
150
- last * * Future
151
- }
152
- bufmut sync.Mutex
153
- buf smallWBuf
154
- enc * msgpack.Encoder
155
- _pad [16 ]uint64 //nolint: unused,structcheck
192
+ rmut sync.Mutex
193
+ requests [requestsMap ]futureList
194
+ requestsWithCtx [requestsMap ]futureList
195
+ bufmut sync.Mutex
196
+ buf smallWBuf
197
+ enc * msgpack.Encoder
198
+ _pad [16 ]uint64 //nolint: unused,structcheck
156
199
}
157
200
158
201
// Greeting is a message sent by Tarantool on connect.
@@ -262,12 +305,13 @@ type SslOpts struct {
262
305
// and will not finish to make attempts on authorization failures.
263
306
func Connect (addr string , opts Opts ) (conn * Connection , err error ) {
264
307
conn = & Connection {
265
- addr : addr ,
266
- requestId : 0 ,
267
- Greeting : & Greeting {},
268
- control : make (chan struct {}),
269
- opts : opts ,
270
- dec : msgpack .NewDecoder (& smallBuf {}),
308
+ addr : addr ,
309
+ requestId : 0 ,
310
+ contextRequestId : 1 ,
311
+ Greeting : & Greeting {},
312
+ control : make (chan struct {}),
313
+ opts : opts ,
314
+ dec : msgpack .NewDecoder (& smallBuf {}),
271
315
}
272
316
maxprocs := uint32 (runtime .GOMAXPROCS (- 1 ))
273
317
if conn .opts .Concurrency == 0 || conn .opts .Concurrency > maxprocs * 128 {
@@ -286,6 +330,9 @@ func Connect(addr string, opts Opts) (conn *Connection, err error) {
286
330
for j := range shard .requests {
287
331
shard .requests [j ].last = & shard .requests [j ].first
288
332
}
333
+ for j := range shard .requests {
334
+ shard .requestsWithCtx [j ].last = & shard .requestsWithCtx [j ].first
335
+ }
289
336
}
290
337
291
338
if opts .RateLimit > 0 {
@@ -387,6 +434,17 @@ func (conn *Connection) Handle() interface{} {
387
434
return conn .opts .Handle
388
435
}
389
436
437
+ func (conn * Connection ) cancelFuture (fut * Future , err error ) error {
438
+ if fut == nil {
439
+ return fmt .Errorf ("passed nil future" )
440
+ }
441
+ if fut = conn .fetchFuture (fut .requestId ); fut != nil {
442
+ fut .SetError (err )
443
+ conn .markDone (fut )
444
+ }
445
+ return nil
446
+ }
447
+
390
448
func (conn * Connection ) dial () (err error ) {
391
449
var connection net.Conn
392
450
network := "tcp"
@@ -582,14 +640,11 @@ func (conn *Connection) closeConnection(neterr error, forever bool) (err error)
582
640
conn .shard [i ].buf .Reset ()
583
641
requests := & conn .shard [i ].requests
584
642
for pos := range requests {
585
- fut := requests [pos ].first
586
- requests [pos ].first = nil
587
- requests [pos ].last = & requests [pos ].first
588
- for fut != nil {
589
- fut .SetError (neterr )
590
- conn .markDone (fut )
591
- fut , fut .next = fut .next , nil
592
- }
643
+ requests [pos ].clear (neterr , conn )
644
+ }
645
+ requestsWithCtx := & conn .shard [i ].requestsWithCtx
646
+ for pos := range requestsWithCtx {
647
+ requestsWithCtx [pos ].clear (neterr , conn )
593
648
}
594
649
}
595
650
return
@@ -721,7 +776,7 @@ func (conn *Connection) reader(r *bufio.Reader, c net.Conn) {
721
776
}
722
777
}
723
778
724
- func (conn * Connection ) newFuture () (fut * Future ) {
779
+ func (conn * Connection ) newFuture (ctx context. Context ) (fut * Future ) {
725
780
fut = NewFuture ()
726
781
if conn .rlimit != nil && conn .opts .RLimitAction == RLimitDrop {
727
782
select {
@@ -736,7 +791,7 @@ func (conn *Connection) newFuture() (fut *Future) {
736
791
return
737
792
}
738
793
}
739
- fut .requestId = conn .nextRequestId ()
794
+ fut .requestId = conn .nextRequestId (ctx != nil )
740
795
shardn := fut .requestId & (conn .opts .Concurrency - 1 )
741
796
shard := & conn .shard [shardn ]
742
797
shard .rmut .Lock ()
@@ -761,11 +816,20 @@ func (conn *Connection) newFuture() (fut *Future) {
761
816
return
762
817
}
763
818
pos := (fut .requestId / conn .opts .Concurrency ) & (requestsMap - 1 )
764
- pair := & shard .requests [pos ]
765
- * pair .last = fut
766
- pair .last = & fut .next
767
- if conn .opts .Timeout > 0 {
768
- fut .timeout = time .Since (epoch ) + conn .opts .Timeout
819
+ if ctx != nil {
820
+ select {
821
+ case <- ctx .Done ():
822
+ fut .SetError (fmt .Errorf ("context is done" ))
823
+ shard .rmut .Unlock ()
824
+ return
825
+ default :
826
+ }
827
+ shard .requestsWithCtx [pos ].addFuture (fut )
828
+ } else {
829
+ shard .requests [pos ].addFuture (fut )
830
+ if conn .opts .Timeout > 0 {
831
+ fut .timeout = time .Since (epoch ) + conn .opts .Timeout
832
+ }
769
833
}
770
834
shard .rmut .Unlock ()
771
835
if conn .rlimit != nil && conn .opts .RLimitAction == RLimitWait {
@@ -785,12 +849,40 @@ func (conn *Connection) newFuture() (fut *Future) {
785
849
return
786
850
}
787
851
852
+ func (conn * Connection ) contextWatchdog (fut * Future , ctx context.Context ) {
853
+ select {
854
+ case <- fut .done :
855
+ default :
856
+ select {
857
+ case <- ctx .Done ():
858
+ conn .cancelFuture (fut , fmt .Errorf ("context is done" ))
859
+ default :
860
+ select {
861
+ case <- fut .done :
862
+ case <- ctx .Done ():
863
+ conn .cancelFuture (fut , fmt .Errorf ("context is done" ))
864
+ }
865
+ }
866
+ }
867
+ }
868
+
788
869
func (conn * Connection ) send (req Request ) * Future {
789
- fut := conn .newFuture ()
870
+ fut := conn .newFuture (req . Ctx () )
790
871
if fut .ready == nil {
791
872
return fut
792
873
}
874
+ if req .Ctx () != nil {
875
+ select {
876
+ case <- req .Ctx ().Done ():
877
+ conn .cancelFuture (fut , fmt .Errorf ("context is done" ))
878
+ return fut
879
+ default :
880
+ }
881
+ }
793
882
conn .putFuture (fut , req )
883
+ if req .Ctx () != nil {
884
+ go conn .contextWatchdog (fut , req .Ctx ())
885
+ }
794
886
return fut
795
887
}
796
888
@@ -877,25 +969,10 @@ func (conn *Connection) fetchFuture(reqid uint32) (fut *Future) {
877
969
func (conn * Connection ) getFutureImp (reqid uint32 , fetch bool ) * Future {
878
970
shard := & conn .shard [reqid & (conn .opts .Concurrency - 1 )]
879
971
pos := (reqid / conn .opts .Concurrency ) & (requestsMap - 1 )
880
- pair := & shard .requests [pos ]
881
- root := & pair .first
882
- for {
883
- fut := * root
884
- if fut == nil {
885
- return nil
886
- }
887
- if fut .requestId == reqid {
888
- if fetch {
889
- * root = fut .next
890
- if fut .next == nil {
891
- pair .last = root
892
- } else {
893
- fut .next = nil
894
- }
895
- }
896
- return fut
897
- }
898
- root = & fut .next
972
+ if reqid % 2 == 0 {
973
+ return shard .requests [pos ].findFuture (reqid , fetch )
974
+ } else {
975
+ return shard .requestsWithCtx [pos ].findFuture (reqid , fetch )
899
976
}
900
977
}
901
978
@@ -984,8 +1061,12 @@ func (conn *Connection) read(r io.Reader) (response []byte, err error) {
984
1061
return
985
1062
}
986
1063
987
- func (conn * Connection ) nextRequestId () (requestId uint32 ) {
988
- return atomic .AddUint32 (& conn .requestId , 1 )
1064
+ func (conn * Connection ) nextRequestId (Context bool ) (requestId uint32 ) {
1065
+ if Context {
1066
+ return atomic .AddUint32 (& conn .contextRequestId , 2 )
1067
+ } else {
1068
+ return atomic .AddUint32 (& conn .requestId , 2 )
1069
+ }
989
1070
}
990
1071
991
1072
// Do performs a request asynchronously on the connection.
@@ -1000,6 +1081,15 @@ func (conn *Connection) Do(req Request) *Future {
1000
1081
return fut
1001
1082
}
1002
1083
}
1084
+ if req .Ctx () != nil {
1085
+ select {
1086
+ case <- req .Ctx ().Done ():
1087
+ fut := NewFuture ()
1088
+ fut .SetError (fmt .Errorf ("context is done" ))
1089
+ return fut
1090
+ default :
1091
+ }
1092
+ }
1003
1093
return conn .send (req )
1004
1094
}
1005
1095
0 commit comments