@@ -4,7 +4,7 @@ use std::cmp::Ordering;
44use std:: collections:: { BinaryHeap , HashSet } ;
55use std:: f32;
66
7- use ndarray:: { s, ArrayView1 , CowArray , Ix1 } ;
7+ use ndarray:: { s, ArrayView1 , Axis , CowArray , Ix1 } ;
88use ordered_float:: NotNan ;
99
1010use crate :: chunks:: storage:: { Storage , StorageView } ;
@@ -82,12 +82,20 @@ pub trait Analogy {
8282 /// At most, `limit` results are returned. `Result::Err` is returned
8383 /// when no embedding could be computed for one or more of the tokens,
8484 /// indicating which of the tokens were present.
85+ ///
86+ /// If `batch_size` is `None`, the query will be performed on all
87+ /// word embeddings at once. This is typically the most efficient, but
88+ /// can require a large amount of memory. The query is performed on batches
89+ /// of size `n` when `batch_size` is `Some(n)`. Setting this to a smaller
90+ /// value than the number of word embeddings reduces memory use at the
91+ /// cost of computational efficiency.
8592 fn analogy (
8693 & self ,
8794 query : [ & str ; 3 ] ,
8895 limit : usize ,
96+ batch_size : Option < usize > ,
8997 ) -> Result < Vec < WordSimilarityResult > , [ bool ; 3 ] > {
90- self . analogy_masked ( query, [ true , true , true ] , limit)
98+ self . analogy_masked ( query, [ true , true , true ] , limit, batch_size )
9199 }
92100
93101 /// Perform an analogy query.
@@ -104,6 +112,13 @@ pub trait Analogy {
104112 /// output candidates. If `remove[0]` is `true`, `word1` cannot be
105113 /// returned as an answer to the query.
106114 ///
115+ /// If `batch_size` is `None`, the query will be performed on all
116+ /// word embeddings at once. This is typically the most efficient, but
117+ /// can require a large amount of memory. The query is performed on batches
118+ /// of size `n` when `batch_size` is `Some(n)`. Setting this to a smaller
119+ /// value than the number of word embeddings reduces memory use at the
120+ /// cost of computational efficiency.
121+ ///
107122 ///`Result::Err` is returned when no embedding could be computed
108123 /// for one or more of the tokens, indicating which of the tokens
109124 /// were present.
@@ -112,6 +127,7 @@ pub trait Analogy {
112127 query : [ & str ; 3 ] ,
113128 remove : [ bool ; 3 ] ,
114129 limit : usize ,
130+ batch_size : Option < usize > ,
115131 ) -> Result < Vec < WordSimilarityResult > , [ bool ; 3 ] > ;
116132}
117133
@@ -125,6 +141,7 @@ where
125141 query : [ & str ; 3 ] ,
126142 remove : [ bool ; 3 ] ,
127143 limit : usize ,
144+ batch_size : Option < usize > ,
128145 ) -> Result < Vec < WordSimilarityResult > , [ bool ; 3 ] > {
129146 {
130147 let [ embedding1, embedding2, embedding3] = lookup_words3 ( self , query) ?;
@@ -139,7 +156,7 @@ where
139156 . map ( |( word, _) | word. to_owned ( ) )
140157 . collect ( ) ;
141158
142- Ok ( self . similarity_ ( embedding. view ( ) , & skip, limit) )
159+ Ok ( self . similarity_ ( embedding. view ( ) , & skip, limit, batch_size ) )
143160 }
144161 }
145162}
@@ -152,20 +169,37 @@ pub trait WordSimilarity {
152169 /// the embeddings. If the vectors are unit vectors (e.g. by virtue of
153170 /// calling `normalize`), this is the cosine similarity. At most, `limit`
154171 /// results are returned.
155- fn word_similarity ( & self , word : & str , limit : usize ) -> Option < Vec < WordSimilarityResult > > ;
172+ ///
173+ /// If `batch_size` is `None`, the query will be performed on all
174+ /// word embeddings at once. This is typically the most efficient, but
175+ /// can require a large amount of memory. The query is performed on batches
176+ /// of size `n` when `batch_size` is `Some(n)`. Setting this to a smaller
177+ /// value than the number of word embeddings reduces memory use at the
178+ /// cost of computational efficiency.
179+ fn word_similarity (
180+ & self ,
181+ word : & str ,
182+ limit : usize ,
183+ batch_size : Option < usize > ,
184+ ) -> Option < Vec < WordSimilarityResult > > ;
156185}
157186
158187impl < V , S > WordSimilarity for Embeddings < V , S >
159188where
160189 V : Vocab ,
161190 S : StorageView ,
162191{
163- fn word_similarity ( & self , word : & str , limit : usize ) -> Option < Vec < WordSimilarityResult > > {
192+ fn word_similarity (
193+ & self ,
194+ word : & str ,
195+ limit : usize ,
196+ batch_size : Option < usize > ,
197+ ) -> Option < Vec < WordSimilarityResult > > {
164198 let embed = self . embedding ( word) ?;
165199 let mut skip = HashSet :: new ( ) ;
166200 skip. insert ( word) ;
167201
168- Some ( self . similarity_ ( embed. view ( ) , & skip, limit) )
202+ Some ( self . similarity_ ( embed. view ( ) , & skip, limit, batch_size ) )
169203 }
170204}
171205
@@ -177,12 +211,20 @@ pub trait EmbeddingSimilarity {
177211 /// defined by the dot product of the embeddings. The embeddings in the
178212 /// storage are l2-normalized, this method l2-normalizes the input query,
179213 /// therefore the dot product is equivalent to the cosine similarity.
214+ ///
215+ /// If `batch_size` is `None`, the query will be performed on all
216+ /// word embeddings at once. This is typically the most efficient, but
217+ /// can require a large amount of memory. The query is performed on batches
218+ /// of size `n` when `batch_size` is `Some(n)`. Setting this to a smaller
219+ /// value than the number of word embeddings reduces memory use at the
220+ /// cost of computational efficiency.
180221 fn embedding_similarity (
181222 & self ,
182223 query : ArrayView1 < f32 > ,
183224 limit : usize ,
225+ batch_size : Option < usize > ,
184226 ) -> Option < Vec < WordSimilarityResult > > {
185- self . embedding_similarity_masked ( query, limit, & HashSet :: new ( ) )
227+ self . embedding_similarity_masked ( query, limit, & HashSet :: new ( ) , batch_size )
186228 }
187229
188230 /// Find words that are similar to the query embedding while skipping
@@ -192,11 +234,19 @@ pub trait EmbeddingSimilarity {
192234 /// defined by the dot product of the embeddings. The embeddings in the
193235 /// storage are l2-normalized, this method l2-normalizes the input query,
194236 /// therefore the dot product is equivalent to the cosine similarity.
237+ ///
238+ /// If `batch_size` is `None`, the query will be performed on all
239+ /// word embeddings at once. This is typically the most efficient, but
240+ /// can require a large amount of memory. The query is performed on batches
241+ /// of size `n` when `batch_size` is `Some(n)`. Setting this to a smaller
242+ /// value than the number of word embeddings reduces memory use at the
243+ /// cost of computational efficiency.
195244 fn embedding_similarity_masked (
196245 & self ,
197246 query : ArrayView1 < f32 > ,
198247 limit : usize ,
199248 skips : & HashSet < & str > ,
249+ batch_size : Option < usize > ,
200250 ) -> Option < Vec < WordSimilarityResult > > ;
201251}
202252
@@ -210,10 +260,11 @@ where
210260 query : ArrayView1 < f32 > ,
211261 limit : usize ,
212262 skip : & HashSet < & str > ,
263+ batch_size : Option < usize > ,
213264 ) -> Option < Vec < WordSimilarityResult > > {
214265 let mut query = query. to_owned ( ) ;
215266 l2_normalize ( query. view_mut ( ) ) ;
216- Some ( self . similarity_ ( query. view ( ) , skip, limit) )
267+ Some ( self . similarity_ ( query. view ( ) , skip, limit, batch_size ) )
217268 }
218269}
219270
@@ -223,6 +274,7 @@ trait SimilarityPrivate {
223274 embed : ArrayView1 < f32 > ,
224275 skip : & HashSet < & str > ,
225276 limit : usize ,
277+ batch_size : Option < usize > ,
226278 ) -> Vec < WordSimilarityResult > ;
227279}
228280
@@ -236,35 +288,41 @@ where
236288 embed : ArrayView1 < f32 > ,
237289 skip : & HashSet < & str > ,
238290 limit : usize ,
291+ batch_size : Option < usize > ,
239292 ) -> Vec < WordSimilarityResult > {
240- // ndarray#474
241- #[ allow( clippy:: deref_addrof) ]
242- let sims = self
293+ let batch_size = batch_size. unwrap_or_else ( || self . vocab ( ) . words_len ( ) ) ;
294+
295+ let mut results = BinaryHeap :: with_capacity ( limit) ;
296+
297+ for ( batch_idx, batch) in self
243298 . storage ( )
244299 . view ( )
245300 . slice ( s ! [ 0 ..self . vocab( ) . words_len( ) , ..] )
246- . dot ( & embed. view ( ) ) ;
301+ . axis_chunks_iter ( Axis ( 0 ) , batch_size)
302+ . enumerate ( )
303+ {
304+ let sims = batch. dot ( & embed. view ( ) ) ;
247305
248- let mut results = BinaryHeap :: with_capacity ( limit) ;
249- for ( idx, & sim) in sims. iter ( ) . enumerate ( ) {
250- let word = & self . vocab ( ) . words ( ) [ idx] ;
306+ for ( idx, & sim) in sims. iter ( ) . enumerate ( ) {
307+ let word = & self . vocab ( ) . words ( ) [ ( batch_idx * batch_size) + idx] ;
251308
252- // Don't add words that we are explicitly asked to skip.
253- if skip. contains ( word. as_str ( ) ) {
254- continue ;
255- }
309+ // Don't add words that we are explicitly asked to skip.
310+ if skip. contains ( word. as_str ( ) ) {
311+ continue ;
312+ }
256313
257- let word_similarity = WordSimilarityResult {
258- word,
259- similarity : NotNan :: new ( sim) . expect ( "Encountered NaN" ) ,
260- } ;
261-
262- if results. len ( ) < limit {
263- results. push ( word_similarity) ;
264- } else {
265- let mut peek = results. peek_mut ( ) . expect ( "Cannot peek non-empty heap" ) ;
266- if word_similarity < * peek {
267- * peek = word_similarity
314+ let word_similarity = WordSimilarityResult {
315+ word,
316+ similarity : NotNan :: new ( sim) . expect ( "Encountered NaN" ) ,
317+ } ;
318+
319+ if results. len ( ) < limit {
320+ results. push ( word_similarity) ;
321+ } else {
322+ let mut peek = results. peek_mut ( ) . expect ( "Cannot peek non-empty heap" ) ;
323+ if word_similarity < * peek {
324+ * peek = word_similarity
325+ }
268326 }
269327 }
270328 }
@@ -504,7 +562,7 @@ mod tests {
504562 let mut reader = BufReader :: new ( f) ;
505563 let embeddings = Embeddings :: read_word2vec_binary ( & mut reader) . unwrap ( ) ;
506564
507- let result = embeddings. word_similarity ( "Berlin" , 40 ) ;
565+ let result = embeddings. word_similarity ( "Berlin" , 40 , None ) ;
508566 assert ! ( result. is_some( ) ) ;
509567 let result = result. unwrap ( ) ;
510568 assert_eq ! ( 40 , result. len( ) ) ;
@@ -513,14 +571,23 @@ mod tests {
513571 assert_eq ! ( SIMILARITY_ORDER [ idx] , word_similarity. word)
514572 }
515573
516- let result = embeddings. word_similarity ( "Berlin" , 10 ) ;
574+ let result = embeddings. word_similarity ( "Berlin" , 10 , None ) ;
517575 assert ! ( result. is_some( ) ) ;
518576 let result = result. unwrap ( ) ;
519577 assert_eq ! ( 10 , result. len( ) ) ;
520578
521579 for ( idx, word_similarity) in result. iter ( ) . enumerate ( ) {
522580 assert_eq ! ( SIMILARITY_ORDER [ idx] , word_similarity. word)
523581 }
582+
583+ let result = embeddings. word_similarity ( "Berlin" , 40 , Some ( 17 ) ) ;
584+ assert ! ( result. is_some( ) ) ;
585+ let result = result. unwrap ( ) ;
586+ assert_eq ! ( 40 , result. len( ) ) ;
587+
588+ for ( idx, word_similarity) in result. iter ( ) . enumerate ( ) {
589+ assert_eq ! ( SIMILARITY_ORDER [ idx] , word_similarity. word)
590+ }
524591 }
525592
526593 #[ test]
@@ -529,7 +596,7 @@ mod tests {
529596 let mut reader = BufReader :: new ( f) ;
530597 let embeddings = Embeddings :: read_word2vec_binary ( & mut reader) . unwrap ( ) ;
531598 let embedding = embeddings. embedding ( "Berlin" ) . unwrap ( ) ;
532- let result = embeddings. embedding_similarity ( embedding. view ( ) , 10 ) ;
599+ let result = embeddings. embedding_similarity ( embedding. view ( ) , 10 , None ) ;
533600 assert ! ( result. is_some( ) ) ;
534601 let mut result = result. unwrap ( ) . into_iter ( ) ;
535602 assert_eq ! ( 10 , result. len( ) ) ;
@@ -546,7 +613,7 @@ mod tests {
546613 let mut reader = BufReader :: new ( f) ;
547614 let embeddings = Embeddings :: read_word2vec_binary ( & mut reader) . unwrap ( ) ;
548615
549- let result = embeddings. word_similarity ( "Stuttgart" , 10 ) ;
616+ let result = embeddings. word_similarity ( "Stuttgart" , 10 , None ) ;
550617 assert ! ( result. is_some( ) ) ;
551618 let result = result. unwrap ( ) ;
552619 assert_eq ! ( 10 , result. len( ) ) ;
@@ -562,7 +629,16 @@ mod tests {
562629 let mut reader = BufReader :: new ( f) ;
563630 let embeddings = Embeddings :: read_word2vec_binary ( & mut reader) . unwrap ( ) ;
564631
565- let result = embeddings. analogy ( [ "Paris" , "Frankreich" , "Berlin" ] , 40 ) ;
632+ let result = embeddings. analogy ( [ "Paris" , "Frankreich" , "Berlin" ] , 40 , None ) ;
633+ assert ! ( result. is_ok( ) ) ;
634+ let result = result. unwrap ( ) ;
635+ assert_eq ! ( 40 , result. len( ) ) ;
636+
637+ for ( idx, word_similarity) in result. iter ( ) . enumerate ( ) {
638+ assert_eq ! ( ANALOGY_ORDER [ idx] , word_similarity. word)
639+ }
640+
641+ let result = embeddings. analogy ( [ "Paris" , "Frankreich" , "Berlin" ] , 40 , Some ( 17 ) ) ;
566642 assert ! ( result. is_ok( ) ) ;
567643 let result = result. unwrap ( ) ;
568644 assert_eq ! ( 40 , result. len( ) ) ;
@@ -579,15 +655,15 @@ mod tests {
579655 let embeddings = Embeddings :: read_word2vec_binary ( & mut reader) . unwrap ( ) ;
580656
581657 assert_eq ! (
582- embeddings. analogy( [ "Foo" , "Frankreich" , "Berlin" ] , 40 ) ,
658+ embeddings. analogy( [ "Foo" , "Frankreich" , "Berlin" ] , 40 , None ) ,
583659 Err ( [ false , true , true ] )
584660 ) ;
585661 assert_eq ! (
586- embeddings. analogy( [ "Paris" , "Foo" , "Berlin" ] , 40 ) ,
662+ embeddings. analogy( [ "Paris" , "Foo" , "Berlin" ] , 40 , None ) ,
587663 Err ( [ true , false , true ] )
588664 ) ;
589665 assert_eq ! (
590- embeddings. analogy( [ "Paris" , "Frankreich" , "Foo" ] , 40 ) ,
666+ embeddings. analogy( [ "Paris" , "Frankreich" , "Foo" ] , 40 , None ) ,
591667 Err ( [ true , true , false ] )
592668 ) ;
593669 }
0 commit comments