Skip to content

Commit 56de542

Browse files
danieldkDaniël de Kok
authored andcommitted
Add a batch size option to similarity and anology queries
This allows the user to trade-off efficiency for lower memory use.
1 parent 7e66f91 commit 56de542

File tree

2 files changed

+116
-40
lines changed

2 files changed

+116
-40
lines changed

src/compat/fasttext/io.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -620,7 +620,7 @@ mod tests {
620620
#[test]
621621
fn test_read_fasttext() {
622622
let embeddings = read_fasttext();
623-
let results = embeddings.word_similarity("über", 3).unwrap();
623+
let results = embeddings.word_similarity("über", 3, None).unwrap();
624624
assert_eq!(results[0].word(), "auf");
625625
assert_abs_diff_eq!(results[0].cosine_similarity(), 0.568513, epsilon = 1e-6);
626626
assert_eq!(results[1].word(), "vor");
@@ -632,7 +632,7 @@ mod tests {
632632
#[test]
633633
fn test_read_fasttext_unknown() {
634634
let embeddings = read_fasttext();
635-
let results = embeddings.word_similarity("unknown", 3).unwrap();
635+
let results = embeddings.word_similarity("unknown", 3, None).unwrap();
636636
assert_eq!(results[0].word(), "einer");
637637
assert_abs_diff_eq!(results[0].cosine_similarity(), 0.691177, epsilon = 1e-6);
638638
assert_eq!(results[1].word(), "und");

src/similarity.rs

Lines changed: 114 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ use std::cmp::Ordering;
44
use std::collections::{BinaryHeap, HashSet};
55
use std::f32;
66

7-
use ndarray::{s, ArrayView1, CowArray, Ix1};
7+
use ndarray::{s, ArrayView1, Axis, CowArray, Ix1};
88
use ordered_float::NotNan;
99

1010
use crate::chunks::storage::{Storage, StorageView};
@@ -82,12 +82,20 @@ pub trait Analogy {
8282
/// At most, `limit` results are returned. `Result::Err` is returned
8383
/// when no embedding could be computed for one or more of the tokens,
8484
/// indicating which of the tokens were present.
85+
///
86+
/// If `batch_size` is `None`, the query will be performed on all
87+
/// word embeddings at once. This is typically the most efficient, but
88+
/// can require a large amount of memory. The query is performed on batches
89+
/// of size `n` when `batch_size` is `Some(n)`. Setting this to a smaller
90+
/// value than the number of word embeddings reduces memory use at the
91+
/// cost of computational efficiency.
8592
fn analogy(
8693
&self,
8794
query: [&str; 3],
8895
limit: usize,
96+
batch_size: Option<usize>,
8997
) -> Result<Vec<WordSimilarityResult>, [bool; 3]> {
90-
self.analogy_masked(query, [true, true, true], limit)
98+
self.analogy_masked(query, [true, true, true], limit, batch_size)
9199
}
92100

93101
/// Perform an analogy query.
@@ -104,6 +112,13 @@ pub trait Analogy {
104112
/// output candidates. If `remove[0]` is `true`, `word1` cannot be
105113
/// returned as an answer to the query.
106114
///
115+
/// If `batch_size` is `None`, the query will be performed on all
116+
/// word embeddings at once. This is typically the most efficient, but
117+
/// can require a large amount of memory. The query is performed on batches
118+
/// of size `n` when `batch_size` is `Some(n)`. Setting this to a smaller
119+
/// value than the number of word embeddings reduces memory use at the
120+
/// cost of computational efficiency.
121+
///
107122
///`Result::Err` is returned when no embedding could be computed
108123
/// for one or more of the tokens, indicating which of the tokens
109124
/// were present.
@@ -112,6 +127,7 @@ pub trait Analogy {
112127
query: [&str; 3],
113128
remove: [bool; 3],
114129
limit: usize,
130+
batch_size: Option<usize>,
115131
) -> Result<Vec<WordSimilarityResult>, [bool; 3]>;
116132
}
117133

@@ -125,6 +141,7 @@ where
125141
query: [&str; 3],
126142
remove: [bool; 3],
127143
limit: usize,
144+
batch_size: Option<usize>,
128145
) -> Result<Vec<WordSimilarityResult>, [bool; 3]> {
129146
{
130147
let [embedding1, embedding2, embedding3] = lookup_words3(self, query)?;
@@ -139,7 +156,7 @@ where
139156
.map(|(word, _)| word.to_owned())
140157
.collect();
141158

142-
Ok(self.similarity_(embedding.view(), &skip, limit))
159+
Ok(self.similarity_(embedding.view(), &skip, limit, batch_size))
143160
}
144161
}
145162
}
@@ -152,20 +169,37 @@ pub trait WordSimilarity {
152169
/// the embeddings. If the vectors are unit vectors (e.g. by virtue of
153170
/// calling `normalize`), this is the cosine similarity. At most, `limit`
154171
/// results are returned.
155-
fn word_similarity(&self, word: &str, limit: usize) -> Option<Vec<WordSimilarityResult>>;
172+
///
173+
/// If `batch_size` is `None`, the query will be performed on all
174+
/// word embeddings at once. This is typically the most efficient, but
175+
/// can require a large amount of memory. The query is performed on batches
176+
/// of size `n` when `batch_size` is `Some(n)`. Setting this to a smaller
177+
/// value than the number of word embeddings reduces memory use at the
178+
/// cost of computational efficiency.
179+
fn word_similarity(
180+
&self,
181+
word: &str,
182+
limit: usize,
183+
batch_size: Option<usize>,
184+
) -> Option<Vec<WordSimilarityResult>>;
156185
}
157186

158187
impl<V, S> WordSimilarity for Embeddings<V, S>
159188
where
160189
V: Vocab,
161190
S: StorageView,
162191
{
163-
fn word_similarity(&self, word: &str, limit: usize) -> Option<Vec<WordSimilarityResult>> {
192+
fn word_similarity(
193+
&self,
194+
word: &str,
195+
limit: usize,
196+
batch_size: Option<usize>,
197+
) -> Option<Vec<WordSimilarityResult>> {
164198
let embed = self.embedding(word)?;
165199
let mut skip = HashSet::new();
166200
skip.insert(word);
167201

168-
Some(self.similarity_(embed.view(), &skip, limit))
202+
Some(self.similarity_(embed.view(), &skip, limit, batch_size))
169203
}
170204
}
171205

@@ -177,12 +211,20 @@ pub trait EmbeddingSimilarity {
177211
/// defined by the dot product of the embeddings. The embeddings in the
178212
/// storage are l2-normalized, this method l2-normalizes the input query,
179213
/// therefore the dot product is equivalent to the cosine similarity.
214+
///
215+
/// If `batch_size` is `None`, the query will be performed on all
216+
/// word embeddings at once. This is typically the most efficient, but
217+
/// can require a large amount of memory. The query is performed on batches
218+
/// of size `n` when `batch_size` is `Some(n)`. Setting this to a smaller
219+
/// value than the number of word embeddings reduces memory use at the
220+
/// cost of computational efficiency.
180221
fn embedding_similarity(
181222
&self,
182223
query: ArrayView1<f32>,
183224
limit: usize,
225+
batch_size: Option<usize>,
184226
) -> Option<Vec<WordSimilarityResult>> {
185-
self.embedding_similarity_masked(query, limit, &HashSet::new())
227+
self.embedding_similarity_masked(query, limit, &HashSet::new(), batch_size)
186228
}
187229

188230
/// Find words that are similar to the query embedding while skipping
@@ -192,11 +234,19 @@ pub trait EmbeddingSimilarity {
192234
/// defined by the dot product of the embeddings. The embeddings in the
193235
/// storage are l2-normalized, this method l2-normalizes the input query,
194236
/// therefore the dot product is equivalent to the cosine similarity.
237+
///
238+
/// If `batch_size` is `None`, the query will be performed on all
239+
/// word embeddings at once. This is typically the most efficient, but
240+
/// can require a large amount of memory. The query is performed on batches
241+
/// of size `n` when `batch_size` is `Some(n)`. Setting this to a smaller
242+
/// value than the number of word embeddings reduces memory use at the
243+
/// cost of computational efficiency.
195244
fn embedding_similarity_masked(
196245
&self,
197246
query: ArrayView1<f32>,
198247
limit: usize,
199248
skips: &HashSet<&str>,
249+
batch_size: Option<usize>,
200250
) -> Option<Vec<WordSimilarityResult>>;
201251
}
202252

@@ -210,10 +260,11 @@ where
210260
query: ArrayView1<f32>,
211261
limit: usize,
212262
skip: &HashSet<&str>,
263+
batch_size: Option<usize>,
213264
) -> Option<Vec<WordSimilarityResult>> {
214265
let mut query = query.to_owned();
215266
l2_normalize(query.view_mut());
216-
Some(self.similarity_(query.view(), skip, limit))
267+
Some(self.similarity_(query.view(), skip, limit, batch_size))
217268
}
218269
}
219270

@@ -223,6 +274,7 @@ trait SimilarityPrivate {
223274
embed: ArrayView1<f32>,
224275
skip: &HashSet<&str>,
225276
limit: usize,
277+
batch_size: Option<usize>,
226278
) -> Vec<WordSimilarityResult>;
227279
}
228280

@@ -236,35 +288,41 @@ where
236288
embed: ArrayView1<f32>,
237289
skip: &HashSet<&str>,
238290
limit: usize,
291+
batch_size: Option<usize>,
239292
) -> Vec<WordSimilarityResult> {
240-
// ndarray#474
241-
#[allow(clippy::deref_addrof)]
242-
let sims = self
293+
let batch_size = batch_size.unwrap_or_else(|| self.vocab().words_len());
294+
295+
let mut results = BinaryHeap::with_capacity(limit);
296+
297+
for (batch_idx, batch) in self
243298
.storage()
244299
.view()
245300
.slice(s![0..self.vocab().words_len(), ..])
246-
.dot(&embed.view());
301+
.axis_chunks_iter(Axis(0), batch_size)
302+
.enumerate()
303+
{
304+
let sims = batch.dot(&embed.view());
247305

248-
let mut results = BinaryHeap::with_capacity(limit);
249-
for (idx, &sim) in sims.iter().enumerate() {
250-
let word = &self.vocab().words()[idx];
306+
for (idx, &sim) in sims.iter().enumerate() {
307+
let word = &self.vocab().words()[(batch_idx * batch_size) + idx];
251308

252-
// Don't add words that we are explicitly asked to skip.
253-
if skip.contains(word.as_str()) {
254-
continue;
255-
}
309+
// Don't add words that we are explicitly asked to skip.
310+
if skip.contains(word.as_str()) {
311+
continue;
312+
}
256313

257-
let word_similarity = WordSimilarityResult {
258-
word,
259-
similarity: NotNan::new(sim).expect("Encountered NaN"),
260-
};
261-
262-
if results.len() < limit {
263-
results.push(word_similarity);
264-
} else {
265-
let mut peek = results.peek_mut().expect("Cannot peek non-empty heap");
266-
if word_similarity < *peek {
267-
*peek = word_similarity
314+
let word_similarity = WordSimilarityResult {
315+
word,
316+
similarity: NotNan::new(sim).expect("Encountered NaN"),
317+
};
318+
319+
if results.len() < limit {
320+
results.push(word_similarity);
321+
} else {
322+
let mut peek = results.peek_mut().expect("Cannot peek non-empty heap");
323+
if word_similarity < *peek {
324+
*peek = word_similarity
325+
}
268326
}
269327
}
270328
}
@@ -504,7 +562,7 @@ mod tests {
504562
let mut reader = BufReader::new(f);
505563
let embeddings = Embeddings::read_word2vec_binary(&mut reader).unwrap();
506564

507-
let result = embeddings.word_similarity("Berlin", 40);
565+
let result = embeddings.word_similarity("Berlin", 40, None);
508566
assert!(result.is_some());
509567
let result = result.unwrap();
510568
assert_eq!(40, result.len());
@@ -513,14 +571,23 @@ mod tests {
513571
assert_eq!(SIMILARITY_ORDER[idx], word_similarity.word)
514572
}
515573

516-
let result = embeddings.word_similarity("Berlin", 10);
574+
let result = embeddings.word_similarity("Berlin", 10, None);
517575
assert!(result.is_some());
518576
let result = result.unwrap();
519577
assert_eq!(10, result.len());
520578

521579
for (idx, word_similarity) in result.iter().enumerate() {
522580
assert_eq!(SIMILARITY_ORDER[idx], word_similarity.word)
523581
}
582+
583+
let result = embeddings.word_similarity("Berlin", 40, Some(17));
584+
assert!(result.is_some());
585+
let result = result.unwrap();
586+
assert_eq!(40, result.len());
587+
588+
for (idx, word_similarity) in result.iter().enumerate() {
589+
assert_eq!(SIMILARITY_ORDER[idx], word_similarity.word)
590+
}
524591
}
525592

526593
#[test]
@@ -529,7 +596,7 @@ mod tests {
529596
let mut reader = BufReader::new(f);
530597
let embeddings = Embeddings::read_word2vec_binary(&mut reader).unwrap();
531598
let embedding = embeddings.embedding("Berlin").unwrap();
532-
let result = embeddings.embedding_similarity(embedding.view(), 10);
599+
let result = embeddings.embedding_similarity(embedding.view(), 10, None);
533600
assert!(result.is_some());
534601
let mut result = result.unwrap().into_iter();
535602
assert_eq!(10, result.len());
@@ -546,7 +613,7 @@ mod tests {
546613
let mut reader = BufReader::new(f);
547614
let embeddings = Embeddings::read_word2vec_binary(&mut reader).unwrap();
548615

549-
let result = embeddings.word_similarity("Stuttgart", 10);
616+
let result = embeddings.word_similarity("Stuttgart", 10, None);
550617
assert!(result.is_some());
551618
let result = result.unwrap();
552619
assert_eq!(10, result.len());
@@ -562,7 +629,16 @@ mod tests {
562629
let mut reader = BufReader::new(f);
563630
let embeddings = Embeddings::read_word2vec_binary(&mut reader).unwrap();
564631

565-
let result = embeddings.analogy(["Paris", "Frankreich", "Berlin"], 40);
632+
let result = embeddings.analogy(["Paris", "Frankreich", "Berlin"], 40, None);
633+
assert!(result.is_ok());
634+
let result = result.unwrap();
635+
assert_eq!(40, result.len());
636+
637+
for (idx, word_similarity) in result.iter().enumerate() {
638+
assert_eq!(ANALOGY_ORDER[idx], word_similarity.word)
639+
}
640+
641+
let result = embeddings.analogy(["Paris", "Frankreich", "Berlin"], 40, Some(17));
566642
assert!(result.is_ok());
567643
let result = result.unwrap();
568644
assert_eq!(40, result.len());
@@ -579,15 +655,15 @@ mod tests {
579655
let embeddings = Embeddings::read_word2vec_binary(&mut reader).unwrap();
580656

581657
assert_eq!(
582-
embeddings.analogy(["Foo", "Frankreich", "Berlin"], 40),
658+
embeddings.analogy(["Foo", "Frankreich", "Berlin"], 40, None),
583659
Err([false, true, true])
584660
);
585661
assert_eq!(
586-
embeddings.analogy(["Paris", "Foo", "Berlin"], 40),
662+
embeddings.analogy(["Paris", "Foo", "Berlin"], 40, None),
587663
Err([true, false, true])
588664
);
589665
assert_eq!(
590-
embeddings.analogy(["Paris", "Frankreich", "Foo"], 40),
666+
embeddings.analogy(["Paris", "Frankreich", "Foo"], 40, None),
591667
Err([true, true, false])
592668
);
593669
}

0 commit comments

Comments
 (0)