diff --git a/.circleci/config.yml b/.circleci/config.yml index f146307..72faf51 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -16,7 +16,7 @@ jobs: conda config --add channels conda-forge conda create -q -n test-environment python=${PYTHON} source activate test-environment - conda install -q coverage flake8 pytest pytest-cov pytest-xdist numpy pandas xgboost dask distributed scikit-learn sparse scipy + conda install -q coverage flake8 pytest pytest-cov numpy pandas xgboost dask distributed scikit-learn sparse scipy pip install -e . conda list test-environment - run: diff --git a/dask_xgboost/core.py b/dask_xgboost/core.py index 6bf29d7..2a5c04c 100644 --- a/dask_xgboost/core.py +++ b/dask_xgboost/core.py @@ -73,14 +73,21 @@ def train_part(env, param, list_of_parts, dmatrix_kwargs=None, **kwargs): ------- model if rank zero, None otherwise """ - data, labels = zip(*list_of_parts) # Prepare data + # Prepare data + if len(list_of_parts[0]) == 3: + data, labels, weight = zip(*list_of_parts) + weight = concat(weight) + else: + data, labels = zip(*list_of_parts) + weight = None + data = concat(data) # Concatenate many parts into one labels = concat(labels) if dmatrix_kwargs is None: dmatrix_kwargs = {} dmatrix_kwargs["feature_names"] = getattr(data, 'columns', None) - dtrain = xgb.DMatrix(data, labels, **dmatrix_kwargs) + dtrain = xgb.DMatrix(data, labels, weight=weight, **dmatrix_kwargs) args = [('%s=%s' % item).encode() for item in env.items()] xgb.rabit.init(args) @@ -99,7 +106,8 @@ def train_part(env, param, list_of_parts, dmatrix_kwargs=None, **kwargs): @gen.coroutine -def _train(client, params, data, labels, dmatrix_kwargs={}, **kwargs): +def _train(client, params, data, labels, + sample_weight, dmatrix_kwargs={}, **kwargs): """ Asynchronous version of train @@ -117,8 +125,18 @@ def _train(client, params, data, labels, dmatrix_kwargs={}, **kwargs): assert label_parts.ndim == 1 or label_parts.shape[1] == 1 label_parts = label_parts.flatten().tolist() - # Arrange parts into pairs. This enforces co-locality - parts = list(map(delayed, zip(data_parts, label_parts))) + if sample_weight is not None: + sample_weight_parts = sample_weight.to_delayed() + if isinstance(sample_weight_parts, np.ndarray): + assert sample_weight_parts.ndim == 1 or sample_weight_parts.shape[1] == 1 + sample_weight_parts = sample_weight_parts.flatten().tolist() + + # Arrange parts into pairs. This enforces co-locality + parts = list(map(delayed, zip(data_parts, label_parts, sample_weight_parts))) + else: + # Arrange parts into pairs. This enforces co-locality + parts = list(map(delayed, zip(data_parts, label_parts))) + parts = client.compute(parts) # Start computation in the background yield wait(parts) @@ -158,7 +176,7 @@ def _train(client, params, data, labels, dmatrix_kwargs={}, **kwargs): raise gen.Return(result) -def train(client, params, data, labels, dmatrix_kwargs={}, **kwargs): +def train(client, params, data, labels, sample_weight=None, dmatrix_kwargs={}, **kwargs): """ Train an XGBoost model on a Dask Cluster This starts XGBoost on all Dask workers, moves input data to those workers, @@ -188,7 +206,7 @@ def train(client, params, data, labels, dmatrix_kwargs={}, **kwargs): predict """ return client.sync(_train, client, params, data, - labels, dmatrix_kwargs, **kwargs) + labels, sample_weight, dmatrix_kwargs, **kwargs) def _predict_part(part, model=None): @@ -258,7 +276,7 @@ def predict(client, model, data): class XGBRegressor(xgb.XGBRegressor): - def fit(self, X, y=None): + def fit(self, X, y=None, sample_weight=None): """Fit the gradient boosting model Parameters @@ -279,6 +297,7 @@ def fit(self, X, y=None): client = default_client() xgb_options = self.get_xgb_params() self._Booster = train(client, xgb_options, X, y, + sample_weight, num_boost_round=self.n_estimators) return self @@ -289,7 +308,7 @@ def predict(self, X): class XGBClassifier(xgb.XGBClassifier): - def fit(self, X, y=None, classes=None): + def fit(self, X, y=None, classes=None, sample_weight=None): """Fit a gradient boosting classifier Parameters @@ -301,6 +320,8 @@ def fit(self, X, y=None, classes=None): classes : sequence, optional The unique values in `y`. If no specified, this will be eagerly computed from `y` before training. + sample_weight : array-line [n_samples] + Weights for each traning sample Returns ------- @@ -345,9 +366,9 @@ def fit(self, X, y=None, classes=None): # TODO: auto label-encode y # that will require a dependency on dask-ml - # TODO: sample weight self._Booster = train(client, xgb_options, X, y, + sample_weight, num_boost_round=self.n_estimators) return self diff --git a/dask_xgboost/tests/test_core.py b/dask_xgboost/tests/test_core.py index 22ca104..205b593 100644 --- a/dask_xgboost/tests/test_core.py +++ b/dask_xgboost/tests/test_core.py @@ -1,3 +1,5 @@ +from unittest.mock import Mock + import numpy as np import pandas as pd import xgboost as xgb @@ -16,6 +18,7 @@ import dask_xgboost as dxgb + # Workaround for conflict with distributed 1.23.0 # https://github.com/dask/dask-xgboost/pull/27#issuecomment-417474734 from concurrent.futures import ThreadPoolExecutor @@ -32,25 +35,58 @@ X = df.values y = labels.values +weight = np.random.rand(10) + + +@pytest.yield_fixture(scope="function") # noqa +def xgboost_loop(loop, monkeypatch): + xgb.rabit.init() + fake_xgb = xgb + + init_mock = Mock() + fake_xgb.rabit.init = init_mock + finalize_mock = Mock() + fake_xgb.rabit.finalize = finalize_mock + + monkeypatch.setattr(dxgb.core, 'xgb', fake_xgb) + yield loop + xgb.rabit.finalize() + + +@pytest.yield_fixture(scope="function") +def xgboost_gen_cluster(): + xgb.rabit.init() + yield + xgb.rabit.finalize() + + +def xgboost_fixture_deco(func): + # this decoration adds another layer over gen_cluster, allows to add fixture + def outer_wrapper(xgboost_gen_cluster): + wrapper = gen_cluster(client=True, timeout=None, + check_new_threads=False) + return wrapper(func) + return outer_wrapper -def test_classifier(loop): # noqa +def test_classifier(xgboost_loop): # noqa with cluster() as (s, [a, b]): - with Client(s['address'], loop=loop): + with Client(s['address'], loop=xgboost_loop): a = dxgb.XGBClassifier() X2 = da.from_array(X, 5) y2 = da.from_array(y, 5) - a.fit(X2, y2) + weight1 = da.from_array(weight, 5) + a.fit(X2, y2, sample_weight=weight1) p1 = a.predict(X2) b = xgb.XGBClassifier() - b.fit(X, y) + b.fit(X, y, sample_weight=weight) np.testing.assert_array_almost_equal(a.feature_importances_, b.feature_importances_) assert_eq(p1, b.predict(X)) -def test_multiclass_classifier(loop): # noqa +def test_multiclass_classifier(xgboost_loop): # noqa # data iris = load_iris() X, y = iris.data, iris.target @@ -68,7 +104,7 @@ def test_multiclass_classifier(loop): # noqa d = dxgb.XGBClassifier() with cluster() as (s, [_, _]): - with Client(s['address'], loop=loop): + with Client(s['address'], loop=xgboost_loop): # fit a.fit(X, y) # array b.fit(dX, dy, classes=[0, 1, 2]) @@ -83,8 +119,7 @@ def test_multiclass_classifier(loop): # noqa @pytest.mark.parametrize("kind", ['array', 'dataframe']) # noqa -def test_classifier_multi(kind, loop): - +def test_classifier_multi(kind, xgboost_loop): if kind == 'array': X2 = da.from_array(X, 5) y2 = da.from_array( @@ -96,7 +131,7 @@ def test_classifier_multi(kind, loop): y2 = dd.from_pandas(labels, npartitions=2) with cluster() as (s, [a, b]): - with Client(s['address'], loop=loop): + with Client(s['address'], loop=xgboost_loop): a = dxgb.XGBClassifier(num_class=3, n_estimators=10, objective="multi:softprob") a.fit(X2, y2) @@ -119,17 +154,21 @@ def test_classifier_multi(kind, loop): assert p2.compute().shape == (10, 3) -def test_regressor(loop): # noqa +def test_regressor(xgboost_loop): # noqa with cluster() as (s, [a, b]): - with Client(s['address'], loop=loop): + with Client(s['address'], loop=xgboost_loop): a = dxgb.XGBRegressor() X2 = da.from_array(X, 5) y2 = da.from_array(y, 5) - a.fit(X2, y2) + weight1 = da.from_array(weight, 5) + a.fit(X2, y2, sample_weight=weight1) p1 = a.predict(X2) b = xgb.XGBRegressor() - b.fit(X, y) + b.fit(X, y, sample_weight=weight) + + np.testing.assert_array_almost_equal(a.feature_importances_, + b.feature_importances_) assert_eq(p1, b.predict(X)) @@ -158,12 +197,11 @@ def test_basic(c, s, a, b): assert ((predictions > 0.5) != labels).sum() < 2 -@gen_cluster(client=True, timeout=None, check_new_threads=False) +@xgboost_fixture_deco def test_dmatrix_kwargs(c, s, a, b): - xgb.rabit.init() # workaround for "Doing rabit call after Finalize" dX = da.from_array(X, chunks=(2, 2)) dy = da.from_array(y, chunks=(2,)) - dbst = yield dxgb.train(c, param, dX, dy, {"missing": 0.0}) + dbst = yield dxgb.train(c, param, dX, dy, dmatrix_kwargs={"missing": 0.0}) # Distributed model matches local model with dmatrix kwargs dtrain = xgb.DMatrix(X, label=y, missing=0.0) @@ -193,9 +231,8 @@ def _test_container(dbst, predictions, X_type): assert ((predictions > 0.5) != labels).sum() < 2 -@gen_cluster(client=True, timeout=None, check_new_threads=False) +@xgboost_fixture_deco def test_numpy(c, s, a, b): - xgb.rabit.init() # workaround for "Doing rabit call after Finalize" dX = da.from_array(X, chunks=(2, 2)) dy = da.from_array(y, chunks=(2,)) dbst = yield dxgb.train(c, param, dX, dy) @@ -207,9 +244,8 @@ def test_numpy(c, s, a, b): _test_container(dbst, predictions, np.array) -@gen_cluster(client=True, timeout=None, check_new_threads=False) +@xgboost_fixture_deco def test_scipy_sparse(c, s, a, b): - xgb.rabit.init() # workaround for "Doing rabit call after Finalize" dX = da.from_array(X, chunks=(2, 2)).map_blocks(scipy.sparse.csr_matrix) dy = da.from_array(y, chunks=(2,)) dbst = yield dxgb.train(c, param, dX, dy) @@ -222,9 +258,8 @@ def test_scipy_sparse(c, s, a, b): _test_container(dbst, predictions_result, scipy.sparse.csr_matrix) -@gen_cluster(client=True, timeout=None, check_new_threads=False) +@xgboost_fixture_deco def test_sparse(c, s, a, b): - xgb.rabit.init() # workaround for "Doing rabit call after Finalize" dX = da.from_array(X, chunks=(2, 2)).map_blocks(sparse.COO) dy = da.from_array(y, chunks=(2,)) dbst = yield dxgb.train(c, param, dX, dy) diff --git a/requirements.txt b/requirements.txt index aae7dc9..a236cab 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,3 @@ -xgboost +xgboost >= 0.7 dask distributed >= 1.15.2 diff --git a/setup.cfg b/setup.cfg index 2348f49..4fe7a30 100644 --- a/setup.cfg +++ b/setup.cfg @@ -3,6 +3,7 @@ universal=1 [flake8] exclude = tests/data,docs,benchmarks,scripts +max-line-length = 120 [tool:pytest] -addopts = -rsx -v -n 1 --boxed +addopts = -rsx -v