Skip to content

Commit e9bc661

Browse files
danieldkDaniël de Kok
authored andcommitted
Add WriteEmbeddings::write_embeddings_len
This methods precomputes the size of serialized finalfusion embeddings.
1 parent 9846cfe commit e9bc661

File tree

3 files changed

+87
-3
lines changed

3 files changed

+87
-3
lines changed

src/embeddings.rs

Lines changed: 85 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -584,6 +584,40 @@ where
584584

585585
Ok(())
586586
}
587+
588+
fn write_embeddings_len(&self, offset: u64) -> u64 {
589+
let mut len = 0;
590+
591+
let mut chunks = match self.metadata {
592+
Some(ref metadata) => vec![metadata.chunk_identifier()],
593+
None => vec![],
594+
};
595+
596+
chunks.extend_from_slice(&[
597+
self.vocab.chunk_identifier(),
598+
self.storage.chunk_identifier(),
599+
]);
600+
601+
if let Some(ref norms) = self.norms {
602+
chunks.push(norms.chunk_identifier());
603+
}
604+
605+
let header = Header::new(chunks);
606+
len += header.chunk_len(offset + len);
607+
608+
if let Some(ref metadata) = self.metadata {
609+
len += metadata.chunk_len(offset + len);
610+
}
611+
612+
len += self.vocab.chunk_len(offset + len);
613+
len += self.storage.chunk_len(offset + len);
614+
615+
if let Some(ref norms) = self.norms {
616+
len += norms.chunk_len(offset + len);
617+
}
618+
619+
len
620+
}
587621
}
588622

589623
/// Quantizable embedding matrix.
@@ -736,17 +770,36 @@ mod tests {
736770
use crate::chunks::metadata::Metadata;
737771
use crate::chunks::norms::NdNorms;
738772
use crate::chunks::storage::{NdArray, Storage, StorageView};
739-
use crate::chunks::vocab::{SimpleVocab, Vocab};
773+
use crate::chunks::vocab::{FastTextSubwordVocab, SimpleVocab, Vocab};
740774
use crate::compat::fasttext::ReadFastText;
741775
use crate::compat::word2vec::ReadWord2VecRaw;
742776
use crate::io::{ReadEmbeddings, WriteEmbeddings};
777+
use crate::prelude::StorageWrap;
778+
use crate::storage::QuantizedArray;
743779
use crate::subword::Indexer;
780+
use crate::vocab::VocabWrap;
744781

745782
fn test_embeddings() -> Embeddings<SimpleVocab, NdArray> {
746783
let mut reader = BufReader::new(File::open("testdata/similarity.bin").unwrap());
747784
Embeddings::read_word2vec_binary_raw(&mut reader, false).unwrap()
748785
}
749786

787+
fn test_embeddings_with_metadata() -> Embeddings<SimpleVocab, NdArray> {
788+
let mut embeds = test_embeddings();
789+
embeds.set_metadata(Some(test_metadata()));
790+
embeds
791+
}
792+
793+
fn test_embeddings_fasttext() -> Embeddings<FastTextSubwordVocab, NdArray> {
794+
let mut reader = BufReader::new(File::open("testdata/fasttext.bin").unwrap());
795+
Embeddings::read_fasttext(&mut reader).unwrap()
796+
}
797+
798+
fn test_embeddings_quantized() -> Embeddings<SimpleVocab, QuantizedArray> {
799+
let mut reader = BufReader::new(File::open("testdata/quantized.fifu").unwrap());
800+
Embeddings::read_embeddings(&mut reader).unwrap()
801+
}
802+
750803
fn test_metadata() -> Metadata {
751804
Metadata::new(toml! {
752805
[hyperparameters]
@@ -867,12 +920,15 @@ mod tests {
867920
Embeddings::read_embeddings(&mut cursor).unwrap();
868921
assert_eq!(embeds.storage().view(), check_embeds.storage().view());
869922
assert_eq!(embeds.vocab(), check_embeds.vocab());
923+
assert_eq!(
924+
cursor.into_inner().len() as u64,
925+
check_embeds.write_embeddings_len(0)
926+
);
870927
}
871928

872929
#[test]
873930
fn write_read_simple_metadata_roundtrip() {
874-
let mut check_embeds = test_embeddings();
875-
check_embeds.set_metadata(Some(test_metadata()));
931+
let check_embeds = test_embeddings_with_metadata();
876932

877933
let mut cursor = Cursor::new(Vec::new());
878934
check_embeds.write_embeddings(&mut cursor).unwrap();
@@ -881,5 +937,31 @@ mod tests {
881937
Embeddings::read_embeddings(&mut cursor).unwrap();
882938
assert_eq!(embeds.storage().view(), check_embeds.storage().view());
883939
assert_eq!(embeds.vocab(), check_embeds.vocab());
940+
assert_eq!(
941+
cursor.into_inner().len() as u64,
942+
check_embeds.write_embeddings_len(0)
943+
);
944+
}
945+
946+
#[test]
947+
fn embeddings_write_length_different_offsets() {
948+
let embeddings: Vec<Embeddings<VocabWrap, StorageWrap>> = vec![
949+
test_embeddings().into(),
950+
test_embeddings_with_metadata().into(),
951+
test_embeddings_fasttext().into(),
952+
test_embeddings_quantized().into(),
953+
];
954+
955+
for check_embeddings in &embeddings {
956+
for offset in 0..16u64 {
957+
let mut cursor = Cursor::new(Vec::new());
958+
cursor.seek(SeekFrom::Start(offset)).unwrap();
959+
check_embeddings.write_embeddings(&mut cursor).unwrap();
960+
assert_eq!(
961+
cursor.into_inner().len() as u64 - offset,
962+
check_embeddings.write_embeddings_len(offset)
963+
);
964+
}
965+
}
884966
}
885967
}

src/io.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,4 +85,6 @@ pub trait WriteEmbeddings {
8585
fn write_embeddings<W>(&self, write: &mut W) -> Result<()>
8686
where
8787
W: Write + Seek;
88+
89+
fn write_embeddings_len(&self, offset: u64) -> u64;
8890
}

testdata/quantized.fifu

15.2 KB
Binary file not shown.

0 commit comments

Comments
 (0)