@@ -1255,6 +1255,7 @@ def build_plans_cudnn_fp4_gemm_graph(
12551255 device ,
12561256 alpha ,
12571257 use_nvfp4 ,
1258+ tactic : int = - 1 ,
12581259):
12591260 graph = create_cudnn_execution_plans_fp4_gemm (
12601261 a_shape ,
@@ -1274,7 +1275,10 @@ def build_plans_cudnn_fp4_gemm_graph(
12741275 )
12751276
12761277 graph .check_support ()
1277- graph .build_plans ()
1278+ if tactic != - 1 :
1279+ graph .build_plan_at_index (tactic )
1280+ else :
1281+ graph .build_plans ()
12781282 return graph
12791283
12801284
@@ -1287,6 +1291,7 @@ def execute_cudnn_gemm_fp4_graph(
12871291 alpha ,
12881292 c_final ,
12891293 workspace_buffer ,
1294+ tactic : int = - 1 ,
12901295):
12911296 variant_pack = {
12921297 UIDs .A_UID .value : a .view (get_native_fp4_dtype ()),
@@ -1306,7 +1311,12 @@ def execute_cudnn_gemm_fp4_graph(
13061311
13071312 stream = torch .cuda .current_stream (a .device )
13081313
1309- graph .execute (variant_pack , workspace_buffer , handle = _get_cudnn_handle (stream ))
1314+ if tactic == - 1 :
1315+ graph .execute (variant_pack , workspace_buffer , handle = _get_cudnn_handle (stream ))
1316+ else :
1317+ graph .execute_plan_at_index (
1318+ variant_pack , workspace_buffer , tactic , handle = _get_cudnn_handle (stream )
1319+ )
13101320
13111321
13121322@functools .cache
@@ -1651,6 +1661,7 @@ def _cudnn_gemm_fp4(
16511661 block_size : int = 16 ,
16521662 use_nvfp4 : bool = True ,
16531663 workspace_buffer : torch .Tensor = None ,
1664+ tactic : int = - 1 ,
16541665):
16551666 _check_cudnn_availability ()
16561667 # the fp4 cudnn graph will be shared for both mm and bmm, so
@@ -1682,11 +1693,12 @@ def _cudnn_gemm_fp4(
16821693 a .device ,
16831694 alpha is not None ,
16841695 use_nvfp4 ,
1696+ tactic = tactic ,
16851697 )
16861698
16871699 # execute the fp4 cudnn graph
16881700 execute_cudnn_gemm_fp4_graph (
1689- graph , a , b , a_descale , b_descale , alpha , out , workspace_buffer
1701+ graph , a , b , a_descale , b_descale , alpha , out , workspace_buffer , tactic = tactic
16901702 )
16911703
16921704
@@ -1698,7 +1710,48 @@ def get_valid_tactics(
16981710 profile : OptimizationProfile ,
16991711 ) -> List [int ]:
17001712 # cudnn has heuristic for fp4 gemm, so we only need to use the default tactic
1701- return [0 ]
1713+ _check_cudnn_availability ()
1714+ (
1715+ a ,
1716+ b ,
1717+ a_descale ,
1718+ b_descale ,
1719+ alpha ,
1720+ out_dtype ,
1721+ out ,
1722+ block_size ,
1723+ use_nvfp4 ,
1724+ workspace_buffer ,
1725+ ) = inputs
1726+
1727+ real_a_shape , real_a_stride = _get_real_fp4_shape_from_packed_uint8 (a )
1728+ real_b_shape , real_b_stride = _get_real_fp4_shape_from_packed_uint8 (b )
1729+ batch = real_a_shape [0 ]
1730+ expanded_a_descale_shape , expanded_a_descale_stride = (
1731+ _expand_block_scale_tensor_shape (a_descale , batch )
1732+ )
1733+ expanded_b_descale_shape , expanded_b_descale_stride = (
1734+ _expand_block_scale_tensor_shape (b_descale , batch )
1735+ )
1736+
1737+ graph = build_plans_cudnn_fp4_gemm_graph (
1738+ real_a_shape ,
1739+ real_a_stride ,
1740+ real_b_shape ,
1741+ real_b_stride ,
1742+ expanded_a_descale_shape ,
1743+ expanded_a_descale_stride ,
1744+ expanded_b_descale_shape ,
1745+ expanded_b_descale_stride ,
1746+ cudnn .data_type .FP4_E2M1 ,
1747+ _torch_data_type_to_cudnn_data_type (out_dtype ),
1748+ block_size ,
1749+ a .device ,
1750+ alpha is not None ,
1751+ use_nvfp4 ,
1752+ )
1753+ num_plans = graph .get_execution_plan_count ()
1754+ return list (range (num_plans ))
17021755
17031756 def forward (
17041757 self ,
@@ -1730,6 +1783,7 @@ def forward(
17301783 block_size ,
17311784 use_nvfp4 ,
17321785 workspace_buffer ,
1786+ tactic = tactic ,
17331787 )
17341788
17351789 return CudnnFp4GemmRunner ()
0 commit comments