Skip to content

Commit 08fa4ed

Browse files
committed
Use lower case for backend name.
Signed-off-by: Bo Li <[email protected]>
1 parent 90686db commit 08fa4ed

File tree

1 file changed

+8
-8
lines changed

1 file changed

+8
-8
lines changed

tensorrt_llm/_torch/modules/fused_moe/fused_moe_trtllm_gen.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -113,13 +113,13 @@ def __init__(
113113
self.alltoall_workspace = None
114114
self.alltoall_prepare_workspace = None
115115
if self.enable_alltoall:
116-
if self.moe_alltoall_backend == "MnnvlLatency":
116+
if self.moe_alltoall_backend == "mnnvllatency":
117117
MnnvlMemory.initialize()
118118
self.alltoall_workspace = MnnvlMoe.get_moe_workspaces(
119119
model_config.mapping)
120120
self.alltoall_prepare_workspace = MnnvlMoe.get_moe_prepare_workspace(
121121
model_config.mapping)
122-
elif self.moe_alltoall_backend == "MnnvlThroughput":
122+
elif self.moe_alltoall_backend == "mnnvlthroughput":
123123
workspace_mb = int(
124124
os.environ.get("TRTLLM_MOE_A2A_WORKSPACE_MB", "512"))
125125
self.moe_a2a = MoeAlltoAll(
@@ -149,9 +149,9 @@ def enable_alltoall(self):
149149

150150
@cached_property
151151
def moe_alltoall_backend(self):
152-
# "MnnvlLatency" (default) or "MnnvlThroughput"
152+
# "mnnvllatency" (default) or "mnnvlthroughput"
153153
return os.environ.get("TRTLLM_MOE_ALLTOALL_BACKEND",
154-
"MnnvlLatency").strip().lower()
154+
"mnnvllatency").strip().lower()
155155

156156
def _check_configs(self):
157157
assert self.has_deepseek_fp8_block_scales \
@@ -320,7 +320,7 @@ def forward_impl(
320320
else:
321321
token_final_scales = token_final_scales.to(torch.float32)
322322

323-
if self.moe_alltoall_backend == "MnnvlLatency":
323+
if self.moe_alltoall_backend == "mnnvllatency":
324324
assert self.alltoall_prepare_workspace is not None, "alltoall_prepare_workspace should be initialized"
325325
alltoall_info, _ = MnnvlMoe.mnnvl_moe_alltoallv_prepare_without_allgather(
326326
token_selected_experts,
@@ -360,7 +360,7 @@ def forward_impl(
360360

361361
if token_final_scales is not None:
362362
token_final_scales = token_final_scales.to(torch.bfloat16)
363-
elif self.moe_alltoall_backend == "MnnvlThroughput":
363+
elif self.moe_alltoall_backend == "mnnvlthroughput":
364364
if x_sf is not None:
365365
x_sf = x_sf.view(x_row,
366366
ceil_div(x_col, self.scaling_vector_size))
@@ -667,7 +667,7 @@ def forward_impl(
667667

668668
# Combine results if using alltoall
669669
if self.enable_alltoall:
670-
if self.moe_alltoall_backend == "MnnvlLatency":
670+
if self.moe_alltoall_backend == "mnnvllatency":
671671
if alltoall_info is not None:
672672
final_hidden_states = MnnvlMoe.mnnvl_moe_alltoallv_combine(
673673
final_hidden_states,
@@ -678,7 +678,7 @@ def forward_impl(
678678
top_k=top_k,
679679
token_count=token_count,
680680
)
681-
elif self.moe_alltoall_backend == "MnnvlThroughput":
681+
elif self.moe_alltoall_backend == "mnnvlthroughput":
682682
hidden = final_hidden_states.shape[-1]
683683
payload = final_hidden_states.view(
684684
self.ep_size, self.moe_a2a.max_num_tokens_per_rank, hidden)

0 commit comments

Comments
 (0)