Skip to content

Commit 90738a7

Browse files
joistick11soumith
authored andcommitted
Fix error AttributeError: 'RNNModel' object has no attribute 'model_type' (pytorch#614)
1 parent 6f62fcd commit 90738a7

File tree

1 file changed

+4
-2
lines changed

1 file changed

+4
-2
lines changed

word_language_model/generate.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -49,14 +49,16 @@
4949

5050
corpus = data.Corpus(args.data)
5151
ntokens = len(corpus.dictionary)
52-
if model.model_type != 'Transformer':
52+
53+
is_transformer_model = hasattr(model, 'model_type') and model.model_type == 'Transformer'
54+
if not is_transformer_model:
5355
hidden = model.init_hidden(1)
5456
input = torch.randint(ntokens, (1, 1), dtype=torch.long).to(device)
5557

5658
with open(args.outf, 'w') as outf:
5759
with torch.no_grad(): # no tracking history
5860
for i in range(args.words):
59-
if model.model_type == 'Transformer':
61+
if is_transformer_model:
6062
output = model(input, False)
6163
word_weights = output[-1].squeeze().div(args.temperature).exp().cpu()
6264
word_idx = torch.multinomial(word_weights, 1)[0]

0 commit comments

Comments
 (0)