@@ -31,7 +31,16 @@ def _get_ptr(lora_weights: list[torch.Tensor], device: torch.device):
3131 return _LORA_PTR_DICT .get (key )
3232
3333
34- @triton .jit
34+ @triton .jit (
35+ do_not_specialize = [
36+ "num_valid_tokens" ,
37+ "EM" ,
38+ "stride_tl" ,
39+ "stride_el" ,
40+ "slice_a_size" ,
41+ "slice_c_size" ,
42+ ]
43+ )
3544def _fused_moe_lora_kernel (
3645 a_ptr ,
3746 b_ptr ,
@@ -60,11 +69,11 @@ def _fused_moe_lora_kernel(
6069 stride_cn ,
6170 stride_tl ,
6271 stride_el ,
72+ slice_a_size ,
73+ slice_c_size ,
6374 # Meta-parameters
6475 num_slice_a : tl .constexpr ,
6576 num_slice_c : tl .constexpr ,
66- slice_a_size : tl .constexpr ,
67- slice_c_size : tl .constexpr ,
6877 top_k : tl .constexpr ,
6978 MUL_ROUTED_WEIGHT : tl .constexpr ,
7079 BLOCK_SIZE_M : tl .constexpr ,
@@ -256,10 +265,10 @@ def _fused_moe_lora(
256265 a_intermediate_cache1 .stride (3 ),
257266 sorted_token_ids .stride (0 ),
258267 expert_ids .stride (0 ),
259- num_slice_a = 1 ,
260- num_slice_c = num_slices ,
261268 slice_a_size = qcurr_hidden_states .numel (),
262269 slice_c_size = a_intermediate_cache1 .numel () // num_slices ,
270+ num_slice_a = 1 ,
271+ num_slice_c = num_slices ,
263272 top_k = 1 if mul_routed_weight else top_k_num ,
264273 MUL_ROUTED_WEIGHT = False ,
265274 ** config ,
@@ -305,10 +314,10 @@ def _fused_moe_lora(
305314 b_intermediate_cache1 .stride (3 ),
306315 sorted_token_ids .stride (0 ),
307316 expert_ids .stride (0 ),
308- num_slice_a = num_slices ,
309- num_slice_c = num_slices ,
310317 slice_a_size = a_intermediate_cache1 .numel () // num_slices ,
311318 slice_c_size = b_intermediate_cache1 .numel () // num_slices ,
319+ num_slice_a = num_slices ,
320+ num_slice_c = num_slices ,
312321 top_k = 1 ,
313322 MUL_ROUTED_WEIGHT = mul_routed_weight ,
314323 ** config ,
0 commit comments