From de6855b9207dafd1911769722af4a3305148db14 Mon Sep 17 00:00:00 2001 From: Florent Rambaud Date: Fri, 10 Dec 2021 11:58:01 +0100 Subject: [PATCH] fix: choose lighter dtype for distance and error stored values Signed-off-by: Florent Rambaud --- src/macest/classification/models.py | 12 ++++++++---- src/macest/regression/models.py | 6 ++++-- 2 files changed, 12 insertions(+), 6 deletions(-) diff --git a/src/macest/classification/models.py b/src/macest/classification/models.py index bf5d545..11a67b0 100644 --- a/src/macest/classification/models.py +++ b/src/macest/classification/models.py @@ -205,7 +205,8 @@ def calc_dist_to_neighbours( neighbours = np.array( self.graph[cls].knnQueryBatch( # type: ignore x_star, k=self._num_neighbours, num_threads=num_threads_available - ) + ), + dtype='float32', ) class_dist = neighbours[:, 1, :].clip(min=10 ** -15) class_ind = neighbours[:, 0, :].astype(int) @@ -213,7 +214,8 @@ def calc_dist_to_neighbours( raise ValueError("training_preds_by_class has already been cached") class_preds = self.training_preds_by_class[cls] class_error = np.array( - [class_preds[class_ind[j]] != cls for j in range(x_star.shape[0])] + [class_preds[class_ind[j]] != cls for j in range(x_star.shape[0])], + dtype='bool', ) else: if self.distance_to_neighbours is None: @@ -507,7 +509,8 @@ def _precompute_neighbours(self) -> PrecomputedNeighbourInfo: max_neighbours = np.array( self.model.graph[class_num].knnQueryBatch( # type: ignore self.x_cal, k=max_nbrs, num_threads=num_threads_available - ) + ), + dtype='float32', ) max_dist = max_neighbours[x_cal_len_array, 1] max_ind = max_neighbours[x_cal_len_array, 0] @@ -519,7 +522,8 @@ def _precompute_neighbours(self) -> PrecomputedNeighbourInfo: [ cls_preds[ind[j].astype(int)] != class_num for j in range(self.x_cal.shape[0]) - ] + ], + dtype='bool', ) # type: ignore dist_dict[k] = dist diff --git a/src/macest/regression/models.py b/src/macest/regression/models.py index be78b92..c7e0719 100644 --- a/src/macest/regression/models.py +++ b/src/macest/regression/models.py @@ -156,7 +156,8 @@ def calc_nn_dist(self, x_star: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: neighbours = np.array( self.prec_graph.knnQueryBatch( x_star, k=self._num_neighbours, num_threads=num_threads_available - ) + ), + dtype='float32', ) dist = neighbours[:, 1, :] ind = neighbours[:, 0, :].astype(int) @@ -437,7 +438,8 @@ def _prec_neighbours(self) -> Tuple[Dict[int, np.ndarray], Dict[int, np.ndarray] max_neighbours = np.array( self.prec_graph.knnQueryBatch( self.x_cal, k=int(max_nbrs), num_threads=num_threads_available - ) + ), + dtype='float32', ) max_dist = max_neighbours[x_cal_len_array, 1]