From 40fb02be36cad4c3db987319e7a44f5e70f3ce1d Mon Sep 17 00:00:00 2001 From: mlanger Date: Mon, 29 May 2023 09:06:01 +0000 Subject: [PATCH 1/2] Add deps as submodules to ensure equal env on ARM and x86. --- .gitmodules | 8 +++++++- CMakeLists.txt | 8 ++++++-- {tests => third_party}/googletest | 0 third_party/libcudacxx | 1 + third_party/thrust | 1 + 5 files changed, 15 insertions(+), 3 deletions(-) rename {tests => third_party}/googletest (100%) create mode 160000 third_party/libcudacxx create mode 160000 third_party/thrust diff --git a/.gitmodules b/.gitmodules index 44d6d26f4..72d46f04a 100644 --- a/.gitmodules +++ b/.gitmodules @@ -1,4 +1,10 @@ [submodule "tests/googletest"] - path = tests/googletest + path = third_party/googletest url = https://github.com/google/googletest.git ignore = dirty +[submodule "third_party/libcudacxx"] + path = third_party/libcudacxx + url = https://github.com/NVIDIA/libcudacxx +[submodule "third_party/thrust"] + path = third_party/thrust + url = https://github.com/NVIDIA/thrust.git diff --git a/CMakeLists.txt b/CMakeLists.txt index f7ad51300..cb73d7871 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -75,10 +75,14 @@ message(CMAKE_CUDA_FLAGS="${CMAKE_CUDA_FLAGS}") include_directories( ${PROJECT_SOURCE_DIR}/include - ${PROJECT_SOURCE_DIR}/tests/googletest/googletest/include + ${PROJECT_SOURCE_DIR}/third_party/libcudacxx/include + ${PROJECT_SOURCE_DIR}/third_party/thrust/include + ${PROJECT_SOURCE_DIR}/third_party/googletest/googletest/include ) -ADD_SUBDIRECTORY(tests/googletest) +ADD_SUBDIRECTORY(third_party/googletest) +ADD_SUBDIRECTORY(third_party/libcudacxx) +ADD_SUBDIRECTORY(third_party/thrust) link_directories( ) diff --git a/tests/googletest b/third_party/googletest similarity index 100% rename from tests/googletest rename to third_party/googletest diff --git a/third_party/libcudacxx b/third_party/libcudacxx new file mode 160000 index 000000000..88a91da3d --- /dev/null +++ b/third_party/libcudacxx @@ -0,0 +1 @@ +Subproject commit 88a91da3dddf0eccf7c9ffd66214ddd2cad497da diff --git a/third_party/thrust b/third_party/thrust new file mode 160000 index 000000000..cf0fc2129 --- /dev/null +++ b/third_party/thrust @@ -0,0 +1 @@ +Subproject commit cf0fc2129f82cca026b1ef1c53d5facb35d06c4e From 40cd796a34d915b4db0ee0d8b575f8045631a1df Mon Sep 17 00:00:00 2001 From: mlanger Date: Mon, 29 May 2023 09:23:48 +0000 Subject: [PATCH 2/2] Simplify thrust calls. --- include/merlin_hashtable.cuh | 20 +++++++------------- 1 file changed, 7 insertions(+), 13 deletions(-) diff --git a/include/merlin_hashtable.cuh b/include/merlin_hashtable.cuh index 6d8cf4235..2a1647b43 100644 --- a/include/merlin_hashtable.cuh +++ b/include/merlin_hashtable.cuh @@ -119,12 +119,6 @@ using EraseIfPredict = bool (*)( const S& threshold ///< The threshold to compare with the `score` argument. ); -#if THRUST_VERSION >= 101600 -static constexpr auto& thrust_par = thrust::cuda::par_nosync; -#else -static constexpr auto& thrust_par = thrust::cuda::par; -#endif - /** * A HierarchicalKV hash table is a concurrent and hierarchical hash table that * is powered by GPUs and can use HBM and host memory as storage for key-value @@ -327,7 +321,7 @@ class HashTable { reinterpret_cast(d_dst)); thrust::device_ptr d_src_offset_ptr(d_src_offset); - thrust::sort_by_key(thrust_par.on(stream), d_dst_ptr, d_dst_ptr + n, + thrust::sort_by_key(thrust::cuda::par_nosync.on(stream), d_dst_ptr, d_dst_ptr + n, d_src_offset_ptr, thrust::less()); } @@ -561,7 +555,7 @@ class HashTable { thrust::device_ptr dst_ptr(reinterpret_cast(dst)); thrust::device_ptr src_offset_ptr(src_offset); - thrust::sort_by_key(thrust_par.on(stream), dst_ptr, dst_ptr + n, + thrust::sort_by_key(thrust::cuda::par_nosync.on(stream), dst_ptr, dst_ptr + n, src_offset_ptr, thrust::less()); } @@ -655,7 +649,7 @@ class HashTable { reinterpret_cast(d_table_value_addrs)); thrust::device_ptr param_key_index_ptr(param_key_index); - thrust::sort_by_key(thrust_par.on(stream), table_value_ptr, + thrust::sort_by_key(thrust::cuda::par_nosync.on(stream), table_value_ptr, table_value_ptr + n, param_key_index_ptr, thrust::less()); } @@ -825,7 +819,7 @@ class HashTable { reinterpret_cast(d_dst)); thrust::device_ptr d_src_offset_ptr(d_src_offset); - thrust::sort_by_key(thrust_par.on(stream), d_dst_ptr, d_dst_ptr + n, + thrust::sort_by_key(thrust::cuda::par_nosync.on(stream), d_dst_ptr, d_dst_ptr + n, d_src_offset_ptr, thrust::less()); } @@ -926,7 +920,7 @@ class HashTable { reinterpret_cast(src)); thrust::device_ptr dst_offset_ptr(dst_offset); - thrust::sort_by_key(thrust_par.on(stream), src_ptr, src_ptr + n, + thrust::sort_by_key(thrust::cuda::par_nosync.on(stream), src_ptr, src_ptr + n, dst_offset_ptr, thrust::less()); } @@ -1278,7 +1272,7 @@ class HashTable { for (size_type start_i = 0; start_i < N; start_i += step) { size_type end_i = std::min(start_i + step, N); - h_size += thrust::reduce(thrust_par.on(stream), size_ptr + start_i, + h_size += thrust::reduce(thrust::cuda::par_nosync.on(stream), size_ptr + start_i, size_ptr + end_i, 0, thrust::plus()); } @@ -1594,7 +1588,7 @@ class HashTable { thrust::device_ptr size_ptr(table_->buckets_size); - int size = thrust::reduce(thrust_par.on(stream), size_ptr, size_ptr + N, 0, + int size = thrust::reduce(thrust::cuda::par_nosync.on(stream), size_ptr, size_ptr + N, 0, thrust::plus()); CudaCheckError();