Skip to content

Commit 100c434

Browse files
divijvaidyaluben
authored andcommitted
Relax the requirement for source and target ByteBuffer in ZstdBufferDecompressingStream
1 parent 4793b0b commit 100c434

5 files changed

+65
-10
lines changed

src/main/java/com/github/luben/zstd/BaseZstdBufferDecompressingStreamNoFinalizer.java

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,13 @@ public abstract class BaseZstdBufferDecompressingStreamNoFinalizer implements Cl
1010
protected boolean closed = false;
1111
private boolean finishedFrame = false;
1212
private boolean streamEnd = false;
13+
/**
14+
* This field is set by the native call to represent the number of bytes consumed from {@link #source} buffer.
15+
*/
1316
private int consumed;
17+
/**
18+
* This field is set by the native call to represent the number of bytes produced into the target buffer.
19+
*/
1420
private int produced;
1521

1622
BaseZstdBufferDecompressingStreamNoFinalizer(ByteBuffer source) {
@@ -27,6 +33,9 @@ protected ByteBuffer refill(ByteBuffer toRefill) {
2733
return toRefill;
2834
}
2935

36+
/**
37+
* @return false if all data is processed and no more data is available from the {@link #source}
38+
*/
3039
public boolean hasRemaining() {
3140
return !streamEnd && (source.hasRemaining() || !finishedFrame);
3241
}
@@ -52,6 +61,15 @@ public BaseZstdBufferDecompressingStreamNoFinalizer setDict(ZstdDictDecompress d
5261
return this;
5362
}
5463

64+
/**
65+
* Set the value of zstd parameter <code>ZSTD_d_windowLogMax</code>.
66+
*
67+
* @param windowLogMax window size in bytes
68+
* @return this instance of {@link BaseZstdBufferDecompressingStreamNoFinalizer}
69+
* @throws ZstdIOException if there is an error while setting the configuration natively.
70+
*
71+
* @see <a href="https://github.com/facebook/zstd/blob/0525d1cec64a8df749ff293ee476f616de79f7b0/lib/zstd.h#L606"> Zstd's ZSTD_d_windowLogMax parameter</a>
72+
*/
5573
public BaseZstdBufferDecompressingStreamNoFinalizer setLongMax(int windowLogMax) throws IOException {
5674
long size = Zstd.setDecompressionLongMax(stream, windowLogMax);
5775
if (Zstd.isError(size)) {
@@ -106,6 +124,23 @@ public void close() {
106124
}
107125
}
108126
}
127+
128+
/**
129+
* Reads the content of the de-compressed stream into the target buffer.
130+
* <p>This method will block until the chunk of compressed data stored in {@link #source} has been decompressed and
131+
* written into the target buffer. After each execution, this method will refill the {@link #source} buffer, using
132+
* {@link #refill(ByteBuffer)}.
133+
*<p>To read the full stream of decompressed data, this method should be called in a loop while {@link #hasRemaining()}
134+
* is <code>true</code>.
135+
*<p>The target buffer will be written starting from {@link ByteBuffer#position()}. The {@link ByteBuffer#position()}
136+
* of source and the target buffers will be modified to represent the data read and written respectively.
137+
*
138+
* @param target buffer to store the read bytes from uncompressed stream.
139+
* @return the number of bytes read into the target buffer.
140+
* @throws ZstdIOException if an error occurs while reading.
141+
* @throws IllegalArgumentException if provided source or target buffers are incorrectly configured.
142+
* @throws IOException if the stream is closed before reading.
143+
*/
109144
public abstract int read(ByteBuffer target) throws IOException;
110145

111146
abstract long createDStream();

src/main/java/com/github/luben/zstd/ZstdBufferDecompressingStreamNoFinalizer.java

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,8 @@ public ZstdBufferDecompressingStreamNoFinalizer(ByteBuffer source) {
1515
if (source.isDirect()) {
1616
throw new IllegalArgumentException("Source buffer should be a non-direct buffer");
1717
}
18-
stream = createDStreamNative();
19-
initDStreamNative(stream);
18+
stream = createDStream();
19+
initDStream(stream);
2020
}
2121

2222
@Override
@@ -44,8 +44,14 @@ long initDStream(long stream) {
4444

4545
@Override
4646
long decompressStream(long stream, ByteBuffer dst, int dstOffset, int dstSize, ByteBuffer src, int srcOffset, int srcSize) {
47-
byte[] targetArr = Zstd.extractArray(dst);
48-
byte[] sourceArr = Zstd.extractArray(source);
47+
if (!src.hasArray()) {
48+
throw new IllegalArgumentException("provided source ByteBuffer lacks array");
49+
}
50+
if (!dst.hasArray()) {
51+
throw new IllegalArgumentException("provided destination ByteBuffer lacks array");
52+
}
53+
byte[] targetArr = dst.array();
54+
byte[] sourceArr = src.array();
4955

5056
return decompressStreamNative(stream, targetArr, dstOffset, dstSize, sourceArr, srcOffset, srcSize);
5157
}

src/main/java/com/github/luben/zstd/ZstdDirectBufferDecompressingStreamNoFinalizer.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,8 @@ public ZstdDirectBufferDecompressingStreamNoFinalizer(ByteBuffer source) {
1616
throw new IllegalArgumentException("Source buffer should be a direct buffer");
1717
}
1818
this.source = source;
19-
stream = createDStreamNative();
20-
initDStreamNative(stream);
19+
stream = createDStream();
20+
initDStream(stream);
2121
}
2222

2323
@Override

src/main/java/com/github/luben/zstd/ZstdInputStreamNoFinalizer.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,7 @@ int readInternal(byte[] dst, int offset, int len) throws IOException {
140140
throw new IOException("Stream closed");
141141
}
142142

143-
// guard agains buffer overflows
143+
// guard against buffer overflows
144144
if (offset < 0 || len > dst.length - offset) {
145145
throw new IndexOutOfBoundsException("Requested length " + len
146146
+ " from offset " + offset + " in buffer of size " + dst.length);

src/test/scala/Zstd.scala

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,13 @@ package com.github.luben.zstd
22

33
import org.scalatest.flatspec.AnyFlatSpec
44
import org.scalatestplus.scalacheck.ScalaCheckPropertyChecks
5+
56
import java.io._
67
import java.nio._
78
import java.nio.channels.FileChannel
89
import java.nio.channels.FileChannel.MapMode
10+
import java.nio.charset.Charset
911
import java.nio.file.StandardOpenOption
10-
1112
import scala.io._
1213
import scala.collection.mutable.WrappedArray
1314
import scala.util.Using
@@ -676,21 +677,34 @@ class ZstdSpec extends AnyFlatSpec with ScalaCheckPropertyChecks {
676677
val orig = new File("src/test/resources/xml")
677678
val file = new File(s"src/test/resources/xml-$level.zst")
678679
val channel = FileChannel.open(file.toPath, StandardOpenOption.READ)
679-
val readBuffer = ByteBuffer.allocate(channel.size().toInt)
680+
// write some garbage bytes at the beginning of buffer containing compressed data to prove that
681+
// this buffer's position doesn't have to start from 0.
682+
val garbageBytes = "garbage bytes".getBytes(Charset.defaultCharset());
683+
val readBuffer = ByteBuffer.allocate(channel.size().toInt + garbageBytes.length)
684+
readBuffer.put(garbageBytes)
680685
channel.read(readBuffer)
686+
// set pos to 0 and limit to containing bytes
681687
readBuffer.flip()
688+
// advance the position after garbage data
689+
readBuffer.position(garbageBytes.length)
690+
682691
val zis = new ZstdBufferDecompressingStream(readBuffer)
683692
val length = orig.length.toInt
684693
val buff = Array.fill[Byte](length)(0)
685694
var pos = 0
686-
val block = ByteBuffer.allocate(1)
695+
// write some garbage bytes at the beginning of buffer containing uncompressed data to prove that
696+
// this buffer's position doesn't have to start from 0.
697+
val block = ByteBuffer.allocate(1 + garbageBytes.length)
687698
while (pos < length && zis.hasRemaining) {
688699
block.clear
700+
block.put(garbageBytes)
689701
val read = zis.read(block)
690702
if (read != 1) {
691703
sys.error(s"Failed reading compressed file before end. Bytes read: $read")
692704
}
693705
block.flip()
706+
// advance the position after garbage data
707+
block.position(garbageBytes.length);
694708
buff.update(pos, block.get())
695709
pos += 1
696710
}

0 commit comments

Comments
 (0)