Skip to content

Commit e48bc6e

Browse files
committed
Style fix & add use of final_logit_softcapping for attributes check.
1 parent b25f290 commit e48bc6e

File tree

6 files changed

+160
-73
lines changed

6 files changed

+160
-73
lines changed

src/transformers/models/auto/image_processing_auto.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -176,13 +176,13 @@
176176
("swin", ("ViTImageProcessor", "ViTImageProcessorFast")),
177177
("swin2sr", ("Swin2SRImageProcessor", "Swin2SRImageProcessorFast")),
178178
("swinv2", ("ViTImageProcessor", "ViTImageProcessorFast")),
179+
("t5gemma2", ("Gemma3ImageProcessor", "Gemma3ImageProcessorFast")),
179180
("table-transformer", ("DetrImageProcessor", "DetrImageProcessorFast")),
180181
("textnet", ("TextNetImageProcessor", "TextNetImageProcessorFast")),
181182
("timesformer", ("VideoMAEImageProcessor", None)),
182183
("timm_wrapper", ("TimmWrapperImageProcessor", None)),
183184
("tvlt", ("TvltImageProcessor", None)),
184185
("tvp", ("TvpImageProcessor", "TvpImageProcessorFast")),
185-
("t5gemma2", ("Gemma3ImageProcessor", "Gemma3ImageProcessorFast")),
186186
("udop", ("LayoutLMv3ImageProcessor", "LayoutLMv3ImageProcessorFast")),
187187
("upernet", ("SegformerImageProcessor", "SegformerImageProcessorFast")),
188188
("van", ("ConvNextImageProcessor", "ConvNextImageProcessorFast")),

src/transformers/models/auto/processing_auto.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -137,10 +137,10 @@
137137
("speech_to_text", "Speech2TextProcessor"),
138138
("speech_to_text_2", "Speech2Text2Processor"),
139139
("speecht5", "SpeechT5Processor"),
140+
("t5gemma2", "Gemma3Processor"),
140141
("trocr", "TrOCRProcessor"),
141142
("tvlt", "TvltProcessor"),
142143
("tvp", "TvpProcessor"),
143-
("t5gemma2", "Gemma3Processor"),
144144
("udop", "UdopProcessor"),
145145
("unispeech", "Wav2Vec2Processor"),
146146
("unispeech-sat", "Wav2Vec2Processor"),

src/transformers/models/t5gemma2/configuration_t5gemma2.py

Lines changed: 82 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,88 @@
3232

3333

3434
class T5Gemma2ModuleConfig(PreTrainedConfig):
35-
"""Module config for encoder or decoder backbone."""
35+
r"""
36+
This is the configuration class to store the configuration of a [`T5Gemma2ModuleModel`]. It is used to instantiate an T5Gemma2Module
37+
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
38+
defaults will yield a similar configuration to that of the T5Gemma2Module-7B.
39+
e.g. [google/t5_gemma2_module-7b](https://huggingface.co/google/t5_gemma2_module-7b)
40+
Configuration objects inherit from [`PreTrainedConfig`] and can be used to control the model outputs. Read the
41+
documentation from [`PreTrainedConfig`] for more information.
42+
43+
Args:
44+
vocab_size (`int`, *optional*, defaults to 262208):
45+
Vocabulary size of the T5Gemma2Module model. Defines the number of different tokens that can be represented by the
46+
`inputs_ids` passed when calling [`T5Gemma2ModuleModel`]
47+
hidden_size (`int`, *optional*, defaults to 2304):
48+
Dimension of the hidden representations.
49+
intermediate_size (`int`, *optional*, defaults to 9216):
50+
Dimension of the MLP representations.
51+
num_hidden_layers (`int`, *optional*, defaults to 26):
52+
Number of hidden layers in the Transformer decoder.
53+
num_attention_heads (`int`, *optional*, defaults to 8):
54+
Number of attention heads for each attention layer in the Transformer decoder.
55+
num_key_value_heads (`int`, *optional*, defaults to 4):
56+
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
57+
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
58+
`num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
59+
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
60+
by meanpooling all the original heads within that group. For more details, check out [this
61+
paper](https://huggingface.co/papers/2305.13245). If it is not specified, will default to
62+
`num_attention_heads`.
63+
head_dim (`int`, *optional*, defaults to 256):
64+
The attention head dimension.
65+
hidden_activation (`str` or `function`, *optional*, defaults to `"gelu_pytorch_tanh"`):
66+
The non-linear activation function (function or string) in the decoder. Will default to `"gelu_pytorch_tanh"`
67+
if not specified. `"gelu_pytorch_tanh"` uses an approximation of the `"gelu"` activation function.
68+
max_position_embeddings (`int`, *optional*, defaults to 131072):
69+
The maximum sequence length that this model might ever be used with.
70+
initializer_range (`float`, *optional*, defaults to 0.02):
71+
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
72+
rms_norm_eps (`float`, *optional*, defaults to 1e-06):
73+
The epsilon used by the rms normalization layers.
74+
use_cache (`bool`, *optional*, defaults to `True`):
75+
Whether or not the model should return the last key/values attentions (not used by all models). Only
76+
relevant if `config.is_decoder=True`.
77+
pad_token_id (`int`, *optional*, defaults to 0):
78+
Padding token id.
79+
eos_token_id (`int`, *optional*, defaults to 1):
80+
End of stream token id.
81+
bos_token_id (`int`, *optional*, defaults to 2):
82+
Beginning of stream token id.
83+
tie_word_embeddings (`bool`, *optional*, defaults to `True`):
84+
Whether to tie weight embeddings
85+
attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`):
86+
Whether to use a bias in the query, key, value and output projection layers during self-attention.
87+
attention_dropout (`float`, *optional*, defaults to 0.0):
88+
The dropout ratio for the attention probabilities.
89+
query_pre_attn_scalar (`float`, *optional*, defaults to 256):
90+
Scaling factor used on the attention scores
91+
sliding_window (`int`, *optional*, defaults to 4096):
92+
In T5Gemma2Module, every other layer uses sliding window attention. This is the size of the sliding window.
93+
layer_types (`list`, *optional*):
94+
Attention pattern for each layer.
95+
final_logit_softcapping (`float`, *optional*):
96+
Scaling factor when applying tanh softcapping on the logits.
97+
attn_logit_softcapping (`float`, *optional*):
98+
Scaling factor when applying tanh softcapping on the attention scores.
99+
rope_parameters (`RopeParameters`, *optional*):
100+
Dictionary containing the configuration parameters for the RoPE embeddings. The dictionaty should contain
101+
a value for `rope_theta` and optionally parameters used for scaling in case you want to use RoPE
102+
with longer `max_position_embeddings`.
103+
use_bidirectional_attention (`bool`, *optional*, defaults to `False`):
104+
If True, the model will attend to all text tokens instead of using a causal mask. This does not change
105+
behavior for vision tokens.
106+
107+
```python
108+
>>> from transformers import T5Gemma2ModuleModel, T5Gemma2ModuleConfig
109+
>>> # Initializing a T5Gemma2Module t5_gemma2_module-7b style configuration
110+
>>> configuration = T5Gemma2ModuleConfig()
111+
>>> # Initializing a model from the t5_gemma2_module-7b style configuration
112+
>>> model = T5Gemma2ModuleModel(configuration)
113+
>>> # Accessing the model configuration
114+
>>> configuration = model.config
115+
```
116+
"""
36117

37118
model_type = "t5gemma2_module"
38119
keys_to_ignore_at_inference = ["past_key_values"]

src/transformers/models/t5gemma2/modeling_t5gemma2.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1352,6 +1352,12 @@ def forward(
13521352
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
13531353
logits = self.lm_head(hidden_states[:, slice_indices, :])
13541354

1355+
decoder_config = self.config.decoder
1356+
if decoder_config.final_logit_softcapping is not None:
1357+
logits = logits / decoder_config.final_logit_softcapping
1358+
logits = torch.tanh(logits)
1359+
logits = logits * decoder_config.final_logit_softcapping
1360+
13551361
loss = None
13561362
if labels is not None:
13571363
# Input has right-shifted so we directly perform masked lm loss

0 commit comments

Comments
 (0)