Skip to content

Commit 14cf404

Browse files
committed
Add 5th draft of mm_fp4 backend -- enable cudnn autotune
1 parent 8c46854 commit 14cf404

File tree

2 files changed

+58
-6
lines changed

2 files changed

+58
-6
lines changed

flashinfer/gemm.py

Lines changed: 58 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1267,6 +1267,7 @@ def build_plans_cudnn_fp4_gemm_graph(
12671267
device,
12681268
alpha,
12691269
use_nvfp4,
1270+
tactic: int = -1,
12701271
):
12711272
graph = create_cudnn_execution_plans_fp4_gemm(
12721273
a_shape,
@@ -1286,7 +1287,10 @@ def build_plans_cudnn_fp4_gemm_graph(
12861287
)
12871288

12881289
graph.check_support()
1289-
graph.build_plans()
1290+
if tactic != -1:
1291+
graph.build_plan_at_index(tactic)
1292+
else:
1293+
graph.build_plans()
12901294
return graph
12911295

12921296

@@ -1299,6 +1303,7 @@ def execute_cudnn_gemm_fp4_graph(
12991303
alpha,
13001304
c_final,
13011305
workspace_buffer,
1306+
tactic: int = -1,
13021307
):
13031308
variant_pack = {
13041309
UIDs.A_UID.value: a.view(get_native_fp4_dtype()),
@@ -1318,7 +1323,12 @@ def execute_cudnn_gemm_fp4_graph(
13181323

13191324
stream = torch.cuda.current_stream(a.device)
13201325

1321-
graph.execute(variant_pack, workspace_buffer, handle=_get_cudnn_handle(stream))
1326+
if tactic == -1:
1327+
graph.execute(variant_pack, workspace_buffer, handle=_get_cudnn_handle(stream))
1328+
else:
1329+
graph.execute_plan_at_index(
1330+
variant_pack, workspace_buffer, tactic, handle=_get_cudnn_handle(stream)
1331+
)
13221332

13231333

13241334
@functools.cache
@@ -1663,6 +1673,7 @@ def _cudnn_gemm_fp4(
16631673
block_size: int = 16,
16641674
use_nvfp4: bool = True,
16651675
workspace_buffer: torch.Tensor = None,
1676+
tactic: int = -1,
16661677
):
16671678
_check_cudnn_availability()
16681679
# the fp4 cudnn graph will be shared for both mm and bmm, so
@@ -1694,11 +1705,12 @@ def _cudnn_gemm_fp4(
16941705
a.device,
16951706
alpha is not None,
16961707
use_nvfp4,
1708+
tactic=tactic,
16971709
)
16981710

16991711
# execute the fp4 cudnn graph
17001712
execute_cudnn_gemm_fp4_graph(
1701-
graph, a, b, a_descale, b_descale, alpha, out, workspace_buffer
1713+
graph, a, b, a_descale, b_descale, alpha, out, workspace_buffer, tactic=tactic
17021714
)
17031715

17041716

@@ -1710,7 +1722,48 @@ def get_valid_tactics(
17101722
profile: OptimizationProfile,
17111723
) -> List[int]:
17121724
# cudnn has heuristic for fp4 gemm, so we only need to use the default tactic
1713-
return [0]
1725+
_check_cudnn_availability()
1726+
(
1727+
a,
1728+
b,
1729+
a_descale,
1730+
b_descale,
1731+
alpha,
1732+
out_dtype,
1733+
out,
1734+
block_size,
1735+
use_nvfp4,
1736+
workspace_buffer,
1737+
) = inputs
1738+
1739+
real_a_shape, real_a_stride = _get_real_fp4_shape_from_packed_uint8(a)
1740+
real_b_shape, real_b_stride = _get_real_fp4_shape_from_packed_uint8(b)
1741+
batch = real_a_shape[0]
1742+
expanded_a_descale_shape, expanded_a_descale_stride = (
1743+
_expand_block_scale_tensor_shape(a_descale, batch)
1744+
)
1745+
expanded_b_descale_shape, expanded_b_descale_stride = (
1746+
_expand_block_scale_tensor_shape(b_descale, batch)
1747+
)
1748+
1749+
graph = build_plans_cudnn_fp4_gemm_graph(
1750+
real_a_shape,
1751+
real_a_stride,
1752+
real_b_shape,
1753+
real_b_stride,
1754+
expanded_a_descale_shape,
1755+
expanded_a_descale_stride,
1756+
expanded_b_descale_shape,
1757+
expanded_b_descale_stride,
1758+
cudnn.data_type.FP4_E2M1,
1759+
_torch_data_type_to_cudnn_data_type(out_dtype),
1760+
block_size,
1761+
a.device,
1762+
alpha is not None,
1763+
use_nvfp4,
1764+
)
1765+
num_plans = graph.get_execution_plan_count()
1766+
return list(range(num_plans))
17141767

17151768
def forward(
17161769
self,
@@ -1742,6 +1795,7 @@ def forward(
17421795
block_size,
17431796
use_nvfp4,
17441797
workspace_buffer,
1798+
tactic=tactic,
17451799
)
17461800

17471801
return CudnnFp4GemmRunner()

tests/gemm/test_mm_fp4.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,8 +40,6 @@ def test_mm_fp4(
4040
pytest.skip("trtllm gemm does not support SM110/SM120/SM121 GPUs.")
4141
if not use_128x4_sf_layout and backend != "trtllm":
4242
pytest.skip("Skipping test for non-trtllm fp4 with use_128x4_sf_layout=False")
43-
if auto_tuning and backend == "cudnn":
44-
pytest.skip("Skipping test for cudnn fp4 with auto_tuning=True")
4543
if not use_nvfp4 and backend != "cudnn":
4644
pytest.skip("mx_fp4 is only supported for cudnn backend")
4745

0 commit comments

Comments
 (0)