Skip to content

Commit f56fb9c

Browse files
committed
Add accelerate API support for Word Language Model example
1 parent 7092296 commit f56fb9c

File tree

3 files changed

+17
-36
lines changed

3 files changed

+17
-36
lines changed

run_python_examples.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,7 @@ function vision_transformer() {
153153
}
154154

155155
function word_language_model() {
156-
uv run main.py --epochs 1 --dry-run $CUDA_FLAG --mps || error "word_language_model failed"
156+
uv run main.py --epochs 1 --dry-run $ACCEL_FLAG || error "word_language_model failed"
157157
}
158158

159159
function gcn() {

word_language_model/generate.py

Lines changed: 7 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -21,38 +21,28 @@
2121
help='number of words to generate')
2222
parser.add_argument('--seed', type=int, default=1111,
2323
help='random seed')
24-
parser.add_argument('--cuda', action='store_true',
25-
help='use CUDA')
26-
parser.add_argument('--mps', action='store_true', default=False,
27-
help='enables macOS GPU training')
2824
parser.add_argument('--temperature', type=float, default=1.0,
2925
help='temperature - higher will increase diversity')
3026
parser.add_argument('--log-interval', type=int, default=100,
3127
help='reporting interval')
28+
parser.add_argument('--accel', action='store_true', default=False,
29+
help='Enables accelerated inference')
3230
args = parser.parse_args()
3331

3432
# Set the random seed manually for reproducibility.
3533
torch.manual_seed(args.seed)
36-
if torch.cuda.is_available():
37-
if not args.cuda:
38-
print("WARNING: You have a CUDA device, so you should probably run with --cuda.")
39-
if torch.backends.mps.is_available():
40-
if not args.mps:
41-
print("WARNING: You have mps device, to enable macOS GPU run with --mps.")
42-
43-
use_mps = args.mps and torch.backends.mps.is_available()
44-
if args.cuda:
45-
device = torch.device("cuda")
46-
elif use_mps:
47-
device = torch.device("mps")
34+
35+
if args.accel and torch.accelerator.is_available():
36+
device = torch.accelerator.current_accelerator()
37+
4838
else:
4939
device = torch.device("cpu")
5040

5141
if args.temperature < 1e-3:
5242
parser.error("--temperature has to be greater or equal 1e-3.")
5343

5444
with open(args.checkpoint, 'rb') as f:
55-
model = torch.load(f, map_location=device)
45+
model = torch.load(f, map_location=device, weights_only=False)
5646
model.eval()
5747

5848
corpus = data.Corpus(args.data)

word_language_model/main.py

Lines changed: 9 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -37,10 +37,6 @@
3737
help='tie the word embedding and softmax weights')
3838
parser.add_argument('--seed', type=int, default=1111,
3939
help='random seed')
40-
parser.add_argument('--cuda', action='store_true', default=False,
41-
help='use CUDA')
42-
parser.add_argument('--mps', action='store_true', default=False,
43-
help='enables macOS GPU training')
4440
parser.add_argument('--log-interval', type=int, default=200, metavar='N',
4541
help='report interval')
4642
parser.add_argument('--save', type=str, default='model.pt',
@@ -51,25 +47,20 @@
5147
help='the number of heads in the encoder/decoder of the transformer model')
5248
parser.add_argument('--dry-run', action='store_true',
5349
help='verify the code and the model')
50+
parser.add_argument('--accel', action='store_true',help='Enables accelerated training')
5451
args = parser.parse_args()
5552

5653
# Set the random seed manually for reproducibility.
5754
torch.manual_seed(args.seed)
58-
if torch.cuda.is_available():
59-
if not args.cuda:
60-
print("WARNING: You have a CUDA device, so you should probably run with --cuda.")
61-
if hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
62-
if not args.mps:
63-
print("WARNING: You have mps device, to enable macOS GPU run with --mps.")
64-
65-
use_mps = args.mps and torch.backends.mps.is_available()
66-
if args.cuda:
67-
device = torch.device("cuda")
68-
elif use_mps:
69-
device = torch.device("mps")
55+
56+
if args.accel and torch.accelerator.is_available():
57+
device = torch.accelerator.current_accelerator()
58+
7059
else:
7160
device = torch.device("cpu")
7261

62+
print("Using device:", device)
63+
7364
###############################################################################
7465
# Load data
7566
###############################################################################
@@ -243,11 +234,11 @@ def export_onnx(path, batch_size, seq_len):
243234

244235
# Load the best saved model.
245236
with open(args.save, 'rb') as f:
246-
model = torch.load(f)
237+
torch.load(f, weights_only=False)
247238
# after load the rnn params are not a continuous chunk of memory
248239
# this makes them a continuous chunk, and will speed up forward pass
249240
# Currently, only rnn model supports flatten_parameters function.
250-
if args.model in ['RNN_TANH', 'RNN_RELU', 'LSTM', 'GRU']:
241+
if args.model in ['RNN_TANH', 'RNN_RELU', 'LSTM', 'GRU'] and device.type == 'cuda':
251242
model.rnn.flatten_parameters()
252243

253244
# Run on test data.

0 commit comments

Comments
 (0)