diff --git a/src/macest/classification/models.py b/src/macest/classification/models.py index bf5d545..f7a6d8d 100644 --- a/src/macest/classification/models.py +++ b/src/macest/classification/models.py @@ -77,17 +77,18 @@ class ModelWithConfidence: """This class creates a model which returns a prediction and a confidence interval.""" def __init__( - self, - point_pred_model: _ClassificationPointPredictionModel, - x_train: np.ndarray, - y_train: Iterable[int], - macest_model_params: MacestConfModelParams = MacestConfModelParams(), - precomputed_neighbour_info: Optional[PrecomputedNeighbourInfo] = None, - graph: Optional[Dict[int, nmslib.dist.FloatIndex]] = None, - search_method_args: HnswGraphArgs = HnswGraphArgs(), - training_preds_by_class: Optional[Dict[int, np.ndarray]] = None, - verbose_training: bool = True, - empirical_conflict_constant: float = 0.5, + self, + point_pred_model: _ClassificationPointPredictionModel, + x_train: np.ndarray, + y_train: Iterable[int], + macest_model_params: MacestConfModelParams = MacestConfModelParams(), + precomputed_neighbour_info: Optional[PrecomputedNeighbourInfo] = None, + graph: Optional[Dict[int, nmslib.dist.FloatIndex]] = None, + search_method_args: HnswGraphArgs = HnswGraphArgs(), + training_preds_by_class: Optional[Dict[int, np.ndarray]] = None, + verbose_training: bool = True, + empirical_conflict_constant: float = 0.5, + num_threads: Optional[int] = num_threads_available, ): """ Init. @@ -109,6 +110,7 @@ def __init__( :param verbose_training: If true, information such as training progress will be shown :param empirical_conflict_constant: Constant to set confidence conflicting predictions calculated during calibration + :param num_threads: Number of threads to use to query the HNSW graph """ self.point_pred_model = point_pred_model self.x_train = x_train @@ -122,6 +124,7 @@ def __init__( self.search_method_args = search_method_args self._check_consistent_search_method_args() self._check_data_consistent_with_search_args() + self.num_threads = num_threads self.training_preds_by_class = training_preds_by_class if training_preds_by_class is None: @@ -187,7 +190,7 @@ def build_class_graphs(self) -> Dict[int, nmslib.dist.FloatIndex]: return self.graph def calc_dist_to_neighbours( - self, x_star: np.ndarray, cls: int + self, x_star: np.ndarray, cls: int ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: """ Calculate the distance to nearest neighbours, the index of them within the class graph and @@ -204,7 +207,9 @@ def calc_dist_to_neighbours( self.build_class_graphs() neighbours = np.array( self.graph[cls].knnQueryBatch( # type: ignore - x_star, k=self._num_neighbours, num_threads=num_threads_available + x_star, + k=self._num_neighbours, + num_threads=self.num_threads, ) ) class_dist = neighbours[:, 1, :].clip(min=10 ** -15) @@ -239,7 +244,7 @@ def calc_dist_to_neighbours( return neighbour_info def calc_linear_distance_error_func( - self, local_distance: np.ndarray, local_error: np.ndarray, + self, local_distance: np.ndarray, local_error: np.ndarray, ) -> Tuple[np.ndarray, np.ndarray]: """ Calculate the parametric linear distance function using the local error and distance. @@ -260,7 +265,7 @@ def calc_linear_distance_error_func( return dist, error def predict_proba( - self, x_star: np.ndarray, change_conflicts: bool = False, + self, x_star: np.ndarray, change_conflicts: bool = False, ) -> np.ndarray: """ Compute a confidence score for each class for a given points(s) x_star. @@ -286,7 +291,7 @@ def predict_proba( return relative_conf def predict_confidence_of_point_prediction( - self, x_star: np.ndarray, change_conflicts: bool = False, + self, x_star: np.ndarray, change_conflicts: bool = False, ) -> np.ndarray: """ Estimate a single confidence score, this represents the confidence of the point prediction @@ -312,7 +317,7 @@ def predict_confidence_of_point_prediction( return point_prediction_confidence def _calc_relative_distance_softmax_normalisation( - self, average_distance_error_func: np.ndarray, + self, average_distance_error_func: np.ndarray, ) -> np.ndarray: """ Take a vector of distance functions, we then scale these by the mean distance across @@ -332,7 +337,7 @@ def _calc_relative_distance_softmax_normalisation( return relative_conf def _renormalise_conf_with_empirical_constant( - self, x_star: np.ndarray, conf_array: np.ndarray + self, x_star: np.ndarray, conf_array: np.ndarray ) -> np.ndarray: """ Change conflicting predictions to the empirically learnt constant probability learnt during \ @@ -374,11 +379,11 @@ def find_conflicting_predictions(self, x_star: np.ndarray) -> np.ndarray: return conflicting_predictions def fit( - self, - x_cal: np.ndarray, - y_cal: np.ndarray, - param_range: SearchBounds = SearchBounds(), - optimiser_args: Optional[Dict[Any, Any]] = None, + self, + x_cal: np.ndarray, + y_cal: np.ndarray, + param_range: SearchBounds = SearchBounds(), + optimiser_args: Optional[Dict[Any, Any]] = None, ) -> None: """ Fit MACEst model using the calibration data. @@ -447,11 +452,11 @@ class _TrainingHelper(object): """Class which provides methods used when fitting MACEst model.""" def __init__( - self, - init_conf_model: ModelWithConfidence, - x_cal: np.ndarray, - y_cal: np.ndarray, - param_range: SearchBounds = SearchBounds(), + self, + init_conf_model: ModelWithConfidence, + x_cal: np.ndarray, + y_cal: np.ndarray, + param_range: SearchBounds = SearchBounds(), ): """ Init. @@ -506,7 +511,9 @@ def _precompute_neighbours(self) -> PrecomputedNeighbourInfo: max_neighbours = np.array( self.model.graph[class_num].knnQueryBatch( # type: ignore - self.x_cal, k=max_nbrs, num_threads=num_threads_available + self.x_cal, + k=max_nbrs, + num_threads=self.model.num_threads, ) ) max_dist = max_neighbours[x_cal_len_array, 1] @@ -573,9 +580,9 @@ def loss(self, params: MacestConfModelParams) -> float: return expected_calibration_error(self.model.point_preds, self.y_cal, pred_conf) def fit( - self, - optimiser: Literal["de"] = "de", - optimiser_args: Optional[Dict[Any, Any]] = None, + self, + optimiser: Literal["de"] = "de", + optimiser_args: Optional[Dict[Any, Any]] = None, ) -> ModelWithConfidence: """ Fit MACEst model using the calibration data. diff --git a/src/macest/regression/models.py b/src/macest/regression/models.py index be78b92..9ce5d5d 100644 --- a/src/macest/regression/models.py +++ b/src/macest/regression/models.py @@ -78,6 +78,7 @@ def __init__( prec_point_preds: Optional[np.ndarray] = None, prec_graph: Optional[nmslib.dist.FloatIndex] = None, search_method_args: HnswGraphArgs = HnswGraphArgs(), + num_threads: Optional[int] = num_threads_available, ): """ Init. @@ -95,6 +96,7 @@ def __init__( :param prec_distance_to_nn: The pre-computed nearest neighbour distances for the calibration and test data :param prec_ind_of_nn: The pre-computed nearest neighbour indices for the calibration and test data :param prec_graph: The pre-computed graph to use for online hnsw search + :param num_threads: Number of threads to use to query the HNSW graph """ self.model = model self.x_train = x_train @@ -117,6 +119,7 @@ def __init__( self.search_method_args = search_method_args self._check_consistent_search_method_args() self._check_data_consistent_with_search_args() + self.num_threads = num_threads def predict(self, x_star: np.ndarray) -> np.ndarray: """ @@ -155,7 +158,9 @@ def calc_nn_dist(self, x_star: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: neighbours = np.array( self.prec_graph.knnQueryBatch( - x_star, k=self._num_neighbours, num_threads=num_threads_available + x_star, + k=self._num_neighbours, + num_threads=self.num_threads, ) ) dist = neighbours[:, 1, :] @@ -436,7 +441,9 @@ def _prec_neighbours(self) -> Tuple[Dict[int, np.ndarray], Dict[int, np.ndarray] max_neighbours = np.array( self.prec_graph.knnQueryBatch( - self.x_cal, k=int(max_nbrs), num_threads=num_threads_available + self.x_cal, + k=int(max_nbrs), + num_threads=self.model.num_threads, ) )