|
27 | 27 | except ValueError:
|
28 | 28 | import OnPLS.consts as consts # When run as a program.
|
29 | 29 |
|
30 |
| -__all__ = ["BaseEstimator", |
| 30 | +__all__ = ["BaseEstimator", "BaseUniblock", "BaseTwoblock", "BaseMultiblock", |
31 | 31 |
|
32 | 32 | "PCA", "nPLS", "OnPLS"]
|
33 | 33 |
|
@@ -133,7 +133,19 @@ def warn(self, *strs):
|
133 | 133 | warnings.warn(warning)
|
134 | 134 |
|
135 | 135 |
|
136 |
| -class PCA(BaseEstimator): |
| 136 | +class BaseUniblock(object): |
| 137 | + pass |
| 138 | + |
| 139 | + |
| 140 | +class BaseTwoblock(object): |
| 141 | + pass |
| 142 | + |
| 143 | + |
| 144 | +class BaseMultiblock(object): |
| 145 | + pass |
| 146 | + |
| 147 | + |
| 148 | +class PCA(BaseUniblock, BaseEstimator): |
137 | 149 | """A NIPALS implementation of principal components analysis.
|
138 | 150 |
|
139 | 151 | Parameters
|
@@ -358,7 +370,7 @@ def _deflate(self, X, p, t):
|
358 | 370 | return X - np.dot(t, p.T)
|
359 | 371 |
|
360 | 372 |
|
361 |
| -class nPLS(BaseEstimator): |
| 373 | +class nPLS(BaseMultiblock, BaseEstimator): |
362 | 374 | """The nPLS method for multiblock data analysis.
|
363 | 375 |
|
364 | 376 | Parameters
|
@@ -684,7 +696,7 @@ def generateA(self, X, psd=True):
|
684 | 696 | return A
|
685 | 697 |
|
686 | 698 |
|
687 |
| -class OnPLS(BaseEstimator): |
| 699 | +class OnPLS(BaseMultiblock, BaseEstimator): |
688 | 700 | """The OnPLS method for multiblock data analysis.
|
689 | 701 |
|
690 | 702 | Parameters
|
@@ -1033,7 +1045,7 @@ def predict(self, X, which=[], return_scores=False):
|
1033 | 1045 | if woi is None:
|
1034 | 1046 | continue
|
1035 | 1047 | poi = self.Po[i]
|
1036 |
| - ki = self.Wo[0].shape[1] |
| 1048 | + ki = self.Wo[i].shape[1] |
1037 | 1049 | for k in range(ki):
|
1038 | 1050 | woik = woi[:, [k]]
|
1039 | 1051 | poik = poi[:, [k]]
|
@@ -1075,8 +1087,8 @@ def predict(self, X, which=[], return_scores=False):
|
1075 | 1087 | Ti = T[i]
|
1076 | 1088 | Tik = Ti[:, [k]]
|
1077 | 1089 | Tiks.append(Tik)
|
| 1090 | + Tiks = np.hstack(Tiks) |
1078 | 1091 | if len(Tiks) > 0:
|
1079 |
| - Tiks = np.hstack(Tiks) |
1080 | 1092 | beta = np.dot(np.linalg.pinv(Tiks), Tw[:, [k]])
|
1081 | 1093 | Thatwk = np.dot(Tiks, beta)
|
1082 | 1094 | Xhatw = Xhatw + np.dot(Thatwk, Pw.T)
|
|
0 commit comments