22Base module.
33"""
44
5- from sklearn .base import BaseEstimator
5+ from sklearn .base import BaseEstimator , ClassifierMixin
66from sklearn .utils .extmath import stable_cumsum
77from sklearn .utils .validation import _is_arraylike , check_is_fitted
88from sklearn .metrics import roc_auc_score , roc_curve , precision_recall_curve
@@ -464,7 +464,7 @@ def get_mahalanobis_matrix(self):
464464 return self .components_ .T .dot (self .components_ )
465465
466466
467- class _PairsClassifierMixin (BaseMetricLearner ):
467+ class _PairsClassifierMixin (BaseMetricLearner , ClassifierMixin ):
468468 """Base class for pairs learners.
469469
470470 Attributes
@@ -475,6 +475,7 @@ class _PairsClassifierMixin(BaseMetricLearner):
475475 classified as dissimilar.
476476 """
477477
478+ classes_ = np .array ([0 , 1 ])
478479 _tuple_size = 2 # number of points in a tuple, 2 for pairs
479480
480481 def predict (self , pairs ):
@@ -752,11 +753,12 @@ def _validate_calibration_params(strategy='accuracy', min_rate=None,
752753 'Got {} instead.' .format (type (beta )))
753754
754755
755- class _TripletsClassifierMixin (BaseMetricLearner ):
756+ class _TripletsClassifierMixin (BaseMetricLearner , ClassifierMixin ):
756757 """
757758 Base class for triplets learners.
758759 """
759760
761+ classes_ = np .array ([0 , 1 ])
760762 _tuple_size = 3 # number of points in a tuple, 3 for triplets
761763
762764 def predict (self , triplets ):
@@ -837,11 +839,12 @@ def score(self, triplets):
837839 return self .predict (triplets ).mean () / 2 + 0.5
838840
839841
840- class _QuadrupletsClassifierMixin (BaseMetricLearner ):
842+ class _QuadrupletsClassifierMixin (BaseMetricLearner , ClassifierMixin ):
841843 """
842844 Base class for quadruplets learners.
843845 """
844846
847+ classes_ = np .array ([0 , 1 ])
845848 _tuple_size = 4 # number of points in a tuple, 4 for quadruplets
846849
847850 def predict (self , quadruplets ):
0 commit comments