Skip to content

Commit 59a6abf

Browse files
authored
[Hotfix][CI/Build][Kernel] CUDA 11.8 does not support layernorm optimizations (vllm-project#3782)
1 parent bc0c019 commit 59a6abf

File tree

2 files changed

+6
-2
lines changed

2 files changed

+6
-2
lines changed

cmake/utils.cmake

+2
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,8 @@ function (get_torch_gpu_compiler_flags OUT_GPU_FLAGS GPU_LANG)
100100

101101
if (CUDA_VERSION VERSION_GREATER_EQUAL 11.8)
102102
list(APPEND GPU_FLAGS "-DENABLE_FP8_E5M2")
103+
endif()
104+
if (CUDA_VERSION VERSION_GREATER_EQUAL 12.0)
103105
list(REMOVE_ITEM GPU_FLAGS
104106
"-D__CUDA_NO_HALF_OPERATORS__"
105107
"-D__CUDA_NO_HALF_CONVERSIONS__"

csrc/layernorm_kernels.cu

+4-2
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,8 @@ __global__ void rms_norm_kernel(
5959
template<typename torch_type>
6060
struct _typeConvert { static constexpr bool exists = false; };
6161

62+
#if defined(USE_ROCM) || (defined(CUDA_VERSION) && (CUDA_VERSION >= 12000))
63+
// CUDA < 12.0 runs into issues with packed type conversion
6264
template<>
6365
struct _typeConvert<c10::Half> {
6466
static constexpr bool exists = true;
@@ -85,8 +87,8 @@ struct _typeConvert<c10::BFloat16> {
8587
__device__ static inline hip_type convert(float x) { return __float2bfloat16(x); }
8688
__device__ static inline packed_hip_type convert(float2 x) { return __float22bfloat162_rn(x); }
8789
};
88-
#endif
89-
90+
#endif // defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
91+
#endif // defined(USE_ROCM) || (defined(CUDA_VERSION) && (CUDA_VERSION >= 12000))
9092

9193
/* Vector POD struct to generate vectorized and packed FP16/BF16 ops
9294
for appropriate specializations of fused_add_rms_norm_kernel.

0 commit comments

Comments
 (0)