Skip to content

Commit 370cdce

Browse files
author
chongjiu.jin
committed
add rnn nlg
1 parent 704f59e commit 370cdce

File tree

1 file changed

+58
-0
lines changed

1 file changed

+58
-0
lines changed

rnn-poetry/model.py

+58
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
import torch
2+
import torch.nn as nn
3+
import torch.nn.functional as F
4+
5+
batch_size=128
6+
embed_size=128
7+
hidden_dims=256
8+
9+
def generate_poetry(model,word2ix,ix2word,device,begin,sent_len=4):
10+
start_idx=[word2ix['[']]
11+
end_word=''
12+
lens=0
13+
hidden = None
14+
ret=''
15+
data_ = torch.tensor([start_idx], device=device).long()
16+
output, hidden = model(data_, hidden)
17+
start_idx=[word2ix[begin]]
18+
ret+=begin
19+
while end_word!=']' and len(ret)<100:
20+
data_ = torch.tensor([start_idx],device=device).long()
21+
# print("data size",data_.size())
22+
output, hidden = model(data_, hidden)
23+
# print("output size", output.size())
24+
ouput_idx=output.view(-1).argmax().cpu()
25+
# print('ouput_idx',ouput_idx)
26+
# print('ouput_idx', ouput_idx.item())
27+
ouput_idx=ouput_idx.item()
28+
start_idx=[ouput_idx]
29+
end_word=ix2word[ouput_idx]
30+
ret+=end_word
31+
return ret
32+
33+
class RNNModel(nn.Module):
34+
def __init__(self, vocab_size, embedding_dim, hidden_dim):
35+
super(RNNModel, self).__init__()
36+
self.hidden_dim = hidden_dim
37+
self.embeddings = nn.Embedding(vocab_size, embedding_dim)
38+
self.lstm = nn.LSTM(embedding_dim, self.hidden_dim, num_layers=2)
39+
self.linear1 = nn.Linear(self.hidden_dim, vocab_size)
40+
41+
42+
43+
def forward(self, x, hidden=None):
44+
seq_len, batch_size = x.size()
45+
46+
47+
# size: (seq_len,batch_size,embeding_dim)
48+
embeds = self.embeddings(x)
49+
# output size: (seq_len,batch_size,hidden_dim)
50+
if hidden is None:
51+
output, hidden = self.lstm(embeds)
52+
else:
53+
h_0, c_0 = hidden
54+
output, hidden = self.lstm(embeds, (h_0, c_0))
55+
56+
# size: (seq_len*batch_size,vocab_size)
57+
output = self.linear1(output.view(seq_len * batch_size, -1))
58+
return output, hidden

0 commit comments

Comments
 (0)