Skip to content

Commit c03cdba

Browse files
committed
CORS support in HTTP method predicate
This commit introduces CORS support for the HttpMethodPredicate in WebMvc.fn and WebFlux.fn. Closes gh-24564
1 parent fc12891 commit c03cdba

File tree

4 files changed

+156
-48
lines changed

4 files changed

+156
-48
lines changed

spring-webflux/src/main/java/org/springframework/web/reactive/function/server/RequestPredicates.java

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2002-2019 the original author or authors.
2+
* Copyright 2002-2020 the original author or authors.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -52,6 +52,7 @@
5252
import org.springframework.lang.Nullable;
5353
import org.springframework.util.Assert;
5454
import org.springframework.util.MultiValueMap;
55+
import org.springframework.web.cors.reactive.CorsUtils;
5556
import org.springframework.web.reactive.function.BodyExtractor;
5657
import org.springframework.web.server.ServerWebExchange;
5758
import org.springframework.web.server.WebSession;
@@ -449,11 +450,25 @@ public HttpMethodPredicate(HttpMethod... httpMethods) {
449450

450451
@Override
451452
public boolean test(ServerRequest request) {
452-
boolean match = this.httpMethods.contains(request.method());
453-
traceMatch("Method", this.httpMethods, request.method(), match);
453+
HttpMethod method = method(request);
454+
boolean match = this.httpMethods.contains(method);
455+
traceMatch("Method", this.httpMethods, method, match);
454456
return match;
455457
}
456458

459+
@Nullable
460+
private static HttpMethod method(ServerRequest request) {
461+
if (CorsUtils.isPreFlightRequest(request.exchange().getRequest())) {
462+
String accessControlRequestMethod =
463+
request.headers().firstHeader(HttpHeaders.ACCESS_CONTROL_REQUEST_METHOD);
464+
return HttpMethod.resolve(accessControlRequestMethod);
465+
}
466+
else {
467+
return request.method();
468+
}
469+
}
470+
471+
457472
@Override
458473
public void accept(Visitor visitor) {
459474
visitor.method(Collections.unmodifiableSet(this.httpMethods));

spring-webflux/src/test/java/org/springframework/web/reactive/function/server/RequestPredicatesTests.java

Lines changed: 103 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2002-2019 the original author or authors.
2+
* Copyright 2002-2020 the original author or authors.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -22,109 +22,148 @@
2222

2323
import org.junit.jupiter.api.Test;
2424

25+
import org.springframework.http.HttpHeaders;
2526
import org.springframework.http.HttpMethod;
2627
import org.springframework.http.MediaType;
28+
import org.springframework.web.testfixture.http.server.reactive.MockServerHttpRequest;
29+
import org.springframework.web.testfixture.server.MockServerWebExchange;
2730
import org.springframework.web.util.pattern.PathPatternParser;
2831

32+
import static java.util.Collections.emptyList;
2933
import static org.assertj.core.api.Assertions.assertThat;
3034

3135
/**
3236
* @author Arjen Poutsma
3337
*/
3438
public class RequestPredicatesTests {
3539

40+
3641
@Test
3742
public void all() {
43+
MockServerHttpRequest mockRequest = MockServerHttpRequest.get("https://example.com").build();
44+
MockServerWebExchange mockExchange = MockServerWebExchange.from(mockRequest);
3845
RequestPredicate predicate = RequestPredicates.all();
39-
MockServerRequest request = MockServerRequest.builder().build();
46+
ServerRequest request = new DefaultServerRequest(mockExchange, Collections.emptyList());
4047
assertThat(predicate.test(request)).isTrue();
4148
}
4249

4350
@Test
4451
public void method() {
52+
MockServerHttpRequest mockRequest = MockServerHttpRequest.get("https://example.com").build();
53+
4554
HttpMethod httpMethod = HttpMethod.GET;
4655
RequestPredicate predicate = RequestPredicates.method(httpMethod);
47-
MockServerRequest request = MockServerRequest.builder().method(httpMethod).build();
56+
ServerRequest request = new DefaultServerRequest(MockServerWebExchange.from(mockRequest), Collections.emptyList());
57+
assertThat(predicate.test(request)).isTrue();
58+
59+
mockRequest = MockServerHttpRequest.post("https://example.com").build();
60+
request = new DefaultServerRequest(MockServerWebExchange.from(mockRequest), Collections.emptyList());
61+
assertThat(predicate.test(request)).isFalse();
62+
}
63+
64+
@Test
65+
public void methodCorsPreFlight() {
66+
RequestPredicate predicate = RequestPredicates.method(HttpMethod.PUT);
67+
68+
MockServerHttpRequest mockRequest = MockServerHttpRequest.options("https://example.com")
69+
.header("Origin", "https://example.com")
70+
.header(HttpHeaders.ACCESS_CONTROL_REQUEST_METHOD, "PUT")
71+
.build();
72+
ServerRequest request = new DefaultServerRequest(MockServerWebExchange.from(mockRequest), Collections.emptyList());
73+
4874
assertThat(predicate.test(request)).isTrue();
4975

50-
request = MockServerRequest.builder().method(HttpMethod.POST).build();
76+
mockRequest = MockServerHttpRequest.options("https://example.com")
77+
.header("Origin", "https://example.com")
78+
.header(HttpHeaders.ACCESS_CONTROL_REQUEST_METHOD, "POST")
79+
.build();
80+
request = new DefaultServerRequest(MockServerWebExchange.from(mockRequest), Collections.emptyList());
5181
assertThat(predicate.test(request)).isFalse();
5282
}
5383

84+
5485
@Test
5586
public void methods() {
5687
RequestPredicate predicate = RequestPredicates.methods(HttpMethod.GET, HttpMethod.HEAD);
57-
MockServerRequest request = MockServerRequest.builder().method(HttpMethod.GET).build();
88+
MockServerHttpRequest mockRequest = MockServerHttpRequest.get("https://example.com").build();
89+
ServerRequest request = new DefaultServerRequest(MockServerWebExchange.from(mockRequest), Collections.emptyList());
5890
assertThat(predicate.test(request)).isTrue();
5991

60-
request = MockServerRequest.builder().method(HttpMethod.HEAD).build();
92+
mockRequest = MockServerHttpRequest.head("https://example.com").build();
93+
request = new DefaultServerRequest(MockServerWebExchange.from(mockRequest), Collections.emptyList());
6194
assertThat(predicate.test(request)).isTrue();
6295

63-
request = MockServerRequest.builder().method(HttpMethod.POST).build();
96+
mockRequest = MockServerHttpRequest.post("https://example.com").build();
97+
request = new DefaultServerRequest(MockServerWebExchange.from(mockRequest), Collections.emptyList());
6498
assertThat(predicate.test(request)).isFalse();
6599
}
66100

67101
@Test
68102
public void allMethods() {
69-
URI uri = URI.create("http://localhost/path");
70-
71103
RequestPredicate predicate = RequestPredicates.GET("/p*");
72-
MockServerRequest request = MockServerRequest.builder().method(HttpMethod.GET).uri(uri).build();
104+
MockServerHttpRequest mockRequest = MockServerHttpRequest.get("https://example.com/path").build();
105+
ServerRequest request = new DefaultServerRequest(MockServerWebExchange.from(mockRequest), Collections.emptyList());
73106
assertThat(predicate.test(request)).isTrue();
74107

75108
predicate = RequestPredicates.HEAD("/p*");
76-
request = MockServerRequest.builder().method(HttpMethod.HEAD).uri(uri).build();
109+
mockRequest = MockServerHttpRequest.head("https://example.com/path").build();
110+
request = new DefaultServerRequest(MockServerWebExchange.from(mockRequest), Collections.emptyList());
77111
assertThat(predicate.test(request)).isTrue();
78112

79113
predicate = RequestPredicates.POST("/p*");
80-
request = MockServerRequest.builder().method(HttpMethod.POST).uri(uri).build();
114+
mockRequest = MockServerHttpRequest.post("https://example.com/path").build();
115+
request = new DefaultServerRequest(MockServerWebExchange.from(mockRequest), Collections.emptyList());
81116
assertThat(predicate.test(request)).isTrue();
82117

83118
predicate = RequestPredicates.PUT("/p*");
84-
request = MockServerRequest.builder().method(HttpMethod.PUT).uri(uri).build();
119+
mockRequest = MockServerHttpRequest.put("https://example.com/path").build();
120+
request = new DefaultServerRequest(MockServerWebExchange.from(mockRequest), Collections.emptyList());
85121
assertThat(predicate.test(request)).isTrue();
86122

87123
predicate = RequestPredicates.PATCH("/p*");
88-
request = MockServerRequest.builder().method(HttpMethod.PATCH).uri(uri).build();
124+
mockRequest = MockServerHttpRequest.patch("https://example.com/path").build();
125+
request = new DefaultServerRequest(MockServerWebExchange.from(mockRequest), Collections.emptyList());
89126
assertThat(predicate.test(request)).isTrue();
90127

91128
predicate = RequestPredicates.DELETE("/p*");
92-
request = MockServerRequest.builder().method(HttpMethod.DELETE).uri(uri).build();
129+
mockRequest = MockServerHttpRequest.delete("https://example.com/path").build();
130+
request = new DefaultServerRequest(MockServerWebExchange.from(mockRequest), Collections.emptyList());
93131
assertThat(predicate.test(request)).isTrue();
94132

95133
predicate = RequestPredicates.OPTIONS("/p*");
96-
request = MockServerRequest.builder().method(HttpMethod.OPTIONS).uri(uri).build();
134+
mockRequest = MockServerHttpRequest.options("https://example.com/path").build();
135+
request = new DefaultServerRequest(MockServerWebExchange.from(mockRequest), Collections.emptyList());
97136
assertThat(predicate.test(request)).isTrue();
98137
}
99138

100139
@Test
101140
public void path() {
102-
URI uri = URI.create("http://localhost/path");
141+
URI uri = URI.create("https://localhost/path");
103142
RequestPredicate predicate = RequestPredicates.path("/p*");
104-
MockServerRequest request = MockServerRequest.builder().uri(uri).build();
143+
MockServerHttpRequest mockRequest = MockServerHttpRequest.get(uri.toString()).build();
144+
ServerRequest request = new DefaultServerRequest(MockServerWebExchange.from(mockRequest), emptyList());
105145
assertThat(predicate.test(request)).isTrue();
106146

107-
request = MockServerRequest.builder().build();
147+
mockRequest = MockServerHttpRequest.head("https://example.com").build();
148+
request = new DefaultServerRequest(MockServerWebExchange.from(mockRequest), Collections.emptyList());
108149
assertThat(predicate.test(request)).isFalse();
109150
}
110151

111152
@Test
112153
public void pathNoLeadingSlash() {
113-
URI uri = URI.create("http://localhost/path");
114154
RequestPredicate predicate = RequestPredicates.path("p*");
115-
MockServerRequest request = MockServerRequest.builder().uri(uri).build();
155+
MockServerHttpRequest mockRequest = MockServerHttpRequest.get("https://example.com/path").build();
156+
ServerRequest request = new DefaultServerRequest(MockServerWebExchange.from(mockRequest), Collections.emptyList());
116157
assertThat(predicate.test(request)).isTrue();
117158
}
118159

119160
@Test
120161
public void pathEncoded() {
121-
URI uri = URI.create("http://localhost/foo%20bar");
162+
URI uri = URI.create("https://localhost/foo%20bar");
122163
RequestPredicate predicate = RequestPredicates.path("/foo bar");
123-
MockServerRequest request = MockServerRequest.builder().uri(uri).build();
164+
MockServerHttpRequest mockRequest = MockServerHttpRequest.method(HttpMethod.GET, uri).build();
165+
ServerRequest request = new DefaultServerRequest(MockServerWebExchange.from(mockRequest), Collections.emptyList());
124166
assertThat(predicate.test(request)).isTrue();
125-
126-
request = MockServerRequest.builder().build();
127-
assertThat(predicate.test(request)).isFalse();
128167
}
129168

130169
@Test
@@ -133,9 +172,9 @@ public void pathPredicates() {
133172
parser.setCaseSensitive(false);
134173
Function<String, RequestPredicate> pathPredicates = RequestPredicates.pathPredicates(parser);
135174

136-
URI uri = URI.create("http://localhost/path");
137175
RequestPredicate predicate = pathPredicates.apply("/P*");
138-
MockServerRequest request = MockServerRequest.builder().uri(uri).build();
176+
MockServerHttpRequest mockRequest = MockServerHttpRequest.get("https://example.com/path").build();
177+
ServerRequest request = new DefaultServerRequest(MockServerWebExchange.from(mockRequest), Collections.emptyList());
139178
assertThat(predicate.test(request)).isTrue();
140179
}
141180

@@ -146,58 +185,81 @@ public void headers() {
146185
RequestPredicate predicate =
147186
RequestPredicates.headers(
148187
headers -> headers.header(name).equals(Collections.singletonList(value)));
149-
MockServerRequest request = MockServerRequest.builder().header(name, value).build();
188+
MockServerHttpRequest mockRequest = MockServerHttpRequest.get("https://example.com")
189+
.header(name, value)
190+
.build();
191+
ServerRequest request = new DefaultServerRequest(MockServerWebExchange.from(mockRequest), Collections.emptyList());
150192
assertThat(predicate.test(request)).isTrue();
151193

152-
request = MockServerRequest.builder().build();
194+
mockRequest = MockServerHttpRequest.get("https://example.com")
195+
.header(name, "bar")
196+
.build();
197+
request = new DefaultServerRequest(MockServerWebExchange.from(mockRequest), Collections.emptyList());
153198
assertThat(predicate.test(request)).isFalse();
154199
}
155200

156201
@Test
157202
public void contentType() {
158203
MediaType json = MediaType.APPLICATION_JSON;
159204
RequestPredicate predicate = RequestPredicates.contentType(json);
160-
MockServerRequest request = MockServerRequest.builder().header("Content-Type", json.toString()).build();
205+
MockServerHttpRequest mockRequest = MockServerHttpRequest.get("https://example.com")
206+
.header(HttpHeaders.CONTENT_TYPE, json.toString())
207+
.build();
208+
ServerRequest request = new DefaultServerRequest(MockServerWebExchange.from(mockRequest), Collections.emptyList());
161209
assertThat(predicate.test(request)).isTrue();
162210

163-
request = MockServerRequest.builder().build();
211+
mockRequest = MockServerHttpRequest.get("https://example.com")
212+
.header(HttpHeaders.CONTENT_TYPE, "foo/bar")
213+
.build();
214+
request = new DefaultServerRequest(MockServerWebExchange.from(mockRequest), Collections.emptyList());
164215
assertThat(predicate.test(request)).isFalse();
165216
}
166217

167218
@Test
168219
public void accept() {
169220
MediaType json = MediaType.APPLICATION_JSON;
170221
RequestPredicate predicate = RequestPredicates.accept(json);
171-
MockServerRequest request = MockServerRequest.builder().header("Accept", json.toString()).build();
222+
MockServerHttpRequest mockRequest = MockServerHttpRequest.get("https://example.com")
223+
.header(HttpHeaders.ACCEPT, json.toString())
224+
.build();
225+
ServerRequest request = new DefaultServerRequest(MockServerWebExchange.from(mockRequest), Collections.emptyList());
172226
assertThat(predicate.test(request)).isTrue();
173227

174-
request = MockServerRequest.builder().header("Accept", MediaType.TEXT_XML_VALUE).build();
228+
mockRequest = MockServerHttpRequest.get("https://example.com")
229+
.header(HttpHeaders.ACCEPT, "foo/bar")
230+
.build();
231+
request = new DefaultServerRequest(MockServerWebExchange.from(mockRequest), Collections.emptyList());
175232
assertThat(predicate.test(request)).isFalse();
176233
}
177234

178235
@Test
179236
public void pathExtension() {
180237
RequestPredicate predicate = RequestPredicates.pathExtension("txt");
181238

182-
URI uri = URI.create("http://localhost/file.txt");
183-
MockServerRequest request = MockServerRequest.builder().uri(uri).build();
239+
URI uri = URI.create("https://localhost/file.txt");
240+
MockServerHttpRequest mockRequest = MockServerHttpRequest.method(HttpMethod.GET, uri).build();
241+
ServerRequest request = new DefaultServerRequest(MockServerWebExchange.from(mockRequest), Collections.emptyList());
184242
assertThat(predicate.test(request)).isTrue();
185243

186-
uri = URI.create("http://localhost/FILE.TXT");
187-
request = MockServerRequest.builder().uri(uri).build();
244+
uri = URI.create("https://localhost/FILE.TXT");
245+
mockRequest = MockServerHttpRequest.method(HttpMethod.GET, uri).build();
246+
request = new DefaultServerRequest(MockServerWebExchange.from(mockRequest), Collections.emptyList());
188247
assertThat(predicate.test(request)).isTrue();
189248

190249
predicate = RequestPredicates.pathExtension("bar");
191250
assertThat(predicate.test(request)).isFalse();
192251

193-
uri = URI.create("http://localhost/file.foo");
194-
request = MockServerRequest.builder().uri(uri).build();
252+
uri = URI.create("https://localhost/file.foo");
253+
mockRequest = MockServerHttpRequest.method(HttpMethod.GET, uri).build();
254+
request = new DefaultServerRequest(MockServerWebExchange.from(mockRequest), Collections.emptyList());
195255
assertThat(predicate.test(request)).isFalse();
196256
}
197257

198258
@Test
199259
public void queryParam() {
200-
MockServerRequest request = MockServerRequest.builder().queryParam("foo", "bar").build();
260+
MockServerHttpRequest mockRequest = MockServerHttpRequest.get("https://example.com")
261+
.queryParam("foo", "bar").build();
262+
ServerRequest request = new DefaultServerRequest(MockServerWebExchange.from(mockRequest), Collections.emptyList());
201263
RequestPredicate predicate = RequestPredicates.queryParam("foo", s -> s.equals("bar"));
202264
assertThat(predicate.test(request)).isTrue();
203265

spring-webmvc/src/main/java/org/springframework/web/servlet/function/RequestPredicates.java

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2002-2019 the original author or authors.
2+
* Copyright 2002-2020 the original author or authors.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -53,6 +53,7 @@
5353
import org.springframework.lang.Nullable;
5454
import org.springframework.util.Assert;
5555
import org.springframework.util.MultiValueMap;
56+
import org.springframework.web.cors.CorsUtils;
5657
import org.springframework.web.util.UriBuilder;
5758
import org.springframework.web.util.UriUtils;
5859
import org.springframework.web.util.pattern.PathPattern;
@@ -444,11 +445,24 @@ public HttpMethodPredicate(HttpMethod... httpMethods) {
444445

445446
@Override
446447
public boolean test(ServerRequest request) {
447-
boolean match = this.httpMethods.contains(request.method());
448-
traceMatch("Method", this.httpMethods, request.method(), match);
448+
HttpMethod method = method(request);
449+
boolean match = this.httpMethods.contains(method);
450+
traceMatch("Method", this.httpMethods, method, match);
449451
return match;
450452
}
451453

454+
@Nullable
455+
private static HttpMethod method(ServerRequest request) {
456+
if (CorsUtils.isPreFlightRequest(request.servletRequest())) {
457+
String accessControlRequestMethod =
458+
request.headers().firstHeader(HttpHeaders.ACCESS_CONTROL_REQUEST_METHOD);
459+
return HttpMethod.resolve(accessControlRequestMethod);
460+
}
461+
else {
462+
return request.method();
463+
}
464+
}
465+
452466
@Override
453467
public void accept(Visitor visitor) {
454468
visitor.method(Collections.unmodifiableSet(this.httpMethods));

0 commit comments

Comments
 (0)