|
| 1 | +/* |
| 2 | + * Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. |
| 3 | + * |
| 4 | + * Licensed under the Apache License, Version 2.0 (the "License"); |
| 5 | + * you may not use this file except in compliance with the License. |
| 6 | + * You may obtain a copy of the License at |
| 7 | + * |
| 8 | + * http://www.apache.org/licenses/LICENSE-2.0 |
| 9 | + * |
| 10 | + * Unless required by applicable law or agreed to in writing, software |
| 11 | + * distributed under the License is distributed on an "AS IS" BASIS, |
| 12 | + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 13 | + * See the License for the specific language governing permissions and |
| 14 | + * limitations under the License. |
| 15 | + */ |
| 16 | + |
| 17 | +#include "cutlass/numeric_conversion.h" |
| 18 | +#include "tensorrt_llm/common/cudaFp8Utils.h" |
| 19 | +#include "tensorrt_llm/kernels/weightOnlyBatchedGemv/cudaCoreGemmNVFP4.h" |
| 20 | +#include <cub/cub.cuh> |
| 21 | + |
| 22 | +namespace tensorrt_llm |
| 23 | +{ |
| 24 | +namespace kernels |
| 25 | +{ |
| 26 | +namespace cuda_core_gemm_nvfp4 |
| 27 | +{ |
| 28 | +template <typename InputType, typename OutputType, typename ScaleType, SizeType32 TILE_M, SizeType32 TILE_N, |
| 29 | + SizeType32 BLOCK_SIZE> |
| 30 | +__device__ void cudaCoreGemmImpl(InputType const* __restrict__ act, InputType const* __restrict__ weight, |
| 31 | + ScaleType const* __restrict__ scale_a, ScaleType const* __restrict__ scale_w, float const alpha, |
| 32 | + OutputType* __restrict__ output, SizeType32 m, SizeType32 n, SizeType32 k) |
| 33 | +{ |
| 34 | + using VecType = int4; |
| 35 | + |
| 36 | + using ScaleVecType = __nv_fp8x2_e4m3; |
| 37 | + using CvtInputType = typename tensorrt_llm::kernels::cutlass_kernels::TllmToCutlassTypeAdapter<InputType>::type; |
| 38 | + static constexpr SizeType32 step_k = static_cast<SizeType32>(128 / cutlass::sizeof_bits<CvtInputType>::value); |
| 39 | + static constexpr SizeType32 nvfp4_scale_granularity = 16; |
| 40 | + static constexpr SizeType32 step_k_scale = step_k / nvfp4_scale_granularity; |
| 41 | + static constexpr SizeType32 tile_k = step_k * BLOCK_SIZE; |
| 42 | + auto tile_id_m = static_cast<SizeType32>(blockIdx.x * TILE_M); |
| 43 | + auto tile_id_n = static_cast<SizeType32>(blockIdx.y * TILE_N); |
| 44 | + auto tid = static_cast<SizeType32>(threadIdx.x); |
| 45 | + float tile_a[step_k]; |
| 46 | + float tile_w[TILE_N * step_k]; |
| 47 | + float tile_a_scale[step_k_scale]; |
| 48 | + float tile_w_scale[TILE_N * step_k_scale]; |
| 49 | + float acc[TILE_M * TILE_N]; |
| 50 | + |
| 51 | + static_assert(step_k % 4 == 0); |
| 52 | + using Converter = cutlass::NumericArrayConverter<float, CvtInputType, 8>; |
| 53 | + using CvtSrcType = typename Converter::source_type; |
| 54 | + using CvtResType = typename Converter::result_type; |
| 55 | + static constexpr SizeType32 k_cvt_count = static_cast<SizeType32>(sizeof(VecType) / sizeof(CvtSrcType)); |
| 56 | + |
| 57 | +#pragma unroll |
| 58 | + for (SizeType32 i = 0; i < TILE_M * TILE_N; ++i) |
| 59 | + { |
| 60 | + acc[i] = 0; |
| 61 | + } |
| 62 | + act += tile_id_m * k / 2; |
| 63 | + weight += tile_id_n * k / 2; |
| 64 | + output += tile_id_m * n + tile_id_n; |
| 65 | + |
| 66 | + scale_a += tile_id_m * k / nvfp4_scale_granularity; |
| 67 | + |
| 68 | +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) |
| 69 | + cudaGridDependencySynchronize(); |
| 70 | +#endif |
| 71 | + |
| 72 | + int const num_cols_sf = k / nvfp4_scale_granularity; |
| 73 | + int const num_sf_tiles_k = (num_cols_sf + 4 - 1) / 4; |
| 74 | + for (SizeType32 idx_k = tid * step_k; idx_k < k; idx_k += tile_k) |
| 75 | + { |
| 76 | + for (SizeType32 j = 0; j < TILE_N; ++j) |
| 77 | + { |
| 78 | + auto tile_w_quantized = reinterpret_cast<VecType const*>(weight + (j * k + idx_k) / 2)[0]; |
| 79 | +#pragma unroll |
| 80 | + for (SizeType32 cvt_idx = 0; cvt_idx < k_cvt_count; ++cvt_idx) |
| 81 | + { |
| 82 | + reinterpret_cast<CvtResType*>(tile_w)[j * k_cvt_count + cvt_idx] |
| 83 | + = Converter::convert(reinterpret_cast<CvtSrcType*>(&tile_w_quantized)[cvt_idx]); |
| 84 | + } |
| 85 | + } |
| 86 | + for (SizeType32 j = 0; j < TILE_N; ++j) |
| 87 | + { |
| 88 | + int const row_idx = tile_id_n + j; |
| 89 | + int const col_idx = idx_k / nvfp4_scale_granularity; |
| 90 | + int const tile_offset = ((row_idx / 128) * num_sf_tiles_k + col_idx / 4) * 512; |
| 91 | + int const dst_idx = tile_offset + (row_idx % 32) * 16 + ((row_idx % 128) / 32) * 4 + col_idx % 4; |
| 92 | + auto tile_w_scale_fp8x2 = reinterpret_cast<ScaleVecType const*>(scale_w + dst_idx)[0]; |
| 93 | + const char2 tmp = reinterpret_cast<char2 const&>(tile_w_scale_fp8x2); |
| 94 | + tile_w_scale[j * step_k_scale + 0] = static_cast<float>(reinterpret_cast<__nv_fp8_e4m3 const&>(tmp.x)); |
| 95 | + tile_w_scale[j * step_k_scale + 1] = static_cast<float>(reinterpret_cast<__nv_fp8_e4m3 const&>(tmp.y)); |
| 96 | + } |
| 97 | +#pragma unroll |
| 98 | + for (SizeType32 i = 0; i < TILE_M; ++i) |
| 99 | + { |
| 100 | + auto tile_a_quantized = reinterpret_cast<VecType const*>(act + (i * k + idx_k) / 2)[0]; |
| 101 | +#pragma unroll |
| 102 | + for (SizeType32 cvt_idx = 0; cvt_idx < k_cvt_count; ++cvt_idx) |
| 103 | + { |
| 104 | + reinterpret_cast<CvtResType*>(tile_a)[cvt_idx] |
| 105 | + = Converter::convert(reinterpret_cast<CvtSrcType*>(&tile_a_quantized)[cvt_idx]); |
| 106 | + } |
| 107 | + auto tile_a_scale_fp8x2 |
| 108 | + = reinterpret_cast<ScaleVecType const*>(scale_a + (i * k + idx_k) / nvfp4_scale_granularity)[0]; |
| 109 | + const char2 tmp = reinterpret_cast<char2 const&>(tile_a_scale_fp8x2); |
| 110 | + tile_a_scale[0] = static_cast<float>(reinterpret_cast<__nv_fp8_e4m3 const&>(tmp.x)); |
| 111 | + tile_a_scale[1] = static_cast<float>(reinterpret_cast<__nv_fp8_e4m3 const&>(tmp.y)); |
| 112 | +#pragma unroll |
| 113 | + for (SizeType32 j = 0; j < TILE_N; ++j) |
| 114 | + { |
| 115 | +#pragma unroll |
| 116 | + for (SizeType32 l = 0; l < step_k; ++l) |
| 117 | + { |
| 118 | + acc[i * TILE_N + j] = fma(alpha * tile_a[l] * tile_a_scale[l / nvfp4_scale_granularity], |
| 119 | + tile_w[j * step_k + l] * tile_w_scale[j * step_k_scale + l / nvfp4_scale_granularity], |
| 120 | + acc[i * TILE_N + j]); |
| 121 | + } |
| 122 | + } |
| 123 | + } |
| 124 | + } |
| 125 | + |
| 126 | + typedef cub::WarpReduce<float> WarpReduce; |
| 127 | + |
| 128 | + static constexpr SizeType32 warp_size = 32; |
| 129 | + static constexpr SizeType32 warp_num = BLOCK_SIZE / warp_size; |
| 130 | + SizeType32 warp_id = tid / warp_size, lane_id = tid % warp_size; |
| 131 | + __shared__ float shmem[TILE_M * TILE_N * warp_num]; |
| 132 | + __shared__ typename WarpReduce::TempStorage temp_storage[warp_num]; |
| 133 | +#pragma unroll |
| 134 | + for (SizeType32 mi = 0; mi < TILE_M; ++mi) |
| 135 | + { |
| 136 | +#pragma unroll |
| 137 | + for (SizeType32 ni = 0; ni < TILE_N; ++ni) |
| 138 | + { |
| 139 | + float val = WarpReduce(temp_storage[warp_id]).Sum(acc[mi * TILE_N + ni]); |
| 140 | + if (lane_id == 0) |
| 141 | + { |
| 142 | + shmem[mi * TILE_N + ni + warp_id * TILE_M * TILE_N] = val; |
| 143 | + } |
| 144 | + } |
| 145 | + } |
| 146 | + __syncthreads(); |
| 147 | + for (SizeType32 ii = tid; ii < TILE_M * TILE_N; ii += BLOCK_SIZE) |
| 148 | + { |
| 149 | + SizeType32 mid = ii / TILE_N, nid = ii % TILE_N; |
| 150 | + float val = 0; |
| 151 | +#pragma unroll |
| 152 | + for (SizeType32 jj = 0; jj < warp_num; ++jj) |
| 153 | + { |
| 154 | + val += shmem[jj * TILE_M * TILE_N + ii]; |
| 155 | + } |
| 156 | + output[mid * n + nid] = static_cast<OutputType>(val); |
| 157 | + } |
| 158 | + |
| 159 | +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) |
| 160 | + cudaTriggerProgrammaticLaunchCompletion(); |
| 161 | +#endif |
| 162 | +} |
| 163 | + |
| 164 | +template <typename InputType, typename OutputType, typename ScaleType, SizeType32 TILE_M, SizeType32 TILE_N, |
| 165 | + SizeType32 BLOCK_SIZE> |
| 166 | +__global__ void cudaCoreGemmFp4(InputType const* __restrict__ act, InputType const* __restrict__ weight, |
| 167 | + ScaleType const* __restrict__ scale_a, ScaleType const* __restrict__ scale_w, float const* alpha_ptr, |
| 168 | + OutputType* __restrict__ output, SizeType32 m, SizeType32 n, SizeType32 k) |
| 169 | +{ |
| 170 | + float alpha = alpha_ptr[0]; |
| 171 | + cudaCoreGemmImpl<InputType, OutputType, ScaleType, TILE_M, TILE_N, BLOCK_SIZE>( |
| 172 | + reinterpret_cast<InputType const*>(act), reinterpret_cast<InputType const*>(weight), |
| 173 | + reinterpret_cast<ScaleType const*>(scale_a), reinterpret_cast<ScaleType const*>(scale_w), alpha, |
| 174 | + reinterpret_cast<OutputType*>(output), m, n, k); |
| 175 | +} |
| 176 | + |
| 177 | +template <typename InputType, typename OutputType, typename ScaleType, SizeType32 TILE_M, SizeType32 TILE_N, |
| 178 | + SizeType32 BLOCK_SIZE> |
| 179 | +void cudaCoreGemmKernel(Params const& params, cudaStream_t stream) |
| 180 | +{ |
| 181 | + dim3 block(BLOCK_SIZE); |
| 182 | + dim3 grid(params.m / TILE_M, params.n / TILE_N); |
| 183 | + |
| 184 | + if (tensorrt_llm::common::getEnvEnablePDL()) |
| 185 | + { |
| 186 | + TLLM_LOG_DEBUG("Enable PDL in fp8_gemm_plugin"); |
| 187 | + cudaLaunchConfig_t kernelConfig = {0}; |
| 188 | + kernelConfig.gridDim = grid; |
| 189 | + kernelConfig.blockDim = block; |
| 190 | + kernelConfig.dynamicSmemBytes = 0; |
| 191 | + kernelConfig.stream = stream; |
| 192 | + |
| 193 | + cudaLaunchAttribute attribute[1]; |
| 194 | + attribute[0].id = cudaLaunchAttributeProgrammaticStreamSerialization; |
| 195 | + attribute[0].val.programmaticStreamSerializationAllowed = 1; |
| 196 | + kernelConfig.attrs = attribute; |
| 197 | + kernelConfig.numAttrs = 1; |
| 198 | + |
| 199 | + if (params.scale_a && params.scale_b && params.alpha_ptr) |
| 200 | + { |
| 201 | + TLLM_CUDA_CHECK(cudaLaunchKernelEx(&kernelConfig, |
| 202 | + cudaCoreGemmFp4<InputType, OutputType, ScaleType, TILE_M, TILE_N, BLOCK_SIZE>, |
| 203 | + reinterpret_cast<InputType const*>(params.act), reinterpret_cast<InputType const*>(params.weight), |
| 204 | + reinterpret_cast<ScaleType const*>(params.scale_a), reinterpret_cast<ScaleType const*>(params.scale_b), |
| 205 | + params.alpha_ptr, reinterpret_cast<OutputType*>(params.output), params.m, params.n, params.k)); |
| 206 | + } |
| 207 | + } |
| 208 | + else |
| 209 | + { |
| 210 | + if (params.scale_a && params.scale_b && params.alpha_ptr) |
| 211 | + { |
| 212 | + cudaCoreGemmFp4<InputType, OutputType, ScaleType, TILE_M, TILE_N, BLOCK_SIZE><<<grid, block, 0, stream>>>( |
| 213 | + reinterpret_cast<InputType const*>(params.act), reinterpret_cast<InputType const*>(params.weight), |
| 214 | + reinterpret_cast<ScaleType const*>(params.scale_a), reinterpret_cast<ScaleType const*>(params.scale_b), |
| 215 | + params.alpha_ptr, reinterpret_cast<OutputType*>(params.output), params.m, params.n, params.k); |
| 216 | + } |
| 217 | + } |
| 218 | +} |
| 219 | + |
| 220 | +template <typename InputType, typename OutputType, typename ScaleType, int TILE_M, int TILE_N, int BLOCK_SIZE> |
| 221 | +bool cudaCoreGemmTemplateCaller(Params const& params, cudaStream_t stream) |
| 222 | +{ |
| 223 | + constexpr int cudaCoreGemmTemplateMaxM = 16; |
| 224 | + if (params.m == TILE_M) |
| 225 | + { |
| 226 | + cudaCoreGemmKernel<InputType, OutputType, ScaleType, TILE_M, TILE_N, BLOCK_SIZE>(params, stream); |
| 227 | + return true; |
| 228 | + } |
| 229 | + if constexpr (TILE_M < cudaCoreGemmTemplateMaxM) |
| 230 | + { |
| 231 | + return cudaCoreGemmTemplateCaller<InputType, OutputType, ScaleType, TILE_M + 1, TILE_N, BLOCK_SIZE>( |
| 232 | + params, stream); |
| 233 | + } |
| 234 | + return false; |
| 235 | +} |
| 236 | + |
| 237 | +template <typename InputType, typename OutputType, typename ScaleType = float> |
| 238 | +bool cudaCoreGemmLauncher(Params const& params, cudaStream_t stream) |
| 239 | +{ |
| 240 | + return cudaCoreGemmTemplateCaller<InputType, OutputType, ScaleType, 1, 2, 128>(params, stream); |
| 241 | +} |
| 242 | + |
| 243 | +bool cudaCoreGemmDispatcher(Params const& params, cudaStream_t stream) |
| 244 | +{ |
| 245 | + bool dispatched = true; |
| 246 | + if (params.n % 2 != 0) |
| 247 | + { |
| 248 | + dispatched = false; |
| 249 | + } |
| 250 | + else if (params.inputType == CUDA_R_8U) |
| 251 | + { |
| 252 | + if (params.k % 16 != 0) |
| 253 | + { |
| 254 | + // Expect k % 16 == 0 for nvfp4 scaling granularity |
| 255 | + dispatched = false; |
| 256 | + } |
| 257 | + else if (params.outputType == CUDA_R_16F) |
| 258 | + { |
| 259 | + dispatched = cudaCoreGemmLauncher<__nv_fp4_e2m1, half, __nv_fp8_e4m3>(params, stream); |
| 260 | + } |
| 261 | + else if (params.outputType == CUDA_R_16BF) |
| 262 | + { |
| 263 | + dispatched = cudaCoreGemmLauncher<__nv_fp4_e2m1, __nv_bfloat16, __nv_fp8_e4m3>(params, stream); |
| 264 | + } |
| 265 | + else if (params.outputType == CUDA_R_32F) |
| 266 | + { |
| 267 | + dispatched = cudaCoreGemmLauncher<__nv_fp4_e2m1, float, __nv_fp8_e4m3>(params, stream); |
| 268 | + } |
| 269 | + else |
| 270 | + { |
| 271 | + dispatched = false; |
| 272 | + } |
| 273 | + } |
| 274 | + else |
| 275 | + { |
| 276 | + dispatched = false; |
| 277 | + } |
| 278 | + |
| 279 | + if (!dispatched) |
| 280 | + { |
| 281 | + TLLM_LOG_WARNING( |
| 282 | + "tensorrt_llm::kernels::cuda_core_gemm_nvfp4::cudaCoreGemmDispatcher [NOT DISPATCHED], inputType=%d, " |
| 283 | + "outputType=%d, " |
| 284 | + "m=%d, " |
| 285 | + "n=%d, k=%d", |
| 286 | + params.inputType, params.outputType, params.m, params.n, params.k); |
| 287 | + } |
| 288 | + return dispatched; |
| 289 | +} |
| 290 | + |
| 291 | +} // namespace cuda_core_gemm_nvfp4 |
| 292 | +} // namespace kernels |
| 293 | +} // namespace tensorrt_llm |
0 commit comments