@@ -113,13 +113,13 @@ def __init__(
113113 self .alltoall_workspace = None
114114 self .alltoall_prepare_workspace = None
115115 if self .enable_alltoall :
116- if self .moe_alltoall_backend == "MnnvlLatency " :
116+ if self .moe_alltoall_backend == "mnnvllatency " :
117117 MnnvlMemory .initialize ()
118118 self .alltoall_workspace = MnnvlMoe .get_moe_workspaces (
119119 model_config .mapping )
120120 self .alltoall_prepare_workspace = MnnvlMoe .get_moe_prepare_workspace (
121121 model_config .mapping )
122- elif self .moe_alltoall_backend == "MnnvlThroughput " :
122+ elif self .moe_alltoall_backend == "mnnvlthroughput " :
123123 workspace_mb = int (
124124 os .environ .get ("TRTLLM_MOE_A2A_WORKSPACE_MB" , "512" ))
125125 self .moe_a2a = MoeAlltoAll (
@@ -149,9 +149,9 @@ def enable_alltoall(self):
149149
150150 @cached_property
151151 def moe_alltoall_backend (self ):
152- # "MnnvlLatency " (default) or "MnnvlThroughput "
152+ # "mnnvllatency " (default) or "mnnvlthroughput "
153153 return os .environ .get ("TRTLLM_MOE_ALLTOALL_BACKEND" ,
154- "MnnvlLatency " ).strip ().lower ()
154+ "mnnvllatency " ).strip ().lower ()
155155
156156 def _check_configs (self ):
157157 assert self .has_deepseek_fp8_block_scales \
@@ -320,7 +320,7 @@ def forward_impl(
320320 else :
321321 token_final_scales = token_final_scales .to (torch .float32 )
322322
323- if self .moe_alltoall_backend == "MnnvlLatency " :
323+ if self .moe_alltoall_backend == "mnnvllatency " :
324324 assert self .alltoall_prepare_workspace is not None , "alltoall_prepare_workspace should be initialized"
325325 alltoall_info , _ = MnnvlMoe .mnnvl_moe_alltoallv_prepare_without_allgather (
326326 token_selected_experts ,
@@ -360,7 +360,7 @@ def forward_impl(
360360
361361 if token_final_scales is not None :
362362 token_final_scales = token_final_scales .to (torch .bfloat16 )
363- elif self .moe_alltoall_backend == "MnnvlThroughput " :
363+ elif self .moe_alltoall_backend == "mnnvlthroughput " :
364364 if x_sf is not None :
365365 x_sf = x_sf .view (x_row ,
366366 ceil_div (x_col , self .scaling_vector_size ))
@@ -667,7 +667,7 @@ def forward_impl(
667667
668668 # Combine results if using alltoall
669669 if self .enable_alltoall :
670- if self .moe_alltoall_backend == "MnnvlLatency " :
670+ if self .moe_alltoall_backend == "mnnvllatency " :
671671 if alltoall_info is not None :
672672 final_hidden_states = MnnvlMoe .mnnvl_moe_alltoallv_combine (
673673 final_hidden_states ,
@@ -678,7 +678,7 @@ def forward_impl(
678678 top_k = top_k ,
679679 token_count = token_count ,
680680 )
681- elif self .moe_alltoall_backend == "MnnvlThroughput " :
681+ elif self .moe_alltoall_backend == "mnnvlthroughput " :
682682 hidden = final_hidden_states .shape [- 1 ]
683683 payload = final_hidden_states .view (
684684 self .ep_size , self .moe_a2a .max_num_tokens_per_rank , hidden )
0 commit comments