Skip to content

Commit

Permalink
Merge pull request #1302 from bghira/feature/prodigy-lr-scheduler
Browse files Browse the repository at this point in the history
prodigy may work best with cosine scheduler over long training runs
  • Loading branch information
bghira authored Jan 26, 2025
2 parents d67aabe + bb69f42 commit 9165da4
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 44 deletions.
25 changes: 19 additions & 6 deletions helpers/training/optimizer_param.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,7 @@
"adamw_schedulefree": {
"precision": "any",
"override_lr_scheduler": True,
"is_schedulefree": True,
"can_warmup": True,
"default_settings": {
"betas": (0.9, 0.999),
Expand All @@ -137,6 +138,7 @@
"adamw_schedulefree+aggressive": {
"precision": "any",
"override_lr_scheduler": True,
"is_schedulefree": True,
"can_warmup": True,
"default_settings": {
"betas": (0.9, 0.999),
Expand All @@ -148,6 +150,7 @@
"adamw_schedulefree+no_kahan": {
"precision": "any",
"override_lr_scheduler": True,
"is_schedulefree": True,
"can_warmup": True,
"default_settings": {
"betas": (0.9, 0.999),
Expand Down Expand Up @@ -473,7 +476,8 @@
{
"prodigy": {
"precision": "any",
"override_lr_scheduler": True,
"override_lr_scheduler": False,
"is_schedulefree": True,
"can_warmup": False,
"default_settings": {
"lr": 1.0,
Expand Down Expand Up @@ -562,10 +566,7 @@ def optimizer_parameters(optimizer, args):
if args.optimizer == "prodigy":
prodigy_steps = args.prodigy_steps
if prodigy_steps and prodigy_steps > 0:
optimizer_params["prodigy_steps"] = prodigy_steps
else:
# 25% of the total number of steps
optimizer_params["prodigy_steps"] = int(args.max_train_steps * 0.25)
optimizer_params["prodigy_steps"] = int(prodigy_steps)
print(
f"Using Prodigy optimiser with {optimizer_params['prodigy_steps']} steps of learning rate adjustment."
)
Expand All @@ -582,7 +583,19 @@ def is_lr_scheduler_disabled(optimizer: str):
"override_lr_scheduler", False
)
return is_disabled

def is_lr_schedulefree(optimizer: str):
"""
Check if the optimizer has ScheduleFree logic.
This is separate from the disabling of LR schedulers, because some optimizers
that contain ScheduleFree logic (Prodigy) can use an LR scheduler.
"""
is_schedulefree = False
if optimizer in optimizer_choices:
is_schedulefree = optimizer_choices.get(optimizer).get(
"is_schedulefree", False
)
return is_schedulefree

def show_optimizer_defaults(optimizer: str = None):
"""we'll print the defaults on a single line, eg. foo=bar, buz=baz"""
Expand Down
69 changes: 31 additions & 38 deletions helpers/training/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
determine_optimizer_class_with_config,
determine_params_to_optimize,
is_lr_scheduler_disabled,
is_lr_schedulefree,
cpu_offload_optimizer,
)
from helpers.data_backend.factory import BatchFetcher
Expand Down Expand Up @@ -1231,23 +1232,14 @@ def init_optimizer(self):
)

def init_lr_scheduler(self):
self.config.is_schedulefree = is_lr_scheduler_disabled(self.config.optimizer)
self.config.is_schedulefree = is_lr_schedulefree(self.config.optimizer)
self.config.is_lr_scheduler_disabled = is_lr_scheduler_disabled(self.config.optimizer) or self.config.use_deepspeed_scheduler
if self.config.is_schedulefree:
logger.info("Using experimental ScheduleFree optimiser..")
if self.config.is_lr_scheduler_disabled:
# we don't use LR schedulers with schedulefree optimisers
logger.info("Optimiser cannot use an LR scheduler, so we are disabling it.")
lr_scheduler = None
if not self.config.use_deepspeed_scheduler and not self.config.is_schedulefree:
logger.info(
f"Loading {self.config.lr_scheduler} learning rate scheduler with {self.config.lr_warmup_steps} warmup steps"
)
lr_scheduler = get_lr_scheduler(
self.config,
self.optimizer,
self.accelerator,
logger,
use_deepspeed_scheduler=False,
)
else:
logger.info(f"Using dummy learning rate scheduler")
if torch.backends.mps.is_available():
lr_scheduler = None
Expand All @@ -1257,13 +1249,24 @@ def init_lr_scheduler(self):
total_num_steps=self.config.max_train_steps,
warmup_num_steps=self.config.lr_warmup_steps,
)
if lr_scheduler is not None:
if hasattr(lr_scheduler, "num_update_steps_per_epoch"):
lr_scheduler.num_update_steps_per_epoch = (
self.config.num_update_steps_per_epoch
)
if hasattr(lr_scheduler, "last_step"):
lr_scheduler.last_step = self.state.get("global_resume_step", 0)
return lr_scheduler

logger.info(
f"Loading {self.config.lr_scheduler} learning rate scheduler with {self.config.lr_warmup_steps} warmup steps"
)
lr_scheduler = get_lr_scheduler(
self.config,
self.optimizer,
self.accelerator,
logger,
use_deepspeed_scheduler=False,
)
if hasattr(lr_scheduler, "num_update_steps_per_epoch"):
lr_scheduler.num_update_steps_per_epoch = (
self.config.num_update_steps_per_epoch
)
if hasattr(lr_scheduler, "last_step"):
lr_scheduler.last_step = self.state.get("global_resume_step", 0)

return lr_scheduler

Expand Down Expand Up @@ -1859,17 +1862,19 @@ def move_models(self, destination: str = "accelerator"):
)

def mark_optimizer_train(self):
if is_lr_scheduler_disabled(self.config.optimizer) and hasattr(
if is_lr_schedulefree(self.config.optimizer) and hasattr(
self.optimizer, "train"
):
# we typically have to call train() on the optim for schedulefree.
logger.debug("Setting optimiser into train() mode.")
self.optimizer.train()

def mark_optimizer_eval(self):
if is_lr_scheduler_disabled(self.config.optimizer) and hasattr(
if is_lr_schedulefree(self.config.optimizer) and hasattr(
self.optimizer, "eval"
):
# we typically have to call eval() on the optim for schedulefree before saving or running validations.
logger.debug("Setting optimiser into eval() mode.")
self.optimizer.eval()

def _send_webhook_msg(
Expand Down Expand Up @@ -2835,21 +2840,9 @@ def train(self):
if self.accelerator.sync_gradients:
try:
if "prodigy" in self.config.optimizer:
self.lr_scheduler.step(**self.extra_lr_scheduler_kwargs)
self.lr = self.optimizer.param_groups[0]["d"]
wandb_logs.update(
{
"prodigy/d": self.optimizer.param_groups[0]["d"],
"prodigy/d_prev": self.optimizer.param_groups[0][
"d_prev"
],
"prodigy/d0": self.optimizer.param_groups[0]["d0"],
"prodigy/d_coef": self.optimizer.param_groups[0][
"d_coef"
],
"prodigy/k": self.optimizer.param_groups[0]["k"],
}
)
elif self.config.is_schedulefree:
elif self.config.is_lr_scheduler_disabled:
# hackjob method of retrieving LR from accelerated optims
self.lr = StateTracker.get_last_lr()
else:
Expand Down Expand Up @@ -3053,9 +3046,9 @@ def train(self):
logs["grad_absmax"] = self.grad_norm

progress_bar.set_postfix(**logs)
self.mark_optimizer_eval()
if self.validation is not None:
if self.validation.would_validate():
self.mark_optimizer_eval()
self.enable_sageattention_inference()
self.disable_gradient_checkpointing()
self.validation.run_validations(
Expand All @@ -3064,7 +3057,7 @@ def train(self):
if self.validation.would_validate():
self.disable_sageattention_inference()
self.enable_gradient_checkpointing()
self.mark_optimizer_train()
self.mark_optimizer_train()
if (
self.config.push_to_hub
and self.config.push_checkpoints_to_hub
Expand Down

0 comments on commit 9165da4

Please sign in to comment.