@@ -404,24 +404,48 @@ def forward_chunk_nvfp4(
404404 local_expert_offset = self .slot_start ,
405405 tile_size = tile_size ,
406406 )
407- x = torch .ops .trtllm .cute_dsl_nvfp4_grouped_gemm_finalize_blackwell (
408- input = x .view (torch .float4_e2m1fn_x2 ),
409- weight = self .w2_weight .view (torch .float4_e2m1fn_x2 ),
410- input_scale = x_sf .view (torch .uint8 ),
411- weight_scale = self .quant_scales .fc2_weight_block .view (torch .uint8 ),
412- alpha = self .quant_scales .fc2_global ,
413- tile_idx_to_group_idx = tile_idx_to_expert_idx ,
414- tile_idx_to_mn_limit = tile_idx_to_mn_limit ,
415- permuted_idx_to_expanded_idx = permuted_idx_to_expanded_idx ,
416- num_non_exiting_tiles = num_non_exiting_tiles ,
417- token_final_scales = token_final_scales ,
418- num_experts = self .num_slots ,
419- top_k = self .routing_method .experts_per_token ,
420- num_local_experts = self .expert_size_per_partition ,
421- local_expert_offset = self .slot_start ,
422- tile_size = tile_size ,
423- output_dtype = output_dtype ,
424- )
407+ if self .use_fused_finalize :
408+ x = torch .ops .trtllm .cute_dsl_nvfp4_grouped_gemm_finalize_blackwell (
409+ input = x .view (torch .float4_e2m1fn_x2 ),
410+ weight = self .w2_weight .view (torch .float4_e2m1fn_x2 ),
411+ input_scale = x_sf .view (torch .uint8 ),
412+ weight_scale = self .quant_scales .fc2_weight_block .view (
413+ torch .uint8 ),
414+ alpha = self .quant_scales .fc2_global ,
415+ tile_idx_to_group_idx = tile_idx_to_expert_idx ,
416+ tile_idx_to_mn_limit = tile_idx_to_mn_limit ,
417+ permuted_idx_to_expanded_idx = permuted_idx_to_expanded_idx ,
418+ num_non_exiting_tiles = num_non_exiting_tiles ,
419+ token_final_scales = token_final_scales ,
420+ num_experts = self .num_slots ,
421+ top_k = self .routing_method .experts_per_token ,
422+ num_local_experts = self .expert_size_per_partition ,
423+ local_expert_offset = self .slot_start ,
424+ tile_size = tile_size ,
425+ output_dtype = output_dtype ,
426+ )
427+ else :
428+ x = torch .ops .trtllm .cute_dsl_nvfp4_grouped_gemm_blackwell (
429+ input = x .view (torch .float4_e2m1fn_x2 ),
430+ weight = self .w2_weight .view (torch .float4_e2m1fn_x2 ),
431+ input_scale = x_sf .view (torch .uint8 ),
432+ weight_scale = self .quant_scales .fc2_weight_block .view (
433+ torch .uint8 ),
434+ alpha = self .quant_scales .fc2_global ,
435+ tile_idx_to_group_idx = tile_idx_to_expert_idx ,
436+ num_non_exiting_tiles = num_non_exiting_tiles ,
437+ num_experts = self .num_slots ,
438+ top_k = self .routing_method .experts_per_token ,
439+ num_local_experts = self .expert_size_per_partition ,
440+ local_expert_offset = self .slot_start ,
441+ tile_size = tile_size ,
442+ output_dtype = output_dtype ,
443+ )
444+ x = torch .ops .trtllm .moe_unpermute (
445+ permuted_input = x ,
446+ expanded_idx_to_permuted_idx = expanded_idx_to_permuted_idx ,
447+ topk_scales = token_final_scales ,
448+ )
425449 return x
426450
427451 def forward_chunk (
0 commit comments