Skip to content

Commit 98948a4

Browse files
committed
fix perf
Signed-off-by: Enwei Zhu <[email protected]>
1 parent 2717692 commit 98948a4

File tree

4 files changed

+21
-7
lines changed

4 files changed

+21
-7
lines changed

cpp/tensorrt_llm/kernels/cuteDslKernels/moeUtils.cu

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,7 @@ void moePermute(InputType const* input, InputType* permuted_output, SFType const
142142
#endif
143143

144144
static int32_t const smCount = tensorrt_llm::common::getMultiProcessorCount();
145-
int32_t const blocks = std::min(smCount, max_num_permuted_tokens);
145+
int32_t const blocks = std::min(smCount * 8, max_num_permuted_tokens);
146146
int32_t const threads = kThreadsPerBlock;
147147

148148
auto kernel = &moePermuteKernel<InputType, SFType, kSFVecSize, kThreadsPerBlock>;
@@ -383,7 +383,7 @@ void moeActivation(InputType const* input, OutputType* output, float const* glob
383383
#endif
384384

385385
static int32_t const smCount = tensorrt_llm::common::getMultiProcessorCount();
386-
int32_t const blocks = std::min(smCount, max_num_permuted_tokens);
386+
int32_t const blocks = std::min(smCount * 8, max_num_permuted_tokens);
387387
int32_t const threads = kThreadsPerBlock;
388388

389389
auto get_act_kernel = [](ActivationType activation_type) -> void (*)(InputType const* input, OutputType* output,

cpp/tensorrt_llm/thop/cuteDslMoeUtilsOp.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -120,8 +120,8 @@ std::vector<torch::Tensor> moe_sort(torch::Tensor const& token_selected_experts,
120120
TORCH_CHECK(token_final_scales.size(0) == num_tokens, "token_final_scales.size(0) must be num_tokens.");
121121
TORCH_CHECK(token_final_scales.size(1) == top_k, "token_final_scales.size(1) must be top_k.");
122122
return moe_topk_sort_impl(std::nullopt, std::nullopt, token_selected_experts, token_final_scales, num_experts,
123-
top_k, std::nullopt, std::nullopt, local_expert_offset, local_num_experts, std::nullopt, tile_tokens_dim,
124-
RoutingMethodType::Renormalize);
123+
top_k, 1, 1, local_expert_offset, local_num_experts, std::nullopt, tile_tokens_dim,
124+
RoutingMethodType::DeepSeekV3);
125125
}
126126

127127
// Permute

tensorrt_llm/_torch/custom_ops/cute_dsl_custom_ops.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -676,6 +676,10 @@ def forward(self, inputs: List[torch.Tensor],
676676
mma_tiler_mn=mma_tiler_mn,
677677
cluster_shape_mn=cluster_shape_mn,
678678
)
679+
# Compute max active clusters on current device
680+
hardware_info = cutlass.utils.HardwareInfo()
681+
max_active_clusters = hardware_info.get_max_active_clusters(
682+
cluster_shape_mn[0] * cluster_shape_mn[1])
679683

680684
compiled_gemm = cute.compile(
681685
gemm.wrapper,
@@ -693,7 +697,7 @@ def forward(self, inputs: List[torch.Tensor],
693697
l,
694698
tile_size=self.tile_size,
695699
scaling_vector_size=self.scaling_vector_size,
696-
max_active_clusters=16,
700+
max_active_clusters=max_active_clusters,
697701
stream=stream,
698702
)
699703
self.__class__.kernel_cache[cache_key] = compiled_gemm
@@ -970,6 +974,10 @@ def forward(self, inputs: List[torch.Tensor],
970974
mma_tiler_mn=mma_tiler_mn,
971975
cluster_shape_mn=cluster_shape_mn,
972976
)
977+
# Compute max active clusters on current device
978+
hardware_info = cutlass.utils.HardwareInfo()
979+
max_active_clusters = hardware_info.get_max_active_clusters(
980+
cluster_shape_mn[0] * cluster_shape_mn[1])
973981

974982
compiled_gemm = cute.compile(
975983
gemm.wrapper,
@@ -992,7 +1000,7 @@ def forward(self, inputs: List[torch.Tensor],
9921000
self.top_k,
9931001
tile_size=self.tile_size,
9941002
scaling_vector_size=self.scaling_vector_size,
995-
max_active_clusters=16,
1003+
max_active_clusters=max_active_clusters,
9961004
stream=stream,
9971005
)
9981006
self.__class__.kernel_cache[cache_key] = compiled_gemm

tensorrt_llm/tools/layer_wise_benchmarks/deepseekv3_runner.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -209,7 +209,13 @@ def __init__(
209209
tensorrt_llm._torch.models.modeling_deepseekv3.DeepseekV3Gate = gate_cls_orig
210210

211211
def replace_routing_method(self, balance_method: BalanceMethod, balance_ratio: float):
212-
if self.model_config.moe_backend not in ["CUTLASS", "DEEPGEMM", "TRTLLM", "WIDEEP"]:
212+
if self.model_config.moe_backend not in [
213+
"CUTLASS",
214+
"DEEPGEMM",
215+
"TRTLLM",
216+
"WIDEEP",
217+
"CUTEDSL",
218+
]:
213219
raise NotImplementedError(
214220
f'Not support replace routing method for moe_backend "{self.model_config.moe_backend}",'
215221
f' please set balance_method to "NotModified"'

0 commit comments

Comments
 (0)