Skip to content

Commit 2b17db2

Browse files
committed
conditional finalize fusion
Signed-off-by: Enwei Zhu <[email protected]>
1 parent 313cfc3 commit 2b17db2

File tree

1 file changed

+42
-18
lines changed

1 file changed

+42
-18
lines changed

tensorrt_llm/_torch/modules/fused_moe/fused_moe_cute_dsl.py

Lines changed: 42 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -404,24 +404,48 @@ def forward_chunk_nvfp4(
404404
local_expert_offset=self.slot_start,
405405
tile_size=tile_size,
406406
)
407-
x = torch.ops.trtllm.cute_dsl_nvfp4_grouped_gemm_finalize_blackwell(
408-
input=x.view(torch.float4_e2m1fn_x2),
409-
weight=self.w2_weight.view(torch.float4_e2m1fn_x2),
410-
input_scale=x_sf.view(torch.uint8),
411-
weight_scale=self.quant_scales.fc2_weight_block.view(torch.uint8),
412-
alpha=self.quant_scales.fc2_global,
413-
tile_idx_to_group_idx=tile_idx_to_expert_idx,
414-
tile_idx_to_mn_limit=tile_idx_to_mn_limit,
415-
permuted_idx_to_expanded_idx=permuted_idx_to_expanded_idx,
416-
num_non_exiting_tiles=num_non_exiting_tiles,
417-
token_final_scales=token_final_scales,
418-
num_experts=self.num_slots,
419-
top_k=self.routing_method.experts_per_token,
420-
num_local_experts=self.expert_size_per_partition,
421-
local_expert_offset=self.slot_start,
422-
tile_size=tile_size,
423-
output_dtype=output_dtype,
424-
)
407+
if self.use_fused_finalize:
408+
x = torch.ops.trtllm.cute_dsl_nvfp4_grouped_gemm_finalize_blackwell(
409+
input=x.view(torch.float4_e2m1fn_x2),
410+
weight=self.w2_weight.view(torch.float4_e2m1fn_x2),
411+
input_scale=x_sf.view(torch.uint8),
412+
weight_scale=self.quant_scales.fc2_weight_block.view(
413+
torch.uint8),
414+
alpha=self.quant_scales.fc2_global,
415+
tile_idx_to_group_idx=tile_idx_to_expert_idx,
416+
tile_idx_to_mn_limit=tile_idx_to_mn_limit,
417+
permuted_idx_to_expanded_idx=permuted_idx_to_expanded_idx,
418+
num_non_exiting_tiles=num_non_exiting_tiles,
419+
token_final_scales=token_final_scales,
420+
num_experts=self.num_slots,
421+
top_k=self.routing_method.experts_per_token,
422+
num_local_experts=self.expert_size_per_partition,
423+
local_expert_offset=self.slot_start,
424+
tile_size=tile_size,
425+
output_dtype=output_dtype,
426+
)
427+
else:
428+
x = torch.ops.trtllm.cute_dsl_nvfp4_grouped_gemm_blackwell(
429+
input=x.view(torch.float4_e2m1fn_x2),
430+
weight=self.w2_weight.view(torch.float4_e2m1fn_x2),
431+
input_scale=x_sf.view(torch.uint8),
432+
weight_scale=self.quant_scales.fc2_weight_block.view(
433+
torch.uint8),
434+
alpha=self.quant_scales.fc2_global,
435+
tile_idx_to_group_idx=tile_idx_to_expert_idx,
436+
num_non_exiting_tiles=num_non_exiting_tiles,
437+
num_experts=self.num_slots,
438+
top_k=self.routing_method.experts_per_token,
439+
num_local_experts=self.expert_size_per_partition,
440+
local_expert_offset=self.slot_start,
441+
tile_size=tile_size,
442+
output_dtype=output_dtype,
443+
)
444+
x = torch.ops.trtllm.moe_unpermute(
445+
permuted_input=x,
446+
expanded_idx_to_permuted_idx=expanded_idx_to_permuted_idx,
447+
topk_scales=token_final_scales,
448+
)
425449
return x
426450

427451
def forward_chunk(

0 commit comments

Comments
 (0)