From bcc9dde23e300d1a5ddadd5d94d625a25523620c Mon Sep 17 00:00:00 2001 From: kevinstrijbos Date: Sat, 25 May 2019 18:23:02 +0200 Subject: [PATCH] Make it easier to set bufferRequestBody on ClientHttpRequestFactory Enables easily setting the bufferRequestBody value on a ClientHttpRequestFactory using the RestTemplateBuilder. Fixes gh-16538 --- .../boot/web/client/RestTemplateBuilder.java | 76 ++++++++++++++++++- .../web/client/RestTemplateBuilderTests.java | 64 ++++++++++++++++ 2 files changed, 136 insertions(+), 4 deletions(-) diff --git a/spring-boot-project/spring-boot/src/main/java/org/springframework/boot/web/client/RestTemplateBuilder.java b/spring-boot-project/spring-boot/src/main/java/org/springframework/boot/web/client/RestTemplateBuilder.java index 4a5a09820e21..72614cbd1855 100644 --- a/spring-boot-project/spring-boot/src/main/java/org/springframework/boot/web/client/RestTemplateBuilder.java +++ b/spring-boot-project/spring-boot/src/main/java/org/springframework/boot/web/client/RestTemplateBuilder.java @@ -58,6 +58,7 @@ * @author Phillip Webb * @author Andy Wilkinson * @author Brian Clozel + * @author Kevin Strijbos * @since 1.4.0 */ public class RestTemplateBuilder { @@ -487,6 +488,22 @@ public RestTemplateBuilder setReadTimeout(Duration readTimeout) { this.interceptors); } + /** + * Sets the bufferrequestbody value on the underlying + * {@link ClientHttpRequestFactory}. + * @param bufferRequestBody value of the bufferRequestBody parameter + * @return a new builder instance. + * @since 2.1.0 + */ + public RestTemplateBuilder setBufferRequestBody(boolean bufferRequestBody) { + return new RestTemplateBuilder(this.detectRequestFactory, this.rootUri, + this.messageConverters, this.requestFactorySupplier, + this.uriTemplateHandler, this.errorHandler, this.basicAuthentication, + this.restTemplateCustomizers, + this.requestFactoryCustomizer.bufferRequestBody(bufferRequestBody), + this.interceptors); + } + /** * Build a new {@link RestTemplate} instance and configure it using this builder. * @return a configured {@link RestTemplate} instance. @@ -574,21 +591,35 @@ private static class RequestFactoryCustomizer private final Duration readTimeout; + private final boolean bufferRequestBody; + + private final boolean bufferRequestBodyFlag; + RequestFactoryCustomizer() { - this(null, null); + this(null, null, true, false); } - private RequestFactoryCustomizer(Duration connectTimeout, Duration readTimeout) { + private RequestFactoryCustomizer(Duration connectTimeout, Duration readTimeout, + boolean bufferRequestBody, boolean bufferRequestBodyFlag) { this.connectTimeout = connectTimeout; this.readTimeout = readTimeout; + this.bufferRequestBody = bufferRequestBody; + this.bufferRequestBodyFlag = bufferRequestBodyFlag; } public RequestFactoryCustomizer connectTimeout(Duration connectTimeout) { - return new RequestFactoryCustomizer(connectTimeout, this.readTimeout); + return new RequestFactoryCustomizer(connectTimeout, this.readTimeout, + this.bufferRequestBody, this.bufferRequestBodyFlag); } public RequestFactoryCustomizer readTimeout(Duration readTimeout) { - return new RequestFactoryCustomizer(this.connectTimeout, readTimeout); + return new RequestFactoryCustomizer(this.connectTimeout, readTimeout, + this.bufferRequestBody, this.bufferRequestBodyFlag); + } + + public RequestFactoryCustomizer bufferRequestBody(boolean bufferRequestBody) { + return new RequestFactoryCustomizer(this.connectTimeout, this.readTimeout, + bufferRequestBody, true); } @Override @@ -603,6 +634,10 @@ public void accept(ClientHttpRequestFactory requestFactory) { new TimeoutRequestFactoryCustomizer(this.readTimeout, "setReadTimeout") .customize(unwrappedRequestFactory); } + if (this.bufferRequestBodyFlag) { + new BufferRequestBodyFactoryCustomizer(this.bufferRequestBody, + "setBufferRequestBody").customize(unwrappedRequestFactory); + } } private ClientHttpRequestFactory unwrapRequestFactoryIfNecessary( @@ -653,6 +688,39 @@ private Method findMethod(ClientHttpRequestFactory factory) { } + /** + * {@link ClientHttpRequestFactory} customizer to call a "set buffer request body" + * method. + */ + private static final class BufferRequestBodyFactoryCustomizer { + + private final boolean bufferRequestBody; + + private final String methodName; + + BufferRequestBodyFactoryCustomizer(boolean bufferRequestBody, + String methodName) { + this.bufferRequestBody = bufferRequestBody; + this.methodName = methodName; + } + + void customize(ClientHttpRequestFactory factory) { + ReflectionUtils.invokeMethod(findMethod(factory), factory, + this.bufferRequestBody); + } + + private Method findMethod(ClientHttpRequestFactory factory) { + Method method = ReflectionUtils.findMethod(factory.getClass(), + this.methodName, boolean.class); + if (method != null) { + return method; + } + throw new IllegalStateException("Request factory " + factory.getClass() + + " does not have a " + this.methodName + "(boolean) method"); + } + + } + } } diff --git a/spring-boot-project/spring-boot/src/test/java/org/springframework/boot/web/client/RestTemplateBuilderTests.java b/spring-boot-project/spring-boot/src/test/java/org/springframework/boot/web/client/RestTemplateBuilderTests.java index f2070c7f9a29..e085c37fa64b 100644 --- a/spring-boot-project/spring-boot/src/test/java/org/springframework/boot/web/client/RestTemplateBuilderTests.java +++ b/spring-boot-project/spring-boot/src/test/java/org/springframework/boot/web/client/RestTemplateBuilderTests.java @@ -47,6 +47,7 @@ import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; +import static org.assertj.core.api.Assertions.assertThatIllegalStateException; import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.inOrder; import static org.mockito.Mockito.mock; @@ -63,6 +64,7 @@ * @author Phillip Webb * @author Andy Wilkinson * @author Dmytro Nosan + * @author Kevin Strijbos */ public class RestTemplateBuilderTests { @@ -477,6 +479,23 @@ public void readTimeoutCanBeConfiguredOnHttpComponentsRequestFactory() { "requestConfig")).getSocketTimeout()).isEqualTo(1234); } + @Test + public void bufferRequestBodyCanBeConfiguredOnHttpComponentsRequestFactory() { + ClientHttpRequestFactory requestFactory = this.builder + .requestFactory(HttpComponentsClientHttpRequestFactory.class) + .setBufferRequestBody(false).build().getRequestFactory(); + assertThat(requestFactory).hasFieldOrPropertyWithValue("bufferRequestBody", + false); + requestFactory = this.builder + .requestFactory(HttpComponentsClientHttpRequestFactory.class) + .setBufferRequestBody(true).build().getRequestFactory(); + assertThat(requestFactory).hasFieldOrPropertyWithValue("bufferRequestBody", true); + requestFactory = this.builder + .requestFactory(HttpComponentsClientHttpRequestFactory.class).build() + .getRequestFactory(); + assertThat(requestFactory).hasFieldOrPropertyWithValue("bufferRequestBody", true); + } + @Test public void connectTimeoutCanBeConfiguredOnSimpleRequestFactory() { ClientHttpRequestFactory requestFactory = this.builder @@ -493,6 +512,21 @@ public void readTimeoutCanBeConfiguredOnSimpleRequestFactory() { assertThat(requestFactory).hasFieldOrPropertyWithValue("readTimeout", 1234); } + @Test + public void bufferRequestBodyCanBeConfiguredOnSimpleRequestFactory() { + ClientHttpRequestFactory requestFactory = this.builder + .requestFactory(SimpleClientHttpRequestFactory.class) + .setBufferRequestBody(false).build().getRequestFactory(); + assertThat(requestFactory).hasFieldOrPropertyWithValue("bufferRequestBody", + false); + requestFactory = this.builder.requestFactory(SimpleClientHttpRequestFactory.class) + .setBufferRequestBody(true).build().getRequestFactory(); + assertThat(requestFactory).hasFieldOrPropertyWithValue("bufferRequestBody", true); + requestFactory = this.builder.requestFactory(SimpleClientHttpRequestFactory.class) + .build().getRequestFactory(); + assertThat(requestFactory).hasFieldOrPropertyWithValue("bufferRequestBody", true); + } + @Test public void connectTimeoutCanBeConfiguredOnOkHttp3RequestFactory() { ClientHttpRequestFactory requestFactory = this.builder @@ -513,6 +547,15 @@ public void readTimeoutCanBeConfiguredOnOkHttp3RequestFactory() { .isEqualTo(1234); } + @Test + public void bufferRequestBodyCanNotBeConfiguredOnOkHttp3RequestFactory() { + assertThatIllegalStateException() + .isThrownBy(() -> this.builder + .requestFactory(OkHttp3ClientHttpRequestFactory.class) + .setBufferRequestBody(false).build().getRequestFactory()) + .withMessageContaining(OkHttp3ClientHttpRequestFactory.class.getName()); + } + @Test public void connectTimeoutCanBeConfiguredOnAWrappedRequestFactory() { SimpleClientHttpRequestFactory requestFactory = new SimpleClientHttpRequestFactory(); @@ -533,6 +576,27 @@ public void readTimeoutCanBeConfiguredOnAWrappedRequestFactory() { assertThat(requestFactory).hasFieldOrPropertyWithValue("readTimeout", 1234); } + @Test + public void bufferRequestBodyCanBeConfiguredOnAWrappedRequestFactory() { + SimpleClientHttpRequestFactory requestFactory = new SimpleClientHttpRequestFactory(); + this.builder + .requestFactory( + () -> new BufferingClientHttpRequestFactory(requestFactory)) + .setBufferRequestBody(false).build(); + assertThat(requestFactory).hasFieldOrPropertyWithValue("bufferRequestBody", + false); + this.builder + .requestFactory( + () -> new BufferingClientHttpRequestFactory(requestFactory)) + .setBufferRequestBody(true).build(); + assertThat(requestFactory).hasFieldOrPropertyWithValue("bufferRequestBody", true); + this.builder + .requestFactory( + () -> new BufferingClientHttpRequestFactory(requestFactory)) + .build(); + assertThat(requestFactory).hasFieldOrPropertyWithValue("bufferRequestBody", true); + } + @Test public void unwrappingDoesNotAffectRequestFactoryThatIsSetOnTheBuiltTemplate() { SimpleClientHttpRequestFactory requestFactory = new SimpleClientHttpRequestFactory();