1717from sklearn .metrics import make_scorer
1818from sklearn .utils .validation import check_is_fitted
1919
20+ from dask_sql .utils import make_pickable_without_dask_sql
21+
2022try :
2123 import sklearn .base
2224 import sklearn .metrics
@@ -198,10 +200,14 @@ def transform(self, X):
198200 if isinstance (X , da .Array ):
199201 if output_meta is None :
200202 output_meta = _get_output_dask_ar_meta_for_estimator (
201- _transform , self ._postfit_estimator , X
203+ _transform ,
204+ self ._postfit_estimator ,
205+ X ,
202206 )
203207 return X .map_blocks (
204- _transform , estimator = self ._postfit_estimator , meta = output_meta
208+ _transform ,
209+ estimator = self ._postfit_estimator ,
210+ meta = output_meta ,
205211 )
206212 elif isinstance (X , dd ._Frame ):
207213 if output_meta is None :
@@ -219,7 +225,9 @@ def transform(self, X):
219225 # for infering meta
220226 output_meta = dd .core .no_default
221227 return X .map_partitions (
222- _transform , estimator = self ._postfit_estimator , meta = output_meta
228+ _transform ,
229+ estimator = self ._postfit_estimator ,
230+ meta = output_meta ,
223231 )
224232 else :
225233 return _transform (X , estimator = self ._postfit_estimator )
@@ -315,7 +323,9 @@ def predict(self, X):
315323 if output_meta is None :
316324 output_meta = dd .core .no_default
317325 return X .map_partitions (
318- _predict , estimator = self ._postfit_estimator , meta = output_meta
326+ _predict ,
327+ estimator = self ._postfit_estimator ,
328+ meta = output_meta ,
319329 )
320330 else :
321331 return _predict (X , estimator = self ._postfit_estimator )
@@ -553,6 +563,7 @@ def partial_fit(self, X, y=None, **fit_kwargs):
553563 return self ._fit_for_estimator (estimator , X , y , ** fit_kwargs )
554564
555565
566+ @make_pickable_without_dask_sql
556567def _predict (part , estimator , output_meta = None ):
557568 if part .shape [0 ] == 0 and output_meta is not None :
558569 empty_output = handle_empty_partitions (output_meta )
@@ -561,6 +572,7 @@ def _predict(part, estimator, output_meta=None):
561572 return estimator .predict (part )
562573
563574
575+ @make_pickable_without_dask_sql
564576def _predict_proba (part , estimator , output_meta = None ):
565577 if part .shape [0 ] == 0 and output_meta is not None :
566578 empty_output = handle_empty_partitions (output_meta )
@@ -569,6 +581,7 @@ def _predict_proba(part, estimator, output_meta=None):
569581 return estimator .predict_proba (part )
570582
571583
584+ @make_pickable_without_dask_sql
572585def _transform (part , estimator , output_meta = None ):
573586 if part .shape [0 ] == 0 and output_meta is not None :
574587 empty_output = handle_empty_partitions (output_meta )
0 commit comments