Skip to content

Commit f990465

Browse files
committed
Minor improvement
Signed-off-by: Amit Raj <[email protected]>
1 parent eb97c4a commit f990465

File tree

2 files changed

+15
-31
lines changed

2 files changed

+15
-31
lines changed

QEfficient/transformers/models/grok_1/modeling_grok1.py

+11-29
Original file line numberDiff line numberDiff line change
@@ -9,35 +9,14 @@
99
import torch
1010
import torch.nn as nn
1111
import torch.nn.functional as F
12-
13-
from QEfficient.transformers.cache_utils import QEffDynamicCache
14-
from QEfficient.transformers.modeling_attn_mask_utils import _create_causal_mask
15-
1612
from transformers.modeling_outputs import (
1713
MoeCausalLMOutputWithPast,
1814
MoeModelOutputWithPast,
1915
)
16+
from transformers.models.llama.modeling_llama import repeat_kv, rotate_half
2017

21-
22-
# Copied from transformers.models.llama.modeling_llama.repeat_kv
23-
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
24-
"""
25-
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
26-
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
27-
"""
28-
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
29-
if n_rep == 1:
30-
return hidden_states
31-
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
32-
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
33-
34-
35-
# Copied from transformers.models.llama.modeling_llama.rotate_half
36-
def rotate_half(x):
37-
"""Rotates half the hidden dims of the input."""
38-
x1 = x[..., : x.shape[-1] // 2]
39-
x2 = x[..., x.shape[-1] // 2 :]
40-
return torch.cat((-x2, x1), dim=-1)
18+
from QEfficient.transformers.cache_utils import QEffDynamicCache
19+
from QEfficient.transformers.modeling_attn_mask_utils import _create_causal_mask
4120

4221

4322
# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
@@ -70,12 +49,10 @@ def qeff_apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
7049

7150

7251
class QEffGrok1MultiHeadAttention(nn.Module):
73-
def __qeff_init__(self):
74-
self.layer_idx = 0
75-
7652
def forward(
7753
self,
7854
hidden_states: torch.Tensor,
55+
layer_idx: int,
7956
attention_mask: Optional[torch.Tensor] = None,
8057
position_ids: Optional[torch.LongTensor] = None,
8158
past_key_value: Optional[Tuple[torch.Tensor]] = None,
@@ -96,7 +73,7 @@ def forward(
9673

9774
kv_seq_len = key_states.shape[-2]
9875
if past_key_value is not None:
99-
kv_seq_len = past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
76+
kv_seq_len = past_key_value.get_usable_length(kv_seq_len, layer_idx)
10077

10178
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
10279
query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
@@ -108,7 +85,7 @@ def forward(
10885
"batch_index": batch_index,
10986
"position_ids": position_ids,
11087
} # Specific to RoPE models
111-
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
88+
key_states, value_states = past_key_value.update(key_states, value_states, layer_idx, cache_kwargs)
11289

11390
# repeat k/v heads if n_kv_heads < n_heads
11491
key_states = repeat_kv(key_states, self.num_key_value_groups)
@@ -194,6 +171,7 @@ def forward(
194171
hidden_states = self.pre_attn_norm(hidden_states)
195172
hidden_states, attention_weights, present_key_value = self.attn(
196173
hidden_states,
174+
layer_idx=self.layer_idx,
197175
attention_mask=attention_mask,
198176
position_ids=position_ids,
199177
past_key_value=past_key_value,
@@ -221,6 +199,10 @@ def forward(
221199

222200

223201
class QEffGrok1Model(nn.Module):
202+
def __qeff_init__(self):
203+
for idx, layer in enumerate(self.layers):
204+
layer.layer_idx = idx
205+
224206
def forward(
225207
self,
226208
input_ids: torch.LongTensor = None,

QEfficient/transformers/models/pytorch_transforms.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -444,12 +444,14 @@ class KVCacheModuleMethodMapperTransform(ModuleMethodMapperTransform):
444444
"InternVisionEmbeddings": {"forward": QEffInternVisionEmbeddings.forward},
445445
# #Mapping for grok1 model
446446
"Grok1ModelForCausalLM": {"forward": QEffGrok1ModelForCausalLM.forward},
447-
"Grok1Model": {"forward": QEffGrok1Model.forward},
447+
"Grok1Model": {
448+
"forward": QEffGrok1Model.forward,
449+
"__qeff_init__": QEffGrok1Model.__qeff_init__,
450+
},
448451
"DecoderLayer": {"forward": QEffGrok1DecoderLayer.forward},
449452
"MoeBlock": {"forward": QEffGrok1MoeBlock.forward},
450453
"MultiHeadAttention": {
451454
"forward": QEffGrok1MultiHeadAttention.forward,
452-
"__qeff_init__": QEffGrok1MultiHeadAttention.__qeff_init__,
453455
},
454456
}
455457

0 commit comments

Comments
 (0)