From 05ad2ba4f5ffa7be3e068b76d9efeb59183cc3d7 Mon Sep 17 00:00:00 2001 From: Patrick Miles Date: Thu, 16 Apr 2026 08:52:00 -0700 Subject: [PATCH 01/11] use class weights in CE loss to make background less dominant; calc weights at trainer init --- ScaFFold/configs/benchmark_default.yml | 1 + ScaFFold/configs/benchmark_testing.yml | 1 + ScaFFold/utils/config_utils.py | 1 + ScaFFold/utils/evaluate.py | 30 +++--- ScaFFold/utils/losses.py | 82 ++++++++++++++++ ScaFFold/utils/trainer.py | 127 +++++++++++++++++-------- 6 files changed, 183 insertions(+), 59 deletions(-) create mode 100644 ScaFFold/utils/losses.py diff --git a/ScaFFold/configs/benchmark_default.yml b/ScaFFold/configs/benchmark_default.yml index ac54d54..14e5f62 100644 --- a/ScaFFold/configs/benchmark_default.yml +++ b/ScaFFold/configs/benchmark_default.yml @@ -34,5 +34,6 @@ checkpoint_dir: "checkpoints" # Subfolder in which to save training checkpo loss_freq: 1 # Number of epochs between logging the overall loss. normalize: 1 # Cateogry search normalization parameter warmup_batches: 5 # How many warmup batches per rank to run before training. +ce_weight_num_samples: 8 # How many training masks to sample when estimating background vs foreground CE weights. dataset_reuse_enforce_commit_id: 0 # Enforce matching commit IDs for dataset reuse. target_dice: 0.95 diff --git a/ScaFFold/configs/benchmark_testing.yml b/ScaFFold/configs/benchmark_testing.yml index 8b72435..15cb264 100644 --- a/ScaFFold/configs/benchmark_testing.yml +++ b/ScaFFold/configs/benchmark_testing.yml @@ -34,5 +34,6 @@ checkpoint_dir: "checkpoints" # Subfolder in which to save training checkpo loss_freq: 1 # Number of epochs between logging the overall loss. normalize: 1 # Cateogry search normalization parameter warmup_batches: 5 # How many warmup batches per rank to run before training. +ce_weight_num_samples: 8 # How many training masks to sample when estimating background vs foreground CE weights. dataset_reuse_enforce_commit_id: 0 # Enforce matching commit IDs for dataset reuse. target_dice: 0.95 diff --git a/ScaFFold/utils/config_utils.py b/ScaFFold/utils/config_utils.py index 2c73a6c..98421e0 100644 --- a/ScaFFold/utils/config_utils.py +++ b/ScaFFold/utils/config_utils.py @@ -71,6 +71,7 @@ def __init__(self, config_dict): self.checkpoint_dir = config_dict["checkpoint_dir"] self.normalize = config_dict["normalize"] self.warmup_batches = config_dict.get("warmup_batches") + self.ce_weight_num_samples = config_dict.get("ce_weight_num_samples", 8) self.dataset_reuse_enforce_commit_id = config_dict[ "dataset_reuse_enforce_commit_id" ] diff --git a/ScaFFold/utils/evaluate.py b/ScaFFold/utils/evaluate.py index 2fd3eb1..67b01da 100644 --- a/ScaFFold/utils/evaluate.py +++ b/ScaFFold/utils/evaluate.py @@ -17,11 +17,9 @@ from distconv import DCTensor from tqdm import tqdm -from ScaFFold.utils.data_types import AMP_DTYPE, VOLUME_DTYPE -from ScaFFold.utils.dice_score import ( - SpatialAllReduce, - compute_sharded_dice, -) +from ScaFFold.utils.data_types import AMP_DTYPE +from ScaFFold.utils.dice_score import compute_sharded_dice +from ScaFFold.utils.losses import compute_sharded_cross_entropy_loss from ScaFFold.utils.perf_measure import annotate @@ -56,6 +54,7 @@ def foreground_dice_stats(dice_scores): with torch.autocast(**autocast_kwargs): val_loss_epoch = 0.0 + class_weights = getattr(criterion, "weight", None) for batch in tqdm( dataloader, total=num_val_batches, @@ -94,22 +93,15 @@ def foreground_dice_stats(dice_scores): # Calculate CE and Dice loss in single precision for numerical stability. with torch.autocast(device_type=autocast_device_type, enabled=False): - # Compute global CE loss from sharded CE loss - local_ce_sum = F.cross_entropy( - local_preds.float(), local_labels, reduction="sum" + CE_loss = compute_sharded_cross_entropy_loss( + local_preds, + local_labels, + spatial_mesh, + parallel_strategy.num_shards, + autocast_device_type, + class_weights, ) - global_ce_sum = SpatialAllReduce.apply(local_ce_sum, spatial_mesh) - local_voxel_count = torch.tensor( - float(local_labels.numel()), - device=local_labels.device, - dtype=VOLUME_DTYPE, - ) - global_total_voxels = SpatialAllReduce.apply( - local_voxel_count, spatial_mesh - ) - CE_loss = global_ce_sum / global_total_voxels - # Compute global dice loss from sharded dice loss mask_pred_probs = F.softmax(local_preds.float(), dim=1) mask_true_onehot = ( F.one_hot(local_labels, n_categories + 1) diff --git a/ScaFFold/utils/losses.py b/ScaFFold/utils/losses.py new file mode 100644 index 0000000..bddc55a --- /dev/null +++ b/ScaFFold/utils/losses.py @@ -0,0 +1,82 @@ +# Copyright (c) 2014-2026, Lawrence Livermore National Security, LLC. +# Produced at the Lawrence Livermore National Laboratory. +# Written by the LBANN Research Team (B. Van Essen, et al.) listed in +# the CONTRIBUTORS file. See the top-level LICENSE file for details. +# +# LLNL-CODE-697807. +# All rights reserved. +# +# This file is part of LBANN: Livermore Big Artificial Neural Network +# Toolkit. For details, see http://software.llnl.gov/LBANN or +# https://github.com/LBANN and https://github.com/LBANN/ScaFFold. +# +# SPDX-License-Identifier: (Apache-2.0) + +import torch +import torch.nn.functional as F + +from ScaFFold.utils.dice_score import SpatialAllReduce + + +def compute_sharded_cross_entropy_loss( + local_preds, + local_labels, + spatial_mesh, + _num_shards, + device_type, + class_weights=None, +): + """ + Compute the CE loss for a spatially sharded volume. + + Each rank only sees a local spatial shard, so we cannot use the local + `reduction="mean"` result directly. Instead we: + 1. compute the local CE numerator with `reduction="sum"`, + 2. build the correct global denominator, + 3. all-reduce across the spatial mesh, and + 4. divide to recover the same value we would get from a non-sharded tensor. + + When `class_weights` is provided, PyTorch's CE "mean" divides by the sum of + the target weights, not the raw voxel count, so we reproduce that behavior + explicitly here. + """ + + autocast_device = device_type if device_type != "mps" else "cpu" + with torch.autocast(autocast_device, enabled=False): + # Accumulate CE in full precision. Using reduction="sum" gives us the + # numerator of the final global mean; if class weights are present, + # PyTorch applies the target-class weight to each voxel here. + local_ce_sum = F.cross_entropy( + local_preds.float(), + local_labels, + weight=class_weights, + reduction="sum", + ) + + if class_weights is None: + # Sum the actual local voxel counts across spatial shards. We use + # an all-reduced count instead of numel()*num_shards because shard + # sizes can differ at chunk boundaries. + local_voxel_count = local_ce_sum.new_tensor(float(local_labels.numel())) + global_normalizer = SpatialAllReduce.apply( + local_voxel_count, spatial_mesh + ) + else: + # Weighted CE divides by sum(weight[target_i]) over all voxels. + # Build that denominator from the local label histogram, then + # all-reduce it across the spatial mesh. + local_class_counts = torch.bincount( + local_labels.reshape(-1), minlength=class_weights.numel() + ).to(dtype=local_ce_sum.dtype) + local_weight_sum = torch.dot( + local_class_counts, class_weights.to(dtype=local_ce_sum.dtype) + ) + global_normalizer = SpatialAllReduce.apply(local_weight_sum, spatial_mesh) + + # Sum the local CE numerators from each spatial shard to get the global CE + # numerator, then divide by the matching global denominator. + global_ce_sum = SpatialAllReduce.apply(local_ce_sum, spatial_mesh) + # Clamp to avoid a divide-by-zero in degenerate cases. + return global_ce_sum / global_normalizer.clamp_min( + torch.finfo(global_ce_sum.dtype).eps + ) diff --git a/ScaFFold/utils/trainer.py b/ScaFFold/utils/trainer.py index 654cc19..f0f77d7 100644 --- a/ScaFFold/utils/trainer.py +++ b/ScaFFold/utils/trainer.py @@ -12,6 +12,7 @@ # # SPDX-License-Identifier: (Apache-2.0) +# Standard library import math import os import shutil @@ -22,6 +23,7 @@ import torch import torch.nn as nn import torch.nn.functional as F +import torch.distributed as dist from distconv import DCTensor from torch import optim from torch.utils.data import DataLoader @@ -29,12 +31,10 @@ from ScaFFold.utils.checkpointing import CheckpointManager from ScaFFold.utils.data_loading import FractalDataset, SpatialShardSpec -from ScaFFold.utils.data_types import AMP_DTYPE, VOLUME_DTYPE -from ScaFFold.utils.dice_score import ( - SpatialAllReduce, - compute_sharded_dice, -) +from ScaFFold.utils.data_types import AMP_DTYPE +from ScaFFold.utils.dice_score import compute_sharded_dice from ScaFFold.utils.distributed import get_local_rank, get_world_rank, get_world_size +from ScaFFold.utils.losses import compute_sharded_cross_entropy_loss # Local from ScaFFold.utils.evaluate import evaluate @@ -72,6 +72,7 @@ def __init__(self, model, config, device, log): self.scheduler = None self.grad_scaler = None self.criterion = None + self.ce_class_weights = None self.global_step = 0 self.start_epoch = -1 self.ps = getattr(self.config, "_parallel_strategy", None) @@ -183,6 +184,70 @@ def create_dataloaders(self): "Reduce batch_size or adjust validation sharding." ) + def _sample_ce_weight_indices(self): + """Pick a small, deterministic subset of masks to estimate CE weights.""" + requested_samples = int(getattr(self.config, "ce_weight_num_samples", 8) or 8) + sample_count = min(max(requested_samples, 1), self.n_train) + if sample_count == self.n_train: + return list(range(self.n_train)) + + return torch.linspace(0, self.n_train - 1, steps=sample_count).long().tolist() + + def _compute_ce_class_weights(self): + """ + Estimate background vs foreground CE weights from a few training masks. + + Background keeps its own inverse-frequency weight, and every non-zero + fractal class shares the foreground weight derived from the aggregate + non-empty voxel count. + """ + + num_classes = self.config.n_categories + 1 + class_weights = torch.ones(num_classes, device=self.device, dtype=torch.float32) + + if self.n_train == 0: + self.log.warning( + "Training set is empty while computing CE class weights. Falling back to uniform weights." + ) + return class_weights + + sample_indices = self._sample_ce_weight_indices() + sampled_class_counts = torch.zeros(num_classes, dtype=torch.long) + + for sample_idx in sample_indices: + mask = self.train_set[sample_idx]["mask"] + sampled_class_counts += torch.bincount( + mask.reshape(-1), minlength=num_classes + ) + + # The dataset may already return only this rank's local spatial shard, + # so combine per-rank counts before deriving the global CE weights. + sampled_class_counts = sampled_class_counts.to(device=self.device) + if self.config.dist: + dist.all_reduce(sampled_class_counts, op=dist.ReduceOp.SUM) + + background_voxels = int(sampled_class_counts[0].item()) + foreground_voxels = int(sampled_class_counts[1:].sum().item()) + + if background_voxels > 0 and foreground_voxels > 0: + total_voxels = background_voxels + foreground_voxels + class_weights[0] = total_voxels / background_voxels + class_weights[1:] = total_voxels / foreground_voxels + else: + self.log.warning( + "Sampled masks did not contain both background and foreground voxels. Falling back to uniform CE weights." + ) + + if not self.config.dist or self.world_rank == 0: + self.log.info( + f"CE weights estimated from {len(sample_indices)} training masks " + f"(indices={sample_indices}): background_voxels={background_voxels} " + f"foreground_voxels={foreground_voxels} " + f"weights={class_weights.detach().cpu().tolist()}" + ) + + return class_weights + def setup_training_components(self): """Set up the optimizer, scheduler, gradient scaler, and loss function.""" # Set up optimizer @@ -221,10 +286,12 @@ def setup_training_components(self): # Set up loss function self.criterion = ( - nn.CrossEntropyLoss() + nn.CrossEntropyLoss(weight=self._compute_ce_class_weights()).to(self.device) if self.config.n_categories + 1 > 1 - else nn.BCEWithLogitsLoss() + else nn.BCEWithLogitsLoss().to(self.device) ) + if isinstance(self.criterion, nn.CrossEntropyLoss): + self.ce_class_weights = self.criterion.weight self.log.info( f"Optimizer: {self.optimizer}, Scheduler: {self.scheduler}, AMP dtype: {self.amp_dtype}, Gradient Scaler Enabled: {self.use_grad_scaler}" @@ -464,24 +531,15 @@ def warmup(self): # Calculate CE and Dice loss in single precision for numerical stability. with torch.autocast(**self._autocast_kwargs(enabled=False)): - # Compute global CE loss from sharded CE loss - local_ce_sum = F.cross_entropy( - local_preds.float(), local_labels, reduction="sum" + loss_ce = compute_sharded_cross_entropy_loss( + local_preds, + local_labels, + self.spatial_mesh, + self.config.dc_num_shards, + self.amp_device_type, + self.ce_class_weights, ) - global_ce_sum = SpatialAllReduce.apply( - local_ce_sum, self.spatial_mesh - ) - local_voxel_count = torch.tensor( - float(local_labels.numel()), - device=local_labels.device, - dtype=VOLUME_DTYPE, - ) - global_total_voxels = SpatialAllReduce.apply( - local_voxel_count, self.spatial_mesh - ) - loss_ce = global_ce_sum / global_total_voxels - # Compute global dice loss from sharded dice loss local_preds_softmax = F.softmax(local_preds.float(), dim=1) local_labels_one_hot = ( F.one_hot( @@ -659,26 +717,15 @@ def train(self): # Calculate CE and Dice loss in single precision for numerical stability. with torch.autocast(**self._autocast_kwargs(enabled=False)): - # Compute global CE loss from sharded CE loss - local_ce_sum = F.cross_entropy( - local_preds.float(), + loss_ce = compute_sharded_cross_entropy_loss( + local_preds, local_labels, - reduction="sum", - ) - global_ce_sum = SpatialAllReduce.apply( - local_ce_sum, self.spatial_mesh - ) - local_voxel_count = torch.tensor( - float(local_labels.numel()), - device=local_labels.device, - dtype=VOLUME_DTYPE, - ) - global_total_voxels = SpatialAllReduce.apply( - local_voxel_count, self.spatial_mesh + self.spatial_mesh, + self.config.dc_num_shards, + self.amp_device_type, + self.ce_class_weights, ) - loss_ce = global_ce_sum / global_total_voxels - # Compute global dice loss from sharded dice loss local_preds_softmax = F.softmax( local_preds.float(), dim=1 ) From c0807d8f332424a9e983b0cd7830c4ccd696496f Mon Sep 17 00:00:00 2001 From: Patrick Miles Date: Wed, 22 Apr 2026 16:06:40 -0700 Subject: [PATCH 02/11] ruff --- ScaFFold/utils/trainer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ScaFFold/utils/trainer.py b/ScaFFold/utils/trainer.py index f0f77d7..d1283a2 100644 --- a/ScaFFold/utils/trainer.py +++ b/ScaFFold/utils/trainer.py @@ -21,9 +21,9 @@ # Third party import torch +import torch.distributed as dist import torch.nn as nn import torch.nn.functional as F -import torch.distributed as dist from distconv import DCTensor from torch import optim from torch.utils.data import DataLoader @@ -34,10 +34,10 @@ from ScaFFold.utils.data_types import AMP_DTYPE from ScaFFold.utils.dice_score import compute_sharded_dice from ScaFFold.utils.distributed import get_local_rank, get_world_rank, get_world_size -from ScaFFold.utils.losses import compute_sharded_cross_entropy_loss # Local from ScaFFold.utils.evaluate import evaluate +from ScaFFold.utils.losses import compute_sharded_cross_entropy_loss from ScaFFold.utils.perf_measure import adiak_value, begin_code_region, end_code_region from ScaFFold.utils.utils import gather_and_print_mem From 194b915a249f771fad30517181397a6a8963f1ef Mon Sep 17 00:00:00 2001 From: Patrick Miles Date: Wed, 22 Apr 2026 16:21:15 -0700 Subject: [PATCH 03/11] missing import --- ScaFFold/utils/trainer.py | 1 + 1 file changed, 1 insertion(+) diff --git a/ScaFFold/utils/trainer.py b/ScaFFold/utils/trainer.py index d1283a2..b2053a8 100644 --- a/ScaFFold/utils/trainer.py +++ b/ScaFFold/utils/trainer.py @@ -17,6 +17,7 @@ import os import shutil import time +import math from pathlib import Path # Third party From a6b7c85e31c2a31a67a49e984b84d647af1d3b17 Mon Sep 17 00:00:00 2001 From: Patrick Miles Date: Wed, 22 Apr 2026 16:21:44 -0700 Subject: [PATCH 04/11] ruff --- ScaFFold/utils/trainer.py | 1 - 1 file changed, 1 deletion(-) diff --git a/ScaFFold/utils/trainer.py b/ScaFFold/utils/trainer.py index b2053a8..d1283a2 100644 --- a/ScaFFold/utils/trainer.py +++ b/ScaFFold/utils/trainer.py @@ -17,7 +17,6 @@ import os import shutil import time -import math from pathlib import Path # Third party From 1ea83f62a2526e62ad3b4062b9393e02a7dfa8f9 Mon Sep 17 00:00:00 2001 From: Patrick Miles Date: Wed, 6 May 2026 08:16:45 -0700 Subject: [PATCH 05/11] ruff --- ScaFFold/utils/losses.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/ScaFFold/utils/losses.py b/ScaFFold/utils/losses.py index bddc55a..e89a25b 100644 --- a/ScaFFold/utils/losses.py +++ b/ScaFFold/utils/losses.py @@ -58,9 +58,7 @@ def compute_sharded_cross_entropy_loss( # an all-reduced count instead of numel()*num_shards because shard # sizes can differ at chunk boundaries. local_voxel_count = local_ce_sum.new_tensor(float(local_labels.numel())) - global_normalizer = SpatialAllReduce.apply( - local_voxel_count, spatial_mesh - ) + global_normalizer = SpatialAllReduce.apply(local_voxel_count, spatial_mesh) else: # Weighted CE divides by sum(weight[target_i]) over all voxels. # Build that denominator from the local label histogram, then From 25809656b365cf09846581fec2c98db6abce30a4 Mon Sep 17 00:00:00 2001 From: Patrick Miles Date: Wed, 6 May 2026 08:23:55 -0700 Subject: [PATCH 06/11] fix missing volume dtype --- ScaFFold/utils/trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ScaFFold/utils/trainer.py b/ScaFFold/utils/trainer.py index d1283a2..013807e 100644 --- a/ScaFFold/utils/trainer.py +++ b/ScaFFold/utils/trainer.py @@ -31,7 +31,7 @@ from ScaFFold.utils.checkpointing import CheckpointManager from ScaFFold.utils.data_loading import FractalDataset, SpatialShardSpec -from ScaFFold.utils.data_types import AMP_DTYPE +from ScaFFold.utils.data_types import AMP_DTYPE, VOLUME_DTYPE from ScaFFold.utils.dice_score import compute_sharded_dice from ScaFFold.utils.distributed import get_local_rank, get_world_rank, get_world_size From 42ab124f1287fbffe07a32acdfa6ac9a2c3a27ed Mon Sep 17 00:00:00 2001 From: Patrick R Miles <78748866+PatrickRMiles@users.noreply.github.com> Date: Wed, 6 May 2026 10:44:42 -0700 Subject: [PATCH 07/11] remove default ce_weight_num_samples Co-authored-by: Michael McKinsey --- ScaFFold/utils/config_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ScaFFold/utils/config_utils.py b/ScaFFold/utils/config_utils.py index 98421e0..44102bb 100644 --- a/ScaFFold/utils/config_utils.py +++ b/ScaFFold/utils/config_utils.py @@ -71,7 +71,7 @@ def __init__(self, config_dict): self.checkpoint_dir = config_dict["checkpoint_dir"] self.normalize = config_dict["normalize"] self.warmup_batches = config_dict.get("warmup_batches") - self.ce_weight_num_samples = config_dict.get("ce_weight_num_samples", 8) + self.ce_weight_num_samples = config_dict.get("ce_weight_num_samples") self.dataset_reuse_enforce_commit_id = config_dict[ "dataset_reuse_enforce_commit_id" ] From b289d65fde352fda793713099f705bad0f059748 Mon Sep 17 00:00:00 2001 From: Patrick R Miles <78748866+PatrickRMiles@users.noreply.github.com> Date: Wed, 6 May 2026 10:44:59 -0700 Subject: [PATCH 08/11] remove default ce_weight_num_samples in trainer Co-authored-by: Michael McKinsey --- ScaFFold/utils/trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ScaFFold/utils/trainer.py b/ScaFFold/utils/trainer.py index 013807e..e4249c6 100644 --- a/ScaFFold/utils/trainer.py +++ b/ScaFFold/utils/trainer.py @@ -186,7 +186,7 @@ def create_dataloaders(self): def _sample_ce_weight_indices(self): """Pick a small, deterministic subset of masks to estimate CE weights.""" - requested_samples = int(getattr(self.config, "ce_weight_num_samples", 8) or 8) + requested_samples = int(self.config["ce_weight_num_samples"]) sample_count = min(max(requested_samples, 1), self.n_train) if sample_count == self.n_train: return list(range(self.n_train)) From 464c191e64257e5259d24e8a8e12f2274350e0ab Mon Sep 17 00:00:00 2001 From: Patrick Miles Date: Wed, 6 May 2026 11:47:32 -0700 Subject: [PATCH 09/11] move ce loss helpers to losses.py --- ScaFFold/utils/losses.py | 79 +++++++++++++++++++++++++++++++++ ScaFFold/utils/trainer.py | 91 +++++++++------------------------------ 2 files changed, 99 insertions(+), 71 deletions(-) diff --git a/ScaFFold/utils/losses.py b/ScaFFold/utils/losses.py index e89a25b..0403028 100644 --- a/ScaFFold/utils/losses.py +++ b/ScaFFold/utils/losses.py @@ -13,11 +13,90 @@ # SPDX-License-Identifier: (Apache-2.0) import torch +import torch.distributed as dist import torch.nn.functional as F from ScaFFold.utils.dice_score import SpatialAllReduce +def _sample_ce_weight_indices(n_train, requested_samples): + """Pick a small, deterministic subset of masks to estimate CE weights.""" + if n_train <= 0: + return [] + + sample_count = min(max(int(requested_samples or 8), 1), n_train) + if sample_count == n_train: + return list(range(n_train)) + + return torch.linspace(0, n_train - 1, steps=sample_count).long().tolist() + + +def _compute_ce_class_weights( + train_set, + n_train, + n_categories, + device, + requested_samples=8, + dist_enabled=False, + world_rank=0, + log=None, +): + """ + Estimate background vs foreground CE weights from a few training masks. + + Background keeps its own inverse-frequency weight, and every non-zero + fractal class shares the foreground weight derived from the aggregate + non-empty voxel count. + """ + + num_classes = n_categories + 1 + class_weights = torch.ones(num_classes, device=device, dtype=torch.float32) + + if n_train == 0: + if log is not None: + log.warning( + "Training set is empty while computing CE class weights. Falling back to uniform weights." + ) + return class_weights + + sample_indices = _sample_ce_weight_indices(n_train, requested_samples) + sampled_class_counts = torch.zeros(num_classes, dtype=torch.long) + + for sample_idx in sample_indices: + mask = train_set[sample_idx]["mask"] + sampled_class_counts += torch.bincount( + mask.reshape(-1), minlength=num_classes + ) + + # The dataset may already return only this rank's local spatial shard, + # so combine per-rank counts before deriving the global CE weights. + sampled_class_counts = sampled_class_counts.to(device=device) + if dist_enabled: + dist.all_reduce(sampled_class_counts, op=dist.ReduceOp.SUM) + + background_voxels = int(sampled_class_counts[0].item()) + foreground_voxels = int(sampled_class_counts[1:].sum().item()) + + if background_voxels > 0 and foreground_voxels > 0: + total_voxels = background_voxels + foreground_voxels + class_weights[0] = total_voxels / background_voxels + class_weights[1:] = total_voxels / foreground_voxels + elif log is not None: + log.warning( + "Sampled masks did not contain both background and foreground voxels. Falling back to uniform CE weights." + ) + + if log is not None and (not dist_enabled or world_rank == 0): + log.info( + f"CE weights estimated from {len(sample_indices)} training masks " + f"(indices={sample_indices}): background_voxels={background_voxels} " + f"foreground_voxels={foreground_voxels} " + f"weights={class_weights.detach().cpu().tolist()}" + ) + + return class_weights + + def compute_sharded_cross_entropy_loss( local_preds, local_labels, diff --git a/ScaFFold/utils/trainer.py b/ScaFFold/utils/trainer.py index e4249c6..645a702 100644 --- a/ScaFFold/utils/trainer.py +++ b/ScaFFold/utils/trainer.py @@ -21,7 +21,6 @@ # Third party import torch -import torch.distributed as dist import torch.nn as nn import torch.nn.functional as F from distconv import DCTensor @@ -37,7 +36,10 @@ # Local from ScaFFold.utils.evaluate import evaluate -from ScaFFold.utils.losses import compute_sharded_cross_entropy_loss +from ScaFFold.utils.losses import ( + _compute_ce_class_weights, + compute_sharded_cross_entropy_loss, +) from ScaFFold.utils.perf_measure import adiak_value, begin_code_region, end_code_region from ScaFFold.utils.utils import gather_and_print_mem @@ -184,70 +186,6 @@ def create_dataloaders(self): "Reduce batch_size or adjust validation sharding." ) - def _sample_ce_weight_indices(self): - """Pick a small, deterministic subset of masks to estimate CE weights.""" - requested_samples = int(self.config["ce_weight_num_samples"]) - sample_count = min(max(requested_samples, 1), self.n_train) - if sample_count == self.n_train: - return list(range(self.n_train)) - - return torch.linspace(0, self.n_train - 1, steps=sample_count).long().tolist() - - def _compute_ce_class_weights(self): - """ - Estimate background vs foreground CE weights from a few training masks. - - Background keeps its own inverse-frequency weight, and every non-zero - fractal class shares the foreground weight derived from the aggregate - non-empty voxel count. - """ - - num_classes = self.config.n_categories + 1 - class_weights = torch.ones(num_classes, device=self.device, dtype=torch.float32) - - if self.n_train == 0: - self.log.warning( - "Training set is empty while computing CE class weights. Falling back to uniform weights." - ) - return class_weights - - sample_indices = self._sample_ce_weight_indices() - sampled_class_counts = torch.zeros(num_classes, dtype=torch.long) - - for sample_idx in sample_indices: - mask = self.train_set[sample_idx]["mask"] - sampled_class_counts += torch.bincount( - mask.reshape(-1), minlength=num_classes - ) - - # The dataset may already return only this rank's local spatial shard, - # so combine per-rank counts before deriving the global CE weights. - sampled_class_counts = sampled_class_counts.to(device=self.device) - if self.config.dist: - dist.all_reduce(sampled_class_counts, op=dist.ReduceOp.SUM) - - background_voxels = int(sampled_class_counts[0].item()) - foreground_voxels = int(sampled_class_counts[1:].sum().item()) - - if background_voxels > 0 and foreground_voxels > 0: - total_voxels = background_voxels + foreground_voxels - class_weights[0] = total_voxels / background_voxels - class_weights[1:] = total_voxels / foreground_voxels - else: - self.log.warning( - "Sampled masks did not contain both background and foreground voxels. Falling back to uniform CE weights." - ) - - if not self.config.dist or self.world_rank == 0: - self.log.info( - f"CE weights estimated from {len(sample_indices)} training masks " - f"(indices={sample_indices}): background_voxels={background_voxels} " - f"foreground_voxels={foreground_voxels} " - f"weights={class_weights.detach().cpu().tolist()}" - ) - - return class_weights - def setup_training_components(self): """Set up the optimizer, scheduler, gradient scaler, and loss function.""" # Set up optimizer @@ -285,11 +223,22 @@ def setup_training_components(self): self.grad_scaler = torch.amp.GradScaler("cuda", enabled=self.use_grad_scaler) # Set up loss function - self.criterion = ( - nn.CrossEntropyLoss(weight=self._compute_ce_class_weights()).to(self.device) - if self.config.n_categories + 1 > 1 - else nn.BCEWithLogitsLoss().to(self.device) - ) + if self.config.n_categories + 1 > 1: + ce_class_weights = _compute_ce_class_weights( + train_set=self.train_set, + n_train=self.n_train, + n_categories=self.config.n_categories, + device=self.device, + requested_samples=self.config.ce_weight_num_samples, + dist_enabled=self.config.dist, + world_rank=self.world_rank, + log=self.log, + ) + self.criterion = nn.CrossEntropyLoss(weight=ce_class_weights).to( + self.device + ) + else: + self.criterion = nn.BCEWithLogitsLoss().to(self.device) if isinstance(self.criterion, nn.CrossEntropyLoss): self.ce_class_weights = self.criterion.weight From 3be05dd5a508a383a13ed4e5d920f77761283c72 Mon Sep 17 00:00:00 2001 From: Patrick Miles Date: Wed, 6 May 2026 11:57:44 -0700 Subject: [PATCH 10/11] sample by fraction of total rather than hard number --- ScaFFold/configs/benchmark_default.yml | 2 +- ScaFFold/configs/benchmark_testing.yml | 2 +- ScaFFold/utils/config_utils.py | 4 +++- ScaFFold/utils/losses.py | 19 ++++++++++++++----- ScaFFold/utils/trainer.py | 2 +- 5 files changed, 20 insertions(+), 9 deletions(-) diff --git a/ScaFFold/configs/benchmark_default.yml b/ScaFFold/configs/benchmark_default.yml index 14e5f62..cc0d938 100644 --- a/ScaFFold/configs/benchmark_default.yml +++ b/ScaFFold/configs/benchmark_default.yml @@ -34,6 +34,6 @@ checkpoint_dir: "checkpoints" # Subfolder in which to save training checkpo loss_freq: 1 # Number of epochs between logging the overall loss. normalize: 1 # Cateogry search normalization parameter warmup_batches: 5 # How many warmup batches per rank to run before training. -ce_weight_num_samples: 8 # How many training masks to sample when estimating background vs foreground CE weights. +ce_weight_sample_fraction: 0.1 # Fraction of training masks to sample when estimating background vs foreground CE weights. dataset_reuse_enforce_commit_id: 0 # Enforce matching commit IDs for dataset reuse. target_dice: 0.95 diff --git a/ScaFFold/configs/benchmark_testing.yml b/ScaFFold/configs/benchmark_testing.yml index 15cb264..5167de1 100644 --- a/ScaFFold/configs/benchmark_testing.yml +++ b/ScaFFold/configs/benchmark_testing.yml @@ -34,6 +34,6 @@ checkpoint_dir: "checkpoints" # Subfolder in which to save training checkpo loss_freq: 1 # Number of epochs between logging the overall loss. normalize: 1 # Cateogry search normalization parameter warmup_batches: 5 # How many warmup batches per rank to run before training. -ce_weight_num_samples: 8 # How many training masks to sample when estimating background vs foreground CE weights. +ce_weight_sample_fraction: 0.1 # Fraction of training masks to sample when estimating background vs foreground CE weights. dataset_reuse_enforce_commit_id: 0 # Enforce matching commit IDs for dataset reuse. target_dice: 0.95 diff --git a/ScaFFold/utils/config_utils.py b/ScaFFold/utils/config_utils.py index 44102bb..37de3c4 100644 --- a/ScaFFold/utils/config_utils.py +++ b/ScaFFold/utils/config_utils.py @@ -71,7 +71,9 @@ def __init__(self, config_dict): self.checkpoint_dir = config_dict["checkpoint_dir"] self.normalize = config_dict["normalize"] self.warmup_batches = config_dict.get("warmup_batches") - self.ce_weight_num_samples = config_dict.get("ce_weight_num_samples") + self.ce_weight_sample_fraction = config_dict.get( + "ce_weight_sample_fraction", 0.1 + ) self.dataset_reuse_enforce_commit_id = config_dict[ "dataset_reuse_enforce_commit_id" ] diff --git a/ScaFFold/utils/losses.py b/ScaFFold/utils/losses.py index 0403028..26978f8 100644 --- a/ScaFFold/utils/losses.py +++ b/ScaFFold/utils/losses.py @@ -12,6 +12,8 @@ # # SPDX-License-Identifier: (Apache-2.0) +import math + import torch import torch.distributed as dist import torch.nn.functional as F @@ -19,12 +21,18 @@ from ScaFFold.utils.dice_score import SpatialAllReduce -def _sample_ce_weight_indices(n_train, requested_samples): +def _sample_ce_weight_indices(n_train, sample_fraction): """Pick a small, deterministic subset of masks to estimate CE weights.""" if n_train <= 0: return [] - sample_count = min(max(int(requested_samples or 8), 1), n_train) + if sample_fraction is None: + sample_fraction = 0.1 + + sample_count = min( + max(math.ceil(n_train * float(sample_fraction)), 1), + n_train, + ) if sample_count == n_train: return list(range(n_train)) @@ -36,7 +44,7 @@ def _compute_ce_class_weights( n_train, n_categories, device, - requested_samples=8, + sample_fraction=0.1, dist_enabled=False, world_rank=0, log=None, @@ -59,7 +67,7 @@ def _compute_ce_class_weights( ) return class_weights - sample_indices = _sample_ce_weight_indices(n_train, requested_samples) + sample_indices = _sample_ce_weight_indices(n_train, sample_fraction) sampled_class_counts = torch.zeros(num_classes, dtype=torch.long) for sample_idx in sample_indices: @@ -89,7 +97,8 @@ def _compute_ce_class_weights( if log is not None and (not dist_enabled or world_rank == 0): log.info( f"CE weights estimated from {len(sample_indices)} training masks " - f"(indices={sample_indices}): background_voxels={background_voxels} " + f"(sample_fraction={sample_fraction}, indices={sample_indices}): " + f"background_voxels={background_voxels} " f"foreground_voxels={foreground_voxels} " f"weights={class_weights.detach().cpu().tolist()}" ) diff --git a/ScaFFold/utils/trainer.py b/ScaFFold/utils/trainer.py index 645a702..1a1d2e0 100644 --- a/ScaFFold/utils/trainer.py +++ b/ScaFFold/utils/trainer.py @@ -229,7 +229,7 @@ def setup_training_components(self): n_train=self.n_train, n_categories=self.config.n_categories, device=self.device, - requested_samples=self.config.ce_weight_num_samples, + sample_fraction=self.config.ce_weight_sample_fraction, dist_enabled=self.config.dist, world_rank=self.world_rank, log=self.log, From 9168123f7070406bf59cd5bbddda47b4b45e88ca Mon Sep 17 00:00:00 2001 From: Patrick Miles Date: Wed, 6 May 2026 11:59:12 -0700 Subject: [PATCH 11/11] ruff --- ScaFFold/utils/losses.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/ScaFFold/utils/losses.py b/ScaFFold/utils/losses.py index 26978f8..3869b8b 100644 --- a/ScaFFold/utils/losses.py +++ b/ScaFFold/utils/losses.py @@ -72,9 +72,7 @@ def _compute_ce_class_weights( for sample_idx in sample_indices: mask = train_set[sample_idx]["mask"] - sampled_class_counts += torch.bincount( - mask.reshape(-1), minlength=num_classes - ) + sampled_class_counts += torch.bincount(mask.reshape(-1), minlength=num_classes) # The dataset may already return only this rank's local spatial shard, # so combine per-rank counts before deriving the global CE weights.