Skip to content

Commit 15e5eee

Browse files
charlesconnellluben
authored andcommitted
Support decompression from byte array to ByteBuffer and vice-versa
1 parent 6a6b1b1 commit 15e5eee

File tree

4 files changed

+172
-16
lines changed

4 files changed

+172
-16
lines changed

src/main/java/com/github/luben/zstd/Zstd.java

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -418,6 +418,15 @@ public static long decompress(byte[] dst, byte[] src) {
418418
}
419419
}
420420

421+
public static int decompress(byte[] dst, ByteBuffer srcBuf) {
422+
ZstdDecompressCtx ctx = new ZstdDecompressCtx();
423+
try {
424+
return ctx.decompress(dst, srcBuf);
425+
} finally {
426+
ctx.close();
427+
}
428+
}
429+
421430
/**
422431
* Decompresses buffer 'src' into buffer 'dst'.
423432
*
@@ -1343,6 +1352,15 @@ public static int decompress(ByteBuffer dstBuf, ByteBuffer srcBuf) {
13431352
}
13441353
}
13451354

1355+
public static int decompress(ByteBuffer dstBuf, byte[] src) {
1356+
ZstdDecompressCtx ctx = new ZstdDecompressCtx();
1357+
try {
1358+
return ctx.decompress(dstBuf, src);
1359+
} finally {
1360+
ctx.close();
1361+
}
1362+
}
1363+
13461364
/**
13471365
* Decompress data
13481366
*

src/main/java/com/github/luben/zstd/ZstdDecompressCtx.java

Lines changed: 81 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -234,6 +234,60 @@ public int decompressByteArray(byte[] dstBuff, int dstOffset, int dstSize, byte[
234234

235235
private static native long decompressByteArray0(long nativePtr, byte[] dst, int dstOffset, int dstSize, byte[] src, int srcOffset, int srcSize);
236236

237+
public int decompressByteArrayToDirectByteBuffer(ByteBuffer dstBuff, int dstOffset, int dstSize, byte[] srcBuff, int srcOffset, int srcSize) {
238+
if (!dstBuff.isDirect()) {
239+
throw new IllegalArgumentException("dstBuff must be a direct buffer");
240+
}
241+
242+
Objects.checkFromIndexSize(srcOffset, srcSize, srcBuff.length);
243+
Objects.checkFromIndexSize(dstOffset, dstSize, dstBuff.limit());
244+
245+
ensureOpen();
246+
acquireSharedLock();
247+
248+
try {
249+
long size = decompressByteArrayToDirectByteBuffer0(nativePtr, dstBuff, dstOffset, dstSize, srcBuff, srcOffset, srcSize);
250+
if (Zstd.isError(size)) {
251+
throw new ZstdException(size);
252+
}
253+
if (size > Integer.MAX_VALUE) {
254+
throw new ZstdException(Zstd.errGeneric(), "Output size is greater than MAX_INT");
255+
}
256+
return (int) size;
257+
} finally {
258+
releaseSharedLock();
259+
}
260+
}
261+
262+
private static native long decompressByteArrayToDirectByteBuffer0(long nativePtr, ByteBuffer dst, int dstOffset, int dstSize, byte[] src, int srcOffset, int srcSize);
263+
264+
public int decompressDirectByteBufferToByteArray(byte[] dstBuff, int dstOffset, int dstSize, ByteBuffer srcBuff, int srcOffset, int srcSize) {
265+
if (!srcBuff.isDirect()) {
266+
throw new IllegalArgumentException("srcBuff must be a direct buffer");
267+
}
268+
269+
Objects.checkFromIndexSize(srcOffset, srcSize, srcBuff.limit());
270+
Objects.checkFromIndexSize(dstOffset, dstSize, dstBuff.length);
271+
272+
ensureOpen();
273+
acquireSharedLock();
274+
275+
try {
276+
long size = decompressDirectByteBufferToByteArray0(nativePtr, dstBuff, dstOffset, dstSize, srcBuff, srcOffset, srcSize);
277+
if (Zstd.isError(size)) {
278+
throw new ZstdException(size);
279+
}
280+
if (size > Integer.MAX_VALUE) {
281+
throw new ZstdException(Zstd.errGeneric(), "Output size is greater than MAX_INT");
282+
}
283+
return (int) size;
284+
} finally {
285+
releaseSharedLock();
286+
}
287+
}
288+
289+
private static native long decompressDirectByteBufferToByteArray0(long nativePtr, byte[] dst, int dstOffset, int dstSize, ByteBuffer src, int srcOffset, int srcSize);
290+
237291
/* Covenience methods */
238292

239293
/**
@@ -257,16 +311,38 @@ public int decompressByteArray(byte[] dstBuff, int dstOffset, int dstSize, byte[
257311
*/
258312
public int decompress(ByteBuffer dstBuf, ByteBuffer srcBuf) throws ZstdException {
259313
int size = decompressDirectByteBuffer(dstBuf, // decompress into dstBuf
260-
dstBuf.position(), // write decompressed data at offset position()
261-
dstBuf.limit() - dstBuf.position(), // write no more than limit() - position()
262-
srcBuf, // read compressed data from srcBuf
263-
srcBuf.position(), // read starting at offset position()
264-
srcBuf.limit() - srcBuf.position()); // read no more than limit() - position()
314+
dstBuf.position(), // write decompressed data at offset position()
315+
dstBuf.limit() - dstBuf.position(), // write no more than limit() - position()
316+
srcBuf, // read compressed data from srcBuf
317+
srcBuf.position(), // read starting at offset position()
318+
srcBuf.limit() - srcBuf.position()); // read no more than limit() - position()
265319
srcBuf.position(srcBuf.limit());
266320
dstBuf.position(dstBuf.position() + size);
267321
return size;
268322
}
269323

324+
public int decompress(ByteBuffer dstBuf, byte[] src) throws ZstdException {
325+
int size = decompressByteArrayToDirectByteBuffer(dstBuf, // decompress into dstBuf
326+
dstBuf.position(), // write decompressed data at offset position()
327+
dstBuf.limit() - dstBuf.position(), // write no more than limit() - position()
328+
src, // read compressed data from src
329+
0,
330+
src.length);
331+
dstBuf.position(dstBuf.position() + size);
332+
return size;
333+
}
334+
335+
public int decompress(byte[] dst, ByteBuffer srcBuf) throws ZstdException {
336+
int size = decompressDirectByteBufferToByteArray(dst, // decompress into dst
337+
0,
338+
dst.length,
339+
srcBuf, // read compressed data from srcBuf
340+
srcBuf.position(), // read starting at offset position()
341+
srcBuf.limit() - srcBuf.position()); // read no more than limit() - position()
342+
srcBuf.position(srcBuf.limit());
343+
return size;
344+
}
345+
270346
public ByteBuffer decompress(ByteBuffer srcBuf, int originalSize) throws ZstdException {
271347
ByteBuffer dstBuf = ByteBuffer.allocateDirect(originalSize);
272348
int size = decompressDirectByteBuffer(dstBuf, 0, originalSize, srcBuf, srcBuf.position(), srcBuf.limit() - srcBuf.position());

src/main/native/jni_fast_zstd.c

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -689,3 +689,69 @@ JNIEXPORT jlong JNICALL Java_com_github_luben_zstd_ZstdDecompressCtx_decompressB
689689
E2: (*env)->ReleasePrimitiveArrayCritical(env, dst, dst_buff, 0);
690690
E1: return size;
691691
}
692+
693+
/*
694+
* Class: com_github_luben_zstd_ZstdDecompressCtx
695+
* Method: decompressByteArrayToDirectByteBuffer0
696+
* Signature: (Ljava/nio/ByteBuffer;II[BII)I
697+
*/
698+
JNIEXPORT jlong JNICALL Java_com_github_luben_zstd_ZstdDecompressCtx_decompressByteArrayToDirectByteBuffer0
699+
(JNIEnv *env, jclass jclazz, jlong ptr, jobject dst, jint dst_offset, jint dst_size, jbyteArray src, jint src_offset, jint src_size) {
700+
size_t size = -ZSTD_error_memory_allocation;
701+
702+
if (NULL == dst) return -ZSTD_error_dstSize_tooSmall;
703+
if (NULL == src) return -ZSTD_error_srcSize_wrong;
704+
if (0 > dst_offset) return -ZSTD_error_dstSize_tooSmall;
705+
if (0 > src_offset) return -ZSTD_error_srcSize_wrong;
706+
if (0 > src_size) return -ZSTD_error_srcSize_wrong;
707+
708+
if (src_offset + src_size > (*env)->GetArrayLength(env, src)) return -ZSTD_error_srcSize_wrong;
709+
jsize dst_cap = (*env)->GetDirectBufferCapacity(env, dst);
710+
if (dst_offset + dst_size > dst_cap) return -ZSTD_error_dstSize_tooSmall;
711+
712+
ZSTD_DCtx* dctx = (ZSTD_DCtx*)(intptr_t)ptr;
713+
714+
char *dst_buff = (char*)(*env)->GetDirectBufferAddress(env, dst);
715+
if (dst_buff == NULL) return -ZSTD_error_memory_allocation;
716+
void *src_buff = (*env)->GetPrimitiveArrayCritical(env, src, NULL);
717+
if (src_buff == NULL) goto E1;
718+
719+
ZSTD_DCtx_reset(dctx, ZSTD_reset_session_only);
720+
size = ZSTD_decompressDCtx(dctx, ((char *)dst_buff) + dst_offset, (size_t) dst_size, ((char *)src_buff) + src_offset, (size_t) src_size);
721+
722+
(*env)->ReleasePrimitiveArrayCritical(env, src, src_buff, JNI_ABORT);
723+
E1: return size;
724+
}
725+
726+
/*
727+
* Class: com_github_luben_zstd_ZstdDecompressCtx
728+
* Method: decompressDirectByteBufferToByteArray0
729+
* Signature: ([BIILjava/nio/ByteBuffer;II)I
730+
*/
731+
JNIEXPORT jlong JNICALL Java_com_github_luben_zstd_ZstdDecompressCtx_decompressDirectByteBufferToByteArray0
732+
(JNIEnv *env, jclass jclazz, jlong ptr, jbyteArray dst, jint dst_offset, jint dst_size, jobject src, jint src_offset, jint src_size) {
733+
size_t size = -ZSTD_error_memory_allocation;
734+
735+
if (NULL == dst) return -ZSTD_error_dstSize_tooSmall;
736+
if (NULL == src) return -ZSTD_error_srcSize_wrong;
737+
if (0 > dst_offset) return -ZSTD_error_dstSize_tooSmall;
738+
if (0 > src_offset) return -ZSTD_error_srcSize_wrong;
739+
if (0 > src_size) return -ZSTD_error_srcSize_wrong;
740+
741+
if (dst_offset + dst_size > (*env)->GetArrayLength(env, dst)) return -ZSTD_error_dstSize_tooSmall;
742+
jsize src_cap = (*env)->GetDirectBufferCapacity(env, src);
743+
if (src_offset + src_size > src_cap) return -ZSTD_error_srcSize_wrong;
744+
745+
ZSTD_DCtx* dctx = (ZSTD_DCtx*)(intptr_t)ptr;
746+
747+
void *dst_buff = (*env)->GetPrimitiveArrayCritical(env, dst, NULL);
748+
if (dst_buff == NULL) goto E1;
749+
char *src_buff = (char*)(*env)->GetDirectBufferAddress(env, src);
750+
if (src_buff == NULL) return -ZSTD_error_memory_allocation;
751+
752+
ZSTD_DCtx_reset(dctx, ZSTD_reset_session_only);
753+
size = ZSTD_decompressDCtx(dctx, ((char *)dst_buff) + dst_offset, (size_t) dst_size, ((char *)src_buff) + src_offset, (size_t) src_size);
754+
755+
(*env)->ReleasePrimitiveArrayCritical(env, dst, dst_buff, 0);
756+
E1: return size;
757+
}

src/test/scala/Zstd.scala

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -86,13 +86,10 @@ class ZstdSpec extends AnyFlatSpec with ScalaCheckPropertyChecks {
8686
val size = input.length
8787
val compressed = Zstd.compress(input, level)
8888

89-
val compressedBuffer = ByteBuffer.allocateDirect(Zstd.compressBound(size.toLong).toInt)
90-
compressedBuffer.put(compressed)
91-
compressedBuffer.limit(compressedBuffer.position())
92-
compressedBuffer.flip()
93-
94-
val decompressedBuffer = Zstd.decompress(compressedBuffer, size)
95-
val decompressed = new Array[Byte](size)
89+
val decompressedBuffer = ByteBuffer.allocateDirect(size)
90+
val decompressedSize = Zstd.decompress(decompressedBuffer, compressed);
91+
val decompressed = new Array[Byte](decompressedSize)
92+
decompressedBuffer.flip();
9693
decompressedBuffer.get(decompressed)
9794
input.toSeq == decompressed.toSeq
9895
}
@@ -104,11 +101,10 @@ class ZstdSpec extends AnyFlatSpec with ScalaCheckPropertyChecks {
104101
val inputBuffer = ByteBuffer.allocateDirect(size)
105102
inputBuffer.put(input)
106103
inputBuffer.flip()
107-
val compressedBuffer = Zstd.compress(inputBuffer, level)
108-
val compressed = new Array[Byte](compressedBuffer.limit() - compressedBuffer.position())
109-
compressedBuffer.get(compressed)
104+
val compressedBuffer = Zstd.compress(inputBuffer, level)
110105

111-
val decompressed = Zstd.decompress(compressed, size)
106+
val decompressed = new Array[Byte](size)
107+
Zstd.decompress(decompressed, compressedBuffer)
112108
input.toSeq == decompressed.toSeq
113109
}
114110
}

0 commit comments

Comments
 (0)