Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TrainingEpochLoop._should_check_val_fx discrepancy between continued run <> restore from ckpt #14579

Closed
Anner-deJong opened this issue Sep 7, 2022 · 4 comments · Fixed by #20379
Labels
bug Something isn't working checkpointing Related to checkpointing help wanted Open to be worked on loops Related to the Loop API

Comments

@Anner-deJong
Copy link
Contributor

Anner-deJong commented Sep 7, 2022

🐛 Bug

Found a discrepancy between a continued run after checkpointing, and restoring from checkpoint

Observation:

training_batch / val_loop ordering upon checkpoint restoration not the same as original run after checkpoint saving.

There are still the same amount of train steps, but the validation loops are interleaved at a single step later, which can cause the restored run to end up with one less validation loop (see colab)

Assumption / expectation:

Zero difference between a training run after a checkpoint and a run continued from said checkpoint

Investigation so far:

Im new to some of this lightning code, but IIUC:

Key:

TrainingEpochLoop's self.batch_progress.increment_completed() is called after "on_train_batch_end" hooks, the latter kicking off checkpoint saving.

  1. upon restoring, the TrainingEpochLoop.batch_progress.current.reset_on_restart() will reset the ready back to completed
  2. yet the global_step, which refers to TrainingEpochLoop.batch_loop.optimizer_loop.optim_progress.optimizer.step.total.completed, has been increment_completed() (called within TrainingEpochLoop.batch_loop.run) and thus upon restoring, ..optimizer.step.total.ready is set to an up to date optimizer.step.total.completed, out of sync with the above
  3. [simplification] in "val_check_interval mode", validation is triggered when TrainingEpochLoop.batch_progress.current.ready % val_check_interval == 0 (through TrainingEpochLoop.on_advance_end -> TrainingEpochLoop._should_check_val_fx
  4. combining the above three, the same batch_progress.current ready/completed counter for the continued and restored runs, end up aligned with different global_steps, and hence validation triggers at different global_steps

Another observation:

The following if statement seems to allow for a zero-difference restart, except that just like 4. above, _should_check_val_fx wouldnt trigger where in the original run on the checkpointing step it did (although there called in on_advance_end). Not sure if the original intention of this snippet included the current scope

class TrainingEpochLoop(loops.Loop[_OUTPUTS_TYPE]):
    ...
    def advance(self, data_fetcher: AbstractDataFetcher) -> None:  # type: ignore[override]
        ...
        if self.restarting and self._should_check_val_fx(self.batch_idx, self.batch_progress.is_last_batch):
            # skip training and run validation in `on_advance_end`
            return

PR's relevant to this line:

Potential impact:

Assuming not too worrisome for the more default Lightning use cases:

  • With val_check_interval >> 3 (colab example = 3), or that turned off relying instead on check_val_every_n_epoch

However, in theory it can influence all of the following:

  • no 1:1 deterministic reproducibility
  • affect the latest/best validation loss
    • affects any code flow / decision making based on that
  • cause a "different usage order" of rngs (<- how I initially caught the issue, even with correctly restored rng states, if both validation and training steps use one theyll each end up with different random numbers as compared to the continued run)
  • other

To Reproduce

customized google colab bug_report_model.ipynb with same observation on BoringModel

Expected behavior

Zero difference between a training run continued after a checkpoint and a run continued from said checkpoint

Environment

Note:

  • The below is from original investigation in our own code base, with pytorch lightning v1.6.4.
  • The environment details from the BoringModel's reproduction are listed in the colab, with pytorch lighting v1.7.4
  • I also browsed through the master branch last weeks and the relevant code seems unchanged
Details
  • CUDA:
    • GPU:
      • NVIDIA RTX A4000
      • NVIDIA RTX A4000
      • NVIDIA RTX A4000
      • NVIDIA RTX A4000
    • available: True
    • version: 11.0
  • Lightning:
    • efficientnet-pytorch: 0.7.1
    • pytorch-lightning: 1.6.4
    • torch: 1.11.0.post1103
    • torchmetrics: 0.7.0
    • torchvision: 0.12.0a1110.post1103
  • Packages:
    • absl-py: 0.15.0
    • adal: 1.2.7
    • adlfs: 2021.10.0
    • aiohttp: 3.7.4
    • applicationinsights: 0.11.10
    • argcomplete: 1.12.3
    • async-timeout: 3.0.1
    • attrdict: 2.0.0
    • attrs: 21.1.0
    • av: 8.0.3
    • azure-cli-core: 2.38.0
    • azure-cli-telemetry: 1.0.6
    • azure-common: 1.1.27
    • azure-core: 1.20.0
    • azure-datalake-store: 0.0.52
    • azure-identity: 1.10.0
    • azure-keyvault-secrets: 4.2.0
    • azure-mgmt-core: 1.2.2
    • azure-storage-blob: 12.11.0
    • backcall: 0.2.0
    • backoff: 1.10.0
    • bcrypt: 3.2.0
    • cachetools: 4.2.2
    • certifi: 2020.12.5
    • cffi: 1.14.5
    • chardet: 3.0.4
    • charset-normalizer: 2.0.12
    • click: 7.1.2
    • confluent-kafka: 1.7.0
    • cryptography: 3.4.8
    • cycler: 0.10.0
    • datadog: 0.44.0
    • decorator: 5.0.7
    • deepdiff: 5.5.0
    • deltalake: 0.5.8
    • docker-pycreds: 0.4.0
    • efficientnet-pytorch: 0.7.1
    • einops: 0.4.1
    • filelock: 3.7.1
    • fonttools: 4.37.1
    • frozendict: 2.3.2
    • fsspec: 2022.1.0
    • gitdb: 4.0.7
    • gitpython: 3.1.14
    • google-auth: 1.30.0
    • google-auth-oauthlib: 0.4.4
    • grpcio: 1.37.1
    • htmlmin: 0.1.12
    • humanfriendly: 10.0
    • idna: 2.10
    • imagehash: 4.2.1
    • inplace-abn: 1.1.0a1110.post1103
    • ipdb: 0.13.9
    • ipython: 7.23.1
    • isodate: 0.6.0
    • jedi: 0.18.0
    • jinja2: 3.1.2
    • jmespath: 0.10.0
    • joblib: 1.0.1
    • kafka-python: 2.0.2
    • kiwisolver: 1.3.1
    • knack: 0.9.0
    • markdown: 3.3.4
    • markupsafe: 2.0.1
    • matplotlib: 3.5.3
    • matplotlib-inline: 0.1.2
    • methodtools: 0.1.2
    • missingno: 0.5.0
    • msal: 1.16.0
    • msal-extensions: 0.3.0
    • msrest: 0.6.21
    • msrestazure: 0.6.4
    • multidict: 5.1.0
    • multimethod: 1.6
    • networkx: 2.5.1
    • numpy: 1.22.4
    • oauthlib: 3.1.0
    • opencv-python: 4.4.0.44
    • ordered-set: 4.0.2
    • packaging: 21.3
    • pandas: 1.4.3
    • pandas-profiling: 3.1.0
    • paramiko: 2.7.2
    • parso: 0.8.2
    • pathtools: 0.1.2
    • pexpect: 4.8.0
    • phik: 0.12.0
    • pickleshare: 0.7.5
    • pillow: 9.2.0
    • pip: 22.0.3
    • pkginfo: 1.7.0
    • polyline: 1.4.0
    • portalocker: 1.7.1
    • prometheus-client: 0.8.0
    • promise: 2.3
    • prompt-toolkit: 2.0.10
    • protobuf: 3.15.8
    • psutil: 5.9.1
    • psycopg2: 2.8.3
    • ptyprocess: 0.7.0
    • py: 1.10.0
    • py3nvml: 0.2.7
    • pyarrow: 9.0.0
    • pyasn1: 0.4.8
    • pyasn1-modules: 0.2.8
    • pycparser: 2.20
    • pydantic: 1.8.2
    • pydeprecate: 0.3.1
    • pygame: 2.1.2
    • pygments: 2.9.0
    • pyjwt: 1.7.1
    • pynacl: 1.4.0
    • pyntcloud: 0.1.6
    • pyopenssl: 20.0.1
    • pyparsing: 2.4.7
    • pyquaternion: 0.9.9
    • pysocks: 1.7.1
    • python-dateutil: 2.8.2
    • python-json-logger: 2.0.2
    • pytorch-lightning: 1.6.4
    • pytz: 2022.1
    • pywavelets: 1.1.1
    • pyyaml: 6.0
    • qrcode: 6.1
    • requests: 2.27.1
    • requests-oauthlib: 1.3.0
    • retry: 0.9.2
    • rsa: 4.7.2
    • runai: 0.3.0
    • scipy: 1.6.2
    • seaborn: 0.11.2
    • semver: 2.13.0
    • sentry-sdk: 1.9.4
    • setproctitle: 1.2.2
    • setuptools: 59.5.0
    • shapely: 1.8.0
    • shortuuid: 1.0.1
    • simplejpeg: 1.4.1
    • six: 1.16.0
    • slackclient: 2.9.4
    • smmap: 4.0.0
    • sqlalchemy: 1.3.24
    • tabulate: 0.8.9
    • tangled-up-in-unicode: 0.1.0
    • tensorboard: 2.6.0
    • tensorboard-data-server: 0.6.1
    • tensorboard-plugin-wit: 1.8.0
    • timm: 0.4.5
    • toml: 0.10.2
    • torch: 1.11.0.post1103
    • torchmetrics: 0.7.0
    • torchvision: 0.12.0a1110.post1103
    • tqdm: 4.60.0
    • traitlets: 5.3.0
    • transforms3d: 0.3.1
    • typing-extensions: 4.1.1
    • urllib3: 1.26.11
    • visions: 0.7.4
    • wandb: 0.12.14
    • wcwidth: 0.2.5
    • werkzeug: 1.0.1
    • wheel: 0.36.2
    • wirerope: 0.3.1
    • wrapt: 1.14.1
    • xmltodict: 0.12.0
    • xxhash: 1.4.1
    • yarl: 1.6.3
  • System:

Additional context

cc @awaelchli @ananthsub @ninginthecloud @rohitgr7 @otaj @carmocca @justusschock

@Anner-deJong Anner-deJong added the needs triage Waiting to be triaged by maintainers label Sep 7, 2022
@Anner-deJong
Copy link
Contributor Author

I'd be happy to help fix if agreed this is a bug, but like the opinion + take on a solution from folks more involved with pytorch lightning.
Some simple ideas, but none of them great:

  • Give the model checkpoint callback a special status compared to others, call it at the end of training_epoch_loop.on_advance_end
    • Sounds like this really shouldnt be required + makes codes fragile/ harder to maintain
  • have a relevant hook called at the beginning of training_epoch_loop.on_advance_end that checkpoints
    • why would we need yet another hook? And if we move one, I assume any hooks location in code flow was well thought out already
  • call on_train_batch_end after the increment_completed
    • definition based: does on_train_batch_end mean it wont change training module anymore and could be run after the increment? Is there logic in the on_train_batch_end that expects "completed" not to be incremented yet?

@krshrimali krshrimali added checkpointing Related to checkpointing bug Something isn't working and removed needs triage Waiting to be triaged by maintainers labels Sep 7, 2022
@krshrimali
Copy link
Contributor

Hi, @Anner-deJong - Thank you for creating this issue, and helping with all the context around this. Probably @awaelchli or @carmocca can help you on this one.

@rohitgr7 rohitgr7 added the loops Related to the Loop API label Sep 7, 2022
@carmocca carmocca self-assigned this Sep 7, 2022
@awaelchli awaelchli added the help wanted Open to be worked on label Dec 31, 2023
@carmocca carmocca removed their assignment Jul 30, 2024
@lantiga
Copy link
Collaborator

lantiga commented Oct 30, 2024

Getting back to this issue, as it came up recently in a different context. This is definitely a behavior that needs fixing.

My take would be to not add an extra hook, but invoke increment_completed prior to on_train_batch_end. Which one comes before is very debatable from a definition standpoint. Having increments already set up correctly when I'm in my on_train_batch_end hook is a fair expectation IMO.

The sequence would go from:

self.batch_progress.increment_ready()
on_train_batch_start
self.batch_progress.increment_started()
...
self.batch_progress.increment_processed()
...
on_train_batch_end
self.batch_progress.increment_completed()

to

self.batch_progress.increment_ready()
on_train_batch_start
self.batch_progress.increment_started()
...
self.batch_progress.increment_processed()
...
self.batch_progress.increment_completed()
on_train_batch_end

Not sure what do think about the started part, but that's for another time.

@lantiga
Copy link
Collaborator

lantiga commented Oct 31, 2024

On further thought, I'll take a more conservative approach since the implications are pretty wide.

@Borda Borda changed the title TrainingEpochLoop._should_check_val_fx discrepancy between continued run <> restore from ckpt TrainingEpochLoop._should_check_val_fx discrepancy between continued run <> restore from ckpt Nov 12, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working checkpointing Related to checkpointing help wanted Open to be worked on loops Related to the Loop API
Projects
None yet
6 participants