diff --git a/services/s3/pom.xml b/services/s3/pom.xml index a7defbd75af2..9db8a50368eb 100644 --- a/services/s3/pom.xml +++ b/services/s3/pom.xml @@ -64,7 +64,7 @@ software.amazon.awssdk.crt aws-crt - ${awscrt.version} + 1.0.0-SNAPSHOT software.amazon.awssdk diff --git a/services/s3/src/it/java/software/amazon/awssdk/services/s3/CrtClientIntegrationTest.java b/services/s3/src/it/java/software/amazon/awssdk/services/s3/CrtClientIntegrationTest.java new file mode 100644 index 000000000000..7a34d312e30d --- /dev/null +++ b/services/s3/src/it/java/software/amazon/awssdk/services/s3/CrtClientIntegrationTest.java @@ -0,0 +1,148 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package software.amazon.awssdk.services.s3; + +import static org.assertj.core.api.Assertions.assertThat; +import static software.amazon.awssdk.testutils.service.S3BucketUtils.temporaryBucketName; +import io.reactivex.Flowable; +import java.io.IOException; +import java.nio.ByteBuffer; +import java.nio.file.Files; +import java.security.NoSuchAlgorithmException; +import java.util.List; +import java.util.Optional; +import java.util.Random; +import java.util.stream.Collectors; +import java.util.stream.Stream; +import org.junit.After; +import org.junit.AfterClass; +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.Test; +import org.reactivestreams.Subscriber; +import software.amazon.awssdk.core.ResponseBytes; +import software.amazon.awssdk.core.ResponseInputStream; +import software.amazon.awssdk.core.async.AsyncRequestBody; +import software.amazon.awssdk.core.sync.ResponseTransformer; +import software.amazon.awssdk.services.s3.model.GetObjectResponse; +import software.amazon.awssdk.services.s3.utils.ChecksumUtils; +import software.amazon.awssdk.testutils.RandomTempFile; + +public class CrtClientIntegrationTest extends S3IntegrationTestBase { + private static final String TEST_BUCKET = temporaryBucketName(CrtClientIntegrationTest.class); + private static final String TEST_KEY = "8mib_file.dat"; + private static final int OBJ_SIZE = 8 * 1024 * 1024; + + private static RandomTempFile testFile; + + private S3CrtAsyncClient s3Crt; + + @BeforeClass + public static void setup() throws Exception { + S3IntegrationTestBase.setUp(); + createBucket(TEST_BUCKET); + + testFile = new RandomTempFile(TEST_KEY, OBJ_SIZE); + } + + @Before + public void methodSetup() { + s3Crt = S3CrtAsyncClient.builder() + .credentialsProvider(CREDENTIALS_PROVIDER_CHAIN) + .region(DEFAULT_REGION) + .build(); + } + + @After + public void methodTeardown() { + s3Crt.close(); + } + + @AfterClass + public static void teardown() throws IOException { + deleteBucketAndAllContents(TEST_BUCKET); + Files.delete(testFile.toPath()); + } + + @Test + public void putObject_fileRequestBody_objectSentCorrectly() throws IOException, NoSuchAlgorithmException { + AsyncRequestBody body = AsyncRequestBody.fromFile(testFile.toPath()); + s3Crt.putObject(r -> r.bucket(TEST_BUCKET).key(TEST_KEY), body).join(); + + ResponseInputStream objContent = s3.getObject(r -> r.bucket(TEST_BUCKET).key(TEST_KEY), + ResponseTransformer.toInputStream()); + + byte[] expectedSum = ChecksumUtils.computeCheckSum(Files.newInputStream(testFile.toPath())); + + assertThat(ChecksumUtils.computeCheckSum(objContent)).isEqualTo(expectedSum); + } + + @Test + public void putObject_byteBufferBody_objectSentCorrectly() throws IOException, NoSuchAlgorithmException { + byte[] data = new byte[16384]; + new Random().nextBytes(data); + ByteBuffer byteBuffer = ByteBuffer.wrap(data); + + AsyncRequestBody body = AsyncRequestBody.fromByteBuffer(byteBuffer); + + s3Crt.putObject(r -> r.bucket(TEST_BUCKET).key(TEST_KEY), body).join(); + + ResponseBytes responseBytes = s3.getObject(r -> r.bucket(TEST_BUCKET).key(TEST_KEY), + ResponseTransformer.toBytes()); + + byte[] expectedSum = ChecksumUtils.computeCheckSum(byteBuffer); + + assertThat(ChecksumUtils.computeCheckSum(responseBytes.asByteBuffer())).isEqualTo(expectedSum); + } + + @Test + public void putObject_customRequestBody_objectSentCorrectly() throws IOException, NoSuchAlgorithmException { + Random rng = new Random(); + int bufferSize = 16384; + int nBuffers = 15; + List bodyData = Stream.generate(() -> { + byte[] data = new byte[bufferSize]; + rng.nextBytes(data); + return ByteBuffer.wrap(data); + }).limit(nBuffers).collect(Collectors.toList()); + + long contentLength = bufferSize * nBuffers; + + byte[] expectedSum = ChecksumUtils.computeCheckSum(bodyData); + + Flowable publisher = Flowable.fromIterable(bodyData); + + AsyncRequestBody customRequestBody = new AsyncRequestBody() { + @Override + public Optional contentLength() { + return Optional.of(contentLength); + } + + @Override + public void subscribe(Subscriber subscriber) { + publisher.subscribe(subscriber); + } + }; + + s3Crt.putObject(r -> r.bucket(TEST_BUCKET).key(TEST_KEY), customRequestBody).join(); + + ResponseInputStream objContent = s3.getObject(r -> r.bucket(TEST_BUCKET).key(TEST_KEY), + ResponseTransformer.toInputStream()); + + + assertThat(ChecksumUtils.computeCheckSum(objContent)).isEqualTo(expectedSum); + } +} diff --git a/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/DefaultS3CrtAsyncClient.java b/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/DefaultS3CrtAsyncClient.java index b62682eeb76f..f5454b183fb3 100644 --- a/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/DefaultS3CrtAsyncClient.java +++ b/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/DefaultS3CrtAsyncClient.java @@ -18,6 +18,7 @@ import static software.amazon.awssdk.services.s3.internal.S3CrtUtils.createCrtCredentialsProvider; +import com.amazonaws.s3.RequestDataSupplier; import com.amazonaws.s3.S3NativeClient; import java.util.concurrent.CompletableFuture; import software.amazon.awssdk.annotations.SdkInternalApi; @@ -26,6 +27,11 @@ import software.amazon.awssdk.services.s3.S3CrtAsyncClient; import software.amazon.awssdk.services.s3.model.GetObjectRequest; import software.amazon.awssdk.services.s3.model.GetObjectResponse; +import com.amazonaws.s3.model.PutObjectOutput; +import software.amazon.awssdk.core.async.AsyncRequestBody; +import software.amazon.awssdk.services.s3.internal.s3crt.RequestDataSupplierAdapter; +import software.amazon.awssdk.services.s3.model.PutObjectRequest; +import software.amazon.awssdk.services.s3.model.PutObjectResponse; @SdkInternalApi public final class DefaultS3CrtAsyncClient implements S3CrtAsyncClient { @@ -74,6 +80,20 @@ public CompletableFuture getObject( return future; } + public CompletableFuture putObject(PutObjectRequest putObjectRequest, AsyncRequestBody requestBody) { + com.amazonaws.s3.model.PutObjectRequest adaptedRequest = S3CrtUtils.toCrtPutObjectRequest(putObjectRequest); + + if (adaptedRequest.contentLength() == null && requestBody.contentLength().isPresent()) { + adaptedRequest = adaptedRequest.toBuilder().contentLength(requestBody.contentLength().get()) + .build(); + } + + CompletableFuture putObjectOutputCompletableFuture = s3NativeClient.putObject(adaptedRequest, + adaptToDataSupplier(requestBody)); + + return putObjectOutputCompletableFuture.thenApply(S3CrtUtils::fromCrtPutObjectOutput); + } + @Override public String serviceName() { return "s3"; @@ -84,4 +104,8 @@ public void close() { s3NativeClient.close(); configuration.close(); } + + private static RequestDataSupplier adaptToDataSupplier(AsyncRequestBody requestBody) { + return new RequestDataSupplierAdapter(requestBody); + } } diff --git a/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/S3CrtUtils.java b/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/S3CrtUtils.java index 299680e90923..65ca4f2cc790 100644 --- a/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/S3CrtUtils.java +++ b/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/S3CrtUtils.java @@ -15,6 +15,13 @@ package software.amazon.awssdk.services.s3.internal; +import com.amazonaws.s3.model.ObjectCannedACL; +import com.amazonaws.s3.model.ObjectLockLegalHoldStatus; +import com.amazonaws.s3.model.ObjectLockMode; +import com.amazonaws.s3.model.PutObjectOutput; +import com.amazonaws.s3.model.RequestPayer; +import com.amazonaws.s3.model.ServerSideEncryption; +import com.amazonaws.s3.model.StorageClass; import software.amazon.awssdk.annotations.SdkInternalApi; import software.amazon.awssdk.auth.credentials.AwsCredentials; import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider; @@ -23,6 +30,8 @@ import software.amazon.awssdk.crt.auth.credentials.StaticCredentialsProvider; import software.amazon.awssdk.services.s3.model.GetObjectRequest; import software.amazon.awssdk.services.s3.model.GetObjectResponse; +import software.amazon.awssdk.services.s3.model.PutObjectRequest; +import software.amazon.awssdk.services.s3.model.PutObjectResponse; @SdkInternalApi public final class S3CrtUtils { @@ -64,21 +73,82 @@ public static com.amazonaws.s3.model.GetObjectRequest adaptGetObjectRequest(GetO // TODO: codegen public static GetObjectResponse adaptGetObjectOutput(com.amazonaws.s3.model.GetObjectOutput response) { return GetObjectResponse.builder() - .bucketKeyEnabled(response.bucketKeyEnabled()) - .acceptRanges(response.acceptRanges()) - .contentDisposition(response.contentDisposition()) - .cacheControl(response.cacheControl()) - .contentEncoding(response.contentEncoding()) - .contentLanguage(response.contentLanguage()) - .contentRange(response.contentRange()) - .contentLength(response.contentLength()) - .contentType(response.contentType()) - .deleteMarker(response.deleteMarker()) - .eTag(response.eTag()) - .expiration(response.expiration()) - .expires(response.expires()) - .lastModified(response.lastModified()) - .metadata(response.metadata()) - .build(); + .bucketKeyEnabled(response.bucketKeyEnabled()) + .acceptRanges(response.acceptRanges()) + .contentDisposition(response.contentDisposition()) + .cacheControl(response.cacheControl()) + .contentEncoding(response.contentEncoding()) + .contentLanguage(response.contentLanguage()) + .contentRange(response.contentRange()) + .contentLength(response.contentLength()) + .contentType(response.contentType()) + .deleteMarker(response.deleteMarker()) + .eTag(response.eTag()) + .expiration(response.expiration()) + .expires(response.expires()) + .lastModified(response.lastModified()) + .metadata(response.metadata()) + .build(); + } + + //TODO: codegen + public static com.amazonaws.s3.model.PutObjectRequest toCrtPutObjectRequest(PutObjectRequest sdkPutObject) { + return com.amazonaws.s3.model.PutObjectRequest.builder() + .contentLength(sdkPutObject.contentLength()) + .aCL(ObjectCannedACL.fromValue(sdkPutObject.aclAsString())) + .bucket(sdkPutObject.bucket()) + .key(sdkPutObject.key()) + .bucketKeyEnabled(sdkPutObject.bucketKeyEnabled()) + .cacheControl(sdkPutObject.cacheControl()) + .contentDisposition(sdkPutObject.contentDisposition()) + .contentEncoding(sdkPutObject.contentEncoding()) + .contentLanguage(sdkPutObject.contentLanguage()) + .contentMD5(sdkPutObject.contentMD5()) + .contentType(sdkPutObject.contentType()) + .expectedBucketOwner(sdkPutObject.expectedBucketOwner()) + .expires(sdkPutObject.expires()) + .grantFullControl(sdkPutObject.grantFullControl()) + .grantRead(sdkPutObject.grantRead()) + .grantReadACP(sdkPutObject.grantReadACP()) + .grantWriteACP(sdkPutObject.grantWriteACP()) + .metadata(sdkPutObject.metadata()) + .objectLockLegalHoldStatus(ObjectLockLegalHoldStatus.fromValue(sdkPutObject.objectLockLegalHoldStatusAsString())) + .objectLockMode(ObjectLockMode.fromValue(sdkPutObject.objectLockModeAsString())) + .objectLockRetainUntilDate(sdkPutObject.objectLockRetainUntilDate()) + .requestPayer(RequestPayer.fromValue(sdkPutObject.requestPayerAsString())) + .serverSideEncryption(ServerSideEncryption.fromValue(sdkPutObject.requestPayerAsString())) + .sSECustomerAlgorithm(sdkPutObject.sseCustomerAlgorithm()) + .sSECustomerKey(sdkPutObject.sseCustomerKey()) + .sSECustomerKeyMD5(sdkPutObject.sseCustomerKeyMD5()) + .sSEKMSEncryptionContext(sdkPutObject.ssekmsEncryptionContext()) + .sSEKMSKeyId(sdkPutObject.ssekmsKeyId()) + .storageClass(StorageClass.fromValue(sdkPutObject.storageClassAsString())) + .tagging(sdkPutObject.tagging()) + .websiteRedirectLocation(sdkPutObject.websiteRedirectLocation()) + .build(); + } + + //TODO: codegen + public static PutObjectResponse fromCrtPutObjectOutput(PutObjectOutput crtPutObjectOutput) { + // TODO: Provide the HTTP request-level data (e.g. response metadata, HTTP response) + PutObjectResponse.Builder builder = PutObjectResponse.builder() + .bucketKeyEnabled(crtPutObjectOutput.bucketKeyEnabled()) + .eTag(crtPutObjectOutput.eTag()) + .expiration(crtPutObjectOutput.expiration()) + .sseCustomerAlgorithm(crtPutObjectOutput.sSECustomerAlgorithm()) + .sseCustomerKeyMD5(crtPutObjectOutput.sSECustomerKeyMD5()) + .ssekmsEncryptionContext(crtPutObjectOutput.sSEKMSEncryptionContext()) + .ssekmsKeyId(crtPutObjectOutput.sSEKMSKeyId()) + .versionId(crtPutObjectOutput.versionId()); + + if (crtPutObjectOutput.requestCharged() != null) { + builder.requestCharged(crtPutObjectOutput.requestCharged().value()); + } + + if (crtPutObjectOutput.serverSideEncryption() != null) { + builder.serverSideEncryption(crtPutObjectOutput.serverSideEncryption().value()); + } + + return builder.build(); } } diff --git a/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/s3crt/RequestDataSupplierAdapter.java b/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/s3crt/RequestDataSupplierAdapter.java new file mode 100644 index 000000000000..0d2bc99d4b1b --- /dev/null +++ b/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/s3crt/RequestDataSupplierAdapter.java @@ -0,0 +1,338 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package software.amazon.awssdk.services.s3.internal.s3crt; + +import com.amazonaws.s3.RequestDataSupplier; +import java.nio.ByteBuffer; +import java.util.Deque; +import java.util.concurrent.BlockingDeque; +import java.util.concurrent.BlockingQueue; +import java.util.concurrent.LinkedBlockingDeque; +import java.util.concurrent.LinkedBlockingQueue; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Consumer; +import org.reactivestreams.Publisher; +import org.reactivestreams.Subscriber; +import org.reactivestreams.Subscription; +import software.amazon.awssdk.annotations.SdkInternalApi; +import software.amazon.awssdk.utils.Logger; + +/** + * Adapts an SDK {@link software.amazon.awssdk.core.async.AsyncRequestBody} to CRT's {@link RequestDataSupplier}. + */ +@SdkInternalApi +public final class RequestDataSupplierAdapter implements RequestDataSupplier { + private static final Logger LOG = Logger.loggerFor(RequestDataSupplierAdapter.class); + + static final long DEFAULT_REQUEST_SIZE = 8; + + private final AtomicReference subscriptionStatus = new AtomicReference<>(SubscriptionStatus.NOT_SUBSCRIBED); + private final BlockingQueue subscriptionQueue = new LinkedBlockingQueue<>(1); + private final BlockingDeque eventBuffer = new LinkedBlockingDeque<>(); + + private final Publisher bodyPublisher; + + // Not volatile, we synchronize on the subscriptionQueue + private Subscription subscription; + + // TODO: not volatile since it's read and written only by CRT thread(s). Need to + // ensure that CRT actually ensures consistency across their threads... + private Subscriber subscriber; + private long pending = 0; + + public RequestDataSupplierAdapter(Publisher bodyPublisher) { + this.bodyPublisher = bodyPublisher; + this.subscriber = createSubscriber(); + } + + @Override + public boolean getRequestBytes(ByteBuffer outBuffer) { + LOG.debug(() -> "Getting data to fill buffer of size " + outBuffer.remaining()); + + // Per the spec, onSubscribe is always called before any other + // signal, so we expect a subscription to always be provided; we just + // wait for that to happen + waitForSubscription(); + + // The "event loop". Per the spec, the sequence of events is "onSubscribe onNext* (onError | onComplete)?". + // We don't handle onSubscribe as a discrete event; instead we only enter this loop once we have a + // subscription. + // + // This works by requesting and consuming DATA events until we fill the buffer. We return from the method if + // we encounter either of the terminal events, COMPLETE or ERROR. + while (true) { + // The supplier API requires that we fill the buffer entirely. + if (!outBuffer.hasRemaining()) { + break; + } + + if (eventBuffer.isEmpty() && pending == 0) { + pending = DEFAULT_REQUEST_SIZE; + subscription.request(pending); + } + + Event ev = takeFirstEvent(); + + // Discard the event if it's not for the current subscriber + if (!ev.subscriber().equals(subscriber)) { + LOG.debug(() -> "Received an event for a previous publisher. Discarding. Event was: " + ev); + continue; + } + + switch (ev.type()) { + case DATA: + ByteBuffer srcBuffer = ((DataEvent) ev).data(); + + ByteBuffer bufferToWrite = srcBuffer.duplicate(); + int nBytesToWrite = Math.min(outBuffer.remaining(), srcBuffer.remaining()); + + // src is larger, create a resized view to prevent + // buffer overflow in the subsequent put() call + if (bufferToWrite.remaining() > nBytesToWrite) { + bufferToWrite.limit(bufferToWrite.position() + nBytesToWrite); + } + + outBuffer.put(bufferToWrite); + srcBuffer.position(bufferToWrite.limit()); + + if (!srcBuffer.hasRemaining()) { + --pending; + } else { + eventBuffer.push(ev); + } + + break; + + case COMPLETE: + // Leave this event in the queue so that if getRequestData + // gets call after the stream is already done, we pop it off again. + eventBuffer.push(ev); + pending = 0; + return true; + + case ERROR: + // Leave this event in the queue so that if getRequestData + // gets call after the stream is already done, we pop it off again. + eventBuffer.push(ev); + Throwable t = ((ErrorEvent) ev).error(); + if (t instanceof RuntimeException) { + throw (RuntimeException) t; + } + throw new RuntimeException(t); + + default: + // In case new event types are introduced that this loop doesn't account for + throw new IllegalStateException("Unknown event type: " + ev.type()); + } + } + + return false; + } + + @Override + public boolean resetPosition() { + subscription.cancel(); + subscription = null; + + this.subscriber = createSubscriber(); + subscriptionStatus.set(SubscriptionStatus.NOT_SUBSCRIBED); + + // NOTE: It's possible that even after this happens, eventBuffer gets + // residual events from the canceled subscription if the publisher + // handles cancel asynchronously. That doesn't affect us too much since + // we always ensure the event is for the current subscriber. + eventBuffer.clear(); + pending = 0; + + return true; + } + + private Event takeFirstEvent() { + try { + return eventBuffer.takeFirst(); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + throw new RuntimeException("Interrupted while waiting for next event", e); + } + } + + public SubscriberImpl createSubscriber() { + return new SubscriberImpl(this::setSubscription, eventBuffer); + } + + private void setSubscription(Subscription subscription) { + if (subscriptionStatus.compareAndSet(SubscriptionStatus.SUBSCRIBING, SubscriptionStatus.SUBSCRIBED)) { + subscriptionQueue.add(subscription); + } else { + LOG.error(() -> "The supplier stopped waiting for the subscription. This is likely because it took " + + "longer than the timeout to arrive. Cancelling the subscription"); + subscription.cancel(); + } + } + + static class SubscriberImpl implements Subscriber { + private final Consumer subscriptionSetter; + private final Deque eventBuffer; + private boolean subscribed = false; + + public SubscriberImpl(Consumer subscriptionSetter, Deque eventBuffer) { + this.subscriptionSetter = subscriptionSetter; + this.eventBuffer = eventBuffer; + } + + @Override + public void onSubscribe(Subscription subscription) { + if (subscription == null) { + throw new NullPointerException("Subscription must not be null"); + } + + if (subscribed) { + subscription.cancel(); + return; + } + + subscriptionSetter.accept(subscription); + subscribed = true; + } + + @Override + public void onNext(ByteBuffer byteBuffer) { + if (byteBuffer == null) { + throw new NullPointerException("byteBuffer must not be null"); + } + LOG.debug(() -> "Received new data of size: " + byteBuffer.remaining()); + eventBuffer.add(new DataEvent(this, byteBuffer)); + } + + @Override + public void onError(Throwable throwable) { + eventBuffer.add(new ErrorEvent(this, throwable)); + } + + @Override + public void onComplete() { + eventBuffer.add(new CompleteEvent(this)); + } + } + + private void waitForSubscription() { + if (!subscriptionStatus.compareAndSet(SubscriptionStatus.NOT_SUBSCRIBED, SubscriptionStatus.SUBSCRIBING)) { + return; + } + + bodyPublisher.subscribe(this.subscriber); + + try { + this.subscription = subscriptionQueue.poll(5, TimeUnit.SECONDS); + if (subscription == null) { + if (!subscriptionStatus.compareAndSet(SubscriptionStatus.SUBSCRIBING, SubscriptionStatus.TIMED_OUT)) { + subscriptionQueue.take().cancel(); + } + + throw new RuntimeException("Publisher did not respond with a subscription within 5 seconds"); + } + } catch (InterruptedException e) { + LOG.error(() -> "Interrupted while waiting for subscription", e); + Thread.currentThread().interrupt(); + throw new RuntimeException("Interrupted while waiting for subscription", e); + } + } + + private enum EventType { + DATA, + COMPLETE, + ERROR + } + + private interface Event { + Subscriber subscriber(); + EventType type(); + } + + private static class DataEvent implements Event { + private final Subscriber subscriber; + private final ByteBuffer data; + + public DataEvent(Subscriber subscriber, ByteBuffer data) { + this.subscriber = subscriber; + this.data = data; + } + + @Override + public Subscriber subscriber() { + return subscriber; + } + + @Override + public EventType type() { + return EventType.DATA; + } + + public final ByteBuffer data() { + return data; + } + } + + private static class CompleteEvent implements Event { + private final Subscriber subscriber; + + public CompleteEvent(Subscriber subscriber) { + this.subscriber = subscriber; + } + + @Override + public Subscriber subscriber() { + return subscriber; + } + + @Override + public EventType type() { + return EventType.COMPLETE; + } + } + + private static class ErrorEvent implements Event { + private final Subscriber subscriber; + private final Throwable error; + + public ErrorEvent(Subscriber subscriber, Throwable error) { + this.subscriber = subscriber; + this.error = error; + } + + @Override + public Subscriber subscriber() { + return subscriber; + } + + @Override + public EventType type() { + return EventType.ERROR; + } + + public final Throwable error() { + return error; + } + } + + private enum SubscriptionStatus { + NOT_SUBSCRIBED, + SUBSCRIBING, + SUBSCRIBED, + TIMED_OUT + } +} diff --git a/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/s3crt/RequestDataSupplierAdapterTckTest.java b/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/s3crt/RequestDataSupplierAdapterTckTest.java new file mode 100644 index 000000000000..8b9bc8f37895 --- /dev/null +++ b/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/s3crt/RequestDataSupplierAdapterTckTest.java @@ -0,0 +1,75 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package software.amazon.awssdk.services.s3.internal.s3crt; + +import java.nio.ByteBuffer; +import java.util.ArrayDeque; +import org.reactivestreams.Subscriber; +import org.reactivestreams.Subscription; +import org.reactivestreams.tck.SubscriberWhiteboxVerification; +import org.reactivestreams.tck.TestEnvironment; + +public class RequestDataSupplierAdapterTckTest extends SubscriberWhiteboxVerification { + private static final byte[] CONTENT = new byte[16]; + + protected RequestDataSupplierAdapterTckTest() { + super(new TestEnvironment()); + } + + @Override + public Subscriber createSubscriber(WhiteboxSubscriberProbe whiteboxSubscriberProbe) { + return new RequestDataSupplierAdapter.SubscriberImpl((s) -> {}, new ArrayDeque<>()) { + @Override + public void onSubscribe(Subscription subscription) { + super.onSubscribe(subscription); + whiteboxSubscriberProbe.registerOnSubscribe(new SubscriberPuppet() { + @Override + public void triggerRequest(long l) { + subscription.request(l); + } + + @Override + public void signalCancel() { + subscription.cancel(); + } + }); + } + + @Override + public void onNext(ByteBuffer bb) { + super.onNext(bb); + whiteboxSubscriberProbe.registerOnNext(bb); + } + + @Override + public void onError(Throwable t) { + super.onError(t); + whiteboxSubscriberProbe.registerOnError(t); + } + + @Override + public void onComplete() { + super.onComplete(); + whiteboxSubscriberProbe.registerOnComplete(); + } + }; + } + + @Override + public ByteBuffer createElement(int i) { + return ByteBuffer.wrap(CONTENT); + } +} diff --git a/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/s3crt/RequestDataSupplierAdapterTest.java b/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/s3crt/RequestDataSupplierAdapterTest.java new file mode 100644 index 000000000000..f194a5471a28 --- /dev/null +++ b/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/s3crt/RequestDataSupplierAdapterTest.java @@ -0,0 +1,158 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package software.amazon.awssdk.services.s3.internal.s3crt; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import io.reactivex.Flowable; +import java.io.IOException; +import java.nio.ByteBuffer; +import java.util.Arrays; +import java.util.List; +import java.util.stream.Collectors; +import java.util.stream.Stream; +import org.junit.Test; +import org.reactivestreams.Publisher; +import org.reactivestreams.Subscriber; +import software.amazon.awssdk.core.async.AsyncRequestBody; + +public class RequestDataSupplierAdapterTest { + + @Test + public void getRequestData_fillsInputBuffer_publisherBuffersAreSmaller() { + int inputBufferSize = 16; + + List data = Stream.generate(() -> (byte) 42) + .limit(inputBufferSize) + .map(b -> { + ByteBuffer bb = ByteBuffer.allocate(1); + bb.put(b); + bb.flip(); + return bb; + }) + .collect(Collectors.toList()); + + AsyncRequestBody requestBody = AsyncRequestBody.fromPublisher(Flowable.fromIterable(data)); + + RequestDataSupplierAdapter adapter = new RequestDataSupplierAdapter(requestBody); + + ByteBuffer inputBuffer = ByteBuffer.allocate(inputBufferSize); + adapter.getRequestBytes(inputBuffer); + + assertThat(inputBuffer.remaining()).isEqualTo(0); + } + + @Test + public void getRequestData_fillsInputBuffer_publisherBuffersAreLarger() { + int bodySize = 16; + + ByteBuffer data = ByteBuffer.allocate(bodySize); + data.put(new byte[bodySize]); + data.flip(); + + AsyncRequestBody requestBody = AsyncRequestBody.fromPublisher(Flowable.just(data)); + + RequestDataSupplierAdapter adapter = new RequestDataSupplierAdapter(requestBody); + + ByteBuffer inputBuffer = ByteBuffer.allocate(1); + + for (int i = 0; i < bodySize; ++i) { + adapter.getRequestBytes(inputBuffer); + assertThat(inputBuffer.remaining()).isEqualTo(0); + inputBuffer.flip(); + } + } + + @Test + public void getRequestData_publisherThrows_surfacesException() { + Publisher errorPublisher = Flowable.error(new RuntimeException("Something wrong happened")); + + AsyncRequestBody requestBody = AsyncRequestBody.fromPublisher(errorPublisher); + RequestDataSupplierAdapter adapter = new RequestDataSupplierAdapter(requestBody); + + assertThatThrownBy(() -> adapter.getRequestBytes(ByteBuffer.allocate(16))) + .isInstanceOf(RuntimeException.class) + .hasMessageContaining("Something wrong happened"); + } + + @Test + public void getRequestData_publisherThrows_wrapsExceptionIfNotRuntimeException() { + Publisher errorPublisher = Flowable.error(new IOException("Some I/O error happened")); + + AsyncRequestBody requestBody = AsyncRequestBody.fromPublisher(errorPublisher); + RequestDataSupplierAdapter adapter = new RequestDataSupplierAdapter(requestBody); + + assertThatThrownBy(() -> adapter.getRequestBytes(ByteBuffer.allocate(16))) + .isInstanceOf(RuntimeException.class) + .hasCauseInstanceOf(IOException.class); + } + + @Test + public void resetMidStream_discardsBufferedData() { + long requestSize = RequestDataSupplierAdapter.DEFAULT_REQUEST_SIZE; + int inputBufferSize = 16; + + Publisher requestBody = new Publisher() { + private byte value = 0; + + @Override + public void subscribe(Subscriber subscriber) { + byte byteVal = value++; + + List dataList = Stream.generate(() -> { + byte[] data = new byte[inputBufferSize]; + Arrays.fill(data, byteVal); + return ByteBuffer.wrap(data); + }) + .limit(requestSize) + .collect(Collectors.toList()); + + Flowable realPublisher = Flowable.fromIterable(dataList); + + realPublisher.subscribe(subscriber); + } + }; + + RequestDataSupplierAdapter adapter = new RequestDataSupplierAdapter(requestBody); + + long resetAfter = requestSize / 2; + + ByteBuffer inputBuffer = ByteBuffer.allocate(inputBufferSize); + + for (long l = 0; l < resetAfter; ++l) { + adapter.getRequestBytes(inputBuffer); + inputBuffer.flip(); + } + + adapter.resetPosition(); + + byte[] expectedBufferContent = new byte[inputBufferSize]; + Arrays.fill(expectedBufferContent, (byte) 1); + + byte[] readBuffer = new byte[inputBufferSize]; + for (int l = 0; l < requestSize; ++l) { + adapter.getRequestBytes(inputBuffer); + // flip for reading + inputBuffer.flip(); + inputBuffer.get(readBuffer); + + // flip for writing + inputBuffer.flip(); + + assertThat(readBuffer).isEqualTo(expectedBufferContent); + } + } +} diff --git a/services/s3/src/test/java/software/amazon/awssdk/services/s3/utils/ChecksumUtils.java b/services/s3/src/test/java/software/amazon/awssdk/services/s3/utils/ChecksumUtils.java new file mode 100644 index 000000000000..960f7314ee2c --- /dev/null +++ b/services/s3/src/test/java/software/amazon/awssdk/services/s3/utils/ChecksumUtils.java @@ -0,0 +1,61 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package software.amazon.awssdk.services.s3.utils; + +import java.io.IOException; +import java.io.InputStream; +import java.nio.ByteBuffer; +import java.security.MessageDigest; +import java.security.NoSuchAlgorithmException; +import java.util.List; + +/** + * Utilities for computing the SHA-256 checksums of various binary objects. + */ +public final class ChecksumUtils { + public static byte[] computeCheckSum(InputStream is) throws IOException, NoSuchAlgorithmException { + MessageDigest instance = MessageDigest.getInstance("SHA-256"); + + byte buff[] = new byte[16384]; + int read; + while ((read = is.read(buff)) != -1) { + instance.update(buff, 0, read); + } + + return instance.digest(); + } + + public static byte[] computeCheckSum(ByteBuffer bb) throws NoSuchAlgorithmException { + MessageDigest instance = MessageDigest.getInstance("SHA-256"); + + instance.update(bb); + + bb.rewind(); + + return instance.digest(); + } + + public static byte[] computeCheckSum(List buffers) throws NoSuchAlgorithmException { + MessageDigest instance = MessageDigest.getInstance("SHA-256"); + + buffers.forEach(bb -> { + instance.update(bb); + bb.rewind(); + }); + + return instance.digest(); + } +}