1919# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
2020# See the License for the specific language governing permissions and
2121# limitations under the License.
22- import copy
2322from collections .abc import Callable
2423from typing import Optional , Union
2524
@@ -98,33 +97,78 @@ class T5Gemma2RotaryEmbedding(nn.Module):
9897
9998 def __init__ (self , config : T5Gemma2ModuleConfig , device = None ):
10099 super ().__init__ ()
101- # BC: "rope_type" was originally "type"
102- if hasattr (config , "rope_scaling" ) and isinstance (config .rope_scaling , dict ):
103- self .rope_type = config .rope_scaling .get ("rope_type" , config .rope_scaling .get ("type" ))
104- else :
105- self .rope_type = "default"
106100 self .max_seq_len_cached = config .max_position_embeddings
107101 self .original_max_seq_len = config .max_position_embeddings
108102
109103 self .config = config
110- self .rope_init_fn = ROPE_INIT_FUNCTIONS [self .rope_type ]
111104
112- inv_freq , self .attention_scaling = self .rope_init_fn (self .config , device )
113- self .register_buffer ("inv_freq" , inv_freq , persistent = False )
114- self .original_inv_freq = self .inv_freq
105+ self .layer_types = list (set (config .layer_types ))
106+ self .rope_type = {}
107+ for layer_type in self .layer_types :
108+ rope_params = self .config .rope_parameters [layer_type ]
109+ if rope_params is None :
110+ continue
111+
112+ self .rope_type [layer_type ] = rope_params ["rope_type" ]
113+ rope_init_fn : Callable = self .compute_default_rope_parameters
114+ if self .rope_type [layer_type ] != "default" :
115+ rope_init_fn = ROPE_INIT_FUNCTIONS [self .rope_type [layer_type ]]
116+ curr_inv_freq , curr_attention_scaling = rope_init_fn (self .config , device , layer_type = layer_type )
117+ self .register_buffer (f"{ layer_type } _inv_freq" , curr_inv_freq , persistent = False )
118+ setattr (self , f"{ layer_type } _original_inv_freq" , curr_inv_freq )
119+ setattr (self , f"{ layer_type } _attention_scaling" , curr_attention_scaling )
120+
121+ @staticmethod
122+ def compute_default_rope_parameters (
123+ config : Optional [T5Gemma2ModuleConfig ] = None ,
124+ device : Optional ["torch.device" ] = None ,
125+ seq_len : Optional [int ] = None ,
126+ layer_type : Optional [str ] = None ,
127+ ) -> tuple ["torch.Tensor" , float ]:
128+ """
129+ Computes the inverse frequencies according to the original RoPE implementation
130+ Args:
131+ config ([`~transformers.PreTrainedConfig`]):
132+ The model configuration.
133+ device (`torch.device`):
134+ The device to use for initialization of the inverse frequencies.
135+ seq_len (`int`, *optional*):
136+ The current sequence length. Unused for this type of RoPE.
137+ layer_type (`str`, *optional*):
138+ The current layer type if the model has different RoPE parameters per type.
139+ Should not be used unless `config.layer_types is not None`
140+
141+ Returns:
142+ Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
143+ post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE).
144+ """
145+ # For backward compatibility standardize the `rope_parameters_dict` if it uses old format
146+ base = config .rope_parameters [layer_type ]["rope_theta" ]
147+ dim = getattr (config , "head_dim" , None ) or config .hidden_size // config .num_attention_heads
148+
149+ attention_factor = 1.0 # Unused in this type of RoPE
150+
151+ # Compute the inverse frequencies
152+ inv_freq = 1.0 / (
153+ base ** (torch .arange (0 , dim , 2 , dtype = torch .int64 ).to (device = device , dtype = torch .float ) / dim )
154+ )
155+ return inv_freq , attention_factor
115156
116157 @torch .no_grad ()
117158 @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
118- def forward (self , x , position_ids ):
119- inv_freq_expanded = self .inv_freq [None , :, None ].float ().expand (position_ids .shape [0 ], - 1 , 1 ).to (x .device )
159+ def forward (self , x , position_ids , layer_type = None ):
160+ inv_freq = getattr (self , f"{ layer_type } _inv_freq" )
161+ attention_scaling = getattr (self , f"{ layer_type } _attention_scaling" )
162+
163+ inv_freq_expanded = inv_freq [None , :, None ].float ().expand (position_ids .shape [0 ], - 1 , 1 ).to (x .device )
120164 position_ids_expanded = position_ids [:, None , :].float ()
121165
122166 device_type = x .device .type if isinstance (x .device .type , str ) and x .device .type != "mps" else "cpu"
123167 with torch .autocast (device_type = device_type , enabled = False ): # Force float32
124168 freqs = (inv_freq_expanded .float () @ position_ids_expanded .float ()).transpose (1 , 2 )
125169 emb = torch .cat ((freqs , freqs ), dim = - 1 )
126- cos = emb .cos () * self . attention_scaling
127- sin = emb .sin () * self . attention_scaling
170+ cos = emb .cos () * attention_scaling
171+ sin = emb .sin () * attention_scaling
128172
129173 return cos .to (dtype = x .dtype ), sin .to (dtype = x .dtype )
130174
@@ -215,7 +259,7 @@ class T5Gemma2SelfAttention(nn.Module):
215259
216260 def __init__ (self , config : T5Gemma2ModuleConfig , layer_idx : int ):
217261 super ().__init__ ()
218- self .is_sliding = config .layer_types [layer_idx ] == "sliding_attention"
262+ self .layer_type = config .layer_types [layer_idx ] if hasattr ( config , "layer_types" ) else None
219263 self .config = config
220264 self .layer_idx = layer_idx
221265 self .head_dim = getattr (config , "head_dim" , config .hidden_size // config .num_attention_heads )
@@ -237,16 +281,17 @@ def __init__(self, config: T5Gemma2ModuleConfig, layer_idx: int):
237281 config .num_attention_heads * self .head_dim , config .hidden_size , bias = config .attention_bias
238282 )
239283 self .attn_logit_softcapping = self .config .attn_logit_softcapping
240- self .sliding_window = config .sliding_window if self .is_sliding else None
284+ self .sliding_window = config .sliding_window if self .layer_type == "sliding_attention" else None
285+ self .is_sliding = self .layer_type == "sliding_attention"
241286
242287 self .q_norm = T5Gemma2RMSNorm (dim = config .head_dim , eps = config .rms_norm_eps )
243288 self .k_norm = T5Gemma2RMSNorm (dim = config .head_dim , eps = config .rms_norm_eps )
244289
245290 def forward (
246291 self ,
247292 hidden_states : torch .Tensor ,
248- position_embeddings : torch .Tensor ,
249- attention_mask : Optional [torch .Tensor ],
293+ position_embeddings : torch .Tensor = None ,
294+ attention_mask : Optional [torch .Tensor ] = None ,
250295 past_key_values : Optional [Cache ] = None ,
251296 cache_position : Optional [torch .LongTensor ] = None ,
252297 ** kwargs : Unpack [FlashAttentionKwargs ],
@@ -295,7 +340,7 @@ class T5Gemma2MergedAttention(nn.Module):
295340
296341 def __init__ (self , config : T5Gemma2ModuleConfig , layer_idx : int ):
297342 super ().__init__ ()
298- self .is_sliding = config .layer_types [layer_idx ] == "sliding_attention"
343+ self .layer_type = config .layer_types [layer_idx ] if hasattr ( config , "layer_types" ) else None
299344 self .config = config
300345 self .layer_idx = layer_idx
301346 self .head_dim = getattr (config , "head_dim" , config .hidden_size // config .num_attention_heads )
@@ -317,7 +362,8 @@ def __init__(self, config: T5Gemma2ModuleConfig, layer_idx: int):
317362 config .num_attention_heads * self .head_dim , config .hidden_size , bias = config .attention_bias
318363 )
319364 self .attn_logit_softcapping = self .config .attn_logit_softcapping
320- self .sliding_window = config .sliding_window if self .is_sliding else None
365+ self .sliding_window = config .sliding_window if self .layer_type == "sliding_attention" else None
366+ self .is_sliding = self .layer_type == "sliding_attention"
321367
322368 self .q_norm = T5Gemma2RMSNorm (dim = config .head_dim , eps = config .rms_norm_eps )
323369 self .k_norm = T5Gemma2RMSNorm (dim = config .head_dim , eps = config .rms_norm_eps )
@@ -474,21 +520,14 @@ def __init__(self, config: T5Gemma2ModuleConfig, layer_idx: int):
474520 def forward (
475521 self ,
476522 hidden_states : torch .Tensor ,
477- position_embeddings_global : tuple [torch .Tensor , torch .Tensor ],
478- position_embeddings_local : tuple [torch .Tensor , torch .Tensor ],
523+ position_embeddings : tuple [torch .Tensor , torch .Tensor ],
479524 attention_mask : Optional [torch .Tensor ] = None ,
480525 position_ids : Optional [torch .LongTensor ] = None ,
481526 ** kwargs ,
482527 ) -> torch .FloatTensor :
483528 residual = hidden_states
484529 hidden_states = self .pre_self_attn_layernorm (hidden_states )
485530
486- # apply global RoPE to non-sliding layer only
487- if self .self_attn .is_sliding :
488- position_embeddings = position_embeddings_local
489- else :
490- position_embeddings = position_embeddings_global
491-
492531 hidden_states , _ = self .self_attn (
493532 hidden_states = hidden_states ,
494533 position_embeddings = position_embeddings ,
@@ -523,8 +562,7 @@ def __init__(self, config, layer_idx: int):
523562 def forward (
524563 self ,
525564 hidden_states : torch .Tensor ,
526- position_embeddings_global : tuple [torch .Tensor , torch .Tensor ],
527- position_embeddings_local : tuple [torch .Tensor , torch .Tensor ],
565+ position_embeddings : tuple [torch .Tensor , torch .Tensor ],
528566 attention_mask : Optional [torch .Tensor ] = None ,
529567 position_ids : Optional [torch .LongTensor ] = None ,
530568 past_key_values : Optional [EncoderDecoderCache ] = None ,
@@ -537,12 +575,6 @@ def forward(
537575 residual = hidden_states
538576 hidden_states = self .pre_self_attn_layernorm (hidden_states )
539577
540- # apply global RoPE to non-sliding layer only
541- if self .self_attn .is_sliding :
542- position_embeddings = position_embeddings_local
543- else :
544- position_embeddings = position_embeddings_global
545-
546578 hidden_states , _ , _ = self .self_attn (
547579 hidden_states = hidden_states ,
548580 position_embeddings = position_embeddings ,
@@ -677,6 +709,7 @@ class T5Gemma2PreTrainedModel(PreTrainedModel):
677709 OutputRecorder (T5Gemma2MergedAttention , index = 2 , layer_name = "cross_attn" ),
678710 ],
679711 }
712+ input_modalities = ["image" , "text" ]
680713
681714 def _init_weights (self , module ):
682715 super ()._init_weights (module )
@@ -798,14 +831,7 @@ def __init__(
798831 [T5Gemma2EncoderLayer (config , layer_idx ) for layer_idx in range (config .num_hidden_layers )]
799832 )
800833 self .dropout = nn .Dropout (config .dropout_rate )
801-
802- # global rope.
803- self .rotary_emb = T5Gemma2RotaryEmbedding (config = config )
804- # local rope.
805- config = copy .deepcopy (config )
806- config .rope_theta = config .rope_local_base_freq
807- config .rope_scaling = {"rope_type" : "default" }
808- self .rotary_emb_local = T5Gemma2RotaryEmbedding (config = config )
834+ self .rotary_emb = T5Gemma2RotaryEmbedding (config )
809835
810836 # Initialize weights and apply final processing
811837 self .post_init ()
@@ -877,8 +903,9 @@ def forward(
877903 hidden_states = inputs_embeds
878904
879905 # global and local position embeddings
880- position_embeddings_global = self .rotary_emb (hidden_states , position_ids )
881- position_embeddings_local = self .rotary_emb_local (hidden_states , position_ids )
906+ position_embeddings = {}
907+ for layer_type in self .config .layer_types :
908+ position_embeddings [layer_type ] = self .rotary_emb (hidden_states , position_ids , layer_type )
882909
883910 # dropout
884911 hidden_states = self .dropout (hidden_states )
@@ -887,8 +914,7 @@ def forward(
887914 assert isinstance (layer_module , T5Gemma2EncoderLayer )
888915 hidden_states = layer_module (
889916 hidden_states ,
890- position_embeddings_global ,
891- position_embeddings_local ,
917+ position_embeddings [layer_module .attention_type ],
892918 self_attn_mask_mapping [layer_module .attention_type ],
893919 position_ids ,
894920 ** kwargs ,
@@ -992,8 +1018,9 @@ def forward(
9921018 hidden_states = inputs_embeds
9931019
9941020 # global and local position embeddings
995- position_embeddings_global = self .rotary_emb (hidden_states , position_ids )
996- position_embeddings_local = self .rotary_emb_local (hidden_states , position_ids )
1021+ position_embeddings = {}
1022+ for layer_type in self .config .layer_types :
1023+ position_embeddings [layer_type ] = self .rotary_emb (hidden_states , position_ids , layer_type )
9971024
9981025 # dropout
9991026 hidden_states = self .dropout (hidden_states )
@@ -1002,8 +1029,7 @@ def forward(
10021029 assert isinstance (layer_module , T5Gemma2DecoderLayer )
10031030 hidden_states = layer_module (
10041031 hidden_states ,
1005- position_embeddings_global ,
1006- position_embeddings_local ,
1032+ position_embeddings [layer_module .attention_type ],
10071033 self_attn_mask_mapping [layer_module .attention_type ],
10081034 position_ids ,
10091035 past_key_values ,
0 commit comments