9
9
import torch
10
10
import torch .nn as nn
11
11
import torch .nn .functional as F
12
-
13
- from QEfficient .transformers .cache_utils import QEffDynamicCache
14
- from QEfficient .transformers .modeling_attn_mask_utils import _create_causal_mask
15
-
16
12
from transformers .modeling_outputs import (
17
13
MoeCausalLMOutputWithPast ,
18
14
MoeModelOutputWithPast ,
19
15
)
16
+ from transformers .models .llama .modeling_llama import repeat_kv , rotate_half
20
17
21
-
22
- # Copied from transformers.models.llama.modeling_llama.repeat_kv
23
- def repeat_kv (hidden_states : torch .Tensor , n_rep : int ) -> torch .Tensor :
24
- """
25
- This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
26
- num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
27
- """
28
- batch , num_key_value_heads , slen , head_dim = hidden_states .shape
29
- if n_rep == 1 :
30
- return hidden_states
31
- hidden_states = hidden_states [:, :, None , :, :].expand (batch , num_key_value_heads , n_rep , slen , head_dim )
32
- return hidden_states .reshape (batch , num_key_value_heads * n_rep , slen , head_dim )
33
-
34
-
35
- # Copied from transformers.models.llama.modeling_llama.rotate_half
36
- def rotate_half (x ):
37
- """Rotates half the hidden dims of the input."""
38
- x1 = x [..., : x .shape [- 1 ] // 2 ]
39
- x2 = x [..., x .shape [- 1 ] // 2 :]
40
- return torch .cat ((- x2 , x1 ), dim = - 1 )
18
+ from QEfficient .transformers .cache_utils import QEffDynamicCache
19
+ from QEfficient .transformers .modeling_attn_mask_utils import _create_causal_mask
41
20
42
21
43
22
# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
@@ -70,12 +49,10 @@ def qeff_apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
70
49
71
50
72
51
class QEffGrok1MultiHeadAttention (nn .Module ):
73
- def __qeff_init__ (self ):
74
- self .layer_idx = 0
75
-
76
52
def forward (
77
53
self ,
78
54
hidden_states : torch .Tensor ,
55
+ layer_idx : int ,
79
56
attention_mask : Optional [torch .Tensor ] = None ,
80
57
position_ids : Optional [torch .LongTensor ] = None ,
81
58
past_key_value : Optional [Tuple [torch .Tensor ]] = None ,
@@ -96,7 +73,7 @@ def forward(
96
73
97
74
kv_seq_len = key_states .shape [- 2 ]
98
75
if past_key_value is not None :
99
- kv_seq_len = past_key_value .get_usable_length (kv_seq_len , self . layer_idx )
76
+ kv_seq_len = past_key_value .get_usable_length (kv_seq_len , layer_idx )
100
77
101
78
cos , sin = self .rotary_emb (value_states , seq_len = kv_seq_len )
102
79
query_states , key_states = qeff_apply_rotary_pos_emb (query_states , key_states , cos , sin , position_ids )
@@ -108,7 +85,7 @@ def forward(
108
85
"batch_index" : batch_index ,
109
86
"position_ids" : position_ids ,
110
87
} # Specific to RoPE models
111
- key_states , value_states = past_key_value .update (key_states , value_states , self . layer_idx , cache_kwargs )
88
+ key_states , value_states = past_key_value .update (key_states , value_states , layer_idx , cache_kwargs )
112
89
113
90
# repeat k/v heads if n_kv_heads < n_heads
114
91
key_states = repeat_kv (key_states , self .num_key_value_groups )
@@ -194,6 +171,7 @@ def forward(
194
171
hidden_states = self .pre_attn_norm (hidden_states )
195
172
hidden_states , attention_weights , present_key_value = self .attn (
196
173
hidden_states ,
174
+ layer_idx = self .layer_idx ,
197
175
attention_mask = attention_mask ,
198
176
position_ids = position_ids ,
199
177
past_key_value = past_key_value ,
@@ -221,6 +199,10 @@ def forward(
221
199
222
200
223
201
class QEffGrok1Model (nn .Module ):
202
+ def __qeff_init__ (self ):
203
+ for idx , layer in enumerate (self .layers ):
204
+ layer .layer_idx = idx
205
+
224
206
def forward (
225
207
self ,
226
208
input_ids : torch .LongTensor = None ,
0 commit comments