Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions lightllm/common/kv_cache_mem_manager/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from .ppl_int4kv_mem_manager import PPLINT4KVMemoryManager
from .deepseek2_mem_manager import Deepseek2MemoryManager
from .deepseek2_fp8kv_mem_manager import Deepseek2FP8KVMemoryManager
from .neo_mem_manager import NeoMemoryManager

__all__ = [
"MemoryManager",
Expand All @@ -17,4 +18,5 @@
"PPLINT8KVMemoryManager",
"Deepseek2MemoryManager",
"Deepseek2FP8KVMemoryManager",
"NeoMemoryManager",
]
2 changes: 1 addition & 1 deletion lightllm/common/kv_cache_mem_manager/mem_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def __init__(self, size, dtype, head_num, head_dim, layer_num, always_copy=False
self.size,
dtype,
head_num,
head_dim,
self.head_dim,
layer_num,
)
self.HOLD_TOKEN_MEMINDEX = self.size
Expand Down
7 changes: 6 additions & 1 deletion lightllm/common/kv_cache_mem_manager/mem_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
PPLINT4KVMemoryManager,
Deepseek2MemoryManager,
Deepseek2FP8KVMemoryManager,
NeoMemoryManager,
)
from lightllm.utils.log_utils import init_logger
from lightllm.utils.envs_utils import get_env_start_args
Expand All @@ -23,7 +24,7 @@ def select_mem_manager_class():
# case 1
# 先判断是否是 deepseek 系列的模型
model_class = get_llm_model_class()
from lightllm.models import Deepseek2TpPartModel
from lightllm.models import Deepseek2TpPartModel, NeoTpMOEPartModel, NeoTpPartModel

if issubclass(model_class, Deepseek2TpPartModel):
mem_class = Deepseek2MemoryManager
Expand All @@ -32,6 +33,10 @@ def select_mem_manager_class():

logger.info(f"Model kv cache using mode {mode}, mem_manager class: {mem_class}")
return mem_class
# 判断是否是 neo 系列的模型
elif issubclass(model_class, NeoTpMOEPartModel) or issubclass(model_class, NeoTpPartModel):
mem_class = NeoMemoryManager
return mem_class

# case normal
logger.info(f"mode setting params: {mode}")
Expand Down
46 changes: 46 additions & 0 deletions lightllm/common/kv_cache_mem_manager/neo_mem_manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
import torch
from lightllm.utils.dist_utils import get_current_rank_in_node
from lightllm.server.router.dynamic_prompt.shared_arr import SharedInt
from lightllm.common.kv_cache_mem_manager.mem_manager import MemoryManager


class NeoMemoryManager(MemoryManager):
def __init__(self, size, dtype, head_num, head_dim, layer_num, always_copy=False, mem_fraction=0.9):
self.size = size
self.head_num = head_num
self.head_dim = head_dim * 2 # neo kv 是[k, k_h, k_w]拼在一起的
self.layer_num = layer_num
self.always_copy = always_copy
self.dtype = dtype
# profile the max total token num if the size is None
self.profile_size(mem_fraction)

self.mem_state = torch.arange(
0, self.size, dtype=torch.int32, device="cpu", requires_grad=False, pin_memory=True
)
self._mem_state_return = torch.arange(
0, self.size * 3, dtype=torch.int32, device="cpu", requires_grad=False, pin_memory=True
)
self._return_start = 0
self.mark_start = 0
self.mark_end = self.size

self.can_use_mem_size = self.size

# 用共享内存进行共享,router 模块读取进行精确的调度估计, nccl port 作为一个单机中单实列的标记。防止冲突。
from lightllm.utils.envs_utils import get_unique_server_name

rank_in_node = get_current_rank_in_node()
self.shared_can_use_token_num = SharedInt(
f"{get_unique_server_name()}_mem_manger_can_use_token_num_{rank_in_node}"
)

self.shared_can_use_token_num.set_value(self.can_use_mem_size)
self._init_buffers(
self.size,
dtype,
head_num,
self.head_dim,
layer_num,
)
self.HOLD_TOKEN_MEMINDEX = self.size
2 changes: 2 additions & 0 deletions lightllm/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,4 +38,6 @@
Tarsier2LlamaTpPartModel,
)
from lightllm.models.gpt_oss.model import GptOssTpPartModel
from lightllm.models.neo_chat_moe.model import NeoTpMOEPartModel
from lightllm.models.neo_chat.model import NeoTpPartModel
from .registry import get_model, get_model_class
47 changes: 46 additions & 1 deletion lightllm/models/llama/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,8 @@ def _init_custom(self):
rope_scaling = self.config.get("rope_scaling", None)
if rope_scaling is None:
self._init_to_get_rotary()
if "rope_theta_hw" in self.config:
self._init_to_get_hw_rotary()
return

if "rope_type" in rope_scaling:
Expand All @@ -132,6 +134,8 @@ def _init_custom(self):
self._init_to_get_rotary()
else:
raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
if "rope_theta_hw" in self.config:
self._init_to_get_hw_rotary()
return

def _init_weights(self):
Expand Down Expand Up @@ -178,7 +182,7 @@ def _init_to_get_rotary(self, default_base=10000):
rope_scaling_factor = self.config.get("rope_scaling", {}).get("factor", 1.0)

base = self.config.get("rope_theta", float(default_base))

print(f"base is {base}")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This print statement appears to be for debugging purposes and should be removed before merging.

if "max_sequence_length" in self.config:
max_seq_len = self.config["max_sequence_length"]
else:
Expand Down Expand Up @@ -211,6 +215,47 @@ def _init_to_get_rotary(self, default_base=10000):
self._sin_cached = torch.sin(freqs).to(self.data_type).cuda()
return

def _init_to_get_hw_rotary(self, default_base=10000):
partial_head_dim = int(self.config.get("partial_rotary_factor", 1) * self.head_dim_ // 2)
if self.config.get("rope_scaling", {}) is None:
rope_scaling_factor = 1.0
else:
rope_scaling_factor = self.config.get("rope_scaling", {}).get("factor", 1.0)

base = self.config.get("rope_theta_hw", float(default_base))
print(f"hw_base is {base}")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This print statement appears to be for debugging purposes and should be removed before merging.

if "max_sequence_length" in self.config:
max_seq_len = self.config["max_sequence_length"]
else:
max_position_embeddings = self.config.get(
"max_position_embeddings_hw", 2048 if base <= 10000.0 + 1e-5 else 16384
)
max_seq_len = max_position_embeddings * rope_scaling_factor

# NTK
try:
ntk_alpha = float(os.environ.get("LIGHTLLM_NTK_ALPHA", 1))
assert ntk_alpha >= 1
if ntk_alpha > 1:
logger.info(f"Note: NTK enabled, alpha set to {ntk_alpha}")
max_seq_len *= ntk_alpha
base = base * (ntk_alpha ** (partial_head_dim / (partial_head_dim - 2))) # Base change formula
except:
pass
Comment on lines +236 to +244
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Using a bare except: is generally discouraged as it can catch and silence a wide range of unexpected errors, making debugging difficult. It's better to catch specific exceptions that you expect might occur, such as ValueError or AssertionError, and log them for better diagnostics.

Suggested change
try:
ntk_alpha = float(os.environ.get("LIGHTLLM_NTK_ALPHA", 1))
assert ntk_alpha >= 1
if ntk_alpha > 1:
logger.info(f"Note: NTK enabled, alpha set to {ntk_alpha}")
max_seq_len *= ntk_alpha
base = base * (ntk_alpha ** (partial_head_dim / (partial_head_dim - 2))) # Base change formula
except:
pass
try:
ntk_alpha = float(os.environ.get("LIGHTLLM_NTK_ALPHA", 1))
assert ntk_alpha >= 1
if ntk_alpha > 1:
logger.info(f"Note: NTK enabled, alpha set to {ntk_alpha}")
max_seq_len *= ntk_alpha
base = base * (ntk_alpha ** (partial_head_dim / (partial_head_dim - 2))) # Base change formula
except (ValueError, AssertionError) as e:
logger.warning(f"Could not apply NTK scaling: {e}")


inv_freq = 1.0 / (
base ** (torch.arange(0, partial_head_dim, 2, device="cpu", dtype=torch.float32) / partial_head_dim)
)
t = (
torch.arange(max(max_seq_len + 1024 * 128, self.max_seq_length), device="cpu", dtype=torch.float32)
/ rope_scaling_factor
)
freqs = torch.outer(t, inv_freq)

self._hw_cos_cached = torch.cos(freqs).to(self.data_type).cuda()
self._hw_sin_cached = torch.sin(freqs).to(self.data_type).cuda()
return

def _init_to_get_dynamic_ntk_rotary(self):
partial_head_dim = int(self.config.get("partial_rotary_factor", 1) * self.head_dim_)
max_position_embeddings = self.config.get("max_position_embeddings", 2048)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,8 @@ def token_att_fwd(q, k, att_out, Req_to_tokens, B_req_idx, B_Start_Loc, B_Seqlen
Lq, Lk = q.shape[-1], k.shape[-1]
assert Lq == Lk
assert Lk in {16, 32, 64, 128, 256}
sm_scale = 1.0 / (Lk ** 0.5)
Lk_scale = Lk // 2
sm_scale = 1.0 / (Lk_scale ** 0.5)

batch, head_num = B_req_idx.shape[0], q.shape[1]

Expand Down
Empty file.
Empty file.
159 changes: 159 additions & 0 deletions lightllm/models/neo_chat/layer_infer/transformer_layer_infer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
import torch
from functools import partial
from typing import Tuple
from lightllm.models.llama.triton_kernel.rmsnorm import rmsnorm_forward
from lightllm.models.llama.triton_kernel.rotary_emb import rotary_emb_fwd
from lightllm.models.neo_chat_moe.infer_struct import NeoChatInferStateInfo
from lightllm.models.neo_chat_moe.triton_kernel.context_attention_fwd_neo import context_attention_fwd_neo
from lightllm.models.llama.triton_kernel.context_flashattention_nopad import context_attention_fwd
from lightllm.models.llama.triton_kernel.token_attention_nopad_att1 import token_att_fwd
from lightllm.models.qwen3.layer_infer.transformer_layer_infer import Qwen3TransformerLayerInfer
from lightllm.models.neo_chat.layer_weights.transformer_layer_weight import NeoChatTransformerLayerWeight
from lightllm.distributed import all_reduce
import torch.distributed as dist
from lightllm.models.llama.layer_infer.transformer_layer_infer import LlamaTransformerLayerInfer
from lightllm.models.qwen3.triton_kernel.qk_norm import qk_rmsnorm_forward


class NeoChatTransformerLayerInfer(Qwen3TransformerLayerInfer):
def __init__(self, data_type, network_config, mode):
super().__init__(data_type, network_config, mode)
return

def _bind_attention(self):
self._context_attention_kernel = self._context_attention_kernel
self._token_attention_kernel = self._token_decode_attention_normal
self._copy_kv_to_mem_cache = self._copy_kv_to_mem_cache_normal
return

def _get_qkv(self, input, infer_state: NeoChatInferStateInfo, layer_weight: NeoChatTransformerLayerWeight):
input = input.view(-1, self.embed_dim_)
q = layer_weight.q_proj.mm(input) # [T, Hq*D]

q_hw = layer_weight.q_hw_proj.mm(input)
q_hw = q_hw.view(-1, self.tp_q_head_num_, self.head_dim_)
q_h, q_w = q_hw.chunk(2, dim=-1)

k_hw = layer_weight.k_hw_proj.mm(input)
k_hw = k_hw.view(-1, self.tp_k_head_num_, self.head_dim_)
k_h, k_w = k_hw.chunk(2, dim=-1)

cache_kv = layer_weight.kv_proj.mm(input) # [T, (Hk+Hv)*D]

qk_rmsnorm_forward(q, weight=layer_weight.q_norm_weight_.weight, eps=self.eps_)

q_h_2d = q_h.reshape(q.shape[0], -1)
q_w_2d = q_w.reshape(q.shape[0], -1)
qk_rmsnorm_forward(q_h_2d, weight=layer_weight.q_norm_h_weight_.weight, eps=self.eps_)
qk_rmsnorm_forward(q_w_2d, weight=layer_weight.q_norm_w_weight_.weight, eps=self.eps_)
q_h = q_h_2d.view(q.shape[0], self.tp_q_head_num_, self.head_dim_ // 2)
q_w = q_w_2d.view(q.shape[0], self.tp_q_head_num_, self.head_dim_ // 2)

qk_rmsnorm_forward(
cache_kv[:, : self.tp_k_head_num_ * self.head_dim_],
weight=layer_weight.k_norm_weight_.weight,
eps=self.eps_,
)

k_h_2d = k_h.reshape(q.shape[0], -1) # [T, Hk*(D/2)]
k_w_2d = k_w.reshape(q.shape[0], -1)
qk_rmsnorm_forward(k_h_2d, weight=layer_weight.k_norm_h_weight_.weight, eps=self.eps_)
qk_rmsnorm_forward(k_w_2d, weight=layer_weight.k_norm_w_weight_.weight, eps=self.eps_)
k_h = k_h_2d.view(q.shape[0], self.tp_k_head_num_, self.head_dim_ // 2)
k_w = k_w_2d.view(q.shape[0], self.tp_k_head_num_, self.head_dim_ // 2)

cache_kv = cache_kv.view(-1, (self.tp_k_head_num_ + self.tp_v_head_num_), self.head_dim_)

rotary_emb_fwd(
q.view(-1, self.tp_q_head_num_, self.head_dim_),
cache_kv[:, : self.tp_k_head_num_, :],
infer_state.position_cos,
infer_state.position_sin,
)
rotary_emb_fwd(
q_h,
k_h,
infer_state.position_cos_h,
infer_state.position_sin_h,
)
rotary_emb_fwd(
q_w,
k_w,
infer_state.position_cos_w,
infer_state.position_sin_w,
)

q3 = q.view(-1, self.tp_q_head_num_, self.head_dim_)
q3 = torch.cat([q3, q_h, q_w], dim=-1)
q = q3.reshape(q3.shape[0], -1)

k = cache_kv[:, : self.tp_k_head_num_, :]
k = torch.cat([k, k_h, k_w], dim=-1)

v = cache_kv[:, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, :]
v_pad = torch.zeros((v.shape[0], v.shape[1], self.head_dim_), device=v.device, dtype=v.dtype)
v = torch.cat([v, v_pad], dim=-1)

cache_kv = torch.cat([k, v], dim=1)
return q, cache_kv

def _context_attention_kernel(
self, q, kv, infer_state: NeoChatInferStateInfo, layer_weight, out=None
) -> torch.Tensor:
o_tensor = self.alloc_tensor(q.shape, q.dtype) if out is None else out
kv = infer_state.mem_manager.kv_buffer[self.layer_num_]
context_attention_fwd_neo(
q.view(-1, self.tp_q_head_num_, self.head_dim_ * 2),
kv[:, 0 : self.tp_k_head_num_, :],
kv[:, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, :],
o_tensor.view(-1, self.tp_q_head_num_, self.head_dim_ * 2),
infer_state.position_ids[0], # [0,0,1,2,3,3,3,4]
infer_state.b_req_idx,
infer_state.b_start_loc,
infer_state.b_seq_len,
infer_state.b_ready_cache_len,
infer_state.max_len_in_batch,
infer_state.req_manager.req_to_token_indexs,
)
o3 = o_tensor.view(-1, self.tp_q_head_num_, self.head_dim_ * 2)
o3 = o3[:, :, : self.head_dim_].contiguous()
return o3.view(o3.shape[0], -1)

def _token_decode_attention_normal(self, q, infer_state: NeoChatInferStateInfo, layer_weight, out=None):
total_token_num = infer_state.total_token_num
batch_size = infer_state.batch_size

q_3d = q.view(batch_size, self.tp_q_head_num_, self.head_dim_ * 2)

att_m_tensor = self.alloc_tensor((self.tp_q_head_num_, total_token_num), torch.float32)

k_3d = infer_state.mem_manager.kv_buffer[self.layer_num_][:, 0 : self.tp_k_head_num_, :]
token_att_fwd(
q_3d,
k_3d,
att_m_tensor,
infer_state.req_manager.req_to_token_indexs,
infer_state.b_req_idx,
infer_state.b_start_loc,
infer_state.b_seq_len,
infer_state.max_len_in_batch,
)

from lightllm.models.llama.triton_kernel.token_attention_softmax_and_reducev import token_softmax_reducev_fwd

v_3d = infer_state.mem_manager.kv_buffer[self.layer_num_][
:, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, : self.head_dim_
]

o_3d = self.alloc_tensor((batch_size, self.tp_q_head_num_, self.head_dim_), q.dtype) if out is None else out

token_softmax_reducev_fwd(
att_m_tensor,
v_3d,
o_3d,
infer_state.req_manager.req_to_token_indexs,
infer_state.b_req_idx,
infer_state.b_start_loc,
infer_state.b_seq_len,
)
return o_3d.view(batch_size, -1)
Empty file.
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
import torch
import numpy as np
from lightllm.models.qwen2.layer_weights.pre_and_post_layer_weight import Qwen2PreAndPostLayerWeight

# add key: language_model.xxx -> xxx
# only change keys at PreAndPostLayerWeight load, TransformLayerWeight is correct now
def rename_weight_keys(weights):
prefix = "language_model."
keys = list(weights.keys())
for k in keys:
if prefix in k:
weights[k.replace(prefix, "")] = weights.pop(k)


class NeoChatPreAndPostLayerWeight(Qwen2PreAndPostLayerWeight):
def __init__(self, data_type, network_config, mode):
super().__init__(data_type, network_config, mode)
return

def load_hf_weights(self, weights):
rename_weight_keys(weights)
super().load_hf_weights(weights)
return
Loading