From 3e2c3ae8d050b8eecaa1e0140c9cc0cf0c11c48d Mon Sep 17 00:00:00 2001 From: dggaytan Date: Fri, 18 Jul 2025 13:47:55 -0700 Subject: [PATCH 1/2] Adding torch accelerator to ddp-tutorial-series example Signed-off-by: dggaytan --- distributed/ddp-tutorial-series/multigpu.py | 16 ++++++++++--- .../ddp-tutorial-series/multigpu_torchrun.py | 23 +++++++++++++++---- distributed/ddp-tutorial-series/multinode.py | 22 ++++++++++++++---- .../ddp-tutorial-series/requirements.txt | 2 +- .../ddp-tutorial-series/run_example.sh | 10 ++++++++ run_distributed_examples.sh | 6 +++++ 6 files changed, 65 insertions(+), 14 deletions(-) create mode 100644 distributed/ddp-tutorial-series/run_example.sh diff --git a/distributed/ddp-tutorial-series/multigpu.py b/distributed/ddp-tutorial-series/multigpu.py index 7e11633305..652822bb06 100644 --- a/distributed/ddp-tutorial-series/multigpu.py +++ b/distributed/ddp-tutorial-series/multigpu.py @@ -18,8 +18,18 @@ def ddp_setup(rank, world_size): """ os.environ["MASTER_ADDR"] = "localhost" os.environ["MASTER_PORT"] = "12355" - torch.cuda.set_device(rank) - init_process_group(backend="nccl", rank=rank, world_size=world_size) + + rank = int(os.environ["LOCAL_RANK"]) + if torch.accelerator.is_available(): + device = torch.device(f"{torch.accelerator.current_accelerator()}:{rank}") + torch.accelerator.set_device_index(rank) + print(f"Running on rank {rank} on device {device}") + else: + device = torch.device("cpu") + print(f"Running on device {device}") + + backend = torch.distributed.get_default_backend_for_device(device) + init_process_group(backend=backend, rank=rank, world_size=world_size) class Trainer: def __init__( @@ -100,5 +110,5 @@ def main(rank: int, world_size: int, save_every: int, total_epochs: int, batch_s parser.add_argument('--batch_size', default=32, type=int, help='Input batch size on each device (default: 32)') args = parser.parse_args() - world_size = torch.cuda.device_count() + world_size = torch.accelerator.device_count() mp.spawn(main, args=(world_size, args.save_every, args.total_epochs, args.batch_size), nprocs=world_size) diff --git a/distributed/ddp-tutorial-series/multigpu_torchrun.py b/distributed/ddp-tutorial-series/multigpu_torchrun.py index 32d6254d2d..0dadfe9449 100644 --- a/distributed/ddp-tutorial-series/multigpu_torchrun.py +++ b/distributed/ddp-tutorial-series/multigpu_torchrun.py @@ -11,8 +11,19 @@ def ddp_setup(): - torch.cuda.set_device(int(os.environ["LOCAL_RANK"])) - init_process_group(backend="nccl") + rank = int(os.environ["LOCAL_RANK"]) + if torch.accelerator.is_available(): + device = torch.device(f"{torch.accelerator.current_accelerator()}:{rank}") + torch.accelerator.set_device_index(rank) + print(f"Running on rank {rank} on device {device}") + else: + device = torch.device("cpu") + print(f"Running on device {device}") + + backend = torch.distributed.get_default_backend_for_device(device) + torch.distributed.init_process_group(backend=backend, device_id=device) + return device + class Trainer: def __init__( @@ -22,6 +33,7 @@ def __init__( optimizer: torch.optim.Optimizer, save_every: int, snapshot_path: str, + device: torch.device, ) -> None: self.gpu_id = int(os.environ["LOCAL_RANK"]) self.model = model.to(self.gpu_id) @@ -30,6 +42,7 @@ def __init__( self.save_every = save_every self.epochs_run = 0 self.snapshot_path = snapshot_path + self.device = device if os.path.exists(snapshot_path): print("Loading snapshot") self._load_snapshot(snapshot_path) @@ -37,7 +50,7 @@ def __init__( self.model = DDP(self.model, device_ids=[self.gpu_id]) def _load_snapshot(self, snapshot_path): - loc = f"cuda:{self.gpu_id}" + loc = str(self.device) snapshot = torch.load(snapshot_path, map_location=loc) self.model.load_state_dict(snapshot["MODEL_STATE"]) self.epochs_run = snapshot["EPOCHS_RUN"] @@ -92,10 +105,10 @@ def prepare_dataloader(dataset: Dataset, batch_size: int): def main(save_every: int, total_epochs: int, batch_size: int, snapshot_path: str = "snapshot.pt"): - ddp_setup() + device = ddp_setup() dataset, model, optimizer = load_train_objs() train_data = prepare_dataloader(dataset, batch_size) - trainer = Trainer(model, train_data, optimizer, save_every, snapshot_path) + trainer = Trainer(model, train_data, optimizer, save_every, snapshot_path, device) trainer.train(total_epochs) destroy_process_group() diff --git a/distributed/ddp-tutorial-series/multinode.py b/distributed/ddp-tutorial-series/multinode.py index 2cbae84b56..32ee031639 100644 --- a/distributed/ddp-tutorial-series/multinode.py +++ b/distributed/ddp-tutorial-series/multinode.py @@ -11,8 +11,18 @@ def ddp_setup(): - torch.cuda.set_device(int(os.environ["LOCAL_RANK"])) - init_process_group(backend="nccl") + rank = int(os.environ["LOCAL_RANK"]) + if torch.accelerator.is_available(): + device = torch.device(f"{torch.accelerator.current_accelerator()}:{rank}") + torch.accelerator.set_device_index(rank) + print(f"Running on rank {rank} on device {device}") + else: + device = torch.device("cpu") + print(f"Running on device {device}") + + backend = torch.distributed.get_default_backend_for_device(device) + torch.distributed.init_process_group(backend=backend, device_id=device) + return device class Trainer: def __init__( @@ -22,6 +32,7 @@ def __init__( optimizer: torch.optim.Optimizer, save_every: int, snapshot_path: str, + device: torch.device, ) -> None: self.local_rank = int(os.environ["LOCAL_RANK"]) self.global_rank = int(os.environ["RANK"]) @@ -31,6 +42,7 @@ def __init__( self.save_every = save_every self.epochs_run = 0 self.snapshot_path = snapshot_path + self.device = device if os.path.exists(snapshot_path): print("Loading snapshot") self._load_snapshot(snapshot_path) @@ -38,7 +50,7 @@ def __init__( self.model = DDP(self.model, device_ids=[self.local_rank]) def _load_snapshot(self, snapshot_path): - loc = f"cuda:{self.local_rank}" + loc = str(self.device) snapshot = torch.load(snapshot_path, map_location=loc) self.model.load_state_dict(snapshot["MODEL_STATE"]) self.epochs_run = snapshot["EPOCHS_RUN"] @@ -93,10 +105,10 @@ def prepare_dataloader(dataset: Dataset, batch_size: int): def main(save_every: int, total_epochs: int, batch_size: int, snapshot_path: str = "snapshot.pt"): - ddp_setup() + device = ddp_setup() dataset, model, optimizer = load_train_objs() train_data = prepare_dataloader(dataset, batch_size) - trainer = Trainer(model, train_data, optimizer, save_every, snapshot_path) + trainer = Trainer(model, train_data, optimizer, save_every, snapshot_path, device) trainer.train(total_epochs) destroy_process_group() diff --git a/distributed/ddp-tutorial-series/requirements.txt b/distributed/ddp-tutorial-series/requirements.txt index 9270a1d6ee..285a4d8195 100644 --- a/distributed/ddp-tutorial-series/requirements.txt +++ b/distributed/ddp-tutorial-series/requirements.txt @@ -1 +1 @@ -torch>=1.11.0 \ No newline at end of file +torch>=2.7 diff --git a/distributed/ddp-tutorial-series/run_example.sh b/distributed/ddp-tutorial-series/run_example.sh new file mode 100644 index 0000000000..d439b681b4 --- /dev/null +++ b/distributed/ddp-tutorial-series/run_example.sh @@ -0,0 +1,10 @@ +# /bin/bash +# bash run_example.sh {file_to_run.py} {num_gpus} +# where file_to_run = example to run. Default = 'example.py' +# num_gpus = num local gpus to use (must be at least 2). Default = 2 + +# samples to run include: +# example.py + +echo "Launching ${1:-example.py} with ${2:-2} gpus" +torchrun --nnodes=1 --nproc_per_node=${2:-2} --rdzv_id=101 --rdzv_endpoint="localhost:5972" ${1:-example.py} diff --git a/run_distributed_examples.sh b/run_distributed_examples.sh index e1f579c072..a7b03e489b 100755 --- a/run_distributed_examples.sh +++ b/run_distributed_examples.sh @@ -50,6 +50,12 @@ function distributed_tensor_parallelism() { uv run bash run_example.sh fsdp_tp_example.py || error "2D parallel example failed" } +function distributed_ddp-tutorial-series() { + uv run bash run_example.sh multigpu.py || error "ddp tutorial series multigpu example failed" + uv run bash run_example.sh multigpu_torchrun.py || error "ddp tutorial series multigpu torchrun example failed" + uv run bash run_example.sh multinode.py || error "ddp tutorial series multinode example failed" +} + function distributed_ddp() { uv run main.py || error "ddp example failed" } From 67b4a05e8d04fed848e5fb117aba34f39418a65e Mon Sep 17 00:00:00 2001 From: dggaytan Date: Mon, 11 Aug 2025 12:46:40 -0700 Subject: [PATCH 2/2] Adding torch accelerator to ddp-tutorial-series example Signed-off-by: dggaytan --- distributed/ddp-tutorial-series/README.md | 28 ++++++++++++++++--- distributed/ddp-tutorial-series/multigpu.py | 4 --- .../ddp-tutorial-series/multigpu_torchrun.py | 19 ++++++------- distributed/ddp-tutorial-series/multinode.py | 19 ++++++------- .../ddp-tutorial-series/run_example.sh | 5 ++-- distributed/ddp-tutorial-series/single_gpu.py | 2 +- run_distributed_examples.sh | 3 +- 7 files changed, 46 insertions(+), 34 deletions(-) diff --git a/distributed/ddp-tutorial-series/README.md b/distributed/ddp-tutorial-series/README.md index d0ce17c00f..3a27f3d8e0 100644 --- a/distributed/ddp-tutorial-series/README.md +++ b/distributed/ddp-tutorial-series/README.md @@ -15,7 +15,27 @@ Each code file extends upon the previous one. The series starts with a non-distr * [slurm/setup_pcluster_slurm.md](slurm/setup_pcluster_slurm.md): instructions to set up an AWS cluster * [slurm/config.yaml.template](slurm/config.yaml.template): configuration to set up an AWS cluster * [slurm/sbatch_run.sh](slurm/sbatch_run.sh): slurm script to launch the training job - - - - +## Installation +``` +pip install -r requirements.txt +``` +## Running Examples +For running the examples to run for 20 Epochs and save checkpoints every 5 Epochs, you can use the following command: +### Single GPU +``` +python single_gpu.py 20 5 +``` +### Multi-GPU +``` +python multigpu.py 20 5 +``` +### Multi-GPU Torchrun +``` +torchrun --nnodes=1 --nproc_per_node=4 multigpu_torchrun.py 20 5 +``` +### Multi-Node +``` +torchrun --nnodes=2 --nproc_per_node=4 multinode.py 20 5 +``` + +For more details, check the [run_examples.sh](distributed/ddp-tutorial-series/run_examples.sh) script. \ No newline at end of file diff --git a/distributed/ddp-tutorial-series/multigpu.py b/distributed/ddp-tutorial-series/multigpu.py index 652822bb06..7968bec80a 100644 --- a/distributed/ddp-tutorial-series/multigpu.py +++ b/distributed/ddp-tutorial-series/multigpu.py @@ -19,14 +19,10 @@ def ddp_setup(rank, world_size): os.environ["MASTER_ADDR"] = "localhost" os.environ["MASTER_PORT"] = "12355" - rank = int(os.environ["LOCAL_RANK"]) if torch.accelerator.is_available(): device = torch.device(f"{torch.accelerator.current_accelerator()}:{rank}") torch.accelerator.set_device_index(rank) print(f"Running on rank {rank} on device {device}") - else: - device = torch.device("cpu") - print(f"Running on device {device}") backend = torch.distributed.get_default_backend_for_device(device) init_process_group(backend=backend, rank=rank, world_size=world_size) diff --git a/distributed/ddp-tutorial-series/multigpu_torchrun.py b/distributed/ddp-tutorial-series/multigpu_torchrun.py index 0dadfe9449..01013ca53d 100644 --- a/distributed/ddp-tutorial-series/multigpu_torchrun.py +++ b/distributed/ddp-tutorial-series/multigpu_torchrun.py @@ -17,12 +17,11 @@ def ddp_setup(): torch.accelerator.set_device_index(rank) print(f"Running on rank {rank} on device {device}") else: - device = torch.device("cpu") - print(f"Running on device {device}") - - backend = torch.distributed.get_default_backend_for_device(device) - torch.distributed.init_process_group(backend=backend, device_id=device) - return device + print(f"Multi-GPU environment not detected") + + backend = torch.distributed.get_default_backend_for_device(rank) + torch.distributed.init_process_group(backend=backend, rank=rank, device_id=rank) + class Trainer: @@ -33,7 +32,6 @@ def __init__( optimizer: torch.optim.Optimizer, save_every: int, snapshot_path: str, - device: torch.device, ) -> None: self.gpu_id = int(os.environ["LOCAL_RANK"]) self.model = model.to(self.gpu_id) @@ -42,7 +40,6 @@ def __init__( self.save_every = save_every self.epochs_run = 0 self.snapshot_path = snapshot_path - self.device = device if os.path.exists(snapshot_path): print("Loading snapshot") self._load_snapshot(snapshot_path) @@ -50,7 +47,7 @@ def __init__( self.model = DDP(self.model, device_ids=[self.gpu_id]) def _load_snapshot(self, snapshot_path): - loc = str(self.device) + loc = str(torch.accelerator.current_accelerator()) snapshot = torch.load(snapshot_path, map_location=loc) self.model.load_state_dict(snapshot["MODEL_STATE"]) self.epochs_run = snapshot["EPOCHS_RUN"] @@ -105,10 +102,10 @@ def prepare_dataloader(dataset: Dataset, batch_size: int): def main(save_every: int, total_epochs: int, batch_size: int, snapshot_path: str = "snapshot.pt"): - device = ddp_setup() + ddp_setup() dataset, model, optimizer = load_train_objs() train_data = prepare_dataloader(dataset, batch_size) - trainer = Trainer(model, train_data, optimizer, save_every, snapshot_path, device) + trainer = Trainer(model, train_data, optimizer, save_every, snapshot_path) trainer.train(total_epochs) destroy_process_group() diff --git a/distributed/ddp-tutorial-series/multinode.py b/distributed/ddp-tutorial-series/multinode.py index 32ee031639..83d29108e1 100644 --- a/distributed/ddp-tutorial-series/multinode.py +++ b/distributed/ddp-tutorial-series/multinode.py @@ -17,12 +17,11 @@ def ddp_setup(): torch.accelerator.set_device_index(rank) print(f"Running on rank {rank} on device {device}") else: - device = torch.device("cpu") - print(f"Running on device {device}") - - backend = torch.distributed.get_default_backend_for_device(device) - torch.distributed.init_process_group(backend=backend, device_id=device) - return device + print(f"Multi-GPU environment not detected") + + backend = torch.distributed.get_default_backend_for_device(rank) + torch.distributed.init_process_group(backend=backend, rank=rank, device_id=rank) + class Trainer: def __init__( @@ -32,7 +31,6 @@ def __init__( optimizer: torch.optim.Optimizer, save_every: int, snapshot_path: str, - device: torch.device, ) -> None: self.local_rank = int(os.environ["LOCAL_RANK"]) self.global_rank = int(os.environ["RANK"]) @@ -42,7 +40,6 @@ def __init__( self.save_every = save_every self.epochs_run = 0 self.snapshot_path = snapshot_path - self.device = device if os.path.exists(snapshot_path): print("Loading snapshot") self._load_snapshot(snapshot_path) @@ -50,7 +47,7 @@ def __init__( self.model = DDP(self.model, device_ids=[self.local_rank]) def _load_snapshot(self, snapshot_path): - loc = str(self.device) + loc = str(torch.accelerator.current_accelerator()) snapshot = torch.load(snapshot_path, map_location=loc) self.model.load_state_dict(snapshot["MODEL_STATE"]) self.epochs_run = snapshot["EPOCHS_RUN"] @@ -105,10 +102,10 @@ def prepare_dataloader(dataset: Dataset, batch_size: int): def main(save_every: int, total_epochs: int, batch_size: int, snapshot_path: str = "snapshot.pt"): - device = ddp_setup() + ddp_setup() dataset, model, optimizer = load_train_objs() train_data = prepare_dataloader(dataset, batch_size) - trainer = Trainer(model, train_data, optimizer, save_every, snapshot_path, device) + trainer = Trainer(model, train_data, optimizer, save_every, snapshot_path) trainer.train(total_epochs) destroy_process_group() diff --git a/distributed/ddp-tutorial-series/run_example.sh b/distributed/ddp-tutorial-series/run_example.sh index d439b681b4..f9c9312171 100644 --- a/distributed/ddp-tutorial-series/run_example.sh +++ b/distributed/ddp-tutorial-series/run_example.sh @@ -4,7 +4,8 @@ # num_gpus = num local gpus to use (must be at least 2). Default = 2 # samples to run include: -# example.py +# multigpu_torchrun.py +# multinode.py echo "Launching ${1:-example.py} with ${2:-2} gpus" -torchrun --nnodes=1 --nproc_per_node=${2:-2} --rdzv_id=101 --rdzv_endpoint="localhost:5972" ${1:-example.py} +torchrun --nnodes=1 --nproc_per_node=${2:-2} --rdzv_id=101 --rdzv_endpoint="localhost:5972" ${1:-example.py} 10 1 diff --git a/distributed/ddp-tutorial-series/single_gpu.py b/distributed/ddp-tutorial-series/single_gpu.py index e91ab81cc1..c7a0b134e2 100644 --- a/distributed/ddp-tutorial-series/single_gpu.py +++ b/distributed/ddp-tutorial-series/single_gpu.py @@ -78,5 +78,5 @@ def main(device, total_epochs, save_every, batch_size): parser.add_argument('--batch_size', default=32, type=int, help='Input batch size on each device (default: 32)') args = parser.parse_args() - device = 0 # shorthand for cuda:0 + device = 0 main(device, args.total_epochs, args.save_every, args.batch_size) diff --git a/run_distributed_examples.sh b/run_distributed_examples.sh index a7b03e489b..1fdddbb589 100755 --- a/run_distributed_examples.sh +++ b/run_distributed_examples.sh @@ -51,9 +51,10 @@ function distributed_tensor_parallelism() { } function distributed_ddp-tutorial-series() { - uv run bash run_example.sh multigpu.py || error "ddp tutorial series multigpu example failed" + uv python multigpu.py 10 1 || error "ddp tutorial series multigpu example failed" uv run bash run_example.sh multigpu_torchrun.py || error "ddp tutorial series multigpu torchrun example failed" uv run bash run_example.sh multinode.py || error "ddp tutorial series multinode example failed" + uv python single_gpu.py 10 1 || error "ddp tutorial series single gpu example failed" } function distributed_ddp() {