Skip to content

Commit 510bbd6

Browse files
benjaminpluben
authored andcommitted
Add methods for streaming (de)compression of direct ByteBuffers.
The goal of this change is to enable the highest-performance (de)compression of streams. `ZSTD_compressStream2` and `ZSTD_decompressStream` are exposed on `ZstdCompressCtx` and `ZstdDecompressCtx` respectively. They work in terms of direct `ByteBuffer`s and allow non-blocking, incremental handling of streams. The new `reset()` methods on `ZstdCompressCtx` and `ZstdDecompressCtx` allow reusing contexts, avoiding unnecessary allocations and deallocations.
1 parent 62b9dad commit 510bbd6

File tree

5 files changed

+308
-1
lines changed

5 files changed

+308
-1
lines changed
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
package com.github.luben.zstd;
2+
3+
/** Enum that expresses desired flushing for a streaming compression call.
4+
*
5+
* @see ZstdCompressCtx#compressDirectByteBufferStream
6+
*/
7+
public enum EndDirective {
8+
CONTINUE(0),
9+
FLUSH(1),
10+
END(2);
11+
12+
private final int value;
13+
private EndDirective(int value) {
14+
this.value = value;
15+
}
16+
17+
int value() {
18+
return value;
19+
}
20+
}

src/main/java/com/github/luben/zstd/ZstdCompressCtx.java

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,72 @@ public ZstdCompressCtx loadDict(byte[] dict) {
177177
}
178178
private native long loadCDict0(byte[] dict);
179179

180+
private void ensureOpen() {
181+
if (nativePtr == 0) {
182+
throw new IllegalStateException("Compression context is closed");
183+
}
184+
}
185+
186+
/**
187+
* Clear all state and parameters from the compression context. This leaves the object in a
188+
* state identical to a newly created compression context.
189+
*/
190+
public void reset() {
191+
ensureOpen();
192+
long result = reset0();
193+
if (Zstd.isError(result)) {
194+
throw new ZstdException(result);
195+
}
196+
}
197+
private native long reset0();
198+
199+
/**
200+
* Promise to compress a certain number of source bytes. Knowing the number of bytes to compress
201+
* up front helps to choose proper compression settings and size internal buffers. Additionally,
202+
* the pledged size is stored in the header of the output stream, allowing decompressors to know
203+
* how much uncompressed data to expect.
204+
*
205+
* Attempting to compress more or less than than the pledged size will result in an error.
206+
*/
207+
public void setPledgedSrcSize(long srcSize) {
208+
ensureOpen();
209+
long result = setPledgedSrcSize0(srcSize);
210+
if (Zstd.isError(result)) {
211+
throw new ZstdException(result);
212+
}
213+
}
214+
private native long setPledgedSrcSize0(long srcSize);
215+
216+
/**
217+
* Compress as much of the <code>src</code> {@link ByteBuffer} into the <code>dst</code> {@link
218+
* ByteBuffer} as possible.
219+
*
220+
* @param dst destination of compressed data
221+
* @param src buffer to compress
222+
* @param endOp directive for handling the end of the stream
223+
* @return true if all state has been flushed from internal buffers
224+
*/
225+
public boolean compressDirectByteBufferStream(ByteBuffer dst, ByteBuffer src, EndDirective endOp) {
226+
ensureOpen();
227+
long result = compressDirectByteBufferStream0(dst, dst.position(), dst.limit(), src, src.position(), src.limit(), endOp.value());
228+
if ((result & 0x80000000L) != 0) {
229+
long code = result & 0xFF;
230+
throw new ZstdException(code, Zstd.getErrorName(code));
231+
}
232+
src.position((int)(result & 0x7FFFFFFF));
233+
dst.position((int)(result >>> 32) & 0x7FFFFFFF);
234+
return (result >>> 63) == 1;
235+
}
236+
237+
/**
238+
* 4 pieces of information are packed into the return value of this method, which must be
239+
* treated as an unsigned long. The highest bit is set if all data has been flushed from
240+
* internal buffers. The next 31 bits are the new position of the destination buffer. The next
241+
* bit is set if an error occurred. If an error occurred, the lowest 31 bits encode a zstd error
242+
* code. Otherwise, the lowest 31 bits are the new position of the source buffer.
243+
*/
244+
private native long compressDirectByteBufferStream0(ByteBuffer dst, int dstOffset, int dstSize, ByteBuffer src, int srcSize, int srcOffset, int endOp);
245+
180246
/**
181247
* Compresses buffer 'srcBuff' into buffer 'dstBuff' reusing this ZstdCompressCtx.
182248
*

src/main/java/com/github/luben/zstd/ZstdDecompressCtx.java

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,51 @@ public ZstdDecompressCtx loadDict(byte[] dict) {
8787
}
8888
private native long loadDDict0(byte[] dict);
8989

90+
/**
91+
* Clear all state and parameters from the decompression context. This leaves the object in a
92+
* state identical to a newly created decompression context.
93+
*/
94+
public void reset() {
95+
ensureOpen();
96+
reset0();
97+
}
98+
private native void reset0();
99+
100+
private void ensureOpen() {
101+
if (nativePtr == 0) {
102+
throw new IllegalStateException("Decompression context is closed");
103+
}
104+
}
105+
106+
/**
107+
* Decompress as much of the <code>src</code> {@link ByteBuffer} into the <code>dst</code> {@link
108+
* ByteBuffer} as possible.
109+
*
110+
* @param dst destination of uncompressed data
111+
* @param src buffer to decompress
112+
* @return true if all state has been flushed from internal buffers
113+
*/
114+
public boolean decompressDirectByteBufferStream(ByteBuffer dst, ByteBuffer src) {
115+
ensureOpen();
116+
long result = decompressDirectByteBufferStream0(dst, dst.position(), dst.limit(), src, src.position(), src.limit());
117+
if ((result & 0x80000000L) != 0) {
118+
long code = result & 0xFF;
119+
throw new ZstdException(code, Zstd.getErrorName(code));
120+
}
121+
src.position((int)(result & 0x7FFFFFFF));
122+
dst.position((int)(result >>> 32) & 0x7FFFFFFF);
123+
return (result >>> 63) == 1;
124+
}
125+
126+
/**
127+
* 4 pieces of information are packed into the return value of this method, which must be
128+
* treated as an unsigned long. The highest bit is set if all data has been flushed from
129+
* internal buffers. The next 31 bits are the new position of the destination buffer. The next
130+
* bit is set if an error occurred. If an error occurred, the lowest 31 bits encode a zstd error
131+
* code. Otherwise, the lowest 31 bits are the new position of the source buffer.
132+
*/
133+
private native long decompressDirectByteBufferStream0(ByteBuffer dst, int dstOffset, int dstSize, ByteBuffer src, int srcOffset, int srcSize);
134+
90135
/**
91136
* Decompresses buffer 'srcBuff' into buffer 'dstBuff' using this ZstdDecompressCtx.
92137
*

src/main/native/jni_fast_zstd.c

Lines changed: 134 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -319,6 +319,75 @@ JNIEXPORT jlong JNICALL Java_com_github_luben_zstd_ZstdCompressCtx_loadCDict0
319319
return result;
320320
}
321321

322+
/*
323+
* Class: com_github_luben_zstd_ZstdCompressCtx
324+
* Method: reset0
325+
* Signature: (L)J
326+
*/
327+
JNIEXPORT jlong JNICALL Java_com_github_luben_zstd_ZstdCompressCtx_reset0
328+
(JNIEnv *env, jclass jctx) {
329+
ZSTD_CCtx* cctx = (ZSTD_CCtx*)(intptr_t)(*env)->GetLongField(env, jctx, compress_ctx_nativePtr);
330+
return ZSTD_CCtx_reset(cctx, ZSTD_reset_session_and_parameters);
331+
}
332+
333+
JNIEXPORT jlong JNICALL Java_com_github_luben_zstd_ZstdCompressCtx_setPledgedSrcSize0
334+
(JNIEnv *env, jclass jctx, jlong src_size) {
335+
if (src_size < 0) {
336+
return ZSTD_error_srcSize_wrong;
337+
}
338+
ZSTD_CCtx* cctx = (ZSTD_CCtx*)(intptr_t)(*env)->GetLongField(env, jctx, compress_ctx_nativePtr);
339+
return ZSTD_CCtx_setPledgedSrcSize(cctx, (unsigned long long)src_size);
340+
}
341+
342+
static size_t compress_direct_buffer_stream
343+
(JNIEnv *env, jclass jctx, jobject dst, jint *dst_offset, jint dst_size, jobject src, jint *src_offset, jint src_size, jint end_op) {
344+
if (NULL == dst) return ZSTD_ERROR(dstSize_tooSmall);
345+
if (NULL == src) return ZSTD_ERROR(srcSize_wrong);
346+
if (0 > *dst_offset) return ZSTD_ERROR(dstSize_tooSmall);
347+
if (0 > *src_offset) return ZSTD_ERROR(srcSize_wrong);
348+
if (0 > src_size) return ZSTD_ERROR(srcSize_wrong);
349+
350+
jsize dst_cap = (*env)->GetDirectBufferCapacity(env, dst);
351+
if (dst_size > dst_cap) return ZSTD_ERROR(dstSize_tooSmall);
352+
jsize src_cap = (*env)->GetDirectBufferCapacity(env, src);
353+
if (src_size > src_cap) return ZSTD_ERROR(srcSize_wrong);
354+
ZSTD_CCtx* cctx = (ZSTD_CCtx*)(intptr_t)(*env)->GetLongField(env, jctx, compress_ctx_nativePtr);
355+
356+
ZSTD_outBuffer out;
357+
out.pos = *dst_offset;
358+
out.size = dst_size;
359+
out.dst = (*env)->GetDirectBufferAddress(env, dst);
360+
if (out.dst == NULL) return ZSTD_ERROR(memory_allocation);
361+
ZSTD_inBuffer in;
362+
in.pos = *src_offset;
363+
in.size = src_size;
364+
in.src = (*env)->GetDirectBufferAddress(env, src);
365+
if (in.src == NULL) return ZSTD_ERROR(memory_allocation);
366+
367+
size_t result = ZSTD_compressStream2(cctx, &out, &in, end_op);
368+
*dst_offset = out.pos;
369+
*src_offset = in.pos;
370+
return result;
371+
}
372+
373+
/*
374+
* Class: com_github_luben_zstd_ZstdCompressCtx
375+
* Method: compressDirectByteBufferStream0
376+
* Signature: (Ljava/nio/ByteBuffer;IILjava/nio/ByteBuffer;III)J
377+
*/
378+
JNIEXPORT jlong JNICALL Java_com_github_luben_zstd_ZstdCompressCtx_compressDirectByteBufferStream0
379+
(JNIEnv *env, jclass jctx, jobject dst, jint dst_offset, jint dst_size, jobject src, jint src_offset, jint src_size, jint end_op) {
380+
size_t result = compress_direct_buffer_stream(env, jctx, dst, &dst_offset, dst_size, src, &src_offset, src_size, end_op);
381+
if (ZSTD_isError(result)) {
382+
return (1ULL << 31) | ZSTD_getErrorCode(result);
383+
}
384+
jlong encoded_result = ((jlong)dst_offset << 32) | src_offset;
385+
if (result == 0) {
386+
encoded_result |= 1ULL << 63;
387+
}
388+
return encoded_result;
389+
}
390+
322391
/*
323392
* Class: com_github_luben_zstd_ZstdCompressCtx
324393
* Method: compressDirectByteBuffer0
@@ -450,10 +519,74 @@ JNIEXPORT jlong JNICALL Java_com_github_luben_zstd_ZstdDecompressCtx_loadDDict0
450519
return result;
451520
}
452521

522+
/*
523+
* Class: com_github_luben_zstd_ZstdDecompressCtx
524+
* Method: reset0
525+
* Signature: (L)J
526+
*/
527+
JNIEXPORT jlong JNICALL Java_com_github_luben_zstd_ZstdDecompressCtx_reset0
528+
(JNIEnv *env, jclass jctx) {
529+
ZSTD_DCtx* dctx = (ZSTD_DCtx*)(intptr_t)(*env)->GetLongField(env, jctx, compress_ctx_nativePtr);
530+
return ZSTD_DCtx_reset(dctx, ZSTD_reset_session_and_parameters);
531+
}
532+
533+
static size_t decompress_direct_buffer_stream
534+
(JNIEnv *env, jclass jctx, jobject dst, jint *dst_offset, jint dst_size, jobject src, jint *src_offset, jint src_size)
535+
{
536+
if (NULL == dst) return ZSTD_ERROR(dstSize_tooSmall);
537+
if (NULL == src) return ZSTD_ERROR(srcSize_wrong);
538+
if (0 > *dst_offset) return ZSTD_ERROR(dstSize_tooSmall);
539+
if (0 > *src_offset) return ZSTD_ERROR(srcSize_wrong);
540+
if (0 > dst_size) return ZSTD_ERROR(dstSize_tooSmall);
541+
if (0 > src_size) return ZSTD_ERROR(srcSize_wrong);
542+
543+
jsize dst_cap = (*env)->GetDirectBufferCapacity(env, dst);
544+
if (dst_size > dst_cap) return ZSTD_ERROR(dstSize_tooSmall);
545+
jsize src_cap = (*env)->GetDirectBufferCapacity(env, src);
546+
if (src_size > src_cap) return ZSTD_ERROR(srcSize_wrong);
547+
548+
ZSTD_DCtx* dctx = (ZSTD_DCtx*)(intptr_t)(*env)->GetLongField(env, jctx, decompress_ctx_nativePtr);
549+
550+
ZSTD_outBuffer out;
551+
out.pos = *dst_offset;
552+
out.size = dst_size;
553+
out.dst = (*env)->GetDirectBufferAddress(env, dst);
554+
if (out.dst == NULL) return ZSTD_ERROR(memory_allocation);
555+
ZSTD_inBuffer in;
556+
in.pos = *src_offset;
557+
in.size = src_size;
558+
in.src = (*env)->GetDirectBufferAddress(env, src);
559+
if (in.src == NULL) return ZSTD_ERROR(memory_allocation);
560+
561+
size_t result = ZSTD_decompressStream(dctx, &out, &in);
562+
*dst_offset = out.pos;
563+
*src_offset = in.pos;
564+
return result;
565+
}
566+
567+
/*
568+
* Class: com_github_luben_zstd_ZstdDecompressCtx
569+
* Method: decompressDirectByteBufferStream0
570+
* Signature: (Ljava/nio/ByteBuffer;IILjava/nio/ByteBuffer;II)J
571+
*/
572+
JNIEXPORT jlong JNICALL Java_com_github_luben_zstd_ZstdDecompressCtx_decompressDirectByteBufferStream0
573+
(JNIEnv *env, jclass jctx, jobject dst, jint dst_offset, jint dst_size, jobject src, jint src_offset, jint src_size)
574+
{
575+
size_t result = decompress_direct_buffer_stream(env, jctx, dst, &dst_offset, dst_size, src, &src_offset, src_size);
576+
if (ZSTD_isError(result)) {
577+
return (1ULL << 31) | ZSTD_getErrorCode(result);
578+
}
579+
jlong encoded_result = ((jlong)dst_offset << 32) | src_offset;
580+
if (result == 0) {
581+
encoded_result |= 1ULL << 63;
582+
}
583+
return encoded_result;
584+
}
585+
453586

454587
/*
455588
* Class: com_github_luben_zstd_ZstdDecompressCtx
456-
* Method: decompressDirectByteBuffe0
589+
* Method: decompressDirectByteBuffer0
457590
* Signature: (Ljava/nio/ByteBuffer;IILjava/nio/ByteBuffer;II)J
458591
*/
459592
JNIEXPORT jlong JNICALL Java_com_github_luben_zstd_ZstdDecompressCtx_decompressDirectByteBuffer0

src/test/scala/Zstd.scala

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ import java.nio.file.StandardOpenOption
1212

1313
import scala.io._
1414
import scala.collection.mutable.WrappedArray
15+
import scala.util.Using
1516

1617
class ZstdSpec extends FlatSpec with Checkers {
1718

@@ -878,4 +879,46 @@ class ZstdSpec extends FlatSpec with Checkers {
878879
Zstd.extractArray(buf.slice)
879880
}
880881
}
882+
883+
"streaming compressiong and decompression" should "roundtrip" in {
884+
Using.Manager { use =>
885+
val cctx = use(new ZstdCompressCtx())
886+
val dctx = use(new ZstdDecompressCtx())
887+
check { input: Array[Byte] =>
888+
{
889+
val size = input.length
890+
val inputBuffer = ByteBuffer.allocateDirect(size)
891+
inputBuffer.put(input)
892+
inputBuffer.flip()
893+
cctx.reset()
894+
cctx.setPledgedSrcSize(size)
895+
val compressedBuffer = ByteBuffer.allocateDirect(Zstd.compressBound(size).toInt)
896+
while (inputBuffer.hasRemaining) {
897+
compressedBuffer.limit(compressedBuffer.position() + 1)
898+
cctx.compressDirectByteBufferStream(compressedBuffer, inputBuffer, EndDirective.CONTINUE)
899+
}
900+
compressedBuffer.limit(compressedBuffer.capacity())
901+
val done = cctx.compressDirectByteBufferStream(compressedBuffer, inputBuffer, EndDirective.END)
902+
assert(done)
903+
904+
compressedBuffer.flip()
905+
val decompressedBuffer = ByteBuffer.allocateDirect(size)
906+
dctx.reset()
907+
while (compressedBuffer.hasRemaining) {
908+
if (decompressedBuffer.limit() < decompressedBuffer.position()) {
909+
decompressedBuffer.limit(compressedBuffer.position() + 1)
910+
}
911+
dctx.decompressDirectByteBufferStream(decompressedBuffer, compressedBuffer)
912+
}
913+
914+
inputBuffer.rewind()
915+
compressedBuffer.rewind()
916+
decompressedBuffer.flip()
917+
918+
val comparison = inputBuffer.compareTo(decompressedBuffer)
919+
comparison == 0 && Zstd.decompressedSize(compressedBuffer) == size
920+
}
921+
}
922+
}.get
923+
}
881924
}

0 commit comments

Comments
 (0)