@@ -369,7 +369,7 @@ void trtllm_fp8_per_tensor_scale_moe(
369369 auto const hidden_size = hidden_states.size (1 );
370370 bool mUseDeepSeekFp8 {false }; // FP8 per-tensor doesn't use DeepSeek FP8
371371
372- std::vector<int32_t > mSupportedTileN = {8 , 16 , 32 , 64 , 128 };
372+ std::vector<int32_t > mSupportedTileN = {8 , 16 , 32 , 64 , 128 , 192 , 256 };
373373 std::set<int32_t > selected_tile_nums =
374374 computeSelectedTileN (mSupportedTileN , num_tokens, top_k, local_num_experts);
375375
@@ -718,7 +718,7 @@ void trtllm_fp8_block_scale_moe(
718718 auto const num_tokens = hidden_states.size (0 );
719719 auto const hidden_size = hidden_states.size (1 );
720720
721- std::vector<int32_t > mSupportedTileN = {8 , 16 , 32 , 64 };
721+ std::vector<int32_t > mSupportedTileN = {8 , 16 , 32 , 64 , 128 };
722722 std::set<int32_t > selected_tile_nums =
723723 computeSelectedTileN (mSupportedTileN , num_tokens, top_k, local_num_experts);
724724
@@ -1228,6 +1228,11 @@ Array<Tensor> trtllm_fp4_block_scale_moe(
12281228 if (mDtypeAct != btg::Dtype::Bfloat16) {
12291229 mSupportedTileN .push_back (128 );
12301230 }
1231+ if ((mDtypeAct == btg::Dtype::MxE4m3 && mDtypeWeights == btg::Dtype::MxE2m1) ||
1232+ (mDtypeAct == btg::Dtype::E2m1 && mDtypeWeights == btg::Dtype::E2m1)) {
1233+ // MxFP4 x MxFP4 or NvFP4 x NvFP4
1234+ mSupportedTileN .push_back (256 );
1235+ }
12311236 std::set<int32_t > selected_tile_nums =
12321237 computeSelectedTileN (mSupportedTileN , num_tokens, top_k, local_num_experts);
12331238 // Build runners for all supported tile sizes
@@ -1305,8 +1310,20 @@ Array<Array<int64_t>> trtllm_get_valid_moe_configs(
13051310 bool is_fp8_per_tensor =
13061311 dtype_weights == btg::Dtype::E4m3 && dtype_act == btg::Dtype::E4m3 && !useDeepSeekFp8;
13071312
1308- if (is_fp4_without_bf16_act || is_fp8_per_tensor ) {
1313+ if (useDeepSeekFp8 ) {
13091314 supported_tile_nums.push_back (128 );
1315+ } else if (is_fp8_per_tensor) {
1316+ supported_tile_nums.push_back (128 );
1317+ supported_tile_nums.push_back (192 );
1318+ supported_tile_nums.push_back (256 );
1319+ } else if (is_fp4_without_bf16_act) {
1320+ supported_tile_nums.push_back (128 );
1321+ }
1322+
1323+ if ((dtype_act == btg::Dtype::MxE4m3 && dtype_weights == btg::Dtype::MxE2m1) ||
1324+ (dtype_act == btg::Dtype::E2m1 && dtype_weights == btg::Dtype::E2m1)) {
1325+ // MxFP4 x MxFP4 or NvFP4 x NvFP4
1326+ supported_tile_nums.push_back (256 );
13101327 }
13111328 std::set<int32_t > selected_tile_nums =
13121329 computeSelectedTileN (supported_tile_nums, num_tokens, top_k, num_local_experts);
0 commit comments