From 6b2f8e345a9db4a268868d5554686b5a98018abd Mon Sep 17 00:00:00 2001 From: Valentin Laurent <valentin.laurent.fr@gmail.com> Date: Mon, 27 Jan 2025 16:26:10 +0100 Subject: [PATCH] ENH: refine classification API --- mapie_v1/classification.py | 53 ++++++++++---------------------------- 1 file changed, 14 insertions(+), 39 deletions(-) diff --git a/mapie_v1/classification.py b/mapie_v1/classification.py index 183c37ac..710f27c4 100644 --- a/mapie_v1/classification.py +++ b/mapie_v1/classification.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Optional, Union, List +from typing import Optional, Union, Tuple, Iterable from typing_extensions import Self import numpy as np @@ -16,7 +16,7 @@ class SplitConformalClassifier: def __init__( self, estimator: ClassifierMixin = LogisticRegression(), - confidence_level: Union[float, List[float]] = 0.9, + confidence_level: Union[float, Iterable[float]] = 0.9, conformity_score: Union[str, BaseClassificationScore] = "lac", prefit: bool = True, n_jobs: Optional[int] = None, @@ -42,44 +42,31 @@ def conformalize( return self def predict(self, X: ArrayLike) -> NDArray: - """ - Return - ----- - Return ponctual prediction similar to predict method of - scikit-learn classifiers - Shape (n_samples,) - """ return np.ndarray(0) - def predict_sets( + def predict_set( self, X: ArrayLike, conformity_score_params: Optional[dict] = None, # Parameters specific to conformal method, # For example: include_last_label - ) -> NDArray: + ) -> Tuple[NDArray, NDArray]: """ - Return - ----- - An array containing the prediction sets - Shape (n_samples, n_classes) if confidence_level is float, - Shape (n_samples, n_classes, confidence_level) if confidence_level - is a list of floats + Shape: (n, ), (n, n_class, n_confidence_level) """ - return np.ndarray(0) + return np.ndarray(0), np.ndarray(0) class CrossConformalClassifier: def __init__( self, estimator: ClassifierMixin = LogisticRegression(), - confidence_level: Union[float, List[float]] = 0.9, + confidence_level: Union[float, Iterable[float]] = 0.9, conformity_score: Union[str, BaseClassificationScore] = "lac", - cross_val: Union[BaseCrossValidator, str] = 5, + cv: Union[int, BaseCrossValidator] = 5, n_jobs: Optional[int] = None, verbose: int = 0, random_state: Optional[Union[int, np.random.RandomState]] = None, - ) -> None: pass @@ -95,34 +82,22 @@ def conformalize( self, X_conformalize: ArrayLike, y_conformalize: ArrayLike, + groups: Optional[ArrayLike] = None, predict_params: Optional[dict] = None ) -> Self: return self - def predict(self, - X: ArrayLike) -> NDArray: - """ - Return - ----- - Return ponctual prediction similar to predict method of - scikit-learn classifiers - Shape (n_samples,) - """ + def predict(self, X: ArrayLike) -> NDArray: return np.ndarray(0) - def predict_sets( + def predict_set( self, X: ArrayLike, aggregation_method: Optional[str] = "mean", # How to aggregate the scores by the estimators on test data conformity_score_params: Optional[dict] = None - ) -> NDArray: + ) -> Tuple[NDArray, NDArray]: """ - Return - ----- - An array containing the prediction sets - Shape (n_samples, n_classes) if confidence_level is float, - Shape (n_samples, n_classes, confidence_level) if confidence_level - is a list of floats + Shape: (n, ), (n, n_class, n_confidence_level) """ - return np.ndarray(0) + return np.ndarray(0), np.ndarray(0)