We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent 6f62fcd commit 90738a7Copy full SHA for 90738a7
word_language_model/generate.py
@@ -49,14 +49,16 @@
49
50
corpus = data.Corpus(args.data)
51
ntokens = len(corpus.dictionary)
52
-if model.model_type != 'Transformer':
+
53
+is_transformer_model = hasattr(model, 'model_type') and model.model_type == 'Transformer'
54
+if not is_transformer_model:
55
hidden = model.init_hidden(1)
56
input = torch.randint(ntokens, (1, 1), dtype=torch.long).to(device)
57
58
with open(args.outf, 'w') as outf:
59
with torch.no_grad(): # no tracking history
60
for i in range(args.words):
- if model.model_type == 'Transformer':
61
+ if is_transformer_model:
62
output = model(input, False)
63
word_weights = output[-1].squeeze().div(args.temperature).exp().cpu()
64
word_idx = torch.multinomial(word_weights, 1)[0]
0 commit comments