@@ -1267,6 +1267,7 @@ def build_plans_cudnn_fp4_gemm_graph(
12671267 device ,
12681268 alpha ,
12691269 use_nvfp4 ,
1270+ tactic : int = - 1 ,
12701271):
12711272 graph = create_cudnn_execution_plans_fp4_gemm (
12721273 a_shape ,
@@ -1286,7 +1287,10 @@ def build_plans_cudnn_fp4_gemm_graph(
12861287 )
12871288
12881289 graph .check_support ()
1289- graph .build_plans ()
1290+ if tactic != - 1 :
1291+ graph .build_plan_at_index (tactic )
1292+ else :
1293+ graph .build_plans ()
12901294 return graph
12911295
12921296
@@ -1299,6 +1303,7 @@ def execute_cudnn_gemm_fp4_graph(
12991303 alpha ,
13001304 c_final ,
13011305 workspace_buffer ,
1306+ tactic : int = - 1 ,
13021307):
13031308 variant_pack = {
13041309 UIDs .A_UID .value : a .view (get_native_fp4_dtype ()),
@@ -1318,7 +1323,12 @@ def execute_cudnn_gemm_fp4_graph(
13181323
13191324 stream = torch .cuda .current_stream (a .device )
13201325
1321- graph .execute (variant_pack , workspace_buffer , handle = _get_cudnn_handle (stream ))
1326+ if tactic == - 1 :
1327+ graph .execute (variant_pack , workspace_buffer , handle = _get_cudnn_handle (stream ))
1328+ else :
1329+ graph .execute_plan_at_index (
1330+ variant_pack , workspace_buffer , tactic , handle = _get_cudnn_handle (stream )
1331+ )
13221332
13231333
13241334@functools .cache
@@ -1663,6 +1673,7 @@ def _cudnn_gemm_fp4(
16631673 block_size : int = 16 ,
16641674 use_nvfp4 : bool = True ,
16651675 workspace_buffer : torch .Tensor = None ,
1676+ tactic : int = - 1 ,
16661677):
16671678 _check_cudnn_availability ()
16681679 # the fp4 cudnn graph will be shared for both mm and bmm, so
@@ -1694,11 +1705,12 @@ def _cudnn_gemm_fp4(
16941705 a .device ,
16951706 alpha is not None ,
16961707 use_nvfp4 ,
1708+ tactic = tactic ,
16971709 )
16981710
16991711 # execute the fp4 cudnn graph
17001712 execute_cudnn_gemm_fp4_graph (
1701- graph , a , b , a_descale , b_descale , alpha , out , workspace_buffer
1713+ graph , a , b , a_descale , b_descale , alpha , out , workspace_buffer , tactic = tactic
17021714 )
17031715
17041716
@@ -1710,7 +1722,48 @@ def get_valid_tactics(
17101722 profile : OptimizationProfile ,
17111723 ) -> List [int ]:
17121724 # cudnn has heuristic for fp4 gemm, so we only need to use the default tactic
1713- return [0 ]
1725+ _check_cudnn_availability ()
1726+ (
1727+ a ,
1728+ b ,
1729+ a_descale ,
1730+ b_descale ,
1731+ alpha ,
1732+ out_dtype ,
1733+ out ,
1734+ block_size ,
1735+ use_nvfp4 ,
1736+ workspace_buffer ,
1737+ ) = inputs
1738+
1739+ real_a_shape , real_a_stride = _get_real_fp4_shape_from_packed_uint8 (a )
1740+ real_b_shape , real_b_stride = _get_real_fp4_shape_from_packed_uint8 (b )
1741+ batch = real_a_shape [0 ]
1742+ expanded_a_descale_shape , expanded_a_descale_stride = (
1743+ _expand_block_scale_tensor_shape (a_descale , batch )
1744+ )
1745+ expanded_b_descale_shape , expanded_b_descale_stride = (
1746+ _expand_block_scale_tensor_shape (b_descale , batch )
1747+ )
1748+
1749+ graph = build_plans_cudnn_fp4_gemm_graph (
1750+ real_a_shape ,
1751+ real_a_stride ,
1752+ real_b_shape ,
1753+ real_b_stride ,
1754+ expanded_a_descale_shape ,
1755+ expanded_a_descale_stride ,
1756+ expanded_b_descale_shape ,
1757+ expanded_b_descale_stride ,
1758+ cudnn .data_type .FP4_E2M1 ,
1759+ _torch_data_type_to_cudnn_data_type (out_dtype ),
1760+ block_size ,
1761+ a .device ,
1762+ alpha is not None ,
1763+ use_nvfp4 ,
1764+ )
1765+ num_plans = graph .get_execution_plan_count ()
1766+ return list (range (num_plans ))
17141767
17151768 def forward (
17161769 self ,
@@ -1742,6 +1795,7 @@ def forward(
17421795 block_size ,
17431796 use_nvfp4 ,
17441797 workspace_buffer ,
1798+ tactic = tactic ,
17451799 )
17461800
17471801 return CudnnFp4GemmRunner ()
0 commit comments