diff --git a/hydragnn/train/train_validate_test.py b/hydragnn/train/train_validate_test.py index 8dd828435..77a76c187 100644 --- a/hydragnn/train/train_validate_test.py +++ b/hydragnn/train/train_validate_test.py @@ -9,32 +9,35 @@ # SPDX-License-Identifier: BSD-3-Clause # ############################################################################## -from tqdm import tqdm -import numpy as np import torch +from torch.cuda.amp import autocast, GradScaler -from hydragnn.preprocess.serialized_dataset_loader import SerializedDataLoader from hydragnn.postprocess.postprocess import output_denormalize from hydragnn.postprocess.visualizer import Visualizer -from hydragnn.utils.print_utils import print_distributed, iterate_tqdm, log +from hydragnn.utils.print_utils import print_distributed, iterate_tqdm from hydragnn.utils.time_utils import Timer from hydragnn.utils.profile import Profiler -from hydragnn.utils.distributed import get_device, print_peak_memory -from hydragnn.preprocess.load_data import HydraDataLoader +from hydragnn.utils.distributed import get_device from hydragnn.utils.model import Checkpoint, EarlyStopping import os from torch.profiler import record_function -import contextlib -from unittest.mock import MagicMock from hydragnn.utils.distributed import get_comm_size_and_rank import torch.distributed as dist -import pickle import hydragnn.utils.tracer as tr +# Check if Tensor Cores are available on the GPU +# We look only at one GPUs, assuming that all the GPUs in the distributed compute environment are equal +use_tensor_cores = ( + torch.cuda.is_available() and torch.cuda.get_device_properties(0).major >= 7 +) + +# Initialize AMP if Tensor Cores are available +scaler = GradScaler(enabled=use_tensor_cores) + def train_validate_test( model, @@ -425,17 +428,25 @@ def train( tr.stop("get_head_indices") tr.start("forward") with record_function("forward"): - data = data.to(get_device()) - pred = model(data) - loss, tasks_loss = model.module.loss(pred, data.y, head_index) - tr.stop("forward") - tr.start("backward") + # Perform forward pass and backward pass under autocast + with autocast(enabled=use_tensor_cores, dtype=torch.bfloat16): + # with autocast(enabled=use_tensor_cores, dtype=torch.float16): + # with autocast(enabled=use_tensor_cores, dtype=torch.float32): + data = data.to(get_device()) + pred = model(data) + loss, tasks_loss = model.module.loss(pred, data.y, head_index) + tr.stop("forward") + tr.start("backward") with record_function("backward"): - loss.backward() + # Scale the loss and perform backpropagation + scaler.scale(loss).backward() + tr.stop("backward") tr.start("opt_step") # print_peak_memory(verbosity, "Max memory allocated before optimizer step") - opt.step() + # Unscaled step of optimizer + scaler.step(opt) + scaler.update() # print_peak_memory(verbosity, "Max memory allocated after optimizer step") tr.stop("opt_step") profiler.step() @@ -458,7 +469,6 @@ def train( @torch.no_grad() def validate(loader, model, verbosity, reduce_ranks=True): - total_error = torch.tensor(0.0, device=get_device()) tasks_error = torch.zeros(model.module.num_heads, device=get_device()) num_samples_local = 0 @@ -496,7 +506,6 @@ def validate(loader, model, verbosity, reduce_ranks=True): @torch.no_grad() def test(loader, model, verbosity, reduce_ranks=True, return_samples=True): - total_error = torch.tensor(0.0, device=get_device()) tasks_error = torch.zeros(model.module.num_heads, device=get_device()) num_samples_local = 0