diff --git a/word_language_model/README.md b/word_language_model/README.md index 254b726585..6d9ce2c7b8 100644 --- a/word_language_model/README.md +++ b/word_language_model/README.md @@ -4,15 +4,18 @@ This example trains a multi-layer RNN (Elman, GRU, or LSTM) or Transformer on a The trained model can then be used by the generate script to generate new text. ```bash -python main.py --cuda --epochs 6 # Train a LSTM on Wikitext-2 with CUDA. -python main.py --cuda --epochs 6 --tied # Train a tied LSTM on Wikitext-2 with CUDA. -python main.py --cuda --tied # Train a tied LSTM on Wikitext-2 with CUDA for 40 epochs. -python main.py --cuda --epochs 6 --model Transformer --lr 5 - # Train a Transformer model on Wikitext-2 with CUDA. +python main.py --accel --epochs 6 # Train a LSTM on Wikitext-2. +python main.py --accel --epochs 6 --tied # Train a tied LSTM on Wikitext-2. +python main.py --accel --tied # Train a tied LSTM on Wikitext-2for 40 epochs. +python main.py --accel --epochs 6 --model Transformer --lr 5 + # Train a Transformer model on Wikitext-2. -python generate.py # Generate samples from the default model checkpoint. +python generate.py --accel # Generate samples from the default model checkpoint. ``` +> [!NOTE] +> Example supports running on acceleration devices (CUDA, MPS, XPU) + The model uses the `nn.RNN` module (and its sister modules `nn.GRU` and `nn.LSTM`) or Transformer module (`nn.TransformerEncoder` and `nn.TransformerEncoderLayer`) which will automatically use the cuDNN backend if run on CUDA with cuDNN installed. During training, if a keyboard interrupt (Ctrl-C) is received, training is stopped and the current model is evaluated against the test dataset. @@ -35,8 +38,7 @@ optional arguments: --dropout DROPOUT dropout applied to layers (0 = no dropout) --tied tie the word embedding and softmax weights --seed SEED random seed - --cuda use CUDA - --mps enable GPU on macOS + --accel use accelerator --log-interval N report interval --save SAVE path to save the final model --onnx-export ONNX_EXPORT @@ -49,8 +51,8 @@ With these arguments, a variety of models can be tested. As an example, the following arguments produce slower but better models: ```bash -python main.py --cuda --emsize 650 --nhid 650 --dropout 0.5 --epochs 40 -python main.py --cuda --emsize 650 --nhid 650 --dropout 0.5 --epochs 40 --tied -python main.py --cuda --emsize 1500 --nhid 1500 --dropout 0.65 --epochs 40 -python main.py --cuda --emsize 1500 --nhid 1500 --dropout 0.65 --epochs 40 --tied +python main.py --accel --emsize 650 --nhid 650 --dropout 0.5 --epochs 40 +python main.py --accel --emsize 650 --nhid 650 --dropout 0.5 --epochs 40 --tied +python main.py --accel --emsize 1500 --nhid 1500 --dropout 0.65 --epochs 40 +python main.py --accel --emsize 1500 --nhid 1500 --dropout 0.65 --epochs 40 --tied ``` diff --git a/word_language_model/generate.py b/word_language_model/generate.py index e8214abdd7..e3dd1dd8a0 100644 --- a/word_language_model/generate.py +++ b/word_language_model/generate.py @@ -22,30 +22,20 @@ help='number of words to generate') parser.add_argument('--seed', type=int, default=1111, help='random seed') -parser.add_argument('--cuda', action='store_true', - help='use CUDA') -parser.add_argument('--mps', action='store_true', default=False, - help='enables macOS GPU training') parser.add_argument('--temperature', type=float, default=1.0, help='temperature - higher will increase diversity') parser.add_argument('--log-interval', type=int, default=100, help='reporting interval') +parser.add_argument('--accel', action='store_true', default=False, + help='use accelerator') args = parser.parse_args() # Set the random seed manually for reproducibility. torch.manual_seed(args.seed) -if torch.cuda.is_available(): - if not args.cuda: - print("WARNING: You have a CUDA device, so you should probably run with --cuda.") -if torch.backends.mps.is_available(): - if not args.mps: - print("WARNING: You have mps device, to enable macOS GPU run with --mps.") - -use_mps = args.mps and torch.backends.mps.is_available() -if args.cuda: - device = torch.device("cuda") -elif use_mps: - device = torch.device("mps") + +if args.accel and torch.accelerator.is_available(): + device = torch.accelerator.current_accelerator() + else: device = torch.device("cpu") diff --git a/word_language_model/main.py b/word_language_model/main.py index 72fee6cd3b..cd697da3db 100644 --- a/word_language_model/main.py +++ b/word_language_model/main.py @@ -37,10 +37,6 @@ help='tie the word embedding and softmax weights') parser.add_argument('--seed', type=int, default=1111, help='random seed') -parser.add_argument('--cuda', action='store_true', default=False, - help='use CUDA') -parser.add_argument('--mps', action='store_true', default=False, - help='enables macOS GPU training') parser.add_argument('--log-interval', type=int, default=200, metavar='N', help='report interval') parser.add_argument('--save', type=str, default='model.pt', @@ -51,25 +47,20 @@ help='the number of heads in the encoder/decoder of the transformer model') parser.add_argument('--dry-run', action='store_true', help='verify the code and the model') +parser.add_argument('--accel', action='store_true',help='Enables accelerated training') args = parser.parse_args() # Set the random seed manually for reproducibility. torch.manual_seed(args.seed) -if torch.cuda.is_available(): - if not args.cuda: - print("WARNING: You have a CUDA device, so you should probably run with --cuda.") -if hasattr(torch.backends, "mps") and torch.backends.mps.is_available(): - if not args.mps: - print("WARNING: You have mps device, to enable macOS GPU run with --mps.") - -use_mps = args.mps and torch.backends.mps.is_available() -if args.cuda: - device = torch.device("cuda") -elif use_mps: - device = torch.device("mps") + +if args.accel and torch.accelerator.is_available(): + device = torch.accelerator.current_accelerator() + else: device = torch.device("cpu") +print("Using device:", device) + ############################################################################### # Load data ###############################################################################