Skip to content

Commit 0a20f5c

Browse files
removed all unnecessary calls to compute()
1 parent 1ac8e38 commit 0a20f5c

File tree

2 files changed

+59
-45
lines changed

2 files changed

+59
-45
lines changed

.gitignore

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -122,5 +122,3 @@ docs/source/auto_examples/
122122
docs/source/examples/mydask.png
123123

124124
dask-worker-space
125-
/.project
126-
/.pydevproject

dask_ml/feature_extraction/text.py

Lines changed: 59 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,30 @@ def _hasher(self):
120120
return sklearn.feature_extraction.text.FeatureHasher
121121

122122

123+
def _n_samples(X):
124+
"""Count the number of samples in dask.array.Array X."""
125+
def chunk_n_samples(chunk, axis, keepdims):
126+
return np.array([chunk.shape[0]], dtype=np.int64)
127+
128+
return da.reduction(X,
129+
chunk=chunk_n_samples,
130+
aggregate=np.sum,
131+
concatenate=False,
132+
dtype=np.int64)
133+
134+
135+
def _n_features(X):
136+
"""Count the number of features in dask.array.Array X."""
137+
def chunk_n_features(chunk, axis, keepdims):
138+
return np.array([chunk.shape[1]], dtype=np.int64)
139+
140+
return da.reduction(X,
141+
chunk=chunk_n_features,
142+
aggregate=lambda x, axis, keepdims: x[0],
143+
concatenate=True,
144+
dtype=np.int64)
145+
146+
123147
def _document_frequency(X, dtype):
124148
"""Count the number of non-zero values for each feature in dask array X."""
125149
def chunk_doc_freq(chunk, axis, keepdims):
@@ -133,7 +157,7 @@ def chunk_doc_freq(chunk, axis, keepdims):
133157
aggregate=np.sum,
134158
axis=0,
135159
concatenate=False,
136-
dtype=dtype).compute().astype(dtype)
160+
dtype=dtype)
137161

138162

139163
class CountVectorizer(sklearn.feature_extraction.text.CountVectorizer):
@@ -203,17 +227,19 @@ class CountVectorizer(sklearn.feature_extraction.text.CountVectorizer):
203227
['and', 'document', 'first', 'is', 'one', 'second', 'the', 'third', 'this']
204228
"""
205229

206-
def fit_transform(self, raw_documents, y=None):
230+
def get_params(self):
207231
# Note that in general 'self' could refer to an instance of either this
208232
# class or a subclass of this class. Hence it is possible that
209233
# self.get_params() could get unexpected parameters of an instance of a
210234
# subclass. Such parameters need to be excluded here:
211-
subclass_instance_params = self.get_params()
235+
subclass_instance_params = super().get_params()
212236
excluded_keys = getattr(self, '_non_CountVectorizer_params', [])
213-
params = {key: subclass_instance_params[key]
214-
for key in subclass_instance_params
215-
if key not in excluded_keys}
237+
return {key: subclass_instance_params[key]
238+
for key in subclass_instance_params
239+
if key not in excluded_keys}
216240

241+
def fit_transform(self, raw_documents, y=None):
242+
params = self.get_params()
217243
vocabulary = params.pop("vocabulary")
218244
vocabulary_for_transform = vocabulary
219245

@@ -227,12 +253,12 @@ def fit_transform(self, raw_documents, y=None):
227253
# Case 2: learn vocabulary from the data.
228254
vocabularies = raw_documents.map_partitions(_build_vocabulary, params)
229255
vocabulary = vocabulary_for_transform = (
230-
_merge_vocabulary( *vocabularies.to_delayed() ))
256+
_merge_vocabulary(*vocabularies.to_delayed()))
231257
vocabulary_for_transform = vocabulary_for_transform.persist()
232258
vocabulary_ = vocabulary.compute()
233259
n_features = len(vocabulary_)
234260

235-
meta = scipy.sparse.eye(0, format="csr", dtype=self.dtype)
261+
meta = scipy.sparse.csr_matrix((0, n_features), dtype=self.dtype)
236262
if isinstance(raw_documents, dd.Series):
237263
result = raw_documents.map_partitions(
238264
_count_vectorizer_transform, vocabulary_for_transform,
@@ -241,23 +267,14 @@ def fit_transform(self, raw_documents, y=None):
241267
result = raw_documents.map_partitions(
242268
_count_vectorizer_transform, vocabulary_for_transform, params)
243269
result = build_array(result, n_features, meta)
244-
result.compute_chunk_sizes()
245270

246271
self.vocabulary_ = vocabulary_
247272
self.fixed_vocabulary_ = fixed_vocabulary
248273

249274
return result
250275

251276
def transform(self, raw_documents):
252-
# Note that in general 'self' could refer to an instance of either this
253-
# class or a subclass of this class. Hence it is possible that
254-
# self.get_params() could get unexpected parameters of an instance of a
255-
# subclass. Such parameters need to be excluded here:
256-
subclass_instance_params = self.get_params()
257-
excluded_keys = getattr(self, '_non_CountVectorizer_params', [])
258-
params = {key: subclass_instance_params[key]
259-
for key in subclass_instance_params
260-
if key not in excluded_keys}
277+
params = self.get_params()
261278
vocabulary = params.pop("vocabulary")
262279

263280
if vocabulary is None:
@@ -271,14 +288,13 @@ def transform(self, raw_documents):
271288
except ValueError:
272289
vocabulary_for_transform = dask.delayed(vocabulary)
273290
else:
274-
(vocabulary_for_transform,) = client.scatter(
275-
(vocabulary,), broadcast=True
276-
)
291+
(vocabulary_for_transform,) = client.scatter((vocabulary,),
292+
broadcast=True)
277293
else:
278294
vocabulary_for_transform = vocabulary
279295

280296
n_features = vocabulary_length(vocabulary_for_transform)
281-
meta = scipy.sparse.eye(0, format="csr", dtype=self.dtype)
297+
meta = scipy.sparse.csr_matrix((0, n_features), dtype=self.dtype)
282298
if isinstance(raw_documents, dd.Series):
283299
result = raw_documents.map_partitions(
284300
_count_vectorizer_transform, vocabulary_for_transform,
@@ -287,7 +303,6 @@ def transform(self, raw_documents):
287303
transformed = raw_documents.map_partitions(
288304
_count_vectorizer_transform, vocabulary_for_transform, params)
289305
result = build_array(transformed, n_features, meta)
290-
result.compute_chunk_sizes()
291306
return result
292307

293308
class TfidfTransformer(sklearn.feature_extraction.text.TfidfTransformer):
@@ -331,30 +346,23 @@ def fit(self, X, y=None):
331346
X : sparse matrix of shape n_samples, n_features)
332347
A matrix of term/token counts.
333348
"""
334-
# X = check_array(X, accept_sparse=('csr', 'csc'))
335-
# if not sp.issparse(X):
336-
# X = sp.csr_matrix(X)
337-
dtype = X.dtype if X.dtype in FLOAT_DTYPES else np.float64
338-
339-
if self.use_idf:
340-
n_samples, n_features = X.shape
349+
def get_idf_diag(X, dtype):
350+
n_samples = _n_samples(X) # X.shape[0] is not yet known
351+
n_features = X.shape[1]
341352
df = _document_frequency(X, dtype)
342-
# df = df.astype(dtype, **_astype_copy_false(df))
343353

344354
# perform idf smoothing if required
345355
df += int(self.smooth_idf)
346356
n_samples += int(self.smooth_idf)
347357

348358
# log+1 instead of log makes sure terms with zero idf don't get
349359
# suppressed entirely.
350-
idf = np.log(n_samples / df) + 1
351-
self._idf_diag = scipy.sparse.diags(
352-
idf,
353-
offsets=0,
354-
shape=(n_features, n_features),
355-
format="csr",
356-
dtype=dtype,
357-
)
360+
return np.log(n_samples / df) + 1
361+
362+
dtype = X.dtype if X.dtype in FLOAT_DTYPES else np.float64
363+
364+
if self.use_idf:
365+
self._idf_diag = get_idf_diag(X, dtype)
358366

359367
return self
360368

@@ -404,8 +412,17 @@ def _dot_idf_diag(chunk):
404412
# idf_ being a property, the automatic attributes detection
405413
# does not work as usual and we need to specify the attribute
406414
# name:
407-
check_is_fitted(self, attributes=["idf_"], msg="idf vector is not fitted")
408-
415+
check_is_fitted(self, attributes=["idf_"],
416+
msg="idf vector is not fitted")
417+
if dask.is_dask_collection(self._idf_diag):
418+
_idf_diag = self._idf_diag.compute()
419+
n_features = len(_idf_diag)
420+
self._idf_diag = scipy.sparse.diags(
421+
_idf_diag,
422+
offsets=0,
423+
shape=(n_features, n_features),
424+
format="csr",
425+
dtype=_idf_diag.dtype)
409426
X = X.map_blocks(_dot_idf_diag, dtype=np.float64, meta=meta)
410427

411428
if self.norm:
@@ -619,8 +636,7 @@ def fit(self, raw_documents, y=None):
619636
"""
620637
self._check_params()
621638
self._warn_for_unused_params()
622-
X = super().fit_transform(raw_documents,
623-
y=self._non_CountVectorizer_params)
639+
X = super().fit_transform(raw_documents)
624640
self._tfidf.fit(X)
625641
return self
626642

0 commit comments

Comments
 (0)