@@ -120,6 +120,30 @@ def _hasher(self):
120
120
return sklearn .feature_extraction .text .FeatureHasher
121
121
122
122
123
+ def _n_samples (X ):
124
+ """Count the number of samples in dask.array.Array X."""
125
+ def chunk_n_samples (chunk , axis , keepdims ):
126
+ return np .array ([chunk .shape [0 ]], dtype = np .int64 )
127
+
128
+ return da .reduction (X ,
129
+ chunk = chunk_n_samples ,
130
+ aggregate = np .sum ,
131
+ concatenate = False ,
132
+ dtype = np .int64 )
133
+
134
+
135
+ def _n_features (X ):
136
+ """Count the number of features in dask.array.Array X."""
137
+ def chunk_n_features (chunk , axis , keepdims ):
138
+ return np .array ([chunk .shape [1 ]], dtype = np .int64 )
139
+
140
+ return da .reduction (X ,
141
+ chunk = chunk_n_features ,
142
+ aggregate = lambda x , axis , keepdims : x [0 ],
143
+ concatenate = True ,
144
+ dtype = np .int64 )
145
+
146
+
123
147
def _document_frequency (X , dtype ):
124
148
"""Count the number of non-zero values for each feature in dask array X."""
125
149
def chunk_doc_freq (chunk , axis , keepdims ):
@@ -133,7 +157,7 @@ def chunk_doc_freq(chunk, axis, keepdims):
133
157
aggregate = np .sum ,
134
158
axis = 0 ,
135
159
concatenate = False ,
136
- dtype = dtype ). compute (). astype ( dtype )
160
+ dtype = dtype )
137
161
138
162
139
163
class CountVectorizer (sklearn .feature_extraction .text .CountVectorizer ):
@@ -203,17 +227,19 @@ class CountVectorizer(sklearn.feature_extraction.text.CountVectorizer):
203
227
['and', 'document', 'first', 'is', 'one', 'second', 'the', 'third', 'this']
204
228
"""
205
229
206
- def fit_transform (self , raw_documents , y = None ):
230
+ def get_params (self ):
207
231
# Note that in general 'self' could refer to an instance of either this
208
232
# class or a subclass of this class. Hence it is possible that
209
233
# self.get_params() could get unexpected parameters of an instance of a
210
234
# subclass. Such parameters need to be excluded here:
211
- subclass_instance_params = self .get_params ()
235
+ subclass_instance_params = super () .get_params ()
212
236
excluded_keys = getattr (self , '_non_CountVectorizer_params' , [])
213
- params = {key : subclass_instance_params [key ]
214
- for key in subclass_instance_params
215
- if key not in excluded_keys }
237
+ return {key : subclass_instance_params [key ]
238
+ for key in subclass_instance_params
239
+ if key not in excluded_keys }
216
240
241
+ def fit_transform (self , raw_documents , y = None ):
242
+ params = self .get_params ()
217
243
vocabulary = params .pop ("vocabulary" )
218
244
vocabulary_for_transform = vocabulary
219
245
@@ -227,12 +253,12 @@ def fit_transform(self, raw_documents, y=None):
227
253
# Case 2: learn vocabulary from the data.
228
254
vocabularies = raw_documents .map_partitions (_build_vocabulary , params )
229
255
vocabulary = vocabulary_for_transform = (
230
- _merge_vocabulary ( * vocabularies .to_delayed () ))
256
+ _merge_vocabulary (* vocabularies .to_delayed ()))
231
257
vocabulary_for_transform = vocabulary_for_transform .persist ()
232
258
vocabulary_ = vocabulary .compute ()
233
259
n_features = len (vocabulary_ )
234
260
235
- meta = scipy .sparse .eye ( 0 , format = "csr" , dtype = self .dtype )
261
+ meta = scipy .sparse .csr_matrix (( 0 , n_features ) , dtype = self .dtype )
236
262
if isinstance (raw_documents , dd .Series ):
237
263
result = raw_documents .map_partitions (
238
264
_count_vectorizer_transform , vocabulary_for_transform ,
@@ -241,23 +267,14 @@ def fit_transform(self, raw_documents, y=None):
241
267
result = raw_documents .map_partitions (
242
268
_count_vectorizer_transform , vocabulary_for_transform , params )
243
269
result = build_array (result , n_features , meta )
244
- result .compute_chunk_sizes ()
245
270
246
271
self .vocabulary_ = vocabulary_
247
272
self .fixed_vocabulary_ = fixed_vocabulary
248
273
249
274
return result
250
275
251
276
def transform (self , raw_documents ):
252
- # Note that in general 'self' could refer to an instance of either this
253
- # class or a subclass of this class. Hence it is possible that
254
- # self.get_params() could get unexpected parameters of an instance of a
255
- # subclass. Such parameters need to be excluded here:
256
- subclass_instance_params = self .get_params ()
257
- excluded_keys = getattr (self , '_non_CountVectorizer_params' , [])
258
- params = {key : subclass_instance_params [key ]
259
- for key in subclass_instance_params
260
- if key not in excluded_keys }
277
+ params = self .get_params ()
261
278
vocabulary = params .pop ("vocabulary" )
262
279
263
280
if vocabulary is None :
@@ -271,14 +288,13 @@ def transform(self, raw_documents):
271
288
except ValueError :
272
289
vocabulary_for_transform = dask .delayed (vocabulary )
273
290
else :
274
- (vocabulary_for_transform ,) = client .scatter (
275
- (vocabulary ,), broadcast = True
276
- )
291
+ (vocabulary_for_transform ,) = client .scatter ((vocabulary ,),
292
+ broadcast = True )
277
293
else :
278
294
vocabulary_for_transform = vocabulary
279
295
280
296
n_features = vocabulary_length (vocabulary_for_transform )
281
- meta = scipy .sparse .eye ( 0 , format = "csr" , dtype = self .dtype )
297
+ meta = scipy .sparse .csr_matrix (( 0 , n_features ) , dtype = self .dtype )
282
298
if isinstance (raw_documents , dd .Series ):
283
299
result = raw_documents .map_partitions (
284
300
_count_vectorizer_transform , vocabulary_for_transform ,
@@ -287,7 +303,6 @@ def transform(self, raw_documents):
287
303
transformed = raw_documents .map_partitions (
288
304
_count_vectorizer_transform , vocabulary_for_transform , params )
289
305
result = build_array (transformed , n_features , meta )
290
- result .compute_chunk_sizes ()
291
306
return result
292
307
293
308
class TfidfTransformer (sklearn .feature_extraction .text .TfidfTransformer ):
@@ -331,30 +346,23 @@ def fit(self, X, y=None):
331
346
X : sparse matrix of shape n_samples, n_features)
332
347
A matrix of term/token counts.
333
348
"""
334
- # X = check_array(X, accept_sparse=('csr', 'csc'))
335
- # if not sp.issparse(X):
336
- # X = sp.csr_matrix(X)
337
- dtype = X .dtype if X .dtype in FLOAT_DTYPES else np .float64
338
-
339
- if self .use_idf :
340
- n_samples , n_features = X .shape
349
+ def get_idf_diag (X , dtype ):
350
+ n_samples = _n_samples (X ) # X.shape[0] is not yet known
351
+ n_features = X .shape [1 ]
341
352
df = _document_frequency (X , dtype )
342
- # df = df.astype(dtype, **_astype_copy_false(df))
343
353
344
354
# perform idf smoothing if required
345
355
df += int (self .smooth_idf )
346
356
n_samples += int (self .smooth_idf )
347
357
348
358
# log+1 instead of log makes sure terms with zero idf don't get
349
359
# suppressed entirely.
350
- idf = np .log (n_samples / df ) + 1
351
- self ._idf_diag = scipy .sparse .diags (
352
- idf ,
353
- offsets = 0 ,
354
- shape = (n_features , n_features ),
355
- format = "csr" ,
356
- dtype = dtype ,
357
- )
360
+ return np .log (n_samples / df ) + 1
361
+
362
+ dtype = X .dtype if X .dtype in FLOAT_DTYPES else np .float64
363
+
364
+ if self .use_idf :
365
+ self ._idf_diag = get_idf_diag (X , dtype )
358
366
359
367
return self
360
368
@@ -404,8 +412,17 @@ def _dot_idf_diag(chunk):
404
412
# idf_ being a property, the automatic attributes detection
405
413
# does not work as usual and we need to specify the attribute
406
414
# name:
407
- check_is_fitted (self , attributes = ["idf_" ], msg = "idf vector is not fitted" )
408
-
415
+ check_is_fitted (self , attributes = ["idf_" ],
416
+ msg = "idf vector is not fitted" )
417
+ if dask .is_dask_collection (self ._idf_diag ):
418
+ _idf_diag = self ._idf_diag .compute ()
419
+ n_features = len (_idf_diag )
420
+ self ._idf_diag = scipy .sparse .diags (
421
+ _idf_diag ,
422
+ offsets = 0 ,
423
+ shape = (n_features , n_features ),
424
+ format = "csr" ,
425
+ dtype = _idf_diag .dtype )
409
426
X = X .map_blocks (_dot_idf_diag , dtype = np .float64 , meta = meta )
410
427
411
428
if self .norm :
@@ -619,8 +636,7 @@ def fit(self, raw_documents, y=None):
619
636
"""
620
637
self ._check_params ()
621
638
self ._warn_for_unused_params ()
622
- X = super ().fit_transform (raw_documents ,
623
- y = self ._non_CountVectorizer_params )
639
+ X = super ().fit_transform (raw_documents )
624
640
self ._tfidf .fit (X )
625
641
return self
626
642
0 commit comments