Skip to content

Improve AMP stability#60

Merged
michaelmckinsey1 merged 10 commits intoLBANN:mainfrom
michaelmckinsey1:fix-amp
Apr 30, 2026
Merged

Improve AMP stability#60
michaelmckinsey1 merged 10 commits intoLBANN:mainfrom
michaelmckinsey1:fix-amp

Conversation

@michaelmckinsey1
Copy link
Copy Markdown
Collaborator

  • Compute dice loss in FP32 to avoid val_dice_score=nan
  • Use BF16 to prevent numerical overflow.
  • The datatype for AMP is now centrally configurable in ScaFFold/utils/data_types.py for testing`.

Comment thread ScaFFold/utils/evaluate.py Outdated
Comment on lines +98 to +109
# --- 2. Sharded Dice Loss ---
mask_pred_probs = F.softmax(local_preds.float(), dim=1)
mask_true_onehot = (
F.one_hot(local_labels, n_categories + 1)
.permute(0, 4, 1, 2, 3)
.float()
)

# Dice loss uses probabilities
dice_score_probs = compute_sharded_dice(
mask_pred_probs, mask_true_onehot, spatial_mesh
)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like this got inserted in the middle of the CE loss calc. Can you move it back to being after CE_loss = ...? This should also shrink the diff + make it more clear what the actual changes are here (not casting local_preds to float)

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

Comment thread ScaFFold/utils/trainer.py Outdated

# Set up gradient scaler for AMP (Automatic Mixed Precision)
self.grad_scaler = torch.amp.GradScaler("cuda", enabled=self.config.torch_amp)
self.use_grad_scaler = self.config.torch_amp and self.amp_dtype == torch.float16
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the and self.amp_dtype == torch.float16 basically just catching the case where we're NOT running with bf16? Would it be better to write that explicitly, like and self.amp_dtype != torch.bfloat16?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes basically the options are bf16 or f16. I don't think bf/f8 would work. I am ok with making this change

Comment thread ScaFFold/utils/trainer.py Outdated
Comment on lines +450 to +461
# 2. Sharded Dice Loss
local_preds_softmax = F.softmax(local_preds.float(), dim=1)
local_labels_one_hot = (
F.one_hot(
local_labels, num_classes=self.config.n_categories + 1
)
.permute(0, 4, 1, 2, 3)
.float()
)
dice_scores = compute_sharded_dice(
local_preds_softmax, local_labels_one_hot, self.spatial_mesh
)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as in evaluate.py, I think this should stay after global_ce_sum = ...

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

Comment thread ScaFFold/utils/trainer.py
Comment thread ScaFFold/utils/trainer.py Outdated
Comment on lines +648 to +665
# Sharded Dice Loss
local_preds_softmax = F.softmax(
local_preds.float(), dim=1
)
local_labels_one_hot = (
F.one_hot(
local_labels,
num_classes=self.config.n_categories + 1,
)
.permute(0, 4, 1, 2, 3)
.float()
)
# Compute sharded dice
dice_scores = compute_sharded_dice(
local_preds_softmax,
local_labels_one_hot,
self.spatial_mesh,
)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as evaluate and warmup, this should come after global_ce_sum = ...

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

@michaelmckinsey1 michaelmckinsey1 merged commit 574c081 into LBANN:main Apr 30, 2026
1 check passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants