diff --git a/.gitignore b/.gitignore index b40cab9..dc396ef 100644 --- a/.gitignore +++ b/.gitignore @@ -60,6 +60,7 @@ expected_outputs/ testdata/ actual_outputs/ *.html +*.zip # IDE/project-specific folders .vscode/ diff --git a/config/multi_node/shampoo_opt_multi_node.yaml b/config/multi_node/shampoo_opt_multi_node.yaml index 88fc42b..6123efc 100644 --- a/config/multi_node/shampoo_opt_multi_node.yaml +++ b/config/multi_node/shampoo_opt_multi_node.yaml @@ -6,19 +6,28 @@ logging: level: INFO training: - epochs: 10 + epochs: 100 batch_size: 512 gradient_accumulation: 2 mixed_precision: bf16 - max_steps: 2200 + max_steps: 4400 grad_clip_norm: 1.0 output_dir: artifacts/user_shampoo - log_interval: 20 + log_interval: 5 additional_compute_streams: 2 lightweight_op_waves: 3 +# Warmup settings to prevent RCCL hangs in multi-node training +warmup: + # RCCL communicator warmup - runs all_reduce on process groups before FSDP init + enable_rccl_warmup: true + rccl_warmup_iterations: 5 + # Training warmup - runs forward/backward/optimizer steps before main loop + enable_training_warmup: true + training_warmup_steps: 1 + optimizer: - name: shampoo + name: adamW lr: 0.0002 weight_decay: 0.01 betas: [0.9, 0.985] @@ -35,7 +44,7 @@ dataset: sparse_features: 64 vocab_size: 350000 num_dense_features: 32 - seed: 2025 + seed: 42 model: vocab_size: 350000 @@ -94,19 +103,17 @@ dataloader: pin_memory: true profiling: - enabled: true - wait: 2 - warmup: 2 - active: 6 + enabled: false + wait: 0 + warmup: 0 + active: 20 repeat: 1 record_shapes: true profile_memory: true with_stack: false with_flops: false - # tensorboard: true - # chrome_trace: true tensorboard: false - chrome_trace: false + chrome_trace: true trace_filename: user_shampoo.json tracelens: diff --git a/docker/docker-compose.rocm70_9-1-shampoo.yaml b/docker/docker-compose.rocm70_9-1-shampoo.yaml index 232e3a3..c376d3a 100644 --- a/docker/docker-compose.rocm70_9-1-shampoo.yaml +++ b/docker/docker-compose.rocm70_9-1-shampoo.yaml @@ -22,7 +22,8 @@ services: volumes: - /home/manrao:/manrao - - /home/oyazdanb/aorta:/workspace/aorta + - /apps/oyazdanb/aorta:/workspace/aorta + - /apps/oyazdanb/rccl:/rccl devices: - /dev/kfd - /dev/dri diff --git a/scripts/multi_node/README.md b/scripts/multi_node/README.md index 2650af7..b4714e7 100644 --- a/scripts/multi_node/README.md +++ b/scripts/multi_node/README.md @@ -211,6 +211,34 @@ done | NCCL timeout | Update `NCCL_SOCKET_IFNAME` in `set_env_variables.sh` | | World size mismatch | Check `rocm-smi --showid \| wc -l`, adjust `--nproc` | +### Training Hangs at RCCL Initialization + +If training hangs at "Warming up global world group..." or during FSDP initialization: + +1. **Ensure NCCL environment variables are set** in `local_launch.sh`: + - `NCCL_SOCKET_IFNAME` and `TORCH_NCCL_DUMP_ON_TIMEOUT=1` are critical + - See the full set in `local_launch.sh` DOCKER_EXEC section + +2. **Enable warmup settings** in your config YAML: + +```yaml +warmup: + # RCCL communicator warmup - runs all_reduce before FSDP init + enable_rccl_warmup: true + rccl_warmup_iterations: 5 + # Training warmup - runs forward/backward/optimizer before main loop + enable_training_warmup: true + training_warmup_steps: 1 +``` + +3. **Debug with NCCL logging**: +```bash +export NCCL_DEBUG=INFO +export NCCL_DEBUG_SUBSYS=ALL +``` + +The warmup settings exercise RCCL communicators before the main training loop starts, preventing race conditions during inter-node RDMA setup with HYBRID_SHARD strategy. + --- ## NCCL Configuration diff --git a/scripts/multi_node/config_node.sh b/scripts/multi_node/config_node.sh index 0115124..d808c97 100755 --- a/scripts/multi_node/config_node.sh +++ b/scripts/multi_node/config_node.sh @@ -21,6 +21,10 @@ ROCPROF_INPUT=$(echo "${15}" | sed 's/"//g') DOCKER_CONTAINER="${DOCKER_CONTAINER:-$(echo "${16}" | sed 's/"//g')}" DOCKER_CONTAINER="${DOCKER_CONTAINER:-training-overlap-bugs-rocm70_9-1}" +echo "============================================" +echo "DEBUG: Received ${16} parameters" +echo "DEBUG: Param 16 (DOCKER_CONTAINER) = '${16}'" +echo "DEBUG: After processing = '$DOCKER_CONTAINER'" echo "============================================" echo "Node Configuration" echo "============================================" diff --git a/scripts/multi_node/local_launch.sh b/scripts/multi_node/local_launch.sh index 2eeff18..497cfd8 100755 --- a/scripts/multi_node/local_launch.sh +++ b/scripts/multi_node/local_launch.sh @@ -1,6 +1,9 @@ #!/bin/bash # Multi-node local launch script for GEMM training # Runs on each node with single channel/thread configuration +# +# NCCL/RCCL environment variables are sourced from set_env_variables.sh +# Edit that file to change NCCL configuration - no need to modify this script. if [[ $# -lt 11 ]]; then echo "Usage: $0 [ENABLE_ROCPROF] [ROCPROF_STATS] [ROCPROF_INPUT] [DOCKER_CONTAINER]" @@ -23,6 +26,16 @@ ROCPROF_STATS="${13:-false}" ROCPROF_INPUT="${14:-}" DOCKER_CONTAINER="${15:-training-overlap-bugs-rocm70_9-1}" +# Source environment variables (should already be sourced by config_node.sh, but ensure it's loaded) +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +if [[ -f "$SCRIPT_DIR/set_env_variables.sh" ]]; then + source "$SCRIPT_DIR/set_env_variables.sh" +fi + +# Override channel/thread settings from command line arguments +export NCCL_MAX_NCHANNELS="${CHANNELS}" +export RCCL_THREADS_PER_BLOCK="${THREADS}" + echo "==========================================" echo "Local Launch Configuration" echo "==========================================" @@ -37,7 +50,6 @@ echo "Experiment Dir: $EXPERIMENT_DIR" echo "Config File: $CONFIG_FILE" echo "Channels: $CHANNELS" echo "Threads: $THREADS" -echo "Docker Container: $DOCKER_CONTAINER" echo "rocprof enabled: $ENABLE_ROCPROF" echo "==========================================" echo "" @@ -60,25 +72,22 @@ else CONFIG_FILE_DOCKER="$CONFIG_FILE" fi -# Log file -LOG_FILE="${OUTPUT_DIR}/node_${NODE_RANK}_output.log" - # Function to log with timestamp log() { local message="$1" local timestamp=$(date '+%Y-%m-%d %H:%M:%S') - echo "[${timestamp}] [Node ${NODE_RANK}] ${message}" | tee -a "${LOG_FILE}" + echo "[${timestamp}] [Node ${NODE_RANK}] ${message}" } # Cleanup function cleanup() { echo "" - echo "=== Caught interrupt signal ===" | tee -a "${LOG_FILE}" + echo "=== Caught interrupt signal ===" log "Cleaning up training processes on node ${NODE_RANK}..." # Try to kill processes inside Docker container - docker exec "$DOCKER_CONTAINER" pkill -9 -f "train.py" 2>/dev/null || true - docker exec "$DOCKER_CONTAINER" pkill -9 -f "torchrun" 2>/dev/null || true + docker exec training-overlap-bugs-rocm70_9-1 pkill -9 -f "train.py" 2>/dev/null || true + docker exec training-overlap-bugs-rocm70_9-1 pkill -9 -f "torchrun" 2>/dev/null || true # Also try on host (in case anything leaked) sudo pkill -9 -f "train.py" 2>/dev/null || true @@ -109,12 +118,15 @@ BASE_CMD="torchrun --nnodes ${NNODES} --node_rank ${NODE_RANK} --nproc_per_node BASE_OVERRIDES="--override profiling.tensorboard=false" # Build docker exec prefix with environment variables -DOCKER_EXEC="docker exec \ - -e RCCL_THREADS_PER_BLOCK=${THREADS} \ - -e NCCL_MAX_NCHANNELS=${CHANNELS} \ - -e HSA_ENABLE_SDMA=0 \ - -e PYTORCH_ROCM_PROFILER_ENABLE_TRACING=1 \ - ${DOCKER_CONTAINER}" +# All NCCL/RCCL variables are defined in set_env_variables.sh +DOCKER_ENV_FLAGS=$(build_docker_env_flags) +DOCKER_EXEC="docker exec ${DOCKER_ENV_FLAGS} ${DOCKER_CONTAINER}" + +# Log which env vars are being passed +log "Docker environment variables:" +for var in "${DOCKER_ENV_VARS[@]}"; do + log " ${var}=${!var}" +done # Run with or without rocprofv3 if [ "${ENABLE_ROCPROF}" = "true" ]; then diff --git a/scripts/multi_node/master_launch.sh b/scripts/multi_node/master_launch.sh index 9955052..43ab2aa 100755 --- a/scripts/multi_node/master_launch.sh +++ b/scripts/multi_node/master_launch.sh @@ -1,11 +1,6 @@ #!/bin/bash # Multi-node orchestration script for Aorta GEMM training # Adapted from DLRM master_launch.sh pattern -# -# TODO: Convert to SLURM-native launch using srun instead of SSH to individual nodes. -# Currently this script runs from a compute node and SSHs to other nodes. -# Ideally, we should run SLURM commands from the login node, which would -# eliminate the need for SSH connectivity checks and branch verification. usage() { echo "Usage: $0 [OPTIONS]" diff --git a/scripts/multi_node/set_env_variables.sh b/scripts/multi_node/set_env_variables.sh index 3e9c070..4151f08 100755 --- a/scripts/multi_node/set_env_variables.sh +++ b/scripts/multi_node/set_env_variables.sh @@ -1,42 +1,121 @@ #!/bin/bash +# ============================================================================= # Global NCCL/RCCL environment variables for multi-node training -# Based on DLRM_set_env_variables.sh +# Configured for MI350X cluster +# +# This file is the SINGLE SOURCE OF TRUTH for all NCCL/RCCL configuration. +# Edit variables here - local_launch.sh will automatically pick them up. +# ============================================================================= -# NCCL Debug Settings (use INFO for debugging network issues) -export NCCL_DEBUG=INFO -export NCCL_DEBUG_SUBSYS=INIT,NET -# Try disabling IB if InfiniBand is not properly configured -export NCCL_IB_DISABLE=1 +# ----------------------------------------------------------------------------- +# NCCL Debug Settings +# ----------------------------------------------------------------------------- +export NCCL_DEBUG=WARN +export NCCL_DEBUG_SUBSYS= # Options: COLL,INIT,NET (empty = none) -# IB/RNIC Configuration (commented out when IB is disabled) -# export NCCL_IB_HCA=bnxt_re0,bnxt_re1,bnxt_re2,bnxt_re3,bnxt_re4,bnxt_re5,bnxt_re6,bnxt_re7 -# export NCCL_IB_GID_INDEX=3 +# ----------------------------------------------------------------------------- +# RCCL-Specific Settings (ROCm) +# ----------------------------------------------------------------------------- +export RCCL_DIRECT_ALLGATHER_DISABLE=1 # Disable direct allgather +export RCCL_MSCCL_ENABLE=0 # Disable MSCCL +export RCCL_THREADS_PER_BLOCK=256 # Threads per block (override via --threads) + +# ----------------------------------------------------------------------------- +# IB/RNIC Configuration for MI350X +# ----------------------------------------------------------------------------- +export NCCL_IB_HCA=bnxt_re0,bnxt_re1,bnxt_re2,bnxt_re3,bnxt_re4,bnxt_re5,bnxt_re6,bnxt_re7 +export NCCL_IB_GID_INDEX=3 export NCCL_NCHANNELS_PER_NET_PEER=8 +# ----------------------------------------------------------------------------- # HSA Settings for ROCm +# ----------------------------------------------------------------------------- export HSA_ENABLE_IPC_MODE_LEGACY=1 +export HSA_ENABLE_SDMA=0 # Disable SDMA for stability -# NCCL Protocol +# ----------------------------------------------------------------------------- +# NCCL Protocol and Channels +# ----------------------------------------------------------------------------- export NCCL_PROTO=Simple - -# Channel Configuration (can be overridden by sweep parameters) export NCCL_MIN_NCHANNELS=40 -export NCCL_MAX_NCHANNELS=40 +export NCCL_MAX_NCHANNELS=40 # Override via --channels -# Network Interface -# Change this to match your network interface: eth0, ib0, enp49s0f0np0, etc. -# Temporarily commented out for auto-detection: -# export NCCL_SOCKET_IFNAME=enp193s0f0 +# ----------------------------------------------------------------------------- +# Network Interface for MI350X cluster +# ----------------------------------------------------------------------------- +export NCCL_SOCKET_IFNAME=enp49s0f0np0,fenic0 +# ----------------------------------------------------------------------------- +# Timeout and Error Handling +# ----------------------------------------------------------------------------- +export NCCL_TIMEOUT_MS=12000 # 12 second timeout +export TORCH_DIST_INIT_TIMEOUT=60 +export TORCH_NCCL_ASYNC_ERROR_HANDLING=1 +export TORCH_NCCL_TRACE_BUFFER_SIZE=10000 +export TORCH_NCCL_DUMP_ON_TIMEOUT=1 # Critical for hang debugging + +# ----------------------------------------------------------------------------- # PyTorch ROCm Profiler +# ----------------------------------------------------------------------------- export PYTORCH_ROCM_PROFILER_ENABLE_TRACING=1 -# Optional: Force non-overlap for debugging +# ----------------------------------------------------------------------------- +# List of environment variables to pass to Docker container +# Add/remove variables here to control what gets passed through +# ----------------------------------------------------------------------------- +DOCKER_ENV_VARS=( + # NCCL Debug + NCCL_DEBUG + NCCL_DEBUG_SUBSYS + # RCCL + RCCL_DIRECT_ALLGATHER_DISABLE + RCCL_MSCCL_ENABLE + RCCL_THREADS_PER_BLOCK + # IB/RNIC + NCCL_IB_HCA + NCCL_IB_GID_INDEX + NCCL_NCHANNELS_PER_NET_PEER + # HSA + HSA_ENABLE_IPC_MODE_LEGACY + HSA_ENABLE_SDMA + # Protocol/Channels + NCCL_PROTO + NCCL_MIN_NCHANNELS + NCCL_MAX_NCHANNELS + # Network + NCCL_SOCKET_IFNAME + # Timeout/Error Handling + NCCL_TIMEOUT_MS + TORCH_DIST_INIT_TIMEOUT + TORCH_NCCL_ASYNC_ERROR_HANDLING + TORCH_NCCL_TRACE_BUFFER_SIZE + TORCH_NCCL_DUMP_ON_TIMEOUT + # Profiler + PYTORCH_ROCM_PROFILER_ENABLE_TRACING +) +export DOCKER_ENV_VARS + +# ----------------------------------------------------------------------------- +# Helper function: Build docker -e flags from DOCKER_ENV_VARS +# Usage: DOCKER_ENV_FLAGS=$(build_docker_env_flags) +# ----------------------------------------------------------------------------- +build_docker_env_flags() { + local flags="" + for var in "${DOCKER_ENV_VARS[@]}"; do + local value="${!var}" + flags+=" -e ${var}=${value}" + done + echo "$flags" +} +export -f build_docker_env_flags + +# ============================================================================= +# Optional settings (uncomment to enable) +# ============================================================================= + +# Force non-overlap for debugging (single HW queue) # export GPU_MAX_HW_QUEUES=1 # unset TORCH_NCCL_HIGH_PRIORITY -# Optional: Disable SDMA for testing -# export HSA_ENABLE_SDMA=0 - -# Optional: Disable IB for Ethernet-only testing +# Disable IB for Ethernet-only testing # export NCCL_IB_DISABLE=1 diff --git a/src/aorta/profiling/stream_profiler.py b/src/aorta/profiling/stream_profiler.py index 90c2e66..7e6ab01 100644 --- a/src/aorta/profiling/stream_profiler.py +++ b/src/aorta/profiling/stream_profiler.py @@ -37,7 +37,11 @@ class MarkerRecord: class StreamProfiler: """Track activity across multiple CUDA/HIP streams with precise timing.""" - def __init__(self, device: torch.device, stream_names: Optional[Iterable[StreamName]] = None) -> None: + def __init__( + self, + device: torch.device, + stream_names: Optional[Iterable[StreamName]] = None, + ) -> None: if not torch.cuda.is_available(): # pragma: no cover - runtime guard raise RuntimeError("StreamProfiler requires CUDA/HIP availability") diff --git a/src/aorta/training/fsdp_trainer.py b/src/aorta/training/fsdp_trainer.py index eea1923..f2e24ab 100644 --- a/src/aorta/training/fsdp_trainer.py +++ b/src/aorta/training/fsdp_trainer.py @@ -7,6 +7,7 @@ import json import logging import os +import random import signal import subprocess from dataclasses import dataclass, field @@ -14,6 +15,7 @@ from typing import Any, Dict, Generator, Iterable, Optional from functools import partial +import numpy as np import torch import torch.distributed as dist import torch.nn as nn @@ -23,7 +25,7 @@ from torch.nn.utils import clip_grad_norm_ from torch.optim import AdamW from torch.nn.parallel import DistributedDataParallel as DDP - +from datetime import timedelta from aorta.data import SyntheticDatasetConfig, create_dataloader from aorta.models import ModelConfig, RankingTransformerModel from aorta.profiling.stream_profiler import StreamProfiler @@ -61,6 +63,17 @@ class TrainingConfig: allreduce_stress_level: int = 1 # Number of all_reduce ops per iteration (1-10) +@dataclass +class WarmupConfig: + """Warmup settings to prevent RCCL hangs in multi-node training.""" + # RCCL communicator warmup - runs all_reduce on process groups before FSDP init + enable_rccl_warmup: bool = True + rccl_warmup_iterations: int = 10 + # Training warmup - runs forward/backward/optimizer steps before main loop + enable_training_warmup: bool = True + training_warmup_steps: int = 1 + + @dataclass class FSDPConfig: sharding_strategy: str = "full_shard" @@ -171,6 +184,15 @@ def _build_fsdp_config(raw: Dict[str, Any]) -> FSDPConfig: return cfg +def _build_warmup_config(raw: Dict[str, Any]) -> WarmupConfig: + section = raw.get("warmup", {}) + cfg = WarmupConfig() + for field in dataclass_fields(WarmupConfig): + if field.name in section: + setattr(cfg, field.name, section[field.name]) + return cfg + + def _build_ddp_config(raw: Dict[str, Any]) -> DDPConfig: section = raw.get("distributed", {}) cfg = DDPConfig() @@ -204,9 +226,22 @@ def dataclass_fields(cls) -> Iterable[Any]: return getattr(cls, "__dataclass_fields__").values() +def set_seed(seed: int, rank: int) -> None: + """Set all random seeds for reproducibility across runs.""" + seed_value = seed + rank + random.seed(seed_value) + np.random.seed(seed_value) + torch.manual_seed(seed_value) + torch.cuda.manual_seed(seed_value) + torch.cuda.manual_seed_all(seed_value) + log.info("Set random seed=%d for rank=%d (base_seed=%d)", seed_value, rank, seed) + + def init_distributed(training_cfg: TrainingConfig, log_level: str) -> Dict[str, Any]: + backend = get_distributed_backend() - dist.init_process_group(backend=backend) + timeout_seconds = int(os.environ.get("TORCH_DIST_INIT_TIMEOUT", "600")) + dist.init_process_group(backend=backend, timeout=timedelta(seconds=timeout_seconds)) rank = dist.get_rank() world_size = dist.get_world_size() local_rank = int(os.environ.get("LOCAL_RANK", os.environ.get("SLURM_LOCALID", 0))) @@ -237,6 +272,7 @@ def build_fsdp_model( model_cfg: ModelConfig, fsdp_cfg: FSDPConfig, compile_cfg: CompileConfig, + warmup_cfg: WarmupConfig, device: torch.device, ) -> FSDP: model = RankingTransformerModel(model_cfg) @@ -250,11 +286,46 @@ def build_fsdp_model( # Create process groups for hybrid_shard strategy process_group = None + shard_group = None + replicate_group = None + if sharding == ShardingStrategy.HYBRID_SHARD: - process_group = _create_hybrid_shard_process_groups(fsdp_cfg.hybrid_shard_gpus_per_node) - if process_group is not None: + result = _create_hybrid_shard_process_groups(fsdp_cfg.hybrid_shard_gpus_per_node) + if result is not None: + shard_group, replicate_group = result + process_group = (shard_group, replicate_group) + + # Warmup RCCL communicators BEFORE FSDP initialization + # This ensures inter-node communicators are fully established before + # the _sync_params_and_buffers broadcasts that can cause hangs + if warmup_cfg.enable_rccl_warmup: + _warmup_rccl_communicators( + shard_group, + replicate_group, + device, + num_warmup_ops=warmup_cfg.rccl_warmup_iterations, + ) log.info("Created custom process groups for HYBRID_SHARD strategy") + # Ensure CUDA operations are complete before FSDP wrapping + # This helps prevent race conditions with inter-node communicators + torch.cuda.synchronize() + dist.barrier() + + # For HYBRID_SHARD with sync_module_states, we disable automatic sync and do it + # manually with explicit barriers to avoid RCCL race conditions + use_sync_module_states = fsdp_cfg.sync_module_states + needs_manual_sync = False + + if sharding == ShardingStrategy.HYBRID_SHARD and fsdp_cfg.sync_module_states: + use_sync_module_states = False + needs_manual_sync = True + log.info( + "Disabling sync_module_states for HYBRID_SHARD - will sync manually after wrapping" + ) + + log.info("Starting FSDP model wrapping with sync_module_states=%s", use_sync_module_states) + fsdp_model = FSDP( model.to(device), sharding_strategy=sharding, @@ -265,8 +336,19 @@ def build_fsdp_model( limit_all_gathers=fsdp_cfg.limit_all_gathers, forward_prefetch=fsdp_cfg.forward_prefetch, device_id=torch.cuda.current_device(), - sync_module_states=fsdp_cfg.sync_module_states, + sync_module_states=use_sync_module_states, ) + + log.info("FSDP model wrapping complete") + + # Manual parameter sync for HYBRID_SHARD after FSDP wrapping + if needs_manual_sync and replicate_group is not None: + _manual_sync_params(fsdp_model, replicate_group) + # Extra synchronization after manual sync before proceeding + torch.cuda.synchronize() + dist.barrier() + log.info("Post-sync barrier complete") + if compile_cfg.enabled: fsdp_model = _maybe_compile(fsdp_model, compile_cfg) return fsdp_model @@ -344,9 +426,227 @@ def _create_hybrid_shard_process_groups(gpus_per_node: Optional[int] = None): dist.get_world_size(my_replicate_group), ) + # Note: We don't barrier here because the warmup function will handle synchronization. + # Calling barrier here can trigger NCCL init race conditions before warmup runs. + return (my_shard_group, my_replicate_group) +def _warmup_rccl_communicators( + shard_group: Optional[dist.ProcessGroup], + replicate_group: Optional[dist.ProcessGroup], + device: torch.device, + num_warmup_ops: int = 5, +) -> None: + """ + Warm up RCCL communicators with small operations before heavy FSDP usage. + + This ensures inter-node communicators are fully established before the + _sync_params_and_buffers broadcasts. The race condition in RCCL/RoCE RDMA + setup can cause hangs during FSDP initialization if broadcasts are issued + before the communicators are ready. + + Args: + shard_group: Intra-node shard process group (may be None) + replicate_group: Inter-node replicate process group (may be None) + device: CUDA device to use for warmup tensors + num_warmup_ops: Number of warmup operations to perform (default: 5) + """ + rank = dist.get_rank() + # Use a larger tensor for more thorough warmup + warmup_tensor = torch.ones(8192, device=device, dtype=torch.float32) + + log.info("Starting RCCL communicator warmup with %d iterations (rank=%d)...", num_warmup_ops, rank) + + # First, warmup the global world group + log.info("Warming up global world group...") + for i in range(num_warmup_ops): + dist.all_reduce(warmup_tensor) + dist.broadcast(warmup_tensor, src=0) + torch.cuda.synchronize() + + dist.barrier() + log.info("Global world group warmup complete") + + # Then warmup the shard and replicate groups + for i in range(num_warmup_ops): + # Warmup intra-node shard group + if shard_group is not None: + dist.all_reduce(warmup_tensor, group=shard_group) + # Also do broadcast from first rank in shard group + shard_ranks = dist.get_process_group_ranks(shard_group) + dist.broadcast(warmup_tensor, src=shard_ranks[0], group=shard_group) + + # Warmup inter-node replicate group (this is where the race condition occurs) + if replicate_group is not None: + # Get the ranks in this replicate group and use the first one as source + # Note: dist.get_process_group_ranks returns global ranks in the group + group_ranks = dist.get_process_group_ranks(replicate_group) + src_global_rank = group_ranks[0] # First rank in the group + dist.broadcast(warmup_tensor, src=src_global_rank, group=replicate_group) + dist.all_reduce(warmup_tensor, group=replicate_group) + + # Synchronize CUDA and global barrier between iterations + torch.cuda.synchronize() + dist.barrier() + + # Final synchronization with extra delay + torch.cuda.synchronize() + dist.barrier() + torch.cuda.synchronize() + dist.barrier() + + log.info("RCCL communicator warmup complete (rank=%d)", rank) + + +def _manual_sync_params( + model: FSDP, + replicate_group: Optional[dist.ProcessGroup], +) -> None: + """ + Manually synchronize FSDP parameters from the first rank in each replicate group. + + This replaces the automatic sync_module_states with controlled synchronization + to avoid race conditions in RCCL/RDMA during FSDP initialization. Parameters + are broadcast from the first rank in each replicate group to ensure consistency. + + Args: + model: The FSDP-wrapped model + replicate_group: Inter-node replicate process group for broadcasting + """ + rank = dist.get_rank() + + log.info("Starting manual parameter synchronization (rank=%d)...", rank) + + # Synchronize before param sync + torch.cuda.synchronize() + dist.barrier() + + # Determine the source rank for this replicate group + # Each replicate group contains ranks with the same local_rank across nodes + # e.g., group for local_rank 2: [2, 10, 18] - we broadcast from rank 2 (first in group) + src_global_rank = None + if replicate_group is not None: + group_ranks = dist.get_process_group_ranks(replicate_group) + src_global_rank = group_ranks[0] # First rank in the group + log.info("Manual sync: replicate group ranks=%s, src_rank=%d", group_ranks, src_global_rank) + + param_count = 0 + with torch.no_grad(): + for name, param in model.named_parameters(): + if param.is_meta: + log.debug("Skipping meta parameter: %s", name) + continue + + # Broadcast from the first rank within this replicate group + if replicate_group is not None and src_global_rank is not None: + dist.broadcast(param.data, src=src_global_rank, group=replicate_group) + + param_count += 1 + + # Periodic sync to prevent overwhelming the network + if param_count % 10 == 0: + torch.cuda.synchronize() + + # Final barrier to ensure all ranks complete + torch.cuda.synchronize() + dist.barrier() + + log.info("Manual parameter synchronization complete (rank=%d, params=%d)", rank, param_count) + + +def _warmup_training_collectives( + model: nn.Module, + optimizer: torch.optim.Optimizer, + dataloader, + device: torch.device, + autocast_dtype: Optional[torch.dtype], + scaler: Optional[torch.cuda.amp.GradScaler], + num_warmup_steps: int = 3, +) -> None: + """ + Warm up training collectives by running dummy forward/backward/optimizer steps. + + This exercises all the collective operations used during training (all-gather, + reduce-scatter, all-reduce) to ensure RCCL communicators are fully established + before the main training loop starts. + + Args: + model: The model (FSDP-wrapped) + optimizer: The optimizer + dataloader: Training dataloader + device: CUDA device + autocast_dtype: Mixed precision dtype (or None) + scaler: Gradient scaler for fp16 (or None) + num_warmup_steps: Number of warmup steps to run + """ + rank = dist.get_rank() if dist.is_initialized() else 0 + + log.info("[Warmup] Starting training warmup (rank=%d, steps=%d)...", rank, num_warmup_steps) + + # Get an iterator from the dataloader + data_iter = iter(dataloader) + + for warmup_step in range(num_warmup_steps): + log.info("[Warmup] Step %d/%d: Getting batch (rank=%d)...", warmup_step + 1, num_warmup_steps, rank) + try: + cpu_batch = next(data_iter) + except StopIteration: + # Restart iterator if dataloader is exhausted + data_iter = iter(dataloader) + cpu_batch = next(data_iter) + + # Move batch to device + batch = {k: v.to(device, non_blocking=True) if hasattr(v, 'to') else v + for k, v in cpu_batch.items()} + torch.cuda.synchronize() + log.info("[Warmup] Step %d/%d: Batch moved to device (rank=%d)", warmup_step + 1, num_warmup_steps, rank) + + # Forward pass + optimizer.zero_grad(set_to_none=True) + log.info("[Warmup] Step %d/%d: Starting forward pass (rank=%d)...", warmup_step + 1, num_warmup_steps, rank) + if autocast_dtype: + with torch.autocast(device_type="cuda", dtype=autocast_dtype): + scores = model(batch) + loss = compute_loss(scores, batch) + else: + scores = model(batch) + loss = compute_loss(scores, batch) + torch.cuda.synchronize() + log.info("[Warmup] Step %d/%d: Forward pass complete, loss=%.4f (rank=%d)", warmup_step + 1, num_warmup_steps, loss.item(), rank) + + # Backward pass + log.info("[Warmup] Step %d/%d: Starting backward pass (rank=%d)...", warmup_step + 1, num_warmup_steps, rank) + if scaler is not None: + scaler.scale(loss).backward() + else: + loss.backward() + torch.cuda.synchronize() + log.info("[Warmup] Step %d/%d: Backward pass complete (rank=%d)", warmup_step + 1, num_warmup_steps, rank) + + # Optimizer step + log.info("[Warmup] Step %d/%d: Starting optimizer step (rank=%d)...", warmup_step + 1, num_warmup_steps, rank) + if scaler is not None: + scaler.step(optimizer) + scaler.update() + else: + optimizer.step() + torch.cuda.synchronize() + log.info("[Warmup] Step %d/%d: Optimizer step complete (rank=%d)", warmup_step + 1, num_warmup_steps, rank) + + # Synchronize all ranks after each warmup step + log.info("[Warmup] Step %d/%d: Starting barrier (rank=%d)...", warmup_step + 1, num_warmup_steps, rank) + dist.barrier() + log.info("[Warmup] Step %d/%d: Barrier complete (rank=%d)", warmup_step + 1, num_warmup_steps, rank) + + # Reset optimizer state after warmup to not affect actual training + log.info("[Warmup] Resetting optimizer state (rank=%d)...", rank) + optimizer.zero_grad(set_to_none=True) + torch.cuda.synchronize() + dist.barrier() + log.info("[Warmup] Training warmup complete (rank=%d)", rank) + + def build_ddp_model( model_cfg: ModelConfig, ddp_cfg: DDPConfig, @@ -438,6 +738,7 @@ def training_loop( optimizer: torch.optim.Optimizer, dataloader, training_cfg: TrainingConfig, + warmup_cfg: WarmupConfig, scheduler: Optional[torch.optim.lr_scheduler._LRScheduler], environment: Dict[str, Any], profiler: StreamProfiler, @@ -470,6 +771,16 @@ def training_loop( model.train() + # Warmup training collectives before main loop to avoid RCCL race conditions + # This exercises forward/backward/optimizer step on all communicators + if warmup_cfg.enable_training_warmup: + log.info("Starting training warmup pass (rank=%d, num_steps=%d)...", rank, warmup_cfg.training_warmup_steps) + _warmup_training_collectives(model, optimizer, dataloader, device, autocast_dtype, scaler, num_warmup_steps=warmup_cfg.training_warmup_steps) + log.info("Training warmup complete (rank=%d)", rank) + + # Note: force_async is now controlled dynamically per-step in training loop + # Only racing ranks (local 1, 2 on Node 0) at step >= 3 will have force_async enabled + profiler_dir = training_cfg.output_dir / "torch_profiler" with profiler.intercept_distributed_ops(): with _torch_profiler_context(profiler_cfg, profiler_dir, rank, device) as torch_profiler: @@ -508,11 +819,20 @@ def training_loop( grad_norm = clip_grad_norm_(model.parameters(), training_cfg.grad_clip_norm) with profiler.range("aux", f"epoch{epoch}_step{step}_optimizer"): - if scaler is not None: - scaler.step(optimizer) - scaler.update() - else: - optimizer.step() + try: + if scaler is not None: + scaler.step(optimizer) + scaler.update() + else: + optimizer.step() + except AssertionError as e: + if "NaN" in str(e) or "Inf" in str(e): + log.error("NaN/Inf detected in rank %d at step %d: %s", rank, global_step, e) + log.error("Stopping training to save traces") + stop_flag["stop"] = True + break + else: + raise if scheduler is not None: scheduler.step() @@ -520,7 +840,6 @@ def training_loop( # Inject all_reduce operations to trigger hang pattern # Pattern: all_reduce → device-to-device copy → host-device copy → compute blocked if training_cfg.inject_allreduce_copies: - import torch.distributed as dist if dist.is_initialized(): with profiler.range("aux", f"epoch{epoch}_step{step}_allreduce_sync"): # Perform multiple all_reduce + memory copy cycles @@ -594,6 +913,10 @@ def training_loop( iteration_payload.update(collect_rocm_metrics(enable_rocm_metrics)) metrics_logger.log(iteration_payload) + loss_log = training_cfg.output_dir / f"loss_rank{rank}.log" + with open(loss_log, "a") as f: + f.write(f"step={global_step} epoch={epoch} loss={iteration_payload['loss']:.6f} lr={iteration_payload['lr']:.6f}\n") + if global_step % training_cfg.log_interval == 0 and rank == 0: log.info( "epoch=%s step=%s loss=%.5f lr=%.6f overlap=%.3fms compute=%.3fms", @@ -616,10 +939,19 @@ def training_loop( break if stop_flag["stop"]: + if rank == 0: + log.info("Training stopped at epoch=%d step=%d", epoch, step) break metrics_logger.close() + if rank == 0: + log.info("Training loop finished. Profiler will export traces in cleanup phase.") + log.info("Output directory: %s", training_cfg.output_dir) + log.info("Torch profiler traces: %s", training_cfg.output_dir / "torch_profiler") + log.info("Loss logs: %s/loss_rank*.log", training_cfg.output_dir) + log.info("Metrics: %s/rank_*_metrics.jsonl", training_cfg.output_dir) + def configure_optimizer(model: nn.Module, cfg: OptimizerConfig, dist_mode: str = "ddp") -> torch.optim.Optimizer: if cfg.name.lower() == "shampoo": @@ -754,19 +1086,16 @@ def _torch_profiler_context( prof.__exit__(None, None, None) produce_tb = cfg.tensorboard produce_chrome = cfg.chrome_trace - try: - stats_available = prof._stats() is not None # type: ignore[attr-defined] - except Exception: - stats_available = False - if produce_tb and stats_available: + if produce_tb: try: handler = tensorboard_trace_handler(str(rank_dir)) handler(prof) - except Exception as exc: # pragma: no cover - best effort - log.warning("TensorBoard trace export failed: %s", exc, exc_info=True) + log.info("Exported TensorBoard trace to %s", rank_dir) + except Exception as exc: + log.warning("TensorBoard trace export failed: %s", exc) - if produce_chrome and stats_available: + if produce_chrome: stem, ext = os.path.splitext(cfg.trace_filename) if not ext: ext = ".json" @@ -775,8 +1104,9 @@ def _torch_profiler_context( trace_name = f"{stem}_step{prof.step_num}{ext}" try: prof.export_chrome_trace(str(rank_dir / trace_name)) - except Exception as exc: # pragma: no cover - best effort - log.warning("Chrome trace export failed: %s", exc, exc_info=True) + log.info("Exported chrome trace to %s/%s", rank_dir, trace_name) + except Exception as exc: + log.warning("Chrome trace export failed: %s", exc) def main_cli() -> None: # pragma: no cover - CLI entry @@ -810,6 +1140,7 @@ def main(args: Optional[argparse.Namespace] = None, *, enable_rocm_metrics: bool model_cfg = _build_model_config(config) dataset_cfg = _build_dataset_config(config) fsdp_cfg = _build_fsdp_config(config) + warmup_cfg = _build_warmup_config(config) ddp_cfg = _build_ddp_config(config) compile_cfg = _build_compile_config(config) profiler_cfg = _build_profiler_config(config) @@ -818,6 +1149,8 @@ def main(args: Optional[argparse.Namespace] = None, *, enable_rocm_metrics: bool env = init_distributed(training_cfg, log_level) rank = env["rank"] + set_seed(dataset_cfg.seed, rank) + dataloader = create_dataloader( dataset_cfg, batch_size=training_cfg.batch_size, @@ -835,7 +1168,7 @@ def main(args: Optional[argparse.Namespace] = None, *, enable_rocm_metrics: bool if dist_mode == "ddp": model = build_ddp_model(model_cfg, ddp_cfg, compile_cfg, env["device"]) else: - model = build_fsdp_model(model_cfg, fsdp_cfg, compile_cfg, env["device"]) + model = build_fsdp_model(model_cfg, fsdp_cfg, compile_cfg, warmup_cfg, env["device"]) optimizer = configure_optimizer(model, optimizer_cfg, dist_mode) scheduler = configure_scheduler( optimizer, @@ -851,6 +1184,7 @@ def main(args: Optional[argparse.Namespace] = None, *, enable_rocm_metrics: bool optimizer, dataloader, training_cfg, + warmup_cfg, scheduler, env, profiler, @@ -858,8 +1192,15 @@ def main(args: Optional[argparse.Namespace] = None, *, enable_rocm_metrics: bool profiler_cfg, ) finally: - dist.barrier() - dist.destroy_process_group() + if dist.is_initialized(): + try: + dist.barrier() + except Exception as e: + log.warning("Barrier failed during cleanup: %s", e) + try: + dist.destroy_process_group() + except Exception as e: + log.warning("destroy_process_group failed: %s", e) __all__ = ["main", "main_cli"]