22
22
import static org .neo4j .driver .internal .async .connection .ChannelAttributes .poolId ;
23
23
import static org .neo4j .driver .internal .async .connection .ChannelAttributes .setTerminationReason ;
24
24
import static org .neo4j .driver .internal .util .Futures .asCompletionStage ;
25
+ import static org .neo4j .driver .internal .util .LockUtil .executeWithLock ;
25
26
26
27
import io .netty .channel .Channel ;
27
28
import io .netty .channel .ChannelHandler ;
28
29
import java .time .Clock ;
29
30
import java .util .concurrent .CompletableFuture ;
30
31
import java .util .concurrent .CompletionStage ;
31
32
import java .util .concurrent .TimeUnit ;
32
- import java .util .concurrent .atomic .AtomicReference ;
33
+ import java .util .concurrent .locks .Lock ;
34
+ import java .util .concurrent .locks .ReentrantLock ;
35
+ import java .util .function .Consumer ;
33
36
import org .neo4j .driver .Logger ;
34
37
import org .neo4j .driver .Logging ;
35
38
import org .neo4j .driver .internal .BoltServerAddress ;
41
44
import org .neo4j .driver .internal .handlers .ResetResponseHandler ;
42
45
import org .neo4j .driver .internal .messaging .BoltProtocol ;
43
46
import org .neo4j .driver .internal .messaging .Message ;
47
+ import org .neo4j .driver .internal .messaging .request .DiscardAllMessage ;
48
+ import org .neo4j .driver .internal .messaging .request .DiscardMessage ;
49
+ import org .neo4j .driver .internal .messaging .request .PullAllMessage ;
50
+ import org .neo4j .driver .internal .messaging .request .PullMessage ;
44
51
import org .neo4j .driver .internal .messaging .request .ResetMessage ;
52
+ import org .neo4j .driver .internal .messaging .request .RunWithMetadataMessage ;
45
53
import org .neo4j .driver .internal .metrics .ListenerEvent ;
46
54
import org .neo4j .driver .internal .metrics .MetricsListener ;
47
55
import org .neo4j .driver .internal .spi .Connection ;
53
61
*/
54
62
public class NetworkConnection implements Connection {
55
63
private final Logger log ;
64
+ private final Lock lock ;
56
65
private final Channel channel ;
57
66
private final InboundMessageDispatcher messageDispatcher ;
58
67
private final String serverAgent ;
@@ -61,12 +70,13 @@ public class NetworkConnection implements Connection {
61
70
private final ExtendedChannelPool channelPool ;
62
71
private final CompletableFuture <Void > releaseFuture ;
63
72
private final Clock clock ;
64
-
65
- private final AtomicReference <Status > status = new AtomicReference <>(Status .OPEN );
66
73
private final MetricsListener metricsListener ;
67
74
private final ListenerEvent <?> inUseEvent ;
68
75
69
76
private final Long connectionReadTimeout ;
77
+
78
+ private Status status = Status .OPEN ;
79
+ private UnmanagedTransaction transaction ;
70
80
private ChannelHandler connectionReadTimeoutHandler ;
71
81
72
82
public NetworkConnection (
@@ -76,6 +86,7 @@ public NetworkConnection(
76
86
MetricsListener metricsListener ,
77
87
Logging logging ) {
78
88
this .log = logging .getLog (getClass ());
89
+ this .lock = new ReentrantLock ();
79
90
this .channel = channel ;
80
91
this .messageDispatcher = ChannelAttributes .messageDispatcher (channel );
81
92
this .serverAgent = ChannelAttributes .serverAgent (channel );
@@ -93,7 +104,7 @@ public NetworkConnection(
93
104
94
105
@ Override
95
106
public boolean isOpen () {
96
- return status . get () == Status .OPEN ;
107
+ return executeWithLock ( lock , () -> status == Status .OPEN ) ;
97
108
}
98
109
99
110
@ Override
@@ -110,52 +121,31 @@ public void disableAutoRead() {
110
121
}
111
122
}
112
123
113
- @ Override
114
- public void flush () {
115
- if (verifyOpen (null , null )) {
116
- flushInEventLoop ();
117
- }
118
- }
119
-
120
124
@ Override
121
125
public void write (Message message , ResponseHandler handler ) {
122
- if (verifyOpen (handler , null )) {
126
+ if (verifyOpen (handler )) {
123
127
writeMessageInEventLoop (message , handler , false );
124
128
}
125
129
}
126
130
127
- @ Override
128
- public void write (Message message1 , ResponseHandler handler1 , Message message2 , ResponseHandler handler2 ) {
129
- if (verifyOpen (handler1 , handler2 )) {
130
- writeMessagesInEventLoop (message1 , handler1 , message2 , handler2 , false );
131
- }
132
- }
133
-
134
131
@ Override
135
132
public void writeAndFlush (Message message , ResponseHandler handler ) {
136
- if (verifyOpen (handler , null )) {
133
+ if (verifyOpen (handler )) {
137
134
writeMessageInEventLoop (message , handler , true );
138
135
}
139
136
}
140
137
141
138
@ Override
142
- public void writeAndFlush (Message message1 , ResponseHandler handler1 , Message message2 , ResponseHandler handler2 ) {
143
- if (verifyOpen (handler1 , handler2 )) {
144
- writeMessagesInEventLoop (message1 , handler1 , message2 , handler2 , true );
145
- }
146
- }
147
-
148
- @ Override
149
- public CompletionStage <Void > reset () {
150
- CompletableFuture <Void > result = new CompletableFuture <>();
151
- ResetResponseHandler handler = new ResetResponseHandler (messageDispatcher , result );
139
+ public CompletionStage <Void > reset (Throwable throwable ) {
140
+ var result = new CompletableFuture <Void >();
141
+ var handler = new ResetResponseHandler (messageDispatcher , result , throwable );
152
142
writeResetMessageIfNeeded (handler , true );
153
143
return result ;
154
144
}
155
145
156
146
@ Override
157
147
public CompletionStage <Void > release () {
158
- if (status . compareAndSet ( Status . OPEN , Status .RELEASED )) {
148
+ if (executeWithLock ( lock , () -> updateStateIfOpen ( Status .RELEASED ) )) {
159
149
ChannelReleasingResetResponseHandler handler = new ChannelReleasingResetResponseHandler (
160
150
channel , channelPool , messageDispatcher , clock , releaseFuture );
161
151
@@ -167,7 +157,7 @@ public CompletionStage<Void> release() {
167
157
168
158
@ Override
169
159
public void terminateAndRelease (String reason ) {
170
- if (status . compareAndSet ( Status . OPEN , Status .TERMINATED )) {
160
+ if (executeWithLock ( lock , () -> updateStateIfOpen ( Status .TERMINATED ) )) {
171
161
setTerminationReason (channel , reason );
172
162
asCompletionStage (channel .close ())
173
163
.exceptionally (throwable -> null )
@@ -194,6 +184,25 @@ public BoltProtocol protocol() {
194
184
return protocol ;
195
185
}
196
186
187
+ @ Override
188
+ public void bindTransaction (UnmanagedTransaction transaction ) {
189
+ executeWithLock (lock , () -> {
190
+ if (this .transaction != null ) {
191
+ throw new IllegalStateException ("transaction is already set" );
192
+ }
193
+ this .transaction = transaction ;
194
+ });
195
+ }
196
+
197
+ private boolean updateStateIfOpen (Status newStatus ) {
198
+ if (Status .OPEN .equals (status )) {
199
+ status = newStatus ;
200
+ return true ;
201
+ } else {
202
+ return false ;
203
+ }
204
+ }
205
+
197
206
private void writeResetMessageIfNeeded (ResponseHandler resetHandler , boolean isSessionReset ) {
198
207
channel .eventLoop ().execute (() -> {
199
208
if (isSessionReset && !isOpen ()) {
@@ -208,73 +217,49 @@ private void writeResetMessageIfNeeded(ResponseHandler resetHandler, boolean isS
208
217
});
209
218
}
210
219
211
- private void flushInEventLoop () {
212
- channel .eventLoop ().execute (() -> {
213
- channel .flush ();
214
- registerConnectionReadTimeout (channel );
215
- });
216
- }
217
-
218
220
private void writeMessageInEventLoop (Message message , ResponseHandler handler , boolean flush ) {
219
- channel .eventLoop ().execute (() -> {
220
- messageDispatcher .enqueue (handler );
221
-
222
- if (flush ) {
223
- channel .writeAndFlush (message ).addListener (future -> registerConnectionReadTimeout (channel ));
224
- } else {
225
- channel .write (message , channel .voidPromise ());
226
- }
227
- });
228
- }
229
-
230
- private void writeMessagesInEventLoop (
231
- Message message1 , ResponseHandler handler1 , Message message2 , ResponseHandler handler2 , boolean flush ) {
232
- channel .eventLoop ().execute (() -> {
233
- messageDispatcher .enqueue (handler1 );
234
- messageDispatcher .enqueue (handler2 );
235
-
236
- channel .write (message1 , channel .voidPromise ());
237
-
238
- if (flush ) {
239
- channel .writeAndFlush (message2 ).addListener (future -> registerConnectionReadTimeout (channel ));
240
- } else {
241
- channel .write (message2 , channel .voidPromise ());
242
- }
243
- });
221
+ channel .eventLoop ()
222
+ .execute (() -> transactionTerminationAwareExecutor (message ).accept (causeOfTermination -> {
223
+ if (causeOfTermination == null ) {
224
+ messageDispatcher .enqueue (handler );
225
+
226
+ if (flush ) {
227
+ channel .writeAndFlush (message )
228
+ .addListener (future -> registerConnectionReadTimeout (channel ));
229
+ } else {
230
+ channel .write (message , channel .voidPromise ());
231
+ }
232
+ } else {
233
+ handler .onFailure (causeOfTermination );
234
+ }
235
+ }));
244
236
}
245
237
246
238
private void setAutoRead (boolean value ) {
247
239
channel .config ().setAutoRead (value );
248
240
}
249
241
250
- private boolean verifyOpen (ResponseHandler handler1 , ResponseHandler handler2 ) {
251
- Status connectionStatus = this .status .get ();
252
- switch (connectionStatus ) {
253
- case OPEN :
254
- return true ;
255
- case RELEASED :
242
+ private boolean verifyOpen (ResponseHandler handler ) {
243
+ var connectionStatus = executeWithLock (lock , () -> status );
244
+ return switch (connectionStatus ) {
245
+ case OPEN -> true ;
246
+ case RELEASED -> {
256
247
Exception error =
257
248
new IllegalStateException ("Connection has been released to the pool and can't be used" );
258
- if (handler1 != null ) {
259
- handler1 .onFailure (error );
249
+ if (handler != null ) {
250
+ handler .onFailure (error );
260
251
}
261
- if (handler2 != null ) {
262
- handler2 .onFailure (error );
263
- }
264
- return false ;
265
- case TERMINATED :
252
+ yield false ;
253
+ }
254
+ case TERMINATED -> {
266
255
Exception terminatedError =
267
256
new IllegalStateException ("Connection has been terminated and can't be used" );
268
- if (handler1 != null ) {
269
- handler1 .onFailure (terminatedError );
270
- }
271
- if (handler2 != null ) {
272
- handler2 .onFailure (terminatedError );
257
+ if (handler != null ) {
258
+ handler .onFailure (terminatedError );
273
259
}
274
- return false ;
275
- default :
276
- throw new IllegalStateException ("Unknown status: " + connectionStatus );
277
- }
260
+ yield false ;
261
+ }
262
+ };
278
263
}
279
264
280
265
private void registerConnectionReadTimeout (Channel channel ) {
@@ -295,6 +280,25 @@ private void registerConnectionReadTimeout(Channel channel) {
295
280
}
296
281
}
297
282
283
+ private Consumer <Consumer <Throwable >> transactionTerminationAwareExecutor (Message message ) {
284
+ var result = (Consumer <Consumer <Throwable >>) consumer -> consumer .accept (null );
285
+ if (isQueryMessage (message )) {
286
+ var transaction = executeWithLock (lock , () -> this .transaction );
287
+ if (transaction != null ) {
288
+ result = transaction ::executeWithLockedState ;
289
+ }
290
+ }
291
+ return result ;
292
+ }
293
+
294
+ private boolean isQueryMessage (Message message ) {
295
+ return message instanceof RunWithMetadataMessage
296
+ || message instanceof PullMessage
297
+ || message instanceof PullAllMessage
298
+ || message instanceof DiscardMessage
299
+ || message instanceof DiscardAllMessage ;
300
+ }
301
+
298
302
private enum Status {
299
303
OPEN ,
300
304
RELEASED ,
0 commit comments