|
37 | 37 | help='tie the word embedding and softmax weights')
|
38 | 38 | parser.add_argument('--seed', type=int, default=1111,
|
39 | 39 | help='random seed')
|
40 |
| -parser.add_argument('--cuda', action='store_true', default=False, |
41 |
| - help='use CUDA') |
42 |
| -parser.add_argument('--mps', action='store_true', default=False, |
43 |
| - help='enables macOS GPU training') |
44 | 40 | parser.add_argument('--log-interval', type=int, default=200, metavar='N',
|
45 | 41 | help='report interval')
|
46 | 42 | parser.add_argument('--save', type=str, default='model.pt',
|
|
51 | 47 | help='the number of heads in the encoder/decoder of the transformer model')
|
52 | 48 | parser.add_argument('--dry-run', action='store_true',
|
53 | 49 | help='verify the code and the model')
|
| 50 | +parser.add_argument('--accel', action='store_true',help='Enables accelerated training') |
54 | 51 | args = parser.parse_args()
|
55 | 52 |
|
56 | 53 | # Set the random seed manually for reproducibility.
|
57 | 54 | torch.manual_seed(args.seed)
|
58 |
| -if torch.cuda.is_available(): |
59 |
| - if not args.cuda: |
60 |
| - print("WARNING: You have a CUDA device, so you should probably run with --cuda.") |
61 |
| -if hasattr(torch.backends, "mps") and torch.backends.mps.is_available(): |
62 |
| - if not args.mps: |
63 |
| - print("WARNING: You have mps device, to enable macOS GPU run with --mps.") |
64 |
| - |
65 |
| -use_mps = args.mps and torch.backends.mps.is_available() |
66 |
| -if args.cuda: |
67 |
| - device = torch.device("cuda") |
68 |
| -elif use_mps: |
69 |
| - device = torch.device("mps") |
| 55 | + |
| 56 | +if args.accel and torch.accelerator.is_available(): |
| 57 | + device = torch.accelerator.current_accelerator() |
| 58 | + |
70 | 59 | else:
|
71 | 60 | device = torch.device("cpu")
|
72 | 61 |
|
| 62 | +print("Using device:", device) |
| 63 | + |
73 | 64 | ###############################################################################
|
74 | 65 | # Load data
|
75 | 66 | ###############################################################################
|
@@ -243,11 +234,11 @@ def export_onnx(path, batch_size, seq_len):
|
243 | 234 |
|
244 | 235 | # Load the best saved model.
|
245 | 236 | with open(args.save, 'rb') as f:
|
246 |
| - model = torch.load(f) |
| 237 | + torch.load(f, weights_only=False) |
247 | 238 | # after load the rnn params are not a continuous chunk of memory
|
248 | 239 | # this makes them a continuous chunk, and will speed up forward pass
|
249 | 240 | # Currently, only rnn model supports flatten_parameters function.
|
250 |
| - if args.model in ['RNN_TANH', 'RNN_RELU', 'LSTM', 'GRU']: |
| 241 | + if args.model in ['RNN_TANH', 'RNN_RELU', 'LSTM', 'GRU'] and device.type == 'cuda': |
251 | 242 | model.rnn.flatten_parameters()
|
252 | 243 |
|
253 | 244 | # Run on test data.
|
|
0 commit comments