diff --git a/include/merlin/array_kernels.cuh b/include/merlin/array_kernels.cuh index f5093cc8..cfbeb762 100644 --- a/include/merlin/array_kernels.cuh +++ b/include/merlin/array_kernels.cuh @@ -17,6 +17,7 @@ #pragma once #include +#include "cub/cub.cuh" #include "cuda_runtime.h" #include "thrust/device_vector.h" #include "thrust/execution_policy.h" @@ -104,18 +105,13 @@ template void gpu_boolean_mask(size_t grid_size, size_t block_size, const bool* masks, size_t n, size_t* n_evicted, Tidx* offsets, K* __restrict keys, V* __restrict values, - S* __restrict scores, size_t dim, cudaStream_t stream) { + S* __restrict scores, Tidx* offset_ws, + size_t offset_ws_bytes, size_t dim, cudaStream_t stream) { size_t n_offsets = (n + TILE_SIZE - 1) / TILE_SIZE; gpu_cell_count <<>>(masks, offsets, n, n_evicted); -#if THRUST_VERSION >= 101600 - auto policy = thrust::cuda::par_nosync.on(stream); -#else - auto policy = thrust::cuda::par.on(stream); -#endif - thrust::device_ptr d_src(offsets); - thrust::device_ptr d_dest(offsets); - thrust::exclusive_scan(policy, d_src, d_src + n_offsets, d_dest); + CUDA_CHECK(cub::DeviceScan::ExclusiveSum(offset_ws, offset_ws_bytes, offsets, + offsets, n_offsets, stream)); gpu_select_kvm_kernel <<>>(masks, n, offsets, keys, values, scores, dim); diff --git a/include/merlin_hashtable.cuh b/include/merlin_hashtable.cuh index 17c9e126..4a790f70 100644 --- a/include/merlin_hashtable.cuh +++ b/include/merlin_hashtable.cuh @@ -26,6 +26,7 @@ #include #include #include +#include "cub/cub.cuh" #include "merlin/allocator.cuh" #include "merlin/array_kernels.cuh" #include "merlin/core_kernels.cuh" @@ -598,9 +599,20 @@ class HashTable { keys_not_empty <<>>(evicted_keys, d_masks, n); + + void* d_temp_storage = nullptr; + size_t temp_storage_bytes = 0; + CUDA_CHECK(cub::DeviceScan::ExclusiveSum(d_temp_storage, + temp_storage_bytes, d_offsets, + d_offsets, n_offsets, stream)); + auto helper_ws{ + dev_mem_pool_->get_workspace<1>(temp_storage_bytes, stream)}; + int64_t* d_temp_storage_i64 = helper_ws.get(0); + gpu_boolean_mask( grid_size, block_size, d_masks, n, d_evicted_counter, d_offsets, - evicted_keys, evicted_values, evicted_scores, dim(), stream); + evicted_keys, evicted_values, evicted_scores, d_temp_storage_i64, + temp_storage_bytes, dim(), stream); } return; }