diff --git a/bson/src/main/org/bson/ByteBuf.java b/bson/src/main/org/bson/ByteBuf.java index bc37102f0b..13dfa1f95a 100644 --- a/bson/src/main/org/bson/ByteBuf.java +++ b/bson/src/main/org/bson/ByteBuf.java @@ -192,7 +192,7 @@ public interface ByteBuf { * @return {@code true} if, and only if, this buffer is backed by an array and is not read-only * @since 5.5 */ - boolean hasArray(); + boolean isBackedByArray(); /** * Returns the offset of the first byte within the backing byte array of diff --git a/bson/src/main/org/bson/ByteBufNIO.java b/bson/src/main/org/bson/ByteBufNIO.java index ba71625d76..dfcc637907 100644 --- a/bson/src/main/org/bson/ByteBufNIO.java +++ b/bson/src/main/org/bson/ByteBufNIO.java @@ -133,7 +133,7 @@ public byte[] array() { } @Override - public boolean hasArray() { + public boolean isBackedByArray() { return buf.hasArray(); } diff --git a/bson/src/main/org/bson/io/ByteBufferBsonInput.java b/bson/src/main/org/bson/io/ByteBufferBsonInput.java index a5a0e7a542..f8be97ddaa 100644 --- a/bson/src/main/org/bson/io/ByteBufferBsonInput.java +++ b/bson/src/main/org/bson/io/ByteBufferBsonInput.java @@ -33,6 +33,13 @@ public class ByteBufferBsonInput implements BsonInput { private static final String[] ONE_BYTE_ASCII_STRINGS = new String[Byte.MAX_VALUE + 1]; + /* A dynamically sized scratch buffer, that is reused across BSON String reads: + * 1. Reduces garbage collection by avoiding new byte array creation. + * 2. Improves cache utilization through temporal locality. + * 3. Avoids JVM allocation and zeroing cost for new memory allocations. + */ + private byte[] scratchBuffer; + static { for (int b = 0; b < ONE_BYTE_ASCII_STRINGS.length; b++) { @@ -127,15 +134,12 @@ public String readString() { @Override public String readCString() { - int mark = buffer.position(); - skipCString(); - int size = buffer.position() - mark; - buffer.position(mark); + int size = computeCStringLength(buffer.position()); return readString(size); } - private String readString(final int size) { - if (size == 2) { + private String readString(final int bsonStringSize) { + if (bsonStringSize == 2) { byte asciiByte = buffer.get(); // if only one byte in the string, it must be ascii. byte nullByte = buffer.get(); // read null terminator if (nullByte != 0) { @@ -146,26 +150,51 @@ private String readString(final int size) { } return ONE_BYTE_ASCII_STRINGS[asciiByte]; // this will throw if asciiByte is negative } else { - byte[] bytes = new byte[size - 1]; - buffer.get(bytes); - byte nullByte = buffer.get(); - if (nullByte != 0) { - throw new BsonSerializationException("Found a BSON string that is not null-terminated"); + if (buffer.isBackedByArray()) { + int position = buffer.position(); + int arrayOffset = buffer.arrayOffset(); + int newPosition = position + bsonStringSize; + buffer.position(newPosition); + + byte[] array = buffer.array(); + if (array[arrayOffset + newPosition - 1] != 0) { + throw new BsonSerializationException("Found a BSON string that is not null-terminated"); + } + return new String(array, arrayOffset + position, bsonStringSize - 1, StandardCharsets.UTF_8); + } else if (scratchBuffer == null || bsonStringSize > scratchBuffer.length) { + int scratchBufferSize = bsonStringSize + (bsonStringSize >>> 1); //1.5 times the size + scratchBuffer = new byte[scratchBufferSize]; + } + + buffer.get(scratchBuffer, 0, bsonStringSize); + if (scratchBuffer[bsonStringSize - 1] != 0) { + throw new BsonSerializationException("BSON string not null-terminated"); } - return new String(bytes, StandardCharsets.UTF_8); + return new String(scratchBuffer, 0, bsonStringSize - 1, StandardCharsets.UTF_8); } } @Override public void skipCString() { ensureOpen(); - boolean checkNext = true; - while (checkNext) { - if (!buffer.hasRemaining()) { - throw new BsonSerializationException("Found a BSON string that is not null-terminated"); + int pos = buffer.position(); + int length = computeCStringLength(pos); + buffer.position(pos + length); + } + + private int computeCStringLength(final int prevPos) { + ensureOpen(); + int pos = buffer.position(); + int limit = buffer.limit(); + + while (pos < limit) { + if (buffer.get(pos++) == 0) { + return (pos - prevPos); } - checkNext = buffer.get() != 0; } + + buffer.position(pos); + throw new BsonSerializationException("Found a BSON string that is not null-terminated"); } @Override diff --git a/driver-core/src/main/com/mongodb/internal/connection/ByteBufferBsonOutput.java b/driver-core/src/main/com/mongodb/internal/connection/ByteBufferBsonOutput.java index c684eddf9f..600145db48 100644 --- a/driver-core/src/main/com/mongodb/internal/connection/ByteBufferBsonOutput.java +++ b/driver-core/src/main/com/mongodb/internal/connection/ByteBufferBsonOutput.java @@ -399,7 +399,7 @@ protected int writeCharacters(final String str, final boolean checkNullTerminati int curBufferLimit = curBuffer.limit(); int remaining = curBufferLimit - curBufferPos; - if (curBuffer.hasArray()) { + if (curBuffer.isBackedByArray()) { byte[] dst = curBuffer.array(); int arrayOffset = curBuffer.arrayOffset(); if (remaining >= str.length() + 1) { diff --git a/driver-core/src/main/com/mongodb/internal/connection/CompositeByteBuf.java b/driver-core/src/main/com/mongodb/internal/connection/CompositeByteBuf.java index e7e0186e12..a3ce668040 100644 --- a/driver-core/src/main/com/mongodb/internal/connection/CompositeByteBuf.java +++ b/driver-core/src/main/com/mongodb/internal/connection/CompositeByteBuf.java @@ -214,7 +214,7 @@ public byte[] array() { } @Override - public boolean hasArray() { + public boolean isBackedByArray() { return false; } diff --git a/driver-core/src/main/com/mongodb/internal/connection/ResponseBuffers.java b/driver-core/src/main/com/mongodb/internal/connection/ResponseBuffers.java index e984862fe0..6774b4a50a 100644 --- a/driver-core/src/main/com/mongodb/internal/connection/ResponseBuffers.java +++ b/driver-core/src/main/com/mongodb/internal/connection/ResponseBuffers.java @@ -53,13 +53,15 @@ T getResponseDocument(final int messageId, final Decode } /** - * Returns a read-only buffer containing the response body. Care should be taken to not use the returned buffer after this instance has + * Returns a buffer containing the response body. Care should be taken to not use the returned buffer after this instance has * been closed. * - * @return a read-only buffer containing the response body + * NOTE: do not modify this buffer, it is being made writable for performance reasons to avoid redundant copying. + * + * @return a buffer containing the response body */ public ByteBuf getBodyByteBuffer() { - return bodyByteBuffer.asReadOnly(); + return bodyByteBuffer; } public void reset() { diff --git a/driver-core/src/main/com/mongodb/internal/connection/netty/NettyByteBuf.java b/driver-core/src/main/com/mongodb/internal/connection/netty/NettyByteBuf.java index cbe50aaada..21124d81d3 100644 --- a/driver-core/src/main/com/mongodb/internal/connection/netty/NettyByteBuf.java +++ b/driver-core/src/main/com/mongodb/internal/connection/netty/NettyByteBuf.java @@ -125,7 +125,7 @@ public byte[] array() { } @Override - public boolean hasArray() { + public boolean isBackedByArray() { return proxied.hasArray(); } diff --git a/driver-core/src/test/unit/com/mongodb/internal/connection/ByteBufferBsonInputTest.java b/driver-core/src/test/unit/com/mongodb/internal/connection/ByteBufferBsonInputTest.java new file mode 100644 index 0000000000..0846f7a54f --- /dev/null +++ b/driver-core/src/test/unit/com/mongodb/internal/connection/ByteBufferBsonInputTest.java @@ -0,0 +1,719 @@ +/* + * Copyright 2008-present MongoDB, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.mongodb.internal.connection; + +import com.google.common.primitives.Ints; +import com.mongodb.internal.connection.netty.NettyByteBuf; +import io.netty.buffer.PooledByteBufAllocator; +import org.bson.BsonSerializationException; +import org.bson.ByteBuf; +import org.bson.ByteBufNIO; +import org.bson.io.ByteBufferBsonInput; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; + +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.nio.ByteBuffer; +import java.nio.charset.StandardCharsets; +import java.util.ArrayList; +import java.util.List; +import java.util.stream.Stream; + +import static java.lang.Character.MAX_CODE_POINT; +import static java.lang.Character.MAX_LOW_SURROGATE; +import static java.lang.Character.MIN_HIGH_SURROGATE; +import static java.lang.Integer.reverseBytes; +import static java.lang.String.join; +import static java.util.Collections.nCopies; +import static java.util.stream.Collectors.toList; +import static java.util.stream.IntStream.range; +import static java.util.stream.IntStream.rangeClosed; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; + + +class ByteBufferBsonInputTest { + + private static final List ALL_CODE_POINTS_EXCLUDING_SURROGATES = Stream.concat( + range(1, MIN_HIGH_SURROGATE).boxed(), + rangeClosed(MAX_LOW_SURROGATE + 1, MAX_CODE_POINT).boxed()) + .filter(i -> i < 128 || i % 10 == 0) // only subset of code points to speed up testing + .collect(toList()); + + static Stream bufferProviders() { + return Stream.of( + size -> new NettyByteBuf(PooledByteBufAllocator.DEFAULT.directBuffer(size)), + size -> new NettyByteBuf(PooledByteBufAllocator.DEFAULT.heapBuffer(size)), + new PowerOfTwoBufferPool(), + size -> new ByteBufNIO(ByteBuffer.wrap(new byte[size + 5], 2, size).slice()), //different array offsets + size -> new ByteBufNIO(ByteBuffer.wrap(new byte[size + 4], 3, size).slice()), //different array offsets + size -> new ByteBufNIO(ByteBuffer.allocateDirect(size)), + size -> new ByteBufNIO(ByteBuffer.allocate(size)) { + @Override + public boolean isBackedByArray() { + return false; + } + + @Override + public byte[] array() { + return Assertions.fail("array() is called, when isBackedByArray() returns false"); + } + + @Override + public int arrayOffset() { + return Assertions.fail("arrayOffset() is called, when isBackedByArray() returns false"); + } + } + ); + } + + @ParameterizedTest + @MethodSource("bufferProviders") + void shouldReadEmptyString(final BufferProvider bufferProvider) { + // given + byte[] input = {1, 0, 0, 0, 0}; + ByteBuf buffer = allocateAndWriteToBuffer(bufferProvider, input); + + try (ByteBufferBsonInput bufferInput = new ByteBufferBsonInput(buffer)) { + // when + String result = bufferInput.readString(); + + // then + assertEquals("", result); + assertEquals(5, bufferInput.getPosition()); + } + } + + @ParameterizedTest + @MethodSource("bufferProviders") + void shouldReadEmptyCString(final BufferProvider bufferProvider) { + // given + ByteBuf buffer = allocateAndWriteToBuffer(bufferProvider, new byte[]{0}); + try (ByteBufferBsonInput bufferInput = new ByteBufferBsonInput(buffer)) { + // when + String result = bufferInput.readCString(); + + // then + assertEquals("", result); + assertEquals(1, bufferInput.getPosition()); + } + } + + @ParameterizedTest + @MethodSource("bufferProviders") + void shouldReadInvalidOneByteString(final BufferProvider bufferProvider) { + ByteBuf buffer = allocateAndWriteToBuffer(bufferProvider, new byte[]{2, 0, 0, 0, (byte) 0xFF, 0}); + try (ByteBufferBsonInput bufferInput = new ByteBufferBsonInput(buffer)) { + + // when + String result = bufferInput.readString(); + + // then + assertEquals("\uFFFD", result); + assertEquals(6, bufferInput.getPosition()); + } + } + + @ParameterizedTest + @MethodSource("bufferProviders") + void shouldReadInvalidOneByteCString(final BufferProvider bufferProvider) { + ByteBuf buffer = allocateAndWriteToBuffer(bufferProvider, new byte[]{-0x01, 0}); + try (ByteBufferBsonInput bufferInput = new ByteBufferBsonInput(buffer)) { + + // when + String result = bufferInput.readCString(); + + // then + assertEquals("\uFFFD", result); + assertEquals(2, bufferInput.getPosition()); + } + } + + + @ParameterizedTest + @MethodSource("bufferProviders") + void shouldReadStringUptoBufferLimit(final BufferProvider bufferProvider) { + // given + for (Integer codePoint : ALL_CODE_POINTS_EXCLUDING_SURROGATES) { + for (int offset = 0; offset < 18; offset++) { + String expectedString = join("", nCopies(offset, "b")) + + String.valueOf(Character.toChars(codePoint)); + byte[] expectedStringEncoding = getExpectedEncodedString(expectedString); + + ByteBuf buffer = allocateAndWriteToBuffer(bufferProvider, expectedStringEncoding); + try (ByteBufferBsonInput bufferInput = new ByteBufferBsonInput(buffer)) { + + // when + String actualString = bufferInput.readString(); + + // then + assertEquals(expectedString, actualString); + assertEquals(expectedStringEncoding.length, bufferInput.getPosition()); + } + } + } + } + + @ParameterizedTest + @MethodSource("bufferProviders") + void shouldReadStringWithMoreDataInBuffer(final BufferProvider bufferProvider) throws IOException { + // given + for (Integer codePoint : ALL_CODE_POINTS_EXCLUDING_SURROGATES) { + for (int offset = 0; offset < 18; offset++) { + String expectedString = join("", nCopies(offset, "b")) + + String.valueOf(Character.toChars(codePoint)); + byte[] expectedStringEncoding = getExpectedEncodedString(expectedString); + byte[] bufferBytes = mergeArrays( + expectedStringEncoding, + new byte[]{1, 2, 3} + ); + + ByteBuf buffer = allocateAndWriteToBuffer(bufferProvider, bufferBytes); + + try (ByteBufferBsonInput bufferInput = new ByteBufferBsonInput(buffer)) { + + // when + String actualString = bufferInput.readString(); + + // then + assertEquals(expectedString, actualString); + assertEquals(expectedStringEncoding.length, bufferInput.getPosition()); + } + } + } + } + + @ParameterizedTest + @MethodSource("bufferProviders") + void shouldReadMultipleStringsWithinBuffer(final BufferProvider bufferProvider) throws IOException { + // given + for (Integer codePoint : ALL_CODE_POINTS_EXCLUDING_SURROGATES) { + for (int offset = 0; offset < 18; offset++) { + String expectedString1 = join("", nCopies(offset, "b")) + + String.valueOf(Character.toChars(codePoint)); + String expectedString2 = join("", nCopies(offset, "a")) + + String.valueOf(Character.toChars(codePoint)); + + byte[] expectedStringEncoding1 = getExpectedEncodedString(expectedString1); + byte[] expectedStringEncoding2 = getExpectedEncodedString(expectedString2); + int expectedInteger = 12412; + byte[] bufferBytes = mergeArrays( + new byte[]{1, 2, 3}, + expectedStringEncoding1, + Ints.toByteArray(reverseBytes(expectedInteger)), + expectedStringEncoding2, + new byte[]{1, 2, 3, 4} + ); + ByteBuf buffer = allocateAndWriteToBuffer(bufferProvider, bufferBytes); + buffer.position(3); + + try (ByteBufferBsonInput bufferInput = new ByteBufferBsonInput(buffer)) { + // when + String actualString1 = bufferInput.readString(); + + // then + assertEquals( + expectedString1, + actualString1); + assertEquals( + 3 + expectedStringEncoding1.length, + bufferInput.getPosition()); + + // when + assertEquals(expectedInteger, bufferInput.readInt32()); + + // then + String actualString2 = bufferInput.readString(); + assertEquals( + expectedString2, + actualString2); + assertEquals( + 3 + expectedStringEncoding1.length + expectedStringEncoding2.length + Integer.BYTES, + bufferInput.getPosition()); + } + } + } + } + + @ParameterizedTest + @MethodSource("bufferProviders") + void shouldReadConsecutiveMultipleStringsWithinBuffer(final BufferProvider bufferProvider) throws IOException { + // given + for (Integer codePoint : ALL_CODE_POINTS_EXCLUDING_SURROGATES) { + for (int offset = 0; offset < 18; offset++) { + String expectedString1 = join("", nCopies(offset, "b")) + + String.valueOf(Character.toChars(codePoint)); + String expectedString2 = join("", nCopies(offset, "a")) + + String.valueOf(Character.toChars(codePoint)); + + byte[] expectedStringEncoding1 = getExpectedEncodedString(expectedString1); + byte[] expectedStringEncoding2 = getExpectedEncodedString(expectedString2); + byte[] bufferBytes = mergeArrays( + new byte[]{1, 2, 3}, + expectedStringEncoding1, + expectedStringEncoding2, + new byte[]{1, 2, 3, 4} + ); + ByteBuf buffer = allocateAndWriteToBuffer(bufferProvider, bufferBytes); + buffer.position(3); + + try (ByteBufferBsonInput bufferInput = new ByteBufferBsonInput(buffer)) { + // when + String actualString1 = bufferInput.readString(); + + // then + assertEquals( + expectedString1, + actualString1); + assertEquals( + 3 + expectedStringEncoding1.length, + bufferInput.getPosition()); + + // when + String actualString2 = bufferInput.readString(); + + // then + assertEquals( + expectedString2, + actualString2); + assertEquals( + 3 + expectedStringEncoding1.length + expectedStringEncoding2.length, + bufferInput.getPosition()); + } + } + + } + } + + @ParameterizedTest + @MethodSource("bufferProviders") + void shouldReadConsecutiveMultipleCStringsWithinBuffer(final BufferProvider bufferProvider) throws IOException { + // given + for (Integer codePoint : ALL_CODE_POINTS_EXCLUDING_SURROGATES) { + for (int offset = 0; offset < 18; offset++) { + String expectedString1 = join("", nCopies(offset, "b")) + + String.valueOf(Character.toChars(codePoint)); + String expectedString2 = join("", nCopies(offset, "a")) + + String.valueOf(Character.toChars(codePoint)); + + byte[] expectedStringEncoding1 = getExpectedEncodedCString(expectedString1); + byte[] expectedStringEncoding2 = getExpectedEncodedCString(expectedString2); + byte[] bufferBytes = mergeArrays( + new byte[]{1, 2, 3}, + expectedStringEncoding1, + expectedStringEncoding2, + new byte[]{1, 2, 3, 4} + ); + + ByteBuf buffer = allocateAndWriteToBuffer(bufferProvider, bufferBytes); + buffer.position(3); + + try (ByteBufferBsonInput bufferInput = new ByteBufferBsonInput(buffer)) { + // when + String actualString1 = bufferInput.readCString(); + + // then + assertEquals( + expectedString1, + actualString1); + assertEquals( + 3 + expectedStringEncoding1.length, + bufferInput.getPosition()); + + // when + String actualString2 = bufferInput.readCString(); + + // then + assertEquals( + expectedString2, + actualString2); + assertEquals( + 3 + expectedStringEncoding1.length + expectedStringEncoding2.length, + bufferInput.getPosition()); + } + } + } + } + + @ParameterizedTest + @MethodSource("bufferProviders") + void shouldReadMultipleCStringsWithinBuffer(final BufferProvider bufferProvider) throws IOException { + // given + for (Integer codePoint : ALL_CODE_POINTS_EXCLUDING_SURROGATES) { + for (int offset = 0; offset < 18; offset++) { + String expectedString1 = join("", nCopies(offset, "b")) + + String.valueOf(Character.toChars(codePoint)); + String expectedString2 = join("", nCopies(offset, "a")) + + String.valueOf(Character.toChars(codePoint)); + + byte[] expectedStringEncoding1 = getExpectedEncodedCString(expectedString1); + byte[] expectedStringEncoding2 = getExpectedEncodedCString(expectedString2); + int expectedInteger = 12412; + byte[] bufferBytes = mergeArrays( + new byte[]{1, 2, 3}, + expectedStringEncoding1, + Ints.toByteArray(reverseBytes(expectedInteger)), + expectedStringEncoding2, + new byte[]{1, 2, 3, 4} + ); + ByteBuf buffer = allocateAndWriteToBuffer(bufferProvider, bufferBytes); + buffer.position(3); + + try (ByteBufferBsonInput bufferInput = new ByteBufferBsonInput(buffer)) { + // when + String actualString1 = bufferInput.readCString(); + + // then + assertEquals( + expectedString1, + actualString1); + assertEquals( + 3 + expectedStringEncoding1.length, + bufferInput.getPosition()); + + // when + int actualInteger = bufferInput.readInt32(); + + // then + assertEquals(expectedInteger, actualInteger); + + // when + String actualString2 = bufferInput.readCString(); + + // then + assertEquals( + expectedString2, + actualString2); + assertEquals( + 3 + expectedStringEncoding1.length + expectedStringEncoding2.length + Integer.BYTES, + bufferInput.getPosition()); + } + } + } + } + + @ParameterizedTest + @MethodSource("bufferProviders") + void shouldReadStringWithinBuffer(final BufferProvider bufferProvider) throws IOException { + // given + for (Integer codePoint : ALL_CODE_POINTS_EXCLUDING_SURROGATES) { + for (int offset = 0; offset < 18; offset++) { + String expectedString = join("", nCopies(offset, "b")) + + String.valueOf(Character.toChars(codePoint)); + + byte[] expectedStringEncoding = getExpectedEncodedString(expectedString); + byte[] bufferBytes = mergeArrays( + new byte[]{1, 2, 3}, + expectedStringEncoding, + new byte[]{4, 5, 6} + ); + + ByteBuf buffer = allocateAndWriteToBuffer(bufferProvider, bufferBytes); + buffer.position(3); + + try (ByteBufferBsonInput bufferInput = new ByteBufferBsonInput(buffer)) { + + // when + String actualString = bufferInput.readString(); + + // then + assertEquals(expectedString, actualString); + assertEquals(3 + expectedStringEncoding.length, bufferInput.getPosition()); + } + } + } + } + + @ParameterizedTest + @MethodSource("bufferProviders") + void shouldReadCStringUptoBufferLimit(final BufferProvider bufferProvider) { + // given + for (Integer codePoint : ALL_CODE_POINTS_EXCLUDING_SURROGATES) { + for (int offset = 0; offset < 18; offset++) { + String expectedString = join("", nCopies(offset, "b")) + + String.valueOf(Character.toChars(codePoint)); + byte[] expectedStringEncoding = getExpectedEncodedCString(expectedString); + ByteBuf buffer = allocateAndWriteToBuffer(bufferProvider, expectedStringEncoding); + + try (ByteBufferBsonInput bufferInput = new ByteBufferBsonInput(buffer)) { + + // when + String actualString = bufferInput.readCString(); + + // then + assertEquals(expectedString, actualString); + assertEquals(expectedStringEncoding.length, bufferInput.getPosition()); + } + } + } + } + + @ParameterizedTest + @MethodSource("bufferProviders") + void shouldReadCStringWithMoreDataInBuffer(final BufferProvider bufferProvider) throws IOException { + // given + for (Integer codePoint : ALL_CODE_POINTS_EXCLUDING_SURROGATES) { + for (int offset = 0; offset < 18; offset++) { + String expectedString = join("", nCopies(offset, "b")) + + String.valueOf(Character.toChars(codePoint)); + byte[] expectedStringEncoding = getExpectedEncodedCString(expectedString); + byte[] bufferBytes = mergeArrays( + expectedStringEncoding, + new byte[]{1, 2, 3} + ); + + ByteBuf buffer = allocateAndWriteToBuffer(bufferProvider, bufferBytes); + + try (ByteBufferBsonInput bufferInput = new ByteBufferBsonInput(buffer)) { + + // when + String actualString = bufferInput.readCString(); + + // then + assertEquals(expectedString, actualString); + assertEquals(expectedStringEncoding.length, bufferInput.getPosition()); + } + } + } + } + + @ParameterizedTest + @MethodSource("bufferProviders") + void shouldReadCStringWithingBuffer(final BufferProvider bufferProvider) throws IOException { + // given + for (Integer codePoint : ALL_CODE_POINTS_EXCLUDING_SURROGATES) { + for (int offset = 0; offset < 18; offset++) { + //given + String expectedString = join("", nCopies(offset, "b")) + + String.valueOf(Character.toChars(codePoint)); + + byte[] expectedStringEncoding = getExpectedEncodedCString(expectedString); + byte[] bufferBytes = mergeArrays( + new byte[]{1, 2, 3}, + expectedStringEncoding, + new byte[]{4, 5, 6} + ); + + ByteBuf buffer = allocateAndWriteToBuffer(bufferProvider, bufferBytes); + buffer.position(3); + + try (ByteBufferBsonInput bufferInput = new ByteBufferBsonInput(buffer)) { + // when + String actualString = bufferInput.readCString(); + + // then + assertEquals(expectedString, actualString); + assertEquals(3 + expectedStringEncoding.length, bufferInput.getPosition()); + } + } + } + } + + @ParameterizedTest + @MethodSource("bufferProviders") + void shouldThrowIfCStringIsNotNullTerminatedSkip(final BufferProvider bufferProvider) { + // given + ByteBuf buffer = allocateAndWriteToBuffer(bufferProvider, new byte[]{(byte) 0xe0, (byte) 0xa4, (byte) 0x80}); + try (ByteBufferBsonInput expectedString = new ByteBufferBsonInput(buffer)) { + + // when & then + assertThrows(BsonSerializationException.class, expectedString::skipCString); + } + } + + + public static Stream nonNullTerminatedStringsWithBuffers() { + List arguments = new ArrayList<>(); + List collect = bufferProviders().collect(toList()); + for (BufferProvider bufferProvider : collect) { + arguments.add(Arguments.of(new byte[]{1, 0, 0, 0, 1}, bufferProvider)); + arguments.add(Arguments.of(new byte[]{2, 0, 0, 0, 1, 3}, bufferProvider)); + arguments.add(Arguments.of(new byte[]{3, 0, 0, 1, 2, 3}, bufferProvider)); + arguments.add(Arguments.of(new byte[]{4, 0, 0, 0, 1, 2, 3, 4}, bufferProvider)); + arguments.add(Arguments.of(new byte[]{8, 0, 0, 0, 2, 3, 4, 5, 6, 7, 8, 9}, bufferProvider)); + arguments.add(Arguments.of(new byte[]{9, 0, 0, 0, 2, 3, 4, 5, 6, 7, 8, 9, 1}, bufferProvider)); + } + return arguments.stream(); + } + + @ParameterizedTest + @MethodSource("nonNullTerminatedStringsWithBuffers") + void shouldThrowIfStringIsNotNullTerminated(final byte[] nonNullTerminatedString, final BufferProvider bufferProvider) { + // given + ByteBuf buffer = allocateAndWriteToBuffer(bufferProvider, nonNullTerminatedString); + try (ByteBufferBsonInput expectedString = new ByteBufferBsonInput(buffer)) { + + // when & then + assertThrows(BsonSerializationException.class, expectedString::readString); + } + } + + public static Stream nonNullTerminatedCStringsWithBuffers() { + List arguments = new ArrayList<>(); + List collect = bufferProviders().collect(toList()); + for (BufferProvider bufferProvider : collect) { + arguments.add(Arguments.of(new byte[]{1}, bufferProvider)); + arguments.add(Arguments.of(new byte[]{1, 2}, bufferProvider)); + arguments.add(Arguments.of(new byte[]{1, 2, 3}, bufferProvider)); + arguments.add(Arguments.of(new byte[]{1, 2, 3, 4}, bufferProvider)); + arguments.add(Arguments.of(new byte[]{2, 3, 4, 5, 6, 7, 8, 9}, bufferProvider)); + arguments.add(Arguments.of(new byte[]{2, 3, 4, 5, 6, 7, 8, 9, 1}, bufferProvider)); + } + return arguments.stream(); + } + + @ParameterizedTest + @MethodSource("nonNullTerminatedCStringsWithBuffers") + void shouldThrowIfCStringIsNotNullTerminated(final byte[] nonNullTerminatedCString, final BufferProvider bufferProvider) { + // given + ByteBuf buffer = allocateAndWriteToBuffer(bufferProvider, nonNullTerminatedCString); + try (ByteBufferBsonInput bufferInput = new ByteBufferBsonInput(buffer)) { + + // when & then + assertThrows(BsonSerializationException.class, bufferInput::readCString); + } + } + + + @ParameterizedTest + @MethodSource("bufferProviders") + void shouldThrowIfOneByteStringIsNotNullTerminated(final BufferProvider bufferProvider) { + // given + ByteBuf buffer = allocateAndWriteToBuffer(bufferProvider, new byte[]{2, 0, 0, 0, 1}); + try (ByteBufferBsonInput bufferInput = new ByteBufferBsonInput(buffer)) { + + // when & then + assertThrows(BsonSerializationException.class, bufferInput::readString); + } + } + + @ParameterizedTest + @MethodSource("bufferProviders") + void shouldThrowIfOneByteCStringIsNotNullTerminated(final BufferProvider bufferProvider) { + // given + ByteBuf buffer = allocateAndWriteToBuffer(bufferProvider, new byte[]{1}); + try (ByteBufferBsonInput bufferInput = new ByteBufferBsonInput(buffer)) { + + // when & then + assertThrows(BsonSerializationException.class, bufferInput::readCString); + } + } + + @ParameterizedTest + @MethodSource("bufferProviders") + void shouldThrowIfLengthOfBsonStringIsNotPositive(final BufferProvider bufferProvider) { + // given + ByteBuf buffer = allocateAndWriteToBuffer(bufferProvider, new byte[]{-1, -1, -1, -1, 41, 42, 43, 0}); + try (ByteBufferBsonInput bufferInput = new ByteBufferBsonInput(buffer)) { + + // when & then + assertThrows(BsonSerializationException.class, bufferInput::readString); + } + } + + public static Stream shouldSkipCStringWhenMultipleNullTerminationPresent() { + List arguments = new ArrayList<>(); + List collect = bufferProviders().collect(toList()); + for (BufferProvider bufferProvider : collect) { + arguments.add(Arguments.of(new byte[]{0, 8, 0, 0, 0}, bufferProvider)); + arguments.add(Arguments.of(new byte[]{0x4a, 0, 8, 0, 0, 0}, bufferProvider)); + arguments.add(Arguments.of(new byte[]{0x4a, 0x4b, 0, 8, 0, 0, 0}, bufferProvider)); + arguments.add(Arguments.of(new byte[]{0x4a, 0x4b, 0x4c, 0, 8, 0, 0, 0}, bufferProvider)); + arguments.add(Arguments.of(new byte[]{0x4a, 0x61, 0x76, 0x61, 0, 8, 0, 0, 0}, bufferProvider)); + arguments.add(Arguments.of(new byte[]{0x4a, 0x61, 0x76, 0x61, 0x62, 0, 8, 0, 0, 0}, bufferProvider)); + arguments.add(Arguments.of(new byte[]{0x4a, 0x61, 0x76, 0x61, 0x65, 0x62, 0x67, 0, 8, 0, 0, 0}, bufferProvider)); + arguments.add(Arguments.of(new byte[]{0x4a, 0, 8, 0, 0, 0}, bufferProvider)); + } + return arguments.stream(); + } + + @ParameterizedTest + @MethodSource() + void shouldSkipCStringWhenMultipleNullTerminationPresent(final byte[] cStringBytes, final BufferProvider bufferProvider) { + // given + ByteBuf buffer = allocateAndWriteToBuffer(bufferProvider, cStringBytes); + try (ByteBufferBsonInput bufferInput = new ByteBufferBsonInput(buffer)) { + + // when + bufferInput.skipCString(); + + //then + assertEquals(cStringBytes.length - Integer.BYTES, bufferInput.getPosition()); + assertEquals(8, bufferInput.readInt32()); + } + } + + @ParameterizedTest + @MethodSource("bufferProviders") + void shouldReadSkipCStringWhenMultipleNullTerminationPresentWithinBuffer(final BufferProvider bufferProvider) { + // given + byte[] input = {4, 0, 0, 0, 0x4a, 0x61, 0x76, 0x61, 0, 8, 0, 0, 0}; + ByteBuf buffer = allocateAndWriteToBuffer(bufferProvider, input); + buffer.position(4); + try (ByteBufferBsonInput bufferInput = new ByteBufferBsonInput(buffer)) { + + // when + bufferInput.skipCString(); + + // then + assertEquals(9, bufferInput.getPosition()); + assertEquals(8, bufferInput.readInt32()); + } + } + + + private static ByteBuf allocateAndWriteToBuffer(final BufferProvider bufferProvider, final byte[] input) { + ByteBuf buffer = bufferProvider.getBuffer(input.length); + buffer.put(input, 0, input.length); + buffer.flip(); + return buffer; + } + + + public static byte[] mergeArrays(final byte[]... arrays) throws IOException { + int size = 0; + for (byte[] array : arrays) { + size += array.length; + } + ByteArrayOutputStream baos = new ByteArrayOutputStream(size); + for (byte[] array : arrays) { + baos.write(array); + } + return baos.toByteArray(); + } + + private static byte[] getExpectedEncodedString(final String expectedString) { + byte[] expectedEncoding = expectedString.getBytes(StandardCharsets.UTF_8); + int littleEndianLength = reverseBytes(expectedEncoding.length + "\u0000".length()); + byte[] length = Ints.toByteArray(littleEndianLength); + + byte[] combined = new byte[expectedEncoding.length + length.length + 1]; + System.arraycopy(length, 0, combined, 0, length.length); + System.arraycopy(expectedEncoding, 0, combined, length.length, expectedEncoding.length); + return combined; + } + + private static byte[] getExpectedEncodedCString(final String expectedString) { + byte[] encoding = expectedString.getBytes(StandardCharsets.UTF_8); + byte[] combined = new byte[encoding.length + 1]; + System.arraycopy(encoding, 0, combined, 0, encoding.length); + return combined; + } +} diff --git a/driver-core/src/test/unit/com/mongodb/internal/connection/ByteBufferBsonOutputTest.java b/driver-core/src/test/unit/com/mongodb/internal/connection/ByteBufferBsonOutputTest.java index bd05546111..4ab076dd5d 100644 --- a/driver-core/src/test/unit/com/mongodb/internal/connection/ByteBufferBsonOutputTest.java +++ b/driver-core/src/test/unit/com/mongodb/internal/connection/ByteBufferBsonOutputTest.java @@ -87,18 +87,18 @@ static Stream bufferProviders() { size -> new ByteBufNIO(ByteBuffer.wrap(new byte[size + 4], 3, size).slice()), //different array offsets size -> new ByteBufNIO(ByteBuffer.allocate(size)) { @Override - public boolean hasArray() { + public boolean isBackedByArray() { return false; } @Override public byte[] array() { - return Assertions.fail("array() is called, when hasArray() returns false"); + return Assertions.fail("array() is called, when isBackedByArray() returns false"); } @Override public int arrayOffset() { - return Assertions.fail("arrayOffset() is called, when hasArray() returns false"); + return Assertions.fail("arrayOffset() is called, when isBackedByArray() returns false"); } } );