Skip to content

Commit d84e1d5

Browse files
wenscarlChristinaZ
andauthored
Fix bias dtype in cutlass_moe (#1876)
<!-- .github/pull_request_template.md --> ## 📌 Description co-authored by @ChristinaZ(major contribution) After this change, for deepseek mode, now logits is always fp32, bias can be bfloat16 and fp32 now. ## 🔍 Related Issues <!-- Link any related issues here --> ## 🚀 Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### ✅ Pre-commit Checks - [ ] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [ ] I have installed the hooks with `pre-commit install`. - [ ] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## 🧪 Tests - [ ] Tests have been added or updated as needed. - [ ] All tests are passing (`unittest`, etc.). ## Reviewer Notes <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. --> --------- Signed-off-by: Christina Zhang <[email protected]> Co-authored-by: Christina Zhang <[email protected]>
1 parent 9ee58ac commit d84e1d5

File tree

6 files changed

+151
-80
lines changed

6 files changed

+151
-80
lines changed

csrc/trtllm_fused_moe_kernel_launcher.cu

Lines changed: 59 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,9 @@ void trtllm_fp8_per_tensor_scale_moe_launcher(
6868
TVM_FFI_ICHECK_EQ(routing_logits->ndim, 2) << "routing_logits must be 2D.";
6969
TVM_FFI_ICHECK_EQ(routing_logits->shape[1], num_experts) << "routing_logits has incorrect shape.";
7070
if (routing_bias.has_value()) {
71-
TVM_FFI_ICHECK_EQ(routing_bias.value()->dtype, dl_bfloat16) << "routing_bias must be bfloat16.";
71+
TVM_FFI_ICHECK(routing_bias.value()->dtype == dl_bfloat16 ||
72+
routing_bias.value()->dtype == dl_float32)
73+
<< "routing_bias must be bfloat16 or float.";
7274
TVM_FFI_ICHECK_EQ(routing_bias.value()->ndim, 1) << "routing_bias must be 1D.";
7375
TVM_FFI_ICHECK_EQ(routing_bias.value()->shape[0], num_experts)
7476
<< "routing_bias has incorrect shape.";
@@ -110,6 +112,10 @@ void trtllm_fp8_per_tensor_scale_moe_launcher(
110112
args.routing_logits = routing_logits->data;
111113
auto const routing_bias_dtype =
112114
routing_bias.has_value() ? routing_bias.value()->dtype : dl_bfloat16;
115+
auto btg_routing_bias_dtype = btg::Dtype::Fp32;
116+
if (routing_bias_dtype == dl_bfloat16) {
117+
btg_routing_bias_dtype = btg::Dtype::Bfloat16;
118+
}
113119
args.routing_bias = routing_bias.has_value() ? routing_bias.value()->data : nullptr;
114120
args.hidden_states = hidden_states->data;
115121
args.gemm1_weights = gemm1_weights->data;
@@ -141,7 +147,7 @@ void trtllm_fp8_per_tensor_scale_moe_launcher(
141147
Tensor permuted_idx_to_token_idx =
142148
alloc_tensor({max_num_padded_tokens}, dl_int32, routing_logits->device);
143149
Tensor expert_weights =
144-
alloc_tensor({args.num_tokens, args.top_k}, routing_bias_dtype, routing_logits->device);
150+
alloc_tensor({args.num_tokens, args.top_k}, dl_bfloat16, routing_logits->device);
145151
Tensor expert_indexes =
146152
alloc_tensor({args.num_tokens, args.top_k}, dl_int32, routing_logits->device);
147153
Tensor expert_count_histogram = alloc_tensor(
@@ -185,8 +191,9 @@ void trtllm_fp8_per_tensor_scale_moe_launcher(
185191
static_cast<int*>(num_tokens_per_expert->data),
186192
static_cast<int*>(cta_idx_xy_to_batch_idx->data),
187193
static_cast<int*>(cta_idx_xy_to_mn_limit->data),
188-
static_cast<int*>(num_non_exiting_ctas->data), args.mDtypeElt, use_routing_scales_on_input,
189-
false /* use_deep_seek_fp8 */, static_cast<RoutingMethodType>(routing_method_type), stream);
194+
static_cast<int*>(num_non_exiting_ctas->data), args.mDtypeElt, btg_routing_bias_dtype,
195+
use_routing_scales_on_input, false /* use_deep_seek_fp8 */,
196+
static_cast<RoutingMethodType>(routing_method_type), stream);
190197

191198
// MoE kernel except routing
192199
TVM_FFI_ICHECK_EQ(hidden_states->dtype, dl_float8_e4m3fn) << "hidden_states must be fp8.";
@@ -369,7 +376,9 @@ void trtllm_fp8_block_scale_moe_launcher(
369376

370377
auto const routing_bias_dtype =
371378
routing_bias.has_value() ? routing_bias.value()->dtype : dl_bfloat16;
372-
args.mDtypeExpW = routing_bias_dtype == dl_bfloat16 ? btg::Dtype::Bfloat16 : btg::Dtype::Fp32;
379+
auto btg_routing_bias_dtype =
380+
routing_bias_dtype == dl_bfloat16 ? btg::Dtype::Bfloat16 : btg::Dtype::Fp32;
381+
373382
args.routing_logits = static_cast<float*>(routing_logits->data);
374383
args.routing_bias = routing_bias.has_value() ? routing_bias.value()->data : nullptr;
375384
args.hidden_states = hidden_states->data;
@@ -407,8 +416,10 @@ void trtllm_fp8_block_scale_moe_launcher(
407416
alloc_tensor({args.num_tokens * args.top_k}, dl_int32, routing_logits->device);
408417
Tensor permuted_idx_to_token_idx =
409418
alloc_tensor({max_num_padded_tokens}, dl_int32, routing_logits->device);
419+
410420
Tensor expert_weights =
411-
alloc_tensor({args.num_tokens, args.top_k}, routing_bias_dtype, routing_logits->device);
421+
alloc_tensor({args.num_tokens, args.top_k}, dl_bfloat16, routing_logits->device);
422+
// NOTE: the output type of routing kernel is currently always bfloat16
412423
Tensor expert_indexes =
413424
alloc_tensor({args.num_tokens, args.top_k}, dl_int32, routing_logits->device);
414425
int64_t const size_of_expert_count_histogram = std::max(num_experts * 2, int64_t(256 * 2));
@@ -441,20 +452,21 @@ void trtllm_fp8_block_scale_moe_launcher(
441452

442453
tensorrt_llm::kernels::trtllmgen_moe::Routing::Runner routing_runner(tile_tokens_dim);
443454
cudaStream_t stream = get_stream(routing_logits->device);
444-
routing_runner.run(static_cast<float*>(routing_logits->data), args.routing_bias, args.num_tokens,
445-
args.num_experts, args.top_k, args.n_group, args.topk_group,
446-
args.local_expert_offset, args.local_num_experts, args.routed_scaling_factor,
447-
static_cast<int*>(expert_indexes->data),
448-
static_cast<int*>(expert_count_histogram->data),
449-
static_cast<int*>(total_num_padded_tokens->data),
450-
static_cast<int*>(expanded_idx_to_permuted_idx->data),
451-
nullptr /*static_cast<int*>(permuted_idx_to_expanded_idx->data)*/,
452-
static_cast<int*>(permuted_idx_to_token_idx->data), expert_weights->data,
453-
static_cast<int*>(num_tokens_per_expert->data),
454-
static_cast<int*>(cta_idx_xy_to_batch_idx->data),
455-
static_cast<int*>(cta_idx_xy_to_mn_limit->data),
456-
static_cast<int*>(num_non_exiting_ctas->data), args.mDtypeElt, false, true,
457-
static_cast<RoutingMethodType>(routing_method_type), stream);
455+
routing_runner.run(
456+
static_cast<float*>(routing_logits->data), args.routing_bias, args.num_tokens,
457+
args.num_experts, args.top_k, args.n_group, args.topk_group, args.local_expert_offset,
458+
args.local_num_experts, args.routed_scaling_factor, static_cast<int*>(expert_indexes->data),
459+
static_cast<int*>(expert_count_histogram->data),
460+
static_cast<int*>(total_num_padded_tokens->data),
461+
static_cast<int*>(expanded_idx_to_permuted_idx->data),
462+
nullptr /*static_cast<int*>(permuted_idx_to_expanded_idx->data)*/,
463+
static_cast<int*>(permuted_idx_to_token_idx->data), expert_weights->data,
464+
static_cast<int*>(num_tokens_per_expert->data),
465+
static_cast<int*>(cta_idx_xy_to_batch_idx->data),
466+
static_cast<int*>(cta_idx_xy_to_mn_limit->data),
467+
static_cast<int*>(num_non_exiting_ctas->data), args.mDtypeElt, btg_routing_bias_dtype,
468+
false /* use_routing_scales_on_input */, true /* use_deep_seek_fp8 */,
469+
static_cast<RoutingMethodType>(routing_method_type), stream);
458470

459471
// MoE kernel except routing
460472
TVM_FFI_ICHECK_EQ(hidden_states->dtype, dl_float8_e4m3fn) << "hidden_states must be fp8.";
@@ -683,7 +695,10 @@ Array<Tensor> trtllm_fp4_block_scale_moe_launcher(
683695
<< "routing_logits has incorrect shape.";
684696
}
685697
if (routing_bias.has_value()) {
686-
TVM_FFI_ICHECK_EQ(routing_bias.value()->dtype, dl_bfloat16) << "routing_bias must be bfloat16.";
698+
TVM_FFI_ICHECK(routing_bias.value()->dtype == dl_bfloat16 ||
699+
routing_bias.value()->dtype == dl_float32)
700+
<< "routing_bias must be bfloat16 or float.";
701+
687702
TVM_FFI_ICHECK_EQ(routing_bias.value()->ndim, 1) << "routing_bias must be 1D.";
688703
TVM_FFI_ICHECK_EQ(routing_bias.value()->shape[0], num_experts)
689704
<< "routing_bias has incorrect shape.";
@@ -726,15 +741,14 @@ Array<Tensor> trtllm_fp4_block_scale_moe_launcher(
726741
tensorrt_llm::kernels::trtllmgen_moe::MoE::MoEWorkspace workspace;
727742

728743
// setup args
729-
// note: the assumption is that output data type is always Bfloat16 (the default)
730-
auto routing_bias_dtype = dl_bfloat16;
731-
if (routing_bias.has_value()) {
732-
routing_bias_dtype = routing_bias.value()->dtype;
733-
} else if (routing_logits.has_value()) {
734-
routing_bias_dtype = routing_logits.value()->dtype;
735-
}
736744
args.mDtypeElt = dtype_act;
737-
args.mDtypeExpW = routing_bias_dtype == dl_float32 ? btg::Dtype::Fp32 : btg::Dtype::Bfloat16;
745+
// note: the assumption is that output data type is always Bfloat16 (the default)
746+
auto routing_bias_dtype = routing_bias.has_value() ? routing_bias.value()->dtype : dl_bfloat16;
747+
auto btg_routing_bias_dtype =
748+
routing_bias_dtype == dl_bfloat16 ? btg::Dtype::Bfloat16 : btg::Dtype::Fp32;
749+
// We shouln't use args.mDtypeExpW since it indicates the output data type of routing kernel,
750+
// which is currently always bfloat16 for routing kernel while the data type of routing bias now
751+
// can be fp32
738752
args.routing_logits = routing_logits.has_value() ? routing_logits.value()->data : nullptr;
739753
args.routing_bias = routing_bias.has_value() ? routing_bias.value()->data : nullptr;
740754
args.hidden_states = hidden_states->data;
@@ -789,7 +803,7 @@ Array<Tensor> trtllm_fp4_block_scale_moe_launcher(
789803
Tensor permuted_idx_to_token_idx =
790804
alloc_tensor({max_num_padded_tokens}, dl_int32, hidden_states->device);
791805
// Tensor expert_weights = alloc_tensor(
792-
// {args.num_tokens, args.top_k}, routing_bias_dtype, hidden_states->device);
806+
// {args.num_tokens, args.top_k}, dl_bfloat16, hidden_states->device);
793807
// Tensor expert_indexes = alloc_tensor(
794808
// {args.num_tokens, args.top_k}, dl_int32, hidden_states->device);
795809
int constexpr MAX_NUM_EXPERTS = 384;
@@ -833,21 +847,21 @@ Array<Tensor> trtllm_fp4_block_scale_moe_launcher(
833847

834848
tensorrt_llm::kernels::trtllmgen_moe::Routing::Runner routing_runner(tile_tokens_dim);
835849
cudaStream_t stream = get_stream(hidden_states->device);
836-
routing_runner.run(args.routing_logits, args.routing_bias, args.num_tokens, args.num_experts,
837-
args.top_k, args.n_group, args.topk_group, args.local_expert_offset,
838-
args.local_num_experts, args.routed_scaling_factor,
839-
static_cast<int*>(expert_indices->data),
840-
static_cast<int*>(expert_count_histogram->data),
841-
static_cast<int*>(total_num_padded_tokens->data),
842-
static_cast<int*>(expanded_idx_to_permuted_idx->data),
843-
nullptr, /*static_cast<int*>(permuted_idx_to_expanded_idx->data),*/
844-
static_cast<int*>(permuted_idx_to_token_idx->data), expert_weights->data,
845-
static_cast<int*>(num_tokens_per_expert->data),
846-
static_cast<int*>(cta_idx_xy_to_batch_idx->data),
847-
static_cast<int*>(cta_idx_xy_to_mn_limit->data),
848-
static_cast<int*>(num_non_exiting_ctas->data), args.mDtypeElt,
849-
false /* use_routing_scales_on_input */, false /* use_deep_seek_fp8 */,
850-
static_cast<RoutingMethodType>(routing_method_type), stream);
850+
routing_runner.run(
851+
args.routing_logits, args.routing_bias, args.num_tokens, args.num_experts, args.top_k,
852+
args.n_group, args.topk_group, args.local_expert_offset, args.local_num_experts,
853+
args.routed_scaling_factor, static_cast<int*>(expert_indices->data),
854+
static_cast<int*>(expert_count_histogram->data),
855+
static_cast<int*>(total_num_padded_tokens->data),
856+
static_cast<int*>(expanded_idx_to_permuted_idx->data),
857+
nullptr, /*static_cast<int*>(permuted_idx_to_expanded_idx->data),*/
858+
static_cast<int*>(permuted_idx_to_token_idx->data), expert_weights->data,
859+
static_cast<int*>(num_tokens_per_expert->data),
860+
static_cast<int*>(cta_idx_xy_to_batch_idx->data),
861+
static_cast<int*>(cta_idx_xy_to_mn_limit->data),
862+
static_cast<int*>(num_non_exiting_ctas->data), args.mDtypeElt, btg_routing_bias_dtype,
863+
false /* use_routing_scales_on_input */, false /* use_deep_seek_fp8 */,
864+
static_cast<RoutingMethodType>(routing_method_type), stream);
851865

852866
//
853867
// FC13 (gemm1) + FC2 (gemm2)

csrc/trtllm_fused_moe_routing_deepseek.cu

Lines changed: 26 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -79,8 +79,8 @@ __global__ void routingMainKernel(KernelParams params) {
7979
expertSelected = laneIdx < params.mNumExpertsPerGroup;
8080
}
8181
auto scoreIdx = int64_t{blockIdx.x} * int64_t{params.mNumExperts} + threadExpert;
82-
auto biasVal = expertSelected ? params.mPtrRoutingBias[threadExpert] : invalidScore;
83-
82+
auto biasVal =
83+
expertSelected ? static_cast<float>(params.mPtrRoutingBias[threadExpert]) : invalidScoreFloat;
8484
// initialize the mPtrExpertCounts
8585
if (params.mPtrExpertCounts) {
8686
int32_t globalThreadIdx = blockIdx.x * NumThreads + threadIdx.x;
@@ -496,24 +496,24 @@ void runImpl(Data& data, void* stream) {
496496

497497
// Maximum number of tokens supported by the kernel using a cooperative launch.
498498
int const maxTokensCoop = (numBlocksCoop * NumThreads * 64) / data.mTopK;
499-
LAUNCH_ROUTING_WITH_EXTRA_FLAG(data,
500-
/*coopLaunch=*/false, routingMainKernel, numBlocks, NumThreads,
501-
/*smemSize=*/0, // No dynamic smem
502-
stream, data.mNumExpertGroups > 1, /*forceFloatInput=*/true);
499+
LAUNCH_ROUTING_DEEPSEEK(data,
500+
/*coopLaunch=*/false, routingMainKernel, numBlocks, NumThreads,
501+
/*smemSize=*/0, // No dynamic smem
502+
stream, data.mNumExpertGroups > 1);
503503

504504
if (data.mPtrPermutedIdxSize != nullptr) {
505505
if (useSingleCluster) {
506-
LAUNCH_ROUTING_WITH_EXTRA_FLAG(data,
507-
/*coopLaunch=*/false, routingIndicesClusterKernel,
508-
NumBlocksPerCluster, NumThreads,
509-
/*smemSize=*/0, // No dynamic smem
510-
stream, data.mNumExpertGroups > 1, /*forceFloatInput=*/true);
506+
LAUNCH_ROUTING_DEEPSEEK(data,
507+
/*coopLaunch=*/false, routingIndicesClusterKernel,
508+
NumBlocksPerCluster, NumThreads,
509+
/*smemSize=*/0, // No dynamic smem
510+
stream, data.mNumExpertGroups > 1);
511511
} else if (data.mNumTokens <= maxTokensCoop) {
512-
LAUNCH_ROUTING_WITH_EXTRA_FLAG(data,
513-
/*coopLaunch=*/true, routingIndicesCoopKernel, numBlocksCoop,
514-
NumThreads,
515-
/*smemSize=*/0, // No dynamic smem
516-
stream, data.mNumExpertGroups > 1, /*forceFloatInput=*/true);
512+
LAUNCH_ROUTING_DEEPSEEK(data,
513+
/*coopLaunch=*/true, routingIndicesCoopKernel, numBlocksCoop,
514+
NumThreads,
515+
/*smemSize=*/0, // No dynamic smem
516+
stream, data.mNumExpertGroups > 1);
517517
} else {
518518
const int32_t expandedIdxSize = data.mNumTokens * data.mTopK;
519519

@@ -528,16 +528,16 @@ void runImpl(Data& data, void* stream) {
528528
int const numBlocksOffsets =
529529
std::min((expandedIdxSize + offsetEltsPerBlock - 1) / offsetEltsPerBlock, maxNumBlocks);
530530

531-
LAUNCH_ROUTING_WITH_EXTRA_FLAG(data,
532-
/*coopLaunch=*/false, routingIndicesHistogramKernel,
533-
numBlocksHistogram, NumThreads,
534-
/*smemSize=*/0, // No dynamic smem
535-
stream, data.mNumExpertGroups > 1, /*forceFloatInput=*/true);
536-
LAUNCH_ROUTING_WITH_EXTRA_FLAG(data,
537-
/*coopLaunch=*/false, routingIndicesOffsetsKernel,
538-
numBlocksOffsets, NumThreads,
539-
/*smemSize=*/0, // No dynamic smem
540-
stream, data.mNumExpertGroups > 1, /*forceFloatInput=*/true);
531+
LAUNCH_ROUTING_DEEPSEEK(data,
532+
/*coopLaunch=*/false, routingIndicesHistogramKernel,
533+
numBlocksHistogram, NumThreads,
534+
/*smemSize=*/0, // No dynamic smem
535+
stream, data.mNumExpertGroups > 1);
536+
LAUNCH_ROUTING_DEEPSEEK(data,
537+
/*coopLaunch=*/false, routingIndicesOffsetsKernel, numBlocksOffsets,
538+
NumThreads,
539+
/*smemSize=*/0, // No dynamic smem
540+
stream, data.mNumExpertGroups > 1);
541541
}
542542
}
543543
}

csrc/trtllm_fused_moe_runner.cu

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -55,13 +55,18 @@ void Runner::run(void* routingLogits, void* routingBias, int32_t numTokens, int3
5555
int32_t* expandedIdxToPermutedIdx, int32_t* permutedIdxToExpandedIdx,
5656
int32_t* permutedIdxToTokenIdx, void* expertWeights, int32_t* numTokensPerExpert,
5757
int32_t* ctaIdxXyToBatchIdx, int32_t* ctaIdxXyToMnLimit,
58-
int32_t* numNonExitingCtas, btg::Dtype dtypeElt, bool useRoutingScalesOnInput,
59-
bool useDeepSeekFp8, RoutingMethodType routingMethodType, cudaStream_t stream) {
58+
int32_t* numNonExitingCtas, btg::Dtype dtypeElt, btg::Dtype dtypeBias,
59+
bool useRoutingScalesOnInput, bool useDeepSeekFp8,
60+
RoutingMethodType routingMethodType, cudaStream_t stream) {
6061
if (routingMethodType == RoutingMethodType::DeepSeekV3) {
6162
FLASHINFER_CHECK(topK <= 8, "For DeepSeek routing method, must have topK <= 8");
6263
FLASHINFER_CHECK(topkGroup <= 4, "For DeepSeek routing method, must have topkGroup <= 4");
6364
moe::dev::routing::routingDeepSeek::Data routingData;
64-
routingData.mDtypeExpW = btg::Dtype::Bfloat16;
65+
routingData.mDtypeExpW =
66+
btg::Dtype::Bfloat16; // for DeepSeek, the expW is currently always bfloat16
67+
routingData.mDtypeBias = dtypeBias; // for DeepSeek, the bias can be bfloat16 or fp32
68+
69+
routingData.mDtypeScore = btg::Dtype::Fp32; // for DeepSeek, the score is currently always fp32
6570
routingData.mUsePdl = true;
6671

6772
// output:

0 commit comments

Comments
 (0)