-
-
Notifications
You must be signed in to change notification settings - Fork 260
Description
Describe the issue:
I've been trying to use dask-ml to train large models with multidimensional inputs using Incremental to sequentially pass chunks of the dask array for training. Unfortunately, it seems Incremental, or one of the downstream libraries it calls, cannot handle data that is more than 2 dimensional. When X.ndim <= 2, Incremental correctly passes each chunk of X sequentially as the underlying numpy array to the partial_fit of the estimator, which is the advertised behaviour. However, when X.ndim >2, Incremental instead passes a tuple with the dask task key string, and location - and there seems to be no obvious way of retrieving the underlying correct data.
As a workaround, is there a way of retrieving the underlying data using the supplied information?
Alternatively, the obvious workaround is to reshape the multidimensional array to 2D prior to calling fit, and then unpack it in the partial fit back to the correct shape. The array is chunked exclusively along the first dimension (and we would only roll the remaining dimensions) - which from my understanding should not be prohibitively expensive. However, this seems like unnecessary overhead at each training step.
Minimal Complete Verifiable Example:
from dask_ml.wrappers import Incremental
import dask.array as da
#Make minimalist scikit-learn style estimator.
class IncrementalEstimator():
def __init__(self, model):
self.model = model
def partial_fit(self, X, y=None):
print('X : {}'.format(X))
print('Type X: {}'.format(type(X)))
print('y : {}'.format(y))
print('Type y: {}'.format(type(y)))
def fit(self, X, y=None):
raise NotImplementedError('Use partial_fit instead')
def predict(self, X):
return self.model.predict(X)
def score(self, X, y):
raise NotImplementedError('Use predict instead')
def get_params(self, deep=True):
return {'model': self.model}
def set_params(self,**params):
for key,value in params.items():
self.key = value
return self
#Dummy data
y = da.ones((10,), chunks=(1,))
X = da.random.random(size=(10,100,100,10,10), chunks=(1,100,100,10,10))
#Subsample such that X.ndim <= 2. This will work
X_in = X[:,:,0,0,0]
estimator = Incremental(estimator=IncrementalEstimator(None))
estimator.fit(X_in,y=y)
#Now subsample such that X.ndim = 3. This will fail and pass a tuple with dask task graph name instead.
X_in = X[:,:,:,0,0]
estimator = Incremental(estimator=IncrementalEstimator(None))
estimator.fit(X_in,y=y)
Anything else we need to know?:
If there is a better way of accomplishing what I'm trying to do using the dask ecosystem, let me know! :)
Environment:
- Dask version: dask : '2023.3.1', dask_ml : '2023.3.24'
- Python version: 3.9.16
- Operating System: Windows
- Install method (conda, pip, source): pip