Skip to content

Commit 15c293a

Browse files
authored
[None][feat] Enable nvfp4 cuda core for sm120 (#8620)
Signed-off-by: Cheng Hang <[email protected]>
1 parent bc26f4c commit 15c293a

File tree

7 files changed

+562
-12
lines changed

7 files changed

+562
-12
lines changed
Lines changed: 293 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,293 @@
1+
/*
2+
* Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
#include "cutlass/numeric_conversion.h"
18+
#include "tensorrt_llm/common/cudaFp8Utils.h"
19+
#include "tensorrt_llm/kernels/weightOnlyBatchedGemv/cudaCoreGemmNVFP4.h"
20+
#include <cub/cub.cuh>
21+
22+
namespace tensorrt_llm
23+
{
24+
namespace kernels
25+
{
26+
namespace cuda_core_gemm_nvfp4
27+
{
28+
template <typename InputType, typename OutputType, typename ScaleType, SizeType32 TILE_M, SizeType32 TILE_N,
29+
SizeType32 BLOCK_SIZE>
30+
__device__ void cudaCoreGemmImpl(InputType const* __restrict__ act, InputType const* __restrict__ weight,
31+
ScaleType const* __restrict__ scale_a, ScaleType const* __restrict__ scale_w, float const alpha,
32+
OutputType* __restrict__ output, SizeType32 m, SizeType32 n, SizeType32 k)
33+
{
34+
using VecType = int4;
35+
36+
using ScaleVecType = __nv_fp8x2_e4m3;
37+
using CvtInputType = typename tensorrt_llm::kernels::cutlass_kernels::TllmToCutlassTypeAdapter<InputType>::type;
38+
static constexpr SizeType32 step_k = static_cast<SizeType32>(128 / cutlass::sizeof_bits<CvtInputType>::value);
39+
static constexpr SizeType32 nvfp4_scale_granularity = 16;
40+
static constexpr SizeType32 step_k_scale = step_k / nvfp4_scale_granularity;
41+
static constexpr SizeType32 tile_k = step_k * BLOCK_SIZE;
42+
auto tile_id_m = static_cast<SizeType32>(blockIdx.x * TILE_M);
43+
auto tile_id_n = static_cast<SizeType32>(blockIdx.y * TILE_N);
44+
auto tid = static_cast<SizeType32>(threadIdx.x);
45+
float tile_a[step_k];
46+
float tile_w[TILE_N * step_k];
47+
float tile_a_scale[step_k_scale];
48+
float tile_w_scale[TILE_N * step_k_scale];
49+
float acc[TILE_M * TILE_N];
50+
51+
static_assert(step_k % 4 == 0);
52+
using Converter = cutlass::NumericArrayConverter<float, CvtInputType, 8>;
53+
using CvtSrcType = typename Converter::source_type;
54+
using CvtResType = typename Converter::result_type;
55+
static constexpr SizeType32 k_cvt_count = static_cast<SizeType32>(sizeof(VecType) / sizeof(CvtSrcType));
56+
57+
#pragma unroll
58+
for (SizeType32 i = 0; i < TILE_M * TILE_N; ++i)
59+
{
60+
acc[i] = 0;
61+
}
62+
act += tile_id_m * k / 2;
63+
weight += tile_id_n * k / 2;
64+
output += tile_id_m * n + tile_id_n;
65+
66+
scale_a += tile_id_m * k / nvfp4_scale_granularity;
67+
68+
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
69+
cudaGridDependencySynchronize();
70+
#endif
71+
72+
int const num_cols_sf = k / nvfp4_scale_granularity;
73+
int const num_sf_tiles_k = (num_cols_sf + 4 - 1) / 4;
74+
for (SizeType32 idx_k = tid * step_k; idx_k < k; idx_k += tile_k)
75+
{
76+
for (SizeType32 j = 0; j < TILE_N; ++j)
77+
{
78+
auto tile_w_quantized = reinterpret_cast<VecType const*>(weight + (j * k + idx_k) / 2)[0];
79+
#pragma unroll
80+
for (SizeType32 cvt_idx = 0; cvt_idx < k_cvt_count; ++cvt_idx)
81+
{
82+
reinterpret_cast<CvtResType*>(tile_w)[j * k_cvt_count + cvt_idx]
83+
= Converter::convert(reinterpret_cast<CvtSrcType*>(&tile_w_quantized)[cvt_idx]);
84+
}
85+
}
86+
for (SizeType32 j = 0; j < TILE_N; ++j)
87+
{
88+
int const row_idx = tile_id_n + j;
89+
int const col_idx = idx_k / nvfp4_scale_granularity;
90+
int const tile_offset = ((row_idx / 128) * num_sf_tiles_k + col_idx / 4) * 512;
91+
int const dst_idx = tile_offset + (row_idx % 32) * 16 + ((row_idx % 128) / 32) * 4 + col_idx % 4;
92+
auto tile_w_scale_fp8x2 = reinterpret_cast<ScaleVecType const*>(scale_w + dst_idx)[0];
93+
const char2 tmp = reinterpret_cast<char2 const&>(tile_w_scale_fp8x2);
94+
tile_w_scale[j * step_k_scale + 0] = static_cast<float>(reinterpret_cast<__nv_fp8_e4m3 const&>(tmp.x));
95+
tile_w_scale[j * step_k_scale + 1] = static_cast<float>(reinterpret_cast<__nv_fp8_e4m3 const&>(tmp.y));
96+
}
97+
#pragma unroll
98+
for (SizeType32 i = 0; i < TILE_M; ++i)
99+
{
100+
auto tile_a_quantized = reinterpret_cast<VecType const*>(act + (i * k + idx_k) / 2)[0];
101+
#pragma unroll
102+
for (SizeType32 cvt_idx = 0; cvt_idx < k_cvt_count; ++cvt_idx)
103+
{
104+
reinterpret_cast<CvtResType*>(tile_a)[cvt_idx]
105+
= Converter::convert(reinterpret_cast<CvtSrcType*>(&tile_a_quantized)[cvt_idx]);
106+
}
107+
auto tile_a_scale_fp8x2
108+
= reinterpret_cast<ScaleVecType const*>(scale_a + (i * k + idx_k) / nvfp4_scale_granularity)[0];
109+
const char2 tmp = reinterpret_cast<char2 const&>(tile_a_scale_fp8x2);
110+
tile_a_scale[0] = static_cast<float>(reinterpret_cast<__nv_fp8_e4m3 const&>(tmp.x));
111+
tile_a_scale[1] = static_cast<float>(reinterpret_cast<__nv_fp8_e4m3 const&>(tmp.y));
112+
#pragma unroll
113+
for (SizeType32 j = 0; j < TILE_N; ++j)
114+
{
115+
#pragma unroll
116+
for (SizeType32 l = 0; l < step_k; ++l)
117+
{
118+
acc[i * TILE_N + j] = fma(alpha * tile_a[l] * tile_a_scale[l / nvfp4_scale_granularity],
119+
tile_w[j * step_k + l] * tile_w_scale[j * step_k_scale + l / nvfp4_scale_granularity],
120+
acc[i * TILE_N + j]);
121+
}
122+
}
123+
}
124+
}
125+
126+
typedef cub::WarpReduce<float> WarpReduce;
127+
128+
static constexpr SizeType32 warp_size = 32;
129+
static constexpr SizeType32 warp_num = BLOCK_SIZE / warp_size;
130+
SizeType32 warp_id = tid / warp_size, lane_id = tid % warp_size;
131+
__shared__ float shmem[TILE_M * TILE_N * warp_num];
132+
__shared__ typename WarpReduce::TempStorage temp_storage[warp_num];
133+
#pragma unroll
134+
for (SizeType32 mi = 0; mi < TILE_M; ++mi)
135+
{
136+
#pragma unroll
137+
for (SizeType32 ni = 0; ni < TILE_N; ++ni)
138+
{
139+
float val = WarpReduce(temp_storage[warp_id]).Sum(acc[mi * TILE_N + ni]);
140+
if (lane_id == 0)
141+
{
142+
shmem[mi * TILE_N + ni + warp_id * TILE_M * TILE_N] = val;
143+
}
144+
}
145+
}
146+
__syncthreads();
147+
for (SizeType32 ii = tid; ii < TILE_M * TILE_N; ii += BLOCK_SIZE)
148+
{
149+
SizeType32 mid = ii / TILE_N, nid = ii % TILE_N;
150+
float val = 0;
151+
#pragma unroll
152+
for (SizeType32 jj = 0; jj < warp_num; ++jj)
153+
{
154+
val += shmem[jj * TILE_M * TILE_N + ii];
155+
}
156+
output[mid * n + nid] = static_cast<OutputType>(val);
157+
}
158+
159+
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
160+
cudaTriggerProgrammaticLaunchCompletion();
161+
#endif
162+
}
163+
164+
template <typename InputType, typename OutputType, typename ScaleType, SizeType32 TILE_M, SizeType32 TILE_N,
165+
SizeType32 BLOCK_SIZE>
166+
__global__ void cudaCoreGemmFp4(InputType const* __restrict__ act, InputType const* __restrict__ weight,
167+
ScaleType const* __restrict__ scale_a, ScaleType const* __restrict__ scale_w, float const* alpha_ptr,
168+
OutputType* __restrict__ output, SizeType32 m, SizeType32 n, SizeType32 k)
169+
{
170+
float alpha = alpha_ptr[0];
171+
cudaCoreGemmImpl<InputType, OutputType, ScaleType, TILE_M, TILE_N, BLOCK_SIZE>(
172+
reinterpret_cast<InputType const*>(act), reinterpret_cast<InputType const*>(weight),
173+
reinterpret_cast<ScaleType const*>(scale_a), reinterpret_cast<ScaleType const*>(scale_w), alpha,
174+
reinterpret_cast<OutputType*>(output), m, n, k);
175+
}
176+
177+
template <typename InputType, typename OutputType, typename ScaleType, SizeType32 TILE_M, SizeType32 TILE_N,
178+
SizeType32 BLOCK_SIZE>
179+
void cudaCoreGemmKernel(Params const& params, cudaStream_t stream)
180+
{
181+
dim3 block(BLOCK_SIZE);
182+
dim3 grid(params.m / TILE_M, params.n / TILE_N);
183+
184+
if (tensorrt_llm::common::getEnvEnablePDL())
185+
{
186+
TLLM_LOG_DEBUG("Enable PDL in fp8_gemm_plugin");
187+
cudaLaunchConfig_t kernelConfig = {0};
188+
kernelConfig.gridDim = grid;
189+
kernelConfig.blockDim = block;
190+
kernelConfig.dynamicSmemBytes = 0;
191+
kernelConfig.stream = stream;
192+
193+
cudaLaunchAttribute attribute[1];
194+
attribute[0].id = cudaLaunchAttributeProgrammaticStreamSerialization;
195+
attribute[0].val.programmaticStreamSerializationAllowed = 1;
196+
kernelConfig.attrs = attribute;
197+
kernelConfig.numAttrs = 1;
198+
199+
if (params.scale_a && params.scale_b && params.alpha_ptr)
200+
{
201+
TLLM_CUDA_CHECK(cudaLaunchKernelEx(&kernelConfig,
202+
cudaCoreGemmFp4<InputType, OutputType, ScaleType, TILE_M, TILE_N, BLOCK_SIZE>,
203+
reinterpret_cast<InputType const*>(params.act), reinterpret_cast<InputType const*>(params.weight),
204+
reinterpret_cast<ScaleType const*>(params.scale_a), reinterpret_cast<ScaleType const*>(params.scale_b),
205+
params.alpha_ptr, reinterpret_cast<OutputType*>(params.output), params.m, params.n, params.k));
206+
}
207+
}
208+
else
209+
{
210+
if (params.scale_a && params.scale_b && params.alpha_ptr)
211+
{
212+
cudaCoreGemmFp4<InputType, OutputType, ScaleType, TILE_M, TILE_N, BLOCK_SIZE><<<grid, block, 0, stream>>>(
213+
reinterpret_cast<InputType const*>(params.act), reinterpret_cast<InputType const*>(params.weight),
214+
reinterpret_cast<ScaleType const*>(params.scale_a), reinterpret_cast<ScaleType const*>(params.scale_b),
215+
params.alpha_ptr, reinterpret_cast<OutputType*>(params.output), params.m, params.n, params.k);
216+
}
217+
}
218+
}
219+
220+
template <typename InputType, typename OutputType, typename ScaleType, int TILE_M, int TILE_N, int BLOCK_SIZE>
221+
bool cudaCoreGemmTemplateCaller(Params const& params, cudaStream_t stream)
222+
{
223+
constexpr int cudaCoreGemmTemplateMaxM = 16;
224+
if (params.m == TILE_M)
225+
{
226+
cudaCoreGemmKernel<InputType, OutputType, ScaleType, TILE_M, TILE_N, BLOCK_SIZE>(params, stream);
227+
return true;
228+
}
229+
if constexpr (TILE_M < cudaCoreGemmTemplateMaxM)
230+
{
231+
return cudaCoreGemmTemplateCaller<InputType, OutputType, ScaleType, TILE_M + 1, TILE_N, BLOCK_SIZE>(
232+
params, stream);
233+
}
234+
return false;
235+
}
236+
237+
template <typename InputType, typename OutputType, typename ScaleType = float>
238+
bool cudaCoreGemmLauncher(Params const& params, cudaStream_t stream)
239+
{
240+
return cudaCoreGemmTemplateCaller<InputType, OutputType, ScaleType, 1, 2, 128>(params, stream);
241+
}
242+
243+
bool cudaCoreGemmDispatcher(Params const& params, cudaStream_t stream)
244+
{
245+
bool dispatched = true;
246+
if (params.n % 2 != 0)
247+
{
248+
dispatched = false;
249+
}
250+
else if (params.inputType == CUDA_R_8U)
251+
{
252+
if (params.k % 16 != 0)
253+
{
254+
// Expect k % 16 == 0 for nvfp4 scaling granularity
255+
dispatched = false;
256+
}
257+
else if (params.outputType == CUDA_R_16F)
258+
{
259+
dispatched = cudaCoreGemmLauncher<__nv_fp4_e2m1, half, __nv_fp8_e4m3>(params, stream);
260+
}
261+
else if (params.outputType == CUDA_R_16BF)
262+
{
263+
dispatched = cudaCoreGemmLauncher<__nv_fp4_e2m1, __nv_bfloat16, __nv_fp8_e4m3>(params, stream);
264+
}
265+
else if (params.outputType == CUDA_R_32F)
266+
{
267+
dispatched = cudaCoreGemmLauncher<__nv_fp4_e2m1, float, __nv_fp8_e4m3>(params, stream);
268+
}
269+
else
270+
{
271+
dispatched = false;
272+
}
273+
}
274+
else
275+
{
276+
dispatched = false;
277+
}
278+
279+
if (!dispatched)
280+
{
281+
TLLM_LOG_WARNING(
282+
"tensorrt_llm::kernels::cuda_core_gemm_nvfp4::cudaCoreGemmDispatcher [NOT DISPATCHED], inputType=%d, "
283+
"outputType=%d, "
284+
"m=%d, "
285+
"n=%d, k=%d",
286+
params.inputType, params.outputType, params.m, params.n, params.k);
287+
}
288+
return dispatched;
289+
}
290+
291+
} // namespace cuda_core_gemm_nvfp4
292+
} // namespace kernels
293+
} // namespace tensorrt_llm
Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
/*
2+
* Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
#pragma once
18+
#include "tensorrt_llm/common/assert.h"
19+
#include "tensorrt_llm/common/cudaUtils.h"
20+
#include "tensorrt_llm/common/envUtils.h"
21+
#include "tensorrt_llm/common/logger.h"
22+
#include "tensorrt_llm/common/quantization.h"
23+
#include "tensorrt_llm/kernels/cutlass_kernels/cutlass_type_conversion.h"
24+
#include "tensorrt_llm/runtime/common.h"
25+
26+
#include <NvInferRuntime.h>
27+
28+
#include <cassert>
29+
#include <cmath>
30+
#include <cstdint>
31+
#include <cuda_bf16.h>
32+
#include <cuda_fp16.h>
33+
#include <cuda_fp8.h>
34+
#include <cuda_runtime.h>
35+
#include <cuda_runtime_api.h>
36+
#include <iostream>
37+
38+
namespace tensorrt_llm
39+
{
40+
namespace kernels
41+
{
42+
namespace cuda_core_gemm_nvfp4
43+
{
44+
using SizeType32 = tensorrt_llm::runtime::SizeType32;
45+
46+
struct Params
47+
{
48+
void const* act;
49+
void const* weight;
50+
void* output;
51+
SizeType32 m, n, k;
52+
cudaDataType_t inputType;
53+
cudaDataType_t outputType;
54+
// torch flow
55+
__nv_fp8_e4m3 const* scale_a;
56+
__nv_fp8_e4m3 const* scale_b;
57+
float const* alpha_ptr;
58+
59+
// used by torch flow
60+
Params(void const* _act, void const* _weight, void* _output, SizeType32 _m, SizeType32 _n, SizeType32 _k,
61+
__nv_fp8_e4m3 const* _scale_a, __nv_fp8_e4m3 const* _scale_b, cudaDataType_t _inputType,
62+
cudaDataType_t _outputType, float const* _alpha_ptr)
63+
: act(_act)
64+
, weight(_weight)
65+
, output(_output)
66+
, m(_m)
67+
, n(_n)
68+
, k(_k)
69+
, inputType(_inputType)
70+
, outputType(_outputType)
71+
, scale_a(_scale_a)
72+
, scale_b(_scale_b)
73+
, alpha_ptr(_alpha_ptr)
74+
{
75+
}
76+
};
77+
78+
bool cudaCoreGemmDispatcher(Params const& params, cudaStream_t stream);
79+
} // namespace cuda_core_gemm_nvfp4
80+
} // namespace kernels
81+
} // namespace tensorrt_llm

cpp/tensorrt_llm/thop/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ add_library(
4747
cutlassScaledMM.cpp
4848
cublasScaledMM.cpp
4949
cublasFp4ScaledMM.cpp
50+
cudaNvfp4MM.cpp
5051
cudaScaledMM.cpp
5152
dynamicDecodeOp.cpp
5253
fmhaPackMaskOp.cpp

0 commit comments

Comments
 (0)