Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions ScaFFold/configs/benchmark_default.yml
Original file line number Diff line number Diff line change
Expand Up @@ -34,5 +34,6 @@ checkpoint_dir: "checkpoints" # Subfolder in which to save training checkpo
loss_freq: 1 # Number of epochs between logging the overall loss.
normalize: 1 # Cateogry search normalization parameter
warmup_batches: 5 # How many warmup batches per rank to run before training.
ce_weight_sample_fraction: 0.1 # Fraction of training masks to sample when estimating background vs foreground CE weights.
dataset_reuse_enforce_commit_id: 0 # Enforce matching commit IDs for dataset reuse.
target_dice: 0.95
1 change: 1 addition & 0 deletions ScaFFold/configs/benchmark_testing.yml
Original file line number Diff line number Diff line change
Expand Up @@ -34,5 +34,6 @@ checkpoint_dir: "checkpoints" # Subfolder in which to save training checkpo
loss_freq: 1 # Number of epochs between logging the overall loss.
normalize: 1 # Cateogry search normalization parameter
warmup_batches: 5 # How many warmup batches per rank to run before training.
ce_weight_sample_fraction: 0.1 # Fraction of training masks to sample when estimating background vs foreground CE weights.
dataset_reuse_enforce_commit_id: 0 # Enforce matching commit IDs for dataset reuse.
target_dice: 0.95
3 changes: 3 additions & 0 deletions ScaFFold/utils/config_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,9 @@ def __init__(self, config_dict):
self.checkpoint_dir = config_dict["checkpoint_dir"]
self.normalize = config_dict["normalize"]
self.warmup_batches = config_dict.get("warmup_batches")
self.ce_weight_sample_fraction = config_dict.get(
"ce_weight_sample_fraction", 0.1
)
self.dataset_reuse_enforce_commit_id = config_dict[
"dataset_reuse_enforce_commit_id"
]
Expand Down
30 changes: 11 additions & 19 deletions ScaFFold/utils/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,9 @@
from distconv import DCTensor
from tqdm import tqdm

from ScaFFold.utils.data_types import AMP_DTYPE, VOLUME_DTYPE
from ScaFFold.utils.dice_score import (
SpatialAllReduce,
compute_sharded_dice,
)
from ScaFFold.utils.data_types import AMP_DTYPE
from ScaFFold.utils.dice_score import compute_sharded_dice
from ScaFFold.utils.losses import compute_sharded_cross_entropy_loss
from ScaFFold.utils.perf_measure import annotate


Expand Down Expand Up @@ -56,6 +54,7 @@ def foreground_dice_stats(dice_scores):

with torch.autocast(**autocast_kwargs):
val_loss_epoch = 0.0
class_weights = getattr(criterion, "weight", None)
for batch in tqdm(
dataloader,
total=num_val_batches,
Expand Down Expand Up @@ -94,22 +93,15 @@ def foreground_dice_stats(dice_scores):

# Calculate CE and Dice loss in single precision for numerical stability.
with torch.autocast(device_type=autocast_device_type, enabled=False):
# Compute global CE loss from sharded CE loss
local_ce_sum = F.cross_entropy(
local_preds.float(), local_labels, reduction="sum"
CE_loss = compute_sharded_cross_entropy_loss(
local_preds,
local_labels,
spatial_mesh,
parallel_strategy.num_shards,
autocast_device_type,
class_weights,
)
global_ce_sum = SpatialAllReduce.apply(local_ce_sum, spatial_mesh)
local_voxel_count = torch.tensor(
float(local_labels.numel()),
device=local_labels.device,
dtype=VOLUME_DTYPE,
)
global_total_voxels = SpatialAllReduce.apply(
local_voxel_count, spatial_mesh
)
CE_loss = global_ce_sum / global_total_voxels

# Compute global dice loss from sharded dice loss
mask_pred_probs = F.softmax(local_preds.float(), dim=1)
mask_true_onehot = (
F.one_hot(local_labels, n_categories + 1)
Expand Down
166 changes: 166 additions & 0 deletions ScaFFold/utils/losses.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,166 @@
# Copyright (c) 2014-2026, Lawrence Livermore National Security, LLC.
# Produced at the Lawrence Livermore National Laboratory.
# Written by the LBANN Research Team (B. Van Essen, et al.) listed in
# the CONTRIBUTORS file. See the top-level LICENSE file for details.
#
# LLNL-CODE-697807.
# All rights reserved.
#
# This file is part of LBANN: Livermore Big Artificial Neural Network
# Toolkit. For details, see http://software.llnl.gov/LBANN or
# https://github.com/LBANN and https://github.com/LBANN/ScaFFold.
#
# SPDX-License-Identifier: (Apache-2.0)

import math

import torch
import torch.distributed as dist
import torch.nn.functional as F

from ScaFFold.utils.dice_score import SpatialAllReduce


def _sample_ce_weight_indices(n_train, sample_fraction):
"""Pick a small, deterministic subset of masks to estimate CE weights."""
if n_train <= 0:
return []

if sample_fraction is None:
sample_fraction = 0.1

sample_count = min(
max(math.ceil(n_train * float(sample_fraction)), 1),
n_train,
)
if sample_count == n_train:
return list(range(n_train))

return torch.linspace(0, n_train - 1, steps=sample_count).long().tolist()


def _compute_ce_class_weights(
train_set,
n_train,
n_categories,
device,
sample_fraction=0.1,
dist_enabled=False,
world_rank=0,
log=None,
):
"""
Estimate background vs foreground CE weights from a few training masks.

Background keeps its own inverse-frequency weight, and every non-zero
fractal class shares the foreground weight derived from the aggregate
non-empty voxel count.
"""

num_classes = n_categories + 1
class_weights = torch.ones(num_classes, device=device, dtype=torch.float32)

if n_train == 0:
if log is not None:
log.warning(
"Training set is empty while computing CE class weights. Falling back to uniform weights."
)
return class_weights

sample_indices = _sample_ce_weight_indices(n_train, sample_fraction)
sampled_class_counts = torch.zeros(num_classes, dtype=torch.long)

for sample_idx in sample_indices:
mask = train_set[sample_idx]["mask"]
sampled_class_counts += torch.bincount(mask.reshape(-1), minlength=num_classes)

# The dataset may already return only this rank's local spatial shard,
# so combine per-rank counts before deriving the global CE weights.
sampled_class_counts = sampled_class_counts.to(device=device)
if dist_enabled:
dist.all_reduce(sampled_class_counts, op=dist.ReduceOp.SUM)

background_voxels = int(sampled_class_counts[0].item())
foreground_voxels = int(sampled_class_counts[1:].sum().item())

if background_voxels > 0 and foreground_voxels > 0:
total_voxels = background_voxels + foreground_voxels
class_weights[0] = total_voxels / background_voxels
class_weights[1:] = total_voxels / foreground_voxels
elif log is not None:
log.warning(
"Sampled masks did not contain both background and foreground voxels. Falling back to uniform CE weights."
)

if log is not None and (not dist_enabled or world_rank == 0):
log.info(
f"CE weights estimated from {len(sample_indices)} training masks "
f"(sample_fraction={sample_fraction}, indices={sample_indices}): "
f"background_voxels={background_voxels} "
f"foreground_voxels={foreground_voxels} "
f"weights={class_weights.detach().cpu().tolist()}"
)

return class_weights


def compute_sharded_cross_entropy_loss(
local_preds,
local_labels,
spatial_mesh,
_num_shards,
device_type,
class_weights=None,
):
"""
Compute the CE loss for a spatially sharded volume.

Each rank only sees a local spatial shard, so we cannot use the local
`reduction="mean"` result directly. Instead we:
1. compute the local CE numerator with `reduction="sum"`,
2. build the correct global denominator,
3. all-reduce across the spatial mesh, and
4. divide to recover the same value we would get from a non-sharded tensor.

When `class_weights` is provided, PyTorch's CE "mean" divides by the sum of
the target weights, not the raw voxel count, so we reproduce that behavior
explicitly here.
"""

autocast_device = device_type if device_type != "mps" else "cpu"
with torch.autocast(autocast_device, enabled=False):
# Accumulate CE in full precision. Using reduction="sum" gives us the
# numerator of the final global mean; if class weights are present,
# PyTorch applies the target-class weight to each voxel here.
local_ce_sum = F.cross_entropy(
local_preds.float(),
local_labels,
weight=class_weights,
reduction="sum",
)

if class_weights is None:
# Sum the actual local voxel counts across spatial shards. We use
# an all-reduced count instead of numel()*num_shards because shard
# sizes can differ at chunk boundaries.
local_voxel_count = local_ce_sum.new_tensor(float(local_labels.numel()))
global_normalizer = SpatialAllReduce.apply(local_voxel_count, spatial_mesh)
else:
# Weighted CE divides by sum(weight[target_i]) over all voxels.
# Build that denominator from the local label histogram, then
# all-reduce it across the spatial mesh.
local_class_counts = torch.bincount(
local_labels.reshape(-1), minlength=class_weights.numel()
).to(dtype=local_ce_sum.dtype)
local_weight_sum = torch.dot(
local_class_counts, class_weights.to(dtype=local_ce_sum.dtype)
)
global_normalizer = SpatialAllReduce.apply(local_weight_sum, spatial_mesh)

# Sum the local CE numerators from each spatial shard to get the global CE
# numerator, then divide by the matching global denominator.
global_ce_sum = SpatialAllReduce.apply(local_ce_sum, spatial_mesh)
# Clamp to avoid a divide-by-zero in degenerate cases.
return global_ce_sum / global_normalizer.clamp_min(
torch.finfo(global_ce_sum.dtype).eps
)
80 changes: 38 additions & 42 deletions ScaFFold/utils/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#
# SPDX-License-Identifier: (Apache-2.0)

# Standard library
import math
import os
import shutil
Expand All @@ -30,14 +31,15 @@
from ScaFFold.utils.checkpointing import CheckpointManager
from ScaFFold.utils.data_loading import FractalDataset, SpatialShardSpec
from ScaFFold.utils.data_types import AMP_DTYPE, VOLUME_DTYPE
from ScaFFold.utils.dice_score import (
SpatialAllReduce,
compute_sharded_dice,
)
from ScaFFold.utils.dice_score import compute_sharded_dice
from ScaFFold.utils.distributed import get_local_rank, get_world_rank, get_world_size

# Local
from ScaFFold.utils.evaluate import evaluate
from ScaFFold.utils.losses import (
_compute_ce_class_weights,
compute_sharded_cross_entropy_loss,
)
from ScaFFold.utils.perf_measure import adiak_value, begin_code_region, end_code_region
from ScaFFold.utils.utils import gather_and_print_mem

Expand Down Expand Up @@ -72,6 +74,7 @@ def __init__(self, model, config, device, log):
self.scheduler = None
self.grad_scaler = None
self.criterion = None
self.ce_class_weights = None
self.global_step = 0
self.start_epoch = -1
self.ps = getattr(self.config, "_parallel_strategy", None)
Expand Down Expand Up @@ -220,11 +223,24 @@ def setup_training_components(self):
self.grad_scaler = torch.amp.GradScaler("cuda", enabled=self.use_grad_scaler)

# Set up loss function
self.criterion = (
nn.CrossEntropyLoss()
if self.config.n_categories + 1 > 1
else nn.BCEWithLogitsLoss()
)
if self.config.n_categories + 1 > 1:
ce_class_weights = _compute_ce_class_weights(
train_set=self.train_set,
n_train=self.n_train,
n_categories=self.config.n_categories,
device=self.device,
sample_fraction=self.config.ce_weight_sample_fraction,
dist_enabled=self.config.dist,
world_rank=self.world_rank,
log=self.log,
)
self.criterion = nn.CrossEntropyLoss(weight=ce_class_weights).to(
self.device
)
else:
self.criterion = nn.BCEWithLogitsLoss().to(self.device)
if isinstance(self.criterion, nn.CrossEntropyLoss):
self.ce_class_weights = self.criterion.weight

self.log.info(
f"Optimizer: {self.optimizer}, Scheduler: {self.scheduler}, AMP dtype: {self.amp_dtype}, Gradient Scaler Enabled: {self.use_grad_scaler}"
Expand Down Expand Up @@ -464,24 +480,15 @@ def warmup(self):

# Calculate CE and Dice loss in single precision for numerical stability.
with torch.autocast(**self._autocast_kwargs(enabled=False)):
# Compute global CE loss from sharded CE loss
local_ce_sum = F.cross_entropy(
local_preds.float(), local_labels, reduction="sum"
loss_ce = compute_sharded_cross_entropy_loss(
local_preds,
local_labels,
self.spatial_mesh,
self.config.dc_num_shards,
self.amp_device_type,
self.ce_class_weights,
)
global_ce_sum = SpatialAllReduce.apply(
local_ce_sum, self.spatial_mesh
)
local_voxel_count = torch.tensor(
float(local_labels.numel()),
device=local_labels.device,
dtype=VOLUME_DTYPE,
)
global_total_voxels = SpatialAllReduce.apply(
local_voxel_count, self.spatial_mesh
)
loss_ce = global_ce_sum / global_total_voxels

# Compute global dice loss from sharded dice loss
local_preds_softmax = F.softmax(local_preds.float(), dim=1)
local_labels_one_hot = (
F.one_hot(
Expand Down Expand Up @@ -659,26 +666,15 @@ def train(self):

# Calculate CE and Dice loss in single precision for numerical stability.
with torch.autocast(**self._autocast_kwargs(enabled=False)):
# Compute global CE loss from sharded CE loss
local_ce_sum = F.cross_entropy(
local_preds.float(),
loss_ce = compute_sharded_cross_entropy_loss(
local_preds,
local_labels,
reduction="sum",
)
global_ce_sum = SpatialAllReduce.apply(
local_ce_sum, self.spatial_mesh
)
local_voxel_count = torch.tensor(
float(local_labels.numel()),
device=local_labels.device,
dtype=VOLUME_DTYPE,
)
global_total_voxels = SpatialAllReduce.apply(
local_voxel_count, self.spatial_mesh
self.spatial_mesh,
self.config.dc_num_shards,
self.amp_device_type,
self.ce_class_weights,
)
loss_ce = global_ce_sum / global_total_voxels

# Compute global dice loss from sharded dice loss
local_preds_softmax = F.softmax(
local_preds.float(), dim=1
)
Expand Down
Loading