-
Notifications
You must be signed in to change notification settings - Fork 31.2k
Correctly create tied key mapping in post_init, and dynamic tie weight #42270
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
ArthurZucker
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
missing a bit of reprensentative doc! Let's take a t5 as example ? Or rtdetr? to ahve a complexe list
| for prefix, submodule in self.named_modules(): | ||
| if isinstance(submodule, PreTrainedModel): | ||
| # Will dynamically check the config if it has changed | ||
| submodel_tied_weights = submodule.get_expanded_tied_weights_keys(all_submodels=False) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
don't know if we really have to go the inheritance path here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
given that we do named_parameters afterwards
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, in order to check the proper subconfig... No better way unfortunately as sometimes we cannot get the subconfig in a proper way
| source_name = "^" + source_name | ||
| target_name = "^" + target_name | ||
| # In this case, the keys stored in `all_tied_weights_keys` are already correct | ||
| if not recompute_mapping: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
to update with setter and getter for tie_words_embedding no?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, was already checked before!
|
[For maintainers] Suggested jobs to run (before merge) run-slow: esm, hubert, idefics, openai, sew, sew_d, unispeech, unispeech_sat, wav2vec2, wavlm |
ArthurZucker
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for itiretaing ! i like that its explicit now!
As we rely more and more in
self.all_tied_weight_keyseverywhere (i.e. the list of tied keys obtained duringpost_init) for multiple manipulations (device_map computation, cuda warmup, post-processing offrom_pretrained...), it becomes very important that the (few) models containing regex patterns for their_tied_weights_keysmapping have the patterns expanded to fit inall_tied_weight_keysas well, instead of containing simple patterns that are skipped in different ways for all downstream application.This PR fixes that, by expanding correctly at
post_inittime, so the mapping are correct params everywhere.Also allows for recomputing this mapping in
tie_weightsdynamically, so that it is correct if callingtie_weightsafter having modified the config