Skip to content

FSDP2 example #1339

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions distributed/FSDP2/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
## FSDP2
To run FSDP2 on transformer model:
```
torchrun --nproc_per_node 2 train.py
```

## Ensure you are running a recent version of PyTorch:
see https://pytorch.org/get-started/locally/ to install at least 2.5 and ideally a current nightly build.
209 changes: 209 additions & 0 deletions distributed/FSDP2/checkpoint.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,209 @@
import os
import time

import torch
import torch.nn as nn
from torch.distributed.checkpoint.state_dict import (
_init_optim_state,
get_model_state_dict,
get_optimizer_state_dict,
set_model_state_dict,
set_optimizer_state_dict,
StateDictOptions,
)
from torch.distributed.fsdp import FSDPModule
from torch.distributed.tensor import distribute_tensor, DTensor


MODEL_CHECKPOINT = "model_state_dict.pt"
OPTIM_CHECKPOINT = "optim_state_dict.pt"
PARAMS = "params"


def get_latest_checkpoint_folder(path):
max_num = None
if not os.path.exists(path):
return max_num
for name in os.listdir(path):
folder_path = os.path.join(path, name)
if os.path.isdir(folder_path):
try:
num = int(name)
if max_num is None or num > max_num:
max_num = num
except ValueError:
pass # Skip non-numeric folder names
return max_num


class Checkpointer:
def __init__(self, folder: str, dcp_api: bool):
self.folder = folder
self.dcp_api = dcp_api
self.last_training_time = get_latest_checkpoint_folder(
f"{folder}/{'dcp_api' if dcp_api else 'dtensor_api'}"
)

def is_empty(self):
return self.last_training_time is None

def load_model(self, model: FSDPModule):
last_model_checkpoint = (
f"{self.folder}/{'dcp_api' if self.dcp_api else 'dtensor_api'}"
f"/{self.last_training_time}/{MODEL_CHECKPOINT}"
)
full_sd = torch.load(
last_model_checkpoint, mmap=True, weights_only=True, map_location="cpu"
)
if self.dcp_api:
set_model_state_dict(
model=model,
model_state_dict=full_sd,
options=StateDictOptions(
full_state_dict=True,
broadcast_from_rank0=True,
),
)
return
meta_sharded_sd = model.state_dict()
sharded_sd = {}
for param_name, full_tensor in full_sd.items():
sharded_meta_param = meta_sharded_sd.get(param_name)
sharded_tensor = distribute_tensor(
full_tensor,
sharded_meta_param.device_mesh,
sharded_meta_param.placements,
)
sharded_sd[param_name] = nn.Parameter(sharded_tensor)
# choose `assign=True` since we cannot call `copy_` on meta tensor
model.load_state_dict(sharded_sd, strict=False, assign=True)

def load_optim(self, model: FSDPModule, opt: torch.optim.Optimizer):
last_optim_checkpoint = (
f"{self.folder}/{'dcp_api' if self.dcp_api else 'dtensor_api'}"
f"/{self.last_training_time}/{OPTIM_CHECKPOINT}"
)
full_sd = torch.load(
last_optim_checkpoint, mmap=True, weights_only=True, map_location="cpu"
)
if self.dcp_api:
set_optimizer_state_dict(
model=model,
optimizers=opt,
optim_state_dict=full_sd,
options=StateDictOptions(
full_state_dict=True,
broadcast_from_rank0=True,
),
)
return
_init_optim_state(opt)
param_groups = opt.state_dict()["param_groups"]
state = opt.state_dict()["state"]

full_param_groups = full_sd["param_groups"]
full_state = full_sd["state"]

for param_group, full_param_group in zip(param_groups, full_param_groups):
for key, value in full_param_group.items():
if key == PARAMS:
continue
param_group[key] = value
for pid, full_pid in zip(param_group[PARAMS], full_param_group[PARAMS]):
if pid not in state:
continue
param_state = state[pid]
full_param_state = full_state[full_pid]
for attr, full_tensor in full_param_state.items():
sharded_tensor = param_state[attr]
if isinstance(sharded_tensor, DTensor):
# exp_avg is DTensor
param_state[attr] = distribute_tensor(
full_tensor,
sharded_tensor.device_mesh,
sharded_tensor.placements,
)
else:
# step is plain tensor
param_state[attr] = full_tensor
opt.load_state_dict(
{
"param_groups": param_groups,
"state": state,
}
)

def _get_full_model_state_dict(self, model: FSDPModule):
if self.dcp_api:
return get_model_state_dict(
model=model,
options=StateDictOptions(
full_state_dict=True,
cpu_offload=True,
),
)

sharded_sd = model.state_dict()
cpu_state_dict = {}
for param_name, sharded_param in sharded_sd.items():
full_param = sharded_param.full_tensor()
if torch.distributed.get_rank() == 0:
cpu_state_dict[param_name] = full_param.cpu()
else:
del full_param
return cpu_state_dict

def _get_full_optimizer_state_dict(
self,
model: FSDPModule,
opt: torch.optim.Optimizer,
):
if self.dcp_api:
return get_optimizer_state_dict(
model=model,
optimizers=opt,
options=StateDictOptions(
full_state_dict=True,
cpu_offload=True,
),
)
is_rank_zero = torch.distributed.get_rank() == 0
sharded_sd = opt.state_dict()
sharded_state = sharded_sd["state"]
full_state = {}
for group_id, sharded_group in sharded_state.items():
group_state = {}
for attr, sharded_tensor in sharded_group.items():
if isinstance(sharded_tensor, DTensor):
# "exp_avg" in AdamW is `DTensor`
full_tensor = sharded_tensor.full_tensor()
else:
# "step" in AdamW is plain tensor
full_tensor = sharded_tensor
if is_rank_zero:
group_state[attr] = full_tensor.cpu()
else:
del full_tensor
if is_rank_zero:
full_state[group_id] = group_state
else:
del group_state
if is_rank_zero:
return {
"param_groups": sharded_sd["param_groups"],
"state": full_state,
}
else:
return {}

def save(self, model: FSDPModule, optim: torch.optim.Optimizer):
model_state_dict = self._get_full_model_state_dict(model)
optim_state_dict = self._get_full_optimizer_state_dict(model, optim)
if torch.distributed.get_rank() == 0:
new_training_time = int(time.time() * 1000)
new_checkpoint_folder = f"{self.folder}/{'dcp_api' if self.dcp_api else 'dtensor_api'}/{new_training_time}"
new_model_checkpoint = f"{new_checkpoint_folder}/{MODEL_CHECKPOINT}"
new_optim_checkpoint = f"{new_checkpoint_folder}/{OPTIM_CHECKPOINT}"
os.makedirs(new_checkpoint_folder, exist_ok=True)
torch.save(model_state_dict, new_model_checkpoint)
torch.save(optim_state_dict, new_optim_checkpoint)
134 changes: 134 additions & 0 deletions distributed/FSDP2/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
from dataclasses import dataclass

import torch
import torch.nn as nn
import torch.nn.functional as F


@dataclass
class ModelArgs:
n_layers: int = 2
vocab_size: int = 8
max_seq_len: int = 16
dim: int = 16
n_heads: int = 4
dropout_p: float = 0.1


class Attention(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
assert args.dim % args.n_heads == 0
self.head_dim = args.dim // args.n_heads
self.n_heads = args.n_heads
self.dropout_p = args.dropout_p
self.resid_dropout = nn.Dropout(args.dropout_p)

self.wq = nn.Linear(args.dim, args.dim, bias=False)
self.wk = nn.Linear(args.dim, args.dim, bias=False)
self.wv = nn.Linear(args.dim, args.dim, bias=False)
self.wo = nn.Linear(args.dim, args.dim, bias=False)

def forward(self, x):
bsz, seq_len, _ = x.size()
queries, keys, values = self.wq(x), self.wk(x), self.wv(x)
queries = queries.view(bsz, seq_len, self.n_heads, self.head_dim)
keys = keys.view(bsz, seq_len, self.n_heads, self.head_dim)
values = values.view(bsz, seq_len, self.n_heads, self.head_dim)

queries = queries.transpose(1, 2) # (bsz, n_heads, seq_len, head_dim)
keys = keys.transpose(1, 2) # (bsz, n_heads, seq_len, head_dim)
values = values.transpose(1, 2) # (bsz, n_heads, seq_len, head_dim)

output = F.scaled_dot_product_attention(
queries,
keys,
values,
None,
self.dropout_p if self.training else 0,
)
output = output.transpose(1, 2).contiguous().view(bsz, seq_len, -1)
return self.resid_dropout(self.wo(output))

def reset_parameters(self):
self.wq.reset_parameters()
self.wk.reset_parameters()
self.wv.reset_parameters()
self.wo.reset_parameters()


class FeedForward(nn.Module):
def __init__(self, dim, hidden_dim, dropout_p):
super().__init__()
self.w1 = nn.Linear(dim, hidden_dim)
self.gelu = nn.GELU()
self.w2 = nn.Linear(hidden_dim, dim)
self.resid_dropout = nn.Dropout(dropout_p)

def forward(self, x):
return self.resid_dropout(self.w2(self.gelu(self.w1(x))))

def reset_parameters(self):
self.w1.reset_parameters()
self.w2.reset_parameters()


class TransformerBlock(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.attention_norm = nn.LayerNorm(args.dim)
self.attention = Attention(args)
self.ffn_norm = nn.LayerNorm(args.dim)
self.feed_forward = FeedForward(
args.dim, hidden_dim=4 * args.dim, dropout_p=args.dropout_p
)

def forward(self, x):
h = x + self.attention(self.attention_norm(x))
out = h + self.feed_forward(self.ffn_norm(h))
return out

def reset_parameters(self):
self.attention_norm.reset_parameters()
self.attention.reset_parameters()
self.ffn_norm.reset_parameters()
self.feed_forward.reset_parameters()


# A toy transformer model, partly inspired by the nanoGPT model:
# https://github.com/karpathy/nanoGPT.
class Transformer(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
assert args.vocab_size is not None
assert args.max_seq_len is not None
self.model_args = args
self.max_seq_len = args.max_seq_len
self.tok_embeddings = nn.Embedding(args.vocab_size, args.dim)
self.pos_embeddings = nn.Embedding(args.max_seq_len, args.dim)
self.dropout = nn.Dropout(args.dropout_p)
self.layers = nn.ModuleList()
for _ in range(args.n_layers):
self.layers.append(TransformerBlock(args))
self.norm = nn.LayerNorm(args.dim)
self.output = nn.Linear(args.dim, args.vocab_size, bias=False)

def forward(self, tokens):
_bsz, seq_len = tokens.size()
assert seq_len <= self.max_seq_len
h = self.tok_embeddings(tokens)
pos = torch.arange(0, seq_len, device=tokens.device)
p = self.pos_embeddings(pos) # positional embeddings of shape (seq_len, dim)
h = h + p
h = self.dropout(h)
for layer in self.layers:
h = layer(h)
h = self.norm(h)
output = self.output(h).float()
return output

def reset_parameters(self):
self.tok_embeddings.reset_parameters()
self.pos_embeddings.reset_parameters()
self.norm.reset_parameters()
self.output.reset_parameters()
Loading