@@ -813,7 +813,7 @@ class T5Gemma2Encoder(T5Gemma2PreTrainedModel):
813813 def __init__ (
814814 self ,
815815 config : T5Gemma2ModuleConfig ,
816- shared_embedding : T5Gemma2TextScaledWordEmbedding ,
816+ eoi_token_index : int = 256_000 ,
817817 pixel2feature_preprocessor_fn : Optional [Callable ] = None ,
818818 ):
819819 super ().__init__ (config )
@@ -823,7 +823,13 @@ def __init__(
823823 # preprocessor for raw images pixel values: injected from outside.
824824 self .pixel2feature_preprocessor_fn = pixel2feature_preprocessor_fn
825825
826- self .embed_tokens = shared_embedding
826+ self .embed_tokens = T5Gemma2TextScaledWordEmbedding (
827+ config .vocab_size ,
828+ config .hidden_size ,
829+ config .pad_token_id ,
830+ embed_scale = config .hidden_size ** 0.5 ,
831+ eoi_token_index = eoi_token_index ,
832+ )
827833 self .norm = T5Gemma2RMSNorm (config .hidden_size , eps = config .rms_norm_eps )
828834 self .gradient_checkpointing = False
829835
@@ -1051,13 +1057,11 @@ def forward(
10511057@auto_docstring
10521058class T5Gemma2Model (T5Gemma2PreTrainedModel ):
10531059 _tied_weights_keys = [
1054- "embed_tokens.weight" ,
10551060 "encoder.embed_tokens.weight" ,
10561061 "decoder.embed_tokens.weight" ,
10571062 ]
10581063 _dynamic_tied_weights_keys = [
10591064 "encoder.embed_tokens.eoi_embedding" ,
1060- "embed_tokens.eoi_embedding" ,
10611065 "decoder.embed_tokens.eoi_embedding" ,
10621066 ]
10631067
@@ -1079,18 +1083,9 @@ def __init__(self, config: T5Gemma2Config):
10791083 f"encoder ({ config .encoder .vocab_size } ) vs decoder ({ config .decoder .vocab_size } )."
10801084 )
10811085
1082- # shared embedding
1083- self .embed_tokens = T5Gemma2TextScaledWordEmbedding (
1084- config .encoder .vocab_size ,
1085- config .encoder .hidden_size ,
1086- config .encoder .pad_token_id ,
1087- embed_scale = config .encoder .hidden_size ** 0.5 ,
1088- eoi_token_index = config .eoi_token_index ,
1089- )
1090-
10911086 # setup encoder and decoder
1092- self .encoder = T5Gemma2Encoder (config .encoder , self . embed_tokens , self .pixel2feature_preprocessor )
1093- self .decoder = T5Gemma2Decoder (config .decoder , self . embed_tokens )
1087+ self .encoder = T5Gemma2Encoder (config .encoder , config . eoi_token_index , self .pixel2feature_preprocessor )
1088+ self .decoder = T5Gemma2Decoder (config .decoder , config . eoi_token_index )
10941089
10951090 # setup vision encoder
10961091 self .vision_tower = AutoModel .from_config (config = config .vision_config )
@@ -1113,10 +1108,8 @@ def set_input_embeddings(self, new_embeddings):
11131108 def _tie_weights (self ):
11141109 # Decoder input and output embeddings are tied.
11151110 if self .config .tie_word_embeddings :
1116- self .encoder .embed_tokens .weight = self .embed_tokens .weight
1117- self .decoder .embed_tokens .weight = self .embed_tokens .weight
1118- self .encoder .embed_tokens .eoi_embedding = self .embed_tokens .eoi_embedding
1119- self .decoder .embed_tokens .eoi_embedding = self .embed_tokens .eoi_embedding
1111+ self .decoder .embed_tokens .weight = self .encoder .embed_tokens .weight
1112+ self .decoder .embed_tokens .eoi_embedding = self .encoder .embed_tokens .eoi_embedding
11201113
11211114 def _get_image_features (self , pixel_values : torch .Tensor ) -> torch .Tensor :
11221115 """Convert pixel image to image features via the encoder and projector."""
@@ -1234,14 +1227,12 @@ def forward(
12341227
12351228class T5Gemma2ForConditionalGeneration (T5Gemma2PreTrainedModel , GenerationMixin ):
12361229 _tied_weights_keys = [
1237- "model.embed_tokens.weight" ,
12381230 "lm_head.out_proj.weight" ,
12391231 "model.encoder.embed_tokens.weight" ,
12401232 "model.decoder.embed_tokens.weight" ,
12411233 ]
12421234 _dynamic_tied_weights_keys = [
12431235 "model.encoder.embed_tokens.eoi_embedding" ,
1244- "model.embed_tokens.eoi_embedding" ,
12451236 "model.decoder.embed_tokens.eoi_embedding" ,
12461237 ]
12471238 _tp_plan = {"lm_head.out_proj" : "colwise_rep" }
@@ -1273,7 +1264,7 @@ def set_input_embeddings(self, value):
12731264 def _tie_weights (self ):
12741265 # Decoder input and output embeddings are tied.
12751266 if self .config .tie_word_embeddings :
1276- self .lm_head .out_proj .weight = self .model .embed_tokens .weight
1267+ self .lm_head .out_proj .weight = self .model .encoder . embed_tokens .weight
12771268
12781269 def get_encoder (self ):
12791270 return self .model .get_encoder ()
@@ -1382,13 +1373,11 @@ def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):
13821373@auto_docstring
13831374class T5Gemma2ForSequenceClassification (T5Gemma2PreTrainedModel ):
13841375 _tied_weights_keys = [
1385- "model.embed_tokens.weight" ,
13861376 "model.encoder.embed_tokens.weight" ,
13871377 "model.decoder.embed_tokens.weight" ,
13881378 ]
13891379 _dynamic_tied_weights_keys = [
13901380 "model.encoder.embed_tokens.eoi_embedding" ,
1391- "model.embed_tokens.eoi_embedding" ,
13921381 "model.decoder.embed_tokens.eoi_embedding" ,
13931382 ]
13941383
@@ -1491,13 +1480,11 @@ def forward(
14911480@auto_docstring
14921481class T5Gemma2ForTokenClassification (T5Gemma2PreTrainedModel ):
14931482 _tied_weights_keys = [
1494- "model.embed_tokens.weight" ,
14951483 "model.encoder.embed_tokens.weight" ,
14961484 "model.decoder.embed_tokens.weight" ,
14971485 ]
14981486 _dynamic_tied_weights_keys = [
14991487 "model.encoder.embed_tokens.eoi_embedding" ,
1500- "model.embed_tokens.eoi_embedding" ,
15011488 "model.decoder.embed_tokens.eoi_embedding" ,
15021489 ]
15031490
0 commit comments