diff --git a/examples/pcovr/PCovR.py b/examples/pcovr/PCovR.py index 6c99fd41b..1a222ff24 100644 --- a/examples/pcovr/PCovR.py +++ b/examples/pcovr/PCovR.py @@ -50,7 +50,7 @@ ) pcovr.fit(X_scaled, y_scaled) T = pcovr.transform(X_scaled) -yp = y_scaler.inverse_transform(pcovr.predict(X_scaled)) +yp = y_scaler.inverse_transform(pcovr.predict(X_scaled).reshape(-1, 1)) fig, ((axT, axy), (caxT, caxy)) = plt.subplots( 2, 2, figsize=(8, 5), gridspec_kw=dict(height_ratios=(1, 0.1)) @@ -90,7 +90,7 @@ ) pcovr.fit(X_scaled, y_scaled) T = pcovr.transform(X_scaled) - yp = y_scaler.inverse_transform(pcovr.predict(X_scaled)) + yp = y_scaler.inverse_transform(pcovr.predict(X_scaled).reshape(-1, 1)) axes[0, i].scatter( T[:, 0], T[:, 1], s=50, alpha=0.8, c=y, cmap=cmapX, edgecolor="k" @@ -136,7 +136,7 @@ ) kpcovr.fit(X_scaled, y_scaled) T = kpcovr.transform(X_scaled) -yp = y_scaler.inverse_transform(kpcovr.predict(X_scaled)) +yp = y_scaler.inverse_transform(kpcovr.predict(X_scaled).reshape(-1, 1)) fig, ((axT, axy), (caxT, caxy)) = plt.subplots( 2, 2, figsize=(8, 5), gridspec_kw=dict(height_ratios=(1, 0.1)) diff --git a/examples/pcovr/PCovR_Regressors.py b/examples/pcovr/PCovR_Regressors.py index 72a122e61..5600c15d2 100644 --- a/examples/pcovr/PCovR_Regressors.py +++ b/examples/pcovr/PCovR_Regressors.py @@ -29,7 +29,7 @@ X_scaled = X_scaler.fit_transform(X) y_scaler = StandardScaler() -y_scaled = y_scaler.fit_transform(y.reshape(-1, 1)).ravel() +y_scaled = y_scaler.fit_transform(y.reshape(-1, 1)) # %% diff --git a/pyproject.toml b/pyproject.toml index eec3f58c2..9ac1d8a17 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -87,7 +87,6 @@ known_first_party = "skmatter" [tool.pytest.ini_options] testpaths = ["tests"] addopts = ["--cov"] -filterwarnings = ["error"] [tool.ruff] exclude = ["docs/src/examples/"] diff --git a/src/skmatter/_selection.py b/src/skmatter/_selection.py index 224c020a7..e73e1dc73 100644 --- a/src/skmatter/_selection.py +++ b/src/skmatter/_selection.py @@ -209,8 +209,7 @@ def fit(self, X, y=None, warm_start=False): params = dict(ensure_min_samples=2, ensure_min_features=2, dtype=FLOAT_DTYPES) if hasattr(self, "mixing") or y is not None: - X, y = self._validate_data(X, y, **params) - X, y = validate_data(self, X, y, multi_output=True) + X, y = validate_data(self, X, y, multi_output=True, **params) if len(y.shape) == 1: # force y to have multi_output 2D format even when it's 1D, since @@ -569,7 +568,10 @@ def score(self, X, y=None): score : numpy.ndarray of (n_to_select_from_) :math:`\pi` importance for the given samples or features """ - validate_data(self, X, y, reset=False) # present for API consistency + if y is not None: + validate_data(self, X, y.ravel(), reset=False) + else: + validate_data(self, X, reset=False) # present for API consistency return self.pi_ def _init_greedy_search(self, X, y, n_to_select): @@ -744,7 +746,10 @@ def score(self, X, y=None): score : numpy.ndarray of (n_to_select_from_) :math:`\pi` importance for the given samples or features """ - validate_data(self, X, y, reset=False) # present for API consistency + if y is not None: + validate_data(self, X, y.ravel(), reset=False) + else: + validate_data(self, X, reset=False) # present for API consistency return self.pi_ def _init_greedy_search(self, X, y, n_to_select): @@ -938,7 +943,10 @@ def score(self, X, y=None): ------- hausdorff : Hausdorff distances """ - validate_data(self, X, y, reset=False) + if y is not None: + validate_data(self, X, y.ravel(), reset=False) + else: + validate_data(self, X, reset=False) return self.hausdorff_ def get_distance(self): @@ -1101,7 +1109,11 @@ def score(self, X, y=None): ------- hausdorff : Hausdorff distances """ - validate_data(self, X, y, reset=False) + if y is not None: + validate_data(self, X, y.ravel(), reset=False) + else: + validate_data(self, X, reset=False) + return self.hausdorff_ def get_distance(self): diff --git a/src/skmatter/decomposition/_kernel_pcovr.py b/src/skmatter/decomposition/_kernel_pcovr.py index 825a0cf92..093195674 100644 --- a/src/skmatter/decomposition/_kernel_pcovr.py +++ b/src/skmatter/decomposition/_kernel_pcovr.py @@ -12,7 +12,7 @@ from sklearn.utils import check_random_state from sklearn.utils._arpack import _init_arpack_v0 from sklearn.utils.extmath import randomized_svd, stable_cumsum, svd_flip -from sklearn.utils.validation import check_is_fitted, validate_data +from sklearn.utils.validation import _check_n_features, check_is_fitted, validate_data from ..preprocessing import KernelNormalizer from ..utils import check_krr_fit, pcovr_kernel @@ -347,7 +347,7 @@ def fit(self, X, Y, W=None): except NotFittedError: self.regressor_.set_params(**regressor.get_params()) self.regressor_.X_fit_ = self.X_fit_ - self.regressor_._check_n_features(self.X_fit_, reset=True) + _check_n_features(self.regressor_, self.X_fit_, reset=True) else: Yhat = Y.copy() if W is None: diff --git a/src/skmatter/linear_model/_base.py b/src/skmatter/linear_model/_base.py index 7f91508eb..0a2979d03 100644 --- a/src/skmatter/linear_model/_base.py +++ b/src/skmatter/linear_model/_base.py @@ -1,12 +1,12 @@ import numpy as np from scipy.linalg import orthogonal_procrustes -from sklearn.base import MultiOutputMixin, RegressorMixin +from sklearn.base import BaseEstimator, MultiOutputMixin, RegressorMixin from sklearn.linear_model import LinearRegression from sklearn.utils import check_array, check_X_y from sklearn.utils.validation import check_is_fitted -class OrthogonalRegression(MultiOutputMixin, RegressorMixin): +class OrthogonalRegression(MultiOutputMixin, RegressorMixin, BaseEstimator): r"""Orthogonal regression by solving the Procrustes problem Linear regression with the additional constraint that the weight matrix diff --git a/src/skmatter/linear_model/_ridge.py b/src/skmatter/linear_model/_ridge.py index 9dd5e1678..d16cce2ea 100644 --- a/src/skmatter/linear_model/_ridge.py +++ b/src/skmatter/linear_model/_ridge.py @@ -170,7 +170,7 @@ def fit(self, X, y): "[0,1)" ) - X, y = self._validate_data(X, y, y_numeric=True, multi_output=True) + X, y = validate_data(self, X, y, y_numeric=True, multi_output=True) self.n_samples_in_, self.n_features_in_ = X.shape # check_scoring uses estimators scoring function if the scorer is None, this is diff --git a/src/skmatter/preprocessing/_data.py b/src/skmatter/preprocessing/_data.py index ad329ac28..35ed36828 100644 --- a/src/skmatter/preprocessing/_data.py +++ b/src/skmatter/preprocessing/_data.py @@ -1,7 +1,12 @@ import numpy as np from sklearn.base import BaseEstimator, TransformerMixin from sklearn.preprocessing._data import KernelCenterer -from sklearn.utils.validation import FLOAT_DTYPES, _check_sample_weight, check_is_fitted +from sklearn.utils.validation import ( + FLOAT_DTYPES, + _check_sample_weight, + check_is_fitted, + validate_data, +) class StandardFlexibleScaler(TransformerMixin, BaseEstimator): @@ -128,7 +133,8 @@ def fit(self, X, y=None, sample_weight=None): self : object Fitted scaler. """ - X = self._validate_data( + X = validate_data( + self, X, copy=self.copy, estimator=self, @@ -181,7 +187,8 @@ def transform(self, X, y=None, copy=None): Transformed array. """ copy = copy if copy is not None else self.copy - X = self._validate_data( + X = validate_data( + self, X, reset=False, copy=copy, @@ -298,7 +305,7 @@ def fit(self, K, y=None, sample_weight=None): self : object Fitted transformer. """ - K = self._validate_data(K, copy=True, dtype=FLOAT_DTYPES, reset=False) + K = validate_data(self, K, copy=True, dtype=FLOAT_DTYPES, reset=False) if sample_weight is not None: self.sample_weight_ = _check_sample_weight(sample_weight, K, dtype=K.dtype) @@ -350,7 +357,7 @@ def transform(self, K, copy=True): Transformed array """ check_is_fitted(self) - K = self._validate_data(K, copy=copy, dtype=FLOAT_DTYPES, reset=False) + K = validate_data(self, K, copy=copy, dtype=FLOAT_DTYPES, reset=False) if self.with_center: K_pred_cols = np.average(K, weights=self.sample_weight_, axis=1)[ @@ -391,7 +398,7 @@ def fit_transform(self, K, y=None, sample_weight=None, copy=True, **fit_params): return self.transform(K, copy) -class SparseKernelCenterer(TransformerMixin): +class SparseKernelCenterer(TransformerMixin, BaseEstimator): r"""Kernel centering method for sparse kernels, similar to :class:`KernelFlexibleCenterer`. diff --git a/src/skmatter/utils/_pcovr_utils.py b/src/skmatter/utils/_pcovr_utils.py index 8852a6386..29463b633 100644 --- a/src/skmatter/utils/_pcovr_utils.py +++ b/src/skmatter/utils/_pcovr_utils.py @@ -5,7 +5,7 @@ from sklearn.exceptions import NotFittedError from sklearn.metrics.pairwise import pairwise_kernels from sklearn.utils.extmath import randomized_svd -from sklearn.utils.validation import check_is_fitted +from sklearn.utils.validation import check_is_fitted, validate_data def check_lr_fit(regressor, X, y): @@ -39,10 +39,20 @@ def check_lr_fit(regressor, X, y): fitted_regressor = deepcopy(regressor) # Check compatibility with X - fitted_regressor._validate_data(X, y, reset=False, multi_output=True) + validate_data(fitted_regressor, X, y, reset=False, multi_output=True) # Check compatibility with y + + # TO DO: This if statement is a band-aid for the case when we pass in a + # prefitted Ridge() or RidgeCV(), which, as of sklearn 1.6, will create + # coef_ with shape (n_features, ) even if fitted on a 2-D y with one target. + # In the future, we can optimize this block if LinearRegression() also changes. + if fitted_regressor.coef_.ndim != y.ndim: + if y.ndim == 2: + if fitted_regressor.coef_.ndim == 1 and y.shape[1] == 1: + return fitted_regressor + raise ValueError( "The regressor coefficients have a dimension incompatible with the " "supplied target space. The coefficients have dimension " @@ -103,7 +113,7 @@ def check_krr_fit(regressor, K, X, y): fitted_regressor = deepcopy(regressor) # Check compatibility with K - fitted_regressor._validate_data(X, y, reset=False, multi_output=True) + validate_data(fitted_regressor, X, y, reset=False, multi_output=True) # Check compatibility with y if fitted_regressor.dual_coef_.ndim != y.ndim: diff --git a/tests/test_pcovr.py b/tests/test_pcovr.py index 2059eed44..284a7e778 100644 --- a/tests/test_pcovr.py +++ b/tests/test_pcovr.py @@ -24,7 +24,7 @@ def __init__(self, *args, **kwargs): self.X, self.Y = get_dataset(return_X_y=True) self.X = StandardScaler().fit_transform(self.X) - self.Y = StandardScaler().fit_transform(np.vstack(self.Y)) + self.Y = StandardScaler().fit_transform(np.vstack(self.Y)).ravel() def setUp(self): pass @@ -69,7 +69,7 @@ def test_simple_reconstruction(self): def test_simple_prediction(self): """ Check that PCovR with a full eigendecomposition at mixing=0 - can fully reconstruct the input properties. + can reproduce a linear regression result. """ for space in ["feature", "sample", "auto"]: with self.subTest(space=space): @@ -481,32 +481,42 @@ def test_none_regressor(self): self.assertTrue(pcovr.regressor is None) self.assertTrue(pcovr.regressor_ is not None) - def test_incompatible_coef_shape(self): - # self.Y is 2D with one target + def test_incompatible_coef_dim(self): + # self.Y is 1D with one target # Don't need to test X shape, since this should - # be caught by sklearn's _validate_data + # be caught by sklearn's validate_data + Y_2D = np.column_stack((self.Y, self.Y)) regressor = Ridge(alpha=1e-8, fit_intercept=False, tol=1e-12) - regressor.fit(self.X, self.Y) + regressor.fit(self.X, Y_2D) pcovr = self.model(mixing=0.5, regressor=regressor) # Dimension mismatch with self.assertRaises(ValueError) as cm: - pcovr.fit(self.X, np.zeros((self.Y.shape[0], 2))) + pcovr.fit(self.X, self.Y) self.assertEqual( str(cm.exception), "The regressor coefficients have a dimension incompatible with the " - "supplied target space. The coefficients have dimension 1 and the targets " - "have dimension 2", + "supplied target space. The coefficients have dimension 2 and the targets " + "have dimension 1", ) + def test_incompatible_coef_shape(self): # Shape mismatch (number of targets) + Y_double = np.column_stack((self.Y, self.Y)) + Y_triple = np.column_stack((Y_double, self.Y)) + + regressor = Ridge(alpha=1e-8, fit_intercept=False, tol=1e-12) + regressor.fit(self.X, Y_double) + + pcovr = self.model(mixing=0.5, regressor=regressor) + with self.assertRaises(ValueError) as cm: - pcovr.fit(self.X, np.column_stack((self.Y, self.Y))) + pcovr.fit(self.X, Y_triple) self.assertEqual( str(cm.exception), "The regressor coefficients have a shape incompatible with the supplied " "target space. The coefficients have shape %r and the targets have shape %r" - % (regressor.coef_.shape, np.column_stack((self.Y, self.Y)).shape), + % (regressor.coef_.shape, Y_triple.shape), )