Skip to content

Commit b25f290

Browse files
committed
Update positional embeddings to match latest update.
1 parent d81f428 commit b25f290

File tree

3 files changed

+138
-117
lines changed

3 files changed

+138
-117
lines changed

src/transformers/models/t5gemma2/configuration_t5gemma2.py

Lines changed: 39 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
from typing import Any, Optional, Union
2424

2525
from ...configuration_utils import PreTrainedConfig, layer_type_validation
26-
from ...modeling_rope_utils import rope_config_validation
26+
from ...modeling_rope_utils import RopeParameters, rope_config_validation, standardize_rope_params
2727
from ...utils import logging
2828
from ..siglip import SiglipVisionConfig
2929

@@ -53,33 +53,31 @@ class T5Gemma2ModuleConfig(PreTrainedConfig):
5353

5454
def __init__(
5555
self,
56-
vocab_size=262_208,
57-
hidden_size=2304,
58-
intermediate_size=9216,
59-
num_hidden_layers=26,
60-
num_attention_heads=8,
61-
num_key_value_heads=4,
62-
head_dim=256,
63-
hidden_activation="gelu_pytorch_tanh",
64-
max_position_embeddings=131_072,
65-
initializer_range=0.02,
66-
rms_norm_eps=1e-6,
67-
use_cache=True,
68-
pad_token_id=0,
69-
eos_token_id=1,
70-
bos_token_id=2,
71-
tie_word_embeddings=True,
72-
rope_theta=1_000_000.0,
73-
attention_bias=False,
74-
attention_dropout=0.0,
75-
query_pre_attn_scalar=256,
76-
sliding_window=4096,
77-
layer_types=None,
78-
final_logit_softcapping=None,
79-
attn_logit_softcapping=None,
80-
rope_scaling=None,
81-
rope_local_base_freq=10_000.0,
82-
use_bidirectional_attention=False,
56+
vocab_size: Optional[int] = 262_208,
57+
hidden_size: Optional[int] = 2304,
58+
intermediate_size: Optional[int] = 9216,
59+
num_hidden_layers: Optional[int] = 26,
60+
num_attention_heads: Optional[int] = 8,
61+
num_key_value_heads: Optional[int] = 4,
62+
head_dim: Optional[int] = 256,
63+
hidden_activation: Optional[str] = "gelu_pytorch_tanh",
64+
max_position_embeddings: Optional[int] = 131_072,
65+
initializer_range: Optional[float] = 0.02,
66+
rms_norm_eps: Optional[int] = 1e-6,
67+
use_cache: Optional[bool] = True,
68+
pad_token_id: Optional[int] = 0,
69+
eos_token_id: Optional[int] = 1,
70+
bos_token_id: Optional[int] = 2,
71+
tie_word_embeddings: Optional[bool] = True,
72+
attention_bias: Optional[bool] = False,
73+
attention_dropout: Optional[float] = 0.0,
74+
query_pre_attn_scalar: Optional[int] = 256,
75+
sliding_window: Optional[int] = 4096,
76+
layer_types: Optional[list[str]] = None,
77+
final_logit_softcapping: Optional[float] = None,
78+
attn_logit_softcapping: Optional[float] = None,
79+
rope_parameters: Optional[RopeParameters | dict[RopeParameters]] = None,
80+
use_bidirectional_attention: Optional[bool] = False,
8381
**kwargs,
8482
):
8583
super().__init__(
@@ -100,7 +98,6 @@ def __init__(
10098
self.initializer_range = initializer_range
10199
self.rms_norm_eps = rms_norm_eps
102100
self.use_cache = use_cache
103-
self.rope_theta = rope_theta
104101
self.attention_bias = attention_bias
105102
self.attention_dropout = attention_dropout
106103
self.hidden_activation = hidden_activation
@@ -109,14 +106,15 @@ def __init__(
109106
self.final_logit_softcapping = final_logit_softcapping
110107
self.attn_logit_softcapping = attn_logit_softcapping
111108
self.layer_types = layer_types
109+
# Try to set `rope_scaling` if available, otherwise use `rope_parameters`
110+
rope_scaling = kwargs.pop("rope_scaling", None)
111+
if rope_scaling is not None:
112+
rope_parameters = {"sliding_attention": {"rope_type": "default"}, "full_attention": rope_scaling}
113+
self.rope_parameters = rope_parameters
112114
self.use_bidirectional_attention = use_bidirectional_attention
113115
if use_bidirectional_attention:
114116
self.sliding_window = (self.sliding_window // 2) + 1 # due to fa we set exclusive bounds
115117

116-
self.rope_local_base_freq = rope_local_base_freq
117-
self.rope_scaling = rope_scaling
118-
rope_config_validation(self)
119-
120118
# BC -> the pattern used to be a simple int, and it's still present in configs on the Hub
121119
self._sliding_window_pattern = kwargs.get("sliding_window_pattern", 6)
122120

@@ -127,6 +125,14 @@ def __init__(
127125
]
128126
layer_type_validation(self.layer_types, self.num_hidden_layers)
129127

128+
# Validate the correctness of rotary position embeddings parameters
129+
rope_theta = getattr(self, "rope_theta", 1_000_000.0)
130+
rope_local_base_freq = getattr(self, "rope_local_base_freq", 10000.0)
131+
standardize_rope_params(
132+
self, rope_theta={"full_attention": rope_theta, "sliding_attention": rope_local_base_freq}
133+
)
134+
rope_config_validation(self)
135+
130136

131137
class T5Gemma2Config(PreTrainedConfig):
132138
r"""

src/transformers/models/t5gemma2/modeling_t5gemma2.py

Lines changed: 78 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
2020
# See the License for the specific language governing permissions and
2121
# limitations under the License.
22-
import copy
2322
from collections.abc import Callable
2423
from typing import Optional, Union
2524

@@ -98,33 +97,78 @@ class T5Gemma2RotaryEmbedding(nn.Module):
9897

9998
def __init__(self, config: T5Gemma2ModuleConfig, device=None):
10099
super().__init__()
101-
# BC: "rope_type" was originally "type"
102-
if hasattr(config, "rope_scaling") and isinstance(config.rope_scaling, dict):
103-
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
104-
else:
105-
self.rope_type = "default"
106100
self.max_seq_len_cached = config.max_position_embeddings
107101
self.original_max_seq_len = config.max_position_embeddings
108102

109103
self.config = config
110-
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
111104

112-
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
113-
self.register_buffer("inv_freq", inv_freq, persistent=False)
114-
self.original_inv_freq = self.inv_freq
105+
self.layer_types = list(set(config.layer_types))
106+
self.rope_type = {}
107+
for layer_type in self.layer_types:
108+
rope_params = self.config.rope_parameters[layer_type]
109+
if rope_params is None:
110+
continue
111+
112+
self.rope_type[layer_type] = rope_params["rope_type"]
113+
rope_init_fn: Callable = self.compute_default_rope_parameters
114+
if self.rope_type[layer_type] != "default":
115+
rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type[layer_type]]
116+
curr_inv_freq, curr_attention_scaling = rope_init_fn(self.config, device, layer_type=layer_type)
117+
self.register_buffer(f"{layer_type}_inv_freq", curr_inv_freq, persistent=False)
118+
setattr(self, f"{layer_type}_original_inv_freq", curr_inv_freq)
119+
setattr(self, f"{layer_type}_attention_scaling", curr_attention_scaling)
120+
121+
@staticmethod
122+
def compute_default_rope_parameters(
123+
config: Optional[T5Gemma2ModuleConfig] = None,
124+
device: Optional["torch.device"] = None,
125+
seq_len: Optional[int] = None,
126+
layer_type: Optional[str] = None,
127+
) -> tuple["torch.Tensor", float]:
128+
"""
129+
Computes the inverse frequencies according to the original RoPE implementation
130+
Args:
131+
config ([`~transformers.PreTrainedConfig`]):
132+
The model configuration.
133+
device (`torch.device`):
134+
The device to use for initialization of the inverse frequencies.
135+
seq_len (`int`, *optional*):
136+
The current sequence length. Unused for this type of RoPE.
137+
layer_type (`str`, *optional*):
138+
The current layer type if the model has different RoPE parameters per type.
139+
Should not be used unless `config.layer_types is not None`
140+
141+
Returns:
142+
Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
143+
post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE).
144+
"""
145+
# For backward compatibility standardize the `rope_parameters_dict` if it uses old format
146+
base = config.rope_parameters[layer_type]["rope_theta"]
147+
dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads
148+
149+
attention_factor = 1.0 # Unused in this type of RoPE
150+
151+
# Compute the inverse frequencies
152+
inv_freq = 1.0 / (
153+
base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim)
154+
)
155+
return inv_freq, attention_factor
115156

116157
@torch.no_grad()
117158
@dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
118-
def forward(self, x, position_ids):
119-
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
159+
def forward(self, x, position_ids, layer_type=None):
160+
inv_freq = getattr(self, f"{layer_type}_inv_freq")
161+
attention_scaling = getattr(self, f"{layer_type}_attention_scaling")
162+
163+
inv_freq_expanded = inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
120164
position_ids_expanded = position_ids[:, None, :].float()
121165

122166
device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
123167
with torch.autocast(device_type=device_type, enabled=False): # Force float32
124168
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
125169
emb = torch.cat((freqs, freqs), dim=-1)
126-
cos = emb.cos() * self.attention_scaling
127-
sin = emb.sin() * self.attention_scaling
170+
cos = emb.cos() * attention_scaling
171+
sin = emb.sin() * attention_scaling
128172

129173
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
130174

@@ -215,7 +259,7 @@ class T5Gemma2SelfAttention(nn.Module):
215259

216260
def __init__(self, config: T5Gemma2ModuleConfig, layer_idx: int):
217261
super().__init__()
218-
self.is_sliding = config.layer_types[layer_idx] == "sliding_attention"
262+
self.layer_type = config.layer_types[layer_idx] if hasattr(config, "layer_types") else None
219263
self.config = config
220264
self.layer_idx = layer_idx
221265
self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
@@ -237,16 +281,17 @@ def __init__(self, config: T5Gemma2ModuleConfig, layer_idx: int):
237281
config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
238282
)
239283
self.attn_logit_softcapping = self.config.attn_logit_softcapping
240-
self.sliding_window = config.sliding_window if self.is_sliding else None
284+
self.sliding_window = config.sliding_window if self.layer_type == "sliding_attention" else None
285+
self.is_sliding = self.layer_type == "sliding_attention"
241286

242287
self.q_norm = T5Gemma2RMSNorm(dim=config.head_dim, eps=config.rms_norm_eps)
243288
self.k_norm = T5Gemma2RMSNorm(dim=config.head_dim, eps=config.rms_norm_eps)
244289

245290
def forward(
246291
self,
247292
hidden_states: torch.Tensor,
248-
position_embeddings: torch.Tensor,
249-
attention_mask: Optional[torch.Tensor],
293+
position_embeddings: torch.Tensor = None,
294+
attention_mask: Optional[torch.Tensor] = None,
250295
past_key_values: Optional[Cache] = None,
251296
cache_position: Optional[torch.LongTensor] = None,
252297
**kwargs: Unpack[FlashAttentionKwargs],
@@ -295,7 +340,7 @@ class T5Gemma2MergedAttention(nn.Module):
295340

296341
def __init__(self, config: T5Gemma2ModuleConfig, layer_idx: int):
297342
super().__init__()
298-
self.is_sliding = config.layer_types[layer_idx] == "sliding_attention"
343+
self.layer_type = config.layer_types[layer_idx] if hasattr(config, "layer_types") else None
299344
self.config = config
300345
self.layer_idx = layer_idx
301346
self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
@@ -317,7 +362,8 @@ def __init__(self, config: T5Gemma2ModuleConfig, layer_idx: int):
317362
config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
318363
)
319364
self.attn_logit_softcapping = self.config.attn_logit_softcapping
320-
self.sliding_window = config.sliding_window if self.is_sliding else None
365+
self.sliding_window = config.sliding_window if self.layer_type == "sliding_attention" else None
366+
self.is_sliding = self.layer_type == "sliding_attention"
321367

322368
self.q_norm = T5Gemma2RMSNorm(dim=config.head_dim, eps=config.rms_norm_eps)
323369
self.k_norm = T5Gemma2RMSNorm(dim=config.head_dim, eps=config.rms_norm_eps)
@@ -474,21 +520,14 @@ def __init__(self, config: T5Gemma2ModuleConfig, layer_idx: int):
474520
def forward(
475521
self,
476522
hidden_states: torch.Tensor,
477-
position_embeddings_global: tuple[torch.Tensor, torch.Tensor],
478-
position_embeddings_local: tuple[torch.Tensor, torch.Tensor],
523+
position_embeddings: tuple[torch.Tensor, torch.Tensor],
479524
attention_mask: Optional[torch.Tensor] = None,
480525
position_ids: Optional[torch.LongTensor] = None,
481526
**kwargs,
482527
) -> torch.FloatTensor:
483528
residual = hidden_states
484529
hidden_states = self.pre_self_attn_layernorm(hidden_states)
485530

486-
# apply global RoPE to non-sliding layer only
487-
if self.self_attn.is_sliding:
488-
position_embeddings = position_embeddings_local
489-
else:
490-
position_embeddings = position_embeddings_global
491-
492531
hidden_states, _ = self.self_attn(
493532
hidden_states=hidden_states,
494533
position_embeddings=position_embeddings,
@@ -523,8 +562,7 @@ def __init__(self, config, layer_idx: int):
523562
def forward(
524563
self,
525564
hidden_states: torch.Tensor,
526-
position_embeddings_global: tuple[torch.Tensor, torch.Tensor],
527-
position_embeddings_local: tuple[torch.Tensor, torch.Tensor],
565+
position_embeddings: tuple[torch.Tensor, torch.Tensor],
528566
attention_mask: Optional[torch.Tensor] = None,
529567
position_ids: Optional[torch.LongTensor] = None,
530568
past_key_values: Optional[EncoderDecoderCache] = None,
@@ -537,12 +575,6 @@ def forward(
537575
residual = hidden_states
538576
hidden_states = self.pre_self_attn_layernorm(hidden_states)
539577

540-
# apply global RoPE to non-sliding layer only
541-
if self.self_attn.is_sliding:
542-
position_embeddings = position_embeddings_local
543-
else:
544-
position_embeddings = position_embeddings_global
545-
546578
hidden_states, _, _ = self.self_attn(
547579
hidden_states=hidden_states,
548580
position_embeddings=position_embeddings,
@@ -677,6 +709,7 @@ class T5Gemma2PreTrainedModel(PreTrainedModel):
677709
OutputRecorder(T5Gemma2MergedAttention, index=2, layer_name="cross_attn"),
678710
],
679711
}
712+
input_modalities = ["image", "text"]
680713

681714
def _init_weights(self, module):
682715
super()._init_weights(module)
@@ -798,14 +831,7 @@ def __init__(
798831
[T5Gemma2EncoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
799832
)
800833
self.dropout = nn.Dropout(config.dropout_rate)
801-
802-
# global rope.
803-
self.rotary_emb = T5Gemma2RotaryEmbedding(config=config)
804-
# local rope.
805-
config = copy.deepcopy(config)
806-
config.rope_theta = config.rope_local_base_freq
807-
config.rope_scaling = {"rope_type": "default"}
808-
self.rotary_emb_local = T5Gemma2RotaryEmbedding(config=config)
834+
self.rotary_emb = T5Gemma2RotaryEmbedding(config)
809835

810836
# Initialize weights and apply final processing
811837
self.post_init()
@@ -877,8 +903,9 @@ def forward(
877903
hidden_states = inputs_embeds
878904

879905
# global and local position embeddings
880-
position_embeddings_global = self.rotary_emb(hidden_states, position_ids)
881-
position_embeddings_local = self.rotary_emb_local(hidden_states, position_ids)
906+
position_embeddings = {}
907+
for layer_type in self.config.layer_types:
908+
position_embeddings[layer_type] = self.rotary_emb(hidden_states, position_ids, layer_type)
882909

883910
# dropout
884911
hidden_states = self.dropout(hidden_states)
@@ -887,8 +914,7 @@ def forward(
887914
assert isinstance(layer_module, T5Gemma2EncoderLayer)
888915
hidden_states = layer_module(
889916
hidden_states,
890-
position_embeddings_global,
891-
position_embeddings_local,
917+
position_embeddings[layer_module.attention_type],
892918
self_attn_mask_mapping[layer_module.attention_type],
893919
position_ids,
894920
**kwargs,
@@ -992,8 +1018,9 @@ def forward(
9921018
hidden_states = inputs_embeds
9931019

9941020
# global and local position embeddings
995-
position_embeddings_global = self.rotary_emb(hidden_states, position_ids)
996-
position_embeddings_local = self.rotary_emb_local(hidden_states, position_ids)
1021+
position_embeddings = {}
1022+
for layer_type in self.config.layer_types:
1023+
position_embeddings[layer_type] = self.rotary_emb(hidden_states, position_ids, layer_type)
9971024

9981025
# dropout
9991026
hidden_states = self.dropout(hidden_states)
@@ -1002,8 +1029,7 @@ def forward(
10021029
assert isinstance(layer_module, T5Gemma2DecoderLayer)
10031030
hidden_states = layer_module(
10041031
hidden_states,
1005-
position_embeddings_global,
1006-
position_embeddings_local,
1032+
position_embeddings[layer_module.attention_type],
10071033
self_attn_mask_mapping[layer_module.attention_type],
10081034
position_ids,
10091035
past_key_values,

0 commit comments

Comments
 (0)