diff --git a/driver-core/src/main/com/mongodb/internal/connection/tlschannel/ServerTlsChannel.java b/driver-core/src/main/com/mongodb/internal/connection/tlschannel/ServerTlsChannel.java index dc2827b37c9..b9b66c9429f 100644 --- a/driver-core/src/main/com/mongodb/internal/connection/tlschannel/ServerTlsChannel.java +++ b/driver-core/src/main/com/mongodb/internal/connection/tlschannel/ServerTlsChannel.java @@ -35,7 +35,7 @@ import javax.net.ssl.SSLSession; import javax.net.ssl.StandardConstants; import java.io.IOException; -import java.nio.Buffer; +import java.nio.BufferUnderflowException; import java.nio.ByteBuffer; import java.nio.channels.ByteChannel; import java.nio.channels.Channel; @@ -47,7 +47,9 @@ import java.util.function.Consumer; import java.util.function.Function; -/** A server-side {@link TlsChannel}. */ +/** + * A server-side {@link TlsChannel}. + */ public class ServerTlsChannel implements TlsChannel { private static final Logger LOGGER = Loggers.getLogger("connection.tls"); @@ -83,7 +85,7 @@ public SSLContext getSslContext(SniReader sniReader) throws IOException, EofExce throw new TlsChannelCallbackException("SNI callback failed", e); } return chosenContext.orElseThrow( - () -> new SSLHandshakeException("No ssl context available for received SNI: " + nameOpt)); + () -> new SSLHandshakeException("No ssl context available for received SNI: " + nameOpt)); } } @@ -111,12 +113,13 @@ private static SSLEngine defaultSSLEngineFactory(SSLContext sslContext) { return engine; } - /** Builder of {@link ServerTlsChannel} */ + /** + * Builder of {@link ServerTlsChannel} + */ public static class Builder extends TlsChannelBuilder { private final SslContextStrategy internalSslContextFactory; - private Function sslEngineFactory = - ServerTlsChannel::defaultSSLEngineFactory; + private Function sslEngineFactory = ServerTlsChannel::defaultSSLEngineFactory; private Builder(ByteChannel underlying, SSLContext sslContext) { super(underlying); @@ -140,20 +143,20 @@ public Builder withEngineFactory(Function sslEngineFactor public ServerTlsChannel build() { return new ServerTlsChannel( - underlying, - internalSslContextFactory, - sslEngineFactory, - sessionInitCallback, - runTasks, - plainBufferAllocator, - encryptedBufferAllocator, - releaseBuffers, - waitForCloseConfirmation); + underlying, + internalSslContextFactory, + sslEngineFactory, + sessionInitCallback, + runTasks, + plainBufferAllocator, + encryptedBufferAllocator, + releaseBuffers, + waitForCloseConfirmation); } } /** - * Create a new {@link Builder}, configured with a underlying {@link Channel} and a fixed {@link + * Create a new {@link Builder}, configured with an underlying {@link Channel} and a fixed {@link * SSLContext}, which will be used to create the {@link SSLEngine}. * * @param underlying a reference to the underlying {@link ByteChannel} @@ -165,16 +168,16 @@ public static Builder newBuilder(ByteChannel underlying, SSLContext sslContext) } /** - * Create a new {@link Builder}, configured with a underlying {@link Channel} and a custom {@link + * Create a new {@link Builder}, configured with an underlying {@link Channel} and a custom {@link * SSLContext} factory, which will be used to create the context (in turn used to create the - * {@link SSLEngine}, as a function of the SNI received at the TLS connection start. + * {@link SSLEngine}), as a function of the SNI received at the TLS connection start. * *

Implementation note:
* Due to limitations of {@link SSLEngine}, configuring a {@link ServerTlsChannel} to select the * {@link SSLContext} based on the SNI value implies parsing the first TLS frame (ClientHello) * independently of the SSLEngine. * - * @param underlying a reference to the underlying {@link ByteChannel} + * @param underlying a reference to the underlying {@link ByteChannel} * @param sslContextFactory a function from an optional SNI to the {@link SSLContext} to be used * @return the new builder * @see Server Name Indication @@ -203,15 +206,15 @@ public static Builder newBuilder(ByteChannel underlying, SniSslContextFactory ss // @formatter:off private ServerTlsChannel( - ByteChannel underlying, - SslContextStrategy internalSslContextFactory, - Function engineFactory, - Consumer sessionInitCallback, - boolean runTasks, - BufferAllocator plainBufAllocator, - BufferAllocator encryptedBufAllocator, - boolean releaseBuffers, - boolean waitForCloseConfirmation) { + ByteChannel underlying, + SslContextStrategy internalSslContextFactory, + Function engineFactory, + Consumer sessionInitCallback, + boolean runTasks, + BufferAllocator plainBufAllocator, + BufferAllocator encryptedBufAllocator, + boolean releaseBuffers, + boolean waitForCloseConfirmation) { this.underlying = underlying; this.sslContextStrategy = internalSslContextFactory; this.engineFactory = engineFactory; @@ -221,8 +224,7 @@ private ServerTlsChannel( this.encryptedBufAllocator = new TrackingAllocator(encryptedBufAllocator); this.releaseBuffers = releaseBuffers; this.waitForCloseConfirmation = waitForCloseConfirmation; - inEncrypted = - new BufferHolder( + inEncrypted = new BufferHolder( "inEncrypted", Optional.empty(), encryptedBufAllocator, @@ -242,7 +244,7 @@ public ByteChannel getUnderlying() { /** * Return the used {@link SSLContext}. * - * @return if context if present, of null if the TLS connection as not been initializer, or the + * @return context if present, or null if the TLS connection as not been initializer, or the * SNI not received yet. */ public SSLContext getSslContext() { @@ -347,8 +349,12 @@ public void handshake() throws IOException { @Override public void close() throws IOException { - if (impl != null) impl.close(); - if (inEncrypted != null) inEncrypted.dispose(); + if (impl != null) { + impl.close(); + } + if (inEncrypted != null) { + inEncrypted.dispose(); + } underlying.close(); } @@ -370,8 +376,7 @@ private void initEngine() throws IOException, EofException { LOGGER.trace("client threw exception in SSLEngine factory", e); throw new TlsChannelCallbackException("SSLEngine creation callback failed", e); } - impl = - new TlsChannelImpl( + impl = new TlsChannelImpl( underlying, underlying, engine, @@ -393,41 +398,33 @@ private void initEngine() throws IOException, EofException { private Optional getServerNameIndication() throws IOException, EofException { inEncrypted.prepare(); try { - int recordHeaderSize = readRecordHeaderSize(); - while (inEncrypted.buffer.position() < recordHeaderSize) { - if (!inEncrypted.buffer.hasRemaining()) { - inEncrypted.enlarge(); + // loop finishes using return statements + while (true) { + try { + inEncrypted.buffer.flip(); + try { + Map serverNames = TlsExplorer.exploreTlsRecord(inEncrypted.buffer); + SNIServerName hostName = serverNames.get(StandardConstants.SNI_HOST_NAME); + if (hostName instanceof SNIHostName) { + return Optional.of(hostName); + } else { + return Optional.empty(); + } + } finally { + inEncrypted.buffer.compact(); + } + } catch (BufferUnderflowException e) { + if (!inEncrypted.buffer.hasRemaining()) { + inEncrypted.enlarge(); + } + TlsChannelImpl.callChannelRead(underlying, inEncrypted.buffer); // IO block } - TlsChannelImpl.readFromChannel(underlying, inEncrypted.buffer); // IO block - } - ((Buffer) inEncrypted.buffer).flip(); - Map serverNames = TlsExplorer.explore(inEncrypted.buffer); - inEncrypted.buffer.compact(); - SNIServerName hostName = serverNames.get(StandardConstants.SNI_HOST_NAME); - if (hostName != null && hostName instanceof SNIHostName) { - SNIHostName sniHostName = (SNIHostName) hostName; - return Optional.of(sniHostName); - } else { - return Optional.empty(); } } finally { inEncrypted.release(); } } - private int readRecordHeaderSize() throws IOException, EofException { - while (inEncrypted.buffer.position() < TlsExplorer.RECORD_HEADER_SIZE) { - if (!inEncrypted.buffer.hasRemaining()) { - throw new IllegalStateException("inEncrypted too small"); - } - TlsChannelImpl.readFromChannel(underlying, inEncrypted.buffer); // IO block - } - ((Buffer) inEncrypted.buffer).flip(); - int recordHeaderSize = TlsExplorer.getRequiredSize(inEncrypted.buffer); - inEncrypted.buffer.compact(); - return recordHeaderSize; - } - @Override public boolean shutdown() throws IOException { return impl != null && impl.shutdown(); @@ -442,4 +439,4 @@ public boolean shutdownReceived() { public boolean shutdownSent() { return impl != null && impl.shutdownSent(); } -} +} \ No newline at end of file diff --git a/driver-core/src/main/com/mongodb/internal/connection/tlschannel/impl/ByteBufferSet.java b/driver-core/src/main/com/mongodb/internal/connection/tlschannel/impl/ByteBufferSet.java index cf95a90801b..2906104c770 100644 --- a/driver-core/src/main/com/mongodb/internal/connection/tlschannel/impl/ByteBufferSet.java +++ b/driver-core/src/main/com/mongodb/internal/connection/tlschannel/impl/ByteBufferSet.java @@ -29,11 +29,19 @@ public class ByteBufferSet { public final int length; public ByteBufferSet(ByteBuffer[] array, int offset, int length) { - if (array == null) throw new NullPointerException(); - if (array.length < offset) throw new IndexOutOfBoundsException(); - if (array.length < offset + length) throw new IndexOutOfBoundsException(); + if (array == null) { + throw new NullPointerException(); + } + if (array.length < offset) { + throw new IndexOutOfBoundsException(); + } + if (array.length < offset + length) { + throw new IndexOutOfBoundsException(); + } for (int i = offset; i < offset + length; i++) { - if (array[i] == null) throw new NullPointerException(); + if (array[i] == null) { + throw new NullPointerException(); + } } this.array = array; this.offset = offset; @@ -56,10 +64,20 @@ public long remaining() { return ret; } + public long position() { + long ret = 0; + for (int i = offset; i < offset + length; i++) { + ret += array[i].position(); + } + return ret; + } + public int putRemaining(ByteBuffer from) { int totalBytes = 0; for (int i = offset; i < offset + length; i++) { - if (!from.hasRemaining()) break; + if (!from.hasRemaining()) { + break; + } ByteBuffer dstBuffer = array[i]; int bytes = Math.min(from.remaining(), dstBuffer.remaining()); ByteBufferUtil.copy(from, dstBuffer, bytes); @@ -78,7 +96,9 @@ public ByteBufferSet put(ByteBuffer from, int length) { int totalBytes = 0; for (int i = offset; i < offset + this.length; i++) { int pending = length - totalBytes; - if (pending == 0) break; + if (pending == 0) { + break; + } int bytes = Math.min(pending, (int) remaining()); ByteBuffer dstBuffer = array[i]; ByteBufferUtil.copy(from, dstBuffer, bytes); @@ -90,7 +110,9 @@ public ByteBufferSet put(ByteBuffer from, int length) { public int getRemaining(ByteBuffer dst) { int totalBytes = 0; for (int i = offset; i < offset + length; i++) { - if (!dst.hasRemaining()) break; + if (!dst.hasRemaining()) { + break; + } ByteBuffer srcBuffer = array[i]; int bytes = Math.min(dst.remaining(), srcBuffer.remaining()); ByteBufferUtil.copy(srcBuffer, dst, bytes); @@ -109,7 +131,9 @@ public ByteBufferSet get(ByteBuffer dst, int length) { int totalBytes = 0; for (int i = offset; i < offset + this.length; i++) { int pending = length - totalBytes; - if (pending == 0) break; + if (pending == 0) { + break; + } ByteBuffer srcBuffer = array[i]; int bytes = Math.min(pending, srcBuffer.remaining()); ByteBufferUtil.copy(srcBuffer, dst, bytes); @@ -124,19 +148,15 @@ public boolean hasRemaining() { public boolean isReadOnly() { for (int i = offset; i < offset + length; i++) { - if (array[i].isReadOnly()) return true; + if (array[i].isReadOnly()) { + return true; + } } return false; } @Override public String toString() { - return "ByteBufferSet[array=" - + Arrays.toString(array) - + ", offset=" - + offset - + ", length=" - + length - + "]"; + return "ByteBufferSet[" + Arrays.toString(array) + ":" + offset + ":" + length + "]"; } -} +} \ No newline at end of file diff --git a/driver-core/src/main/com/mongodb/internal/connection/tlschannel/impl/TlsChannelImpl.java b/driver-core/src/main/com/mongodb/internal/connection/tlschannel/impl/TlsChannelImpl.java index 3c845ce6d08..28212c5be76 100644 --- a/driver-core/src/main/com/mongodb/internal/connection/tlschannel/impl/TlsChannelImpl.java +++ b/driver-core/src/main/com/mongodb/internal/connection/tlschannel/impl/TlsChannelImpl.java @@ -36,7 +36,6 @@ import javax.net.ssl.SSLException; import javax.net.ssl.SSLSession; import java.io.IOException; -import java.nio.Buffer; import java.nio.ByteBuffer; import java.nio.channels.ByteChannel; import java.nio.channels.ClosedChannelException; @@ -48,7 +47,6 @@ import java.util.function.Consumer; import static java.lang.String.format; -import static javax.net.ssl.SSLEngineResult.HandshakeStatus.NOT_HANDSHAKING; public class TlsChannelImpl implements ByteChannel { @@ -56,36 +54,20 @@ public class TlsChannelImpl implements ByteChannel { public static final int buffersInitialSize = 4096; - /** Official TLS max data size is 2^14 = 16k. Use 1024 more to account for the overhead */ + /** + * Official TLS max data size is 2^14 = 16k. Use 1024 more to account for the overhead + */ public static final int maxTlsPacketSize = 17 * 1024; - private static class UnwrapResult { - public final int bytesProduced; - public final HandshakeStatus lastHandshakeStatus; - public final boolean wasClosed; - - public UnwrapResult(int bytesProduced, HandshakeStatus lastHandshakeStatus, boolean wasClosed) { - this.bytesProduced = bytesProduced; - this.lastHandshakeStatus = lastHandshakeStatus; - this.wasClosed = wasClosed; - } - } - - private static class WrapResult { - public final int bytesConsumed; - public final HandshakeStatus lastHandshakeStatus; - - public WrapResult(int bytesConsumed, HandshakeStatus lastHandshakeStatus) { - this.bytesConsumed = bytesConsumed; - this.lastHandshakeStatus = lastHandshakeStatus; - } - } - - /** Used to signal EOF conditions from the underlying channel */ + /** + * Used to signal EOF conditions from the underlying channel + */ public static class EofException extends Exception { private static final long serialVersionUID = -3859156713994602991L; - /** For efficiency, override this method to do nothing. */ + /** + * For efficiency, override this method to do nothing. + */ @Override public Throwable fillInStackTrace() { return this; @@ -95,7 +77,7 @@ public Throwable fillInStackTrace() { private final ReadableByteChannel readChannel; private final WritableByteChannel writeChannel; private final SSLEngine engine; - private BufferHolder inEncrypted; + private final BufferHolder inEncrypted; private final Consumer initSessionCallback; private final boolean runTasks; @@ -105,38 +87,34 @@ public Throwable fillInStackTrace() { // @formatter:off public TlsChannelImpl( - ReadableByteChannel readChannel, - WritableByteChannel writeChannel, - SSLEngine engine, - Optional inEncrypted, - Consumer initSessionCallback, - boolean runTasks, - TrackingAllocator plainBufAllocator, - TrackingAllocator encryptedBufAllocator, - boolean releaseBuffers, - boolean waitForCloseConfirmation) { + ReadableByteChannel readChannel, + WritableByteChannel writeChannel, + SSLEngine engine, + Optional inEncrypted, + Consumer initSessionCallback, + boolean runTasks, + TrackingAllocator plainBufAllocator, + TrackingAllocator encryptedBufAllocator, + boolean releaseBuffers, + boolean waitForCloseConfirmation) { // @formatter:on this.readChannel = readChannel; this.writeChannel = writeChannel; this.engine = engine; - this.inEncrypted = - inEncrypted.orElseGet( - () -> - new BufferHolder( - "inEncrypted", - Optional.empty(), - encryptedBufAllocator, - buffersInitialSize, - maxTlsPacketSize, - false /* plainData */, - releaseBuffers)); + this.inEncrypted = inEncrypted.orElseGet(() -> new BufferHolder( + "inEncrypted", + Optional.empty(), + encryptedBufAllocator, + buffersInitialSize, + maxTlsPacketSize, + false /* plainData */, + releaseBuffers)); this.initSessionCallback = initSessionCallback; this.runTasks = runTasks; this.plainBufAllocator = plainBufAllocator; this.encryptedBufAllocator = encryptedBufAllocator; this.waitForCloseConfirmation = waitForCloseConfirmation; - inPlain = - new BufferHolder( + inPlain = new BufferHolder( "inPlain", Optional.empty(), plainBufAllocator, @@ -144,8 +122,7 @@ public TlsChannelImpl( maxTlsPacketSize, true /* plainData */, releaseBuffers); - outEncrypted = - new BufferHolder( + outEncrypted = new BufferHolder( "outEncrypted", Optional.empty(), encryptedBufAllocator, @@ -159,7 +136,9 @@ public TlsChannelImpl( private final Lock readLock = new ReentrantLock(); private final Lock writeLock = new ReentrantLock(); - private volatile boolean negotiated = false; + private boolean handshakeStarted = false; + + private volatile boolean handshakeCompleted = false; /** * Whether a IOException was received from the underlying channel or from the {@link SSLEngine}. @@ -172,11 +151,27 @@ public TlsChannelImpl( /** Whether a close_notify was already received. */ private volatile boolean shutdownReceived = false; - // decrypted data from inEncrypted - private BufferHolder inPlain; + /** + * Decrypted data from inEncrypted + */ + private final BufferHolder inPlain; - // contains data encrypted to send to the underlying channel - private BufferHolder outEncrypted; + /** + * Contains data encrypted to send to the underlying channel + */ + private final BufferHolder outEncrypted; + + /** + * Reference to the current read buffer supplied by the client this field is only valid during a + * read operation. This field is used instead of {@link #inPlain} in order to avoid copying + * returned bytes when possible. + */ + private ByteBufferSet suppliedInPlain; + + /** + * Bytes produced by the current read operation + */ + private int bytesToReturn; /** * Handshake wrap() method calls need a buffer to read from, even when they actually do not read @@ -185,8 +180,7 @@ public TlsChannelImpl( *

Note: standard SSLEngine is happy with no buffers, the empty buffer is here to make this * work with Netty's OpenSSL's wrapper. */ - private final ByteBufferSet dummyOut = - new ByteBufferSet(new ByteBuffer[] {ByteBuffer.allocate(0)}); + private final ByteBufferSet dummyOut = new ByteBufferSet(new ByteBuffer[] {ByteBuffer.allocate(0)}); public Consumer getSessionInitCallback() { return initSessionCallback; @@ -204,45 +198,53 @@ public TrackingAllocator getEncryptedBufferAllocator() { public long read(ByteBufferSet dest) throws IOException { checkReadBuffer(dest); - if (!dest.hasRemaining()) return 0; + if (!dest.hasRemaining()) { + return 0; + } handshake(); readLock.lock(); try { if (invalid || shutdownSent) { throw new ClosedChannelException(); } - HandshakeStatus handshakeStatus = engine.getHandshakeStatus(); - int bytesToReturn = inPlain.nullOrEmpty() ? 0 :inPlain.buffer.position(); + + long originalDestPosition = dest.position(); + suppliedInPlain = dest; + bytesToReturn = inPlain.nullOrEmpty() ? 0 : inPlain.buffer.position(); + while (true) { + + // return bytes are soon as we have them if (bytesToReturn > 0) { if (inPlain.nullOrEmpty()) { + // if there is not in internal buffer, that means that the bytes must be in the supplied + // buffer + Util.assertTrue(dest.position() == originalDestPosition + bytesToReturn); return bytesToReturn; } else { + Util.assertTrue(inPlain.buffer.position() == bytesToReturn); return transferPendingPlain(dest); } } + if (shutdownReceived) { return -1; } Util.assertTrue(inPlain.nullOrEmpty()); - switch (handshakeStatus) { + switch (engine.getHandshakeStatus()) { case NEED_UNWRAP: case NEED_WRAP: - bytesToReturn = handshake(Optional.of(dest), Optional.of(handshakeStatus)); - handshakeStatus = NOT_HANDSHAKING; + writeAndHandshake(); break; case NOT_HANDSHAKING: case FINISHED: - UnwrapResult res = readAndUnwrap(Optional.of(dest)); - if (res.wasClosed) { + readAndUnwrap(); + if (shutdownReceived) { return -1; } - bytesToReturn = res.bytesProduced; - handshakeStatus = res.lastHandshakeStatus; break; case NEED_TASK: handleTask(); - handshakeStatus = engine.getHandshakeStatus(); break; default: // Unsupported stage eg: NEED_UNWRAP_AGAIN @@ -252,20 +254,28 @@ public long read(ByteBufferSet dest) throws IOException { } catch (EofException e) { return -1; } finally { + bytesToReturn = 0; + suppliedInPlain = null; readLock.unlock(); } } private void handleTask() throws NeedsTaskException { + Runnable task = engine.getDelegatedTask(); if (runTasks) { - engine.getDelegatedTask().run(); + LOGGER.trace("delegating in task: " + task); + task.run(); } else { - throw new NeedsTaskException(engine.getDelegatedTask()); + if (LOGGER.isTraceEnabled()) { + LOGGER.trace("task needed, throwing exception: " + task); + } + throw new NeedsTaskException(task); } } + /** Copies bytes from the internal input plain buffer to the supplied buffer. */ private int transferPendingPlain(ByteBufferSet dstBuffers) { - ((Buffer) inPlain.buffer).flip(); // will read + inPlain.buffer.flip(); // will read int bytes = dstBuffers.putRemaining(inPlain.buffer); inPlain.buffer.compact(); // will write boolean disposed = inPlain.release(); @@ -275,40 +285,54 @@ private int transferPendingPlain(ByteBufferSet dstBuffers) { return bytes; } - private UnwrapResult unwrapLoop(Optional dest, HandshakeStatus originalStatus) - throws SSLException { - ByteBufferSet effDest = - dest.orElseGet( - () -> { - inPlain.prepare(); - return new ByteBufferSet(inPlain.buffer); - }); + private SSLEngineResult unwrapLoop() throws SSLException { + ByteBufferSet effDest; + if (suppliedInPlain != null) { + effDest = suppliedInPlain; + } else { + inPlain.prepare(); + effDest = new ByteBufferSet(inPlain.buffer); + } + while (true) { Util.assertTrue(inPlain.nullOrEmpty()); SSLEngineResult result = callEngineUnwrap(effDest); + HandshakeStatus status = engine.getHandshakeStatus(); + /* * Note that data can be returned even in case of overflow, in that * case, just return the data. */ - if (result.bytesProduced() > 0 - || result.getStatus() == Status.BUFFER_UNDERFLOW - || result.getStatus() == Status.CLOSED - || result.getHandshakeStatus() != originalStatus) { - boolean wasClosed = result.getStatus() == Status.CLOSED; - return new UnwrapResult(result.bytesProduced(), result.getHandshakeStatus(), wasClosed); + if (result.bytesProduced() > 0) { + return result; + } + if (result.getStatus() == Status.CLOSED) { + return result; + } + if (result.getStatus() == Status.BUFFER_UNDERFLOW) { + return result; + } + + if (result.getHandshakeStatus() == HandshakeStatus.FINISHED + || status == HandshakeStatus.NEED_TASK + || status == HandshakeStatus.NEED_WRAP) { + return result; } if (result.getStatus() == Status.BUFFER_OVERFLOW) { - if (dest.isPresent() && effDest == dest.get()) { + if (effDest == suppliedInPlain) { /* * The client-supplier buffer is not big enough. Use the - * internal inPlain buffer, also ensure that it is bigger + * internal inPlain buffer. Also ensure that it is bigger * than the too-small supplied one. */ inPlain.prepare(); - ensureInPlainCapacity(Math.min(((int) dest.get().remaining()) * 2, maxTlsPacketSize)); + if (inPlain.buffer.capacity() <= suppliedInPlain.remaining()) { + inPlain.enlarge(); + } } else { inPlain.enlarge(); } + // inPlain changed, re-create the wrapper effDest = new ByteBufferSet(inPlain.buffer); } @@ -316,7 +340,7 @@ private UnwrapResult unwrapLoop(Optional dest, HandshakeStatus or } private SSLEngineResult callEngineUnwrap(ByteBufferSet dest) throws SSLException { - ((Buffer) inEncrypted.buffer).flip(); + inEncrypted.buffer.flip(); try { SSLEngineResult result = engine.unwrap(inEncrypted.buffer, dest.array, dest.offset, dest.length); @@ -339,9 +363,9 @@ private SSLEngineResult callEngineUnwrap(ByteBufferSet dest) throws SSLException } } - private int readFromChannel() throws IOException, EofException { + private void readFromChannel() throws IOException, EofException { try { - return readFromChannel(readChannel, inEncrypted.buffer); + callChannelRead(readChannel, inEncrypted.buffer); } catch (WouldBlockException e) { throw e; } catch (IOException e) { @@ -350,8 +374,8 @@ private int readFromChannel() throws IOException, EofException { } } - public static int readFromChannel(ReadableByteChannel readChannel, ByteBuffer buffer) - throws IOException, EofException { + public static void callChannelRead(ReadableByteChannel readChannel, ByteBuffer buffer) + throws IOException, EofException { Util.assertTrue(buffer.hasRemaining()); LOGGER.trace("Reading from channel"); int c = readChannel.read(buffer); // IO block @@ -364,7 +388,6 @@ public static int readFromChannel(ReadableByteChannel readChannel, ByteBuffer bu if (c == 0) { throw new NeedsReadException(); } - return c; } // write @@ -389,27 +412,33 @@ public long write(ByteBufferSet source) throws IOException { private long wrapAndWrite(ByteBufferSet source) throws IOException { long bytesToConsume = source.remaining(); - long bytesConsumed = 0; outEncrypted.prepare(); try { while (true) { - writeToChannel(); - if (bytesConsumed == bytesToConsume) return bytesToConsume; - WrapResult res = wrapLoop(source); - bytesConsumed += res.bytesConsumed; + writeToChannel(); // IO block + if (source.remaining() == 0) { + return bytesToConsume; + } + SSLEngineResult result = wrapLoop(source); + if (result.getStatus() == Status.CLOSED) { + return bytesToConsume - source.remaining(); + } } } finally { outEncrypted.release(); } } - private WrapResult wrapLoop(ByteBufferSet source) throws SSLException { + /** + * Returns last {@link HandshakeStatus} of the loop + */ + private SSLEngineResult wrapLoop(ByteBufferSet source) throws SSLException { while (true) { SSLEngineResult result = callEngineWrap(source); switch (result.getStatus()) { case OK: case CLOSED: - return new WrapResult(result.bytesConsumed(), result.getHandshakeStatus()); + return result; case BUFFER_OVERFLOW: Util.assertTrue(result.bytesConsumed() == 0); outEncrypted.enlarge(); @@ -439,26 +468,14 @@ private SSLEngineResult callEngineWrap(ByteBufferSet source) throws SSLException } } - private void ensureInPlainCapacity(int newCapacity) { - if (inPlain.buffer.capacity() < newCapacity) { - if (LOGGER.isTraceEnabled()) { - LOGGER.trace(format( - "inPlain buffer too small, increasing from %s to %s", - inPlain.buffer.capacity(), - newCapacity)); - } - inPlain.resize(newCapacity); - } - } - private void writeToChannel() throws IOException { if (outEncrypted.buffer.position() == 0) { return; } - ((Buffer) outEncrypted.buffer).flip(); + outEncrypted.buffer.flip(); try { try { - writeToChannel(writeChannel, outEncrypted.buffer); + callChannelWrite(writeChannel, outEncrypted.buffer); // IO block } catch (WouldBlockException e) { throw e; } catch (IOException e) { @@ -470,13 +487,12 @@ private void writeToChannel() throws IOException { } } - private static void writeToChannel(WritableByteChannel channel, ByteBuffer src) - throws IOException { + private static void callChannelWrite(WritableByteChannel channel, ByteBuffer src) throws IOException { while (src.hasRemaining()) { if (LOGGER.isTraceEnabled()) { LOGGER.trace("Writing to channel: " + src); } - int c = channel.write(src); + int c = channel.write(src); // IO block if (c == 0) { /* * If no bytesProduced were written, it means that the socket is @@ -485,7 +501,7 @@ private static void writeToChannel(WritableByteChannel channel, ByteBuffer src) throw new NeedsWriteException(); } // blocking SocketChannels can write less than all the bytesProduced - // just before an error the loop forces the exception + // just before an error, the loop forces the exception } } @@ -526,14 +542,34 @@ public void handshake() throws IOException { } private void doHandshake(boolean force) throws IOException, EofException { - if (!force && negotiated) return; + if (!force && handshakeCompleted) { + return; + } initLock.lock(); try { - if (invalid || shutdownSent) throw new ClosedChannelException(); - if (force || !negotiated) { - engine.beginHandshake(); - LOGGER.trace("Called engine.beginHandshake()"); - handshake(Optional.empty(), Optional.empty()); + if (invalid || shutdownSent) { + throw new ClosedChannelException(); + } + if (force || !handshakeCompleted) { + + if (!handshakeStarted) { + LOGGER.trace("Called engine.beginHandshake()"); + engine.beginHandshake(); + + // Some engines that do not support renegotiations may be sensitive to calling + // SSLEngine.beginHandshake() more than once. This guard prevents that. + // See: https://github.com/marianobarrios/tls-channel/issues/197 + handshakeStarted = true; + } + + writeAndHandshake(); + + if (engine.getSession().getProtocol().startsWith("DTLS")) { + throw new IllegalArgumentException("DTLS not supported"); + } + + handshakeCompleted = true; + // call client code try { initSessionCallback.accept(engine.getSession()); @@ -541,27 +577,22 @@ private void doHandshake(boolean force) throws IOException, EofException { LOGGER.trace("client code threw exception in session initialization callback", e); throw new TlsChannelCallbackException("session initialization callback failed", e); } - negotiated = true; } } finally { initLock.unlock(); } } - private int handshake(Optional dest, Optional handshakeStatus) - throws IOException, EofException { + private void writeAndHandshake() throws IOException, EofException { readLock.lock(); try { writeLock.lock(); try { - if (invalid || shutdownSent) { - throw new ClosedChannelException(); - } Util.assertTrue(inPlain.nullOrEmpty()); outEncrypted.prepare(); try { writeToChannel(); // IO block - return handshakeLoop(dest, handshakeStatus); + handshakeLoop(); } finally { outEncrypted.release(); } @@ -573,58 +604,56 @@ private int handshake(Optional dest, Optional ha } } - private int handshakeLoop(Optional dest, Optional handshakeStatus) - throws IOException, EofException { + private void handshakeLoop() throws IOException, EofException { Util.assertTrue(inPlain.nullOrEmpty()); - HandshakeStatus status = handshakeStatus.orElseGet(() -> engine.getHandshakeStatus()); while (true) { - switch (status) { + switch (engine.getHandshakeStatus()) { case NEED_WRAP: Util.assertTrue(outEncrypted.nullOrEmpty()); - WrapResult wrapResult = wrapLoop(dummyOut); - status = wrapResult.lastHandshakeStatus; + wrapLoop(dummyOut); writeToChannel(); // IO block break; case NEED_UNWRAP: - UnwrapResult res = readAndUnwrap(dest); - status = res.lastHandshakeStatus; - if (res.bytesProduced > 0) return res.bytesProduced; + readAndUnwrap(); + if (bytesToReturn > 0) { + return; + } break; case NOT_HANDSHAKING: - /* - * This should not really happen using SSLEngine, because - * handshaking ends with a FINISHED status. However, we accept - * this value to permit the use of a pass-through stub engine - * with no encryption. - */ - return 0; + return; case NEED_TASK: handleTask(); - status = engine.getHandshakeStatus(); break; case FINISHED: - return 0; + // this status is never returned by SSLEngine.getHandshakeStatus() + throw new IllegalStateException(); default: // Unsupported stage eg: NEED_UNWRAP_AGAIN - return 0; + throw new IllegalStateException(); } } } - private UnwrapResult readAndUnwrap(Optional dest) - throws IOException, EofException { + private void readAndUnwrap() throws IOException, EofException { // Save status before operation: use it to stop when status changes - HandshakeStatus orig = engine.getHandshakeStatus(); inEncrypted.prepare(); try { while (true) { Util.assertTrue(inPlain.nullOrEmpty()); - UnwrapResult res = unwrapLoop(dest, orig); - if (res.bytesProduced > 0 || res.lastHandshakeStatus != orig || res.wasClosed) { - if (res.wasClosed) { - shutdownReceived = true; - } - return res; + SSLEngineResult result = unwrapLoop(); + HandshakeStatus status = engine.getHandshakeStatus(); + if (result.bytesProduced() > 0) { + bytesToReturn = result.bytesProduced(); + return; + } + if (result.getStatus() == Status.CLOSED) { + shutdownReceived = true; + return; + } + if (result.getHandshakeStatus() == HandshakeStatus.FINISHED + || status == HandshakeStatus.NEED_TASK + || status == HandshakeStatus.NEED_WRAP) { + return; } if (!inEncrypted.buffer.hasRemaining()) { inEncrypted.enlarge(); @@ -720,7 +749,7 @@ public boolean shutdown() throws IOException { if (!shutdownReceived) { try { // IO block - readAndUnwrap(Optional.empty()); + readAndUnwrap(); Util.assertTrue(shutdownReceived); } catch (EofException e) { throw new ClosedChannelException(); @@ -737,18 +766,9 @@ public boolean shutdown() throws IOException { } private void freeBuffers() { - if (inEncrypted != null) { - inEncrypted.dispose(); - inEncrypted = null; - } - if (inPlain != null) { - inPlain.dispose(); - inPlain = null; - } - if (outEncrypted != null) { - outEncrypted.dispose(); - outEncrypted = null; - } + inEncrypted.dispose(); + inPlain.dispose(); + outEncrypted.dispose(); } public boolean isOpen() { diff --git a/driver-core/src/main/com/mongodb/internal/connection/tlschannel/impl/TlsExplorer.java b/driver-core/src/main/com/mongodb/internal/connection/tlschannel/impl/TlsExplorer.java index 0f75c6b33f5..baad50bfb51 100644 --- a/driver-core/src/main/com/mongodb/internal/connection/tlschannel/impl/TlsExplorer.java +++ b/driver-core/src/main/com/mongodb/internal/connection/tlschannel/impl/TlsExplorer.java @@ -29,158 +29,178 @@ import java.util.HashMap; import java.util.Map; -/* - * Implement basic TLS parsing, just to read the SNI (as this is not done by - * {@link SSLEngine}. - */ +//** Implement basic TLS parsing, just to read the SNI. */ public final class TlsExplorer { private TlsExplorer() {} - /** The header size of TLS/SSL records. */ - public static final int RECORD_HEADER_SIZE = 5; - - /** - * Returns the required number of bytesProduced in the {@code source} {@link ByteBuffer} necessary - * to explore SSL/TLS connection. + /* + * struct { + * uint8 major; + * uint8 minor; + * } ProtocolVersion; * - *

This method tries to parse as few bytesProduced as possible from {@code source} byte buffer - * to get the length of an SSL/TLS record. + * enum { + * change_cipher_spec(20), + * alert(21), + * handshake(22), + * application_data(23), + * (255) + * } ContentType; * - * @param source source buffer - * @return the required size + * struct { + * ContentType type; + * ProtocolVersion version; + * uint16 length; + * opaque fragment[TLSPlaintext.length]; + * } TLSPlaintext; */ - public static int getRequiredSize(ByteBuffer source) { - if (source.remaining() < RECORD_HEADER_SIZE) throw new BufferUnderflowException(); - ((Buffer) source).mark(); - try { - byte firstByte = source.get(); - source.get(); // second byte discarded - byte thirdByte = source.get(); - if ((firstByte & 0x80) != 0 && thirdByte == 0x01) { - // looks like a V2ClientHello - return RECORD_HEADER_SIZE; // Only need the header fields - } else { - return (((source.get() & 0xFF) << 8) | (source.get() & 0xFF)) + 5; - } - } finally { - ((Buffer) source).reset(); - } - } + /** Explores a TLS record in search to the SNI. This method does not consume buffer. */ + public static Map exploreTlsRecord(ByteBuffer input) throws SSLProtocolException { - public static Map explore(ByteBuffer source) throws SSLProtocolException { - if (source.remaining() < RECORD_HEADER_SIZE) throw new BufferUnderflowException(); - ((Buffer) source).mark(); + input.mark(); try { - byte firstByte = source.get(); - ignore(source, 1); // ignore second byte - byte thirdByte = source.get(); - if ((firstByte & 0x80) != 0 && thirdByte == 0x01) { - // looks like a V2ClientHello - return new HashMap<>(); - } else if (firstByte == 22) { + byte firstByte = input.get(); + if (firstByte != 22) { // 22: handshake record - return exploreTLSRecord(source, firstByte); - } else { - throw new SSLProtocolException("Not handshake record"); + throw new SSLProtocolException("Not a handshake record"); + } + + ignore(input, 2); // ignore version + + // Is there enough data for a full record? + int recordLength = getInt16(input); + if (recordLength > input.remaining()) { + throw new BufferUnderflowException(); } + + return exploreHandshake(input, recordLength); } finally { - ((Buffer) source).reset(); + input.reset(); } } /* - * struct { uint8 major; uint8 minor; } ProtocolVersion; + * enum { + * hello_request(0), + * client_hello(1), + * server_hello(2), + * certificate(11), + * server_key_exchange (12), + * certificate_request(13), + * server_hello_done(14), + * certificate_verify(15), + * client_key_exchange(16), + * finished(20), + * (255) + * } HandshakeType; * - * enum { change_cipher_spec(20), alert(21), handshake(22), - * application_data(23), (255) } ContentType; - * - * struct { ContentType type; ProtocolVersion version; uint16 length; opaque - * fragment[TLSPlaintext.length]; } TLSPlaintext; - */ - private static Map exploreTLSRecord(ByteBuffer input, byte firstByte) - throws SSLProtocolException { - // Is it a handshake message? - if (firstByte != 22) // 22: handshake record - throw new SSLProtocolException("Not handshake record"); - // Is there enough data for a full record? - int recordLength = getInt16(input); - if (recordLength > input.remaining()) throw new BufferUnderflowException(); - return exploreHandshake(input, recordLength); - } - - /* - * enum { hello_request(0), client_hello(1), server_hello(2), - * certificate(11), server_key_exchange (12), certificate_request(13), - * server_hello_done(14), certificate_verify(15), client_key_exchange(16), - * finished(20) (255) } HandshakeType; - * - * struct { HandshakeType msg_type; uint24 length; select (HandshakeType) { - * case hello_request: HelloRequest; case client_hello: ClientHello; case - * server_hello: ServerHello; case certificate: Certificate; case - * server_key_exchange: ServerKeyExchange; case certificate_request: - * CertificateRequest; case server_hello_done: ServerHelloDone; case - * certificate_verify: CertificateVerify; case client_key_exchange: - * ClientKeyExchange; case finished: Finished; } body; } Handshake; + * struct { + * HandshakeType msg_type; + * uint24 length; + * select (HandshakeType) { + * case hello_request: HelloRequest; + * case client_hello: ClientHello; + * case server_hello: ServerHello; + * case certificate: Certificate; + * case server_key_exchange: ServerKeyExchange; + * case certificate_request: CertificateRequest; + * case server_hello_done: ServerHelloDone; + * case certificate_verify: CertificateVerify; + * case client_key_exchange: ClientKeyExchange; + * case finished: Finished; + * } body; + * } Handshake; */ private static Map exploreHandshake(ByteBuffer input, int recordLength) - throws SSLProtocolException { - // What is the handshake type? + throws SSLProtocolException { byte handshakeType = input.get(); - if (handshakeType != 0x01) // 0x01: client_hello message - throw new SSLProtocolException("Not initial handshaking"); + if (handshakeType != 0x01) { + // 0x01: client_hello message + throw new SSLProtocolException("Not an initial handshaking"); + } + // What is the handshake body length? int handshakeLength = getInt24(input); + // Theoretically, a single handshake message might span multiple // records, but in practice this does not occur. - if (handshakeLength > recordLength - 4) // 4: handshake header size - throw new SSLProtocolException("Handshake message spans multiple records"); - ((Buffer) input).limit(handshakeLength + input.position()); + if (handshakeLength > recordLength - 4) { + // 4: handshake header size + throw new SSLProtocolException("Handshake message spans multiple records"); + } + input.limit(handshakeLength + input.position()); + return exploreClientHello(input); } /* - * struct { uint32 gmt_unix_time; opaque random_bytes[28]; } Random; + * struct { + * uint32 gmt_unix_time; + * opaque random_bytes[28]; + * } Random; * * opaque SessionID<0..32>; * * uint8 CipherSuite[2]; * - * enum { null(0), (255) } CompressionMethod; + * enum { + * null(0), + * (255) + * } CompressionMethod; * - * struct { ProtocolVersion client_version; Random random; SessionID - * session_id; CipherSuite cipher_suites<2..2^16-2>; CompressionMethod - * compression_methods<1..2^8-1>; select (extensions_present) { case false: - * struct {}; case true: Extension extensions<0..2^16-1>; }; } ClientHello; + * struct { + * ProtocolVersion client_version; + * Random random; + * SessionID session_id; + * CipherSuite cipher_suites<2..2^16-2>; + * CompressionMethod compression_methods<1..2^8-1>; + * select (extensions_present) { + * case false: struct {}; + * case true: Extension extensions<0..2^16-1>; + * }; + * } ClientHello; */ - private static Map exploreClientHello(ByteBuffer input) - throws SSLProtocolException { + private static Map exploreClientHello(ByteBuffer input) throws SSLProtocolException { ignore(input, 2); // ignore version ignore(input, 32); // ignore random; 32: the length of Random ignoreByteVector8(input); // ignore session id ignoreByteVector16(input); // ignore cipher_suites ignoreByteVector8(input); // ignore compression methods - if (input.remaining() > 0) return exploreExtensions(input); - else return new HashMap<>(); + if (input.hasRemaining()) { + return exploreExtensions(input); + } else { + return new HashMap<>(); + } } /* - * struct { ExtensionType extension_type; opaque extension_data<0..2^16-1>; + * struct { + * ExtensionType extension_type; + * opaque extension_data<0..2^16-1>; * } Extension; * - * enum { server_name(0), max_fragment_length(1), client_certificate_url(2), - * trusted_ca_keys(3), truncated_hmac(4), status_request(5), (65535) } + * enum { + * server_name(0), + * max_fragment_length(1), + * client_certificate_url(2), + * trusted_ca_keys(3), + * truncated_hmac(4), + * status_request(5), + * (65535) + * } * ExtensionType; */ - private static Map exploreExtensions(ByteBuffer input) - throws SSLProtocolException { + private static Map exploreExtensions(ByteBuffer input) throws SSLProtocolException { int length = getInt16(input); // length of extensions while (length > 0) { int extType = getInt16(input); // extension type int extLen = getInt16(input); // length of extension data - if (extType == 0x00) { // 0x00: type of server name indication + if (extType == 0x00) { + // 0x00: type of server name indication return exploreSNIExt(input, extLen); - } else { // ignore other extensions + } else { + // ignore other extensions ignore(input, extLen); } length -= extLen + 4; @@ -189,49 +209,65 @@ private static Map exploreExtensions(ByteBuffer input) } /* - * struct { NameType name_type; select (name_type) { case host_name: - * HostName; } name; } ServerName; + * struct { + * NameType name_type; + * select (name_type) { + * case host_name: HostName; + * } name; + * } ServerName; * - * enum { host_name(0), (255) } NameType; + * enum { + * host_name(0), + * (255) + * } NameType; * * opaque HostName<1..2^16-1>; * - * struct { ServerName server_name_list<1..2^16-1> } ServerNameList; + * struct { + * ServerName server_name_list<1..2^16-1> + * } ServerNameList; */ - private static Map exploreSNIExt(ByteBuffer input, int extLen) - throws SSLProtocolException { + private static Map exploreSNIExt(ByteBuffer input, int extLen) throws SSLProtocolException { Map sniMap = new HashMap<>(); int remains = extLen; - if (extLen >= 2) { // "server_name" extension in ClientHello + if (extLen >= 2) { + // "server_name" extension in ClientHello int listLen = getInt16(input); // length of server_name_list - if (listLen == 0 || listLen + 2 != extLen) + if (listLen == 0 || listLen + 2 != extLen) { throw new SSLProtocolException("Invalid server name indication extension"); + } remains -= 2; // 2: the length field of server_name_list while (remains > 0) { int code = getInt8(input); // name_type int snLen = getInt16(input); // length field of server name - if (snLen > remains) + if (snLen > remains) { throw new SSLProtocolException("Not enough data to fill declared vector size"); + } byte[] encoded = new byte[snLen]; input.get(encoded); SNIServerName serverName; if (code == StandardConstants.SNI_HOST_NAME) { - if (encoded.length == 0) + if (encoded.length == 0) { throw new SSLProtocolException("Empty HostName in server name indication"); + } serverName = new SNIHostName(encoded); } else { serverName = new UnknownServerName(code, encoded); } // check for duplicated server name type - if (sniMap.put(serverName.getType(), serverName) != null) + if (sniMap.put(serverName.getType(), serverName) != null) { throw new SSLProtocolException("Duplicated server name of type " + serverName.getType()); + } remains -= encoded.length + 3; // NameType: 1 byte; HostName; // length: 2 bytesProduced } - } else if (extLen == 0) { // "server_name" extension in ServerHello + } else if (extLen == 0) { + // "server_name" extension in ServerHello throw new SSLProtocolException("Not server name indication extension in client"); } - if (remains != 0) throw new SSLProtocolException("Invalid server name indication extension"); + if (remains != 0) { + throw new SSLProtocolException("Invalid server name indication extension"); + } return sniMap; } @@ -257,7 +293,10 @@ private static void ignoreByteVector16(ByteBuffer input) { private static void ignore(ByteBuffer input, int length) { if (length != 0) { - ((Buffer) input).position(input.position() + length); + if (input.remaining() < length) { + throw new BufferUnderflowException(); + } + input.position(input.position() + length); } } @@ -267,4 +306,4 @@ private static class UnknownServerName extends SNIServerName { super(code, encoded); } } -} +} \ No newline at end of file