diff --git a/bwt.rs b/bwt.rs index d8e4605..47071f3 100644 --- a/bwt.rs +++ b/bwt.rs @@ -167,7 +167,6 @@ pub fn encode_brute(input: &[u8], suf: &mut [Suffix], fn_out: |u8|) -> Suffix { } } - assert!( origin.is_some() ); origin.unwrap() } @@ -335,7 +334,7 @@ impl Reader for Decoder { self.header = true; } let mut amt = dst.len(); - let len = amt; + let dst_len = amt; while amt > 0 { if self.output.len() == self.start { @@ -344,19 +343,19 @@ impl Reader for Decoder { break } } - let n = num::min( amt, self.output.len() - self.start ); + let n = num::min(amt, self.output.len() - self.start); vec::bytes::copy_memory( - dst.mut_slice_from(len - amt), - self.output.slice_from(self.start) + dst.mut_slice_from(dst_len - amt), + self.output.slice(self.start, self.start + n) ); self.start += n; amt -= n; } - if len == amt { + if dst_len == amt { Err(io::standard_error(io::EndOfFile)) } else { - Ok(len - amt) + Ok(dst_len - amt) } } } diff --git a/entropy/ari.rs b/entropy/ari.rs index 74db340..851cde6 100644 --- a/entropy/ari.rs +++ b/entropy/ari.rs @@ -19,11 +19,11 @@ use compress::entropy::ari; let text = "some text"; let mut e = ari::ByteEncoder::new(MemWriter::new()); e.write_str(text); -let (encoded, _) = e.encoder.finish(); +let (encoded, _) = e.finish(); // Decode the encoded text let mut d = ari::ByteDecoder::new(MemReader::new(encoded.unwrap())); -let decoded = d.read_bytes(text.len()).unwrap(); +let decoded = d.read_to_end().unwrap(); ``` # Credit @@ -435,6 +435,7 @@ impl Model for FrequencyTable { /// A basic byte-encoding arithmetic +/// uses a special terminator code to end the stream pub struct ByteEncoder { /// A lower level encoder encoder: Encoder, @@ -448,9 +449,16 @@ impl ByteEncoder { let freq_max = range_default_threshold >> 2; ByteEncoder { encoder: Encoder::new(w), - freq: FrequencyTable::new_flat(symbol_total, freq_max), + freq: FrequencyTable::new_flat(symbol_total+1, freq_max), } } + + /// Finish encoding & write the terminator symbol + pub fn finish(mut self) -> (W, io::IoResult<()>) { + let ret = self.encoder.encode(symbol_total, &self.freq); + let (w,r2) = self.encoder.finish(); + (w, ret.and(r2)) + } } impl Writer for ByteEncoder { @@ -470,21 +478,24 @@ impl Writer for ByteEncoder { /// A basic byte-decoding arithmetic +/// expects a special terminator code for the end of the stream pub struct ByteDecoder { /// A lower level decoder decoder: Decoder, /// A basic frequency table freq: FrequencyTable, + /// Remember if we found the terminator code + priv is_eof: bool, } impl ByteDecoder { /// Create a decoder on top of a given Reader - /// requires the output size to be known pub fn new(r: R) -> ByteDecoder { let freq_max = range_default_threshold >> 2; ByteDecoder { decoder: Decoder::new(r), - freq: FrequencyTable::new_flat(symbol_total, freq_max), + freq: FrequencyTable::new_flat(symbol_total+1, freq_max), + is_eof: false, } } } @@ -494,20 +505,21 @@ impl Reader for ByteDecoder { if self.decoder.tell() == 0 { if_ok!(self.decoder.start()); } - let mut ret = Ok(dst.len()); + if self.is_eof { + return Err(io::standard_error(io::EndOfFile)) + } + let mut amount = 0u; for out_byte in dst.mut_iter() { - match self.decoder.decode(&self.freq) { - Ok(value) => { - self.freq.update(value, 10, 1); - *out_byte = value as u8; - }, - Err(e) => { - ret = Err(e); - break - } + let value = if_ok!(self.decoder.decode(&self.freq)); + if value == symbol_total { + self.is_eof = true; + break } + self.freq.update(value, 10, 1); + *out_byte = value as u8; + amount += 1; } - ret + Ok(amount) } } @@ -523,12 +535,12 @@ mod test { info!("Roundtrip Ari of size {}", bytes.len()); let mut e = ByteEncoder::new(MemWriter::new()); e.write(bytes).unwrap(); - let (e, r) = e.encoder.finish(); + let (e, r) = e.finish(); r.unwrap(); let encoded = e.unwrap(); debug!("Roundtrip input {:?} encoded {:?}", bytes, encoded); let mut d = ByteDecoder::new(BufReader::new(encoded)); - let decoded = d.read_bytes(bytes.len()).unwrap(); + let decoded = d.read_to_end().unwrap(); assert_eq!(bytes.as_slice(), decoded.as_slice()); }