diff --git a/sknn/backend/lasagne/mlp.py b/sknn/backend/lasagne/mlp.py index f7e528c..1947dd9 100644 --- a/sknn/backend/lasagne/mlp.py +++ b/sknn/backend/lasagne/mlp.py @@ -246,10 +246,19 @@ def _initialize_impl(self, X, y=None, w=None): if self.valid_size > 0.0: assert self.valid_set is None, "Can't specify valid_size and valid_set together." - X, X_v, y, y_v = sklearn.cross_validation.train_test_split( - X, y, - test_size=self.valid_size, - random_state=self.random_state) + indices = numpy.arange(X.shape[0]) + indices, indices_v = sklearn.cross_validation.train_test_split( + indices, + test_size=self.valid_size, + random_state=self.random_state) + X_v = X[indices_v] + y_v = y[indices_v] + + X = X[indices] + y = y[indices] + if w is not None: + w = w[indices] + self.valid_set = X_v, y_v if self.valid_set and self.is_convolution(): @@ -263,7 +272,7 @@ def _initialize_impl(self, X, y=None, w=None): params.extend(mlp_layer.get_params()) self.trainer, self.validator = self._create_mlp_trainer(params) - return X, y + return X, y, w def _predict_impl(self, X): if self.is_convolution(): diff --git a/sknn/mlp.py b/sknn/mlp.py index c1b2b41..7b459fc 100644 --- a/sknn/mlp.py +++ b/sknn/mlp.py @@ -210,7 +210,7 @@ def _fit(self, X, y, w=None): X, y = self._reshape(X, y) if not self.is_initialized: - X, y = self._initialize(X, y, w) + X, y, w = self._initialize(X, y, w) log.info("Training on dataset of {:,} samples with {} total size.".format(data_shape[0], data_size)) if data_shape[1:] != X.shape[1:]: