Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions src/cuda/helpers.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,11 @@ __device__ inline void __syncwarp(uint32_t mask){} //TODO: 6.1 should have this

#include "ctranslate2/types.h"

#include <thrust/device_ptr.h>
#include <thrust/iterator/counting_iterator.h>
#include <thrust/iterator/permutation_iterator.h>
#include <thrust/iterator/transform_iterator.h>

#include "utils.h"

#ifdef CT2_USE_HIP
Expand Down
14 changes: 9 additions & 5 deletions src/cuda/primitives.cu
Original file line number Diff line number Diff line change
Expand Up @@ -18,18 +18,22 @@
#define CUBLAS_COMPUTE_32I HIPBLAS_COMPUTE_32I
#define CUDA_R_32F HIP_R_32F
#define CUDA_R_16BF HIP_R_16BF
#define cublasGemmEx hipblasGemmEx_v2
#define cublasGemmEx hipblasGemmEx
#define CUDA_R_8I HIP_R_8I
#define CUDA_R_32I HIP_R_32I
#define CUBLAS_GEMM_DEFAULT_TENSOR_OP HIPBLAS_GEMM_DEFAULT
#define cublasSgemmStridedBatched hipblasSgemmStridedBatched
#define cublasGemmStridedBatchedEx hipblasGemmStridedBatchedEx_v2
#define cublasGemmStridedBatchedEx hipblasGemmStridedBatchedEx
#else
#include <cuda_runtime.h>
#include <cublas_v2.h>
#endif

#include <thrust/device_ptr.h>
#include <thrust/reduce.h>
#include <thrust/extrema.h>
#include <thrust/iterator/counting_iterator.h>
#include <thrust/iterator/transform_iterator.h>
#include "cuda/helpers.h"
#include "type_dispatch.h"

Expand Down Expand Up @@ -517,7 +521,7 @@ namespace ctranslate2 {
}

// cuBLAS assumes column-major storage, so swap a and b accordingly.
CUBLAS_CHECK(cublasGemmEx(cuda::get_cublas_handle(),
CUBLAS_CHECK(hipblasGemmEx(cuda::get_cublas_handle(),
transpose_b ? CUBLAS_OP_T : CUBLAS_OP_N,
transpose_a ? CUBLAS_OP_T : CUBLAS_OP_N,
n, m, k,
Expand Down Expand Up @@ -572,7 +576,7 @@ namespace ctranslate2 {
int32_t beta_i = beta;

// cuBLAS assumes column-major storage, so swap a and b accordingly.
CUBLAS_CHECK(cublasGemmEx(cuda::get_cublas_handle(),
CUBLAS_CHECK(hipblasGemmEx(cuda::get_cublas_handle(),
transpose_b ? CUBLAS_OP_T : CUBLAS_OP_N,
transpose_a ? CUBLAS_OP_T : CUBLAS_OP_N,
n, m, k,
Expand Down Expand Up @@ -632,7 +636,7 @@ namespace ctranslate2 {
}

// cuBLAS assumes column-major storage, so swap a and b accordingly.
CUBLAS_CHECK(cublasGemmStridedBatchedEx(cuda::get_cublas_handle(),
CUBLAS_CHECK(hipblasGemmStridedBatchedEx(cuda::get_cublas_handle(),
transpose_b ? CUBLAS_OP_T : CUBLAS_OP_N,
transpose_a ? CUBLAS_OP_T : CUBLAS_OP_N,
n, m, k,
Expand Down
3 changes: 3 additions & 0 deletions src/ops/gumbel_max_gpu.cu
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
#include "ctranslate2/ops/gumbel_max.h"

#include <thrust/iterator/counting_iterator.h>
#include <thrust/transform.h>

#include "type_dispatch.h"
#include "cuda/helpers.h"
#include "cuda/random.h"
Expand Down