Skip to content

Commit cf7ae82

Browse files
authored
[Model] Reuse RoPE positions for Deepseek-v2 model (#3084)
This PR updates the Deepseek-v2 model implementation with the updated RoPE position arrays. Prior to this PR, we will query the RoPE positions for every single layer, while in fact these arrays can be reused and thus only one query is sufficient.
1 parent bf70bea commit cf7ae82

File tree

1 file changed

+19
-5
lines changed

1 file changed

+19
-5
lines changed

python/mlc_llm/model/deepseek_v2/deepseek_v2_model.py

+19-5
Original file line numberDiff line numberDiff line change
@@ -229,7 +229,13 @@ def __init__(self, config: DeepseekV2Config):
229229
self.softmax_scale = self.softmax_scale * mscale * mscale
230230
self.rotary_emb = DeepseekV2YarnRotaryEmbedding(config)
231231

232-
def forward(self, hidden_states: Tensor, paged_kv_cache: PagedKVCache, layer_id: int):
232+
def forward(
233+
self,
234+
hidden_states: Tensor,
235+
paged_kv_cache: PagedKVCache,
236+
layer_id: int,
237+
query_positions: Tensor,
238+
):
233239
b, s, _ = hidden_states.shape
234240

235241
if self.q_lora_rank is None:
@@ -260,7 +266,7 @@ def forward(self, hidden_states: Tensor, paged_kv_cache: PagedKVCache, layer_id:
260266
kv, [self.qk_nope_head_dim], axis=-1
261267
) # (b, s, num_heads, qk_nope_head_dim), (b, s, num_heads, v_head_dim)
262268

263-
q_pe, k_pe = self.rotary_emb(q_pe, k_pe, paged_kv_cache.get_query_positions(s))
269+
q_pe, k_pe = self.rotary_emb(q_pe, k_pe, query_positions)
264270

265271
@T.prim_func
266272
def inplace_q(var_q: T.handle, var_pe: T.handle):
@@ -471,9 +477,15 @@ def _set(layer, hint):
471477
self.tensor_parallel_shards = config.tensor_parallel_shards
472478
_set_tp()
473479

474-
def forward(self, hidden_states: Tensor, paged_kv_cache: PagedKVCache, layer_id: int):
480+
def forward(
481+
self,
482+
hidden_states: Tensor,
483+
paged_kv_cache: PagedKVCache,
484+
layer_id: int,
485+
query_positions: Tensor,
486+
):
475487
out = self.input_layernorm(hidden_states)
476-
out = self.self_attn(out, paged_kv_cache, layer_id)
488+
out = self.self_attn(out, paged_kv_cache, layer_id, query_positions)
477489
hidden_states = self._apply_residual(out, residual=hidden_states)
478490
out = self.post_attention_layernorm(hidden_states)
479491
out = self.mlp(out) # type: ignore[operator]
@@ -499,8 +511,10 @@ def __init__(self, config: DeepseekV2Config):
499511

500512
def forward(self, inputs: Tensor, paged_kv_cache: PagedKVCache):
501513
hidden_states = inputs
514+
print(f"inputs.shape = {inputs.shape}")
515+
query_positions = paged_kv_cache.get_query_positions(inputs.shape[0] * inputs.shape[1])
502516
for layer_id, layer in enumerate(self.layers):
503-
hidden_states = layer(hidden_states, paged_kv_cache, layer_id)
517+
hidden_states = layer(hidden_states, paged_kv_cache, layer_id, query_positions)
504518
hidden_states = self.norm(hidden_states)
505519
return hidden_states
506520

0 commit comments

Comments
 (0)