Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 1 addition & 7 deletions csrc/attention/attention_kernels.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@

#include "attention_dtypes.h"
#include "attention_utils.cuh"
#include "cuda_compat.h"

#ifdef USE_ROCM
#include <hip/hip_bf16.h>
Expand All @@ -33,12 +34,6 @@ typedef __hip_bfloat16 __nv_bfloat16;
#include "../quantization/fp8/nvidia/quant_utils.cuh"
#endif

#ifndef USE_ROCM
#define WARP_SIZE 32
#else
#define WARP_SIZE warpSize
#endif

#define MAX(a, b) ((a) > (b) ? (a) : (b))
#define MIN(a, b) ((a) < (b) ? (a) : (b))
#define DIVIDE_ROUND_UP(a, b) (((a) + (b) - 1) / (b))
Expand Down Expand Up @@ -670,7 +665,6 @@ __global__ void paged_attention_v2_reduce_kernel(

} // namespace vllm

#undef WARP_SIZE
#undef MAX
#undef MIN
#undef DIVIDE_ROUND_UP
8 changes: 1 addition & 7 deletions csrc/attention/paged_attention_v1.cu
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,7 @@
*/

#include "attention_kernels.cuh"

#ifndef USE_ROCM
#define WARP_SIZE 32
#else
#define WARP_SIZE warpSize
#endif
#include "cuda_compat.h"

#define MAX(a, b) ((a) > (b) ? (a) : (b))
#define MIN(a, b) ((a) < (b) ? (a) : (b))
Expand Down Expand Up @@ -190,7 +185,6 @@ void paged_attention_v1(
CALL_V1_LAUNCHER_BLOCK_SIZE)
}

#undef WARP_SIZE
#undef MAX
#undef MIN
#undef DIVIDE_ROUND_UP
8 changes: 1 addition & 7 deletions csrc/attention/paged_attention_v2.cu
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,7 @@
*/

#include "attention_kernels.cuh"

#ifndef USE_ROCM
#define WARP_SIZE 32
#else
#define WARP_SIZE warpSize
#endif
#include "cuda_compat.h"

#define MAX(a, b) ((a) > (b) ? (a) : (b))
#define MIN(a, b) ((a) < (b) ? (a) : (b))
Expand Down Expand Up @@ -200,7 +195,6 @@ void paged_attention_v2(
CALL_V2_LAUNCHER_BLOCK_SIZE)
}

#undef WARP_SIZE
#undef MAX
#undef MIN
#undef DIVIDE_ROUND_UP
6 changes: 3 additions & 3 deletions csrc/cuda_compat.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@
#include <hip/hip_runtime.h>
#endif

#ifndef USE_ROCM
#define WARP_SIZE 32
#if defined(USE_ROCM) && defined(__GFX9__)
#define WARP_SIZE 64
#else
#define WARP_SIZE warpSize
#define WARP_SIZE 32
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're cherry picking the wrong change. You'll always get 32 on the host doing it like this

#endif

#ifndef USE_ROCM
Expand Down
Loading