|
| 1 | +from typing import Dict, Optional |
| 2 | + |
| 3 | +import torch |
| 4 | + |
| 5 | +try: |
| 6 | + from vllm._C import cache_ops as vllm_cache_ops |
| 7 | + from vllm._C import ops as vllm_ops |
| 8 | +except ImportError: |
| 9 | + pass |
| 10 | + |
| 11 | + |
| 12 | +# activation ops |
| 13 | +def silu_and_mul(out: torch.Tensor, x: torch.Tensor) -> None: |
| 14 | + vllm_ops.silu_and_mul(out, x) |
| 15 | + |
| 16 | + |
| 17 | +def gelu_and_mul(out: torch.Tensor, x: torch.Tensor) -> None: |
| 18 | + vllm_ops.gelu_and_mul(out, x) |
| 19 | + |
| 20 | + |
| 21 | +def gelu_tanh_and_mul(out: torch.Tensor, x: torch.Tensor) -> None: |
| 22 | + vllm_ops.gelu_tanh_and_mul(out, x) |
| 23 | + |
| 24 | + |
| 25 | +def gelu_fast(out: torch.Tensor, x: torch.Tensor) -> None: |
| 26 | + vllm_ops.gelu_fast(out, x) |
| 27 | + |
| 28 | + |
| 29 | +def gelu_new(out: torch.Tensor, x: torch.Tensor) -> None: |
| 30 | + vllm_ops.gelu_new(out, x) |
| 31 | + |
| 32 | + |
| 33 | +# page attention ops |
| 34 | +def paged_attention_v1( |
| 35 | + out: torch.Tensor, |
| 36 | + query: torch.Tensor, |
| 37 | + key_cache: torch.Tensor, |
| 38 | + value_cache: torch.Tensor, |
| 39 | + num_kv_heads: int, |
| 40 | + scale: float, |
| 41 | + block_tables: torch.Tensor, |
| 42 | + context_lens: torch.Tensor, |
| 43 | + block_size: int, |
| 44 | + max_context_len: int, |
| 45 | + alibi_slopes: Optional[torch.Tensor], |
| 46 | + kv_cache_dtype: str, |
| 47 | + kv_scale: float, |
| 48 | +) -> None: |
| 49 | + vllm_ops.paged_attention_v1(out, query, key_cache, value_cache, |
| 50 | + num_kv_heads, scale, block_tables, |
| 51 | + context_lens, block_size, max_context_len, |
| 52 | + alibi_slopes, kv_cache_dtype, kv_scale) |
| 53 | + |
| 54 | + |
| 55 | +def paged_attention_v2( |
| 56 | + out: torch.Tensor, |
| 57 | + exp_sum: torch.Tensor, |
| 58 | + max_logits: torch.Tensor, |
| 59 | + tmp_out: torch.Tensor, |
| 60 | + query: torch.Tensor, |
| 61 | + key_cache: torch.Tensor, |
| 62 | + value_cache: torch.Tensor, |
| 63 | + num_kv_heads: int, |
| 64 | + scale: float, |
| 65 | + block_tables: torch.Tensor, |
| 66 | + context_lens: torch.Tensor, |
| 67 | + block_size: int, |
| 68 | + max_context_len: int, |
| 69 | + alibi_slopes: Optional[torch.Tensor], |
| 70 | + kv_cache_dtype: str, |
| 71 | + kv_scale: float, |
| 72 | +) -> None: |
| 73 | + vllm_ops.paged_attention_v2(out, exp_sum, max_logits, tmp_out, query, |
| 74 | + key_cache, value_cache, num_kv_heads, scale, |
| 75 | + block_tables, context_lens, block_size, |
| 76 | + max_context_len, alibi_slopes, kv_cache_dtype, |
| 77 | + kv_scale) |
| 78 | + |
| 79 | + |
| 80 | +# pos encoding ops |
| 81 | +def rotary_embedding( |
| 82 | + positions: torch.Tensor, |
| 83 | + query: torch.Tensor, |
| 84 | + key: torch.Tensor, |
| 85 | + head_size: int, |
| 86 | + cos_sin_cache: torch.Tensor, |
| 87 | + is_neox: bool, |
| 88 | +) -> None: |
| 89 | + vllm_ops.rotary_embedding(positions, query, key, head_size, cos_sin_cache, |
| 90 | + is_neox) |
| 91 | + |
| 92 | + |
| 93 | +def batched_rotary_embedding(positions: torch.Tensor, query: torch.Tensor, |
| 94 | + key: torch.Tensor, head_size: int, |
| 95 | + cos_sin_cache: torch.Tensor, is_neox: bool, |
| 96 | + rot_dim: int, |
| 97 | + cos_sin_cache_offsets: torch.Tensor) -> None: |
| 98 | + vllm_ops.batched_rotary_embedding(positions, query, key, head_size, |
| 99 | + cos_sin_cache, is_neox, rot_dim, |
| 100 | + cos_sin_cache_offsets) |
| 101 | + |
| 102 | + |
| 103 | +# layer norm ops |
| 104 | +def rms_norm(out: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, |
| 105 | + epsilon: float) -> None: |
| 106 | + vllm_ops.rms_norm(out, input, weight, epsilon) |
| 107 | + |
| 108 | + |
| 109 | +def fused_add_rms_norm(input: torch.Tensor, residual: torch.Tensor, |
| 110 | + weight: torch.Tensor, epsilon: float) -> None: |
| 111 | + vllm_ops.fused_add_rms_norm(input, residual, weight, epsilon) |
| 112 | + |
| 113 | + |
| 114 | +# quantization ops |
| 115 | +# awq |
| 116 | +def awq_dequantize(qweight: torch.Tensor, scales: torch.Tensor, |
| 117 | + zeros: torch.Tensor, split_k_iters: int, thx: int, |
| 118 | + thy: int) -> torch.Tensor: |
| 119 | + return vllm_ops.awq_dequantize(qweight, scales, zeros, split_k_iters, thx, |
| 120 | + thy) |
| 121 | + |
| 122 | + |
| 123 | +def awq_gemm(input: torch.Tensor, qweight: torch.Tensor, qzeros: torch.Tensor, |
| 124 | + scales: torch.Tensor, split_k_iters: int) -> torch.Tensor: |
| 125 | + return vllm_ops.awq_gemm(input, qweight, qzeros, scales, split_k_iters) |
| 126 | + |
| 127 | + |
| 128 | +# gptq |
| 129 | +def gptq_gemm(a: torch.Tensor, b_q_weight: torch.Tensor, |
| 130 | + b_gptq_qzeros: torch.Tensor, b_gptq_scales: torch.Tensor, |
| 131 | + b_g_idx: torch.Tensor, use_exllama: bool, |
| 132 | + bit: int) -> torch.Tensor: |
| 133 | + return vllm_ops.gptq_gemm(a, b_q_weight, b_gptq_qzeros, b_gptq_scales, |
| 134 | + b_g_idx, use_exllama, bit) |
| 135 | + |
| 136 | + |
| 137 | +def gptq_shuffle(q_weight: torch.Tensor, q_perm: torch.Tensor, |
| 138 | + bit: int) -> None: |
| 139 | + vllm_ops.gptq_shuffle(q_weight, q_perm, bit) |
| 140 | + |
| 141 | + |
| 142 | +# squeezellm |
| 143 | +def squeezellm_gemm(vec: torch.Tensor, mat: torch.Tensor, mul: torch.Tensor, |
| 144 | + lookup_table: torch.Tensor) -> None: |
| 145 | + vllm_ops.squeezellm_gemm(vec, mat, mul, lookup_table) |
| 146 | + |
| 147 | + |
| 148 | +# marlin |
| 149 | +def marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor, |
| 150 | + b_scales: torch.Tensor, workspace: torch.Tensor, size_m: int, |
| 151 | + size_n: int, size_k: int) -> torch.Tensor: |
| 152 | + return vllm_ops.marlin_gemm(a, b_q_weight, b_scales, workspace, size_m, |
| 153 | + size_n, size_k) |
| 154 | + |
| 155 | + |
| 156 | +# moe |
| 157 | +def moe_align_block_size(topk_ids: torch.Tensor, num_experts: int, |
| 158 | + block_size: int, sorted_token_ids: torch.Tensor, |
| 159 | + experts_ids: torch.Tensor, |
| 160 | + num_tokens_post_pad: torch.Tensor) -> None: |
| 161 | + vllm_ops.moe_align_block_size(topk_ids, num_experts, block_size, |
| 162 | + sorted_token_ids, experts_ids, |
| 163 | + num_tokens_post_pad) |
| 164 | + |
| 165 | + |
| 166 | +def reshape_and_cache( |
| 167 | + key: torch.Tensor, |
| 168 | + value: torch.Tensor, |
| 169 | + key_cache: torch.Tensor, |
| 170 | + value_cache: torch.Tensor, |
| 171 | + slot_mapping: torch.Tensor, |
| 172 | + kv_cache_dtype: str, |
| 173 | + kv_scale: float, |
| 174 | +) -> None: |
| 175 | + vllm_cache_ops.reshape_and_cache(key, value, key_cache, value_cache, |
| 176 | + slot_mapping, kv_cache_dtype, kv_scale) |
| 177 | + |
| 178 | + |
| 179 | +def copy_blocks(key_caches: torch.Tensor, value_caches: torch.Tensor, |
| 180 | + block_mapping: torch.Tensor) -> None: |
| 181 | + vllm_cache_ops.copy_blocks(key_caches, value_caches, block_mapping) |
| 182 | + |
| 183 | + |
| 184 | +def swap_blocks(src: torch.Tensor, dst: torch.Tensor, |
| 185 | + block_mapping: Dict[int, int]) -> None: |
| 186 | + vllm_cache_ops.swap_blocks(src, dst, block_mapping) |
| 187 | + |
| 188 | + |
| 189 | +def convert_fp8(output: torch.Tensor, input: torch.Tensor) -> None: |
| 190 | + vllm_cache_ops.convert_fp8(output, input) |
| 191 | + |
| 192 | + |
| 193 | +#TODO: cuda_utils, custom_ar |
0 commit comments