diff --git a/distributed/FSDP2/train.py b/distributed/FSDP2/train.py index 3b38f214d6..f9c123baad 100644 --- a/distributed/FSDP2/train.py +++ b/distributed/FSDP2/train.py @@ -10,7 +10,7 @@ def main(args): rank = int(os.environ["LOCAL_RANK"]) device = torch.device(f"cuda:{rank}") torch.cuda.set_device(device) - torch.manual_seed(rank) + torch.manual_seed(args.seed) vocab_size = 1024 model_args = ModelArgs( n_layers=3, @@ -35,10 +35,6 @@ def main(args): if __name__ == "__main__": parser = argparse.ArgumentParser(description='PyTorch FSDP2 example') - parser.add_argument('--meta-init', type=int, default=4, metavar='N', - help='input batch size for training (default: 64)') - parser.add_argument('--epochs', type=int, default=2, metavar='N', - help='number of epochs to train (default: 3)') parser.add_argument('--seed', type=int, default=1, metavar='S', help='random seed (default: 1)') args = parser.parse_args()