Skip to content

Commit e9da5a4

Browse files
authored
[Misc] Add indirection layer for custom ops (vllm-project#3913)
1 parent e42df72 commit e9da5a4

File tree

14 files changed

+224
-32
lines changed

14 files changed

+224
-32
lines changed

benchmarks/kernels/benchmark_paged_attention.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
import torch
77

8-
from vllm._C import ops
8+
from vllm import _custom_ops as ops
99
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, create_kv_caches_with_random
1010

1111
NUM_BLOCKS = 1024

tests/kernels/test_attention.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from xformers import ops as xops
88
from xformers.ops.fmha.attn_bias import BlockDiagonalCausalMask
99

10-
from vllm._C import cache_ops, ops
10+
from vllm import _custom_ops as ops
1111
from vllm.utils import get_max_shared_memory_bytes, is_hip
1212

1313
FLOAT32_BYTES = torch.finfo(torch.float).bits // 8
@@ -237,14 +237,14 @@ def test_paged_attention(
237237
dequantized_key_cache = torch.empty(size=key_cache_shape,
238238
dtype=dtype,
239239
device=device)
240-
cache_ops.convert_fp8(key_cache, dequantized_key_cache)
240+
ops.convert_fp8(key_cache, dequantized_key_cache)
241241
key_cache = dequantized_key_cache
242242

243243
value_cache_shape = value_cache.shape
244244
dequantized_value_cache = torch.empty(size=value_cache_shape,
245245
dtype=dtype,
246246
device=device)
247-
cache_ops.convert_fp8(value_cache, dequantized_value_cache)
247+
ops.convert_fp8(value_cache, dequantized_value_cache)
248248
value_cache = dequantized_value_cache
249249

250250
ref_output = torch.empty_like(query)

tests/kernels/test_cache.py

+12-13
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import pytest
55
import torch
66

7-
from vllm._C import cache_ops
7+
from vllm import _custom_ops as ops
88
from vllm.utils import is_hip
99

1010
COPYING_DIRECTION = [('cuda', 'cpu'), ('cuda', 'cuda'), ('cpu', 'cuda')]
@@ -80,7 +80,7 @@ def test_copy_blocks(
8080
cloned_value_caches = [value_cache.clone() for value_cache in value_caches]
8181

8282
# Call the copy blocks kernel.
83-
cache_ops.copy_blocks(key_caches, value_caches, block_mapping)
83+
ops.copy_blocks(key_caches, value_caches, block_mapping)
8484

8585
# Run the reference implementation.
8686
for src, dsts in block_mapping.items():
@@ -145,9 +145,9 @@ def test_reshape_and_cache(
145145
# Clone the KV caches.
146146
if kv_cache_dtype == "fp8":
147147
cloned_key_cache = torch.empty_like(key_cache, dtype=torch.float16)
148-
cache_ops.convert_fp8(key_cache, cloned_key_cache)
148+
ops.convert_fp8(key_cache, cloned_key_cache)
149149
cloned_value_cache = torch.empty_like(value_cache, dtype=torch.float16)
150-
cache_ops.convert_fp8(value_cache, cloned_value_cache)
150+
ops.convert_fp8(value_cache, cloned_value_cache)
151151
else:
152152
cloned_key_cache = key_cache.clone()
153153
cloned_value_cache = value_cache.clone()
@@ -156,14 +156,14 @@ def test_reshape_and_cache(
156156
kv_scale = 1.0
157157

158158
# Call the reshape_and_cache kernel.
159-
cache_ops.reshape_and_cache(key, value, key_cache, value_cache,
160-
slot_mapping, kv_cache_dtype, kv_scale)
159+
ops.reshape_and_cache(key, value, key_cache, value_cache, slot_mapping,
160+
kv_cache_dtype, kv_scale)
161161

162162
if kv_cache_dtype == "fp8":
163163
result_key_cache = torch.empty_like(key_cache, dtype=torch.float16)
164-
cache_ops.convert_fp8(key_cache, result_key_cache)
164+
ops.convert_fp8(key_cache, result_key_cache)
165165
result_value_cache = torch.empty_like(value_cache, dtype=torch.float16)
166-
cache_ops.convert_fp8(value_cache, result_value_cache)
166+
ops.convert_fp8(value_cache, result_value_cache)
167167

168168
# Run the reference implementation.
169169
reshaped_key = key.reshape(num_tokens, *key_cache[0, :, :, 0, :].shape)
@@ -251,9 +251,8 @@ def test_swap_blocks(
251251
src_value_caches_clone = src_value_caches[0].clone()
252252

253253
# Call the swap_blocks kernel.
254-
cache_ops.swap_blocks(src_key_caches[0], dist_key_caches[0], block_mapping)
255-
cache_ops.swap_blocks(src_value_caches[0], dist_value_caches[0],
256-
block_mapping)
254+
ops.swap_blocks(src_key_caches[0], dist_key_caches[0], block_mapping)
255+
ops.swap_blocks(src_value_caches[0], dist_value_caches[0], block_mapping)
257256

258257
for src, dst in block_mapping.items():
259258
assert torch.allclose(src_key_caches_clone[src].cpu(),
@@ -291,9 +290,9 @@ def test_fp8_conversion(
291290
cache.uniform_(low, high)
292291

293292
cache_fp8 = torch.empty_like(cache, dtype=torch.uint8)
294-
cache_ops.convert_fp8(cache, cache_fp8)
293+
ops.convert_fp8(cache, cache_fp8)
295294

296295
converted_cache = torch.empty_like(cache)
297-
cache_ops.convert_fp8(cache_fp8, converted_cache)
296+
ops.convert_fp8(cache_fp8, converted_cache)
298297

299298
assert torch.allclose(cache, converted_cache, atol=0.001, rtol=0.1)

vllm/_custom_ops.py

+193
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,193 @@
1+
from typing import Dict, Optional
2+
3+
import torch
4+
5+
try:
6+
from vllm._C import cache_ops as vllm_cache_ops
7+
from vllm._C import ops as vllm_ops
8+
except ImportError:
9+
pass
10+
11+
12+
# activation ops
13+
def silu_and_mul(out: torch.Tensor, x: torch.Tensor) -> None:
14+
vllm_ops.silu_and_mul(out, x)
15+
16+
17+
def gelu_and_mul(out: torch.Tensor, x: torch.Tensor) -> None:
18+
vllm_ops.gelu_and_mul(out, x)
19+
20+
21+
def gelu_tanh_and_mul(out: torch.Tensor, x: torch.Tensor) -> None:
22+
vllm_ops.gelu_tanh_and_mul(out, x)
23+
24+
25+
def gelu_fast(out: torch.Tensor, x: torch.Tensor) -> None:
26+
vllm_ops.gelu_fast(out, x)
27+
28+
29+
def gelu_new(out: torch.Tensor, x: torch.Tensor) -> None:
30+
vllm_ops.gelu_new(out, x)
31+
32+
33+
# page attention ops
34+
def paged_attention_v1(
35+
out: torch.Tensor,
36+
query: torch.Tensor,
37+
key_cache: torch.Tensor,
38+
value_cache: torch.Tensor,
39+
num_kv_heads: int,
40+
scale: float,
41+
block_tables: torch.Tensor,
42+
context_lens: torch.Tensor,
43+
block_size: int,
44+
max_context_len: int,
45+
alibi_slopes: Optional[torch.Tensor],
46+
kv_cache_dtype: str,
47+
kv_scale: float,
48+
) -> None:
49+
vllm_ops.paged_attention_v1(out, query, key_cache, value_cache,
50+
num_kv_heads, scale, block_tables,
51+
context_lens, block_size, max_context_len,
52+
alibi_slopes, kv_cache_dtype, kv_scale)
53+
54+
55+
def paged_attention_v2(
56+
out: torch.Tensor,
57+
exp_sum: torch.Tensor,
58+
max_logits: torch.Tensor,
59+
tmp_out: torch.Tensor,
60+
query: torch.Tensor,
61+
key_cache: torch.Tensor,
62+
value_cache: torch.Tensor,
63+
num_kv_heads: int,
64+
scale: float,
65+
block_tables: torch.Tensor,
66+
context_lens: torch.Tensor,
67+
block_size: int,
68+
max_context_len: int,
69+
alibi_slopes: Optional[torch.Tensor],
70+
kv_cache_dtype: str,
71+
kv_scale: float,
72+
) -> None:
73+
vllm_ops.paged_attention_v2(out, exp_sum, max_logits, tmp_out, query,
74+
key_cache, value_cache, num_kv_heads, scale,
75+
block_tables, context_lens, block_size,
76+
max_context_len, alibi_slopes, kv_cache_dtype,
77+
kv_scale)
78+
79+
80+
# pos encoding ops
81+
def rotary_embedding(
82+
positions: torch.Tensor,
83+
query: torch.Tensor,
84+
key: torch.Tensor,
85+
head_size: int,
86+
cos_sin_cache: torch.Tensor,
87+
is_neox: bool,
88+
) -> None:
89+
vllm_ops.rotary_embedding(positions, query, key, head_size, cos_sin_cache,
90+
is_neox)
91+
92+
93+
def batched_rotary_embedding(positions: torch.Tensor, query: torch.Tensor,
94+
key: torch.Tensor, head_size: int,
95+
cos_sin_cache: torch.Tensor, is_neox: bool,
96+
rot_dim: int,
97+
cos_sin_cache_offsets: torch.Tensor) -> None:
98+
vllm_ops.batched_rotary_embedding(positions, query, key, head_size,
99+
cos_sin_cache, is_neox, rot_dim,
100+
cos_sin_cache_offsets)
101+
102+
103+
# layer norm ops
104+
def rms_norm(out: torch.Tensor, input: torch.Tensor, weight: torch.Tensor,
105+
epsilon: float) -> None:
106+
vllm_ops.rms_norm(out, input, weight, epsilon)
107+
108+
109+
def fused_add_rms_norm(input: torch.Tensor, residual: torch.Tensor,
110+
weight: torch.Tensor, epsilon: float) -> None:
111+
vllm_ops.fused_add_rms_norm(input, residual, weight, epsilon)
112+
113+
114+
# quantization ops
115+
# awq
116+
def awq_dequantize(qweight: torch.Tensor, scales: torch.Tensor,
117+
zeros: torch.Tensor, split_k_iters: int, thx: int,
118+
thy: int) -> torch.Tensor:
119+
return vllm_ops.awq_dequantize(qweight, scales, zeros, split_k_iters, thx,
120+
thy)
121+
122+
123+
def awq_gemm(input: torch.Tensor, qweight: torch.Tensor, qzeros: torch.Tensor,
124+
scales: torch.Tensor, split_k_iters: int) -> torch.Tensor:
125+
return vllm_ops.awq_gemm(input, qweight, qzeros, scales, split_k_iters)
126+
127+
128+
# gptq
129+
def gptq_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
130+
b_gptq_qzeros: torch.Tensor, b_gptq_scales: torch.Tensor,
131+
b_g_idx: torch.Tensor, use_exllama: bool,
132+
bit: int) -> torch.Tensor:
133+
return vllm_ops.gptq_gemm(a, b_q_weight, b_gptq_qzeros, b_gptq_scales,
134+
b_g_idx, use_exllama, bit)
135+
136+
137+
def gptq_shuffle(q_weight: torch.Tensor, q_perm: torch.Tensor,
138+
bit: int) -> None:
139+
vllm_ops.gptq_shuffle(q_weight, q_perm, bit)
140+
141+
142+
# squeezellm
143+
def squeezellm_gemm(vec: torch.Tensor, mat: torch.Tensor, mul: torch.Tensor,
144+
lookup_table: torch.Tensor) -> None:
145+
vllm_ops.squeezellm_gemm(vec, mat, mul, lookup_table)
146+
147+
148+
# marlin
149+
def marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
150+
b_scales: torch.Tensor, workspace: torch.Tensor, size_m: int,
151+
size_n: int, size_k: int) -> torch.Tensor:
152+
return vllm_ops.marlin_gemm(a, b_q_weight, b_scales, workspace, size_m,
153+
size_n, size_k)
154+
155+
156+
# moe
157+
def moe_align_block_size(topk_ids: torch.Tensor, num_experts: int,
158+
block_size: int, sorted_token_ids: torch.Tensor,
159+
experts_ids: torch.Tensor,
160+
num_tokens_post_pad: torch.Tensor) -> None:
161+
vllm_ops.moe_align_block_size(topk_ids, num_experts, block_size,
162+
sorted_token_ids, experts_ids,
163+
num_tokens_post_pad)
164+
165+
166+
def reshape_and_cache(
167+
key: torch.Tensor,
168+
value: torch.Tensor,
169+
key_cache: torch.Tensor,
170+
value_cache: torch.Tensor,
171+
slot_mapping: torch.Tensor,
172+
kv_cache_dtype: str,
173+
kv_scale: float,
174+
) -> None:
175+
vllm_cache_ops.reshape_and_cache(key, value, key_cache, value_cache,
176+
slot_mapping, kv_cache_dtype, kv_scale)
177+
178+
179+
def copy_blocks(key_caches: torch.Tensor, value_caches: torch.Tensor,
180+
block_mapping: torch.Tensor) -> None:
181+
vllm_cache_ops.copy_blocks(key_caches, value_caches, block_mapping)
182+
183+
184+
def swap_blocks(src: torch.Tensor, dst: torch.Tensor,
185+
block_mapping: Dict[int, int]) -> None:
186+
vllm_cache_ops.swap_blocks(src, dst, block_mapping)
187+
188+
189+
def convert_fp8(output: torch.Tensor, input: torch.Tensor) -> None:
190+
vllm_cache_ops.convert_fp8(output, input)
191+
192+
193+
#TODO: cuda_utils, custom_ar

vllm/attention/ops/paged_attn.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
import torch
55

6-
from vllm._C import cache_ops, ops
6+
from vllm import _custom_ops as ops
77
from vllm.attention.ops.prefix_prefill import context_attention_fwd
88

99
# Should be the same as PARTITION_SIZE in `paged_attention_v2_launcher`.
@@ -69,7 +69,7 @@ def write_to_paged_cache(
6969
kv_cache_dtype: str,
7070
kv_scale: float,
7171
) -> None:
72-
cache_ops.reshape_and_cache(
72+
ops.reshape_and_cache(
7373
key,
7474
value,
7575
key_cache,
@@ -199,11 +199,11 @@ def swap_blocks(
199199
) -> None:
200200
src_key_cache = src_kv_cache[0]
201201
dst_key_cache = dst_kv_cache[0]
202-
cache_ops.swap_blocks(src_key_cache, dst_key_cache, src_to_dst)
202+
ops.swap_blocks(src_key_cache, dst_key_cache, src_to_dst)
203203

204204
src_value_cache = src_kv_cache[1]
205205
dst_value_cache = dst_kv_cache[1]
206-
cache_ops.swap_blocks(src_value_cache, dst_value_cache, src_to_dst)
206+
ops.swap_blocks(src_value_cache, dst_value_cache, src_to_dst)
207207

208208
@staticmethod
209209
def copy_blocks(
@@ -212,4 +212,4 @@ def copy_blocks(
212212
) -> None:
213213
key_caches = [kv_cache[0] for kv_cache in kv_caches]
214214
value_caches = [kv_cache[1] for kv_cache in kv_caches]
215-
cache_ops.copy_blocks(key_caches, value_caches, src_to_dists)
215+
ops.copy_blocks(key_caches, value_caches, src_to_dists)

vllm/model_executor/layers/activation.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import torch.nn as nn
77
import torch.nn.functional as F
88

9-
from vllm._C import ops
9+
from vllm import _custom_ops as ops
1010
from vllm.distributed import (divide, get_tensor_model_parallel_rank,
1111
get_tensor_model_parallel_world_size)
1212
from vllm.model_executor.layers.quantization import QuantizationConfig

vllm/model_executor/layers/fused_moe/fused_moe.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import triton
99
import triton.language as tl
1010

11-
from vllm._C import ops
11+
from vllm import _custom_ops as ops
1212
from vllm.logger import init_logger
1313
from vllm.utils import is_hip
1414

vllm/model_executor/layers/layernorm.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import torch
55
import torch.nn as nn
66

7-
from vllm._C import ops
7+
from vllm import _custom_ops as ops
88

99

1010
class RMSNorm(nn.Module):

vllm/model_executor/layers/quantization/awq.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import torch
44
from torch.nn.parameter import Parameter
55

6-
from vllm._C import ops
6+
from vllm import _custom_ops as ops
77
from vllm.model_executor.layers.linear import (LinearMethodBase,
88
set_weight_attrs)
99
from vllm.model_executor.layers.quantization.base_config import (

vllm/model_executor/layers/quantization/gptq.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import torch
77
from torch.nn.parameter import Parameter
88

9-
from vllm._C import ops
9+
from vllm import _custom_ops as ops
1010
from vllm.model_executor.layers.linear import (LinearMethodBase,
1111
set_weight_attrs)
1212
from vllm.model_executor.layers.quantization.base_config import (

vllm/model_executor/layers/quantization/marlin.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import torch
44
from torch.nn.parameter import Parameter
55

6-
from vllm._C import ops
6+
from vllm import _custom_ops as ops
77
from vllm.model_executor.layers.linear import (LinearMethodBase,
88
set_weight_attrs)
99
from vllm.model_executor.layers.quantization.base_config import (

0 commit comments

Comments
 (0)