@@ -584,6 +584,40 @@ where
584
584
585
585
Ok ( ( ) )
586
586
}
587
+
588
+ fn write_embeddings_len ( & self , offset : u64 ) -> u64 {
589
+ let mut len = 0 ;
590
+
591
+ let mut chunks = match self . metadata {
592
+ Some ( ref metadata) => vec ! [ metadata. chunk_identifier( ) ] ,
593
+ None => vec ! [ ] ,
594
+ } ;
595
+
596
+ chunks. extend_from_slice ( & [
597
+ self . vocab . chunk_identifier ( ) ,
598
+ self . storage . chunk_identifier ( ) ,
599
+ ] ) ;
600
+
601
+ if let Some ( ref norms) = self . norms {
602
+ chunks. push ( norms. chunk_identifier ( ) ) ;
603
+ }
604
+
605
+ let header = Header :: new ( chunks) ;
606
+ len += header. chunk_len ( offset + len) ;
607
+
608
+ if let Some ( ref metadata) = self . metadata {
609
+ len += metadata. chunk_len ( offset + len) ;
610
+ }
611
+
612
+ len += self . vocab . chunk_len ( offset + len) ;
613
+ len += self . storage . chunk_len ( offset + len) ;
614
+
615
+ if let Some ( ref norms) = self . norms {
616
+ len += norms. chunk_len ( offset + len) ;
617
+ }
618
+
619
+ len
620
+ }
587
621
}
588
622
589
623
/// Quantizable embedding matrix.
@@ -736,17 +770,36 @@ mod tests {
736
770
use crate :: chunks:: metadata:: Metadata ;
737
771
use crate :: chunks:: norms:: NdNorms ;
738
772
use crate :: chunks:: storage:: { NdArray , Storage , StorageView } ;
739
- use crate :: chunks:: vocab:: { SimpleVocab , Vocab } ;
773
+ use crate :: chunks:: vocab:: { FastTextSubwordVocab , SimpleVocab , Vocab } ;
740
774
use crate :: compat:: fasttext:: ReadFastText ;
741
775
use crate :: compat:: word2vec:: ReadWord2VecRaw ;
742
776
use crate :: io:: { ReadEmbeddings , WriteEmbeddings } ;
777
+ use crate :: prelude:: StorageWrap ;
778
+ use crate :: storage:: QuantizedArray ;
743
779
use crate :: subword:: Indexer ;
780
+ use crate :: vocab:: VocabWrap ;
744
781
745
782
fn test_embeddings ( ) -> Embeddings < SimpleVocab , NdArray > {
746
783
let mut reader = BufReader :: new ( File :: open ( "testdata/similarity.bin" ) . unwrap ( ) ) ;
747
784
Embeddings :: read_word2vec_binary_raw ( & mut reader, false ) . unwrap ( )
748
785
}
749
786
787
+ fn test_embeddings_with_metadata ( ) -> Embeddings < SimpleVocab , NdArray > {
788
+ let mut embeds = test_embeddings ( ) ;
789
+ embeds. set_metadata ( Some ( test_metadata ( ) ) ) ;
790
+ embeds
791
+ }
792
+
793
+ fn test_embeddings_fasttext ( ) -> Embeddings < FastTextSubwordVocab , NdArray > {
794
+ let mut reader = BufReader :: new ( File :: open ( "testdata/fasttext.bin" ) . unwrap ( ) ) ;
795
+ Embeddings :: read_fasttext ( & mut reader) . unwrap ( )
796
+ }
797
+
798
+ fn test_embeddings_quantized ( ) -> Embeddings < SimpleVocab , QuantizedArray > {
799
+ let mut reader = BufReader :: new ( File :: open ( "testdata/quantized.fifu" ) . unwrap ( ) ) ;
800
+ Embeddings :: read_embeddings ( & mut reader) . unwrap ( )
801
+ }
802
+
750
803
fn test_metadata ( ) -> Metadata {
751
804
Metadata :: new ( toml ! {
752
805
[ hyperparameters]
@@ -867,12 +920,15 @@ mod tests {
867
920
Embeddings :: read_embeddings ( & mut cursor) . unwrap ( ) ;
868
921
assert_eq ! ( embeds. storage( ) . view( ) , check_embeds. storage( ) . view( ) ) ;
869
922
assert_eq ! ( embeds. vocab( ) , check_embeds. vocab( ) ) ;
923
+ assert_eq ! (
924
+ cursor. into_inner( ) . len( ) as u64 ,
925
+ check_embeds. write_embeddings_len( 0 )
926
+ ) ;
870
927
}
871
928
872
929
#[ test]
873
930
fn write_read_simple_metadata_roundtrip ( ) {
874
- let mut check_embeds = test_embeddings ( ) ;
875
- check_embeds. set_metadata ( Some ( test_metadata ( ) ) ) ;
931
+ let check_embeds = test_embeddings_with_metadata ( ) ;
876
932
877
933
let mut cursor = Cursor :: new ( Vec :: new ( ) ) ;
878
934
check_embeds. write_embeddings ( & mut cursor) . unwrap ( ) ;
@@ -881,5 +937,31 @@ mod tests {
881
937
Embeddings :: read_embeddings ( & mut cursor) . unwrap ( ) ;
882
938
assert_eq ! ( embeds. storage( ) . view( ) , check_embeds. storage( ) . view( ) ) ;
883
939
assert_eq ! ( embeds. vocab( ) , check_embeds. vocab( ) ) ;
940
+ assert_eq ! (
941
+ cursor. into_inner( ) . len( ) as u64 ,
942
+ check_embeds. write_embeddings_len( 0 )
943
+ ) ;
944
+ }
945
+
946
+ #[ test]
947
+ fn embeddings_write_length_different_offsets ( ) {
948
+ let embeddings: Vec < Embeddings < VocabWrap , StorageWrap > > = vec ! [
949
+ test_embeddings( ) . into( ) ,
950
+ test_embeddings_with_metadata( ) . into( ) ,
951
+ test_embeddings_fasttext( ) . into( ) ,
952
+ test_embeddings_quantized( ) . into( ) ,
953
+ ] ;
954
+
955
+ for check_embeddings in & embeddings {
956
+ for offset in 0 ..16u64 {
957
+ let mut cursor = Cursor :: new ( Vec :: new ( ) ) ;
958
+ cursor. seek ( SeekFrom :: Start ( offset) ) . unwrap ( ) ;
959
+ check_embeddings. write_embeddings ( & mut cursor) . unwrap ( ) ;
960
+ assert_eq ! (
961
+ cursor. into_inner( ) . len( ) as u64 - offset,
962
+ check_embeddings. write_embeddings_len( offset)
963
+ ) ;
964
+ }
965
+ }
884
966
}
885
967
}
0 commit comments