Skip to content

Commit e73fed7

Browse files
committed
FSDP2 example
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: b80f513 Pull Request resolved: #1339 args ghstack-source-id: b80f513 Pull Request resolved: #1340
1 parent 54e132e commit e73fed7

File tree

5 files changed

+454
-0
lines changed

5 files changed

+454
-0
lines changed

distributed/FSDP2/README.md

+8
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
## FSDP2
2+
To run FSDP2 on transformer model:
3+
```
4+
torchrun --nproc_per_node 2 train.py
5+
```
6+
7+
## Ensure you are running a recent version of PyTorch:
8+
see https://pytorch.org/get-started/locally/ to install at least 2.5 and ideally a current nightly build.

distributed/FSDP2/checkpoint.py

+195
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,195 @@
1+
import os
2+
import time
3+
import torch
4+
import torch.nn as nn
5+
from torch.distributed.fsdp import FSDPModule
6+
from torch.distributed.checkpoint.state_dict import (
7+
_init_optim_state,
8+
get_model_state_dict,
9+
get_optimizer_state_dict,
10+
set_model_state_dict,
11+
set_optimizer_state_dict,
12+
StateDictOptions,
13+
)
14+
from torch.distributed.tensor import distribute_tensor, DTensor
15+
16+
17+
MODEL_CHECKPOINT = "model_state_dict.pt"
18+
OPTIM_CHECKPOINT = "optim_state_dict.pt"
19+
20+
def get_latest_checkpoint_folder(path):
21+
max_num = None
22+
if not os.path.exists(path):
23+
return max_num
24+
for name in os.listdir(path):
25+
folder_path = os.path.join(path, name)
26+
if os.path.isdir(folder_path):
27+
try:
28+
num = int(name)
29+
if max_num is None or num > max_num:
30+
max_num = num
31+
except ValueError:
32+
pass # Skip non-numeric folder names
33+
return max_num
34+
35+
36+
class Checkpointer:
37+
def __init__(self, folder: str, dcp_api: bool):
38+
self.folder = folder
39+
self.dcp_api = dcp_api
40+
self.last_training_time = get_latest_checkpoint_folder(f"{folder}/{'dcp_api' if dcp_api else 'dtensor_api'}")
41+
42+
def is_empty(self):
43+
return self.last_training_time is None
44+
45+
def load_model(self, model: FSDPModule):
46+
last_model_checkpoint = f"{self.folder}/{'dcp_api' if self.dcp_api else 'dtensor_api'}/{self.last_training_time}/{MODEL_CHECKPOINT}"
47+
full_sd = torch.load(last_model_checkpoint, mmap=True, weights_only=True, map_location='cpu')
48+
if self.dcp_api:
49+
set_model_state_dict(
50+
model=model,
51+
model_state_dict=full_sd,
52+
options=StateDictOptions(
53+
full_state_dict=True,
54+
broadcast_from_rank0=True,
55+
),
56+
)
57+
return
58+
meta_sharded_sd = model.state_dict()
59+
sharded_sd = {}
60+
for param_name, full_tensor in full_sd.items():
61+
sharded_meta_param = meta_sharded_sd.get(param_name)
62+
sharded_tensor = distribute_tensor(
63+
full_tensor,
64+
sharded_meta_param.device_mesh,
65+
sharded_meta_param.placements,
66+
)
67+
sharded_sd[param_name] = nn.Parameter(sharded_tensor)
68+
# choose `assign=True` since we cannot call `copy_` on meta tensor
69+
model.load_state_dict(sharded_sd, strict=False, assign=True)
70+
71+
def load_optim(self, model: FSDPModule, opt: torch.optim.Optimizer):
72+
last_optim_checkpoint = f"{self.folder}/{'dcp_api' if self.dcp_api else 'dtensor_api'}/{self.last_training_time}/{OPTIM_CHECKPOINT}"
73+
full_sd = torch.load(last_optim_checkpoint, mmap=True, weights_only=True, map_location='cpu')
74+
if self.dcp_api:
75+
set_optimizer_state_dict(
76+
model=model,
77+
optimizers=opt,
78+
optim_state_dict=full_sd,
79+
options=StateDictOptions(
80+
full_state_dict=True,
81+
broadcast_from_rank0=True,
82+
)
83+
)
84+
return
85+
PARAMS = "params"
86+
_init_optim_state(opt)
87+
param_groups = opt.state_dict()["param_groups"]
88+
state = opt.state_dict()["state"]
89+
90+
full_param_groups = full_sd["param_groups"]
91+
full_state = full_sd["state"]
92+
93+
for param_group, full_param_group in zip(param_groups, full_param_groups):
94+
for key, value in full_param_group.items():
95+
if key == PARAMS:
96+
continue
97+
param_group[key] = value
98+
for pid, full_pid in zip(param_group[PARAMS], full_param_group[PARAMS]):
99+
if pid not in state:
100+
continue
101+
param_state = state[pid]
102+
full_param_state = full_state[full_pid]
103+
for attr, full_tensor in full_param_state.items():
104+
sharded_tensor = param_state[attr]
105+
if isinstance(sharded_tensor, DTensor):
106+
# exp_avg is DTensor
107+
param_state[attr] = distribute_tensor(
108+
full_tensor,
109+
sharded_tensor.device_mesh,
110+
sharded_tensor.placements,
111+
)
112+
else:
113+
# step is plain tensor
114+
param_state[attr] = full_tensor
115+
opt.load_state_dict(
116+
{
117+
"param_groups": param_groups,
118+
"state": state,
119+
}
120+
)
121+
122+
def _get_full_model_state_dict(self, model: FSDPModule):
123+
if self.dcp_api:
124+
return get_model_state_dict(
125+
model=model,
126+
options=StateDictOptions(
127+
full_state_dict=True,
128+
cpu_offload=True,
129+
)
130+
)
131+
132+
sharded_sd = model.state_dict()
133+
cpu_state_dict = {}
134+
for param_name, sharded_param in sharded_sd.items():
135+
full_param = sharded_param.full_tensor()
136+
if torch.distributed.get_rank() == 0:
137+
cpu_state_dict[param_name] = full_param.cpu()
138+
else:
139+
del full_param
140+
return cpu_state_dict
141+
142+
def _get_full_optimizer_state_dict(
143+
self,
144+
model: FSDPModule,
145+
opt: torch.optim.Optimizer,
146+
):
147+
if self.dcp_api:
148+
return get_optimizer_state_dict(
149+
model=model,
150+
optimizers=opt,
151+
options=StateDictOptions(
152+
full_state_dict=True,
153+
cpu_offload=True,
154+
)
155+
)
156+
is_rank_zero = (torch.distributed.get_rank() == 0)
157+
sharded_sd = opt.state_dict()
158+
sharded_state = sharded_sd["state"]
159+
full_state = {}
160+
for group_id, sharded_group in sharded_state.items():
161+
group_state = {}
162+
for attr, sharded_tensor in sharded_group.items():
163+
if isinstance(sharded_tensor, DTensor):
164+
# "exp_avg" in AdamW is `DTensor`
165+
full_tensor = sharded_tensor.full_tensor()
166+
else:
167+
# "step" in AdamW is plain tensor
168+
full_tensor = sharded_tensor
169+
if is_rank_zero:
170+
group_state[attr] = full_tensor.cpu()
171+
else:
172+
del full_tensor
173+
if is_rank_zero:
174+
full_state[group_id] = group_state
175+
else:
176+
del group_state
177+
if is_rank_zero:
178+
return {
179+
"param_groups": sharded_sd["param_groups"],
180+
"state": full_state,
181+
}
182+
else:
183+
return {}
184+
185+
def save(self, model: FSDPModule, optim: torch.optim.Optimizer):
186+
model_state_dict = self._get_full_model_state_dict(model)
187+
optim_state_dict = self._get_full_optimizer_state_dict(model, optim)
188+
if torch.distributed.get_rank() == 0:
189+
new_training_time = int(time.time() * 1000)
190+
new_checkpoint_folder = f"{self.folder}/{'dcp_api' if self.dcp_api else 'dtensor_api'}/{new_training_time}"
191+
new_model_checkpoint = f"{new_checkpoint_folder}/{MODEL_CHECKPOINT}"
192+
new_optim_checkpoint = f"{new_checkpoint_folder}/{OPTIM_CHECKPOINT}"
193+
os.makedirs(new_checkpoint_folder, exist_ok=True)
194+
torch.save(model_state_dict, new_model_checkpoint)
195+
torch.save(optim_state_dict, new_optim_checkpoint)

distributed/FSDP2/model.py

+133
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,133 @@
1+
from dataclasses import dataclass
2+
import torch
3+
import torch.nn as nn
4+
import torch.nn.functional as F
5+
6+
7+
@dataclass
8+
class ModelArgs:
9+
n_layers: int = 2
10+
vocab_size: int = 8
11+
max_seq_len: int = 16
12+
dim: int = 16
13+
n_heads: int = 4
14+
dropout_p: float = 0.1
15+
16+
17+
class Attention(nn.Module):
18+
def __init__(self, args: ModelArgs):
19+
super().__init__()
20+
assert args.dim % args.n_heads == 0
21+
self.head_dim = args.dim // args.n_heads
22+
self.n_heads = args.n_heads
23+
self.dropout_p = args.dropout_p
24+
self.resid_dropout = nn.Dropout(args.dropout_p)
25+
26+
self.wq = nn.Linear(args.dim, args.dim, bias=False)
27+
self.wk = nn.Linear(args.dim, args.dim, bias=False)
28+
self.wv = nn.Linear(args.dim, args.dim, bias=False)
29+
self.wo = nn.Linear(args.dim, args.dim, bias=False)
30+
31+
def forward(self, x):
32+
bsz, seq_len, _ = x.size()
33+
queries, keys, values = self.wq(x), self.wk(x), self.wv(x)
34+
queries = queries.view(bsz, seq_len, self.n_heads, self.head_dim)
35+
keys = keys.view(bsz, seq_len, self.n_heads, self.head_dim)
36+
values = values.view(bsz, seq_len, self.n_heads, self.head_dim)
37+
38+
queries = queries.transpose(1, 2) # (bsz, n_heads, seq_len, head_dim)
39+
keys = keys.transpose(1, 2) # (bsz, n_heads, seq_len, head_dim)
40+
values = values.transpose(1, 2) # (bsz, n_heads, seq_len, head_dim)
41+
42+
output = F.scaled_dot_product_attention(
43+
queries,
44+
keys,
45+
values,
46+
None,
47+
self.dropout_p if self.training else 0,
48+
)
49+
output = output.transpose(1, 2).contiguous().view(bsz, seq_len, -1)
50+
return self.resid_dropout(self.wo(output))
51+
52+
def reset_parameters(self):
53+
self.wq.reset_parameters()
54+
self.wk.reset_parameters()
55+
self.wv.reset_parameters()
56+
self.wo.reset_parameters()
57+
58+
59+
class FeedForward(nn.Module):
60+
def __init__(self, dim, hidden_dim, dropout_p):
61+
super().__init__()
62+
self.w1 = nn.Linear(dim, hidden_dim)
63+
self.gelu = nn.GELU()
64+
self.w2 = nn.Linear(hidden_dim, dim)
65+
self.resid_dropout = nn.Dropout(dropout_p)
66+
67+
def forward(self, x):
68+
return self.resid_dropout(self.w2(self.gelu(self.w1(x))))
69+
70+
def reset_parameters(self):
71+
self.w1.reset_parameters()
72+
self.w2.reset_parameters()
73+
74+
75+
class TransformerBlock(nn.Module):
76+
def __init__(self, args: ModelArgs):
77+
super().__init__()
78+
self.attention_norm = nn.LayerNorm(args.dim)
79+
self.attention = Attention(args)
80+
self.ffn_norm = nn.LayerNorm(args.dim)
81+
self.feed_forward = FeedForward(
82+
args.dim, hidden_dim=4 * args.dim, dropout_p=args.dropout_p
83+
)
84+
85+
def forward(self, x):
86+
h = x + self.attention(self.attention_norm(x))
87+
out = h + self.feed_forward(self.ffn_norm(h))
88+
return out
89+
90+
def reset_parameters(self):
91+
self.attention_norm.reset_parameters()
92+
self.attention.reset_parameters()
93+
self.ffn_norm.reset_parameters()
94+
self.feed_forward.reset_parameters()
95+
96+
97+
# A toy transformer model, partly inspired by the nanoGPT model:
98+
# https://github.com/karpathy/nanoGPT.
99+
class Transformer(nn.Module):
100+
def __init__(self, args: ModelArgs):
101+
super().__init__()
102+
assert args.vocab_size is not None
103+
assert args.max_seq_len is not None
104+
self.model_args = args
105+
self.max_seq_len = args.max_seq_len
106+
self.tok_embeddings = nn.Embedding(args.vocab_size, args.dim)
107+
self.pos_embeddings = nn.Embedding(args.max_seq_len, args.dim)
108+
self.dropout = nn.Dropout(args.dropout_p)
109+
self.layers = nn.ModuleList()
110+
for _ in range(args.n_layers):
111+
self.layers.append(TransformerBlock(args))
112+
self.norm = nn.LayerNorm(args.dim)
113+
self.output = nn.Linear(args.dim, args.vocab_size, bias=False)
114+
115+
def forward(self, tokens):
116+
_bsz, seq_len = tokens.size()
117+
assert seq_len <= self.max_seq_len
118+
h = self.tok_embeddings(tokens)
119+
pos = torch.arange(0, seq_len, device=tokens.device)
120+
p = self.pos_embeddings(pos) # positional embeddings of shape (seq_len, dim)
121+
h = h + p
122+
h = self.dropout(h)
123+
for layer in self.layers:
124+
h = layer(h)
125+
h = self.norm(h)
126+
output = self.output(h).float()
127+
return output
128+
129+
def reset_parameters(self):
130+
self.tok_embeddings.reset_parameters()
131+
self.pos_embeddings.reset_parameters()
132+
self.norm.reset_parameters()
133+
self.output.reset_parameters()

0 commit comments

Comments
 (0)