Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 17 additions & 4 deletions dask_sql/physical/rel/custom/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
from sklearn.metrics import make_scorer
from sklearn.utils.validation import check_is_fitted

from dask_sql.utils import make_pickable_without_dask_sql

try:
import sklearn.base
import sklearn.metrics
Expand Down Expand Up @@ -198,10 +200,14 @@ def transform(self, X):
if isinstance(X, da.Array):
if output_meta is None:
output_meta = _get_output_dask_ar_meta_for_estimator(
_transform, self._postfit_estimator, X
_transform,
self._postfit_estimator,
X,
)
return X.map_blocks(
_transform, estimator=self._postfit_estimator, meta=output_meta
_transform,
estimator=self._postfit_estimator,
meta=output_meta,
)
elif isinstance(X, dd._Frame):
if output_meta is None:
Expand All @@ -219,7 +225,9 @@ def transform(self, X):
# for infering meta
output_meta = dd.core.no_default
return X.map_partitions(
_transform, estimator=self._postfit_estimator, meta=output_meta
_transform,
estimator=self._postfit_estimator,
meta=output_meta,
)
else:
return _transform(X, estimator=self._postfit_estimator)
Expand Down Expand Up @@ -315,7 +323,9 @@ def predict(self, X):
if output_meta is None:
output_meta = dd.core.no_default
return X.map_partitions(
_predict, estimator=self._postfit_estimator, meta=output_meta
_predict,
estimator=self._postfit_estimator,
meta=output_meta,
)
else:
return _predict(X, estimator=self._postfit_estimator)
Expand Down Expand Up @@ -553,6 +563,7 @@ def partial_fit(self, X, y=None, **fit_kwargs):
return self._fit_for_estimator(estimator, X, y, **fit_kwargs)


@make_pickable_without_dask_sql
def _predict(part, estimator, output_meta=None):
if part.shape[0] == 0 and output_meta is not None:
empty_output = handle_empty_partitions(output_meta)
Expand All @@ -561,6 +572,7 @@ def _predict(part, estimator, output_meta=None):
return estimator.predict(part)


@make_pickable_without_dask_sql
def _predict_proba(part, estimator, output_meta=None):
if part.shape[0] == 0 and output_meta is not None:
empty_output = handle_empty_partitions(output_meta)
Expand All @@ -569,6 +581,7 @@ def _predict_proba(part, estimator, output_meta=None):
return estimator.predict_proba(part)


@make_pickable_without_dask_sql
def _transform(part, estimator, output_meta=None):
if part.shape[0] == 0 and output_meta is not None:
empty_output = handle_empty_partitions(output_meta)
Expand Down
4 changes: 4 additions & 0 deletions tests/unit/test_ml_wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

from dask_sql.physical.rel.custom.wrappers import Incremental, ParallelPostFit

from ..integration.fixtures import skip_if_external_scheduler


def _check_axis_partitioning(chunks, n_features):
c = chunks[1][0]
Expand Down Expand Up @@ -123,6 +125,7 @@ def assert_estimator_equal(left, right, exclude=None, **kwargs):
_assert_eq(l, r, name=attr, **kwargs)


@skip_if_external_scheduler
def test_parallelpostfit_basic():
clf = ParallelPostFit(GradientBoostingClassifier())

Expand Down Expand Up @@ -194,6 +197,7 @@ def test_transform(kind):
assert_eq_ar(result, expected)


@skip_if_external_scheduler
@pytest.mark.parametrize("dataframes", [False, True])
def test_incremental_basic(dataframes):
# Create observations that we know linear models can recover
Expand Down