Skip to content

Commit d56ccbd

Browse files
committed
add ridge example about regularization method
* compares ridge with sklearn LOO CV in terms of numerical accuracy and timings
1 parent f83b1f6 commit d56ccbd

File tree

2 files changed

+359
-2
lines changed

2 files changed

+359
-2
lines changed

examples/regression/README.rst

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
1-
Orthogonal Regression
2-
=====================
1+
Regression
2+
==========
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,357 @@
1+
# %%
2+
3+
r"""
4+
RidgeRegression2FoldCV for data with low effective rank
5+
=======================================================
6+
In this notebook we explain in more detail how
7+
:class:`skmatter.linear_model.RidgeRegression2FoldCV` speeds up the
8+
cross-validation optimizing the regularitzation parameter :param alpha: and
9+
compare it with existing solution for that in scikit-learn
10+
:class:`slearn.linear_model.RidgeCV`.
11+
:class:`skmatter.linear_model.RidgeRegression2FoldCV` was designed to predict
12+
efficiently feature matrices, but it can be also useful for the prediction
13+
single targets.
14+
"""
15+
# %%
16+
#
17+
18+
import time
19+
20+
import matplotlib.pyplot as plt
21+
import numpy as np
22+
from sklearn.datasets import make_regression
23+
from sklearn.linear_model import RidgeCV
24+
from sklearn.metrics import mean_squared_error
25+
from sklearn.model_selection import KFold, train_test_split
26+
27+
from skmatter.linear_model import RidgeRegression2FoldCV
28+
29+
30+
# %%
31+
32+
SEED = 12616
33+
N_REPEAT_MICRO_BENCH = 5
34+
35+
# %%
36+
# Numerical instabilities of sklearn leave-one-out CV
37+
# ---------------------------------------------------
38+
#
39+
# In linear regression, the complexity of computing the weight matrix is
40+
# theoretically bounded by the inversion of the covariance matrix. This is
41+
# more costly when conducting regularized regression, wherein we need to
42+
# optimise the regularization parameter in a cross-validation (CV) scheme,
43+
# thereby recomputing the inverse for each parameter. scikit-learn offers an
44+
# efficient leave-one-out CV (LOO CV) for its ridge regression which avoids
45+
# these repeated computations [loocv]_. Because we needed an efficient ridge that works
46+
# in predicting for the reconstruction measures in :py:mod:`skmatter.metrics`
47+
# we implemented with :class:`skmatter.linear_model.RidgeRegression2FoldCV` an
48+
# efficient 2-fold CV ridge regression that uses a singular value decomposition
49+
# (SVD) to reuse it for all regularization parameters :math:`\lambda`. Assuming
50+
# we have the standard regression problem optimizing the weight matrix in
51+
#
52+
# .. math::
53+
#
54+
# \begin{align}
55+
# \|\mathbf{X}\mathbf{W} - \mathbf{Y}\|
56+
# \end{align}
57+
#
58+
# Here :math:`\mathbf{Y}` can be seen also a matrix as it is in the case of
59+
# multi target learning. Then in 2-fold cross validation we would predict first
60+
# the targets of fold 2 using fold 1 to estimate the weight matrix and vice
61+
# versa. Using SVD the scheme estimation on fold 1 looks like this.
62+
#
63+
# .. math::
64+
#
65+
# \begin{align}
66+
# &\mathbf{X}_1 = \mathbf{U}_1\mathbf{S}_1\mathbf{V}_1^T,
67+
# \qquad\qquad\qquad\quad
68+
# \textrm{feature matrix }\mathbf{X}\textrm{ for fold 1} \\
69+
# &\mathbf{W}_1(\lambda) = \mathbf{V}_1
70+
# \tilde{\mathbf{S}}_1(\lambda)^{-1} \mathbf{U}_1^T \mathbf{Y}_1,
71+
# \qquad
72+
# \textrm{weight matrix fitted on fold 1}\\
73+
# &\tilde{\mathbf{Y}}_2 = \mathbf{X}_2 \mathbf{W}_1,
74+
# \qquad\qquad\qquad\qquad
75+
# \textrm{ prediction of }\mathbf{Y}\textrm{ for fold 2}
76+
# \end{align}
77+
#
78+
# The efficient 2-fold scheme in `RidgeRegression2FoldCV` reuses the matrices
79+
#
80+
# .. math::
81+
#
82+
# \begin{align}
83+
# &\mathbf{A}_1 = \mathbf{X}_2 \mathbf{V}_1, \quad
84+
# \mathbf{B}_1 = \mathbf{U}_1^T \mathbf{Y}_1.
85+
# \end{align}
86+
#
87+
# for each fold to not recompute the SVD. The computational complexity
88+
# after the initial SVD is thereby reduced to that of matrix multiplications.
89+
90+
91+
# %%
92+
# We first create an artificial dataset
93+
94+
95+
X, y = make_regression(
96+
n_samples=1000,
97+
n_features=400,
98+
random_state=SEED,
99+
)
100+
101+
102+
# %%
103+
104+
# regularization parameters
105+
alphas = np.geomspace(1e-12, 1e-1, 12)
106+
107+
# 2 folds for train and validation split
108+
cv = KFold(n_splits=2, shuffle=True, random_state=SEED)
109+
110+
skmatter_ridge_2foldcv_cutoff = RidgeRegression2FoldCV(
111+
alphas=alphas, regularization_method="cutoff", cv=cv
112+
)
113+
114+
skmatter_ridge_2foldcv_tikhonov = RidgeRegression2FoldCV(
115+
alphas=alphas, regularization_method="tikhonov", cv=cv
116+
)
117+
118+
sklearn_ridge_2foldcv_tikhonov = RidgeCV(
119+
alphas=alphas, cv=cv, fit_intercept=False # remove the incluence of learning bias
120+
)
121+
122+
sklearn_ridge_loocv_tikhonov = RidgeCV(
123+
alphas=alphas, cv=None, fit_intercept=False # remove the incluence of learning bias
124+
)
125+
126+
# %%
127+
# Now we do simple benchmarks
128+
129+
130+
def micro_bench(ridge):
131+
global N_REPEAT_MICRO_BENCH, X, y
132+
timings = []
133+
train_mse = []
134+
test_mse = []
135+
for _ in range(N_REPEAT_MICRO_BENCH):
136+
X_train, X_test, y_train, y_test = train_test_split(
137+
X, y, train_size=0.5, random_state=SEED
138+
)
139+
start = time.time()
140+
ridge.fit(X_train, y_train)
141+
end = time.time()
142+
timings.append(end - start)
143+
train_mse.append(mean_squared_error(y_train, ridge.predict(X_train)))
144+
test_mse.append(mean_squared_error(y_test, ridge.predict(X_test)))
145+
146+
print(f" Time: {np.mean(timings)}s")
147+
print(f" Train MSE: {np.mean(train_mse)}")
148+
print(f" Test MSE: {np.mean(test_mse)}")
149+
150+
151+
print("skmatter 2-fold CV cutoff")
152+
micro_bench(skmatter_ridge_2foldcv_cutoff)
153+
print()
154+
print("skmatter 2-fold CV tikhonov")
155+
micro_bench(skmatter_ridge_2foldcv_tikhonov)
156+
print()
157+
print("sklearn 2-fold CV tikhonov")
158+
micro_bench(sklearn_ridge_2foldcv_tikhonov)
159+
print()
160+
print("sklearn leave-one-out CV")
161+
micro_bench(sklearn_ridge_loocv_tikhonov)
162+
163+
164+
# %%
165+
# We can see that leave-one-out CV is completely off. Let us manually check
166+
# each regularization parameter individually and compare it with the store mean
167+
# squared errors (MSE).
168+
169+
170+
results = {}
171+
results["sklearn 2-fold CV Tikhonov"] = {"MSE train": [], "MSE test": []}
172+
results["sklearn LOO CV Tikhonov"] = {"MSE train": [], "MSE test": []}
173+
174+
X_train, X_test, y_train, y_test = train_test_split(
175+
X, y, train_size=0.5, random_state=SEED
176+
)
177+
178+
179+
def get_train_test_error(estimator):
180+
global X_train, y_train, X_test, y_test
181+
estimator = estimator.fit(X_train, y_train)
182+
return (
183+
mean_squared_error(y_train, estimator.predict(X_train)),
184+
mean_squared_error(y_test, estimator.predict(X_test)),
185+
)
186+
187+
188+
for i in range(len(alphas)):
189+
print(f"Computing step={i} using alpha={alphas[i]}")
190+
191+
train_error, test_error = get_train_test_error(RidgeCV(alphas=[alphas[i]], cv=2))
192+
results["sklearn 2-fold CV Tikhonov"]["MSE train"].append(train_error)
193+
results["sklearn 2-fold CV Tikhonov"]["MSE test"].append(test_error)
194+
train_error, test_error = get_train_test_error(RidgeCV(alphas=[alphas[i]], cv=None))
195+
196+
results["sklearn LOO CV Tikhonov"]["MSE train"].append(train_error)
197+
results["sklearn LOO CV Tikhonov"]["MSE test"].append(test_error)
198+
199+
200+
# returns array of errors, one error per fold/sample
201+
# ndarray of shape (n_samples, n_alphas)
202+
loocv_cv_train_error = (
203+
RidgeCV(
204+
alphas=alphas,
205+
cv=None,
206+
store_cv_values=True,
207+
scoring=None, # uses by default mean squared error
208+
fit_intercept=False,
209+
)
210+
.fit(X_train, y_train)
211+
.cv_values_
212+
)
213+
214+
results["sklearn LOO CV Tikhonov"]["MSE validation"] = np.mean(
215+
loocv_cv_train_error, axis=0
216+
).tolist()
217+
218+
219+
# %%
220+
221+
# We plot all the results.
222+
plt.figure(figsize=(12, 8))
223+
for i, items in enumerate(results.items()):
224+
method_name, errors = items
225+
226+
plt.loglog(
227+
alphas,
228+
errors["MSE test"],
229+
label=f"{method_name} MSE test",
230+
color=f"C{i}",
231+
lw=3,
232+
alpha=0.9,
233+
)
234+
plt.loglog(
235+
alphas,
236+
errors["MSE train"],
237+
label=f"{method_name} MSE train",
238+
color=f"C{i}",
239+
lw=4,
240+
alpha=0.9,
241+
linestyle="--",
242+
)
243+
if "MSE validation" in errors.keys():
244+
plt.loglog(
245+
alphas,
246+
errors["MSE validation"],
247+
label=f"{method_name} MSE validation",
248+
color=f"C{i}",
249+
linestyle="dotted",
250+
lw=5,
251+
)
252+
plt.ylim(1e-16, 1)
253+
plt.xlabel("alphas (regularization parameter)")
254+
plt.ylabel("MSE")
255+
256+
plt.legend()
257+
plt.show()
258+
259+
# %%
260+
# We can see that Leave-one-out CV is estimating the error wrong for low
261+
# alpha values. That seems to be a numerical instability of the method. If we
262+
# would have limit our alphas to 1E-5, then LOO CV would have reach similar
263+
# accuracies as the 2-fold method.
264+
265+
# %%
266+
# **Important** to note that this is not an fully encompasing comparison
267+
# covering sufficient enough the parameter space. We just want to note that in
268+
# cases with high feature size and low effective rank the ridge solvers in
269+
# skmatter can be numerical more stable and act on a comparable speed.
270+
271+
# %%
272+
# Cutoff and Tikhonov regularization
273+
# ----------------------------------
274+
# When using a hard threshold as regularization (using parameter ``cutoff``),
275+
# the singular values below :math:`\lambda` are cut off, the size of the
276+
# matrices :math:`\mathbf{A}_1` and :math:`\mathbf{B}_1` can then be reduced,
277+
# resulting in further computation time savings. This performance advantage of
278+
# ``cutoff`` over the ``tikhonov`` is visible if we to predict multiple targets
279+
# and use a regularization range that cuts off a lot of singular values. For
280+
# that we increase the feature size and use as regression task the prediction
281+
# of a shuffled version of :math:`\mathbf{X}`.
282+
283+
X, y = make_regression(
284+
n_samples=1000,
285+
n_features=400,
286+
n_informative=400,
287+
effective_rank=5, # decreasiing effective rank
288+
tail_strength=1e-9,
289+
random_state=SEED,
290+
)
291+
292+
idx = np.arange(X.shape[1])
293+
np.random.seed(SEED)
294+
np.random.shuffle(idx)
295+
y = X.copy()[:, idx]
296+
297+
singular_values = np.linalg.svd(X, full_matrices=False)[1]
298+
299+
# %%
300+
301+
plt.loglog(singular_values)
302+
plt.title("Singular values of our feature matrix X")
303+
plt.axhline(1e-8, color="gray")
304+
plt.xlabel("index feature")
305+
plt.ylabel("singular value")
306+
plt.show()
307+
308+
# %%
309+
# We can see that a regularization value of 1e-8 cuts off a lot of singular
310+
# values. This is crucial for the computational speed up of the ``cutoff``
311+
# regularization method
312+
313+
# %%
314+
315+
# we use a denser range of regularization parameters to make
316+
# the speed up more visible
317+
alphas = np.geomspace(1e-8, 1e-1, 20)
318+
319+
cv = KFold(n_splits=2, shuffle=True, random_state=SEED)
320+
321+
skmatter_ridge_2foldcv_cutoff = RidgeRegression2FoldCV(
322+
alphas=alphas,
323+
regularization_method="cutoff",
324+
cv=cv,
325+
)
326+
327+
skmatter_ridge_2foldcv_tikhonov = RidgeRegression2FoldCV(
328+
alphas=alphas,
329+
regularization_method="tikhonov",
330+
cv=cv,
331+
)
332+
333+
sklearn_ridge_loocv_tikhonov = RidgeCV(
334+
alphas=alphas, cv=None, fit_intercept=False # remove the incluence of learning bias
335+
)
336+
337+
print("skmatter 2-fold CV cutoff")
338+
micro_bench(skmatter_ridge_2foldcv_cutoff)
339+
print()
340+
print("skmatter 2-fold CV tikhonov")
341+
micro_bench(skmatter_ridge_2foldcv_tikhonov)
342+
print()
343+
print("sklearn LOO CV tikhonov")
344+
micro_bench(sklearn_ridge_loocv_tikhonov)
345+
346+
347+
# %%
348+
# We also want to note that these benchmarks have huge deviations per run and
349+
# that more robust benchmarking methods would be adequate for this situation.
350+
# However, we cannot do this here as we try to keep the computation of these
351+
# examples as minimal as possible.
352+
353+
# %%
354+
# References
355+
# ----------
356+
# .. [loocv] Rifkin "Regularized Least Squares."
357+
# https://www.mit.edu/~9.520/spring07/Classes/rlsslides.pdf

0 commit comments

Comments
 (0)