Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions ScaFFold/utils/data_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,13 @@
# SPDX-License-Identifier: (Apache-2.0)

import numpy as np
import torch

DEFAULT_NP_DTYPE = np.float64
# Masks are values 0 <= x <= n_categories
MASK_DTYPE = np.uint16
# Volumes/img are 0 <= x <= 1
VOLUME_DTYPE = np.float32

# Shared AMP dtype selection for torch.autocast.
AMP_DTYPE = torch.bfloat16
71 changes: 35 additions & 36 deletions ScaFFold/utils/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from distconv import DCTensor
from tqdm import tqdm

from ScaFFold.utils.data_types import AMP_DTYPE, VOLUME_DTYPE
from ScaFFold.utils.dice_score import (
SpatialAllReduce,
compute_sharded_dice,
Expand All @@ -29,13 +30,16 @@
def evaluate(
net, dataloader, device, amp, primary, criterion, n_categories, parallel_strategy
):

def foreground_dice_mean(dice_scores):
if dice_scores.size(1) > 1:
return dice_scores[:, 1:].mean().item()
return dice_scores.mean().item()

net.eval()
autocast_device_type = device.type if device.type != "mps" else "cpu"
autocast_kwargs = {"device_type": autocast_device_type, "enabled": amp}
if amp:
autocast_kwargs["dtype"] = AMP_DTYPE
num_val_batches = len(dataloader)
total_dice_score = 0.0
processed_batches = 0
Expand All @@ -47,7 +51,7 @@ def foreground_dice_mean(dice_scores):
f"[eval] ps.shard_dim={parallel_strategy.shard_dim} num_shards={parallel_strategy.num_shards}"
)

with torch.autocast(device.type if device.type != "mps" else "cpu", enabled=amp):
with torch.autocast(**autocast_kwargs):
val_loss_epoch = 0.0
for batch in tqdm(
dataloader,
Expand Down Expand Up @@ -85,44 +89,39 @@ def foreground_dice_mean(dice_scores):
if local_preds.size(0) == 0 or local_labels.size(0) == 0:
continue

# --- 1. Sharded CE Loss ---
with torch.autocast(
device.type if device.type != "mps" else "cpu", enabled=False
):
# Calculate CE and Dice loss in single precision for numerical stability.
with torch.autocast(device_type=autocast_device_type, enabled=False):
# Compute global CE loss from sharded CE loss
local_ce_sum = F.cross_entropy(
local_preds.float(), local_labels, reduction="sum"
)
global_ce_sum = SpatialAllReduce.apply(local_ce_sum, spatial_mesh)

# Divide by the actual global voxel count to handle uneven shards.
local_voxel_count = torch.tensor(
float(local_labels.numel()),
device=local_labels.device,
dtype=torch.float32,
)
global_total_voxels = SpatialAllReduce.apply(
local_voxel_count, spatial_mesh
)
CE_loss = global_ce_sum / global_total_voxels

# --- 2. Format Predictions & Labels (Strictly Multiclass) ---
mask_pred_probs = F.softmax(local_preds, dim=1).float()
mask_true_onehot = (
F.one_hot(local_labels, n_categories + 1).permute(0, 4, 1, 2, 3).float()
)
global_ce_sum = SpatialAllReduce.apply(local_ce_sum, spatial_mesh)
local_voxel_count = torch.tensor(
float(local_labels.numel()),
device=local_labels.device,
dtype=VOLUME_DTYPE,
)
global_total_voxels = SpatialAllReduce.apply(
local_voxel_count, spatial_mesh
)
CE_loss = global_ce_sum / global_total_voxels

# Compute global dice loss from sharded dice loss
mask_pred_probs = F.softmax(local_preds.float(), dim=1)
mask_true_onehot = (
F.one_hot(local_labels, n_categories + 1)
.permute(0, 4, 1, 2, 3)
.float()
)
dice_score_probs = compute_sharded_dice(
mask_pred_probs, mask_true_onehot, spatial_mesh
)
batch_dice_score = foreground_dice_mean(dice_score_probs)

# Dice loss uses probabilities
dice_score_probs = compute_sharded_dice(
mask_pred_probs, mask_true_onehot, spatial_mesh
)
# Eval metric (excluding background class 0)
# dice_score_probs shape is [Batch, Channels].
batch_dice_score = foreground_dice_mean(dice_score_probs)

# --- Combine and Accumulate ---
loss = CE_loss + (1.0 - batch_dice_score)
val_loss_epoch += loss.item()
total_dice_score += batch_dice_score.item()
# Sum global CE Loss and Dice loss
loss = CE_loss + (1.0 - batch_dice_score)
val_loss_epoch += loss.item()
total_dice_score += batch_dice_score
processed_batches += 1

net.train()
Expand Down
182 changes: 95 additions & 87 deletions ScaFFold/utils/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,11 @@

from ScaFFold.utils.checkpointing import CheckpointManager
from ScaFFold.utils.data_loading import FractalDataset, SpatialShardSpec
from ScaFFold.utils.dice_score import SpatialAllReduce, compute_sharded_dice
from ScaFFold.utils.data_types import AMP_DTYPE, VOLUME_DTYPE
from ScaFFold.utils.dice_score import (
SpatialAllReduce,
compute_sharded_dice,
)
from ScaFFold.utils.distributed import get_local_rank, get_world_rank, get_world_size

# Local
Expand All @@ -48,6 +52,9 @@ def __init__(self, model, config, device, log):
self.config = config
self.device = device
self.log = log
self.amp_device_type = self.device.type if self.device.type != "mps" else "cpu"
self.amp_dtype = AMP_DTYPE
self.use_grad_scaler = False
self.world_size = get_world_size(required=self.config.dist)
self.world_rank = get_world_rank(required=self.config.dist)
self.local_rank = get_local_rank(required=self.config.dist)
Expand Down Expand Up @@ -194,7 +201,11 @@ def setup_training_components(self):
)

# Set up gradient scaler for AMP (Automatic Mixed Precision)
self.grad_scaler = torch.amp.GradScaler("cuda", enabled=self.config.torch_amp)
# bfloat does not need grad scaler
self.use_grad_scaler = (
self.config.torch_amp and self.amp_dtype != torch.bfloat16
)
self.grad_scaler = torch.amp.GradScaler("cuda", enabled=self.use_grad_scaler)

# Set up loss function
self.criterion = (
Expand All @@ -204,15 +215,24 @@ def setup_training_components(self):
)

self.log.info(
f"Optimizer: {self.optimizer}, Scheduler: {self.scheduler}, Gradient Scaler Enabled: {self.config.torch_amp}"
f"Optimizer: {self.optimizer}, Scheduler: {self.scheduler}, AMP dtype: {self.amp_dtype}, Gradient Scaler Enabled: {self.use_grad_scaler}"
)

def _autocast_kwargs(self, enabled=None):
Comment thread
PatrickRMiles marked this conversation as resolved.
if enabled is None:
enabled = self.config.torch_amp

kwargs = {"device_type": self.amp_device_type, "enabled": enabled}
if enabled:
kwargs["dtype"] = self.amp_dtype
return kwargs

@staticmethod
def _foreground_dice_mean(dice_scores):
"""Match optimization to the reported validation metric by excluding background."""
if dice_scores.size(1) > 1:
return dice_scores[:, 1:].mean().item()
return dice_scores.mean().item()
return dice_scores[:, 1:].mean()
return dice_scores.mean()


class PyTorchTrainer(BaseTrainer):
Expand Down Expand Up @@ -399,10 +419,7 @@ def warmup(self):
true_masks_dc = DCTensor.from_shard(true_masks, self.ps)
self._get_memsize(images_dc, "Sharded image", self.config.verbose)

with torch.autocast(
self.device.type if self.device.type != "mps" else "cpu",
enabled=self.config.torch_amp,
):
with torch.autocast(**self._autocast_kwargs()):
# Forward on DCTensor
self.log.debug(" warmup: running forward pass")
masks_pred_dc = self.model(images_dc)
Expand All @@ -428,42 +445,41 @@ def warmup(self):
f" warmup: Calculating sharded loss. Mem: {current_mem:.2f} GB."
)

# 1. Sharded Cross Entropy
with torch.autocast(
self.device.type if self.device.type != "mps" else "cpu",
enabled=False,
):
# Calculate CE and Dice loss in single precision for numerical stability.
with torch.autocast(**self._autocast_kwargs(enabled=False)):
# Compute global CE loss from sharded CE loss
local_ce_sum = F.cross_entropy(
local_preds.float(), local_labels, reduction="sum"
)
global_ce_sum = SpatialAllReduce.apply(
local_ce_sum, self.spatial_mesh
)
local_voxel_count = torch.tensor(
float(local_labels.numel()),
device=local_labels.device,
dtype=VOLUME_DTYPE,
)
global_total_voxels = SpatialAllReduce.apply(
local_voxel_count, self.spatial_mesh
)
loss_ce = global_ce_sum / global_total_voxels

# Pass the spatial_mesh directly
global_ce_sum = SpatialAllReduce.apply(local_ce_sum, self.spatial_mesh)

local_voxel_count = torch.tensor(
float(local_labels.numel()),
device=local_labels.device,
dtype=torch.float32,
)
global_total_voxels = SpatialAllReduce.apply(
local_voxel_count, self.spatial_mesh
)
loss_ce = global_ce_sum / global_total_voxels

# 2. Sharded Dice Loss
local_preds_softmax = F.softmax(local_preds, dim=1).float()
local_labels_one_hot = (
F.one_hot(local_labels, num_classes=self.config.n_categories + 1)
.permute(0, 4, 1, 2, 3)
.float()
)
dice_scores = compute_sharded_dice(
local_preds_softmax, local_labels_one_hot, self.spatial_mesh
)
batch_dice_score = self._foreground_dice_mean(dice_scores)
# Compute global dice loss from sharded dice loss
local_preds_softmax = F.softmax(local_preds.float(), dim=1)
local_labels_one_hot = (
F.one_hot(
local_labels, num_classes=self.config.n_categories + 1
)
.permute(0, 4, 1, 2, 3)
.float()
)
dice_scores = compute_sharded_dice(
local_preds_softmax, local_labels_one_hot, self.spatial_mesh
)
batch_dice_score = self._foreground_dice_mean(dice_scores)

# 3. Combine Loss
loss = loss_ce + (1.0 - batch_dice_score)
# Sum global CE Loss and Dice loss
loss = loss_ce + (1.0 - batch_dice_score)

self.log.debug(
" warmup: loss calculation complete. Proceeding to backward pass"
Expand Down Expand Up @@ -592,10 +608,7 @@ def train(self):
images_dc, "Sharded image", self.config.verbose
)

with torch.autocast(
self.device.type if self.device.type != "mps" else "cpu",
enabled=self.config.torch_amp,
):
with torch.autocast(**self._autocast_kwargs()):
# Predict on this batch
torch.cuda.reset_peak_memory_stats()
gather_and_print_mem(self.log, "pre_forward")
Expand Down Expand Up @@ -627,56 +640,51 @@ def train(self):
f"Calculating sharded loss. Mem: {current_mem:.2f} GB."
)

# 1. Sharded Cross Entropy
with torch.autocast(
self.device.type
if self.device.type != "mps"
else "cpu",
enabled=False,
):
# Calculate CE and Dice loss in single precision for numerical stability.
with torch.autocast(**self._autocast_kwargs(enabled=False)):
# Compute global CE loss from sharded CE loss
local_ce_sum = F.cross_entropy(
local_preds.float(),
local_labels,
reduction="sum",
)

# Pass the spatial_mesh directly
global_ce_sum = SpatialAllReduce.apply(
local_ce_sum, self.spatial_mesh
)

local_voxel_count = torch.tensor(
float(local_labels.numel()),
device=local_labels.device,
dtype=torch.float32,
)
global_total_voxels = SpatialAllReduce.apply(
local_voxel_count, self.spatial_mesh
)
loss_ce = global_ce_sum / global_total_voxels

# 2. Sharded Dice Loss
local_preds_softmax = F.softmax(local_preds, dim=1).float()
local_labels_one_hot = (
F.one_hot(
local_labels,
num_classes=self.config.n_categories + 1,
global_ce_sum = SpatialAllReduce.apply(
local_ce_sum, self.spatial_mesh
)
.permute(0, 4, 1, 2, 3)
.float()
)
local_voxel_count = torch.tensor(
float(local_labels.numel()),
device=local_labels.device,
dtype=VOLUME_DTYPE,
)
global_total_voxels = SpatialAllReduce.apply(
local_voxel_count, self.spatial_mesh
)
loss_ce = global_ce_sum / global_total_voxels

# Compute sharded dice using new function
dice_scores = compute_sharded_dice(
local_preds_softmax,
local_labels_one_hot,
self.spatial_mesh,
)
batch_dice_score = self._foreground_dice_mean(dice_scores)
# Compute global dice loss from sharded dice loss
local_preds_softmax = F.softmax(
local_preds.float(), dim=1
)
local_labels_one_hot = (
F.one_hot(
local_labels,
num_classes=self.config.n_categories + 1,
)
.permute(0, 4, 1, 2, 3)
.float()
)
dice_scores = compute_sharded_dice(
local_preds_softmax,
local_labels_one_hot,
self.spatial_mesh,
)
batch_dice_score = self._foreground_dice_mean(
dice_scores
)

# 3. Combine Loss
loss = loss_ce + (1.0 - batch_dice_score)
train_dice_total += batch_dice_score
# Sum global CE Loss and Dice loss
loss = loss_ce + (1.0 - batch_dice_score)
train_dice_total += batch_dice_score

end_code_region("calculate_loss")

Expand Down Expand Up @@ -748,7 +756,7 @@ def train(self):
#
# Write out data for this epoch to train stats csv
#
train_dice = float(train_dice_total / len(self.train_loader))
train_dice = float(train_dice_total.item() / len(self.train_loader))
self.log.info(
f" epoch {epoch} \
| train_dice_loss {train_dice:.6f} (type {type(train_dice)}) \
Expand Down
Loading