Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,10 @@ variance_threshold: 0.15 # Variance threshold for valid fractals. Defa
n_fracts_per_vol: 3 # Number of fractals overlaid in each volume. Default is 3.
val_split: 25 # In percent.
epochs: 100 # Number of training epochs.
learning_rate: .0001 # Learning rate for training.
starting_learning_rate: .01 # Initial learning rate for training.
min_learning_rate: .001 # Minimum learning rate for CosineAnnealingWarmRestarts.
T_0: 100 # Epochs in the first cosine restart cycle.
T_mult: 2 # Restart cycle growth factor.
disable_scheduler: 1 # If 1, disable scheduler during training to use constant LR.
more_determinism: 0 # If 1, improve model training determinism.
datagen_from_scratch: 0 # If 1, delete existing fractals and instances, then regenerate from scratch.
Expand Down
22 changes: 22 additions & 0 deletions ScaFFold/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,28 @@ def main():
type=int,
help="Number of training epochs.",
)
benchmark_parser.add_argument(
"--starting-learning-rate",
type=float,
help="Initial learning rate for training.",
)
benchmark_parser.add_argument(
"--min-learning-rate",
type=float,
help="Minimum learning rate for CosineAnnealingWarmRestarts.",
)
benchmark_parser.add_argument(
"--T-0",
dest="T_0",
type=int,
help="Epochs in the first cosine restart cycle.",
)
benchmark_parser.add_argument(
"--T-mult",
dest="T_mult",
type=int,
help="Restart cycle growth factor.",
)

comm = MPI.COMM_WORLD
rank = comm.Get_rank()
Expand Down
11 changes: 7 additions & 4 deletions ScaFFold/configs/benchmark_default.yml
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@ n_instances_used_per_fractal: 145 # Number of unique instances to pull from eac
problem_scale: 7 # Determines dataset resolution and number of unet layers. Default is 6.
unet_bottleneck_dim: 3 # Power of 2 of the unet bottleneck layer dimension. Default of 3 -> bottleneck layer of size 8.
seed: 42 # Random seed.
batch_size: 1 # Batch sizes for each vol size.
dataloader_num_workers: 4 # Number of DataLoader worker processes per rank.
batch_size: 1 # Batch sizes for each vol size per rank.
dataloader_num_workers: 1 # Number of DataLoader worker processes per rank. More workers will use more memory
optimizer: "ADAM" # "ADAM" is preferred option, otherwise training defautls to RMSProp.
dc_num_shards: [1, 1, 1] # DistConv param: number of shards to divide the tensor into. It's best to choose the fewest ranks needed to fit one sample in GPU memory, since that keeps communication at a minimum
dc_shard_dims: [2, 3, 4] # DistConv param: dimension on which to shard
Expand All @@ -19,8 +19,11 @@ variance_threshold: 0.15 # Variance threshold for valid fractals. Defa
n_fracts_per_vol: 3 # Number of fractals overlaid in each volume. Default is 3.
val_split: 25 # In percent.
epochs: -1 # Number of training epochs.
learning_rate: .0001 # Learning rate for training.
disable_scheduler: 1 # If 1, disable scheduler during training to use constant LR.
starting_learning_rate: 0.1 # Initial learning rate for training.
min_learning_rate: 0.001 # Minimum learning rate for CosineAnnealingWarmRestarts.
T_0: 100 # Epochs in the first cosine restart cycle.
T_mult: 2 # Restart cycle growth factor.
disable_scheduler: 0 # If 1, disable scheduler during training to use constant LR.
more_determinism: 0 # If 1, improve model training determinism.
datagen_from_scratch: 0 # If 1, delete existing fractals and instances, then regenerate from scratch.
train_from_scratch: 1 # If 1, delete existing train stats and checkpoint files. Keep 0 if want to restart runs where we left off.
Expand Down
5 changes: 4 additions & 1 deletion ScaFFold/configs/benchmark_testing.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,10 @@ variance_threshold: 0.15 # Variance threshold for valid fractals. Defa
n_fracts_per_vol: 3 # Number of fractals overlaid in each volume. Default is 3.
val_split: 25 # In percent.
epochs: 10 # Number of training epochs.
learning_rate: .0001 # Learning rate for training.
starting_learning_rate: .0001 # Initial learning rate for training.
min_learning_rate: .0001 # Minimum learning rate for CosineAnnealingWarmRestarts.
T_0: 10 # Epochs in the first cosine restart cycle.
T_mult: 1 # Restart cycle growth factor.
disable_scheduler: 1 # If 1, disable scheduler during training to use constant LR.
more_determinism: 0 # If 1, improve model training determinism.
datagen_from_scratch: 0 # If 1, delete existing fractals and instances, then regenerate from scratch.
Expand Down
5 changes: 4 additions & 1 deletion ScaFFold/utils/config_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,10 @@ def __init__(self, config_dict):
self.seed = config_dict["seed"]
self.dist = bool(config_dict["dist"])
self.framework = config_dict["framework"]
self.learning_rate = config_dict["learning_rate"]
self.starting_learning_rate = config_dict["starting_learning_rate"]
self.min_learning_rate = config_dict["min_learning_rate"]
self.T_0 = config_dict["T_0"]
self.T_mult = config_dict["T_mult"]
self.variance_threshold = config_dict["variance_threshold"]
self.torch_amp = bool(config_dict["torch_amp"])
self.loss_freq = config_dict["loss_freq"]
Expand Down
27 changes: 20 additions & 7 deletions ScaFFold/utils/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,12 @@
def evaluate(
net, dataloader, device, amp, primary, criterion, n_categories, parallel_strategy
):
def foreground_dice_mean(dice_scores):
def foreground_dice_stats(dice_scores):
if dice_scores.size(1) > 1:
return dice_scores[:, 1:].mean().item()
return dice_scores.mean().item()
per_sample_scores = dice_scores[:, 1:].mean(dim=1)
else:
per_sample_scores = dice_scores.mean(dim=1)
return per_sample_scores.sum().item(), per_sample_scores.numel()

net.eval()
autocast_device_type = device.type if device.type != "mps" else "cpu"
Expand All @@ -43,6 +45,7 @@ def foreground_dice_mean(dice_scores):
num_val_batches = len(dataloader)
total_dice_score = 0.0
processed_batches = 0
processed_samples = 0

spatial_mesh = parallel_strategy.device_mesh[parallel_strategy.distconv_dim_names]

Expand Down Expand Up @@ -116,19 +119,29 @@ def foreground_dice_mean(dice_scores):
dice_score_probs = compute_sharded_dice(
mask_pred_probs, mask_true_onehot, spatial_mesh
)
batch_dice_score = foreground_dice_mean(dice_score_probs)
batch_dice_sum, batch_sample_count = foreground_dice_stats(
dice_score_probs
)
batch_dice_score = batch_dice_sum / max(batch_sample_count, 1)

# Sum global CE Loss and Dice loss
loss = CE_loss + (1.0 - batch_dice_score)
val_loss_epoch += loss.item()
total_dice_score += batch_dice_score
total_dice_score += batch_dice_sum
processed_batches += 1
processed_samples += batch_sample_count

net.train()

val_loss_avg = val_loss_epoch / max(processed_batches, 1)
if primary:
print(
f"evaluate.py: dice_score={total_dice_score}, val_loss_epoch={val_loss_epoch}, val_loss_avg={val_loss_avg}, num_val_batches={processed_batches}"
f"evaluate.py: dice_score={total_dice_score}, val_loss_epoch={val_loss_epoch}, val_loss_avg={val_loss_avg}, num_val_batches={processed_batches}, num_val_samples={processed_samples}"
)
return total_dice_score, val_loss_epoch, val_loss_avg, processed_batches
return (
total_dice_score,
val_loss_epoch,
val_loss_avg,
processed_batches,
processed_samples,
)
61 changes: 36 additions & 25 deletions ScaFFold/utils/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ def create_dataset(self):
)
self.n_train = len(self.train_set)
self.n_val = len(self.val_set)
self.log.debug(
self.log.info(
f"Datasets created with n_train={self.n_train}, n_val={self.n_val}"
)

Expand Down Expand Up @@ -173,31 +173,43 @@ def create_dataloaders(self):
self.train_set, sampler=self.train_sampler, **loader_args
)
self.val_loader = DataLoader(
self.val_set, sampler=self.val_sampler, drop_last=True, **loader_args
self.val_set, sampler=self.val_sampler, drop_last=False, **loader_args
)
if len(self.val_loader) == 0:
raise ValueError(
"Validation DataLoader has zero batches. "
f"n_val={self.n_val}, batch_size={self.config.batch_size}, "
f"data_num_replicas={self.data_num_replicas}. "
"Reduce batch_size or adjust validation sharding."
)

def setup_training_components(self):
"""Set up the optimizer, scheduler, gradient scaler, and loss function."""
# Set up optimizer
if self.config.optimizer == "ADAM":
self.log.info("Using ADAM optimizer.")
self.optimizer = optim.Adam(
self.model.parameters(), lr=self.config.learning_rate
self.model.parameters(), lr=self.config.starting_learning_rate
)
elif self.config.optimizer == "SGD":
self.log.info("Using SGD optimizer.")
self.optimizer = optim.SGD(
self.model.parameters(), lr=self.config.learning_rate
self.model.parameters(), lr=self.config.starting_learning_rate
)
else:
self.log.info("Using RMSprop optimizer.")
self.optimizer = optim.RMSprop(
self.model.parameters(), lr=self.config.learning_rate, foreach=True
self.model.parameters(),
lr=self.config.starting_learning_rate,
foreach=True,
)

# Set up learning rate scheduler
self.scheduler = optim.lr_scheduler.ReduceLROnPlateau(
self.optimizer, "max", patience=25
self.scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(
self.optimizer,
T_0=self.config.T_0,
T_mult=self.config.T_mult,
eta_min=self.config.min_learning_rate,
)

# Set up gradient scaler for AMP (Automatic Mixed Precision)
Expand Down Expand Up @@ -234,6 +246,11 @@ def _foreground_dice_mean(dice_scores):
return dice_scores[:, 1:].mean()
return dice_scores.mean()

def _current_learning_rate(self):
if self.optimizer is None or not self.optimizer.param_groups:
return self.config.starting_learning_rate
return self.optimizer.param_groups[0]["lr"]


class PyTorchTrainer(BaseTrainer):
"""
Expand Down Expand Up @@ -403,7 +420,7 @@ def warmup(self):

images = images.to(
device=self.device,
dtype=torch.float32,
dtype=VOLUME_DTYPE,
memory_format=torch.channels_last_3d,
non_blocking=True,
)
Expand Down Expand Up @@ -587,7 +604,7 @@ def train(self):
begin_code_region("image_to_device")
images = images.to(
device=self.device,
dtype=torch.float32,
dtype=VOLUME_DTYPE,
memory_format=torch.channels_last_3d, # NDHWC (channels last) vs NCDHW (channels first)
non_blocking=True,
)
Expand Down Expand Up @@ -720,7 +737,13 @@ def train(self):
#
# Evaluate model on validation set, update LR if necessary
#
dice_sum, val_loss_epoch, val_loss_avg, numbatch = evaluate(
(
dice_sum,
val_loss_epoch,
val_loss_avg,
numbatch,
numsamples,
) = evaluate(
self.model,
self.val_loader,
self.device,
Expand All @@ -730,24 +753,15 @@ def train(self):
self.config.n_categories,
self.config._parallel_strategy,
)
dice_info = torch.tensor([dice_sum, numbatch])
dice_info = torch.tensor([dice_sum, numsamples], dtype=VOLUME_DTYPE)
if self.config.dist:
dice_info = dice_info.to(device=self.device)
torch.distributed.all_reduce(
dice_info, op=torch.distributed.ReduceOp.SUM
)
val_score = dice_info[0].item() / max(dice_info[1].item(), 1)
if not self.config.disable_scheduler:
# The following is true when trying to overfit,
# in which case we only care about train loss
if self.n_train == 1 or "overfit" in self.outfile_path:
self.log.debug(
"WARNING: scheduler step by overall_loss, \
not val_score (n_train==1 or overfit in outfile_path)"
)
self.scheduler.step(overall_loss)
else: # Otherwise, we're really trying to optimize for validation dice score
self.scheduler.step(val_score)
self.scheduler.step()
else:
self.log.debug("scheduler disabled, no LR update this step")

Expand All @@ -758,10 +772,7 @@ def train(self):
#
train_dice = float(train_dice_total.item() / len(self.train_loader))
self.log.info(
f" epoch {epoch} \
| train_dice_loss {train_dice:.6f} (type {type(train_dice)}) \
| val_dice_score {val_score:.6f} \
| lr {self.config.learning_rate:.8f}"
f" epoch {epoch} | train_dice_score {train_dice:.6f} | val_dice_score {val_score:.6f} | lr {self._current_learning_rate():.8f}"
)
self.log.debug(f" writing to csv at {self.outfile_path}")
if self.world_rank == 0:
Expand Down
12 changes: 8 additions & 4 deletions ScaFFold/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,14 +217,18 @@ def main(kwargs_dict: dict = {}):
trainer.spatial_mesh = ps.device_mesh[ps.distconv_dim_names]
num_spatial_dims = len(ps.shard_dim)
trainer.ddp_placements = [Shard(0)] + [Replicate()] * num_spatial_dims
global_batch_size = config.batch_size * (
world_size // math.prod(config.dc_num_shards)
)
total_shards = math.prod(config.dc_num_shards)
global_batch_size = config.batch_size * (world_size // total_shards)
ddp_ranks = world_size // total_shards
if rank == 0:
log.info(
f"Effective global batch size = {global_batch_size} "
f"(batch_size={config.batch_size} * "
f"(world_size={world_size} / prod(dc_num_shards)={math.prod(config.dc_num_shards)}))"
f"(world_size={world_size} / prod(dc_num_shards)={total_shards}))"
)
log.info(
f"DDP ranks = {ddp_ranks} "
f"world_size={world_size} // prod(dc_num_shards)={total_shards}"
)
if global_batch_size > trainer.n_train:
raise ValueError(
Expand Down
Loading