Skip to content

Commit 543e2d4

Browse files
committed
Update
1 parent 3e35174 commit 543e2d4

File tree

1 file changed

+156
-125
lines changed

1 file changed

+156
-125
lines changed

driver/src/main/java/org/neo4j/driver/internal/async/NetworkSession.java

Lines changed: 156 additions & 125 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import static org.neo4j.driver.internal.util.Futures.completionExceptionCause;
2222

2323
import java.util.Collections;
24+
import java.util.HashMap;
2425
import java.util.HashSet;
2526
import java.util.Map;
2627
import java.util.Objects;
@@ -30,7 +31,6 @@
3031
import java.util.concurrent.CompletionStage;
3132
import java.util.concurrent.TimeoutException;
3233
import java.util.concurrent.atomic.AtomicBoolean;
33-
import java.util.concurrent.atomic.AtomicReference;
3434
import java.util.function.Consumer;
3535
import java.util.function.Function;
3636
import java.util.function.Supplier;
@@ -63,6 +63,7 @@
6363
import org.neo4j.driver.internal.bolt.api.DatabaseName;
6464
import org.neo4j.driver.internal.bolt.api.DatabaseNameUtil;
6565
import org.neo4j.driver.internal.bolt.api.NotificationConfig;
66+
import org.neo4j.driver.internal.bolt.api.SecurityPlan;
6667
import org.neo4j.driver.internal.bolt.api.TelemetryApi;
6768
import org.neo4j.driver.internal.bolt.api.exception.MinVersionAcquisitionException;
6869
import org.neo4j.driver.internal.bolt.api.summary.RunSummary;
@@ -413,22 +414,126 @@ protected CompletionStage<Boolean> currentConnectionIsOpen() {
413414
connection.isOpen()); // and it's still open
414415
}
415416

416-
private org.neo4j.driver.internal.bolt.api.AccessMode asBoltAccessMode(AccessMode mode) {
417-
return switch (mode) {
418-
case WRITE -> org.neo4j.driver.internal.bolt.api.AccessMode.WRITE;
419-
case READ -> org.neo4j.driver.internal.bolt.api.AccessMode.READ;
420-
};
421-
}
422-
423417
private void handleDatabaseName(String name) {
424418
connectionContext.databaseNameFuture.complete(DatabaseNameUtil.database(name));
425419
homeDatabaseCache.put(homeDatabaseKey, name);
426420
}
427421

428422
private CompletionStage<BoltConnectionWithCloseTracking> acquireConnection(AccessMode mode) {
429-
var currentConnectionStage = connectionStage;
423+
var overrideAuthToken = connectionContext.overrideAuthToken();
424+
var authTokenManager = overrideAuthToken != null ? NoopAuthTokenManager.INSTANCE : this.authTokenManager;
425+
var newConnectionStage = pulledResultCursorStage(connectionStage)
426+
.thenCompose(ignored -> securityPlanManager.plan())
427+
.thenCompose(securityPlan -> acquireConnection(securityPlan, mode)
428+
.thenApply(connection -> (DriverBoltConnection)
429+
new BoltConnectionWithAuthTokenManager(connection, authTokenManager))
430+
.thenApply(BoltConnectionWithCloseTracking::new)
431+
.exceptionally(this::mapAcquisitionError));
432+
connectionStage = newConnectionStage.exceptionally(error -> null);
433+
return newConnectionStage;
434+
}
435+
436+
private BoltConnectionWithCloseTracking mapAcquisitionError(Throwable throwable) {
437+
throwable = Futures.completionExceptionCause(throwable);
438+
if (throwable instanceof TimeoutException) {
439+
throw new ClientException(
440+
GqlStatusError.UNKNOWN.getStatus(),
441+
GqlStatusError.UNKNOWN.getStatusDescription(throwable.getMessage()),
442+
"N/A",
443+
throwable.getMessage(),
444+
GqlStatusError.DIAGNOSTIC_RECORD,
445+
throwable);
446+
}
447+
if (throwable instanceof MinVersionAcquisitionException minVersionAcquisitionException) {
448+
if (connectionContext.impersonatedUser == null && connectionContext.impersonatedUser() != null) {
449+
var message =
450+
"Detected connection that does not support impersonation, please make sure to have all servers running 4.4 version or above and communicating"
451+
+ " over Bolt version 4.4 or above when using impersonation feature";
452+
throw new ClientException(
453+
GqlStatusError.UNKNOWN.getStatus(),
454+
GqlStatusError.UNKNOWN.getStatusDescription(message),
455+
"N/A",
456+
message,
457+
GqlStatusError.DIAGNOSTIC_RECORD,
458+
null);
459+
} else {
460+
throw new CompletionException(new UnsupportedFeatureException(String.format(
461+
"Detected Bolt %s connection that does not support the auth token override feature, please make sure to have all servers communicating over Bolt 5.1 or above to use the feature",
462+
minVersionAcquisitionException.version())));
463+
}
464+
} else {
465+
throw new CompletionException(throwable);
466+
}
467+
}
468+
469+
private CompletionStage<DriverBoltConnection> acquireConnection(SecurityPlan securityPlan, AccessMode mode) {
470+
var databaseName = connectionContext.databaseNameFuture().getNow(null);
471+
var impersonatedUser = connectionContext.impersonatedUser();
472+
var minVersion = minBoltVersion(connectionContext);
473+
var overrideAuthToken = connectionContext.overrideAuthToken();
474+
var tokenStageSupplier = tokenStageSupplier(overrideAuthToken, authTokenManager);
475+
var accessMode = asBoltAccessMode(mode);
476+
var bookmarks = connectionContext.rediscoveryBookmarks().stream()
477+
.map(Bookmark::value)
478+
.collect(Collectors.toSet());
479+
480+
Map<String, Object> additionalParameters = new HashMap<>();
481+
if (databaseName == null) {
482+
homeDatabaseCache.get(homeDatabaseKey).ifPresent(name -> additionalParameters.put("homeDatabase", name));
483+
}
484+
485+
Consumer<DatabaseName> databaseNameConsumer = (name) -> {
486+
if (name != null) {
487+
if (databaseName == null) {
488+
name.databaseName().ifPresent(n -> homeDatabaseCache.put(homeDatabaseKey, n));
489+
}
490+
} else {
491+
name = DatabaseNameUtil.defaultDatabase();
492+
}
493+
connectionContext.databaseNameFuture().complete(name);
494+
};
495+
496+
return boltConnectionProvider
497+
.connect(
498+
securityPlan,
499+
databaseName,
500+
tokenStageSupplier,
501+
accessMode,
502+
bookmarks,
503+
impersonatedUser,
504+
minVersion,
505+
driverNotificationConfig,
506+
databaseNameConsumer,
507+
additionalParameters)
508+
.thenCompose(boltConnection -> {
509+
if (databaseName == null
510+
&& additionalParameters.containsKey("homeDatabase")
511+
&& !boltConnection.serverSideRoutingEnabled()
512+
&& !connectionContext.databaseNameFuture.isDone()) {
513+
// home database was requested with hint, but the returned connection does not have SSR enabled
514+
additionalParameters.remove("homeDatabase");
515+
return boltConnection
516+
.close()
517+
.thenCompose(ignored -> boltConnectionProvider.connect(
518+
securityPlan,
519+
null,
520+
tokenStageSupplier,
521+
accessMode,
522+
bookmarks,
523+
impersonatedUser,
524+
minVersion,
525+
driverNotificationConfig,
526+
databaseNameConsumer,
527+
additionalParameters));
528+
} else {
529+
return CompletableFuture.completedStage(boltConnection);
530+
}
531+
});
532+
}
430533

431-
var newConnectionStage = resultCursorStage
534+
private CompletionStage<Void> pulledResultCursorStage(
535+
CompletionStage<BoltConnectionWithCloseTracking> connectionStage) {
536+
return resultCursorStage
432537
.thenCompose(cursor -> {
433538
if (cursor == null) {
434539
return completedWithNull();
@@ -443,126 +548,13 @@ private CompletionStage<BoltConnectionWithCloseTracking> acquireConnection(Acces
443548
// 2) previous result has been successful and is fully consumed
444549
// 3) previous result failed and error has been consumed
445550

446-
// return existing connection, which should've been released back to the pool by now
447-
return currentConnectionStage.exceptionally(ignore -> null);
551+
// the existing connection should've been released back to the pool by now
552+
return connectionStage.handle((ignored, throwable) -> null);
448553
} else {
449554
// there exists unconsumed error, re-throw it
450555
throw new CompletionException(error);
451556
}
452-
})
453-
.thenCompose(ignored -> {
454-
var databaseName = connectionContext.databaseNameFuture.getNow(null);
455-
456-
Supplier<CompletionStage<Map<String, Value>>> tokenStageSupplier;
457-
var minVersion = new AtomicReference<BoltProtocolVersion>();
458-
if (connectionContext.impersonatedUser() != null) {
459-
minVersion.set(new BoltProtocolVersion(4, 4));
460-
}
461-
var overrideAuthToken = connectionContext.overrideAuthToken();
462-
if (overrideAuthToken != null) {
463-
tokenStageSupplier = () -> CompletableFuture.completedStage(connectionContext.authToken)
464-
.thenApply(token -> ((InternalAuthToken) token).toMap());
465-
minVersion.set(new BoltProtocolVersion(5, 1));
466-
} else {
467-
tokenStageSupplier = () ->
468-
authTokenManager.getToken().thenApply(token -> ((InternalAuthToken) token).toMap());
469-
}
470-
return securityPlanManager.plan().thenCompose(securityPlan -> {
471-
;
472-
var additionalParams = homeDatabaseCache
473-
.get(homeDatabaseKey)
474-
.map(name -> Map.<String, Object>of("homeDatabase", name))
475-
.orElse(Collections.emptyMap());
476-
return boltConnectionProvider
477-
.connect(
478-
securityPlan,
479-
databaseName,
480-
tokenStageSupplier,
481-
switch (mode) {
482-
case WRITE -> org.neo4j.driver.internal.bolt.api.AccessMode.WRITE;
483-
case READ -> org.neo4j.driver.internal.bolt.api.AccessMode.READ;
484-
},
485-
connectionContext.rediscoveryBookmarks().stream()
486-
.map(Bookmark::value)
487-
.collect(Collectors.toSet()),
488-
connectionContext.impersonatedUser(),
489-
minVersion.get(),
490-
driverNotificationConfig,
491-
(name) -> {
492-
connectionContext
493-
.databaseNameFuture()
494-
.complete(name == null ? DatabaseNameUtil.defaultDatabase() : name);
495-
},
496-
additionalParams)
497-
.thenApply(connection -> {
498-
if (connection.serverSideRoutingEnabled()) {
499-
if (databaseName == null) {
500-
// home database was requested
501-
connectionContext
502-
.databaseNameFuture()
503-
.getNow(DatabaseNameUtil.defaultDatabase())
504-
.databaseName()
505-
.ifPresent(name -> homeDatabaseCache.put(homeDatabaseKey, name));
506-
}
507-
}
508-
return connection;
509-
})
510-
.thenApply(connection -> (DriverBoltConnection) new BoltConnectionWithAuthTokenManager(
511-
connection,
512-
overrideAuthToken != null
513-
? new AuthTokenManager() {
514-
@Override
515-
public CompletionStage<AuthToken> getToken() {
516-
return null;
517-
}
518-
519-
@Override
520-
public boolean handleSecurityException(
521-
AuthToken authToken, SecurityException exception) {
522-
return false;
523-
}
524-
}
525-
: authTokenManager))
526-
.thenApply(BoltConnectionWithCloseTracking::new)
527-
.exceptionally(throwable -> {
528-
throwable = Futures.completionExceptionCause(throwable);
529-
if (throwable instanceof TimeoutException) {
530-
throw new ClientException(
531-
GqlStatusError.UNKNOWN.getStatus(),
532-
GqlStatusError.UNKNOWN.getStatusDescription(throwable.getMessage()),
533-
"N/A",
534-
throwable.getMessage(),
535-
GqlStatusError.DIAGNOSTIC_RECORD,
536-
throwable);
537-
}
538-
if (throwable
539-
instanceof MinVersionAcquisitionException minVersionAcquisitionException) {
540-
if (overrideAuthToken == null && connectionContext.impersonatedUser() != null) {
541-
var message =
542-
"Detected connection that does not support impersonation, please make sure to have all servers running 4.4 version or above and communicating"
543-
+ " over Bolt version 4.4 or above when using impersonation feature";
544-
throw new ClientException(
545-
GqlStatusError.UNKNOWN.getStatus(),
546-
GqlStatusError.UNKNOWN.getStatusDescription(message),
547-
"N/A",
548-
message,
549-
GqlStatusError.DIAGNOSTIC_RECORD,
550-
null);
551-
} else {
552-
throw new CompletionException(new UnsupportedFeatureException(String.format(
553-
"Detected Bolt %s connection that does not support the auth token override feature, please make sure to have all servers communicating over Bolt 5.1 or above to use the feature",
554-
minVersionAcquisitionException.version())));
555-
}
556-
} else {
557-
throw new CompletionException(throwable);
558-
}
559-
});
560-
});
561557
});
562-
563-
connectionStage = newConnectionStage.exceptionally(error -> null);
564-
565-
return newConnectionStage;
566558
}
567559

568560
private CompletionStage<Throwable> closeTransactionAndReleaseConnection() {
@@ -645,6 +637,31 @@ private void assertDatabaseNameFutureIsDone() {
645637
}
646638
}
647639

640+
private static BoltProtocolVersion minBoltVersion(NetworkSessionConnectionContext connectionContext) {
641+
BoltProtocolVersion minBoltVersion = null;
642+
if (connectionContext.overrideAuthToken() != null) {
643+
minBoltVersion = new BoltProtocolVersion(5, 1);
644+
} else if (connectionContext.impersonatedUser() != null) {
645+
minBoltVersion = new BoltProtocolVersion(4, 4);
646+
}
647+
return minBoltVersion;
648+
}
649+
650+
private static Supplier<CompletionStage<Map<String, Value>>> tokenStageSupplier(
651+
AuthToken overrideAuthToken, AuthTokenManager authTokenManager) {
652+
return overrideAuthToken != null
653+
? () -> CompletableFuture.completedStage(overrideAuthToken)
654+
.thenApply(token -> ((InternalAuthToken) token).toMap())
655+
: () -> authTokenManager.getToken().thenApply(token -> ((InternalAuthToken) token).toMap());
656+
}
657+
658+
private static org.neo4j.driver.internal.bolt.api.AccessMode asBoltAccessMode(AccessMode mode) {
659+
return switch (mode) {
660+
case WRITE -> org.neo4j.driver.internal.bolt.api.AccessMode.WRITE;
661+
case READ -> org.neo4j.driver.internal.bolt.api.AccessMode.READ;
662+
};
663+
}
664+
648665
/**
649666
* The {@link NetworkSessionConnectionContext#mode} can be mutable for a session connection context
650667
*/
@@ -763,4 +780,18 @@ public void onComplete() {
763780
}
764781
}
765782
}
783+
784+
private static final class NoopAuthTokenManager implements AuthTokenManager {
785+
static final NoopAuthTokenManager INSTANCE = new NoopAuthTokenManager();
786+
787+
@Override
788+
public CompletionStage<AuthToken> getToken() {
789+
return null;
790+
}
791+
792+
@Override
793+
public boolean handleSecurityException(AuthToken authToken, SecurityException exception) {
794+
return false;
795+
}
796+
}
766797
}

0 commit comments

Comments
 (0)