Skip to content

Commit f26575b

Browse files
Resolve ML cluster failures (#957)
* Wrap ML functions with make_picklable_without_dask_sql * Add skip_if_external_scheduler to relevant functions * Wrap functions with decorators * [test-upstream] Co-authored-by: Ayush Dattagupta <[email protected]>
1 parent e473b03 commit f26575b

File tree

2 files changed

+21
-4
lines changed

2 files changed

+21
-4
lines changed

dask_sql/physical/rel/custom/wrappers.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717
from sklearn.metrics import make_scorer
1818
from sklearn.utils.validation import check_is_fitted
1919

20+
from dask_sql.utils import make_pickable_without_dask_sql
21+
2022
try:
2123
import sklearn.base
2224
import sklearn.metrics
@@ -198,10 +200,14 @@ def transform(self, X):
198200
if isinstance(X, da.Array):
199201
if output_meta is None:
200202
output_meta = _get_output_dask_ar_meta_for_estimator(
201-
_transform, self._postfit_estimator, X
203+
_transform,
204+
self._postfit_estimator,
205+
X,
202206
)
203207
return X.map_blocks(
204-
_transform, estimator=self._postfit_estimator, meta=output_meta
208+
_transform,
209+
estimator=self._postfit_estimator,
210+
meta=output_meta,
205211
)
206212
elif isinstance(X, dd._Frame):
207213
if output_meta is None:
@@ -219,7 +225,9 @@ def transform(self, X):
219225
# for infering meta
220226
output_meta = dd.core.no_default
221227
return X.map_partitions(
222-
_transform, estimator=self._postfit_estimator, meta=output_meta
228+
_transform,
229+
estimator=self._postfit_estimator,
230+
meta=output_meta,
223231
)
224232
else:
225233
return _transform(X, estimator=self._postfit_estimator)
@@ -315,7 +323,9 @@ def predict(self, X):
315323
if output_meta is None:
316324
output_meta = dd.core.no_default
317325
return X.map_partitions(
318-
_predict, estimator=self._postfit_estimator, meta=output_meta
326+
_predict,
327+
estimator=self._postfit_estimator,
328+
meta=output_meta,
319329
)
320330
else:
321331
return _predict(X, estimator=self._postfit_estimator)
@@ -553,6 +563,7 @@ def partial_fit(self, X, y=None, **fit_kwargs):
553563
return self._fit_for_estimator(estimator, X, y, **fit_kwargs)
554564

555565

566+
@make_pickable_without_dask_sql
556567
def _predict(part, estimator, output_meta=None):
557568
if part.shape[0] == 0 and output_meta is not None:
558569
empty_output = handle_empty_partitions(output_meta)
@@ -561,6 +572,7 @@ def _predict(part, estimator, output_meta=None):
561572
return estimator.predict(part)
562573

563574

575+
@make_pickable_without_dask_sql
564576
def _predict_proba(part, estimator, output_meta=None):
565577
if part.shape[0] == 0 and output_meta is not None:
566578
empty_output = handle_empty_partitions(output_meta)
@@ -569,6 +581,7 @@ def _predict_proba(part, estimator, output_meta=None):
569581
return estimator.predict_proba(part)
570582

571583

584+
@make_pickable_without_dask_sql
572585
def _transform(part, estimator, output_meta=None):
573586
if part.shape[0] == 0 and output_meta is not None:
574587
empty_output = handle_empty_partitions(output_meta)

tests/unit/test_ml_wrappers.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717

1818
from dask_sql.physical.rel.custom.wrappers import Incremental, ParallelPostFit
1919

20+
from ..integration.fixtures import skip_if_external_scheduler
21+
2022

2123
def _check_axis_partitioning(chunks, n_features):
2224
c = chunks[1][0]
@@ -123,6 +125,7 @@ def assert_estimator_equal(left, right, exclude=None, **kwargs):
123125
_assert_eq(l, r, name=attr, **kwargs)
124126

125127

128+
@skip_if_external_scheduler
126129
def test_parallelpostfit_basic():
127130
clf = ParallelPostFit(GradientBoostingClassifier())
128131

@@ -194,6 +197,7 @@ def test_transform(kind):
194197
assert_eq_ar(result, expected)
195198

196199

200+
@skip_if_external_scheduler
197201
@pytest.mark.parametrize("dataframes", [False, True])
198202
def test_incremental_basic(dataframes):
199203
# Create observations that we know linear models can recover

0 commit comments

Comments
 (0)