Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions ScaFFold/configs/benchmark_default.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@ variance_threshold: 0.15 # Variance threshold for valid fractals. Defa
n_fracts_per_vol: 3 # Number of fractals overlaid in each volume. Default is 3.
val_split: 30 # In percent.
epochs: -1 # Number of training epochs.
starting_learning_rate: 0.01 # Initial learning rate for training.
min_learning_rate: 0.001 # Minimum learning rate for CosineAnnealingWarmRestarts.
starting_learning_rate: 0.001 # Initial learning rate for training.
min_learning_rate: 0.0001 # Minimum learning rate for CosineAnnealingWarmRestarts.
T_0: 100 # Epochs in the first cosine restart cycle.
T_mult: 2 # Restart cycle growth factor.
disable_scheduler: 0 # If 1, disable scheduler during training to use constant LR.
Expand All @@ -33,7 +33,7 @@ framework: "torch" # The DL framework to train with. Only valid
checkpoint_dir: "checkpoints" # Subfolder in which to save training checkpoints.
loss_freq: 1 # Number of epochs between logging the overall loss.
normalize: 1 # Cateogry search normalization parameter
warmup_batches: 5 # How many warmup batches per rank to run before training.
warmup_batches: 64 # How many warmup batches per rank to run before training.
ce_weight_sample_fraction: 0.1 # Fraction of training masks to sample when estimating background vs foreground CE weights.
dataset_reuse_enforce_commit_id: 0 # Enforce matching commit IDs for dataset reuse.
target_dice: 0.95
61 changes: 59 additions & 2 deletions ScaFFold/utils/checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,46 @@ def wait_for_save(self):
self._log(f"Background save failed with error: {e}")
self.future = None

def snapshot_training_state(self) -> Dict[str, Any]:
"""Capture mutable in-memory training state without writing a checkpoint."""
model_ref = self.model.module if hasattr(self.model, "module") else self.model
return {
"model_state_dict": self._clone_state_dict(model_ref.state_dict()),
"optimizer_state_dict": self._clone_state_dict(self.optimizer.state_dict())
if self.optimizer
else None,
"scheduler_state_dict": self._clone_state_dict(self.scheduler.state_dict())
if self.scheduler
else None,
"grad_scaler_state_dict": self._clone_state_dict(
self.grad_scaler.state_dict()
)
if self.grad_scaler
else None,
"model_training": model_ref.training,
**self._get_rng_snapshot(),
}

def restore_training_state(self, snapshot: Dict[str, Any]) -> None:
"""Restore an in-memory training snapshot."""
model_ref = self.model.module if hasattr(self.model, "module") else self.model
model_ref.load_state_dict(snapshot["model_state_dict"])

if self.optimizer and snapshot.get("optimizer_state_dict") is not None:
self.optimizer.load_state_dict(snapshot["optimizer_state_dict"])

if self.scheduler and snapshot.get("scheduler_state_dict") is not None:
self.scheduler.load_state_dict(snapshot["scheduler_state_dict"])

if self.grad_scaler and snapshot.get("grad_scaler_state_dict") is not None:
self.grad_scaler.load_state_dict(snapshot["grad_scaler_state_dict"])

self._restore_rng(snapshot)
model_ref.train(snapshot.get("model_training", True))

if self.optimizer:
self.optimizer.zero_grad(set_to_none=True)

def load_from_checkpoint(self) -> int:
"""Load the latest checkpoint. Returns start_epoch (default 1)."""
self.wait_for_save() # Safety: don't load while writing
Expand Down Expand Up @@ -285,6 +325,19 @@ def _transfer_dict_to_cpu(self, obj):
else:
return obj

def _clone_state_dict(self, obj):
"""Recursively clone tensors so in-memory snapshots are isolated."""
if torch.is_tensor(obj):
return obj.detach().clone()
elif isinstance(obj, dict):
return {k: self._clone_state_dict(v) for k, v in obj.items()}
elif isinstance(obj, list):
return [self._clone_state_dict(v) for v in obj]
elif isinstance(obj, tuple):
return tuple(self._clone_state_dict(v) for v in obj)
else:
return obj

def _barrier(self):
if self.dist_enabled:
dist.barrier()
Expand All @@ -305,7 +358,7 @@ def _log(self, msg):
def _get_rng_snapshot(self) -> Dict[str, Any]:
snap = {"rng_state_pytorch": torch.get_rng_state()}
if torch.cuda.is_available():
snap["rng_state_pytorch_cuda"] = torch.cuda.get_rng_state()
snap["rng_state_pytorch_cuda"] = torch.cuda.get_rng_state_all()
try:
snap["rng_state_numpy"] = np.random.get_state()
except ImportError:
Expand All @@ -321,7 +374,11 @@ def _restore_rng(self, snap: Dict[str, Any]):
if "rng_state_pytorch" in snap:
torch.set_rng_state(snap["rng_state_pytorch"])
if "rng_state_pytorch_cuda" in snap and torch.cuda.is_available():
torch.cuda.set_rng_state(snap["rng_state_pytorch_cuda"])
cuda_state = snap["rng_state_pytorch_cuda"]
if isinstance(cuda_state, list):
torch.cuda.set_rng_state_all(cuda_state)
else:
torch.cuda.set_rng_state(cuda_state)
if "rng_state_numpy" in snap:
np.random.set_state(snap["rng_state_numpy"])
if "rng_state_python" in snap:
Expand Down
5 changes: 4 additions & 1 deletion ScaFFold/utils/data_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,10 @@
# Masks are values 0 <= x <= n_categories
MASK_DTYPE = np.uint16
# Volumes/img are 0 <= x <= 1
VOLUME_DTYPE = np.float32
VOLUME_DTYPE_NAME = "float32"
VOLUME_NP_DTYPE = getattr(np, VOLUME_DTYPE_NAME)
VOLUME_TORCH_DTYPE = getattr(torch, VOLUME_DTYPE_NAME)
VOLUME_DTYPE = VOLUME_NP_DTYPE

# Shared AMP dtype selection for torch.autocast.
AMP_DTYPE = torch.bfloat16
30 changes: 22 additions & 8 deletions ScaFFold/utils/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,15 @@
@annotate()
@torch.inference_mode()
def evaluate(
net, dataloader, device, amp, primary, criterion, n_categories, parallel_strategy
net,
dataloader,
device,
amp,
primary,
criterion,
n_categories,
parallel_strategy,
max_batches=None,
):
def foreground_dice_stats(dice_scores):
if dice_scores.size(1) > 1:
Expand All @@ -41,6 +49,8 @@ def foreground_dice_stats(dice_scores):
if amp:
autocast_kwargs["dtype"] = AMP_DTYPE
num_val_batches = len(dataloader)
if max_batches is not None:
num_val_batches = min(num_val_batches, max_batches)
total_dice_score = 0.0
processed_batches = 0
processed_samples = 0
Expand All @@ -55,14 +65,18 @@ def foreground_dice_stats(dice_scores):
with torch.autocast(**autocast_kwargs):
val_loss_epoch = 0.0
class_weights = getattr(criterion, "weight", None)
for batch in tqdm(
dataloader,
total=num_val_batches,
desc="Validation round",
unit="batch",
leave=False,
disable=not primary,
for batch_idx, batch in enumerate(
tqdm(
dataloader,
total=num_val_batches,
desc="Validation round",
unit="batch",
leave=False,
disable=not primary,
)
):
if batch_idx >= num_val_batches:
break
image, mask_true = batch["image"], batch["mask"]

image = image.to(
Expand Down
Loading
Loading