-
Notifications
You must be signed in to change notification settings - Fork 293
Add neo chat #1161
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Add neo chat #1161
Changes from all commits
8a67a47
fdc1369
e8e7416
ba44983
4d41a33
0e8845c
b48cd49
7a904f3
4b757dd
245357c
6503ac8
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,46 @@ | ||
| import torch | ||
| from lightllm.utils.dist_utils import get_current_rank_in_node | ||
| from lightllm.server.router.dynamic_prompt.shared_arr import SharedInt | ||
| from lightllm.common.kv_cache_mem_manager.mem_manager import MemoryManager | ||
|
|
||
|
|
||
| class NeoMemoryManager(MemoryManager): | ||
| def __init__(self, size, dtype, head_num, head_dim, layer_num, always_copy=False, mem_fraction=0.9): | ||
| self.size = size | ||
| self.head_num = head_num | ||
| self.head_dim = head_dim * 2 # neo kv 是[k, k_h, k_w]拼在一起的 | ||
| self.layer_num = layer_num | ||
| self.always_copy = always_copy | ||
| self.dtype = dtype | ||
| # profile the max total token num if the size is None | ||
| self.profile_size(mem_fraction) | ||
|
|
||
| self.mem_state = torch.arange( | ||
| 0, self.size, dtype=torch.int32, device="cpu", requires_grad=False, pin_memory=True | ||
| ) | ||
| self._mem_state_return = torch.arange( | ||
| 0, self.size * 3, dtype=torch.int32, device="cpu", requires_grad=False, pin_memory=True | ||
| ) | ||
| self._return_start = 0 | ||
| self.mark_start = 0 | ||
| self.mark_end = self.size | ||
|
|
||
| self.can_use_mem_size = self.size | ||
|
|
||
| # 用共享内存进行共享,router 模块读取进行精确的调度估计, nccl port 作为一个单机中单实列的标记。防止冲突。 | ||
| from lightllm.utils.envs_utils import get_unique_server_name | ||
|
|
||
| rank_in_node = get_current_rank_in_node() | ||
| self.shared_can_use_token_num = SharedInt( | ||
| f"{get_unique_server_name()}_mem_manger_can_use_token_num_{rank_in_node}" | ||
| ) | ||
|
|
||
| self.shared_can_use_token_num.set_value(self.can_use_mem_size) | ||
| self._init_buffers( | ||
| self.size, | ||
| dtype, | ||
| head_num, | ||
| self.head_dim, | ||
| layer_num, | ||
| ) | ||
| self.HOLD_TOKEN_MEMINDEX = self.size |
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -110,6 +110,8 @@ def _init_custom(self): | |||||||||||||||||||||||||||||||||||||
| rope_scaling = self.config.get("rope_scaling", None) | ||||||||||||||||||||||||||||||||||||||
| if rope_scaling is None: | ||||||||||||||||||||||||||||||||||||||
| self._init_to_get_rotary() | ||||||||||||||||||||||||||||||||||||||
| if "rope_theta_hw" in self.config: | ||||||||||||||||||||||||||||||||||||||
| self._init_to_get_hw_rotary() | ||||||||||||||||||||||||||||||||||||||
| return | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| if "rope_type" in rope_scaling: | ||||||||||||||||||||||||||||||||||||||
|
|
@@ -132,6 +134,8 @@ def _init_custom(self): | |||||||||||||||||||||||||||||||||||||
| self._init_to_get_rotary() | ||||||||||||||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||||||||||||||
| raise ValueError(f"Unknown RoPE scaling type {scaling_type}") | ||||||||||||||||||||||||||||||||||||||
| if "rope_theta_hw" in self.config: | ||||||||||||||||||||||||||||||||||||||
| self._init_to_get_hw_rotary() | ||||||||||||||||||||||||||||||||||||||
| return | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| def _init_weights(self): | ||||||||||||||||||||||||||||||||||||||
|
|
@@ -178,7 +182,7 @@ def _init_to_get_rotary(self, default_base=10000): | |||||||||||||||||||||||||||||||||||||
| rope_scaling_factor = self.config.get("rope_scaling", {}).get("factor", 1.0) | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| base = self.config.get("rope_theta", float(default_base)) | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| print(f"base is {base}") | ||||||||||||||||||||||||||||||||||||||
| if "max_sequence_length" in self.config: | ||||||||||||||||||||||||||||||||||||||
| max_seq_len = self.config["max_sequence_length"] | ||||||||||||||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||||||||||||||
|
|
@@ -211,6 +215,47 @@ def _init_to_get_rotary(self, default_base=10000): | |||||||||||||||||||||||||||||||||||||
| self._sin_cached = torch.sin(freqs).to(self.data_type).cuda() | ||||||||||||||||||||||||||||||||||||||
| return | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| def _init_to_get_hw_rotary(self, default_base=10000): | ||||||||||||||||||||||||||||||||||||||
| partial_head_dim = int(self.config.get("partial_rotary_factor", 1) * self.head_dim_ // 2) | ||||||||||||||||||||||||||||||||||||||
| if self.config.get("rope_scaling", {}) is None: | ||||||||||||||||||||||||||||||||||||||
| rope_scaling_factor = 1.0 | ||||||||||||||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||||||||||||||
| rope_scaling_factor = self.config.get("rope_scaling", {}).get("factor", 1.0) | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| base = self.config.get("rope_theta_hw", float(default_base)) | ||||||||||||||||||||||||||||||||||||||
| print(f"hw_base is {base}") | ||||||||||||||||||||||||||||||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||||||||||||||||||||||||||||||||||||||
| if "max_sequence_length" in self.config: | ||||||||||||||||||||||||||||||||||||||
| max_seq_len = self.config["max_sequence_length"] | ||||||||||||||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||||||||||||||
| max_position_embeddings = self.config.get( | ||||||||||||||||||||||||||||||||||||||
| "max_position_embeddings_hw", 2048 if base <= 10000.0 + 1e-5 else 16384 | ||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||
| max_seq_len = max_position_embeddings * rope_scaling_factor | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| # NTK | ||||||||||||||||||||||||||||||||||||||
| try: | ||||||||||||||||||||||||||||||||||||||
| ntk_alpha = float(os.environ.get("LIGHTLLM_NTK_ALPHA", 1)) | ||||||||||||||||||||||||||||||||||||||
| assert ntk_alpha >= 1 | ||||||||||||||||||||||||||||||||||||||
| if ntk_alpha > 1: | ||||||||||||||||||||||||||||||||||||||
| logger.info(f"Note: NTK enabled, alpha set to {ntk_alpha}") | ||||||||||||||||||||||||||||||||||||||
| max_seq_len *= ntk_alpha | ||||||||||||||||||||||||||||||||||||||
| base = base * (ntk_alpha ** (partial_head_dim / (partial_head_dim - 2))) # Base change formula | ||||||||||||||||||||||||||||||||||||||
| except: | ||||||||||||||||||||||||||||||||||||||
| pass | ||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+236
to
+244
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Using a bare
Suggested change
|
||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| inv_freq = 1.0 / ( | ||||||||||||||||||||||||||||||||||||||
| base ** (torch.arange(0, partial_head_dim, 2, device="cpu", dtype=torch.float32) / partial_head_dim) | ||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||
| t = ( | ||||||||||||||||||||||||||||||||||||||
| torch.arange(max(max_seq_len + 1024 * 128, self.max_seq_length), device="cpu", dtype=torch.float32) | ||||||||||||||||||||||||||||||||||||||
| / rope_scaling_factor | ||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||
| freqs = torch.outer(t, inv_freq) | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| self._hw_cos_cached = torch.cos(freqs).to(self.data_type).cuda() | ||||||||||||||||||||||||||||||||||||||
| self._hw_sin_cached = torch.sin(freqs).to(self.data_type).cuda() | ||||||||||||||||||||||||||||||||||||||
| return | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| def _init_to_get_dynamic_ntk_rotary(self): | ||||||||||||||||||||||||||||||||||||||
| partial_head_dim = int(self.config.get("partial_rotary_factor", 1) * self.head_dim_) | ||||||||||||||||||||||||||||||||||||||
| max_position_embeddings = self.config.get("max_position_embeddings", 2048) | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,159 @@ | ||
| import torch | ||
| from functools import partial | ||
| from typing import Tuple | ||
| from lightllm.models.llama.triton_kernel.rmsnorm import rmsnorm_forward | ||
| from lightllm.models.llama.triton_kernel.rotary_emb import rotary_emb_fwd | ||
| from lightllm.models.neo_chat_moe.infer_struct import NeoChatInferStateInfo | ||
| from lightllm.models.neo_chat_moe.triton_kernel.context_attention_fwd_neo import context_attention_fwd_neo | ||
| from lightllm.models.llama.triton_kernel.context_flashattention_nopad import context_attention_fwd | ||
| from lightllm.models.llama.triton_kernel.token_attention_nopad_att1 import token_att_fwd | ||
| from lightllm.models.qwen3.layer_infer.transformer_layer_infer import Qwen3TransformerLayerInfer | ||
| from lightllm.models.neo_chat.layer_weights.transformer_layer_weight import NeoChatTransformerLayerWeight | ||
| from lightllm.distributed import all_reduce | ||
| import torch.distributed as dist | ||
| from lightllm.models.llama.layer_infer.transformer_layer_infer import LlamaTransformerLayerInfer | ||
| from lightllm.models.qwen3.triton_kernel.qk_norm import qk_rmsnorm_forward | ||
|
|
||
|
|
||
| class NeoChatTransformerLayerInfer(Qwen3TransformerLayerInfer): | ||
| def __init__(self, data_type, network_config, mode): | ||
| super().__init__(data_type, network_config, mode) | ||
| return | ||
|
|
||
| def _bind_attention(self): | ||
| self._context_attention_kernel = self._context_attention_kernel | ||
| self._token_attention_kernel = self._token_decode_attention_normal | ||
| self._copy_kv_to_mem_cache = self._copy_kv_to_mem_cache_normal | ||
| return | ||
|
|
||
| def _get_qkv(self, input, infer_state: NeoChatInferStateInfo, layer_weight: NeoChatTransformerLayerWeight): | ||
| input = input.view(-1, self.embed_dim_) | ||
| q = layer_weight.q_proj.mm(input) # [T, Hq*D] | ||
|
|
||
| q_hw = layer_weight.q_hw_proj.mm(input) | ||
| q_hw = q_hw.view(-1, self.tp_q_head_num_, self.head_dim_) | ||
| q_h, q_w = q_hw.chunk(2, dim=-1) | ||
|
|
||
| k_hw = layer_weight.k_hw_proj.mm(input) | ||
| k_hw = k_hw.view(-1, self.tp_k_head_num_, self.head_dim_) | ||
| k_h, k_w = k_hw.chunk(2, dim=-1) | ||
|
|
||
| cache_kv = layer_weight.kv_proj.mm(input) # [T, (Hk+Hv)*D] | ||
|
|
||
| qk_rmsnorm_forward(q, weight=layer_weight.q_norm_weight_.weight, eps=self.eps_) | ||
|
|
||
| q_h_2d = q_h.reshape(q.shape[0], -1) | ||
| q_w_2d = q_w.reshape(q.shape[0], -1) | ||
| qk_rmsnorm_forward(q_h_2d, weight=layer_weight.q_norm_h_weight_.weight, eps=self.eps_) | ||
| qk_rmsnorm_forward(q_w_2d, weight=layer_weight.q_norm_w_weight_.weight, eps=self.eps_) | ||
| q_h = q_h_2d.view(q.shape[0], self.tp_q_head_num_, self.head_dim_ // 2) | ||
| q_w = q_w_2d.view(q.shape[0], self.tp_q_head_num_, self.head_dim_ // 2) | ||
|
|
||
| qk_rmsnorm_forward( | ||
| cache_kv[:, : self.tp_k_head_num_ * self.head_dim_], | ||
| weight=layer_weight.k_norm_weight_.weight, | ||
| eps=self.eps_, | ||
| ) | ||
|
|
||
| k_h_2d = k_h.reshape(q.shape[0], -1) # [T, Hk*(D/2)] | ||
| k_w_2d = k_w.reshape(q.shape[0], -1) | ||
| qk_rmsnorm_forward(k_h_2d, weight=layer_weight.k_norm_h_weight_.weight, eps=self.eps_) | ||
| qk_rmsnorm_forward(k_w_2d, weight=layer_weight.k_norm_w_weight_.weight, eps=self.eps_) | ||
| k_h = k_h_2d.view(q.shape[0], self.tp_k_head_num_, self.head_dim_ // 2) | ||
| k_w = k_w_2d.view(q.shape[0], self.tp_k_head_num_, self.head_dim_ // 2) | ||
|
|
||
| cache_kv = cache_kv.view(-1, (self.tp_k_head_num_ + self.tp_v_head_num_), self.head_dim_) | ||
|
|
||
| rotary_emb_fwd( | ||
| q.view(-1, self.tp_q_head_num_, self.head_dim_), | ||
| cache_kv[:, : self.tp_k_head_num_, :], | ||
| infer_state.position_cos, | ||
| infer_state.position_sin, | ||
| ) | ||
| rotary_emb_fwd( | ||
| q_h, | ||
| k_h, | ||
| infer_state.position_cos_h, | ||
| infer_state.position_sin_h, | ||
| ) | ||
| rotary_emb_fwd( | ||
| q_w, | ||
| k_w, | ||
| infer_state.position_cos_w, | ||
| infer_state.position_sin_w, | ||
| ) | ||
|
|
||
| q3 = q.view(-1, self.tp_q_head_num_, self.head_dim_) | ||
| q3 = torch.cat([q3, q_h, q_w], dim=-1) | ||
| q = q3.reshape(q3.shape[0], -1) | ||
|
|
||
| k = cache_kv[:, : self.tp_k_head_num_, :] | ||
| k = torch.cat([k, k_h, k_w], dim=-1) | ||
|
|
||
| v = cache_kv[:, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, :] | ||
| v_pad = torch.zeros((v.shape[0], v.shape[1], self.head_dim_), device=v.device, dtype=v.dtype) | ||
| v = torch.cat([v, v_pad], dim=-1) | ||
|
|
||
| cache_kv = torch.cat([k, v], dim=1) | ||
| return q, cache_kv | ||
|
|
||
| def _context_attention_kernel( | ||
| self, q, kv, infer_state: NeoChatInferStateInfo, layer_weight, out=None | ||
| ) -> torch.Tensor: | ||
| o_tensor = self.alloc_tensor(q.shape, q.dtype) if out is None else out | ||
| kv = infer_state.mem_manager.kv_buffer[self.layer_num_] | ||
| context_attention_fwd_neo( | ||
| q.view(-1, self.tp_q_head_num_, self.head_dim_ * 2), | ||
| kv[:, 0 : self.tp_k_head_num_, :], | ||
| kv[:, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, :], | ||
| o_tensor.view(-1, self.tp_q_head_num_, self.head_dim_ * 2), | ||
| infer_state.position_ids[0], # [0,0,1,2,3,3,3,4] | ||
| infer_state.b_req_idx, | ||
| infer_state.b_start_loc, | ||
| infer_state.b_seq_len, | ||
| infer_state.b_ready_cache_len, | ||
| infer_state.max_len_in_batch, | ||
| infer_state.req_manager.req_to_token_indexs, | ||
| ) | ||
| o3 = o_tensor.view(-1, self.tp_q_head_num_, self.head_dim_ * 2) | ||
| o3 = o3[:, :, : self.head_dim_].contiguous() | ||
| return o3.view(o3.shape[0], -1) | ||
|
|
||
| def _token_decode_attention_normal(self, q, infer_state: NeoChatInferStateInfo, layer_weight, out=None): | ||
| total_token_num = infer_state.total_token_num | ||
| batch_size = infer_state.batch_size | ||
|
|
||
| q_3d = q.view(batch_size, self.tp_q_head_num_, self.head_dim_ * 2) | ||
|
|
||
| att_m_tensor = self.alloc_tensor((self.tp_q_head_num_, total_token_num), torch.float32) | ||
|
|
||
| k_3d = infer_state.mem_manager.kv_buffer[self.layer_num_][:, 0 : self.tp_k_head_num_, :] | ||
| token_att_fwd( | ||
| q_3d, | ||
| k_3d, | ||
| att_m_tensor, | ||
| infer_state.req_manager.req_to_token_indexs, | ||
| infer_state.b_req_idx, | ||
| infer_state.b_start_loc, | ||
| infer_state.b_seq_len, | ||
| infer_state.max_len_in_batch, | ||
| ) | ||
|
|
||
| from lightllm.models.llama.triton_kernel.token_attention_softmax_and_reducev import token_softmax_reducev_fwd | ||
|
|
||
| v_3d = infer_state.mem_manager.kv_buffer[self.layer_num_][ | ||
| :, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, : self.head_dim_ | ||
| ] | ||
|
|
||
| o_3d = self.alloc_tensor((batch_size, self.tp_q_head_num_, self.head_dim_), q.dtype) if out is None else out | ||
|
|
||
| token_softmax_reducev_fwd( | ||
| att_m_tensor, | ||
| v_3d, | ||
| o_3d, | ||
| infer_state.req_manager.req_to_token_indexs, | ||
| infer_state.b_req_idx, | ||
| infer_state.b_start_loc, | ||
| infer_state.b_seq_len, | ||
| ) | ||
| return o_3d.view(batch_size, -1) |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,23 @@ | ||
| import torch | ||
| import numpy as np | ||
| from lightllm.models.qwen2.layer_weights.pre_and_post_layer_weight import Qwen2PreAndPostLayerWeight | ||
|
|
||
| # add key: language_model.xxx -> xxx | ||
| # only change keys at PreAndPostLayerWeight load, TransformLayerWeight is correct now | ||
| def rename_weight_keys(weights): | ||
| prefix = "language_model." | ||
| keys = list(weights.keys()) | ||
| for k in keys: | ||
| if prefix in k: | ||
| weights[k.replace(prefix, "")] = weights.pop(k) | ||
|
|
||
|
|
||
| class NeoChatPreAndPostLayerWeight(Qwen2PreAndPostLayerWeight): | ||
| def __init__(self, data_type, network_config, mode): | ||
| super().__init__(data_type, network_config, mode) | ||
| return | ||
|
|
||
| def load_hf_weights(self, weights): | ||
| rename_weight_keys(weights) | ||
| super().load_hf_weights(weights) | ||
| return |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This
printstatement appears to be for debugging purposes and should be removed before merging.