diff --git a/ScaFFold/configs/benchmark_default.yml b/ScaFFold/configs/benchmark_default.yml index 541d664..1b0310c 100644 --- a/ScaFFold/configs/benchmark_default.yml +++ b/ScaFFold/configs/benchmark_default.yml @@ -19,8 +19,8 @@ variance_threshold: 0.15 # Variance threshold for valid fractals. Defa n_fracts_per_vol: 3 # Number of fractals overlaid in each volume. Default is 3. val_split: 30 # In percent. epochs: -1 # Number of training epochs. -starting_learning_rate: 0.01 # Initial learning rate for training. -min_learning_rate: 0.001 # Minimum learning rate for CosineAnnealingWarmRestarts. +starting_learning_rate: 0.001 # Initial learning rate for training. +min_learning_rate: 0.0001 # Minimum learning rate for CosineAnnealingWarmRestarts. T_0: 100 # Epochs in the first cosine restart cycle. T_mult: 2 # Restart cycle growth factor. disable_scheduler: 0 # If 1, disable scheduler during training to use constant LR. @@ -33,7 +33,7 @@ framework: "torch" # The DL framework to train with. Only valid checkpoint_dir: "checkpoints" # Subfolder in which to save training checkpoints. loss_freq: 1 # Number of epochs between logging the overall loss. normalize: 1 # Cateogry search normalization parameter -warmup_batches: 5 # How many warmup batches per rank to run before training. +warmup_batches: 64 # How many warmup batches per rank to run before training. ce_weight_sample_fraction: 0.1 # Fraction of training masks to sample when estimating background vs foreground CE weights. dataset_reuse_enforce_commit_id: 0 # Enforce matching commit IDs for dataset reuse. target_dice: 0.95 diff --git a/ScaFFold/utils/checkpointing.py b/ScaFFold/utils/checkpointing.py index 5a65b09..2a06a3c 100644 --- a/ScaFFold/utils/checkpointing.py +++ b/ScaFFold/utils/checkpointing.py @@ -105,6 +105,46 @@ def wait_for_save(self): self._log(f"Background save failed with error: {e}") self.future = None + def snapshot_training_state(self) -> Dict[str, Any]: + """Capture mutable in-memory training state without writing a checkpoint.""" + model_ref = self.model.module if hasattr(self.model, "module") else self.model + return { + "model_state_dict": self._clone_state_dict(model_ref.state_dict()), + "optimizer_state_dict": self._clone_state_dict(self.optimizer.state_dict()) + if self.optimizer + else None, + "scheduler_state_dict": self._clone_state_dict(self.scheduler.state_dict()) + if self.scheduler + else None, + "grad_scaler_state_dict": self._clone_state_dict( + self.grad_scaler.state_dict() + ) + if self.grad_scaler + else None, + "model_training": model_ref.training, + **self._get_rng_snapshot(), + } + + def restore_training_state(self, snapshot: Dict[str, Any]) -> None: + """Restore an in-memory training snapshot.""" + model_ref = self.model.module if hasattr(self.model, "module") else self.model + model_ref.load_state_dict(snapshot["model_state_dict"]) + + if self.optimizer and snapshot.get("optimizer_state_dict") is not None: + self.optimizer.load_state_dict(snapshot["optimizer_state_dict"]) + + if self.scheduler and snapshot.get("scheduler_state_dict") is not None: + self.scheduler.load_state_dict(snapshot["scheduler_state_dict"]) + + if self.grad_scaler and snapshot.get("grad_scaler_state_dict") is not None: + self.grad_scaler.load_state_dict(snapshot["grad_scaler_state_dict"]) + + self._restore_rng(snapshot) + model_ref.train(snapshot.get("model_training", True)) + + if self.optimizer: + self.optimizer.zero_grad(set_to_none=True) + def load_from_checkpoint(self) -> int: """Load the latest checkpoint. Returns start_epoch (default 1).""" self.wait_for_save() # Safety: don't load while writing @@ -285,6 +325,19 @@ def _transfer_dict_to_cpu(self, obj): else: return obj + def _clone_state_dict(self, obj): + """Recursively clone tensors so in-memory snapshots are isolated.""" + if torch.is_tensor(obj): + return obj.detach().clone() + elif isinstance(obj, dict): + return {k: self._clone_state_dict(v) for k, v in obj.items()} + elif isinstance(obj, list): + return [self._clone_state_dict(v) for v in obj] + elif isinstance(obj, tuple): + return tuple(self._clone_state_dict(v) for v in obj) + else: + return obj + def _barrier(self): if self.dist_enabled: dist.barrier() @@ -305,7 +358,7 @@ def _log(self, msg): def _get_rng_snapshot(self) -> Dict[str, Any]: snap = {"rng_state_pytorch": torch.get_rng_state()} if torch.cuda.is_available(): - snap["rng_state_pytorch_cuda"] = torch.cuda.get_rng_state() + snap["rng_state_pytorch_cuda"] = torch.cuda.get_rng_state_all() try: snap["rng_state_numpy"] = np.random.get_state() except ImportError: @@ -321,7 +374,11 @@ def _restore_rng(self, snap: Dict[str, Any]): if "rng_state_pytorch" in snap: torch.set_rng_state(snap["rng_state_pytorch"]) if "rng_state_pytorch_cuda" in snap and torch.cuda.is_available(): - torch.cuda.set_rng_state(snap["rng_state_pytorch_cuda"]) + cuda_state = snap["rng_state_pytorch_cuda"] + if isinstance(cuda_state, list): + torch.cuda.set_rng_state_all(cuda_state) + else: + torch.cuda.set_rng_state(cuda_state) if "rng_state_numpy" in snap: np.random.set_state(snap["rng_state_numpy"]) if "rng_state_python" in snap: diff --git a/ScaFFold/utils/data_types.py b/ScaFFold/utils/data_types.py index b555811..ef1515d 100644 --- a/ScaFFold/utils/data_types.py +++ b/ScaFFold/utils/data_types.py @@ -19,7 +19,10 @@ # Masks are values 0 <= x <= n_categories MASK_DTYPE = np.uint16 # Volumes/img are 0 <= x <= 1 -VOLUME_DTYPE = np.float32 +VOLUME_DTYPE_NAME = "float32" +VOLUME_NP_DTYPE = getattr(np, VOLUME_DTYPE_NAME) +VOLUME_TORCH_DTYPE = getattr(torch, VOLUME_DTYPE_NAME) +VOLUME_DTYPE = VOLUME_NP_DTYPE # Shared AMP dtype selection for torch.autocast. AMP_DTYPE = torch.bfloat16 diff --git a/ScaFFold/utils/evaluate.py b/ScaFFold/utils/evaluate.py index 67b01da..6b23b15 100644 --- a/ScaFFold/utils/evaluate.py +++ b/ScaFFold/utils/evaluate.py @@ -26,7 +26,15 @@ @annotate() @torch.inference_mode() def evaluate( - net, dataloader, device, amp, primary, criterion, n_categories, parallel_strategy + net, + dataloader, + device, + amp, + primary, + criterion, + n_categories, + parallel_strategy, + max_batches=None, ): def foreground_dice_stats(dice_scores): if dice_scores.size(1) > 1: @@ -41,6 +49,8 @@ def foreground_dice_stats(dice_scores): if amp: autocast_kwargs["dtype"] = AMP_DTYPE num_val_batches = len(dataloader) + if max_batches is not None: + num_val_batches = min(num_val_batches, max_batches) total_dice_score = 0.0 processed_batches = 0 processed_samples = 0 @@ -55,14 +65,18 @@ def foreground_dice_stats(dice_scores): with torch.autocast(**autocast_kwargs): val_loss_epoch = 0.0 class_weights = getattr(criterion, "weight", None) - for batch in tqdm( - dataloader, - total=num_val_batches, - desc="Validation round", - unit="batch", - leave=False, - disable=not primary, + for batch_idx, batch in enumerate( + tqdm( + dataloader, + total=num_val_batches, + desc="Validation round", + unit="batch", + leave=False, + disable=not primary, + ) ): + if batch_idx >= num_val_batches: + break image, mask_true = batch["image"], batch["mask"] image = image.to( diff --git a/ScaFFold/utils/trainer.py b/ScaFFold/utils/trainer.py index 1a1d2e0..082a4c3 100644 --- a/ScaFFold/utils/trainer.py +++ b/ScaFFold/utils/trainer.py @@ -30,7 +30,7 @@ from ScaFFold.utils.checkpointing import CheckpointManager from ScaFFold.utils.data_loading import FractalDataset, SpatialShardSpec -from ScaFFold.utils.data_types import AMP_DTYPE, VOLUME_DTYPE +from ScaFFold.utils.data_types import AMP_DTYPE, VOLUME_TORCH_DTYPE from ScaFFold.utils.dice_score import compute_sharded_dice from ScaFFold.utils.distributed import get_local_rank, get_world_rank, get_world_size @@ -412,6 +412,137 @@ def _get_memsize(self, tensor, tensor_label: str, verbosity: int = 0): tensor_memory_gb = tensor_memory_bytes / (1024**3) self.log.info(f"{tensor_label} size on GPU: {tensor_memory_gb:.2f} GB") + def _run_training_batch( + self, + batch, + *, + log_prefix="", + gather_mem_stats=False, + log_peak_mem=False, + ): + """Run one training batch and return batch size, detached loss, and dice.""" + images, true_masks = batch["image"], batch["mask"] + + begin_code_region("image_to_device") + images = images.to( + device=self.device, + dtype=VOLUME_TORCH_DTYPE, + memory_format=torch.channels_last_3d, + non_blocking=True, + ) + true_masks = true_masks.to( + device=self.device, dtype=torch.long, non_blocking=True + ).contiguous() + end_code_region("image_to_device") + if gather_mem_stats: + gather_and_print_mem(self.log, "after_batch_to_device") + + # Add a dummy channel dimension to get 5D [B, 1, D, H, W] + true_masks = true_masks.unsqueeze(1) + + # Inputs are already loaded as local shards by the dataset. + images_dc = DCTensor.from_shard(images, self.ps) + true_masks_dc = DCTensor.from_shard(true_masks, self.ps) + del images, true_masks + self._get_memsize(images_dc, "Sharded image", self.config.verbose) + + with torch.autocast(**self._autocast_kwargs()): + if gather_mem_stats: + torch.cuda.reset_peak_memory_stats() + gather_and_print_mem(self.log, "pre_forward") + begin_code_region("predict") + self.log.debug(f" {log_prefix}running forward pass") + masks_pred_dc = self.model(images_dc) + end_code_region("predict") + if gather_mem_stats: + gather_and_print_mem(self.log, "post_forward") + self.log.debug(f" {log_prefix}forward pass complete") + + # Extract the underlying PyTorch local tensors + local_preds = masks_pred_dc + local_labels_5d = true_masks_dc + + # Remove the dummy channel dimension so CE Loss is happy [B, D, H, W] + local_labels = local_labels_5d.squeeze(1) + if self.world_rank == 0: + self.log.debug(f" {log_prefix}Local Preds Shape: {local_preds.shape}") + self.log.debug( + f" {log_prefix}Local Labels Shape: {local_labels.shape}" + ) + + begin_code_region("calculate_loss") + current_mem = torch.cuda.memory_allocated() / (1024**3) + self.log.debug( + f" {log_prefix}Calculating sharded loss. Mem: {current_mem:.2f} GB." + ) + + # Calculate CE and Dice loss in single precision for numerical stability. + with torch.autocast(**self._autocast_kwargs(enabled=False)): + loss_ce = compute_sharded_cross_entropy_loss( + local_preds, + local_labels, + self.spatial_mesh, + self.config.dc_num_shards, + self.amp_device_type, + self.ce_class_weights, + ) + + local_preds_softmax = F.softmax(local_preds.float(), dim=1) + local_labels_one_hot = ( + F.one_hot(local_labels, num_classes=self.config.n_categories + 1) + .permute(0, 4, 1, 2, 3) + .float() + ) + dice_scores = compute_sharded_dice( + local_preds_softmax, + local_labels_one_hot, + self.spatial_mesh, + ) + batch_dice_score = self._foreground_dice_mean(dice_scores) + + # Sum global CE Loss and Dice loss + loss = loss_ce + (1.0 - batch_dice_score) + end_code_region("calculate_loss") + + self.log.debug( + f" {log_prefix}loss calculation complete. Proceeding to backward pass" + ) + if gather_mem_stats: + gather_and_print_mem(self.log, "pre_backward") + begin_code_region("backward") + self.grad_scaler.scale(loss).backward() + end_code_region("backward") + if gather_mem_stats: + gather_and_print_mem(self.log, "post_backward") + + begin_code_region("step_and_update") + self.grad_scaler.unscale_(self.optimizer) + torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0) + self.log.debug(f" {log_prefix}backward pass complete. Stepping optimizer") + self.grad_scaler.step(self.optimizer) + if gather_mem_stats: + gather_and_print_mem(self.log, "after_optim_step") + self.grad_scaler.update() + self.optimizer.zero_grad(set_to_none=False) + end_code_region("step_and_update") + + batch_size = images_dc.shape[0] + detached_loss = loss.detach() + + # Free memory aggressively + del images_dc, true_masks_dc, masks_pred_dc + del local_preds, local_labels, local_preds_softmax, local_labels_one_hot + del loss_ce, loss + + if log_peak_mem and self.world_rank == 0: + peak_alloc = torch.cuda.max_memory_allocated() / (1024**3) + peak_reserved = torch.cuda.max_memory_reserved() / (1024**3) + self.log.debug( + f"[MEM-PEAK] Peak alloc: {peak_alloc:.2f} GiB | Peak reserved: {peak_reserved:.2f} GiB", + ) + + return batch_size, detached_loss, batch_dice_score + def warmup(self): """Run warmup iterations before the main training loop.""" warmup_batches = self.config.warmup_batches @@ -421,143 +552,52 @@ def warmup(self): if self.config.dist: self.train_loader.sampler.set_epoch(0) - # Match the main training path as closely as possible. - self.model.train() - self.optimizer.zero_grad(set_to_none=False) start_warmup = time.time() max_batches = min(warmup_batches, len(self.train_loader)) - self.log.info(f"Running {max_batches} warmup batch(es) per rank") - - for batch_idx, batch in enumerate(self.train_loader): - if batch_idx >= max_batches: - break + max_val_batches = min(warmup_batches, len(self.val_loader)) + self.log.info( + f"Running {max_batches} training warmup batch(es) and {max_val_batches} validation warmup batch(es) per rank" + ) + snapshot = self.checkpoint_manager.snapshot_training_state() - images, true_masks = batch["image"], batch["mask"] + # Match the main training path as closely as possible, but roll back all + # mutable state so warmup does not affect convergence. + self.model.train() + self.optimizer.zero_grad(set_to_none=False) - images = images.to( - device=self.device, - dtype=VOLUME_DTYPE, - memory_format=torch.channels_last_3d, - non_blocking=True, - ) - true_masks = true_masks.to( - device=self.device, dtype=torch.long, non_blocking=True - ).contiguous() - - # Add a dummy channel dimension to get 5D [B, 1, D, H, W] - true_masks = true_masks.unsqueeze(1) - - # Inputs are already loaded as local shards by the dataset. - images_dc = DCTensor.from_shard(images, self.ps) - true_masks_dc = DCTensor.from_shard(true_masks, self.ps) - self._get_memsize(images_dc, "Sharded image", self.config.verbose) - - with torch.autocast(**self._autocast_kwargs()): - # Forward on DCTensor - self.log.debug(" warmup: running forward pass") - masks_pred_dc = self.model(images_dc) - self.log.debug(" warmup: forward pass complete") - - # Extract the underlying PyTorch local tensors - local_preds = masks_pred_dc - local_labels_5d = true_masks_dc - - # Remove the dummy channel dimension so CE Loss is happy [B, D, H, W] - local_labels = local_labels_5d.squeeze(1) - if self.world_rank == 0: - self.log.debug(f" warmup: Local Preds Shape: {local_preds.shape}") - # Should be something like [1, 6, 128, 128, 64] if sharding Width by 2 - self.log.debug( - f" warmup: Local Labels Shape: {local_labels.shape}" - ) - # Should be something like [1, 128, 128, 64] + try: + for batch_idx, batch in enumerate(self.train_loader): + if batch_idx >= max_batches: + break - # --- SHARDED LOSS CALCULATION --- - current_mem = torch.cuda.memory_allocated() / (1024**3) - self.log.debug( - f" warmup: Calculating sharded loss. Mem: {current_mem:.2f} GB." + self._run_training_batch( + batch, + log_prefix="warmup: ", + log_peak_mem=True, ) - - # Calculate CE and Dice loss in single precision for numerical stability. - with torch.autocast(**self._autocast_kwargs(enabled=False)): - loss_ce = compute_sharded_cross_entropy_loss( - local_preds, - local_labels, - self.spatial_mesh, - self.config.dc_num_shards, - self.amp_device_type, - self.ce_class_weights, - ) - - local_preds_softmax = F.softmax(local_preds.float(), dim=1) - local_labels_one_hot = ( - F.one_hot( - local_labels, num_classes=self.config.n_categories + 1 - ) - .permute(0, 4, 1, 2, 3) - .float() - ) - dice_scores = compute_sharded_dice( - local_preds_softmax, local_labels_one_hot, self.spatial_mesh - ) - batch_dice_score = self._foreground_dice_mean(dice_scores) - - # Sum global CE Loss and Dice loss - loss = loss_ce + (1.0 - batch_dice_score) - - self.log.debug( - " warmup: loss calculation complete. Proceeding to backward pass" - ) - - # Backward pass - self.grad_scaler.scale(loss).backward() - self.grad_scaler.unscale_(self.optimizer) - torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0) - self.log.debug(" warmup: backward pass complete. Stepping optimizer") - - self.grad_scaler.step(self.optimizer) - self.grad_scaler.update() - - # Free memory aggressively - del images_dc, true_masks_dc, masks_pred_dc - del ( - local_preds, - local_labels, - local_preds_softmax, - local_labels_one_hot, - ) - del loss_ce, loss - - if self.world_rank == 0: - peak_alloc = torch.cuda.max_memory_allocated() / (1024**3) - peak_reserved = torch.cuda.max_memory_reserved() / (1024**3) + batch_t_end = time.time() self.log.debug( - f"[MEM-PEAK] Peak alloc: {peak_alloc:.2f} GiB | Peak reserved: {peak_reserved:.2f} GiB", + f" warmup: batch {batch_idx} completed in {batch_t_end - start_warmup} seconds" ) - batch_t_end = time.time() - self.log.debug( - f" warmup: batch {batch_idx} completed in {batch_t_end - start_warmup} seconds" - ) - # Nuke any accumulated grads so the first real step starts clean - for p in self.model.parameters(): - p.grad = None - self.optimizer.zero_grad(set_to_none=True) + if self.config.dist: + self.val_loader.sampler.set_epoch(0) - if self.config.dist: - self.val_loader.sampler.set_epoch(0) - - evaluate( - self.model, - self.val_loader, - self.device, - self.config.torch_amp, - self.world_rank == 0, - self.criterion, - self.config.n_categories, - self.config._parallel_strategy, - ) - self.model.train() + if max_val_batches > 0: + self.log.debug(" warmup: running validation warmup pass") + evaluate( + self.model, + self.val_loader, + self.device, + self.config.torch_amp, + False, + self.criterion, + self.config.n_categories, + self.config._parallel_strategy, + max_batches=max_val_batches, + ) + finally: + self.checkpoint_manager.restore_training_state(snapshot) if self.config.dist: torch.distributed.barrier() @@ -605,125 +645,20 @@ def train(self): ) as pbar: begin_code_region("batch_loop") for batch in self.train_loader: - # Load initial samples and labels - images, true_masks = batch["image"], batch["mask"] - - begin_code_region("image_to_device") - images = images.to( - device=self.device, - dtype=VOLUME_DTYPE, - memory_format=torch.channels_last_3d, # NDHWC (channels last) vs NCDHW (channels first) - non_blocking=True, - ) - true_masks = true_masks.to( - device=self.device, dtype=torch.long, non_blocking=True - ).contiguous() # masks no channels NDHW, but ensure continuity. - end_code_region("image_to_device") - gather_and_print_mem(self.log, "after_batch_to_device") - - # Add a dummy channel dimension to get 5D [B, 1, D, H, W] - true_masks = true_masks.unsqueeze(1) - - # Inputs are already loaded as local shards by the dataset. - images_dc = DCTensor.from_shard(images, self.ps) - true_masks_dc = DCTensor.from_shard(true_masks, self.ps) - del images, true_masks - self._get_memsize( - images_dc, "Sharded image", self.config.verbose - ) - - with torch.autocast(**self._autocast_kwargs()): - # Predict on this batch - torch.cuda.reset_peak_memory_stats() - gather_and_print_mem(self.log, "pre_forward") - begin_code_region("predict") - masks_pred_dc = self.model(images_dc) - end_code_region("predict") - gather_and_print_mem(self.log, "post_forward") - - # Extract the underlying PyTorch local tensors - local_preds = masks_pred_dc - local_labels_5d = true_masks_dc - - # Remove the dummy channel dimension so CE Loss is happy [B, D, H, W] - local_labels = local_labels_5d.squeeze(1) - if self.world_rank == 0: - self.log.debug( - f"Local Preds Shape: {local_preds.shape}" - ) - # Should be something like [1, 6, 128, 128, 64] if sharding Width by 2 - self.log.debug( - f"Local Labels Shape: {local_labels.shape}" - ) - # Should be something like [1, 128, 128, 64] - - begin_code_region("calculate_loss") - # --- SHARDED LOSS CALCULATION --- - current_mem = torch.cuda.memory_allocated() / (1024**3) - self.log.debug( - f"Calculating sharded loss. Mem: {current_mem:.2f} GB." + batch_size, batch_loss, batch_dice_score = ( + self._run_training_batch( + batch, + gather_mem_stats=True, ) - - # Calculate CE and Dice loss in single precision for numerical stability. - with torch.autocast(**self._autocast_kwargs(enabled=False)): - loss_ce = compute_sharded_cross_entropy_loss( - local_preds, - local_labels, - self.spatial_mesh, - self.config.dc_num_shards, - self.amp_device_type, - self.ce_class_weights, - ) - - local_preds_softmax = F.softmax( - local_preds.float(), dim=1 - ) - local_labels_one_hot = ( - F.one_hot( - local_labels, - num_classes=self.config.n_categories + 1, - ) - .permute(0, 4, 1, 2, 3) - .float() - ) - dice_scores = compute_sharded_dice( - local_preds_softmax, - local_labels_one_hot, - self.spatial_mesh, - ) - batch_dice_score = self._foreground_dice_mean( - dice_scores - ) - - # Sum global CE Loss and Dice loss - loss = loss_ce + (1.0 - batch_dice_score) - train_dice_total += batch_dice_score - - end_code_region("calculate_loss") - - gather_and_print_mem(self.log, "pre_backward") - begin_code_region("backward") - self.grad_scaler.scale(loss).backward() - end_code_region("backward") - gather_and_print_mem(self.log, "post_backward") - - begin_code_region("step_and_update") - self.grad_scaler.unscale_(self.optimizer) - torch.nn.utils.clip_grad_norm_( - self.model.parameters(), max_norm=1.0 ) - self.grad_scaler.step(self.optimizer) - gather_and_print_mem(self.log, "after_optim_step") - self.grad_scaler.update() - self.optimizer.zero_grad(set_to_none=False) - end_code_region("step_and_update") + train_dice_total += batch_dice_score # Update the loss begin_code_region("update_loss") - pbar.update(images_dc.shape[0]) + pbar.update(batch_size) self.global_step += 1 # Stay on GPU - epoch_loss += loss.detach() + epoch_loss += batch_loss end_code_region("update_loss") end_code_region("batch_loop") @@ -749,7 +684,9 @@ def train(self): self.config.n_categories, self.config._parallel_strategy, ) - dice_info = torch.tensor([dice_sum, numsamples], dtype=VOLUME_DTYPE) + dice_info = torch.tensor( + [dice_sum, numsamples], dtype=VOLUME_TORCH_DTYPE + ) if self.config.dist: dice_info = dice_info.to(device=self.device) torch.distributed.all_reduce(