Skip to content

Commit

Permalink
adding cv as parameter to 2-fold ridge
Browse files Browse the repository at this point in the history
  • Loading branch information
agoscinski committed Aug 7, 2023
1 parent adeaac6 commit f83b1f6
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 6 deletions.
21 changes: 15 additions & 6 deletions src/skmatter/linear_model/_ridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from joblib import Parallel, delayed
from sklearn.base import BaseEstimator, MultiOutputMixin, RegressorMixin
from sklearn.metrics import check_scoring
from sklearn.model_selection import KFold
from sklearn.model_selection import KFold, check_cv
from sklearn.utils import check_array
from sklearn.utils.validation import check_is_fitted

Expand Down Expand Up @@ -75,14 +75,20 @@ class RidgeRegression2FoldCV(BaseEstimator, MultiOutputMixin, RegressorMixin):
parameter in e.g. :obj:`numpy.linalg.lstsq`. Be aware that for every case
we always apply a small default cutoff dependend on the numerical
accuracy of the data type of ``X`` in the fitting function.
cv: cross-validation generator or an iterable, default=None
The first yield of the generator is used do determine the two folds.
If None, a 0.5 split of the two folds is used using the arguments
:param shuffle: and :param random_state:
shuffle : bool, default=True
Whether or not to shuffle the data before splitting.
If :param cv: is not None, this parameter is ignored.
random_state : int or RandomState instance, default=None
Controls the shuffling applied to the data before applying the split.
Pass an int for reproducible output across multiple function calls.
See
`random_state glossary from sklearn (external link) <https://scikit-learn.org/stable/glossary.html#term-random-state>`_
parameter is ignored.
If :param cv: is not None, this parameter is ignored.
scoring : str, callable, default=None
A string (see model evaluation documentation) or
a scorer callable object / function with signature
Expand Down Expand Up @@ -115,6 +121,7 @@ def __init__(
alphas=(0.1, 1.0, 10.0),
alpha_type="absolute",
regularization_method="tikhonov",
cv=None,
scoring=None,
random_state=None,
shuffle=True,
Expand All @@ -123,6 +130,7 @@ def __init__(
self.alphas = np.asarray(alphas)
self.alpha_type = alpha_type
self.regularization_method = regularization_method
self.cv = cv
self.scoring = scoring
self.random_state = random_state
self.shuffle = shuffle
Expand Down Expand Up @@ -171,11 +179,12 @@ def fit(self, X, y):
else:
scorer = check_scoring(self, scoring=self.scoring, allow_none=False)

fold1_idx, fold2_idx = next(
KFold(
n_splits=2, shuffle=self.shuffle, random_state=self.random_state
).split(X)
)
if self.cv is None:
cv = KFold(n_splits=2, shuffle=self.shuffle, random_state=self.random_state)
else:
cv = check_cv(self.cv)

fold1_idx, fold2_idx = next(cv.split(X))
self.coef_ = self._2fold_cv(X, y, fold1_idx, fold2_idx, scorer)
return self

Expand Down
7 changes: 7 additions & 0 deletions tests/test_linear_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,13 @@ def test_ridge_regression_2fold_relative_alpha_type_raise_error(self):
self.features_small, self.features_small
)

def test_ridge_regression_2fold_iterable_cv(self):
# tests if we can use iterable as cv parameter
cv = [([0, 1, 2, 3], [4, 5, 6])]
RidgeRegression2FoldCV(alphas=[1], cv=cv).fit(
self.features_small, self.features_small
)

ridge_parameters = [
["absolute_tikhonov", "absolute", "tikhonov"],
["absolute_cutoff", "absolute", "cutoff"],
Expand Down

0 comments on commit f83b1f6

Please sign in to comment.