diff --git a/stumpy/aamp.py b/stumpy/aamp.py index ca253c9af..4928baaf7 100644 --- a/stumpy/aamp.py +++ b/stumpy/aamp.py @@ -15,7 +15,7 @@ @njit( # "(f8[:], f8[:], i8, b1[:], b1[:], f8, i8[:], i8, i8, i8, f8[:, :, :]," - # "i8[:, :, :], b1)", + # "f8[:, :], f8[:, :], i8[:, :, :], i8[:, :], i8[:, :], b1)", fastmath=True, ) def _compute_diagonal( @@ -30,12 +30,17 @@ def _compute_diagonal( diags_stop_idx, thread_idx, P, + PL, + PR, I, + IL, + IR, ignore_trivial, ): """ - Compute (Numba JIT-compiled) and update P, I along a single diagonal using a single - thread and avoiding race conditions + Compute (Numba JIT-compiled) and update the (top-k) matrix profile P, + PL, PR, I, IL, and IR sequentially along individual diagonals using a single + thread and avoiding race conditions. Parameters ---------- @@ -49,12 +54,6 @@ def _compute_diagonal( m : int Window size - P : numpy.ndarray - Matrix profile - - I : numpy.ndarray - Matrix profile indices - T_A_subseq_isfinite : numpy.ndarray A boolean array that indicates whether a subsequence in `T_A` contains a `np.nan`/`np.inf` value (False) @@ -78,6 +77,24 @@ def _compute_diagonal( thread_idx : int The thread index + P : numpy.ndarray + The (top-k) matrix profile, sorted in ascending order per row + + PL : numpy.ndarray + The top-1 left marix profile + + PR : numpy.ndarray + The top-1 right marix profile + + I : numpy.ndarray + The (top-k) matrix profile indices + + IL : numpy.ndarray + The top-1 left matrix profile indices + + IR : numpy.ndarray + The top-1 right matrix profile indices + ignore_trivial : bool Set to `True` if this is a self-join. Otherwise, for AB-join, set this to `False`. Default is `True`. @@ -92,16 +109,16 @@ def _compute_diagonal( uint64_1 = np.uint64(1) for diag_idx in range(diags_start_idx, diags_stop_idx): - k = diags[diag_idx] + g = diags[diag_idx] - if k >= 0: - iter_range = range(0, min(n_A - m + 1, n_B - m + 1 - k)) + if g >= 0: + iter_range = range(0, min(n_A - m + 1, n_B - m + 1 - g)) else: - iter_range = range(-k, min(n_A - m + 1, n_B - m + 1 - k)) + iter_range = range(-g, min(n_A - m + 1, n_B - m + 1 - g)) for i in iter_range: uint64_i = np.uint64(i) - uint64_j = np.uint64(i + k) + uint64_j = np.uint64(i + g) if uint64_i == 0 or uint64_j == 0: p_norm = ( @@ -129,36 +146,59 @@ def _compute_diagonal( if T_A_subseq_isfinite[uint64_i] and T_B_subseq_isfinite[uint64_j]: # Neither subsequence contains NaNs - if p_norm < P[thread_idx, uint64_i, 0]: - P[thread_idx, uint64_i, 0] = p_norm - I[thread_idx, uint64_i, 0] = uint64_j - if ignore_trivial: - if p_norm < P[thread_idx, uint64_j, 0]: - P[thread_idx, uint64_j, 0] = p_norm - I[thread_idx, uint64_j, 0] = uint64_i + # `P[thread_idx, i, :]` is sorted ascendingly and MUST be updated + # when the newly-calculated `p_norm` value becomes smaller than the + # last (i.e. greatest) element in this array. Note that the goal + # is to have top-k smallest distancs for each subsequence. + if p_norm < P[thread_idx, uint64_i, -1]: + idx = np.searchsorted(P[thread_idx, uint64_i], p_norm) + core._shift_insert_at_index( + P[thread_idx, uint64_i], idx, p_norm, shift="right" + ) + core._shift_insert_at_index( + I[thread_idx, uint64_i], idx, uint64_j, shift="right" + ) + + if ignore_trivial: # self-joins only + if p_norm < P[thread_idx, uint64_j, -1]: + idx = np.searchsorted(P[thread_idx, uint64_j], p_norm) + core._shift_insert_at_index( + P[thread_idx, uint64_j], idx, p_norm, shift="right" + ) + core._shift_insert_at_index( + I[thread_idx, uint64_j], idx, uint64_i, shift="right" + ) if uint64_i < uint64_j: # left matrix profile and left matrix profile index - if p_norm < P[thread_idx, uint64_j, 1]: - P[thread_idx, uint64_j, 1] = p_norm - I[thread_idx, uint64_j, 1] = uint64_i + if p_norm < PL[thread_idx, uint64_j]: + PL[thread_idx, uint64_j] = p_norm + IL[thread_idx, uint64_j] = uint64_i # right matrix profile and right matrix profile index - if p_norm < P[thread_idx, uint64_i, 2]: - P[thread_idx, uint64_i, 2] = p_norm - I[thread_idx, uint64_i, 2] = uint64_j + if p_norm < PR[thread_idx, uint64_i]: + PR[thread_idx, uint64_i] = p_norm + IR[thread_idx, uint64_i] = uint64_j return @njit( - # "(f8[:], f8[:], i8, b1[:], b1[:], i8[:], b1)", + # "(f8[:], f8[:], i8, b1[:], b1[:], i8[:], b1, i8)", parallel=True, fastmath=True, ) def _aamp( - T_A, T_B, m, T_A_subseq_isfinite, T_B_subseq_isfinite, p, diags, ignore_trivial + T_A, + T_B, + m, + T_A_subseq_isfinite, + T_B_subseq_isfinite, + p, + diags, + ignore_trivial, + k, ): """ A Numba JIT-compiled version of AAMP for parallel computation of the matrix @@ -194,13 +234,30 @@ def _aamp( Set to `True` if this is a self-join. Otherwise, for AB-join, set this to `False`. Default is `True`. + k : int + The number of top `k` smallest distances used to construct the matrix profile. + Note that this will increase the total computational time and memory usage + when k > 1. + Returns ------- - P : numpy.ndarray - Matrix profile + out1 : numpy.ndarray + The (top-k) matrix profile - I : numpy.ndarray - Matrix profile indices + out2 : numpy.ndarray + The (top-1) left matrix profile + + out3 : numpy.ndarray + The (top-1) right matrix profile + + out4 : numpy.ndarray + The (top-k) matrix profile indices + + out5 : numpy.ndarray + The (top-1) left matrix profile indices + + out6 : numpy.ndarray + The (top-1) right matrix profile indices Notes ----- @@ -213,8 +270,15 @@ def _aamp( n_B = T_B.shape[0] l = n_A - m + 1 n_threads = numba.config.NUMBA_NUM_THREADS - P = np.full((n_threads, l, 3), np.inf, dtype=np.float64) - I = np.full((n_threads, l, 3), -1, dtype=np.int64) + + P = np.full((n_threads, l, k), np.inf, dtype=np.float64) + I = np.full((n_threads, l, k), -1, dtype=np.int64) + + PL = np.full((n_threads, l), np.inf, dtype=np.float64) + IL = np.full((n_threads, l), -1, dtype=np.int64) + + PR = np.full((n_threads, l), np.inf, dtype=np.float64) + IR = np.full((n_threads, l), -1, dtype=np.int64) ndist_counts = core._count_diagonal_ndist(diags, m, n_A, n_B) diags_ranges = core._get_array_ranges(ndist_counts, n_threads, False) @@ -233,26 +297,37 @@ def _aamp( diags_ranges[thread_idx, 1], thread_idx, P, + PL, + PR, I, + IL, + IR, ignore_trivial, ) # Reduction of results from all threads for thread_idx in range(1, n_threads): - for i in prange(l): - if P[0, i, 0] > P[thread_idx, i, 0]: - P[0, i, 0] = P[thread_idx, i, 0] - I[0, i, 0] = I[thread_idx, i, 0] - # left matrix profile and left matrix profile indices - if P[0, i, 1] > P[thread_idx, i, 1]: - P[0, i, 1] = P[thread_idx, i, 1] - I[0, i, 1] = I[thread_idx, i, 1] - # right matrix profile and right matrix profile indices - if P[0, i, 2] > P[thread_idx, i, 2]: - P[0, i, 2] = P[thread_idx, i, 2] - I[0, i, 2] = I[thread_idx, i, 2] - - return np.power(P[0, :, :], 1.0 / p), I[0, :, :] + # update top-k arrays + core._merge_topk_PI(P[0], P[thread_idx], I[0], I[thread_idx]) + + # update left matrix profile and matrix profile indices + mask = PL[0] > PL[thread_idx] + PL[0][mask] = PL[thread_idx][mask] + IL[0][mask] = IL[thread_idx][mask] + + # update right matrix profile and matrix profile indices + mask = PR[0] > PR[thread_idx] + PR[0][mask] = PR[thread_idx][mask] + IR[0][mask] = IR[thread_idx][mask] + + return ( + np.power(P[0], 1.0 / p), + np.power(PL[0], 1.0 / p), + np.power(PR[0], 1.0 / p), + I[0], + IL[0], + IR[0], + ) def aamp(T_A, m, T_B=None, ignore_trivial=True, p=2.0, k=1): @@ -291,8 +366,16 @@ def aamp(T_A, m, T_B=None, ignore_trivial=True, p=2.0, k=1): Returns ------- out : numpy.ndarray - The first column consists of the matrix profile, the second column - consists of the matrix profile indices. + When k = 1 (default), the first column consists of the matrix profile, + the second column consists of the matrix profile indices, the third column + consists of the left matrix profile indices, and the fourth column consists + of the right matrix profile indices. However, when k > 1, the output array + will contain exactly 2 * k + 2 columns. The first k columns (i.e., out[:, :k]) + consists of the top-k matrix profile, the next set of k columns + (i.e., out[:, k:2k]) consists of the corresponding top-k matrix profile + indices, and the last two columns (i.e., out[:, 2k] and out[:, 2k+1] or, + equivalently, out[:, -2] and out[:, -1]) correspond to the top-1 left + matrix profile indices and the top-1 right matrix profile indices, respectively. Notes ----- @@ -331,19 +414,26 @@ def aamp(T_A, m, T_B=None, ignore_trivial=True, p=2.0, k=1): l = n_A - m + 1 excl_zone = int(np.ceil(m / config.STUMPY_EXCL_ZONE_DENOM)) - out = np.empty((l, 4), dtype=object) - if ignore_trivial: diags = np.arange(excl_zone + 1, n_A - m + 1, dtype=np.int64) else: diags = np.arange(-(n_A - m + 1) + 1, n_B - m + 1, dtype=np.int64) - P, I = _aamp( - T_A, T_B, m, T_A_subseq_isfinite, T_B_subseq_isfinite, p, diags, ignore_trivial + P, PL, PR, I, IL, IR = _aamp( + T_A, + T_B, + m, + T_A_subseq_isfinite, + T_B_subseq_isfinite, + p, + diags, + ignore_trivial, + k, ) - out[:, 0] = P[:, 0] - out[:, 1:] = I[:, :] + out = np.empty((l, 2 * k + 2), dtype=object) + out[:, :k] = P + out[:, k:] = np.column_stack((I, IL, IR)) core._check_P(out[:, 0]) diff --git a/stumpy/aamped.py b/stumpy/aamped.py index 3dd346275..97205c53d 100644 --- a/stumpy/aamped.py +++ b/stumpy/aamped.py @@ -55,8 +55,16 @@ def aamped(dask_client, T_A, m, T_B=None, ignore_trivial=True, p=2.0, k=1): Returns ------- out : numpy.ndarray - The first column consists of the matrix profile, the second column - consists of the matrix profile indices. + When k = 1 (default), the first column consists of the matrix profile, + the second column consists of the matrix profile indices, the third column + consists of the left matrix profile indices, and the fourth column consists + of the right matrix profile indices. However, when k > 1, the output array + will contain exactly 2 * k + 2 columns. The first k columns (i.e., out[:, :k]) + consists of the top-k matrix profile, the next set of k columns + (i.e., out[:, k:2k]) consists of the corresponding top-k matrix profile + indices, and the last two columns (i.e., out[:, 2k] and out[:, 2k+1] or, + equivalently, out[:, -2] and out[:, -1]) correspond to the top-1 left + matrix profile indices and the top-1 right matrix profile indices, respectively. Notes ----- @@ -94,12 +102,10 @@ def aamped(dask_client, T_A, m, T_B=None, ignore_trivial=True, p=2.0, k=1): n_B = T_B.shape[0] l = n_A - m + 1 - excl_zone = int(np.ceil(m / config.STUMPY_EXCL_ZONE_DENOM)) - out = np.empty((l, 4), dtype=object) - hosts = list(dask_client.ncores().keys()) nworkers = len(hosts) + excl_zone = int(np.ceil(m / config.STUMPY_EXCL_ZONE_DENOM)) if ignore_trivial: diags = np.arange(excl_zone + 1, n_A - m + 1, dtype=np.int64) else: @@ -141,20 +147,30 @@ def aamped(dask_client, T_A, m, T_B=None, ignore_trivial=True, p=2.0, k=1): p, diags_futures[i], ignore_trivial, + k, ) ) results = dask_client.gather(futures) - profile, indices = results[0] + profile, profile_L, profile_R, indices, indices_L, indices_R = results[0] for i in range(1, len(hosts)): - P, I = results[i] - for col in range(P.shape[1]): # pragma: no cover - cond = P[:, col] < profile[:, col] - profile[:, col] = np.where(cond, P[:, col], profile[:, col]) - indices[:, col] = np.where(cond, I[:, col], indices[:, col]) - - out[:, 0] = profile[:, 0] - out[:, 1:4] = indices + P, PL, PR, I, IL, IR = results[i] + # Update top-k matrix profile and matrix profile indices + core._merge_topk_PI(profile, P, indices, I) + + # Update top-1 left matrix profile and matrix profile index + mask = PL < profile_L + profile_L[mask] = PL[mask] + indices_L[mask] = IL[mask] + + # Update top-1 right matrix profile and matrix profile index + mask = PR < profile_R + profile_R[mask] = PR[mask] + indices_R[mask] = IR[mask] + + out = np.empty((l, 2 * k + 2), dtype=object) + out[:, :k] = profile + out[:, k : 2 * k + 2] = np.column_stack((indices, indices_L, indices_R)) core._check_P(out[:, 0]) diff --git a/stumpy/aampi.py b/stumpy/aampi.py index 621e20293..0dc7bd5cd 100644 --- a/stumpy/aampi.py +++ b/stumpy/aampi.py @@ -101,12 +101,13 @@ def __init__(self, T, m, egress=True, p=2.0, k=1): self._excl_zone = int(np.ceil(self._m / config.STUMPY_EXCL_ZONE_DENOM)) self._egress = egress self._p = p + self._k = k - mp = aamp(self._T, self._m, p=self._p) - self._P = mp[:, 0].astype(np.float64) - self._I = mp[:, 1].astype(np.int64) - self._left_I = mp[:, 2].astype(np.int64) - self._left_P = np.empty(self._P.shape, dtype=np.float64) + mp = aamp(self._T, self._m, p=self._p, k=self._k) + self._P = mp[:, : self._k].astype(np.float64) + self._I = mp[:, self._k : 2 * self._k].astype(np.int64) + self._left_I = mp[:, 2 * self._k].astype(np.int64) + self._left_P = np.full_like(self._left_I, np.inf, dtype=np.float64) self._left_P[:] = np.inf self._T_isfinite = np.isfinite(self._T) @@ -120,8 +121,8 @@ def __init__(self, T, m, egress=True, p=2.0, k=1): # matrix profile values, we can save time by re-computing only the left matrix # profile value when the matrix profile index is equal to the right matrix # profile index. - mask = self._left_I == self._I - self._left_P[mask] = self._P[mask] + mask = self._left_I == self._I[:, 0] + self._left_P[mask] = self._P[mask, 0] # Only re-compute the `i`-th left matrix profile value, `self._left_P[i]`, # when `self._I[i] != self._left_I[i]` @@ -132,7 +133,7 @@ def __init__(self, T, m, egress=True, p=2.0, k=1): self._T[i : i + self._m] - self._T[j : j + self._m], ord=self._p ) - Q = self._T[-m:] + Q = self._T[-self._m :] self._p_norm = core.mass_absolute(Q, self._T, p=self._p) ** self._p if self._egress: self._p_norm_new = np.empty(self._p_norm.shape[0], dtype=np.float64) @@ -202,9 +203,10 @@ def _update_egress(self, t): self._p_norm_new[0] = ( np.linalg.norm(self._T[: self._m] - S[: self._m], ord=self._p) ** self._p ) - self._p_norm_new[:] = np.where( - self._p_norm_new < config.STUMPY_P_NORM_THRESHOLD, 0, self._p_norm_new - ) + + mask = self._p_norm_new < config.STUMPY_P_NORM_THRESHOLD + self._p_norm_new[mask] = 0 + D = np.power(self._p_norm_new, 1.0 / self._p) D[~self._T_subseq_isfinite] = np.inf if np.any(~self._T_isfinite[-self._m :]): @@ -212,30 +214,37 @@ def _update_egress(self, t): core.apply_exclusion_zone(D, D.shape[0] - 1, self._excl_zone, np.inf) - update_idx = np.argwhere(D < self._P).flatten() - self._I[update_idx] = D.shape[0] + self._n_appended - 1 # D.shape[0] is base-1 - self._P[update_idx] = D[update_idx] - - I_last = np.argmin(D) - - if np.isinf(D[I_last]): - self._I[-1] = -1 - self._P[-1] = np.inf - else: - self._I[-1] = I_last + self._n_appended - self._P[-1] = D[I_last] - - # Regarding the last subsequence, the left profile (index) value is the - # same as the profile (index) value. - self._left_I[-1] = self._I[-1] - self._left_P[-1] = self._P[-1] + update_idx = np.argwhere(D < self._P[:, -1]).flatten() + for i in update_idx: + idx = np.searchsorted(self._P[i], D[i], side="right") + core._shift_insert_at_index(self._P[i], idx, D[i]) + core._shift_insert_at_index( + self._I[i], idx, D.shape[0] + self._n_appended - 1 + ) + # D.shape[0] is base-1 + + # Calculate the (top-k) matrix profile values/indices for the last susequence + # by using its correspondng distance profile `D` + self._P[-1] = np.inf + self._I[-1] = -1 + for i, d in enumerate(D): + if d < self._P[-1, -1]: + idx = np.searchsorted(self._P[-1], d, side="right") + core._shift_insert_at_index(self._P[-1], idx, d) + core._shift_insert_at_index(self._I[-1], idx, i + self._n_appended) + + # All neighbors of the last subsequence are on its left. So, its (top-1) + # matrix profile value/index and its left matrix profile value/index must + # be equal. + self._left_P[-1] = self._P[-1, 0] + self._left_I[-1] = self._I[-1, 0] self._p_norm[:] = self._p_norm_new def _update(self, t): """ - Ingress a new data point and update the matrix profile and matrix profile - indices without egressing the oldest data point + Ingress a new data point and update the (top-k) matrix profile and matrix + profile indices without egressing the oldest data point """ self._n = self._T.shape[0] l = self._n - self._m + 1 @@ -264,9 +273,10 @@ def _update(self, t): p_norm_new[0] = ( np.linalg.norm(T_new[: self._m] - S[: self._m], ord=self._p) ** self._p ) - p_norm_new[:] = np.where( - p_norm_new < config.STUMPY_P_NORM_THRESHOLD, 0, p_norm_new - ) + + mask = p_norm_new < config.STUMPY_P_NORM_THRESHOLD + p_norm_new[mask] = 0 + D = np.power(p_norm_new, 1.0 / self._p) D[~self._T_subseq_isfinite] = np.inf if np.any(~self._T_isfinite[-self._m :]): @@ -274,55 +284,69 @@ def _update(self, t): core.apply_exclusion_zone(D, D.shape[0] - 1, self._excl_zone, np.inf) - update_idx = np.argwhere(D[:l] < self._P[:l]).flatten() - self._I[update_idx] = l - self._P[update_idx] = D[update_idx] - - I_last = np.argmin(D) - if np.isinf(D[I_last]): - I_new = np.append(self._I, -1) - P_new = np.append(self._P, np.inf) - else: - I_new = np.append(self._I, I_last) - P_new = np.append(self._P, D[I_last]) - - # Regarding the last subsequence, the left profile (index) value is the - # same as the profile (index) value. - left_I_new = np.append(self._left_I, I_new[-1]) - left_P_new = np.append(self._left_P, P_new[-1]) + update_idx = np.argwhere(D[:l] < self._P[:l, -1]).flatten() + for i in update_idx: + idx = np.searchsorted(self._P[i], D[i], side="right") + core._shift_insert_at_index(self._P[i], idx, D[i]) + core._shift_insert_at_index(self._I[i], idx, l) + + # Calculating top-k matrix profile and (top-1) left matrix profile (and their + # corresponding indices) for new subsequence whose distance profie is `D` + P_new = np.full(self._k, np.inf, dtype=np.float64) + I_new = np.full(self._k, -1, dtype=np.int64) + for i, d in enumerate(D): + if d < P_new[-1]: # maximum value in sorted array P_new + idx = np.searchsorted(P_new, d, side="right") + core._shift_insert_at_index(P_new, idx, d) + core._shift_insert_at_index(I_new, idx, i) + + left_I_new = I_new[0] + left_P_new = P_new[0] self._T = T_new - self._P = P_new - self._I = I_new - self._left_I = left_I_new - self._left_P = left_P_new + self._P = np.append(self._P, P_new.reshape(1, -1), axis=0) + self._I = np.append(self._I, I_new.reshape(1, -1), axis=0) + self._left_P = np.append(self._left_P, left_P_new) + self._left_I = np.append(self._left_I, left_I_new) self._p_norm = p_norm_new @property def P_(self): """ - Get the matrix profile + Get the (top-k) matrix profile. When `k=1` (default), the output is + a 1D array consisting of the matrix profile. When `k > 1`, the + output is a 2D array that has exactly `k` columns and it consists of the + top-k matrix profile. """ - return self._P.astype(np.float64) + if self._k == 1: + return self._P.flatten().astype(np.float64) + else: + return self._P.astype(np.float64) @property def I_(self): """ - Get the matrix profile indices + Get the (top-k) matrix profile indices. When `k=1` (default), the output is + a 1D array consisting of the matrix profile indices. When `k > 1`, the + output is a 2D array that has exactly `k` columns and it consists of the + top-k matrix profile indices. """ - return self._I.astype(np.int64) + if self._k == 1: + return self._I.flatten().astype(np.int64) + else: + return self._I.astype(np.int64) @property def left_P_(self): """ - Get the left matrix profile + Get the (top-1) left matrix profile """ return self._left_P.astype(np.float64) @property def left_I_(self): """ - Get the left matrix profile indices + Get the (top-1) left matrix profile indices """ return self._left_I.astype(np.int64) diff --git a/stumpy/core.py b/stumpy/core.py index ef99a0746..77d5bcd54 100644 --- a/stumpy/core.py +++ b/stumpy/core.py @@ -7,7 +7,7 @@ import inspect import numpy as np -from numba import njit +from numba import njit, cuda from scipy.signal import convolve from scipy.ndimage import maximum_filter1d, minimum_filter1d from scipy import linalg @@ -2885,3 +2885,93 @@ def max_distance(D): candidate_idx = np.argmin(D) return np.array(matches, dtype=object) + + +@cuda.jit(device=True) +def _gpu_searchsorted_left(a, v, bfs, nlevel): + """ + A device function, equivalent to numpy.searchsorted(a, v, side='left') + + Parameters + ---------- + a : numpy.ndarray + 1-dim array sorted in ascending order. + + v : float + Value to insert into array `a` + + bfs : numpy.ndarray + The breadth-first-search indices where the missing leaves of its corresponding + binary search tree are filled with -1. + + nlevel : int + The number of levels in the binary search tree from which the array + `bfs` is obtained. + + Returns + ------- + idx : int + The index of the insertion point + """ + n = a.shape[0] + idx = 0 + for level in range(nlevel): + if v <= a[bfs[idx]]: + next_idx = 2 * idx + 1 + else: + next_idx = 2 * idx + 2 + + if level == nlevel - 1 or bfs[next_idx] < 0: + if v <= a[bfs[idx]]: + idx = max(bfs[idx], 0) + else: + idx = min(bfs[idx] + 1, n) + break + idx = next_idx + + return idx + + +@cuda.jit(device=True) +def _gpu_searchsorted_right(a, v, bfs, nlevel): + """ + A device function, equivalent to numpy.searchsorted(a, v, side='right') + + Parameters + ---------- + a : numpy.ndarray + 1-dim array sorted in ascending order. + + v : float + Value to insert into array `a` + + bfs : numpy.ndarray + The breadth-first-search indices where the missing leaves of its corresponding + binary search tree are filled with -1. + + nlevel : int + The number of levels in the binary search tree from which the array + `bfs` is obtained. + + Returns + ------- + idx : int + The index of the insertion point + """ + n = a.shape[0] + idx = 0 + for level in range(nlevel): + if v < a[bfs[idx]]: + next_idx = 2 * idx + 1 + else: + next_idx = 2 * idx + 2 + + if level == nlevel - 1 or bfs[next_idx] < 0: + if v < a[bfs[idx]]: + idx = max(bfs[idx], 0) + else: + idx = min(bfs[idx] + 1, n) + break + idx = next_idx + + return idx diff --git a/stumpy/gpu_aamp.py b/stumpy/gpu_aamp.py index 7bf783525..8150009af 100644 --- a/stumpy/gpu_aamp.py +++ b/stumpy/gpu_aamp.py @@ -16,7 +16,8 @@ @cuda.jit( "(i8, f8[:], f8[:], i8, f8, f8[:], f8[:], f8[:], b1[:], b1[:]," - "i8, b1, i8, f8[:, :], i8[:, :], b1)" + "i8, b1, i8, f8[:, :], f8[:], f8[:], i8[:, :], i8[:], i8[:], b1," + "i8[:], i8, i8)" ) def _compute_and_update_PI_kernel( i, @@ -29,12 +30,19 @@ def _compute_and_update_PI_kernel( p_norm_first, T_A_subseq_isfinite, T_B_subseq_isfinite, - k, + w, ignore_trivial, excl_zone, profile, + profile_L, + profile_R, indices, + indices_L, + indices_R, compute_p_norm, + bfs, + nlevel, + k, ): """ A Numba CUDA kernel to update the non-normalized (i.e., without z-normalization) @@ -75,7 +83,7 @@ def _compute_and_update_PI_kernel( A boolean array that indicates whether a subsequence in `T_B` contains a `np.nan`/`np.inf` value (False) - k : int + w : int The total number of sliding windows to iterate over ignore_trivial : bool @@ -87,18 +95,31 @@ def _compute_and_update_PI_kernel( sliding window profile : numpy.ndarray - Matrix profile. The first column consists of the global matrix profile, - the second column consists of the left matrix profile, and the third - column consists of the right matrix profile. + The (top-k) matrix profile, sorted in ascending order per row + + profile_L : numpy.ndarray + The (top-1) left matrix profile + + profile_R : numpy.ndarray + The (top-1) right matrix profile indices : numpy.ndarray - The first column consists of the matrix profile indices, the second - column consists of the left matrix profile indices, and the third - column consists of the right matrix profile indices. + The (top-k) matrix profile indices + + indices_L : numpy.ndarray + The (top-1) left matrix profile indices + + indices_R : numpy.ndarray + The (top-1) right matrix profile indices compute_p_norm : bool A boolean flag for whether or not to compute the p-norm + k : int + The number of top `k` smallest distances used to construct the matrix profile. + Note that this will increase the total computational time and memory usage + when k > 1. + Returns ------- None @@ -129,7 +150,7 @@ def _compute_and_update_PI_kernel( for j in range(start, p_norm_out.shape[0], stride): zone_start = max(0, j - excl_zone) - zone_stop = min(k, j + excl_zone) + zone_stop = min(w, j + excl_zone) if compute_p_norm: p_norm_out[j] = ( @@ -149,16 +170,21 @@ def _compute_and_update_PI_kernel( if ignore_trivial: if i <= zone_stop and i >= zone_start: p_norm = np.inf - if p_norm < profile[j, 1] and i < j: - profile[j, 1] = p_norm - indices[j, 1] = i - if p_norm < profile[j, 2] and i > j: - profile[j, 2] = p_norm - indices[j, 2] = i + if p_norm < profile_L[j] and i < j: + profile_L[j] = p_norm + indices_L[j] = i + if p_norm < profile_R[j] and i > j: + profile_R[j] = p_norm + indices_R[j] = i - if p_norm < profile[j, 0]: - profile[j, 0] = p_norm - indices[j, 0] = i + if p_norm < profile[j, -1]: + idx = core._gpu_searchsorted_right(profile[j], p_norm, bfs, nlevel) + for g in range(k - 1, idx, -1): + profile[j, g] = profile[j, g - 1] + indices[j, g] = indices[j, g - 1] + + profile[j, idx] = p_norm + indices[j, idx] = i def _gpu_aamp( @@ -172,10 +198,11 @@ def _gpu_aamp( p, p_norm_fname, p_norm_first_fname, - k, + w, ignore_trivial=True, range_start=1, device_id=0, + k=1, ): """ A Numba CUDA version of AAMP for parallel computation of the non-normalized (i.e., @@ -223,7 +250,7 @@ def _gpu_aamp( The file name for the p-norm for the first window relative to the current sliding window - k : int + w : int The total number of sliding windows to iterate over ignore_trivial : bool, default True @@ -237,16 +264,30 @@ def _gpu_aamp( device_id : int, default 0 The (GPU) device number to use. The default value is `0`. + k : int + The number of top `k` smallest distances used to construct the matrix profile. + Note that this will increase the total computational time and memory usage + when k > 1. + Returns ------- profile_fname : str - The file name for the matrix profile + The file name for the (top-k) matrix profile + + profile_L_fname : str + The file name for the (top-1) left matrix profile + + profile_R_fname : str + The file name for the (top-1) right matrix profile indices_fname : str - The file name for the matrix profile indices. The first column of the - array consists of the matrix profile indices, the second column consists - of the left matrix profile indices, and the third column consists of the - right matrix profile indices. + The file name for the (top-k) matrix profile indices + + indices_L_fname : str + The file name for the (top-1) left matrix profile indices + + indices_R_fname : str + The file name for the (top-1) right matrix profile indices Notes ----- @@ -263,7 +304,7 @@ def _gpu_aamp( See Table II, Figure 5, and Figure 6 """ threads_per_block = config.STUMPY_THREADS_PER_BLOCK - blocks_per_grid = math.ceil(k / threads_per_block) + blocks_per_grid = math.ceil(w / threads_per_block) T_A = np.load(T_A_fname, allow_pickle=False) T_B = np.load(T_B_fname, allow_pickle=False) @@ -272,6 +313,10 @@ def _gpu_aamp( T_A_subseq_isfinite = np.load(T_A_subseq_isfinite_fname, allow_pickle=False) T_B_subseq_isfinite = np.load(T_B_subseq_isfinite_fname, allow_pickle=False) + device_bfs = cuda.to_device(core._bfs_indices(k, fill_value=-1)) + nlevel = np.floor(np.log2(k) + 1).astype(np.int64) + # number of levels in binary search tree from which `bfs` is constructed. + with cuda.gpus[device_id]: device_T_A = cuda.to_device(T_A) device_T_A_subseq_isfinite = cuda.to_device(T_A_subseq_isfinite) @@ -285,11 +330,22 @@ def _gpu_aamp( device_T_B = cuda.to_device(T_B) device_T_B_subseq_isfinite = cuda.to_device(T_B_subseq_isfinite) - profile = np.full((k, 3), np.inf, dtype=np.float64) - indices = np.full((k, 3), -1, dtype=np.int64) + profile = np.full((w, k), np.inf, dtype=np.float64) + indices = np.full((w, k), -1, dtype=np.int64) + + profile_L = np.full(w, np.inf, dtype=np.float64) + indices_L = np.full(w, -1, dtype=np.int64) + + profile_R = np.full(w, np.inf, dtype=np.float64) + indices_R = np.full(w, -1, dtype=np.int64) device_profile = cuda.to_device(profile) + device_profile_L = cuda.to_device(profile_L) + device_profile_R = cuda.to_device(profile_R) device_indices = cuda.to_device(indices) + device_indices_L = cuda.to_device(indices_L) + device_indices_R = cuda.to_device(indices_R) + _compute_and_update_PI_kernel[blocks_per_grid, threads_per_block]( range_start - 1, device_T_A, @@ -301,12 +357,19 @@ def _gpu_aamp( device_p_norm_first, device_T_A_subseq_isfinite, device_T_B_subseq_isfinite, - k, + w, ignore_trivial, excl_zone, device_profile, + device_profile_L, + device_profile_R, device_indices, + device_indices_L, + device_indices_R, False, + device_bfs, + nlevel, + k, ) for i in range(range_start, range_stop): @@ -321,36 +384,61 @@ def _gpu_aamp( device_p_norm_first, device_T_A_subseq_isfinite, device_T_B_subseq_isfinite, - k, + w, ignore_trivial, excl_zone, device_profile, + device_profile_L, + device_profile_R, device_indices, + device_indices_L, + device_indices_R, True, + device_bfs, + nlevel, + k, ) profile = device_profile.copy_to_host() + profile_L = device_profile_L.copy_to_host() + profile_R = device_profile_R.copy_to_host() indices = device_indices.copy_to_host() - profile = np.power(profile, 1.0 / p) + indices_L = device_indices_L.copy_to_host() + indices_R = device_indices_R.copy_to_host() + + profile[:, :] = np.power(profile, 1.0 / p) + profile_L[:] = np.power(profile_L, 1.0 / p) + profile_R[:] = np.power(profile_R, 1.0 / p) profile_fname = core.array_to_temp_file(profile) + profile_L_fname = core.array_to_temp_file(profile_L) + profile_R_fname = core.array_to_temp_file(profile_R) indices_fname = core.array_to_temp_file(indices) + indices_L_fname = core.array_to_temp_file(indices_L) + indices_R_fname = core.array_to_temp_file(indices_R) - return profile_fname, indices_fname + return ( + profile_fname, + profile_L_fname, + profile_R_fname, + indices_fname, + indices_L_fname, + indices_R_fname, + ) def gpu_aamp(T_A, m, T_B=None, ignore_trivial=True, device_id=0, p=2.0, k=1): # function needs to be revised to return (top-k) matrix profile and # matrix profile indices """ - Compute the non-normalized (i.e., without z-normalization) matrix profile with one - or more GPU devices + Compute the non-normalized (i.e., without z-normalization) matrix profile with + one or more GPU devices This is a convenience wrapper around the Numba `cuda.jit` `_gpu_aamp` function - which computes the non-normalized matrix profile according to modified version - GPU-STOMP. The default number of threads-per-block is set to `512` and may be - changed by setting the global parameter `config.STUMPY_THREADS_PER_BLOCK` to an - appropriate number based on your GPU hardware. + which computes the non-normalized (top-k) matrix profile according to modified + version GPU-STOMP. The default number of threads-per-block is set to `512` and + may be changed by setting the global parameter `config.STUMPY_THREADS_PER_BLOCK` + to an appropriate number based on your GPU hardware. Parameters ---------- @@ -385,10 +473,16 @@ def gpu_aamp(T_A, m, T_B=None, ignore_trivial=True, device_id=0, p=2.0, k=1): Returns ------- out : numpy.ndarray - The first column consists of the matrix profile, the second column - consists of the matrix profile indices, the third column consists of - the left matrix profile indices, and the fourth column consists of - the right matrix profile indices. + When k = 1 (default), the first column consists of the matrix profile, + the second column consists of the matrix profile indices, the third column + consists of the left matrix profile indices, and the fourth column consists + of the right matrix profile indices. However, when k > 1, the output array + will contain exactly 2 * k + 2 columns. The first k columns (i.e., out[:, :k]) + consists of the top-k matrix profile, the next set of k columns + (i.e., out[:, k:2k]) consists of the corresponding top-k matrix profile + indices, and the last two columns (i.e., out[:, 2k] and out[:, 2k+1] or, + equivalently, out[:, -2] and out[:, -1]) correspond to the top-1 left + matrix profile indices and the top-1 right matrix profile indices, respectively. Notes ----- @@ -434,7 +528,7 @@ def gpu_aamp(T_A, m, T_B=None, ignore_trivial=True, device_id=0, p=2.0, k=1): logger.warning("Try setting `ignore_trivial = False`.") n = T_B.shape[0] - k = T_A.shape[0] - m + 1 + w = T_A.shape[0] - m + 1 l = n - m + 1 excl_zone = int( np.ceil(m / config.STUMPY_EXCL_ZONE_DENOM) @@ -445,8 +539,6 @@ def gpu_aamp(T_A, m, T_B=None, ignore_trivial=True, device_id=0, p=2.0, k=1): T_A_subseq_isfinite_fname = core.array_to_temp_file(T_A_subseq_isfinite) T_B_subseq_isfinite_fname = core.array_to_temp_file(T_B_subseq_isfinite) - out = np.empty((k, 4), dtype=object) - if isinstance(device_id, int): device_ids = [device_id] else: @@ -455,6 +547,12 @@ def gpu_aamp(T_A, m, T_B=None, ignore_trivial=True, device_id=0, p=2.0, k=1): profile = [None] * len(device_ids) indices = [None] * len(device_ids) + profile_L = [None] * len(device_ids) + indices_L = [None] * len(device_ids) + + profile_R = [None] * len(device_ids) + indices_R = [None] * len(device_ids) + for _id in device_ids: with cuda.gpus[_id]: if ( @@ -499,16 +597,24 @@ def gpu_aamp(T_A, m, T_B=None, ignore_trivial=True, device_id=0, p=2.0, k=1): p, p_norm_fname, p_norm_first_fname, - k, + w, ignore_trivial, start + 1, device_ids[idx], + k, ), ) else: # Execute last chunk in parent process # Only parent process is executed when a single GPU is requested - profile[idx], indices[idx] = _gpu_aamp( + ( + profile[idx], + profile_L[idx], + profile_R[idx], + indices[idx], + indices_L[idx], + indices_R[idx], + ) = _gpu_aamp( T_A_fname, T_B_fname, m, @@ -519,10 +625,11 @@ def gpu_aamp(T_A, m, T_B=None, ignore_trivial=True, device_id=0, p=2.0, k=1): p, p_norm_fname, p_norm_first_fname, - k, + w, ignore_trivial, start + 1, device_ids[idx], + k, ) # Clean up process pool for multi-GPU request @@ -533,7 +640,14 @@ def gpu_aamp(T_A, m, T_B=None, ignore_trivial=True, device_id=0, p=2.0, k=1): # Collect results from spawned child processes if they exist for idx, result in enumerate(results): if result is not None: - profile[idx], indices[idx] = result.get() + ( + profile[idx], + profile_L[idx], + profile_R[idx], + indices[idx], + indices_L[idx], + indices_R[idx], + ) = result.get() os.remove(T_A_fname) os.remove(T_B_fname) @@ -546,22 +660,44 @@ def gpu_aamp(T_A, m, T_B=None, ignore_trivial=True, device_id=0, p=2.0, k=1): for idx in range(len(device_ids)): profile_fname = profile[idx] + profile_L_fname = profile_L[idx] + profile_R_fname = profile_R[idx] indices_fname = indices[idx] + indices_L_fname = indices_L[idx] + indices_R_fname = indices_R[idx] + profile[idx] = np.load(profile_fname, allow_pickle=False) + profile_L[idx] = np.load(profile_L_fname, allow_pickle=False) + profile_R[idx] = np.load(profile_R_fname, allow_pickle=False) indices[idx] = np.load(indices_fname, allow_pickle=False) + indices_L[idx] = np.load(indices_L_fname, allow_pickle=False) + indices_R[idx] = np.load(indices_R_fname, allow_pickle=False) + os.remove(profile_fname) + os.remove(profile_L_fname) + os.remove(profile_R_fname) os.remove(indices_fname) - - for i in range(1, len(device_ids)): - # Update all matrix profiles and matrix profile indices - # (global, left, right) and store in profile[0] and indices[0] - for col in range(profile[0].shape[1]): # pragma: no cover - cond = profile[0][:, col] < profile[i][:, col] - profile[0][:, col] = np.where(cond, profile[0][:, col], profile[i][:, col]) - indices[0][:, col] = np.where(cond, indices[0][:, col], indices[i][:, col]) - - out[:, 0] = profile[0][:, 0] - out[:, 1:4] = indices[0][:, :] + os.remove(indices_L_fname) + os.remove(indices_R_fname) + + for i in range(1, len(device_ids)): # pragma: no cover + # Update (top-k) matrix profile and matrix profile indices + core._merge_topk_PI(profile[0], profile[i], indices[0], indices[i]) + + # Update (top-1) left matrix profile and matrix profile indices + mask = profile_L[0] > profile_L[i] + profile_L[0][mask] = profile_L[i][mask] + indices_L[0][mask] = indices_L[i][mask] + + # Update (top-1) right matrix profile and matrix profile indices + mask = profile_R[0] > profile_R[i] + profile_R[0][mask] = profile_R[i][mask] + indices_R[0][mask] = indices_R[i][mask] + + out = np.empty((w, 2 * k + 2), dtype=object) # last two columns are to store + # (top-1) left/right matrix profile indices + out[:, :k] = profile[0] + out[:, k:] = np.column_stack((indices[0], indices_L[0], indices_R[0])) core._check_P(out[:, 0]) diff --git a/stumpy/gpu_stump.py b/stumpy/gpu_stump.py index d0890cceb..5538e8696 100644 --- a/stumpy/gpu_stump.py +++ b/stumpy/gpu_stump.py @@ -15,100 +15,10 @@ logger = logging.getLogger(__name__) -@cuda.jit(device=True) -def _gpu_searchsorted_left(a, v, bfs, nlevel): - """ - A device function, equivalent to numpy.searchsorted(a, v, side='left') - - Parameters - ---------- - a : numpy.ndarray - 1-dim array sorted in ascending order. - - v : float - Value to insert into array `a` - - bfs : numpy.ndarray - The breadth-first-search indices where the missing leaves of its corresponding - binary search tree are filled with -1. - - nlevel : int - The number of levels in the binary search tree from which the array - `bfs` is obtained. - - Returns - ------- - idx : int - The index of the insertion point - """ - n = a.shape[0] - idx = 0 - for level in range(nlevel): - if v <= a[bfs[idx]]: - next_idx = 2 * idx + 1 - else: - next_idx = 2 * idx + 2 - - if level == nlevel - 1 or bfs[next_idx] < 0: - if v <= a[bfs[idx]]: - idx = max(bfs[idx], 0) - else: - idx = min(bfs[idx] + 1, n) - break - idx = next_idx - - return idx - - -@cuda.jit(device=True) -def _gpu_searchsorted_right(a, v, bfs, nlevel): - """ - A device function, equivalent to numpy.searchsorted(a, v, side='right') - - Parameters - ---------- - a : numpy.ndarray - 1-dim array sorted in ascending order. - - v : float - Value to insert into array `a` - - bfs : numpy.ndarray - The breadth-first-search indices where the missing leaves of its corresponding - binary search tree are filled with -1. - - nlevel : int - The number of levels in the binary search tree from which the array - `bfs` is obtained. - - Returns - ------- - idx : int - The index of the insertion point - """ - n = a.shape[0] - idx = 0 - for level in range(nlevel): - if v < a[bfs[idx]]: - next_idx = 2 * idx + 1 - else: - next_idx = 2 * idx + 2 - - if level == nlevel - 1 or bfs[next_idx] < 0: - if v < a[bfs[idx]]: - idx = max(bfs[idx], 0) - else: - idx = min(bfs[idx] + 1, n) - break - idx = next_idx - - return idx - - @cuda.jit( "(i8, f8[:], f8[:], i8, f8[:], f8[:], f8[:], f8[:], f8[:]," "f8[:], f8[:], i8, b1, i8, f8[:, :], f8[:], f8[:], i8[:, :], i8[:], i8[:]," - "b1, i8[:], i8, i2)" + "b1, i8[:], i8, i8)" ) def _compute_and_update_PI_kernel( i, @@ -284,7 +194,7 @@ def _compute_and_update_PI_kernel( indices_R[j] = i if p_norm < profile[j, -1]: - idx = _gpu_searchsorted_right(profile[j], p_norm, bfs, nlevel) + idx = core._gpu_searchsorted_right(profile[j], p_norm, bfs, nlevel) for g in range(k - 1, idx, -1): profile[j, g] = profile[j, g - 1] indices[j, g] = indices[j, g - 1] @@ -854,12 +764,12 @@ def gpu_stump( core._merge_topk_PI(profile[0], profile[i], indices[0], indices[i]) # Update (top-1) left matrix profile and matrix profile indices - mask = profile_L[0] < profile_L[i] + mask = profile_L[0] > profile_L[i] profile_L[0][mask] = profile_L[i][mask] indices_L[0][mask] = indices_L[i][mask] # Update (top-1) right matrix profile and matrix profile indices - mask = profile_R[0] < profile_R[i] + mask = profile_R[0] > profile_R[i] profile_R[0][mask] = profile_R[i][mask] indices_R[0][mask] = indices_R[i][mask] diff --git a/stumpy/scraamp.py b/stumpy/scraamp.py index e81c267c3..846985978 100644 --- a/stumpy/scraamp.py +++ b/stumpy/scraamp.py @@ -174,65 +174,116 @@ def _compute_PI( p_norm_profile[~T_B_subseq_isfinite] = np.inf # Update P[i] relative to all T[j : j + m] if excl_zone is not None: - zone_start = max(0, i - excl_zone) - zone_stop = min(l, i + excl_zone) - p_norm_profile[zone_start : zone_stop + 1] = np.inf - - # only for self-join - mask = p_norm_profile < P_NORM[thread_idx] - P_NORM[thread_idx][mask] = p_norm_profile[mask] - I[thread_idx][mask] = i - - I[thread_idx, i] = np.argmin(p_norm_profile) - P_NORM[thread_idx, i] = p_norm_profile[I[thread_idx, i]] - if P_NORM[thread_idx, i] == np.inf: # pragma: no cover - I[thread_idx, i] = -1 - else: - j = I[thread_idx, i] - # Given the squared distance, work backwards and compute QT - p_norm_j = P_NORM[thread_idx, i] - p_norm_j_prime = p_norm_j - for k in range(1, min(s, l - i, w - j)): - p_norm_j = ( - p_norm_j - - abs(T_B[j + k - 1] - T_A[i + k - 1]) ** p - + abs(T_B[j + k + m - 1] - T_A[i + k + m - 1]) ** p - ) - if ( - not T_A_subseq_isfinite[i + k] or not T_B_subseq_isfinite[j + k] - ): # pragma: no cover - p_norm = np.inf - else: - p_norm = p_norm_j - if p_norm < config.STUMPY_P_NORM_THRESHOLD: # pragma: no cover - p_norm = 0.0 - if p_norm < P_NORM[thread_idx, i + k]: - P_NORM[thread_idx, i + k] = p_norm - I[thread_idx, i + k] = j + k - if excl_zone is not None and p_norm < P_NORM[thread_idx, j + k]: - P_NORM[thread_idx, j + k] = p_norm - I[thread_idx, j + k] = i + k - p_norm_j = p_norm_j_prime - for k in range(1, min(s, i + 1, j + 1)): - p_norm_j = ( - p_norm_j - - abs(T_B[j - k + m] - T_A[i - k + m]) ** p - + abs(T_B[j - k] - T_A[i - k]) ** p - ) - if ( - not T_A_subseq_isfinite[i - k] or not T_B_subseq_isfinite[j - k] - ): # pragma: no cover - p_norm = np.inf - else: - p_norm = p_norm_j - if p_norm < config.STUMPY_P_NORM_THRESHOLD: # pragma: no cover - p_norm = 0.0 - if p_norm < P_NORM[thread_idx, i - k]: - P_NORM[thread_idx, i - k] = p_norm - I[thread_idx, i - k] = j - k - if excl_zone is not None and p_norm < P_NORM[thread_idx, j - k]: - P_NORM[thread_idx, j - k] = p_norm - I[thread_idx, j - k] = i - k + core._apply_exclusion_zone(p_norm_profile, i, excl_zone, np.inf) + + nn_i = np.argmin(p_norm_profile) + if ( + p_norm_profile[nn_i] < P_NORM[thread_idx, i, -1] + and nn_i not in I[thread_idx, i] + ): + idx = np.searchsorted( + P_NORM[thread_idx, i], + p_norm_profile[nn_i], + side="right", + ) + core._shift_insert_at_index( + P_NORM[thread_idx, i], idx, p_norm_profile[nn_i] + ) + core._shift_insert_at_index(I[thread_idx, i], idx, nn_i) + + # this if is not needed as it is probably never executed + if P_NORM[thread_idx, i, 0] == np.inf: # pragma: no cover + I[thread_idx, i, 0] = -1 + continue + + j = nn_i + p_norm_j = P_NORM[thread_idx, i, 0] + p_norm_j_prime = p_norm_j + for g in range(1, min(s, l - i, w - j)): + p_norm_j = ( + p_norm_j + - abs(T_B[j + g - 1] - T_A[i + g - 1]) ** p + + abs(T_B[j + g + m - 1] - T_A[i + g + m - 1]) ** p + ) + if ( + not T_A_subseq_isfinite[i + g] or not T_B_subseq_isfinite[j + g] + ): # pragma: no cover + p_norm = np.inf + else: + p_norm = p_norm_j + if p_norm < config.STUMPY_P_NORM_THRESHOLD: # pragma: no cover + p_norm = 0.0 + + if ( + p_norm < P_NORM[thread_idx, i + g, -1] + and (j + g) not in I[thread_idx, i + g] + ): + idx = np.searchsorted(P_NORM[thread_idx, i + g], p_norm, side="right") + core._shift_insert_at_index(P_NORM[thread_idx, i + g], idx, p_norm) + core._shift_insert_at_index(I[thread_idx, i + g], idx, j + g) + + if ( + excl_zone is not None + and p_norm < P_NORM[thread_idx, j + g, -1] + and (i + g) not in I[thread_idx, j + g] + ): + idx = np.searchsorted(P_NORM[thread_idx, j + g], p_norm, side="right") + core._shift_insert_at_index(P_NORM[thread_idx, j + g], idx, p_norm) + core._shift_insert_at_index(I[thread_idx, j + g], idx, i + g) + + p_norm_j = p_norm_j_prime + for g in range(1, min(s, i + 1, j + 1)): + p_norm_j = ( + p_norm_j + - abs(T_B[j - g + m] - T_A[i - g + m]) ** p + + abs(T_B[j - g] - T_A[i - g]) ** p + ) + if ( + not T_A_subseq_isfinite[i - g] or not T_B_subseq_isfinite[j - g] + ): # pragma: no cover + p_norm = np.inf + else: + p_norm = p_norm_j + if p_norm < config.STUMPY_P_NORM_THRESHOLD: # pragma: no cover + p_norm = 0.0 + + if ( + p_norm < P_NORM[thread_idx, i - g, -1] + and (j - g) not in I[thread_idx, i - g] + ): + idx = np.searchsorted(P_NORM[thread_idx, i - g], p_norm, side="right") + core._shift_insert_at_index(P_NORM[thread_idx, i - g], idx, p_norm) + core._shift_insert_at_index(I[thread_idx, i - g], idx, j - g) + + if ( + excl_zone is not None + and p_norm < P_NORM[thread_idx, j - g, -1] + and (i - g) not in I[thread_idx, j - g] + ): + idx = np.searchsorted(P_NORM[thread_idx, j - g], p_norm, side="right") + core._shift_insert_at_index(P_NORM[thread_idx, j - g], idx, p_norm) + core._shift_insert_at_index(I[thread_idx, j - g], idx, i - g) + + # In the case of a self-join, the calculated profile can also be used + # to refine the top-k for all non-trivial subsequences + if excl_zone is not None: + # Note that `p_norm_profile[j]`, the distance between subsequences + # `S_i = T[i : i + m]` and `S_j = T[j : j + m]` can be used to update + # the top-k for BOTH subsequence `i` and subsequence `j`. We update + # the latter here. + + indices = np.flatnonzero(p_norm_profile < P_NORM[thread_idx, :, -1]) + for j in indices: + if i not in I[thread_idx, j]: + idx = np.searchsorted( + P_NORM[thread_idx, j], + p_norm_profile[j], + side="right", + ) + core._shift_insert_at_index( + P_NORM[thread_idx, j], idx, p_norm_profile[j] + ) + core._shift_insert_at_index(I[thread_idx, j], idx, i) @njit( @@ -251,6 +302,7 @@ def _prescraamp( indices, s, excl_zone=None, + k=1, ): """ A Numba JIT-compiled implementation of the non-normalized (i.e., without @@ -298,13 +350,24 @@ def _prescraamp( excl_zone : int The half width for the exclusion zone relative to the `i`. + k : int, default 1 + The number of top `k` smallest distances used to construct the matrix profile. + Note that this will increase the total computational time and memory usage + when k > 1. + Returns ------- out1 : numpy.ndarray - Matrix profile + The (top-k) matrix profile. When k=1 (default), the first (and only) column + in this 2D array consists of the matrix profile. When k > 1, the output + has exactly `k` columns consisting of the top-k matrix profile. out2 : numpy.ndarray - Matrix profile indices + The (top-k) matrix profile indices. When k=1 (default), the first (and only) + column in this 2D array consists of the matrix profile indices. When k > 1, + the output has exactly `k` columns consisting of the top-k matrix profile + indices. + Notes ----- `DOI: 10.1109/ICDM.2018.00099 \ @@ -314,8 +377,8 @@ def _prescraamp( """ n_threads = numba.config.NUMBA_NUM_THREADS l = T_A.shape[0] - m + 1 - P_NORM = np.full((n_threads, l), np.inf, dtype=np.float64) - I = np.full((n_threads, l), -1, dtype=np.int64) + P_NORM = np.full((n_threads, l, k), np.inf, dtype=np.float64) + I = np.full((n_threads, l, k), -1, dtype=np.int64) idx_ranges = core._get_ranges(len(indices), n_threads, truncate=False) for thread_idx in prange(n_threads): @@ -337,10 +400,7 @@ def _prescraamp( ) for thread_idx in range(1, n_threads): - for i in range(l): - if P_NORM[thread_idx, i] < P_NORM[0, i]: - P_NORM[0, i] = P_NORM[thread_idx, i] - I[0, i] = I[thread_idx, i] + core._merge_topk_PI(P_NORM[0], P_NORM[thread_idx], I[0], I[thread_idx]) return np.power(P_NORM[0], 1.0 / p), I[0] @@ -349,8 +409,8 @@ def prescraamp(T_A, m, T_B=None, s=None, p=2.0, k=1): # this function should be modified so that it can return top-k matrix profile """ A convenience wrapper around the Numba JIT-compiled parallelized `_prescraamp` - function which computes the approximate matrix profile according to the - non-normalized (i.e., without z-normalization) preSCRIMP algorithm + function which computes the approximate (top-k) matrix profile according to + the non-normalized (i.e., without z-normalization) preSCRIMP algorithm Parameters ---------- @@ -379,10 +439,15 @@ def prescraamp(T_A, m, T_B=None, s=None, p=2.0, k=1): Returns ------- P : numpy.ndarray - Matrix profile + The (top-k) matrix profile. When k = 1 (default), this is a 1D array + consisting of the matrix profile. When k > 1, the output is a 2D array that + has exactly `k` columns consisting of the top-k matrix profile. I : numpy.ndarray - Matrix profile indices + The (top-k) matrix profile indices. When k = 1 (default), this is a 1D array + consisting of the matrix profile indices. When k > 1, the output is a 2D + array that has exactly `k` columns consisting of the top-k matrix profile + indices. Notes ----- @@ -411,9 +476,13 @@ def prescraamp(T_A, m, T_B=None, s=None, p=2.0, k=1): indices, s, excl_zone, + k, ) - return P, I + if k == 1: + return P.flatten().astype(np.float64), I.flatten().astype(np.int64) + else: + return P, I class scraamp: @@ -465,19 +534,31 @@ class scraamp: Attributes ---------- P_ : numpy.ndarray - The updated matrix profile + The updated (top-k) matrix profile. When `k=1` (default), this output is + a 1D array consisting of the matrix profile. When `k > 1`, the output + is a 2D array that has exactly `k` columns consisting of the top-k matrix + profile. I_ : numpy.ndarray - The updated matrix profile indices + The updated (top-k) matrix profile indices. When `k=1` (default), this output is + a 1D array consisting of the matrix profile indices. When `k > 1`, the output + is a 2D array that has exactly `k` columns consisting of the top-k matrix + profile indiecs. + + left_I_ : numpy.ndarray + The updated left (top-1) matrix profile indices + + right_I_ : numpy.ndarray + The updated right (top-1) matrix profile indices Methods ------- update() Update the matrix profile and the matrix profile indices by computing additional new distances (limited by `percentage`) that make up the full - distance matrix. Each output contains three columns that correspond to - the matrix profile, the left matrix profile, and the right matrix profile, - respectively. + distance matrix. It updates the (top-k) matrix profile, (top-1) left + matrix profile, (top-1) right matrix profile, (top-k) matrix profile indices, + (top-1) left matrix profile indices, and (top-1) right matrix profile indices. Notes ----- @@ -586,11 +667,15 @@ def __init__( self._n_A = self._T_A.shape[0] self._n_B = self._T_B.shape[0] self._l = self._n_A - self._m + 1 + self._k = k + + self._P = np.full((self._l, self._k), np.inf, dtype=np.float64) + self._PL = np.full(self._l, np.inf, dtype=np.float64) + self._PR = np.full(self._l, np.inf, dtype=np.float64) - self._P = np.empty((self._l, 3), dtype=np.float64) - self._I = np.empty((self._l, 3), dtype=np.int64) - self._P[:, :] = np.inf - self._I[:, :] = -1 + self._I = np.full((self._l, self._k), -1, dtype=np.int64) + self._IL = np.full(self._l, -1, dtype=np.int64) + self._IR = np.full(self._l, -1, dtype=np.int64) self._excl_zone = int(np.ceil(self._m / config.STUMPY_EXCL_ZONE_DENOM)) if s is None: @@ -631,12 +716,9 @@ def __init__( indices, s, excl_zone, + k, ) - - for i in range(P.shape[0]): - if self._P[i, 0] > P[i]: - self._P[i, 0] = P[i] - self._I[i, 0] = I[i] + core._merge_topk_PI(self._P, P, self._I, I) if self._ignore_trivial: self._diags = np.random.permutation( @@ -667,14 +749,14 @@ def __init__( def update(self): """ - Update the matrix profile and the matrix profile indices by computing - additional new distances (limited by `percentage`) that make up the full - distance matrix. + Update the (top-k) matrix profile and the (top-k) matrix profile indices by + computing additional new distances (limited by `percentage`) that make up + the full distance matrix. """ if self._chunk_idx < self._n_chunks: start_idx, stop_idx = self._chunk_diags_ranges[self._chunk_idx] - P, I = _aamp( + P, PL, PR, I, IL, IR = _aamp( self._T_A, self._T_B, self._m, @@ -683,48 +765,60 @@ def update(self): self._p, self._diags[start_idx:stop_idx], self._ignore_trivial, + self._k, ) - # Update matrix profile and indices - for i in range(self._P.shape[0]): - if self._P[i, 0] > P[i, 0]: - self._P[i, 0] = P[i, 0] - self._I[i, 0] = I[i, 0] - # left matrix profile and left matrix profile indices - if self._P[i, 1] > P[i, 1]: - self._P[i, 1] = P[i, 1] - self._I[i, 1] = I[i, 1] - # right matrix profile and right matrix profile indices - if self._P[i, 2] > P[i, 2]: - self._P[i, 2] = P[i, 2] - self._I[i, 2] = I[i, 2] + # Update (top-k) matrix profile and indices + core._merge_topk_PI(self._P, P, self._I, I) + + # update left matrix profile and indices + mask = PL < self._PL + self._PL[mask] = PL[mask] + self._IL[mask] = IL[mask] + + # update right matrix profile and indices + mask = PR < self._PR + self._PR[mask] = PR[mask] + self._IR[mask] = IR[mask] self._chunk_idx += 1 @property def P_(self): """ - Get the updated matrix profile + Get the updated (top-k) matrix profile. When `k=1` (default), this output + is a 1D array consisting of the updated matrix profile. When `k > 1`, the + output is a 2D array that has exactly `k` columns consisting of the updated + top-k matrix profile. """ - return self._P[:, 0].astype(np.float64) + if self._k == 1: + return self._P.flatten().astype(np.float64) + else: + return self._P.astype(np.float64) @property def I_(self): """ - Get the updated matrix profile indices + Get the updated (top-k) matrix profile indices. When `k=1` (default), this + output is a 1D array consisting of the updated matrix profile indices. When + `k > 1`, the output is a 2D array that has exactly `k` columns consisting + of the updated top-k matrix profile indices. """ - return self._I[:, 0].astype(np.int64) + if self._k == 1: + return self._I.flatten().astype(np.int64) + else: + return self._I.astype(np.int64) @property def left_I_(self): """ - Get the updated left matrix profile indices + Get the updated left (top-1) matrix profile indices """ - return self._I[:, 1].astype(np.int64) + return self._IL.astype(np.int64) @property def right_I_(self): """ - Get the updated right matrix profile indices + Get the updated right (top-1) matrix profile indices """ - return self._I[:, 2].astype(np.int64) + return self._IR.astype(np.int64) diff --git a/stumpy/stump.py b/stumpy/stump.py index 33c573d42..327b843db 100644 --- a/stumpy/stump.py +++ b/stumpy/stump.py @@ -142,11 +142,6 @@ def _compute_diagonal( Set to `True` if this is a self-join. Otherwise, for AB-join, set this to `False`. Default is `True`. - k : int - The number of top `k` smallest distances used to construct the matrix profile. - Note that this will increase the total computational time and memory usage - when k > 1. - Returns ------- None @@ -345,22 +340,22 @@ def _stump( Returns ------- - profile : numpy.ndarray + out1 : numpy.ndarray The (top-k) matrix profile - indices : numpy.ndarray - The (top-k) matrix profile indices - - left profile : numpy.ndarray + out2 : numpy.ndarray The (top-1) left matrix profile - left indices : numpy.ndarray - The (top-1) left matrix profile indices - - right profile : numpy.ndarray + out3 : numpy.ndarray The (top-1) right matrix profile - right indices : numpy.ndarray + out4 : numpy.ndarray + The (top-k) matrix profile indices + + out5 : numpy.ndarray + The (top-1) left matrix profile indices + + out6 : numpy.ndarray The (top-1) right matrix profile indices Notes @@ -508,11 +503,14 @@ def _stump( if p_norm_R[i] < config.STUMPY_P_NORM_THRESHOLD: p_norm_R[i] = 0.0 - P = np.sqrt(p_norm) - PL = np.sqrt(p_norm_L) - PR = np.sqrt(p_norm_R) - - return P, PL, PR, I, IL[0], IR[0] + return ( + np.sqrt(p_norm), + np.sqrt(p_norm_L), + np.sqrt(p_norm_R), + I, + IL[0], + IR[0], + ) @core.non_normalized(aamp) diff --git a/stumpy/stumpi.py b/stumpy/stumpi.py index 82fac2da8..0bd6cc765 100644 --- a/stumpy/stumpi.py +++ b/stumpy/stumpi.py @@ -134,10 +134,10 @@ def __init__(self, T, m, egress=True, normalize=True, p=2.0, k=1): self._egress = egress mp = stump(self._T, self._m, k=self._k) - self._P = mp[:, :k].astype(np.float64) - self._I = mp[:, k : 2 * k].astype(np.int64) + self._P = mp[:, : self._k].astype(np.float64) + self._I = mp[:, self._k : 2 * self._k].astype(np.int64) - self._left_I = mp[:, 2 * k].astype(np.int64) + self._left_I = mp[:, 2 * self._k].astype(np.int64) self._left_P = np.full_like(self._left_I, np.inf, dtype=np.float64) self._T, self._M_T, self._Σ_T = core.preprocess(self._T, self._m) @@ -165,7 +165,7 @@ def __init__(self, T, m, egress=True, normalize=True, p=2.0, k=1): ) self._left_P[i] = np.sqrt(D_square) - Q = self._T[-m:] + Q = self._T[-self._m :] self._QT = core.sliding_dot_product(Q, self._T) if self._egress: self._QT_new = np.empty(self._QT.shape[0], dtype=np.float64) @@ -320,7 +320,7 @@ def _update(self, t): core._shift_insert_at_index(self._P[i], idx, D[i]) core._shift_insert_at_index(self._I[i], idx, l) - # Calculating top-k matrix profile and (top-1) left matrix profile (and thier + # Calculating top-k matrix profile and (top-1) left matrix profile (and their # corresponding indices) for new subsequence whose distance profie is `D` P_new = np.full(self._k, np.inf, dtype=np.float64) I_new = np.full(self._k, -1, dtype=np.int64) diff --git a/tests/naive.py b/tests/naive.py index 992125d02..a7ab79c6a 100644 --- a/tests/naive.py +++ b/tests/naive.py @@ -267,17 +267,22 @@ def stump(T_A, m, T_B=None, exclusion_zone=None, row_wise=False, k=1): return result -def aamp(T_A, m, T_B=None, exclusion_zone=None, p=2.0): +def aamp(T_A, m, T_B=None, exclusion_zone=None, p=2.0, row_wise=False, k=1): + """ + Traverse distance matrix diagonally and update the top-k matrix profile and + matrix profile indices if the parameter `row_wise` is set to `False`. If the + parameter `row_wise` is set to `True`, it is a row-wise traversal. + """ T_A = np.asarray(T_A) T_A = T_A.copy() if T_B is None: - T_B = T_A.copy() ignore_trivial = True + T_B = T_A.copy() else: + ignore_trivial = False T_B = np.asarray(T_B) T_B = T_B.copy() - ignore_trivial = False T_A[np.isinf(T_A)] = np.nan T_B[np.isinf(T_B)] = np.nan @@ -285,53 +290,84 @@ def aamp(T_A, m, T_B=None, exclusion_zone=None, p=2.0): rolling_T_A = core.rolling_window(T_A, m) rolling_T_B = core.rolling_window(T_B, m) + distance_matrix = cdist(rolling_T_A, rolling_T_B, metric="minkowski", p=p) + n_A = T_A.shape[0] n_B = T_B.shape[0] l = n_A - m + 1 if exclusion_zone is None: exclusion_zone = int(np.ceil(m / config.STUMPY_EXCL_ZONE_DENOM)) - distance_matrix = cdist(rolling_T_A, rolling_T_B, metric="minkowski", p=p) + P = np.full((l, k + 2), np.inf, dtype=np.float64) + I = np.full((l, k + 2), -1, dtype=np.int64) # two more columns are to store + # ... left and right top-1 matrix profile indices - if ignore_trivial: - diags = np.arange(exclusion_zone + 1, n_A - m + 1) - else: - diags = np.arange(-(n_A - m + 1) + 1, n_B - m + 1) + if row_wise: + if ignore_trivial: # self-join + for i in range(l): + apply_exclusion_zone(distance_matrix[i], i, exclusion_zone, np.inf) + + for i, D in enumerate(distance_matrix): # D: distance profile + # self-join / AB-join: matrix proifle and indices + indices = np.argsort(D)[:k] + P[i, :k] = D[indices] + indices[P[i, :k] == np.inf] = -1 + I[i, :k] = indices - P = np.full((l, 3), np.inf) - I = np.full((l, 3), -1, dtype=np.int64) + # self-join: left matrix profile index (top-1) + if ignore_trivial and i > 0: + IL = np.argmin(D[:i]) + if D[IL] == np.inf: + IL = -1 + I[i, k] = IL + + # self-join: right matrix profile index (top-1) + if ignore_trivial and i < D.shape[0]: + IR = i + np.argmin(D[i:]) # offset by `i` to get true index + if D[IR] == np.inf: + IR = -1 + I[i, k + 1] = IR - for k in diags: - if k >= 0: - iter_range = range(0, min(n_A - m + 1, n_B - m + 1 - k)) + else: + if ignore_trivial: + diags = np.arange(exclusion_zone + 1, n_A - m + 1) else: - iter_range = range(-k, min(n_A - m + 1, n_B - m + 1 - k)) - - for i in iter_range: - D = distance_matrix[i, i + k] - if D < P[i, 0]: - P[i, 0] = D - I[i, 0] = i + k - - if ignore_trivial: # Self-joins only - if D < P[i + k, 0]: - P[i + k, 0] = D - I[i + k, 0] = i - - if i < i + k: - # Left matrix profile and left matrix profile index - if D < P[i + k, 1]: - P[i + k, 1] = D - I[i + k, 1] = i - - if D < P[i, 2]: - # right matrix profile and right matrix profile index - P[i, 2] = D - I[i, 2] = i + k - - result = np.empty((l, 4), dtype=object) - result[:, 0] = P[:, 0] - result[:, 1:4] = I[:, :] + diags = np.arange(-(n_A - m + 1) + 1, n_B - m + 1) + + for g in diags: + if g >= 0: + iter_range = range(0, min(n_A - m + 1, n_B - m + 1 - g)) + else: + iter_range = range(-g, min(n_A - m + 1, n_B - m + 1 - g)) + + for i in iter_range: + d = distance_matrix[i, i + g] + if d < P[i, k - 1]: + idx = searchsorted_right(P[i], d) + # to keep the top-k, we must discard the last element. + P[i, :k] = np.insert(P[i, :k], idx, d)[:-1] + I[i, :k] = np.insert(I[i, :k], idx, i + g)[:-1] + + if ignore_trivial: # Self-joins only + if d < P[i + g, k - 1]: + idx = searchsorted_right(P[i + g], d) + P[i + g, :k] = np.insert(P[i + g, :k], idx, d)[:-1] + I[i + g, :k] = np.insert(I[i + g, :k], idx, i)[:-1] + + if i < i + g: + # Left matrix profile and left matrix profile index + if d < P[i + g, k]: + P[i + g, k] = d + I[i + g, k] = i + + if d < P[i, k + 1]: + # right matrix profile and right matrix profile index + P[i, k + 1] = d + I[i, k + 1] = i + g + + result = np.empty((l, 2 * k + 2), dtype=object) + result[:, :k] = P[:, :k] + result[:, k:] = I[:, :] return result @@ -713,27 +749,30 @@ def get_array_ranges(a, n_chunks, truncate): class aampi_egress(object): - def __init__(self, T, m, excl_zone=None, p=2.0): + def __init__(self, T, m, excl_zone=None, p=2.0, k=1): self._T = np.asarray(T) self._T = self._T.copy() self._T_isfinite = np.isfinite(self._T) self._m = m self._p = p + self._k = k self._excl_zone = excl_zone if self._excl_zone is None: self._excl_zone = int(np.ceil(self._m / config.STUMPY_EXCL_ZONE_DENOM)) self._l = self._T.shape[0] - m + 1 - mp = aamp(T, m, exclusion_zone=self._excl_zone, p=p) - self.P_ = mp[:, 0] - self.I_ = mp[:, 1].astype(np.int64) - self.left_P_ = np.full(self.P_.shape, np.inf) - self.left_I_ = mp[:, 2].astype(np.int64) - for i, j in enumerate(self.left_I_): - if j >= 0: - self.left_P_[i] = np.linalg.norm( - self._T[i : i + self._m] - self._T[j : j + self._m], ord=self._p + mp = aamp(T, m, exclusion_zone=self._excl_zone, p=p, k=self._k) + self._P = mp[:, :k].astype(np.float64) + self._I = mp[:, k : 2 * k].astype(np.int64) + + self._left_I = mp[:, 2 * k].astype(np.int64) + self._left_P = np.full_like(self._left_I, np.inf, dtype=np.float64) + for idx, nn_idx in enumerate(self._left_I): + if nn_idx >= 0: + self._left_P[idx] = np.linalg.norm( + self._T[idx : idx + self._m] - self._T[nn_idx : nn_idx + self._m], + ord=self._p, ) self._n_appended = 0 @@ -749,12 +788,11 @@ def update(self, t): self._T[-1] = 0 self._n_appended += 1 - self.P_[:] = np.roll(self.P_, -1) - self.I_[:] = np.roll(self.I_, -1) - self.left_P_[:] = np.roll(self.left_P_, -1) - self.left_I_[:] = np.roll(self.left_I_, -1) + self._P = np.roll(self._P, -1, axis=0) + self._I = np.roll(self._I, -1, axis=0) + self._left_P[:] = np.roll(self._left_P, -1) + self._left_I[:] = np.roll(self._left_I, -1) - D = core.mass_absolute(self._T[-self._m :], self._T) D = cdist( core.rolling_window(self._T[-self._m :], self._m), core.rolling_window(self._T, self._m), @@ -770,21 +808,45 @@ def update(self, t): apply_exclusion_zone(D, D.shape[0] - 1, self._excl_zone, np.inf) for j in range(D.shape[0]): - if D[j] < self.P_[j]: - self.I_[j] = D.shape[0] - 1 + self._n_appended - self.P_[j] = D[j] + if D[j] < self._P[j, -1]: + pos = np.searchsorted(self._P[j], D[j], side="right") + self._P[j] = np.insert(self._P[j], pos, D[j])[:-1] + self._I[j] = np.insert( + self._I[j], pos, D.shape[0] - 1 + self._n_appended + )[:-1] - I_last = np.argmin(D) + # update top-k for the last, newly-updated index + I_last_topk = np.argsort(D, kind="mergesort")[: self._k] + self._P[-1] = D[I_last_topk] + self._I[-1] = I_last_topk + self._n_appended + self._I[-1][self._P[-1] == np.inf] = -1 - if np.isinf(D[I_last]): - self.I_[-1] = -1 - self.P_[-1] = np.inf + # for the last index, the left matrix profile value is self.P_[-1, 0] + # and the same goes for the left matrix profile index + self._left_P[-1] = self._P[-1, 0] + self._left_I[-1] = self._I[-1, 0] + + @property + def P_(self): + if self._k == 1: + return self._P.flatten().astype(np.float64) else: - self.I_[-1] = I_last + self._n_appended - self.P_[-1] = D[I_last] + return self._P.astype(np.float64) - self.left_I_[-1] = self.I_[-1] - self.left_P_[-1] = self.P_[-1] + @property + def I_(self): + if self._k == 1: + return self._I.flatten().astype(np.int64) + else: + return self._I.astype(np.int64) + + @property + def left_P_(self): + return self._left_P.astype(np.float64) + + @property + def left_I_(self): + return self._left_I.astype(np.int64) class stumpi_egress(object): @@ -1576,55 +1638,81 @@ def scrump(T_A, m, T_B, percentage, exclusion_zone, pre_scrump, s, k=1): return P, I, IL, IR -def prescraamp(T_A, m, T_B, s, exclusion_zone=None, p=2.0): +def prescraamp(T_A, m, T_B, s, exclusion_zone=None, p=2.0, k=1): distance_matrix = aamp_distance_matrix(T_A, T_B, m, p) - l = T_A.shape[0] - m + 1 # length of matrix profile - w = T_B.shape[0] - m + 1 # length of each distance profile + l = T_A.shape[0] - m + 1 # matrix profile length + w = T_B.shape[0] - m + 1 # distance profile length - P = np.empty(l) - I = np.empty(l, dtype=np.int64) - P[:] = np.inf - I[:] = -1 + P = np.full((l, k), np.inf, dtype=np.float64) + I = np.full((l, k), -1, dtype=np.int64) for i in np.random.permutation(range(0, l, s)): distance_profile = distance_matrix[i] if exclusion_zone is not None: apply_exclusion_zone(distance_profile, i, exclusion_zone, np.inf) - # only for self-join - mask = distance_profile < P - P[mask] = distance_profile[mask] - I[mask] = i + nn_idx = np.argmin(distance_profile) + if distance_profile[nn_idx] < P[i, -1] and nn_idx not in I[i]: + pos = np.searchsorted(P[i], distance_profile[nn_idx], side="right") + P[i] = np.insert(P[i], pos, distance_profile[nn_idx])[:-1] + I[i] = np.insert(I[i], pos, nn_idx)[:-1] + + if P[i, 0] == np.inf: + I[i, 0] = -1 + continue - I[i] = np.argmin(distance_profile) - P[i] = distance_profile[I[i]] - if P[i] == np.inf: # pragma: no cover - I[i] = -1 - else: - j = I[i] - for k in range(1, min(s, l - i, w - j)): - d = distance_matrix[i + k, j + k] - if d < P[i + k]: - P[i + k] = d - I[i + k] = j + k - if exclusion_zone is not None and d < P[j + k]: - P[j + k] = d - I[j + k] = i + k - - for k in range(1, min(s, i + 1, j + 1)): - d = distance_matrix[i - k, j - k] - if d < P[i - k]: - P[i - k] = d - I[i - k] = j - k - if exclusion_zone is not None and d < P[j - k]: - P[j - k] = d - I[j - k] = i - k + j = nn_idx + for g in range(1, min(s, l - i, w - j)): + d = distance_matrix[i + g, j + g] + # Do NOT optimize the `condition` in the following if statement + # and similar ones in this naive function. This is to ensure + # we are avoiding duplicates in each row of I. + if d < P[i + g, -1] and (j + g) not in I[i + g]: + pos = np.searchsorted(P[i + g], d, side="right") + P[i + g] = np.insert(P[i + g], pos, d)[:-1] + I[i + g] = np.insert(I[i + g], pos, j + g)[:-1] + if ( + exclusion_zone is not None + and d < P[j + g, -1] + and (i + g) not in I[j + g] + ): + pos = np.searchsorted(P[j + g], d, side="right") + P[j + g] = np.insert(P[j + g], pos, d)[:-1] + I[j + g] = np.insert(I[j + g], pos, i + g)[:-1] + + for g in range(1, min(s, i + 1, j + 1)): + d = distance_matrix[i - g, j - g] + if d < P[i - g, -1] and (j - g) not in I[i - g]: + pos = np.searchsorted(P[i - g], d, side="right") + P[i - g] = np.insert(P[i - g], pos, d)[:-1] + I[i - g] = np.insert(I[i - g], pos, j - g)[:-1] + if ( + exclusion_zone is not None + and d < P[j - g, -1] + and (i - g) not in I[j - g] + ): + pos = np.searchsorted(P[j - g], d, side="right") + P[j - g] = np.insert(P[j - g], pos, d)[:-1] + I[j - g] = np.insert(I[j - g], pos, i - g)[:-1] + + # In the case of a self-join, the calculated distance profile can also be + # used to refine the top-k for all non-trivial subsequences + if exclusion_zone is not None: + for idx in np.flatnonzero(distance_profile < P[:, -1]): + if i not in I[idx]: + pos = np.searchsorted(P[idx], distance_profile[idx], side="right") + P[idx] = np.insert(P[idx], pos, distance_profile[idx])[:-1] + I[idx] = np.insert(I[idx], pos, i)[:-1] + + if k == 1: + P = P.flatten() + I = I.flatten() return P, I -def scraamp(T_A, m, T_B, percentage, exclusion_zone, pre_scraamp, s, p=2.0): +def scraamp(T_A, m, T_B, percentage, exclusion_zone, pre_scraamp, s, p=2.0, k=1): distance_matrix = aamp_distance_matrix(T_A, T_B, m, p) n_A = T_A.shape[0] @@ -1646,47 +1734,48 @@ def scraamp(T_A, m, T_B, percentage, exclusion_zone, pre_scraamp, s, p=2.0): diags_ranges_start = diags_ranges[0, 0] diags_ranges_stop = diags_ranges[0, 1] - out = np.full((l, 4), np.inf, dtype=object) - out[:, 1:] = -1 - left_P = np.full(l, np.inf, dtype=np.float64) - right_P = np.full(l, np.inf, dtype=np.float64) + P = np.full((l, k), np.inf, dtype=np.float64) # Topk + PL = np.full(l, np.inf, dtype=np.float64) + PR = np.full(l, np.inf, dtype=np.float64) + + I = np.full((l, k), -1, dtype=np.int64) + IL = np.full(l, -1, dtype=np.int64) + IR = np.full(l, -1, dtype=np.int64) for diag_idx in range(diags_ranges_start, diags_ranges_stop): - k = diags[diag_idx] + g = diags[diag_idx] for i in range(n_A - m + 1): for j in range(n_B - m + 1): - if j - i == k: - if distance_matrix[i, j] < out[i, 0]: - out[i, 0] = distance_matrix[i, j] - out[i, 1] = i + k - - if ( - exclusion_zone is not None - and distance_matrix[i, j] < out[i + k, 0] - ): - out[i + k, 0] = distance_matrix[i, j] - out[i + k, 1] = i + if j - i == g: + d = distance_matrix[i, j] + if d < P[i, -1]: + idx = searchsorted_right(P[i], d) + if (i + g) not in I[i]: + P[i] = np.insert(P[i], idx, d)[:-1] + I[i] = np.insert(I[i], idx, i + g)[:-1] + + if exclusion_zone is not None and d < P[i + g, -1]: + idx = searchsorted_right(P[i + g], d) + if i not in I[i + g]: + P[i + g] = np.insert(P[i + g], idx, d)[:-1] + I[i + g] = np.insert(I[i + g], idx, i)[:-1] # left matrix profile and left matrix profile indices - if ( - exclusion_zone is not None - and i < i + k - and distance_matrix[i, j] < left_P[i + k] - ): - left_P[i + k] = distance_matrix[i, j] - out[i + k, 2] = i + if exclusion_zone is not None and i < i + g and d < PL[i + g]: + PL[i + g] = d + IL[i + g] = i # right matrix profile and right matrix profile indices - if ( - exclusion_zone is not None - and i + k > i - and distance_matrix[i, j] < right_P[i] - ): - right_P[i] = distance_matrix[i, j] - out[i, 3] = i + k + if exclusion_zone is not None and i + g > i and d < PR[i]: + PR[i] = d + IR[i] = i + g - return out + if k == 1: + P = P.flatten() + I = I.flatten() + + return P, I, IL, IR def normalize_pan(pan, ms, bfs_indices, n_processed, T_min=None, T_max=None, p=2.0): diff --git a/tests/test_aamp.py b/tests/test_aamp.py index 9ddaca8d7..603b38fb5 100644 --- a/tests/test_aamp.py +++ b/tests/test_aamp.py @@ -236,3 +236,37 @@ def test_aamp_nan_zero_mean_self_join(): naive.replace_inf(ref_mp) naive.replace_inf(comp_mp) npt.assert_almost_equal(ref_mp, comp_mp) + + +@pytest.mark.parametrize("T_A, T_B", test_data) +def test_aamp_self_join_KNN(T_A, T_B): + m = 3 + for k in range(2, 4): + for p in [1.0, 2.0, 3.0]: + ref_mp = naive.aamp(T_B, m, p=p, k=k) + comp_mp = aamp(T_B, m, p=p, k=k) + naive.replace_inf(ref_mp) + naive.replace_inf(comp_mp) + npt.assert_almost_equal(ref_mp, comp_mp) + + comp_mp = aamp(pd.Series(T_B), m, p=p, k=k) + naive.replace_inf(comp_mp) + npt.assert_almost_equal(ref_mp, comp_mp) + + +@pytest.mark.parametrize("T_A, T_B", test_data) +def test_aamp_A_B_join_KNN(T_A, T_B): + m = 3 + for k in range(2, 4): + for p in [1.0, 2.0, 3.0]: + ref_mp = naive.aamp(T_A, m, T_B=T_B, p=p, k=k) + comp_mp = aamp(T_A, m, T_B, ignore_trivial=False, p=p, k=k) + naive.replace_inf(ref_mp) + naive.replace_inf(comp_mp) + npt.assert_almost_equal(ref_mp, comp_mp) + + comp_mp = aamp( + pd.Series(T_A), m, pd.Series(T_B), ignore_trivial=False, p=p, k=k + ) + naive.replace_inf(comp_mp) + npt.assert_almost_equal(ref_mp, comp_mp) diff --git a/tests/test_aamp_stimp.py b/tests/test_aamp_stimp.py index 78665710c..c8e9aff93 100644 --- a/tests/test_aamp_stimp.py +++ b/tests/test_aamp_stimp.py @@ -51,12 +51,9 @@ def test_aamp_stimp_1_percent(T): zone = int(np.ceil(m / 4)) s = zone tmp_P, tmp_I = naive.prescraamp(T, m, T, s=s, exclusion_zone=zone) - ref_mp = naive.scraamp(T, m, T, percentage, zone, True, s) - for i in range(ref_mp.shape[0]): - if tmp_P[i] < ref_mp[i, 0]: - ref_mp[i, 0] = tmp_P[i] - ref_mp[i, 1] = tmp_I[i] - ref_PAN[pan._bfs_indices[idx], : ref_mp.shape[0]] = ref_mp[:, 0] + ref_P, ref_I, _, _ = naive.scraamp(T, m, T, percentage, zone, True, s) + naive.merge_topk_PI(ref_P, tmp_P, ref_I, tmp_I) + ref_PAN[pan._bfs_indices[idx], : ref_P.shape[0]] = ref_P # Compare raw pan cmp_PAN = pan._PAN @@ -114,12 +111,9 @@ def test_aamp_stimp_max_m(T): zone = int(np.ceil(m / 4)) s = zone tmp_P, tmp_I = naive.prescraamp(T, m, T, s=s, exclusion_zone=zone) - ref_mp = naive.scraamp(T, m, T, percentage, zone, True, s) - for i in range(ref_mp.shape[0]): - if tmp_P[i] < ref_mp[i, 0]: - ref_mp[i, 0] = tmp_P[i] - ref_mp[i, 1] = tmp_I[i] - ref_PAN[pan._bfs_indices[idx], : ref_mp.shape[0]] = ref_mp[:, 0] + ref_P, ref_I, _, _ = naive.scraamp(T, m, T, percentage, zone, True, s) + naive.merge_topk_PI(ref_P, tmp_P, ref_I, tmp_I) + ref_PAN[pan._bfs_indices[idx], : ref_P.shape[0]] = ref_P # Compare raw pan cmp_PAN = pan._PAN diff --git a/tests/test_aamped.py b/tests/test_aamped.py index feae3f6df..8abe46cef 100644 --- a/tests/test_aamped.py +++ b/tests/test_aamped.py @@ -602,3 +602,39 @@ def test_aamped_two_subsequences_nan_inf_A_B_join_swap( naive.replace_inf(ref_mp) naive.replace_inf(comp_mp) npt.assert_almost_equal(ref_mp, comp_mp) + + +@pytest.mark.filterwarnings("ignore:numpy.dtype size changed") +@pytest.mark.filterwarnings("ignore:numpy.ufunc size changed") +@pytest.mark.filterwarnings("ignore:numpy.ndarray size changed") +@pytest.mark.filterwarnings("ignore:\\s+Port 8787 is already in use:UserWarning") +@pytest.mark.parametrize("T_A, T_B", test_data) +def test_aamped_self_join_KNN(T_A, T_B, dask_cluster): + with Client(dask_cluster) as dask_client: + m = 3 + for k in range(2, 4): + for p in [1.0, 2.0, 3.0]: + ref_mp = naive.aamp(T_B, m, p=p, k=k) + comp_mp = aamped(dask_client, T_B, m, p=p, k=k) + naive.replace_inf(ref_mp) + naive.replace_inf(comp_mp) + npt.assert_almost_equal(ref_mp, comp_mp) + + +@pytest.mark.filterwarnings("ignore:numpy.dtype size changed") +@pytest.mark.filterwarnings("ignore:numpy.ufunc size changed") +@pytest.mark.filterwarnings("ignore:numpy.ndarray size changed") +@pytest.mark.filterwarnings("ignore:\\s+Port 8787 is already in use:UserWarning") +@pytest.mark.parametrize("T_A, T_B", test_data) +def test_aamped_A_B_join_KNN(T_A, T_B, dask_cluster): + with Client(dask_cluster) as dask_client: + m = 3 + for k in range(2, 4): + for p in [1.0, 2.0, 3.0]: + ref_mp = naive.aamp(T_A, m, T_B=T_B, p=p, k=k) + comp_mp = aamped( + dask_client, T_A, m, T_B, ignore_trivial=False, p=p, k=k + ) + naive.replace_inf(ref_mp) + naive.replace_inf(comp_mp) + npt.assert_almost_equal(ref_mp, comp_mp) diff --git a/tests/test_aampi.py b/tests/test_aampi.py index 4f1f437dd..6969dbf35 100644 --- a/tests/test_aampi.py +++ b/tests/test_aampi.py @@ -976,3 +976,172 @@ def test_aampi_profile_index_match(): npt.assert_almost_equal(stream.left_P_, left_P) n += 1 + + +def test_aampi_self_join_KNN(): + m = 3 + for k in range(2, 4): + for p in [1.0, 2.0, 3.0]: + seed = np.random.randint(100000) + np.random.seed(seed) + + n = 30 + T = np.random.rand(n) + stream = aampi(T, m, egress=False, p=p, k=k) + for i in range(34): + t = np.random.rand() + stream.update(t) + + comp_P = stream.P_ + comp_I = stream.I_ + comp_left_P = stream.left_P_ + comp_left_I = stream.left_I_ + + ref_mp = naive.aamp(stream.T_, m, p=p, k=k) + ref_P = ref_mp[:, :k] + ref_I = ref_mp[:, k : 2 * k] + + ref_left_I = ref_mp[:, 2 * k] + ref_left_P = np.full_like(ref_left_I, np.inf, dtype=np.float64) + for i, j in enumerate(ref_left_I): + if j >= 0: + ref_left_P[i] = np.linalg.norm( + stream.T_[i : i + m] - stream.T_[j : j + m], ord=p + ) + + naive.replace_inf(ref_P) + naive.replace_inf(ref_left_P) + naive.replace_inf(comp_P) + naive.replace_inf(comp_left_P) + + npt.assert_almost_equal(ref_P, comp_P) + npt.assert_almost_equal(ref_I, comp_I) + npt.assert_almost_equal(ref_left_P, comp_left_P) + npt.assert_almost_equal(ref_left_I, comp_left_I) + + np.random.seed(seed) + n = 30 + T = np.random.rand(n) + T = pd.Series(T) + stream = aampi(T, m, egress=False, p=p, k=k) + for i in range(34): + t = np.random.rand() + stream.update(t) + + comp_P = stream.P_ + comp_I = stream.I_ + comp_left_P = stream.left_P_ + comp_left_I = stream.left_I_ + + naive.replace_inf(comp_P) + naive.replace_inf(comp_left_P) + + npt.assert_almost_equal(ref_P, comp_P) + npt.assert_almost_equal(ref_I, comp_I) + npt.assert_almost_equal(ref_left_P, comp_left_P) + npt.assert_almost_equal(ref_left_I, comp_left_I) + + +def test_aampi_self_join_egress_KNN(): + m = 3 + for k in range(2, 4): + for p in [1.0, 2.0, 3.0]: + seed = np.random.randint(100000) + np.random.seed(seed) + + n = 30 + T = np.random.rand(n) + + ref_mp = naive.aampi_egress(T, m, p=p, k=k) + ref_P = ref_mp.P_.copy() + ref_I = ref_mp.I_ + ref_left_P = ref_mp.left_P_.copy() + ref_left_I = ref_mp.left_I_ + + stream = aampi(T, m, egress=True, p=p, k=k) + + comp_P = stream.P_.copy() + comp_I = stream.I_ + comp_left_P = stream.left_P_.copy() + comp_left_I = stream.left_I_ + + naive.replace_inf(ref_P) + naive.replace_inf(ref_left_P) + naive.replace_inf(comp_P) + naive.replace_inf(comp_left_P) + + npt.assert_almost_equal(ref_P, comp_P) + npt.assert_almost_equal(ref_I, comp_I) + npt.assert_almost_equal(ref_left_P, comp_left_P) + npt.assert_almost_equal(ref_left_I, comp_left_I) + + for i in range(34): + t = np.random.rand() + + ref_mp.update(t) + stream.update(t) + + comp_P = stream.P_.copy() + comp_I = stream.I_ + comp_left_P = stream.left_P_.copy() + comp_left_I = stream.left_I_ + + ref_P = ref_mp.P_.copy() + ref_I = ref_mp.I_ + ref_left_P = ref_mp.left_P_.copy() + ref_left_I = ref_mp.left_I_ + + naive.replace_inf(ref_P) + naive.replace_inf(ref_left_P) + naive.replace_inf(comp_P) + naive.replace_inf(comp_left_P) + + npt.assert_almost_equal(ref_P, comp_P) + npt.assert_almost_equal(ref_I, comp_I) + npt.assert_almost_equal(ref_left_P, comp_left_P) + npt.assert_almost_equal(ref_left_I, comp_left_I) + + np.random.seed(seed) + T = np.random.rand(n) + T = pd.Series(T) + + ref_mp = naive.aampi_egress(T, m, p=p, k=k) + ref_P = ref_mp.P_.copy() + ref_I = ref_mp.I_ + + stream = aampi(T, m, egress=True, p=p, k=k) + + comp_P = stream.P_.copy() + comp_I = stream.I_ + + naive.replace_inf(ref_P) + naive.replace_inf(comp_P) + + npt.assert_almost_equal(ref_P, comp_P) + npt.assert_almost_equal(ref_I, comp_I) + + for i in range(34): + t = np.random.rand() + + ref_mp.update(t) + stream.update(t) + + comp_P = stream.P_.copy() + comp_I = stream.I_ + comp_left_P = stream.left_P_.copy() + comp_left_I = stream.left_I_ + + ref_P = ref_mp.P_.copy() + ref_I = ref_mp.I_ + ref_left_P = ref_mp.left_P_.copy() + ref_left_I = ref_mp.left_I_ + + naive.replace_inf(ref_P) + naive.replace_inf(ref_left_P) + naive.replace_inf(comp_P) + naive.replace_inf(comp_left_P) + + npt.assert_almost_equal(ref_P, comp_P) + npt.assert_almost_equal(ref_I, comp_I) + npt.assert_almost_equal(ref_left_P, comp_left_P) + npt.assert_almost_equal(ref_left_I, comp_left_I) diff --git a/tests/test_core.py b/tests/test_core.py index 77235326f..a8e7cba8e 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -1,4 +1,5 @@ import numpy as np +from numba import cuda import numpy.testing as npt import pandas as pd from scipy.spatial.distance import cdist @@ -10,6 +11,37 @@ import naive +if cuda.is_available(): + from stumpy.core import _gpu_searchsorted_left, _gpu_searchsorted_right +else: # pragma: no cover + from stumpy.core import ( + _gpu_searchsorted_left_driver_not_found as _gpu_searchsorted_left, + ) + from stumpy.core import ( + _gpu_searchsorted_right_driver_not_found as _gpu_searchsorted_right, + ) + +try: + from numba.errors import NumbaPerformanceWarning +except ModuleNotFoundError: + from numba.core.errors import NumbaPerformanceWarning + +TEST_THREADS_PER_BLOCK = 10 + +if not cuda.is_available(): # pragma: no cover + pytest.skip("Skipping Tests No GPUs Available", allow_module_level=True) + + +@cuda.jit("(f8[:, :], f8[:], i8[:], i8, b1, i8[:])") +def _gpu_searchsorted_kernel(a, v, bfs, nlevel, is_left, idx): + # A wrapper kernel for calling device function _gpu_searchsorted_left/right. + i = cuda.grid(1) + if i < a.shape[0]: + if is_left: + idx[i] = _gpu_searchsorted_left(a[i], v[i], bfs, nlevel) + else: + idx[i] = _gpu_searchsorted_right(a[i], v[i], bfs, nlevel) + def naive_rolling_window_dot_product(Q, T): window = len(Q) @@ -1365,3 +1397,52 @@ def test_find_matches_maxmatch(): comp = core._find_matches(D, excl_zone, max_distance, max_matches) npt.assert_almost_equal(ref, comp) + + +@pytest.mark.filterwarnings("ignore", category=NumbaPerformanceWarning) +@patch("stumpy.config.STUMPY_THREADS_PER_BLOCK", TEST_THREADS_PER_BLOCK) +def test_gpu_searchsorted(): + if not cuda.is_available(): # pragma: no cover + pytest.skip("Skipping Tests No GPUs Available", allow_module_level=True) + + n = 3 * config.STUMPY_THREADS_PER_BLOCK + 1 + V = np.empty(n, dtype=np.float64) + + threads_per_block = config.STUMPY_THREADS_PER_BLOCK + blocks_per_grid = math.ceil(n / threads_per_block) + + for k in range(1, 32): + device_bfs = cuda.to_device(core._bfs_indices(k, fill_value=-1)) + nlevel = np.floor(np.log2(k) + 1).astype(np.int64) + + A = np.sort(np.random.rand(n, k), axis=1) + device_A = cuda.to_device(A) + + V[:] = np.random.rand(n) + for i, idx in enumerate(np.random.choice(np.arange(n), size=k, replace=False)): + V[idx] = A[idx, i] # create ties + device_V = cuda.to_device(V) + + is_left = True # test case + ref_IDX = [np.searchsorted(A[i], V[i], side="left") for i in range(n)] + ref_IDX = np.asarray(ref_IDX, dtype=np.int64) + + comp_IDX = np.full(n, -1, dtype=np.int64) + device_comp_IDX = cuda.to_device(comp_IDX) + _gpu_searchsorted_kernel[blocks_per_grid, threads_per_block]( + device_A, device_V, device_bfs, nlevel, is_left, device_comp_IDX + ) + comp_IDX = device_comp_IDX.copy_to_host() + npt.assert_array_equal(ref_IDX, comp_IDX) + + is_left = False # test case + ref_IDX = [np.searchsorted(A[i], V[i], side="right") for i in range(n)] + ref_IDX = np.asarray(ref_IDX, dtype=np.int64) + + comp_IDX = np.full(n, -1, dtype=np.int64) + device_comp_IDX = cuda.to_device(comp_IDX) + _gpu_searchsorted_kernel[blocks_per_grid, threads_per_block]( + device_A, device_V, device_bfs, nlevel, is_left, device_comp_IDX + ) + comp_IDX = device_comp_IDX.copy_to_host() + npt.assert_array_equal(ref_IDX, comp_IDX) diff --git a/tests/test_gpu_aamp.py b/tests/test_gpu_aamp.py index 43c805adc..c51a0e768 100644 --- a/tests/test_gpu_aamp.py +++ b/tests/test_gpu_aamp.py @@ -48,7 +48,7 @@ def test_gpu_aamp_self_join(T_A, T_B): m = 3 zone = int(np.ceil(m / 4)) for p in [1.0, 2.0, 3.0]: - ref_mp = naive.aamp(T_B, m, exclusion_zone=zone, p=p) + ref_mp = naive.aamp(T_B, m, exclusion_zone=zone, p=p, row_wise=True) comp_mp = gpu_aamp(T_B, m, ignore_trivial=True, p=p) naive.replace_inf(ref_mp) naive.replace_inf(comp_mp) @@ -88,7 +88,7 @@ def test_gpu_aamp_self_join_larger_window(T_A, T_B, m): def test_gpu_aamp_A_B_join(T_A, T_B): m = 3 for p in [1.0, 2.0, 3.0]: - ref_mp = naive.aamp(T_B, m, T_B=T_A, p=p) + ref_mp = naive.aamp(T_B, m, T_B=T_A, p=p, row_wise=True) comp_mp = gpu_aamp(T_B, m, T_A, ignore_trivial=False, p=p) naive.replace_inf(ref_mp) naive.replace_inf(comp_mp) @@ -369,3 +369,34 @@ def test_gpu_aamp_nan_zero_mean_self_join(): naive.replace_inf(ref_mp) naive.replace_inf(comp_mp) npt.assert_almost_equal(ref_mp, comp_mp) + + +@pytest.mark.filterwarnings("ignore", category=NumbaPerformanceWarning) +@pytest.mark.parametrize("T_A, T_B", test_data) +def test_gpu_aamp_self_join_KNN(T_A, T_B): + m = 3 + zone = int(np.ceil(m / 4)) + for k in range(2, 4): + for p in [1.0, 2.0, 3.0]: + ref_mp = naive.aamp(T_B, m, exclusion_zone=zone, p=p, k=k) + comp_mp = gpu_aamp(T_B, m, ignore_trivial=True, p=p, k=k) + naive.replace_inf(ref_mp) + naive.replace_inf(comp_mp) + npt.assert_almost_equal(ref_mp, comp_mp) + + comp_mp = gpu_aamp(pd.Series(T_B), m, ignore_trivial=True, p=p, k=k) + naive.replace_inf(comp_mp) + npt.assert_almost_equal(ref_mp, comp_mp) + + +@pytest.mark.filterwarnings("ignore", category=NumbaPerformanceWarning) +@pytest.mark.parametrize("T_A, T_B", test_data) +def test_gpu_aamp_A_B_join_KNN(T_A, T_B): + m = 3 + for k in range(2, 4): + for p in [1.0, 2.0, 3.0]: + ref_mp = naive.aamp(T_B, m, T_B=T_A, p=p, k=k) + comp_mp = gpu_aamp(T_B, m, T_A, ignore_trivial=False, p=p, k=k) + naive.replace_inf(ref_mp) + naive.replace_inf(comp_mp) + npt.assert_almost_equal(ref_mp, comp_mp) diff --git a/tests/test_gpu_stump.py b/tests/test_gpu_stump.py index 242f8135d..c7517929d 100644 --- a/tests/test_gpu_stump.py +++ b/tests/test_gpu_stump.py @@ -1,22 +1,11 @@ -import math import numpy as np import numpy.testing as npt import pandas as pd -from stumpy import core, gpu_stump +from stumpy import gpu_stump from stumpy import config from numba import cuda from unittest.mock import patch -if cuda.is_available(): - from stumpy.gpu_stump import _gpu_searchsorted_left, _gpu_searchsorted_right -else: # pragma: no cover - from stumpy.core import ( - _gpu_searchsorted_left_driver_not_found as _gpu_searchsorted_left, - ) - from stumpy.core import ( - _gpu_searchsorted_right_driver_not_found as _gpu_searchsorted_right, - ) - try: from numba.errors import NumbaPerformanceWarning except ModuleNotFoundError: @@ -52,63 +41,6 @@ def test_gpu_stump_int_input(): gpu_stump(np.arange(10), 5, ignore_trivial=True) -@cuda.jit("(f8[:, :], f8[:], i8[:], i8, b1, i8[:])") -def _gpu_searchsorted_kernel(a, v, bfs, nlevel, is_left, idx): - # A wrapper kernel for calling device function _gpu_searchsorted_left/right. - i = cuda.grid(1) - if i < a.shape[0]: - if is_left: - idx[i] = _gpu_searchsorted_left(a[i], v[i], bfs, nlevel) - else: - idx[i] = _gpu_searchsorted_right(a[i], v[i], bfs, nlevel) - - -@pytest.mark.filterwarnings("ignore", category=NumbaPerformanceWarning) -@patch("stumpy.config.STUMPY_THREADS_PER_BLOCK", TEST_THREADS_PER_BLOCK) -def test_gpu_searchsorted(): - n = 3 * config.STUMPY_THREADS_PER_BLOCK + 1 - V = np.empty(n, dtype=np.float64) - - threads_per_block = config.STUMPY_THREADS_PER_BLOCK - blocks_per_grid = math.ceil(n / threads_per_block) - - for k in range(1, 32): - device_bfs = cuda.to_device(core._bfs_indices(k, fill_value=-1)) - nlevel = np.floor(np.log2(k) + 1).astype(np.int64) - - A = np.sort(np.random.rand(n, k), axis=1) - device_A = cuda.to_device(A) - - V[:] = np.random.rand(n) - for i, idx in enumerate(np.random.choice(np.arange(n), size=k, replace=False)): - V[idx] = A[idx, i] # create ties - device_V = cuda.to_device(V) - - is_left = True # test case - ref_IDX = [np.searchsorted(A[i], V[i], side="left") for i in range(n)] - ref_IDX = np.asarray(ref_IDX, dtype=np.int64) - - comp_IDX = np.full(n, -1, dtype=np.int64) - device_comp_IDX = cuda.to_device(comp_IDX) - _gpu_searchsorted_kernel[blocks_per_grid, threads_per_block]( - device_A, device_V, device_bfs, nlevel, is_left, device_comp_IDX - ) - comp_IDX = device_comp_IDX.copy_to_host() - npt.assert_array_equal(ref_IDX, comp_IDX) - - is_left = False # test case - ref_IDX = [np.searchsorted(A[i], V[i], side="right") for i in range(n)] - ref_IDX = np.asarray(ref_IDX, dtype=np.int64) - - comp_IDX = np.full(n, -1, dtype=np.int64) - device_comp_IDX = cuda.to_device(comp_IDX) - _gpu_searchsorted_kernel[blocks_per_grid, threads_per_block]( - device_A, device_V, device_bfs, nlevel, is_left, device_comp_IDX - ) - comp_IDX = device_comp_IDX.copy_to_host() - npt.assert_array_equal(ref_IDX, comp_IDX) - - @pytest.mark.filterwarnings("ignore", category=NumbaPerformanceWarning) @pytest.mark.parametrize("T_A, T_B", test_data) @patch("stumpy.config.STUMPY_THREADS_PER_BLOCK", TEST_THREADS_PER_BLOCK) diff --git a/tests/test_scraamp.py b/tests/test_scraamp.py index 3193f0753..c0ae88d50 100644 --- a/tests/test_scraamp.py +++ b/tests/test_scraamp.py @@ -113,11 +113,9 @@ def test_scraamp_self_join(T_A, T_B, percentages): seed = np.random.randint(100000) np.random.seed(seed) - ref_mp = naive.scraamp(T_B, m, T_B, percentage, zone, False, None, p=p) - ref_P = ref_mp[:, 0] - ref_I = ref_mp[:, 1] - ref_left_I = ref_mp[:, 2] - ref_right_I = ref_mp[:, 3] + ref_P, ref_I, ref_left_I, ref_right_I = naive.scraamp( + T_B, m, T_B, percentage, zone, False, None, p=p + ) np.random.seed(seed) approx = scraamp( @@ -152,11 +150,9 @@ def test_scraamp_A_B_join(T_A, T_B, percentages): seed = np.random.randint(100000) np.random.seed(seed) - ref_mp = naive.scraamp(T_A, m, T_B, percentage, None, False, None, p=p) - ref_P = ref_mp[:, 0] - ref_I = ref_mp[:, 1] - ref_left_I = ref_mp[:, 2] - ref_right_I = ref_mp[:, 3] + ref_P, ref_I, ref_left_I, ref_right_I = naive.scraamp( + T_A, m, T_B, percentage, None, False, None, p=p + ) np.random.seed(seed) approx = scraamp( @@ -192,11 +188,9 @@ def test_scraamp_A_B_join_swap(T_A, T_B, percentages): seed = np.random.randint(100000) np.random.seed(seed) - ref_mp = naive.scraamp(T_B, m, T_A, percentage, None, False, None) - ref_P = ref_mp[:, 0] - # ref_I = ref_mp[:, 1] - ref_left_I = ref_mp[:, 2] - ref_right_I = ref_mp[:, 3] + ref_P, _, ref_left_I, ref_right_I = naive.scraamp( + T_B, m, T_A, percentage, None, False, None + ) np.random.seed(seed) approx = scraamp( @@ -228,11 +222,9 @@ def test_scraamp_self_join_larger_window(T_A, T_B, m, percentages): seed = np.random.randint(100000) np.random.seed(seed) - ref_mp = naive.scraamp(T_B, m, T_B, percentage, zone, False, None) - ref_P = ref_mp[:, 0] - ref_I = ref_mp[:, 1] - ref_left_I = ref_mp[:, 2] - ref_right_I = ref_mp[:, 3] + ref_P, ref_I, ref_left_I, ref_right_I = naive.scraamp( + T_B, m, T_B, percentage, zone, False, None + ) np.random.seed(seed) approx = scraamp( @@ -403,15 +395,11 @@ def test_scraamp_plus_plus_self_join(T_A, T_B, percentages): ref_P, ref_I = naive.prescraamp( T_B, m, T_B, s=s, exclusion_zone=zone, p=p ) - ref_mp = naive.scraamp(T_B, m, T_B, percentage, zone, True, s, p=p) - for i in range(ref_mp.shape[0]): - if ref_P[i] < ref_mp[i, 0]: - ref_mp[i, 0] = ref_P[i] - ref_mp[i, 1] = ref_I[i] - ref_P = ref_mp[:, 0] - ref_I = ref_mp[:, 1] - # ref_left_I = ref_mp[:, 2] - # ref_right_I = ref_mp[:, 3] + ref_P_aux, ref_I_aux, _, _ = naive.scraamp( + T_B, m, T_B, percentage, zone, True, s, p=p + ) + + naive.merge_topk_PI(ref_P, ref_P_aux, ref_I, ref_I_aux) np.random.seed(seed) approx = scraamp( @@ -430,7 +418,7 @@ def test_scraamp_plus_plus_self_join(T_A, T_B, percentages): # comp_right_I = approx.right_I_ naive.replace_inf(ref_P) - naive.replace_inf(comp_I) + naive.replace_inf(comp_P) npt.assert_almost_equal(ref_P, comp_P) npt.assert_almost_equal(ref_I, comp_I) @@ -451,15 +439,13 @@ def test_scraamp_plus_plus_A_B_join(T_A, T_B, percentages): np.random.seed(seed) ref_P, ref_I = naive.prescraamp(T_A, m, T_B, s=s, p=p) - ref_mp = naive.scraamp(T_A, m, T_B, percentage, None, False, None, p=p) - for i in range(ref_mp.shape[0]): - if ref_P[i] < ref_mp[i, 0]: - ref_mp[i, 0] = ref_P[i] - ref_mp[i, 1] = ref_I[i] - ref_P = ref_mp[:, 0] - ref_I = ref_mp[:, 1] - ref_left_I = ref_mp[:, 2] - ref_right_I = ref_mp[:, 3] + ref_P_aux, ref_I_aux, ref_left_I_aux, ref_right_I_aux = naive.scraamp( + T_A, m, T_B, percentage, None, False, None, p=p, k=1 + ) + + naive.merge_topk_PI(ref_P, ref_P_aux, ref_I, ref_I_aux) + ref_left_I = ref_left_I_aux + ref_right_I = ref_right_I_aux approx = scraamp( T_A, @@ -584,11 +570,7 @@ def test_scraamp_constant_subsequence_self_join(percentages): seed = np.random.randint(100000) np.random.seed(seed) - ref_mp = naive.scraamp(T, m, T, percentage, zone, False, None) - ref_P = ref_mp[:, 0] - # ref_I = ref_mp[:, 1] - # ref_left_I = ref_mp[:, 2] - # ref_right_I = ref_mp[:, 3] + ref_P, _, _, _ = naive.scraamp(T, m, T, percentage, zone, False, None) np.random.seed(seed) approx = scraamp( @@ -622,11 +604,7 @@ def test_scraamp_identical_subsequence_self_join(percentages): seed = np.random.randint(100000) np.random.seed(seed) - ref_mp = naive.scraamp(T, m, T, percentage, zone, False, None) - ref_P = ref_mp[:, 0] - # ref_I = ref_mp[:, 1] - # ref_left_I = ref_mp[:, 2] - # ref_right_I = ref_mp[:, 3] + ref_P, _, _, _ = naive.scraamp(T, m, T, percentage, zone, False, None) np.random.seed(seed) approx = scraamp( @@ -668,11 +646,9 @@ def test_scraamp_nan_inf_self_join( seed = np.random.randint(100000) np.random.seed(seed) - ref_mp = naive.scraamp(T_B_sub, m, T_B_sub, percentage, zone, False, None) - ref_P = ref_mp[:, 0] - ref_I = ref_mp[:, 1] - ref_left_I = ref_mp[:, 2] - ref_right_I = ref_mp[:, 3] + ref_P, ref_I, ref_left_I, ref_right_I = naive.scraamp( + T_B_sub, m, T_B_sub, percentage, zone, False, None + ) np.random.seed(seed) approx = scraamp(T_B_sub, m, percentage=percentage, pre_scraamp=False) @@ -702,11 +678,9 @@ def test_scraamp_nan_zero_mean_self_join(percentages): seed = np.random.randint(100000) np.random.seed(seed) - ref_mp = naive.scraamp(T, m, T, percentage, zone, False, None) - ref_P = ref_mp[:, 0] - ref_I = ref_mp[:, 1] - ref_left_I = ref_mp[:, 2] - ref_right_I = ref_mp[:, 3] + ref_P, ref_I, ref_left_I, ref_right_I = naive.scraamp( + T, m, T, percentage, zone, False, None + ) np.random.seed(seed) approx = scraamp(T, m, percentage=percentage, pre_scraamp=False) @@ -741,3 +715,211 @@ def test_prescraamp_A_B_join_larger_window(T_A, T_B): npt.assert_almost_equal(ref_P, comp_P) npt.assert_almost_equal(ref_I, comp_I) + + +@pytest.mark.parametrize("T_A, T_B", test_data) +def test_prescraamp_self_join_KNN(T_A, T_B): + m = 3 + zone = int(np.ceil(m / 4)) + for k in range(2, 4): + for p in [1.0, 2.0, 3.0]: + for s in range(1, zone + 1): + seed = np.random.randint(100000) + + np.random.seed(seed) + ref_P, ref_I = naive.prescraamp( + T_B, m, T_B, s=s, exclusion_zone=zone, p=p, k=k + ) + + np.random.seed(seed) + comp_P, comp_I = prescraamp(T_B, m, s=s, p=p, k=k) + + npt.assert_almost_equal(ref_P, comp_P) + npt.assert_almost_equal(ref_I, comp_I) + + +@pytest.mark.parametrize("T_A, T_B", test_data) +def test_prescraamp_A_B_join_KNN(T_A, T_B): + m = 3 + zone = int(np.ceil(m / 4)) + for k in range(2, 4): + for p in [1.0, 2.0, 3.0]: + for s in range(1, zone + 1): + seed = np.random.randint(100000) + + np.random.seed(seed) + ref_P, ref_I = naive.prescraamp(T_A, m, T_B, s=s, p=p, k=k) + + np.random.seed(seed) + comp_P, comp_I = prescraamp(T_A, m, T_B=T_B, s=s, p=p, k=k) + + npt.assert_almost_equal(ref_P, comp_P) + npt.assert_almost_equal(ref_I, comp_I) + + +@pytest.mark.parametrize("T_A, T_B", test_data) +@pytest.mark.parametrize("percentages", percentages) +def test_scraamp_self_join_KNN(T_A, T_B, percentages): + m = 3 + zone = int(np.ceil(m / 4)) + + for k in range(2, 4): + for p in [1.0, 2.0, 3.0]: + for percentage in percentages: + seed = np.random.randint(100000) + + np.random.seed(seed) + ref_P, ref_I, ref_left_I, ref_right_I = naive.scraamp( + T_B, m, T_B, percentage, zone, False, None, p=p, k=k + ) + + np.random.seed(seed) + approx = scraamp( + T_B, + m, + ignore_trivial=True, + percentage=percentage, + pre_scraamp=False, + p=p, + k=k, + ) + approx.update() + comp_P = approx.P_ + comp_I = approx.I_ + comp_left_I = approx.left_I_ + comp_right_I = approx.right_I_ + + naive.replace_inf(ref_P) + naive.replace_inf(comp_P) + npt.assert_almost_equal(ref_P, comp_P) + npt.assert_almost_equal(ref_I, comp_I) + npt.assert_almost_equal(ref_left_I, comp_left_I) + npt.assert_almost_equal(ref_right_I, comp_right_I) + + +@pytest.mark.parametrize("T_A, T_B", test_data) +@pytest.mark.parametrize("percentages", percentages) +def test_scraamp_A_B_join_KNN(T_A, T_B, percentages): + m = 3 + for k in range(2, 4): + for p in [1.0, 2.0, 3.0]: + for percentage in percentages: + seed = np.random.randint(100000) + + np.random.seed(seed) + ref_P, ref_I, ref_left_I, ref_right_I = naive.scraamp( + T_A, m, T_B, percentage, None, False, None, p=p, k=k + ) + + np.random.seed(seed) + approx = scraamp( + T_A, + m, + T_B, + ignore_trivial=False, + percentage=percentage, + pre_scraamp=False, + p=p, + k=k, + ) + approx.update() + comp_P = approx.P_ + comp_I = approx.I_ + comp_left_I = approx.left_I_ + comp_right_I = approx.right_I_ + + naive.replace_inf(ref_P) + naive.replace_inf(comp_P) + + npt.assert_almost_equal(ref_P, comp_P) + npt.assert_almost_equal(ref_I, comp_I) + npt.assert_almost_equal(ref_left_I, comp_left_I) + npt.assert_almost_equal(ref_right_I, comp_right_I) + + +@pytest.mark.parametrize("T_A, T_B", test_data) +@pytest.mark.parametrize("percentages", percentages) +def test_scraamp_plus_plus_self_join_KNN(T_A, T_B, percentages): + m = 3 + zone = int(np.ceil(m / 4)) + for k in range(2, 4): + for p in [1.0, 2.0, 3.0]: + for s in range(1, zone + 1): + for percentage in percentages: + seed = np.random.randint(100000) + + np.random.seed(seed) + ref_P, ref_I = naive.prescraamp( + T_B, m, T_B, s=s, exclusion_zone=zone, p=p, k=k + ) + ref_P_aux, ref_I_aux, _, _ = naive.scraamp( + T_B, m, T_B, percentage, zone, True, s, p=p, k=k + ) + + naive.merge_topk_PI(ref_P, ref_P_aux, ref_I, ref_I_aux) + + np.random.seed(seed) + approx = scraamp( + T_B, + m, + ignore_trivial=True, + percentage=percentage, + pre_scraamp=True, + s=s, + p=p, + k=k, + ) + approx.update() + comp_P = approx.P_ + comp_I = approx.I_ + # comp_left_I = approx.left_I_ + # comp_right_I = approx.right_I_ + + naive.replace_inf(ref_P) + naive.replace_inf(comp_P) + + npt.assert_almost_equal(ref_P, comp_P) + npt.assert_almost_equal(ref_I, comp_I) + # npt.assert_almost_equal(ref_left_I, comp_left_I) + # npt.assert_almost_equal(ref_right_I, comp_right_I) + + +@pytest.mark.parametrize("T_A, T_B", test_data) +@pytest.mark.parametrize("m", window_size) +def test_prescraamp_self_join_larger_window_m_5_k_5(T_A, T_B, m): + m = 5 + k = 5 + zone = int(np.ceil(m / 4)) + + if len(T_B) > m: + for s in range(1, zone + 1): + seed = np.random.randint(100000) + + np.random.seed(seed) + ref_P, ref_I = naive.prescraamp(T_B, m, T_B, s=s, exclusion_zone=zone, k=k) + + np.random.seed(seed) + comp_P, comp_I = prescraamp(T_B, m, s=s, k=k) + + npt.assert_almost_equal(ref_P, comp_P) + npt.assert_almost_equal(ref_I, comp_I) + + +@pytest.mark.parametrize("T_A, T_B", test_data) +def test_prescraamp_A_B_join_larger_window_m_5_k_5(T_A, T_B): + m = 5 + k = 5 + zone = int(np.ceil(m / 4)) + + if len(T_A) > m and len(T_B) > m: + for s in range(1, zone + 1): + seed = np.random.randint(100000) + + np.random.seed(seed) + ref_P, ref_I = naive.prescraamp(T_A, m, T_B, s=s, k=k) + + np.random.seed(seed) + comp_P, comp_I = prescraamp(T_A, m, T_B, s=s, k=k) + + npt.assert_almost_equal(ref_P, comp_P) + npt.assert_almost_equal(ref_I, comp_I)