Skip to content

Commit 7995977

Browse files
committed
add new tile token dim (WIP)
1 parent a5f9585 commit 7995977

File tree

1 file changed

+20
-3
lines changed

1 file changed

+20
-3
lines changed

csrc/trtllm_fused_moe_kernel_launcher.cu

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -369,7 +369,7 @@ void trtllm_fp8_per_tensor_scale_moe(
369369
auto const hidden_size = hidden_states.size(1);
370370
bool mUseDeepSeekFp8{false}; // FP8 per-tensor doesn't use DeepSeek FP8
371371

372-
std::vector<int32_t> mSupportedTileN = {8, 16, 32, 64, 128};
372+
std::vector<int32_t> mSupportedTileN = {8, 16, 32, 64, 128, 192, 256};
373373
std::set<int32_t> selected_tile_nums =
374374
computeSelectedTileN(mSupportedTileN, num_tokens, top_k, local_num_experts);
375375

@@ -718,7 +718,7 @@ void trtllm_fp8_block_scale_moe(
718718
auto const num_tokens = hidden_states.size(0);
719719
auto const hidden_size = hidden_states.size(1);
720720

721-
std::vector<int32_t> mSupportedTileN = {8, 16, 32, 64};
721+
std::vector<int32_t> mSupportedTileN = {8, 16, 32, 64, 128};
722722
std::set<int32_t> selected_tile_nums =
723723
computeSelectedTileN(mSupportedTileN, num_tokens, top_k, local_num_experts);
724724

@@ -1228,6 +1228,11 @@ Array<Tensor> trtllm_fp4_block_scale_moe(
12281228
if (mDtypeAct != btg::Dtype::Bfloat16) {
12291229
mSupportedTileN.push_back(128);
12301230
}
1231+
if ((mDtypeAct == btg::Dtype::MxE4m3 && mDtypeWeights == btg::Dtype::MxE2m1) ||
1232+
(mDtypeAct == btg::Dtype::E2m1 && mDtypeWeights == btg::Dtype::E2m1)) {
1233+
// MxFP4 x MxFP4 or NvFP4 x NvFP4
1234+
mSupportedTileN.push_back(256);
1235+
}
12311236
std::set<int32_t> selected_tile_nums =
12321237
computeSelectedTileN(mSupportedTileN, num_tokens, top_k, local_num_experts);
12331238
// Build runners for all supported tile sizes
@@ -1305,8 +1310,20 @@ Array<Array<int64_t>> trtllm_get_valid_moe_configs(
13051310
bool is_fp8_per_tensor =
13061311
dtype_weights == btg::Dtype::E4m3 && dtype_act == btg::Dtype::E4m3 && !useDeepSeekFp8;
13071312

1308-
if (is_fp4_without_bf16_act || is_fp8_per_tensor) {
1313+
if (useDeepSeekFp8) {
13091314
supported_tile_nums.push_back(128);
1315+
} else if (is_fp8_per_tensor) {
1316+
supported_tile_nums.push_back(128);
1317+
supported_tile_nums.push_back(192);
1318+
supported_tile_nums.push_back(256);
1319+
} else if (is_fp4_without_bf16_act) {
1320+
supported_tile_nums.push_back(128);
1321+
}
1322+
1323+
if ((dtype_act == btg::Dtype::MxE4m3 && dtype_weights == btg::Dtype::MxE2m1) ||
1324+
(dtype_act == btg::Dtype::E2m1 && dtype_weights == btg::Dtype::E2m1)) {
1325+
// MxFP4 x MxFP4 or NvFP4 x NvFP4
1326+
supported_tile_nums.push_back(256);
13101327
}
13111328
std::set<int32_t> selected_tile_nums =
13121329
computeSelectedTileN(supported_tile_nums, num_tokens, top_k, num_local_experts);

0 commit comments

Comments
 (0)