Skip to content

JAVA-5788 Improve ByteBufferBsonOutput::writeCharacters #1629

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions bson/src/main/org/bson/ByteBuf.java
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,13 @@ public interface ByteBuf {
*/
ByteBuf flip();

/**
* States whether this buffer is backed by an accessible byte array.
*
* @return {@code true} if, and only if, this buffer is backed by an array and is not read-only
*/
boolean hasArray();

/**
* <p>Returns the byte array that backs this buffer <em>(optional operation)</em>.</p>
*
Expand Down
5 changes: 5 additions & 0 deletions bson/src/main/org/bson/ByteBufNIO.java
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,11 @@ public ByteBuf flip() {
return this;
}

@Override
public boolean hasArray() {
return buf.hasArray();
}

@Override
public byte[] array() {
return buf.array();
Expand Down
8 changes: 6 additions & 2 deletions bson/src/main/org/bson/io/OutputBuffer.java
Original file line number Diff line number Diff line change
Expand Up @@ -196,11 +196,15 @@ public void writeLong(final long value) {
writeInt64(value);
}

private int writeCharacters(final String str, final boolean checkForNullCharacters) {
protected int writeCharacters(final String str, final boolean checkForNullCharacters) {
return writeCharacters(str, 0, checkForNullCharacters);
}

protected final int writeCharacters(final String str, int start, final boolean checkForNullCharacters) {
int len = str.length();
int total = 0;

for (int i = 0; i < len;) {
for (int i = start; i < len;) {
int c = Character.codePointAt(str, i);

if (checkForNullCharacters && c == 0x0) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@

package com.mongodb.internal.connection;

import com.mongodb.internal.connection.netty.NettyByteBuf;
import org.bson.BsonSerializationException;
import org.bson.ByteBuf;
import org.bson.io.OutputBuffer;

Expand All @@ -27,6 +29,7 @@

import static com.mongodb.assertions.Assertions.assertTrue;
import static com.mongodb.assertions.Assertions.notNull;
import static java.lang.String.format;

/**
* <p>This class is not part of the public API and may be removed or changed at any time</p>
Expand Down Expand Up @@ -273,6 +276,161 @@ public void close() {
}
}

@Override
protected int writeCharacters(final String str, final boolean checkForNullCharacters) {
ensureOpen();
ByteBuf buf = getCurrentByteBuffer();
if ((buf.remaining() >= str.length() + 1)) {
if (buf.hasArray()) {
return writeCharactersOnArray(str, checkForNullCharacters, buf);
} else if (buf instanceof NettyByteBuf) {
return writeCharactersOnNettyByteBuf(str, checkForNullCharacters, buf);
}
}
return super.writeCharacters(str, 0, checkForNullCharacters);
}

private static void validateNoNullSingleByteChars(String str, long chars, int i) {
long tmp = (chars & 0x7F7F7F7F7F7F7F7FL) + 0x7F7F7F7F7F7F7F7FL;
tmp = ~(tmp | chars | 0x7F7F7F7F7F7F7F7FL);
if (tmp != 0) {
int firstZero = Long.numberOfTrailingZeros(tmp) >>> 3;
throw new BsonSerializationException(format("BSON cstring '%s' is not valid because it contains a null character "
+ "at index %d", str, i + firstZero));
}
}

private static void validateNoNullAsciiCharacters(String str, long asciiChars, int i) {
// simplified Hacker's delight search for zero with ASCII chars i.e. which doesn't use the MSB
long tmp = asciiChars + 0x7F7F7F7F7F7F7F7FL;
// MSB is 0 iff the byte is 0x00, 1 otherwise
tmp = ~tmp & 0x8080808080808080L;
// MSB is 1 iff the byte is 0x00, 0 otherwise
if (tmp != 0) {
// there's some 0x00 in the word
int firstZero = Long.numberOfTrailingZeros(tmp) >> 3;
throw new BsonSerializationException(format("BSON cstring '%s' is not valid because it contains a null character "
+ "at index %d", str, i + firstZero));
}
}

private int writeCharactersOnNettyByteBuf(String str, boolean checkForNullCharacters, ByteBuf buf) {
int i = 0;
io.netty.buffer.ByteBuf nettyBuffer = ((NettyByteBuf) buf).asByteBuf();
// readonly buffers, netty buffers and off-heap NIO ByteBuffer
boolean slowPath = false;
int batches = str.length() / 8;
final int writerIndex = nettyBuffer.writerIndex();
// this would avoid resizing the buffer while appending: ASCII length + delimiter required space
nettyBuffer.ensureWritable(str.length() + 1);
for (int b = 0; b < batches; b++) {
i = b * 8;
// read 4 chars at time to preserve the 0x0100 cases
long evenChars = str.charAt(i) |
str.charAt(i + 2) << 16 |
(long) str.charAt(i + 4) << 32 |
(long) str.charAt(i + 6) << 48;
long oddChars = str.charAt(i + 1) |
str.charAt(i + 3) << 16 |
(long) str.charAt(i + 5) << 32 |
(long) str.charAt(i + 7) << 48;
// check that both the second byte and the MSB of the first byte of each pair is 0
// needed for cases like \u0100 and \u0080
long mergedChars = evenChars | oddChars;
if ((mergedChars & 0xFF80FF80FF80FF80L) != 0) {
if (allSingleByteChars(mergedChars)) {
i = tryWriteAsciiChars(str, checkForNullCharacters, oddChars, evenChars, nettyBuffer, writerIndex, i);
}
slowPath = true;
break;
}
// all ASCII - compose them into a single long
long asciiChars = oddChars << 8 | evenChars;
if (checkForNullCharacters) {
validateNoNullAsciiCharacters(str, asciiChars, i);
}
nettyBuffer.setLongLE(writerIndex + i, asciiChars);
}
if (!slowPath) {
i = batches * 8;
// do the rest, if any
for (; i < str.length(); i++) {
char c = str.charAt(i);
if (checkForNullCharacters && c == 0x0) {
throw new BsonSerializationException(format("BSON cstring '%s' is not valid because it contains a null character "
+ "at index %d", str, i));
}
if (c >= 0x80) {
slowPath = true;
break;
}
nettyBuffer.setByte(writerIndex + i, c);
}
}
if (slowPath) {
// ith char is not ASCII:
position += i;
buf.position(writerIndex + i);
return i + super.writeCharacters(str, i, checkForNullCharacters);
} else {
nettyBuffer.setByte(writerIndex + str.length(), 0);
int totalWritten = str.length() + 1;
position += totalWritten;
buf.position(writerIndex + totalWritten);
return totalWritten;
}
}

private static boolean allSingleByteChars(long fourChars) {
return (fourChars & 0xFF00FF00FF00FF00L) == 0;
}

private static int tryWriteAsciiChars(String str, boolean checkForNullCharacters,
long oddChars, long evenChars, io.netty.buffer.ByteBuf nettyByteBuf, int writerIndex, int i) {
// all single byte chars
long latinChars = oddChars << 8 | evenChars;
if (checkForNullCharacters) {
validateNoNullSingleByteChars(str, latinChars, i);
}
long msbSetForNonAscii = latinChars & 0x8080808080808080L;
int firstNonAsciiOffset = Long.numberOfTrailingZeros(msbSetForNonAscii) >> 3;
// that's a bit cheating :P but later phases will patch the wrongly encoded ones
nettyByteBuf.setLongLE(writerIndex + i, latinChars);
i += firstNonAsciiOffset;
return i;
}

private int writeCharactersOnArray(String str, boolean checkForNullCharacters, ByteBuf buf) {
int i = 0;
byte[] array = buf.array();
int pos = buf.position();
int len = str.length();
for (; i < len; i++) {
char c = str.charAt(i);
if (checkForNullCharacters && c == 0x0) {
throw new BsonSerializationException(format("BSON cstring '%s' is not valid because it contains a null character "
+ "at index %d", str, i));
}
if (c >= 0x80) {
break;
}
array[pos + i] = (byte) c;
}
if (i == len) {
int total = len + 1;
array[pos + len] = 0;
position += total;
buf.position(pos + total);
return len + 1;
}
// ith character is not ASCII
if (i > 0) {
position += i;
buf.position(pos + i);
}
return i + super.writeCharacters(str, i, checkForNullCharacters);
}

private static final class BufferPositionPair {
private final int bufferIndex;
private int position;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,11 @@ private int getShort(final int index) {
return (short) (get(index) & 0xff | (get(index + 1) & 0xff) << 8);
}

@Override
public boolean hasArray() {
return false;
}

@Override
public byte[] array() {
throw new UnsupportedOperationException("Not implemented yet!");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,10 @@ public ByteBuf flip() {
return this;
}

public boolean hasArray() {
return proxied.hasArray();
}

@Override
public byte[] array() {
return proxied.array();
Expand Down
Loading