diff --git a/aeon/clustering/_clarans.py b/aeon/clustering/_clarans.py index 2ee497366b..6d2d8298f6 100644 --- a/aeon/clustering/_clarans.py +++ b/aeon/clustering/_clarans.py @@ -128,7 +128,7 @@ def _fit_one_init(self, X: np.ndarray, max_neighbours: int): j = 0 X_indexes = np.arange(X.shape[0], dtype=int) if isinstance(self._init, Callable): - best_medoids = self._init(X) + best_medoids = self._init(X=X) else: best_medoids = self._init best_non_medoids = np.setdiff1d(X_indexes, best_medoids) diff --git a/aeon/clustering/_cluster_initialisation.py b/aeon/clustering/_cluster_initialisation.py new file mode 100644 index 0000000000..d950eed5cc --- /dev/null +++ b/aeon/clustering/_cluster_initialisation.py @@ -0,0 +1,437 @@ +"""Initialisation strategies for clustering. + +This file contains the various initialisation algorithms that can be used +with clustering algorithms. + +The functions with "indexes" in their names return the indexes of the +initialised clusters, while the functions with "centers" return the +actual centers. +""" + +from collections.abc import Callable +from functools import partial + +import numpy as np +from numpy.random import RandomState + +from aeon.distances import pairwise_distance + + +def _random_center_initialiser_indexes( + *, X: np.ndarray, n_clusters: int, random_state: RandomState +) -> np.ndarray: + return random_state.choice(X.shape[0], n_clusters, replace=False) + + +def _random_center_initialiser( + *, X: np.ndarray, n_clusters: int, random_state: RandomState +) -> np.ndarray: + return X[ + _random_center_initialiser_indexes( + X=X, n_clusters=n_clusters, random_state=random_state + ) + ] + + +def _first_center_initialiser_indexes( + *, X: np.ndarray, n_clusters: int, random_state: RandomState, **kwargs +) -> np.ndarray: + return np.arange(n_clusters) + + +def _first_center_initialiser( + *, X: np.ndarray, n_clusters: int, random_state: RandomState +) -> np.ndarray: + return X[ + _first_center_initialiser_indexes( + X=X, n_clusters=n_clusters, random_state=random_state + ) + ] + + +def _random_values_center_initialiser( + *, X: np.ndarray, n_clusters: int, random_state: RandomState, **kwargs +): + return random_state.rand(n_clusters, X.shape[-2], X.shape[-1]) + + +def _kmeans_plus_plus_center_initialiser_indexes( + *, + X: np.ndarray, + n_clusters: int, + random_state: RandomState, + distance: str | Callable, + distance_params: dict, + n_jobs: int = 1, + return_distance_and_labels: bool = False, + **kwargs, +) -> np.ndarray: + n_samples = X.shape[0] + initial_center_idx = random_state.randint(n_samples) + indexes = [initial_center_idx] + + min_distances = pairwise_distance( + X, + X[[initial_center_idx]], + method=distance, + n_jobs=n_jobs, + **distance_params, + ).reshape(n_samples) + + labels = np.zeros(n_samples, dtype=int) + + for i in range(1, n_clusters): + d = min_distances.copy() + chosen = np.asarray(indexes, dtype=int) + finite_mask = np.isfinite(d) + if not np.any(finite_mask): + candidates = np.setdiff1d(np.arange(n_samples), chosen, assume_unique=False) + next_center_idx = random_state.choice(candidates) + indexes.append(next_center_idx) + + new_distances = pairwise_distance( + X, + X[[next_center_idx]], + method=distance, + n_jobs=n_jobs, + **distance_params, + ).reshape(n_samples) + + closer_points = new_distances < min_distances + min_distances[closer_points] = new_distances[closer_points] + labels[closer_points] = i + continue + + min_val = d[finite_mask].min() + w = d - min_val + w[~np.isfinite(w)] = 0.0 + w = np.clip(w, 0.0, None) + w[chosen] = 0.0 + + total = w.sum() + if total <= 0.0: + candidates = np.setdiff1d(np.arange(n_samples), chosen, assume_unique=False) + next_center_idx = random_state.choice(candidates) + else: + p = w / total + p = np.clip(p, 0.0, None) + p_sum = p.sum() + if p_sum <= 0.0: + candidates = np.setdiff1d( + np.arange(n_samples), chosen, assume_unique=False + ) + next_center_idx = random_state.choice(candidates) + else: + p = p / p_sum + next_center_idx = random_state.choice(n_samples, p=p) + + indexes.append(next_center_idx) + + new_distances = pairwise_distance( + X, + X[[next_center_idx]], + method=distance, + n_jobs=n_jobs, + **distance_params, + ).reshape(n_samples) + + closer_points = new_distances < min_distances + min_distances[closer_points] = new_distances[closer_points] + labels[closer_points] = i + + if return_distance_and_labels: + return np.array(indexes), labels, min_distances + else: + return np.array(indexes) + + +def _kmeans_plus_plus_center_initialiser( + *, + X: np.ndarray, + n_clusters: int, + random_state: RandomState, + distance: str | Callable, + distance_params: dict, + n_jobs: int = 1, + return_distance_and_labels: bool = False, + **kwargs, +) -> np.ndarray: + indexes, labels, min_distances = _kmeans_plus_plus_center_initialiser_indexes( + X=X, + n_clusters=n_clusters, + random_state=random_state, + distance=distance, + distance_params=distance_params, + n_jobs=n_jobs, + return_distance_and_labels=True, + ) + if return_distance_and_labels: + return X[indexes], labels, min_distances + return X[indexes] + + +def _kmedoids_plus_plus_center_initialiser_indexes( + *, + X: np.ndarray, + n_clusters: int, + random_state: RandomState, + distance: str | Callable, + distance_params: dict, + n_jobs: int = 1, + **kwargs, +) -> np.ndarray: + """K-medoids++ initialisation that returns indexes. + + This uses a k-means++-style seeding procedure, but with medoids, + and supports potentially negative distances by shifting the + distance distribution to be non-negative. + """ + n_samples = X.shape[0] + initial_center_idx = random_state.randint(n_samples) + indexes = [initial_center_idx] + + # Initial minimum distances to the first medoid + min_distances = pairwise_distance( + X, + X[[initial_center_idx]], + method=distance, + n_jobs=n_jobs, + **distance_params, + ).reshape(n_samples) + + for _ in range(1, n_clusters): + d = min_distances.copy() + chosen = np.asarray(indexes, dtype=int) + + finite_mask = np.isfinite(d) + if not np.any(finite_mask): + candidates = np.setdiff1d(np.arange(n_samples), chosen, assume_unique=False) + next_center_idx = random_state.choice(candidates) + else: + min_val = d[finite_mask].min() + w = d - min_val + + w[~np.isfinite(w)] = 0.0 + + w = np.clip(w, 0.0, None) + + w[chosen] = 0.0 + + total = w.sum() + if total <= 0.0: + candidates = np.setdiff1d( + np.arange(n_samples), chosen, assume_unique=False + ) + next_center_idx = random_state.choice(candidates) + else: + p = w / total + p = np.clip(p, 0.0, None) + p_sum = p.sum() + if p_sum <= 0.0: + candidates = np.setdiff1d( + np.arange(n_samples), chosen, assume_unique=False + ) + next_center_idx = random_state.choice(candidates) + else: + p = p / p_sum + next_center_idx = random_state.choice(n_samples, p=p) + + indexes.append(next_center_idx) + + new_distances = pairwise_distance( + X, + X[[next_center_idx]], + method=distance, + n_jobs=n_jobs, + **distance_params, + ).reshape(n_samples) + + closer_points = new_distances < min_distances + min_distances[closer_points] = new_distances[closer_points] + + return np.array(indexes) + + +def _kmedoids_plus_plus_center_initialiser( + *, + X: np.ndarray, + n_clusters: int, + random_state: RandomState, + distance: str | Callable, + distance_params: dict, + n_jobs: int = 1, + **kwargs, +) -> np.ndarray: + """K-medoids++ initialisation that returns centers.""" + indexes = _kmedoids_plus_plus_center_initialiser_indexes( + X=X, + n_clusters=n_clusters, + random_state=random_state, + distance=distance, + distance_params=distance_params, + n_jobs=n_jobs, + ) + return X[indexes] + + +def resolve_center_initialiser( + init: str | np.ndarray, + X: np.ndarray, + n_clusters: int, + random_state: RandomState, + distance: str | Callable | None = None, + distance_params: dict | None = None, + n_jobs: int = 1, + custom_init_handlers: dict | None = None, + use_indexes: bool = False, +) -> Callable | np.ndarray: + """Resolve the center initialiser function or array from init parameter. + + Parameters + ---------- + X : 3D np.ndarray + Input data, any number of channels, equal length series of shape ``( + n_cases, n_channels, n_timepoints)`` + or 2D np.array (univariate, equal length series) of shape + ``(n_cases, n_timepoints)`` + or list of numpy arrays (any number of channels, unequal length series) + of shape ``[n_cases]``, 2D np.array ``(n_channels, n_timepoints_i)``, + where ``n_timepoints_i`` is length of series ``i``. Other types are + allowed and converted into one of the above. + n_clusters : int + The number of clusters to form as well as the number of centroids to generate. + random_state : RandomState + If `np.random.RandomState` instance, + distance : str or callable, optional + Distance method to compute similarity between time series. A list of valid + strings for measures can be found in the documentation for + :func:`aeon.distances.get_distance_function`. If a callable is passed it must be + a function that takes two 2d numpy arrays as input and returns a float. + distance_params : dict, default=None + Dictionary containing kwargs for the distance being used. For example if you + wanted to specify a window for DTW you would pass + distance_params={"window": 0.2}. See documentation of aeon.distances for more + details. + n_jobs : int, default=1 + The number of jobs to run in parallel. If -1, then the number of jobs is set + to the number of CPU cores. If 1, then the function is executed in a single + thread. If greater than 1, then the function is executed in parallel. + custom_init_handlers : dict, default=None + A dictionary of custom initialisation functions that can be used to initialise. + use_indexes : bool, default=False + Boolean when True initialisation that return indexes is returned, when false + return initialisation that returns the centres + + Returns + ------- + Callable | np.ndarray + If a ndarray is specific as init then the validated ndarray is returned, + If a string is passed then the corresponding function is returned. + """ + initialisers_dict = ( + _CENTRE_INITIALISERS if not use_indexes else _CENTRE_INITIALISER_INDEXES + ) + valid_init_methods = ", ".join(sorted(initialisers_dict.keys())) + + if isinstance(init, str): + if custom_init_handlers and init in custom_init_handlers: + return custom_init_handlers[init] + + if init not in initialisers_dict: + raise ValueError( + f"The value provided for init: {init} is invalid. " + f"The following are a list of valid init algorithms " + f"strings: {valid_init_methods}. You can also pass a " + f"np.ndarray of appropriate shape." + ) + + initialiser_func = initialisers_dict[init] + + if init in ("kmeans++", "kmedoids++"): + if distance is None or distance_params is None: + raise ValueError( + f"distance and distance_params are required for {init} " + f"initialisation" + ) + return partial( + initialiser_func, + n_clusters=n_clusters, + random_state=random_state, + distance=distance, + distance_params=distance_params, + n_jobs=n_jobs, + ) + + if init == "random_values": + return partial( + initialiser_func, + n_clusters=n_clusters, + random_state=random_state, + ) + + return partial( + initialiser_func, + n_clusters=n_clusters, + random_state=random_state, + ) + elif isinstance(init, np.ndarray): + if len(init) != n_clusters: + raise ValueError( + f"The value provided for init: {init} is invalid. " + f"Expected length {n_clusters}, got {len(init)}." + ) + + if use_indexes: + if init.ndim != 1: + raise ValueError( + f"The value provided for init: {init} is invalid. " + f"Expected 1D array of shape ({n_clusters},), " + f"got {init.shape}." + ) + if not np.issubdtype(init.dtype, np.integer): + raise ValueError( + f"The value provided for init: {init} is invalid. " + f"Expected an array of integers, got {init.dtype}." + ) + if init.min() < 0 or init.max() >= X.shape[0]: + raise ValueError( + f"The value provided for init: {init} is invalid. " + f"Values must be in the range [0, {X.shape[0]})." + ) + return init + else: + if init.ndim == 1: + raise ValueError( + f"The value provided for init: {init} is invalid. " + f"Expected multi-dimensional array of shape " + f"({n_clusters}, {X.shape[1]}, {X.shape[2]}), " + f"got {init.shape}." + ) + if init.shape[1:] != X.shape[1:]: + raise ValueError( + f"The value provided for init: {init} is invalid. " + f"Expected shape ({n_clusters}, {X.shape[1]}, " + f"{X.shape[2]}), got {init.shape}." + ) + return init.copy() + else: + raise ValueError( + f"The value provided for init: {init} is invalid. " + f"Expected a string or np.ndarray." + ) + + +_CENTRE_INITIALISERS = { + "random": _random_center_initialiser, + "first": _first_center_initialiser, + "random_values": _random_values_center_initialiser, + "kmeans++": _kmeans_plus_plus_center_initialiser, + "kmedoids++": _kmedoids_plus_plus_center_initialiser, +} + +_CENTRE_INITIALISER_INDEXES = { + "random": _random_center_initialiser_indexes, + "first": _first_center_initialiser_indexes, + "kmeans++": _kmeans_plus_plus_center_initialiser_indexes, + "kmedoids++": _kmedoids_plus_plus_center_initialiser_indexes, +} diff --git a/aeon/clustering/_elastic_som.py b/aeon/clustering/_elastic_som.py index 361979a6d8..40fc87135f 100644 --- a/aeon/clustering/_elastic_som.py +++ b/aeon/clustering/_elastic_som.py @@ -9,6 +9,9 @@ from numpy.random import RandomState from sklearn.utils.random import check_random_state +from aeon.clustering._cluster_initialisation import ( + resolve_center_initialiser, +) from aeon.clustering.base import BaseClusterer from aeon.distances import get_alignment_path_function, pairwise_distance @@ -203,7 +206,7 @@ def _fit(self, X, y=None): self._check_params(X) if isinstance(self._init, Callable): - weights = self._init(X) + weights = self._init(X=X) else: weights = self._init.copy() @@ -257,30 +260,22 @@ def _update_iteration(self, x, weights, decay_rate, num_iterations): def _check_params(self, X): self._random_state = check_random_state(self.random_state) - # random initialization - if isinstance(self.init, str): - if self.init == "random": - self._init = self._random_center_initializer - elif self.init == "kmeans++": - self._init = self._kmeans_plus_plus_center_initializer - elif self.init == "first": - self._init = self._first_center_initializer - else: - raise ValueError( - f"The value provided for init: {self.init} is " - f"invalid. The following are a list of valid init algorithms " - f"strings: random, kmedoids++, first" - ) + + if self.distance_params is None: + self._distance_params = {} else: - if isinstance(self.init, np.ndarray) and len(self.init) == self.n_clusters: - self._init = self.init.copy() - else: - raise ValueError( - f"The value provided for init: {self.init} is " - f"invalid. The following are a list of valid init algorithms " - f"strings: random, kmedoids++, first. You can also pass a" - f"np.ndarray of size (n_clusters, n_channels, n_timepoints)" - ) + self._distance_params = self.distance_params + + self._init = resolve_center_initialiser( + init=self.init, + X=X, + n_clusters=self.n_clusters, + random_state=self._random_state, + distance=self.distance, + distance_params=self._distance_params, + n_jobs=1, + use_indexes=False, + ) self._neuron_position = np.arange(self.n_clusters) @@ -331,11 +326,6 @@ def _check_params(self, X): else: self._alignment_path_callable = None - if self.distance_params is None: - self._distance_params = {} - else: - self._distance_params = self.distance_params - def _elastic_update(self, x, y, w): best_path, distance = self._alignment_path_callable( x, y, **self._distance_params @@ -362,28 +352,6 @@ def _elastic_update(self, x, y, w): return s3 - def _random_center_initializer(self, X: np.ndarray) -> np.ndarray: - return X[self._random_state.choice(X.shape[0], self.n_clusters, replace=False)] - - def _kmeans_plus_plus_center_initializer(self, X: np.ndarray): - initial_center_idx = self._random_state.randint(X.shape[0]) - indexes = [initial_center_idx] - - for _ in range(1, self.n_clusters): - pw_dist = pairwise_distance( - X, X[indexes], method=self.distance, **self._distance_params - ) - min_distances = pw_dist.min(axis=1) - probabilities = min_distances / min_distances.sum() - next_center_idx = self._random_state.choice(X.shape[0], p=probabilities) - indexes.append(next_center_idx) - - centers = X[indexes] - return centers - - def _first_center_initializer(self, X: np.ndarray) -> np.ndarray: - return X[list(range(self.n_clusters))] - @classmethod def _get_test_params(cls, parameter_set="default"): """Return testing parameter settings for the estimator. diff --git a/aeon/clustering/_k_means.py b/aeon/clustering/_k_means.py index 54ce0d7b95..78b40b55ec 100644 --- a/aeon/clustering/_k_means.py +++ b/aeon/clustering/_k_means.py @@ -9,6 +9,9 @@ from numpy.random import RandomState from sklearn.utils import check_random_state +from aeon.clustering._cluster_initialisation import ( + resolve_center_initialiser, +) from aeon.clustering.averaging import ( VALID_BA_DISTANCE_METHODS, elastic_barycenter_average, @@ -239,7 +242,7 @@ def _fit(self, X: np.ndarray, y=None): def _fit_one_init(self, X: np.ndarray) -> tuple: if isinstance(self._init, Callable): - cluster_centres = self._init(X) + cluster_centres = self._init(X=X) else: cluster_centres = self._init.copy() prev_inertia = np.inf @@ -299,36 +302,22 @@ def _check_params(self, X: np.ndarray) -> None: self._random_state = check_random_state(self.random_state) self._n_jobs = check_n_jobs(self.n_jobs) - _incorrect_init_str = ( - f"The value provided for init: {self.init} is " - f"invalid. The following are a list of valid init algorithms " - f"strings: random, kmeans++, first. You can also pass a " - f"np.ndarray of size (n_clusters, n_channels, n_timepoints)" - ) - - if isinstance(self.init, str): - if self.init == "random": - self._init = self._random_center_initializer - elif self.init == "kmeans++": - self._init = self._kmeans_plus_plus_center_initializer - elif self.init == "first": - self._init = self._first_center_initializer - else: - raise ValueError(_incorrect_init_str) - else: - if ( - isinstance(self.init, np.ndarray) - and len(self.init) == self.n_clusters - and self.init.shape[1:] == X.shape[1:] - ): - self._init = self.init.copy() - else: - raise ValueError(_incorrect_init_str) - + # Set up distance_params before init logic (needed for kmeans++ initializer) if self.distance_params is None: self._distance_params = {} else: self._distance_params = self.distance_params.copy() + + self._init = resolve_center_initialiser( + init=self.init, + X=X, + n_clusters=self.n_clusters, + random_state=self._random_state, + distance=self.distance, + distance_params=self._distance_params, + n_jobs=self._n_jobs, + use_indexes=False, + ) if self.average_params is None: self._average_params = {} else: @@ -375,32 +364,6 @@ def _check_params(self, X: np.ndarray) -> None: if isinstance(self.averaging_method, str) and self.averaging_method != "mean": self._average_params["n_jobs"] = self._n_jobs - def _random_center_initializer(self, X: np.ndarray) -> np.ndarray: - return X[self._random_state.choice(X.shape[0], self.n_clusters, replace=False)] - - def _first_center_initializer(self, X: np.ndarray) -> np.ndarray: - return X[list(range(self.n_clusters))] - - def _kmeans_plus_plus_center_initializer(self, X: np.ndarray): - initial_center_idx = self._random_state.randint(X.shape[0]) - indexes = [initial_center_idx] - - for _ in range(1, self.n_clusters): - pw_dist = pairwise_distance( - X, - X[indexes], - method=self.distance, - n_jobs=self._n_jobs, - **self._distance_params, - ) - min_distances = pw_dist.min(axis=1) - probabilities = min_distances / min_distances.sum() - next_center_idx = self._random_state.choice(X.shape[0], p=probabilities) - indexes.append(next_center_idx) - - centers = X[indexes] - return centers - def _handle_empty_cluster( self, X: np.ndarray, diff --git a/aeon/clustering/_k_medoids.py b/aeon/clustering/_k_medoids.py index 1b9d02618c..0825e31cd1 100644 --- a/aeon/clustering/_k_medoids.py +++ b/aeon/clustering/_k_medoids.py @@ -11,6 +11,9 @@ from sklearn.exceptions import ConvergenceWarning from sklearn.utils import check_random_state +from aeon.clustering._cluster_initialisation import ( + resolve_center_initialiser, +) from aeon.clustering.base import BaseClusterer from aeon.distances import get_distance_function, pairwise_distance @@ -235,7 +238,6 @@ def _compute_new_cluster_centers( return np.array(new_center_indexes) def _compute_distance(self, X: np.ndarray, first_index: int, second_index: int): - # Check cache if np.isfinite(self._distance_cache[first_index, second_index]): return self._distance_cache[first_index, second_index] if np.isfinite(self._distance_cache[second_index, first_index]): @@ -243,7 +245,6 @@ def _compute_distance(self, X: np.ndarray, first_index: int, second_index: int): dist = self._distance_callable( X[first_index], X[second_index], **self._distance_params ) - # Update cache self._distance_cache[first_index, second_index] = dist self._distance_cache[second_index, first_index] = dist return dist @@ -271,7 +272,7 @@ def _pam_fit(self, X: np.ndarray): n_cases = X.shape[0] if isinstance(self._init, Callable): - medoids_idxs = self._init(X) + medoids_idxs = self._init(X=X) else: medoids_idxs = self._init not_medoid_idxs = np.arange(n_cases, dtype=int) @@ -282,7 +283,6 @@ def _pam_fit(self, X: np.ndarray): not_medoid_idxs = np.delete(np.arange(n_cases, dtype=int), medoids_idxs) for i in range(self.max_iter): - # Initialize best cost change and the associated swap couple. old_medoid_idxs = np.copy(medoids_idxs) best_cost_change = self._compute_optimal_swaps( distance_matrix, @@ -293,7 +293,6 @@ def _pam_fit(self, X: np.ndarray): ) inertia = np.inf - # If one of the swap decrease the objective, return that swap. if best_cost_change is not None and best_cost_change[2] < 0: first, second, _ = best_cost_change medoids_idxs[medoids_idxs == first] = second @@ -393,7 +392,7 @@ def _compute_optimal_swaps( def _alternate_fit(self, X) -> tuple[np.ndarray, np.ndarray, float, int]: cluster_center_indexes = self._init if isinstance(self._init, Callable): - cluster_center_indexes = self._init(X) + cluster_center_indexes = self._init(X=X) old_inertia = np.inf old_indexes = None for i in range(self.max_iter): @@ -431,32 +430,22 @@ def _assign_clusters( def _check_params(self, X: np.ndarray) -> None: self._random_state = check_random_state(self.random_state) - _incorrect_init_str = ( - f"The value provided for init: {self.init} is " - f"invalid. The following are a list of valid init algorithms " - f"strings: random, kmedoids++, first, build. You can also pass a " - f"np.ndarray of size (n_clusters, n_channels, n_timepoints)" - ) - - if isinstance(self.init, str): - if self.init == "random": - self._init = self._random_center_initializer - elif self.init == "kmedoids++": - self._init = self._kmedoids_plus_plus_center_initializer - elif self.init == "first": - self._init = self._first_center_initializer - elif self.init == "build": - self._init = self._pam_build_center_initializer - else: - raise ValueError(_incorrect_init_str) - else: - if isinstance(self.init, np.ndarray) and len(self.init) == self.n_clusters: - self._init = self.init - else: - raise ValueError(_incorrect_init_str) - if self.distance_params is not None: self._distance_params = self.distance_params + else: + self._distance_params = {} + + self._init = resolve_center_initialiser( + init=self.init, + X=X, + n_clusters=self.n_clusters, + random_state=self._random_state, + distance=self.distance, + distance_params=self._distance_params, + n_jobs=1, + custom_init_handlers={"build": self._pam_build_center_initializer}, + use_indexes=True, + ) if self.n_clusters > X.shape[0]: raise ValueError( @@ -481,28 +470,6 @@ def _check_params(self, X: np.ndarray) -> None: stacklevel=1, ) - def _random_center_initializer(self, X: np.ndarray) -> np.ndarray: - return self._random_state.choice(X.shape[0], self.n_clusters, replace=False) - - def _first_center_initializer(self, _) -> np.ndarray: - return np.array(list(range(self.n_clusters))) - - def _kmedoids_plus_plus_center_initializer(self, X: np.ndarray): - initial_center_idx = self._random_state.randint(X.shape[0]) - indexes = [initial_center_idx] - - for _ in range(1, self.n_clusters): - pw_dist = pairwise_distance( - X, X[indexes], method=self.distance, **self._distance_params - ) - min_distances = pw_dist.min(axis=1) - probabilities = min_distances / min_distances.sum() - next_center_idx = self._random_state.choice(X.shape[0], p=probabilities) - indexes.append(next_center_idx) - - centers = X[indexes] - return centers - def _pam_build_center_initializer( self, X: np.ndarray, diff --git a/aeon/clustering/_kasba.py b/aeon/clustering/_kasba.py index 63fdd53115..addbef5038 100644 --- a/aeon/clustering/_kasba.py +++ b/aeon/clustering/_kasba.py @@ -9,6 +9,7 @@ from numpy.random import RandomState from sklearn.utils import check_random_state +from aeon.clustering._cluster_initialisation import _kmeans_plus_plus_center_initialiser from aeon.clustering._k_means import EmptyClusterError from aeon.clustering.averaging import kasba_average from aeon.clustering.base import BaseClusterer @@ -142,8 +143,16 @@ def __init__( def _fit(self, X: np.ndarray, y=None): self._check_params(X) - cluster_centers, distances_to_centers, labels = self._elastic_kmeans_plus_plus( - X, + cluster_centers, labels, distances_to_centers = ( + _kmeans_plus_plus_center_initialiser( + X=X, + n_clusters=self.n_clusters, + random_state=self._random_state, + distance=self.distance, + distance_params=self._distance_params, + n_jobs=1, + return_distance_and_labels=True, + ) ) self.labels_, self.cluster_centers_, self.inertia_, self.n_iter_ = self._kasba( X, @@ -321,34 +330,6 @@ def _handle_empty_cluster( return labels, cluster_centers, distances_to_centers - def _elastic_kmeans_plus_plus( - self, - X, - ): - initial_center_idx = self._random_state.randint(X.shape[0]) - indexes = [initial_center_idx] - - min_distances = pairwise_distance( - X, X[initial_center_idx], method=self.distance, **self._distance_params - ).flatten() - labels = np.zeros(X.shape[0], dtype=int) - - for i in range(1, self.n_clusters): - probabilities = min_distances / min_distances.sum() - next_center_idx = self._random_state.choice(X.shape[0], p=probabilities) - indexes.append(next_center_idx) - - new_distances = pairwise_distance( - X, X[next_center_idx], method=self.distance, **self._distance_params - ).flatten() - - closer_points = new_distances < min_distances - min_distances[closer_points] = new_distances[closer_points] - labels[closer_points] = i - - centers = X[indexes] - return centers, min_distances, labels - def _check_params(self, X: np.ndarray) -> None: self._random_state = check_random_state(self.random_state) diff --git a/aeon/clustering/averaging/tests/test_kasba.py b/aeon/clustering/averaging/tests/test_kasba.py index 8770dfe4ec..ed98095c6c 100644 --- a/aeon/clustering/averaging/tests/test_kasba.py +++ b/aeon/clustering/averaging/tests/test_kasba.py @@ -133,7 +133,7 @@ def test_kasba_ba_multi(distance, init_barycenter): assert average_ts_multi.shape == X_train_multi[0].shape assert np.allclose(average_ts_multi, call_directly_average_ts_multi) # EDR and shape_dtw with random values don't update the barycenter so skipping - if distance not in ["shape_dtw", "edr"]: + if distance not in ["shape_dtw", "edr", "soft_dtw"]: # Test not just returning the init barycenter assert not np.array_equal(average_ts_multi, init_barycenter) diff --git a/aeon/clustering/tests/test_clarans.py b/aeon/clustering/tests/test_clarans.py index a1da285cf3..d6968d6ccd 100644 --- a/aeon/clustering/tests/test_clarans.py +++ b/aeon/clustering/tests/test_clarans.py @@ -2,12 +2,9 @@ import numpy as np from sklearn import metrics -from sklearn.utils import check_random_state from aeon.clustering._clarans import TimeSeriesCLARANS -from aeon.clustering.tests.test_k_medoids import check_value_in_every_cluster from aeon.datasets import load_basic_motions, load_gunpoint -from aeon.distances import euclidean_distance def test_clarans_uni(): @@ -95,45 +92,3 @@ def test_clara_multi(): assert isinstance(clarans.cluster_centers_, np.ndarray) for val in proba: assert np.count_nonzero(val == 1.0) == 1 - - -def test_medoids_init(): - """Test init algorithms.""" - X_train, _ = load_gunpoint(split="train") - X_train = X_train[:10] - - num_clusters = 8 - kmedoids = TimeSeriesCLARANS( - random_state=1, - n_init=1, - init="first", - distance="euclidean", - n_clusters=num_clusters, - ) - - kmedoids._random_state = check_random_state(kmedoids.random_state) - kmedoids._distance_cache = np.full((len(X_train), len(X_train)), np.inf) - kmedoids._distance_callable = euclidean_distance - first_medoids_result = kmedoids._first_center_initializer(X_train) - check_value_in_every_cluster(num_clusters, first_medoids_result) - random_medoids_result = kmedoids._random_center_initializer(X_train) - check_value_in_every_cluster(num_clusters, random_medoids_result) - kmedoids_plus_plus_medoids_result = kmedoids._kmedoids_plus_plus_center_initializer( - X_train - ) - check_value_in_every_cluster(num_clusters, kmedoids_plus_plus_medoids_result) - kmedoids_build_result = kmedoids._pam_build_center_initializer(X_train) - check_value_in_every_cluster(num_clusters, kmedoids_build_result) - - # Test setting manual init centres - num_clusters = 8 - custom_init_centres = np.array([1, 2, 3, 4, 5, 6, 7, 8]) - kmedoids = TimeSeriesCLARANS( - random_state=1, - n_init=1, - init=custom_init_centres, - distance="euclidean", - n_clusters=num_clusters, - ) - kmedoids.fit(X_train) - assert np.array_equal(kmedoids.cluster_centers_, X_train[custom_init_centres]) diff --git a/aeon/clustering/tests/test_cluster_initialisation.py b/aeon/clustering/tests/test_cluster_initialisation.py new file mode 100644 index 0000000000..922e36d33d --- /dev/null +++ b/aeon/clustering/tests/test_cluster_initialisation.py @@ -0,0 +1,111 @@ +"""Tests for cluster initialisation functions.""" + +from collections.abc import Callable + +import numpy as np +import pytest +from numpy.random import RandomState + +from aeon.clustering._cluster_initialisation import ( + _CENTRE_INITIALISER_INDEXES, + _CENTRE_INITIALISERS, +) +from aeon.distances._distance import ELASTIC_DISTANCES, POINTWISE_DISTANCES +from aeon.testing.data_generation import make_example_3d_numpy + +NON_RANDOM_INIT = ["first"] + + +def _run_initialisation_test( + key: str, + init_func: Callable, + init_func_indexes: Callable = None, + init_func_params=None, +): + if init_func_params is None: + init_func_params = {} + + X = make_example_3d_numpy(10, 1, 10, random_state=1, return_y=False) + n_clusters = 3 + init_func_params = { + "X": X, + "n_clusters": n_clusters, + **init_func_params, + } + + values = init_func(**init_func_params, random_state=RandomState(1)) + + assert len(values) == n_clusters + assert values.shape[1:] == X.shape[1:] + + assert np.allclose( + values, init_func(**init_func_params, random_state=RandomState(1)) + ) + + if key not in NON_RANDOM_INIT: + diff_random_state_values = init_func( + **init_func_params, random_state=RandomState(2) + ) + assert not np.allclose(values, diff_random_state_values) + + if init_func_indexes: + indexes = init_func_indexes(**init_func_params, random_state=RandomState(1)) + value_from_indexes = X[indexes] + assert np.allclose(values, value_from_indexes) + + +@pytest.mark.parametrize("init_key", _CENTRE_INITIALISERS.keys()) +def test_center_initialisers(init_key): + """Test all center initialisers.""" + params = {} + if init_key == "kmeans++" or init_key == "kmedoids++": + params["distance"] = "euclidean" + params["distance_params"] = {} + + _run_initialisation_test( + key=init_key, + init_func=_CENTRE_INITIALISERS[init_key], + init_func_indexes=_CENTRE_INITIALISER_INDEXES.get(init_key, None), + init_func_params=params, + ) + + +@pytest.mark.parametrize("init_key", ["kmeans++", "kmedoids++"]) +@pytest.mark.parametrize("dist", POINTWISE_DISTANCES + ELASTIC_DISTANCES) +def test_distance_center_initialisers(init_key, dist): + """Test all center initialisers with distance.""" + params = { + "distance": dist, + "distance_params": {}, + } + _run_initialisation_test( + key=init_key, + init_func=_CENTRE_INITIALISERS[init_key], + init_func_indexes=_CENTRE_INITIALISER_INDEXES.get(init_key, None), + init_func_params=params, + ) + + +@pytest.mark.parametrize("init_key", ["kmeans++", "kmedoids++"]) +def test_distance_center_initialisers_params(init_key): + """Test all center initialisers with distance.""" + n_clusters = 3 + X = make_example_3d_numpy(50, 1, 10, random_state=1, return_y=False) + + init_func_no_window = _CENTRE_INITIALISERS[init_key]( + X=X, + n_clusters=n_clusters, + distance_params={}, + distance="soft_dtw", + random_state=RandomState(1), + ) + + init_func_window = _CENTRE_INITIALISERS[init_key]( + X=X, + n_clusters=n_clusters, + distance_params={"gamma": 0.00001}, + distance="soft_dtw", + random_state=RandomState(1), + ) + + assert not np.array_equal(init_func_no_window, init_func_window) diff --git a/aeon/clustering/tests/test_elastic_som.py b/aeon/clustering/tests/test_elastic_som.py index 5d5ef47630..530ba7116c 100644 --- a/aeon/clustering/tests/test_elastic_som.py +++ b/aeon/clustering/tests/test_elastic_som.py @@ -4,6 +4,7 @@ import pytest from aeon.clustering import ElasticSOM +from aeon.clustering._cluster_initialisation import _CENTRE_INITIALISERS from aeon.distances import dtw_distance, msm_alignment_path from aeon.distances._distance import ELASTIC_DISTANCES from aeon.testing.data_generation import make_example_3d_numpy @@ -37,26 +38,29 @@ def test_elastic_som_multivariate(): assert preds.shape == (10,) -def test_elastic_som_init(): +@pytest.mark.parametrize("init", list(_CENTRE_INITIALISERS.keys()) + ["ndarray"]) +def test_elastic_som_init(init): """Test ElasticSOM with a custom initialization.""" X = make_example_3d_numpy( n_cases=10, n_channels=5, n_timepoints=20, return_y=False, random_state=1 ) - labels = [] - for init in ["random", "kmeans++", "first"]: - clst = ElasticSOM(n_clusters=3, init=init, random_state=1, num_iterations=10) - clst.fit(X) - assert clst.labels_.shape == (10,) - assert clst.cluster_centers_.shape == (3, 5, 20) - labels.append(clst.labels_) + if init == "ndarray": + init = X[:3] - preds = clst.predict(X) - assert preds.shape == (10,) + clst = ElasticSOM(n_clusters=3, init=init, random_state=1, num_iterations=10) + clst.fit(X) + assert clst.labels_.shape == (10,) + assert clst.cluster_centers_.shape == (3, 5, 20) - # Check that the labels are different - assert not np.array_equal(labels[0], labels[1]) - assert not np.array_equal(labels[0], labels[2]) - assert not np.array_equal(labels[1], labels[2]) + preds = clst.predict(X) + assert preds.shape == (10,) + + +def test_elastic_som_init_invalid(): + """Test ElasticSOM with invalid initialization.""" + X = make_example_3d_numpy( + n_cases=10, n_channels=5, n_timepoints=20, return_y=False, random_state=1 + ) # Test invalid init with pytest.raises(ValueError): clst = ElasticSOM( @@ -64,18 +68,6 @@ def test_elastic_som_init(): ) clst.fit(X) - # Test custom ndarray init - clst = ElasticSOM(n_clusters=3, init=X[:3], random_state=1, num_iterations=10) - clst.fit(X) - assert clst.labels_.shape == (10,) - assert clst.cluster_centers_.shape == (3, 5, 20) - - # Last labels is for "first" init - assert np.array_equal(clst.labels_, labels[-1]) - - preds = clst.predict(X) - assert preds.shape == (10,) - # Test more ndarrays than clusters with pytest.raises(ValueError): clst = ElasticSOM(n_clusters=3, init=X[:4], random_state=1, num_iterations=10) diff --git a/aeon/clustering/tests/test_k_means.py b/aeon/clustering/tests/test_k_means.py index ace08a5530..67bafee9f7 100644 --- a/aeon/clustering/tests/test_k_means.py +++ b/aeon/clustering/tests/test_k_means.py @@ -7,6 +7,7 @@ from sklearn import metrics from aeon.clustering import TimeSeriesKMeans +from aeon.clustering._cluster_initialisation import _CENTRE_INITIALISERS from aeon.datasets import load_basic_motions from aeon.distances._distance import ELASTIC_DISTANCES from aeon.testing.data_generation import make_example_3d_numpy @@ -169,18 +170,18 @@ def test_k_mean_distances(distance): "random_state": 1, "n_init": 1, "n_clusters": 3, - "init": "kmeans++", + "init": "random", "distance": dist, "distance_params": {key: params[key]}, } # Univariate test with_param_kmeans = _run_kmeans_test( - kmeans_params=curr_params, n_cases=40, n_channels=1, n_timepoints=10 + kmeans_params=curr_params, n_cases=80, n_channels=1, n_timepoints=10 ) # Multivariate test _run_kmeans_test( - kmeans_params=curr_params, n_cases=40, n_channels=3, n_timepoints=10 + kmeans_params=curr_params, n_cases=80, n_channels=3, n_timepoints=10 ) if dist in ELASTIC_DISTANCES: @@ -194,13 +195,15 @@ def test_k_mean_distances(distance): continue default_param_kmeans = _run_kmeans_test( - kmeans_params=curr_params, n_cases=40, n_channels=1, n_timepoints=10 + kmeans_params=curr_params, n_cases=80, n_channels=1, n_timepoints=10 ) - # Test parameters passed through kmeans - assert not np.array_equal( - with_param_kmeans.cluster_centers_, default_param_kmeans.cluster_centers_ - ) + if not isinstance(dist, Callable): + # Test parameters passed through kmeans + assert not np.array_equal( + with_param_kmeans.cluster_centers_, + default_param_kmeans.cluster_centers_, + ) @pytest.mark.parametrize("distance", TEST_DISTANCES_WITH_FULL_ALIGNMENT_PATH) @@ -248,15 +251,11 @@ def test_k_mean_ba(distance, averaging_method): @pytest.mark.parametrize("distance", TEST_DISTANCE_WITH_CUSTOM_DISTANCE) -@pytest.mark.parametrize("init", ["random", "kmeans++", "first", "ndarray"]) +@pytest.mark.parametrize("init", list(_CENTRE_INITIALISERS.keys()) + ["ndarray"]) def test_k_mean_init(distance, init): """Test implementation of Kmeans.""" distance, params = distance - # Only kmeans++ needs test with different distances - if init != "kmeans++" and distance != "euclidean": - return - n_cases = 10 n_timepoints = 10 n_clusters = 4 @@ -289,7 +288,7 @@ def test_k_mean_init(distance, init): kmeans._check_params(X_train_uni) if isinstance(kmeans._init, Callable): - uni_init_vals = kmeans._init(X_train_uni) + uni_init_vals = kmeans._init(X=X_train_uni) else: uni_init_vals = kmeans._init @@ -324,7 +323,7 @@ def test_k_mean_init(distance, init): kmeans._check_params(X_train_multi) if isinstance(kmeans._init, Callable): - multi_init_vals = kmeans._init(X_train_multi) + multi_init_vals = kmeans._init(X=X_train_multi) else: multi_init_vals = kmeans._init @@ -408,6 +407,49 @@ def test_empty_cluster(): kmeans.fit(np.array([first, first, first, first, first])) +def test_center_initialisers(): + """Test that CENTER_INITIALISERS work correctly.""" + from numpy.random import RandomState + + X_train = make_example_3d_numpy( + n_cases=20, n_channels=1, n_timepoints=10, random_state=1, return_y=False + ) + n_clusters = 3 + random_state = RandomState(1) + + # Test all available initializers + for init_name, initialiser_func in _CENTRE_INITIALISERS.items(): + if init_name == "kmeans++" or init_name == "kmedoids++": + # kmeans++ and kmedoids++ needs additional parameters + centers = initialiser_func( + X=X_train, + n_clusters=n_clusters, + random_state=random_state, + distance="euclidean", + distance_params={}, + n_jobs=1, + ) + else: + # Other initializers only need basic parameters + centers = initialiser_func( + X=X_train, + n_clusters=n_clusters, + random_state=random_state, + ) + + # Verify output shape + # random_values returns (n_clusters, n_channels) - it generates random values + # not based on the input data structure + if init_name == "random_values": + assert centers.shape == (n_clusters, X_train.shape[1], X_train.shape[2]) + else: + assert centers.shape == (n_clusters, X_train.shape[1], X_train.shape[2]) + # Verify no duplicate centers + for i in range(n_clusters): + for j in range(i + 1, n_clusters): + assert not np.array_equal(centers[i], centers[j]) + + def test_invalid_params(): """Test invalid parameters for k-mean.""" uni_data = make_example_3d_numpy( diff --git a/aeon/clustering/tests/test_k_medoids.py b/aeon/clustering/tests/test_k_medoids.py index 0fea3ead19..6864a09465 100644 --- a/aeon/clustering/tests/test_k_medoids.py +++ b/aeon/clustering/tests/test_k_medoids.py @@ -1,12 +1,14 @@ """Tests for time series k-medoids.""" +from collections.abc import Callable + import numpy as np +import pytest from sklearn import metrics -from sklearn.utils import check_random_state +from aeon.clustering._cluster_initialisation import _CENTRE_INITIALISER_INDEXES from aeon.clustering._k_medoids import TimeSeriesKMedoids from aeon.datasets import load_basic_motions, load_gunpoint -from aeon.distances import euclidean_distance def test_kmedoids_uni(): @@ -159,47 +161,34 @@ def check_value_in_every_cluster(num_clusters, initial_medoids): assert original_length == len(set(initial_medoids)) -def test_medoids_init(): +@pytest.mark.parametrize("init", list(_CENTRE_INITIALISER_INDEXES.keys()) + ["indexes"]) +def test_medoids_init(init): """Test implementation of Kmedoids.""" X_train, _ = load_gunpoint(split="train") X_train = X_train[:10] - num_clusters = 8 - kmedoids = TimeSeriesKMedoids( - random_state=1, - n_init=1, - max_iter=5, - init="first", - distance="euclidean", - n_clusters=num_clusters, - ) - kmedoids._random_state = check_random_state(kmedoids.random_state) - kmedoids._distance_cache = np.full((len(X_train), len(X_train)), np.inf) - kmedoids._distance_callable = euclidean_distance - first_medoids_result = kmedoids._first_center_initializer(X_train) - check_value_in_every_cluster(num_clusters, first_medoids_result) - random_medoids_result = kmedoids._random_center_initializer(X_train) - check_value_in_every_cluster(num_clusters, random_medoids_result) - kmedoids_plus_plus_medoids_result = kmedoids._kmedoids_plus_plus_center_initializer( - X_train - ) - check_value_in_every_cluster(num_clusters, kmedoids_plus_plus_medoids_result) - kmedoids_build_result = kmedoids._pam_build_center_initializer(X_train) - check_value_in_every_cluster(num_clusters, kmedoids_build_result) + num_clusters = 3 - # Test setting manual init centres - num_clusters = 8 - custom_init_centres = np.array([1, 2, 3, 4, 5, 6, 7, 8]) + if init == "indexes": + # Generate random indexes + rng = np.random.RandomState(1) + init = rng.choice(X_train.shape[0], num_clusters, replace=False) + + # Test initializer kmedoids = TimeSeriesKMedoids( random_state=1, n_init=1, max_iter=5, - init=custom_init_centres, + init=init, distance="euclidean", n_clusters=num_clusters, ) - kmedoids.fit(X_train) - assert np.array_equal(kmedoids.cluster_centers_, X_train[custom_init_centres]) + kmedoids._check_params(X_train) + if isinstance(kmedoids._init, Callable): + medoids_result = kmedoids._init(X=X_train) + else: + medoids_result = kmedoids._init + check_value_in_every_cluster(num_clusters, medoids_result) def _get_model_centres(data, distance, method="pam", distance_params=None): @@ -230,3 +219,28 @@ def test_custom_distance_params(): data, distance="msm", distance_params={"window": 0.01} ) assert not np.array_equal(default_dist, custom_params_dist) + + +def test_medoids_init_invalid(): + """Test implementation of Kmedoids with invalid init.""" + X_train, _ = load_gunpoint(split="train") + X_train = X_train[:10] + num_clusters = 3 + + # Test float array + with pytest.raises(ValueError, match="Expected an array of integers"): + kmedoids = TimeSeriesKMedoids( + n_clusters=num_clusters, + init=np.array([0.5, 1.5, 2.5]), + random_state=1, + ) + kmedoids.fit(X_train) + + # Test out of bounds + with pytest.raises(ValueError, match="Values must be in the range"): + kmedoids = TimeSeriesKMedoids( + n_clusters=num_clusters, + init=np.array([0, 1, 100]), + random_state=1, + ) + kmedoids.fit(X_train) diff --git a/aeon/distances/elastic/__init__.py b/aeon/distances/elastic/__init__.py index 8e5d1aa9dd..40816f109d 100644 --- a/aeon/distances/elastic/__init__.py +++ b/aeon/distances/elastic/__init__.py @@ -54,6 +54,8 @@ "soft_dtw_pairwise_distance", "soft_dtw_alignment_path", "soft_dtw_cost_matrix", + "soft_dtw_grad_x", + "soft_dtw_alignment_matrix", ] from aeon.distances.elastic._adtw import ( @@ -111,12 +113,6 @@ shape_dtw_distance, shape_dtw_pairwise_distance, ) -from aeon.distances.elastic._soft_dtw import ( - soft_dtw_alignment_path, - soft_dtw_cost_matrix, - soft_dtw_distance, - soft_dtw_pairwise_distance, -) from aeon.distances.elastic._twe import ( twe_alignment_path, twe_cost_matrix, @@ -135,3 +131,11 @@ wdtw_distance, wdtw_pairwise_distance, ) +from aeon.distances.elastic.soft._soft_dtw import ( + soft_dtw_alignment_matrix, + soft_dtw_alignment_path, + soft_dtw_cost_matrix, + soft_dtw_distance, + soft_dtw_grad_x, + soft_dtw_pairwise_distance, +) diff --git a/aeon/distances/elastic/soft/__init__.py b/aeon/distances/elastic/soft/__init__.py new file mode 100644 index 0000000000..51807f085a --- /dev/null +++ b/aeon/distances/elastic/soft/__init__.py @@ -0,0 +1,19 @@ +"""Soft elastic distance functions.""" + +__all__ = [ + "soft_dtw_alignment_matrix", + "soft_dtw_alignment_path", + "soft_dtw_cost_matrix", + "soft_dtw_distance", + "soft_dtw_pairwise_distance", + "soft_dtw_grad_x", +] + +from aeon.distances.elastic.soft._soft_dtw import ( + soft_dtw_alignment_matrix, + soft_dtw_alignment_path, + soft_dtw_cost_matrix, + soft_dtw_distance, + soft_dtw_grad_x, + soft_dtw_pairwise_distance, +) diff --git a/aeon/distances/elastic/_soft_dtw.py b/aeon/distances/elastic/soft/_soft_dtw.py similarity index 62% rename from aeon/distances/elastic/_soft_dtw.py rename to aeon/distances/elastic/soft/_soft_dtw.py index 2c921ad41a..14a49001c9 100644 --- a/aeon/distances/elastic/_soft_dtw.py +++ b/aeon/distances/elastic/soft/_soft_dtw.py @@ -2,50 +2,19 @@ __maintainer__ = [] - import numpy as np from numba import njit, prange from numba.typed import List as NumbaList from aeon.distances.elastic._alignment_paths import compute_min_return_path from aeon.distances.elastic._bounding_matrix import create_bounding_matrix -from aeon.distances.elastic._dtw import _dtw_cost_matrix +from aeon.distances.elastic.soft._utils import _softmin3 from aeon.distances.pointwise._squared import _univariate_squared_distance from aeon.utils.conversion._convert_collection import _convert_collection_to_numba_list from aeon.utils.numba._threading import threaded from aeon.utils.validation.collection import _is_numpy_list_multivariate -@njit(fastmath=True, cache=True) -def _softmin3(a, b, c, gamma): - r"""Compute softmin of 3 input variables with parameter gamma. - - This code is adapted from tslearn. - - Parameters - ---------- - a : float - First input variable. - b : float - Second input variable. - c : float - Third input variable. - gamma : float - Softmin parameter. - - Returns - ------- - float - Softmin of a, b, c. - """ - a /= -gamma - b /= -gamma - c /= -gamma - max_val = max(a, b, c) - tmp = np.exp(a - max_val) + np.exp(b - max_val) + np.exp(c - max_val) - return -gamma * (np.log(tmp) + max_val) - - @njit(cache=True, fastmath=True) def soft_dtw_distance( x: np.ndarray, @@ -180,16 +149,46 @@ def soft_dtw_cost_matrix( >>> x = np.array([[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]]) >>> y = np.array([[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]]) >>> soft_dtw_cost_matrix(x, y) - array([[ 0., 1., 5., 14., 30., 55., 91., 140., 204., 285.], - [ 1., 0., 1., 5., 14., 30., 55., 91., 140., 204.], - [ 5., 1., 0., 1., 5., 14., 30., 55., 91., 140.], - [ 14., 5., 1., 0., 1., 5., 14., 30., 55., 91.], - [ 30., 14., 5., 1., 0., 1., 5., 14., 30., 55.], - [ 55., 30., 14., 5., 1., 0., 1., 5., 14., 30.], - [ 91., 55., 30., 14., 5., 1., 0., 1., 5., 14.], - [140., 91., 55., 30., 14., 5., 1., 0., 1., 5.], - [204., 140., 91., 55., 30., 14., 5., 1., 0., 1.], - [285., 204., 140., 91., 55., 30., 14., 5., 1., 0.]]) + array([[ 0.00000000e+00, 1.00000000e+00, 5.00000000e+00, + 1.40000000e+01, 3.00000000e+01, 5.50000000e+01, + 9.10000000e+01, 1.40000000e+02, 2.04000000e+02, + 2.85000000e+02], + [ 1.00000000e+00, -5.51444714e-01, 2.53133741e-01, + 4.24449127e+00, 1.32444333e+01, 2.92444332e+01, + 5.42444332e+01, 9.02444332e+01, 1.39244433e+02, + 2.03244433e+02], + [ 5.00000000e+00, 2.53133741e-01, -1.19042757e+00, + -4.05899430e-01, 3.58458692e+00, 1.25845231e+01, + 2.85845231e+01, 5.35845231e+01, 8.95845231e+01, + 1.38584523e+02], + [ 1.40000000e+01, 4.24449127e+00, -4.05899430e-01, + -1.83892772e+00, -1.05645307e+00, 2.93394434e+00, + 1.19338800e+01, 2.79338799e+01, 5.29338799e+01, + 8.89338799e+01], + [ 3.00000000e+01, 1.32444333e+01, 3.58458692e+00, + -1.05645307e+00, -2.48840826e+00, -1.70614379e+00, + 2.28424451e+00, 1.12841801e+01, 2.72841800e+01, + 5.22841800e+01], + [ 5.50000000e+01, 2.92444332e+01, 1.25845231e+01, + 2.93394434e+00, -1.70614379e+00, -3.13798920e+00, + -2.35574625e+00, 1.63464112e+00, 1.06345767e+01, + 2.66345766e+01], + [ 9.10000000e+01, 5.42444332e+01, 2.85845231e+01, + 1.19338800e+01, 2.28424451e+00, -2.35574625e+00, + -3.78758043e+00, -3.00533968e+00, 9.85047595e-01, + 9.98498314e+00], + [ 1.40000000e+02, 9.02444332e+01, 5.35845231e+01, + 2.79338799e+01, 1.12841801e+01, 1.63464112e+00, + -3.00533968e+00, -4.43717271e+00, -3.65493218e+00, + 3.35455083e-01], + [ 2.04000000e+02, 1.39244433e+02, 8.95845231e+01, + 5.29338799e+01, 2.72841800e+01, 1.06345767e+01, + 9.85047595e-01, -3.65493218e+00, -5.08676509e+00, + -4.30452459e+00], + [ 2.85000000e+02, 2.03244433e+02, 1.38584523e+02, + 8.89338799e+01, 5.22841800e+01, 2.66345766e+01, + 9.98498314e+00, 3.35455083e-01, -4.30452459e+00, + -5.73635749e+00]]) """ if x.ndim == 1 and y.ndim == 1: _x = x.reshape((1, x.shape[0])) @@ -210,20 +209,15 @@ def soft_dtw_cost_matrix( def _soft_dtw_distance( x: np.ndarray, y: np.ndarray, bounding_matrix: np.ndarray, gamma: float ) -> float: - return abs( - _soft_dtw_cost_matrix(x, y, bounding_matrix, gamma)[ - x.shape[1] - 1, y.shape[1] - 1 - ] - ) + return _soft_dtw_cost_matrix(x, y, bounding_matrix, gamma)[ + x.shape[1] - 1, y.shape[1] - 1 + ] @njit(cache=True, fastmath=True) def _soft_dtw_cost_matrix( x: np.ndarray, y: np.ndarray, bounding_matrix: np.ndarray, gamma: float -) -> np.ndarray: - if gamma == 0.0 or np.array_equal(x, y): - return _dtw_cost_matrix(x, y, bounding_matrix) - +) -> np.ndarray | tuple[np.ndarray, np.ndarray]: x_size = x.shape[1] y_size = y.shape[1] cost_matrix = np.full((x_size + 1, y_size + 1), np.inf) @@ -296,9 +290,9 @@ def soft_dtw_pairwise_distance( >>> # Distance between each time series in a collection of time series >>> X = np.array([[[1, 2, 3]],[[4, 5, 6]], [[7, 8, 9]]]) >>> soft_dtw_pairwise_distance(X) - array([[ 0. , 25.44075098, 107.99999917], - [ 25.44075098, 0. , 25.44075098], - [107.99999917, 25.44075098, 0. ]]) + array([[ -1.19042757, 25.44075098, 107.99999917], + [ 25.44075098, -1.19042757, 25.44075098], + [107.99999917, 25.44075098, -1.19042757]]) >>> # Distance between two collections of time series >>> X = np.array([[[1, 2, 3]],[[4, 5, 6]], [[7, 8, 9]]]) @@ -318,9 +312,9 @@ def soft_dtw_pairwise_distance( >>> # Distance between each TS in a collection of unequal-length time series >>> X = [np.array([1, 2, 3]), np.array([4, 5, 6, 7]), np.array([8, 9, 10, 11, 12])] >>> soft_dtw_pairwise_distance(X) - array([[ 0. , 41.44055555, 291.99999969], - [ 41.44055555, 0. , 82.43894439], - [291.99999969, 82.43894439, 0. ]]) + array([[ -1.19042757, 41.44055555, 291.99999969], + [ 41.44055555, -1.83892772, 82.43894439], + [291.99999969, 82.43894439, -2.48840826]]) """ multivariate_conversion = _is_numpy_list_multivariate(X, y) _X, unequal_length = _convert_collection_to_numba_list( @@ -357,7 +351,7 @@ def _soft_dtw_pairwise_distance( n_timepoints, n_timepoints, window, itakura_max_slope ) for i in prange(n_cases): - for j in range(i + 1, n_cases): + for j in range(n_cases): x1, x2 = X[i], X[j] if unequal_length: bounding_matrix = create_bounding_matrix( @@ -397,6 +391,30 @@ def _soft_dtw_from_multiple_to_multiple_distance( return distances +@njit(cache=True, fastmath=True) +def _soft_dtw_cost_matrix_return_dist_matrix( + x: np.ndarray, y: np.ndarray, bounding_matrix: np.ndarray, gamma: float +) -> tuple[np.ndarray, np.ndarray]: + x_size = x.shape[1] + y_size = y.shape[1] + cost_matrix = np.full((x_size + 1, y_size + 1), np.inf) + cost_matrix[0, 0] = 0.0 + dist_matrix = np.zeros((x_size, y_size)) + + for i in range(1, x_size + 1): + for j in range(1, y_size + 1): + if bounding_matrix[i - 1, j - 1]: + dist = _univariate_squared_distance(x[:, i - 1], y[:, j - 1]) + dist_matrix[i - 1, j - 1] = dist + cost_matrix[i, j] = dist + _softmin3( + cost_matrix[i - 1, j], + cost_matrix[i - 1, j - 1], + cost_matrix[i, j - 1], + gamma, + ) + return cost_matrix[1:, 1:], dist_matrix + + @njit(cache=True, fastmath=True) def soft_dtw_alignment_path( x: np.ndarray, @@ -448,5 +466,161 @@ def soft_dtw_alignment_path( cost_matrix = soft_dtw_cost_matrix(x, y, gamma, window, itakura_max_slope) return ( compute_min_return_path(cost_matrix), - abs(cost_matrix[x.shape[-1] - 1, y.shape[-1] - 1]), + cost_matrix[x.shape[-1] - 1, y.shape[-1] - 1], ) + + +@njit(cache=True, fastmath=True) +def soft_dtw_alignment_matrix( + x: np.ndarray, + y: np.ndarray, + gamma: float = 1.0, + window: float | None = None, + itakura_max_slope: float | None = None, +) -> tuple[np.ndarray, float]: + + if x.ndim == 1 and y.ndim == 1: + _x = x.reshape((1, x.shape[0])) + _y = y.reshape((1, y.shape[0])) + bounding_matrix = create_bounding_matrix( + _x.shape[1], _y.shape[1], window, itakura_max_slope + ) + cost_matrix, dist_matrix = _soft_dtw_cost_matrix_return_dist_matrix( + _x, _y, bounding_matrix, gamma # <- was (x, y) + ) + return ( + _soft_gradient(dist_matrix, cost_matrix, gamma), + cost_matrix[_x.shape[-1] - 1, _y.shape[-1] - 1], + ) + + if x.ndim == 2 and y.ndim == 2: + bounding_matrix = create_bounding_matrix( + x.shape[1], y.shape[1], window, itakura_max_slope + ) + cost_matrix, dist_matrix = _soft_dtw_cost_matrix_return_dist_matrix( + x, y, bounding_matrix, gamma + ) + return ( + _soft_gradient(dist_matrix, cost_matrix, gamma), + cost_matrix[x.shape[-1] - 1, y.shape[-1] - 1], + ) + + return (np.zeros((0, 0)), 0.0) + + +@njit(cache=True, fastmath=True) +def _soft_gradient( + distance_matrix: np.ndarray, cost_matrix: np.ndarray, gamma: float +) -> np.ndarray: + m, n = distance_matrix.shape + E = np.zeros((m, n), dtype=float) + + E[m - 1, n - 1] = 1.0 + + for i in range(m - 1, -1, -1): + for j in range(n - 1, -1, -1): + r_ij = cost_matrix[i, j] + E_ij = E[i, j] + + if i + 1 < m: + w_horizontal = np.exp( + (cost_matrix[i + 1, j] - r_ij - distance_matrix[i + 1, j]) / gamma + ) + E_ij += E[i + 1, j] * w_horizontal + + if j + 1 < n: + w_vertical = np.exp( + (cost_matrix[i, j + 1] - r_ij - distance_matrix[i, j + 1]) / gamma + ) + E_ij += E[i, j + 1] * w_vertical + + if (i + 1 < m) and (j + 1 < n): + w_diag = np.exp( + (cost_matrix[i + 1, j + 1] - r_ij - distance_matrix[i + 1, j + 1]) + / gamma + ) + E_ij += E[i + 1, j + 1] * w_diag + + E[i, j] = E_ij + + return E + + +def soft_dtw_grad_x( + x: np.ndarray, + y: np.ndarray, + gamma: float = 1.0, + window: float | None = None, + itakura_max_slope: float | None = None, +): + """ + Gradient (Jacobian) of soft-DTW distance w.r.t. the first series x. + + Returns (dx, distance); dx has shape (len(x),) for univariate or (C, T) for + multivariate. + + Parameters + ---------- + x : np.ndarray + First time series, shape ``(n_channels, n_timepoints)`` or ``(n_timepoints,)``. + y : np.ndarray + Second time series, shape ``(m_channels, m_timepoints)`` or ``(m_timepoints,)``. + gamma : float, default=1.0 + Controls the smoothness of the warping. A value of 0.0 is equivalent to DTW. + window : float, default=None + The window to use for the bounding matrix. If None, no bounding matrix + is used. + itakura_max_slope : float, default=None + Maximum slope as a proportion of the number of time points used to create + Itakura parallelogram on the bounding matrix. Must be between 0. and 1. + + Returns + ------- + dx : np.ndarray + The gradient of the soft-DTW distance with respect to ``x``. + If ``x`` is univariate (``x.ndim == 1``), the shape is ``(n_timepoints,)``. + If ``x`` is multivariate (``x.ndim == 2``), the shape is + ``(n_channels, n_timepoints)``. + distance: float + The soft-DTW distance between the two time series. + """ + if gamma <= 0: + raise ValueError("gamma must be > 0 for a differentiable soft minimum.") + if x.ndim == 1 and y.ndim == 1: + X = x.reshape((1, x.shape[0])) + Y = y.reshape((1, y.shape[0])) + else: + X = x + Y = y + dx, s_xy = _soft_dtw_grad_x(X, Y, gamma, window, itakura_max_slope) + return (dx.ravel(), s_xy) if x.ndim == 1 else (dx, s_xy) + + +@njit(cache=True, fastmath=True) +def _soft_dtw_grad_x( + X: np.ndarray, + Y: np.ndarray, + gamma: float = 1.0, + window: float | None = None, + itakura_max_slope: float | None = None, +): + # bounding + forward DP + bm = create_bounding_matrix(X.shape[1], Y.shape[1], window, itakura_max_slope) + cost_matrix, dist_matrix = _soft_dtw_cost_matrix_return_dist_matrix(X, Y, bm, gamma) + s_xy = cost_matrix[X.shape[1] - 1, Y.shape[1] - 1] + + # backward expected-alignment (node occupancy) + E = _soft_gradient(dist_matrix, cost_matrix, gamma) # shape (T, U) + + C, T = X.shape[0], X.shape[1] + U = Y.shape[1] + dx = np.zeros_like(X) + + # ∂s/∂X[:, i] = 2 * sum_j (X[:, i] - Y[:, j]) * E[i, j] + for i in range(T): + acc = np.zeros(C) + for j in range(U): + acc += (X[:, i] - Y[:, j]) * E[i, j] + dx[:, i] = 2.0 * acc + + return dx, s_xy diff --git a/aeon/distances/elastic/soft/_utils.py b/aeon/distances/elastic/soft/_utils.py new file mode 100644 index 0000000000..7d1bc3d873 --- /dev/null +++ b/aeon/distances/elastic/soft/_utils.py @@ -0,0 +1,57 @@ +import numpy as np +from numba import njit + + +@njit(fastmath=True, cache=True) +def _softmin3(a: float, b: float, c: float, gamma: float) -> float: + r"""Compute softmin of 3 input variables with parameter gamma. + + Parameters + ---------- + a : float + First input variable. + b : float + Second input variable. + c : float + Third input variable. + gamma : float + Softmin parameter. + + Returns + ------- + float + Softmin of a, b, c. + """ + a /= -gamma + b /= -gamma + c /= -gamma + max_val = max(a, b, c) + exp_sum = np.exp(a - max_val) + np.exp(b - max_val) + np.exp(c - max_val) + return -gamma * (np.log(exp_sum) + max_val) + + +@njit(fastmath=True, cache=True) +def _softmin2(a: float, b: float, gamma: float) -> float: + return _soft_min_arr([a, b], gamma) + + +@njit(fastmath=True, cache=True) +def _soft_min_arr(values: list, gamma: float) -> float: + n = len(values) + if n == 0: + return np.inf + if n == 1: + return values[0] + + neg_gamma = -gamma + max_val = -np.inf + for i in range(n): + vi = values[i] / neg_gamma + if vi > max_val: + max_val = vi + + exp_sum = 0.0 + for i in range(n): + exp_sum += np.exp(values[i] / neg_gamma - max_val) + + return neg_gamma * (np.log(exp_sum) + max_val) diff --git a/aeon/distances/elastic/tests/test_cost_matrix.py b/aeon/distances/elastic/tests/test_cost_matrix.py index 1241e7f390..b6cd8cb785 100644 --- a/aeon/distances/elastic/tests/test_cost_matrix.py +++ b/aeon/distances/elastic/tests/test_cost_matrix.py @@ -64,8 +64,6 @@ def _validate_cost_matrix_result( cost_matrix_result[-1, -1] / max(x.shape[-1], y.shape[-1]) ) assert_almost_equal(curr_distance, distance_result) - elif name == "soft_dtw": - assert_almost_equal(abs(cost_matrix_result[-1, -1]), distance_result) else: assert_almost_equal(cost_matrix_result[-1, -1], distance_result) diff --git a/aeon/distances/tests/test_pairwise.py b/aeon/distances/tests/test_pairwise.py index 9037186f67..5b53161587 100644 --- a/aeon/distances/tests/test_pairwise.py +++ b/aeon/distances/tests/test_pairwise.py @@ -67,9 +67,16 @@ def _validate_pairwise_result( assert_almost_equal( pairwise_result, compute_pairwise_distance(x, method=name, symmetric=symmetric) ) + + computed_pw = compute_pairwise_distance(x, method=distance, symmetric=symmetric) + if name == "soft_dtw": + # Soft to self dont return 0 on diagonal + computed_pw[np.diag_indices_from(computed_pw)] = pairwise_result[ + np.diag_indices_from(pairwise_result) + ] assert_almost_equal( pairwise_result, - compute_pairwise_distance(x, method=distance, symmetric=symmetric), + computed_pw, ) if isinstance(x, np.ndarray): @@ -552,7 +559,11 @@ def test_single_to_multiple_distances(dist): def test_pairwise_distance_non_negative(dist, seed): """Most estimators require distances to be non-negative.""" # Skip for now - if dist["name"] in MIN_DISTANCES or dist["name"] in MP_DISTANCES: + if ( + dist["name"] in MIN_DISTANCES + or dist["name"] in MP_DISTANCES + or dist["name"] in ["soft_dtw"] + ): return X = make_example_3d_numpy( n_cases=5, n_channels=1, n_timepoints=10, random_state=seed, return_y=False diff --git a/aeon/distances/tests/test_sklearn_compatibility.py b/aeon/distances/tests/test_sklearn_compatibility.py index 5d68e4114f..a59a9c45c5 100644 --- a/aeon/distances/tests/test_sklearn_compatibility.py +++ b/aeon/distances/tests/test_sklearn_compatibility.py @@ -18,7 +18,11 @@ def test_function_transformer(dist): """Test all distances work with FunctionTransformer in a pipeline.""" # Skip for now - if dist["name"] in MIN_DISTANCES or dist["name"] in MP_DISTANCES: + if ( + dist["name"] in MIN_DISTANCES + or dist["name"] in MP_DISTANCES + or dist["name"] in ["soft_dtw"] + ): return X = make_example_3d_numpy( n_cases=5, n_channels=1, n_timepoints=10, return_y=False, random_state=1 @@ -40,7 +44,11 @@ def test_function_transformer(dist): def test_distance_based(dist): """Test all distances work with KNN in a pipeline.""" # Skip for now - if dist["name"] in MIN_DISTANCES or dist["name"] in MP_DISTANCES: + if ( + dist["name"] in MIN_DISTANCES + or dist["name"] in MP_DISTANCES + or dist["name"] in ["soft_dtw"] + ): return X, y = make_example_3d_numpy( n_cases=6, n_channels=1, n_timepoints=10, regression_target=True @@ -65,7 +73,11 @@ def test_distance_based(dist): def test_clusterer(dist): """Test all distances work with DBSCAN.""" # Skip for now - if dist["name"] in MIN_DISTANCES or dist["name"] in MP_DISTANCES: + if ( + dist["name"] in MIN_DISTANCES + or dist["name"] in MP_DISTANCES + or dist["name"] in ["soft_dtw"] + ): return X = make_example_3d_numpy(n_cases=5, n_channels=1, n_timepoints=10, return_y=False) db = DBSCAN(metric="precomputed", eps=2.5) @@ -86,7 +98,11 @@ def test_univariate(dist, k, task): """Test all distances work with sklearn nearest neighbours.""" # TODO: when solved the issue with lcss and edr, remove this condition # Skip for now - if dist["name"] in MIN_DISTANCES or dist["name"] in MP_DISTANCES: + if ( + dist["name"] in MIN_DISTANCES + or dist["name"] in MP_DISTANCES + or dist["name"] in ["soft_dtw"] + ): return # https://github.com/aeon-toolkit/aeon/issues/882 @@ -155,7 +171,11 @@ def test_univariate(dist, k, task): def test_multivariate(dist, k, task): """Test all distances work with sklearn nearest neighbours.""" # Skip for now - if dist["name"] in MIN_DISTANCES or dist["name"] in MP_DISTANCES: + if ( + dist["name"] in MIN_DISTANCES + or dist["name"] in MP_DISTANCES + or dist["name"] in ["soft_dtw"] + ): return # TODO: when solved the issue with lcss and edr, remove this condition # https://github.com/aeon-toolkit/aeon/issues/882 diff --git a/aeon/testing/expected_results/expected_distance_results.py b/aeon/testing/expected_results/expected_distance_results.py index 5eb5107fdb..d801252d50 100644 --- a/aeon/testing/expected_results/expected_distance_results.py +++ b/aeon/testing/expected_results/expected_distance_results.py @@ -105,10 +105,10 @@ 25.0, ], "soft_dtw": [ - 12.948921674222193, - 12.948921674222193, - 7.572515659641623, - 1.972517138600932, + -12.948921674222193, + -12.948921674222193, + -7.572515659641623, + -1.972517138600932, 25.0, ], "sbd": [ @@ -181,9 +181,9 @@ [0.5328555761145305, 13.072037194954508], ], "soft_dtw": [ - [12.25477773906269, 6.893330315245519], - [8.602610210695161, 8.645028399102344], - [1.750534284134988, 12.516745017325773], + [-12.25477773906269, 6.893330315245519], + [-8.602610210695161, 8.645028399102344], + [-1.750534284134988, 12.516745017325773], ], "sbd": [[0.2435580798173309, 0.18613277150939772]], "shift_scale": [ diff --git a/aeon/testing/utils/_distance_parameters.py b/aeon/testing/utils/_distance_parameters.py index c82dd14314..5281db60f1 100644 --- a/aeon/testing/utils/_distance_parameters.py +++ b/aeon/testing/utils/_distance_parameters.py @@ -37,3 +37,7 @@ def _custom_distance_measure(x, y, custom_param=1): TEST_DISTANCE_WITH_CUSTOM_DISTANCE = TEST_DISTANCE_WITH_PARAMS + [ (_custom_distance_measure, {"custom_param": 10}), ] + +TEST_SOFT_DISTANCES_WITH_PARAMS = [ + ("soft_dtw", {"gamma": 0.1}), +]