Skip to content

Commit 5a72bac

Browse files
committed
REsolved Conflicts
Signed-off-by: Amit Raj <[email protected]>
1 parent b0b5003 commit 5a72bac

File tree

5 files changed

+8
-73
lines changed

5 files changed

+8
-73
lines changed

QEfficient/transformers/models/gemma/modeling_gemma.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -86,8 +86,8 @@ def qeff_apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
8686
Returns:
8787
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
8888
"""
89-
cos = cos.unsqueeze(unsqueeze_dim)
90-
sin = sin.unsqueeze(unsqueeze_dim)
89+
cos = cos[position_ids].unsqueeze(unsqueeze_dim)
90+
sin = sin[position_ids].unsqueeze(unsqueeze_dim)
9191

9292
# Apply rotation
9393
q_embed = (q * cos) + (rotate_half(q) * sin)

QEfficient/transformers/models/gemma2/modeling_gemma2.py

Lines changed: 0 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -40,11 +40,7 @@ class QEffGemma2RotaryEmbedding(Gemma2RotaryEmbedding):
4040
"""
4141

4242
def __init__(self, config: Gemma2Config, device=None):
43-
<<<<<<< HEAD
4443
super().__init__(config=config)
45-
=======
46-
Gemma2RotaryEmbedding.__init__(self, config=config)
47-
>>>>>>> 6ba4c76 (Code cleaning and formating)
4844

4945
# Build here to make `torch.jit.trace` work.
5046
self._set_cos_sin_cache(
@@ -136,16 +132,6 @@ class QEffGemma2Attention(Gemma2Attention):
136132
- add new args cache idx for the kv retention
137133
"""
138134

139-
<<<<<<< HEAD
140-
=======
141-
def __init__(self, config: Gemma2Config, layer_idx: Optional[int] = None):
142-
super().__init__(config, layer_idx)
143-
# Define the general __qeff_init__() for any changes in the init calls
144-
# Set the init in the module mapping pytorch transforms
145-
self.config = config
146-
self.__qeff_init__()
147-
148-
>>>>>>> 6ba4c76 (Code cleaning and formating)
149135
def __qeff_init__(self):
150136
self.rotary_emb = QEffGemma2RotaryEmbedding(config=self.config)
151137

@@ -355,10 +341,6 @@ def forward(
355341
output_attentions=output_attentions,
356342
use_cache=use_cache,
357343
cache_position=cache_position,
358-
<<<<<<< HEAD
359-
=======
360-
position_embeddings=position_embeddings,
361-
>>>>>>> 6ba4c76 (Code cleaning and formating)
362344
**kwargs,
363345
)
364346

@@ -384,26 +366,6 @@ def forward(
384366
)
385367
return output if return_dict else output.to_tuple()
386368

387-
<<<<<<< HEAD
388-
=======
389-
def _update_causal_mask(
390-
self,
391-
attention_mask: torch.Tensor,
392-
input_tensor: torch.Tensor,
393-
cache_position: torch.Tensor,
394-
position_ids: torch.Tensor,
395-
past_key_values: Cache,
396-
output_attentions: bool,
397-
):
398-
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
399-
target_length = attention_mask.shape[-1] if isinstance(attention_mask, torch.Tensor) else past_seen_tokens
400-
causal_mask = _create_causal_mask(
401-
position_ids=position_ids, target_length=target_length, sliding_window=self.config.sliding_window
402-
)
403-
404-
return causal_mask
405-
406-
>>>>>>> 6ba4c76 (Code cleaning and formating)
407369

408370
class QEffGemma2ForCausalLM(Gemma2ForCausalLM, GenerationMixin):
409371
"""

QEfficient/transformers/models/llama/modeling_llama.py

Lines changed: 4 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -85,26 +85,15 @@ def qeff_apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
8585
Returns:
8686
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
8787
"""
88-
cos = cos.unsqueeze(unsqueeze_dim)
89-
sin = sin.unsqueeze(unsqueeze_dim)
88+
cos = cos[position_ids].unsqueeze(unsqueeze_dim)
89+
sin = sin[position_ids].unsqueeze(unsqueeze_dim)
9090

9191
# Apply rotation
9292
q_embed = (q * cos) + (rotate_half(q) * sin)
9393
k_embed = (k * cos) + (rotate_half(k) * sin)
9494
# Cast back to original dtype
9595
return q_embed.to(q.dtype), k_embed.to(k.dtype)
9696

97-
def eager_attention_forward(
98-
module: nn.Module,
99-
query: torch.Tensor,
100-
key: torch.Tensor,
101-
value: torch.Tensor,
102-
attention_mask: Optional[torch.Tensor],
103-
scaling: float,
104-
**kwargs,
105-
):
106-
key_states = repeat_kv(key, module.num_key_value_groups)
107-
value_states = repeat_kv(value, module.num_key_value_groups)
10897

10998
def eager_attention_forward(
11099
module: nn.Module,
@@ -142,6 +131,8 @@ def forward(
142131
position_ids: Optional[torch.LongTensor] = None,
143132
past_key_value: Optional[Cache] = None,
144133
batch_index: Optional[torch.LongTensor] = None,
134+
output_attentions: bool = False,
135+
use_cache: bool = False,
145136
cache_position: Optional[torch.LongTensor] = None,
146137
**kwargs,
147138
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:

QEfficient/transformers/models/mistral/modeling_mistral.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,8 @@ def forward(
140140
position_ids: Optional[torch.LongTensor] = None,
141141
past_key_value: Optional[Cache] = None,
142142
batch_index: Optional[torch.LongTensor] = None,
143+
output_attentions: bool = False,
144+
use_cache: bool = False,
143145
cache_position: Optional[torch.LongTensor] = None,
144146
**kwargs,
145147
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
@@ -199,7 +201,6 @@ def forward(
199201
output_attentions: Optional[bool] = False,
200202
use_cache: Optional[bool] = False,
201203
cache_position: Optional[torch.LongTensor] = None,
202-
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
203204
**kwargs,
204205
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
205206
"""
@@ -230,7 +231,6 @@ def forward(
230231
output_attentions=output_attentions,
231232
use_cache=use_cache,
232233
cache_position=cache_position,
233-
position_embeddings=position_embeddings,
234234
**kwargs,
235235
)
236236
hidden_states = residual + hidden_states

QEfficient/transformers/models/qwen2/modeling_qwen2.py

Lines changed: 0 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -68,10 +68,6 @@ def forward(self, x, seq_len=None):
6868
)
6969

7070

71-
<<<<<<< HEAD
72-
<<<<<<< HEAD
73-
=======
74-
>>>>>>> e4503c5 (Minor fixes)
7571
def qeff_apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
7672
"""Applies Rotary Position Embedding with Multimodal Sections to the query and key tensors (https://qwenlm.github.io/blog/qwen2-vl/).
7773
@@ -105,24 +101,15 @@ def qeff_apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
105101
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
106102
"""
107103

108-
<<<<<<< HEAD
109-
cos = cos.unsqueeze(unsqueeze_dim)
110-
sin = sin.unsqueeze(unsqueeze_dim)
111-
=======
112104
cos = cos[position_ids].unsqueeze(unsqueeze_dim)
113105
sin = sin[position_ids].unsqueeze(unsqueeze_dim)
114-
>>>>>>> e4503c5 (Minor fixes)
115106

116107
q_embed = (q * cos) + (rotate_half(q) * sin)
117108
k_embed = (k * cos) + (rotate_half(k) * sin)
118109

119110
return q_embed.to(q.dtype), k_embed.to(k.dtype)
120111

121112

122-
<<<<<<< HEAD
123-
=======
124-
>>>>>>> d0f7ffd (Ruff check and format)
125-
=======
126113
def eager_attention_forward(
127114
module: nn.Module,
128115
query: torch.Tensor,
@@ -146,7 +133,6 @@ def eager_attention_forward(
146133
return attn_output, attn_weights
147134

148135

149-
>>>>>>> e4503c5 (Minor fixes)
150136
class QEffQwen2Attention(Qwen2Attention):
151137
"""
152138
Copied from Qwen2Attention: https://github.com/huggingface/transformers/blob/main/src/transformers/models/qwen2/modeling_qwen2.py
@@ -177,11 +163,7 @@ def forward(
177163
kv_seq_len = key_states.shape[-2]
178164

179165
kv_seq_len = past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
180-
<<<<<<< HEAD
181-
cos, sin = position_embeddings
182-
=======
183166
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
184-
>>>>>>> e4503c5 (Minor fixes)
185167
query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
186168

187169
if past_key_value is not None:

0 commit comments

Comments
 (0)