diff --git a/README.md b/README.md index 04bef50..9438b1d 100644 --- a/README.md +++ b/README.md @@ -83,7 +83,10 @@ variance_threshold: 0.15 # Variance threshold for valid fractals. Defa n_fracts_per_vol: 3 # Number of fractals overlaid in each volume. Default is 3. val_split: 25 # In percent. epochs: 100 # Number of training epochs. -learning_rate: .0001 # Learning rate for training. +starting_learning_rate: .01 # Initial learning rate for training. +min_learning_rate: .001 # Minimum learning rate for CosineAnnealingWarmRestarts. +T_0: 100 # Epochs in the first cosine restart cycle. +T_mult: 2 # Restart cycle growth factor. disable_scheduler: 1 # If 1, disable scheduler during training to use constant LR. more_determinism: 0 # If 1, improve model training determinism. datagen_from_scratch: 0 # If 1, delete existing fractals and instances, then regenerate from scratch. diff --git a/ScaFFold/cli.py b/ScaFFold/cli.py index 30512a6..461c338 100644 --- a/ScaFFold/cli.py +++ b/ScaFFold/cli.py @@ -177,6 +177,28 @@ def main(): type=int, help="Number of training epochs.", ) + benchmark_parser.add_argument( + "--starting-learning-rate", + type=float, + help="Initial learning rate for training.", + ) + benchmark_parser.add_argument( + "--min-learning-rate", + type=float, + help="Minimum learning rate for CosineAnnealingWarmRestarts.", + ) + benchmark_parser.add_argument( + "--T-0", + dest="T_0", + type=int, + help="Epochs in the first cosine restart cycle.", + ) + benchmark_parser.add_argument( + "--T-mult", + dest="T_mult", + type=int, + help="Restart cycle growth factor.", + ) comm = MPI.COMM_WORLD rank = comm.Get_rank() diff --git a/ScaFFold/configs/benchmark_default.yml b/ScaFFold/configs/benchmark_default.yml index 0bc715d..ac54d54 100644 --- a/ScaFFold/configs/benchmark_default.yml +++ b/ScaFFold/configs/benchmark_default.yml @@ -7,8 +7,8 @@ n_instances_used_per_fractal: 145 # Number of unique instances to pull from eac problem_scale: 7 # Determines dataset resolution and number of unet layers. Default is 6. unet_bottleneck_dim: 3 # Power of 2 of the unet bottleneck layer dimension. Default of 3 -> bottleneck layer of size 8. seed: 42 # Random seed. -batch_size: 1 # Batch sizes for each vol size. -dataloader_num_workers: 4 # Number of DataLoader worker processes per rank. +batch_size: 1 # Batch sizes for each vol size per rank. +dataloader_num_workers: 1 # Number of DataLoader worker processes per rank. More workers will use more memory optimizer: "ADAM" # "ADAM" is preferred option, otherwise training defautls to RMSProp. dc_num_shards: [1, 1, 1] # DistConv param: number of shards to divide the tensor into. It's best to choose the fewest ranks needed to fit one sample in GPU memory, since that keeps communication at a minimum dc_shard_dims: [2, 3, 4] # DistConv param: dimension on which to shard @@ -19,8 +19,11 @@ variance_threshold: 0.15 # Variance threshold for valid fractals. Defa n_fracts_per_vol: 3 # Number of fractals overlaid in each volume. Default is 3. val_split: 25 # In percent. epochs: -1 # Number of training epochs. -learning_rate: .0001 # Learning rate for training. -disable_scheduler: 1 # If 1, disable scheduler during training to use constant LR. +starting_learning_rate: 0.1 # Initial learning rate for training. +min_learning_rate: 0.001 # Minimum learning rate for CosineAnnealingWarmRestarts. +T_0: 100 # Epochs in the first cosine restart cycle. +T_mult: 2 # Restart cycle growth factor. +disable_scheduler: 0 # If 1, disable scheduler during training to use constant LR. more_determinism: 0 # If 1, improve model training determinism. datagen_from_scratch: 0 # If 1, delete existing fractals and instances, then regenerate from scratch. train_from_scratch: 1 # If 1, delete existing train stats and checkpoint files. Keep 0 if want to restart runs where we left off. diff --git a/ScaFFold/configs/benchmark_testing.yml b/ScaFFold/configs/benchmark_testing.yml index 8fbd52c..8b72435 100644 --- a/ScaFFold/configs/benchmark_testing.yml +++ b/ScaFFold/configs/benchmark_testing.yml @@ -19,7 +19,10 @@ variance_threshold: 0.15 # Variance threshold for valid fractals. Defa n_fracts_per_vol: 3 # Number of fractals overlaid in each volume. Default is 3. val_split: 25 # In percent. epochs: 10 # Number of training epochs. -learning_rate: .0001 # Learning rate for training. +starting_learning_rate: .0001 # Initial learning rate for training. +min_learning_rate: .0001 # Minimum learning rate for CosineAnnealingWarmRestarts. +T_0: 10 # Epochs in the first cosine restart cycle. +T_mult: 1 # Restart cycle growth factor. disable_scheduler: 1 # If 1, disable scheduler during training to use constant LR. more_determinism: 0 # If 1, improve model training determinism. datagen_from_scratch: 0 # If 1, delete existing fractals and instances, then regenerate from scratch. diff --git a/ScaFFold/utils/config_utils.py b/ScaFFold/utils/config_utils.py index 4684779..2c73a6c 100644 --- a/ScaFFold/utils/config_utils.py +++ b/ScaFFold/utils/config_utils.py @@ -61,7 +61,10 @@ def __init__(self, config_dict): self.seed = config_dict["seed"] self.dist = bool(config_dict["dist"]) self.framework = config_dict["framework"] - self.learning_rate = config_dict["learning_rate"] + self.starting_learning_rate = config_dict["starting_learning_rate"] + self.min_learning_rate = config_dict["min_learning_rate"] + self.T_0 = config_dict["T_0"] + self.T_mult = config_dict["T_mult"] self.variance_threshold = config_dict["variance_threshold"] self.torch_amp = bool(config_dict["torch_amp"]) self.loss_freq = config_dict["loss_freq"] diff --git a/ScaFFold/utils/evaluate.py b/ScaFFold/utils/evaluate.py index 62d0fdf..2fd3eb1 100644 --- a/ScaFFold/utils/evaluate.py +++ b/ScaFFold/utils/evaluate.py @@ -30,10 +30,12 @@ def evaluate( net, dataloader, device, amp, primary, criterion, n_categories, parallel_strategy ): - def foreground_dice_mean(dice_scores): + def foreground_dice_stats(dice_scores): if dice_scores.size(1) > 1: - return dice_scores[:, 1:].mean().item() - return dice_scores.mean().item() + per_sample_scores = dice_scores[:, 1:].mean(dim=1) + else: + per_sample_scores = dice_scores.mean(dim=1) + return per_sample_scores.sum().item(), per_sample_scores.numel() net.eval() autocast_device_type = device.type if device.type != "mps" else "cpu" @@ -43,6 +45,7 @@ def foreground_dice_mean(dice_scores): num_val_batches = len(dataloader) total_dice_score = 0.0 processed_batches = 0 + processed_samples = 0 spatial_mesh = parallel_strategy.device_mesh[parallel_strategy.distconv_dim_names] @@ -116,19 +119,29 @@ def foreground_dice_mean(dice_scores): dice_score_probs = compute_sharded_dice( mask_pred_probs, mask_true_onehot, spatial_mesh ) - batch_dice_score = foreground_dice_mean(dice_score_probs) + batch_dice_sum, batch_sample_count = foreground_dice_stats( + dice_score_probs + ) + batch_dice_score = batch_dice_sum / max(batch_sample_count, 1) # Sum global CE Loss and Dice loss loss = CE_loss + (1.0 - batch_dice_score) val_loss_epoch += loss.item() - total_dice_score += batch_dice_score + total_dice_score += batch_dice_sum processed_batches += 1 + processed_samples += batch_sample_count net.train() val_loss_avg = val_loss_epoch / max(processed_batches, 1) if primary: print( - f"evaluate.py: dice_score={total_dice_score}, val_loss_epoch={val_loss_epoch}, val_loss_avg={val_loss_avg}, num_val_batches={processed_batches}" + f"evaluate.py: dice_score={total_dice_score}, val_loss_epoch={val_loss_epoch}, val_loss_avg={val_loss_avg}, num_val_batches={processed_batches}, num_val_samples={processed_samples}" ) - return total_dice_score, val_loss_epoch, val_loss_avg, processed_batches + return ( + total_dice_score, + val_loss_epoch, + val_loss_avg, + processed_batches, + processed_samples, + ) diff --git a/ScaFFold/utils/trainer.py b/ScaFFold/utils/trainer.py index b680746..654cc19 100644 --- a/ScaFFold/utils/trainer.py +++ b/ScaFFold/utils/trainer.py @@ -130,7 +130,7 @@ def create_dataset(self): ) self.n_train = len(self.train_set) self.n_val = len(self.val_set) - self.log.debug( + self.log.info( f"Datasets created with n_train={self.n_train}, n_val={self.n_val}" ) @@ -173,8 +173,15 @@ def create_dataloaders(self): self.train_set, sampler=self.train_sampler, **loader_args ) self.val_loader = DataLoader( - self.val_set, sampler=self.val_sampler, drop_last=True, **loader_args + self.val_set, sampler=self.val_sampler, drop_last=False, **loader_args ) + if len(self.val_loader) == 0: + raise ValueError( + "Validation DataLoader has zero batches. " + f"n_val={self.n_val}, batch_size={self.config.batch_size}, " + f"data_num_replicas={self.data_num_replicas}. " + "Reduce batch_size or adjust validation sharding." + ) def setup_training_components(self): """Set up the optimizer, scheduler, gradient scaler, and loss function.""" @@ -182,22 +189,27 @@ def setup_training_components(self): if self.config.optimizer == "ADAM": self.log.info("Using ADAM optimizer.") self.optimizer = optim.Adam( - self.model.parameters(), lr=self.config.learning_rate + self.model.parameters(), lr=self.config.starting_learning_rate ) elif self.config.optimizer == "SGD": self.log.info("Using SGD optimizer.") self.optimizer = optim.SGD( - self.model.parameters(), lr=self.config.learning_rate + self.model.parameters(), lr=self.config.starting_learning_rate ) else: self.log.info("Using RMSprop optimizer.") self.optimizer = optim.RMSprop( - self.model.parameters(), lr=self.config.learning_rate, foreach=True + self.model.parameters(), + lr=self.config.starting_learning_rate, + foreach=True, ) # Set up learning rate scheduler - self.scheduler = optim.lr_scheduler.ReduceLROnPlateau( - self.optimizer, "max", patience=25 + self.scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts( + self.optimizer, + T_0=self.config.T_0, + T_mult=self.config.T_mult, + eta_min=self.config.min_learning_rate, ) # Set up gradient scaler for AMP (Automatic Mixed Precision) @@ -234,6 +246,11 @@ def _foreground_dice_mean(dice_scores): return dice_scores[:, 1:].mean() return dice_scores.mean() + def _current_learning_rate(self): + if self.optimizer is None or not self.optimizer.param_groups: + return self.config.starting_learning_rate + return self.optimizer.param_groups[0]["lr"] + class PyTorchTrainer(BaseTrainer): """ @@ -403,7 +420,7 @@ def warmup(self): images = images.to( device=self.device, - dtype=torch.float32, + dtype=VOLUME_DTYPE, memory_format=torch.channels_last_3d, non_blocking=True, ) @@ -587,7 +604,7 @@ def train(self): begin_code_region("image_to_device") images = images.to( device=self.device, - dtype=torch.float32, + dtype=VOLUME_DTYPE, memory_format=torch.channels_last_3d, # NDHWC (channels last) vs NCDHW (channels first) non_blocking=True, ) @@ -720,7 +737,13 @@ def train(self): # # Evaluate model on validation set, update LR if necessary # - dice_sum, val_loss_epoch, val_loss_avg, numbatch = evaluate( + ( + dice_sum, + val_loss_epoch, + val_loss_avg, + numbatch, + numsamples, + ) = evaluate( self.model, self.val_loader, self.device, @@ -730,7 +753,7 @@ def train(self): self.config.n_categories, self.config._parallel_strategy, ) - dice_info = torch.tensor([dice_sum, numbatch]) + dice_info = torch.tensor([dice_sum, numsamples], dtype=VOLUME_DTYPE) if self.config.dist: dice_info = dice_info.to(device=self.device) torch.distributed.all_reduce( @@ -738,16 +761,7 @@ def train(self): ) val_score = dice_info[0].item() / max(dice_info[1].item(), 1) if not self.config.disable_scheduler: - # The following is true when trying to overfit, - # in which case we only care about train loss - if self.n_train == 1 or "overfit" in self.outfile_path: - self.log.debug( - "WARNING: scheduler step by overall_loss, \ - not val_score (n_train==1 or overfit in outfile_path)" - ) - self.scheduler.step(overall_loss) - else: # Otherwise, we're really trying to optimize for validation dice score - self.scheduler.step(val_score) + self.scheduler.step() else: self.log.debug("scheduler disabled, no LR update this step") @@ -758,10 +772,7 @@ def train(self): # train_dice = float(train_dice_total.item() / len(self.train_loader)) self.log.info( - f" epoch {epoch} \ - | train_dice_loss {train_dice:.6f} (type {type(train_dice)}) \ - | val_dice_score {val_score:.6f} \ - | lr {self.config.learning_rate:.8f}" + f" epoch {epoch} | train_dice_score {train_dice:.6f} | val_dice_score {val_score:.6f} | lr {self._current_learning_rate():.8f}" ) self.log.debug(f" writing to csv at {self.outfile_path}") if self.world_rank == 0: diff --git a/ScaFFold/worker.py b/ScaFFold/worker.py index d2818c1..16087c8 100644 --- a/ScaFFold/worker.py +++ b/ScaFFold/worker.py @@ -217,14 +217,18 @@ def main(kwargs_dict: dict = {}): trainer.spatial_mesh = ps.device_mesh[ps.distconv_dim_names] num_spatial_dims = len(ps.shard_dim) trainer.ddp_placements = [Shard(0)] + [Replicate()] * num_spatial_dims - global_batch_size = config.batch_size * ( - world_size // math.prod(config.dc_num_shards) - ) + total_shards = math.prod(config.dc_num_shards) + global_batch_size = config.batch_size * (world_size // total_shards) + ddp_ranks = world_size // total_shards if rank == 0: log.info( f"Effective global batch size = {global_batch_size} " f"(batch_size={config.batch_size} * " - f"(world_size={world_size} / prod(dc_num_shards)={math.prod(config.dc_num_shards)}))" + f"(world_size={world_size} / prod(dc_num_shards)={total_shards}))" + ) + log.info( + f"DDP ranks = {ddp_ranks} " + f"world_size={world_size} // prod(dc_num_shards)={total_shards}" ) if global_batch_size > trainer.n_train: raise ValueError(