Skip to content

Commit dd2a3d2

Browse files
committed
Tests and docs for AuthenticationWebSocketInterceptor
Closes gh-268
1 parent 1171aee commit dd2a3d2

File tree

10 files changed

+328
-30
lines changed

10 files changed

+328
-30
lines changed

spring-graphql-docs/modules/ROOT/pages/transports.adoc

Lines changed: 25 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -148,15 +148,16 @@ called to process a request.
148148
[[server.interception.web]]
149149
=== `WebGraphQlInterceptor`
150150

151-
xref:transports.adoc#server.transports.http[HTTP] and xref:transports.adoc#server.transports.websocket[WebSocket] transports invoke a chain of
152-
0 or more `WebGraphQlInterceptor`, followed by an `ExecutionGraphQlService` that calls
153-
the GraphQL Java engine. `WebGraphQlInterceptor` allows an application to intercept
154-
incoming requests and do one of the following:
151+
xref:transports.adoc#server.transports.http[HTTP] and xref:transports.adoc#server.transports.websocket[WebSocket]
152+
transports invoke a chain of 0 or more `WebGraphQlInterceptor`, followed by an
153+
`ExecutionGraphQlService` that calls the GraphQL Java engine.
154+
Interceptors allow applications to intercept incoming requests in order to:
155155

156156
- Check HTTP request details
157157
- Customize the `graphql.ExecutionInput`
158158
- Add HTTP response headers
159159
- Customize the `graphql.ExecutionResult`
160+
- and more
160161

161162
For example, an interceptor can pass an HTTP request header to a `DataFetcher`:
162163

@@ -184,6 +185,26 @@ by the xref:boot-starter.adoc[Boot Starter], see
184185
{spring-boot-ref-docs}/web.html#web.graphql.transports.http-websocket[Web Endpoints].
185186

186187

188+
[[server.interception.websocket]]
189+
=== `WebSocketGraphQlInterceptor`
190+
191+
`WebSocketGraphQlInterceptor` extends `WebGraphQlInterceptor` with additional callbacks
192+
to handle the start and end of a WebSocket connection, in addition to client-side
193+
cancellation of subscriptions. The same also intercepts every GraphQL request on the
194+
WebSocket connection.
195+
196+
Use `WebGraphQlHandler` to configure the `WebGraphQlInterceptor` chain. This is supported
197+
by the xref:boot-starter.adoc[Boot Starter], see
198+
{spring-boot-ref-docs}/web.html#web.graphql.transports.http-websocket[Web Endpoints].
199+
There can be at most one `WebSocketGraphQlInterceptor` in a chain of interceptors.
200+
201+
There are two built-in WebSocket interceptors called `AuthenticationWebSocketInterceptor`,
202+
one for the WebMVC and one for the WebFlux transports. These help to extract authentication
203+
details from the payload of a `"connection_init"` GraphQL over WebSocket message, authenticate,
204+
and then propagate the `SecurityContext` to subsequent requests on the WebSocket connection.
205+
206+
207+
187208
[[server.interception.rsocket]]
188209
=== `RSocketQlInterceptor`
189210

spring-graphql/build.gradle

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@ dependencies {
6969
testImplementation 'org.testcontainers:neo4j'
7070
testImplementation 'org.testcontainers:junit-jupiter'
7171
testImplementation 'org.springframework.security:spring-security-core'
72+
testImplementation 'org.springframework.security:spring-security-oauth2-resource-server'
7273
testImplementation 'com.querydsl:querydsl-core'
7374
testImplementation 'com.querydsl:querydsl-collections'
7475
testImplementation 'jakarta.servlet:jakarta.servlet-api'

spring-graphql/src/main/java/org/springframework/graphql/server/WebSocketGraphQlInterceptor.java

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2002-2023 the original author or authors.
2+
* Copyright 2002-2024 the original author or authors.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -23,7 +23,8 @@
2323

2424
/**
2525
* An extension of {@link WebGraphQlInterceptor} with additional methods
26-
* to handle the start and end of a WebSocket connection.
26+
* to handle the start and end of a WebSocket connection, as well as client-side
27+
* cancellation of subscriptions.
2728
*
2829
* <p>Use {@link WebGraphQlHandler.Builder#interceptor(WebGraphQlInterceptor...)}
2930
* to configure the interceptor chain. Only one interceptor in the chain may be

spring-graphql/src/main/java/org/springframework/graphql/server/support/AbstractAuthenticationWebSocketInterceptor.java

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
import org.springframework.graphql.server.WebSocketSessionInfo;
2929
import org.springframework.security.core.Authentication;
3030
import org.springframework.security.core.context.SecurityContext;
31+
import org.springframework.security.core.context.SecurityContextImpl;
3132

3233
/**
3334
* Base class for interceptors that extract an {@link Authentication} from
@@ -41,8 +42,7 @@
4142
*/
4243
public abstract class AbstractAuthenticationWebSocketInterceptor implements WebSocketGraphQlInterceptor {
4344

44-
private static final String AUTHENTICATION_ATTRIBUTE =
45-
AbstractAuthenticationWebSocketInterceptor.class.getName() + ".AUTHENTICATION";
45+
private final String authenticationAttribute = getClass().getName() + ".AUTHENTICATION";
4646

4747

4848
private final AuthenticationExtractor authenticationExtractor;
@@ -60,8 +60,11 @@ public AbstractAuthenticationWebSocketInterceptor(AuthenticationExtractor authEx
6060
@Override
6161
public Mono<Object> handleConnectionInitialization(WebSocketSessionInfo info, Map<String, Object> payload) {
6262
return this.authenticationExtractor.getAuthentication(payload)
63-
.flatMap(this::getSecurityContext)
64-
.doOnNext((securityContext) -> info.getAttributes().put(AUTHENTICATION_ATTRIBUTE, securityContext))
63+
.flatMap(this::authenticate)
64+
.doOnNext((authentication) -> {
65+
SecurityContext securityContext = new SecurityContextImpl(authentication);
66+
info.getAttributes().put(this.authenticationAttribute, securityContext);
67+
})
6568
.then(Mono.empty());
6669
}
6770

@@ -70,15 +73,15 @@ public Mono<Object> handleConnectionInitialization(WebSocketSessionInfo info, Ma
7073
* {@link SecurityContext} or an error.
7174
* @param authentication the authentication value extracted from the payload
7275
*/
73-
protected abstract Mono<SecurityContext> getSecurityContext(Authentication authentication);
76+
protected abstract Mono<Authentication> authenticate(Authentication authentication);
7477

7578
@Override
7679
public Mono<WebGraphQlResponse> intercept(WebGraphQlRequest request, Chain chain) {
7780
if (!(request instanceof WebSocketGraphQlRequest webSocketRequest)) {
7881
return chain.next(request);
7982
}
8083
Map<String, Object> attributes = webSocketRequest.getSessionInfo().getAttributes();
81-
SecurityContext securityContext = (SecurityContext) attributes.get(AUTHENTICATION_ATTRIBUTE);
84+
SecurityContext securityContext = (SecurityContext) attributes.get(this.authenticationAttribute);
8285
ContextView contextView = getContextToWrite(securityContext);
8386
return chain.next(request).contextWrite(contextView);
8487
}

spring-graphql/src/main/java/org/springframework/graphql/server/support/BearerTokenAuthenticationExtractor.java

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,9 @@
3939
*/
4040
public final class BearerTokenAuthenticationExtractor implements AuthenticationExtractor {
4141

42+
/** Default key to access Authorization value in {@code connection_init} payload. */
43+
public static final String AUTHORIZATION_KEY = "Authorization";
44+
4245
private static final Pattern authorizationPattern =
4346
Pattern.compile("^Bearer (?<token>[a-zA-Z0-9-._~+/]+=*)$", Pattern.CASE_INSENSITIVE);
4447

@@ -47,10 +50,10 @@ public final class BearerTokenAuthenticationExtractor implements AuthenticationE
4750

4851

4952
/**
50-
* Constructor that defaults the payload key to use to "Authorization".
53+
* Constructor that defaults to {@link #AUTHORIZATION_KEY} for the payload key.
5154
*/
5255
public BearerTokenAuthenticationExtractor() {
53-
this("Authorization");
56+
this(AUTHORIZATION_KEY);
5457
}
5558

5659
/**
@@ -66,18 +69,23 @@ public BearerTokenAuthenticationExtractor(String authorizationKey) {
6669
@Override
6770
public Mono<Authentication> getAuthentication(Map<String, Object> payload) {
6871
String authorizationValue = (String) payload.get(this.authorizationKey);
69-
if (!StringUtils.startsWithIgnoreCase(authorizationValue, "bearer")) {
72+
if (authorizationValue == null) {
7073
return Mono.empty();
7174
}
7275

76+
if (!StringUtils.startsWithIgnoreCase(authorizationValue, "bearer")) {
77+
BearerTokenError error = BearerTokenErrors.invalidRequest("Not a bearer token");
78+
return Mono.error(new OAuth2AuthenticationException(error));
79+
}
80+
7381
Matcher matcher = authorizationPattern.matcher(authorizationValue);
74-
if (matcher.matches()) {
75-
String token = matcher.group("token");
76-
return Mono.just(new BearerTokenAuthenticationToken(token));
82+
if (!matcher.matches()) {
83+
BearerTokenError error = BearerTokenErrors.invalidToken("Bearer token is malformed");
84+
return Mono.error(new OAuth2AuthenticationException(error));
7785
}
7886

79-
BearerTokenError error = BearerTokenErrors.invalidToken("Bearer token is malformed");
80-
return Mono.error(new OAuth2AuthenticationException(error));
87+
String token = matcher.group("token");
88+
return Mono.just(new BearerTokenAuthenticationToken(token));
8189
}
8290

8391
}

spring-graphql/src/main/java/org/springframework/graphql/server/webflux/AuthenticationWebSocketInterceptor.java

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@
2525
import org.springframework.security.core.Authentication;
2626
import org.springframework.security.core.context.ReactiveSecurityContextHolder;
2727
import org.springframework.security.core.context.SecurityContext;
28-
import org.springframework.security.core.context.SecurityContextImpl;
2928

3029
/**
3130
* Extension of {@link AbstractAuthenticationWebSocketInterceptor} for use with
@@ -35,7 +34,7 @@
3534
* @author Rossen Stoyanchev
3635
* @since 1.3.0
3736
*/
38-
public class AuthenticationWebSocketInterceptor extends AbstractAuthenticationWebSocketInterceptor {
37+
public final class AuthenticationWebSocketInterceptor extends AbstractAuthenticationWebSocketInterceptor {
3938

4039
private final ReactiveAuthenticationManager authenticationManager;
4140

@@ -48,8 +47,8 @@ public AuthenticationWebSocketInterceptor(
4847
}
4948

5049
@Override
51-
protected Mono<SecurityContext> getSecurityContext(Authentication authentication) {
52-
return this.authenticationManager.authenticate(authentication).map(SecurityContextImpl::new);
50+
protected Mono<Authentication> authenticate(Authentication authentication) {
51+
return this.authenticationManager.authenticate(authentication);
5352
}
5453

5554
@Override

spring-graphql/src/main/java/org/springframework/graphql/server/webmvc/AuthenticationWebSocketInterceptor.java

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@
2525
import org.springframework.security.authentication.AuthenticationManager;
2626
import org.springframework.security.core.Authentication;
2727
import org.springframework.security.core.context.SecurityContext;
28-
import org.springframework.security.core.context.SecurityContextImpl;
2928

3029
/**
3130
* Extension of {@link AbstractAuthenticationWebSocketInterceptor} for use with
@@ -35,22 +34,21 @@
3534
* @author Rossen Stoyanchev
3635
* @since 1.3.0
3736
*/
38-
public class AuthenticationWebSocketInterceptor extends AbstractAuthenticationWebSocketInterceptor {
37+
public final class AuthenticationWebSocketInterceptor extends AbstractAuthenticationWebSocketInterceptor {
3938

4039
private final AuthenticationManager authenticationManager;
4140

4241

4342
public AuthenticationWebSocketInterceptor(
44-
AuthenticationManager authManager, AuthenticationExtractor authExtractor) {
43+
AuthenticationExtractor authExtractor, AuthenticationManager authManager) {
4544

4645
super(authExtractor);
4746
this.authenticationManager = authManager;
4847
}
4948

5049
@Override
51-
protected Mono<SecurityContext> getSecurityContext(Authentication authentication) {
52-
Authentication authenticate = this.authenticationManager.authenticate(authentication);
53-
return Mono.just(new SecurityContextImpl(authenticate));
50+
protected Mono<Authentication> authenticate(Authentication authentication) {
51+
return Mono.just(this.authenticationManager.authenticate(authentication));
5452
}
5553

5654
@Override
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
/*
2+
* Copyright 2002-2024 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package org.springframework.graphql.server.support;
18+
19+
import java.util.Collections;
20+
import java.util.Map;
21+
22+
import org.junit.jupiter.api.Test;
23+
24+
import org.springframework.lang.Nullable;
25+
import org.springframework.security.core.Authentication;
26+
import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
27+
28+
import static org.assertj.core.api.Assertions.assertThat;
29+
import static org.assertj.core.api.Assertions.assertThatThrownBy;
30+
31+
/**
32+
* Unit tests for {@link BearerTokenAuthenticationExtractorTests}.
33+
*
34+
* @author Rossen Stoyanchev
35+
*/
36+
public class BearerTokenAuthenticationExtractorTests {
37+
38+
private static final BearerTokenAuthenticationExtractor extractor = new BearerTokenAuthenticationExtractor();
39+
40+
41+
@Test
42+
void extract() {
43+
Authentication auth = getAuthentication("Bearer 123456789");
44+
45+
assertThat(auth).isNotNull();
46+
assertThat(auth.getName()).isEqualTo("123456789");
47+
}
48+
49+
@Test
50+
void noToken() {
51+
Authentication auth = getAuthentication(Collections.emptyMap());
52+
assertThat(auth).isNull();
53+
}
54+
55+
@Test
56+
void notBearerToken() {
57+
assertThatThrownBy(() -> getAuthentication("abc"))
58+
.isInstanceOf(OAuth2AuthenticationException.class)
59+
.hasMessage("Not a bearer token");
60+
}
61+
62+
@Test
63+
void invalidToken() {
64+
assertThatThrownBy(() -> getAuthentication("Bearer ???"))
65+
.isInstanceOf(OAuth2AuthenticationException.class)
66+
.hasMessage("Bearer token is malformed");
67+
}
68+
69+
@Nullable
70+
private static Authentication getAuthentication(String value) {
71+
return getAuthentication(Map.of(BearerTokenAuthenticationExtractor.AUTHORIZATION_KEY, value));
72+
}
73+
74+
@Nullable
75+
private static Authentication getAuthentication(Map<String, Object> map) {
76+
return extractor.getAuthentication(map).block();
77+
}
78+
79+
}

0 commit comments

Comments
 (0)