|
| 1 | +# ----------------------------------------------------------------------------- |
| 2 | +# |
| 3 | +# Copyright (c) 2024 Qualcomm Innovation Center, Inc. All rights reserved. |
| 4 | +# SPDX-License-Identifier: BSD-3-Clause |
| 5 | +# |
| 6 | +# ----------------------------------------------------------------------------- |
| 7 | +from typing import List, Optional, Tuple, Union |
| 8 | + |
| 9 | +import torch |
| 10 | +import torch.nn as nn |
| 11 | +import torch.nn.functional as F |
| 12 | + |
| 13 | +from QEfficient.transformers.cache_utils import QEffDynamicCache |
| 14 | +from QEfficient.transformers.modeling_attn_mask_utils import _create_causal_mask |
| 15 | + |
| 16 | +from transformers.modeling_outputs import ( |
| 17 | + MoeCausalLMOutputWithPast, |
| 18 | + MoeModelOutputWithPast, |
| 19 | +) |
| 20 | + |
| 21 | + |
| 22 | +# Copied from transformers.models.llama.modeling_llama.repeat_kv |
| 23 | +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: |
| 24 | + """ |
| 25 | + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, |
| 26 | + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) |
| 27 | + """ |
| 28 | + batch, num_key_value_heads, slen, head_dim = hidden_states.shape |
| 29 | + if n_rep == 1: |
| 30 | + return hidden_states |
| 31 | + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) |
| 32 | + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) |
| 33 | + |
| 34 | + |
| 35 | +# Copied from transformers.models.llama.modeling_llama.rotate_half |
| 36 | +def rotate_half(x): |
| 37 | + """Rotates half the hidden dims of the input.""" |
| 38 | + x1 = x[..., : x.shape[-1] // 2] |
| 39 | + x2 = x[..., x.shape[-1] // 2 :] |
| 40 | + return torch.cat((-x2, x1), dim=-1) |
| 41 | + |
| 42 | + |
| 43 | +# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb |
| 44 | +def qeff_apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): |
| 45 | + """Applies Rotary Position Embedding to the query and key tensors. |
| 46 | +
|
| 47 | + Args: |
| 48 | + q (`torch.Tensor`): The query tensor. |
| 49 | + k (`torch.Tensor`): The key tensor. |
| 50 | + cos (`torch.Tensor`): The cosine part of the rotary embedding. |
| 51 | + sin (`torch.Tensor`): The sine part of the rotary embedding. |
| 52 | + position_ids (`torch.Tensor`): |
| 53 | + The position indices of the tokens corresponding to the query and key tensors. For example, this can be |
| 54 | + used to pass offsetted position ids when working with a KV-cache. |
| 55 | + unsqueeze_dim (`int`, *optional*, defaults to 1): |
| 56 | + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and |
| 57 | + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note |
| 58 | + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and |
| 59 | + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes |
| 60 | + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have |
| 61 | + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. |
| 62 | + Returns: |
| 63 | + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. |
| 64 | + """ |
| 65 | + cos = cos[position_ids].unsqueeze(unsqueeze_dim) |
| 66 | + sin = sin[position_ids].unsqueeze(unsqueeze_dim) |
| 67 | + q_embed = (q * cos) + (rotate_half(q) * sin) |
| 68 | + k_embed = (k * cos) + (rotate_half(k) * sin) |
| 69 | + return q_embed, k_embed |
| 70 | + |
| 71 | + |
| 72 | +class QEffGrok1MultiHeadAttention(nn.Module): |
| 73 | + def __qeff_init__(self): |
| 74 | + self.layer_idx = 0 |
| 75 | + |
| 76 | + def forward( |
| 77 | + self, |
| 78 | + hidden_states: torch.Tensor, |
| 79 | + attention_mask: Optional[torch.Tensor] = None, |
| 80 | + position_ids: Optional[torch.LongTensor] = None, |
| 81 | + past_key_value: Optional[Tuple[torch.Tensor]] = None, |
| 82 | + batch_index: Optional[torch.LongTensor] = None, |
| 83 | + output_attentions: bool = False, |
| 84 | + use_cache: bool = False, |
| 85 | + **kwargs, |
| 86 | + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: |
| 87 | + bsz, q_len, _ = hidden_states.size() |
| 88 | + |
| 89 | + query_states = self.q_proj(hidden_states) |
| 90 | + key_states = self.k_proj(hidden_states) |
| 91 | + value_states = self.v_proj(hidden_states) |
| 92 | + |
| 93 | + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) |
| 94 | + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) |
| 95 | + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) |
| 96 | + |
| 97 | + kv_seq_len = key_states.shape[-2] |
| 98 | + if past_key_value is not None: |
| 99 | + kv_seq_len = past_key_value.get_usable_length(kv_seq_len, self.layer_idx) |
| 100 | + |
| 101 | + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) |
| 102 | + query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) |
| 103 | + |
| 104 | + if past_key_value is not None: |
| 105 | + cache_kwargs = { |
| 106 | + "sin": sin, |
| 107 | + "cos": cos, |
| 108 | + "batch_index": batch_index, |
| 109 | + "position_ids": position_ids, |
| 110 | + } # Specific to RoPE models |
| 111 | + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) |
| 112 | + |
| 113 | + # repeat k/v heads if n_kv_heads < n_heads |
| 114 | + key_states = repeat_kv(key_states, self.num_key_value_groups) |
| 115 | + value_states = repeat_kv(value_states, self.num_key_value_groups) |
| 116 | + |
| 117 | + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)).to(torch.float) |
| 118 | + attn_weights = attn_weights * self.attn_output_multiplier |
| 119 | + attn_weights = self.max_attn_val * F.tanh(attn_weights / self.max_attn_val) |
| 120 | + |
| 121 | + if attention_mask is not None: |
| 122 | + attn_weights = torch.where(attention_mask, torch.tensor(-10000.0, dtype=torch.float32), attn_weights) |
| 123 | + |
| 124 | + attn_weights = F.softmax(attn_weights, dim=-1).to(query_states.dtype) |
| 125 | + attn_output = torch.matmul(attn_weights, value_states) |
| 126 | + |
| 127 | + if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): |
| 128 | + raise ValueError( |
| 129 | + f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" |
| 130 | + f" {attn_output.size()}" |
| 131 | + ) |
| 132 | + |
| 133 | + attn_output = attn_output.transpose(1, 2).contiguous() |
| 134 | + attn_output = attn_output.reshape(bsz, q_len, -1) |
| 135 | + |
| 136 | + attn_output = self.o_proj(attn_output) |
| 137 | + |
| 138 | + if not output_attentions: |
| 139 | + attn_weights = None |
| 140 | + |
| 141 | + return attn_output, attn_weights, past_key_value |
| 142 | + |
| 143 | + |
| 144 | +class QEffGrok1MoeBlock(nn.Module): |
| 145 | + def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor]: |
| 146 | + batch_size, sequence_length, hidden_dim = hidden_states.shape |
| 147 | + hidden_states = hidden_states.view(-1, hidden_dim) |
| 148 | + # router_logits: (batch * sequence_length, n_experts) |
| 149 | + router_logits = self.gate(hidden_states) |
| 150 | + |
| 151 | + routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float) |
| 152 | + routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1) |
| 153 | + # we cast back to the input dtype |
| 154 | + routing_weights = routing_weights.to(hidden_states.dtype) |
| 155 | + |
| 156 | + final_hidden_states = torch.zeros( |
| 157 | + (batch_size * sequence_length, hidden_dim), |
| 158 | + dtype=hidden_states.dtype, |
| 159 | + device=hidden_states.device, |
| 160 | + ) |
| 161 | + # One hot encode the selected experts to create an expert mask |
| 162 | + # this will be used to easily index which expert is going to be sollicitated |
| 163 | + expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0) |
| 164 | + |
| 165 | + # Loop over all available experts in the model and perform the computation on each expert |
| 166 | + for expert_idx in range(self.num_experts): |
| 167 | + expert_layer = self.experts[expert_idx] |
| 168 | + expert_mask_tr = expert_mask[expert_idx].transpose(0, 1) |
| 169 | + current_hidden_states = expert_layer(hidden_states) * (((routing_weights * expert_mask_tr).sum(1))[:, None]) |
| 170 | + current_hidden_states = torch.where( |
| 171 | + (routing_weights * expert_mask_tr).sum(1).to(torch.bool)[:, None], |
| 172 | + current_hidden_states, |
| 173 | + torch.tensor(0.0), |
| 174 | + ) |
| 175 | + final_hidden_states = final_hidden_states + current_hidden_states |
| 176 | + final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim) |
| 177 | + return final_hidden_states, router_logits |
| 178 | + |
| 179 | + |
| 180 | +class QEffGrok1DecoderLayer(nn.Module): |
| 181 | + def forward( |
| 182 | + self, |
| 183 | + hidden_states: torch.Tensor, |
| 184 | + attention_mask: Optional[torch.Tensor] = None, |
| 185 | + position_ids: Optional[torch.LongTensor] = None, |
| 186 | + past_key_value: Optional[Tuple[torch.Tensor]] = None, |
| 187 | + batch_index: Optional[torch.LongTensor] = None, |
| 188 | + output_attentions: Optional[bool] = False, |
| 189 | + output_router_logits: Optional[bool] = False, |
| 190 | + use_cache: Optional[bool] = False, |
| 191 | + **kwargs, |
| 192 | + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: |
| 193 | + residual = hidden_states |
| 194 | + hidden_states = self.pre_attn_norm(hidden_states) |
| 195 | + hidden_states, attention_weights, present_key_value = self.attn( |
| 196 | + hidden_states, |
| 197 | + attention_mask=attention_mask, |
| 198 | + position_ids=position_ids, |
| 199 | + past_key_value=past_key_value, |
| 200 | + batch_index=batch_index, |
| 201 | + output_attentions=output_attentions, |
| 202 | + use_cache=use_cache, |
| 203 | + ) |
| 204 | + hidden_states = self.post_attn_norm(hidden_states) |
| 205 | + hidden_states = residual + hidden_states |
| 206 | + |
| 207 | + residual = hidden_states |
| 208 | + hidden_states = self.pre_moe_norm(hidden_states) |
| 209 | + hidden_states, router_logits = self.moe_block(hidden_states) |
| 210 | + hidden_states = self.post_moe_norm(hidden_states) |
| 211 | + hidden_states = residual + hidden_states |
| 212 | + |
| 213 | + outputs = (hidden_states,) |
| 214 | + if output_attentions: |
| 215 | + outputs += (attention_weights,) |
| 216 | + if use_cache: |
| 217 | + outputs += (present_key_value,) |
| 218 | + if output_router_logits: |
| 219 | + outputs += (router_logits,) |
| 220 | + return outputs |
| 221 | + |
| 222 | + |
| 223 | +class QEffGrok1Model(nn.Module): |
| 224 | + def forward( |
| 225 | + self, |
| 226 | + input_ids: torch.LongTensor = None, |
| 227 | + attention_mask: Optional[torch.Tensor] = None, |
| 228 | + position_ids: Optional[torch.LongTensor] = None, |
| 229 | + past_key_values: Optional[List[torch.FloatTensor]] = None, |
| 230 | + batch_index: Optional[torch.LongTensor] = None, |
| 231 | + inputs_embeds: Optional[torch.FloatTensor] = None, |
| 232 | + use_cache: Optional[bool] = None, |
| 233 | + output_attentions: Optional[bool] = None, |
| 234 | + output_hidden_states: Optional[bool] = None, |
| 235 | + output_router_logits: Optional[bool] = None, |
| 236 | + return_dict: Optional[bool] = None, |
| 237 | + ) -> Union[Tuple, MoeModelOutputWithPast]: |
| 238 | + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
| 239 | + output_hidden_states = ( |
| 240 | + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
| 241 | + ) |
| 242 | + use_cache = use_cache if use_cache is not None else self.config.use_cache |
| 243 | + |
| 244 | + return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
| 245 | + |
| 246 | + # retrieve input_ids and inputs_embeds |
| 247 | + if input_ids is not None and inputs_embeds is not None: |
| 248 | + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") |
| 249 | + elif input_ids is not None: |
| 250 | + batch_size, seq_length = input_ids.shape[:2] |
| 251 | + elif inputs_embeds is not None: |
| 252 | + batch_size, seq_length = inputs_embeds.shape[:2] |
| 253 | + else: |
| 254 | + raise ValueError("You have to specify either input_ids or inputs_embeds") |
| 255 | + |
| 256 | + seq_length_with_past = seq_length |
| 257 | + past_key_values_length = 0 |
| 258 | + if past_key_values is not None: |
| 259 | + past_key_values_length = past_key_values[0][0].shape[2] |
| 260 | + seq_length_with_past = seq_length_with_past + past_key_values_length |
| 261 | + |
| 262 | + past_key_values = QEffDynamicCache.from_legacy_cache(past_key_values) |
| 263 | + |
| 264 | + if inputs_embeds is None: |
| 265 | + inputs_embeds = self.embed_tokens(input_ids) |
| 266 | + inputs_embeds = inputs_embeds * self.embedding_multiplier_scale |
| 267 | + |
| 268 | + attention_mask = _create_causal_mask(position_ids=position_ids, target_length=past_key_values_length) |
| 269 | + |
| 270 | + # embed positions |
| 271 | + hidden_states = inputs_embeds |
| 272 | + |
| 273 | + # decoder layers |
| 274 | + all_hidden_states = () if output_hidden_states else None |
| 275 | + all_self_attns = () if output_attentions else None |
| 276 | + all_router_logits = () if output_router_logits else None |
| 277 | + next_decoder_cache = () if use_cache else None |
| 278 | + |
| 279 | + for idx, decoder_layer in enumerate(self.layers): |
| 280 | + if output_hidden_states: |
| 281 | + all_hidden_states += (hidden_states,) |
| 282 | + |
| 283 | + layer_outputs = decoder_layer( |
| 284 | + hidden_states, |
| 285 | + attention_mask=attention_mask, |
| 286 | + position_ids=position_ids, |
| 287 | + past_key_value=past_key_values, |
| 288 | + batch_index=batch_index, |
| 289 | + output_attentions=output_attentions, |
| 290 | + use_cache=use_cache, |
| 291 | + ) |
| 292 | + |
| 293 | + hidden_states = layer_outputs[0] |
| 294 | + |
| 295 | + if use_cache: |
| 296 | + next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) |
| 297 | + |
| 298 | + if output_attentions: |
| 299 | + all_self_attns += (layer_outputs[1],) |
| 300 | + |
| 301 | + if output_router_logits: |
| 302 | + all_router_logits += (layer_outputs[-1],) |
| 303 | + |
| 304 | + hidden_states = self.norm(hidden_states) |
| 305 | + |
| 306 | + # add hidden states from the last decoder layer |
| 307 | + if output_hidden_states: |
| 308 | + all_hidden_states += (hidden_states,) |
| 309 | + |
| 310 | + past_key_values = past_key_values.to_legacy_cache() |
| 311 | + |
| 312 | + return MoeModelOutputWithPast( |
| 313 | + last_hidden_state=hidden_states, |
| 314 | + past_key_values=past_key_values, |
| 315 | + hidden_states=all_hidden_states, |
| 316 | + attentions=all_self_attns, |
| 317 | + router_logits=all_router_logits, |
| 318 | + ) |
| 319 | + |
| 320 | + |
| 321 | +class QEffGrok1ModelForCausalLM(nn.Module): |
| 322 | + def forward( |
| 323 | + self, |
| 324 | + input_ids: torch.LongTensor = None, |
| 325 | + attention_mask: Optional[torch.Tensor] = None, |
| 326 | + position_ids: Optional[torch.LongTensor] = None, |
| 327 | + past_key_values: Optional[List[torch.FloatTensor]] = None, |
| 328 | + batch_index: Optional[torch.LongTensor] = None, |
| 329 | + inputs_embeds: Optional[torch.FloatTensor] = None, |
| 330 | + labels: Optional[torch.LongTensor] = None, |
| 331 | + use_cache: Optional[bool] = None, |
| 332 | + output_attentions: Optional[bool] = None, |
| 333 | + output_hidden_states: Optional[bool] = None, |
| 334 | + output_router_logits: Optional[bool] = None, |
| 335 | + return_dict: Optional[bool] = None, |
| 336 | + **kwargs, |
| 337 | + ): |
| 338 | + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
| 339 | + output_router_logits = ( |
| 340 | + output_router_logits if output_router_logits is not None else self.config.output_router_logits |
| 341 | + ) |
| 342 | + |
| 343 | + output_hidden_states = ( |
| 344 | + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
| 345 | + ) |
| 346 | + return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
| 347 | + |
| 348 | + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) |
| 349 | + outputs = self.model( |
| 350 | + input_ids=input_ids, |
| 351 | + attention_mask=attention_mask, |
| 352 | + position_ids=position_ids, |
| 353 | + past_key_values=past_key_values, |
| 354 | + batch_index=batch_index, |
| 355 | + inputs_embeds=inputs_embeds, |
| 356 | + use_cache=use_cache, |
| 357 | + output_attentions=output_attentions, |
| 358 | + output_hidden_states=output_hidden_states, |
| 359 | + output_router_logits=output_router_logits, |
| 360 | + return_dict=return_dict, |
| 361 | + **kwargs, |
| 362 | + ) |
| 363 | + |
| 364 | + # Cast to int32 to avoid ONNXRT issue |
| 365 | + logit_idx = position_ids.to(torch.int32).argmax(1, keepdim=True) |
| 366 | + hidden_states = outputs[0][torch.arange(position_ids.shape[0]).view(-1, 1), logit_idx] |
| 367 | + logits = self.lm_head(hidden_states) |
| 368 | + logits = logits * self.output_multiplier_scale |
| 369 | + logits = logits.float() |
| 370 | + |
| 371 | + return MoeCausalLMOutputWithPast( |
| 372 | + logits=logits, |
| 373 | + past_key_values=outputs.past_key_values, |
| 374 | + hidden_states=outputs.hidden_states, |
| 375 | + attentions=outputs.attentions, |
| 376 | + router_logits=outputs.router_logits, |
| 377 | + ) |
0 commit comments