diff --git a/distributed/FSDP2/README.md b/distributed/FSDP2/README.md new file mode 100644 index 0000000000..5023173e18 --- /dev/null +++ b/distributed/FSDP2/README.md @@ -0,0 +1,8 @@ +## FSDP2 +To run FSDP2 on transformer model: +``` +torchrun --nproc_per_node 2 train.py +``` + +## Ensure you are running a recent version of PyTorch: +see https://pytorch.org/get-started/locally/ to install at least 2.5 and ideally a current nightly build. diff --git a/distributed/FSDP2/checkpoint.py b/distributed/FSDP2/checkpoint.py new file mode 100644 index 0000000000..aede917823 --- /dev/null +++ b/distributed/FSDP2/checkpoint.py @@ -0,0 +1,209 @@ +import os +import time + +import torch +import torch.nn as nn +from torch.distributed.checkpoint.state_dict import ( + _init_optim_state, + get_model_state_dict, + get_optimizer_state_dict, + set_model_state_dict, + set_optimizer_state_dict, + StateDictOptions, +) +from torch.distributed.fsdp import FSDPModule +from torch.distributed.tensor import distribute_tensor, DTensor + + +MODEL_CHECKPOINT = "model_state_dict.pt" +OPTIM_CHECKPOINT = "optim_state_dict.pt" +PARAMS = "params" + + +def get_latest_checkpoint_folder(path): + max_num = None + if not os.path.exists(path): + return max_num + for name in os.listdir(path): + folder_path = os.path.join(path, name) + if os.path.isdir(folder_path): + try: + num = int(name) + if max_num is None or num > max_num: + max_num = num + except ValueError: + pass # Skip non-numeric folder names + return max_num + + +class Checkpointer: + def __init__(self, folder: str, dcp_api: bool): + self.folder = folder + self.dcp_api = dcp_api + self.last_training_time = get_latest_checkpoint_folder( + f"{folder}/{'dcp_api' if dcp_api else 'dtensor_api'}" + ) + + def is_empty(self): + return self.last_training_time is None + + def load_model(self, model: FSDPModule): + last_model_checkpoint = ( + f"{self.folder}/{'dcp_api' if self.dcp_api else 'dtensor_api'}" + f"/{self.last_training_time}/{MODEL_CHECKPOINT}" + ) + full_sd = torch.load( + last_model_checkpoint, mmap=True, weights_only=True, map_location="cpu" + ) + if self.dcp_api: + set_model_state_dict( + model=model, + model_state_dict=full_sd, + options=StateDictOptions( + full_state_dict=True, + broadcast_from_rank0=True, + ), + ) + return + meta_sharded_sd = model.state_dict() + sharded_sd = {} + for param_name, full_tensor in full_sd.items(): + sharded_meta_param = meta_sharded_sd.get(param_name) + sharded_tensor = distribute_tensor( + full_tensor, + sharded_meta_param.device_mesh, + sharded_meta_param.placements, + ) + sharded_sd[param_name] = nn.Parameter(sharded_tensor) + # choose `assign=True` since we cannot call `copy_` on meta tensor + model.load_state_dict(sharded_sd, strict=False, assign=True) + + def load_optim(self, model: FSDPModule, opt: torch.optim.Optimizer): + last_optim_checkpoint = ( + f"{self.folder}/{'dcp_api' if self.dcp_api else 'dtensor_api'}" + f"/{self.last_training_time}/{OPTIM_CHECKPOINT}" + ) + full_sd = torch.load( + last_optim_checkpoint, mmap=True, weights_only=True, map_location="cpu" + ) + if self.dcp_api: + set_optimizer_state_dict( + model=model, + optimizers=opt, + optim_state_dict=full_sd, + options=StateDictOptions( + full_state_dict=True, + broadcast_from_rank0=True, + ), + ) + return + _init_optim_state(opt) + param_groups = opt.state_dict()["param_groups"] + state = opt.state_dict()["state"] + + full_param_groups = full_sd["param_groups"] + full_state = full_sd["state"] + + for param_group, full_param_group in zip(param_groups, full_param_groups): + for key, value in full_param_group.items(): + if key == PARAMS: + continue + param_group[key] = value + for pid, full_pid in zip(param_group[PARAMS], full_param_group[PARAMS]): + if pid not in state: + continue + param_state = state[pid] + full_param_state = full_state[full_pid] + for attr, full_tensor in full_param_state.items(): + sharded_tensor = param_state[attr] + if isinstance(sharded_tensor, DTensor): + # exp_avg is DTensor + param_state[attr] = distribute_tensor( + full_tensor, + sharded_tensor.device_mesh, + sharded_tensor.placements, + ) + else: + # step is plain tensor + param_state[attr] = full_tensor + opt.load_state_dict( + { + "param_groups": param_groups, + "state": state, + } + ) + + def _get_full_model_state_dict(self, model: FSDPModule): + if self.dcp_api: + return get_model_state_dict( + model=model, + options=StateDictOptions( + full_state_dict=True, + cpu_offload=True, + ), + ) + + sharded_sd = model.state_dict() + cpu_state_dict = {} + for param_name, sharded_param in sharded_sd.items(): + full_param = sharded_param.full_tensor() + if torch.distributed.get_rank() == 0: + cpu_state_dict[param_name] = full_param.cpu() + else: + del full_param + return cpu_state_dict + + def _get_full_optimizer_state_dict( + self, + model: FSDPModule, + opt: torch.optim.Optimizer, + ): + if self.dcp_api: + return get_optimizer_state_dict( + model=model, + optimizers=opt, + options=StateDictOptions( + full_state_dict=True, + cpu_offload=True, + ), + ) + is_rank_zero = torch.distributed.get_rank() == 0 + sharded_sd = opt.state_dict() + sharded_state = sharded_sd["state"] + full_state = {} + for group_id, sharded_group in sharded_state.items(): + group_state = {} + for attr, sharded_tensor in sharded_group.items(): + if isinstance(sharded_tensor, DTensor): + # "exp_avg" in AdamW is `DTensor` + full_tensor = sharded_tensor.full_tensor() + else: + # "step" in AdamW is plain tensor + full_tensor = sharded_tensor + if is_rank_zero: + group_state[attr] = full_tensor.cpu() + else: + del full_tensor + if is_rank_zero: + full_state[group_id] = group_state + else: + del group_state + if is_rank_zero: + return { + "param_groups": sharded_sd["param_groups"], + "state": full_state, + } + else: + return {} + + def save(self, model: FSDPModule, optim: torch.optim.Optimizer): + model_state_dict = self._get_full_model_state_dict(model) + optim_state_dict = self._get_full_optimizer_state_dict(model, optim) + if torch.distributed.get_rank() == 0: + new_training_time = int(time.time() * 1000) + new_checkpoint_folder = f"{self.folder}/{'dcp_api' if self.dcp_api else 'dtensor_api'}/{new_training_time}" + new_model_checkpoint = f"{new_checkpoint_folder}/{MODEL_CHECKPOINT}" + new_optim_checkpoint = f"{new_checkpoint_folder}/{OPTIM_CHECKPOINT}" + os.makedirs(new_checkpoint_folder, exist_ok=True) + torch.save(model_state_dict, new_model_checkpoint) + torch.save(optim_state_dict, new_optim_checkpoint) diff --git a/distributed/FSDP2/model.py b/distributed/FSDP2/model.py new file mode 100644 index 0000000000..21b609615a --- /dev/null +++ b/distributed/FSDP2/model.py @@ -0,0 +1,134 @@ +from dataclasses import dataclass + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +@dataclass +class ModelArgs: + n_layers: int = 2 + vocab_size: int = 8 + max_seq_len: int = 16 + dim: int = 16 + n_heads: int = 4 + dropout_p: float = 0.1 + + +class Attention(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + assert args.dim % args.n_heads == 0 + self.head_dim = args.dim // args.n_heads + self.n_heads = args.n_heads + self.dropout_p = args.dropout_p + self.resid_dropout = nn.Dropout(args.dropout_p) + + self.wq = nn.Linear(args.dim, args.dim, bias=False) + self.wk = nn.Linear(args.dim, args.dim, bias=False) + self.wv = nn.Linear(args.dim, args.dim, bias=False) + self.wo = nn.Linear(args.dim, args.dim, bias=False) + + def forward(self, x): + bsz, seq_len, _ = x.size() + queries, keys, values = self.wq(x), self.wk(x), self.wv(x) + queries = queries.view(bsz, seq_len, self.n_heads, self.head_dim) + keys = keys.view(bsz, seq_len, self.n_heads, self.head_dim) + values = values.view(bsz, seq_len, self.n_heads, self.head_dim) + + queries = queries.transpose(1, 2) # (bsz, n_heads, seq_len, head_dim) + keys = keys.transpose(1, 2) # (bsz, n_heads, seq_len, head_dim) + values = values.transpose(1, 2) # (bsz, n_heads, seq_len, head_dim) + + output = F.scaled_dot_product_attention( + queries, + keys, + values, + None, + self.dropout_p if self.training else 0, + ) + output = output.transpose(1, 2).contiguous().view(bsz, seq_len, -1) + return self.resid_dropout(self.wo(output)) + + def reset_parameters(self): + self.wq.reset_parameters() + self.wk.reset_parameters() + self.wv.reset_parameters() + self.wo.reset_parameters() + + +class FeedForward(nn.Module): + def __init__(self, dim, hidden_dim, dropout_p): + super().__init__() + self.w1 = nn.Linear(dim, hidden_dim) + self.gelu = nn.GELU() + self.w2 = nn.Linear(hidden_dim, dim) + self.resid_dropout = nn.Dropout(dropout_p) + + def forward(self, x): + return self.resid_dropout(self.w2(self.gelu(self.w1(x)))) + + def reset_parameters(self): + self.w1.reset_parameters() + self.w2.reset_parameters() + + +class TransformerBlock(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.attention_norm = nn.LayerNorm(args.dim) + self.attention = Attention(args) + self.ffn_norm = nn.LayerNorm(args.dim) + self.feed_forward = FeedForward( + args.dim, hidden_dim=4 * args.dim, dropout_p=args.dropout_p + ) + + def forward(self, x): + h = x + self.attention(self.attention_norm(x)) + out = h + self.feed_forward(self.ffn_norm(h)) + return out + + def reset_parameters(self): + self.attention_norm.reset_parameters() + self.attention.reset_parameters() + self.ffn_norm.reset_parameters() + self.feed_forward.reset_parameters() + + +# A toy transformer model, partly inspired by the nanoGPT model: +# https://github.com/karpathy/nanoGPT. +class Transformer(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + assert args.vocab_size is not None + assert args.max_seq_len is not None + self.model_args = args + self.max_seq_len = args.max_seq_len + self.tok_embeddings = nn.Embedding(args.vocab_size, args.dim) + self.pos_embeddings = nn.Embedding(args.max_seq_len, args.dim) + self.dropout = nn.Dropout(args.dropout_p) + self.layers = nn.ModuleList() + for _ in range(args.n_layers): + self.layers.append(TransformerBlock(args)) + self.norm = nn.LayerNorm(args.dim) + self.output = nn.Linear(args.dim, args.vocab_size, bias=False) + + def forward(self, tokens): + _bsz, seq_len = tokens.size() + assert seq_len <= self.max_seq_len + h = self.tok_embeddings(tokens) + pos = torch.arange(0, seq_len, device=tokens.device) + p = self.pos_embeddings(pos) # positional embeddings of shape (seq_len, dim) + h = h + p + h = self.dropout(h) + for layer in self.layers: + h = layer(h) + h = self.norm(h) + output = self.output(h).float() + return output + + def reset_parameters(self): + self.tok_embeddings.reset_parameters() + self.pos_embeddings.reset_parameters() + self.norm.reset_parameters() + self.output.reset_parameters() diff --git a/distributed/FSDP2/train.py b/distributed/FSDP2/train.py new file mode 100644 index 0000000000..b47ad9cceb --- /dev/null +++ b/distributed/FSDP2/train.py @@ -0,0 +1,98 @@ +import argparse +import os + +import torch +from checkpoint import Checkpointer +from model import ModelArgs, Transformer +from torch.distributed.fsdp import fully_shard, MixedPrecisionPolicy +from utils import inspect_mixed_precision, inspect_model + + +def set_modules_to_forward_prefetch(model, num_to_forward_prefetch): + for i, layer in enumerate(model.layers): + if i >= len(model.layers) - num_to_forward_prefetch: + break + layers_to_prefetch = [ + model.layers[i + j] for j in range(1, num_to_forward_prefetch + 1) + ] + layer.set_modules_to_forward_prefetch(layers_to_prefetch) + + +def set_modules_to_backward_prefetch(model, num_to_backward_prefetch): + for i, layer in enumerate(model.layers): + if i < num_to_backward_prefetch: + continue + layers_to_prefetch = [ + model.layers[i - j] for j in range(1, num_to_backward_prefetch + 1) + ] + layer.set_modules_to_backward_prefetch(layers_to_prefetch) + + +def main(args): + rank = int(os.environ["LOCAL_RANK"]) + device = torch.device(f"cuda:{rank}") + torch.cuda.set_device(device) + torch.distributed.init_process_group(backend="nccl", device_id=device) + torch.manual_seed(0) + vocab_size = 1024 + batch_size = 32 + seq_len = 64 + model_args = ModelArgs( + n_layers=10, + n_heads=4, + vocab_size=vocab_size, + max_seq_len=seq_len, + dropout_p=0, + ) + with torch.device("meta"): + model = Transformer(model_args) + fsdp_kwargs = {} + if args.mixed_precision: + fsdp_kwargs["mp_policy"] = MixedPrecisionPolicy( + param_dtype=torch.bfloat16, + reduce_dtype=torch.float32, + ) + for layer in model.layers: + fully_shard(layer, **fsdp_kwargs) + fully_shard(model, **fsdp_kwargs) + + inspect_model(model) + if args.mixed_precision: + inspect_mixed_precision(model) + + if args.explicit_prefetching: + set_modules_to_forward_prefetch(model, num_to_forward_prefetch=2) + set_modules_to_backward_prefetch(model, num_to_backward_prefetch=2) + + checkpointer = Checkpointer("checkpoints", dcp_api=args.dcp_api) + if checkpointer.last_training_time is None: + model.to_empty(device="cuda") + model.reset_parameters() + else: + checkpointer.load_model(model) + + optim = torch.optim.Adam(model.parameters(), lr=1e-2) + if checkpointer.last_training_time is not None: + checkpointer.load_optim(model, optim) + + for _ in range(10): + if args.explicit_prefetching: + model.unshard() + x = torch.randint(0, vocab_size, (batch_size, seq_len), device=device) + loss = model(x).sum() + loss.backward() + torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) + optim.step() + optim.zero_grad() + + checkpointer.save(model, optim) + torch.distributed.destroy_process_group() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="PyTorch FSDP2 example") + parser.add_argument("--explicit-prefetching", action="store_true", default=False) + parser.add_argument("--mixed-precision", action="store_true", default=False) + parser.add_argument("--dcp-api", action="store_true", default=False) + args = parser.parse_args() + main(args) diff --git a/distributed/FSDP2/utils.py b/distributed/FSDP2/utils.py new file mode 100644 index 0000000000..402d2bf33d --- /dev/null +++ b/distributed/FSDP2/utils.py @@ -0,0 +1,24 @@ +import torch +from model import Transformer +from torch.distributed.fsdp import FSDPModule +from torch.distributed.tensor import Shard + + +def inspect_model(model: FSDPModule): + assert isinstance(model, Transformer) + assert isinstance(model, FSDPModule) + + if torch.distributed.get_rank() == 0: + print(model) + + for param in model.parameters(): + assert param.placements == (Shard(0),) + assert param.dtype == torch.float32 + # print(param.get_local_tensor()) + + +def inspect_mixed_precision(model: FSDPModule): + model.unshard() + for param in model.parameters(recurse=False): + assert param.dtype == torch.bfloat16 + model.reshard()