Skip to content

Commit fc8e731

Browse files
committed
Added suppport of Grok1 model
Signed-off-by: Amit Raj <[email protected]>
1 parent 598b83f commit fc8e731

File tree

7 files changed

+437
-25
lines changed

7 files changed

+437
-25
lines changed

QEfficient/base/pytorch_transforms.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,10 @@ def apply(cls, model: nn.Module) -> Tuple[nn.Module, bool]:
107107
):
108108
for orig_method_name, mapped_method in repl_method_map.items():
109109
setattr(module, orig_method_name, MethodType(mapped_method, module))
110+
111+
if hasattr(module, "__qeff_init__"):
112+
module.__qeff_init__()
113+
110114
transformed = True
111115

112116
return model, transformed
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
# -----------------------------------------------------------------------------
2+
#
3+
# Copyright (c) 2024 Qualcomm Innovation Center, Inc. All rights reserved.
4+
# SPDX-License-Identifier: BSD-3-Clause
5+
#
6+
# -----------------------------------------------------------------------------
7+
Lines changed: 377 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,377 @@
1+
# -----------------------------------------------------------------------------
2+
#
3+
# Copyright (c) 2024 Qualcomm Innovation Center, Inc. All rights reserved.
4+
# SPDX-License-Identifier: BSD-3-Clause
5+
#
6+
# -----------------------------------------------------------------------------
7+
from typing import List, Optional, Tuple, Union
8+
9+
import torch
10+
import torch.nn as nn
11+
import torch.nn.functional as F
12+
13+
from QEfficient.transformers.cache_utils import QEffDynamicCache
14+
from QEfficient.transformers.modeling_attn_mask_utils import _create_causal_mask
15+
16+
from transformers.modeling_outputs import (
17+
MoeCausalLMOutputWithPast,
18+
MoeModelOutputWithPast,
19+
)
20+
21+
22+
# Copied from transformers.models.llama.modeling_llama.repeat_kv
23+
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
24+
"""
25+
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
26+
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
27+
"""
28+
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
29+
if n_rep == 1:
30+
return hidden_states
31+
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
32+
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
33+
34+
35+
# Copied from transformers.models.llama.modeling_llama.rotate_half
36+
def rotate_half(x):
37+
"""Rotates half the hidden dims of the input."""
38+
x1 = x[..., : x.shape[-1] // 2]
39+
x2 = x[..., x.shape[-1] // 2 :]
40+
return torch.cat((-x2, x1), dim=-1)
41+
42+
43+
# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
44+
def qeff_apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
45+
"""Applies Rotary Position Embedding to the query and key tensors.
46+
47+
Args:
48+
q (`torch.Tensor`): The query tensor.
49+
k (`torch.Tensor`): The key tensor.
50+
cos (`torch.Tensor`): The cosine part of the rotary embedding.
51+
sin (`torch.Tensor`): The sine part of the rotary embedding.
52+
position_ids (`torch.Tensor`):
53+
The position indices of the tokens corresponding to the query and key tensors. For example, this can be
54+
used to pass offsetted position ids when working with a KV-cache.
55+
unsqueeze_dim (`int`, *optional*, defaults to 1):
56+
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
57+
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
58+
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
59+
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
60+
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
61+
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
62+
Returns:
63+
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
64+
"""
65+
cos = cos[position_ids].unsqueeze(unsqueeze_dim)
66+
sin = sin[position_ids].unsqueeze(unsqueeze_dim)
67+
q_embed = (q * cos) + (rotate_half(q) * sin)
68+
k_embed = (k * cos) + (rotate_half(k) * sin)
69+
return q_embed, k_embed
70+
71+
72+
class QEffGrok1MultiHeadAttention(nn.Module):
73+
def __qeff_init__(self):
74+
self.layer_idx = 0
75+
76+
def forward(
77+
self,
78+
hidden_states: torch.Tensor,
79+
attention_mask: Optional[torch.Tensor] = None,
80+
position_ids: Optional[torch.LongTensor] = None,
81+
past_key_value: Optional[Tuple[torch.Tensor]] = None,
82+
batch_index: Optional[torch.LongTensor] = None,
83+
output_attentions: bool = False,
84+
use_cache: bool = False,
85+
**kwargs,
86+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
87+
bsz, q_len, _ = hidden_states.size()
88+
89+
query_states = self.q_proj(hidden_states)
90+
key_states = self.k_proj(hidden_states)
91+
value_states = self.v_proj(hidden_states)
92+
93+
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
94+
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
95+
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
96+
97+
kv_seq_len = key_states.shape[-2]
98+
if past_key_value is not None:
99+
kv_seq_len = past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
100+
101+
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
102+
query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
103+
104+
if past_key_value is not None:
105+
cache_kwargs = {
106+
"sin": sin,
107+
"cos": cos,
108+
"batch_index": batch_index,
109+
"position_ids": position_ids,
110+
} # Specific to RoPE models
111+
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
112+
113+
# repeat k/v heads if n_kv_heads < n_heads
114+
key_states = repeat_kv(key_states, self.num_key_value_groups)
115+
value_states = repeat_kv(value_states, self.num_key_value_groups)
116+
117+
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)).to(torch.float)
118+
attn_weights = attn_weights * self.attn_output_multiplier
119+
attn_weights = self.max_attn_val * F.tanh(attn_weights / self.max_attn_val)
120+
121+
if attention_mask is not None:
122+
attn_weights = torch.where(attention_mask, torch.tensor(-10000.0, dtype=torch.float32), attn_weights)
123+
124+
attn_weights = F.softmax(attn_weights, dim=-1).to(query_states.dtype)
125+
attn_output = torch.matmul(attn_weights, value_states)
126+
127+
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
128+
raise ValueError(
129+
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
130+
f" {attn_output.size()}"
131+
)
132+
133+
attn_output = attn_output.transpose(1, 2).contiguous()
134+
attn_output = attn_output.reshape(bsz, q_len, -1)
135+
136+
attn_output = self.o_proj(attn_output)
137+
138+
if not output_attentions:
139+
attn_weights = None
140+
141+
return attn_output, attn_weights, past_key_value
142+
143+
144+
class QEffGrok1MoeBlock(nn.Module):
145+
def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor]:
146+
batch_size, sequence_length, hidden_dim = hidden_states.shape
147+
hidden_states = hidden_states.view(-1, hidden_dim)
148+
# router_logits: (batch * sequence_length, n_experts)
149+
router_logits = self.gate(hidden_states)
150+
151+
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
152+
routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
153+
# we cast back to the input dtype
154+
routing_weights = routing_weights.to(hidden_states.dtype)
155+
156+
final_hidden_states = torch.zeros(
157+
(batch_size * sequence_length, hidden_dim),
158+
dtype=hidden_states.dtype,
159+
device=hidden_states.device,
160+
)
161+
# One hot encode the selected experts to create an expert mask
162+
# this will be used to easily index which expert is going to be sollicitated
163+
expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0)
164+
165+
# Loop over all available experts in the model and perform the computation on each expert
166+
for expert_idx in range(self.num_experts):
167+
expert_layer = self.experts[expert_idx]
168+
expert_mask_tr = expert_mask[expert_idx].transpose(0, 1)
169+
current_hidden_states = expert_layer(hidden_states) * (((routing_weights * expert_mask_tr).sum(1))[:, None])
170+
current_hidden_states = torch.where(
171+
(routing_weights * expert_mask_tr).sum(1).to(torch.bool)[:, None],
172+
current_hidden_states,
173+
torch.tensor(0.0),
174+
)
175+
final_hidden_states = final_hidden_states + current_hidden_states
176+
final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim)
177+
return final_hidden_states, router_logits
178+
179+
180+
class QEffGrok1DecoderLayer(nn.Module):
181+
def forward(
182+
self,
183+
hidden_states: torch.Tensor,
184+
attention_mask: Optional[torch.Tensor] = None,
185+
position_ids: Optional[torch.LongTensor] = None,
186+
past_key_value: Optional[Tuple[torch.Tensor]] = None,
187+
batch_index: Optional[torch.LongTensor] = None,
188+
output_attentions: Optional[bool] = False,
189+
output_router_logits: Optional[bool] = False,
190+
use_cache: Optional[bool] = False,
191+
**kwargs,
192+
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
193+
residual = hidden_states
194+
hidden_states = self.pre_attn_norm(hidden_states)
195+
hidden_states, attention_weights, present_key_value = self.attn(
196+
hidden_states,
197+
attention_mask=attention_mask,
198+
position_ids=position_ids,
199+
past_key_value=past_key_value,
200+
batch_index=batch_index,
201+
output_attentions=output_attentions,
202+
use_cache=use_cache,
203+
)
204+
hidden_states = self.post_attn_norm(hidden_states)
205+
hidden_states = residual + hidden_states
206+
207+
residual = hidden_states
208+
hidden_states = self.pre_moe_norm(hidden_states)
209+
hidden_states, router_logits = self.moe_block(hidden_states)
210+
hidden_states = self.post_moe_norm(hidden_states)
211+
hidden_states = residual + hidden_states
212+
213+
outputs = (hidden_states,)
214+
if output_attentions:
215+
outputs += (attention_weights,)
216+
if use_cache:
217+
outputs += (present_key_value,)
218+
if output_router_logits:
219+
outputs += (router_logits,)
220+
return outputs
221+
222+
223+
class QEffGrok1Model(nn.Module):
224+
def forward(
225+
self,
226+
input_ids: torch.LongTensor = None,
227+
attention_mask: Optional[torch.Tensor] = None,
228+
position_ids: Optional[torch.LongTensor] = None,
229+
past_key_values: Optional[List[torch.FloatTensor]] = None,
230+
batch_index: Optional[torch.LongTensor] = None,
231+
inputs_embeds: Optional[torch.FloatTensor] = None,
232+
use_cache: Optional[bool] = None,
233+
output_attentions: Optional[bool] = None,
234+
output_hidden_states: Optional[bool] = None,
235+
output_router_logits: Optional[bool] = None,
236+
return_dict: Optional[bool] = None,
237+
) -> Union[Tuple, MoeModelOutputWithPast]:
238+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
239+
output_hidden_states = (
240+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
241+
)
242+
use_cache = use_cache if use_cache is not None else self.config.use_cache
243+
244+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
245+
246+
# retrieve input_ids and inputs_embeds
247+
if input_ids is not None and inputs_embeds is not None:
248+
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
249+
elif input_ids is not None:
250+
batch_size, seq_length = input_ids.shape[:2]
251+
elif inputs_embeds is not None:
252+
batch_size, seq_length = inputs_embeds.shape[:2]
253+
else:
254+
raise ValueError("You have to specify either input_ids or inputs_embeds")
255+
256+
seq_length_with_past = seq_length
257+
past_key_values_length = 0
258+
if past_key_values is not None:
259+
past_key_values_length = past_key_values[0][0].shape[2]
260+
seq_length_with_past = seq_length_with_past + past_key_values_length
261+
262+
past_key_values = QEffDynamicCache.from_legacy_cache(past_key_values)
263+
264+
if inputs_embeds is None:
265+
inputs_embeds = self.embed_tokens(input_ids)
266+
inputs_embeds = inputs_embeds * self.embedding_multiplier_scale
267+
268+
attention_mask = _create_causal_mask(position_ids=position_ids, target_length=past_key_values_length)
269+
270+
# embed positions
271+
hidden_states = inputs_embeds
272+
273+
# decoder layers
274+
all_hidden_states = () if output_hidden_states else None
275+
all_self_attns = () if output_attentions else None
276+
all_router_logits = () if output_router_logits else None
277+
next_decoder_cache = () if use_cache else None
278+
279+
for idx, decoder_layer in enumerate(self.layers):
280+
if output_hidden_states:
281+
all_hidden_states += (hidden_states,)
282+
283+
layer_outputs = decoder_layer(
284+
hidden_states,
285+
attention_mask=attention_mask,
286+
position_ids=position_ids,
287+
past_key_value=past_key_values,
288+
batch_index=batch_index,
289+
output_attentions=output_attentions,
290+
use_cache=use_cache,
291+
)
292+
293+
hidden_states = layer_outputs[0]
294+
295+
if use_cache:
296+
next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
297+
298+
if output_attentions:
299+
all_self_attns += (layer_outputs[1],)
300+
301+
if output_router_logits:
302+
all_router_logits += (layer_outputs[-1],)
303+
304+
hidden_states = self.norm(hidden_states)
305+
306+
# add hidden states from the last decoder layer
307+
if output_hidden_states:
308+
all_hidden_states += (hidden_states,)
309+
310+
past_key_values = past_key_values.to_legacy_cache()
311+
312+
return MoeModelOutputWithPast(
313+
last_hidden_state=hidden_states,
314+
past_key_values=past_key_values,
315+
hidden_states=all_hidden_states,
316+
attentions=all_self_attns,
317+
router_logits=all_router_logits,
318+
)
319+
320+
321+
class QEffGrok1ModelForCausalLM(nn.Module):
322+
def forward(
323+
self,
324+
input_ids: torch.LongTensor = None,
325+
attention_mask: Optional[torch.Tensor] = None,
326+
position_ids: Optional[torch.LongTensor] = None,
327+
past_key_values: Optional[List[torch.FloatTensor]] = None,
328+
batch_index: Optional[torch.LongTensor] = None,
329+
inputs_embeds: Optional[torch.FloatTensor] = None,
330+
labels: Optional[torch.LongTensor] = None,
331+
use_cache: Optional[bool] = None,
332+
output_attentions: Optional[bool] = None,
333+
output_hidden_states: Optional[bool] = None,
334+
output_router_logits: Optional[bool] = None,
335+
return_dict: Optional[bool] = None,
336+
**kwargs,
337+
):
338+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
339+
output_router_logits = (
340+
output_router_logits if output_router_logits is not None else self.config.output_router_logits
341+
)
342+
343+
output_hidden_states = (
344+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
345+
)
346+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
347+
348+
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
349+
outputs = self.model(
350+
input_ids=input_ids,
351+
attention_mask=attention_mask,
352+
position_ids=position_ids,
353+
past_key_values=past_key_values,
354+
batch_index=batch_index,
355+
inputs_embeds=inputs_embeds,
356+
use_cache=use_cache,
357+
output_attentions=output_attentions,
358+
output_hidden_states=output_hidden_states,
359+
output_router_logits=output_router_logits,
360+
return_dict=return_dict,
361+
**kwargs,
362+
)
363+
364+
# Cast to int32 to avoid ONNXRT issue
365+
logit_idx = position_ids.to(torch.int32).argmax(1, keepdim=True)
366+
hidden_states = outputs[0][torch.arange(position_ids.shape[0]).view(-1, 1), logit_idx]
367+
logits = self.lm_head(hidden_states)
368+
logits = logits * self.output_multiplier_scale
369+
logits = logits.float()
370+
371+
return MoeCausalLMOutputWithPast(
372+
logits=logits,
373+
past_key_values=outputs.past_key_values,
374+
hidden_states=outputs.hidden_states,
375+
attentions=outputs.attentions,
376+
router_logits=outputs.router_logits,
377+
)

0 commit comments

Comments
 (0)