From e5d4d6323508e24ac44dfe4d65a96b7d566c3d92 Mon Sep 17 00:00:00 2001 From: Wei Feng <weif@meta.com> Date: Mon, 5 May 2025 19:45:16 -0700 Subject: [PATCH] FSDP2 example Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: [ghstack-poisoned] --- distributed/FSDP2/README.md | 17 ++++++++++++++ distributed/FSDP2/train.py | 45 +++++++++++++++++++++++++++++++++++++ 2 files changed, 62 insertions(+) create mode 100644 distributed/FSDP2/README.md create mode 100644 distributed/FSDP2/train.py diff --git a/distributed/FSDP2/README.md b/distributed/FSDP2/README.md new file mode 100644 index 0000000000..1f85acb469 --- /dev/null +++ b/distributed/FSDP2/README.md @@ -0,0 +1,17 @@ +## FSDP2 + +To run FSDP2 on transformer model: + +## Install the requirements: +~~~ +pip install -r requirements.txt +~~~ + +## Ensure you are running a recent version of PyTorch: +see https://pytorch.org/get-started/locally/ to install at least 2.5 and ideally a current nightly build. + +Start the training with `torchrun` Torchrun (adjust nproc_per_node to your GPU count): + +``` +torchrun --nnodes 1 --nproc_per_node 2 train.py +``` diff --git a/distributed/FSDP2/train.py b/distributed/FSDP2/train.py new file mode 100644 index 0000000000..3b38f214d6 --- /dev/null +++ b/distributed/FSDP2/train.py @@ -0,0 +1,45 @@ +import os +import argparse +import torch +from torch.distributed.fsdp import fully_shard +from torch.testing._internal.distributed._tensor.common_dtensor import ModelArgs, Transformer + + +def main(args): + torch.distributed.init_process_group(backend="nccl") + rank = int(os.environ["LOCAL_RANK"]) + device = torch.device(f"cuda:{rank}") + torch.cuda.set_device(device) + torch.manual_seed(rank) + vocab_size = 1024 + model_args = ModelArgs( + n_layers=3, + n_heads=4, + vocab_size=vocab_size, + max_seq_len=64, + dropout_p=0, + ) + model = Transformer(model_args) + for layer in model.layers: + fully_shard(layer) + fully_shard(model) + optim = torch.optim.Adam(model.parameters(), lr=1e-2) + for _ in range(10): + x = torch.randint(0, vocab_size, (32, 32), device=device) + loss = model(x).sum() + loss.backward() + optim.step() + optim.zero_grad() + torch.distributed.destroy_process_group() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description='PyTorch FSDP2 example') + parser.add_argument('--meta-init', type=int, default=4, metavar='N', + help='input batch size for training (default: 64)') + parser.add_argument('--epochs', type=int, default=2, metavar='N', + help='number of epochs to train (default: 3)') + parser.add_argument('--seed', type=int, default=1, metavar='S', + help='random seed (default: 1)') + args = parser.parse_args() + main(args)