From e5d4d6323508e24ac44dfe4d65a96b7d566c3d92 Mon Sep 17 00:00:00 2001
From: Wei Feng <weif@meta.com>
Date: Mon, 5 May 2025 19:45:16 -0700
Subject: [PATCH] FSDP2 example

Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

[ghstack-poisoned]
---
 distributed/FSDP2/README.md | 17 ++++++++++++++
 distributed/FSDP2/train.py  | 45 +++++++++++++++++++++++++++++++++++++
 2 files changed, 62 insertions(+)
 create mode 100644 distributed/FSDP2/README.md
 create mode 100644 distributed/FSDP2/train.py

diff --git a/distributed/FSDP2/README.md b/distributed/FSDP2/README.md
new file mode 100644
index 0000000000..1f85acb469
--- /dev/null
+++ b/distributed/FSDP2/README.md
@@ -0,0 +1,17 @@
+## FSDP2
+
+To run FSDP2 on transformer model:
+
+## Install the requirements:
+~~~
+pip install -r requirements.txt
+~~~
+
+## Ensure you are running a recent version of PyTorch:
+see https://pytorch.org/get-started/locally/ to install at least 2.5 and ideally a current nightly build.
+
+Start the training with `torchrun` Torchrun (adjust nproc_per_node to your GPU count):
+
+```
+torchrun --nnodes 1 --nproc_per_node 2 train.py
+```
diff --git a/distributed/FSDP2/train.py b/distributed/FSDP2/train.py
new file mode 100644
index 0000000000..3b38f214d6
--- /dev/null
+++ b/distributed/FSDP2/train.py
@@ -0,0 +1,45 @@
+import os
+import argparse
+import torch
+from torch.distributed.fsdp import fully_shard
+from torch.testing._internal.distributed._tensor.common_dtensor import ModelArgs, Transformer
+
+
+def main(args):
+    torch.distributed.init_process_group(backend="nccl")
+    rank = int(os.environ["LOCAL_RANK"])
+    device = torch.device(f"cuda:{rank}")
+    torch.cuda.set_device(device)
+    torch.manual_seed(rank)
+    vocab_size = 1024
+    model_args = ModelArgs(
+        n_layers=3,
+        n_heads=4,
+        vocab_size=vocab_size,
+        max_seq_len=64,
+        dropout_p=0,
+    )
+    model = Transformer(model_args)
+    for layer in model.layers:
+        fully_shard(layer)
+    fully_shard(model)
+    optim = torch.optim.Adam(model.parameters(), lr=1e-2)
+    for _ in range(10):
+        x = torch.randint(0, vocab_size, (32, 32), device=device)
+        loss = model(x).sum()
+        loss.backward()
+        optim.step()
+        optim.zero_grad()
+    torch.distributed.destroy_process_group()
+
+
+if __name__ == "__main__":
+    parser = argparse.ArgumentParser(description='PyTorch FSDP2 example')
+    parser.add_argument('--meta-init', type=int, default=4, metavar='N',
+                        help='input batch size for training (default: 64)')
+    parser.add_argument('--epochs', type=int, default=2, metavar='N',
+                        help='number of epochs to train (default: 3)')
+    parser.add_argument('--seed', type=int, default=1, metavar='S',
+                        help='random seed (default: 1)')
+    args = parser.parse_args()
+    main(args)