diff --git a/services-custom/s3-transfer-manager/src/main/java/software/amazon/awssdk/transfer/s3/internal/RequestDataSupplierAdapter.java b/services-custom/s3-transfer-manager/src/main/java/software/amazon/awssdk/transfer/s3/internal/RequestDataSupplierAdapter.java index 2c65f8e85c5f..14cbecb4f29b 100644 --- a/services-custom/s3-transfer-manager/src/main/java/software/amazon/awssdk/transfer/s3/internal/RequestDataSupplierAdapter.java +++ b/services-custom/s3-transfer-manager/src/main/java/software/amazon/awssdk/transfer/s3/internal/RequestDataSupplierAdapter.java @@ -30,6 +30,7 @@ import org.reactivestreams.Subscriber; import org.reactivestreams.Subscription; import software.amazon.awssdk.annotations.SdkInternalApi; +import software.amazon.awssdk.crt.CrtRuntimeException; import software.amazon.awssdk.crt.http.HttpHeader; import software.amazon.awssdk.http.SdkHttpResponse; import software.amazon.awssdk.utils.Logger; @@ -49,8 +50,7 @@ public final class RequestDataSupplierAdapter implements RequestDataSupplier { private final Publisher bodyPublisher; - // Not volatile, we synchronize on the subscriptionQueue - private Subscription subscription; + private volatile Subscription subscription; // TODO: not volatile since it's read and written only by CRT thread(s). Need to // ensure that CRT actually ensures consistency across their threads... @@ -175,6 +175,20 @@ public boolean resetPosition() { return true; } + @Override + public void onException(CrtRuntimeException e) { + if (subscription != null) { + subscription.cancel(); + } + } + + @Override + public void onFinished() { + if (subscription != null) { + subscription.cancel(); + } + } + private Event takeFirstEvent() { try { return eventBuffer.takeFirst(); diff --git a/services-custom/s3-transfer-manager/src/test/java/software/amazon/awssdk/transfer/s3/internal/RequestDataSupplierAdapterTest.java b/services-custom/s3-transfer-manager/src/test/java/software/amazon/awssdk/transfer/s3/internal/RequestDataSupplierAdapterTest.java index 87f5d2af82a3..b01cada9e1c6 100644 --- a/services-custom/s3-transfer-manager/src/test/java/software/amazon/awssdk/transfer/s3/internal/RequestDataSupplierAdapterTest.java +++ b/services-custom/s3-transfer-manager/src/test/java/software/amazon/awssdk/transfer/s3/internal/RequestDataSupplierAdapterTest.java @@ -17,18 +17,24 @@ import static org.assertj.core.api.Assertions.assertThatThrownBy; import static org.assertj.core.api.AssertionsForClassTypes.assertThat; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; import io.reactivex.Flowable; import java.io.IOException; import java.nio.ByteBuffer; import java.util.Arrays; import java.util.List; +import java.util.Optional; +import java.util.concurrent.CompletableFuture; import java.util.stream.Collectors; import java.util.stream.Stream; import org.junit.Test; import org.reactivestreams.Publisher; import org.reactivestreams.Subscriber; +import org.reactivestreams.Subscription; import software.amazon.awssdk.core.async.AsyncRequestBody; +import software.amazon.awssdk.crt.CrtRuntimeException; public class RequestDataSupplierAdapterTest { @@ -156,4 +162,56 @@ public void subscribe(Subscriber subscriber) { assertThat(readBuffer).isEqualTo(expectedBufferContent); } } + + @Test + public void onException_cancelsSubscription() { + Subscription subscription = mock(Subscription.class); + + AsyncRequestBody requestBody = new AsyncRequestBody() { + @Override + public Optional contentLength() { + return Optional.empty(); + } + + @Override + public void subscribe(Subscriber subscriber) { + subscriber.onSubscribe(subscription); + } + }; + + RequestDataSupplierAdapter adapter = new RequestDataSupplierAdapter(requestBody); + + // getRequestBytes() triggers a subscribe() on the publisher + adapter.getRequestBytes(ByteBuffer.allocate(0)); + + adapter.onException(new CrtRuntimeException("error")); + + verify(subscription).cancel(); + } + + @Test + public void onFinished_cancelsSubscription() { + Subscription subscription = mock(Subscription.class); + + AsyncRequestBody requestBody = new AsyncRequestBody() { + @Override + public Optional contentLength() { + return Optional.empty(); + } + + @Override + public void subscribe(Subscriber subscriber) { + subscriber.onSubscribe(subscription); + } + }; + + RequestDataSupplierAdapter adapter = new RequestDataSupplierAdapter(requestBody); + + // getRequestBytes() triggers a subscribe() on the publisher + adapter.getRequestBytes(ByteBuffer.allocate(0)); + + adapter.onFinished(); + + verify(subscription).cancel(); + } }