22
22
import static org .neo4j .driver .internal .async .connection .ChannelAttributes .poolId ;
23
23
import static org .neo4j .driver .internal .async .connection .ChannelAttributes .setTerminationReason ;
24
24
import static org .neo4j .driver .internal .util .Futures .asCompletionStage ;
25
+ import static org .neo4j .driver .internal .util .LockUtil .executeWithLock ;
25
26
26
27
import io .netty .channel .Channel ;
27
28
import io .netty .channel .ChannelHandler ;
28
29
import java .time .Clock ;
29
30
import java .util .concurrent .CompletableFuture ;
30
31
import java .util .concurrent .CompletionStage ;
31
32
import java .util .concurrent .TimeUnit ;
32
- import java .util .concurrent .atomic .AtomicReference ;
33
+ import java .util .concurrent .locks .Lock ;
34
+ import java .util .concurrent .locks .ReentrantLock ;
33
35
import org .neo4j .driver .Logger ;
34
36
import org .neo4j .driver .Logging ;
35
37
import org .neo4j .driver .internal .BoltServerAddress ;
41
43
import org .neo4j .driver .internal .handlers .ResetResponseHandler ;
42
44
import org .neo4j .driver .internal .messaging .BoltProtocol ;
43
45
import org .neo4j .driver .internal .messaging .Message ;
46
+ import org .neo4j .driver .internal .messaging .request .DiscardAllMessage ;
47
+ import org .neo4j .driver .internal .messaging .request .DiscardMessage ;
48
+ import org .neo4j .driver .internal .messaging .request .PullAllMessage ;
49
+ import org .neo4j .driver .internal .messaging .request .PullMessage ;
44
50
import org .neo4j .driver .internal .messaging .request .ResetMessage ;
51
+ import org .neo4j .driver .internal .messaging .request .RunWithMetadataMessage ;
45
52
import org .neo4j .driver .internal .metrics .ListenerEvent ;
46
53
import org .neo4j .driver .internal .metrics .MetricsListener ;
47
54
import org .neo4j .driver .internal .spi .Connection ;
53
60
*/
54
61
public class NetworkConnection implements Connection {
55
62
private final Logger log ;
63
+ private final Lock lock ;
56
64
private final Channel channel ;
57
65
private final InboundMessageDispatcher messageDispatcher ;
58
66
private final String serverAgent ;
@@ -61,12 +69,13 @@ public class NetworkConnection implements Connection {
61
69
private final ExtendedChannelPool channelPool ;
62
70
private final CompletableFuture <Void > releaseFuture ;
63
71
private final Clock clock ;
64
-
65
- private final AtomicReference <Status > status = new AtomicReference <>(Status .OPEN );
66
72
private final MetricsListener metricsListener ;
67
73
private final ListenerEvent <?> inUseEvent ;
68
74
69
75
private final Long connectionReadTimeout ;
76
+
77
+ private Status status = Status .OPEN ;
78
+ private TerminationAwareStateLockingExecutor terminationAwareStateLockingExecutor ;
70
79
private ChannelHandler connectionReadTimeoutHandler ;
71
80
72
81
public NetworkConnection (
@@ -76,6 +85,7 @@ public NetworkConnection(
76
85
MetricsListener metricsListener ,
77
86
Logging logging ) {
78
87
this .log = logging .getLog (getClass ());
88
+ this .lock = new ReentrantLock ();
79
89
this .channel = channel ;
80
90
this .messageDispatcher = ChannelAttributes .messageDispatcher (channel );
81
91
this .serverAgent = ChannelAttributes .serverAgent (channel );
@@ -93,7 +103,7 @@ public NetworkConnection(
93
103
94
104
@ Override
95
105
public boolean isOpen () {
96
- return status . get () == Status .OPEN ;
106
+ return executeWithLock ( lock , () -> status == Status .OPEN ) ;
97
107
}
98
108
99
109
@ Override
@@ -110,52 +120,31 @@ public void disableAutoRead() {
110
120
}
111
121
}
112
122
113
- @ Override
114
- public void flush () {
115
- if (verifyOpen (null , null )) {
116
- flushInEventLoop ();
117
- }
118
- }
119
-
120
123
@ Override
121
124
public void write (Message message , ResponseHandler handler ) {
122
- if (verifyOpen (handler , null )) {
125
+ if (verifyOpen (handler )) {
123
126
writeMessageInEventLoop (message , handler , false );
124
127
}
125
128
}
126
129
127
- @ Override
128
- public void write (Message message1 , ResponseHandler handler1 , Message message2 , ResponseHandler handler2 ) {
129
- if (verifyOpen (handler1 , handler2 )) {
130
- writeMessagesInEventLoop (message1 , handler1 , message2 , handler2 , false );
131
- }
132
- }
133
-
134
130
@ Override
135
131
public void writeAndFlush (Message message , ResponseHandler handler ) {
136
- if (verifyOpen (handler , null )) {
132
+ if (verifyOpen (handler )) {
137
133
writeMessageInEventLoop (message , handler , true );
138
134
}
139
135
}
140
136
141
137
@ Override
142
- public void writeAndFlush (Message message1 , ResponseHandler handler1 , Message message2 , ResponseHandler handler2 ) {
143
- if (verifyOpen (handler1 , handler2 )) {
144
- writeMessagesInEventLoop (message1 , handler1 , message2 , handler2 , true );
145
- }
146
- }
147
-
148
- @ Override
149
- public CompletionStage <Void > reset () {
150
- CompletableFuture <Void > result = new CompletableFuture <>();
151
- ResetResponseHandler handler = new ResetResponseHandler (messageDispatcher , result );
138
+ public CompletionStage <Void > reset (Throwable throwable ) {
139
+ var result = new CompletableFuture <Void >();
140
+ var handler = new ResetResponseHandler (messageDispatcher , result , throwable );
152
141
writeResetMessageIfNeeded (handler , true );
153
142
return result ;
154
143
}
155
144
156
145
@ Override
157
146
public CompletionStage <Void > release () {
158
- if (status . compareAndSet ( Status . OPEN , Status .RELEASED )) {
147
+ if (executeWithLock ( lock , () -> updateStateIfOpen ( Status .RELEASED ) )) {
159
148
ChannelReleasingResetResponseHandler handler = new ChannelReleasingResetResponseHandler (
160
149
channel , channelPool , messageDispatcher , clock , releaseFuture );
161
150
@@ -167,7 +156,7 @@ public CompletionStage<Void> release() {
167
156
168
157
@ Override
169
158
public void terminateAndRelease (String reason ) {
170
- if (status . compareAndSet ( Status . OPEN , Status .TERMINATED )) {
159
+ if (executeWithLock ( lock , () -> updateStateIfOpen ( Status .TERMINATED ) )) {
171
160
setTerminationReason (channel , reason );
172
161
asCompletionStage (channel .close ())
173
162
.exceptionally (throwable -> null )
@@ -194,6 +183,25 @@ public BoltProtocol protocol() {
194
183
return protocol ;
195
184
}
196
185
186
+ @ Override
187
+ public void bindTerminationAwareStateLockingExecutor (TerminationAwareStateLockingExecutor executor ) {
188
+ executeWithLock (lock , () -> {
189
+ if (this .terminationAwareStateLockingExecutor != null ) {
190
+ throw new IllegalStateException ("terminationAwareStateLockingExecutor is already set" );
191
+ }
192
+ this .terminationAwareStateLockingExecutor = executor ;
193
+ });
194
+ }
195
+
196
+ private boolean updateStateIfOpen (Status newStatus ) {
197
+ if (Status .OPEN .equals (status )) {
198
+ status = newStatus ;
199
+ return true ;
200
+ } else {
201
+ return false ;
202
+ }
203
+ }
204
+
197
205
private void writeResetMessageIfNeeded (ResponseHandler resetHandler , boolean isSessionReset ) {
198
206
channel .eventLoop ().execute (() -> {
199
207
if (isSessionReset && !isOpen ()) {
@@ -208,73 +216,49 @@ private void writeResetMessageIfNeeded(ResponseHandler resetHandler, boolean isS
208
216
});
209
217
}
210
218
211
- private void flushInEventLoop () {
212
- channel .eventLoop ().execute (() -> {
213
- channel .flush ();
214
- registerConnectionReadTimeout (channel );
215
- });
216
- }
217
-
218
219
private void writeMessageInEventLoop (Message message , ResponseHandler handler , boolean flush ) {
219
- channel .eventLoop ().execute (() -> {
220
- messageDispatcher .enqueue (handler );
221
-
222
- if (flush ) {
223
- channel .writeAndFlush (message ).addListener (future -> registerConnectionReadTimeout (channel ));
224
- } else {
225
- channel .write (message , channel .voidPromise ());
226
- }
227
- });
228
- }
229
-
230
- private void writeMessagesInEventLoop (
231
- Message message1 , ResponseHandler handler1 , Message message2 , ResponseHandler handler2 , boolean flush ) {
232
- channel .eventLoop ().execute (() -> {
233
- messageDispatcher .enqueue (handler1 );
234
- messageDispatcher .enqueue (handler2 );
235
-
236
- channel .write (message1 , channel .voidPromise ());
237
-
238
- if (flush ) {
239
- channel .writeAndFlush (message2 ).addListener (future -> registerConnectionReadTimeout (channel ));
240
- } else {
241
- channel .write (message2 , channel .voidPromise ());
242
- }
243
- });
220
+ channel .eventLoop ()
221
+ .execute (() -> terminationAwareStateLockingExecutor (message ).execute (causeOfTermination -> {
222
+ if (causeOfTermination == null ) {
223
+ messageDispatcher .enqueue (handler );
224
+
225
+ if (flush ) {
226
+ channel .writeAndFlush (message )
227
+ .addListener (future -> registerConnectionReadTimeout (channel ));
228
+ } else {
229
+ channel .write (message , channel .voidPromise ());
230
+ }
231
+ } else {
232
+ handler .onFailure (causeOfTermination );
233
+ }
234
+ }));
244
235
}
245
236
246
237
private void setAutoRead (boolean value ) {
247
238
channel .config ().setAutoRead (value );
248
239
}
249
240
250
- private boolean verifyOpen (ResponseHandler handler1 , ResponseHandler handler2 ) {
251
- Status connectionStatus = this .status .get ();
252
- switch (connectionStatus ) {
253
- case OPEN :
254
- return true ;
255
- case RELEASED :
241
+ private boolean verifyOpen (ResponseHandler handler ) {
242
+ var connectionStatus = executeWithLock (lock , () -> status );
243
+ return switch (connectionStatus ) {
244
+ case OPEN -> true ;
245
+ case RELEASED -> {
256
246
Exception error =
257
247
new IllegalStateException ("Connection has been released to the pool and can't be used" );
258
- if (handler1 != null ) {
259
- handler1 .onFailure (error );
248
+ if (handler != null ) {
249
+ handler .onFailure (error );
260
250
}
261
- if (handler2 != null ) {
262
- handler2 .onFailure (error );
263
- }
264
- return false ;
265
- case TERMINATED :
251
+ yield false ;
252
+ }
253
+ case TERMINATED -> {
266
254
Exception terminatedError =
267
255
new IllegalStateException ("Connection has been terminated and can't be used" );
268
- if (handler1 != null ) {
269
- handler1 .onFailure (terminatedError );
270
- }
271
- if (handler2 != null ) {
272
- handler2 .onFailure (terminatedError );
256
+ if (handler != null ) {
257
+ handler .onFailure (terminatedError );
273
258
}
274
- return false ;
275
- default :
276
- throw new IllegalStateException ("Unknown status: " + connectionStatus );
277
- }
259
+ yield false ;
260
+ }
261
+ };
278
262
}
279
263
280
264
private void registerConnectionReadTimeout (Channel channel ) {
@@ -295,6 +279,25 @@ private void registerConnectionReadTimeout(Channel channel) {
295
279
}
296
280
}
297
281
282
+ private TerminationAwareStateLockingExecutor terminationAwareStateLockingExecutor (Message message ) {
283
+ var result = (TerminationAwareStateLockingExecutor ) consumer -> consumer .accept (null );
284
+ if (isQueryMessage (message )) {
285
+ var lockingExecutor = executeWithLock (lock , () -> this .terminationAwareStateLockingExecutor );
286
+ if (lockingExecutor != null ) {
287
+ result = lockingExecutor ;
288
+ }
289
+ }
290
+ return result ;
291
+ }
292
+
293
+ private boolean isQueryMessage (Message message ) {
294
+ return message instanceof RunWithMetadataMessage
295
+ || message instanceof PullMessage
296
+ || message instanceof PullAllMessage
297
+ || message instanceof DiscardMessage
298
+ || message instanceof DiscardAllMessage ;
299
+ }
300
+
298
301
private enum Status {
299
302
OPEN ,
300
303
RELEASED ,
0 commit comments