@@ -9,8 +9,9 @@ import java.nio.channels.FileChannel
9
9
import java .nio .channels .FileChannel .MapMode
10
10
import java .nio .charset .Charset
11
11
import java .nio .file .StandardOpenOption
12
- import scala .io . _
12
+ import scala .annotation . unused
13
13
import scala .collection .mutable .WrappedArray
14
+ import scala .io ._
14
15
import scala .util .Using
15
16
16
17
class ZstdSpec extends AnyFlatSpec with ScalaCheckPropertyChecks {
@@ -1105,7 +1106,7 @@ class ZstdSpec extends AnyFlatSpec with ScalaCheckPropertyChecks {
1105
1106
}
1106
1107
}
1107
1108
1108
- " streaming compressiong and decompression" should " roundtrip" in {
1109
+ " streaming compression and decompression" should " roundtrip" in {
1109
1110
Using .Manager { use =>
1110
1111
val cctx = use(new ZstdCompressCtx ())
1111
1112
val dctx = use(new ZstdDecompressCtx ())
@@ -1149,7 +1150,7 @@ class ZstdSpec extends AnyFlatSpec with ScalaCheckPropertyChecks {
1149
1150
decompressedBuffer.flip()
1150
1151
1151
1152
val comparison = inputBuffer.compareTo(decompressedBuffer)
1152
- comparison == 0 && Zstd .decompressedSize(compressedBuffer) == size && Zstd .getFrameContentSize(compressedBuffer) == size
1153
+ assert( comparison == 0 && Zstd .decompressedSize(compressedBuffer) == size && Zstd .getFrameContentSize(compressedBuffer) == size)
1153
1154
}
1154
1155
}
1155
1156
}.get
@@ -1211,4 +1212,180 @@ class ZstdSpec extends AnyFlatSpec with ScalaCheckPropertyChecks {
1211
1212
}
1212
1213
}
1213
1214
}.get
1215
+
1216
+ it should " be able to use a sequence producer" in {
1217
+ Using .Manager { use =>
1218
+ val cctx = use(new ZstdCompressCtx ())
1219
+ val cctx2 = use(new ZstdCompressCtx ())
1220
+ val dctx = use(new ZstdDecompressCtx ())
1221
+
1222
+ forAll { input : Array [Byte ] =>
1223
+ {
1224
+ val size = input.length
1225
+ val inputBuffer = ByteBuffer .allocateDirect(size)
1226
+ inputBuffer.put(input)
1227
+ inputBuffer.flip()
1228
+ cctx.reset()
1229
+ cctx.setLevel(9 )
1230
+ val seqProd = new SequenceProducer {
1231
+ def getFunctionPointer (): Long = {
1232
+ Zstd .getBuiltinSequenceProducer()
1233
+ }
1234
+
1235
+ def createState (): Long = {
1236
+ cctx2.getNativePtr()
1237
+ }
1238
+
1239
+ def freeState (@ unused state : Long ) = {}
1240
+ }
1241
+ cctx.registerSequenceProducer(seqProd)
1242
+ cctx.setValidateSequences(true )
1243
+ cctx.setSequenceProducerFallback(false )
1244
+ cctx.setPledgedSrcSize(size)
1245
+ val compressedBuffer = ByteBuffer .allocateDirect(Zstd .compressBound(size).toInt)
1246
+ while (inputBuffer.hasRemaining) {
1247
+ compressedBuffer.limit(compressedBuffer.position() + 1 )
1248
+ cctx.compressDirectByteBufferStream(compressedBuffer, inputBuffer, EndDirective .CONTINUE )
1249
+ }
1250
+
1251
+ var frameProgression = cctx.getFrameProgression()
1252
+ assert(frameProgression.getIngested() == size)
1253
+ assert(frameProgression.getFlushed() == compressedBuffer.position())
1254
+
1255
+ compressedBuffer.limit(compressedBuffer.capacity())
1256
+ val done = cctx.compressDirectByteBufferStream(compressedBuffer, inputBuffer, EndDirective .END )
1257
+ assert(done)
1258
+
1259
+ frameProgression = cctx.getFrameProgression()
1260
+ assert(frameProgression.getConsumed() == size)
1261
+
1262
+ compressedBuffer.flip()
1263
+ val decompressedBuffer = ByteBuffer .allocateDirect(size)
1264
+ dctx.reset()
1265
+ while (compressedBuffer.hasRemaining) {
1266
+ if (decompressedBuffer.limit() < decompressedBuffer.position()) {
1267
+ decompressedBuffer.limit(compressedBuffer.position() + 1 )
1268
+ }
1269
+ dctx.decompressDirectByteBufferStream(decompressedBuffer, compressedBuffer)
1270
+ }
1271
+
1272
+ inputBuffer.rewind()
1273
+ compressedBuffer.rewind()
1274
+ decompressedBuffer.flip()
1275
+
1276
+ val comparison = inputBuffer.compareTo(decompressedBuffer)
1277
+ assert(comparison == 0 && Zstd .decompressedSize(compressedBuffer) == size && Zstd .getFrameContentSize(compressedBuffer) == size)
1278
+ }
1279
+ }
1280
+ }.get
1281
+ }
1282
+
1283
+ it should " fail with a stub sequence producer" in {
1284
+ Using .Manager { use =>
1285
+ val cctx = use(new ZstdCompressCtx ())
1286
+
1287
+ forAll(minSize(32 )) { input : Array [Byte ] =>
1288
+ {
1289
+ val size = input.length
1290
+ val inputBuffer = ByteBuffer .allocateDirect(size)
1291
+ inputBuffer.put(input)
1292
+ inputBuffer.flip()
1293
+ cctx.reset()
1294
+ cctx.setLevel(9 )
1295
+
1296
+ val seqProd = new SequenceProducer {
1297
+ def getFunctionPointer (): Long = {
1298
+ Zstd .getStubSequenceProducer()
1299
+ }
1300
+
1301
+ def createState (): Long = { 0 }
1302
+ def freeState (@ unused state : Long ) = { 0 }
1303
+ }
1304
+
1305
+ cctx.registerSequenceProducer(seqProd)
1306
+ cctx.setValidateSequences(true )
1307
+ cctx.setSequenceProducerFallback(false )
1308
+ cctx.setPledgedSrcSize(size)
1309
+
1310
+ val compressedBuffer = ByteBuffer .allocateDirect(Zstd .compressBound(size).toInt)
1311
+ try {
1312
+ while (inputBuffer.hasRemaining) {
1313
+ compressedBuffer.limit(compressedBuffer.position() + 1 )
1314
+ cctx.compressDirectByteBufferStream(compressedBuffer, inputBuffer, EndDirective .CONTINUE )
1315
+ }
1316
+ cctx.compressDirectByteBufferStream(compressedBuffer, inputBuffer, EndDirective .END )
1317
+ fail(" compression succeeded, but should have failed" )
1318
+ } catch {
1319
+ case _ : ZstdException => // compression should throw a ZstdException
1320
+ }
1321
+ }
1322
+ }
1323
+ }.get
1324
+ }
1325
+
1326
+ it should " succeed with a stub sequence producer and software fallback" in {
1327
+ Using .Manager { use =>
1328
+ val cctx = use(new ZstdCompressCtx ())
1329
+ val dctx = use(new ZstdDecompressCtx ())
1330
+
1331
+ forAll { input : Array [Byte ] =>
1332
+ {
1333
+ val size = input.length
1334
+ val inputBuffer = ByteBuffer .allocateDirect(size)
1335
+ inputBuffer.put(input)
1336
+ inputBuffer.flip()
1337
+ cctx.reset()
1338
+ cctx.setLevel(9 )
1339
+
1340
+ val seqProd = new SequenceProducer {
1341
+ def getFunctionPointer (): Long = {
1342
+ Zstd .getStubSequenceProducer()
1343
+ }
1344
+
1345
+ def createState (): Long = { 0 }
1346
+ def freeState (@ unused state : Long ) = { 0 }
1347
+ }
1348
+
1349
+ cctx.registerSequenceProducer(seqProd)
1350
+ cctx.setValidateSequences(true )
1351
+ cctx.setSequenceProducerFallback(true ) // !!
1352
+ cctx.setPledgedSrcSize(size)
1353
+
1354
+ val compressedBuffer = ByteBuffer .allocateDirect(Zstd .compressBound(size).toInt)
1355
+ while (inputBuffer.hasRemaining) {
1356
+ compressedBuffer.limit(compressedBuffer.position() + 1 )
1357
+ cctx.compressDirectByteBufferStream(compressedBuffer, inputBuffer, EndDirective .CONTINUE )
1358
+ }
1359
+
1360
+ var frameProgression = cctx.getFrameProgression()
1361
+ assert(frameProgression.getIngested() == size)
1362
+ assert(frameProgression.getFlushed() == compressedBuffer.position())
1363
+
1364
+ compressedBuffer.limit(compressedBuffer.capacity())
1365
+ val done = cctx.compressDirectByteBufferStream(compressedBuffer, inputBuffer, EndDirective .END )
1366
+ assert(done)
1367
+
1368
+ frameProgression = cctx.getFrameProgression()
1369
+ assert(frameProgression.getConsumed() == size)
1370
+
1371
+ compressedBuffer.flip()
1372
+ val decompressedBuffer = ByteBuffer .allocateDirect(size)
1373
+ dctx.reset()
1374
+ while (compressedBuffer.hasRemaining) {
1375
+ if (decompressedBuffer.limit() < decompressedBuffer.position()) {
1376
+ decompressedBuffer.limit(compressedBuffer.position() + 1 )
1377
+ }
1378
+ dctx.decompressDirectByteBufferStream(decompressedBuffer, compressedBuffer)
1379
+ }
1380
+
1381
+ inputBuffer.rewind()
1382
+ compressedBuffer.rewind()
1383
+ decompressedBuffer.flip()
1384
+
1385
+ val comparison = inputBuffer.compareTo(decompressedBuffer)
1386
+ assert(comparison == 0 && Zstd .decompressedSize(compressedBuffer) == size && Zstd .getFrameContentSize(compressedBuffer) == size)
1387
+ }
1388
+ }
1389
+ }.get
1390
+ }
1214
1391
}
0 commit comments