Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions python/sglang/srt/mem_cache/radix_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

import heapq
import time
import warnings
from collections import defaultdict
from functools import partial
from typing import TYPE_CHECKING, List, Optional, Tuple
Expand Down
38 changes: 35 additions & 3 deletions python/sglang/srt/models/deepseek_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
# https://github.com/vllm-project/vllm/blob/fb6af8bc086328ca6659e72d11ffd4309ce4de22/vllm/model_executor/models/deepseek_v2.py
"""Inference-only DeepseekV2 model."""

import copy
import logging
import os
from dataclasses import dataclass
Expand Down Expand Up @@ -1093,6 +1094,9 @@ def forward_normal_from_cache(
forward_batch.token_to_kv_pool.set_kv_buffer(
self.attn_mha, forward_batch.out_cache_loc, latent_cache, None
)

k_current = k
v_current = v

# Fetch latent cache from memory pool with precomputed chunked kv indices
latent_cache_buf = forward_batch.token_to_kv_pool.get_key_buffer(
Expand All @@ -1110,11 +1114,13 @@ def forward_normal_from_cache(
chunk_len = forward_batch.extend_seq_lens_cpu[ibatch]

q_chunk = q[acc_chunk_len : acc_chunk_len + chunk_len][None, ...]
k_chunk = k_current[acc_chunk_len : acc_chunk_len + chunk_len][None, ...]
v_chunk = v_current[acc_chunk_len : acc_chunk_len + chunk_len][None, ...]

acc_chunk_len += chunk_len

latent_cache = latent_cache_buf[
block_table[ibatch : ibatch + 1, : prefix_len + chunk_len]
block_table[ibatch : ibatch + 1, : prefix_len]
]

kv_a_normed, k_pe = latent_cache.split(
Expand All @@ -1128,7 +1134,7 @@ def forward_normal_from_cache(
v = kv[..., self.qk_nope_head_dim :]
k_nope = kv[..., : self.qk_nope_head_dim]

k = torch.empty(
k = torch.zeros(
(
k_nope.shape[0],
self.num_local_heads,
Expand All @@ -1139,8 +1145,34 @@ def forward_normal_from_cache(
)
k[..., : self.qk_nope_head_dim] = k_nope
k[..., self.qk_nope_head_dim :] = k_pe

# k = k[:-k_chunk.shape[1]]
# v = v[:-k_chunk.shape[1]]

k = torch.cat([k, k_chunk[0]], dim=0)
v = torch.cat([v, v_chunk[0]], dim=0)

current_forward_batch = copy.copy(forward_batch)
current_forward_batch.batch_size = 1
current_forward_batch.req_pool_indices = forward_batch.req_pool_indices[ibatch:ibatch+1]
current_forward_batch.extend_seq_lens = forward_batch.extend_seq_lens[ibatch: ibatch+1]
current_forward_batch.extend_seq_lens_cpu = forward_batch.extend_seq_lens_cpu[ibatch: ibatch+1]
current_forward_batch.positions = forward_batch.positions[acc_chunk_len:acc_chunk_len + chunk_len]
# cache_loc = (
# forward_batch.out_cache_loc
# if not layer.is_cross_attention
# else forward_batch.encoder_out_cache_loc
# )
assert not self.attn_mha.is_cross_attention
current_forward_batch.out_cache_loc = forward_batch.out_cache_loc[acc_chunk_len:acc_chunk_len + chunk_len]

output = self.attn_mha(q_chunk, k, v, forward_batch, save_kv_cache=False)
output = self.attn_mha(
q_chunk,
k,
v,
forward_batch,
save_kv_cache=False
)

outputs.append(output)
attn_output = torch.cat(outputs, dim=0)
Expand Down
8 changes: 3 additions & 5 deletions python/sglang/srt/server_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -1366,16 +1366,14 @@ def from_cli_args(cls, args: argparse.Namespace):

if args.enable_hip_attention:
from hip_attn.v1_2 import HiPAttentionConfig

if args.hip_attention_config_path is not None:
json_or_path = args.hip_attention_config_path
else:
assert hasattr(args, 'hip_attention_config')
assert hasattr(args, "hip_attention_config")
json_or_path = args.hip_attention_config

args.hip_attention_config = HiPAttentionConfig(
json_or_path=json_or_path
)
args.hip_attention_config = HiPAttentionConfig(json_or_path=json_or_path)
logger.info(
f"attention_backend changed {args.attention_backend} -> hip_attention"
)
Expand Down
Loading