Skip to content

Commit bd4c480

Browse files
committed
Update tests and embedding design.
1 parent e48bc6e commit bd4c480

File tree

3 files changed

+37
-250
lines changed

3 files changed

+37
-250
lines changed

src/transformers/models/t5gemma2/modeling_t5gemma2.py

Lines changed: 13 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -813,7 +813,7 @@ class T5Gemma2Encoder(T5Gemma2PreTrainedModel):
813813
def __init__(
814814
self,
815815
config: T5Gemma2ModuleConfig,
816-
shared_embedding: T5Gemma2TextScaledWordEmbedding,
816+
eoi_token_index: int = 256_000,
817817
pixel2feature_preprocessor_fn: Optional[Callable] = None,
818818
):
819819
super().__init__(config)
@@ -823,7 +823,13 @@ def __init__(
823823
# preprocessor for raw images pixel values: injected from outside.
824824
self.pixel2feature_preprocessor_fn = pixel2feature_preprocessor_fn
825825

826-
self.embed_tokens = shared_embedding
826+
self.embed_tokens = T5Gemma2TextScaledWordEmbedding(
827+
config.vocab_size,
828+
config.hidden_size,
829+
config.pad_token_id,
830+
embed_scale=config.hidden_size**0.5,
831+
eoi_token_index=eoi_token_index,
832+
)
827833
self.norm = T5Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
828834
self.gradient_checkpointing = False
829835

@@ -1051,13 +1057,11 @@ def forward(
10511057
@auto_docstring
10521058
class T5Gemma2Model(T5Gemma2PreTrainedModel):
10531059
_tied_weights_keys = [
1054-
"embed_tokens.weight",
10551060
"encoder.embed_tokens.weight",
10561061
"decoder.embed_tokens.weight",
10571062
]
10581063
_dynamic_tied_weights_keys = [
10591064
"encoder.embed_tokens.eoi_embedding",
1060-
"embed_tokens.eoi_embedding",
10611065
"decoder.embed_tokens.eoi_embedding",
10621066
]
10631067

@@ -1079,18 +1083,9 @@ def __init__(self, config: T5Gemma2Config):
10791083
f"encoder ({config.encoder.vocab_size}) vs decoder ({config.decoder.vocab_size})."
10801084
)
10811085

1082-
# shared embedding
1083-
self.embed_tokens = T5Gemma2TextScaledWordEmbedding(
1084-
config.encoder.vocab_size,
1085-
config.encoder.hidden_size,
1086-
config.encoder.pad_token_id,
1087-
embed_scale=config.encoder.hidden_size**0.5,
1088-
eoi_token_index=config.eoi_token_index,
1089-
)
1090-
10911086
# setup encoder and decoder
1092-
self.encoder = T5Gemma2Encoder(config.encoder, self.embed_tokens, self.pixel2feature_preprocessor)
1093-
self.decoder = T5Gemma2Decoder(config.decoder, self.embed_tokens)
1087+
self.encoder = T5Gemma2Encoder(config.encoder, config.eoi_token_index, self.pixel2feature_preprocessor)
1088+
self.decoder = T5Gemma2Decoder(config.decoder, config.eoi_token_index)
10941089

10951090
# setup vision encoder
10961091
self.vision_tower = AutoModel.from_config(config=config.vision_config)
@@ -1113,10 +1108,8 @@ def set_input_embeddings(self, new_embeddings):
11131108
def _tie_weights(self):
11141109
# Decoder input and output embeddings are tied.
11151110
if self.config.tie_word_embeddings:
1116-
self.encoder.embed_tokens.weight = self.embed_tokens.weight
1117-
self.decoder.embed_tokens.weight = self.embed_tokens.weight
1118-
self.encoder.embed_tokens.eoi_embedding = self.embed_tokens.eoi_embedding
1119-
self.decoder.embed_tokens.eoi_embedding = self.embed_tokens.eoi_embedding
1111+
self.decoder.embed_tokens.weight = self.encoder.embed_tokens.weight
1112+
self.decoder.embed_tokens.eoi_embedding = self.encoder.embed_tokens.eoi_embedding
11201113

11211114
def _get_image_features(self, pixel_values: torch.Tensor) -> torch.Tensor:
11221115
"""Convert pixel image to image features via the encoder and projector."""
@@ -1234,14 +1227,12 @@ def forward(
12341227

12351228
class T5Gemma2ForConditionalGeneration(T5Gemma2PreTrainedModel, GenerationMixin):
12361229
_tied_weights_keys = [
1237-
"model.embed_tokens.weight",
12381230
"lm_head.out_proj.weight",
12391231
"model.encoder.embed_tokens.weight",
12401232
"model.decoder.embed_tokens.weight",
12411233
]
12421234
_dynamic_tied_weights_keys = [
12431235
"model.encoder.embed_tokens.eoi_embedding",
1244-
"model.embed_tokens.eoi_embedding",
12451236
"model.decoder.embed_tokens.eoi_embedding",
12461237
]
12471238
_tp_plan = {"lm_head.out_proj": "colwise_rep"}
@@ -1273,7 +1264,7 @@ def set_input_embeddings(self, value):
12731264
def _tie_weights(self):
12741265
# Decoder input and output embeddings are tied.
12751266
if self.config.tie_word_embeddings:
1276-
self.lm_head.out_proj.weight = self.model.embed_tokens.weight
1267+
self.lm_head.out_proj.weight = self.model.encoder.embed_tokens.weight
12771268

12781269
def get_encoder(self):
12791270
return self.model.get_encoder()
@@ -1382,13 +1373,11 @@ def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):
13821373
@auto_docstring
13831374
class T5Gemma2ForSequenceClassification(T5Gemma2PreTrainedModel):
13841375
_tied_weights_keys = [
1385-
"model.embed_tokens.weight",
13861376
"model.encoder.embed_tokens.weight",
13871377
"model.decoder.embed_tokens.weight",
13881378
]
13891379
_dynamic_tied_weights_keys = [
13901380
"model.encoder.embed_tokens.eoi_embedding",
1391-
"model.embed_tokens.eoi_embedding",
13921381
"model.decoder.embed_tokens.eoi_embedding",
13931382
]
13941383

@@ -1491,13 +1480,11 @@ def forward(
14911480
@auto_docstring
14921481
class T5Gemma2ForTokenClassification(T5Gemma2PreTrainedModel):
14931482
_tied_weights_keys = [
1494-
"model.embed_tokens.weight",
14951483
"model.encoder.embed_tokens.weight",
14961484
"model.decoder.embed_tokens.weight",
14971485
]
14981486
_dynamic_tied_weights_keys = [
14991487
"model.encoder.embed_tokens.eoi_embedding",
1500-
"model.embed_tokens.eoi_embedding",
15011488
"model.decoder.embed_tokens.eoi_embedding",
15021489
]
15031490

src/transformers/models/t5gemma2/modular_t5gemma2.py

Lines changed: 13 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -672,7 +672,7 @@ class T5Gemma2Encoder(T5Gemma2PreTrainedModel):
672672
def __init__(
673673
self,
674674
config: T5Gemma2ModuleConfig,
675-
shared_embedding: T5Gemma2TextScaledWordEmbedding,
675+
eoi_token_index: int = 256_000,
676676
pixel2feature_preprocessor_fn: Optional[Callable] = None,
677677
):
678678
super().__init__(config)
@@ -682,7 +682,13 @@ def __init__(
682682
# preprocessor for raw images pixel values: injected from outside.
683683
self.pixel2feature_preprocessor_fn = pixel2feature_preprocessor_fn
684684

685-
self.embed_tokens = shared_embedding
685+
self.embed_tokens = T5Gemma2TextScaledWordEmbedding(
686+
config.vocab_size,
687+
config.hidden_size,
688+
config.pad_token_id,
689+
embed_scale=config.hidden_size**0.5,
690+
eoi_token_index=eoi_token_index,
691+
)
686692
self.norm = T5Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
687693
self.gradient_checkpointing = False
688694

@@ -910,13 +916,11 @@ def forward(
910916
@auto_docstring
911917
class T5Gemma2Model(T5Gemma2PreTrainedModel):
912918
_tied_weights_keys = [
913-
"embed_tokens.weight",
914919
"encoder.embed_tokens.weight",
915920
"decoder.embed_tokens.weight",
916921
]
917922
_dynamic_tied_weights_keys = [
918923
"encoder.embed_tokens.eoi_embedding",
919-
"embed_tokens.eoi_embedding",
920924
"decoder.embed_tokens.eoi_embedding",
921925
]
922926

@@ -938,18 +942,9 @@ def __init__(self, config: T5Gemma2Config):
938942
f"encoder ({config.encoder.vocab_size}) vs decoder ({config.decoder.vocab_size})."
939943
)
940944

941-
# shared embedding
942-
self.embed_tokens = T5Gemma2TextScaledWordEmbedding(
943-
config.encoder.vocab_size,
944-
config.encoder.hidden_size,
945-
config.encoder.pad_token_id,
946-
embed_scale=config.encoder.hidden_size**0.5,
947-
eoi_token_index=config.eoi_token_index,
948-
)
949-
950945
# setup encoder and decoder
951-
self.encoder = T5Gemma2Encoder(config.encoder, self.embed_tokens, self.pixel2feature_preprocessor)
952-
self.decoder = T5Gemma2Decoder(config.decoder, self.embed_tokens)
946+
self.encoder = T5Gemma2Encoder(config.encoder, config.eoi_token_index, self.pixel2feature_preprocessor)
947+
self.decoder = T5Gemma2Decoder(config.decoder, config.eoi_token_index)
953948

954949
# setup vision encoder
955950
self.vision_tower = AutoModel.from_config(config=config.vision_config)
@@ -972,10 +967,8 @@ def set_input_embeddings(self, new_embeddings):
972967
def _tie_weights(self):
973968
# Decoder input and output embeddings are tied.
974969
if self.config.tie_word_embeddings:
975-
self.encoder.embed_tokens.weight = self.embed_tokens.weight
976-
self.decoder.embed_tokens.weight = self.embed_tokens.weight
977-
self.encoder.embed_tokens.eoi_embedding = self.embed_tokens.eoi_embedding
978-
self.decoder.embed_tokens.eoi_embedding = self.embed_tokens.eoi_embedding
970+
self.decoder.embed_tokens.weight = self.encoder.embed_tokens.weight
971+
self.decoder.embed_tokens.eoi_embedding = self.encoder.embed_tokens.eoi_embedding
979972

980973
def _get_image_features(self, pixel_values: torch.Tensor) -> torch.Tensor:
981974
"""Convert pixel image to image features via the encoder and projector."""
@@ -1093,14 +1086,12 @@ def forward(
10931086

10941087
class T5Gemma2ForConditionalGeneration(T5Gemma2PreTrainedModel, GenerationMixin):
10951088
_tied_weights_keys = [
1096-
"model.embed_tokens.weight",
10971089
"lm_head.out_proj.weight",
10981090
"model.encoder.embed_tokens.weight",
10991091
"model.decoder.embed_tokens.weight",
11001092
]
11011093
_dynamic_tied_weights_keys = [
11021094
"model.encoder.embed_tokens.eoi_embedding",
1103-
"model.embed_tokens.eoi_embedding",
11041095
"model.decoder.embed_tokens.eoi_embedding",
11051096
]
11061097
_tp_plan = {"lm_head.out_proj": "colwise_rep"}
@@ -1132,7 +1123,7 @@ def set_input_embeddings(self, value):
11321123
def _tie_weights(self):
11331124
# Decoder input and output embeddings are tied.
11341125
if self.config.tie_word_embeddings:
1135-
self.lm_head.out_proj.weight = self.model.embed_tokens.weight
1126+
self.lm_head.out_proj.weight = self.model.encoder.embed_tokens.weight
11361127

11371128
def get_encoder(self):
11381129
return self.model.get_encoder()
@@ -1241,13 +1232,11 @@ def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):
12411232
@auto_docstring
12421233
class T5Gemma2ForSequenceClassification(T5Gemma2PreTrainedModel):
12431234
_tied_weights_keys = [
1244-
"model.embed_tokens.weight",
12451235
"model.encoder.embed_tokens.weight",
12461236
"model.decoder.embed_tokens.weight",
12471237
]
12481238
_dynamic_tied_weights_keys = [
12491239
"model.encoder.embed_tokens.eoi_embedding",
1250-
"model.embed_tokens.eoi_embedding",
12511240
"model.decoder.embed_tokens.eoi_embedding",
12521241
]
12531242

@@ -1350,13 +1339,11 @@ def forward(
13501339
@auto_docstring
13511340
class T5Gemma2ForTokenClassification(T5Gemma2PreTrainedModel):
13521341
_tied_weights_keys = [
1353-
"model.embed_tokens.weight",
13541342
"model.encoder.embed_tokens.weight",
13551343
"model.decoder.embed_tokens.weight",
13561344
]
13571345
_dynamic_tied_weights_keys = [
13581346
"model.encoder.embed_tokens.eoi_embedding",
1359-
"model.embed_tokens.eoi_embedding",
13601347
"model.decoder.embed_tokens.eoi_embedding",
13611348
]
13621349

0 commit comments

Comments
 (0)