Custom callback to stop training in DDP. #13834
-
I'm looking to create a callback that, given some external "signal", will stop training & save a checkpoint so that training can be resumed later. Here is the code so far. The "signal" in this case is a file that i create. import os
from pathlib import Path
from pytorch_lightning.callbacks import Callback
class SignalStopCallback(Callback):
'''Given signal, will stop training'''
def __init__(self, signal_fpath: Path, chkpoint_dir: Path):
super().__init__()
self.signal_fpath = signal_fpath
self.chkpoint_dir = chkpoint_dir
def on_train_epoch_end(self, trainer, pl_module):
if self.signal_fpath.exists():
print(f'Signal stop found, stopping at epoch {trainer.current_epoch}')
trainer.save_checkpoint(self.chkpoint_dir.joinpath('last.ckpt'))
trainer.should_stop = True
os.remove(self.signal_fpath) I believe this should work in a single GPU training setting. However, in DDP (one node, 4 GPUs), I believe this will cause undefined behaviour as once the file is removed by one process, I could remedy this situation by having this function not delete the file, and delete it manually myself. But is there a simple way to have the desired functionality here? For example is there a way to set Thanks |
Beta Was this translation helpful? Give feedback.
Replies: 2 comments 3 replies
-
Hi @m-lyon, Use |
Beta Was this translation helpful? Give feedback.
-
If anyone is interested, I made the freak package which can be used for this purpose quite nicely. Here is a callback which allows to stop the training by doing a REST API call to rank 0 worker (port 4444 by default): from typing import Optional
import lightning as L
from freak import Freak
from lightning.pytorch.utilities import rank_zero_only
from pydantic import BaseModel
class TrainingState(BaseModel):
should_stop: bool = False
class TrainingStopCallback(L.Callback):
"""
Callback which stops training when self.state.should_stop is set to True.
"""
def __init__(self, freak: Optional[Freak] = None, state: Optional[TrainingState] = None):
self.freak = freak if freak is not None else Freak(host="127.0.0.1")
self.state = state if state is not None else TrainingState()
@rank_zero_only
def on_train_start(self, trainer: "L.Trainer", pl_module: "L.LightningModule") -> None:
self.freak.control(self.state) # launch the Freak server in a background thread
def on_train_epoch_end(self, trainer: "L.Trainer", pl_module: "L.LightningModule") -> None:
self.state = trainer.strategy.broadcast(self.state, 0)
if self.state.should_stop: # call the Freak API to set this to True
# this triggers lightning to stop training
trainer.should_stop = True
@rank_zero_only
def on_train_end(self, trainer: "L.Trainer", pl_module: "L.LightningModule") -> None:
self.freak.stop() Note that |
Beta Was this translation helpful? Give feedback.
Hi @m-lyon,
Use
trainer.strategy.barrier()
to ensure all processes are at the same line. It is strategy-agnostic, so when you're using a single device, for example, it will be just no-op.https://pytorch-lightning.readthedocs.io/en/latest/api/pytorch_lightning.strategies.Strategy.html#pytorch_lightning.strategies.Strategy.barrier