Skip to content

Commit ce7a5e0

Browse files
authored
Correctly create tied key mapping in post_init, and dynamic tie weight (#42270)
* add dynamic * improve * doc * true dynamic * everywhere * improve * fix * more * small fix * small fix * fix duplicates * fix * doc * fix * improve doc * comment * more doc * style
1 parent f15b95e commit ce7a5e0

File tree

12 files changed

+164
-116
lines changed

12 files changed

+164
-116
lines changed

src/transformers/modeling_utils.py

Lines changed: 147 additions & 100 deletions
Large diffs are not rendered by default.

src/transformers/models/esm/modeling_esm.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -732,8 +732,6 @@ def __init__(self, config):
732732
self.esm = EsmModel(config, add_pooling_layer=False)
733733
self.lm_head = EsmLMHead(config)
734734

735-
self.init_weights()
736-
737735
self.post_init()
738736

739737
def get_output_embeddings(self):
@@ -828,8 +826,6 @@ def __init__(self, config):
828826
self.esm = EsmModel(config, add_pooling_layer=False)
829827
self.classifier = EsmClassificationHead(config)
830828

831-
self.init_weights()
832-
833829
self.post_init()
834830

835831
@can_return_tuple
@@ -903,8 +899,6 @@ def __init__(self, config):
903899
self.dropout = nn.Dropout(config.hidden_dropout_prob)
904900
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
905901

906-
self.init_weights()
907-
908902
self.post_init()
909903

910904
@can_return_tuple

src/transformers/models/hubert/modeling_hubert.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -993,7 +993,7 @@ def __init__(self, config, target_lang: Optional[str] = None):
993993
# Initialize weights and apply final processing
994994
self.post_init()
995995

996-
def tie_weights(self, missing_keys=None):
996+
def tie_weights(self, **kwargs):
997997
"""
998998
This method overwrites [`~PreTrainedModel.tie_weights`] so that adapter weights can be correctly loaded when
999999
passing `target_lang=...` to `from_pretrained(...)`.

src/transformers/models/idefics/modeling_idefics.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1111,7 +1111,7 @@ def __init__(self, config, vision_model=None):
11111111
# Initialize weights and apply final processing
11121112
self.post_init()
11131113

1114-
def tie_weights(self, missing_keys=None):
1114+
def tie_weights(self, **kwargs):
11151115
"""
11161116
Overwrite `transformers.modeling_utils.PreTrainedModel.tie_weights` to handle the case of
11171117
IdeficsDecoupledLinear and IdeficsDecoupledEmbedding.

src/transformers/models/sew/modeling_sew.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -857,7 +857,7 @@ def __init__(self, config, target_lang: Optional[str] = None):
857857
# Initialize weights and apply final processing
858858
self.post_init()
859859

860-
def tie_weights(self, missing_keys=None):
860+
def tie_weights(self, **kwargs):
861861
"""
862862
This method overwrites [`~PreTrainedModel.tie_weights`] so that adapter weights can be correctly loaded when
863863
passing `target_lang=...` to `from_pretrained(...)`.

src/transformers/models/sew_d/modeling_sew_d.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1400,7 +1400,7 @@ def __init__(self, config, target_lang: Optional[str] = None):
14001400
# Initialize weights and apply final processing
14011401
self.post_init()
14021402

1403-
def tie_weights(self, missing_keys=None):
1403+
def tie_weights(self, **kwargs):
14041404
"""
14051405
This method overwrites [`~PreTrainedModel.tie_weights`] so that adapter weights can be correctly loaded when
14061406
passing `target_lang=...` to `from_pretrained(...)`.

src/transformers/models/unispeech/modeling_unispeech.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1210,7 +1210,7 @@ def __init__(self, config, target_lang: Optional[str] = None):
12101210
# Initialize weights and apply final processing
12111211
self.post_init()
12121212

1213-
def tie_weights(self, missing_keys=None):
1213+
def tie_weights(self, **kwargs):
12141214
"""
12151215
This method overwrites [`~PreTrainedModel.tie_weights`] so that adapter weights can be correctly loaded when
12161216
passing `target_lang=...` to `from_pretrained(...)`.

src/transformers/models/unispeech_sat/modeling_unispeech_sat.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1206,7 +1206,7 @@ def __init__(self, config, target_lang: Optional[str] = None):
12061206
# Initialize weights and apply final processing
12071207
self.post_init()
12081208

1209-
def tie_weights(self, missing_keys=None):
1209+
def tie_weights(self, **kwargs):
12101210
"""
12111211
This method overwrites [`~PreTrainedModel.tie_weights`] so that adapter weights can be correctly loaded when
12121212
passing `target_lang=...` to `from_pretrained(...)`.

src/transformers/models/wav2vec2/modeling_wav2vec2.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1684,7 +1684,7 @@ def __init__(self, config, target_lang: Optional[str] = None):
16841684
# Initialize weights and apply final processing
16851685
self.post_init()
16861686

1687-
def tie_weights(self, missing_keys=None):
1687+
def tie_weights(self, **kwargs):
16881688
"""
16891689
This method overwrites [`~PreTrainedModel.tie_weights`] so that adapter weights can be correctly loaded when
16901690
passing `target_lang=...` to `from_pretrained(...)`.

src/transformers/models/wavlm/modeling_wavlm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1135,7 +1135,7 @@ def __init__(self, config, target_lang: Optional[str] = None):
11351135
# Initialize weights and apply final processing
11361136
self.post_init()
11371137

1138-
def tie_weights(self, missing_keys=None):
1138+
def tie_weights(self, **kwargs):
11391139
"""
11401140
This method overwrites [`~PreTrainedModel.tie_weights`] so that adapter weights can be correctly loaded when
11411141
passing `target_lang=...` to `from_pretrained(...)`.

0 commit comments

Comments
 (0)