Skip to content

Commit 928247a

Browse files
authored
[https://nvbugs/5451205][feat] Add cuBLASLt NVFP4 GEMM backend support (#7943)
Signed-off-by: Shijie Wang <[email protected]>
1 parent 04e2b27 commit 928247a

File tree

10 files changed

+860
-5
lines changed

10 files changed

+860
-5
lines changed

cpp/CMakeLists.txt

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,15 @@ option(USING_OSS_CUTLASS_LOW_LATENCY_GEMM
5454
"Using open sourced Cutlass low latency gemm kernel" ON)
5555
option(USING_OSS_CUTLASS_FP4_GEMM "Using open sourced Cutlass fp4 gemm kernel"
5656
ON)
57+
option(ENABLE_CUBLASLT_FP4_GEMM "Enable cuBLASLt FP4 GEMM support" ON)
58+
if(NOT ${CUDAToolkit_VERSION} VERSION_GREATER_EQUAL "12.8")
59+
set(ENABLE_CUBLASLT_FP4_GEMM
60+
OFF
61+
CACHE BOOL "" FORCE)
62+
message(
63+
STATUS
64+
"CUDA ${CUDAToolkit_VERSION} < 12.8: disabling ENABLE_CUBLASLT_FP4_GEMM")
65+
endif()
5766
option(USING_OSS_CUTLASS_MOE_GEMM "Using open sourced Cutlass moe gemm kernel"
5867
ON)
5968
option(USING_OSS_CUTLASS_ALLREDUCE_GEMM

cpp/tensorrt_llm/common/CMakeLists.txt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,3 +36,7 @@ add_library(common_src OBJECT ${SRCS} ${CU_SRCS})
3636
add_cuda_architectures(common_src 89)
3737
set_property(TARGET common_src PROPERTY POSITION_INDEPENDENT_CODE ON)
3838
set_property(TARGET common_src PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON)
39+
40+
if(ENABLE_CUBLASLT_FP4_GEMM)
41+
target_compile_definitions(common_src PRIVATE ENABLE_CUBLASLT_FP4_GEMM)
42+
endif()

cpp/tensorrt_llm/common/cublasMMWrapper.cpp

Lines changed: 179 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
#include "tensorrt_llm/common/assert.h"
1919
#include "tensorrt_llm/common/cublasVersionCheck.h"
2020
#include <algorithm>
21+
#include <unordered_map>
2122

2223
#ifndef CUDART_VERSION
2324
#error CUDART_VERSION Undefined!
@@ -63,6 +64,16 @@ void CublasMMWrapper::createDescriptors(cublasOperation_t transa, cublasOperatio
6364
mOperationDesc, CUBLASLT_MATMUL_DESC_TRANSB, &transb, sizeof(cublasOperation_t)));
6465
check_cuda_error(
6566
cublasLtMatmulDescSetAttribute(mOperationDesc, CUBLASLT_MATMUL_DESC_FAST_ACCUM, &fastAcc, sizeof(int8_t)));
67+
68+
#ifdef ENABLE_CUBLASLT_FP4_GEMM
69+
// Set pointer mode for FP4 GEMM
70+
if (mAType == CUDA_R_4F_E2M1)
71+
{
72+
cublasLtPointerMode_t pointer_mode = CUBLASLT_POINTER_MODE_DEVICE;
73+
check_cuda_error(cublasLtMatmulDescSetAttribute(
74+
mOperationDesc, CUBLASLT_MATMUL_DESC_POINTER_MODE, &pointer_mode, sizeof(pointer_mode)));
75+
}
76+
#endif
6677
}
6778

6879
void CublasMMWrapper::setScaleDescriptors(void* scale_a, void* scale_b)
@@ -71,6 +82,39 @@ void CublasMMWrapper::setScaleDescriptors(void* scale_a, void* scale_b)
7182
cublasLtMatmulDescSetAttribute(mOperationDesc, CUBLASLT_MATMUL_DESC_A_SCALE_POINTER, &scale_a, sizeof(void*)));
7283
check_cuda_error(
7384
cublasLtMatmulDescSetAttribute(mOperationDesc, CUBLASLT_MATMUL_DESC_B_SCALE_POINTER, &scale_b, sizeof(void*)));
85+
86+
// Set scaling modes for FP4 GEMM
87+
if (mAType == CUDA_R_4F_E2M1)
88+
{
89+
// Set scaling mode - cuBLASLt requires e4m3 format scaling factors
90+
cublasLtMatmulMatrixScale_t AScaleMode = CUBLASLT_MATMUL_MATRIX_SCALE_VEC16_UE4M3;
91+
cublasLtMatmulMatrixScale_t BScaleMode = CUBLASLT_MATMUL_MATRIX_SCALE_VEC16_UE4M3;
92+
cublasLtMatmulMatrixScale_t CScaleMode = CUBLASLT_MATMUL_MATRIX_SCALE_SCALAR_32F;
93+
cublasLtMatmulMatrixScale_t DScaleMode = CUBLASLT_MATMUL_MATRIX_SCALE_SCALAR_32F;
94+
cublasLtMatmulMatrixScale_t DOutScaleMode = CUBLASLT_MATMUL_MATRIX_SCALE_SCALAR_32F;
95+
96+
check_cuda_error(cublasLtMatmulDescSetAttribute(
97+
mOperationDesc, CUBLASLT_MATMUL_DESC_A_SCALE_MODE, &AScaleMode, sizeof(AScaleMode)));
98+
check_cuda_error(cublasLtMatmulDescSetAttribute(
99+
mOperationDesc, CUBLASLT_MATMUL_DESC_B_SCALE_MODE, &BScaleMode, sizeof(BScaleMode)));
100+
check_cuda_error(cublasLtMatmulDescSetAttribute(
101+
mOperationDesc, CUBLASLT_MATMUL_DESC_C_SCALE_MODE, &CScaleMode, sizeof(CScaleMode)));
102+
check_cuda_error(cublasLtMatmulDescSetAttribute(
103+
mOperationDesc, CUBLASLT_MATMUL_DESC_D_SCALE_MODE, &DScaleMode, sizeof(DScaleMode)));
104+
check_cuda_error(cublasLtMatmulDescSetAttribute(
105+
mOperationDesc, CUBLASLT_MATMUL_DESC_D_OUT_SCALE_MODE, &DOutScaleMode, sizeof(DOutScaleMode)));
106+
107+
// Set C/D matrix scale pointers to nullptr
108+
void const* c_scale_ptr = nullptr;
109+
void const* d_scale_ptr = nullptr;
110+
void const* d_out_scale_ptr = nullptr;
111+
check_cuda_error(cublasLtMatmulDescSetAttribute(
112+
mOperationDesc, CUBLASLT_MATMUL_DESC_C_SCALE_POINTER, &c_scale_ptr, sizeof(c_scale_ptr)));
113+
check_cuda_error(cublasLtMatmulDescSetAttribute(
114+
mOperationDesc, CUBLASLT_MATMUL_DESC_D_SCALE_POINTER, &d_scale_ptr, sizeof(d_scale_ptr)));
115+
check_cuda_error(cublasLtMatmulDescSetAttribute(
116+
mOperationDesc, CUBLASLT_MATMUL_DESC_D_OUT_SCALE_POINTER, &d_out_scale_ptr, sizeof(d_out_scale_ptr)));
117+
}
74118
}
75119

76120
void CublasMMWrapper::setBiasDescriptor(void* bias)
@@ -247,14 +291,27 @@ void CublasMMWrapper::setFP8GemmConfig(cudaDataType_t outputType)
247291
}
248292
#endif
249293

294+
#ifdef ENABLE_CUBLASLT_FP4_GEMM
295+
void CublasMMWrapper::setFP4GemmConfig(cudaDataType_t outputType)
296+
{
297+
setGemmConfig(CUDA_R_4F_E2M1, CUDA_R_4F_E2M1, outputType, CUDA_R_32F);
298+
}
299+
#endif
300+
250301
void CublasMMWrapper::setGemmConfig(
251302
cudaDataType_t aType, cudaDataType_t bType, cudaDataType_t cType, cudaDataType_t computeType)
252303
{
253304
mAType = aType;
254305
mBType = bType;
255306
mCType = cType;
256307
bool isFp16ComputeType = computeType == CUDA_R_16F;
257-
if (isFp16ComputeType)
308+
if (mAType == CUDA_R_4F_E2M1)
309+
{
310+
// for cublaslt nvfp4 gemm, fp32 compute type and fp32 scale type are required
311+
mComputeType = CUBLAS_COMPUTE_32F;
312+
mScaleType = CUDA_R_32F;
313+
}
314+
else if (isFp16ComputeType)
258315
{
259316
mComputeType = CUBLAS_COMPUTE_16F;
260317
mScaleType = CUDA_R_16F;
@@ -481,6 +538,127 @@ std::vector<cublasLtMatmulHeuristicResult_t> CublasMMWrapper::getTactics(cublasL
481538
#endif
482539
}
483540

541+
#ifdef ENABLE_CUBLASLT_FP4_GEMM
542+
543+
namespace
544+
{
545+
// Helper function: Get or create a zero beta tensor on GPU for the given device
546+
// Beta is always 0 for FP4 GEMM and is allocated once per device per thread
547+
float const* getBetaDevicePointer()
548+
{
549+
thread_local static std::unordered_map<int, float*> beta_per_device;
550+
551+
int current_device;
552+
cudaGetDevice(&current_device);
553+
554+
auto it = beta_per_device.find(current_device);
555+
if (it == beta_per_device.end())
556+
{
557+
// Allocate GPU memory for beta and initialize to 0
558+
float* d_beta;
559+
cudaMalloc(&d_beta, sizeof(float));
560+
cudaMemset(d_beta, 0, sizeof(float));
561+
beta_per_device[current_device] = d_beta;
562+
return d_beta;
563+
}
564+
565+
return it->second;
566+
}
567+
} // namespace
568+
569+
// BlockScaleGemm Version 1: Default algorithm (uses first valid heuristic)
570+
void CublasMMWrapper::BlockScaleGemm(cublasOperation_t transa, cublasOperation_t transb, int const m, int const n,
571+
int const k, void const* A, int const lda, void const* B, int const ldb, void* C, int const ldc, void const* a_sf,
572+
void const* b_sf, float const* alpha)
573+
{
574+
// Forward to the overloaded version with nullptr (use default algorithm)
575+
BlockScaleGemm(transa, transb, m, n, k, A, lda, B, ldb, C, ldc, a_sf, b_sf, alpha, nullptr);
576+
}
577+
578+
// BlockScaleGemm Version 2: Specified algorithm (unified implementation)
579+
void CublasMMWrapper::BlockScaleGemm(cublasOperation_t transa, cublasOperation_t transb, int const m, int const n,
580+
int const k, void const* A, int const lda, void const* B, int const ldb, void* C, int const ldc, void const* a_sf,
581+
void const* b_sf, float const* alpha, cublasLtMatmulAlgo_t const* algo)
582+
{
583+
// Verify input data types (currently supports FP4, can be extended to more formats in the future)
584+
TLLM_CHECK_WITH_INFO(mAType == CUDA_R_4F_E2M1 && mBType == CUDA_R_4F_E2M1,
585+
"BlockScaleGemm currently requires FP4 input types. "
586+
"Future versions may support other quantized formats with block-wise scaling.");
587+
588+
// Validate input pointers
589+
TLLM_CHECK_WITH_INFO(A != nullptr, "A pointer is null");
590+
TLLM_CHECK_WITH_INFO(B != nullptr, "B pointer is null");
591+
TLLM_CHECK_WITH_INFO(C != nullptr, "C pointer is null");
592+
TLLM_CHECK_WITH_INFO(a_sf != nullptr, "a_sf (A scale factor) pointer is null");
593+
TLLM_CHECK_WITH_INFO(b_sf != nullptr, "b_sf (B scale factor) pointer is null");
594+
TLLM_CHECK_WITH_INFO(alpha != nullptr, "alpha pointer is null");
595+
596+
// Beta is always 0 for FP4 GEMM, get per-device GPU pointer
597+
float const* beta = getBetaDevicePointer();
598+
599+
// Create descriptors for block-scaled GEMM
600+
createDescriptors(transa, transb, m, n, k, lda, ldb, ldc, 0);
601+
602+
// Create D descriptor for output matrix
603+
cublasLtMatrixLayout_t Ddesc = NULL;
604+
check_cuda_error(cublasLtMatrixLayoutCreate(&Ddesc, mCType, m, n, ldc));
605+
606+
// Set block-wise scaling descriptors
607+
setScaleDescriptors(const_cast<void*>(a_sf), const_cast<void*>(b_sf));
608+
609+
// Validate cuBLASLt handle
610+
TLLM_CHECK_WITH_INFO(mCublasLtHandle != nullptr, "cuBLASLt handle is null");
611+
612+
// Determine which algorithm to use
613+
cublasLtMatmulAlgo_t const* selected_algo = algo;
614+
cublasLtMatmulAlgo_t default_algo;
615+
616+
if (algo == nullptr)
617+
{
618+
// No algorithm specified, use heuristic (default behavior)
619+
auto heuristics = getTactics(getCublasLtHandle(), mOperationDesc, mADesc, mBDesc, mCDesc, Ddesc);
620+
621+
if (heuristics.empty())
622+
{
623+
if (Ddesc)
624+
cublasLtMatrixLayoutDestroy(Ddesc);
625+
destroyDescriptors();
626+
throw std::runtime_error("No suitable cuBLASLt algorithm found for block-scaled GEMM");
627+
}
628+
629+
// Use the first valid heuristic
630+
auto const& heuristic = heuristics[0];
631+
bool hasAlgo = heuristic.state == CUBLAS_STATUS_SUCCESS && heuristic.workspaceSize <= CUBLAS_WORKSPACE_SIZE;
632+
633+
if (hasAlgo)
634+
{
635+
default_algo = heuristic.algo;
636+
selected_algo = &default_algo;
637+
}
638+
else
639+
{
640+
selected_algo = nullptr; // No valid algorithm, let cuBLASLt choose
641+
}
642+
}
643+
644+
int workspaceSize = mCublasWorkspace == NULL ? 0 : CUBLAS_WORKSPACE_SIZE;
645+
646+
// Call cuBLASLt matmul with selected or default algorithm
647+
check_cuda_error(cublasLtMatmul(getCublasLtHandle(), mOperationDesc, alpha, A, mADesc, B, mBDesc, beta, C, mCDesc,
648+
C, Ddesc, selected_algo, // nullptr or specific algorithm
649+
mCublasWorkspace, workspaceSize, mStream));
650+
651+
// Synchronize stream
652+
sync_check_cuda_error(mStream);
653+
654+
// Clean up descriptors
655+
if (Ddesc)
656+
cublasLtMatrixLayoutDestroy(Ddesc);
657+
destroyDescriptors();
658+
}
659+
660+
#endif
661+
484662
} // namespace common
485663

486664
} // namespace tensorrt_llm

cpp/tensorrt_llm/common/cublasMMWrapper.h

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,22 @@ class CublasMMWrapper
8383
int const lda, void const* B, int const ldb, void* C, int const ldc, float f_alpha, float f_beta,
8484
cublasLtMatmulAlgo_t const& algo, bool hasAlgo, bool usingCublasLt);
8585

86+
#ifdef ENABLE_CUBLASLT_FP4_GEMM
87+
/********************** Block-Scaled GEMMs **********************/
88+
// Generic block-scaled GEMM interface supporting FP4, FP8, and other quantized formats
89+
// that require per-block scaling factors (a_sf, b_sf)
90+
91+
// Uses default/heuristic algorithm
92+
void BlockScaleGemm(cublasOperation_t transa, cublasOperation_t transb, int const m, int const n, int const k,
93+
void const* A, int const lda, void const* B, int const ldb, void* C, int const ldc, void const* a_sf,
94+
void const* b_sf, float const* alpha);
95+
96+
// Uses specified algorithm (for autotuning)
97+
void BlockScaleGemm(cublasOperation_t transa, cublasOperation_t transb, int const m, int const n, int const k,
98+
void const* A, int const lda, void const* B, int const ldb, void* C, int const ldc, void const* a_sf,
99+
void const* b_sf, float const* alpha, cublasLtMatmulAlgo_t const* algo);
100+
#endif
101+
86102
void stridedBatchedGemm(cublasOperation_t transa, cublasOperation_t transb, int const m, int const n, int const k,
87103
void const* A, int const lda, const int64_t strideA, void const* B, int const ldb, const int64_t strideB,
88104
void* C, int const ldc, const int64_t strideC, int const batchCount, float const f_alpha = 1.0f,
@@ -120,6 +136,9 @@ class CublasMMWrapper
120136
#ifdef ENABLE_FP8
121137
void setFP8GemmConfig(cudaDataType_t outputType = CUDA_R_16F);
122138
#endif
139+
#ifdef ENABLE_CUBLASLT_FP4_GEMM
140+
void setFP4GemmConfig(cudaDataType_t outputType = CUDA_R_16BF);
141+
#endif
123142

124143
void setStream(cudaStream_t stream);
125144

@@ -142,6 +161,26 @@ class CublasMMWrapper
142161
{
143162
return *(this->mCublasLtHandle);
144163
}
164+
165+
cublasLtMatmulDesc_t getOperationDesc() const
166+
{
167+
return mOperationDesc;
168+
}
169+
170+
cublasLtMatrixLayout_t getADesc() const
171+
{
172+
return mADesc;
173+
}
174+
175+
cublasLtMatrixLayout_t getBDesc() const
176+
{
177+
return mBDesc;
178+
}
179+
180+
cublasLtMatrixLayout_t getCDesc() const
181+
{
182+
return mCDesc;
183+
}
145184
};
146185

147186
} // namespace common

cpp/tensorrt_llm/thop/CMakeLists.txt

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ add_library(
4646
convertSpecDecodingMaskToPackedMaskOp.cpp
4747
cutlassScaledMM.cpp
4848
cublasScaledMM.cpp
49+
cublasFp4ScaledMM.cpp
4950
cudaScaledMM.cpp
5051
dynamicDecodeOp.cpp
5152
fmhaPackMaskOp.cpp
@@ -115,6 +116,11 @@ if(USING_OSS_CUTLASS_MOE_GEMM)
115116
target_compile_definitions(th_common PUBLIC USING_OSS_CUTLASS_MOE_GEMM)
116117
endif()
117118

119+
if(ENABLE_CUBLASLT_FP4_GEMM)
120+
target_compile_definitions(th_common PUBLIC ENABLE_CUBLASLT_FP4_GEMM)
121+
target_link_libraries(th_common PRIVATE ${CUBLASLT_LIB})
122+
endif()
123+
118124
if(ENABLE_MULTI_DEVICE)
119125
target_include_directories(th_common PUBLIC ${MPI_C_INCLUDE_DIRS})
120126
target_link_libraries(th_common PRIVATE ${MPI_C_LIBRARIES} ${NCCL_LIB}

0 commit comments

Comments
 (0)