|
| 1 | +import os |
| 2 | +import argparse |
| 3 | +import torch |
| 4 | +from torch.distributed.fsdp import fully_shard |
| 5 | +from torch.testing._internal.distributed._tensor.common_dtensor import ModelArgs, Transformer |
| 6 | + |
| 7 | + |
| 8 | +def main(args): |
| 9 | + torch.distributed.init_process_group(backend="nccl") |
| 10 | + rank = int(os.environ["LOCAL_RANK"]) |
| 11 | + device = torch.device(f"cuda:{rank}") |
| 12 | + torch.cuda.set_device(device) |
| 13 | + torch.manual_seed(rank) |
| 14 | + vocab_size = 1024 |
| 15 | + model_args = ModelArgs( |
| 16 | + n_layers=3, |
| 17 | + n_heads=4, |
| 18 | + vocab_size=vocab_size, |
| 19 | + max_seq_len=64, |
| 20 | + dropout_p=0, |
| 21 | + ) |
| 22 | + model = Transformer(model_args) |
| 23 | + for layer in model.layers: |
| 24 | + fully_shard(layer) |
| 25 | + fully_shard(model) |
| 26 | + optim = torch.optim.Adam(model.parameters(), lr=1e-2) |
| 27 | + for _ in range(10): |
| 28 | + x = torch.randint(0, vocab_size, (32, 32), device=device) |
| 29 | + loss = model(x).sum() |
| 30 | + loss.backward() |
| 31 | + optim.step() |
| 32 | + optim.zero_grad() |
| 33 | + torch.distributed.destroy_process_group() |
| 34 | + |
| 35 | + |
| 36 | +if __name__ == "__main__": |
| 37 | + parser = argparse.ArgumentParser(description='PyTorch FSDP2 example') |
| 38 | + parser.add_argument('--meta-init', type=int, default=4, metavar='N', |
| 39 | + help='input batch size for training (default: 64)') |
| 40 | + parser.add_argument('--epochs', type=int, default=2, metavar='N', |
| 41 | + help='number of epochs to train (default: 3)') |
| 42 | + parser.add_argument('--seed', type=int, default=1, metavar='S', |
| 43 | + help='random seed (default: 1)') |
| 44 | + args = parser.parse_args() |
| 45 | + main(args) |
0 commit comments