Skip to content

Commit d3af0af

Browse files
Add shouldConvertGetRequests property
Signed-off-by: Tran Ngoc Nhan <[email protected]>
1 parent 6e793e8 commit d3af0af

File tree

3 files changed

+38
-37
lines changed

3 files changed

+38
-37
lines changed

config/src/test/java/org/springframework/security/config/http/Saml2LoginBeanDefinitionParserTests.java

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2002-2025 the original author or authors.
2+
* Copyright 2002-2022 the original author or authors.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -47,6 +47,7 @@
4747
import org.springframework.security.saml2.core.Saml2Utils;
4848
import org.springframework.security.saml2.core.TestSaml2X509Credentials;
4949
import org.springframework.security.saml2.provider.service.authentication.AbstractSaml2AuthenticationRequest;
50+
import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticatedPrincipal;
5051
import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationException;
5152
import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationToken;
5253
import org.springframework.security.saml2.provider.service.authentication.Saml2RedirectAuthenticationRequest;
@@ -74,7 +75,6 @@
7475
import static org.mockito.BDDMockito.given;
7576
import static org.mockito.Mockito.atLeastOnce;
7677
import static org.mockito.Mockito.mock;
77-
import static org.mockito.Mockito.never;
7878
import static org.mockito.Mockito.verify;
7979
import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.get;
8080
import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.post;
@@ -210,11 +210,12 @@ public void authenticateWhenAuthenticationResponseValidThenAuthenticate() throws
210210
// @formatter:off
211211
this.mvc.perform(post("/login/saml2/sso/" + relyingPartyRegistration.getRegistrationId()).param(Saml2ParameterNames.SAML_RESPONSE, SIGNED_RESPONSE))
212212
.andDo(MockMvcResultHandlers.print())
213-
.andExpect(status().is3xxRedirection());
213+
.andExpect(status().is2xxSuccessful());
214214
// @formatter:on
215215
ArgumentCaptor<Authentication> authenticationCaptor = ArgumentCaptor.forClass(Authentication.class);
216-
verify(this.authenticationSuccessHandler, never()).onAuthenticationSuccess(any(), any(),
217-
authenticationCaptor.capture());
216+
verify(this.authenticationSuccessHandler).onAuthenticationSuccess(any(), any(), authenticationCaptor.capture());
217+
Authentication authentication = authenticationCaptor.getValue();
218+
assertThat(authentication.getPrincipal()).isInstanceOf(Saml2AuthenticatedPrincipal.class);
218219
}
219220

220221
@Test
@@ -224,11 +225,12 @@ public void authenticateWhenCustomSecurityContextHolderStrategyThenUses() throws
224225
// @formatter:off
225226
this.mvc.perform(post("/login/saml2/sso/" + relyingPartyRegistration.getRegistrationId()).param(Saml2ParameterNames.SAML_RESPONSE, SIGNED_RESPONSE))
226227
.andDo(MockMvcResultHandlers.print())
227-
.andExpect(status().is3xxRedirection());
228+
.andExpect(status().is2xxSuccessful());
228229
// @formatter:on
229230
ArgumentCaptor<Authentication> authenticationCaptor = ArgumentCaptor.forClass(Authentication.class);
230-
verify(this.authenticationSuccessHandler, never()).onAuthenticationSuccess(any(), any(),
231-
authenticationCaptor.capture());
231+
verify(this.authenticationSuccessHandler).onAuthenticationSuccess(any(), any(), authenticationCaptor.capture());
232+
Authentication authentication = authenticationCaptor.getValue();
233+
assertThat(authentication.getPrincipal()).isInstanceOf(Saml2AuthenticatedPrincipal.class);
232234
SecurityContextHolderStrategy strategy = this.spring.getContext().getBean(SecurityContextHolderStrategy.class);
233235
verify(strategy, atLeastOnce()).getContext();
234236
}
@@ -240,8 +242,9 @@ public void authenticateWhenAuthenticationResponseValidThenAuthenticationSuccess
240242
// @formatter:off
241243
this.mvc.perform(post("/login/saml2/sso/" + relyingPartyRegistration.getRegistrationId()).param(Saml2ParameterNames.SAML_RESPONSE, SIGNED_RESPONSE))
242244
.andDo(MockMvcResultHandlers.print())
243-
.andExpect(status().is3xxRedirection());
245+
.andExpect(status().is2xxSuccessful());
244246
// @formatter:on
247+
verify(this.authenticationSuccessListener).onApplicationEvent(any(AuthenticationSuccessEvent.class));
245248
}
246249

247250
@Test
@@ -274,8 +277,8 @@ public void authenticateWhenCustomAuthenticationManagerThenUses() throws Excepti
274277
MockHttpServletRequestBuilder request = post("/login/saml2/sso/" + relyingPartyRegistration.getRegistrationId())
275278
.param("SAMLResponse", SIGNED_RESPONSE);
276279
// @formatter:on
277-
this.mvc.perform(request).andExpect(status().is3xxRedirection()).andExpect(redirectedUrl("/login?error"));
278-
verify(authenticationManager, never()).authenticate(any());
280+
this.mvc.perform(request).andExpect(status().is3xxRedirection()).andExpect(redirectedUrl("/"));
281+
verify(authenticationManager).authenticate(any());
279282
}
280283

281284
@Test
@@ -317,6 +320,8 @@ public void authenticateWhenCustomAuthnRequestRepositoryThenUses() throws Except
317320
SIGNED_RESPONSE);
318321
this.mvc.perform(request);
319322
verify(this.authenticationRequestRepository).loadAuthenticationRequest(any(HttpServletRequest.class));
323+
verify(this.authenticationRequestRepository).removeAuthenticationRequest(any(HttpServletRequest.class),
324+
any(HttpServletResponse.class));
320325
}
321326

322327
@Test

saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/Saml2AuthenticationTokenConverter.java

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
import jakarta.servlet.http.HttpServletRequest;
2020

21+
import org.springframework.http.HttpMethod;
2122
import org.springframework.security.saml2.core.Saml2Error;
2223
import org.springframework.security.saml2.core.Saml2ErrorCodes;
2324
import org.springframework.security.saml2.core.Saml2ParameterNames;
@@ -42,7 +43,7 @@ public final class Saml2AuthenticationTokenConverter implements AuthenticationCo
4243

4344
private Saml2AuthenticationRequestRepository<AbstractSaml2AuthenticationRequest> authenticationRequestRepository;
4445

45-
private boolean shouldInflateResponse = true;
46+
private boolean shouldConvertGetRequests = true;
4647

4748
/**
4849
* Constructs a {@link Saml2AuthenticationTokenConverter} given a strategy for
@@ -88,22 +89,26 @@ public void setAuthenticationRequestRepository(
8889
}
8990

9091
/**
91-
* Use the given {@code shouldInflateResponse} to inflate request. Default is
92-
* {@code true}.
93-
* @param shouldInflateResponse the {@code shouldInflateResponse} to use
92+
* Use the given {@code shouldConvertGetRequests} to convert {@code GET} requests.
93+
* Default is {@code true}.
94+
* @param shouldConvertGetRequests the {@code shouldConvertGetRequests} to use
9495
* @since 7.0
9596
*/
96-
public void setShouldInflateResponse(boolean shouldInflateResponse) {
97-
this.shouldInflateResponse = shouldInflateResponse;
97+
public void setShouldConvertGetRequests(boolean shouldConvertGetRequests) {
98+
this.shouldConvertGetRequests = shouldConvertGetRequests;
9899
}
99100

100101
private String decode(HttpServletRequest request) {
101102
String encoded = request.getParameter(Saml2ParameterNames.SAML_RESPONSE);
102103
if (encoded == null) {
103104
return null;
104105
}
106+
boolean isGet = HttpMethod.GET.matches(request.getMethod());
107+
if (!this.shouldConvertGetRequests && isGet) {
108+
return null;
109+
}
105110
try {
106-
return Saml2Utils.withEncoded(encoded).requireBase64(true).inflate(this.shouldInflateResponse).decode();
111+
return Saml2Utils.withEncoded(encoded).requireBase64(true).inflate(isGet).decode();
107112
}
108113
catch (Exception ex) {
109114
throw new Saml2AuthenticationException(new Saml2Error(Saml2ErrorCodes.INVALID_RESPONSE, ex.getMessage()),

saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/web/Saml2AuthenticationTokenConverterTests.java

Lines changed: 10 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -67,8 +67,7 @@ public void convertWhenSamlResponseThenToken() {
6767
request.setParameter(Saml2ParameterNames.SAML_RESPONSE,
6868
Saml2Utils.samlEncode("response".getBytes(StandardCharsets.UTF_8)));
6969
Saml2AuthenticationToken token = converter.convert(request);
70-
assertThat(token.getSaml2Response())
71-
.isEqualTo(Saml2Utils.samlInflate("response".getBytes(StandardCharsets.UTF_8)));
70+
assertThat(token.getSaml2Response()).isEqualTo("response");
7271
assertThat(token.getRelyingPartyRegistration().getRegistrationId())
7372
.isEqualTo(this.relyingPartyRegistration.getRegistrationId());
7473
}
@@ -82,8 +81,7 @@ public void convertWhenSamlResponseWithRelyingPartyRegistrationResolver(
8281
request.setParameter(Saml2ParameterNames.SAML_RESPONSE,
8382
Saml2Utils.samlEncode("response".getBytes(StandardCharsets.UTF_8)));
8483
Saml2AuthenticationToken token = converter.convert(request);
85-
assertThat(token.getSaml2Response())
86-
.isEqualTo(Saml2Utils.samlInflate("response".getBytes(StandardCharsets.UTF_8)));
84+
assertThat(token.getSaml2Response()).isEqualTo("response");
8785
assertThat(token.getRelyingPartyRegistration().getRegistrationId())
8886
.isEqualTo(this.relyingPartyRegistration.getRegistrationId());
8987
verify(resolver).resolve(any(), isNull());
@@ -160,18 +158,15 @@ public void convertWhenGetRequestInvalidDeflatedThenSaml2AuthenticationException
160158
}
161159

162160
@Test
163-
public void convertWhenUsingSamlUtilsBase64ThenSaml2AuthenticationException() throws Exception {
161+
public void convertWhenUsingSamlUtilsBase64ThenXmlIsValid() throws Exception {
164162
Saml2AuthenticationTokenConverter converter = new Saml2AuthenticationTokenConverter(
165163
this.relyingPartyRegistrationResolver);
166164
given(this.relyingPartyRegistrationResolver.resolve(any(HttpServletRequest.class), any()))
167165
.willReturn(this.relyingPartyRegistration);
168166
MockHttpServletRequest request = new MockHttpServletRequest();
169167
request.setParameter(Saml2ParameterNames.SAML_RESPONSE, getSsoCircleEncodedXml());
170-
assertThatExceptionOfType(Saml2AuthenticationException.class).isThrownBy(() -> converter.convert(request))
171-
.withRootCauseInstanceOf(IOException.class)
172-
.satisfies(
173-
(ex) -> assertThat(ex.getSaml2Error().getErrorCode()).isEqualTo(Saml2ErrorCodes.INVALID_RESPONSE))
174-
.satisfies((ex) -> assertThat(ex.getSaml2Error().getDescription()).isEqualTo("Unable to inflate string"));
168+
Saml2AuthenticationToken token = converter.convert(request);
169+
validateSsoCircleXml(token.getSaml2Response());
175170
}
176171

177172
@Test
@@ -192,8 +187,7 @@ public void convertWhenSavedAuthenticationRequestThenToken() {
192187
request.setParameter(Saml2ParameterNames.SAML_RESPONSE,
193188
Saml2Utils.samlEncode("response".getBytes(StandardCharsets.UTF_8)));
194189
Saml2AuthenticationToken token = converter.convert(request);
195-
assertThat(token.getSaml2Response())
196-
.isEqualTo(Saml2Utils.samlInflate("response".getBytes(StandardCharsets.UTF_8)));
190+
assertThat(token.getSaml2Response()).isEqualTo("response");
197191
assertThat(token.getRelyingPartyRegistration().getRegistrationId())
198192
.isEqualTo(this.relyingPartyRegistration.getRegistrationId());
199193
assertThat(token.getAuthenticationRequest()).isEqualTo(authenticationRequest);
@@ -216,8 +210,7 @@ public void convertWhenSavedAuthenticationRequestThenTokenWithRelyingPartyRegist
216210
request.setParameter(Saml2ParameterNames.SAML_RESPONSE,
217211
Saml2Utils.samlEncode("response".getBytes(StandardCharsets.UTF_8)));
218212
Saml2AuthenticationToken token = converter.convert(request);
219-
assertThat(token.getSaml2Response())
220-
.isEqualTo(Saml2Utils.samlInflate("response".getBytes(StandardCharsets.UTF_8)));
213+
assertThat(token.getSaml2Response()).isEqualTo("response");
221214
assertThat(token.getRelyingPartyRegistration().getRegistrationId())
222215
.isEqualTo(this.relyingPartyRegistration.getRegistrationId());
223216
assertThat(token.getAuthenticationRequest()).isEqualTo(authenticationRequest);
@@ -238,20 +231,18 @@ public void setAuthenticationRequestRepositoryWhenNullThenIllegalArgument() {
238231
}
239232

240233
@Test
241-
public void convertWhenGetRequestAndShouldNotInflateResponse() {
234+
public void shouldNotConvertGetRequests() {
242235
Saml2AuthenticationTokenConverter converter = new Saml2AuthenticationTokenConverter(
243236
this.relyingPartyRegistrationResolver);
244-
converter.setShouldInflateResponse(false);
237+
converter.setShouldConvertGetRequests(false);
245238
given(this.relyingPartyRegistrationResolver.resolve(any(HttpServletRequest.class), any()))
246239
.willReturn(this.relyingPartyRegistration);
247240
MockHttpServletRequest request = new MockHttpServletRequest();
248241
request.setMethod("GET");
249242
request.setParameter(Saml2ParameterNames.SAML_RESPONSE,
250243
Saml2Utils.samlEncode("response".getBytes(StandardCharsets.UTF_8)));
251244
Saml2AuthenticationToken token = converter.convert(request);
252-
assertThat(token.getSaml2Response()).isEqualTo("response");
253-
assertThat(token.getRelyingPartyRegistration().getRegistrationId())
254-
.isEqualTo(this.relyingPartyRegistration.getRegistrationId());
245+
assertThat(token).isNull();
255246
}
256247

257248
private void validateSsoCircleXml(String xml) {

0 commit comments

Comments
 (0)