Skip to content

Commit a806c14

Browse files
authored
[Performance][LoRA] add context varying params to 'do_not_specialize' in fused moe lora (vllm-project#27445)
Signed-off-by: gnovack <[email protected]>
1 parent 181bf5b commit a806c14

File tree

1 file changed

+16
-7
lines changed

1 file changed

+16
-7
lines changed

vllm/lora/ops/triton_ops/fused_moe_lora_op.py

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,16 @@ def _get_ptr(lora_weights: list[torch.Tensor], device: torch.device):
3131
return _LORA_PTR_DICT.get(key)
3232

3333

34-
@triton.jit
34+
@triton.jit(
35+
do_not_specialize=[
36+
"num_valid_tokens",
37+
"EM",
38+
"stride_tl",
39+
"stride_el",
40+
"slice_a_size",
41+
"slice_c_size",
42+
]
43+
)
3544
def _fused_moe_lora_kernel(
3645
a_ptr,
3746
b_ptr,
@@ -60,11 +69,11 @@ def _fused_moe_lora_kernel(
6069
stride_cn,
6170
stride_tl,
6271
stride_el,
72+
slice_a_size,
73+
slice_c_size,
6374
# Meta-parameters
6475
num_slice_a: tl.constexpr,
6576
num_slice_c: tl.constexpr,
66-
slice_a_size: tl.constexpr,
67-
slice_c_size: tl.constexpr,
6877
top_k: tl.constexpr,
6978
MUL_ROUTED_WEIGHT: tl.constexpr,
7079
BLOCK_SIZE_M: tl.constexpr,
@@ -256,10 +265,10 @@ def _fused_moe_lora(
256265
a_intermediate_cache1.stride(3),
257266
sorted_token_ids.stride(0),
258267
expert_ids.stride(0),
259-
num_slice_a=1,
260-
num_slice_c=num_slices,
261268
slice_a_size=qcurr_hidden_states.numel(),
262269
slice_c_size=a_intermediate_cache1.numel() // num_slices,
270+
num_slice_a=1,
271+
num_slice_c=num_slices,
263272
top_k=1 if mul_routed_weight else top_k_num,
264273
MUL_ROUTED_WEIGHT=False,
265274
**config,
@@ -305,10 +314,10 @@ def _fused_moe_lora(
305314
b_intermediate_cache1.stride(3),
306315
sorted_token_ids.stride(0),
307316
expert_ids.stride(0),
308-
num_slice_a=num_slices,
309-
num_slice_c=num_slices,
310317
slice_a_size=a_intermediate_cache1.numel() // num_slices,
311318
slice_c_size=b_intermediate_cache1.numel() // num_slices,
319+
num_slice_a=num_slices,
320+
num_slice_c=num_slices,
312321
top_k=1,
313322
MUL_ROUTED_WEIGHT=mul_routed_weight,
314323
**config,

0 commit comments

Comments
 (0)