Skip to content

Commit fdd79e3

Browse files
danieldkDaniël de Kok
authored andcommitted
Add WriteChunk::chunk_len method
This method returns the full length of the chunk when serialized, including the chunk identifier and (written) chunk length.
1 parent c12769a commit fdd79e3

File tree

11 files changed

+379
-119
lines changed

11 files changed

+379
-119
lines changed

src/chunks/io.rs

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ use std::convert::TryFrom;
22
use std::fmt::{self, Display};
33
use std::fs::File;
44
use std::io::{BufReader, Read, Seek, Write};
5+
use std::mem;
56

67
use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt};
78

@@ -154,6 +155,12 @@ pub trait WriteChunk {
154155
/// Get the identifier of a chunk.
155156
fn chunk_identifier(&self) -> ChunkIdentifier;
156157

158+
/// Get the serialized length of a chunk.
159+
///
160+
/// The `offset` of the chunk in the serialized data is required
161+
/// because some chunks store arrays aligned.
162+
fn chunk_len(&self, offset: u64) -> u64;
163+
157164
fn write_chunk<W>(&self, write: &mut W) -> Result<()>
158165
where
159166
W: Write + Seek;
@@ -181,6 +188,14 @@ impl WriteChunk for Header {
181188
ChunkIdentifier::Header
182189
}
183190

191+
fn chunk_len(&self, _offset: u64) -> u64 {
192+
// magic + model version (u32) + chunk ids len (u32), chunk ids (len * u32)
193+
(MAGIC.len()
194+
+ mem::size_of_val(&MODEL_VERSION)
195+
+ mem::size_of::<u32>()
196+
+ self.chunk_identifiers.len() * mem::size_of::<u32>()) as u64
197+
}
198+
184199
fn write_chunk<W>(&self, write: &mut W) -> Result<()>
185200
where
186201
W: Write + Seek,
@@ -256,6 +271,17 @@ mod tests {
256271

257272
use super::{ChunkIdentifier, Header, ReadChunk, WriteChunk};
258273

274+
#[test]
275+
fn header_chunk_len_is_correct() {
276+
let check_header =
277+
Header::new(vec![ChunkIdentifier::SimpleVocab, ChunkIdentifier::NdArray]);
278+
let mut cursor = Cursor::new(Vec::new());
279+
check_header.write_chunk(&mut cursor).unwrap();
280+
281+
let data = cursor.into_inner();
282+
assert_eq!(data.len() as u64, check_header.chunk_len(0));
283+
}
284+
259285
#[test]
260286
fn header_write_read_roundtrip() {
261287
let check_header =

src/chunks/metadata.rs

Lines changed: 24 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
//! Metadata chunks
22
33
use std::io::{Read, Seek, Write};
4+
use std::mem;
45
use std::ops::{Deref, DerefMut};
56

67
use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt};
@@ -80,6 +81,11 @@ impl WriteChunk for Metadata {
8081
ChunkIdentifier::Metadata
8182
}
8283

84+
fn chunk_len(&self, _offset: u64) -> u64 {
85+
// chunk identifier (u32) + metadata length (u64) + metadata
86+
(mem::size_of::<u32>() + mem::size_of::<u64>() + self.to_string().len()) as u64
87+
}
88+
8389
fn write_chunk<W>(&self, write: &mut W) -> Result<()>
8490
where
8591
W: Write + Seek,
@@ -154,16 +160,24 @@ mod tests {
154160

155161
#[test]
156162
fn metadata_correct_chunk_size() {
157-
let check_metadata = test_metadata();
158-
let mut cursor = Cursor::new(Vec::new());
159-
check_metadata.write_chunk(&mut cursor).unwrap();
160-
cursor.seek(SeekFrom::Start(0)).unwrap();
161-
162-
let chunk_size = read_chunk_size(&mut cursor);
163-
assert_eq!(
164-
cursor.read_to_end(&mut Vec::new()).unwrap(),
165-
chunk_size as usize
166-
);
163+
for offset in 0..16u64 {
164+
let check_metadata = test_metadata();
165+
let mut cursor = Cursor::new(Vec::new());
166+
cursor.seek(SeekFrom::Start(offset)).unwrap();
167+
check_metadata.write_chunk(&mut cursor).unwrap();
168+
cursor.seek(SeekFrom::Start(offset)).unwrap();
169+
170+
// Check remaining chunk size against size written into the chunk.
171+
let chunk_size = read_chunk_size(&mut cursor);
172+
assert_eq!(
173+
cursor.read_to_end(&mut Vec::new()).unwrap() as u64,
174+
chunk_size
175+
);
176+
177+
// Check overall chunk size.
178+
let data = cursor.into_inner();
179+
assert_eq!(data.len() as u64 - offset, check_metadata.chunk_len(offset));
180+
}
167181
}
168182

169183
#[test]

src/chunks/norms.rs

Lines changed: 33 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
33
use std::convert::TryInto;
44
use std::io::{Read, Seek, SeekFrom, Write};
5+
use std::mem;
56
use std::mem::size_of;
67
use std::ops::Deref;
78

@@ -89,6 +90,18 @@ impl WriteChunk for NdNorms {
8990
ChunkIdentifier::NdNorms
9091
}
9192

93+
fn chunk_len(&self, offset: u64) -> u64 {
94+
let n_padding = padding::<f32>(offset + mem::size_of::<u32>() as u64);
95+
96+
// Chunk identifier (u32) + chunk len (u64) + len (u64) + type id (u32) + padding + vector.
97+
(mem::size_of::<u32>()
98+
+ mem::size_of::<u64>()
99+
+ mem::size_of::<u64>()
100+
+ mem::size_of::<u32>()
101+
+ self.len() * mem::size_of::<f32>()) as u64
102+
+ n_padding
103+
}
104+
92105
fn write_chunk<W>(&self, write: &mut W) -> Result<()>
93106
where
94107
W: Write + Seek,
@@ -156,16 +169,26 @@ mod tests {
156169

157170
#[test]
158171
fn ndnorms_correct_chunk_size() {
159-
let check_arr = test_ndnorms();
160-
let mut cursor = Cursor::new(Vec::new());
161-
check_arr.write_chunk(&mut cursor).unwrap();
162-
cursor.seek(SeekFrom::Start(0)).unwrap();
163-
164-
let chunk_size = read_chunk_size(&mut cursor);
165-
assert_eq!(
166-
cursor.read_to_end(&mut Vec::new()).unwrap(),
167-
chunk_size as usize
168-
);
172+
for offset in 0..16u64 {
173+
let check_arr = test_ndnorms();
174+
let mut cursor = Cursor::new(Vec::new());
175+
cursor.seek(SeekFrom::Start(offset)).unwrap();
176+
check_arr.write_chunk(&mut cursor).unwrap();
177+
cursor.seek(SeekFrom::Start(offset)).unwrap();
178+
179+
// Check size remained chunk against embedded chunk size.
180+
let chunk_size = read_chunk_size(&mut cursor);
181+
assert_eq!(
182+
cursor.read_to_end(&mut Vec::new()).unwrap(),
183+
chunk_size as usize
184+
);
185+
186+
// Check overall chunk size.
187+
assert_eq!(
188+
cursor.into_inner().len() as u64 - offset,
189+
check_arr.chunk_len(offset)
190+
);
191+
}
169192
}
170193

171194
#[test]

src/chunks/storage/array.rs

Lines changed: 25 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
use std::convert::TryInto;
22
use std::io::{Read, Seek, SeekFrom, Write};
3+
use std::mem;
34
use std::mem::size_of;
45

56
use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt};
@@ -170,6 +171,10 @@ mod mmap {
170171
ChunkIdentifier::NdArray
171172
}
172173

174+
fn chunk_len(&self, offset: u64) -> u64 {
175+
NdArray::chunk_len(self.view(), offset)
176+
}
177+
173178
fn write_chunk<W>(&self, write: &mut W) -> Result<()>
174179
where
175180
W: Write + Seek,
@@ -193,6 +198,19 @@ impl NdArray {
193198
NdArray { inner: arr }
194199
}
195200

201+
fn chunk_len(data: ArrayView2<f32>, offset: u64) -> u64 {
202+
let n_padding = padding::<f32>(offset + mem::size_of::<u32>() as u64);
203+
204+
// Chunk identifier (u32) + chunk len (u64) + rows (u64) + cols (u32) + type id (u32) + padding + matrix.
205+
(mem::size_of::<u32>()
206+
+ mem::size_of::<u64>()
207+
+ mem::size_of::<u64>()
208+
+ mem::size_of::<u32>()
209+
+ mem::size_of::<u32>()
210+
+ data.len() * mem::size_of::<f32>()) as u64
211+
+ n_padding
212+
}
213+
196214
fn write_ndarray_chunk<W>(data: ArrayView2<f32>, write: &mut W) -> Result<()>
197215
where
198216
W: Write + Seek,
@@ -346,6 +364,10 @@ impl WriteChunk for NdArray {
346364
ChunkIdentifier::NdArray
347365
}
348366

367+
fn chunk_len(&self, offset: u64) -> u64 {
368+
Self::chunk_len(self.inner.view(), offset)
369+
}
370+
349371
fn write_chunk<W>(&self, write: &mut W) -> Result<()>
350372
where
351373
W: Write + Seek,
@@ -356,13 +378,13 @@ impl WriteChunk for NdArray {
356378

357379
#[cfg(test)]
358380
mod tests {
359-
use std::io::{Cursor, Read, Seek, SeekFrom};
381+
use std::io::{Cursor, Seek, SeekFrom};
360382

361-
use byteorder::{LittleEndian, ReadBytesExt};
362383
use ndarray::Array2;
363384

364385
use crate::chunks::io::{ReadChunk, WriteChunk};
365386
use crate::chunks::storage::{NdArray, Storage, StorageView};
387+
use crate::storage::tests::test_storage_chunk_len;
366388

367389
const N_ROWS: usize = 100;
368390
const N_COLS: usize = 100;
@@ -375,14 +397,6 @@ mod tests {
375397
NdArray::new(test_data)
376398
}
377399

378-
fn read_chunk_size(read: &mut impl Read) -> u64 {
379-
// Skip identifier.
380-
read.read_u32::<LittleEndian>().unwrap();
381-
382-
// Return chunk length.
383-
read.read_u64::<LittleEndian>().unwrap()
384-
}
385-
386400
#[test]
387401
fn embeddings_returns_expected_embeddings() {
388402
const CHECK_INDICES: &[usize] = &[0, 50, 99, 0];
@@ -398,16 +412,7 @@ mod tests {
398412

399413
#[test]
400414
fn ndarray_correct_chunk_size() {
401-
let check_arr = test_ndarray();
402-
let mut cursor = Cursor::new(Vec::new());
403-
check_arr.write_chunk(&mut cursor).unwrap();
404-
cursor.seek(SeekFrom::Start(0)).unwrap();
405-
406-
let chunk_size = read_chunk_size(&mut cursor);
407-
assert_eq!(
408-
cursor.read_to_end(&mut Vec::new()).unwrap(),
409-
chunk_size as usize
410-
);
415+
test_storage_chunk_len(test_ndarray().into());
411416
}
412417

413418
#[test]

src/chunks/storage/mod.rs

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,3 +54,40 @@ pub(crate) trait StorageViewMut: Storage {
5454
/// Get a view of the embedding matrix.
5555
fn view_mut(&mut self) -> ArrayViewMut2<f32>;
5656
}
57+
58+
#[cfg(test)]
59+
mod tests {
60+
use std::io::{Cursor, Read, Seek, SeekFrom};
61+
62+
use crate::chunks::io::WriteChunk;
63+
use byteorder::{LittleEndian, ReadBytesExt};
64+
65+
use crate::storage::StorageWrap;
66+
67+
fn read_chunk_size(read: &mut impl Read) -> u64 {
68+
// Skip identifier.
69+
read.read_u32::<LittleEndian>().unwrap();
70+
71+
// Return chunk length.
72+
read.read_u64::<LittleEndian>().unwrap()
73+
}
74+
75+
#[cfg(test)]
76+
pub(crate) fn test_storage_chunk_len(check_storage: StorageWrap) {
77+
for offset in 0..16u64 {
78+
let mut cursor = Cursor::new(Vec::new());
79+
cursor.seek(SeekFrom::Start(offset)).unwrap();
80+
check_storage.write_chunk(&mut cursor).unwrap();
81+
cursor.seek(SeekFrom::Start(offset)).unwrap();
82+
83+
let chunk_size = read_chunk_size(&mut cursor);
84+
assert_eq!(
85+
cursor.read_to_end(&mut Vec::new()).unwrap(),
86+
chunk_size as usize
87+
);
88+
89+
let data = cursor.into_inner();
90+
assert_eq!(data.len() as u64 - offset, check_storage.chunk_len(offset));
91+
}
92+
}
93+
}

0 commit comments

Comments
 (0)