@@ -4,7 +4,7 @@ use std::cmp::Ordering;
4
4
use std:: collections:: { BinaryHeap , HashSet } ;
5
5
use std:: f32;
6
6
7
- use ndarray:: { s, ArrayView1 , CowArray , Ix1 } ;
7
+ use ndarray:: { s, ArrayView1 , Axis , CowArray , Ix1 } ;
8
8
use ordered_float:: NotNan ;
9
9
10
10
use crate :: chunks:: storage:: { Storage , StorageView } ;
@@ -82,12 +82,20 @@ pub trait Analogy {
82
82
/// At most, `limit` results are returned. `Result::Err` is returned
83
83
/// when no embedding could be computed for one or more of the tokens,
84
84
/// indicating which of the tokens were present.
85
+ ///
86
+ /// If `batch_size` is `None`, the query will be performed on all
87
+ /// word embeddings at once. This is typically the most efficient, but
88
+ /// can require a large amount of memory. The query is performed on batches
89
+ /// of size `n` when `batch_size` is `Some(n)`. Setting this to a smaller
90
+ /// value than the number of word embeddings reduces memory use at the
91
+ /// cost of computational efficiency.
85
92
fn analogy (
86
93
& self ,
87
94
query : [ & str ; 3 ] ,
88
95
limit : usize ,
96
+ batch_size : Option < usize > ,
89
97
) -> Result < Vec < WordSimilarityResult > , [ bool ; 3 ] > {
90
- self . analogy_masked ( query, [ true , true , true ] , limit)
98
+ self . analogy_masked ( query, [ true , true , true ] , limit, batch_size )
91
99
}
92
100
93
101
/// Perform an analogy query.
@@ -104,6 +112,13 @@ pub trait Analogy {
104
112
/// output candidates. If `remove[0]` is `true`, `word1` cannot be
105
113
/// returned as an answer to the query.
106
114
///
115
+ /// If `batch_size` is `None`, the query will be performed on all
116
+ /// word embeddings at once. This is typically the most efficient, but
117
+ /// can require a large amount of memory. The query is performed on batches
118
+ /// of size `n` when `batch_size` is `Some(n)`. Setting this to a smaller
119
+ /// value than the number of word embeddings reduces memory use at the
120
+ /// cost of computational efficiency.
121
+ ///
107
122
///`Result::Err` is returned when no embedding could be computed
108
123
/// for one or more of the tokens, indicating which of the tokens
109
124
/// were present.
@@ -112,6 +127,7 @@ pub trait Analogy {
112
127
query : [ & str ; 3 ] ,
113
128
remove : [ bool ; 3 ] ,
114
129
limit : usize ,
130
+ batch_size : Option < usize > ,
115
131
) -> Result < Vec < WordSimilarityResult > , [ bool ; 3 ] > ;
116
132
}
117
133
@@ -125,6 +141,7 @@ where
125
141
query : [ & str ; 3 ] ,
126
142
remove : [ bool ; 3 ] ,
127
143
limit : usize ,
144
+ batch_size : Option < usize > ,
128
145
) -> Result < Vec < WordSimilarityResult > , [ bool ; 3 ] > {
129
146
{
130
147
let [ embedding1, embedding2, embedding3] = lookup_words3 ( self , query) ?;
@@ -139,7 +156,7 @@ where
139
156
. map ( |( word, _) | word. to_owned ( ) )
140
157
. collect ( ) ;
141
158
142
- Ok ( self . similarity_ ( embedding. view ( ) , & skip, limit) )
159
+ Ok ( self . similarity_ ( embedding. view ( ) , & skip, limit, batch_size ) )
143
160
}
144
161
}
145
162
}
@@ -152,20 +169,37 @@ pub trait WordSimilarity {
152
169
/// the embeddings. If the vectors are unit vectors (e.g. by virtue of
153
170
/// calling `normalize`), this is the cosine similarity. At most, `limit`
154
171
/// results are returned.
155
- fn word_similarity ( & self , word : & str , limit : usize ) -> Option < Vec < WordSimilarityResult > > ;
172
+ ///
173
+ /// If `batch_size` is `None`, the query will be performed on all
174
+ /// word embeddings at once. This is typically the most efficient, but
175
+ /// can require a large amount of memory. The query is performed on batches
176
+ /// of size `n` when `batch_size` is `Some(n)`. Setting this to a smaller
177
+ /// value than the number of word embeddings reduces memory use at the
178
+ /// cost of computational efficiency.
179
+ fn word_similarity (
180
+ & self ,
181
+ word : & str ,
182
+ limit : usize ,
183
+ batch_size : Option < usize > ,
184
+ ) -> Option < Vec < WordSimilarityResult > > ;
156
185
}
157
186
158
187
impl < V , S > WordSimilarity for Embeddings < V , S >
159
188
where
160
189
V : Vocab ,
161
190
S : StorageView ,
162
191
{
163
- fn word_similarity ( & self , word : & str , limit : usize ) -> Option < Vec < WordSimilarityResult > > {
192
+ fn word_similarity (
193
+ & self ,
194
+ word : & str ,
195
+ limit : usize ,
196
+ batch_size : Option < usize > ,
197
+ ) -> Option < Vec < WordSimilarityResult > > {
164
198
let embed = self . embedding ( word) ?;
165
199
let mut skip = HashSet :: new ( ) ;
166
200
skip. insert ( word) ;
167
201
168
- Some ( self . similarity_ ( embed. view ( ) , & skip, limit) )
202
+ Some ( self . similarity_ ( embed. view ( ) , & skip, limit, batch_size ) )
169
203
}
170
204
}
171
205
@@ -177,12 +211,20 @@ pub trait EmbeddingSimilarity {
177
211
/// defined by the dot product of the embeddings. The embeddings in the
178
212
/// storage are l2-normalized, this method l2-normalizes the input query,
179
213
/// therefore the dot product is equivalent to the cosine similarity.
214
+ ///
215
+ /// If `batch_size` is `None`, the query will be performed on all
216
+ /// word embeddings at once. This is typically the most efficient, but
217
+ /// can require a large amount of memory. The query is performed on batches
218
+ /// of size `n` when `batch_size` is `Some(n)`. Setting this to a smaller
219
+ /// value than the number of word embeddings reduces memory use at the
220
+ /// cost of computational efficiency.
180
221
fn embedding_similarity (
181
222
& self ,
182
223
query : ArrayView1 < f32 > ,
183
224
limit : usize ,
225
+ batch_size : Option < usize > ,
184
226
) -> Option < Vec < WordSimilarityResult > > {
185
- self . embedding_similarity_masked ( query, limit, & HashSet :: new ( ) )
227
+ self . embedding_similarity_masked ( query, limit, & HashSet :: new ( ) , batch_size )
186
228
}
187
229
188
230
/// Find words that are similar to the query embedding while skipping
@@ -192,11 +234,19 @@ pub trait EmbeddingSimilarity {
192
234
/// defined by the dot product of the embeddings. The embeddings in the
193
235
/// storage are l2-normalized, this method l2-normalizes the input query,
194
236
/// therefore the dot product is equivalent to the cosine similarity.
237
+ ///
238
+ /// If `batch_size` is `None`, the query will be performed on all
239
+ /// word embeddings at once. This is typically the most efficient, but
240
+ /// can require a large amount of memory. The query is performed on batches
241
+ /// of size `n` when `batch_size` is `Some(n)`. Setting this to a smaller
242
+ /// value than the number of word embeddings reduces memory use at the
243
+ /// cost of computational efficiency.
195
244
fn embedding_similarity_masked (
196
245
& self ,
197
246
query : ArrayView1 < f32 > ,
198
247
limit : usize ,
199
248
skips : & HashSet < & str > ,
249
+ batch_size : Option < usize > ,
200
250
) -> Option < Vec < WordSimilarityResult > > ;
201
251
}
202
252
@@ -210,10 +260,11 @@ where
210
260
query : ArrayView1 < f32 > ,
211
261
limit : usize ,
212
262
skip : & HashSet < & str > ,
263
+ batch_size : Option < usize > ,
213
264
) -> Option < Vec < WordSimilarityResult > > {
214
265
let mut query = query. to_owned ( ) ;
215
266
l2_normalize ( query. view_mut ( ) ) ;
216
- Some ( self . similarity_ ( query. view ( ) , skip, limit) )
267
+ Some ( self . similarity_ ( query. view ( ) , skip, limit, batch_size ) )
217
268
}
218
269
}
219
270
@@ -223,6 +274,7 @@ trait SimilarityPrivate {
223
274
embed : ArrayView1 < f32 > ,
224
275
skip : & HashSet < & str > ,
225
276
limit : usize ,
277
+ batch_size : Option < usize > ,
226
278
) -> Vec < WordSimilarityResult > ;
227
279
}
228
280
@@ -236,35 +288,41 @@ where
236
288
embed : ArrayView1 < f32 > ,
237
289
skip : & HashSet < & str > ,
238
290
limit : usize ,
291
+ batch_size : Option < usize > ,
239
292
) -> Vec < WordSimilarityResult > {
240
- // ndarray#474
241
- #[ allow( clippy:: deref_addrof) ]
242
- let sims = self
293
+ let batch_size = batch_size. unwrap_or_else ( || self . vocab ( ) . words_len ( ) ) ;
294
+
295
+ let mut results = BinaryHeap :: with_capacity ( limit) ;
296
+
297
+ for ( batch_idx, batch) in self
243
298
. storage ( )
244
299
. view ( )
245
300
. slice ( s ! [ 0 ..self . vocab( ) . words_len( ) , ..] )
246
- . dot ( & embed. view ( ) ) ;
301
+ . axis_chunks_iter ( Axis ( 0 ) , batch_size)
302
+ . enumerate ( )
303
+ {
304
+ let sims = batch. dot ( & embed. view ( ) ) ;
247
305
248
- let mut results = BinaryHeap :: with_capacity ( limit) ;
249
- for ( idx, & sim) in sims. iter ( ) . enumerate ( ) {
250
- let word = & self . vocab ( ) . words ( ) [ idx] ;
306
+ for ( idx, & sim) in sims. iter ( ) . enumerate ( ) {
307
+ let word = & self . vocab ( ) . words ( ) [ ( batch_idx * batch_size) + idx] ;
251
308
252
- // Don't add words that we are explicitly asked to skip.
253
- if skip. contains ( word. as_str ( ) ) {
254
- continue ;
255
- }
309
+ // Don't add words that we are explicitly asked to skip.
310
+ if skip. contains ( word. as_str ( ) ) {
311
+ continue ;
312
+ }
256
313
257
- let word_similarity = WordSimilarityResult {
258
- word,
259
- similarity : NotNan :: new ( sim) . expect ( "Encountered NaN" ) ,
260
- } ;
261
-
262
- if results. len ( ) < limit {
263
- results. push ( word_similarity) ;
264
- } else {
265
- let mut peek = results. peek_mut ( ) . expect ( "Cannot peek non-empty heap" ) ;
266
- if word_similarity < * peek {
267
- * peek = word_similarity
314
+ let word_similarity = WordSimilarityResult {
315
+ word,
316
+ similarity : NotNan :: new ( sim) . expect ( "Encountered NaN" ) ,
317
+ } ;
318
+
319
+ if results. len ( ) < limit {
320
+ results. push ( word_similarity) ;
321
+ } else {
322
+ let mut peek = results. peek_mut ( ) . expect ( "Cannot peek non-empty heap" ) ;
323
+ if word_similarity < * peek {
324
+ * peek = word_similarity
325
+ }
268
326
}
269
327
}
270
328
}
@@ -504,7 +562,7 @@ mod tests {
504
562
let mut reader = BufReader :: new ( f) ;
505
563
let embeddings = Embeddings :: read_word2vec_binary ( & mut reader) . unwrap ( ) ;
506
564
507
- let result = embeddings. word_similarity ( "Berlin" , 40 ) ;
565
+ let result = embeddings. word_similarity ( "Berlin" , 40 , None ) ;
508
566
assert ! ( result. is_some( ) ) ;
509
567
let result = result. unwrap ( ) ;
510
568
assert_eq ! ( 40 , result. len( ) ) ;
@@ -513,14 +571,23 @@ mod tests {
513
571
assert_eq ! ( SIMILARITY_ORDER [ idx] , word_similarity. word)
514
572
}
515
573
516
- let result = embeddings. word_similarity ( "Berlin" , 10 ) ;
574
+ let result = embeddings. word_similarity ( "Berlin" , 10 , None ) ;
517
575
assert ! ( result. is_some( ) ) ;
518
576
let result = result. unwrap ( ) ;
519
577
assert_eq ! ( 10 , result. len( ) ) ;
520
578
521
579
for ( idx, word_similarity) in result. iter ( ) . enumerate ( ) {
522
580
assert_eq ! ( SIMILARITY_ORDER [ idx] , word_similarity. word)
523
581
}
582
+
583
+ let result = embeddings. word_similarity ( "Berlin" , 40 , Some ( 17 ) ) ;
584
+ assert ! ( result. is_some( ) ) ;
585
+ let result = result. unwrap ( ) ;
586
+ assert_eq ! ( 40 , result. len( ) ) ;
587
+
588
+ for ( idx, word_similarity) in result. iter ( ) . enumerate ( ) {
589
+ assert_eq ! ( SIMILARITY_ORDER [ idx] , word_similarity. word)
590
+ }
524
591
}
525
592
526
593
#[ test]
@@ -529,7 +596,7 @@ mod tests {
529
596
let mut reader = BufReader :: new ( f) ;
530
597
let embeddings = Embeddings :: read_word2vec_binary ( & mut reader) . unwrap ( ) ;
531
598
let embedding = embeddings. embedding ( "Berlin" ) . unwrap ( ) ;
532
- let result = embeddings. embedding_similarity ( embedding. view ( ) , 10 ) ;
599
+ let result = embeddings. embedding_similarity ( embedding. view ( ) , 10 , None ) ;
533
600
assert ! ( result. is_some( ) ) ;
534
601
let mut result = result. unwrap ( ) . into_iter ( ) ;
535
602
assert_eq ! ( 10 , result. len( ) ) ;
@@ -546,7 +613,7 @@ mod tests {
546
613
let mut reader = BufReader :: new ( f) ;
547
614
let embeddings = Embeddings :: read_word2vec_binary ( & mut reader) . unwrap ( ) ;
548
615
549
- let result = embeddings. word_similarity ( "Stuttgart" , 10 ) ;
616
+ let result = embeddings. word_similarity ( "Stuttgart" , 10 , None ) ;
550
617
assert ! ( result. is_some( ) ) ;
551
618
let result = result. unwrap ( ) ;
552
619
assert_eq ! ( 10 , result. len( ) ) ;
@@ -562,7 +629,16 @@ mod tests {
562
629
let mut reader = BufReader :: new ( f) ;
563
630
let embeddings = Embeddings :: read_word2vec_binary ( & mut reader) . unwrap ( ) ;
564
631
565
- let result = embeddings. analogy ( [ "Paris" , "Frankreich" , "Berlin" ] , 40 ) ;
632
+ let result = embeddings. analogy ( [ "Paris" , "Frankreich" , "Berlin" ] , 40 , None ) ;
633
+ assert ! ( result. is_ok( ) ) ;
634
+ let result = result. unwrap ( ) ;
635
+ assert_eq ! ( 40 , result. len( ) ) ;
636
+
637
+ for ( idx, word_similarity) in result. iter ( ) . enumerate ( ) {
638
+ assert_eq ! ( ANALOGY_ORDER [ idx] , word_similarity. word)
639
+ }
640
+
641
+ let result = embeddings. analogy ( [ "Paris" , "Frankreich" , "Berlin" ] , 40 , Some ( 17 ) ) ;
566
642
assert ! ( result. is_ok( ) ) ;
567
643
let result = result. unwrap ( ) ;
568
644
assert_eq ! ( 40 , result. len( ) ) ;
@@ -579,15 +655,15 @@ mod tests {
579
655
let embeddings = Embeddings :: read_word2vec_binary ( & mut reader) . unwrap ( ) ;
580
656
581
657
assert_eq ! (
582
- embeddings. analogy( [ "Foo" , "Frankreich" , "Berlin" ] , 40 ) ,
658
+ embeddings. analogy( [ "Foo" , "Frankreich" , "Berlin" ] , 40 , None ) ,
583
659
Err ( [ false , true , true ] )
584
660
) ;
585
661
assert_eq ! (
586
- embeddings. analogy( [ "Paris" , "Foo" , "Berlin" ] , 40 ) ,
662
+ embeddings. analogy( [ "Paris" , "Foo" , "Berlin" ] , 40 , None ) ,
587
663
Err ( [ true , false , true ] )
588
664
) ;
589
665
assert_eq ! (
590
- embeddings. analogy( [ "Paris" , "Frankreich" , "Foo" ] , 40 ) ,
666
+ embeddings. analogy( [ "Paris" , "Frankreich" , "Foo" ] , 40 , None ) ,
591
667
Err ( [ true , true , false ] )
592
668
) ;
593
669
}
0 commit comments