diff --git a/ScaFFold/configs/benchmark_default.yml b/ScaFFold/configs/benchmark_default.yml index 2d1d414..e49e943 100644 --- a/ScaFFold/configs/benchmark_default.yml +++ b/ScaFFold/configs/benchmark_default.yml @@ -19,8 +19,11 @@ variance_threshold: 0.15 # Variance threshold for valid fractals. Defa n_fracts_per_vol: 3 # Number of fractals overlaid in each volume. Default is 3. val_split: 25 # In percent. epochs: -1 # Number of training epochs. -learning_rate: .0001 # Learning rate for training. -disable_scheduler: 1 # If 1, disable scheduler during training to use constant LR. +starting_learning_rate: 0.0025 # Starting LR used at scale_reference 6. +scale_learning_rate_factor: 0.5 # Multiply starting LR by this factor for each +1 increase in problem_scale. +gamma: 0.99 # ExponentialLR decay factor applied each epoch when scheduler is enabled. +min_learning_rate: 0.0001 # Floor for scheduler-adjusted learning rate. +disable_scheduler: 0 # If 1, disable scheduler during training to use constant LR. more_determinism: 0 # If 1, improve model training determinism. datagen_from_scratch: 0 # If 1, delete existing fractals and instances, then regenerate from scratch. train_from_scratch: 1 # If 1, delete existing train stats and checkpoint files. Keep 0 if want to restart runs where we left off. diff --git a/ScaFFold/configs/benchmark_testing.yml b/ScaFFold/configs/benchmark_testing.yml index f37749c..b1eb1fe 100644 --- a/ScaFFold/configs/benchmark_testing.yml +++ b/ScaFFold/configs/benchmark_testing.yml @@ -19,8 +19,11 @@ variance_threshold: 0.15 # Variance threshold for valid fractals. Defa n_fracts_per_vol: 3 # Number of fractals overlaid in each volume. Default is 3. val_split: 25 # In percent. epochs: 10 # Number of training epochs. -learning_rate: .0001 # Learning rate for training. -disable_scheduler: 1 # If 1, disable scheduler during training to use constant LR. +starting_learning_rate: 0.005 # Starting LR used at scale_reference 6. +scale_learning_rate_factor: 0.5 # Multiply starting LR by this factor for each +1 increase in problem_scale. +gamma: 0.99 # ExponentialLR decay factor applied each epoch when scheduler is enabled. +min_learning_rate: 0.0001 # Floor for scheduler-adjusted learning rate. +disable_scheduler: 0 # If 1, disable scheduler during training to use constant LR. more_determinism: 0 # If 1, improve model training determinism. datagen_from_scratch: 0 # If 1, delete existing fractals and instances, then regenerate from scratch. train_from_scratch: 1 # If 1, delete existing train stats and checkpoint files. Keep 0 if want to restart runs where we left off. diff --git a/ScaFFold/utils/config_utils.py b/ScaFFold/utils/config_utils.py index 9a67182..af4baab 100644 --- a/ScaFFold/utils/config_utils.py +++ b/ScaFFold/utils/config_utils.py @@ -61,7 +61,15 @@ def __init__(self, config_dict): self.seed = config_dict["seed"] self.dist = bool(config_dict["dist"]) self.framework = config_dict["framework"] - self.learning_rate = config_dict["learning_rate"] + self.starting_learning_rate = config_dict["starting_learning_rate"] + self.scale_learning_rate_factor = config_dict["scale_learning_rate_factor"] + self.starting_learning_rate = ( + self.starting_learning_rate + * self.scale_learning_rate_factor + ** (self.problem_scale - 6) # Reference problem scale is 6 + ) + self.gamma = config_dict["gamma"] + self.min_learning_rate = config_dict["min_learning_rate"] self.variance_threshold = config_dict["variance_threshold"] self.torch_amp = bool(config_dict["torch_amp"]) self.loss_freq = config_dict["loss_freq"] diff --git a/ScaFFold/utils/evaluate.py b/ScaFFold/utils/evaluate.py index c66c255..2e4ace6 100644 --- a/ScaFFold/utils/evaluate.py +++ b/ScaFFold/utils/evaluate.py @@ -32,6 +32,11 @@ def evaluate( net, dataloader, device, amp, primary, criterion, n_categories, parallel_strategy ): + def foreground_dice_mean(dice_scores): + if dice_scores.size(1) > 1: + return dice_scores[:, 1:].mean() + return dice_scores.mean() + net.eval() num_val_batches = len(dataloader) total_dice_score = 0.0 @@ -118,11 +123,11 @@ def evaluate( dice_score_probs = compute_sharded_dice( mask_pred_probs, mask_true_onehot, spatial_mesh ) - dice_loss_curr = 1.0 - dice_score_probs.mean() + dice_loss_curr = 1.0 - foreground_dice_mean(dice_score_probs) # Eval metric (excluding background class 0) # dice_score_probs shape is [Batch, Channels]. We slice [:, 1:] to drop background - batch_dice_score = dice_score_probs[:, 1:].mean() + batch_dice_score = foreground_dice_mean(dice_score_probs) # --- Combine and Accumulate --- loss = CE_loss + dice_loss_curr diff --git a/ScaFFold/utils/trainer.py b/ScaFFold/utils/trainer.py index c98fa64..6dfbd90 100644 --- a/ScaFFold/utils/trainer.py +++ b/ScaFFold/utils/trainer.py @@ -154,22 +154,24 @@ def setup_training_components(self): if self.config.optimizer == "ADAM": self.log.info("Using ADAM optimizer.") self.optimizer = optim.Adam( - self.model.parameters(), lr=self.config.learning_rate + self.model.parameters(), lr=self.config.starting_learning_rate ) elif self.config.optimizer == "SGD": self.log.info("Using SGD optimizer.") self.optimizer = optim.SGD( - self.model.parameters(), lr=self.config.learning_rate + self.model.parameters(), lr=self.config.starting_learning_rate ) else: self.log.info("Using RMSprop optimizer.") self.optimizer = optim.RMSprop( - self.model.parameters(), lr=self.config.learning_rate, foreach=True + self.model.parameters(), + lr=self.config.starting_learning_rate, + foreach=True, ) # Set up learning rate scheduler - self.scheduler = optim.lr_scheduler.ReduceLROnPlateau( - self.optimizer, "max", patience=25 + self.scheduler = optim.lr_scheduler.ExponentialLR( + self.optimizer, gamma=self.config.gamma ) # Set up gradient scaler for AMP (Automatic Mixed Precision) @@ -186,6 +188,24 @@ def setup_training_components(self): f"Optimizer: {self.optimizer}, Scheduler: {self.scheduler}, Gradient Scaler Enabled: {self.config.torch_amp}" ) + @staticmethod + def _foreground_dice_mean(dice_scores): + """Match optimization to the reported validation metric by excluding background.""" + if dice_scores.size(1) > 1: + return dice_scores[:, 1:].mean() + return dice_scores.mean() + + def _maybe_step_scheduler(self): + """Apply scheduler updates when enabled.""" + if self.config.disable_scheduler: + self.log.debug("scheduler disabled, no LR update this step") + return + + self.scheduler.step() + for param_group in self.optimizer.param_groups: + if param_group["lr"] < self.config.min_learning_rate: + param_group["lr"] = self.config.min_learning_rate + class PyTorchTrainer(BaseTrainer): """ @@ -436,7 +456,7 @@ def warmup(self): dice_scores = compute_sharded_dice( local_preds_softmax, local_labels_one_hot, self.spatial_mesh ) - loss_dice = 1.0 - dice_scores.mean() + loss_dice = 1.0 - self._foreground_dice_mean(dice_scores) # 3. Combine Loss loss = loss_ce + loss_dice @@ -641,11 +661,13 @@ def train(self): local_labels_one_hot, self.spatial_mesh, ) - loss_dice = 1.0 - dice_scores.mean() + loss_dice = 1.0 - self._foreground_dice_mean(dice_scores) # 3. Combine Loss loss = loss_ce + loss_dice - train_dice_total += dice_scores[:, 1:].mean().item() + train_dice_total += self._foreground_dice_mean( + dice_scores + ).item() end_code_region("calculate_loss") @@ -698,19 +720,8 @@ def train(self): dice_info, op=torch.distributed.ReduceOp.SUM ) val_score = dice_info[0].item() / max(dice_info[1].item(), 1) - if not self.config.disable_scheduler: - # The following is true when trying to overfit, - # in which case we only care about train loss - if self.n_train == 1 or "overfit" in self.outfile_path: - self.log.debug( - "WARNING: scheduler step by overall_loss, \ - not val_score (n_train==1 or overfit in outfile_path)" - ) - self.scheduler.step(overall_loss) - else: # Otherwise, we're really trying to optimize for validation dice score - self.scheduler.step(val_score) - else: - self.log.debug("scheduler disabled, no LR update this step") + self._maybe_step_scheduler() + current_lr = self.optimizer.param_groups[0]["lr"] epoch_end_time = time.time() epoch_duration = epoch_end_time - epoch_start_time @@ -721,7 +732,8 @@ def train(self): self.log.info( f" epoch {epoch} \ | train_dice_loss {train_dice:.6f} (type {type(train_dice)}) \ - | val_dice_score {val_score:.6f}" + | val_dice_score {val_score:.6f} \ + | lr {current_lr:.8f}" ) self.log.debug(f" writing to csv at {self.outfile_path}") if self.world_rank == 0: