|
1 | 1 | import numpy as np
|
| 2 | +from numba import cuda |
2 | 3 | import numpy.testing as npt
|
3 | 4 | import pandas as pd
|
4 | 5 | from scipy.spatial.distance import cdist
|
|
10 | 11 |
|
11 | 12 | import naive
|
12 | 13 |
|
| 14 | +if cuda.is_available(): |
| 15 | + from stumpy.core import _gpu_searchsorted_left, _gpu_searchsorted_right |
| 16 | +else: # pragma: no cover |
| 17 | + from stumpy.core import ( |
| 18 | + _gpu_searchsorted_left_driver_not_found as _gpu_searchsorted_left, |
| 19 | + ) |
| 20 | + from stumpy.core import ( |
| 21 | + _gpu_searchsorted_right_driver_not_found as _gpu_searchsorted_right, |
| 22 | + ) |
| 23 | + |
| 24 | +try: |
| 25 | + from numba.errors import NumbaPerformanceWarning |
| 26 | +except ModuleNotFoundError: |
| 27 | + from numba.core.errors import NumbaPerformanceWarning |
| 28 | + |
| 29 | +TEST_THREADS_PER_BLOCK = 10 |
| 30 | + |
| 31 | +if not cuda.is_available(): # pragma: no cover |
| 32 | + pytest.skip("Skipping Tests No GPUs Available", allow_module_level=True) |
| 33 | + |
| 34 | + |
| 35 | +@cuda.jit("(f8[:, :], f8[:], i8[:], i8, b1, i8[:])") |
| 36 | +def _gpu_searchsorted_kernel(a, v, bfs, nlevel, is_left, idx): |
| 37 | + # A wrapper kernel for calling device function _gpu_searchsorted_left/right. |
| 38 | + i = cuda.grid(1) |
| 39 | + if i < a.shape[0]: |
| 40 | + if is_left: |
| 41 | + idx[i] = _gpu_searchsorted_left(a[i], v[i], bfs, nlevel) |
| 42 | + else: |
| 43 | + idx[i] = _gpu_searchsorted_right(a[i], v[i], bfs, nlevel) |
| 44 | + |
13 | 45 |
|
14 | 46 | def naive_rolling_window_dot_product(Q, T):
|
15 | 47 | window = len(Q)
|
@@ -1365,3 +1397,52 @@ def test_find_matches_maxmatch():
|
1365 | 1397 | comp = core._find_matches(D, excl_zone, max_distance, max_matches)
|
1366 | 1398 |
|
1367 | 1399 | npt.assert_almost_equal(ref, comp)
|
| 1400 | + |
| 1401 | + |
| 1402 | +@pytest.mark.filterwarnings("ignore", category=NumbaPerformanceWarning) |
| 1403 | +@patch("stumpy.config.STUMPY_THREADS_PER_BLOCK", TEST_THREADS_PER_BLOCK) |
| 1404 | +def test_gpu_searchsorted(): |
| 1405 | + if not cuda.is_available(): # pragma: no cover |
| 1406 | + pytest.skip("Skipping Tests No GPUs Available", allow_module_level=True) |
| 1407 | + |
| 1408 | + n = 3 * config.STUMPY_THREADS_PER_BLOCK + 1 |
| 1409 | + V = np.empty(n, dtype=np.float64) |
| 1410 | + |
| 1411 | + threads_per_block = config.STUMPY_THREADS_PER_BLOCK |
| 1412 | + blocks_per_grid = math.ceil(n / threads_per_block) |
| 1413 | + |
| 1414 | + for k in range(1, 32): |
| 1415 | + device_bfs = cuda.to_device(core._bfs_indices(k, fill_value=-1)) |
| 1416 | + nlevel = np.floor(np.log2(k) + 1).astype(np.int64) |
| 1417 | + |
| 1418 | + A = np.sort(np.random.rand(n, k), axis=1) |
| 1419 | + device_A = cuda.to_device(A) |
| 1420 | + |
| 1421 | + V[:] = np.random.rand(n) |
| 1422 | + for i, idx in enumerate(np.random.choice(np.arange(n), size=k, replace=False)): |
| 1423 | + V[idx] = A[idx, i] # create ties |
| 1424 | + device_V = cuda.to_device(V) |
| 1425 | + |
| 1426 | + is_left = True # test case |
| 1427 | + ref_IDX = [np.searchsorted(A[i], V[i], side="left") for i in range(n)] |
| 1428 | + ref_IDX = np.asarray(ref_IDX, dtype=np.int64) |
| 1429 | + |
| 1430 | + comp_IDX = np.full(n, -1, dtype=np.int64) |
| 1431 | + device_comp_IDX = cuda.to_device(comp_IDX) |
| 1432 | + _gpu_searchsorted_kernel[blocks_per_grid, threads_per_block]( |
| 1433 | + device_A, device_V, device_bfs, nlevel, is_left, device_comp_IDX |
| 1434 | + ) |
| 1435 | + comp_IDX = device_comp_IDX.copy_to_host() |
| 1436 | + npt.assert_array_equal(ref_IDX, comp_IDX) |
| 1437 | + |
| 1438 | + is_left = False # test case |
| 1439 | + ref_IDX = [np.searchsorted(A[i], V[i], side="right") for i in range(n)] |
| 1440 | + ref_IDX = np.asarray(ref_IDX, dtype=np.int64) |
| 1441 | + |
| 1442 | + comp_IDX = np.full(n, -1, dtype=np.int64) |
| 1443 | + device_comp_IDX = cuda.to_device(comp_IDX) |
| 1444 | + _gpu_searchsorted_kernel[blocks_per_grid, threads_per_block]( |
| 1445 | + device_A, device_V, device_bfs, nlevel, is_left, device_comp_IDX |
| 1446 | + ) |
| 1447 | + comp_IDX = device_comp_IDX.copy_to_host() |
| 1448 | + npt.assert_array_equal(ref_IDX, comp_IDX) |
0 commit comments