Skip to content

Commit f340129

Browse files
committed
ENH: Added cross-validation to the module OnPLs.resampling.
1 parent 4ef5e44 commit f340129

File tree

4 files changed

+73
-7
lines changed

4 files changed

+73
-7
lines changed

OnPLS/__init__.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,9 @@
1313
"""
1414
from OnPLS import consts
1515
from OnPLS import estimators
16+
from OnPLS import resampling
1617
from OnPLS import utils
1718

1819
__version__ = "0.0.1"
1920

20-
__all__ = ["consts", "estimators", "utils"]
21+
__all__ = ["consts", "estimators", "resampling", "utils"]

OnPLS/estimators.py

+18-6
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
except ValueError:
2828
import OnPLS.consts as consts # When run as a program.
2929

30-
__all__ = ["BaseEstimator",
30+
__all__ = ["BaseEstimator", "BaseUniblock", "BaseTwoblock", "BaseMultiblock",
3131

3232
"PCA", "nPLS", "OnPLS"]
3333

@@ -133,7 +133,19 @@ def warn(self, *strs):
133133
warnings.warn(warning)
134134

135135

136-
class PCA(BaseEstimator):
136+
class BaseUniblock(object):
137+
pass
138+
139+
140+
class BaseTwoblock(object):
141+
pass
142+
143+
144+
class BaseMultiblock(object):
145+
pass
146+
147+
148+
class PCA(BaseUniblock, BaseEstimator):
137149
"""A NIPALS implementation of principal components analysis.
138150
139151
Parameters
@@ -358,7 +370,7 @@ def _deflate(self, X, p, t):
358370
return X - np.dot(t, p.T)
359371

360372

361-
class nPLS(BaseEstimator):
373+
class nPLS(BaseMultiblock, BaseEstimator):
362374
"""The nPLS method for multiblock data analysis.
363375
364376
Parameters
@@ -684,7 +696,7 @@ def generateA(self, X, psd=True):
684696
return A
685697

686698

687-
class OnPLS(BaseEstimator):
699+
class OnPLS(BaseMultiblock, BaseEstimator):
688700
"""The OnPLS method for multiblock data analysis.
689701
690702
Parameters
@@ -1033,7 +1045,7 @@ def predict(self, X, which=[], return_scores=False):
10331045
if woi is None:
10341046
continue
10351047
poi = self.Po[i]
1036-
ki = self.Wo[0].shape[1]
1048+
ki = self.Wo[i].shape[1]
10371049
for k in range(ki):
10381050
woik = woi[:, [k]]
10391051
poik = poi[:, [k]]
@@ -1075,8 +1087,8 @@ def predict(self, X, which=[], return_scores=False):
10751087
Ti = T[i]
10761088
Tik = Ti[:, [k]]
10771089
Tiks.append(Tik)
1090+
Tiks = np.hstack(Tiks)
10781091
if len(Tiks) > 0:
1079-
Tiks = np.hstack(Tiks)
10801092
beta = np.dot(np.linalg.pinv(Tiks), Tw[:, [k]])
10811093
Thatwk = np.dot(Tiks, beta)
10821094
Xhatw = Xhatw + np.dot(Thatwk, Pw.T)

OnPLS/resampling.py

+51
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
# -*- coding: utf-8 -*-
2+
"""
3+
The :mod:`OnPLS.resampling` module contains resampling methods that can be used
4+
to determine statistics from an estimator.
5+
6+
Created on Wed Nov 23 21:02:03 2016
7+
8+
Copyright (c) 2016, Tommy Löfstedt. All rights reserved.
9+
10+
@author: Tommy Löfstedt
11+
12+
@license: BSD 3-clause.
13+
"""
14+
import numpy as np
15+
16+
import OnPLS.estimators as estimators
17+
18+
__all__ = ["cross_validation"]
19+
20+
21+
def cross_validation(estimator, X, cv_rounds=7, random_state=None):
22+
23+
if isinstance(estimator, estimators.BaseUniblock):
24+
if isinstance(X, np.ndarray):
25+
X = [X]
26+
27+
n = len(X)
28+
29+
N = X[0].shape[0]
30+
31+
cv_rounds = min(max(1, N), int(cv_rounds))
32+
33+
scores = []
34+
for k in range(cv_rounds):
35+
36+
test_samples = list(range(k, N, cv_rounds))
37+
train_samples = list(set(range(0, N)).difference(test_samples))
38+
39+
Xtest = [0] * n
40+
Xtrain = [0] * n
41+
for i in range(n):
42+
Xi = X[i]
43+
Xtest[i] = Xi[test_samples, :]
44+
Xtrain[i] = Xi[train_samples, :]
45+
46+
estimator.fit(Xtrain)
47+
48+
score = estimator.score(Xtest)
49+
scores.append(score)
50+
51+
return scores

README.md

+2
Original file line numberDiff line numberDiff line change
@@ -84,4 +84,6 @@ Xhat = onpls.predict([X1, X2, X3])
8484

8585
# Compute prediction score
8686
score = onpls.score([X1, X2, X3])
87+
88+
cv_scores = OnPLS.resampling.cross_validation(onpls, [X1, X2, X3], cv_rounds=4)
8789
```

0 commit comments

Comments
 (0)