@@ -229,7 +229,13 @@ def __init__(self, config: DeepseekV2Config):
229
229
self .softmax_scale = self .softmax_scale * mscale * mscale
230
230
self .rotary_emb = DeepseekV2YarnRotaryEmbedding (config )
231
231
232
- def forward (self , hidden_states : Tensor , paged_kv_cache : PagedKVCache , layer_id : int ):
232
+ def forward (
233
+ self ,
234
+ hidden_states : Tensor ,
235
+ paged_kv_cache : PagedKVCache ,
236
+ layer_id : int ,
237
+ query_positions : Tensor ,
238
+ ):
233
239
b , s , _ = hidden_states .shape
234
240
235
241
if self .q_lora_rank is None :
@@ -260,7 +266,7 @@ def forward(self, hidden_states: Tensor, paged_kv_cache: PagedKVCache, layer_id:
260
266
kv , [self .qk_nope_head_dim ], axis = - 1
261
267
) # (b, s, num_heads, qk_nope_head_dim), (b, s, num_heads, v_head_dim)
262
268
263
- q_pe , k_pe = self .rotary_emb (q_pe , k_pe , paged_kv_cache . get_query_positions ( s ) )
269
+ q_pe , k_pe = self .rotary_emb (q_pe , k_pe , query_positions )
264
270
265
271
@T .prim_func
266
272
def inplace_q (var_q : T .handle , var_pe : T .handle ):
@@ -471,9 +477,15 @@ def _set(layer, hint):
471
477
self .tensor_parallel_shards = config .tensor_parallel_shards
472
478
_set_tp ()
473
479
474
- def forward (self , hidden_states : Tensor , paged_kv_cache : PagedKVCache , layer_id : int ):
480
+ def forward (
481
+ self ,
482
+ hidden_states : Tensor ,
483
+ paged_kv_cache : PagedKVCache ,
484
+ layer_id : int ,
485
+ query_positions : Tensor ,
486
+ ):
475
487
out = self .input_layernorm (hidden_states )
476
- out = self .self_attn (out , paged_kv_cache , layer_id )
488
+ out = self .self_attn (out , paged_kv_cache , layer_id , query_positions )
477
489
hidden_states = self ._apply_residual (out , residual = hidden_states )
478
490
out = self .post_attention_layernorm (hidden_states )
479
491
out = self .mlp (out ) # type: ignore[operator]
@@ -499,8 +511,10 @@ def __init__(self, config: DeepseekV2Config):
499
511
500
512
def forward (self , inputs : Tensor , paged_kv_cache : PagedKVCache ):
501
513
hidden_states = inputs
514
+ print (f"inputs.shape = { inputs .shape } " )
515
+ query_positions = paged_kv_cache .get_query_positions (inputs .shape [0 ] * inputs .shape [1 ])
502
516
for layer_id , layer in enumerate (self .layers ):
503
- hidden_states = layer (hidden_states , paged_kv_cache , layer_id )
517
+ hidden_states = layer (hidden_states , paged_kv_cache , layer_id , query_positions )
504
518
hidden_states = self .norm (hidden_states )
505
519
return hidden_states
506
520
0 commit comments