From 6b2f8e345a9db4a268868d5554686b5a98018abd Mon Sep 17 00:00:00 2001
From: Valentin Laurent <valentin.laurent.fr@gmail.com>
Date: Mon, 27 Jan 2025 16:26:10 +0100
Subject: [PATCH] ENH: refine classification API

---
 mapie_v1/classification.py | 53 ++++++++++----------------------------
 1 file changed, 14 insertions(+), 39 deletions(-)

diff --git a/mapie_v1/classification.py b/mapie_v1/classification.py
index 183c37ac..710f27c4 100644
--- a/mapie_v1/classification.py
+++ b/mapie_v1/classification.py
@@ -1,6 +1,6 @@
 from __future__ import annotations
 
-from typing import Optional, Union, List
+from typing import Optional, Union, Tuple, Iterable
 from typing_extensions import Self
 
 import numpy as np
@@ -16,7 +16,7 @@ class SplitConformalClassifier:
     def __init__(
         self,
         estimator: ClassifierMixin = LogisticRegression(),
-        confidence_level: Union[float, List[float]] = 0.9,
+        confidence_level: Union[float, Iterable[float]] = 0.9,
         conformity_score: Union[str, BaseClassificationScore] = "lac",
         prefit: bool = True,
         n_jobs: Optional[int] = None,
@@ -42,44 +42,31 @@ def conformalize(
         return self
 
     def predict(self, X: ArrayLike) -> NDArray:
-        """
-        Return
-        -----
-        Return ponctual prediction similar to predict method of
-        scikit-learn classifiers
-        Shape (n_samples,)
-        """
         return np.ndarray(0)
 
-    def predict_sets(
+    def predict_set(
         self,
         X: ArrayLike,
         conformity_score_params: Optional[dict] = None,
         # Parameters specific to conformal method,
         # For example: include_last_label
-    ) -> NDArray:
+    ) -> Tuple[NDArray, NDArray]:
         """
-        Return
-        -----
-        An array containing the prediction sets
-        Shape (n_samples, n_classes) if confidence_level is float,
-        Shape (n_samples, n_classes, confidence_level) if confidence_level
-        is a list of floats
+        Shape: (n, ), (n, n_class, n_confidence_level)
         """
-        return np.ndarray(0)
+        return np.ndarray(0), np.ndarray(0)
 
 
 class CrossConformalClassifier:
     def __init__(
         self,
         estimator: ClassifierMixin = LogisticRegression(),
-        confidence_level: Union[float, List[float]] = 0.9,
+        confidence_level: Union[float, Iterable[float]] = 0.9,
         conformity_score: Union[str, BaseClassificationScore] = "lac",
-        cross_val: Union[BaseCrossValidator, str] = 5,
+        cv: Union[int, BaseCrossValidator] = 5,
         n_jobs: Optional[int] = None,
         verbose: int = 0,
         random_state: Optional[Union[int, np.random.RandomState]] = None,
-
     ) -> None:
         pass
 
@@ -95,34 +82,22 @@ def conformalize(
         self,
         X_conformalize: ArrayLike,
         y_conformalize: ArrayLike,
+        groups: Optional[ArrayLike] = None,
         predict_params: Optional[dict] = None
     ) -> Self:
         return self
 
-    def predict(self,
-                X: ArrayLike) -> NDArray:
-        """
-        Return
-        -----
-        Return ponctual prediction similar to predict method of
-        scikit-learn classifiers
-        Shape (n_samples,)
-        """
+    def predict(self, X: ArrayLike) -> NDArray:
         return np.ndarray(0)
 
-    def predict_sets(
+    def predict_set(
         self,
         X: ArrayLike,
         aggregation_method: Optional[str] = "mean",
         # How to aggregate the scores by the estimators on test data
         conformity_score_params: Optional[dict] = None
-    ) -> NDArray:
+    ) -> Tuple[NDArray, NDArray]:
         """
-        Return
-        -----
-        An array containing the prediction sets
-        Shape (n_samples, n_classes) if confidence_level is float,
-        Shape (n_samples, n_classes, confidence_level) if confidence_level
-        is a list of floats
+        Shape: (n, ), (n, n_class, n_confidence_level)
         """
-        return np.ndarray(0)
+        return np.ndarray(0), np.ndarray(0)