Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit 9708f36

Browse files
committedMay 6, 2025
FSDP2 example
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: f3eeff8 Pull Request resolved: #1339
1 parent 8393ceb commit 9708f36

File tree

2 files changed

+62
-0
lines changed

2 files changed

+62
-0
lines changed
 

‎distributed/FSDP2/README.md

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
## FSDP2
2+
3+
To run FSDP2 on transformer model:
4+
5+
## Install the requirements:
6+
~~~
7+
pip install -r requirements.txt
8+
~~~
9+
10+
## Ensure you are running a recent version of PyTorch:
11+
see https://pytorch.org/get-started/locally/ to install at least 2.5 and ideally a current nightly build.
12+
13+
Start the training with `torchrun` Torchrun (adjust nproc_per_node to your GPU count):
14+
15+
```
16+
torchrun --nnodes 1 --nproc_per_node 2 train.py
17+
```

‎distributed/FSDP2/train.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
import os
2+
import argparse
3+
import torch
4+
from torch.distributed.fsdp import fully_shard
5+
from torch.testing._internal.distributed._tensor.common_dtensor import ModelArgs, Transformer
6+
7+
8+
def main(args):
9+
torch.distributed.init_process_group(backend="nccl")
10+
rank = int(os.environ["LOCAL_RANK"])
11+
device = torch.device(f"cuda:{rank}")
12+
torch.cuda.set_device(device)
13+
torch.manual_seed(rank)
14+
vocab_size = 1024
15+
model_args = ModelArgs(
16+
n_layers=3,
17+
n_heads=4,
18+
vocab_size=vocab_size,
19+
max_seq_len=64,
20+
dropout_p=0,
21+
)
22+
model = Transformer(model_args)
23+
for layer in model.layers:
24+
fully_shard(layer)
25+
fully_shard(model)
26+
optim = torch.optim.Adam(model.parameters(), lr=1e-2)
27+
for _ in range(10):
28+
x = torch.randint(0, vocab_size, (32, 32), device=device)
29+
loss = model(x).sum()
30+
loss.backward()
31+
optim.step()
32+
optim.zero_grad()
33+
torch.distributed.destroy_process_group()
34+
35+
36+
if __name__ == "__main__":
37+
parser = argparse.ArgumentParser(description='PyTorch FSDP2 example')
38+
parser.add_argument('--meta-init', type=int, default=4, metavar='N',
39+
help='input batch size for training (default: 64)')
40+
parser.add_argument('--epochs', type=int, default=2, metavar='N',
41+
help='number of epochs to train (default: 3)')
42+
parser.add_argument('--seed', type=int, default=1, metavar='S',
43+
help='random seed (default: 1)')
44+
args = parser.parse_args()
45+
main(args)

0 commit comments

Comments
 (0)
Please sign in to comment.