diff --git a/distributed/FSDP2/README.md b/distributed/FSDP2/README.md
new file mode 100644
index 0000000000..5023173e18
--- /dev/null
+++ b/distributed/FSDP2/README.md
@@ -0,0 +1,8 @@
+## FSDP2
+To run FSDP2 on transformer model:
+```
+torchrun --nproc_per_node 2 train.py
+```
+
+## Ensure you are running a recent version of PyTorch:
+see https://pytorch.org/get-started/locally/ to install at least 2.5 and ideally a current nightly build.
diff --git a/distributed/FSDP2/checkpoint.py b/distributed/FSDP2/checkpoint.py
new file mode 100644
index 0000000000..aede917823
--- /dev/null
+++ b/distributed/FSDP2/checkpoint.py
@@ -0,0 +1,209 @@
+import os
+import time
+
+import torch
+import torch.nn as nn
+from torch.distributed.checkpoint.state_dict import (
+    _init_optim_state,
+    get_model_state_dict,
+    get_optimizer_state_dict,
+    set_model_state_dict,
+    set_optimizer_state_dict,
+    StateDictOptions,
+)
+from torch.distributed.fsdp import FSDPModule
+from torch.distributed.tensor import distribute_tensor, DTensor
+
+
+MODEL_CHECKPOINT = "model_state_dict.pt"
+OPTIM_CHECKPOINT = "optim_state_dict.pt"
+PARAMS = "params"
+
+
+def get_latest_checkpoint_folder(path):
+    max_num = None
+    if not os.path.exists(path):
+        return max_num
+    for name in os.listdir(path):
+        folder_path = os.path.join(path, name)
+        if os.path.isdir(folder_path):
+            try:
+                num = int(name)
+                if max_num is None or num > max_num:
+                    max_num = num
+            except ValueError:
+                pass  # Skip non-numeric folder names
+    return max_num
+
+
+class Checkpointer:
+    def __init__(self, folder: str, dcp_api: bool):
+        self.folder = folder
+        self.dcp_api = dcp_api
+        self.last_training_time = get_latest_checkpoint_folder(
+            f"{folder}/{'dcp_api' if dcp_api else 'dtensor_api'}"
+        )
+
+    def is_empty(self):
+        return self.last_training_time is None
+
+    def load_model(self, model: FSDPModule):
+        last_model_checkpoint = (
+            f"{self.folder}/{'dcp_api' if self.dcp_api else 'dtensor_api'}"
+            f"/{self.last_training_time}/{MODEL_CHECKPOINT}"
+        )
+        full_sd = torch.load(
+            last_model_checkpoint, mmap=True, weights_only=True, map_location="cpu"
+        )
+        if self.dcp_api:
+            set_model_state_dict(
+                model=model,
+                model_state_dict=full_sd,
+                options=StateDictOptions(
+                    full_state_dict=True,
+                    broadcast_from_rank0=True,
+                ),
+            )
+            return
+        meta_sharded_sd = model.state_dict()
+        sharded_sd = {}
+        for param_name, full_tensor in full_sd.items():
+            sharded_meta_param = meta_sharded_sd.get(param_name)
+            sharded_tensor = distribute_tensor(
+                full_tensor,
+                sharded_meta_param.device_mesh,
+                sharded_meta_param.placements,
+            )
+            sharded_sd[param_name] = nn.Parameter(sharded_tensor)
+        # choose `assign=True` since we cannot call `copy_` on meta tensor
+        model.load_state_dict(sharded_sd, strict=False, assign=True)
+
+    def load_optim(self, model: FSDPModule, opt: torch.optim.Optimizer):
+        last_optim_checkpoint = (
+            f"{self.folder}/{'dcp_api' if self.dcp_api else 'dtensor_api'}"
+            f"/{self.last_training_time}/{OPTIM_CHECKPOINT}"
+        )
+        full_sd = torch.load(
+            last_optim_checkpoint, mmap=True, weights_only=True, map_location="cpu"
+        )
+        if self.dcp_api:
+            set_optimizer_state_dict(
+                model=model,
+                optimizers=opt,
+                optim_state_dict=full_sd,
+                options=StateDictOptions(
+                    full_state_dict=True,
+                    broadcast_from_rank0=True,
+                ),
+            )
+            return
+        _init_optim_state(opt)
+        param_groups = opt.state_dict()["param_groups"]
+        state = opt.state_dict()["state"]
+
+        full_param_groups = full_sd["param_groups"]
+        full_state = full_sd["state"]
+
+        for param_group, full_param_group in zip(param_groups, full_param_groups):
+            for key, value in full_param_group.items():
+                if key == PARAMS:
+                    continue
+                param_group[key] = value
+            for pid, full_pid in zip(param_group[PARAMS], full_param_group[PARAMS]):
+                if pid not in state:
+                    continue
+                param_state = state[pid]
+                full_param_state = full_state[full_pid]
+                for attr, full_tensor in full_param_state.items():
+                    sharded_tensor = param_state[attr]
+                    if isinstance(sharded_tensor, DTensor):
+                        # exp_avg is DTensor
+                        param_state[attr] = distribute_tensor(
+                            full_tensor,
+                            sharded_tensor.device_mesh,
+                            sharded_tensor.placements,
+                        )
+                    else:
+                        # step is plain tensor
+                        param_state[attr] = full_tensor
+        opt.load_state_dict(
+            {
+                "param_groups": param_groups,
+                "state": state,
+            }
+        )
+
+    def _get_full_model_state_dict(self, model: FSDPModule):
+        if self.dcp_api:
+            return get_model_state_dict(
+                model=model,
+                options=StateDictOptions(
+                    full_state_dict=True,
+                    cpu_offload=True,
+                ),
+            )
+
+        sharded_sd = model.state_dict()
+        cpu_state_dict = {}
+        for param_name, sharded_param in sharded_sd.items():
+            full_param = sharded_param.full_tensor()
+            if torch.distributed.get_rank() == 0:
+                cpu_state_dict[param_name] = full_param.cpu()
+            else:
+                del full_param
+        return cpu_state_dict
+
+    def _get_full_optimizer_state_dict(
+        self,
+        model: FSDPModule,
+        opt: torch.optim.Optimizer,
+    ):
+        if self.dcp_api:
+            return get_optimizer_state_dict(
+                model=model,
+                optimizers=opt,
+                options=StateDictOptions(
+                    full_state_dict=True,
+                    cpu_offload=True,
+                ),
+            )
+        is_rank_zero = torch.distributed.get_rank() == 0
+        sharded_sd = opt.state_dict()
+        sharded_state = sharded_sd["state"]
+        full_state = {}
+        for group_id, sharded_group in sharded_state.items():
+            group_state = {}
+            for attr, sharded_tensor in sharded_group.items():
+                if isinstance(sharded_tensor, DTensor):
+                    # "exp_avg" in AdamW is `DTensor`
+                    full_tensor = sharded_tensor.full_tensor()
+                else:
+                    # "step" in AdamW is plain tensor
+                    full_tensor = sharded_tensor
+                if is_rank_zero:
+                    group_state[attr] = full_tensor.cpu()
+                else:
+                    del full_tensor
+            if is_rank_zero:
+                full_state[group_id] = group_state
+            else:
+                del group_state
+        if is_rank_zero:
+            return {
+                "param_groups": sharded_sd["param_groups"],
+                "state": full_state,
+            }
+        else:
+            return {}
+
+    def save(self, model: FSDPModule, optim: torch.optim.Optimizer):
+        model_state_dict = self._get_full_model_state_dict(model)
+        optim_state_dict = self._get_full_optimizer_state_dict(model, optim)
+        if torch.distributed.get_rank() == 0:
+            new_training_time = int(time.time() * 1000)
+            new_checkpoint_folder = f"{self.folder}/{'dcp_api' if self.dcp_api else 'dtensor_api'}/{new_training_time}"
+            new_model_checkpoint = f"{new_checkpoint_folder}/{MODEL_CHECKPOINT}"
+            new_optim_checkpoint = f"{new_checkpoint_folder}/{OPTIM_CHECKPOINT}"
+            os.makedirs(new_checkpoint_folder, exist_ok=True)
+            torch.save(model_state_dict, new_model_checkpoint)
+            torch.save(optim_state_dict, new_optim_checkpoint)
diff --git a/distributed/FSDP2/model.py b/distributed/FSDP2/model.py
new file mode 100644
index 0000000000..21b609615a
--- /dev/null
+++ b/distributed/FSDP2/model.py
@@ -0,0 +1,134 @@
+from dataclasses import dataclass
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+@dataclass
+class ModelArgs:
+    n_layers: int = 2
+    vocab_size: int = 8
+    max_seq_len: int = 16
+    dim: int = 16
+    n_heads: int = 4
+    dropout_p: float = 0.1
+
+
+class Attention(nn.Module):
+    def __init__(self, args: ModelArgs):
+        super().__init__()
+        assert args.dim % args.n_heads == 0
+        self.head_dim = args.dim // args.n_heads
+        self.n_heads = args.n_heads
+        self.dropout_p = args.dropout_p
+        self.resid_dropout = nn.Dropout(args.dropout_p)
+
+        self.wq = nn.Linear(args.dim, args.dim, bias=False)
+        self.wk = nn.Linear(args.dim, args.dim, bias=False)
+        self.wv = nn.Linear(args.dim, args.dim, bias=False)
+        self.wo = nn.Linear(args.dim, args.dim, bias=False)
+
+    def forward(self, x):
+        bsz, seq_len, _ = x.size()
+        queries, keys, values = self.wq(x), self.wk(x), self.wv(x)
+        queries = queries.view(bsz, seq_len, self.n_heads, self.head_dim)
+        keys = keys.view(bsz, seq_len, self.n_heads, self.head_dim)
+        values = values.view(bsz, seq_len, self.n_heads, self.head_dim)
+
+        queries = queries.transpose(1, 2)  # (bsz, n_heads, seq_len, head_dim)
+        keys = keys.transpose(1, 2)  # (bsz, n_heads, seq_len, head_dim)
+        values = values.transpose(1, 2)  # (bsz, n_heads, seq_len, head_dim)
+
+        output = F.scaled_dot_product_attention(
+            queries,
+            keys,
+            values,
+            None,
+            self.dropout_p if self.training else 0,
+        )
+        output = output.transpose(1, 2).contiguous().view(bsz, seq_len, -1)
+        return self.resid_dropout(self.wo(output))
+
+    def reset_parameters(self):
+        self.wq.reset_parameters()
+        self.wk.reset_parameters()
+        self.wv.reset_parameters()
+        self.wo.reset_parameters()
+
+
+class FeedForward(nn.Module):
+    def __init__(self, dim, hidden_dim, dropout_p):
+        super().__init__()
+        self.w1 = nn.Linear(dim, hidden_dim)
+        self.gelu = nn.GELU()
+        self.w2 = nn.Linear(hidden_dim, dim)
+        self.resid_dropout = nn.Dropout(dropout_p)
+
+    def forward(self, x):
+        return self.resid_dropout(self.w2(self.gelu(self.w1(x))))
+
+    def reset_parameters(self):
+        self.w1.reset_parameters()
+        self.w2.reset_parameters()
+
+
+class TransformerBlock(nn.Module):
+    def __init__(self, args: ModelArgs):
+        super().__init__()
+        self.attention_norm = nn.LayerNorm(args.dim)
+        self.attention = Attention(args)
+        self.ffn_norm = nn.LayerNorm(args.dim)
+        self.feed_forward = FeedForward(
+            args.dim, hidden_dim=4 * args.dim, dropout_p=args.dropout_p
+        )
+
+    def forward(self, x):
+        h = x + self.attention(self.attention_norm(x))
+        out = h + self.feed_forward(self.ffn_norm(h))
+        return out
+
+    def reset_parameters(self):
+        self.attention_norm.reset_parameters()
+        self.attention.reset_parameters()
+        self.ffn_norm.reset_parameters()
+        self.feed_forward.reset_parameters()
+
+
+# A toy transformer model, partly inspired by the nanoGPT model:
+# https://github.com/karpathy/nanoGPT.
+class Transformer(nn.Module):
+    def __init__(self, args: ModelArgs):
+        super().__init__()
+        assert args.vocab_size is not None
+        assert args.max_seq_len is not None
+        self.model_args = args
+        self.max_seq_len = args.max_seq_len
+        self.tok_embeddings = nn.Embedding(args.vocab_size, args.dim)
+        self.pos_embeddings = nn.Embedding(args.max_seq_len, args.dim)
+        self.dropout = nn.Dropout(args.dropout_p)
+        self.layers = nn.ModuleList()
+        for _ in range(args.n_layers):
+            self.layers.append(TransformerBlock(args))
+        self.norm = nn.LayerNorm(args.dim)
+        self.output = nn.Linear(args.dim, args.vocab_size, bias=False)
+
+    def forward(self, tokens):
+        _bsz, seq_len = tokens.size()
+        assert seq_len <= self.max_seq_len
+        h = self.tok_embeddings(tokens)
+        pos = torch.arange(0, seq_len, device=tokens.device)
+        p = self.pos_embeddings(pos)  # positional embeddings of shape (seq_len, dim)
+        h = h + p
+        h = self.dropout(h)
+        for layer in self.layers:
+            h = layer(h)
+        h = self.norm(h)
+        output = self.output(h).float()
+        return output
+
+    def reset_parameters(self):
+        self.tok_embeddings.reset_parameters()
+        self.pos_embeddings.reset_parameters()
+        self.norm.reset_parameters()
+        self.output.reset_parameters()
diff --git a/distributed/FSDP2/train.py b/distributed/FSDP2/train.py
new file mode 100644
index 0000000000..b47ad9cceb
--- /dev/null
+++ b/distributed/FSDP2/train.py
@@ -0,0 +1,98 @@
+import argparse
+import os
+
+import torch
+from checkpoint import Checkpointer
+from model import ModelArgs, Transformer
+from torch.distributed.fsdp import fully_shard, MixedPrecisionPolicy
+from utils import inspect_mixed_precision, inspect_model
+
+
+def set_modules_to_forward_prefetch(model, num_to_forward_prefetch):
+    for i, layer in enumerate(model.layers):
+        if i >= len(model.layers) - num_to_forward_prefetch:
+            break
+        layers_to_prefetch = [
+            model.layers[i + j] for j in range(1, num_to_forward_prefetch + 1)
+        ]
+        layer.set_modules_to_forward_prefetch(layers_to_prefetch)
+
+
+def set_modules_to_backward_prefetch(model, num_to_backward_prefetch):
+    for i, layer in enumerate(model.layers):
+        if i < num_to_backward_prefetch:
+            continue
+        layers_to_prefetch = [
+            model.layers[i - j] for j in range(1, num_to_backward_prefetch + 1)
+        ]
+        layer.set_modules_to_backward_prefetch(layers_to_prefetch)
+
+
+def main(args):
+    rank = int(os.environ["LOCAL_RANK"])
+    device = torch.device(f"cuda:{rank}")
+    torch.cuda.set_device(device)
+    torch.distributed.init_process_group(backend="nccl", device_id=device)
+    torch.manual_seed(0)
+    vocab_size = 1024
+    batch_size = 32
+    seq_len = 64
+    model_args = ModelArgs(
+        n_layers=10,
+        n_heads=4,
+        vocab_size=vocab_size,
+        max_seq_len=seq_len,
+        dropout_p=0,
+    )
+    with torch.device("meta"):
+        model = Transformer(model_args)
+    fsdp_kwargs = {}
+    if args.mixed_precision:
+        fsdp_kwargs["mp_policy"] = MixedPrecisionPolicy(
+            param_dtype=torch.bfloat16,
+            reduce_dtype=torch.float32,
+        )
+    for layer in model.layers:
+        fully_shard(layer, **fsdp_kwargs)
+    fully_shard(model, **fsdp_kwargs)
+
+    inspect_model(model)
+    if args.mixed_precision:
+        inspect_mixed_precision(model)
+
+    if args.explicit_prefetching:
+        set_modules_to_forward_prefetch(model, num_to_forward_prefetch=2)
+        set_modules_to_backward_prefetch(model, num_to_backward_prefetch=2)
+
+    checkpointer = Checkpointer("checkpoints", dcp_api=args.dcp_api)
+    if checkpointer.last_training_time is None:
+        model.to_empty(device="cuda")
+        model.reset_parameters()
+    else:
+        checkpointer.load_model(model)
+
+    optim = torch.optim.Adam(model.parameters(), lr=1e-2)
+    if checkpointer.last_training_time is not None:
+        checkpointer.load_optim(model, optim)
+
+    for _ in range(10):
+        if args.explicit_prefetching:
+            model.unshard()
+        x = torch.randint(0, vocab_size, (batch_size, seq_len), device=device)
+        loss = model(x).sum()
+        loss.backward()
+        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
+        optim.step()
+        optim.zero_grad()
+
+    checkpointer.save(model, optim)
+    torch.distributed.destroy_process_group()
+
+
+if __name__ == "__main__":
+    parser = argparse.ArgumentParser(description="PyTorch FSDP2 example")
+    parser.add_argument("--explicit-prefetching", action="store_true", default=False)
+    parser.add_argument("--mixed-precision", action="store_true", default=False)
+    parser.add_argument("--dcp-api", action="store_true", default=False)
+    args = parser.parse_args()
+    main(args)
diff --git a/distributed/FSDP2/utils.py b/distributed/FSDP2/utils.py
new file mode 100644
index 0000000000..402d2bf33d
--- /dev/null
+++ b/distributed/FSDP2/utils.py
@@ -0,0 +1,24 @@
+import torch
+from model import Transformer
+from torch.distributed.fsdp import FSDPModule
+from torch.distributed.tensor import Shard
+
+
+def inspect_model(model: FSDPModule):
+    assert isinstance(model, Transformer)
+    assert isinstance(model, FSDPModule)
+
+    if torch.distributed.get_rank() == 0:
+        print(model)
+
+    for param in model.parameters():
+        assert param.placements == (Shard(0),)
+        assert param.dtype == torch.float32
+        # print(param.get_local_tensor())
+
+
+def inspect_mixed_precision(model: FSDPModule):
+    model.unshard()
+    for param in model.parameters(recurse=False):
+        assert param.dtype == torch.bfloat16
+    model.reshard()