File tree 2 files changed +6
-2
lines changed
2 files changed +6
-2
lines changed Original file line number Diff line number Diff line change @@ -100,6 +100,8 @@ function (get_torch_gpu_compiler_flags OUT_GPU_FLAGS GPU_LANG)
100
100
101
101
if (CUDA_VERSION VERSION_GREATER_EQUAL 11.8)
102
102
list (APPEND GPU_FLAGS "-DENABLE_FP8_E5M2" )
103
+ endif ()
104
+ if (CUDA_VERSION VERSION_GREATER_EQUAL 12.0)
103
105
list (REMOVE_ITEM GPU_FLAGS
104
106
"-D__CUDA_NO_HALF_OPERATORS__"
105
107
"-D__CUDA_NO_HALF_CONVERSIONS__"
Original file line number Diff line number Diff line change @@ -59,6 +59,8 @@ __global__ void rms_norm_kernel(
59
59
template <typename torch_type>
60
60
struct _typeConvert { static constexpr bool exists = false ; };
61
61
62
+ #if defined(USE_ROCM) || (defined(CUDA_VERSION) && (CUDA_VERSION >= 12000))
63
+ // CUDA < 12.0 runs into issues with packed type conversion
62
64
template <>
63
65
struct _typeConvert <c10::Half> {
64
66
static constexpr bool exists = true ;
@@ -85,8 +87,8 @@ struct _typeConvert<c10::BFloat16> {
85
87
__device__ static inline hip_type convert (float x) { return __float2bfloat16 (x); }
86
88
__device__ static inline packed_hip_type convert (float2 x) { return __float22bfloat162_rn (x); }
87
89
};
88
- #endif
89
-
90
+ #endif // defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
91
+ # endif // defined(USE_ROCM) || (defined(CUDA_VERSION) && (CUDA_VERSION >= 12000))
90
92
91
93
/* Vector POD struct to generate vectorized and packed FP16/BF16 ops
92
94
for appropriate specializations of fused_add_rms_norm_kernel.
You can’t perform that action at this time.
0 commit comments