Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature Request] EarlyStopping for torchrl.trainers.Trainer #2370

Open
1 of 4 tasks
jkrude opened this issue Aug 6, 2024 · 0 comments
Open
1 of 4 tasks

[Feature Request] EarlyStopping for torchrl.trainers.Trainer #2370

jkrude opened this issue Aug 6, 2024 · 0 comments
Assignees
Labels
enhancement New feature or request

Comments

@jkrude
Copy link

jkrude commented Aug 6, 2024

Motivation

Often times we only want to train an algorithm until it learned the intended behavior, and a total number of frames is only a proxy for the stopping condition.
I would like to propose adding a new callback that makes early stopping possible, much like the example in Lightning.

Currently, no callback can influence the end of the training loop, so my current workaround is setting self.total_frames to 1, which isn't great.

Solution

  • Adding a new early-stopping-hook
  • Creating a default hook implementation for TrainerHookBase
  • Stopping the loop if the hook gives a positive return

The main question for me would be where to call the hook and what it can observe.
I would argue that both the batch and the losses_detached would be a relevant signal for the stopping hook, as one might want to stop when a certain return was seen n-times or the loss is smaller than a threshold.
Therefore, either the call would happen in optim_steps(self, batch: TensorDictBase) -> None and the function returns the stopping signal to def train(self) or the hook is called in train and optim_steps returns the loss such that the early stopping hook can observe it.

An example for the second approach could look something like:

def train(self):
    ...
    average_losses = None
    if self.collected_frames > self.collector.init_random_frames:
        average_losses = self.optim_steps(batch)
    
    if self.early_stopping_hook(batch,average_losses):
      self.save_trainer(force_save=True)
     break
    
    self._post_steps_hook()
   ...

I am happy to open a PR if this is something of interest to you.

Checklist

  • I have checked that there is no similar issue in the repo (required)
@jkrude jkrude added the enhancement New feature or request label Aug 6, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

2 participants