diff --git a/.changes/next-release/feature-NettyNIOHTTPClient-8d0be28.json b/.changes/next-release/feature-NettyNIOHTTPClient-8d0be28.json new file mode 100644 index 000000000000..7c87f772028a --- /dev/null +++ b/.changes/next-release/feature-NettyNIOHTTPClient-8d0be28.json @@ -0,0 +1,5 @@ +{ + "category": "Netty NIO HTTP Client", + "type": "feature", + "description": "Add ability to to use HTTP proxies with the Netty async client." +} diff --git a/http-clients/netty-nio-client/src/main/java/software/amazon/awssdk/http/nio/netty/NettyNioAsyncHttpClient.java b/http-clients/netty-nio-client/src/main/java/software/amazon/awssdk/http/nio/netty/NettyNioAsyncHttpClient.java index 3e0c564f47ce..ff76b6720657 100644 --- a/http-clients/netty-nio-client/src/main/java/software/amazon/awssdk/http/nio/netty/NettyNioAsyncHttpClient.java +++ b/http-clients/netty-nio-client/src/main/java/software/amazon/awssdk/http/nio/netty/NettyNioAsyncHttpClient.java @@ -91,6 +91,7 @@ private NettyNioAsyncHttpClient(DefaultBuilder builder, AttributeMap serviceDefa .maxStreams(maxStreams) .sdkEventLoopGroup(sdkEventLoopGroup) .sslProvider(resolveSslProvider(builder)) + .proxyConfiguration(builder.proxyConfiguration) .build(); } @@ -343,6 +344,16 @@ public interface Builder extends SdkAsyncHttpClient.Builder { + private final String scheme; + private final String host; + private final int port; + private final Set nonProxyHosts; + + private ProxyConfiguration(BuilderImpl builder) { + this.scheme = builder.scheme; + this.host = builder.host; + this.port = builder.port; + this.nonProxyHosts = Collections.unmodifiableSet(builder.nonProxyHosts); + } + + /** + * @return The proxy scheme. + */ + public String scheme() { + return scheme; + } + + /** + * @return The proxy host. + */ + public String host() { + return host; + } + + /** + * @return The proxy port. + */ + public int port() { + return port; + } + + /** + * @return The set of hosts that should not be proxied. + */ + public Set nonProxyHosts() { + return nonProxyHosts; + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + + if (o == null || getClass() != o.getClass()) { + return false; + } + + ProxyConfiguration that = (ProxyConfiguration) o; + + if (port != that.port) { + return false; + } + + if (scheme != null ? !scheme.equals(that.scheme) : that.scheme != null) { + return false; + } + + if (host != null ? !host.equals(that.host) : that.host != null) { + return false; + } + + return nonProxyHosts.equals(that.nonProxyHosts); + + } + + @Override + public int hashCode() { + int result = scheme != null ? scheme.hashCode() : 0; + result = 31 * result + (host != null ? host.hashCode() : 0); + result = 31 * result + port; + result = 31 * result + nonProxyHosts.hashCode(); + return result; + } + + @Override + public Builder toBuilder() { + return new BuilderImpl(this); + } + + public static Builder builder() { + return new BuilderImpl(); + } + + /** + * Builder for {@link ProxyConfiguration}. + */ + public interface Builder extends CopyableBuilder { + + /** + * Set the hostname of the proxy. + * @param host The proxy host. + * @return This object for method chaining. + */ + Builder host(String host); + + /** + * Set the port that the proxy expects connections on. + * @param port The proxy port. + * @return This object for method chaining. + */ + Builder port(int port); + + /** + * The HTTP scheme to use for connecting to the proxy. Valid values are {@code http} and {@code https}. + *

+ * The client defaults to {@code http} if none is given. + * + * @param scheme The proxy scheme. + * @return This object for method chaining. + */ + Builder scheme(String scheme); + + /** + * Set the set of hosts that should not be proxied. Any request whose host portion matches any of the patterns + * given in the set will be sent to the remote host directly instead of through the proxy. + * + * @param nonProxyHosts The set of hosts that should not be proxied. + * @return This object for method chaining. + */ + Builder nonProxyHosts(Set nonProxyHosts); + } + + private static final class BuilderImpl implements Builder { + private String scheme; + private String host; + private int port; + private Set nonProxyHosts = Collections.emptySet(); + + private BuilderImpl() { + } + + private BuilderImpl(ProxyConfiguration proxyConfiguration) { + this.scheme = proxyConfiguration.scheme; + this.host = proxyConfiguration.host; + this.port = proxyConfiguration.port; + this.nonProxyHosts = new HashSet<>(proxyConfiguration.nonProxyHosts); + } + + @Override + public Builder scheme(String scheme) { + this.scheme = scheme; + return this; + } + + @Override + public Builder host(String host) { + this.host = host; + return this; + } + + @Override + public Builder port(int port) { + this.port = port; + return this; + } + + @Override + public Builder nonProxyHosts(Set nonProxyHosts) { + if (nonProxyHosts != null) { + this.nonProxyHosts = new HashSet<>(nonProxyHosts); + } else { + this.nonProxyHosts = Collections.emptySet(); + } + return this; + } + + @Override + public ProxyConfiguration build() { + return new ProxyConfiguration(this); + } + } +} diff --git a/http-clients/netty-nio-client/src/main/java/software/amazon/awssdk/http/nio/netty/internal/AwaitCloseChannelPoolMap.java b/http-clients/netty-nio-client/src/main/java/software/amazon/awssdk/http/nio/netty/internal/AwaitCloseChannelPoolMap.java index 9e8232729228..c75617401b08 100644 --- a/http-clients/netty-nio-client/src/main/java/software/amazon/awssdk/http/nio/netty/internal/AwaitCloseChannelPoolMap.java +++ b/http-clients/netty-nio-client/src/main/java/software/amazon/awssdk/http/nio/netty/internal/AwaitCloseChannelPoolMap.java @@ -21,17 +21,21 @@ import io.netty.channel.Channel; import io.netty.channel.ChannelOption; import io.netty.channel.pool.ChannelPool; +import io.netty.channel.pool.ChannelPoolHandler; import io.netty.handler.codec.http2.Http2SecurityUtil; import io.netty.handler.ssl.SslContext; import io.netty.handler.ssl.SslContextBuilder; import io.netty.handler.ssl.SslProvider; import io.netty.handler.ssl.SupportedCipherSuiteFilter; import io.netty.handler.ssl.util.InsecureTrustManagerFactory; -import io.netty.util.concurrent.Future; -import io.netty.util.concurrent.Promise; + +import java.net.InetSocketAddress; import java.net.URI; +import java.net.URISyntaxException; import java.util.Collection; +import java.util.Map; import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ExecutionException; import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeoutException; @@ -41,6 +45,7 @@ import software.amazon.awssdk.annotations.SdkInternalApi; import software.amazon.awssdk.annotations.SdkTestInternalApi; import software.amazon.awssdk.http.Protocol; +import software.amazon.awssdk.http.nio.netty.ProxyConfiguration; import software.amazon.awssdk.http.nio.netty.SdkEventLoopGroup; import software.amazon.awssdk.http.nio.netty.internal.http2.HttpOrHttp2ChannelPool; import software.amazon.awssdk.utils.Logger; @@ -49,17 +54,34 @@ * Implementation of {@link SdkChannelPoolMap} that awaits channel pools to be closed upon closing. */ @SdkInternalApi -public final class AwaitCloseChannelPoolMap extends SdkChannelPoolMap { +public final class AwaitCloseChannelPoolMap extends SdkChannelPoolMap { private static final Logger log = Logger.loggerFor(AwaitCloseChannelPoolMap.class); + private static final ChannelPoolHandler NOOP_HANDLER = new ChannelPoolHandler() { + @Override + public void channelReleased(Channel ch) throws Exception { + } + + @Override + public void channelAcquired(Channel ch) throws Exception { + } + + @Override + public void channelCreated(Channel ch) throws Exception { + } + }; + + private final Map shouldProxyForHostCache = new ConcurrentHashMap<>(); + + private final SdkChannelOptions sdkChannelOptions; private final SdkEventLoopGroup sdkEventLoopGroup; private final NettyConfiguration configuration; private final Protocol protocol; private final long maxStreams; private final SslProvider sslProvider; + private final ProxyConfiguration proxyConfiguration; private AwaitCloseChannelPoolMap(Builder builder) { this.sdkChannelOptions = builder.sdkChannelOptions; @@ -68,6 +90,13 @@ private AwaitCloseChannelPoolMap(Builder builder) { this.protocol = builder.protocol; this.maxStreams = builder.maxStreams; this.sslProvider = builder.sslProvider; + this.proxyConfiguration = builder.proxyConfiguration; + } + + @SdkTestInternalApi + AwaitCloseChannelPoolMap(Builder builder, Map shouldProxyForHostCache) { + this(builder); + this.shouldProxyForHostCache.putAll(shouldProxyForHostCache); } public static Builder builder() { @@ -76,24 +105,30 @@ public static Builder builder() { @Override protected SimpleChannelPoolAwareChannelPool newPool(URI key) { - SslContext sslContext = sslContext(key.getScheme()); - Bootstrap bootstrap = - new Bootstrap() - .group(sdkEventLoopGroup.eventLoopGroup()) - .channelFactory(sdkEventLoopGroup.channelFactory()) - .option(ChannelOption.CONNECT_TIMEOUT_MILLIS, configuration.connectTimeoutMillis()) - // TODO run some performance tests with and without this. - .remoteAddress(key.getHost(), key.getPort()); - sdkChannelOptions.channelOptions().forEach(bootstrap::option); + SslContext sslContext = sslContext(key); + + Bootstrap bootstrap = createBootstrap(key); AtomicReference channelPoolRef = new AtomicReference<>(); - ChannelPipelineInitializer handler = + + ChannelPipelineInitializer pipelineInitializer = new ChannelPipelineInitializer(protocol, sslContext, maxStreams, channelPoolRef, configuration, key); - BetterSimpleChannelPool simpleChannelPool = new BetterSimpleChannelPool(bootstrap, handler); + BetterSimpleChannelPool tcpChannelPool; + ChannelPool baseChannelPool; + if (shouldUseProxyForHost(key)) { + tcpChannelPool = new BetterSimpleChannelPool(bootstrap, NOOP_HANDLER); + baseChannelPool = new Http1TunnelConnectionPool(bootstrap.config().group().next(), tcpChannelPool, + sslContext, proxyAddress(key), key, pipelineInitializer); + } else { + tcpChannelPool = new BetterSimpleChannelPool(bootstrap, pipelineInitializer); + baseChannelPool = tcpChannelPool; + } + + ChannelPool wrappedPool = wrapBaseChannelPool(bootstrap, baseChannelPool); - channelPoolRef.set(wrapSimpleChannelPool(bootstrap, simpleChannelPool)); - return new SimpleChannelPoolAwareChannelPool(simpleChannelPool, channelPoolRef.get()); + channelPoolRef.set(wrappedPool); + return new SimpleChannelPoolAwareChannelPool(wrappedPool, tcpChannelPool); } @Override @@ -110,7 +145,7 @@ public void close() { try { CompletableFuture.allOf(channelPools.stream() - .map(pool -> pool.underlyingSimpleChannelPool.closeFuture()) + .map(pool -> pool.underlyingSimpleChannelPool().closeFuture()) .toArray(CompletableFuture[]::new)) .get(CHANNEL_POOL_CLOSE_TIMEOUT_SECONDS, TimeUnit.SECONDS); } catch (InterruptedException e) { @@ -121,7 +156,67 @@ public void close() { } } - private ChannelPool wrapSimpleChannelPool(Bootstrap bootstrap, ChannelPool channelPool) { + private Bootstrap createBootstrap(URI poolKey) { + String host = bootstrapHost(poolKey); + int port = bootstrapPort(poolKey); + + Bootstrap bootstrap = + new Bootstrap() + .group(sdkEventLoopGroup.eventLoopGroup()) + .channelFactory(sdkEventLoopGroup.channelFactory()) + .option(ChannelOption.CONNECT_TIMEOUT_MILLIS, configuration.connectTimeoutMillis()) + // TODO run some performance tests with and without this. + .remoteAddress(new InetSocketAddress(host, port)); + sdkChannelOptions.channelOptions().forEach(bootstrap::option); + + return bootstrap; + } + + + private boolean shouldUseProxyForHost(URI remoteAddr) { + if (proxyConfiguration == null) { + return false; + } + + + return shouldProxyForHostCache.computeIfAbsent(remoteAddr, (uri) -> + proxyConfiguration.nonProxyHosts().stream().noneMatch(h -> uri.getHost().matches(h)) + ); + } + + private String bootstrapHost(URI remoteHost) { + if (shouldUseProxyForHost(remoteHost)) { + return proxyConfiguration.host(); + } + return remoteHost.getHost(); + } + + private int bootstrapPort(URI remoteHost) { + if (shouldUseProxyForHost(remoteHost)) { + return proxyConfiguration.port(); + } + return remoteHost.getPort(); + } + + private URI proxyAddress(URI remoteHost) { + if (!shouldUseProxyForHost(remoteHost)) { + return null; + } + + String scheme = proxyConfiguration.scheme(); + if (scheme == null) { + scheme = "http"; + } + + try { + return new URI(scheme, null, proxyConfiguration.host(), proxyConfiguration.port(), null, null, + null); + } catch (URISyntaxException e) { + throw new RuntimeException("Unable to construct proxy URI", e); + } + } + + private ChannelPool wrapBaseChannelPool(Bootstrap bootstrap, ChannelPool channelPool) { // Wrap the channel pool such that the ChannelAttributeKey.CLOSE_ON_RELEASE flag is honored. channelPool = new HonorCloseOnReleaseChannelPool(channelPool); @@ -150,10 +245,16 @@ private ChannelPool wrapSimpleChannelPool(Bootstrap bootstrap, ChannelPool chann return channelPool; } - private SslContext sslContext(String protocol) { - if (!protocol.equalsIgnoreCase("https")) { + private SslContext sslContext(URI targetAddress) { + URI proxyAddress = proxyAddress(targetAddress); + + boolean needContext = targetAddress.getScheme().equalsIgnoreCase("https") + || proxyAddress != null && proxyAddress.getScheme().equalsIgnoreCase("https"); + + if (!needContext) { return null; } + try { return SslContextBuilder.forClient() .sslProvider(sslProvider) @@ -169,47 +270,6 @@ private TrustManagerFactory getTrustManager() { return configuration.trustAllCertificates() ? InsecureTrustManagerFactory.INSTANCE : null; } - static final class SimpleChannelPoolAwareChannelPool implements ChannelPool { - private final BetterSimpleChannelPool underlyingSimpleChannelPool; - private final ChannelPool actualChannelPool; - - private SimpleChannelPoolAwareChannelPool(BetterSimpleChannelPool underlyingSimpleChannelPool, - ChannelPool actualChannelPool) { - this.underlyingSimpleChannelPool = underlyingSimpleChannelPool; - this.actualChannelPool = actualChannelPool; - } - - @Override - public Future acquire() { - return actualChannelPool.acquire(); - } - - @Override - public Future acquire(Promise promise) { - return actualChannelPool.acquire(promise); - } - - @Override - public Future release(Channel channel) { - return actualChannelPool.release(channel); - } - - @Override - public Future release(Channel channel, Promise promise) { - return actualChannelPool.release(channel, promise); - } - - @Override - public void close() { - actualChannelPool.close(); - } - - @SdkTestInternalApi - BetterSimpleChannelPool underlyingSimpleChannelPool() { - return underlyingSimpleChannelPool; - } - } - public static class Builder { private SdkChannelOptions sdkChannelOptions; @@ -218,6 +278,7 @@ public static class Builder { private Protocol protocol; private long maxStreams; private SslProvider sslProvider; + private ProxyConfiguration proxyConfiguration; private Builder() { } @@ -252,6 +313,11 @@ public Builder sslProvider(SslProvider sslProvider) { return this; } + public Builder proxyConfiguration(ProxyConfiguration proxyConfiguration) { + this.proxyConfiguration = proxyConfiguration; + return this; + } + public AwaitCloseChannelPoolMap build() { return new AwaitCloseChannelPoolMap(this); } diff --git a/http-clients/netty-nio-client/src/main/java/software/amazon/awssdk/http/nio/netty/internal/Http1TunnelConnectionPool.java b/http-clients/netty-nio-client/src/main/java/software/amazon/awssdk/http/nio/netty/internal/Http1TunnelConnectionPool.java new file mode 100644 index 000000000000..1c051886090f --- /dev/null +++ b/http-clients/netty-nio-client/src/main/java/software/amazon/awssdk/http/nio/netty/internal/Http1TunnelConnectionPool.java @@ -0,0 +1,165 @@ +/* + * Copyright 2010-2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package software.amazon.awssdk.http.nio.netty.internal; + +import io.netty.buffer.ByteBufAllocator; +import io.netty.channel.Channel; +import io.netty.channel.ChannelHandler; +import io.netty.channel.EventLoop; +import io.netty.channel.pool.ChannelPool; +import io.netty.channel.pool.ChannelPoolHandler; +import io.netty.handler.ssl.SslContext; +import io.netty.handler.ssl.SslHandler; +import io.netty.util.AttributeKey; +import io.netty.util.concurrent.Future; +import io.netty.util.concurrent.Promise; + +import java.net.URI; + +import software.amazon.awssdk.annotations.SdkInternalApi; +import software.amazon.awssdk.annotations.SdkTestInternalApi; +import software.amazon.awssdk.utils.Logger; +import software.amazon.awssdk.utils.StringUtils; + +/** + * Connection pool that knows how to establish a tunnel using the HTTP CONNECT method. + */ +@SdkInternalApi +public class Http1TunnelConnectionPool implements ChannelPool { + static final AttributeKey TUNNEL_ESTABLISHED_KEY = AttributeKey.newInstance( + "aws.http.nio.netty.async.Http1TunnelConnectionPool.tunnelEstablished"); + + private static final Logger log = Logger.loggerFor(Http1TunnelConnectionPool.class); + + private final EventLoop eventLoop; + private final ChannelPool delegate; + private final SslContext sslContext; + private final URI proxyAddress; + private final URI remoteAddress; + private final ChannelPoolHandler handler; + private final InitHandlerSupplier initHandlerSupplier; + + public Http1TunnelConnectionPool(EventLoop eventLoop, ChannelPool delegate, SslContext sslContext, + URI proxyAddress, URI remoteAddress, ChannelPoolHandler handler) { + this(eventLoop, delegate, sslContext, proxyAddress, remoteAddress, handler, ProxyTunnelInitHandler::new); + + } + + @SdkTestInternalApi + Http1TunnelConnectionPool(EventLoop eventLoop, ChannelPool delegate, SslContext sslContext, + URI proxyAddress, URI remoteAddress, ChannelPoolHandler handler, + InitHandlerSupplier initHandlerSupplier) { + this.eventLoop = eventLoop; + this.delegate = delegate; + this.sslContext = sslContext; + this.proxyAddress = proxyAddress; + this.remoteAddress = remoteAddress; + this.handler = handler; + this.initHandlerSupplier = initHandlerSupplier; + } + + @Override + public Future acquire() { + return acquire(eventLoop.newPromise()); + } + + @Override + public Future acquire(Promise promise) { + delegate.acquire(eventLoop.newPromise()).addListener((Future f) -> { + if (f.isSuccess()) { + setupChannel(f.getNow(), promise); + } else { + promise.setFailure(f.cause()); + } + }); + return promise; + } + + @Override + public Future release(Channel channel) { + return release(channel, eventLoop.newPromise()); + } + + @Override + public Future release(Channel channel, Promise promise) { + return delegate.release(channel, promise); + } + + @Override + public void close() { + delegate.close(); + } + + private void setupChannel(Channel ch, Promise acquirePromise) { + if (isTunnelEstablished(ch)) { + log.debug(() -> String.format("Tunnel already established for %s", ch.id().asShortText())); + acquirePromise.setSuccess(ch); + return; + } + + log.debug(() -> String.format("Tunnel not yet established for channel %s. Establishing tunnel now.", + ch.id().asShortText())); + + Promise tunnelEstablishedPromise = eventLoop.newPromise(); + + SslHandler sslHandler = createSslHandlerIfNeeded(ch.alloc()); + if (sslHandler != null) { + ch.pipeline().addLast(sslHandler); + } + ch.pipeline().addLast(initHandlerSupplier.newInitHandler(delegate, remoteAddress, tunnelEstablishedPromise)); + + tunnelEstablishedPromise.addListener((Future f) -> { + if (f.isSuccess()) { + Channel tunnel = f.getNow(); + handler.channelCreated(tunnel); + tunnel.attr(TUNNEL_ESTABLISHED_KEY).set(true); + acquirePromise.setSuccess(tunnel); + } else { + ch.close(); + delegate.release(ch); + + Throwable cause = f.cause(); + log.error(() -> String.format("Unable to establish tunnel for channel %s", ch.id().asShortText()), cause); + acquirePromise.setFailure(cause); + } + }); + } + + private SslHandler createSslHandlerIfNeeded(ByteBufAllocator alloc) { + if (sslContext == null) { + return null; + } + + String scheme = proxyAddress.getScheme(); + + if (!"https".equals(StringUtils.lowerCase(scheme))) { + return null; + } + + return sslContext.newHandler(alloc, proxyAddress.getHost(), proxyAddress.getPort()); + } + + private static boolean isTunnelEstablished(Channel ch) { + Boolean established = ch.attr(TUNNEL_ESTABLISHED_KEY).get(); + return Boolean.TRUE.equals(established); + } + + @SdkTestInternalApi + @FunctionalInterface + interface InitHandlerSupplier { + ChannelHandler newInitHandler(ChannelPool sourcePool, URI remoteAddress, Promise tunnelInitFuture); + } +} diff --git a/http-clients/netty-nio-client/src/main/java/software/amazon/awssdk/http/nio/netty/internal/ProxyTunnelInitHandler.java b/http-clients/netty-nio-client/src/main/java/software/amazon/awssdk/http/nio/netty/internal/ProxyTunnelInitHandler.java new file mode 100644 index 000000000000..ff82a10ff584 --- /dev/null +++ b/http-clients/netty-nio-client/src/main/java/software/amazon/awssdk/http/nio/netty/internal/ProxyTunnelInitHandler.java @@ -0,0 +1,113 @@ +/* + * Copyright 2010-2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package software.amazon.awssdk.http.nio.netty.internal; + +import io.netty.buffer.Unpooled; +import io.netty.channel.Channel; +import io.netty.channel.ChannelDuplexHandler; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelPipeline; +import io.netty.channel.pool.ChannelPool; +import io.netty.handler.codec.http.DefaultFullHttpRequest; +import io.netty.handler.codec.http.HttpClientCodec; +import io.netty.handler.codec.http.HttpHeaderNames; +import io.netty.handler.codec.http.HttpMethod; +import io.netty.handler.codec.http.HttpRequest; +import io.netty.handler.codec.http.HttpResponse; +import io.netty.handler.codec.http.HttpVersion; +import io.netty.util.concurrent.Promise; +import java.io.IOException; +import java.net.URI; +import java.util.function.Supplier; +import software.amazon.awssdk.annotations.SdkInternalApi; +import software.amazon.awssdk.annotations.SdkTestInternalApi; + +/** + * Handler that initializes the HTTP tunnel. + */ +@SdkInternalApi +public final class ProxyTunnelInitHandler extends ChannelDuplexHandler { + private final ChannelPool sourcePool; + private final URI remoteHost; + private final Promise initPromise; + private final Supplier httpCodecSupplier; + + public ProxyTunnelInitHandler(ChannelPool sourcePool, URI remoteHost, Promise initPromise) { + this(sourcePool, remoteHost, initPromise, HttpClientCodec::new); + } + + @SdkTestInternalApi + public ProxyTunnelInitHandler(ChannelPool sourcePool, URI remoteHost, Promise initPromise, + Supplier httpCodecSupplier) { + this.sourcePool = sourcePool; + this.remoteHost = remoteHost; + this.initPromise = initPromise; + this.httpCodecSupplier = httpCodecSupplier; + } + + @Override + public void handlerAdded(ChannelHandlerContext ctx) { + ChannelPipeline pipeline = ctx.pipeline(); + pipeline.addBefore(ctx.name(), null, httpCodecSupplier.get()); + HttpRequest connectRequest = connectRequest(); + ctx.channel().writeAndFlush(connectRequest).addListener(f -> { + if (!f.isSuccess()) { + ctx.close(); + sourcePool.release(ctx.channel()); + initPromise.setFailure(new IOException("Unable to send CONNECT request to proxy", f.cause())); + } + }); + } + + @Override + public void handlerRemoved(ChannelHandlerContext ctx) { + if (ctx.pipeline().get(HttpClientCodec.class) != null) { + ctx.pipeline().remove(HttpClientCodec.class); + } + } + + @Override + public void channelRead(ChannelHandlerContext ctx, Object msg) { + if (msg instanceof HttpResponse) { + HttpResponse response = (HttpResponse) msg; + if (response.status().code() == 200) { + ctx.pipeline().remove(this); + // Note: we leave the SslHandler here (if we added it) + initPromise.setSuccess(ctx.channel()); + return; + } + } + + // Fail if we received any other type of message or we didn't get a 200 from the proxy + ctx.pipeline().remove(this); + ctx.close(); + sourcePool.release(ctx.channel()); + initPromise.setFailure(new IOException("Could not connect to proxy")); + } + + private HttpRequest connectRequest() { + String uri = getUri(); + HttpRequest request = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.CONNECT, uri, + Unpooled.EMPTY_BUFFER, false); + request.headers().add(HttpHeaderNames.HOST, uri); + return request; + } + + private String getUri() { + return remoteHost.getHost() + ":" + remoteHost.getPort(); + } +} + diff --git a/http-clients/netty-nio-client/src/main/java/software/amazon/awssdk/http/nio/netty/internal/SimpleChannelPoolAwareChannelPool.java b/http-clients/netty-nio-client/src/main/java/software/amazon/awssdk/http/nio/netty/internal/SimpleChannelPoolAwareChannelPool.java new file mode 100644 index 000000000000..c8a231b07258 --- /dev/null +++ b/http-clients/netty-nio-client/src/main/java/software/amazon/awssdk/http/nio/netty/internal/SimpleChannelPoolAwareChannelPool.java @@ -0,0 +1,63 @@ +/* + * Copyright 2010-2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package software.amazon.awssdk.http.nio.netty.internal; + +import io.netty.channel.Channel; +import io.netty.channel.pool.ChannelPool; +import io.netty.util.concurrent.Future; +import io.netty.util.concurrent.Promise; +import software.amazon.awssdk.annotations.SdkInternalApi; + +@SdkInternalApi +final class SimpleChannelPoolAwareChannelPool implements ChannelPool { + private final ChannelPool delegate; + private final BetterSimpleChannelPool simpleChannelPool; + + SimpleChannelPoolAwareChannelPool(ChannelPool delegate, BetterSimpleChannelPool simpleChannelPool) { + this.delegate = delegate; + this.simpleChannelPool = simpleChannelPool; + } + + @Override + public Future acquire() { + return delegate.acquire(); + } + + @Override + public Future acquire(Promise promise) { + return delegate.acquire(promise); + } + + @Override + public Future release(Channel channel) { + return delegate.release(channel); + } + + @Override + public Future release(Channel channel, Promise promise) { + return delegate.release(channel, promise); + } + + @Override + public void close() { + delegate.close(); + } + + public BetterSimpleChannelPool underlyingSimpleChannelPool() { + return simpleChannelPool; + } + +} diff --git a/http-clients/netty-nio-client/src/test/java/software/amazon/awssdk/http/nio/netty/NettyNioAsyncHttpClientWireMockTest.java b/http-clients/netty-nio-client/src/test/java/software/amazon/awssdk/http/nio/netty/NettyNioAsyncHttpClientWireMockTest.java index 8f6585dddc37..07951a02e562 100644 --- a/http-clients/netty-nio-client/src/test/java/software/amazon/awssdk/http/nio/netty/NettyNioAsyncHttpClientWireMockTest.java +++ b/http-clients/netty-nio-client/src/test/java/software/amazon/awssdk/http/nio/netty/NettyNioAsyncHttpClientWireMockTest.java @@ -42,7 +42,6 @@ import static org.mockito.Mockito.when; import com.github.tomakehurst.wiremock.http.Fault; -import com.github.tomakehurst.wiremock.http.trafficlistener.WiremockNetworkTrafficListener; import com.github.tomakehurst.wiremock.junit.WireMockRule; import io.netty.channel.Channel; import io.netty.channel.ChannelFactory; @@ -54,10 +53,8 @@ import io.netty.handler.ssl.SslProvider; import io.netty.util.AttributeKey; import java.io.IOException; -import java.net.Socket; import java.net.URI; import java.nio.ByteBuffer; -import java.nio.charset.StandardCharsets; import java.time.Duration; import java.util.ArrayList; import java.util.Collection; @@ -441,7 +438,7 @@ public void requestContentOnlyEqualToContentLengthHeaderFromProvider() throws In // HTTP servers will stop processing the request as soon as it reads // bytes equal to 'Content-Length' so we need to inspect the raw // traffic to ensure that there wasn't anything after that. - assertThat(wiremockTrafficListener.requests.toString()).endsWith(content); + assertThat(wiremockTrafficListener.requests().toString()).endsWith(content); } @Test @@ -654,33 +651,4 @@ private static AttributeMap mapWithTrustAllCerts() { .put(SdkHttpConfigurationOption.TRUST_ALL_CERTIFICATES, true) .build(); } - - private static class RecordingNetworkTrafficListener implements WiremockNetworkTrafficListener { - private final StringBuilder requests = new StringBuilder(); - - - @Override - public void opened(Socket socket) { - - } - - @Override - public void incoming(Socket socket, ByteBuffer byteBuffer) { - requests.append(StandardCharsets.UTF_8.decode(byteBuffer)); - } - - @Override - public void outgoing(Socket socket, ByteBuffer byteBuffer) { - - } - - @Override - public void closed(Socket socket) { - - } - - public void reset() { - requests.setLength(0); - } - } } diff --git a/http-clients/netty-nio-client/src/test/java/software/amazon/awssdk/http/nio/netty/ProxyConfigurationTest.java b/http-clients/netty-nio-client/src/test/java/software/amazon/awssdk/http/nio/netty/ProxyConfigurationTest.java new file mode 100644 index 000000000000..753a2f4b6ac4 --- /dev/null +++ b/http-clients/netty-nio-client/src/test/java/software/amazon/awssdk/http/nio/netty/ProxyConfigurationTest.java @@ -0,0 +1,132 @@ +/* + * Copyright 2010-2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package software.amazon.awssdk.http.nio.netty; + +import static org.assertj.core.api.Assertions.assertThat; +import java.lang.reflect.InvocationTargetException; +import java.lang.reflect.Method; +import java.util.HashSet; +import java.util.Random; +import java.util.Set; +import java.util.stream.Stream; +import org.junit.Test; + +/** + * Tests for {@link ProxyConfiguration}. + */ +public class ProxyConfigurationTest { + private static final Random RNG = new Random(); + + @Test + public void build_setsAllProperties() { + verifyAllPropertiesSet(allPropertiesSetConfig()); + } + + @Test + public void toBuilder_roundTrip_producesExactCopy() { + ProxyConfiguration original = allPropertiesSetConfig(); + + ProxyConfiguration copy = original.toBuilder().build(); + + assertThat(copy).isEqualTo(original); + } + + @Test + public void setNonProxyHostsToNull_createsEmptySet() { + ProxyConfiguration cfg = ProxyConfiguration.builder() + .nonProxyHosts(null) + .build(); + + assertThat(cfg.nonProxyHosts()).isEmpty(); + } + + @Test + public void toBuilderModified_doesNotModifySource() { + ProxyConfiguration original = allPropertiesSetConfig(); + + ProxyConfiguration modified = setAllPropertiesToRandomValues(original.toBuilder()).build(); + + assertThat(original).isNotEqualTo(modified); + } + + private ProxyConfiguration allPropertiesSetConfig() { + return setAllPropertiesToRandomValues(ProxyConfiguration.builder()).build(); + } + + private ProxyConfiguration.Builder setAllPropertiesToRandomValues(ProxyConfiguration.Builder builder) { + Stream.of(builder.getClass().getDeclaredMethods()) + .filter(m -> m.getParameterCount() == 1 && m.getReturnType().equals(ProxyConfiguration.Builder.class)) + .forEach(m -> { + try { + m.setAccessible(true); + setRandomValue(builder, m); + } catch (Exception e) { + throw new RuntimeException("Could not create random proxy config", e); + } + }); + return builder; + } + + private void setRandomValue(Object o, Method setter) throws InvocationTargetException, IllegalAccessException { + Class paramClass = setter.getParameterTypes()[0]; + + if (String.class.equals(paramClass)) { + setter.invoke(o, randomString()); + } else if (int.class.equals(paramClass)) { + setter.invoke(o, RNG.nextInt()); + } else if (Set.class.isAssignableFrom(paramClass)) { + setter.invoke(o, randomSet()); + } else { + throw new RuntimeException("Don't know how create random value for type " + paramClass); + } + } + + private void verifyAllPropertiesSet(ProxyConfiguration cfg) { + boolean hasNullProperty = Stream.of(cfg.getClass().getDeclaredMethods()) + .filter(m -> !m.getReturnType().equals(Void.class) && m.getParameterCount() == 0) + .anyMatch(m -> { + m.setAccessible(true); + try { + return m.invoke(cfg) == null; + } catch (Exception e) { + return true; + } + }); + + if (hasNullProperty) { + throw new RuntimeException("Given configuration has unset property"); + } + } + + private String randomString() { + String alpha = "abcdefghijklmnopqrstuwxyz"; + + StringBuilder sb = new StringBuilder(16); + for (int i = 0; i < 16; ++i) { + sb.append(alpha.charAt(RNG.nextInt(16))); + } + + return sb.toString(); + } + + private Set randomSet() { + Set ss = new HashSet<>(16); + for (int i = 0; i < 16; ++i) { + ss.add(randomString()); + } + return ss; + } +} diff --git a/http-clients/netty-nio-client/src/test/java/software/amazon/awssdk/http/nio/netty/ProxyWireMockTest.java b/http-clients/netty-nio-client/src/test/java/software/amazon/awssdk/http/nio/netty/ProxyWireMockTest.java new file mode 100644 index 000000000000..afab5f09a218 --- /dev/null +++ b/http-clients/netty-nio-client/src/test/java/software/amazon/awssdk/http/nio/netty/ProxyWireMockTest.java @@ -0,0 +1,137 @@ +/* + * Copyright 2010-2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + + +package software.amazon.awssdk.http.nio.netty; + +import static com.github.tomakehurst.wiremock.client.WireMock.aResponse; +import static com.github.tomakehurst.wiremock.client.WireMock.get; +import static com.github.tomakehurst.wiremock.client.WireMock.urlPathEqualTo; +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.Mockito.mock; +import com.github.tomakehurst.wiremock.WireMockServer; +import com.github.tomakehurst.wiremock.core.WireMockConfiguration; +import java.io.IOException; +import java.util.concurrent.CompletionException; +import java.util.stream.Collectors; +import java.util.stream.Stream; +import org.junit.After; +import org.junit.AfterClass; +import org.junit.BeforeClass; +import org.junit.Test; +import software.amazon.awssdk.http.SdkHttpFullRequest; +import software.amazon.awssdk.http.SdkHttpMethod; +import software.amazon.awssdk.http.async.AsyncExecuteRequest; +import software.amazon.awssdk.http.async.SdkAsyncHttpClient; +import software.amazon.awssdk.http.async.SdkAsyncHttpResponseHandler; + +/** + * Tests for HTTP proxy functionality in the Netty client. + */ +public class ProxyWireMockTest { + private static SdkAsyncHttpClient client; + + private static ProxyConfiguration proxyCfg; + + private static WireMockServer mockServer = new WireMockServer(new WireMockConfiguration() + .dynamicPort() + .dynamicHttpsPort()); + + private static WireMockServer mockProxy = new WireMockServer(new WireMockConfiguration() + .dynamicPort() + .dynamicHttpsPort()); + + @BeforeClass + public static void setup() { + mockProxy.start(); + mockServer.start(); + + mockServer.stubFor(get(urlPathEqualTo("/")).willReturn(aResponse().withStatus(200).withBody("hello"))); + + proxyCfg = ProxyConfiguration.builder() + .host("localhost") + .port(mockProxy.port()) + .build(); + } + + @AfterClass + public static void teardown() { + mockServer.stop(); + mockProxy.stop(); + } + + @After + public void methodTeardown() { + if (client != null) { + client.close(); + } + client = null; + } + + @Test(expected = IOException.class) + public void proxyConfigured_attemptsToConnect() throws Throwable { + AsyncExecuteRequest req = AsyncExecuteRequest.builder() + .request(testSdkRequest()) + .responseHandler(mock(SdkAsyncHttpResponseHandler.class)) + .build(); + + client = NettyNioAsyncHttpClient.builder() + .proxyConfiguration(proxyCfg) + .build(); + + try { + client.execute(req).join(); + } catch (CompletionException e) { + Throwable cause = e.getCause(); + // WireMock doesn't allow for mocking the CONNECT method so it will just return a 404, causing the client + // to throw an exception. + assertThat(e.getCause().getMessage()).isEqualTo("Could not connect to proxy"); + throw cause; + } + } + + @Test + public void proxyConfigured_hostInNonProxySet_doesNotConnect() { + RecordingResponseHandler responseHandler = new RecordingResponseHandler(); + AsyncExecuteRequest req = AsyncExecuteRequest.builder() + .request(testSdkRequest()) + .responseHandler(responseHandler) + .requestContentPublisher(new EmptyPublisher()) + .build(); + + ProxyConfiguration cfg = proxyCfg.toBuilder() + .nonProxyHosts(Stream.of("localhost").collect(Collectors.toSet())) + .build(); + + client = NettyNioAsyncHttpClient.builder() + .proxyConfiguration(cfg) + .build(); + + client.execute(req).join(); + + responseHandler.completeFuture.join(); + assertThat(responseHandler.fullResponseAsString()).isEqualTo("hello"); + } + + private SdkHttpFullRequest testSdkRequest() { + return SdkHttpFullRequest.builder() + .method(SdkHttpMethod.GET) + .protocol("http") + .host("localhost") + .port(mockServer.port()) + .putHeader("host", "localhost") + .build(); + } +} diff --git a/http-clients/netty-nio-client/src/test/java/software/amazon/awssdk/http/nio/netty/RecordingNetworkTrafficListener.java b/http-clients/netty-nio-client/src/test/java/software/amazon/awssdk/http/nio/netty/RecordingNetworkTrafficListener.java new file mode 100644 index 000000000000..a9b3ba99e15c --- /dev/null +++ b/http-clients/netty-nio-client/src/test/java/software/amazon/awssdk/http/nio/netty/RecordingNetworkTrafficListener.java @@ -0,0 +1,58 @@ +/* + * Copyright 2010-2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package software.amazon.awssdk.http.nio.netty; + +import com.github.tomakehurst.wiremock.http.trafficlistener.WiremockNetworkTrafficListener; +import java.net.Socket; +import java.nio.ByteBuffer; +import java.nio.charset.StandardCharsets; + +/** + * Simple implementation of {@link WiremockNetworkTrafficListener} to record all requests received as a string for later + * verification. + */ +public class RecordingNetworkTrafficListener implements WiremockNetworkTrafficListener { + private final StringBuilder requests = new StringBuilder(); + + + @Override + public void opened(Socket socket) { + + } + + @Override + public void incoming(Socket socket, ByteBuffer byteBuffer) { + requests.append(StandardCharsets.UTF_8.decode(byteBuffer)); + } + + @Override + public void outgoing(Socket socket, ByteBuffer byteBuffer) { + + } + + @Override + public void closed(Socket socket) { + + } + + public void reset() { + requests.setLength(0); + } + + public StringBuilder requests() { + return requests; + } +} \ No newline at end of file diff --git a/http-clients/netty-nio-client/src/test/java/software/amazon/awssdk/http/nio/netty/internal/AwaitCloseChannelPoolMapTest.java b/http-clients/netty-nio-client/src/test/java/software/amazon/awssdk/http/nio/netty/internal/AwaitCloseChannelPoolMapTest.java index 1f46438bad30..b19fe78e0500 100644 --- a/http-clients/netty-nio-client/src/test/java/software/amazon/awssdk/http/nio/netty/internal/AwaitCloseChannelPoolMapTest.java +++ b/http-clients/netty-nio-client/src/test/java/software/amazon/awssdk/http/nio/netty/internal/AwaitCloseChannelPoolMapTest.java @@ -16,51 +16,71 @@ package software.amazon.awssdk.http.nio.netty.internal; +import static com.github.tomakehurst.wiremock.core.WireMockConfiguration.wireMockConfig; import static org.assertj.core.api.Assertions.assertThat; import static software.amazon.awssdk.http.SdkHttpConfigurationOption.GLOBAL_HTTP_DEFAULTS; +import com.github.tomakehurst.wiremock.junit.WireMockRule; +import io.netty.channel.Channel; import io.netty.handler.ssl.SslProvider; +import io.netty.util.concurrent.Future; import java.net.URI; import java.util.ArrayList; +import java.util.HashMap; import java.util.List; -import java.util.concurrent.ExecutionException; +import java.util.Map; +import java.util.stream.Collectors; +import java.util.stream.Stream; import org.apache.commons.lang3.RandomStringUtils; -import org.junit.BeforeClass; +import org.junit.After; +import org.junit.Rule; import org.junit.Test; import software.amazon.awssdk.http.Protocol; +import software.amazon.awssdk.http.nio.netty.ProxyConfiguration; +import software.amazon.awssdk.http.nio.netty.RecordingNetworkTrafficListener; import software.amazon.awssdk.http.nio.netty.SdkEventLoopGroup; -import software.amazon.awssdk.http.nio.netty.internal.AwaitCloseChannelPoolMap.SimpleChannelPoolAwareChannelPool; public class AwaitCloseChannelPoolMapTest { - private static AwaitCloseChannelPoolMap channelPoolMap; + private final RecordingNetworkTrafficListener recorder = new RecordingNetworkTrafficListener(); + private AwaitCloseChannelPoolMap channelPoolMap; - @BeforeClass - public static void setup() { - channelPoolMap = AwaitCloseChannelPoolMap.builder() - .sdkChannelOptions(new SdkChannelOptions()) - .sdkEventLoopGroup(SdkEventLoopGroup.builder().build()) - .configuration(new NettyConfiguration(GLOBAL_HTTP_DEFAULTS)) - .protocol(Protocol.HTTP1_1) - .maxStreams(100) - .sslProvider(SslProvider.OPENSSL) - .build(); + @Rule + public WireMockRule mockProxy = new WireMockRule(wireMockConfig() + .dynamicPort() + .networkTrafficListener(recorder)); + + @After + public void methodTeardown() { + if (channelPoolMap != null) { + channelPoolMap.close(); + } + channelPoolMap = null; + + recorder.reset(); } @Test - public void close_underlyingPoolsShouldBeClosed() throws ExecutionException, InterruptedException { + public void close_underlyingPoolsShouldBeClosed() { + channelPoolMap = AwaitCloseChannelPoolMap.builder() + .sdkChannelOptions(new SdkChannelOptions()) + .sdkEventLoopGroup(SdkEventLoopGroup.builder().build()) + .configuration(new NettyConfiguration(GLOBAL_HTTP_DEFAULTS)) + .protocol(Protocol.HTTP1_1) + .maxStreams(100) + .sslProvider(SslProvider.OPENSSL) + .build(); int numberOfChannelPools = 5; List channelPools = new ArrayList<>(); for (int i = 0; i < numberOfChannelPools; i++) { channelPools.add( - channelPoolMap.get(URI.create("http://" + RandomStringUtils.randomAlphabetic(2) + i + "localhost:" + numberOfChannelPools))); + channelPoolMap.get(URI.create("http://" + RandomStringUtils.randomAlphabetic(2) + i + "localhost:" + numberOfChannelPools))); } assertThat(channelPoolMap.pools().size()).isEqualTo(numberOfChannelPools); - channelPoolMap.close(); channelPools.forEach(channelPool -> { assertThat(channelPool.underlyingSimpleChannelPool().closeFuture()).isDone(); @@ -68,4 +88,64 @@ public void close_underlyingPoolsShouldBeClosed() throws ExecutionException, Int }); } + @Test + public void usingProxy_usesCachedValueWhenPresent() { + URI targetUri = URI.create("https://some-awesome-service-1234.amazonaws.com"); + + Map shouldProxyCache = new HashMap<>(); + shouldProxyCache.put(targetUri, true); + + ProxyConfiguration proxyConfiguration = ProxyConfiguration.builder() + .host("localhost") + .port(mockProxy.port()) + // Deliberately set the target host as a non-proxy host to see if it will check the cache first + .nonProxyHosts(Stream.of(targetUri.getHost()).collect(Collectors.toSet())) + .build(); + + AwaitCloseChannelPoolMap.Builder builder = AwaitCloseChannelPoolMap.builder() + .proxyConfiguration(proxyConfiguration) + .sdkChannelOptions(new SdkChannelOptions()) + .sdkEventLoopGroup(SdkEventLoopGroup.builder().build()) + .configuration(new NettyConfiguration(GLOBAL_HTTP_DEFAULTS)) + .protocol(Protocol.HTTP1_1) + .maxStreams(100) + .sslProvider(SslProvider.OPENSSL); + + channelPoolMap = new AwaitCloseChannelPoolMap(builder, shouldProxyCache); + + // The target host does not exist so acquiring a channel should fail unless we're configured to connect to + // the mock proxy host for this URI. + SimpleChannelPoolAwareChannelPool channelPool = channelPoolMap.newPool(targetUri); + Future channelFuture = channelPool.underlyingSimpleChannelPool().acquire().awaitUninterruptibly(); + assertThat(channelFuture.isSuccess()).isTrue(); + channelPool.release(channelFuture.getNow()).awaitUninterruptibly(); + } + + @Test + public void usingProxy_noSchemeGiven_defaultsToHttp() { + ProxyConfiguration proxyConfiguration = ProxyConfiguration.builder() + .host("localhost") + .port(mockProxy.port()) + .build(); + + channelPoolMap = AwaitCloseChannelPoolMap.builder() + .proxyConfiguration(proxyConfiguration) + .sdkChannelOptions(new SdkChannelOptions()) + .sdkEventLoopGroup(SdkEventLoopGroup.builder().build()) + .configuration(new NettyConfiguration(GLOBAL_HTTP_DEFAULTS)) + .protocol(Protocol.HTTP1_1) + .maxStreams(100) + .sslProvider(SslProvider.OPENSSL) + .build(); + + SimpleChannelPoolAwareChannelPool simpleChannelPoolAwareChannelPool = channelPoolMap.newPool( + URI.create("https://some-awesome-service:443")); + + simpleChannelPoolAwareChannelPool.acquire().awaitUninterruptibly(); + + String requests = recorder.requests().toString(); + + assertThat(requests).contains("CONNECT some-awesome-service:443"); + } + } diff --git a/http-clients/netty-nio-client/src/test/java/software/amazon/awssdk/http/nio/netty/internal/Http1TunnelConnectionPoolTest.java b/http-clients/netty-nio-client/src/test/java/software/amazon/awssdk/http/nio/netty/internal/Http1TunnelConnectionPoolTest.java new file mode 100644 index 000000000000..50cc17621c10 --- /dev/null +++ b/http-clients/netty-nio-client/src/test/java/software/amazon/awssdk/http/nio/netty/internal/Http1TunnelConnectionPoolTest.java @@ -0,0 +1,299 @@ +/* + * Copyright 2010-2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package software.amazon.awssdk.http.nio.netty.internal; + +import static org.assertj.core.api.AssertionsForClassTypes.assertThat; +import static org.mockito.Matchers.any; +import static org.mockito.Matchers.anyObject; +import static org.mockito.Matchers.eq; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; +import static software.amazon.awssdk.http.nio.netty.internal.Http1TunnelConnectionPool.TUNNEL_ESTABLISHED_KEY; +import io.netty.buffer.ByteBufAllocator; +import io.netty.channel.Channel; +import io.netty.channel.ChannelHandler; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelId; +import io.netty.channel.ChannelPipeline; +import io.netty.channel.nio.NioEventLoopGroup; +import io.netty.channel.pool.ChannelPool; +import io.netty.channel.pool.ChannelPoolHandler; +import io.netty.handler.ssl.ApplicationProtocolNegotiator; +import io.netty.handler.ssl.SslContext; +import io.netty.handler.ssl.SslHandler; +import io.netty.util.Attribute; +import io.netty.util.concurrent.Future; +import io.netty.util.concurrent.Promise; +import java.io.IOException; +import java.net.URI; +import java.util.List; +import java.util.concurrent.CountDownLatch; +import javax.net.ssl.SSLEngine; +import javax.net.ssl.SSLSessionContext; +import org.junit.AfterClass; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.mockito.ArgumentCaptor; +import org.mockito.Mock; +import org.mockito.runners.MockitoJUnitRunner; + +/** + * Unit tests for {@link Http1TunnelConnectionPool}. + */ +@RunWith(MockitoJUnitRunner.class) +public class Http1TunnelConnectionPoolTest { + private static final NioEventLoopGroup GROUP = new NioEventLoopGroup(1); + + private static final URI HTTP_PROXY_ADDRESS = URI.create("http://localhost:1234"); + + private static final URI HTTPS_PROXY_ADDRESS = URI.create("https://localhost:5678"); + + private static final URI REMOTE_ADDRESS = URI.create("https://s3.amazonaws.com:5678"); + + @Mock + private ChannelPool delegatePool; + + @Mock + private ChannelPoolHandler mockHandler; + + @Mock + public Channel mockChannel; + + @Mock + public ChannelPipeline mockPipeline; + + @Mock + public Attribute mockAttr; + + @Mock + public ChannelHandlerContext mockCtx; + + @Mock + public ChannelId mockId; + + @Before + public void methodSetup() { + Future channelFuture = GROUP.next().newSucceededFuture(mockChannel); + when(delegatePool.acquire(any(Promise.class))).thenReturn(channelFuture); + + when(mockCtx.channel()).thenReturn(mockChannel); + when(mockCtx.pipeline()).thenReturn(mockPipeline); + + when(mockChannel.attr(eq(TUNNEL_ESTABLISHED_KEY))).thenReturn(mockAttr); + when(mockChannel.id()).thenReturn(mockId); + when(mockChannel.pipeline()).thenReturn(mockPipeline); + } + + @AfterClass + public static void teardown() { + GROUP.shutdownGracefully().awaitUninterruptibly(); + } + + @Test + public void tunnelAlreadyEstablished_doesNotAddInitHandler() { + Http1TunnelConnectionPool tunnelPool = new Http1TunnelConnectionPool(GROUP.next(), delegatePool, null, + HTTP_PROXY_ADDRESS, REMOTE_ADDRESS, mockHandler); + + when(mockAttr.get()).thenReturn(true); + + tunnelPool.acquire().awaitUninterruptibly(); + + verify(mockPipeline, never()).addLast(anyObject()); + } + + @Test(timeout = 1000) + public void tunnelNotEstablished_addsInitHandler() throws InterruptedException { + Http1TunnelConnectionPool tunnelPool = new Http1TunnelConnectionPool(GROUP.next(), delegatePool, null, + HTTP_PROXY_ADDRESS, REMOTE_ADDRESS, mockHandler); + + when(mockAttr.get()).thenReturn(false); + + CountDownLatch latch = new CountDownLatch(1); + when(mockPipeline.addLast(any(ChannelHandler.class))).thenAnswer(i -> { + latch.countDown(); + return mockPipeline; + }); + tunnelPool.acquire(); + latch.await(); + verify(mockPipeline, times(1)).addLast(any(ProxyTunnelInitHandler.class)); + } + + @Test + public void tunnelInitFails_acquireFutureFails() { + Http1TunnelConnectionPool.InitHandlerSupplier supplier = (srcPool, remoteAddr, initFuture) -> { + initFuture.setFailure(new IOException("boom")); + return mock(ChannelHandler.class); + }; + + Http1TunnelConnectionPool tunnelPool = new Http1TunnelConnectionPool(GROUP.next(), delegatePool, null, + HTTP_PROXY_ADDRESS, REMOTE_ADDRESS, mockHandler, supplier); + + Future acquireFuture = tunnelPool.acquire(); + + assertThat(acquireFuture.awaitUninterruptibly().cause()).hasMessage("boom"); + } + + @Test + public void tunnelInitSucceeds_acquireFutureSucceeds() { + Http1TunnelConnectionPool.InitHandlerSupplier supplier = (srcPool, remoteAddr, initFuture) -> { + initFuture.setSuccess(mockChannel); + return mock(ChannelHandler.class); + }; + + Http1TunnelConnectionPool tunnelPool = new Http1TunnelConnectionPool(GROUP.next(), delegatePool, null, + HTTP_PROXY_ADDRESS, REMOTE_ADDRESS, mockHandler, supplier); + + Future acquireFuture = tunnelPool.acquire(); + + assertThat(acquireFuture.awaitUninterruptibly().getNow()).isEqualTo(mockChannel); + } + + @Test + public void acquireFromDelegatePoolFails_failsFuture() { + Http1TunnelConnectionPool tunnelPool = new Http1TunnelConnectionPool(GROUP.next(), delegatePool, null, + HTTP_PROXY_ADDRESS, REMOTE_ADDRESS, mockHandler); + + when(delegatePool.acquire(any(Promise.class))).thenReturn(GROUP.next().newFailedFuture(new IOException("boom"))); + + Future acquireFuture = tunnelPool.acquire(); + + assertThat(acquireFuture.awaitUninterruptibly().cause()).hasMessage("boom"); + } + + @Test + public void sslContextProvided_andProxyUsingHttps_addsSslHandler() { + SslHandler mockSslHandler = mock(SslHandler.class); + TestSslContext mockSslCtx = new TestSslContext(mockSslHandler); + + Http1TunnelConnectionPool.InitHandlerSupplier supplier = (srcPool, remoteAddr, initFuture) -> { + initFuture.setSuccess(mockChannel); + return mock(ChannelHandler.class); + }; + + Http1TunnelConnectionPool tunnelPool = new Http1TunnelConnectionPool(GROUP.next(), delegatePool, mockSslCtx, + HTTPS_PROXY_ADDRESS, REMOTE_ADDRESS, mockHandler, supplier); + + tunnelPool.acquire().awaitUninterruptibly(); + + ArgumentCaptor handlersCaptor = ArgumentCaptor.forClass(ChannelHandler.class); + verify(mockPipeline, times(2)).addLast(handlersCaptor.capture()); + + assertThat(handlersCaptor.getAllValues().get(0)).isEqualTo(mockSslHandler); + } + + @Test + public void sslContextProvided_andProxyNotUsingHttps_doesNotAddSslHandler() { + SslHandler mockSslHandler = mock(SslHandler.class); + TestSslContext mockSslCtx = new TestSslContext(mockSslHandler); + + Http1TunnelConnectionPool.InitHandlerSupplier supplier = (srcPool, remoteAddr, initFuture) -> { + initFuture.setSuccess(mockChannel); + return mock(ChannelHandler.class); + }; + + Http1TunnelConnectionPool tunnelPool = new Http1TunnelConnectionPool(GROUP.next(), delegatePool, mockSslCtx, + HTTP_PROXY_ADDRESS, REMOTE_ADDRESS, mockHandler, supplier); + + tunnelPool.acquire().awaitUninterruptibly(); + + ArgumentCaptor handlersCaptor = ArgumentCaptor.forClass(ChannelHandler.class); + verify(mockPipeline).addLast(handlersCaptor.capture()); + + assertThat(handlersCaptor.getAllValues().get(0)).isNotInstanceOf(SslHandler.class); + } + + @Test + public void release_releasedToDelegatePool() { + Http1TunnelConnectionPool tunnelPool = new Http1TunnelConnectionPool(GROUP.next(), delegatePool, null, + HTTP_PROXY_ADDRESS, REMOTE_ADDRESS, mockHandler); + tunnelPool.release(mockChannel); + verify(delegatePool).release(eq(mockChannel), any(Promise.class)); + } + + @Test + public void release_withGivenPromise_releasedToDelegatePool() { + Http1TunnelConnectionPool tunnelPool = new Http1TunnelConnectionPool(GROUP.next(), delegatePool, null, + HTTP_PROXY_ADDRESS, REMOTE_ADDRESS, mockHandler); + Promise mockPromise = mock(Promise.class); + tunnelPool.release(mockChannel, mockPromise); + verify(delegatePool).release(eq(mockChannel), eq(mockPromise)); + } + + @Test + public void close_closesDelegatePool() { + Http1TunnelConnectionPool tunnelPool = new Http1TunnelConnectionPool(GROUP.next(), delegatePool, null, + HTTP_PROXY_ADDRESS, REMOTE_ADDRESS, mockHandler); + tunnelPool.close(); + verify(delegatePool).close(); + } + + private static class TestSslContext extends SslContext { + private final SslHandler handler; + + protected TestSslContext(SslHandler handler) { + this.handler = handler; + } + + @Override + public boolean isClient() { + return false; + } + + @Override + public List cipherSuites() { + return null; + } + + @Override + public long sessionCacheSize() { + return 0; + } + + @Override + public long sessionTimeout() { + return 0; + } + + @Override + public ApplicationProtocolNegotiator applicationProtocolNegotiator() { + return null; + } + + @Override + public SSLEngine newEngine(ByteBufAllocator alloc) { + return null; + } + + @Override + public SSLEngine newEngine(ByteBufAllocator alloc, String peerHost, int peerPort) { + return null; + } + + @Override + public SSLSessionContext sessionContext() { + return null; + } + + @Override + public SslHandler newHandler(ByteBufAllocator alloc, String host, int port, boolean startTls) { + return handler; + } + } +} diff --git a/http-clients/netty-nio-client/src/test/java/software/amazon/awssdk/http/nio/netty/internal/ProxyTunnelInitHandlerTest.java b/http-clients/netty-nio-client/src/test/java/software/amazon/awssdk/http/nio/netty/internal/ProxyTunnelInitHandlerTest.java new file mode 100644 index 000000000000..887d1b2fa3f9 --- /dev/null +++ b/http-clients/netty-nio-client/src/test/java/software/amazon/awssdk/http/nio/netty/internal/ProxyTunnelInitHandlerTest.java @@ -0,0 +1,198 @@ +/* + * Copyright 2010-2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package software.amazon.awssdk.http.nio.netty.internal; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.Matchers.any; +import static org.mockito.Matchers.anyObject; +import static org.mockito.Matchers.eq; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; +import io.netty.buffer.Unpooled; +import io.netty.channel.Channel; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelPipeline; +import io.netty.channel.DefaultChannelPromise; +import io.netty.channel.nio.NioEventLoopGroup; +import io.netty.channel.pool.ChannelPool; +import io.netty.handler.codec.http.DefaultFullHttpRequest; +import io.netty.handler.codec.http.DefaultHttpResponse; +import io.netty.handler.codec.http.HttpClientCodec; +import io.netty.handler.codec.http.HttpHeaderNames; +import io.netty.handler.codec.http.HttpMethod; +import io.netty.handler.codec.http.HttpRequest; +import io.netty.handler.codec.http.HttpResponseStatus; +import io.netty.handler.codec.http.HttpVersion; +import io.netty.handler.ssl.SslHandler; +import io.netty.util.concurrent.Promise; +import java.io.IOException; +import java.net.URI; +import java.util.function.Supplier; +import org.junit.AfterClass; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.mockito.ArgumentCaptor; +import org.mockito.Mock; +import org.mockito.runners.MockitoJUnitRunner; + +/** + * Unit tests for {@link ProxyTunnelInitHandler}. + */ +@RunWith(MockitoJUnitRunner.class) +public class ProxyTunnelInitHandlerTest { + private static final NioEventLoopGroup GROUP = new NioEventLoopGroup(1); + + private static final URI REMOTE_HOST = URI.create("https://s3.amazonaws.com:1234"); + + @Mock + private ChannelHandlerContext mockCtx; + + @Mock + private Channel mockChannel; + + @Mock + private ChannelPipeline mockPipeline; + + @Mock + private ChannelPool mockChannelPool; + + @Before + public void methodSetup() { + when(mockCtx.channel()).thenReturn(mockChannel); + when(mockCtx.pipeline()).thenReturn(mockPipeline); + when(mockChannel.pipeline()).thenReturn(mockPipeline); + when(mockChannel.writeAndFlush(anyObject())).thenReturn(new DefaultChannelPromise(mockChannel, GROUP.next())); + } + + @AfterClass + public static void teardown() { + GROUP.shutdownGracefully().awaitUninterruptibly(); + } + + @Test + public void addedToPipeline_addsCodec() { + HttpClientCodec codec = new HttpClientCodec(); + Supplier codecSupplier = () -> codec; + when(mockCtx.name()).thenReturn("foo"); + + ProxyTunnelInitHandler handler = new ProxyTunnelInitHandler(mockChannelPool, REMOTE_HOST, null, codecSupplier); + handler.handlerAdded(mockCtx); + + verify(mockPipeline).addBefore(eq("foo"), eq(null), eq(codec)); + } + + @Test + public void successfulProxyResponse_completesFuture() { + Promise promise = GROUP.next().newPromise(); + ProxyTunnelInitHandler handler = new ProxyTunnelInitHandler(mockChannelPool, REMOTE_HOST, promise); + successResponse(handler); + + assertThat(promise.awaitUninterruptibly().getNow()).isEqualTo(mockChannel); + } + + @Test + public void successfulProxyResponse_removesSelfAndCodec() { + Promise promise = GROUP.next().newPromise(); + ProxyTunnelInitHandler handler = new ProxyTunnelInitHandler(mockChannelPool, REMOTE_HOST, promise); + successResponse(handler); + + verify(mockPipeline).remove(eq(handler)); + verify(mockPipeline).remove(any(HttpClientCodec.class)); + } + + @Test + public void successfulProxyResponse_doesNotRemoveSslHandler() { + SslHandler sslHandler = mock(SslHandler.class); + when(mockPipeline.get(eq(SslHandler.class))).thenReturn(sslHandler); + + Promise promise = GROUP.next().newPromise(); + ProxyTunnelInitHandler handler = new ProxyTunnelInitHandler(mockChannelPool, REMOTE_HOST, promise); + successResponse(handler); + + verify(mockPipeline, never()).remove(eq(SslHandler.class)); + } + + @Test + public void unexpectedMessage_failsPromise() { + Promise promise = GROUP.next().newPromise(); + ProxyTunnelInitHandler handler = new ProxyTunnelInitHandler(mockChannelPool, REMOTE_HOST, promise); + handler.channelRead(mockCtx, new Object()); + + assertThat(promise.awaitUninterruptibly().isSuccess()).isFalse(); + } + + @Test + public void unsuccessfulResponse_failsPromise() { + Promise promise = GROUP.next().newPromise(); + ProxyTunnelInitHandler handler = new ProxyTunnelInitHandler(mockChannelPool, REMOTE_HOST, promise); + + DefaultHttpResponse resp = new DefaultHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.FORBIDDEN); + handler.channelRead(mockCtx, resp); + + assertThat(promise.awaitUninterruptibly().isSuccess()).isFalse(); + } + + @Test + public void requestWriteFails_failsPromise() { + DefaultChannelPromise writePromise = new DefaultChannelPromise(mockChannel, GROUP.next()); + writePromise.setFailure(new IOException("boom")); + when(mockChannel.writeAndFlush(anyObject())).thenReturn(writePromise); + + Promise promise = GROUP.next().newPromise(); + ProxyTunnelInitHandler handler = new ProxyTunnelInitHandler(mockChannelPool, REMOTE_HOST, promise); + handler.handlerAdded(mockCtx); + + assertThat(promise.awaitUninterruptibly().isSuccess()).isFalse(); + } + + @Test + public void handlerRemoved_removesCodec() { + HttpClientCodec codec = new HttpClientCodec(); + when(mockPipeline.get(eq(HttpClientCodec.class))).thenReturn(codec); + + Promise promise = GROUP.next().newPromise(); + ProxyTunnelInitHandler handler = new ProxyTunnelInitHandler(mockChannelPool, REMOTE_HOST, promise); + + handler.handlerRemoved(mockCtx); + + verify(mockPipeline).remove(eq(HttpClientCodec.class)); + } + + @Test + public void handledAdded_writesRequest() { + Promise promise = GROUP.next().newPromise(); + ProxyTunnelInitHandler handler = new ProxyTunnelInitHandler(mockChannelPool, REMOTE_HOST, promise); + handler.handlerAdded(mockCtx); + + ArgumentCaptor requestCaptor = ArgumentCaptor.forClass(HttpRequest.class); + verify(mockChannel).writeAndFlush(requestCaptor.capture()); + + String uri = REMOTE_HOST.getHost() + ":" + REMOTE_HOST.getPort(); + HttpRequest expectedRequest = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.CONNECT, uri, + Unpooled.EMPTY_BUFFER, false); + expectedRequest.headers().add(HttpHeaderNames.HOST, uri); + + assertThat(requestCaptor.getValue()).isEqualTo(expectedRequest); + } + + private void successResponse(ProxyTunnelInitHandler handler) { + DefaultHttpResponse resp = new DefaultHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.OK); + handler.channelRead(mockCtx, resp); + } +}