1818#include  " tensorrt_llm/common/assert.h" 
1919#include  " tensorrt_llm/common/cublasVersionCheck.h" 
2020#include  < algorithm> 
21+ #include  < unordered_map> 
2122
2223#ifndef  CUDART_VERSION
2324#error  CUDART_VERSION Undefined!
@@ -63,6 +64,16 @@ void CublasMMWrapper::createDescriptors(cublasOperation_t transa, cublasOperatio
6364        mOperationDesc , CUBLASLT_MATMUL_DESC_TRANSB, &transb, sizeof (cublasOperation_t)));
6465    check_cuda_error (
6566        cublasLtMatmulDescSetAttribute (mOperationDesc , CUBLASLT_MATMUL_DESC_FAST_ACCUM, &fastAcc, sizeof (int8_t )));
67+ 
68+ #ifdef  ENABLE_CUBLASLT_FP4_GEMM
69+     //  Set pointer mode for FP4 GEMM
70+     if  (mAType  == CUDA_R_4F_E2M1)
71+     {
72+         cublasLtPointerMode_t pointer_mode = CUBLASLT_POINTER_MODE_DEVICE;
73+         check_cuda_error (cublasLtMatmulDescSetAttribute (
74+             mOperationDesc , CUBLASLT_MATMUL_DESC_POINTER_MODE, &pointer_mode, sizeof (pointer_mode)));
75+     }
76+ #endif 
6677}
6778
6879void  CublasMMWrapper::setScaleDescriptors (void * scale_a, void * scale_b)
@@ -71,6 +82,39 @@ void CublasMMWrapper::setScaleDescriptors(void* scale_a, void* scale_b)
7182        cublasLtMatmulDescSetAttribute (mOperationDesc , CUBLASLT_MATMUL_DESC_A_SCALE_POINTER, &scale_a, sizeof (void *)));
7283    check_cuda_error (
7384        cublasLtMatmulDescSetAttribute (mOperationDesc , CUBLASLT_MATMUL_DESC_B_SCALE_POINTER, &scale_b, sizeof (void *)));
85+ 
86+     //  Set scaling modes for FP4 GEMM
87+     if  (mAType  == CUDA_R_4F_E2M1)
88+     {
89+         //  Set scaling mode - cuBLASLt requires e4m3 format scaling factors
90+         cublasLtMatmulMatrixScale_t AScaleMode = CUBLASLT_MATMUL_MATRIX_SCALE_VEC16_UE4M3;
91+         cublasLtMatmulMatrixScale_t BScaleMode = CUBLASLT_MATMUL_MATRIX_SCALE_VEC16_UE4M3;
92+         cublasLtMatmulMatrixScale_t CScaleMode = CUBLASLT_MATMUL_MATRIX_SCALE_SCALAR_32F;
93+         cublasLtMatmulMatrixScale_t DScaleMode = CUBLASLT_MATMUL_MATRIX_SCALE_SCALAR_32F;
94+         cublasLtMatmulMatrixScale_t DOutScaleMode = CUBLASLT_MATMUL_MATRIX_SCALE_SCALAR_32F;
95+ 
96+         check_cuda_error (cublasLtMatmulDescSetAttribute (
97+             mOperationDesc , CUBLASLT_MATMUL_DESC_A_SCALE_MODE, &AScaleMode, sizeof (AScaleMode)));
98+         check_cuda_error (cublasLtMatmulDescSetAttribute (
99+             mOperationDesc , CUBLASLT_MATMUL_DESC_B_SCALE_MODE, &BScaleMode, sizeof (BScaleMode)));
100+         check_cuda_error (cublasLtMatmulDescSetAttribute (
101+             mOperationDesc , CUBLASLT_MATMUL_DESC_C_SCALE_MODE, &CScaleMode, sizeof (CScaleMode)));
102+         check_cuda_error (cublasLtMatmulDescSetAttribute (
103+             mOperationDesc , CUBLASLT_MATMUL_DESC_D_SCALE_MODE, &DScaleMode, sizeof (DScaleMode)));
104+         check_cuda_error (cublasLtMatmulDescSetAttribute (
105+             mOperationDesc , CUBLASLT_MATMUL_DESC_D_OUT_SCALE_MODE, &DOutScaleMode, sizeof (DOutScaleMode)));
106+ 
107+         //  Set C/D matrix scale pointers to nullptr
108+         void  const * c_scale_ptr = nullptr ;
109+         void  const * d_scale_ptr = nullptr ;
110+         void  const * d_out_scale_ptr = nullptr ;
111+         check_cuda_error (cublasLtMatmulDescSetAttribute (
112+             mOperationDesc , CUBLASLT_MATMUL_DESC_C_SCALE_POINTER, &c_scale_ptr, sizeof (c_scale_ptr)));
113+         check_cuda_error (cublasLtMatmulDescSetAttribute (
114+             mOperationDesc , CUBLASLT_MATMUL_DESC_D_SCALE_POINTER, &d_scale_ptr, sizeof (d_scale_ptr)));
115+         check_cuda_error (cublasLtMatmulDescSetAttribute (
116+             mOperationDesc , CUBLASLT_MATMUL_DESC_D_OUT_SCALE_POINTER, &d_out_scale_ptr, sizeof (d_out_scale_ptr)));
117+     }
74118}
75119
76120void  CublasMMWrapper::setBiasDescriptor (void * bias)
@@ -247,14 +291,27 @@ void CublasMMWrapper::setFP8GemmConfig(cudaDataType_t outputType)
247291}
248292#endif 
249293
294+ #ifdef  ENABLE_CUBLASLT_FP4_GEMM
295+ void  CublasMMWrapper::setFP4GemmConfig (cudaDataType_t outputType)
296+ {
297+     setGemmConfig (CUDA_R_4F_E2M1, CUDA_R_4F_E2M1, outputType, CUDA_R_32F);
298+ }
299+ #endif 
300+ 
250301void  CublasMMWrapper::setGemmConfig (
251302    cudaDataType_t aType, cudaDataType_t bType, cudaDataType_t cType, cudaDataType_t computeType)
252303{
253304    mAType  = aType;
254305    mBType  = bType;
255306    mCType  = cType;
256307    bool  isFp16ComputeType = computeType == CUDA_R_16F;
257-     if  (isFp16ComputeType)
308+     if  (mAType  == CUDA_R_4F_E2M1)
309+     {
310+         //  for cublaslt nvfp4 gemm, fp32 compute type and fp32 scale type are required
311+         mComputeType  = CUBLAS_COMPUTE_32F;
312+         mScaleType  = CUDA_R_32F;
313+     }
314+     else  if  (isFp16ComputeType)
258315    {
259316        mComputeType  = CUBLAS_COMPUTE_16F;
260317        mScaleType  = CUDA_R_16F;
@@ -481,6 +538,127 @@ std::vector<cublasLtMatmulHeuristicResult_t> CublasMMWrapper::getTactics(cublasL
481538#endif 
482539}
483540
541+ #ifdef  ENABLE_CUBLASLT_FP4_GEMM
542+ 
543+ namespace 
544+ {
545+ //  Helper function: Get or create a zero beta tensor on GPU for the given device
546+ //  Beta is always 0 for FP4 GEMM and is allocated once per device per thread
547+ float  const * getBetaDevicePointer ()
548+ {
549+     thread_local  static  std::unordered_map<int , float *> beta_per_device;
550+ 
551+     int  current_device;
552+     cudaGetDevice (¤t_device);
553+ 
554+     auto  it = beta_per_device.find (current_device);
555+     if  (it == beta_per_device.end ())
556+     {
557+         //  Allocate GPU memory for beta and initialize to 0
558+         float * d_beta;
559+         cudaMalloc (&d_beta, sizeof (float ));
560+         cudaMemset (d_beta, 0 , sizeof (float ));
561+         beta_per_device[current_device] = d_beta;
562+         return  d_beta;
563+     }
564+ 
565+     return  it->second ;
566+ }
567+ } //  namespace
568+ 
569+ //  BlockScaleGemm Version 1: Default algorithm (uses first valid heuristic)
570+ void  CublasMMWrapper::BlockScaleGemm (cublasOperation_t transa, cublasOperation_t transb, int  const  m, int  const  n,
571+     int  const  k, void  const * A, int  const  lda, void  const * B, int  const  ldb, void * C, int  const  ldc, void  const * a_sf,
572+     void  const * b_sf, float  const * alpha)
573+ {
574+     //  Forward to the overloaded version with nullptr (use default algorithm)
575+     BlockScaleGemm (transa, transb, m, n, k, A, lda, B, ldb, C, ldc, a_sf, b_sf, alpha, nullptr );
576+ }
577+ 
578+ //  BlockScaleGemm Version 2: Specified algorithm (unified implementation)
579+ void  CublasMMWrapper::BlockScaleGemm (cublasOperation_t transa, cublasOperation_t transb, int  const  m, int  const  n,
580+     int  const  k, void  const * A, int  const  lda, void  const * B, int  const  ldb, void * C, int  const  ldc, void  const * a_sf,
581+     void  const * b_sf, float  const * alpha, cublasLtMatmulAlgo_t const * algo)
582+ {
583+     //  Verify input data types (currently supports FP4, can be extended to more formats in the future)
584+     TLLM_CHECK_WITH_INFO (mAType  == CUDA_R_4F_E2M1 && mBType  == CUDA_R_4F_E2M1,
585+         " BlockScaleGemm currently requires FP4 input types. " 
586+         " Future versions may support other quantized formats with block-wise scaling." 
587+ 
588+     //  Validate input pointers
589+     TLLM_CHECK_WITH_INFO (A != nullptr , " A pointer is null" 
590+     TLLM_CHECK_WITH_INFO (B != nullptr , " B pointer is null" 
591+     TLLM_CHECK_WITH_INFO (C != nullptr , " C pointer is null" 
592+     TLLM_CHECK_WITH_INFO (a_sf != nullptr , " a_sf (A scale factor) pointer is null" 
593+     TLLM_CHECK_WITH_INFO (b_sf != nullptr , " b_sf (B scale factor) pointer is null" 
594+     TLLM_CHECK_WITH_INFO (alpha != nullptr , " alpha pointer is null" 
595+ 
596+     //  Beta is always 0 for FP4 GEMM, get per-device GPU pointer
597+     float  const * beta = getBetaDevicePointer ();
598+ 
599+     //  Create descriptors for block-scaled GEMM
600+     createDescriptors (transa, transb, m, n, k, lda, ldb, ldc, 0 );
601+ 
602+     //  Create D descriptor for output matrix
603+     cublasLtMatrixLayout_t Ddesc = NULL ;
604+     check_cuda_error (cublasLtMatrixLayoutCreate (&Ddesc, mCType , m, n, ldc));
605+ 
606+     //  Set block-wise scaling descriptors
607+     setScaleDescriptors (const_cast <void *>(a_sf), const_cast <void *>(b_sf));
608+ 
609+     //  Validate cuBLASLt handle
610+     TLLM_CHECK_WITH_INFO (mCublasLtHandle  != nullptr , " cuBLASLt handle is null" 
611+ 
612+     //  Determine which algorithm to use
613+     cublasLtMatmulAlgo_t const * selected_algo = algo;
614+     cublasLtMatmulAlgo_t default_algo;
615+ 
616+     if  (algo == nullptr )
617+     {
618+         //  No algorithm specified, use heuristic (default behavior)
619+         auto  heuristics = getTactics (getCublasLtHandle (), mOperationDesc , mADesc , mBDesc , mCDesc , Ddesc);
620+ 
621+         if  (heuristics.empty ())
622+         {
623+             if  (Ddesc)
624+                 cublasLtMatrixLayoutDestroy (Ddesc);
625+             destroyDescriptors ();
626+             throw  std::runtime_error (" No suitable cuBLASLt algorithm found for block-scaled GEMM" 
627+         }
628+ 
629+         //  Use the first valid heuristic
630+         auto  const & heuristic = heuristics[0 ];
631+         bool  hasAlgo = heuristic.state  == CUBLAS_STATUS_SUCCESS && heuristic.workspaceSize  <= CUBLAS_WORKSPACE_SIZE;
632+ 
633+         if  (hasAlgo)
634+         {
635+             default_algo = heuristic.algo ;
636+             selected_algo = &default_algo;
637+         }
638+         else 
639+         {
640+             selected_algo = nullptr ; //  No valid algorithm, let cuBLASLt choose
641+         }
642+     }
643+ 
644+     int  workspaceSize = mCublasWorkspace  == NULL  ? 0  : CUBLAS_WORKSPACE_SIZE;
645+ 
646+     //  Call cuBLASLt matmul with selected or default algorithm
647+     check_cuda_error (cublasLtMatmul (getCublasLtHandle (), mOperationDesc , alpha, A, mADesc , B, mBDesc , beta, C, mCDesc ,
648+         C, Ddesc, selected_algo, //  nullptr or specific algorithm
649+         mCublasWorkspace , workspaceSize, mStream ));
650+ 
651+     //  Synchronize stream
652+     sync_check_cuda_error (mStream );
653+ 
654+     //  Clean up descriptors
655+     if  (Ddesc)
656+         cublasLtMatrixLayoutDestroy (Ddesc);
657+     destroyDescriptors ();
658+ }
659+ 
660+ #endif 
661+ 
484662} //  namespace common
485663
486664} //  namespace tensorrt_llm
0 commit comments