@@ -68,7 +68,9 @@ void trtllm_fp8_per_tensor_scale_moe_launcher(
6868 TVM_FFI_ICHECK_EQ (routing_logits->ndim , 2 ) << " routing_logits must be 2D." ;
6969 TVM_FFI_ICHECK_EQ (routing_logits->shape [1 ], num_experts) << " routing_logits has incorrect shape." ;
7070 if (routing_bias.has_value ()) {
71- TVM_FFI_ICHECK_EQ (routing_bias.value ()->dtype , dl_bfloat16) << " routing_bias must be bfloat16." ;
71+ TVM_FFI_ICHECK (routing_bias.value ()->dtype == dl_bfloat16 ||
72+ routing_bias.value ()->dtype == dl_float32)
73+ << " routing_bias must be bfloat16 or float." ;
7274 TVM_FFI_ICHECK_EQ (routing_bias.value ()->ndim , 1 ) << " routing_bias must be 1D." ;
7375 TVM_FFI_ICHECK_EQ (routing_bias.value ()->shape [0 ], num_experts)
7476 << " routing_bias has incorrect shape." ;
@@ -110,6 +112,10 @@ void trtllm_fp8_per_tensor_scale_moe_launcher(
110112 args.routing_logits = routing_logits->data ;
111113 auto const routing_bias_dtype =
112114 routing_bias.has_value () ? routing_bias.value ()->dtype : dl_bfloat16;
115+ auto btg_routing_bias_dtype = btg::Dtype::Fp32;
116+ if (routing_bias_dtype == dl_bfloat16) {
117+ btg_routing_bias_dtype = btg::Dtype::Bfloat16;
118+ }
113119 args.routing_bias = routing_bias.has_value () ? routing_bias.value ()->data : nullptr ;
114120 args.hidden_states = hidden_states->data ;
115121 args.gemm1_weights = gemm1_weights->data ;
@@ -141,7 +147,7 @@ void trtllm_fp8_per_tensor_scale_moe_launcher(
141147 Tensor permuted_idx_to_token_idx =
142148 alloc_tensor ({max_num_padded_tokens}, dl_int32, routing_logits->device );
143149 Tensor expert_weights =
144- alloc_tensor ({args.num_tokens , args.top_k }, routing_bias_dtype , routing_logits->device );
150+ alloc_tensor ({args.num_tokens , args.top_k }, dl_bfloat16 , routing_logits->device );
145151 Tensor expert_indexes =
146152 alloc_tensor ({args.num_tokens , args.top_k }, dl_int32, routing_logits->device );
147153 Tensor expert_count_histogram = alloc_tensor (
@@ -185,8 +191,9 @@ void trtllm_fp8_per_tensor_scale_moe_launcher(
185191 static_cast <int *>(num_tokens_per_expert->data ),
186192 static_cast <int *>(cta_idx_xy_to_batch_idx->data ),
187193 static_cast <int *>(cta_idx_xy_to_mn_limit->data ),
188- static_cast <int *>(num_non_exiting_ctas->data ), args.mDtypeElt , use_routing_scales_on_input,
189- false /* use_deep_seek_fp8 */ , static_cast <RoutingMethodType>(routing_method_type), stream);
194+ static_cast <int *>(num_non_exiting_ctas->data ), args.mDtypeElt , btg_routing_bias_dtype,
195+ use_routing_scales_on_input, false /* use_deep_seek_fp8 */ ,
196+ static_cast <RoutingMethodType>(routing_method_type), stream);
190197
191198 // MoE kernel except routing
192199 TVM_FFI_ICHECK_EQ (hidden_states->dtype , dl_float8_e4m3fn) << " hidden_states must be fp8." ;
@@ -369,7 +376,9 @@ void trtllm_fp8_block_scale_moe_launcher(
369376
370377 auto const routing_bias_dtype =
371378 routing_bias.has_value () ? routing_bias.value ()->dtype : dl_bfloat16;
372- args.mDtypeExpW = routing_bias_dtype == dl_bfloat16 ? btg::Dtype::Bfloat16 : btg::Dtype::Fp32;
379+ auto btg_routing_bias_dtype =
380+ routing_bias_dtype == dl_bfloat16 ? btg::Dtype::Bfloat16 : btg::Dtype::Fp32;
381+
373382 args.routing_logits = static_cast <float *>(routing_logits->data );
374383 args.routing_bias = routing_bias.has_value () ? routing_bias.value ()->data : nullptr ;
375384 args.hidden_states = hidden_states->data ;
@@ -407,8 +416,10 @@ void trtllm_fp8_block_scale_moe_launcher(
407416 alloc_tensor ({args.num_tokens * args.top_k }, dl_int32, routing_logits->device );
408417 Tensor permuted_idx_to_token_idx =
409418 alloc_tensor ({max_num_padded_tokens}, dl_int32, routing_logits->device );
419+
410420 Tensor expert_weights =
411- alloc_tensor ({args.num_tokens , args.top_k }, routing_bias_dtype, routing_logits->device );
421+ alloc_tensor ({args.num_tokens , args.top_k }, dl_bfloat16, routing_logits->device );
422+ // NOTE: the output type of routing kernel is currently always bfloat16
412423 Tensor expert_indexes =
413424 alloc_tensor ({args.num_tokens , args.top_k }, dl_int32, routing_logits->device );
414425 int64_t const size_of_expert_count_histogram = std::max (num_experts * 2 , int64_t (256 * 2 ));
@@ -441,20 +452,21 @@ void trtllm_fp8_block_scale_moe_launcher(
441452
442453 tensorrt_llm::kernels::trtllmgen_moe::Routing::Runner routing_runner (tile_tokens_dim);
443454 cudaStream_t stream = get_stream (routing_logits->device );
444- routing_runner.run (static_cast <float *>(routing_logits->data ), args.routing_bias , args.num_tokens ,
445- args.num_experts , args.top_k , args.n_group , args.topk_group ,
446- args.local_expert_offset , args.local_num_experts , args.routed_scaling_factor ,
447- static_cast <int *>(expert_indexes->data ),
448- static_cast <int *>(expert_count_histogram->data ),
449- static_cast <int *>(total_num_padded_tokens->data ),
450- static_cast <int *>(expanded_idx_to_permuted_idx->data ),
451- nullptr /* static_cast<int*>(permuted_idx_to_expanded_idx->data)*/ ,
452- static_cast <int *>(permuted_idx_to_token_idx->data ), expert_weights->data ,
453- static_cast <int *>(num_tokens_per_expert->data ),
454- static_cast <int *>(cta_idx_xy_to_batch_idx->data ),
455- static_cast <int *>(cta_idx_xy_to_mn_limit->data ),
456- static_cast <int *>(num_non_exiting_ctas->data ), args.mDtypeElt , false , true ,
457- static_cast <RoutingMethodType>(routing_method_type), stream);
455+ routing_runner.run (
456+ static_cast <float *>(routing_logits->data ), args.routing_bias , args.num_tokens ,
457+ args.num_experts , args.top_k , args.n_group , args.topk_group , args.local_expert_offset ,
458+ args.local_num_experts , args.routed_scaling_factor , static_cast <int *>(expert_indexes->data ),
459+ static_cast <int *>(expert_count_histogram->data ),
460+ static_cast <int *>(total_num_padded_tokens->data ),
461+ static_cast <int *>(expanded_idx_to_permuted_idx->data ),
462+ nullptr /* static_cast<int*>(permuted_idx_to_expanded_idx->data)*/ ,
463+ static_cast <int *>(permuted_idx_to_token_idx->data ), expert_weights->data ,
464+ static_cast <int *>(num_tokens_per_expert->data ),
465+ static_cast <int *>(cta_idx_xy_to_batch_idx->data ),
466+ static_cast <int *>(cta_idx_xy_to_mn_limit->data ),
467+ static_cast <int *>(num_non_exiting_ctas->data ), args.mDtypeElt , btg_routing_bias_dtype,
468+ false /* use_routing_scales_on_input */ , true /* use_deep_seek_fp8 */ ,
469+ static_cast <RoutingMethodType>(routing_method_type), stream);
458470
459471 // MoE kernel except routing
460472 TVM_FFI_ICHECK_EQ (hidden_states->dtype , dl_float8_e4m3fn) << " hidden_states must be fp8." ;
@@ -683,7 +695,10 @@ Array<Tensor> trtllm_fp4_block_scale_moe_launcher(
683695 << " routing_logits has incorrect shape." ;
684696 }
685697 if (routing_bias.has_value ()) {
686- TVM_FFI_ICHECK_EQ (routing_bias.value ()->dtype , dl_bfloat16) << " routing_bias must be bfloat16." ;
698+ TVM_FFI_ICHECK (routing_bias.value ()->dtype == dl_bfloat16 ||
699+ routing_bias.value ()->dtype == dl_float32)
700+ << " routing_bias must be bfloat16 or float." ;
701+
687702 TVM_FFI_ICHECK_EQ (routing_bias.value ()->ndim , 1 ) << " routing_bias must be 1D." ;
688703 TVM_FFI_ICHECK_EQ (routing_bias.value ()->shape [0 ], num_experts)
689704 << " routing_bias has incorrect shape." ;
@@ -726,15 +741,14 @@ Array<Tensor> trtllm_fp4_block_scale_moe_launcher(
726741 tensorrt_llm::kernels::trtllmgen_moe::MoE::MoEWorkspace workspace;
727742
728743 // setup args
729- // note: the assumption is that output data type is always Bfloat16 (the default)
730- auto routing_bias_dtype = dl_bfloat16;
731- if (routing_bias.has_value ()) {
732- routing_bias_dtype = routing_bias.value ()->dtype ;
733- } else if (routing_logits.has_value ()) {
734- routing_bias_dtype = routing_logits.value ()->dtype ;
735- }
736744 args.mDtypeElt = dtype_act;
737- args.mDtypeExpW = routing_bias_dtype == dl_float32 ? btg::Dtype::Fp32 : btg::Dtype::Bfloat16;
745+ // note: the assumption is that output data type is always Bfloat16 (the default)
746+ auto routing_bias_dtype = routing_bias.has_value () ? routing_bias.value ()->dtype : dl_bfloat16;
747+ auto btg_routing_bias_dtype =
748+ routing_bias_dtype == dl_bfloat16 ? btg::Dtype::Bfloat16 : btg::Dtype::Fp32;
749+ // We shouln't use args.mDtypeExpW since it indicates the output data type of routing kernel,
750+ // which is currently always bfloat16 for routing kernel while the data type of routing bias now
751+ // can be fp32
738752 args.routing_logits = routing_logits.has_value () ? routing_logits.value ()->data : nullptr ;
739753 args.routing_bias = routing_bias.has_value () ? routing_bias.value ()->data : nullptr ;
740754 args.hidden_states = hidden_states->data ;
@@ -789,7 +803,7 @@ Array<Tensor> trtllm_fp4_block_scale_moe_launcher(
789803 Tensor permuted_idx_to_token_idx =
790804 alloc_tensor ({max_num_padded_tokens}, dl_int32, hidden_states->device );
791805 // Tensor expert_weights = alloc_tensor(
792- // {args.num_tokens, args.top_k}, routing_bias_dtype , hidden_states->device);
806+ // {args.num_tokens, args.top_k}, dl_bfloat16 , hidden_states->device);
793807 // Tensor expert_indexes = alloc_tensor(
794808 // {args.num_tokens, args.top_k}, dl_int32, hidden_states->device);
795809 int constexpr MAX_NUM_EXPERTS = 384 ;
@@ -833,21 +847,21 @@ Array<Tensor> trtllm_fp4_block_scale_moe_launcher(
833847
834848 tensorrt_llm::kernels::trtllmgen_moe::Routing::Runner routing_runner (tile_tokens_dim);
835849 cudaStream_t stream = get_stream (hidden_states->device );
836- routing_runner.run (args. routing_logits , args. routing_bias , args. num_tokens , args. num_experts ,
837- args.top_k , args.n_group , args.topk_group , args.local_expert_offset ,
838- args.local_num_experts , args.routed_scaling_factor ,
839- static_cast <int *>(expert_indices->data ),
840- static_cast <int *>(expert_count_histogram->data ),
841- static_cast <int *>(total_num_padded_tokens->data ),
842- static_cast <int *>(expanded_idx_to_permuted_idx->data ),
843- nullptr , /* static_cast<int*>(permuted_idx_to_expanded_idx->data),*/
844- static_cast <int *>(permuted_idx_to_token_idx->data ), expert_weights->data ,
845- static_cast <int *>(num_tokens_per_expert->data ),
846- static_cast <int *>(cta_idx_xy_to_batch_idx->data ),
847- static_cast <int *>(cta_idx_xy_to_mn_limit->data ),
848- static_cast <int *>(num_non_exiting_ctas->data ), args.mDtypeElt ,
849- false /* use_routing_scales_on_input */ , false /* use_deep_seek_fp8 */ ,
850- static_cast <RoutingMethodType>(routing_method_type), stream);
850+ routing_runner.run (
851+ args. routing_logits , args.routing_bias , args.num_tokens , args.num_experts , args.top_k ,
852+ args. n_group , args. topk_group , args.local_expert_offset , args.local_num_experts ,
853+ args. routed_scaling_factor , static_cast <int *>(expert_indices->data ),
854+ static_cast <int *>(expert_count_histogram->data ),
855+ static_cast <int *>(total_num_padded_tokens->data ),
856+ static_cast <int *>(expanded_idx_to_permuted_idx->data ),
857+ nullptr , /* static_cast<int*>(permuted_idx_to_expanded_idx->data),*/
858+ static_cast <int *>(permuted_idx_to_token_idx->data ), expert_weights->data ,
859+ static_cast <int *>(num_tokens_per_expert->data ),
860+ static_cast <int *>(cta_idx_xy_to_batch_idx->data ),
861+ static_cast <int *>(cta_idx_xy_to_mn_limit->data ),
862+ static_cast <int *>(num_non_exiting_ctas->data ), args.mDtypeElt , btg_routing_bias_dtype ,
863+ false /* use_routing_scales_on_input */ , false /* use_deep_seek_fp8 */ ,
864+ static_cast <RoutingMethodType>(routing_method_type), stream);
851865
852866 //
853867 // FC13 (gemm1) + FC2 (gemm2)
0 commit comments