Skip to content

Commit c88bd51

Browse files
ENH: V1 conformalized quantile regressor implementation (#579)
ENH: V1 CQR implmentation --------- Co-authored-by: Valentin Laurent <[email protected]>
1 parent e42b265 commit c88bd51

File tree

3 files changed

+130
-45
lines changed

3 files changed

+130
-45
lines changed

mapie_v1/_utils.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,9 +54,11 @@ def check_if_X_y_different_from_fit(
5454

5555
def make_intervals_single_if_single_alpha(
5656
intervals: NDArray,
57-
alphas: List[float]
57+
alphas: Union[float, List[float]]
5858
) -> NDArray:
59-
if len(alphas) == 1:
59+
if isinstance(alphas, float):
60+
return intervals[:, :, 0]
61+
if isinstance(alphas, list) and len(alphas) == 1:
6062
return intervals[:, :, 0]
6163
return intervals
6264

mapie_v1/integration_tests/tests/test_regression.py

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from sklearn.ensemble import RandomForestRegressor
1111
from sklearn.linear_model import QuantileRegressor
1212
from sklearn.ensemble import GradientBoostingRegressor
13+
from sklearn.model_selection import train_test_split
1314

1415
from mapie.subsample import Subsample
1516
from mapie._typing import ArrayLike
@@ -306,22 +307,29 @@ def test_intervals_and_predictions_exact_equality_jackknife(params_jackknife):
306307
)
307308
gbr_models.append(estimator_)
308309

310+
sample_weight_train = train_test_split(
311+
X,
312+
y,
313+
sample_weight,
314+
test_size=0.4,
315+
random_state=RANDOM_STATE
316+
)[-2]
309317

310318
params_test_cases_quantile = [
311319
{
312320
"v0": {
313321
"alpha": 0.2,
314322
"cv": "split",
315323
"method": "quantile",
316-
"calib_size": 0.3,
324+
"calib_size": 0.4,
317325
"sample_weight": sample_weight,
318326
"random_state": RANDOM_STATE,
319327
},
320328
"v1": {
321329
"confidence_level": 0.8,
322330
"prefit": False,
323-
"test_size": 0.3,
324-
"fit_params": {"sample_weight": sample_weight},
331+
"test_size": 0.4,
332+
"fit_params": {"sample_weight": sample_weight_train},
325333
"random_state": RANDOM_STATE,
326334
},
327335
},
@@ -330,15 +338,15 @@ def test_intervals_and_predictions_exact_equality_jackknife(params_jackknife):
330338
"estimator": gbr_models,
331339
"cv": "prefit",
332340
"method": "quantile",
333-
"calib_size": 0.3,
341+
"calib_size": 0.2,
334342
"sample_weight": sample_weight,
335343
"optimize_beta": True,
336344
"random_state": RANDOM_STATE,
337345
},
338346
"v1": {
339347
"estimator": gbr_models,
340348
"prefit": True,
341-
"test_size": 0.3,
349+
"test_size": 0.2,
342350
"fit_params": {"sample_weight": sample_weight},
343351
"minimize_interval_width": True,
344352
"random_state": RANDOM_STATE,
@@ -418,12 +426,16 @@ def compare_model_predictions_and_intervals(
418426
v1_params: Dict = {},
419427
prefit: bool = False,
420428
test_size: Optional[float] = None,
429+
sample_weight: Optional[ArrayLike] = None,
421430
random_state: int = 42,
422431
) -> None:
423432

424433
if test_size is not None:
425434
X_train, X_conf, y_train, y_conf = train_test_split_shuffle(
426-
X, y, test_size=test_size, random_state=random_state
435+
X,
436+
y,
437+
test_size=test_size,
438+
random_state=random_state,
427439
)
428440
else:
429441
X_train, X_conf, y_train, y_conf = X, X, y, y

mapie_v1/regression.py

Lines changed: 108 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,15 @@
33
from typing_extensions import Self
44

55
import numpy as np
6-
from sklearn.linear_model import LinearRegression, QuantileRegressor
6+
from sklearn.linear_model import LinearRegression
77
from sklearn.base import RegressorMixin, clone
88
from sklearn.model_selection import BaseCrossValidator
9+
from sklearn.pipeline import Pipeline
910

1011
from mapie.subsample import Subsample
1112
from mapie._typing import ArrayLike, NDArray
1213
from mapie.conformity_scores import BaseRegressionScore
13-
from mapie.regression import MapieRegressor
14+
from mapie.regression import MapieRegressor, MapieQuantileRegressor
1415
from mapie.utils import check_estimator_fit_predict
1516
from mapie_v1.conformity_scores._utils import (
1617
check_and_select_regression_conformity_score,
@@ -833,43 +834,54 @@ def predict(
833834

834835
class ConformalizedQuantileRegressor:
835836
"""
836-
A conformal quantile regression model that generates prediction intervals
837-
using quantile regression as the base estimator.
837+
A model that combines quantile regression with conformal prediction to
838+
generate reliable prediction intervals with specified coverage levels.
838839
839-
This approach provides prediction intervals by leveraging
840-
quantile predictions and applying conformal adjustments to ensure coverage.
840+
The `ConformalizedQuantileRegressor` leverages quantile regression as its
841+
base estimator to predict conditional quantiles of the target variable,
842+
and applies conformal adjustments to ensure prediction intervals achieve
843+
the desired confidence levels. This approach is particularly useful in
844+
uncertainty quantification for regression tasks.
841845
842846
Parameters
843847
----------
844-
estimator : RegressorMixin, default=QuantileRegressor()
845-
The base quantile regression estimator used to generate point and
846-
interval predictions.
847-
848-
confidence_level : Union[float, List[float]], default=0.9
848+
estimator : Union[`RegressorMixin`, `Pipeline`, \
849+
`List[Union[RegressorMixin, Pipeline]]`]
850+
The base quantile regression model(s) for estimating target quantiles.
851+
852+
- When `prefit=False` (default):
853+
A single quantile regression estimator (e.g., `QuantileRegressor`)
854+
or a pipeline that combines preprocessing and regression.
855+
Supported Regression estimators:
856+
857+
- ``sklearn.linear_model.QuantileRegressor``
858+
- ``sklearn.ensemble.GradientBoostingRegressor``
859+
- ``sklearn.ensemble.HistGradientBoostingRegressor``
860+
- ``lightgbm.LGBMRegressor``
861+
862+
- When `prefit=True`:
863+
A list of three fitted quantile regression estimators corresponding
864+
to lower, upper, and median quantiles. These estimators should be
865+
pre-trained with consistent quantile settings:
866+
867+
* ``lower quantile = 1 - confidence_level / 2``
868+
* ``upper quantile = confidence_level / 2``
869+
* ``median quantile = 0.5``
870+
871+
confidence_level : float default=0.9
849872
The confidence level(s) for the prediction intervals, indicating the
850-
desired coverage probability of the prediction intervals. If a float
851-
is provided, it represents a single confidence level. If a list,
852-
multiple prediction intervals for each specified confidence level
853-
are returned.
873+
desired coverage probability of the prediction intervals.
854874
855-
conformity_score : Union[str, BaseRegressionScore], default="absolute"
856-
The conformity score method used to calculate the conformity error.
857-
Valid options: TODO : reference here the valid options, once the list
858-
has been be created during the implementation.
859-
See: TODO : reference conformity score classes or documentation
860-
861-
A custom score function inheriting from BaseRegressionScore may also
862-
be provided.
863-
864-
random_state : Optional[Union[int, np.random.RandomState]], default=None
865-
A seed or random state instance to ensure reproducibility in any random
866-
operations within the regressor.
875+
prefit : bool, default=False
876+
If `True`, assumes the base estimators are already fitted.
877+
When set to `True`, the `fit` method cannot be called and the
878+
provided estimators should be pre-trained.
867879
868880
Methods
869881
-------
870882
fit(X_train, y_train, fit_params=None) -> Self
871-
Fits the base estimator to the training data and initializes internal
872-
parameters required for conformal prediction.
883+
Trains the base quantile regression estimator on the provided data.
884+
Not applicable if `prefit=True`.
873885
874886
conformalize(X_conf, y_conf, predict_params=None) -> Self
875887
Calibrates the model on provided data, adjusting the prediction
@@ -904,12 +916,29 @@ class ConformalizedQuantileRegressor:
904916

905917
def __init__(
906918
self,
907-
estimator: RegressorMixin = QuantileRegressor(),
908-
confidence_level: Union[float, List[float]] = 0.9,
909-
conformity_score: Union[str, BaseRegressionScore] = "absolute",
910-
random_state: Optional[Union[int, np.random.RandomState]] = None,
919+
estimator: Optional[
920+
Union[
921+
RegressorMixin,
922+
Pipeline,
923+
List[Union[RegressorMixin, Pipeline]]
924+
]
925+
] = None,
926+
confidence_level: float = 0.9,
927+
prefit: bool = False,
911928
) -> None:
912-
pass
929+
930+
self._alpha = 1 - confidence_level
931+
self.prefit = prefit
932+
933+
cv: str = "prefit" if prefit else "split"
934+
self._mapie_quantile_regressor = MapieQuantileRegressor(
935+
estimator=estimator,
936+
method="quantile",
937+
cv=cv,
938+
alpha=self._alpha,
939+
)
940+
941+
self._sample_weight: Optional[NDArray] = None
913942

914943
def fit(
915944
self,
@@ -937,6 +966,27 @@ def fit(
937966
Self
938967
The fitted ConformalizedQuantileRegressor instance.
939968
"""
969+
970+
if self.prefit:
971+
raise ValueError(
972+
"The estimators are already fitted, the .fit() method should"
973+
" not be called with prefit=True."
974+
)
975+
976+
if fit_params:
977+
fit_params_ = copy.deepcopy(fit_params)
978+
self._sample_weight = fit_params_.pop("sample_weight", None)
979+
else:
980+
fit_params_ = {}
981+
982+
self._mapie_quantile_regressor._initialize_fit_conformalize()
983+
self._mapie_quantile_regressor._fit_estimators(
984+
X=X_train,
985+
y=y_train,
986+
sample_weight=self._sample_weight,
987+
**fit_params_,
988+
)
989+
940990
return self
941991

942992
def conformalize(
@@ -948,7 +998,7 @@ def conformalize(
948998
"""
949999
Calibrates the model on the provided data, adjusting the prediction
9501000
intervals based on quantile predictions and specified confidence
951-
levels. This step analyzes the conformity scores and refines the
1001+
level. This step analyzes the conformity scores and refines the
9521002
intervals to ensure desired coverage.
9531003
9541004
Parameters
@@ -969,6 +1019,14 @@ def conformalize(
9691019
The ConformalizedQuantileRegressor instance with calibrated
9701020
prediction intervals.
9711021
"""
1022+
self.predict_params = predict_params if predict_params else {}
1023+
1024+
self._mapie_quantile_regressor.conformalize(
1025+
X_conf,
1026+
y_conf,
1027+
**self.predict_params
1028+
)
1029+
9721030
return self
9731031

9741032
def predict_set(
@@ -1007,7 +1065,18 @@ def predict_set(
10071065
Prediction intervals with shape `(n_samples, 2)`, with lower
10081066
and upper bounds for each sample.
10091067
"""
1010-
return np.ndarray(0)
1068+
_, intervals = self._mapie_quantile_regressor.predict(
1069+
X,
1070+
optimize_beta=minimize_interval_width,
1071+
allow_infinite_bounds=allow_infinite_bounds,
1072+
symmetry=symmetric_intervals,
1073+
**self.predict_params
1074+
)
1075+
1076+
return make_intervals_single_if_single_alpha(
1077+
intervals,
1078+
self._alpha
1079+
)
10111080

10121081
def predict(
10131082
self,
@@ -1026,7 +1095,9 @@ def predict(
10261095
NDArray
10271096
Array of point predictions with shape `(n_samples,)`.
10281097
"""
1029-
return np.ndarray(0)
1098+
estimator = self._mapie_quantile_regressor
1099+
predictions, _ = estimator.predict(X, **self.predict_params)
1100+
return predictions
10301101

10311102

10321103
class GibbsConformalRegressor:

0 commit comments

Comments
 (0)