Skip to content

Commit c9ddc2c

Browse files
committed
Add 5th draft of mm_fp4 backend -- enable cudnn autotune
1 parent 66dbb15 commit c9ddc2c

File tree

2 files changed

+58
-6
lines changed

2 files changed

+58
-6
lines changed

flashinfer/gemm/gemm_base.py

Lines changed: 58 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1255,6 +1255,7 @@ def build_plans_cudnn_fp4_gemm_graph(
12551255
device,
12561256
alpha,
12571257
use_nvfp4,
1258+
tactic: int = -1,
12581259
):
12591260
graph = create_cudnn_execution_plans_fp4_gemm(
12601261
a_shape,
@@ -1274,7 +1275,10 @@ def build_plans_cudnn_fp4_gemm_graph(
12741275
)
12751276

12761277
graph.check_support()
1277-
graph.build_plans()
1278+
if tactic != -1:
1279+
graph.build_plan_at_index(tactic)
1280+
else:
1281+
graph.build_plans()
12781282
return graph
12791283

12801284

@@ -1287,6 +1291,7 @@ def execute_cudnn_gemm_fp4_graph(
12871291
alpha,
12881292
c_final,
12891293
workspace_buffer,
1294+
tactic: int = -1,
12901295
):
12911296
variant_pack = {
12921297
UIDs.A_UID.value: a.view(get_native_fp4_dtype()),
@@ -1306,7 +1311,12 @@ def execute_cudnn_gemm_fp4_graph(
13061311

13071312
stream = torch.cuda.current_stream(a.device)
13081313

1309-
graph.execute(variant_pack, workspace_buffer, handle=_get_cudnn_handle(stream))
1314+
if tactic == -1:
1315+
graph.execute(variant_pack, workspace_buffer, handle=_get_cudnn_handle(stream))
1316+
else:
1317+
graph.execute_plan_at_index(
1318+
variant_pack, workspace_buffer, tactic, handle=_get_cudnn_handle(stream)
1319+
)
13101320

13111321

13121322
@functools.cache
@@ -1651,6 +1661,7 @@ def _cudnn_gemm_fp4(
16511661
block_size: int = 16,
16521662
use_nvfp4: bool = True,
16531663
workspace_buffer: torch.Tensor = None,
1664+
tactic: int = -1,
16541665
):
16551666
_check_cudnn_availability()
16561667
# the fp4 cudnn graph will be shared for both mm and bmm, so
@@ -1682,11 +1693,12 @@ def _cudnn_gemm_fp4(
16821693
a.device,
16831694
alpha is not None,
16841695
use_nvfp4,
1696+
tactic=tactic,
16851697
)
16861698

16871699
# execute the fp4 cudnn graph
16881700
execute_cudnn_gemm_fp4_graph(
1689-
graph, a, b, a_descale, b_descale, alpha, out, workspace_buffer
1701+
graph, a, b, a_descale, b_descale, alpha, out, workspace_buffer, tactic=tactic
16901702
)
16911703

16921704

@@ -1698,7 +1710,48 @@ def get_valid_tactics(
16981710
profile: OptimizationProfile,
16991711
) -> List[int]:
17001712
# cudnn has heuristic for fp4 gemm, so we only need to use the default tactic
1701-
return [0]
1713+
_check_cudnn_availability()
1714+
(
1715+
a,
1716+
b,
1717+
a_descale,
1718+
b_descale,
1719+
alpha,
1720+
out_dtype,
1721+
out,
1722+
block_size,
1723+
use_nvfp4,
1724+
workspace_buffer,
1725+
) = inputs
1726+
1727+
real_a_shape, real_a_stride = _get_real_fp4_shape_from_packed_uint8(a)
1728+
real_b_shape, real_b_stride = _get_real_fp4_shape_from_packed_uint8(b)
1729+
batch = real_a_shape[0]
1730+
expanded_a_descale_shape, expanded_a_descale_stride = (
1731+
_expand_block_scale_tensor_shape(a_descale, batch)
1732+
)
1733+
expanded_b_descale_shape, expanded_b_descale_stride = (
1734+
_expand_block_scale_tensor_shape(b_descale, batch)
1735+
)
1736+
1737+
graph = build_plans_cudnn_fp4_gemm_graph(
1738+
real_a_shape,
1739+
real_a_stride,
1740+
real_b_shape,
1741+
real_b_stride,
1742+
expanded_a_descale_shape,
1743+
expanded_a_descale_stride,
1744+
expanded_b_descale_shape,
1745+
expanded_b_descale_stride,
1746+
cudnn.data_type.FP4_E2M1,
1747+
_torch_data_type_to_cudnn_data_type(out_dtype),
1748+
block_size,
1749+
a.device,
1750+
alpha is not None,
1751+
use_nvfp4,
1752+
)
1753+
num_plans = graph.get_execution_plan_count()
1754+
return list(range(num_plans))
17021755

17031756
def forward(
17041757
self,
@@ -1730,6 +1783,7 @@ def forward(
17301783
block_size,
17311784
use_nvfp4,
17321785
workspace_buffer,
1786+
tactic=tactic,
17331787
)
17341788

17351789
return CudnnFp4GemmRunner()

tests/gemm/test_mm_fp4.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,8 +40,6 @@ def test_mm_fp4(
4040
pytest.skip("trtllm gemm does not support SM110/SM120/SM121 GPUs.")
4141
if not use_128x4_sf_layout and backend != "trtllm":
4242
pytest.skip("Skipping test for non-trtllm fp4 with use_128x4_sf_layout=False")
43-
if auto_tuning and backend == "cudnn":
44-
pytest.skip("Skipping test for cudnn fp4 with auto_tuning=True")
4543
if not use_nvfp4 and backend != "cudnn":
4644
pytest.skip("mx_fp4 is only supported for cudnn backend")
4745

0 commit comments

Comments
 (0)