Skip to content

Optimize String write #1651

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 15 commits into from
Apr 17, 2025
20 changes: 20 additions & 0 deletions bson/src/main/org/bson/ByteBuf.java
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,26 @@ public interface ByteBuf {
*/
byte[] array();

/**
* <p>States whether this buffer is backed by an accessible byte array.</p>
*
* <p>If this method returns {@code true} then the {@link #array()} and {@link #arrayOffset()} methods may safely be invoked.</p>
*
* @return {@code true} if, and only if, this buffer is backed by an array and is not read-only
* @since 5.5
*/
boolean hasArray();

/**
* Returns the offset of the first byte within the backing byte array of
* this buffer.
*
* @throws java.nio.ReadOnlyBufferException If this buffer is backed by an array but is read-only
* @throws UnsupportedOperationException if this buffer is not backed by an accessible array
* @since 5.5
*/
int arrayOffset();

/**
* Returns this buffer's limit.
*
Expand Down
10 changes: 10 additions & 0 deletions bson/src/main/org/bson/ByteBufNIO.java
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,16 @@ public byte[] array() {
return buf.array();
}

@Override
public boolean hasArray() {
return buf.hasArray();
}

@Override
public int arrayOffset() {
return buf.arrayOffset();
}

@Override
public int limit() {
return buf.limit();
Expand Down
2 changes: 1 addition & 1 deletion bson/src/main/org/bson/io/OutputBuffer.java
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ public void writeLong(final long value) {
writeInt64(value);
}

private int writeCharacters(final String str, final boolean checkForNullCharacters) {
protected int writeCharacters(final String str, final boolean checkForNullCharacters) {
int len = str.length();
int total = 0;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

package com.mongodb.internal.connection;

import org.bson.BsonSerializationException;
import org.bson.ByteBuf;
import org.bson.io.OutputBuffer;

Expand All @@ -25,8 +26,10 @@
import java.util.ArrayList;
import java.util.List;

import static com.mongodb.assertions.Assertions.assertFalse;
import static com.mongodb.assertions.Assertions.assertTrue;
import static com.mongodb.assertions.Assertions.notNull;
import static java.lang.String.format;

/**
* <p>This class is not part of the public API and may be removed or changed at any time</p>
Expand Down Expand Up @@ -178,11 +181,17 @@ private ByteBuf getCurrentByteBuffer() {
return getByteBufferAtIndex(curBufferIndex);
}

private ByteBuf getNextByteBuffer() {
assertFalse(bufferList.get(curBufferIndex).hasRemaining());
return getByteBufferAtIndex(++curBufferIndex);
}

private ByteBuf getByteBufferAtIndex(final int index) {
if (bufferList.size() < index + 1) {
bufferList.add(bufferProvider.getBuffer(index >= (MAX_SHIFT - INITIAL_SHIFT)
? MAX_BUFFER_SIZE
: Math.min(INITIAL_BUFFER_SIZE << index, MAX_BUFFER_SIZE)));
ByteBuf buffer = bufferProvider.getBuffer(index >= (MAX_SHIFT - INITIAL_SHIFT)
? MAX_BUFFER_SIZE
: Math.min(INITIAL_BUFFER_SIZE << index, MAX_BUFFER_SIZE));
bufferList.add(buffer);
}
return bufferList.get(index);
}
Expand Down Expand Up @@ -225,6 +234,16 @@ public List<ByteBuf> getByteBuffers() {
return buffers;
}

public List<ByteBuf> getDuplicateByteBuffers() {
ensureOpen();

List<ByteBuf> buffers = new ArrayList<>(bufferList.size());
for (final ByteBuf cur : bufferList) {
buffers.add(cur.duplicate().order(ByteOrder.LITTLE_ENDIAN));
}
return buffers;
}


@Override
public int pipe(final OutputStream out) throws IOException {
Expand All @@ -233,14 +252,18 @@ public int pipe(final OutputStream out) throws IOException {
byte[] tmp = new byte[INITIAL_BUFFER_SIZE];

int total = 0;
for (final ByteBuf cur : getByteBuffers()) {
ByteBuf dup = cur.duplicate();
while (dup.hasRemaining()) {
int numBytesToCopy = Math.min(dup.remaining(), tmp.length);
dup.get(tmp, 0, numBytesToCopy);
out.write(tmp, 0, numBytesToCopy);
List<ByteBuf> byteBuffers = getByteBuffers();
try {
for (final ByteBuf cur : byteBuffers) {
while (cur.hasRemaining()) {
int numBytesToCopy = Math.min(cur.remaining(), tmp.length);
cur.get(tmp, 0, numBytesToCopy);
out.write(tmp, 0, numBytesToCopy);
}
total += cur.limit();
}
total += dup.limit();
} finally {
byteBuffers.forEach(ByteBuf::release);
}
return total;
}
Expand Down Expand Up @@ -360,4 +383,165 @@ private static final class BufferPositionPair {
this.position = position;
}
}

protected int writeCharacters(final String str, final boolean checkNullTermination) {
int stringLength = str.length();
int sp = 0;
int prevPos = position;

ByteBuf curBuffer = getCurrentByteBuffer();
int curBufferPos = curBuffer.position();
int curBufferLimit = curBuffer.limit();
int remaining = curBufferLimit - curBufferPos;

if (curBuffer.hasArray()) {
byte[] dst = curBuffer.array();
int arrayOffset = curBuffer.arrayOffset();
if (remaining >= str.length() + 1) {
// Write ASCII characters directly to the array until we hit a non-ASCII character.
sp = writeOnArrayAscii(str, dst, arrayOffset + curBufferPos, checkNullTermination);
curBufferPos += sp;
// If the whole string was written as ASCII, append the null terminator.
if (sp == stringLength) {
dst[arrayOffset + curBufferPos++] = 0;
position += sp + 1;
curBuffer.position(curBufferPos);
return sp + 1;
}
// Otherwise, update the position to reflect the partial write.
position += sp;
curBuffer.position(curBufferPos);
}
}

// We get here, when the buffer is not backed by an array, or when the string contains at least one non-ASCII characters.
return writeOnBuffers(str,
Copy link

@franz1981 franz1981 Apr 1, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can't we have within this a fast PATH for ASCII too?
It will grant more chances to get inlined (since is a smaller method) and be more unrolled...
If we can have a JMH bench it would be fairly easy (I can do it) to peek into the assembly produced to verify it

Copy link
Member Author

@vbabanin vbabanin Apr 2, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I’d expect the fast path for buffers to be in the else branch of if (curBuffer.hasArray()).
However, once we detect UTF-8 characters there, we call a fallback writeOnBuffers(maybe we could rename it to writeUtf8OnBuffers).

Are you suggesting we add a fast path similar to writeOnArrayAscii, but using dynamic buffer allocation and falling back to writeOnBuffers/writeUtf8OnBuffers when a UTF-8 character is encountered?

Copy link

@franz1981 franz1981 Apr 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are you suggesting we add a fast path similar to writeOnArrayAscii, but using dynamic buffer allocation and falling back to writeOnBuffers/writeUtf8OnBuffers when a UTF-8 character is encountered?

Yep, since I see the ascii path there is already taking care to change the buffer to write against instead of performing a lookup per each byte to write.
A tighter loop increase the chance it to be loop unrolled, although...the fact we can change the buffer where to write during the loop, can affect this - both for the array case and this.

Copy link
Member Author

@vbabanin vbabanin Apr 14, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, it makes sense. Thanks for the suggestion. I implemented writeOnBuffersAsccii and ran local benchmarks.

The implementation of writeOnBuffersAsccii
private int writeOnBuffersAsccii(final String str,
                                     final boolean checkNullTermination,
                                     final int stringPointer,
                                     final int bufferLimit,
                                     final int bufferPos,
                                     final ByteBuf buffer) {
        int remaining;
        int sp = stringPointer;
        int curBufferPos = bufferPos;
        int curBufferLimit = bufferLimit;
        ByteBuf curBuffer = buffer;
        final int length = str.length();

        while (sp < length) {
            remaining = curBufferLimit - curBufferPos;
            char c = str.charAt(sp);
            if (checkNullTermination && c == 0) {
                throw new BsonSerializationException(
                        format("BSON cstring '%s' is not valid because it contains a null character " + "at index %d", str, sp));
            }
            if (c >= 0x80) {
                break;
            }
            if (remaining == 0) {
                curBuffer = getNextByteBuffer();
                curBufferPos = 0;
                curBufferLimit = curBuffer.limit();
            }
            curBuffer.put((byte) c);
            position++;
            sp++;
            curBufferPos++;
        }
        return sp;
    }

I printed the assembly to see the JIT’s behavior. It looks like the loop wasn’t unrolled by JIT - there’s only one charAt and put per iteration in the main loop. However, there seems to be an additional charAt in the buffer allocation path (after getNextByteBuffer), but it’s a separate code path and mostly an edge-case.

I’ve shared a GitHub Gist with the shortened assembly (keeping the key parts) and a pseudo-Java interpretation to show how the assembly might map back to the logic: Gist. Local perf showed modest gains likely limited by dynamic buffer allocation, as you noted. I’ll run more tests on a dedicated perf instance to confirm. If I missed anything in the assembly, please let me know!

I’m merging this PR for the current improvements, but I agree tighter loops or manual unrolling could be further explored, keeping in mind the maintainability trade-off.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I cannot see the assembly there , but the not decoded binary instead - did you miss the https://blogs.oracle.com/javamagazine/post/java-hotspot-hsdis-disassembler so in your class path?
that will help me a lot

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I mean something which uses compile command too as TechEmpower/FrameworkBenchmarks#9800 (comment)

Copy link
Member Author

@vbabanin vbabanin Apr 17, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You’re right - my earlier Gist showed raw hex. I recompiled on Oracle JDK 17.0.7 with hsdis for readable assembly. Thanks!

The main loop (0x0000000113a12c40–0x0000000113a12d3c) seems to have no unrolling:

Condition: 0x0000000113a12c40 (cmp w4, w15, line 453: sp < length).
charAt: One instance at 0x0000000113a12c60–0x0000000113a12c88 (line 455).

Checks: 
Null termination (0x0000000113a12c8c, line 456), ASCII (0x0000000113a12c94, line 460),
buffer space (0x0000000113a12c98, line 463).

put: One instance at 0x0000000113a12ca4–0x0000000113a12d20 (line 468).
Index Updates: 0x0000000113a12d24–0x0000000113a12d34 (lines 469–471).
Branch Back: 0x0000000113a12d3c: b.lt 0x0000000113a12c60, looping to 0x0000000113a12c40.

A second charAt seem to appear in the getNextByteBuffer path, not the main loop after compilation. I’ve created new Gist with the readable assembly.

Copy link

@franz1981 franz1981 Apr 17, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mmm It looks to me that the unrolling was having a factor of two, I will re read it again since I am more used to x86 asm :)

Anyway, I suggest to look at the PR I sent for this same optimization: having the check for remaining buffer space in the loop would bloat the loop body, reducing the chances that C2 will unroll it many times.
You should keep the loop as simple and branch free as possible

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Another reason why the Netty version loop body is too fat is because the JIT doesn't trust final fields and since you can get a new mongo ByteBuf at each iteration, the required amount of pointer chase (mongo buf, Netty swapped big, Netty buf, Unsafe..) is way to much; this prevent massively to have unrolling.
The reason why I have unwrapped till Netty or NIO, the buffers, is to be as per byte[], at the level in which the loop body can just use whatever is saved into a register (the reference of the address field in the Netty buffer) assuming it to be constant, and trust it.

checkNullTermination,
sp,
stringLength,
curBufferLimit,
curBufferPos,
curBuffer,
prevPos);
}

private int writeOnBuffers(final String str,
final boolean checkNullTermination,
final int stringPointer,
final int stringLength,
final int bufferLimit,
final int bufferPos,
final ByteBuf buffer,
final int prevPos) {
int remaining;
int sp = stringPointer;
int curBufferPos = bufferPos;
int curBufferLimit = bufferLimit;
ByteBuf curBuffer = buffer;
while (sp < stringLength) {
remaining = curBufferLimit - curBufferPos;
int c = str.charAt(sp);

if (checkNullTermination && c == 0x0) {
throw new BsonSerializationException(
format("BSON cstring '%s' is not valid because it contains a null character " + "at index %d", str, sp));
}

if (c < 0x80) {
if (remaining == 0) {
curBuffer = getNextByteBuffer();
curBufferPos = 0;
curBufferLimit = curBuffer.limit();
}
curBuffer.put((byte) c);
curBufferPos++;
position++;
} else if (c < 0x800) {
if (remaining < 2) {
// Not enough space: use write() to handle buffer boundary
write((byte) (0xc0 + (c >> 6)));
write((byte) (0x80 + (c & 0x3f)));

curBuffer = getCurrentByteBuffer();
curBufferPos = curBuffer.position();
curBufferLimit = curBuffer.limit();
} else {
curBuffer.put((byte) (0xc0 + (c >> 6)));
curBuffer.put((byte) (0x80 + (c & 0x3f)));
curBufferPos += 2;
position += 2;
}
} else {
// Handle multibyte characters (may involve surrogate pairs).
c = Character.codePointAt(str, sp);
/*
Malformed surrogate pairs are encoded as-is (3 byte code unit) without substituting any code point.
This known deviation from the spec and current functionality remains for backward compatibility.
Ticket: JAVA-5575
*/
if (c < 0x10000) {
if (remaining < 3) {
write((byte) (0xe0 + (c >> 12)));
write((byte) (0x80 + ((c >> 6) & 0x3f)));
write((byte) (0x80 + (c & 0x3f)));

curBuffer = getCurrentByteBuffer();
curBufferPos = curBuffer.position();
curBufferLimit = curBuffer.limit();
} else {
curBuffer.put((byte) (0xe0 + (c >> 12)));
curBuffer.put((byte) (0x80 + ((c >> 6) & 0x3f)));
curBuffer.put((byte) (0x80 + (c & 0x3f)));
curBufferPos += 3;
position += 3;
}
} else {
if (remaining < 4) {
write((byte) (0xf0 + (c >> 18)));
write((byte) (0x80 + ((c >> 12) & 0x3f)));
write((byte) (0x80 + ((c >> 6) & 0x3f)));
write((byte) (0x80 + (c & 0x3f)));

curBuffer = getCurrentByteBuffer();
curBufferPos = curBuffer.position();
curBufferLimit = curBuffer.limit();
} else {
curBuffer.put((byte) (0xf0 + (c >> 18)));
curBuffer.put((byte) (0x80 + ((c >> 12) & 0x3f)));
curBuffer.put((byte) (0x80 + ((c >> 6) & 0x3f)));
curBuffer.put((byte) (0x80 + (c & 0x3f)));
curBufferPos += 4;
position += 4;
}
}
}
sp += Character.charCount(c);
}

getCurrentByteBuffer().put((byte) 0);
position++;
return position - prevPos;
}

private static int writeOnArrayAscii(final String str,
final byte[] dst,
final int arrayPosition,
final boolean checkNullTermination) {
int pos = arrayPosition;
int sp = 0;
// Fast common path: This tight loop is JIT-friendly (simple, no calls, few branches),
// It might be unrolled for performance.
for (; sp < str.length(); sp++, pos++) {
char c = str.charAt(sp);
if (checkNullTermination && c == 0) {
throw new BsonSerializationException(
format("BSON cstring '%s' is not valid because it contains a null character " + "at index %d", str, sp));
}
if (c >= 0x80) {
break;
}
dst[pos] = (byte) c;
}
return sp;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,16 @@ public byte[] array() {
throw new UnsupportedOperationException("Not implemented yet!");
}

@Override
public boolean hasArray() {
return false;
}

@Override
public int arrayOffset() {
throw new UnsupportedOperationException("Not implemented yet!");
}

@Override
public ByteBuf limit(final int newLimit) {
if (newLimit < 0 || newLimit > capacity()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,16 @@ public byte[] array() {
return proxied.array();
}

@Override
public boolean hasArray() {
return proxied.hasArray();
}

@Override
public int arrayOffset() {
return proxied.arrayOffset();
}

@Override
public int limit() {
if (isWriting) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -249,11 +249,7 @@ class ByteBufSpecification extends Specification {
@Override
ByteBuf getBuffer(final int size) {
io.netty.buffer.ByteBuf buffer = allocator.directBuffer(size, size)
try {
new NettyByteBuf(buffer.retain())
} finally {
buffer.release();
}
new NettyByteBuf(buffer)
}
}
}
Loading