Skip to content

Commit

Permalink
ENH: refine classification API
Browse files Browse the repository at this point in the history
  • Loading branch information
Valentin-Laurent committed Jan 27, 2025
1 parent d13c71b commit 6b2f8e3
Showing 1 changed file with 14 additions and 39 deletions.
53 changes: 14 additions & 39 deletions mapie_v1/classification.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from typing import Optional, Union, List
from typing import Optional, Union, Tuple, Iterable
from typing_extensions import Self

import numpy as np
Expand All @@ -16,7 +16,7 @@ class SplitConformalClassifier:
def __init__(
self,
estimator: ClassifierMixin = LogisticRegression(),
confidence_level: Union[float, List[float]] = 0.9,
confidence_level: Union[float, Iterable[float]] = 0.9,
conformity_score: Union[str, BaseClassificationScore] = "lac",
prefit: bool = True,
n_jobs: Optional[int] = None,
Expand All @@ -42,44 +42,31 @@ def conformalize(
return self

def predict(self, X: ArrayLike) -> NDArray:
"""
Return
-----
Return ponctual prediction similar to predict method of
scikit-learn classifiers
Shape (n_samples,)
"""
return np.ndarray(0)

def predict_sets(
def predict_set(
self,
X: ArrayLike,
conformity_score_params: Optional[dict] = None,
# Parameters specific to conformal method,
# For example: include_last_label
) -> NDArray:
) -> Tuple[NDArray, NDArray]:
"""
Return
-----
An array containing the prediction sets
Shape (n_samples, n_classes) if confidence_level is float,
Shape (n_samples, n_classes, confidence_level) if confidence_level
is a list of floats
Shape: (n, ), (n, n_class, n_confidence_level)
"""
return np.ndarray(0)
return np.ndarray(0), np.ndarray(0)


class CrossConformalClassifier:
def __init__(
self,
estimator: ClassifierMixin = LogisticRegression(),
confidence_level: Union[float, List[float]] = 0.9,
confidence_level: Union[float, Iterable[float]] = 0.9,
conformity_score: Union[str, BaseClassificationScore] = "lac",
cross_val: Union[BaseCrossValidator, str] = 5,
cv: Union[int, BaseCrossValidator] = 5,
n_jobs: Optional[int] = None,
verbose: int = 0,
random_state: Optional[Union[int, np.random.RandomState]] = None,

) -> None:
pass

Expand All @@ -95,34 +82,22 @@ def conformalize(
self,
X_conformalize: ArrayLike,
y_conformalize: ArrayLike,
groups: Optional[ArrayLike] = None,
predict_params: Optional[dict] = None
) -> Self:
return self

def predict(self,
X: ArrayLike) -> NDArray:
"""
Return
-----
Return ponctual prediction similar to predict method of
scikit-learn classifiers
Shape (n_samples,)
"""
def predict(self, X: ArrayLike) -> NDArray:
return np.ndarray(0)

def predict_sets(
def predict_set(
self,
X: ArrayLike,
aggregation_method: Optional[str] = "mean",
# How to aggregate the scores by the estimators on test data
conformity_score_params: Optional[dict] = None
) -> NDArray:
) -> Tuple[NDArray, NDArray]:
"""
Return
-----
An array containing the prediction sets
Shape (n_samples, n_classes) if confidence_level is float,
Shape (n_samples, n_classes, confidence_level) if confidence_level
is a list of floats
Shape: (n, ), (n, n_class, n_confidence_level)
"""
return np.ndarray(0)
return np.ndarray(0), np.ndarray(0)

0 comments on commit 6b2f8e3

Please sign in to comment.