1616from sklearn .utils .validation import check_is_fitted , validate_data
1717from skmatter .decomposition import _BasePCov
1818from skmatter .utils import check_cl_fit
19+ from skmatter .preprocessing import StandardFlexibleScaler
20+ import warnings
1921
2022
2123class PCovC (LinearClassifierMixin , _BasePCov ):
@@ -96,6 +98,14 @@ class PCovC(LinearClassifierMixin, _BasePCov):
9698 Tolerance for singular values computed by svd_solver == 'arpack'.
9799 Must be of range [0.0, infinity).
98100
101+ z_mean_tol: float, default=1e-12
102+ Tolerance for the column means of Z.
103+ Must be of range [0.0, infinity).
104+
105+ z_var_tol: float, default=1.5
106+ Tolerance for the column variances of Z.
107+ Must be of range [0.0, infinity).
108+
99109 space: {'feature', 'sample', 'auto'}, default='auto'
100110 whether to compute the PCovC in ``sample`` or ``feature`` space.
101111 The default is equal to ``sample`` when :math:`{n_{samples} < n_{features}}`
@@ -123,6 +133,9 @@ class PCovC(LinearClassifierMixin, _BasePCov):
123133 If None, ``sklearn.linear_model.LogisticRegression()``
124134 is used as the classifier.
125135
136+ scale_z: bool, default=False
137+ Whether to scale Z prior to eigendecomposition.
138+
126139 iterated_power : int or 'auto', default='auto'
127140 Number of iterations for the power method computed by
128141 svd_solver == 'randomized'.
@@ -143,6 +156,14 @@ class PCovC(LinearClassifierMixin, _BasePCov):
143156 Tolerance for singular values computed by svd_solver == 'arpack'.
144157 Must be of range [0.0, infinity).
145158
159+ z_mean_tol: float
160+ Tolerance for the column means of Z.
161+ Must be of range [0.0, infinity).
162+
163+ z_var_tol: float
164+ Tolerance for the column variances of Z.
165+ Must be of range [0.0, infinity).
166+
146167 space: {'feature', 'sample', 'auto'}, default='auto'
147168 whether to compute the PCovC in ``sample`` or ``feature`` space.
148169 The default is equal to ``sample`` when :math:`{n_{samples} < n_{features}}`
@@ -174,6 +195,9 @@ class PCovC(LinearClassifierMixin, _BasePCov):
174195 the projector, or weights, from the latent-space projection
175196 :math:`\mathbf{T}` to the class confidence scores :math:`\mathbf{Z}`
176197
198+ scale_z: bool
199+ Whether Z is being scaled prior to eigendecomposition
200+
177201 explained_variance_ : numpy.ndarray of shape (n_components,)
178202 The amount of variance explained by each of the selected components.
179203 Equal to n_components largest eigenvalues
@@ -208,8 +232,11 @@ def __init__(
208232 n_components = None ,
209233 svd_solver = "auto" ,
210234 tol = 1e-12 ,
235+ z_mean_tol = 1e-12 ,
236+ z_var_tol = 1.5 ,
211237 space = "auto" ,
212238 classifier = None ,
239+ scale_z = False ,
213240 iterated_power = "auto" ,
214241 random_state = None ,
215242 whiten = False ,
@@ -225,6 +252,9 @@ def __init__(
225252 whiten = whiten ,
226253 )
227254 self .classifier = classifier
255+ self .scale_z = scale_z
256+ self .z_mean_tol = z_mean_tol
257+ self .z_var_tol = z_var_tol
228258
229259 def fit (self , X , Y , W = None ):
230260 r"""Fit the model with X and Y.
@@ -291,7 +321,7 @@ def fit(self, X, Y, W=None):
291321 classifier = self .classifier
292322
293323 self .z_classifier_ = check_cl_fit (classifier , X , Y )
294- W = self .z_classifier_ .coef_ .T
324+ W = self .z_classifier_ .coef_ .T . copy ()
295325
296326 else :
297327 # If precomputed, use default classifier to predict Y from T
@@ -301,6 +331,28 @@ def fit(self, X, Y, W=None):
301331
302332 Z = X @ W
303333
334+ if self .scale_z :
335+ z_scaler = StandardFlexibleScaler ().fit (Z )
336+ Z = z_scaler .transform (Z )
337+ W /= z_scaler .scale_ .reshape (1 , - 1 )
338+
339+ z_means_ = np .mean (Z , axis = 0 )
340+ z_vars_ = np .var (Z , axis = 0 )
341+
342+ if np .max (np .abs (z_means_ )) > self .z_mean_tol :
343+ warnings .warn (
344+ "This class does not automatically center Z, and the column means "
345+ "of Z are greater than the supplied tolerance. We recommend scaling "
346+ "Z (and the weights) by setting `scale_z=True`."
347+ )
348+
349+ if np .max (z_vars_ ) > self .z_var_tol :
350+ warnings .warn (
351+ "This class does not automatically scale Z, and the column variances "
352+ "of Z are greater than the supplied tolerance. We recommend scaling "
353+ "Z (and the weights) by setting `scale_z=True`."
354+ )
355+
304356 if self .space_ == "feature" :
305357 self ._fit_feature_space (X , Y , Z )
306358 else :
0 commit comments