Skip to content

Add AuthToken to Bolt layer #1609

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jan 28, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
import java.util.function.Consumer;
import java.util.function.Supplier;
import org.neo4j.driver.internal.bolt.api.AccessMode;
import org.neo4j.driver.internal.bolt.api.AuthToken;
import org.neo4j.driver.internal.bolt.api.BoltAgent;
import org.neo4j.driver.internal.bolt.api.BoltConnection;
import org.neo4j.driver.internal.bolt.api.BoltConnectionProvider;
Expand All @@ -44,7 +45,6 @@
import org.neo4j.driver.internal.bolt.api.RoutingContext;
import org.neo4j.driver.internal.bolt.api.SecurityPlan;
import org.neo4j.driver.internal.bolt.api.exception.MinVersionAcquisitionException;
import org.neo4j.driver.internal.bolt.api.values.Value;
import org.neo4j.driver.internal.bolt.api.values.ValueFactory;
import org.neo4j.driver.internal.bolt.basicimpl.impl.BoltConnectionImpl;
import org.neo4j.driver.internal.bolt.basicimpl.impl.ConnectionProvider;
Expand Down Expand Up @@ -110,7 +110,7 @@ public CompletionStage<Void> init(
public CompletionStage<BoltConnection> connect(
SecurityPlan securityPlan,
DatabaseName databaseName,
Supplier<CompletionStage<Map<String, Value>>> authMapStageSupplier,
Supplier<CompletionStage<AuthToken>> authTokenStageSupplier,
AccessMode mode,
Set<String> bookmarks,
String impersonatedUser,
Expand All @@ -125,17 +125,17 @@ public CompletionStage<BoltConnection> connect(
}

var latestAuthMillisFuture = new CompletableFuture<Long>();
var authMapRef = new AtomicReference<Map<String, Value>>();
return authMapStageSupplier
var authMapRef = new AtomicReference<AuthToken>();
return authTokenStageSupplier
.get()
.thenCompose(authMap -> {
authMapRef.set(authMap);
.thenCompose(authToken -> {
authMapRef.set(authToken);
return this.connectionProvider.acquireConnection(
address,
securityPlan,
routingContext,
databaseName != null ? databaseName.databaseName().orElse(null) : null,
authMap,
authToken.asMap(),
boltAgent,
userAgent,
mode,
Expand Down Expand Up @@ -180,11 +180,11 @@ public CompletionStage<BoltConnection> connect(
}

@Override
public CompletionStage<Void> verifyConnectivity(SecurityPlan securityPlan, Map<String, Value> authMap) {
public CompletionStage<Void> verifyConnectivity(SecurityPlan securityPlan, AuthToken authToken) {
return connect(
securityPlan,
null,
() -> CompletableFuture.completedStage(authMap),
() -> CompletableFuture.completedStage(authToken),
AccessMode.WRITE,
Collections.emptySet(),
null,
Expand All @@ -196,11 +196,11 @@ public CompletionStage<Void> verifyConnectivity(SecurityPlan securityPlan, Map<S
}

@Override
public CompletionStage<Boolean> supportsMultiDb(SecurityPlan securityPlan, Map<String, Value> authMap) {
public CompletionStage<Boolean> supportsMultiDb(SecurityPlan securityPlan, AuthToken authToken) {
return connect(
securityPlan,
null,
() -> CompletableFuture.completedStage(authMap),
() -> CompletableFuture.completedStage(authToken),
AccessMode.WRITE,
Collections.emptySet(),
null,
Expand All @@ -215,11 +215,11 @@ public CompletionStage<Boolean> supportsMultiDb(SecurityPlan securityPlan, Map<S
}

@Override
public CompletionStage<Boolean> supportsSessionAuth(SecurityPlan securityPlan, Map<String, Value> authMap) {
public CompletionStage<Boolean> supportsSessionAuth(SecurityPlan securityPlan, AuthToken authToken) {
return connect(
securityPlan,
null,
() -> CompletableFuture.completedStage(authMap),
() -> CompletableFuture.completedStage(authToken),
AccessMode.WRITE,
Collections.emptySet(),
null,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,8 @@
import java.util.function.Function;
import java.util.stream.Collectors;
import org.neo4j.driver.internal.bolt.api.AccessMode;
import org.neo4j.driver.internal.bolt.api.AuthData;
import org.neo4j.driver.internal.bolt.api.AuthInfo;
import org.neo4j.driver.internal.bolt.api.AuthToken;
import org.neo4j.driver.internal.bolt.api.BoltConnection;
import org.neo4j.driver.internal.bolt.api.BoltConnectionState;
import org.neo4j.driver.internal.bolt.api.BoltProtocolVersion;
Expand Down Expand Up @@ -82,7 +83,7 @@ public final class BoltConnectionImpl implements BoltConnection {
private final boolean telemetrySupported;
private final boolean serverSideRouting;
private final AtomicReference<BoltConnectionState> stateRef = new AtomicReference<>(BoltConnectionState.OPEN);
private final AtomicReference<CompletableFuture<AuthData>> authDataRef;
private final AtomicReference<CompletableFuture<AuthInfo>> authDataRef;
private final Map<String, Value> routingContext;
private final Queue<Function<ResponseHandler, CompletionStage<Void>>> messageWriters;
private final Clock clock;
Expand All @@ -92,7 +93,7 @@ public BoltConnectionImpl(
BoltProtocol protocol,
Connection connection,
EventLoop eventLoop,
Map<String, Value> authMap,
AuthToken authToken,
CompletableFuture<Long> latestAuthMillisFuture,
RoutingContext routingContext,
Clock clock,
Expand All @@ -107,7 +108,7 @@ public BoltConnectionImpl(
this.telemetrySupported = connection.isTelemetryEnabled();
this.serverSideRouting = connection.isSsrEnabled();
this.authDataRef = new AtomicReference<>(
CompletableFuture.completedFuture(new AuthDataImpl(authMap, latestAuthMillisFuture.join())));
CompletableFuture.completedFuture(new AuthInfoImpl(authToken, latestAuthMillisFuture.join())));
this.valueFactory = Objects.requireNonNull(valueFactory);
this.routingContext = routingContext.toMap().entrySet().stream()
.collect(Collectors.toUnmodifiableMap(
Expand Down Expand Up @@ -369,10 +370,10 @@ public void onSummary(Void summary) {
}

@Override
public CompletionStage<BoltConnection> logon(Map<String, Value> authMap) {
public CompletionStage<BoltConnection> logon(AuthToken authToken) {
return executeInEventLoop(() -> messageWriters.add(handler -> protocol.logon(
connection,
authMap,
authToken.asMap(),
clock,
new MessageHandler<>() {
@Override
Expand All @@ -383,7 +384,7 @@ public void onError(Throwable throwable) {

@Override
public void onSummary(Void summary) {
authDataRef.get().complete(new AuthDataImpl(authMap, clock.millis()));
authDataRef.get().complete(new AuthInfoImpl(authToken, clock.millis()));
handler.onLogonSummary(null);
}
},
Expand Down Expand Up @@ -498,7 +499,7 @@ public BoltConnectionState state() {
}

@Override
public CompletionStage<AuthData> authData() {
public CompletionStage<AuthInfo> authInfo() {
return authDataRef.get();
}

Expand Down Expand Up @@ -572,7 +573,7 @@ private void updateState(Throwable throwable) {
}
}

private record AuthDataImpl(Map<String, Value> authMap, long authAckMillis) implements AuthData {}
private record AuthInfoImpl(AuthToken authToken, long authAckMillis) implements AuthInfo {}

private static class ResponseHandleImpl implements ResponseHandler {
private final ResponseHandler delegate;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
import java.util.function.Function;
import java.util.function.Supplier;
import org.neo4j.driver.internal.bolt.api.AccessMode;
import org.neo4j.driver.internal.bolt.api.AuthToken;
import org.neo4j.driver.internal.bolt.api.BasicResponseHandler;
import org.neo4j.driver.internal.bolt.api.BoltAgent;
import org.neo4j.driver.internal.bolt.api.BoltConnection;
Expand All @@ -51,7 +52,6 @@
import org.neo4j.driver.internal.bolt.api.SecurityPlan;
import org.neo4j.driver.internal.bolt.api.exception.BoltTransientException;
import org.neo4j.driver.internal.bolt.api.exception.MinVersionAcquisitionException;
import org.neo4j.driver.internal.bolt.api.values.Value;
import org.neo4j.driver.internal.bolt.pooledimpl.impl.PooledBoltConnection;
import org.neo4j.driver.internal.bolt.pooledimpl.impl.util.FutureUtil;

Expand Down Expand Up @@ -129,7 +129,7 @@ public CompletionStage<Void> init(
public CompletionStage<BoltConnection> connect(
SecurityPlan securityPlan,
DatabaseName databaseName,
Supplier<CompletionStage<Map<String, Value>>> authMapStageSupplier,
Supplier<CompletionStage<AuthToken>> authTokenStageSupplier,
AccessMode mode,
Set<String> bookmarks,
String impersonatedUser,
Expand All @@ -145,7 +145,7 @@ public CompletionStage<BoltConnection> connect(

var acquisitionFuture = new CompletableFuture<PooledBoltConnection>();

authMapStageSupplier.get().whenComplete((authMap, authThrowable) -> {
authTokenStageSupplier.get().whenComplete((authToken, authThrowable) -> {
if (authThrowable != null) {
acquisitionFuture.completeExceptionally(authThrowable);
return;
Expand All @@ -168,8 +168,8 @@ public CompletionStage<BoltConnection> connect(
acquisitionFuture,
securityPlan,
databaseName,
authMap,
authMapStageSupplier,
authToken,
authTokenStageSupplier,
mode,
bookmarks,
impersonatedUser,
Expand All @@ -191,8 +191,8 @@ private void connect(
CompletableFuture<PooledBoltConnection> acquisitionFuture,
SecurityPlan securityPlan,
DatabaseName databaseName,
Map<String, Value> authMap,
Supplier<CompletionStage<Map<String, Value>>> authMapStageSupplier,
AuthToken authToken,
Supplier<CompletionStage<AuthToken>> authTokenStageSupplier,
AccessMode mode,
Set<String> bookmarks,
String impersonatedUser,
Expand All @@ -207,7 +207,7 @@ private void connect(
empty.set(pooledConnectionEntries.isEmpty());
try {
// go over existing entries first
connectionEntryWithMetadata = acquireExistingEntry(authMap, minVersion);
connectionEntryWithMetadata = acquireExistingEntry(authToken, minVersion);
} catch (MinVersionAcquisitionException e) {
acquisitionFuture.completeExceptionally(e);
return;
Expand Down Expand Up @@ -284,8 +284,8 @@ private void connect(
acquisitionFuture,
securityPlan,
databaseName,
authMap,
authMapStageSupplier,
authToken,
authTokenStageSupplier,
mode,
bookmarks,
impersonatedUser,
Expand All @@ -305,7 +305,7 @@ private void connect(
purge(entry);
metricsListener.afterConnectionReleased(poolId, inUseEvent);
});
reauthStage(entryWithMetadata, authMap).whenComplete((ignored2, throwable2) -> {
reauthStage(entryWithMetadata, authToken).whenComplete((ignored2, throwable2) -> {
if (!acquisitionFuture.complete(pooledConnection)) {
// acquisition timed out
CompletableFuture<PooledBoltConnection> pendingAcquisition;
Expand Down Expand Up @@ -336,7 +336,9 @@ private void connect(
.connect(
securityPlan,
databaseName,
empty.get() ? () -> CompletableFuture.completedStage(authMap) : authMapStageSupplier,
empty.get()
? () -> CompletableFuture.completedStage(authToken)
: authTokenStageSupplier,
mode,
bookmarks,
impersonatedUser,
Expand Down Expand Up @@ -395,7 +397,7 @@ private void connect(
}

private synchronized ConnectionEntryWithMetadata acquireExistingEntry(
Map<String, Value> authMap, BoltProtocolVersion minVersion) {
AuthToken authToken, BoltProtocolVersion minVersion) {
ConnectionEntryWithMetadata connectionEntryWithMetadata = null;
var iterator = pooledConnectionEntries.iterator();
while (iterator.hasNext()) {
Expand Down Expand Up @@ -431,10 +433,10 @@ private synchronized ConnectionEntryWithMetadata acquireExistingEntry(
}

// the pool must not have unauthenticated connections
var authData = connection.authData().toCompletableFuture().getNow(null);
var authInfo = connection.authInfo().toCompletableFuture().getNow(null);

var expiredByError = minAuthTimestamp > 0 && authData.authAckMillis() <= minAuthTimestamp;
var authMatches = authMap.equals(authData.authMap());
var expiredByError = minAuthTimestamp > 0 && authInfo.authAckMillis() <= minAuthTimestamp;
var authMatches = authToken.equals(authInfo.authToken());
var reauthNeeded = expiredByError || !authMatches;

if (reauthNeeded) {
Expand All @@ -461,14 +463,14 @@ private synchronized ConnectionEntryWithMetadata acquireExistingEntry(
}

private CompletionStage<Void> reauthStage(
ConnectionEntryWithMetadata connectionEntryWithMetadata, Map<String, Value> authMap) {
ConnectionEntryWithMetadata connectionEntryWithMetadata, AuthToken authToken) {
CompletionStage<Void> stage;
if (connectionEntryWithMetadata.reauthNeeded) {
stage = connectionEntryWithMetadata
.connectionEntry
.connection
.logoff()
.thenCompose(conn -> conn.logon(authMap))
.thenCompose(conn -> conn.logon(authToken))
.handle((ignored, throwable) -> {
if (throwable != null) {
connectionEntryWithMetadata.connectionEntry.connection.close();
Expand Down Expand Up @@ -500,11 +502,11 @@ private CompletionStage<Void> livenessCheckStage(ConnectionEntry entry) {
}

@Override
public CompletionStage<Void> verifyConnectivity(SecurityPlan securityPlan, Map<String, Value> authMap) {
public CompletionStage<Void> verifyConnectivity(SecurityPlan securityPlan, AuthToken authToken) {
return connect(
securityPlan,
null,
() -> CompletableFuture.completedStage(authMap),
() -> CompletableFuture.completedStage(authToken),
AccessMode.WRITE,
Collections.emptySet(),
null,
Expand All @@ -516,11 +518,11 @@ public CompletionStage<Void> verifyConnectivity(SecurityPlan securityPlan, Map<S
}

@Override
public CompletionStage<Boolean> supportsMultiDb(SecurityPlan securityPlan, Map<String, Value> authMap) {
public CompletionStage<Boolean> supportsMultiDb(SecurityPlan securityPlan, AuthToken authToken) {
return connect(
securityPlan,
null,
() -> CompletableFuture.completedStage(authMap),
() -> CompletableFuture.completedStage(authToken),
AccessMode.WRITE,
Collections.emptySet(),
null,
Expand All @@ -535,11 +537,11 @@ public CompletionStage<Boolean> supportsMultiDb(SecurityPlan securityPlan, Map<S
}

@Override
public CompletionStage<Boolean> supportsSessionAuth(SecurityPlan securityPlan, Map<String, Value> authMap) {
public CompletionStage<Boolean> supportsSessionAuth(SecurityPlan securityPlan, AuthToken authToken) {
return connect(
securityPlan,
null,
() -> CompletableFuture.completedStage(authMap),
() -> CompletableFuture.completedStage(authToken),
AccessMode.WRITE,
Collections.emptySet(),
null,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CompletionStage;
import org.neo4j.driver.internal.bolt.api.AccessMode;
import org.neo4j.driver.internal.bolt.api.AuthData;
import org.neo4j.driver.internal.bolt.api.AuthInfo;
import org.neo4j.driver.internal.bolt.api.AuthToken;
import org.neo4j.driver.internal.bolt.api.BasicResponseHandler;
import org.neo4j.driver.internal.bolt.api.BoltConnection;
import org.neo4j.driver.internal.bolt.api.BoltConnectionState;
Expand Down Expand Up @@ -162,8 +163,8 @@ public CompletionStage<BoltConnection> logoff() {
}

@Override
public CompletionStage<BoltConnection> logon(Map<String, Value> authMap) {
return delegate.logon(authMap).thenApply(ignored -> this);
public CompletionStage<BoltConnection> logon(AuthToken authToken) {
return delegate.logon(authToken).thenApply(ignored -> this);
}

@Override
Expand Down Expand Up @@ -321,8 +322,8 @@ public BoltConnectionState state() {
}

@Override
public CompletionStage<AuthData> authData() {
return delegate.authData();
public CompletionStage<AuthInfo> authInfo() {
return delegate.authInfo();
}

@Override
Expand Down
Loading