Skip to content

Commit f83b1f6

Browse files
committed
adding cv as parameter to 2-fold ridge
1 parent adeaac6 commit f83b1f6

File tree

2 files changed

+22
-6
lines changed

2 files changed

+22
-6
lines changed

src/skmatter/linear_model/_ridge.py

+15-6
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from joblib import Parallel, delayed
33
from sklearn.base import BaseEstimator, MultiOutputMixin, RegressorMixin
44
from sklearn.metrics import check_scoring
5-
from sklearn.model_selection import KFold
5+
from sklearn.model_selection import KFold, check_cv
66
from sklearn.utils import check_array
77
from sklearn.utils.validation import check_is_fitted
88

@@ -75,14 +75,20 @@ class RidgeRegression2FoldCV(BaseEstimator, MultiOutputMixin, RegressorMixin):
7575
parameter in e.g. :obj:`numpy.linalg.lstsq`. Be aware that for every case
7676
we always apply a small default cutoff dependend on the numerical
7777
accuracy of the data type of ``X`` in the fitting function.
78+
cv: cross-validation generator or an iterable, default=None
79+
The first yield of the generator is used do determine the two folds.
80+
If None, a 0.5 split of the two folds is used using the arguments
81+
:param shuffle: and :param random_state:
7882
shuffle : bool, default=True
7983
Whether or not to shuffle the data before splitting.
84+
If :param cv: is not None, this parameter is ignored.
8085
random_state : int or RandomState instance, default=None
8186
Controls the shuffling applied to the data before applying the split.
8287
Pass an int for reproducible output across multiple function calls.
8388
See
8489
`random_state glossary from sklearn (external link) <https://scikit-learn.org/stable/glossary.html#term-random-state>`_
8590
parameter is ignored.
91+
If :param cv: is not None, this parameter is ignored.
8692
scoring : str, callable, default=None
8793
A string (see model evaluation documentation) or
8894
a scorer callable object / function with signature
@@ -115,6 +121,7 @@ def __init__(
115121
alphas=(0.1, 1.0, 10.0),
116122
alpha_type="absolute",
117123
regularization_method="tikhonov",
124+
cv=None,
118125
scoring=None,
119126
random_state=None,
120127
shuffle=True,
@@ -123,6 +130,7 @@ def __init__(
123130
self.alphas = np.asarray(alphas)
124131
self.alpha_type = alpha_type
125132
self.regularization_method = regularization_method
133+
self.cv = cv
126134
self.scoring = scoring
127135
self.random_state = random_state
128136
self.shuffle = shuffle
@@ -171,11 +179,12 @@ def fit(self, X, y):
171179
else:
172180
scorer = check_scoring(self, scoring=self.scoring, allow_none=False)
173181

174-
fold1_idx, fold2_idx = next(
175-
KFold(
176-
n_splits=2, shuffle=self.shuffle, random_state=self.random_state
177-
).split(X)
178-
)
182+
if self.cv is None:
183+
cv = KFold(n_splits=2, shuffle=self.shuffle, random_state=self.random_state)
184+
else:
185+
cv = check_cv(self.cv)
186+
187+
fold1_idx, fold2_idx = next(cv.split(X))
179188
self.coef_ = self._2fold_cv(X, y, fold1_idx, fold2_idx, scorer)
180189
return self
181190

tests/test_linear_model.py

+7
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,13 @@ def test_ridge_regression_2fold_relative_alpha_type_raise_error(self):
133133
self.features_small, self.features_small
134134
)
135135

136+
def test_ridge_regression_2fold_iterable_cv(self):
137+
# tests if we can use iterable as cv parameter
138+
cv = [([0, 1, 2, 3], [4, 5, 6])]
139+
RidgeRegression2FoldCV(alphas=[1], cv=cv).fit(
140+
self.features_small, self.features_small
141+
)
142+
136143
ridge_parameters = [
137144
["absolute_tikhonov", "absolute", "tikhonov"],
138145
["absolute_cutoff", "absolute", "cutoff"],

0 commit comments

Comments
 (0)