diff --git a/bson/src/main/org/bson/BsonBinaryWriter.java b/bson/src/main/org/bson/BsonBinaryWriter.java index e6255ea8478..c54fa6d961a 100644 --- a/bson/src/main/org/bson/BsonBinaryWriter.java +++ b/bson/src/main/org/bson/BsonBinaryWriter.java @@ -259,7 +259,7 @@ public void doWriteNull() { public void doWriteObjectId(final ObjectId value) { bsonOutput.writeByte(BsonType.OBJECT_ID.getValue()); writeCurrentName(); - bsonOutput.writeBytes(value.toByteArray()); + bsonOutput.writeObjectId(value); } @Override diff --git a/bson/src/main/org/bson/io/BasicOutputBuffer.java b/bson/src/main/org/bson/io/BasicOutputBuffer.java index 8227940182e..aaff34d6476 100644 --- a/bson/src/main/org/bson/io/BasicOutputBuffer.java +++ b/bson/src/main/org/bson/io/BasicOutputBuffer.java @@ -18,11 +18,14 @@ import org.bson.ByteBuf; import org.bson.ByteBufNIO; +import org.bson.types.ObjectId; import java.io.IOException; import java.io.OutputStream; +import java.nio.Buffer; import java.nio.ByteBuffer; import java.util.Arrays; +import java.util.Collections; import java.util.List; import static java.lang.String.format; @@ -32,8 +35,12 @@ * A BSON output stream that stores the output in a single, un-pooled byte array. */ public class BasicOutputBuffer extends OutputBuffer { - private byte[] buffer; - private int position; + + /** + * This ByteBuffer allows us to write ObjectIDs without allocating a temporary array per object, and enables us + * to leverage JVM intrinsics for writing little-endian numeric values. + */ + private ByteBuffer buffer; /** * Construct an instance with a default initial byte array size. @@ -48,7 +55,8 @@ public BasicOutputBuffer() { * @param initialSize the initial size of the byte array */ public BasicOutputBuffer(final int initialSize) { - buffer = new byte[initialSize]; + // Allocate heap buffer to ensure we can access underlying array + buffer = ByteBuffer.allocate(initialSize).order(LITTLE_ENDIAN); } /** @@ -58,13 +66,46 @@ public BasicOutputBuffer(final int initialSize) { * @since 3.3 */ public byte[] getInternalBuffer() { - return buffer; + return buffer.array(); } @Override public void write(final byte[] b) { + writeBytes(b, 0, b.length); + } + + @Override + public byte[] toByteArray() { + ensureOpen(); + return Arrays.copyOf(buffer.array(), buffer.position()); + } + + @Override + public void writeInt32(final int value) { + ensureOpen(); + ensure(4); + buffer.putInt(value); + } + + @Override + public void writeInt32(final int position, final int value) { + ensureOpen(); + checkPosition(position, 4); + buffer.putInt(position, value); + } + + @Override + public void writeInt64(final long value) { + ensureOpen(); + ensure(8); + buffer.putLong(value); + } + + @Override + public void writeObjectId(final ObjectId value) { ensureOpen(); - write(b, 0, b.length); + ensure(12); + value.putToByteBuffer(buffer); } @Override @@ -72,8 +113,7 @@ public void writeBytes(final byte[] bytes, final int offset, final int length) { ensureOpen(); ensure(length); - System.arraycopy(bytes, offset, buffer, position, length); - position += length; + buffer.put(bytes, offset, length); } @Override @@ -81,27 +121,21 @@ public void writeByte(final int value) { ensureOpen(); ensure(1); - buffer[position++] = (byte) (0xFF & value); + buffer.put((byte) (0xFF & value)); } @Override protected void write(final int absolutePosition, final int value) { ensureOpen(); + checkPosition(absolutePosition, 1); - if (absolutePosition < 0) { - throw new IllegalArgumentException(format("position must be >= 0 but was %d", absolutePosition)); - } - if (absolutePosition > position - 1) { - throw new IllegalArgumentException(format("position must be <= %d but was %d", position - 1, absolutePosition)); - } - - buffer[absolutePosition] = (byte) (0xFF & value); + buffer.put(absolutePosition, (byte) (0xFF & value)); } @Override public int getPosition() { ensureOpen(); - return position; + return buffer.position(); } /** @@ -110,29 +144,32 @@ public int getPosition() { @Override public int getSize() { ensureOpen(); - return position; + return buffer.position(); } @Override public int pipe(final OutputStream out) throws IOException { ensureOpen(); - out.write(buffer, 0, position); - return position; + out.write(buffer.array(), 0, buffer.position()); + return buffer.position(); } @Override public void truncateToPosition(final int newPosition) { ensureOpen(); - if (newPosition > position || newPosition < 0) { + if (newPosition > buffer.position() || newPosition < 0) { throw new IllegalArgumentException(); } - position = newPosition; + // The cast is required for compatibility with JDK 9+ where ByteBuffer's position method is inherited from Buffer. + ((Buffer) buffer).position(newPosition); } @Override public List getByteBuffers() { ensureOpen(); - return Arrays.asList(new ByteBufNIO(ByteBuffer.wrap(buffer, 0, position).duplicate().order(LITTLE_ENDIAN))); + // Create a flipped copy of the buffer for reading. Note that ByteBufNIO overwrites the endian-ness. + ByteBuffer flipped = ByteBuffer.wrap(buffer.array(), 0, buffer.position()); + return Collections.singletonList(new ByteBufNIO(flipped)); } @Override @@ -147,19 +184,32 @@ private void ensureOpen() { } private void ensure(final int more) { - int need = position + more; - if (need <= buffer.length) { + int length = buffer.position(); + int need = length + more; + if (need <= buffer.capacity()) { return; } - int newSize = buffer.length * 2; + int newSize = length * 2; if (newSize < need) { newSize = need + 128; } - byte[] n = new byte[newSize]; - System.arraycopy(buffer, 0, n, 0, position); - buffer = n; + ByteBuffer tmp = ByteBuffer.allocate(newSize).order(LITTLE_ENDIAN); + tmp.put(buffer.array(), 0, length); // Avoids covariant call to flip on jdk8 + this.buffer = tmp; } + /** + * Ensures that `absolutePosition` is a valid index in `this.buffer` and there is room to write at + * least `bytesToWrite` bytes. + */ + private void checkPosition(final int absolutePosition, final int bytesToWrite) { + if (absolutePosition < 0) { + throw new IllegalArgumentException(format("position must be >= 0 but was %d", absolutePosition)); + } + if (absolutePosition > buffer.position() - bytesToWrite) { + throw new IllegalArgumentException(format("position must be <= %d but was %d", buffer.position() - bytesToWrite, absolutePosition)); + } + } } diff --git a/bson/src/test/unit/org/bson/io/BasicOutputBufferSpecification.groovy b/bson/src/test/unit/org/bson/io/BasicOutputBufferSpecification.groovy index 38de06bf8cf..758d4fc1cfd 100644 --- a/bson/src/test/unit/org/bson/io/BasicOutputBufferSpecification.groovy +++ b/bson/src/test/unit/org/bson/io/BasicOutputBufferSpecification.groovy @@ -44,9 +44,22 @@ class BasicOutputBufferSpecification extends Specification { bsonOutput.size == 1 } + def 'writeBytes shorthand should extend buffer'() { + given: + def bsonOutput = new BasicOutputBuffer(3) + + when: + bsonOutput.write([1, 2, 3, 4] as byte[]) + + then: + getBytes(bsonOutput) == [1, 2, 3, 4] as byte[] + bsonOutput.position == 4 + bsonOutput.size == 4 + } + def 'should write bytes'() { given: - def bsonOutput = new BasicOutputBuffer() + def bsonOutput = new BasicOutputBuffer(3) when: bsonOutput.writeBytes([1, 2, 3, 4] as byte[]) @@ -59,7 +72,7 @@ class BasicOutputBufferSpecification extends Specification { def 'should write bytes from offset until length'() { given: - def bsonOutput = new BasicOutputBuffer() + def bsonOutput = new BasicOutputBuffer(5) when: bsonOutput.writeBytes([0, 1, 2, 3, 4, 5] as byte[], 1, 4) @@ -70,9 +83,40 @@ class BasicOutputBufferSpecification extends Specification { bsonOutput.size == 4 } + def 'toByteArray should be idempotent'() { + given: + def bsonOutput = new BasicOutputBuffer(10) + bsonOutput.writeBytes([1, 2, 3, 4] as byte[]) + + when: + def first = bsonOutput.toByteArray() + def second = bsonOutput.toByteArray() + + then: + getBytes(bsonOutput) == [1, 2, 3, 4] as byte[] + first == [1, 2, 3, 4] as byte[] + second == [1, 2, 3, 4] as byte[] + bsonOutput.position == 4 + bsonOutput.size == 4 + } + + def 'toByteArray creates a copy'() { + given: + def bsonOutput = new BasicOutputBuffer(10) + bsonOutput.writeBytes([1, 2, 3, 4] as byte[]) + + when: + def first = bsonOutput.toByteArray() + def second = bsonOutput.toByteArray() + + then: + first !== second + first == [1, 2, 3, 4] as byte[] + second == [1, 2, 3, 4] as byte[] + } def 'should write a little endian Int32'() { given: - def bsonOutput = new BasicOutputBuffer() + def bsonOutput = new BasicOutputBuffer(3) when: bsonOutput.writeInt32(0x1020304) @@ -85,7 +129,7 @@ class BasicOutputBufferSpecification extends Specification { def 'should write a little endian Int64'() { given: - def bsonOutput = new BasicOutputBuffer() + def bsonOutput = new BasicOutputBuffer(7) when: bsonOutput.writeInt64(0x102030405060708L) @@ -98,7 +142,7 @@ class BasicOutputBufferSpecification extends Specification { def 'should write a double'() { given: - def bsonOutput = new BasicOutputBuffer() + def bsonOutput = new BasicOutputBuffer(7) when: bsonOutput.writeDouble(Double.longBitsToDouble(0x102030405060708L)) @@ -112,7 +156,7 @@ class BasicOutputBufferSpecification extends Specification { def 'should write an ObjectId'() { given: def objectIdAsByteArray = [12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1] as byte[] - def bsonOutput = new BasicOutputBuffer() + def bsonOutput = new BasicOutputBuffer(11) when: bsonOutput.writeObjectId(new ObjectId(objectIdAsByteArray)) @@ -123,6 +167,19 @@ class BasicOutputBufferSpecification extends Specification { bsonOutput.size == 12 } + def 'write ObjectId should throw after close'() { + given: + def objectIdAsByteArray = [12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1] as byte[] + def bsonOutput = new BasicOutputBuffer() + bsonOutput.close() + + when: + bsonOutput.writeObjectId(new ObjectId(objectIdAsByteArray)) + + then: + thrown(IllegalStateException) + } + def 'should write an empty string'() { given: def bsonOutput = new BasicOutputBuffer() @@ -151,7 +208,7 @@ class BasicOutputBufferSpecification extends Specification { def 'should write a UTF-8 string'() { given: - def bsonOutput = new BasicOutputBuffer() + def bsonOutput = new BasicOutputBuffer(7) when: bsonOutput.writeString('\u0900') @@ -263,6 +320,46 @@ class BasicOutputBufferSpecification extends Specification { bsonOutput.size == 8 } + def 'absolute write should throw with invalid position'() { + given: + def bsonOutput = new BasicOutputBuffer() + bsonOutput.writeBytes([1, 2, 3, 4] as byte[]) + + when: + bsonOutput.write(-1, 0x1020304) + + then: + thrown(IllegalArgumentException) + + when: + bsonOutput.write(4, 0x1020304) + + then: + thrown(IllegalArgumentException) + } + + def 'absolute write should write lower byte at position'() { + given: + def bsonOutput = new BasicOutputBuffer() + bsonOutput.writeBytes([0, 0, 0, 0, 1, 2, 3, 4] as byte[]) + + when: + bsonOutput.write(0, 0x1020304) + + then: + getBytes(bsonOutput) == [4, 0, 0, 0, 1, 2, 3, 4] as byte[] + bsonOutput.position == 8 + bsonOutput.size == 8 + + when: + bsonOutput.write(7, 0x1020304) + + then: + getBytes(bsonOutput) == [4, 0, 0, 0, 1, 2, 3, 4] as byte[] + bsonOutput.position == 8 + bsonOutput.size == 8 + } + def 'truncate should throw with invalid position'() { given: def bsonOutput = new BasicOutputBuffer() @@ -320,6 +417,20 @@ class BasicOutputBufferSpecification extends Specification { bsonOutput.getByteBuffers()[0].getInt() == 1 } + def 'should get byte buffer with limit'() { + given: + def bsonOutput = new BasicOutputBuffer(8) + bsonOutput.writeBytes([1, 0, 0, 0] as byte[]) + + when: + def buffers = bsonOutput.getByteBuffers() + + then: + buffers.size() == 1 + buffers[0].position() == 0 + buffers[0].limit() == 4 + } + def 'should get internal buffer'() { given: def bsonOutput = new BasicOutputBuffer(4)