-
Notifications
You must be signed in to change notification settings - Fork 3.6k
Exposeweights_only
for loading checkpoints with Trainer
, LightningModule
, LightningDataModule
#21072
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Exposeweights_only
for loading checkpoints with Trainer
, LightningModule
, LightningDataModule
#21072
Conversation
Codecov Report❌ Patch coverage is ❌ Your project check has failed because the head coverage (49%) is below the target coverage (99%). You can increase the head coverage or adjust the target coverage.
Additional details and impacted files@@ Coverage Diff @@
## master #21072 +/- ##
=========================================
- Coverage 87% 49% -39%
=========================================
Files 269 266 -3
Lines 23520 23468 -52
=========================================
- Hits 20503 11395 -9108
- Misses 3017 12073 +9056 |
… based on ckpt version
d7cb702
to
601e300
Compare
@@ -56,11 +56,17 @@ def _load_from_checkpoint( | |||
map_location: _MAP_LOCATION_TYPE = None, | |||
hparams_file: Optional[_PATH] = None, | |||
strict: Optional[bool] = None, | |||
weights_only: Optional[bool] = None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we default to weights_only=None
or weights_only=True
? If we have no use for weights_only=None
, we can simplify the type hint to weights_only: bool = True
.
@@ -45,7 +46,12 @@ def test_load_legacy_checkpoints(tmp_path, pl_version: str): | |||
assert path_ckpts, f'No checkpoints found in folder "{PATH_LEGACY}"' | |||
path_ckpt = path_ckpts[-1] | |||
|
|||
model = ClassificationModel.load_from_checkpoint(path_ckpt, num_features=24) | |||
# legacy load utility added in 1.5.0 (see https://github.com/Lightning-AI/pytorch-lightning/pull/9166) | |||
if pl_version == "local": |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is the simplest way that I could think of ensuring we continue testing the legacy checkpoints. Another way could be to use torch.serialization.add_safe_globals
, but it seems a little more complicated (particularly since we're using the pl_legacy_patch
context manager already.
weights_only=True
by defaultweights_only=True
by default for loading weights
@Borda I wanted to get your opinion on something before moving forward. I've added My issue right now is with resuming training from a checkpoint with
I'm leaning towards option 1, but it involves changing up |
The cleanest way would probably be 1), but it brings so many new arguments for a marginal use... so personally I would go with 2) |
Hi @matsumotosan let's do that only if the underlying torch is >= 2.6 (since starting weights_only became True by default from that point on), otherwise we're going to break a lot of older code |
* try `deepspeed >=0.14.1,<=0.15.0` * drop from oldest * pip uninstall -y deepspeed * error::DeprecationWarning
I am not sure if it's possible to default to The big issue with context managers is that a different one has to be used each time a different checkpoint is loaded. Setting the environment variable With this in mind, I think passing If we need to force the I have also added
Maybe we could default add |
weights_only=True
by default for loading weightsweights_only
for loading checkpoints with Trainer
, LightningModule
, LightningDataModule
@@ -51,6 +56,9 @@ def _load( | |||
weights_only=weights_only, | |||
) | |||
if str(path_or_url).startswith("http"): | |||
if weights_only is None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if weights_only is None: | |
if weights_only is None and _TORCH_GREATER_EQUAL_2_6: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
torch.load_state_dict_from_url
defaults to weights_only=False
even for 2.8. Should we still default to weights_only=True
in this case?
For cases of torch.load
, I am not modifying weights_only
if the function receives weights_only=None
as the underlying torch.load
function will handle defaults. We cannot do that with torch.load_state_dict_from_url
as the argument is weights_only: bool = False
.
What does this PR do?
Fixes #20450 #20058 #20643
Before submitting
PR review
Anyone in the community is welcome to review the PR.
Before you start reviewing, make sure you have read the review guidelines. In short, see the following bullet-list:
Reviewer checklist
📚 Documentation preview 📚: https://pytorch-lightning--21072.org.readthedocs.build/en/21072/