Skip to content

Commit ae1ad52

Browse files
Coder-256luben
authored andcommitted
Create sequence producer tests
1 parent e4ad211 commit ae1ad52

File tree

4 files changed

+263
-3
lines changed

4 files changed

+263
-3
lines changed

src/main/java/com/github/luben/zstd/Zstd.java

+4
Original file line numberDiff line numberDiff line change
@@ -561,6 +561,9 @@ public static long decompressDirectByteBufferFastDict(ByteBuffer dst, int dstOff
561561
public static native int loadDictCompress(long stream, byte[] dict, int dict_size);
562562
public static native int loadFastDictCompress(long stream, ZstdDictCompress dict);
563563
public static native void registerSequenceProducer(long stream, long seqProdState, long seqProdFunction);
564+
public static native void generateSequences(long stream, long outSeqs, long outSeqsSize, long src, long srcSize);
565+
static native long getBuiltinSequenceProducer(); // Used in tests
566+
static native long getStubSequenceProducer(); // Used in tests
564567
public static native int setCompressionChecksums(long stream, boolean useChecksums);
565568
public static native int setCompressionMagicless(long stream, boolean useMagicless);
566569
public static native int setCompressionLevel(long stream, int level);
@@ -578,6 +581,7 @@ public static long decompressDirectByteBufferFastDict(ByteBuffer dst, int dstOff
578581
public static native int setDecompressionLongMax(long stream, int windowLogMax);
579582
public static native int setDecompressionMagicless(long stream, boolean useMagicless);
580583
public static native int setRefMultipleDDicts(long stream, boolean useMultiple);
584+
public static native int setValidateSequences(long stream, boolean validateSequences);
581585

582586
/* Utility methods */
583587
/**

src/main/java/com/github/luben/zstd/ZstdCompressCtx.java

+18
Original file line numberDiff line numberDiff line change
@@ -312,6 +312,24 @@ public ZstdCompressCtx setSequenceProducerFallback(boolean fallbackFlag){
312312
}
313313
private static native void setSequenceProducerFallback0(long ptr, boolean fallbackFlag);
314314

315+
public ZstdCompressCtx setValidateSequences(boolean validateSequences) {
316+
ensureOpen();
317+
acquireSharedLock();
318+
try {
319+
long result = Zstd.setValidateSequences(nativePtr, validateSequences);
320+
if (Zstd.isError(result)) {
321+
throw new ZstdException(result);
322+
}
323+
} finally {
324+
releaseSharedLock();
325+
}
326+
return this;
327+
}
328+
329+
// Used in tests
330+
long getNativePtr() {
331+
return nativePtr;
332+
}
315333

316334
/**
317335
* Load compression dictionary to be used for subsequently compressed frames.

src/main/native/jni_zstd.c

+61
Original file line numberDiff line numberDiff line change
@@ -293,6 +293,57 @@ JNIEXPORT jint JNICALL Java_com_github_luben_zstd_Zstd_loadFastDictCompress
293293
return ZSTD_CCtx_refCDict((ZSTD_CCtx *)(intptr_t) stream, cdict);
294294
}
295295

296+
size_t builtinSequenceProducer(
297+
void* sequenceProducerState,
298+
ZSTD_Sequence* outSeqs, size_t outSeqsCapacity,
299+
const void* src, size_t srcSize,
300+
const void* dict, size_t dictSize,
301+
int compressionLevel,
302+
size_t windowSize
303+
) {
304+
ZSTD_CCtx *zc = (ZSTD_CCtx *)sequenceProducerState;
305+
int windowLog = 0;
306+
while (windowSize > 1) {
307+
windowLog++;
308+
windowSize >>= 1;
309+
}
310+
ZSTD_CCtx_setParameter(zc, ZSTD_c_compressionLevel, compressionLevel);
311+
ZSTD_CCtx_setParameter(zc, ZSTD_c_windowLog, windowSize);
312+
size_t numSeqs = ZSTD_generateSequences((ZSTD_CCtx *)sequenceProducerState, outSeqs, outSeqsCapacity, src, srcSize);
313+
return ZSTD_isError(numSeqs) ? ZSTD_SEQUENCE_PRODUCER_ERROR : numSeqs;
314+
}
315+
316+
size_t stubSequenceProducer(
317+
void* sequenceProducerState,
318+
ZSTD_Sequence* outSeqs, size_t outSeqsCapacity,
319+
const void* src, size_t srcSize,
320+
const void* dict, size_t dictSize,
321+
int compressionLevel,
322+
size_t windowSize
323+
) {
324+
return ZSTD_SEQUENCE_PRODUCER_ERROR;
325+
}
326+
327+
/*
328+
* Class: com_github_luben_zstd_Zstd
329+
* Method: getBuiltinSequenceProducer
330+
* Signature: ()J
331+
*/
332+
JNIEXPORT jlong JNICALL Java_com_github_luben_zstd_Zstd_getBuiltinSequenceProducer
333+
(JNIEnv *env, jclass obj) {
334+
return (jlong)(intptr_t)&builtinSequenceProducer;
335+
}
336+
337+
/*
338+
* Class: com_github_luben_zstd_Zstd
339+
* Method: getBuiltinSequenceProducer
340+
* Signature: ()J
341+
*/
342+
JNIEXPORT jlong JNICALL Java_com_github_luben_zstd_Zstd_getStubSequenceProducer
343+
(JNIEnv *env, jclass obj) {
344+
return (jlong)(intptr_t)&stubSequenceProducer;
345+
}
346+
296347
/*
297348
* Class: com_github_luben_zstd_Zstd
298349
* Method: registerSequenceProducer
@@ -489,6 +540,16 @@ JNIEXPORT jint JNICALL Java_com_github_luben_zstd_Zstd_setRefMultipleDDicts
489540
return ZSTD_DCtx_setParameter((ZSTD_DCtx *)(intptr_t) stream, ZSTD_d_refMultipleDDicts, value);
490541
}
491542

543+
/*
544+
* Class: com_github_luben_zstd_Zstd
545+
* Method: setValidateSequences
546+
* Signature: (JZ)I
547+
*/
548+
JNIEXPORT jint JNICALL Java_com_github_luben_zstd_Zstd_setValidateSequences
549+
(JNIEnv *env, jclass obj, jlong stream, jboolean validateSequences) {
550+
return ZSTD_CCtx_setParameter((ZSTD_CCtx *)(intptr_t) stream, ZSTD_c_validateSequences, validateSequences);
551+
}
552+
492553
/*
493554
* Class: com_github_luben_zstd_Zstd
494555
* Methods: header constants access

src/test/scala/Zstd.scala

+180-3
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,9 @@ import java.nio.channels.FileChannel
99
import java.nio.channels.FileChannel.MapMode
1010
import java.nio.charset.Charset
1111
import java.nio.file.StandardOpenOption
12-
import scala.io._
12+
import scala.annotation.unused
1313
import scala.collection.mutable.WrappedArray
14+
import scala.io._
1415
import scala.util.Using
1516

1617
class ZstdSpec extends AnyFlatSpec with ScalaCheckPropertyChecks {
@@ -1105,7 +1106,7 @@ class ZstdSpec extends AnyFlatSpec with ScalaCheckPropertyChecks {
11051106
}
11061107
}
11071108

1108-
"streaming compressiong and decompression" should "roundtrip" in {
1109+
"streaming compression and decompression" should "roundtrip" in {
11091110
Using.Manager { use =>
11101111
val cctx = use(new ZstdCompressCtx())
11111112
val dctx = use(new ZstdDecompressCtx())
@@ -1149,7 +1150,7 @@ class ZstdSpec extends AnyFlatSpec with ScalaCheckPropertyChecks {
11491150
decompressedBuffer.flip()
11501151

11511152
val comparison = inputBuffer.compareTo(decompressedBuffer)
1152-
comparison == 0 && Zstd.decompressedSize(compressedBuffer) == size && Zstd.getFrameContentSize(compressedBuffer) == size
1153+
assert(comparison == 0 && Zstd.decompressedSize(compressedBuffer) == size && Zstd.getFrameContentSize(compressedBuffer) == size)
11531154
}
11541155
}
11551156
}.get
@@ -1211,4 +1212,180 @@ class ZstdSpec extends AnyFlatSpec with ScalaCheckPropertyChecks {
12111212
}
12121213
}
12131214
}.get
1215+
1216+
it should "be able to use a sequence producer" in {
1217+
Using.Manager { use =>
1218+
val cctx = use(new ZstdCompressCtx())
1219+
val cctx2 = use(new ZstdCompressCtx())
1220+
val dctx = use(new ZstdDecompressCtx())
1221+
1222+
forAll { input: Array[Byte] =>
1223+
{
1224+
val size = input.length
1225+
val inputBuffer = ByteBuffer.allocateDirect(size)
1226+
inputBuffer.put(input)
1227+
inputBuffer.flip()
1228+
cctx.reset()
1229+
cctx.setLevel(9)
1230+
val seqProd = new SequenceProducer {
1231+
def getFunctionPointer(): Long = {
1232+
Zstd.getBuiltinSequenceProducer()
1233+
}
1234+
1235+
def createState(): Long = {
1236+
cctx2.getNativePtr()
1237+
}
1238+
1239+
def freeState(@unused state: Long) = {}
1240+
}
1241+
cctx.registerSequenceProducer(seqProd)
1242+
cctx.setValidateSequences(true)
1243+
cctx.setSequenceProducerFallback(false)
1244+
cctx.setPledgedSrcSize(size)
1245+
val compressedBuffer = ByteBuffer.allocateDirect(Zstd.compressBound(size).toInt)
1246+
while (inputBuffer.hasRemaining) {
1247+
compressedBuffer.limit(compressedBuffer.position() + 1)
1248+
cctx.compressDirectByteBufferStream(compressedBuffer, inputBuffer, EndDirective.CONTINUE)
1249+
}
1250+
1251+
var frameProgression = cctx.getFrameProgression()
1252+
assert(frameProgression.getIngested() == size)
1253+
assert(frameProgression.getFlushed() == compressedBuffer.position())
1254+
1255+
compressedBuffer.limit(compressedBuffer.capacity())
1256+
val done = cctx.compressDirectByteBufferStream(compressedBuffer, inputBuffer, EndDirective.END)
1257+
assert(done)
1258+
1259+
frameProgression = cctx.getFrameProgression()
1260+
assert(frameProgression.getConsumed() == size)
1261+
1262+
compressedBuffer.flip()
1263+
val decompressedBuffer = ByteBuffer.allocateDirect(size)
1264+
dctx.reset()
1265+
while (compressedBuffer.hasRemaining) {
1266+
if (decompressedBuffer.limit() < decompressedBuffer.position()) {
1267+
decompressedBuffer.limit(compressedBuffer.position() + 1)
1268+
}
1269+
dctx.decompressDirectByteBufferStream(decompressedBuffer, compressedBuffer)
1270+
}
1271+
1272+
inputBuffer.rewind()
1273+
compressedBuffer.rewind()
1274+
decompressedBuffer.flip()
1275+
1276+
val comparison = inputBuffer.compareTo(decompressedBuffer)
1277+
assert(comparison == 0 && Zstd.decompressedSize(compressedBuffer) == size && Zstd.getFrameContentSize(compressedBuffer) == size)
1278+
}
1279+
}
1280+
}.get
1281+
}
1282+
1283+
it should "fail with a stub sequence producer" in {
1284+
Using.Manager { use =>
1285+
val cctx = use(new ZstdCompressCtx())
1286+
1287+
forAll(minSize(32)) { input: Array[Byte] =>
1288+
{
1289+
val size = input.length
1290+
val inputBuffer = ByteBuffer.allocateDirect(size)
1291+
inputBuffer.put(input)
1292+
inputBuffer.flip()
1293+
cctx.reset()
1294+
cctx.setLevel(9)
1295+
1296+
val seqProd = new SequenceProducer {
1297+
def getFunctionPointer(): Long = {
1298+
Zstd.getStubSequenceProducer()
1299+
}
1300+
1301+
def createState(): Long = { 0 }
1302+
def freeState(@unused state: Long) = { 0 }
1303+
}
1304+
1305+
cctx.registerSequenceProducer(seqProd)
1306+
cctx.setValidateSequences(true)
1307+
cctx.setSequenceProducerFallback(false)
1308+
cctx.setPledgedSrcSize(size)
1309+
1310+
val compressedBuffer = ByteBuffer.allocateDirect(Zstd.compressBound(size).toInt)
1311+
try {
1312+
while (inputBuffer.hasRemaining) {
1313+
compressedBuffer.limit(compressedBuffer.position() + 1)
1314+
cctx.compressDirectByteBufferStream(compressedBuffer, inputBuffer, EndDirective.CONTINUE)
1315+
}
1316+
cctx.compressDirectByteBufferStream(compressedBuffer, inputBuffer, EndDirective.END)
1317+
fail("compression succeeded, but should have failed")
1318+
} catch {
1319+
case _: ZstdException => // compression should throw a ZstdException
1320+
}
1321+
}
1322+
}
1323+
}.get
1324+
}
1325+
1326+
it should "succeed with a stub sequence producer and software fallback" in {
1327+
Using.Manager { use =>
1328+
val cctx = use(new ZstdCompressCtx())
1329+
val dctx = use(new ZstdDecompressCtx())
1330+
1331+
forAll { input: Array[Byte] =>
1332+
{
1333+
val size = input.length
1334+
val inputBuffer = ByteBuffer.allocateDirect(size)
1335+
inputBuffer.put(input)
1336+
inputBuffer.flip()
1337+
cctx.reset()
1338+
cctx.setLevel(9)
1339+
1340+
val seqProd = new SequenceProducer {
1341+
def getFunctionPointer(): Long = {
1342+
Zstd.getStubSequenceProducer()
1343+
}
1344+
1345+
def createState(): Long = { 0 }
1346+
def freeState(@unused state: Long) = { 0 }
1347+
}
1348+
1349+
cctx.registerSequenceProducer(seqProd)
1350+
cctx.setValidateSequences(true)
1351+
cctx.setSequenceProducerFallback(true) // !!
1352+
cctx.setPledgedSrcSize(size)
1353+
1354+
val compressedBuffer = ByteBuffer.allocateDirect(Zstd.compressBound(size).toInt)
1355+
while (inputBuffer.hasRemaining) {
1356+
compressedBuffer.limit(compressedBuffer.position() + 1)
1357+
cctx.compressDirectByteBufferStream(compressedBuffer, inputBuffer, EndDirective.CONTINUE)
1358+
}
1359+
1360+
var frameProgression = cctx.getFrameProgression()
1361+
assert(frameProgression.getIngested() == size)
1362+
assert(frameProgression.getFlushed() == compressedBuffer.position())
1363+
1364+
compressedBuffer.limit(compressedBuffer.capacity())
1365+
val done = cctx.compressDirectByteBufferStream(compressedBuffer, inputBuffer, EndDirective.END)
1366+
assert(done)
1367+
1368+
frameProgression = cctx.getFrameProgression()
1369+
assert(frameProgression.getConsumed() == size)
1370+
1371+
compressedBuffer.flip()
1372+
val decompressedBuffer = ByteBuffer.allocateDirect(size)
1373+
dctx.reset()
1374+
while (compressedBuffer.hasRemaining) {
1375+
if (decompressedBuffer.limit() < decompressedBuffer.position()) {
1376+
decompressedBuffer.limit(compressedBuffer.position() + 1)
1377+
}
1378+
dctx.decompressDirectByteBufferStream(decompressedBuffer, compressedBuffer)
1379+
}
1380+
1381+
inputBuffer.rewind()
1382+
compressedBuffer.rewind()
1383+
decompressedBuffer.flip()
1384+
1385+
val comparison = inputBuffer.compareTo(decompressedBuffer)
1386+
assert(comparison == 0 && Zstd.decompressedSize(compressedBuffer) == size && Zstd.getFrameContentSize(compressedBuffer) == size)
1387+
}
1388+
}
1389+
}.get
1390+
}
12141391
}

0 commit comments

Comments
 (0)