diff --git a/src/macest/classification/models.py b/src/macest/classification/models.py index bf5d545..3e1a2ec 100644 --- a/src/macest/classification/models.py +++ b/src/macest/classification/models.py @@ -78,9 +78,9 @@ class ModelWithConfidence: def __init__( self, - point_pred_model: _ClassificationPointPredictionModel, x_train: np.ndarray, y_train: Iterable[int], + point_pred_model: Optional[_ClassificationPointPredictionModel] = None, macest_model_params: MacestConfModelParams = MacestConfModelParams(), precomputed_neighbour_info: Optional[PrecomputedNeighbourInfo] = None, graph: Optional[Dict[int, nmslib.dist.FloatIndex]] = None, @@ -110,6 +110,9 @@ def __init__( :param empirical_conflict_constant: Constant to set confidence conflicting predictions calculated during calibration """ + if point_pred_model is None and training_preds_by_class is None: + raise ValueError("One of 'point_pred_model' or 'training_preds_by_class'" + "must be specified") self.point_pred_model = point_pred_model self.x_train = x_train self.y_train = y_train @@ -161,6 +164,8 @@ def predict(self, x_star: np.ndarray) -> np.ndarray: :param x_star: The point(s) at which we want to predict :return: A point prediction for the given x_star """ + if self.point_pred_model is None: + raise ValueError("Cannot predict as no 'point_pred_model' has been initialized") return self.point_pred_model.predict(x_star) def build_class_graphs(self) -> Dict[int, nmslib.dist.FloatIndex]: @@ -286,19 +291,26 @@ def predict_proba( return relative_conf def predict_confidence_of_point_prediction( - self, x_star: np.ndarray, change_conflicts: bool = False, + self, + x_star: np.ndarray, + prec_point_preds: Optional[np.ndarray] = None, + change_conflicts: bool = False, ) -> np.ndarray: """ Estimate a single confidence score, this represents the confidence of the point prediction being correct rather than a confidence score for each class. :param x_star: The point to predict confidently + :param prec_point_preds: The pre-computed model predictions :param change_conflicts: Boolean, true means conflicting predictions between macest and point prediction are set to an empirical constant :return: The confidence in the point prediction being correct """ + if prec_point_preds is not None: + self.point_preds = prec_point_preds + if self.point_preds is not None: point_prediction = self.point_preds else: @@ -377,6 +389,7 @@ def fit( self, x_cal: np.ndarray, y_cal: np.ndarray, + prec_point_preds: Optional[np.ndarray] = None, param_range: SearchBounds = SearchBounds(), optimiser_args: Optional[Dict[Any, Any]] = None, ) -> None: @@ -385,11 +398,15 @@ def fit( :param x_cal: Calibration data :param y_cal: Target values + :param prec_point_preds: The pre-computed model predictions :param param_range: The bounds within which to search for MACEst parameters :param optimiser_args: Any arguments for the optimiser (see scipy.optimize) :return: None """ + if prec_point_preds is not None: + self.point_preds = prec_point_preds + if optimiser_args is None: optimiser_args = {} @@ -471,7 +488,8 @@ def __init__( self.precomputed_index = self.precomputed_neighbours[1] self.precomputed_error = self.precomputed_neighbours[2] self._n_classes = len(np.unique(self.model.y_train)) - self.model.point_preds = self.model.predict(self.x_cal) + if self.model.point_preds is None: + self.model.point_preds = self.model.predict(self.x_cal) self.model.distance_to_neighbours = self.precomputed_distance self.model.index_of_neighbours = self.precomputed_index self.model.error_on_neighbours = self.precomputed_error @@ -576,12 +594,15 @@ def fit( self, optimiser: Literal["de"] = "de", optimiser_args: Optional[Dict[Any, Any]] = None, + update_empirical_conflict_constant: bool = True, ) -> ModelWithConfidence: """ Fit MACEst model using the calibration data. :param optimiser: The optimisation method :param optimiser_args: Any arguments for the optimisation strategy + :param update_empirical_conflict_constant: Boolean, true means the constant to set + confidence conflicting predictions will be updated at the end of fit :return: A ModelWithConfidence object with the parameters that minimises the loss function """ @@ -619,11 +640,12 @@ def fit( self.model.macest_model_params = self.set_macest_model_params() - point_preds = self.model.predict(self.x_cal) - conflicts = self.model.find_conflicting_predictions(self.x_cal) - self.model.empirical_conflict_constant = np.array( - point_preds[conflicts] == self.y_cal[conflicts] - ).mean() + if update_empirical_conflict_constant: + point_preds = self.model.predict(self.x_cal) + conflicts = self.model.find_conflicting_predictions(self.x_cal) + self.model.empirical_conflict_constant = np.array( + point_preds[conflicts] == self.y_cal[conflicts] + ).mean() self.model.distance_to_neighbours = None self.model.index_of_neighbours = None diff --git a/src/macest/regression/models.py b/src/macest/regression/models.py index be78b92..6360823 100644 --- a/src/macest/regression/models.py +++ b/src/macest/regression/models.py @@ -68,9 +68,9 @@ class ModelWithPredictionInterval: def __init__( self, - model: _RegressionPointPredictionModel, x_train: np.ndarray, train_err: np.ndarray, + model: Optional[_RegressionPointPredictionModel] = None, macest_model_params: MacestPredIntervalModelParams = MacestPredIntervalModelParams(), error_dist: Literal["normal", "laplace"] = "normal", dist_func: Literal["linear", "error_weighted_poly"] = "linear", @@ -96,6 +96,8 @@ def __init__( :param prec_ind_of_nn: The pre-computed nearest neighbour indices for the calibration and test data :param prec_graph: The pre-computed graph to use for online hnsw search """ + if model is None and prec_point_preds is None: + raise ValueError("One of 'model' or 'prec_point_preds' must be specified") self.model = model self.x_train = x_train self.train_err = train_err @@ -126,6 +128,8 @@ def predict(self, x_star: np.ndarray) -> np.ndarray: :return: pred_star : The point prediction for x_star """ + if self.model is None: + raise ValueError("Cannot predict as no 'model' has been initialized") pred_star = self.model.predict(x_star) return pred_star @@ -280,16 +284,23 @@ def _distribution(self, x_star: np.ndarray) -> laplace_gen: return dist def predict_interval( - self, x_star: np.ndarray, conf_level: Union[np.ndarray, int, float] = 90, + self, + x_star: np.ndarray, + prec_point_preds: Optional[np.ndarray] = None, + conf_level: Union[np.ndarray, int, float] = 90, ) -> np.ndarray: """ Predict the upper and lower prediction interval bounds for a given confidence level. :param x_star: The position for which we would like to predict + :param prec_point_preds: The pre-computed model predictions :param conf_level: :return: The confidence bounds for each x_star for each confidence level """ + if prec_point_preds is not None: + self.point_preds = prec_point_preds + dist = self._distribution(x_star) lower_perc = (100 - conf_level) / 2 upper_perc = 100 - lower_perc @@ -330,6 +341,7 @@ def fit( self, x_cal: np.ndarray, y_cal: np.ndarray, + prec_point_preds: Optional[np.ndarray] = None, param_range: SearchBounds = SearchBounds(), optimiser_args: Optional[Dict[Any, Any]] = None, ) -> None: @@ -338,11 +350,15 @@ def fit( :param x_cal: Calibration data :param y_cal: Target values + :param prec_point_preds: The pre-computed model predictions :param param_range: The bounds within which to search for MACEst parameters :param optimiser_args: Any arguments for the optimiser (see scipy.optimize) :return: None """ + if prec_point_preds is not None: + self.point_preds = prec_point_preds + if optimiser_args is None: optimiser_args = {} @@ -418,7 +434,8 @@ def __init__( self.prec_graph = self.model.build_graph() self.model.prec_graph = self.prec_graph self.prec_dist, self.prec_ind = self._prec_neighbours() - self.model.point_preds = self.model.predict(self.x_cal) + if self.model.point_preds is None: + self.model.point_preds = self.model.predict(self.x_cal) def _prec_neighbours(self) -> Tuple[Dict[int, np.ndarray], Dict[int, np.ndarray]]: """