Skip to content

Commit 2a262bf

Browse files
Morten Grouleffluben
Morten Grouleff
authored andcommitted
Add new constructor to ZstdDictCompress and ZstdDictDecompress that
allows the byReference semantics for the provided byte buffer: If you set this to true, you avoid the copying of the dict data into a natively malloc'ed buffer, but then also have to promise that the byte buffer will not be modified before the CTX has been closed.
1 parent a516a43 commit 2a262bf

File tree

4 files changed

+68
-15
lines changed

4 files changed

+68
-15
lines changed

src/main/java/com/github/luben/zstd/ZstdDictCompress.java

+25-3
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,14 @@ public class ZstdDictCompress extends SharedDictBase {
1010
}
1111

1212
private long nativePtr = 0;
13+
14+
private ByteBuffer sharedDict = null;
15+
1316
private int level = Zstd.defaultCompressionLevel();
1417

1518
private native void init(byte[] dict, int dict_offset, int dict_size, int level);
1619

17-
private native void initDirect(ByteBuffer dict, int dict_offset, int dict_size, int level);
20+
private native void initDirect(ByteBuffer dict, int dict_offset, int dict_size, int level, int byReference);
1821

1922
private native void free();
2023

@@ -59,6 +62,18 @@ public ZstdDictCompress(byte[] dict, int offset, int length, int level) {
5962
* @param level compression level
6063
*/
6164
public ZstdDictCompress(ByteBuffer dict, int level) {
65+
this(dict, level, false);
66+
}
67+
68+
/**
69+
* Create a new dictionary for use with fast compress.
70+
* If byReference is true, then the native code does not copy the data but keeps a reference to the byte buffer, which must then not be modified before this context has been closed.
71+
*
72+
* @param dict Direct ByteBuffer containing dictionary using position and limit to define range in buffer.
73+
* @param level compression level
74+
* @param byReference tell the native part to use the byte buffer directly and not copy the data when true.
75+
*/
76+
public ZstdDictCompress(ByteBuffer dict, int level, boolean byReference) {
6277
this.level = level;
6378
int length = dict.limit() - dict.position();
6479
if (!dict.isDirect()) {
@@ -67,11 +82,14 @@ public ZstdDictCompress(ByteBuffer dict, int level) {
6782
if (length < 0) {
6883
throw new IllegalArgumentException("dict cannot be empty.");
6984
}
70-
initDirect(dict, dict.position(), length, level);
85+
initDirect(dict, dict.position(), length, level, byReference ? 1 : 0);
7186

7287
if (nativePtr == 0L) {
7388
throw new IllegalStateException("ZSTD_createCDict failed");
7489
}
90+
if (byReference) {
91+
sharedDict = dict; // ensures the dict is not garbage collected while this object remains, and flags that we should not use native free.
92+
}
7593
// Ensures that even if ZstdDictCompress is created and published through a race, no thread could observe
7694
// nativePtr == 0.
7795
storeFence();
@@ -85,7 +103,11 @@ int level() {
85103
@Override
86104
void doClose() {
87105
if (nativePtr != 0) {
88-
free();
106+
if (sharedDict == null) {
107+
free();
108+
} else {
109+
sharedDict = null;
110+
}
89111
nativePtr = 0;
90112
}
91113
}

src/main/java/com/github/luben/zstd/ZstdDictDecompress.java

+23-3
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,11 @@ public class ZstdDictDecompress extends SharedDictBase {
1111

1212
private long nativePtr = 0L;
1313

14+
private ByteBuffer sharedDict = null;
15+
1416
private native void init(byte[] dict, int dict_offset, int dict_size);
1517

16-
private native void initDirect(ByteBuffer dict, int dict_offset, int dict_size);
18+
private native void initDirect(ByteBuffer dict, int dict_offset, int dict_size, int byReference);
1719

1820
private native void free();
1921

@@ -52,6 +54,17 @@ public ZstdDictDecompress(byte[] dict, int offset, int length) {
5254
* @param dict Direct ByteBuffer containing dictionary using position and limit to define range in buffer.
5355
*/
5456
public ZstdDictDecompress(ByteBuffer dict) {
57+
this(dict, false);
58+
}
59+
60+
/**
61+
* Create a new dictionary for use with fast decompress.
62+
* If byReference is true, then the native code does not copy the data but keeps a reference to the byte buffer, which must then not be modified before this context has been closed.
63+
*
64+
* @param dict Direct ByteBuffer containing dictionary using position and limit to define range in buffer.
65+
* @param byReference tell the native part to use the byte buffer directly and not copy the data when true.
66+
*/
67+
public ZstdDictDecompress(ByteBuffer dict, boolean byReference) {
5568

5669
int length = dict.limit() - dict.position();
5770
if (!dict.isDirect()) {
@@ -60,11 +73,14 @@ public ZstdDictDecompress(ByteBuffer dict) {
6073
if (length < 0) {
6174
throw new IllegalArgumentException("dict cannot be empty.");
6275
}
63-
initDirect(dict, dict.position(), length);
76+
initDirect(dict, dict.position(), length, byReference ? 1 : 0);
6477

6578
if (nativePtr == 0L) {
6679
throw new IllegalStateException("ZSTD_createDDict failed");
6780
}
81+
if (byReference) {
82+
sharedDict = dict; // ensures the dict is not garbage collected while this object remains, and flags that we should not use native free.
83+
}
6884
// Ensures that even if ZstdDictDecompress is created and published through a race, no thread could observe
6985
// nativePtr == 0.
7086
storeFence();
@@ -74,7 +90,11 @@ public ZstdDictDecompress(ByteBuffer dict) {
7490
@Override
7591
void doClose() {
7692
if (nativePtr != 0) {
77-
free();
93+
if (sharedDict == null) {
94+
free();
95+
} else {
96+
sharedDict = null;
97+
}
7898
nativePtr = 0;
7999
}
80100
}

src/main/native/jni_fast_zstd.c

+16-6
Original file line numberDiff line numberDiff line change
@@ -32,17 +32,22 @@ JNIEXPORT void JNICALL Java_com_github_luben_zstd_ZstdDictCompress_init
3232
/*
3333
* Class: com_github_luben_zstd_ZstdDictCompress
3434
* Method: init
35-
* Signature: (Ljava/nio/ByteBuffer;III)V
35+
* Signature: (Ljava/nio/ByteBuffer;IIII)V
3636
*/
3737
JNIEXPORT void JNICALL Java_com_github_luben_zstd_ZstdDictCompress_initDirect
38-
(JNIEnv *env, jobject obj, jobject dict, jint dict_offset, jint dict_size, jint level)
38+
(JNIEnv *env, jobject obj, jobject dict, jint dict_offset, jint dict_size, jint level, jint byReference)
3939
{
4040
jclass clazz = (*env)->GetObjectClass(env, obj);
4141
compress_dict = (*env)->GetFieldID(env, clazz, "nativePtr", "J");
4242
if (NULL == dict) return;
4343
void *dict_buff = (*env)->GetDirectBufferAddress(env, dict);
4444
if (NULL == dict_buff) return;
45-
ZSTD_CDict* cdict = ZSTD_createCDict(((char *)dict_buff) + dict_offset, dict_size, level);
45+
ZSTD_CDict* cdict = NULL;
46+
if (byReference == 0) {
47+
cdict = ZSTD_createCDict(((char *)dict_buff) + dict_offset, dict_size, level);
48+
} else {
49+
cdict = ZSTD_createCDict_byReference(((char *)dict_buff) + dict_offset, dict_size, level);
50+
}
4651
if (NULL == cdict) return;
4752
(*env)->SetLongField(env, obj, compress_dict, (jlong)(intptr_t) cdict);
4853
}
@@ -85,17 +90,22 @@ JNIEXPORT void JNICALL Java_com_github_luben_zstd_ZstdDictDecompress_init
8590
/*
8691
* Class: com_github_luben_zstd_ZstdDictDecompress
8792
* Method: initDirect
88-
* Signature: (Ljava/nio/ByteBuffer;II)V
93+
* Signature: (Ljava/nio/ByteBuffer;III)V
8994
*/
9095
JNIEXPORT void JNICALL Java_com_github_luben_zstd_ZstdDictDecompress_initDirect
91-
(JNIEnv *env, jobject obj, jobject dict, jint dict_offset, jint dict_size)
96+
(JNIEnv *env, jobject obj, jobject dict, jint dict_offset, jint dict_size, jint byReference)
9297
{
9398
jclass clazz = (*env)->GetObjectClass(env, obj);
9499
decompress_dict = (*env)->GetFieldID(env, clazz, "nativePtr", "J");
95100
if (NULL == dict) return;
96101
void *dict_buff = (*env)->GetDirectBufferAddress(env, dict);
97102

98-
ZSTD_DDict* ddict = ZSTD_createDDict(((char *)dict_buff) + dict_offset, dict_size);
103+
ZSTD_DDict* ddict = NULL;
104+
if (byReference == 0) {
105+
ddict = ZSTD_createDDict(((char *)dict_buff) + dict_offset, dict_size);
106+
} else {
107+
ddict = ZSTD_createDDict_byReference(((char *)dict_buff) + dict_offset, dict_size);
108+
}
99109

100110
if (NULL == ddict) return;
101111
(*env)->SetLongField(env, obj, decompress_dict, (jlong)(intptr_t) ddict);

src/test/scala/ZstdDict.scala

+4-3
Original file line numberDiff line numberDiff line change
@@ -104,17 +104,18 @@ class ZstdDictSpec extends AnyFlatSpec {
104104
assert(input.toSeq == decompressed.toSeq)
105105
}
106106

107-
it should s"round-trip compression/decompression ByteBuffers with fast dict at level $level with legacy $legacy" in {
107+
it should s"round-trip compression/decompression ByteBuffers with fast dict at level $level with byReference $legacy" in {
108+
val byReference = legacy // Reuse the variance flag here.
108109
val size = input.length
109110
val inBuf = ByteBuffer.allocateDirect(size)
110111
inBuf.put(input)
111112
inBuf.flip()
112-
val cdict = new ZstdDictCompress(dictInDirectByteBuffer, level)
113+
val cdict = new ZstdDictCompress(dictInDirectByteBuffer, level, byReference)
113114
val compressed = ByteBuffer.allocateDirect(Zstd.compressBound(size).toInt);
114115
Zstd.compress(compressed, inBuf, cdict)
115116
compressed.flip()
116117
cdict.close
117-
val ddict = new ZstdDictDecompress(dictInDirectByteBuffer)
118+
val ddict = new ZstdDictDecompress(dictInDirectByteBuffer, byReference)
118119
val decompressed = ByteBuffer.allocateDirect(size)
119120
Zstd.decompress(decompressed, compressed, ddict)
120121
decompressed.flip()

0 commit comments

Comments
 (0)