diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index 0198cdd33711..6a0e495ed108 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -716,6 +716,8 @@ title: T5 - local: model_doc/t5gemma title: T5Gemma + - local: model_doc/t5gemma2 + title: T5Gemma2 - local: model_doc/t5v1.1 title: T5v1.1 - local: model_doc/tapex diff --git a/docs/source/en/model_doc/t5gemma2.md b/docs/source/en/model_doc/t5gemma2.md new file mode 100644 index 000000000000..5f68f1f1b0a2 --- /dev/null +++ b/docs/source/en/model_doc/t5gemma2.md @@ -0,0 +1,109 @@ + + +
+
+ PyTorch + FlashAttention + SDPA +
+
+ +# T5Gemma 2 + +T5Gemma 2 is a family of pretrained encoder-decoder large language models with strong multilingual, multimodal and long-context capability, available in 270M-270M, 1B-1B and 4B-4B parameters. Following T5Gemma, it is built via model adaptation (based on Gemma 3) using UL2. The architecture is similar to T5Gemma and Gemma 3, enhanced with tied word embeddings and merged self- and cross-attention to save model parameters. + +> [!TIP] +> Click on the T5Gemma 2 models in the right sidebar for more examples of how to apply T5Gemma 2 to different language tasks. + +The example below demonstrates how to chat with the model with [`Pipeline`] or the [`AutoModel`] class, and from the command line. + + + + +```python +import torch +from transformers import pipeline + +generator = pipeline( + "image-text-to-text", + model="google/t5gemma-2-270m-270m", + dtype=torch.bfloat16, + device_map="auto", +) + +generator( + "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/bee.jpg", + text=" in this image, there is", + generate_kwargs={"do_sample": False, "max_new_tokens": 50}, +) +``` + + + + +```python +import torch +import requests +from PIL import Image +from transformers import AutoProcessor, AutoModelForSeq2SeqLM + +processor = AutoProcessor.from_pretrained("google/t5gemma-2-270m-270m") +model = AutoModelForSeq2SeqLM.from_pretrained( + "google/t5gemma-2-270m-270m", + device_map="auto", + dtype=torch.bfloat16, +) + +url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/bee.jpg" +image = Image.open(requests.get(url, stream=True).raw) +prompt = " in this image, there is" + +model_inputs = processor(text=prompt, images=image, return_tensors="pt") +generation = model.generate(**model_inputs, max_new_tokens=20, do_sample=False) +print(processor.decode(generation[0])) +``` + + + + +## T5Gemma2Config + +[[autodoc]] T5Gemma2Config + +## T5Gemma2ModuleConfig + +[[autodoc]] T5Gemma2ModuleConfig + +## T5Gemma2Model + +[[autodoc]] T5Gemma2Model + - forward + +## T5Gemma2ForConditionalGeneration + +[[autodoc]] T5Gemma2ForConditionalGeneration + - forward + +## T5Gemma2ForSequenceClassification + +[[autodoc]] T5Gemma2ForSequenceClassification + - forward + +## T5Gemma2ForTokenClassification + +[[autodoc]] T5Gemma2ForTokenClassification + - forward diff --git a/src/transformers/models/__init__.py b/src/transformers/models/__init__.py index 5630063f92ec..7bebfe437919 100644 --- a/src/transformers/models/__init__.py +++ b/src/transformers/models/__init__.py @@ -335,6 +335,7 @@ from .switch_transformers import * from .t5 import * from .t5gemma import * + from .t5gemma2 import * from .table_transformer import * from .tapas import * from .textnet import * diff --git a/src/transformers/models/auto/configuration_auto.py b/src/transformers/models/auto/configuration_auto.py index 7e2e84a445ef..6c90ba74fafc 100644 --- a/src/transformers/models/auto/configuration_auto.py +++ b/src/transformers/models/auto/configuration_auto.py @@ -396,6 +396,7 @@ ("switch_transformers", "SwitchTransformersConfig"), ("t5", "T5Config"), ("t5gemma", "T5GemmaConfig"), + ("t5gemma2", "T5Gemma2Config"), ("table-transformer", "TableTransformerConfig"), ("tapas", "TapasConfig"), ("textnet", "TextNetConfig"), @@ -855,6 +856,7 @@ ("switch_transformers", "SwitchTransformers"), ("t5", "T5"), ("t5gemma", "T5Gemma"), + ("t5gemma2", "T5Gemma2"), ("t5v1.1", "T5v1.1"), ("table-transformer", "Table Transformer"), ("tapas", "TAPAS"), diff --git a/src/transformers/models/auto/image_processing_auto.py b/src/transformers/models/auto/image_processing_auto.py index 60af0f869bad..1392ca764752 100644 --- a/src/transformers/models/auto/image_processing_auto.py +++ b/src/transformers/models/auto/image_processing_auto.py @@ -176,6 +176,7 @@ ("swin", ("ViTImageProcessor", "ViTImageProcessorFast")), ("swin2sr", ("Swin2SRImageProcessor", "Swin2SRImageProcessorFast")), ("swinv2", ("ViTImageProcessor", "ViTImageProcessorFast")), + ("t5gemma2", ("Gemma3ImageProcessor", "Gemma3ImageProcessorFast")), ("table-transformer", ("DetrImageProcessor", "DetrImageProcessorFast")), ("textnet", ("TextNetImageProcessor", "TextNetImageProcessorFast")), ("timesformer", ("VideoMAEImageProcessor", None)), diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index 197029464efd..a3cb66345bd0 100644 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -382,6 +382,7 @@ class _BaseModelWithGenerate(PreTrainedModel, GenerationMixin): ("switch_transformers", "SwitchTransformersModel"), ("t5", "T5Model"), ("t5gemma", "T5GemmaModel"), + ("t5gemma2", "T5Gemma2Model"), ("table-transformer", "TableTransformerModel"), ("tapas", "TapasModel"), ("textnet", "TextNetModel"), @@ -513,6 +514,7 @@ class _BaseModelWithGenerate(PreTrainedModel, GenerationMixin): ("switch_transformers", "SwitchTransformersForConditionalGeneration"), ("t5", "T5ForConditionalGeneration"), ("t5gemma", "T5GemmaForConditionalGeneration"), + ("t5gemma2", "T5Gemma2ForConditionalGeneration"), ("tapas", "TapasForMaskedLM"), ("transfo-xl", "TransfoXLLMHeadModel"), ("tvlt", "TvltForPreTraining"), @@ -612,6 +614,7 @@ class _BaseModelWithGenerate(PreTrainedModel, GenerationMixin): ("switch_transformers", "SwitchTransformersForConditionalGeneration"), ("t5", "T5ForConditionalGeneration"), ("t5gemma", "T5GemmaForConditionalGeneration"), + ("t5gemma2", "T5Gemma2ForConditionalGeneration"), ("tapas", "TapasForMaskedLM"), ("transfo-xl", "TransfoXLLMHeadModel"), ("wav2vec2", "Wav2Vec2ForMaskedLM"), @@ -1059,6 +1062,7 @@ class _BaseModelWithGenerate(PreTrainedModel, GenerationMixin): ("qwen3_vl_moe", "Qwen3VLMoeForConditionalGeneration"), ("shieldgemma2", "Gemma3ForConditionalGeneration"), ("smolvlm", "SmolVLMForConditionalGeneration"), + ("t5gemma2", "T5Gemma2ForConditionalGeneration"), ("udop", "UdopForConditionalGeneration"), ("video_llama_3", "VideoLlama3ForConditionalGeneration"), ("vipllava", "VipLlavaForConditionalGeneration"), @@ -1185,6 +1189,7 @@ class _BaseModelWithGenerate(PreTrainedModel, GenerationMixin): ("switch_transformers", "SwitchTransformersForConditionalGeneration"), ("t5", "T5ForConditionalGeneration"), ("t5gemma", "T5GemmaForConditionalGeneration"), + ("t5gemma2", "T5Gemma2ForConditionalGeneration"), ("umt5", "UMT5ForConditionalGeneration"), ("voxtral", "VoxtralForConditionalGeneration"), ("xlm-prophetnet", "XLMProphetNetForConditionalGeneration"), @@ -1314,6 +1319,7 @@ class _BaseModelWithGenerate(PreTrainedModel, GenerationMixin): ("starcoder2", "Starcoder2ForSequenceClassification"), ("t5", "T5ForSequenceClassification"), ("t5gemma", "T5GemmaForSequenceClassification"), + ("t5gemma2", "T5Gemma2ForSequenceClassification"), ("tapas", "TapasForSequenceClassification"), ("transfo-xl", "TransfoXLForSequenceClassification"), ("umt5", "UMT5ForSequenceClassification"), @@ -1521,6 +1527,7 @@ class _BaseModelWithGenerate(PreTrainedModel, GenerationMixin): ("starcoder2", "Starcoder2ForTokenClassification"), ("t5", "T5ForTokenClassification"), ("t5gemma", "T5GemmaForTokenClassification"), + ("t5gemma2", "T5Gemma2ForTokenClassification"), ("umt5", "UMT5ForTokenClassification"), ("xlm", "XLMForTokenClassification"), ("xlm-roberta", "XLMRobertaForTokenClassification"), diff --git a/src/transformers/models/auto/processing_auto.py b/src/transformers/models/auto/processing_auto.py index 5186b78b07e0..0381eb557eb3 100644 --- a/src/transformers/models/auto/processing_auto.py +++ b/src/transformers/models/auto/processing_auto.py @@ -137,6 +137,7 @@ ("speech_to_text", "Speech2TextProcessor"), ("speech_to_text_2", "Speech2Text2Processor"), ("speecht5", "SpeechT5Processor"), + ("t5gemma2", "Gemma3Processor"), ("trocr", "TrOCRProcessor"), ("tvlt", "TvltProcessor"), ("tvp", "TvpProcessor"), diff --git a/src/transformers/models/auto/tokenization_auto.py b/src/transformers/models/auto/tokenization_auto.py index a861aee12c57..11ff2d677b78 100644 --- a/src/transformers/models/auto/tokenization_auto.py +++ b/src/transformers/models/auto/tokenization_auto.py @@ -689,6 +689,13 @@ "GemmaTokenizerFast" if is_tokenizers_available() else None, ), ), + ( + "t5gemma2", + ( + "GemmaTokenizer" if is_sentencepiece_available() else None, + "GemmaTokenizerFast" if is_tokenizers_available() else None, + ), + ), ("tapas", ("TapasTokenizer", None)), ("tapex", ("TapexTokenizer", None)), ("transfo-xl", ("TransfoXLTokenizer", None)), diff --git a/src/transformers/models/t5gemma/__init__.py b/src/transformers/models/t5gemma/__init__.py index aa8099e26782..0688bdb54cbe 100644 --- a/src/transformers/models/t5gemma/__init__.py +++ b/src/transformers/models/t5gemma/__init__.py @@ -18,8 +18,8 @@ if TYPE_CHECKING: - from .configuration_encdecgemma2 import * - from .modeling_encdecgemma2 import * + from .configuration_t5gemma import * + from .modeling_t5gemma import * else: import sys diff --git a/src/transformers/models/t5gemma2/__init__.py b/src/transformers/models/t5gemma2/__init__.py new file mode 100644 index 000000000000..7d018bfe722a --- /dev/null +++ b/src/transformers/models/t5gemma2/__init__.py @@ -0,0 +1,27 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_t5gemma2 import * + from .modeling_t5gemma2 import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/src/transformers/models/t5gemma2/configuration_t5gemma2.py b/src/transformers/models/t5gemma2/configuration_t5gemma2.py new file mode 100644 index 000000000000..37e1514a8301 --- /dev/null +++ b/src/transformers/models/t5gemma2/configuration_t5gemma2.py @@ -0,0 +1,420 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/t5gemma2/modular_t5gemma2.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_t5gemma2.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# coding=utf-8 +# Copyright 2025 Google Inc. HuggingFace Inc. team. All rights reserved. +# +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import copy +from typing import Any, Optional, Union + +from ...configuration_utils import PreTrainedConfig, layer_type_validation +from ...modeling_rope_utils import RopeParameters, rope_config_validation, standardize_rope_params +from ...utils import logging +from ..siglip import SiglipVisionConfig + + +logger = logging.get_logger(__name__) + + +class T5Gemma2ModuleConfig(PreTrainedConfig): + r""" + This is the configuration class to store the configuration of a [`T5Gemma2ModuleModel`]. It is used to instantiate an T5Gemma2Module + model according to the specified arguments, defining the model architecture. Instantiating a configuration with the + defaults will yield a similar configuration to that of the T5Gemma2Module-7B. + e.g. [google/t5_gemma2_module-7b](https://huggingface.co/google/t5_gemma2_module-7b) + Configuration objects inherit from [`PreTrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PreTrainedConfig`] for more information. + + Args: + vocab_size (`int`, *optional*, defaults to 262208): + Vocabulary size of the T5Gemma2Module model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`T5Gemma2ModuleModel`] + hidden_size (`int`, *optional*, defaults to 2304): + Dimension of the hidden representations. + intermediate_size (`int`, *optional*, defaults to 9216): + Dimension of the MLP representations. + num_hidden_layers (`int`, *optional*, defaults to 26): + Number of hidden layers in the Transformer decoder. + num_attention_heads (`int`, *optional*, defaults to 8): + Number of attention heads for each attention layer in the Transformer decoder. + num_key_value_heads (`int`, *optional*, defaults to 4): + This is the number of key_value heads that should be used to implement Grouped Query Attention. If + `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if + `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When + converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed + by meanpooling all the original heads within that group. For more details, check out [this + paper](https://huggingface.co/papers/2305.13245). If it is not specified, will default to + `num_attention_heads`. + head_dim (`int`, *optional*, defaults to 256): + The attention head dimension. + hidden_activation (`str` or `function`, *optional*, defaults to `"gelu_pytorch_tanh"`): + The non-linear activation function (function or string) in the decoder. Will default to `"gelu_pytorch_tanh"` + if not specified. `"gelu_pytorch_tanh"` uses an approximation of the `"gelu"` activation function. + max_position_embeddings (`int`, *optional*, defaults to 131072): + The maximum sequence length that this model might ever be used with. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + rms_norm_eps (`float`, *optional*, defaults to 1e-06): + The epsilon used by the rms normalization layers. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + pad_token_id (`int`, *optional*, defaults to 0): + Padding token id. + eos_token_id (`int`, *optional*, defaults to 1): + End of stream token id. + bos_token_id (`int`, *optional*, defaults to 2): + Beginning of stream token id. + tie_word_embeddings (`bool`, *optional*, defaults to `True`): + Whether to tie weight embeddings + attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`): + Whether to use a bias in the query, key, value and output projection layers during self-attention. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + query_pre_attn_scalar (`float`, *optional*, defaults to 256): + Scaling factor used on the attention scores + sliding_window (`int`, *optional*, defaults to 4096): + In T5Gemma2Module, every other layer uses sliding window attention. This is the size of the sliding window. + layer_types (`list`, *optional*): + Attention pattern for each layer. + final_logit_softcapping (`float`, *optional*): + Scaling factor when applying tanh softcapping on the logits. + attn_logit_softcapping (`float`, *optional*): + Scaling factor when applying tanh softcapping on the attention scores. + rope_parameters (`RopeParameters`, *optional*): + Dictionary containing the configuration parameters for the RoPE embeddings. The dictionaty should contain + a value for `rope_theta` and optionally parameters used for scaling in case you want to use RoPE + with longer `max_position_embeddings`. + use_bidirectional_attention (`bool`, *optional*, defaults to `False`): + If True, the model will attend to all text tokens instead of using a causal mask. This does not change + behavior for vision tokens. + + ```python + >>> from transformers import T5Gemma2ModuleModel, T5Gemma2ModuleConfig + >>> # Initializing a T5Gemma2Module t5_gemma2_module-7b style configuration + >>> configuration = T5Gemma2ModuleConfig() + >>> # Initializing a model from the t5_gemma2_module-7b style configuration + >>> model = T5Gemma2ModuleModel(configuration) + >>> # Accessing the model configuration + >>> configuration = model.config + ``` + """ + + model_type = "t5gemma2_module" + keys_to_ignore_at_inference = ["past_key_values"] + base_model_tp_plan = { + "layers.*.self_attn.q_proj": "colwise", + "layers.*.self_attn.k_proj": "colwise", + "layers.*.self_attn.v_proj": "colwise", + "layers.*.self_attn.o_proj": "rowwise", + "layers.*.mlp.gate_proj": "colwise", + "layers.*.mlp.up_proj": "colwise", + "layers.*.mlp.down_proj": "rowwise", + } + base_model_pp_plan = { + "embed_tokens": (["input_ids"], ["inputs_embeds"]), + "layers": (["hidden_states", "attention_mask"], ["hidden_states"]), + "norm": (["hidden_states"], ["hidden_states"]), + } + + def __init__( + self, + vocab_size: Optional[int] = 262_208, + hidden_size: Optional[int] = 2304, + intermediate_size: Optional[int] = 9216, + num_hidden_layers: Optional[int] = 26, + num_attention_heads: Optional[int] = 8, + num_key_value_heads: Optional[int] = 4, + head_dim: Optional[int] = 256, + hidden_activation: Optional[str] = "gelu_pytorch_tanh", + max_position_embeddings: Optional[int] = 131_072, + initializer_range: Optional[float] = 0.02, + rms_norm_eps: Optional[int] = 1e-6, + use_cache: Optional[bool] = True, + pad_token_id: Optional[int] = 0, + eos_token_id: Optional[int] = 1, + bos_token_id: Optional[int] = 2, + tie_word_embeddings: Optional[bool] = True, + attention_bias: Optional[bool] = False, + attention_dropout: Optional[float] = 0.0, + query_pre_attn_scalar: Optional[int] = 256, + sliding_window: Optional[int] = 4096, + layer_types: Optional[list[str]] = None, + final_logit_softcapping: Optional[float] = None, + attn_logit_softcapping: Optional[float] = None, + rope_parameters: Optional[RopeParameters | dict[str, RopeParameters]] = None, + use_bidirectional_attention: Optional[bool] = False, + **kwargs, + ): + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.head_dim = head_dim + self.num_key_value_heads = num_key_value_heads + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.use_cache = use_cache + self.attention_bias = attention_bias + self.attention_dropout = attention_dropout + self.hidden_activation = hidden_activation + self.query_pre_attn_scalar = query_pre_attn_scalar + self.sliding_window = sliding_window + self.final_logit_softcapping = final_logit_softcapping + self.attn_logit_softcapping = attn_logit_softcapping + self.layer_types = layer_types + + # Try to set `rope_scaling` if available, otherwise use `rope_parameters` + if (rope_scaling := kwargs.pop("rope_scaling", None)) is not None: + if rope_parameters is None: + rope_parameters = {"sliding_attention": {"rope_type": "default"}, "full_attention": rope_scaling} + elif "full_attention" in rope_parameters: + rope_parameters["full_attention"].update(rope_scaling) + else: + rope_parameters.update(rope_scaling) + + self.rope_parameters = rope_parameters + self.use_bidirectional_attention = use_bidirectional_attention + if use_bidirectional_attention: + self.sliding_window = (self.sliding_window // 2) + 1 # due to fa we set exclusive bounds + + # BC -> the pattern used to be a simple int, and it's still present in configs on the Hub + self._sliding_window_pattern = kwargs.get("sliding_window_pattern", 6) + + if self.layer_types is None: + self.layer_types = [ + "sliding_attention" if bool((i + 1) % self._sliding_window_pattern) else "full_attention" + for i in range(self.num_hidden_layers) + ] + layer_type_validation(self.layer_types, self.num_hidden_layers) + + # Validate the correctness of rotary position embeddings parameters + rope_theta = getattr(self, "rope_theta", 1_000_000.0) + rope_local_base_freq = getattr(self, "rope_local_base_freq", 10000.0) + standardize_rope_params( + self, rope_theta={"full_attention": rope_theta, "sliding_attention": rope_local_base_freq} + ) + rope_config_validation(self) + + +class T5Gemma2Config(PreTrainedConfig): + r""" + This is the configuration class to store the configuration of a [`T5Gemma2Model`]. It is used to instantiate an T5Gemma2 + model according to the specified arguments, defining the model architecture. Instantiating a configuration with the + defaults will yield a similar configuration to a hypothetical balanced Gemma3 encoder-decoder model. + e.g. [google/t5gemma-2-270m-270m](https://huggingface.co/google/t5gemma-2-270m-270m) + Configuration objects inherit from [PreTrainedConfig] and can be used to control the model outputs. Read the + documentation from [PreTrainedConfig] for more information. + + Args: + encoder (`Union[T5Gemma2ModuleConfig, dict]`, optional, *optional*): + Configuration for the encoder. + decoder (`Union[T5Gemma2ModuleConfig, dict]`, optional, *optional*): + Configuration for the decoder. + vision_config (`Union[SiglipVisionConfig, dict]`, optional, *optional*): + Configuration for the vision encoder. + is_encoder_decoder (bool, optional, *optional*, defaults to `True`): + Whether the model is used as an encoder/decoder or not. + dropout_rate (`float`, *optional*, defaults to 0.0): + The ratio for all dropout layers (following T5). + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for attention. + classifier_dropout_rate (`float`, *optional*, defaults to 0.0): + The dropout ratio for classifier (following T5). + mm_tokens_per_image (`int`, *optional*, defaults to 256): + The number of tokens per image embedding. + boi_token_index (`int`, *optional*, defaults to 255999): + The begin-of-image token index to wrap the image prompt. + eoi_token_index (`int`, *optional*, defaults to 256000): + The end-of-image token index to wrap the image prompt. + image_token_index (`int`, *optional*, defaults to 256001): + The image token index to encode the image prompt. Defaults to 256001, which is right after the eoi_token_index. + Note this is different from Gemma 3. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + vocab_size (`int`, *optional*, defaults to 262144): + Vocabulary size of the T5Gemma2 model (the same as Gemma 3). + kwargs (additional keyword arguments, optional, *optional*): + Will be passed to the PreTrainedConfig base class. + ```python + >>> from transformers import T5Gemma2Config, T5Gemma2Model + >>> t5gemma2_config = T5Gemma2Config.from_pretrained("google/t5gemma-270m-270m") + >>> model = T5Gemma2Model(t5gemma2_config) + ``` + """ + + model_type = "t5gemma2" + keys_to_ignore_at_inference = ["past_key_values"] + base_model_tp_plan = { + # encoder + "encoder.layers.*.self_attn.q_proj": "colwise", + "encoder.layers.*.self_attn.k_proj": "colwise", + "encoder.layers.*.self_attn.v_proj": "colwise", + "encoder.layers.*.self_attn.o_proj": "rowwise", + "encoder.layers.*.mlp.gate_proj": "colwise", + "encoder.layers.*.mlp.up_proj": "colwise", + "encoder.layers.*.mlp.down_proj": "rowwise", + # decoder + "decoder.layers.*.self_attn.q_proj": "colwise", + "decoder.layers.*.self_attn.k_proj": "colwise", + "decoder.layers.*.self_attn.v_proj": "colwise", + "decoder.layers.*.self_attn.o_proj": "rowwise", + "decoder.layers.*.mlp.gate_proj": "colwise", + "decoder.layers.*.mlp.up_proj": "colwise", + "decoder.layers.*.mlp.down_proj": "rowwise", + } + base_model_pp_plan = { + # encoder + "encoder.embed_tokens": (["input_ids"], ["inputs_embeds"]), + "encoder.layers": (["hidden_states", "attention_mask"], ["hidden_states"]), + "encoder.norm": (["hidden_states"], ["hidden_states"]), + # decoder + "decoder.embed_tokens": (["input_ids"], ["inputs_embeds"]), + "decoder.layers": (["hidden_states", "attention_mask"], ["hidden_states"]), + "decoder.norm": (["hidden_states"], ["hidden_states"]), + } + attribute_map = { + "image_token_id": "image_token_index", + "boi_token_id": "boi_token_index", + "eoi_token_id": "eoi_token_index", + } + + def __init__( + self, + encoder: Optional[Union[T5Gemma2ModuleConfig, dict[str, Any]]] = None, + decoder: Optional[Union[T5Gemma2ModuleConfig, dict[str, Any]]] = None, + vision_config: Optional[Union[SiglipVisionConfig, dict[str, Any]]] = None, + is_encoder_decoder: bool = True, + dropout_rate: float = 0.0, + attention_dropout: float = 0.0, + classifier_dropout_rate: float = 0.0, + mm_tokens_per_image: int = 256, + boi_token_index: int = 255_999, + eoi_token_index: int = 256_000, + image_token_index: int = 256_001, + initializer_range: float = 0.02, + vocab_size: int = 262_144, + **kwargs, + ): + if isinstance(encoder, dict): + encoder = T5Gemma2ModuleConfig(**encoder) + elif encoder is None: + encoder = T5Gemma2ModuleConfig() + logger.info("encoder is None, using default T5Gemma2ModuleConfig encoder config.") + else: + if not isinstance(encoder, T5Gemma2ModuleConfig): + raise ValueError(f"{type(encoder)} is not supported.") + + if isinstance(decoder, dict): + decoder = T5Gemma2ModuleConfig(**decoder) + elif decoder is None: + decoder = copy.deepcopy(encoder) + logger.info("decoder is None, using the same config as encoder.") + else: + if not isinstance(decoder, T5Gemma2ModuleConfig): + raise ValueError(f"{type(decoder)} is not supported.") + + if isinstance(vision_config, dict): + vision_config = SiglipVisionConfig(**vision_config) + elif vision_config is None: + vision_config = SiglipVisionConfig() + logger.info("vision_config is None, using default SiglipVisionConfig vision config.") + else: + if not isinstance(vision_config, SiglipVisionConfig): + raise ValueError(f"{type(vision_config)} is not supported.") + + if encoder.hidden_size != decoder.hidden_size: + raise ValueError( + "Imbalanced encoder-decoder is not supported in T5Gemma2: " + f"encoder ({encoder.hidden_size}) vs decoder ({decoder.hidden_size})." + ) + + if not is_encoder_decoder: + raise ValueError("T5Gemma2Model only support encoder-decoder modeling.") + + if encoder.vocab_size != decoder.vocab_size: + raise ValueError( + "Imbalanced encoder-decoder vocabulary size is not supported in T5Gemma2: " + f"encoder ({encoder.vocab_size}) vs decoder ({decoder.vocab_size})." + ) + + encoder = T5Gemma2ModuleConfig(**encoder.to_dict()) + decoder = T5Gemma2ModuleConfig(**decoder.to_dict()) + vision_config = SiglipVisionConfig(**vision_config.to_dict()) + + encoder.is_decoder = False + encoder.dropout_rate = dropout_rate + encoder.attention_dropout = attention_dropout + self.encoder = encoder + + decoder.is_decoder = True + decoder.use_cache = True + decoder.dropout_rate = dropout_rate + decoder.attention_dropout = attention_dropout + self.decoder = decoder + + self.vision_config = vision_config + + for special_token_key in ["bos_token_id", "pad_token_id", "eos_token_id", "use_cache"]: + if special_token_key not in kwargs: + kwargs[special_token_key] = getattr(decoder, special_token_key) + + super().__init__(**kwargs) + + self.mm_tokens_per_image = mm_tokens_per_image + self.boi_token_index = boi_token_index + self.eoi_token_index = eoi_token_index + self.image_token_index = image_token_index + self.initializer_range = initializer_range + + self.is_encoder_decoder = is_encoder_decoder + self.dropout_rate = dropout_rate + self.attention_dropout = attention_dropout + self.classifier_dropout_rate = classifier_dropout_rate + + # Used in pipeline generation. + self.vocab_size = vocab_size + + def __setattr__(self, key, value): + shared_attr_with_submodules = [ + "output_hidden_states", + "output_attentions", + "_attn_implementation", + "dropout_rate", + "attention_dropout", + "vocab_size", + ] + + if key in shared_attr_with_submodules: + setattr(self.encoder, key, value) + setattr(self.decoder, key, value) + if key == "_attn_implementation": + self.vision_config._attn_implementation = value + super().__setattr__(key, value) + + +__all__ = ["T5Gemma2Config", "T5Gemma2ModuleConfig"] diff --git a/src/transformers/models/t5gemma2/modeling_t5gemma2.py b/src/transformers/models/t5gemma2/modeling_t5gemma2.py new file mode 100644 index 000000000000..86d1c845bebe --- /dev/null +++ b/src/transformers/models/t5gemma2/modeling_t5gemma2.py @@ -0,0 +1,1559 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/t5gemma2/modular_t5gemma2.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_t5gemma2.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# coding=utf-8 +# Copyright 2025 Google Inc. HuggingFace Inc. team. All rights reserved. +# +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from collections.abc import Callable +from typing import Optional, Union + +import torch +import torch.nn as nn + +from ...activations import ACT2FN +from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache +from ...generation import GenerationMixin +from ...masking_utils import ( + and_masks, + create_bidirectional_mask, + create_causal_mask, + create_sliding_window_causal_mask, +) +from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_layers import GradientCheckpointingLayer +from ...modeling_outputs import ( + BaseModelOutput, + BaseModelOutputWithPastAndCrossAttentions, + Seq2SeqLMOutput, + Seq2SeqModelOutput, + SequenceClassifierOutput, + TokenClassifierOutput, +) +from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...processing_utils import Unpack +from ...utils import TransformersKwargs, auto_docstring, can_return_tuple +from ...utils.generic import OutputRecorder, check_model_inputs +from ..auto import AutoModel +from .configuration_t5gemma2 import T5Gemma2Config, T5Gemma2ModuleConfig + + +class T5Gemma2RMSNorm(nn.Module): + def __init__(self, dim: int, eps: float = 1e-6): + super().__init__() + self.eps = eps + self.weight = nn.Parameter(torch.zeros(dim)) + + def _norm(self, x): + return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) + + def forward(self, x): + output = self._norm(x.float()) + # Llama does x.to(float16) * w whilst T5Gemma2 is (x * w).to(float16) + # See https://github.com/huggingface/transformers/pull/29402 + output = output * (1.0 + self.weight.float()) + return output.type_as(x) + + def extra_repr(self): + return f"{tuple(self.weight.shape)}, eps={self.eps}" + + +class T5Gemma2MLP(nn.Module): + def __init__(self, config: T5Gemma2ModuleConfig): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) + self.act_fn = ACT2FN[config.hidden_activation] + self.dropout = nn.Dropout(config.dropout_rate) + + def forward(self, x): + hidden_states = self.act_fn(self.gate_proj(x)) * self.up_proj(x) + hidden_states = self.dropout(hidden_states) + down_proj = self.down_proj(hidden_states) + return down_proj + + +class T5Gemma2RotaryEmbedding(nn.Module): + inv_freq: torch.Tensor # fix linting for `register_buffer` + + def __init__(self, config: T5Gemma2ModuleConfig, device=None): + super().__init__() + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings + + self.config = config + + self.layer_types = list(set(config.layer_types)) + self.rope_type = {} + for layer_type in self.layer_types: + rope_params = self.config.rope_parameters[layer_type] + if rope_params is None: + continue + + self.rope_type[layer_type] = rope_params["rope_type"] + rope_init_fn: Callable = self.compute_default_rope_parameters + if self.rope_type[layer_type] != "default": + rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type[layer_type]] + curr_inv_freq, curr_attention_scaling = rope_init_fn(self.config, device, layer_type=layer_type) + self.register_buffer(f"{layer_type}_inv_freq", curr_inv_freq, persistent=False) + setattr(self, f"{layer_type}_original_inv_freq", curr_inv_freq) + setattr(self, f"{layer_type}_attention_scaling", curr_attention_scaling) + + @staticmethod + def compute_default_rope_parameters( + config: Optional[T5Gemma2ModuleConfig] = None, + device: Optional["torch.device"] = None, + seq_len: Optional[int] = None, + layer_type: Optional[str] = None, + ) -> tuple["torch.Tensor", float]: + """ + Computes the inverse frequencies according to the original RoPE implementation + Args: + config ([`~transformers.PreTrainedConfig`]): + The model configuration. + device (`torch.device`): + The device to use for initialization of the inverse frequencies. + seq_len (`int`, *optional*): + The current sequence length. Unused for this type of RoPE. + layer_type (`str`, *optional*): + The current layer type if the model has different RoPE parameters per type. + Should not be used unless `config.layer_types is not None` + + Returns: + Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the + post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE). + """ + # For backward compatibility standardize the `rope_parameters_dict` if it uses old format + base = config.rope_parameters[layer_type]["rope_theta"] + dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads + + attention_factor = 1.0 # Unused in this type of RoPE + + # Compute the inverse frequencies + inv_freq = 1.0 / ( + base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim) + ) + return inv_freq, attention_factor + + @torch.no_grad() + @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope) + def forward(self, x, position_ids, layer_type=None): + inv_freq = getattr(self, f"{layer_type}_inv_freq") + attention_scaling = getattr(self, f"{layer_type}_attention_scaling") + + inv_freq_expanded = inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device) + position_ids_expanded = position_ids[:, None, :].float() + + device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): # Force float32 + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() * attention_scaling + sin = emb.sin() * attention_scaling + + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`, *optional*): + Deprecated and unused. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + dropout: float = 0.0, + scaling: Optional[float] = None, + softcap: Optional[float] = None, + **kwargs, +) -> tuple[torch.Tensor, torch.Tensor]: + if scaling is None: + scaling = module.head_dim**-0.5 + + key_states = repeat_kv(key, module.num_key_value_groups) + value_states = repeat_kv(value, module.num_key_value_groups) + + attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling + + if softcap is not None: + attn_weights = attn_weights / softcap + attn_weights = torch.tanh(attn_weights) + attn_weights = attn_weights * softcap + if attention_mask is not None: # no matter the length, we just slice it + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + attn_output = torch.matmul(attn_weights, value_states) + attn_output = attn_output.transpose(1, 2).contiguous() + return attn_output, attn_weights + + +class T5Gemma2SelfAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: T5Gemma2ModuleConfig, layer_idx: int): + super().__init__() + self.layer_type = config.layer_types[layer_idx] if hasattr(config, "layer_types") else None + self.config = config + self.layer_idx = layer_idx + self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) + self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads + self.scaling = config.query_pre_attn_scalar**-0.5 + self.attention_dropout = self.config.attention_dropout + self.is_causal = not self.config.use_bidirectional_attention + + self.q_proj = nn.Linear( + config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias + ) + self.k_proj = nn.Linear( + config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias + ) + self.v_proj = nn.Linear( + config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias + ) + self.o_proj = nn.Linear( + config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias + ) + self.attn_logit_softcapping = self.config.attn_logit_softcapping + self.sliding_window = config.sliding_window if self.layer_type == "sliding_attention" else None + self.is_sliding = self.layer_type == "sliding_attention" + + self.q_norm = T5Gemma2RMSNorm(dim=config.head_dim, eps=config.rms_norm_eps) + self.k_norm = T5Gemma2RMSNorm(dim=config.head_dim, eps=config.rms_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: torch.Tensor = None, + attention_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[Cache] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[FlashAttentionKwargs], + ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) + + query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) + key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) + + query_states = self.q_norm(query_states) + key_states = self.k_norm(key_states) + + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_values is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) + + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=self.attention_dropout if self.training else 0.0, + scaling=self.scaling, + sliding_window=self.sliding_window, + **kwargs, + ) + + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + attn_output = self.o_proj(attn_output) + return attn_output, attn_weights + + +class T5Gemma2MergedAttention(nn.Module): + """Merged self-attention and cross-attention for decoder.""" + + def __init__(self, config: T5Gemma2ModuleConfig, layer_idx: int): + super().__init__() + self.layer_type = config.layer_types[layer_idx] if hasattr(config, "layer_types") else None + self.config = config + self.layer_idx = layer_idx + self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) + self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads + self.scaling = config.query_pre_attn_scalar**-0.5 + self.attention_dropout = self.config.attention_dropout + self.is_causal = not self.config.use_bidirectional_attention + + self.q_proj = nn.Linear( + config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias + ) + self.k_proj = nn.Linear( + config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias + ) + self.v_proj = nn.Linear( + config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias + ) + self.o_proj = nn.Linear( + config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias + ) + self.attn_logit_softcapping = self.config.attn_logit_softcapping + self.sliding_window = config.sliding_window if self.layer_type == "sliding_attention" else None + self.is_sliding = self.layer_type == "sliding_attention" + + self.q_norm = T5Gemma2RMSNorm(dim=config.head_dim, eps=config.rms_norm_eps) + self.k_norm = T5Gemma2RMSNorm(dim=config.head_dim, eps=config.rms_norm_eps) + + def forward( + self, + # decoder self-attention inputs + hidden_states: torch.Tensor, + position_embeddings: tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor], + # cross-attention inputs + encoder_hidden_states: torch.Tensor, + encoder_attention_mask: Optional[torch.Tensor], + # cache inputs + past_key_values: Optional[EncoderDecoderCache] = None, + cache_position: Optional[torch.LongTensor] = None, + # others + **kwargs: Unpack[FlashAttentionKwargs], + ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: + # attention shapes. + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) + cross_input_shape = encoder_hidden_states.shape[:-1] + cross_hidden_shape = (*cross_input_shape, -1, self.head_dim) + + # self-attention. + query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) + key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) + + query_states = self.q_norm(query_states) + key_states = self.k_norm(key_states) + + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_values is not None: + # self-attention. + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + self_attention_cache = past_key_values.self_attention_cache + key_states, value_states = self_attention_cache.update( + key_states, value_states, self.layer_idx, cache_kwargs + ) + + # cross-attention. + is_updated = past_key_values.is_updated.get(self.layer_idx) + cross_attention_cache = past_key_values.cross_attention_cache + + if past_key_values is None or not is_updated: + cross_key_states = self.k_proj(encoder_hidden_states).view(cross_hidden_shape).transpose(1, 2) + cross_value_states = self.v_proj(encoder_hidden_states).view(cross_hidden_shape).transpose(1, 2) + + cross_key_states = self.k_norm(cross_key_states) + + if past_key_values is not None: + # Handle sliding window for cross-attention: convert the window size to the input size. + if len(cross_attention_cache.layers) > self.layer_idx and hasattr( + cross_attention_cache.layers[self.layer_idx], "sliding_window" + ): + cross_attention_cache.layers[self.layer_idx].sliding_window = cross_key_states.shape[2] + 1 + cross_key_states, cross_value_states = cross_attention_cache.update( + cross_key_states, cross_value_states, self.layer_idx + ) + past_key_values.is_updated[self.layer_idx] = True + else: + cross_key_states = cross_attention_cache.layers[self.layer_idx].keys + cross_value_states = cross_attention_cache.layers[self.layer_idx].values + + # merged attention. + query_states = query_states + cross_key_size = cross_key_states.shape[2] + key_states = torch.cat([key_states, cross_key_states], dim=2) + value_states = torch.cat([value_states, cross_value_states], dim=2) + # merge attention mask. + is_self_attn_mask_none = attention_mask is None + is_cross_attn_mask_none = encoder_attention_mask is None + if is_self_attn_mask_none and is_cross_attn_mask_none: + attention_mask = None + elif is_self_attn_mask_none ^ is_cross_attn_mask_none: + raise ValueError( + f"Either both or neither of attention_mask ({is_self_attn_mask_none}) and " + f"encoder_attention_mask ({is_cross_attn_mask_none}) should be None." + ) + else: + if attention_mask.ndim != encoder_attention_mask.ndim: + raise ValueError( + f"Attention mask dimension {attention_mask.ndim} and encoder attention mask {encoder_attention_mask.ndim} do not match." + ) + + attention_mask = torch.cat([attention_mask, encoder_attention_mask], dim=-1) + + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=self.attention_dropout if self.training else 0.0, + scaling=self.scaling, + # merged attention is not causal or sliding window + sliding_window=None, + is_causal=False, + **kwargs, + ) + + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + attn_output = self.o_proj(attn_output) + + # decompose merged attention weights into self & cross attention weights + if attn_weights is not None: + self_attn_weights = attn_weights[..., :-cross_key_size] + cross_attn_weights = attn_weights[..., -cross_key_size:] + else: + self_attn_weights, cross_attn_weights = None, None + return attn_output, self_attn_weights, cross_attn_weights + + +class T5Gemma2EncoderLayer(GradientCheckpointingLayer): + """Encoder sub-layer.""" + + def __init__(self, config, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + self.config = config + self.layer_idx = layer_idx + self.attention_type = config.layer_types[layer_idx] + + self.self_attn = T5Gemma2SelfAttention( + config=config, + layer_idx=layer_idx, + ) + self.pre_self_attn_layernorm = T5Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_self_attn_layernorm = T5Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + self.mlp = T5Gemma2MLP(config) + self.pre_feedforward_layernorm = T5Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_feedforward_layernorm = T5Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + self.dropout = nn.Dropout(config.dropout_rate) + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + **kwargs, + ) -> tuple[torch.FloatTensor,]: + residual = hidden_states + hidden_states = self.pre_self_attn_layernorm(hidden_states) + hidden_states, _ = self.self_attn( + hidden_states=hidden_states, + position_embeddings=position_embeddings, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=None, + **kwargs, + ) + hidden_states = self.post_self_attn_layernorm(hidden_states) + hidden_states = residual + self.dropout(hidden_states) + + residual = hidden_states + hidden_states = self.pre_feedforward_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = self.post_feedforward_layernorm(hidden_states) + hidden_states = residual + self.dropout(hidden_states) + return hidden_states + + +class T5Gemma2DecoderLayer(T5Gemma2EncoderLayer): + """Decoder sub-layer: merged attention instead of vanilla self-attention.""" + + def __init__(self, config, layer_idx: int): + super().__init__(config, layer_idx) + + # replace vanilla self-attention with merged attention to support joint cross-attention. + self.self_attn = T5Gemma2MergedAttention( + config=config, + layer_idx=layer_idx, + ) + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[EncoderDecoderCache] = None, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + **kwargs, + ) -> torch.FloatTensor: + residual = hidden_states + hidden_states = self.pre_self_attn_layernorm(hidden_states) + + hidden_states, _, _ = self.self_attn( + hidden_states=hidden_states, + position_embeddings=position_embeddings, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + use_cache=use_cache, + cache_position=cache_position, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + **kwargs, + ) + hidden_states = self.post_self_attn_layernorm(hidden_states) + hidden_states = residual + self.dropout(hidden_states) + + residual = hidden_states + hidden_states = self.pre_feedforward_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = self.post_feedforward_layernorm(hidden_states) + hidden_states = residual + self.dropout(hidden_states) + return hidden_states + + +class T5Gemma2LMHead(nn.Module): + """Head for language modeling (generation) tasks.""" + + def __init__(self, hidden_size: int, vocab_size: int, bias: bool = False): + super().__init__() + self.out_proj = nn.Linear(hidden_size, vocab_size, bias=bias) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + logits = self.out_proj(hidden_states) + return logits + + +class T5Gemma2ClassificationHead(nn.Module): + """Head for sentence-level classification tasks.""" + + def __init__(self, hidden_size: int, num_labels: int, classifier_dropout_rate: float = 0.0): + super().__init__() + self.dropout = nn.Dropout(p=classifier_dropout_rate) + self.out_proj = nn.Linear(hidden_size, num_labels) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.dropout(hidden_states) + hidden_states = self.out_proj(hidden_states) + return hidden_states + + +class T5Gemma2MultiModalProjector(nn.Module): + def __init__(self, config: T5Gemma2Config): + super().__init__() + + self.mm_input_projection_weight = nn.Parameter( + torch.zeros(config.vision_config.hidden_size, config.encoder.hidden_size) + ) + + self.mm_soft_emb_norm = T5Gemma2RMSNorm( + config.vision_config.hidden_size, eps=config.vision_config.layer_norm_eps + ) + + self.patches_per_image = int(config.vision_config.image_size // config.vision_config.patch_size) + self.tokens_per_side = int(config.mm_tokens_per_image**0.5) + self.kernel_size = self.patches_per_image // self.tokens_per_side + self.avg_pool = nn.AvgPool2d(kernel_size=self.kernel_size, stride=self.kernel_size) + + def forward(self, vision_outputs: torch.Tensor): + batch_size, _, seq_length = vision_outputs.shape + + reshaped_vision_outputs = vision_outputs.transpose(1, 2) + reshaped_vision_outputs = reshaped_vision_outputs.reshape( + batch_size, seq_length, self.patches_per_image, self.patches_per_image + ) + reshaped_vision_outputs = reshaped_vision_outputs.contiguous() + + pooled_vision_outputs = self.avg_pool(reshaped_vision_outputs) + pooled_vision_outputs = pooled_vision_outputs.flatten(2) + pooled_vision_outputs = pooled_vision_outputs.transpose(1, 2) + + normed_vision_outputs = self.mm_soft_emb_norm(pooled_vision_outputs) + + projected_vision_outputs = torch.matmul(normed_vision_outputs, self.mm_input_projection_weight) + return projected_vision_outputs.type_as(vision_outputs) + + +class T5Gemma2TextScaledWordEmbedding(nn.Embedding): + """T5Gemma2 Embedding: override to add eoi token embedding separately.""" + + def __init__( + self, + num_embeddings: int, + embedding_dim: int, + padding_idx: int, + embed_scale: float = 1.0, + eoi_token_index: int = 256_000, + ): + super().__init__(num_embeddings, embedding_dim, padding_idx) + self.register_buffer("embed_scale", torch.tensor(embed_scale), persistent=False) + self.eoi_token_index = eoi_token_index + self.eoi_embedding = nn.Parameter(torch.zeros(self.embedding_dim)) + + def forward(self, input_ids: torch.Tensor): + input_embeddings = super().forward(input_ids) * self.embed_scale.to(self.weight.dtype) + input_embeddings[input_ids == self.eoi_token_index] = self.eoi_embedding.to(input_embeddings.dtype) + return input_embeddings + + +@auto_docstring +class T5Gemma2PreTrainedModel(PreTrainedModel): + config: T5Gemma2Config + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = [ + "T5Gemma2EncoderLayer", + "T5Gemma2DecoderLayer", + "SiglipVisionEmbeddings", + "SiglipEncoderLayer", + "SiglipMultiheadAttentionPoolingHead", + ] + _skip_keys_device_placement = ["past_key_values"] + _supports_flash_attn = True + _supports_sdpa = True + _supports_flex_attn = True + + _can_compile_fullgraph = True + _supports_attention_backend = True + _can_record_outputs = { + "hidden_states": [T5Gemma2EncoderLayer, T5Gemma2DecoderLayer], + "attentions": [ + OutputRecorder(T5Gemma2SelfAttention, index=1, layer_name="self_attn"), + OutputRecorder(T5Gemma2MergedAttention, index=1, layer_name="self_attn"), + OutputRecorder(T5Gemma2MergedAttention, index=2, layer_name="cross_attn"), + ], + } + input_modalities = ["image", "text"] + + def _init_weights(self, module): + super()._init_weights(module) + if isinstance(module, T5Gemma2MultiModalProjector): + module.mm_input_projection_weight.data.zero_() + elif isinstance(module, T5Gemma2TextScaledWordEmbedding): + module.eoi_embedding.data.zero_() + elif isinstance(module, T5Gemma2ClassificationHead): + scale = module.out_proj.weight.shape[0] ** -0.5 + module.out_proj.weight.data.normal_(mean=0.0, std=self.config.initializer_range * scale) + if hasattr(module.out_proj, "bias") and module.out_proj.bias is not None: + module.out_proj.bias.data.zero_() + # We initialize with 0s to be 1 centered as the RMSNorm here does (1 + weight) + elif "RMSNorm" in module.__class__.__name__: + module.weight.data.zero_() + + def _shift_right(self, input_ids): + """ + Shifts input_ids to the right, prepends the decoder_start_token_id, and handles + pad_token_id replacement for labels that were -100. + This is a common preparation step for decoder inputs in sequence-to-sequence models. + """ + decoder_config = self.config.decoder + decoder_start_token_id = decoder_config.bos_token_id + pad_token_id = decoder_config.pad_token_id + + if decoder_start_token_id is None: + raise ValueError("self.model.config.decoder.bos_token_id has to be defined. ") + + # shift inputs to the right + shifted_input_ids = input_ids.new_zeros(input_ids.shape) + shifted_input_ids[..., 1:] = input_ids[..., :-1].clone() + shifted_input_ids[..., 0] = decoder_start_token_id + + if pad_token_id is None: + raise ValueError("self.model.config.decoder.pad_token_id has to be defined.") + + # Is this T5 specific? + # replace possible -100 values in labels by `pad_token_id` + shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id) + + return shifted_input_ids + + +def bidirectional_mask_function(attention_mask: Optional[torch.Tensor]) -> Callable: + """ + This creates bidirectional attention mask. + """ + + def inner_mask(batch_idx: int, head_idx: int, q_idx: int, kv_idx: int) -> bool: + if attention_mask is None: + return torch.ones((), dtype=torch.bool) + return attention_mask[batch_idx, kv_idx].to(torch.bool) + + return inner_mask + + +def make_default_2d_attention_mask( + token_ids: Optional[torch.LongTensor], + hidden_states: torch.Tensor, + pad_token_id: Optional[int], +) -> torch.Tensor: + """Construct the default attention mask.""" + if token_ids is not None: + if pad_token_id is None: + raise ValueError("`pad_token_id` is required for padding information.") + attention_mask = (token_ids != pad_token_id).to(hidden_states.device, torch.long) + else: + attention_mask = torch.ones( + (hidden_states.shape[0], hidden_states.shape[1]), device=hidden_states.device, dtype=torch.long + ) + return attention_mask + + +def sliding_window_mask_function(sliding_window: int, is_causal=True) -> Callable: + """ + This creates uni/bidirectional attention mask with sliding window. + """ + + def inner_mask(batch_idx: int, head_idx: int, q_idx: int, kv_idx: int) -> bool: + if is_causal: + left_window_size, right_window_size = sliding_window, 0 + else: + left_window_size, right_window_size = ((sliding_window + 1) // 2, (sliding_window) // 2 + 1) + + dist = q_idx - kv_idx + left_mask = (dist >= 0) & (dist < left_window_size) + right_mask = (dist < 0) & (-dist < right_window_size) + return left_mask | right_mask + + return inner_mask + + +class T5Gemma2Encoder(T5Gemma2PreTrainedModel): + config: T5Gemma2ModuleConfig + _can_record_outputs = { + "attentions": T5Gemma2SelfAttention, + "hidden_states": T5Gemma2EncoderLayer, + } + + def __init__( + self, + config: T5Gemma2ModuleConfig, + eoi_token_index: int = 256_000, + pixel2feature_preprocessor_fn: Optional[Callable] = None, + ): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + # preprocessor for raw images pixel values: injected from outside. + self.pixel2feature_preprocessor_fn = pixel2feature_preprocessor_fn + + self.embed_tokens = T5Gemma2TextScaledWordEmbedding( + config.vocab_size, + config.hidden_size, + config.pad_token_id, + embed_scale=config.hidden_size**0.5, + eoi_token_index=eoi_token_index, + ) + self.norm = T5Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.gradient_checkpointing = False + + self.layers = nn.ModuleList( + [T5Gemma2EncoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self.dropout = nn.Dropout(config.dropout_rate) + self.rotary_emb = T5Gemma2RotaryEmbedding(config) + + # Initialize weights and apply final processing + self.post_init() + + @check_model_inputs() + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + pixel_values: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> BaseModelOutput: + """ + pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`, *optional*): + Pixel values to be processed by the image encoder to extract image features. + """ + # Part of Gemma 3 processor output but not used by t5gemma 2. + del token_type_ids + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + # As we want to pass `past_key_values=None` explicitly everywhere, we need to pop them from kwargs if present + kwargs.pop("past_key_values", None) + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if pixel_values is not None: + if self.pixel2feature_preprocessor_fn is None: + raise ValueError("`pixel2feature_preprocessor_fn` has to be provided to process `pixel_values`.") + image_features, image_mask = self.pixel2feature_preprocessor_fn(pixel_values, input_ids, inputs_embeds) + + image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) + inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_features) + + if position_ids is None: + position_ids = torch.arange(0, inputs_embeds.shape[1], device=inputs_embeds.device).unsqueeze(0) + + if attention_mask is None: + attention_mask = make_default_2d_attention_mask(input_ids, inputs_embeds, self.config.pad_token_id) + + if not isinstance(self_attn_mask_mapping := attention_mask, dict): + mask_kwargs = { + "config": self.config, + "input_embeds": inputs_embeds, + "attention_mask": attention_mask, + } + self_attn_mask_mapping = { + "full_attention": create_bidirectional_mask( + **mask_kwargs, + and_mask_function=bidirectional_mask_function(attention_mask), + ), + "sliding_attention": create_bidirectional_mask( + **mask_kwargs, + and_mask_function=and_masks( + sliding_window_mask_function(self.config.sliding_window, is_causal=False), + bidirectional_mask_function(attention_mask), + ), + ), + } + + # input layer + hidden_states = inputs_embeds + + # global and local position embeddings + position_embeddings = {} + for layer_type in self.config.layer_types: + position_embeddings[layer_type] = self.rotary_emb(hidden_states, position_ids, layer_type) + + # dropout + hidden_states = self.dropout(hidden_states) + + for layer_module in self.layers[: self.config.num_hidden_layers]: + if not isinstance(layer_module, T5Gemma2EncoderLayer): + raise ValueError(f"Expected T5Gemma2EncoderLayer, but got {type(layer_module)}.") + hidden_states = layer_module( + hidden_states, + position_embeddings[layer_module.attention_type], + self_attn_mask_mapping[layer_module.attention_type], + position_ids, + **kwargs, + ) + + hidden_states = self.norm(hidden_states) + hidden_states = self.dropout(hidden_states) + return BaseModelOutput( + last_hidden_state=hidden_states, + ) + + +class T5Gemma2Decoder(T5Gemma2Encoder): + _can_record_outputs = { + "attentions": OutputRecorder(T5Gemma2MergedAttention, index=1), + "cross_attentions": OutputRecorder(T5Gemma2MergedAttention, index=2), + "hidden_states": T5Gemma2DecoderLayer, + } + + def __init__(self, config: T5Gemma2ModuleConfig, shared_embedding: T5Gemma2TextScaledWordEmbedding): + super().__init__(config, shared_embedding) + self.layers = nn.ModuleList( + [T5Gemma2DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + + self.post_init() + + @check_model_inputs() + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[EncoderDecoderCache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> BaseModelOutputWithPastAndCrossAttentions: + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + if encoder_hidden_states is None: + raise ValueError("`encoder_hidden_states` must be given in decoder") + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if not self.training and use_cache and past_key_values is None: + past_key_values = EncoderDecoderCache(DynamicCache(config=self.config), DynamicCache(config=self.config)) + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + if attention_mask is None and past_key_values is None: + attention_mask = make_default_2d_attention_mask(input_ids, inputs_embeds, self.config.pad_token_id) + + if not isinstance(self_attn_mask_mapping := attention_mask, dict): + mask_kwargs = { + "config": self.config, + "input_embeds": inputs_embeds, + "attention_mask": attention_mask, + "cache_position": cache_position, + "past_key_values": past_key_values.self_attention_cache if past_key_values is not None else None, + "position_ids": position_ids, + } + # this masking function did nothing to masking but forces `allow_is_causal_skip` to be False + # as we always need a mask during decoding for merged attention. + mask_kwargs["and_mask_function"] = lambda *args: torch.tensor(True, dtype=torch.bool) + self_attn_mask_mapping = { + "full_attention": create_causal_mask(**mask_kwargs), + "sliding_attention": create_sliding_window_causal_mask(**mask_kwargs), + } + + if not isinstance(cross_attn_mask_mapping := encoder_attention_mask, dict): + mask_kwargs = { + "config": self.config, + "input_embeds": encoder_hidden_states, + "attention_mask": encoder_attention_mask, + "cache_position": cache_position, + "past_key_values": None, + "position_ids": None, + } + cross_attn_mask_mapping = { + "full_attention": create_causal_mask( + **mask_kwargs, + or_mask_function=bidirectional_mask_function(encoder_attention_mask), + ), + } + + # input layer + hidden_states = inputs_embeds + + # global and local position embeddings + position_embeddings = {} + for layer_type in self.config.layer_types: + position_embeddings[layer_type] = self.rotary_emb(hidden_states, position_ids, layer_type) + + # dropout + hidden_states = self.dropout(hidden_states) + + for layer_module in self.layers[: self.config.num_hidden_layers]: + if not isinstance(layer_module, T5Gemma2DecoderLayer): + raise ValueError(f"Expected T5Gemma2DecoderLayer, but got {type(layer_module)}.") + hidden_states = layer_module( + hidden_states, + position_embeddings[layer_module.attention_type], + self_attn_mask_mapping[layer_module.attention_type], + position_ids, + past_key_values, + use_cache, + cache_position, + encoder_hidden_states, + cross_attn_mask_mapping["full_attention"], + **kwargs, + ) + + hidden_states = self.norm(hidden_states) + hidden_states = self.dropout(hidden_states) + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=past_key_values, + ) + + +@auto_docstring +class T5Gemma2Model(T5Gemma2PreTrainedModel): + _tied_weights_keys = [ + "encoder.embed_tokens.weight", + "decoder.embed_tokens.weight", + ] + _dynamic_tied_weights_keys = [ + "encoder.embed_tokens.eoi_embedding", + "decoder.embed_tokens.eoi_embedding", + ] + + def __init__(self, config: T5Gemma2Config): + super().__init__(config) + + # setup encoder and decoder + self.encoder = T5Gemma2Encoder(config.encoder, config.eoi_token_index, self.pixel2feature_preprocessor) + self.decoder = T5Gemma2Decoder(config.decoder, config.eoi_token_index) + + # setup vision encoder + self.vision_tower = AutoModel.from_config(config=config.vision_config) + self.multi_modal_projector = T5Gemma2MultiModalProjector(config) + + self.post_init() + + def get_encoder(self): + return self.encoder + + def get_decoder(self): + return self.decoder + + def get_input_embeddings(self): + return self.encoder.get_input_embeddings() + + def set_input_embeddings(self, new_embeddings): + return self.encoder.set_input_embeddings(new_embeddings) + + def _tie_weights(self): + # Decoder input and output embeddings are tied. + if self.config.tie_word_embeddings: + self.decoder.embed_tokens.weight = self.encoder.embed_tokens.weight + self.decoder.embed_tokens.eoi_embedding = self.encoder.embed_tokens.eoi_embedding + + def _get_image_features(self, pixel_values: torch.Tensor) -> torch.Tensor: + """Convert pixel image to image features via the encoder and projector.""" + # pixel_values: (batch_size, channels, height, width) + # image_features: Image feature tensor of shape (num_images, image_length, embed_dim). + vision_outputs = self.vision_tower(pixel_values=pixel_values).last_hidden_state + image_features = self.multi_modal_projector(vision_outputs) + return image_features + + def _get_placeholder_mask( + self, + input_ids: Optional[torch.LongTensor], + inputs_embeds: Optional[torch.FloatTensor], + image_features: torch.FloatTensor, + ): + """ + Obtains multimodal placeholder mask from `input_ids` or `inputs_embeds`, and checks that the placeholder token count is + equal to the length of multimodal features. If the lengths are different, an error is raised. + """ + if input_ids is None: + special_image_mask = inputs_embeds == self.get_input_embeddings()( + torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device) + ) + special_image_mask = special_image_mask.all(-1) + else: + special_image_mask = input_ids == self.config.image_token_id + + n_image_tokens = special_image_mask.sum() + special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) + n_image_features = image_features.shape[0] * image_features.shape[1] + if inputs_embeds[special_image_mask].numel() != image_features.numel(): + raise ValueError( + f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" + ) + return special_image_mask + + def pixel2feature_preprocessor( + self, + pixel_values: torch.Tensor, + input_ids: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + ): + """Convert pixel images to image features and obtain placeholder mask.""" + image_features = self._get_image_features(pixel_values) + image_mask = self._get_placeholder_mask(input_ids, inputs_embeds=inputs_embeds, image_features=image_features) + return image_features, image_mask + + @can_return_tuple + @auto_docstring + def forward( + self, + # encoder inputs + input_ids: Optional[torch.LongTensor] = None, + pixel_values: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + # decoder inputs + decoder_input_ids: Optional[torch.LongTensor] = None, + decoder_attention_mask: Optional[torch.BoolTensor] = None, + decoder_position_ids: Optional[torch.LongTensor] = None, + # others (mainly inference or cache related) + encoder_outputs: Optional[BaseModelOutput] = None, + past_key_values: Optional[EncoderDecoderCache] = None, + inputs_embeds: Optional[torch.Tensor] = None, + decoder_inputs_embeds: Optional[torch.Tensor] = None, + use_cache: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> Seq2SeqModelOutput: + r""" + decoder_position_ids (`torch.LongTensor` of shape `(batch_size, decoder_sequence_length)`, *optional*): + Indices of positions of each decoder input sequence tokens in the position embeddings. Selected in the range `[0, + config.decoder.n_positions - 1]`. [What are position IDs?](../glossary#position-ids) + """ + # encoder + if encoder_outputs is None: + encoder_outputs = self.encoder( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + inputs_embeds=inputs_embeds, + pixel_values=pixel_values, + return_dict=True, + **kwargs, + ) + + encoder_hidden_states = encoder_outputs.last_hidden_state + + # decoder + decoder_outputs = self.decoder( + input_ids=decoder_input_ids, + attention_mask=decoder_attention_mask, + position_ids=decoder_position_ids, + inputs_embeds=decoder_inputs_embeds, + past_key_values=past_key_values, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=attention_mask, + use_cache=use_cache, + cache_position=cache_position, + return_dict=True, + **kwargs, + ) + + return Seq2SeqModelOutput( + last_hidden_state=decoder_outputs.last_hidden_state, + past_key_values=decoder_outputs.past_key_values, + decoder_hidden_states=decoder_outputs.hidden_states + if kwargs.get("output_hidden_states", False) + else (decoder_outputs.last_hidden_state,), + decoder_attentions=decoder_outputs.attentions, + cross_attentions=decoder_outputs.cross_attentions, + encoder_last_hidden_state=encoder_outputs.last_hidden_state, + encoder_hidden_states=encoder_outputs.hidden_states, + encoder_attentions=encoder_outputs.attentions, + ) + + +class T5Gemma2ForConditionalGeneration(T5Gemma2PreTrainedModel, GenerationMixin): + _tied_weights_keys = [ + "lm_head.out_proj.weight", + "model.encoder.embed_tokens.weight", + "model.decoder.embed_tokens.weight", + ] + _dynamic_tied_weights_keys = [ + "model.encoder.embed_tokens.eoi_embedding", + "model.decoder.embed_tokens.eoi_embedding", + ] + _tp_plan = {"lm_head.out_proj": "colwise_rep"} + _pp_plan = {"lm_head.out_proj": (["hidden_states"], ["logits"])} + + def __init__(self, config: T5Gemma2Config): + super().__init__(config) + + self.model = T5Gemma2Model(config) + self.vocab_size = config.decoder.vocab_size + self.lm_head = T5Gemma2LMHead(config.decoder.hidden_size, self.vocab_size) + self.loss_type = "ForMaskedLM" + + self.post_init() + + def set_output_embeddings(self, new_embeddings): + self.lm_head.out_proj = new_embeddings + + def get_output_embeddings(self): + return self.lm_head.out_proj + + def get_input_embeddings(self): + return self.model.get_input_embeddings() + + def set_input_embeddings(self, value): + self.model.set_input_embeddings(value) + + def _tie_weights(self): + # Decoder input and output embeddings are tied. + if self.config.tie_word_embeddings: + self.lm_head.out_proj.weight = self.model.encoder.embed_tokens.weight + + def get_encoder(self): + return self.model.get_encoder() + + def get_decoder(self): + return self.model.get_decoder() + + def get_image_features(self, pixel_values): + return self.model.get_image_features(pixel_values) + + @property + def vision_tower(self): + return self.model.vision_tower + + @property + def multi_modal_projector(self): + return self.model.multi_modal_projector + + @can_return_tuple + @auto_docstring + def forward( + self, + # encoder inputs + input_ids: Optional[torch.LongTensor] = None, + pixel_values: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + # decoder inputs + decoder_input_ids: Optional[torch.LongTensor] = None, + decoder_attention_mask: Optional[torch.BoolTensor] = None, + decoder_position_ids: Optional[torch.LongTensor] = None, + # others (mainly inference or cache related) + encoder_outputs: Optional[BaseModelOutput] = None, + past_key_values: Optional[EncoderDecoderCache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + decoder_inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + **kwargs: Unpack[TransformersKwargs], + ) -> Union[tuple[torch.FloatTensor], Seq2SeqLMOutput]: + r""" + decoder_position_ids (`torch.LongTensor` of shape `(batch_size, decoder_sequence_length)`, *optional*): + Indices of positions of each decoder input sequence tokens in the position embeddings. Selected in the range `[0, + config.decoder.n_positions - 1]`. [What are position IDs?](../glossary#position-ids) + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + """ + + if labels is not None and decoder_input_ids is None and decoder_inputs_embeds is None: + # get decoder inputs from shifting lm labels to the right + decoder_input_ids = self.prepare_decoder_input_ids_from_labels(labels) + + decoder_outputs: Seq2SeqModelOutput = self.model( + input_ids=input_ids, + pixel_values=pixel_values, + attention_mask=attention_mask, + position_ids=position_ids, + decoder_input_ids=decoder_input_ids, + decoder_attention_mask=decoder_attention_mask, + decoder_position_ids=decoder_position_ids, + encoder_outputs=encoder_outputs, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + decoder_inputs_embeds=decoder_inputs_embeds, + use_cache=use_cache, + cache_position=cache_position, + **kwargs, + ) + + hidden_states = decoder_outputs.last_hidden_state + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) + + decoder_config = self.config.decoder + if decoder_config.final_logit_softcapping is not None: + logits = logits / decoder_config.final_logit_softcapping + logits = torch.tanh(logits) + logits = logits * decoder_config.final_logit_softcapping + + loss = None + if labels is not None: + # Input has right-shifted so we directly perform masked lm loss + loss = self.loss_function(logits, labels, self.vocab_size, **kwargs) + + return Seq2SeqLMOutput( + loss=loss, + logits=logits, + past_key_values=decoder_outputs.past_key_values, + decoder_hidden_states=decoder_outputs.decoder_hidden_states, + decoder_attentions=decoder_outputs.decoder_attentions, + cross_attentions=decoder_outputs.cross_attentions, + encoder_last_hidden_state=decoder_outputs.encoder_last_hidden_state, + encoder_hidden_states=decoder_outputs.encoder_hidden_states, + encoder_attentions=decoder_outputs.encoder_attentions, + ) + + def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor): + return self._shift_right(labels) + + +@auto_docstring +class T5Gemma2ForSequenceClassification(T5Gemma2PreTrainedModel): + _tied_weights_keys = [ + "model.encoder.embed_tokens.weight", + "model.decoder.embed_tokens.weight", + ] + _dynamic_tied_weights_keys = [ + "model.encoder.embed_tokens.eoi_embedding", + "model.decoder.embed_tokens.eoi_embedding", + ] + + def __init__(self, config: T5Gemma2Config): + super().__init__(config) + self.num_labels = config.num_labels + self.hidden_size = config.decoder.hidden_size + + self.model = T5Gemma2Model(config) + + classifier_dropout = getattr(config, "classifier_dropout_rate", 0.1) + self.score = T5Gemma2ClassificationHead(self.hidden_size, self.num_labels, classifier_dropout) + self.post_init() + + def get_input_embeddings(self): + return self.model.get_input_embeddings() + + def set_input_embeddings(self, value): + self.model.set_input_embeddings(value) + + @can_return_tuple + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + pixel_values: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + decoder_input_ids: Optional[torch.LongTensor] = None, + decoder_attention_mask: Optional[torch.Tensor] = None, + decoder_position_ids: Optional[torch.LongTensor] = None, + encoder_outputs: Optional[BaseModelOutput] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + decoder_inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> SequenceClassifierOutput: + r""" + decoder_position_ids (`torch.LongTensor` of shape `(batch_size, decoder_sequence_length)`, *optional*): + Indices of positions of each decoder input sequence tokens in the position embeddings. Selected in the range `[0, + config.decoder.n_positions - 1]`. [What are position IDs?](../glossary#position-ids) + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + if inputs_embeds is not None or decoder_inputs_embeds is not None: + raise NotImplementedError( + f"Passing input embeddings is currently not supported for {self.__class__.__name__}." + ) + + if input_ids is None: + raise ValueError("You have to specify input_ids") + + if decoder_input_ids is None: + decoder_input_ids = self._shift_right(input_ids) + + outputs: Seq2SeqModelOutput = self.model( + input_ids, + pixel_values=pixel_values, + attention_mask=attention_mask, + position_ids=position_ids, + decoder_input_ids=decoder_input_ids, + decoder_attention_mask=decoder_attention_mask, + decoder_position_ids=decoder_position_ids, + encoder_outputs=encoder_outputs, + inputs_embeds=inputs_embeds, + decoder_inputs_embeds=decoder_inputs_embeds, + use_cache=False, + **kwargs, + ) + + last_hidden_state = outputs.last_hidden_state + hidden_states = outputs.decoder_hidden_states + attentions = outputs.decoder_attentions + + logits = self.score(last_hidden_state) + + batch_size = input_ids.shape[0] + # To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id + non_pad_mask = (decoder_input_ids != self.config.pad_token_id).to(logits.device, torch.int32) + token_indices = torch.arange(decoder_input_ids.shape[-1], device=logits.device, dtype=torch.int32) + last_non_pad_token = (token_indices * non_pad_mask).argmax(-1) + last_non_pad_token = torch.clamp(last_non_pad_token, max=decoder_input_ids.shape[-1] - 1) + + pooled_logits = logits[torch.arange(batch_size, device=logits.device), last_non_pad_token] + + loss = None + if labels is not None: + loss = self.loss_function(logits=logits, labels=labels, pooled_logits=pooled_logits, config=self.config) + + return SequenceClassifierOutput( + loss=loss, + logits=pooled_logits, + hidden_states=hidden_states, + attentions=attentions, + ) + + +@auto_docstring +class T5Gemma2ForTokenClassification(T5Gemma2PreTrainedModel): + _tied_weights_keys = [ + "model.encoder.embed_tokens.weight", + "model.decoder.embed_tokens.weight", + ] + _dynamic_tied_weights_keys = [ + "model.encoder.embed_tokens.eoi_embedding", + "model.decoder.embed_tokens.eoi_embedding", + ] + + def __init__(self, config: T5Gemma2Config): + super().__init__(config) + self.num_labels = config.num_labels + self.hidden_size = config.decoder.hidden_size + + self.model = T5Gemma2Model(config) + + classifier_dropout = getattr(config, "classifier_dropout_rate", 0.1) + self.score = T5Gemma2ClassificationHead(self.hidden_size, self.num_labels, classifier_dropout) + + self.post_init() + + def get_input_embeddings(self): + return self.model.get_input_embeddings() + + def set_input_embeddings(self, value): + self.model.set_input_embeddings(value) + + @can_return_tuple + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + pixel_values: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + decoder_input_ids: Optional[torch.LongTensor] = None, + decoder_attention_mask: Optional[torch.Tensor] = None, + decoder_position_ids: Optional[torch.LongTensor] = None, + encoder_outputs: Optional[BaseModelOutput] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + decoder_inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> TokenClassifierOutput: + r""" + decoder_position_ids (`torch.LongTensor` of shape `(batch_size, decoder_sequence_length)`, *optional*): + Indices of positions of each decoder input sequence tokens in the position embeddings. Selected in the range `[0, + config.decoder.n_positions - 1]`. [What are position IDs?](../glossary#position-ids) + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + if inputs_embeds is not None or decoder_inputs_embeds is not None: + raise NotImplementedError( + f"Passing input embeddings is currently not supported for {self.__class__.__name__}." + ) + + if input_ids is None: + raise ValueError("You have to specify input_ids") + + if decoder_input_ids is None: + decoder_input_ids = self._shift_right(input_ids) + + outputs: Seq2SeqModelOutput = self.model( + input_ids, + pixel_values=pixel_values, + attention_mask=attention_mask, + position_ids=position_ids, + decoder_input_ids=decoder_input_ids, + decoder_attention_mask=decoder_attention_mask, + decoder_position_ids=decoder_position_ids, + encoder_outputs=encoder_outputs, + inputs_embeds=inputs_embeds, + decoder_inputs_embeds=decoder_inputs_embeds, + use_cache=False, + **kwargs, + ) + last_hidden_state = outputs.last_hidden_state + hidden_states = outputs.decoder_hidden_states + attentions = outputs.decoder_attentions + + logits = self.score(last_hidden_state) + + loss = None + if labels is not None: + loss = self.loss_function(logits, labels, self.config) + + return TokenClassifierOutput( + loss=loss, + logits=logits, + hidden_states=hidden_states, + attentions=attentions, + ) + + +__all__ = [ + "T5Gemma2ForConditionalGeneration", + "T5Gemma2Model", + "T5Gemma2PreTrainedModel", + "T5Gemma2ForSequenceClassification", + "T5Gemma2ForTokenClassification", +] diff --git a/src/transformers/models/t5gemma2/modular_t5gemma2.py b/src/transformers/models/t5gemma2/modular_t5gemma2.py new file mode 100644 index 000000000000..a6bf12af6556 --- /dev/null +++ b/src/transformers/models/t5gemma2/modular_t5gemma2.py @@ -0,0 +1,1381 @@ +# coding=utf-8 +# Copyright 2025 Google Inc. HuggingFace Inc. team. All rights reserved. +# +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import copy +from collections.abc import Callable +from typing import Any, Optional, Union + +import torch +import torch.nn as nn + +from ...cache_utils import DynamicCache, EncoderDecoderCache +from ...configuration_utils import PreTrainedConfig +from ...generation import GenerationMixin +from ...masking_utils import and_masks, create_bidirectional_mask +from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_outputs import ( + BaseModelOutput, + BaseModelOutputWithPastAndCrossAttentions, + Seq2SeqLMOutput, + Seq2SeqModelOutput, + SequenceClassifierOutput, + TokenClassifierOutput, +) +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...processing_utils import Unpack +from ...utils import ( + TransformersKwargs, + auto_docstring, + can_return_tuple, + logging, +) +from ...utils.generic import OutputRecorder, check_model_inputs +from ..auto import AutoModel +from ..gemma3.configuration_gemma3 import Gemma3TextConfig +from ..gemma3.modeling_gemma3 import ( + Gemma3Attention, + Gemma3MLP, + Gemma3MultiModalProjector, + Gemma3PreTrainedModel, + Gemma3RMSNorm, + Gemma3RotaryEmbedding, + Gemma3TextScaledWordEmbedding, + apply_rotary_pos_emb, + create_causal_mask, + create_sliding_window_causal_mask, + eager_attention_forward, +) +from ..siglip import SiglipVisionConfig +from ..t5gemma.modeling_t5gemma import ( + T5GemmaClassificationHead, + T5GemmaEncoderLayer, + T5GemmaLMHead, + bidirectional_mask_function, + make_default_2d_attention_mask, +) + + +logger = logging.get_logger(__name__) + + +class T5Gemma2ModuleConfig(Gemma3TextConfig): + model_type = "t5gemma2_module" + + +class T5Gemma2Config(PreTrainedConfig): + r""" + This is the configuration class to store the configuration of a [`T5Gemma2Model`]. It is used to instantiate an T5Gemma2 + model according to the specified arguments, defining the model architecture. Instantiating a configuration with the + defaults will yield a similar configuration to a hypothetical balanced Gemma3 encoder-decoder model. + e.g. [google/t5gemma-2-270m-270m](https://huggingface.co/google/t5gemma-2-270m-270m) + Configuration objects inherit from [PreTrainedConfig] and can be used to control the model outputs. Read the + documentation from [PreTrainedConfig] for more information. + + Args: + encoder (`Union[T5Gemma2ModuleConfig, dict]`, optional, *optional*): + Configuration for the encoder. + decoder (`Union[T5Gemma2ModuleConfig, dict]`, optional, *optional*): + Configuration for the decoder. + vision_config (`Union[SiglipVisionConfig, dict]`, optional, *optional*): + Configuration for the vision encoder. + is_encoder_decoder (bool, optional, *optional*, defaults to `True`): + Whether the model is used as an encoder/decoder or not. + dropout_rate (`float`, *optional*, defaults to 0.0): + The ratio for all dropout layers (following T5). + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for attention. + classifier_dropout_rate (`float`, *optional*, defaults to 0.0): + The dropout ratio for classifier (following T5). + mm_tokens_per_image (`int`, *optional*, defaults to 256): + The number of tokens per image embedding. + boi_token_index (`int`, *optional*, defaults to 255999): + The begin-of-image token index to wrap the image prompt. + eoi_token_index (`int`, *optional*, defaults to 256000): + The end-of-image token index to wrap the image prompt. + image_token_index (`int`, *optional*, defaults to 256001): + The image token index to encode the image prompt. Defaults to 256001, which is right after the eoi_token_index. + Note this is different from Gemma 3. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + vocab_size (`int`, *optional*, defaults to 262144): + Vocabulary size of the T5Gemma2 model (the same as Gemma 3). + kwargs (additional keyword arguments, optional, *optional*): + Will be passed to the PreTrainedConfig base class. + ```python + >>> from transformers import T5Gemma2Config, T5Gemma2Model + >>> t5gemma2_config = T5Gemma2Config.from_pretrained("google/t5gemma-270m-270m") + >>> model = T5Gemma2Model(t5gemma2_config) + ``` + """ + + model_type = "t5gemma2" + keys_to_ignore_at_inference = ["past_key_values"] + base_model_tp_plan = { + # encoder + "encoder.layers.*.self_attn.q_proj": "colwise", + "encoder.layers.*.self_attn.k_proj": "colwise", + "encoder.layers.*.self_attn.v_proj": "colwise", + "encoder.layers.*.self_attn.o_proj": "rowwise", + "encoder.layers.*.mlp.gate_proj": "colwise", + "encoder.layers.*.mlp.up_proj": "colwise", + "encoder.layers.*.mlp.down_proj": "rowwise", + # decoder + "decoder.layers.*.self_attn.q_proj": "colwise", + "decoder.layers.*.self_attn.k_proj": "colwise", + "decoder.layers.*.self_attn.v_proj": "colwise", + "decoder.layers.*.self_attn.o_proj": "rowwise", + "decoder.layers.*.mlp.gate_proj": "colwise", + "decoder.layers.*.mlp.up_proj": "colwise", + "decoder.layers.*.mlp.down_proj": "rowwise", + } + base_model_pp_plan = { + # encoder + "encoder.embed_tokens": (["input_ids"], ["inputs_embeds"]), + "encoder.layers": (["hidden_states", "attention_mask"], ["hidden_states"]), + "encoder.norm": (["hidden_states"], ["hidden_states"]), + # decoder + "decoder.embed_tokens": (["input_ids"], ["inputs_embeds"]), + "decoder.layers": (["hidden_states", "attention_mask"], ["hidden_states"]), + "decoder.norm": (["hidden_states"], ["hidden_states"]), + } + attribute_map = { + "image_token_id": "image_token_index", + "boi_token_id": "boi_token_index", + "eoi_token_id": "eoi_token_index", + } + + def __init__( + self, + encoder: Optional[Union[T5Gemma2ModuleConfig, dict[str, Any]]] = None, + decoder: Optional[Union[T5Gemma2ModuleConfig, dict[str, Any]]] = None, + vision_config: Optional[Union[SiglipVisionConfig, dict[str, Any]]] = None, + is_encoder_decoder: bool = True, + dropout_rate: float = 0.0, + attention_dropout: float = 0.0, + classifier_dropout_rate: float = 0.0, + mm_tokens_per_image: int = 256, + boi_token_index: int = 255_999, + eoi_token_index: int = 256_000, + image_token_index: int = 256_001, + initializer_range: float = 0.02, + vocab_size: int = 262_144, + **kwargs, + ): + if isinstance(encoder, dict): + encoder = T5Gemma2ModuleConfig(**encoder) + elif encoder is None: + encoder = T5Gemma2ModuleConfig() + logger.info("encoder is None, using default T5Gemma2ModuleConfig encoder config.") + else: + if not isinstance(encoder, T5Gemma2ModuleConfig): + raise ValueError(f"{type(encoder)} is not supported.") + + if isinstance(decoder, dict): + decoder = T5Gemma2ModuleConfig(**decoder) + elif decoder is None: + decoder = copy.deepcopy(encoder) + logger.info("decoder is None, using the same config as encoder.") + else: + if not isinstance(decoder, T5Gemma2ModuleConfig): + raise ValueError(f"{type(decoder)} is not supported.") + + if isinstance(vision_config, dict): + vision_config = SiglipVisionConfig(**vision_config) + elif vision_config is None: + vision_config = SiglipVisionConfig() + logger.info("vision_config is None, using default SiglipVisionConfig vision config.") + else: + if not isinstance(vision_config, SiglipVisionConfig): + raise ValueError(f"{type(vision_config)} is not supported.") + + if encoder.hidden_size != decoder.hidden_size: + raise ValueError( + "Imbalanced encoder-decoder is not supported in T5Gemma2: " + f"encoder ({encoder.hidden_size}) vs decoder ({decoder.hidden_size})." + ) + + if not is_encoder_decoder: + raise ValueError("T5Gemma2Model only support encoder-decoder modeling.") + + if encoder.vocab_size != decoder.vocab_size: + raise ValueError( + "Imbalanced encoder-decoder vocabulary size is not supported in T5Gemma2: " + f"encoder ({encoder.vocab_size}) vs decoder ({decoder.vocab_size})." + ) + + encoder = T5Gemma2ModuleConfig(**encoder.to_dict()) + decoder = T5Gemma2ModuleConfig(**decoder.to_dict()) + vision_config = SiglipVisionConfig(**vision_config.to_dict()) + + encoder.is_decoder = False + encoder.dropout_rate = dropout_rate + encoder.attention_dropout = attention_dropout + self.encoder = encoder + + decoder.is_decoder = True + decoder.use_cache = True + decoder.dropout_rate = dropout_rate + decoder.attention_dropout = attention_dropout + self.decoder = decoder + + self.vision_config = vision_config + + for special_token_key in ["bos_token_id", "pad_token_id", "eos_token_id", "use_cache"]: + if special_token_key not in kwargs: + kwargs[special_token_key] = getattr(decoder, special_token_key) + + super().__init__(**kwargs) + + self.mm_tokens_per_image = mm_tokens_per_image + self.boi_token_index = boi_token_index + self.eoi_token_index = eoi_token_index + self.image_token_index = image_token_index + self.initializer_range = initializer_range + + self.is_encoder_decoder = is_encoder_decoder + self.dropout_rate = dropout_rate + self.attention_dropout = attention_dropout + self.classifier_dropout_rate = classifier_dropout_rate + + # Used in pipeline generation. + self.vocab_size = vocab_size + + def __setattr__(self, key, value): + shared_attr_with_submodules = [ + "output_hidden_states", + "output_attentions", + "_attn_implementation", + "dropout_rate", + "attention_dropout", + "vocab_size", + ] + + if key in shared_attr_with_submodules: + setattr(self.encoder, key, value) + setattr(self.decoder, key, value) + if key == "_attn_implementation": + self.vision_config._attn_implementation = value + super().__setattr__(key, value) + + +class T5Gemma2RMSNorm(Gemma3RMSNorm): + pass + + +class T5Gemma2MLP(Gemma3MLP): + def __init__(self, config: T5Gemma2ModuleConfig): + super().__init__(config) + self.dropout = nn.Dropout(config.dropout_rate) + + def forward(self, x): + hidden_states = self.act_fn(self.gate_proj(x)) * self.up_proj(x) + hidden_states = self.dropout(hidden_states) + down_proj = self.down_proj(hidden_states) + return down_proj + + +class T5Gemma2RotaryEmbedding(Gemma3RotaryEmbedding): + def __init__(self, config: T5Gemma2ModuleConfig, device=None): + super().__init__(config, device) + + @staticmethod + def compute_default_rope_parameters( + config: Optional[T5Gemma2ModuleConfig] = None, + device: Optional["torch.device"] = None, + seq_len: Optional[int] = None, + layer_type: Optional[str] = None, + ) -> tuple["torch.Tensor", float]: + return super().compute_default_rope_parameters(config, device, seq_len, layer_type) + + +class T5Gemma2SelfAttention(Gemma3Attention): + def __init__(self, config: T5Gemma2ModuleConfig, layer_idx: int): + super().__init__(config, layer_idx) + + +class T5Gemma2MergedAttention(Gemma3Attention): + """Merged self-attention and cross-attention for decoder.""" + + def __init__(self, config: T5Gemma2ModuleConfig, layer_idx: int): + super().__init__(config, layer_idx) + + def forward( + self, + # decoder self-attention inputs + hidden_states: torch.Tensor, + position_embeddings: tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor], + # cross-attention inputs + encoder_hidden_states: torch.Tensor, + encoder_attention_mask: Optional[torch.Tensor], + # cache inputs + past_key_values: Optional[EncoderDecoderCache] = None, + cache_position: Optional[torch.LongTensor] = None, + # others + **kwargs: Unpack[FlashAttentionKwargs], + ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: + # attention shapes. + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) + cross_input_shape = encoder_hidden_states.shape[:-1] + cross_hidden_shape = (*cross_input_shape, -1, self.head_dim) + + # self-attention. + query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) + key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) + + query_states = self.q_norm(query_states) + key_states = self.k_norm(key_states) + + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_values is not None: + # self-attention. + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + self_attention_cache = past_key_values.self_attention_cache + key_states, value_states = self_attention_cache.update( + key_states, value_states, self.layer_idx, cache_kwargs + ) + + # cross-attention. + is_updated = past_key_values.is_updated.get(self.layer_idx) + cross_attention_cache = past_key_values.cross_attention_cache + + if past_key_values is None or not is_updated: + cross_key_states = self.k_proj(encoder_hidden_states).view(cross_hidden_shape).transpose(1, 2) + cross_value_states = self.v_proj(encoder_hidden_states).view(cross_hidden_shape).transpose(1, 2) + + cross_key_states = self.k_norm(cross_key_states) + + if past_key_values is not None: + # Handle sliding window for cross-attention: convert the window size to the input size. + if len(cross_attention_cache.layers) > self.layer_idx and hasattr( + cross_attention_cache.layers[self.layer_idx], "sliding_window" + ): + cross_attention_cache.layers[self.layer_idx].sliding_window = cross_key_states.shape[2] + 1 + cross_key_states, cross_value_states = cross_attention_cache.update( + cross_key_states, cross_value_states, self.layer_idx + ) + past_key_values.is_updated[self.layer_idx] = True + else: + cross_key_states = cross_attention_cache.layers[self.layer_idx].keys + cross_value_states = cross_attention_cache.layers[self.layer_idx].values + + # merged attention. + query_states = query_states + cross_key_size = cross_key_states.shape[2] + key_states = torch.cat([key_states, cross_key_states], dim=2) + value_states = torch.cat([value_states, cross_value_states], dim=2) + # merge attention mask. + is_self_attn_mask_none = attention_mask is None + is_cross_attn_mask_none = encoder_attention_mask is None + if is_self_attn_mask_none and is_cross_attn_mask_none: + attention_mask = None + elif is_self_attn_mask_none ^ is_cross_attn_mask_none: + raise ValueError( + f"Either both or neither of attention_mask ({is_self_attn_mask_none}) and " + f"encoder_attention_mask ({is_cross_attn_mask_none}) should be None." + ) + else: + if attention_mask.ndim != encoder_attention_mask.ndim: + raise ValueError( + f"Attention mask dimension {attention_mask.ndim} and encoder attention mask {encoder_attention_mask.ndim} do not match." + ) + + attention_mask = torch.cat([attention_mask, encoder_attention_mask], dim=-1) + + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=self.attention_dropout if self.training else 0.0, + scaling=self.scaling, + # merged attention is not causal or sliding window + sliding_window=None, + is_causal=False, + **kwargs, + ) + + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + attn_output = self.o_proj(attn_output) + + # decompose merged attention weights into self & cross attention weights + if attn_weights is not None: + self_attn_weights = attn_weights[..., :-cross_key_size] + cross_attn_weights = attn_weights[..., -cross_key_size:] + else: + self_attn_weights, cross_attn_weights = None, None + return attn_output, self_attn_weights, cross_attn_weights + + +def sliding_window_mask_function(sliding_window: int, is_causal=True) -> Callable: + """ + This creates uni/bidirectional attention mask with sliding window. + """ + + def inner_mask(batch_idx: int, head_idx: int, q_idx: int, kv_idx: int) -> bool: + if is_causal: + left_window_size, right_window_size = sliding_window, 0 + else: + left_window_size, right_window_size = ((sliding_window + 1) // 2, (sliding_window) // 2 + 1) + + dist = q_idx - kv_idx + left_mask = (dist >= 0) & (dist < left_window_size) + right_mask = (dist < 0) & (-dist < right_window_size) + return left_mask | right_mask + + return inner_mask + + +class T5Gemma2EncoderLayer(T5GemmaEncoderLayer): + pass + + +class T5Gemma2DecoderLayer(T5Gemma2EncoderLayer): + """Decoder sub-layer: merged attention instead of vanilla self-attention.""" + + def __init__(self, config, layer_idx: int): + super().__init__(config, layer_idx) + + # replace vanilla self-attention with merged attention to support joint cross-attention. + self.self_attn = T5Gemma2MergedAttention( + config=config, + layer_idx=layer_idx, + ) + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[EncoderDecoderCache] = None, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + **kwargs, + ) -> torch.FloatTensor: + residual = hidden_states + hidden_states = self.pre_self_attn_layernorm(hidden_states) + + hidden_states, _, _ = self.self_attn( + hidden_states=hidden_states, + position_embeddings=position_embeddings, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + use_cache=use_cache, + cache_position=cache_position, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + **kwargs, + ) + hidden_states = self.post_self_attn_layernorm(hidden_states) + hidden_states = residual + self.dropout(hidden_states) + + residual = hidden_states + hidden_states = self.pre_feedforward_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = self.post_feedforward_layernorm(hidden_states) + hidden_states = residual + self.dropout(hidden_states) + return hidden_states + + +class T5Gemma2LMHead(T5GemmaLMHead): + pass + + +class T5Gemma2ClassificationHead(T5GemmaClassificationHead): + pass + + +class T5Gemma2MultiModalProjector(Gemma3MultiModalProjector): + def __init__(self, config: T5Gemma2Config): + super().__init__() + + self.mm_input_projection_weight = nn.Parameter( + torch.zeros(config.vision_config.hidden_size, config.encoder.hidden_size) + ) + + self.mm_soft_emb_norm = T5Gemma2RMSNorm( + config.vision_config.hidden_size, eps=config.vision_config.layer_norm_eps + ) + + self.patches_per_image = int(config.vision_config.image_size // config.vision_config.patch_size) + self.tokens_per_side = int(config.mm_tokens_per_image**0.5) + self.kernel_size = self.patches_per_image // self.tokens_per_side + self.avg_pool = nn.AvgPool2d(kernel_size=self.kernel_size, stride=self.kernel_size) + + +class T5Gemma2TextScaledWordEmbedding(Gemma3TextScaledWordEmbedding): + """T5Gemma2 Embedding: override to add eoi token embedding separately.""" + + def __init__( + self, + num_embeddings: int, + embedding_dim: int, + padding_idx: int, + embed_scale: float = 1.0, + eoi_token_index: int = 256_000, + ): + super().__init__(num_embeddings, embedding_dim, padding_idx, embed_scale) + self.eoi_token_index = eoi_token_index + self.eoi_embedding = nn.Parameter(torch.zeros(self.embedding_dim)) + + def forward(self, input_ids: torch.Tensor): + input_embeddings = super().forward(input_ids) * self.embed_scale.to(self.weight.dtype) + input_embeddings[input_ids == self.eoi_token_index] = self.eoi_embedding.to(input_embeddings.dtype) + return input_embeddings + + +@auto_docstring +class T5Gemma2PreTrainedModel(Gemma3PreTrainedModel): + config: T5Gemma2Config + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = [ + "T5Gemma2EncoderLayer", + "T5Gemma2DecoderLayer", + "SiglipVisionEmbeddings", + "SiglipEncoderLayer", + "SiglipMultiheadAttentionPoolingHead", + ] + _can_record_outputs = { + "hidden_states": [T5Gemma2EncoderLayer, T5Gemma2DecoderLayer], + "attentions": [ + OutputRecorder(T5Gemma2SelfAttention, index=1, layer_name="self_attn"), + OutputRecorder(T5Gemma2MergedAttention, index=1, layer_name="self_attn"), + OutputRecorder(T5Gemma2MergedAttention, index=2, layer_name="cross_attn"), + ], + } + + def _init_weights(self, module): + PreTrainedModel._init_weights(self, module) + if isinstance(module, T5Gemma2MultiModalProjector): + module.mm_input_projection_weight.data.zero_() + elif isinstance(module, T5Gemma2TextScaledWordEmbedding): + module.eoi_embedding.data.zero_() + elif isinstance(module, T5Gemma2ClassificationHead): + scale = module.out_proj.weight.shape[0] ** -0.5 + module.out_proj.weight.data.normal_(mean=0.0, std=self.config.initializer_range * scale) + if hasattr(module.out_proj, "bias") and module.out_proj.bias is not None: + module.out_proj.bias.data.zero_() + # We initialize with 0s to be 1 centered as the RMSNorm here does (1 + weight) + elif "RMSNorm" in module.__class__.__name__: + module.weight.data.zero_() + + def _shift_right(self, input_ids): + """ + Shifts input_ids to the right, prepends the decoder_start_token_id, and handles + pad_token_id replacement for labels that were -100. + This is a common preparation step for decoder inputs in sequence-to-sequence models. + """ + decoder_config = self.config.decoder + decoder_start_token_id = decoder_config.bos_token_id + pad_token_id = decoder_config.pad_token_id + + if decoder_start_token_id is None: + raise ValueError("self.model.config.decoder.bos_token_id has to be defined. ") + + # shift inputs to the right + shifted_input_ids = input_ids.new_zeros(input_ids.shape) + shifted_input_ids[..., 1:] = input_ids[..., :-1].clone() + shifted_input_ids[..., 0] = decoder_start_token_id + + if pad_token_id is None: + raise ValueError("self.model.config.decoder.pad_token_id has to be defined.") + + # Is this T5 specific? + # replace possible -100 values in labels by `pad_token_id` + shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id) + + return shifted_input_ids + + +class T5Gemma2Encoder(T5Gemma2PreTrainedModel): + config: T5Gemma2ModuleConfig + _can_record_outputs = { + "attentions": T5Gemma2SelfAttention, + "hidden_states": T5Gemma2EncoderLayer, + } + + def __init__( + self, + config: T5Gemma2ModuleConfig, + eoi_token_index: int = 256_000, + pixel2feature_preprocessor_fn: Optional[Callable] = None, + ): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + # preprocessor for raw images pixel values: injected from outside. + self.pixel2feature_preprocessor_fn = pixel2feature_preprocessor_fn + + self.embed_tokens = T5Gemma2TextScaledWordEmbedding( + config.vocab_size, + config.hidden_size, + config.pad_token_id, + embed_scale=config.hidden_size**0.5, + eoi_token_index=eoi_token_index, + ) + self.norm = T5Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.gradient_checkpointing = False + + self.layers = nn.ModuleList( + [T5Gemma2EncoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self.dropout = nn.Dropout(config.dropout_rate) + self.rotary_emb = T5Gemma2RotaryEmbedding(config) + + # Initialize weights and apply final processing + self.post_init() + + @check_model_inputs() + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + pixel_values: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> BaseModelOutput: + """ + pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`, *optional*): + Pixel values to be processed by the image encoder to extract image features. + """ + # Part of Gemma 3 processor output but not used by t5gemma 2. + del token_type_ids + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + # As we want to pass `past_key_values=None` explicitly everywhere, we need to pop them from kwargs if present + kwargs.pop("past_key_values", None) + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if pixel_values is not None: + if self.pixel2feature_preprocessor_fn is None: + raise ValueError("`pixel2feature_preprocessor_fn` has to be provided to process `pixel_values`.") + image_features, image_mask = self.pixel2feature_preprocessor_fn(pixel_values, input_ids, inputs_embeds) + + image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) + inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_features) + + if position_ids is None: + position_ids = torch.arange(0, inputs_embeds.shape[1], device=inputs_embeds.device).unsqueeze(0) + + if attention_mask is None: + attention_mask = make_default_2d_attention_mask(input_ids, inputs_embeds, self.config.pad_token_id) + + if not isinstance(self_attn_mask_mapping := attention_mask, dict): + mask_kwargs = { + "config": self.config, + "input_embeds": inputs_embeds, + "attention_mask": attention_mask, + } + self_attn_mask_mapping = { + "full_attention": create_bidirectional_mask( + **mask_kwargs, + and_mask_function=bidirectional_mask_function(attention_mask), + ), + "sliding_attention": create_bidirectional_mask( + **mask_kwargs, + and_mask_function=and_masks( + sliding_window_mask_function(self.config.sliding_window, is_causal=False), + bidirectional_mask_function(attention_mask), + ), + ), + } + + # input layer + hidden_states = inputs_embeds + + # global and local position embeddings + position_embeddings = {} + for layer_type in self.config.layer_types: + position_embeddings[layer_type] = self.rotary_emb(hidden_states, position_ids, layer_type) + + # dropout + hidden_states = self.dropout(hidden_states) + + for layer_module in self.layers[: self.config.num_hidden_layers]: + if not isinstance(layer_module, T5Gemma2EncoderLayer): + raise ValueError(f"Expected T5Gemma2EncoderLayer, but got {type(layer_module)}.") + hidden_states = layer_module( + hidden_states, + position_embeddings[layer_module.attention_type], + self_attn_mask_mapping[layer_module.attention_type], + position_ids, + **kwargs, + ) + + hidden_states = self.norm(hidden_states) + hidden_states = self.dropout(hidden_states) + return BaseModelOutput( + last_hidden_state=hidden_states, + ) + + +class T5Gemma2Decoder(T5Gemma2Encoder): + _can_record_outputs = { + "attentions": OutputRecorder(T5Gemma2MergedAttention, index=1), + "cross_attentions": OutputRecorder(T5Gemma2MergedAttention, index=2), + "hidden_states": T5Gemma2DecoderLayer, + } + + def __init__(self, config: T5Gemma2ModuleConfig, shared_embedding: T5Gemma2TextScaledWordEmbedding): + super().__init__(config, shared_embedding) + self.layers = nn.ModuleList( + [T5Gemma2DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + + self.post_init() + + @check_model_inputs() + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[EncoderDecoderCache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> BaseModelOutputWithPastAndCrossAttentions: + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + if encoder_hidden_states is None: + raise ValueError("`encoder_hidden_states` must be given in decoder") + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if not self.training and use_cache and past_key_values is None: + past_key_values = EncoderDecoderCache(DynamicCache(config=self.config), DynamicCache(config=self.config)) + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + if attention_mask is None and past_key_values is None: + attention_mask = make_default_2d_attention_mask(input_ids, inputs_embeds, self.config.pad_token_id) + + if not isinstance(self_attn_mask_mapping := attention_mask, dict): + mask_kwargs = { + "config": self.config, + "input_embeds": inputs_embeds, + "attention_mask": attention_mask, + "cache_position": cache_position, + "past_key_values": past_key_values.self_attention_cache if past_key_values is not None else None, + "position_ids": position_ids, + } + # this masking function did nothing to masking but forces `allow_is_causal_skip` to be False + # as we always need a mask during decoding for merged attention. + mask_kwargs["and_mask_function"] = lambda *args: torch.tensor(True, dtype=torch.bool) + self_attn_mask_mapping = { + "full_attention": create_causal_mask(**mask_kwargs), + "sliding_attention": create_sliding_window_causal_mask(**mask_kwargs), + } + + if not isinstance(cross_attn_mask_mapping := encoder_attention_mask, dict): + mask_kwargs = { + "config": self.config, + "input_embeds": encoder_hidden_states, + "attention_mask": encoder_attention_mask, + "cache_position": cache_position, + "past_key_values": None, + "position_ids": None, + } + cross_attn_mask_mapping = { + "full_attention": create_causal_mask( + **mask_kwargs, + or_mask_function=bidirectional_mask_function(encoder_attention_mask), + ), + } + + # input layer + hidden_states = inputs_embeds + + # global and local position embeddings + position_embeddings = {} + for layer_type in self.config.layer_types: + position_embeddings[layer_type] = self.rotary_emb(hidden_states, position_ids, layer_type) + + # dropout + hidden_states = self.dropout(hidden_states) + + for layer_module in self.layers[: self.config.num_hidden_layers]: + if not isinstance(layer_module, T5Gemma2DecoderLayer): + raise ValueError(f"Expected T5Gemma2DecoderLayer, but got {type(layer_module)}.") + hidden_states = layer_module( + hidden_states, + position_embeddings[layer_module.attention_type], + self_attn_mask_mapping[layer_module.attention_type], + position_ids, + past_key_values, + use_cache, + cache_position, + encoder_hidden_states, + cross_attn_mask_mapping["full_attention"], + **kwargs, + ) + + hidden_states = self.norm(hidden_states) + hidden_states = self.dropout(hidden_states) + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=past_key_values, + ) + + +@auto_docstring +class T5Gemma2Model(T5Gemma2PreTrainedModel): + _tied_weights_keys = [ + "encoder.embed_tokens.weight", + "decoder.embed_tokens.weight", + ] + _dynamic_tied_weights_keys = [ + "encoder.embed_tokens.eoi_embedding", + "decoder.embed_tokens.eoi_embedding", + ] + + def __init__(self, config: T5Gemma2Config): + super().__init__(config) + + # setup encoder and decoder + self.encoder = T5Gemma2Encoder(config.encoder, config.eoi_token_index, self.pixel2feature_preprocessor) + self.decoder = T5Gemma2Decoder(config.decoder, config.eoi_token_index) + + # setup vision encoder + self.vision_tower = AutoModel.from_config(config=config.vision_config) + self.multi_modal_projector = T5Gemma2MultiModalProjector(config) + + self.post_init() + + def get_encoder(self): + return self.encoder + + def get_decoder(self): + return self.decoder + + def get_input_embeddings(self): + return self.encoder.get_input_embeddings() + + def set_input_embeddings(self, new_embeddings): + return self.encoder.set_input_embeddings(new_embeddings) + + def _tie_weights(self): + # Decoder input and output embeddings are tied. + if self.config.tie_word_embeddings: + self.decoder.embed_tokens.weight = self.encoder.embed_tokens.weight + self.decoder.embed_tokens.eoi_embedding = self.encoder.embed_tokens.eoi_embedding + + def _get_image_features(self, pixel_values: torch.Tensor) -> torch.Tensor: + """Convert pixel image to image features via the encoder and projector.""" + # pixel_values: (batch_size, channels, height, width) + # image_features: Image feature tensor of shape (num_images, image_length, embed_dim). + vision_outputs = self.vision_tower(pixel_values=pixel_values).last_hidden_state + image_features = self.multi_modal_projector(vision_outputs) + return image_features + + def _get_placeholder_mask( + self, + input_ids: Optional[torch.LongTensor], + inputs_embeds: Optional[torch.FloatTensor], + image_features: torch.FloatTensor, + ): + """ + Obtains multimodal placeholder mask from `input_ids` or `inputs_embeds`, and checks that the placeholder token count is + equal to the length of multimodal features. If the lengths are different, an error is raised. + """ + if input_ids is None: + special_image_mask = inputs_embeds == self.get_input_embeddings()( + torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device) + ) + special_image_mask = special_image_mask.all(-1) + else: + special_image_mask = input_ids == self.config.image_token_id + + n_image_tokens = special_image_mask.sum() + special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) + n_image_features = image_features.shape[0] * image_features.shape[1] + if inputs_embeds[special_image_mask].numel() != image_features.numel(): + raise ValueError( + f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" + ) + return special_image_mask + + def pixel2feature_preprocessor( + self, + pixel_values: torch.Tensor, + input_ids: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + ): + """Convert pixel images to image features and obtain placeholder mask.""" + image_features = self._get_image_features(pixel_values) + image_mask = self._get_placeholder_mask(input_ids, inputs_embeds=inputs_embeds, image_features=image_features) + return image_features, image_mask + + @can_return_tuple + @auto_docstring + def forward( + self, + # encoder inputs + input_ids: Optional[torch.LongTensor] = None, + pixel_values: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + # decoder inputs + decoder_input_ids: Optional[torch.LongTensor] = None, + decoder_attention_mask: Optional[torch.BoolTensor] = None, + decoder_position_ids: Optional[torch.LongTensor] = None, + # others (mainly inference or cache related) + encoder_outputs: Optional[BaseModelOutput] = None, + past_key_values: Optional[EncoderDecoderCache] = None, + inputs_embeds: Optional[torch.Tensor] = None, + decoder_inputs_embeds: Optional[torch.Tensor] = None, + use_cache: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> Seq2SeqModelOutput: + r""" + decoder_position_ids (`torch.LongTensor` of shape `(batch_size, decoder_sequence_length)`, *optional*): + Indices of positions of each decoder input sequence tokens in the position embeddings. Selected in the range `[0, + config.decoder.n_positions - 1]`. [What are position IDs?](../glossary#position-ids) + """ + # encoder + if encoder_outputs is None: + encoder_outputs = self.encoder( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + inputs_embeds=inputs_embeds, + pixel_values=pixel_values, + return_dict=True, + **kwargs, + ) + + encoder_hidden_states = encoder_outputs.last_hidden_state + + # decoder + decoder_outputs = self.decoder( + input_ids=decoder_input_ids, + attention_mask=decoder_attention_mask, + position_ids=decoder_position_ids, + inputs_embeds=decoder_inputs_embeds, + past_key_values=past_key_values, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=attention_mask, + use_cache=use_cache, + cache_position=cache_position, + return_dict=True, + **kwargs, + ) + + return Seq2SeqModelOutput( + last_hidden_state=decoder_outputs.last_hidden_state, + past_key_values=decoder_outputs.past_key_values, + decoder_hidden_states=decoder_outputs.hidden_states + if kwargs.get("output_hidden_states", False) + else (decoder_outputs.last_hidden_state,), + decoder_attentions=decoder_outputs.attentions, + cross_attentions=decoder_outputs.cross_attentions, + encoder_last_hidden_state=encoder_outputs.last_hidden_state, + encoder_hidden_states=encoder_outputs.hidden_states, + encoder_attentions=encoder_outputs.attentions, + ) + + +class T5Gemma2ForConditionalGeneration(T5Gemma2PreTrainedModel, GenerationMixin): + _tied_weights_keys = [ + "lm_head.out_proj.weight", + "model.encoder.embed_tokens.weight", + "model.decoder.embed_tokens.weight", + ] + _dynamic_tied_weights_keys = [ + "model.encoder.embed_tokens.eoi_embedding", + "model.decoder.embed_tokens.eoi_embedding", + ] + _tp_plan = {"lm_head.out_proj": "colwise_rep"} + _pp_plan = {"lm_head.out_proj": (["hidden_states"], ["logits"])} + + def __init__(self, config: T5Gemma2Config): + super().__init__(config) + + self.model = T5Gemma2Model(config) + self.vocab_size = config.decoder.vocab_size + self.lm_head = T5Gemma2LMHead(config.decoder.hidden_size, self.vocab_size) + self.loss_type = "ForMaskedLM" + + self.post_init() + + def set_output_embeddings(self, new_embeddings): + self.lm_head.out_proj = new_embeddings + + def get_output_embeddings(self): + return self.lm_head.out_proj + + def get_input_embeddings(self): + return self.model.get_input_embeddings() + + def set_input_embeddings(self, value): + self.model.set_input_embeddings(value) + + def _tie_weights(self): + # Decoder input and output embeddings are tied. + if self.config.tie_word_embeddings: + self.lm_head.out_proj.weight = self.model.encoder.embed_tokens.weight + + def get_encoder(self): + return self.model.get_encoder() + + def get_decoder(self): + return self.model.get_decoder() + + def get_image_features(self, pixel_values): + return self.model.get_image_features(pixel_values) + + @property + def vision_tower(self): + return self.model.vision_tower + + @property + def multi_modal_projector(self): + return self.model.multi_modal_projector + + @can_return_tuple + @auto_docstring + def forward( + self, + # encoder inputs + input_ids: Optional[torch.LongTensor] = None, + pixel_values: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + # decoder inputs + decoder_input_ids: Optional[torch.LongTensor] = None, + decoder_attention_mask: Optional[torch.BoolTensor] = None, + decoder_position_ids: Optional[torch.LongTensor] = None, + # others (mainly inference or cache related) + encoder_outputs: Optional[BaseModelOutput] = None, + past_key_values: Optional[EncoderDecoderCache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + decoder_inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + **kwargs: Unpack[TransformersKwargs], + ) -> Union[tuple[torch.FloatTensor], Seq2SeqLMOutput]: + r""" + decoder_position_ids (`torch.LongTensor` of shape `(batch_size, decoder_sequence_length)`, *optional*): + Indices of positions of each decoder input sequence tokens in the position embeddings. Selected in the range `[0, + config.decoder.n_positions - 1]`. [What are position IDs?](../glossary#position-ids) + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + """ + + if labels is not None and decoder_input_ids is None and decoder_inputs_embeds is None: + # get decoder inputs from shifting lm labels to the right + decoder_input_ids = self.prepare_decoder_input_ids_from_labels(labels) + + decoder_outputs: Seq2SeqModelOutput = self.model( + input_ids=input_ids, + pixel_values=pixel_values, + attention_mask=attention_mask, + position_ids=position_ids, + decoder_input_ids=decoder_input_ids, + decoder_attention_mask=decoder_attention_mask, + decoder_position_ids=decoder_position_ids, + encoder_outputs=encoder_outputs, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + decoder_inputs_embeds=decoder_inputs_embeds, + use_cache=use_cache, + cache_position=cache_position, + **kwargs, + ) + + hidden_states = decoder_outputs.last_hidden_state + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) + + decoder_config = self.config.decoder + if decoder_config.final_logit_softcapping is not None: + logits = logits / decoder_config.final_logit_softcapping + logits = torch.tanh(logits) + logits = logits * decoder_config.final_logit_softcapping + + loss = None + if labels is not None: + # Input has right-shifted so we directly perform masked lm loss + loss = self.loss_function(logits, labels, self.vocab_size, **kwargs) + + return Seq2SeqLMOutput( + loss=loss, + logits=logits, + past_key_values=decoder_outputs.past_key_values, + decoder_hidden_states=decoder_outputs.decoder_hidden_states, + decoder_attentions=decoder_outputs.decoder_attentions, + cross_attentions=decoder_outputs.cross_attentions, + encoder_last_hidden_state=decoder_outputs.encoder_last_hidden_state, + encoder_hidden_states=decoder_outputs.encoder_hidden_states, + encoder_attentions=decoder_outputs.encoder_attentions, + ) + + def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor): + return self._shift_right(labels) + + +@auto_docstring +class T5Gemma2ForSequenceClassification(T5Gemma2PreTrainedModel): + _tied_weights_keys = [ + "model.encoder.embed_tokens.weight", + "model.decoder.embed_tokens.weight", + ] + _dynamic_tied_weights_keys = [ + "model.encoder.embed_tokens.eoi_embedding", + "model.decoder.embed_tokens.eoi_embedding", + ] + + def __init__(self, config: T5Gemma2Config): + super().__init__(config) + self.num_labels = config.num_labels + self.hidden_size = config.decoder.hidden_size + + self.model = T5Gemma2Model(config) + + classifier_dropout = getattr(config, "classifier_dropout_rate", 0.1) + self.score = T5Gemma2ClassificationHead(self.hidden_size, self.num_labels, classifier_dropout) + self.post_init() + + def get_input_embeddings(self): + return self.model.get_input_embeddings() + + def set_input_embeddings(self, value): + self.model.set_input_embeddings(value) + + @can_return_tuple + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + pixel_values: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + decoder_input_ids: Optional[torch.LongTensor] = None, + decoder_attention_mask: Optional[torch.Tensor] = None, + decoder_position_ids: Optional[torch.LongTensor] = None, + encoder_outputs: Optional[BaseModelOutput] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + decoder_inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> SequenceClassifierOutput: + r""" + decoder_position_ids (`torch.LongTensor` of shape `(batch_size, decoder_sequence_length)`, *optional*): + Indices of positions of each decoder input sequence tokens in the position embeddings. Selected in the range `[0, + config.decoder.n_positions - 1]`. [What are position IDs?](../glossary#position-ids) + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + if inputs_embeds is not None or decoder_inputs_embeds is not None: + raise NotImplementedError( + f"Passing input embeddings is currently not supported for {self.__class__.__name__}." + ) + + if input_ids is None: + raise ValueError("You have to specify input_ids") + + if decoder_input_ids is None: + decoder_input_ids = self._shift_right(input_ids) + + outputs: Seq2SeqModelOutput = self.model( + input_ids, + pixel_values=pixel_values, + attention_mask=attention_mask, + position_ids=position_ids, + decoder_input_ids=decoder_input_ids, + decoder_attention_mask=decoder_attention_mask, + decoder_position_ids=decoder_position_ids, + encoder_outputs=encoder_outputs, + inputs_embeds=inputs_embeds, + decoder_inputs_embeds=decoder_inputs_embeds, + use_cache=False, + **kwargs, + ) + + last_hidden_state = outputs.last_hidden_state + hidden_states = outputs.decoder_hidden_states + attentions = outputs.decoder_attentions + + logits = self.score(last_hidden_state) + + batch_size = input_ids.shape[0] + # To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id + non_pad_mask = (decoder_input_ids != self.config.pad_token_id).to(logits.device, torch.int32) + token_indices = torch.arange(decoder_input_ids.shape[-1], device=logits.device, dtype=torch.int32) + last_non_pad_token = (token_indices * non_pad_mask).argmax(-1) + last_non_pad_token = torch.clamp(last_non_pad_token, max=decoder_input_ids.shape[-1] - 1) + + pooled_logits = logits[torch.arange(batch_size, device=logits.device), last_non_pad_token] + + loss = None + if labels is not None: + loss = self.loss_function(logits=logits, labels=labels, pooled_logits=pooled_logits, config=self.config) + + return SequenceClassifierOutput( + loss=loss, + logits=pooled_logits, + hidden_states=hidden_states, + attentions=attentions, + ) + + +@auto_docstring +class T5Gemma2ForTokenClassification(T5Gemma2PreTrainedModel): + _tied_weights_keys = [ + "model.encoder.embed_tokens.weight", + "model.decoder.embed_tokens.weight", + ] + _dynamic_tied_weights_keys = [ + "model.encoder.embed_tokens.eoi_embedding", + "model.decoder.embed_tokens.eoi_embedding", + ] + + def __init__(self, config: T5Gemma2Config): + super().__init__(config) + self.num_labels = config.num_labels + self.hidden_size = config.decoder.hidden_size + + self.model = T5Gemma2Model(config) + + classifier_dropout = getattr(config, "classifier_dropout_rate", 0.1) + self.score = T5Gemma2ClassificationHead(self.hidden_size, self.num_labels, classifier_dropout) + + self.post_init() + + def get_input_embeddings(self): + return self.model.get_input_embeddings() + + def set_input_embeddings(self, value): + self.model.set_input_embeddings(value) + + @can_return_tuple + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + pixel_values: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + decoder_input_ids: Optional[torch.LongTensor] = None, + decoder_attention_mask: Optional[torch.Tensor] = None, + decoder_position_ids: Optional[torch.LongTensor] = None, + encoder_outputs: Optional[BaseModelOutput] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + decoder_inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> TokenClassifierOutput: + r""" + decoder_position_ids (`torch.LongTensor` of shape `(batch_size, decoder_sequence_length)`, *optional*): + Indices of positions of each decoder input sequence tokens in the position embeddings. Selected in the range `[0, + config.decoder.n_positions - 1]`. [What are position IDs?](../glossary#position-ids) + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + if inputs_embeds is not None or decoder_inputs_embeds is not None: + raise NotImplementedError( + f"Passing input embeddings is currently not supported for {self.__class__.__name__}." + ) + + if input_ids is None: + raise ValueError("You have to specify input_ids") + + if decoder_input_ids is None: + decoder_input_ids = self._shift_right(input_ids) + + outputs: Seq2SeqModelOutput = self.model( + input_ids, + pixel_values=pixel_values, + attention_mask=attention_mask, + position_ids=position_ids, + decoder_input_ids=decoder_input_ids, + decoder_attention_mask=decoder_attention_mask, + decoder_position_ids=decoder_position_ids, + encoder_outputs=encoder_outputs, + inputs_embeds=inputs_embeds, + decoder_inputs_embeds=decoder_inputs_embeds, + use_cache=False, + **kwargs, + ) + last_hidden_state = outputs.last_hidden_state + hidden_states = outputs.decoder_hidden_states + attentions = outputs.decoder_attentions + + logits = self.score(last_hidden_state) + + loss = None + if labels is not None: + loss = self.loss_function(logits, labels, self.config) + + return TokenClassifierOutput( + loss=loss, + logits=logits, + hidden_states=hidden_states, + attentions=attentions, + ) + + +__all__ = [ + "T5Gemma2Config", + "T5Gemma2ModuleConfig", + "T5Gemma2ForConditionalGeneration", + "T5Gemma2Model", + "T5Gemma2PreTrainedModel", + "T5Gemma2ForSequenceClassification", + "T5Gemma2ForTokenClassification", +] diff --git a/tests/models/t5gemma2/__init__.py b/tests/models/t5gemma2/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/models/t5gemma2/test_modeling_t5gemma2.py b/tests/models/t5gemma2/test_modeling_t5gemma2.py new file mode 100644 index 000000000000..7ed39fa213a6 --- /dev/null +++ b/tests/models/t5gemma2/test_modeling_t5gemma2.py @@ -0,0 +1,1125 @@ +# Copyright 2025 Google Inc. HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Testing suite for the PyTorch T5Gemma2 model.""" + +import copy +import inspect +import unittest + +import pytest + +from transformers import T5Gemma2Config, T5Gemma2ModuleConfig, is_torch_available +from transformers.testing_utils import ( + require_torch, + require_torch_accelerator, + torch_device, +) + +from ...generation.test_utils import GenerationTesterMixin, has_similar_generate_outputs +from ...test_configuration_common import ConfigTester +from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor + + +if is_torch_available(): + import torch + import torch.nn.functional as F + + from transformers import ( + T5Gemma2ForConditionalGeneration, + T5Gemma2ForSequenceClassification, + T5Gemma2ForTokenClassification, + T5Gemma2Model, + ) + + +class T5Gemma2ModelTester: + config_class = T5Gemma2Config + module_config_class = T5Gemma2ModuleConfig + + if is_torch_available(): + model_class = T5Gemma2Model + causal_lm_class = T5Gemma2ForConditionalGeneration + sequence_classification_class = T5Gemma2ForSequenceClassification + token_classification_class = T5Gemma2ForTokenClassification + + def __init__( + self, + parent, + batch_size=13, + is_training=True, + use_attention_mask=True, + use_labels=True, + vocab_size=99, + # decoder-specific + seq_length=7, + hidden_size=32, + num_hidden_layers=2, + num_attention_heads=4, + num_key_value_heads=2, + intermediate_size=37, + # encoder-specific + encoder_seq_length=7, + encoder_hidden_size=32, + encoder_num_hidden_layers=2, + encoder_num_attention_heads=4, + encoder_num_key_value_heads=2, + encoder_intermediate_size=37, + # vision-specific + mm_tokens_per_image=2, + image_token_index=4, + boi_token_index=5, + eoi_token_index=6, + vision_config={ + "use_labels": True, + "image_size": 20, + "patch_size": 5, + "num_channels": 3, + "is_training": True, + "hidden_size": 32, + "num_key_value_heads": 1, + "num_hidden_layers": 2, + "num_attention_heads": 4, + "intermediate_size": 37, + "dropout": 0.1, + "attention_dropout": 0.1, + "initializer_range": 0.02, + }, + # common + hidden_act="gelu", + hidden_dropout_prob=0.1, + attention_probs_dropout_prob=0.1, + max_position_embeddings=512, + type_vocab_size=16, + type_sequence_label_size=2, + initializer_range=0.02, + num_labels=3, + num_choices=4, + scope=None, + # special ids + eos_token_id=1, + pad_token_id=0, + bos_token_id=2, + ): + self.parent = parent + self.batch_size = batch_size + self.is_training = is_training + self.use_attention_mask = use_attention_mask + self.use_labels = use_labels + self.vocab_size = vocab_size + # decoder + self.seq_length = seq_length + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.num_key_value_heads = num_key_value_heads + self.intermediate_size = intermediate_size + # encoder + self.encoder_seq_length = encoder_seq_length + self.encoder_hidden_size = encoder_hidden_size + self.encoder_num_hidden_layers = encoder_num_hidden_layers + self.encoder_num_attention_heads = encoder_num_attention_heads + self.encoder_num_key_value_heads = encoder_num_key_value_heads + self.encoder_intermediate_size = encoder_intermediate_size + # vision + self.mm_tokens_per_image = mm_tokens_per_image + self.image_token_index = image_token_index + self.boi_token_index = boi_token_index + self.eoi_token_index = eoi_token_index + self.vision_config = vision_config + self.num_channels = vision_config["num_channels"] + self.image_size = vision_config["image_size"] + # common + self.hidden_act = hidden_act + self.hidden_dropout_prob = hidden_dropout_prob + self.attention_probs_dropout_prob = attention_probs_dropout_prob + self.max_position_embeddings = max_position_embeddings + self.type_vocab_size = type_vocab_size + self.type_sequence_label_size = type_sequence_label_size + self.initializer_range = initializer_range + self.num_labels = num_labels + self.num_choices = num_choices + self.scope = scope + self.head_dim = self.hidden_size // self.num_attention_heads + # assume encoder and decoder have the same head dimension. + assert self.head_dim == self.encoder_hidden_size // self.encoder_num_attention_heads + # special ids + self.eos_token_id = eos_token_id + self.pad_token_id = pad_token_id + self.bos_token_id = bos_token_id + # assume the number of attention heads are the same across encoder and decoder + # only used for generation testing purpose. + assert self.num_attention_heads == self.encoder_num_attention_heads + + def get_encoder_config(self): + return self.module_config_class( + vocab_size=self.vocab_size, + hidden_size=self.encoder_hidden_size, + num_hidden_layers=self.encoder_num_hidden_layers, + num_attention_heads=self.encoder_num_attention_heads, + num_key_value_heads=self.encoder_num_key_value_heads, + intermediate_size=self.encoder_intermediate_size, + hidden_act=self.hidden_act, + hidden_dropout_prob=self.hidden_dropout_prob, + attention_probs_dropout_prob=self.attention_probs_dropout_prob, + max_position_embeddings=self.max_position_embeddings, + type_vocab_size=self.type_vocab_size, + is_decoder=False, + initializer_range=self.initializer_range, + head_dim=self.head_dim, + bos_token_id=self.bos_token_id, + eos_token_id=self.eos_token_id, + pad_token_id=self.pad_token_id, + ) + + def get_decoder_config(self): + return self.module_config_class( + vocab_size=self.vocab_size, + hidden_size=self.hidden_size, + num_hidden_layers=self.num_hidden_layers, + num_attention_heads=self.num_attention_heads, + num_key_value_heads=self.num_key_value_heads, + intermediate_size=self.intermediate_size, + cross_attention_hidden_size=self.encoder_hidden_size, + hidden_act=self.hidden_act, + hidden_dropout_prob=self.hidden_dropout_prob, + attention_probs_dropout_prob=self.attention_probs_dropout_prob, + max_position_embeddings=self.max_position_embeddings, + type_vocab_size=self.type_vocab_size, + is_decoder=True, + initializer_range=self.initializer_range, + head_dim=self.head_dim, + bos_token_id=self.bos_token_id, + eos_token_id=self.eos_token_id, + pad_token_id=self.pad_token_id, + ) + + def get_config(self, is_encoder_decoder=True): + return self.config_class( + encoder=self.get_encoder_config(), + decoder=self.get_decoder_config(), + vision_config=self.vision_config, + is_encoder_decoder=is_encoder_decoder, + # vision + image_token_index=self.image_token_index, + boi_token_index=self.boi_token_index, + eoi_token_index=self.eoi_token_index, + mm_tokens_per_image=self.mm_tokens_per_image, + # Used for generation test. + num_attention_heads=self.num_attention_heads, + num_key_value_heads=self.num_key_value_heads, + vocab_size=self.vocab_size, + hidden_size=self.hidden_size, + num_hidden_layers=self.num_hidden_layers, + ) + + def prepare_config_and_inputs(self): + config = self.get_config() + + input_ids = ids_tensor([self.batch_size, self.encoder_seq_length], self.vocab_size - 1) + 1 + decoder_input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size - 1) + 1 + # Vision inputs. + pixel_values = floats_tensor( + [ + self.batch_size, + self.vision_config["num_channels"], + self.vision_config["image_size"], + self.vision_config["image_size"], + ] + ) + + # Remove BOS symbols from inputs. + input_ids = torch.where(input_ids == self.bos_token_id, 42, input_ids) + decoder_input_ids = torch.where(decoder_input_ids == self.bos_token_id, 42, decoder_input_ids) + + # Avoid leading PAD tokens from inputs. + decoder_input_ids[:, 0] = self.pad_token_id + 1 + + # set the 3 first tokens to be image, and ensure that no other tokens are image tokens + # do not change this unless you modified image size or patch size + input_ids[input_ids == config.image_token_index] = self.pad_token_id + input_ids[:, :1] = config.image_token_index + + attention_mask = None + decoder_attention_mask = None + if self.use_attention_mask: + attention_mask = ids_tensor([self.batch_size, self.encoder_seq_length], vocab_size=2) + decoder_attention_mask = ids_tensor([self.batch_size, self.seq_length], vocab_size=2) + + lm_labels = None + if self.use_labels: + lm_labels = ids_tensor([self.batch_size, self.seq_length], self.vocab_size) + + return ( + config, + input_ids, + decoder_input_ids, + attention_mask, + decoder_attention_mask, + lm_labels, + pixel_values, + ) + + def prepare_config_and_inputs_for_common(self): + config_and_inputs = self.prepare_config_and_inputs() + ( + config, + input_ids, + decoder_input_ids, + attention_mask, + decoder_attention_mask, + lm_labels, + pixel_values, + ) = config_and_inputs + + inputs_dict = { + "input_ids": input_ids, + "attention_mask": attention_mask, + "decoder_input_ids": decoder_input_ids, + "decoder_attention_mask": decoder_attention_mask, + "pixel_values": pixel_values, + } + return config, inputs_dict + + def create_and_check_model( + self, + config, + input_ids, + decoder_input_ids, + attention_mask, + decoder_attention_mask, + lm_labels, + pixel_values, + ): + model = self.model_class(config=config).to(torch_device).eval() + + result = model( + input_ids=input_ids, + decoder_input_ids=decoder_input_ids, + pixel_values=pixel_values, + attention_mask=attention_mask, + decoder_attention_mask=decoder_attention_mask, + ) + + decoder_output = result.last_hidden_state + decoder_past = result.past_key_values + encoder_output = result.encoder_last_hidden_state + + self.parent.assertEqual( + encoder_output.size(), (self.batch_size, self.encoder_seq_length, self.encoder_hidden_size) + ) + self.parent.assertEqual(decoder_output.size(), (self.batch_size, self.seq_length, self.hidden_size)) + self.parent.assertIsNotNone(decoder_past) + self.parent.assertEqual(len(decoder_past.self_attention_cache), config.decoder.num_hidden_layers) + self.parent.assertEqual(len(decoder_past.cross_attention_cache), config.decoder.num_hidden_layers) + + def check_prepare_lm_labels_via_shift_left( + self, + config, + input_ids, + decoder_input_ids, + attention_mask, + decoder_attention_mask, + lm_labels, + pixel_values, + ): + model = self.model_class(config=config).to(torch_device).eval() + + # _shift_right should be called on labels + shifted_labels = model._shift_right(lm_labels) + + # first token should be decoder_start_token_id + self.parent.assertTrue(torch.all(shifted_labels[:, 0] == config.decoder.bos_token_id)) + + # the rest should be the labels shifted by one, with -100 replaced by pad_token_id + labels_without_ignore_index = lm_labels.masked_fill(lm_labels == -100, config.decoder.pad_token_id) + self.parent.assertTrue(torch.all(shifted_labels[:, 1:] == labels_without_ignore_index[:, :-1])) + + def create_and_check_with_lm_head( + self, + config, + input_ids, + decoder_input_ids, + attention_mask, + decoder_attention_mask, + lm_labels, + pixel_values, + ): + model = self.causal_lm_class(config=config).to(torch_device).eval() + outputs = model( + input_ids=input_ids, + decoder_input_ids=decoder_input_ids, + attention_mask=attention_mask, + decoder_attention_mask=decoder_attention_mask, + labels=lm_labels, + pixel_values=pixel_values, + ) + self.parent.assertEqual(len(outputs), 5) + self.parent.assertEqual(outputs["logits"].size(), (self.batch_size, self.seq_length, self.vocab_size)) + self.parent.assertEqual(outputs["loss"].size(), ()) + + def create_and_check_with_sequence_classification_head( + self, + config, + input_ids, + decoder_input_ids, + attention_mask, + decoder_attention_mask, + lm_labels, + pixel_values, + ): + labels = torch.tensor([1] * self.batch_size, dtype=torch.long, device=torch_device) + model = self.sequence_classification_class(config=config).to(torch_device).eval() + outputs = model( + input_ids=input_ids, + pixel_values=pixel_values, + decoder_input_ids=decoder_input_ids, + labels=labels, + ) + self.parent.assertEqual(outputs["logits"].size(), (self.batch_size, config.num_labels)) + self.parent.assertEqual(outputs["loss"].size(), ()) + + def create_and_check_with_token_classification_head( + self, + config, + input_ids, + decoder_input_ids, + attention_mask, + decoder_attention_mask, + lm_labels, + pixel_values, + ): + labels = torch.tensor([1] * self.seq_length * self.batch_size, dtype=torch.long, device=torch_device) + model = self.token_classification_class(config=config) + model = model.to(torch_device).eval() + outputs = model( + input_ids=input_ids, + pixel_values=pixel_values, + decoder_input_ids=decoder_input_ids, + labels=labels, + ) + + self.parent.assertEqual(outputs["logits"].size(), (self.batch_size, self.seq_length, config.num_labels)) + self.parent.assertEqual(outputs["loss"].size(), ()) + + def create_and_check_decoder_model_past( + self, + config, + input_ids, + decoder_input_ids, + attention_mask, + decoder_attention_mask, + lm_labels, + pixel_values, + ): + model = self.model_class(config=config).get_decoder().to(torch_device).eval() + encoder_hidden_states = torch.ones( + (self.batch_size, self.encoder_seq_length, self.encoder_hidden_size), dtype=torch.float32 + ).to(torch_device) + + # first forward pass + outputs = model(decoder_input_ids, encoder_hidden_states=encoder_hidden_states, use_cache=True) + outputs_use_cache_conf = model(decoder_input_ids, encoder_hidden_states=encoder_hidden_states) + outputs_no_past = model(decoder_input_ids, encoder_hidden_states=encoder_hidden_states, use_cache=False) + + self.parent.assertTrue(len(outputs) == len(outputs_use_cache_conf)) + self.parent.assertTrue(len(outputs) == len(outputs_no_past) + 1) + + output, past_key_values = outputs.to_tuple() + + # create hypothetical next token and extent to next_input_ids + next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size) + + # append to next input_ids and + next_input_ids = torch.cat([decoder_input_ids, next_tokens], dim=-1) + + output_from_no_past = model(next_input_ids, encoder_hidden_states=encoder_hidden_states)["last_hidden_state"] + output_from_past = model( + next_tokens, encoder_hidden_states=encoder_hidden_states, past_key_values=past_key_values + )["last_hidden_state"] + + # select random slice + random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item() + output_from_no_past_slice = output_from_no_past[:, -1, random_slice_idx].detach() + output_from_past_slice = output_from_past[:, 0, random_slice_idx].detach() + + # test that outputs are equal for slice + self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3)) + + def create_and_check_decoder_model_attention_mask_past( + self, + config, + input_ids, + decoder_input_ids, + attention_mask, + decoder_attention_mask, + lm_labels, + pixel_values, + ): + model = self.model_class(config=config).get_decoder().to(torch_device).eval() + encoder_hidden_states = torch.ones( + (self.batch_size, self.encoder_seq_length, self.encoder_hidden_size), dtype=torch.float32 + ).to(torch_device) + + # create attention mask + attn_mask = torch.ones(decoder_input_ids.shape, dtype=torch.long, device=torch_device) + + half_seq_length = decoder_input_ids.shape[-1] // 2 + attn_mask[:, half_seq_length:] = 0 + + # first forward pass + output, past_key_values = model( + decoder_input_ids, encoder_hidden_states=encoder_hidden_states, attention_mask=attn_mask, use_cache=True + ).to_tuple() + + # create hypothetical next token and extent to next_input_ids + next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size) + + # change a random masked slice from input_ids + random_seq_idx_to_change = ids_tensor((1,), half_seq_length).item() + 1 + random_other_next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size).squeeze(-1) + decoder_input_ids[:, -random_seq_idx_to_change] = random_other_next_tokens + + # append to next input_ids and attn_mask + next_input_ids = torch.cat([decoder_input_ids, next_tokens], dim=-1) + attn_mask = torch.cat( + [attn_mask, torch.ones((attn_mask.shape[0], 1), dtype=torch.long, device=torch_device)], + dim=1, + ) + + # get two different outputs + output_from_no_past = model( + next_input_ids, encoder_hidden_states=encoder_hidden_states, attention_mask=attn_mask + )["last_hidden_state"] + output_from_past = model( + next_tokens, + encoder_hidden_states=encoder_hidden_states, + past_key_values=past_key_values, + attention_mask=attn_mask, + )["last_hidden_state"] + + # select random slice + random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item() + output_from_no_past_slice = output_from_no_past[:, -1, random_slice_idx].detach() + output_from_past_slice = output_from_past[:, 0, random_slice_idx].detach() + + # test that outputs are equal for slice + self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3)) + + def create_and_check_decoder_model_past_large_inputs( + self, + config, + input_ids, + decoder_input_ids, + attention_mask, + decoder_attention_mask, + lm_labels, + pixel_values, + ): + model = self.model_class(config=config).get_decoder().to(torch_device).eval() + encoder_hidden_states = torch.ones( + (self.batch_size, self.encoder_seq_length, self.encoder_hidden_size), dtype=torch.float32 + ).to(torch_device) + + # first forward pass + outputs = model( + decoder_input_ids, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + use_cache=True, + ) + + output, past_key_values = outputs.to_tuple() + + # create hypothetical multiple next token and extent to next_input_ids + next_tokens = ids_tensor((self.batch_size, 3), config.vocab_size) + next_mask = ids_tensor((self.batch_size, 3), vocab_size=2) + + # append to next input_ids and + next_input_ids = torch.cat([decoder_input_ids, next_tokens], dim=-1) + next_attention_mask = torch.cat([attention_mask, next_mask], dim=-1) + + output_from_no_past = model( + next_input_ids, encoder_hidden_states=encoder_hidden_states, attention_mask=next_attention_mask + )["last_hidden_state"] + output_from_past = model( + next_tokens, + encoder_hidden_states=encoder_hidden_states, + attention_mask=next_attention_mask, + past_key_values=past_key_values, + )["last_hidden_state"] + + # select random slice + random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item() + output_from_no_past_slice = output_from_no_past[:, -3:, random_slice_idx].detach() + output_from_past_slice = output_from_past[:, :, random_slice_idx].detach() + + self.parent.assertTrue(output_from_past_slice.shape[1] == next_tokens.shape[1]) + + # test that outputs are equal for slice + self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3)) + + def create_and_check_generate_with_past_key_values( + self, + config, + input_ids, + decoder_input_ids, + attention_mask, + decoder_attention_mask, + lm_labels, + pixel_values, + ): + model = self.causal_lm_class(config=config).to(torch_device).eval() + torch.manual_seed(0) + output_without_past_cache = model.generate( + input_ids, pixel_values=pixel_values, num_beams=2, max_length=5, do_sample=True, use_cache=False + ) + torch.manual_seed(0) + output_with_past_cache = model.generate( + input_ids, pixel_values=pixel_values, num_beams=2, max_length=5, do_sample=True + ) + self.parent.assertTrue(torch.all(output_with_past_cache == output_without_past_cache)) + + def create_and_check_model_fp16_forward( + self, + config, + input_ids, + decoder_input_ids, + attention_mask, + decoder_attention_mask, + lm_labels, + pixel_values, + ): + model = self.model_class(config=config).to(torch_device).half().eval() + output = model( + input_ids, + pixel_values=pixel_values, + decoder_input_ids=decoder_input_ids, + attention_mask=attention_mask, + decoder_attention_mask=decoder_attention_mask, + )["last_hidden_state"] + self.parent.assertFalse(torch.isnan(output).any().item()) + + +@require_torch +class T5Gemma2ModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase): + all_model_classes = ( + ( + T5Gemma2Model, + T5Gemma2ForConditionalGeneration, + T5Gemma2ForSequenceClassification, + T5Gemma2ForTokenClassification, + ) + if is_torch_available() + else () + ) + + _is_stateful = True + is_encoder_decoder = True + + # used in `test_torch_compile_for_training` + _torch_compile_train_cls = T5Gemma2ForConditionalGeneration if is_torch_available() else None + # `T5Gemma2` will give warning or raise error if it is not `eager` during training. + _torch_compile_train_attn_implementation = "eager" + + # won't fix + test_torchscript = False + + # MP works but offload doesn't work when the SigLIP MultiheadAttention is offloaded + # TODO: One potential solution would be to add to set preload_module_classes = ["SiglipMultiheadAttentionPoolingHead"] + # in the dispatch_model function + test_cpu_offload = False + test_disk_offload_safetensors = False + test_disk_offload_bin = False + + def setUp(self): + self.model_tester = T5Gemma2ModelTester(self) + self.config_tester = ConfigTester( + self, + config_class=T5Gemma2Config, + # For faking the testing. + hidden_size=37, + vocab_size=self.model_tester.vocab_size, + num_attention_heads=self.model_tester.num_attention_heads, + num_hidden_layers=self.model_tester.num_hidden_layers, + ) + + def is_pipeline_test_to_skip( + self, + pipeline_test_case_name, + config_class, + model_architecture, + tokenizer_name, + image_processor_name, + feature_extractor_name, + processor_name, + ): + if tokenizer_name is None: + return True + if pipeline_test_case_name == "QAPipelineTests" and not tokenizer_name.endswith("Fast"): + return True + + return False + + def test_config(self): + self.config_tester.run_common_tests() + + def test_shift_right(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.check_prepare_lm_labels_via_shift_left(*config_and_inputs) + + def test_model(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_model(*config_and_inputs) + + # Based on tests.models.t5.test_modeling_t5.T5ModelTest.test_inputs_embeds + def test_inputs_embeds(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in (T5Gemma2Model, T5Gemma2ForConditionalGeneration): + model = model_class(config) + model.to(torch_device) + model.eval() + + inputs = copy.deepcopy(self._prepare_for_class(inputs_dict, model_class)) + + if not self.is_encoder_decoder: + input_ids = inputs["input_ids"] + del inputs["input_ids"] + else: + encoder_input_ids = inputs["input_ids"] + decoder_input_ids = inputs.get("decoder_input_ids", encoder_input_ids) + del inputs["input_ids"] + inputs.pop("decoder_input_ids", None) + + wte = model.get_input_embeddings() + if not self.is_encoder_decoder: + inputs["inputs_embeds"] = wte(input_ids) + else: + inputs["inputs_embeds"] = wte(encoder_input_ids) + inputs["decoder_inputs_embeds"] = wte(decoder_input_ids) + + with torch.no_grad(): + model(**inputs)[0] + + def test_with_lm_head(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_with_lm_head(*config_and_inputs) + + def test_with_sequence_classification_head(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_with_sequence_classification_head(*config_and_inputs) + + def test_with_token_classification_head(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_with_token_classification_head(*config_and_inputs) + + def test_decoder_model_past(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_decoder_model_past(*config_and_inputs) + + def test_decoder_model_past_with_attn_mask(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_decoder_model_attention_mask_past(*config_and_inputs) + + # Based on tests.models.t5.test_modeling_t5.T5ModelTest.test_decoder_model_past_with_3d_attn_mask + def test_decoder_model_past_with_3d_attn_mask(self): + ( + config, + input_ids, + decoder_input_ids, + attention_mask, + decoder_attention_mask, + lm_labels, + pixel_values, + ) = self.model_tester.prepare_config_and_inputs() + + attention_mask = ids_tensor( + [self.model_tester.batch_size, self.model_tester.encoder_seq_length, self.model_tester.encoder_seq_length], + vocab_size=2, + ) + decoder_attention_mask = ids_tensor( + [self.model_tester.batch_size, self.model_tester.seq_length, self.model_tester.seq_length], + vocab_size=2, + ) + + self.model_tester.create_and_check_decoder_model_attention_mask_past( + config, + input_ids, + decoder_input_ids, + attention_mask, + decoder_attention_mask, + lm_labels, + pixel_values, + ) + + def test_decoder_model_past_with_large_inputs(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_decoder_model_past_large_inputs(*config_and_inputs) + + def test_generate_with_past_key_values(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_generate_with_past_key_values(*config_and_inputs) + + @unittest.skipIf(torch_device == "cpu", "Can't do half precision") + def test_model_fp16_forward(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_model_fp16_forward(*config_and_inputs) + + # Based on tests.models.gemma.test_modeling_gemma.GemmaModelTest.test_Gemma_sequence_classification_model with Gemma -> T5Gemma2 (Add is_encoder_decoder option) + def test_T5Gemma2_sequence_classification_model(self): + config, input_dict = self.model_tester.prepare_config_and_inputs_for_common() + config.num_labels = 3 + input_ids = input_dict["input_ids"] + attention_mask = input_ids.ne(1).to(torch_device) + sequence_labels = ids_tensor([self.model_tester.batch_size], self.model_tester.type_sequence_label_size) + + for pixel_values in [None, input_dict["pixel_values"]]: + model = self.model_tester.sequence_classification_class(config).to(torch_device).eval() + result = model(input_ids, pixel_values=pixel_values, attention_mask=attention_mask, labels=sequence_labels) + self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels)) + + # Based on tests.models.gemma.test_modeling_gemma.GemmaModelTest.test_Gemma_sequence_classification_model_for_single_label with Gemma -> T5Gemma2 (Add is_encoder_decoder option) + def test_T5Gemma2_sequence_classification_model_for_single_label(self): + config, input_dict = self.model_tester.prepare_config_and_inputs_for_common() + config.num_labels = 3 + config.problem_type = "single_label_classification" + input_ids = input_dict["input_ids"] + attention_mask = input_ids.ne(1).to(torch_device) + sequence_labels = ids_tensor([self.model_tester.batch_size], self.model_tester.type_sequence_label_size) + + for pixel_values in [None, input_dict["pixel_values"]]: + model = self.model_tester.sequence_classification_class(config).to(torch_device).eval() + result = model(input_ids, pixel_values=pixel_values, attention_mask=attention_mask, labels=sequence_labels) + self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels)) + + # Based on tests.models.gemma.test_modeling_gemma.GemmaModelTest.test_Gemma_sequence_classification_model_for_multi_label with Gemma -> T5Gemma2 (Add is_encoder_decoder option) + def test_T5Gemma2_sequence_classification_model_for_multi_label(self): + config, input_dict = self.model_tester.prepare_config_and_inputs_for_common() + config.num_labels = 3 + config.problem_type = "multi_label_classification" + input_ids = input_dict["input_ids"] + attention_mask = input_ids.ne(1).to(torch_device) + sequence_labels = ids_tensor( + [self.model_tester.batch_size, config.num_labels], self.model_tester.type_sequence_label_size + ).to(torch.float) + + for pixel_values in [None, input_dict["pixel_values"]]: + model = self.model_tester.sequence_classification_class(config).to(torch_device).eval() + result = model(input_ids, pixel_values=pixel_values, attention_mask=attention_mask, labels=sequence_labels) + self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels)) + + # Based on tests.models.gemma.test_modeling_gemma.GemmaModelTest.test_Gemma_token_classification_model with Gemma -> T5Gemma2 (Add is_encoder_decoder option) + def test_T5Gemma2_token_classification_model(self): + config, input_dict = self.model_tester.prepare_config_and_inputs_for_common() + config.num_labels = 3 + input_ids = input_dict["input_ids"] + decoder_input_ids = input_dict["decoder_input_ids"] + attention_mask = input_ids.ne(1).to(torch_device) + token_labels = ids_tensor([self.model_tester.batch_size, self.model_tester.seq_length], config.num_labels) + + for pixel_values in [None, input_dict["pixel_values"]]: + model = self.model_tester.token_classification_class(config).to(torch_device).eval() + + result = model( + input_ids, + decoder_input_ids=decoder_input_ids, + pixel_values=pixel_values, + attention_mask=attention_mask, + labels=token_labels, + ) + self.assertEqual( + result.logits.shape, + (self.model_tester.batch_size, self.model_tester.seq_length, self.model_tester.num_labels), + ) + + # Based on tests.models.gemma.test_modeling_gemma.GemmaModelTest.test_sdpa_equivalence + # Add decoder_input_ids and adjust hidden states. + @require_torch_accelerator + def test_sdpa_equivalence(self): + for model_class in self.all_model_classes: + if not model_class._supports_sdpa: + self.skipTest(reason="Model does not support SDPA") + + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + model = model_class(config).to(torch_device) + dummy_input = inputs_dict[model_class.main_input_name].to(torch_device) + decoder_dummy_input = torch.ones_like(dummy_input) + + model.config._attn_implementation = "sdpa" + states_sdpa = model(dummy_input, decoder_input_ids=decoder_dummy_input, output_hidden_states=True) + + model.config._attn_implementation = "eager" + states_eager = model(dummy_input, decoder_input_ids=decoder_dummy_input, output_hidden_states=True) + + if hasattr(states_sdpa, "decoder_hidden_states"): + states_sdpa = states_sdpa.decoder_hidden_states[-1] + states_eager = states_eager.decoder_hidden_states[-1] + else: + states_sdpa = states_sdpa.hidden_states[-1] + states_eager = states_eager.hidden_states[-1] + + torch.testing.assert_close(states_sdpa, states_eager, atol=1e-5, rtol=1e-5) + + @unittest.skip("T5Gemma2 eager/FA2 attention outputs are expected to be different") + def test_flash_attn_2_equivalence(self): + pass + + @unittest.skip("This was not properly written, submodules need the attribute to be overwritten") + def test_attention_outputs(self): + pass + + @unittest.skip("Mismatch issue doesn't exist in T5Gemma2.") + def test_load_with_mismatched_shapes(self): + pass + + # Based on tests.generation.test_utils.GenerationTesterMixin.test_generate_continue_from_past_key_values + # Updated decoder_attention_mask to consider the appended bos token + @pytest.mark.generate + def test_generate_continue_from_past_key_values(self): + # Tests that we can continue generating from past key values, returned from a previous `generate` call + for model_class in self.all_generative_model_classes: + if model_class == self.model_tester.token_classification_class: + continue + if any(model_name in model_class.__name__.lower() for model_name in ["imagegpt", "mllama"]): + self.skipTest(reason="Won't fix: old model with unique inputs/caches/other") + if any(model_name in model_class.__name__.lower() for model_name in ["umt5"]): + self.skipTest(reason="TODO: needs modeling or test input preparation fixes for compatibility") + + config, inputs = self.model_tester.prepare_config_and_inputs_for_common() + + if not hasattr(config.get_text_config(), "use_cache"): + self.skipTest(reason=f"{model_class.__name__} doesn't support caching") + + # Let's make it always: + # 1. use cache (for obvious reasons) + # 2. generate to max length (which can be achieved by setting the eos token to an invalid value), which + # would make the test flaky (e.g. EOS is generated on iteration 1 on both generations, but the + # continuation would force it to generate beyond an EOS token) + # 3. ignore `token_type_ids` for simplicity + # 4. ignore `forced_eos_token_id`, which requires further manipulation of the continuation inputs and is + # active by default on some models + # 5. ignore `encoder_no_repeat_ngram_size`, which is set by default in some encoder-decoder models. When + # we use their decoder as a stand-alone model, `encoder_no_repeat_ngram_size` actually prevents + # repetition exclusively from the prompt. This test relies on comparing one call vs 2 calls + # with cache, what is considered a prompt is different in the two cases. + + if "token_type_ids" in inputs: + del inputs["token_type_ids"] + + model = model_class(config).to(torch_device) + model.eval() + + # If "past_key_values" is not returned, skip the test (e.g. RWKV uses a different cache name and format) + outputs = model(**inputs) + if "past_key_values" not in outputs: + self.skipTest(reason="This model doesn't return `past_key_values`") + + generate_kwargs = { + "pad_token_id": -1, + "eos_token_id": -1, + "forced_eos_token_id": None, + "encoder_no_repeat_ngram_size": 0, + "use_cache": True, + "do_sample": False, + "return_dict_in_generate": True, + "output_scores": True, + } + + # Traditional way of generating text, with `return_dict_in_generate` to return the past key values + outputs = model.generate(**inputs, **generate_kwargs, max_new_tokens=4) + + # Let's generate again, but passing the past key values in between (3 + 1 = 4 tokens). Note that the + # inputs may need to be tweaked across `generate` calls (like the attention mask). + outputs_cached = model.generate(**inputs, **generate_kwargs, max_new_tokens=3) + + # Continue from the tokens generated above, preparing the inputs accordingly + inputs["past_key_values"] = outputs_cached.past_key_values + new_attention_len = outputs_cached.sequences.shape[-1] + + # It must be encoder-decoder models + self.assertTrue(config.is_encoder_decoder) + + inputs["decoder_input_ids"] = outputs_cached.sequences + if "decoder_attention_mask" in inputs: + decoder_attention_mask = inputs["decoder_attention_mask"] + + # Add BOS mask: the new sequence comes with a new BOS token, which is not included in the original inputs + padding_tensor = torch.ones_like(decoder_attention_mask[:, :1]) + decoder_attention_mask = torch.cat([padding_tensor, decoder_attention_mask], dim=1) + + inputs["decoder_attention_mask"] = torch.nn.functional.pad( + decoder_attention_mask, + (0, new_attention_len - decoder_attention_mask.shape[1]), + mode="constant", + value=1, + ) + + first_caches_scores = outputs_cached.scores + outputs_cached = model.generate(**inputs, **generate_kwargs, max_new_tokens=1) + full_cached_scores = first_caches_scores + outputs_cached.scores + outputs_cached.scores = full_cached_scores + + # The two sets of generated text and past kv should be equal to each other + self.assertTrue(has_similar_generate_outputs(outputs, outputs_cached)) + self._check_caches_are_equal(outputs.past_key_values, outputs_cached.past_key_values) + + # Based on tests.test_modeling_common.ModelTesterMixin.test_inputs_embeds_matches_input_ids + # Update encoder and decoder embeddings + def test_inputs_embeds_matches_input_ids(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + model_class = self.model_tester.model_class + + model = model_class(config) + model.to(torch_device) + model.eval() + + model_forward_args = inspect.signature(model.forward).parameters + if "inputs_embeds" not in model_forward_args: + self.skipTest(reason="This model doesn't use `inputs_embeds`") + + inputs = copy.deepcopy(self._prepare_for_class(inputs_dict, model_class)) + pad_token_id = config.pad_token_id if config.pad_token_id is not None else 1 + + encoder_embedding = model.get_encoder().get_input_embeddings() + decoder_embedding = model.get_decoder().get_input_embeddings() + + encoder_input_ids = inputs["input_ids"] + decoder_input_ids = inputs.get("decoder_input_ids", encoder_input_ids) + encoder_input_ids[encoder_input_ids == pad_token_id] = max(0, pad_token_id + 1) + decoder_input_ids[decoder_input_ids == pad_token_id] = max(0, pad_token_id + 1) + del inputs["input_ids"] + inputs.pop("decoder_input_ids", None) + + inputs_embeds = encoder_embedding(encoder_input_ids) + decoder_inputs_embeds = decoder_embedding(decoder_input_ids) + with torch.no_grad(): + out_ids = model(input_ids=encoder_input_ids, decoder_input_ids=decoder_input_ids, **inputs)[0] + out_embeds = model(inputs_embeds=inputs_embeds, decoder_inputs_embeds=decoder_inputs_embeds, **inputs)[0] + + torch.testing.assert_close(out_embeds, out_ids) + + @unittest.skip("T5Gemma 2 only support final layer hidden states.") + def test_hidden_states_output(self): + pass + + # Based on tests.models.t5.test_modeling_t5.T5ModelTest.test_custom_4d_attention_mask + # Excluding the final token from input_ids + def test_custom_4d_attention_mask(self): + for model_class in self.all_generative_model_classes: + config, input_dict = self.model_tester.prepare_config_and_inputs_for_common() + model = model_class(config).to(device=torch_device, dtype=torch.float32) + + ( + input_ids, + position_ids, + input_ids_shared_prefix, + mask_shared_prefix, + position_ids_shared_prefix, + ) = self._get_custom_4d_mask_test_data() + mask_shared_prefix = mask_shared_prefix == 0.0 + + outputs = model.forward( + decoder_input_ids=input_ids, + input_ids=input_ids[:, :-1], + decoder_position_ids=position_ids, + ) + logits = outputs.logits + # logits.shape == torch.Size([3, 4, ...]) + + outputs_shared_prefix = model( + input_ids=input_ids[:1, :-1], + decoder_input_ids=input_ids_shared_prefix, + decoder_attention_mask=mask_shared_prefix, + decoder_position_ids=position_ids_shared_prefix, + ) + logits_shared_prefix = outputs_shared_prefix.logits + # logits_shared_prefix.shape == torch.Size([1, 6, ...]) + + torch.testing.assert_close( + outputs.encoder_last_hidden_state[0], outputs_shared_prefix.encoder_last_hidden_state[0] + ) + + out_last_tokens = logits[:, -1, :] # last tokens in each batch line + out_shared_prefix_last_tokens = logits_shared_prefix[0, -3:, :] # last three tokens + + # comparing softmax-normalized logits: + normalized_0 = F.softmax(out_last_tokens) + normalized_1 = F.softmax(out_shared_prefix_last_tokens) + torch.testing.assert_close(normalized_0[2], normalized_1[2], rtol=1e-3, atol=1e-4) + torch.testing.assert_close(normalized_0, normalized_1, rtol=1e-3, atol=1e-4) + + # Based on tests.test_modeling_common.ModelTesterMixin.test_flex_attention_with_grads + # Update hidden size for encoder and decoder + @require_torch_accelerator + def test_flex_attention_with_grads(self): + for model_class in self.all_model_classes: + # TODO: raushan, fix for composite models after making VLMs support new attn API + if not model_class._supports_flex_attn or self._is_composite: + self.skipTest(reason="This model does not support flex attention") + + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + config._attn_implementation = "flex_attention" + # Flex Attention cannot use dropout + config.encoder.attention_dropout = 0 + config.decoder.attention_dropout = 0 + + # Flex attention relies on triton on compilation + # However, triton cannot handle hidden dimensions of less than 16 + # --> forcing at least a hidden dim of 16 + config.encoder.hidden_size *= max( + 16 + // getattr( + config.encoder, "head_dim", config.encoder.hidden_size // config.encoder.num_attention_heads + ), + 1, + ) + config.decoder.hidden_size *= max( + 16 + // getattr( + config.decoder, "head_dim", config.decoder.hidden_size // config.decoder.num_attention_heads + ), + 1, + ) + config.decoder.cross_attention_hidden_size = config.encoder.hidden_size + + config.decoder.head_dim = max(16, config.decoder.head_dim) + config.encoder.head_dim = max(16, config.encoder.head_dim) + + model = model_class(config).to(device=torch_device) + self.assertTrue(model.config._attn_implementation == "flex_attention") + + # Elaborate workaround for encoder-decoder models as some do not specify their main input + dummy_inputs = {model.main_input_name: inputs_dict[model.main_input_name].to(torch_device)} + if config.is_encoder_decoder: + dummy_inputs["decoder_input_ids"] = inputs_dict["decoder_input_ids"].to(torch_device) + dummy_inputs["decoder_attention_mask"] = inputs_dict["decoder_attention_mask"].to(torch_device) + + # If this does not raise an error, the test passes (see https://github.com/huggingface/transformers/pull/35605) + _ = model(**dummy_inputs) + + @unittest.skip(reason="SiglipVisionModel (vision backbone) does not support standalone training") + def test_training_gradient_checkpointing(self): + pass + + @unittest.skip(reason="SiglipVisionModel (vision backbone) does not support standalone training") + def test_training_gradient_checkpointing_use_reentrant(self): + pass + + @unittest.skip(reason="SiglipVisionModel (vision backbone) does not support standalone training") + def test_training_gradient_checkpointing_use_reentrant_false(self): + pass + + @unittest.skip(reason="Self&cross attention are splited after the merged attention") + def test_retain_grad_hidden_states_attentions(self): + pass