|
| 1 | +#!/usr/bin/env python3 |
| 2 | +# -*- coding: utf-8 -*- |
| 3 | + |
| 4 | +""" |
| 5 | +CS224N 2018-19: Homework 5 |
| 6 | +""" |
| 7 | + |
| 8 | +import torch |
| 9 | +import torch.nn as nn |
| 10 | + |
| 11 | +class CharDecoder(nn.Module): |
| 12 | + def __init__(self, hidden_size, char_embedding_size=50, target_vocab=None): |
| 13 | + """ Init Character Decoder. |
| 14 | +
|
| 15 | + @param hidden_size (int): Hidden size of the decoder LSTM |
| 16 | + @param char_embedding_size (int): dimensionality of character embeddings |
| 17 | + @param target_vocab (VocabEntry): vocabulary for the target language. See vocab.py for documentation. |
| 18 | + """ |
| 19 | + ### YOUR CODE HERE for part 2a |
| 20 | + ### TODO - Initialize as an nn.Module. |
| 21 | + ### - Initialize the following variables: |
| 22 | + ### self.charDecoder: LSTM. Please use nn.LSTM() to construct this. |
| 23 | + ### self.char_output_projection: Linear layer, called W_{dec} and b_{dec} in the PDF |
| 24 | + ### self.decoderCharEmb: Embedding matrix of character embeddings |
| 25 | + ### self.target_vocab: vocabulary for the target language |
| 26 | + ### |
| 27 | + ### Hint: - Use target_vocab.char2id to access the character vocabulary for the target language. |
| 28 | + ### - Set the padding_idx argument of the embedding matrix. |
| 29 | + ### - Create a new Embedding layer. Do not reuse embeddings created in Part 1 of this assignment. |
| 30 | + super(CharDecoder, self).__init__() |
| 31 | + self.charDecoder = nn.LSTM(char_embedding_size,hidden_size,batch_first=True) #bias = True |
| 32 | + self.char_output_projection = nn.Linear(hidden_size,len(target_vocab.char2id)) |
| 33 | + self.decoderCharEmb = nn.Embedding(len(target_vocab.char2id),char_embedding_size,padding_idx=target_vocab.char2id['<pad>']) |
| 34 | + self.target_vocab = target_vocab |
| 35 | + |
| 36 | + ### END YOUR CODE |
| 37 | + |
| 38 | + |
| 39 | + |
| 40 | + def forward(self, input, dec_hidden=None): |
| 41 | + """ Forward pass of character decoder. |
| 42 | +
|
| 43 | + @param input: tensor of integers, shape (length, batch) |
| 44 | + @param dec_hidden: internal state of the LSTM before reading the input characters. A tuple of two tensors of shape (1, batch, hidden_size) |
| 45 | +
|
| 46 | + @returns scores: called s_t in the PDF, shape (length, batch, self.vocab_size) |
| 47 | + @returns dec_hidden: internal state of the LSTM after reading the input characters. A tuple of two tensors of shape (1, batch, hidden_size) |
| 48 | + """ |
| 49 | + ### YOUR CODE HERE for part 2b |
| 50 | + ### TODO - Implement the forward pass of the character decoder. |
| 51 | + #print('size of input is',input.size()) |
| 52 | + input = input.permute(1,0).contiguous() |
| 53 | + ip_embedding=self.decoderCharEmb(input)# F.embedding(source_padded, self.model_embeddings.source.weight) |
| 54 | + #X = nn.utils.rnn.pack_padded_sequence(src_padded_embedding,source_lengths) |
| 55 | + |
| 56 | + #ip_embedding = ip_embedding.permute(1,0,2).contiguous() |
| 57 | + |
| 58 | + output,(h_n,c_n) = self.charDecoder(ip_embedding,dec_hidden) |
| 59 | + #print('shape of hidden is',h_n.size()) |
| 60 | + s_t = self.char_output_projection(output) |
| 61 | + #print('shape of logits is',s_t.size()) |
| 62 | + s_t = s_t.permute(1,0,2).contiguous() |
| 63 | + |
| 64 | + return s_t,(h_n,c_n) |
| 65 | + ### END YOUR CODE |
| 66 | + |
| 67 | + |
| 68 | + def train_forward(self, char_sequence, dec_hidden=None): |
| 69 | + """ Forward computation during training. |
| 70 | +
|
| 71 | + @param char_sequence: tensor of integers, shape (length, batch). Note that "length" here and in forward() need not be the same. |
| 72 | + @param dec_hidden: initial internal state of the LSTM, obtained from the output of the word-level decoder. A tuple of two tensors of shape (1, batch, hidden_size) |
| 73 | +
|
| 74 | + @returns The cross-entropy loss, computed as the *sum* of cross-entropy losses of all the words in the batch. |
| 75 | + """ |
| 76 | + ### YOUR CODE HERE for part 2c |
| 77 | + ### TODO - Implement training forward pass. |
| 78 | + ### |
| 79 | + ### Hint: - Make sure padding characters do not contribute to the cross-entropy loss. |
| 80 | + ### - char_sequence corresponds to the sequence x_1 ... x_{n+1} from the handout (e.g., <START>,m,u,s,i,c,<END>). |
| 81 | + |
| 82 | + input = char_sequence[:-1,:] |
| 83 | + output = char_sequence[1:,:] |
| 84 | + #print(input) |
| 85 | + #print(output) |
| 86 | + target = output.reshape(-1) |
| 87 | + #print('shape of target',target.shape) |
| 88 | + s_t,(h_n,c_n) = self.forward(input,dec_hidden) |
| 89 | + #print('shape of s_t',s_t.shape) |
| 90 | + s_t_shape = s_t.shape |
| 91 | + s_t_re = s_t.reshape(-1,s_t.shape[2]) |
| 92 | + |
| 93 | + |
| 94 | + #print('shape of s_t_re',s_t_re.shape) |
| 95 | + loss = nn.CrossEntropyLoss(ignore_index=self.target_vocab.char2id['<pad>'],reduction='sum') |
| 96 | + |
| 97 | + return loss(s_t_re,target) |
| 98 | + ### END YOUR CODE |
| 99 | + |
| 100 | + def decode_greedy(self, initialStates, device, max_length=21): |
| 101 | + """ Greedy decoding |
| 102 | + @param initialStates: initial internal state of the LSTM, a tuple of two tensors of size (1, batch, hidden_size) |
| 103 | + @param device: torch.device (indicates whether the model is on CPU or GPU) |
| 104 | + @param max_length: maximum length of words to decode |
| 105 | +
|
| 106 | + @returns decodedWords: a list (of length batch) of strings, each of which has length <= max_length. |
| 107 | + The decoded strings should NOT contain the start-of-word and end-of-word characters. |
| 108 | + """ |
| 109 | + |
| 110 | + ### YOUR CODE HERE for part 2d |
| 111 | + ### TODO - Implement greedy decoding. |
| 112 | + ### Hints: |
| 113 | + ### - Use target_vocab.char2id and target_vocab.id2char to convert between integers and characters |
| 114 | + ### - Use torch.tensor(..., device=device) to turn a list of character indices into a tensor. |
| 115 | + ### - We use curly brackets as start-of-word and end-of-word characters. That is, use the character '{' for <START> and '}' for <END>. |
| 116 | + ### Their indices are self.target_vocab.start_of_word and self.target_vocab.end_of_word, respectively. |
| 117 | + decodedWords = [] |
| 118 | + current_char = self.target_vocab.start_of_word |
| 119 | + start_tensor = torch.tensor([current_char],device=device) |
| 120 | + #print('size of start_tensor is',start_tensor.shape) |
| 121 | + batch_size = initialStates[0].shape[1] |
| 122 | + start_batch = start_tensor.repeat(batch_size,1) |
| 123 | + #print('size of start_batch is',start_batch.shape) |
| 124 | + embed_current_char = self.decoderCharEmb(start_batch) |
| 125 | + #print('size of embed_current_char is',embed_current_char.shape) |
| 126 | + h_n,c_n = initialStates |
| 127 | + output_word = torch.zeros((batch_size,1),dtype=torch.long,device=device) |
| 128 | + for t in range(0,max_length): |
| 129 | + #h_n,c_n = self.charDecoder(embed_current_char,(h_n,c_n)) |
| 130 | + # s_t,(h_n,c_n) = self.forward(embed_current_char,(h_n,c_n)) |
| 131 | + #print('shape of embed_current_char is',embed_current_char.shape) |
| 132 | + output,(h_n,c_n) = self.charDecoder(embed_current_char,(h_n,c_n)) |
| 133 | + s_t = self.char_output_projection(output) |
| 134 | + #print(s_t.shape) |
| 135 | + st_smax = nn.Softmax(dim=2)(s_t) |
| 136 | + p_next = st_smax.argmax(2) |
| 137 | + current_char = p_next |
| 138 | + embed_current_char = self.decoderCharEmb(current_char) |
| 139 | + #decodedWords.append(self.target_vocab.id2char[current_char]) |
| 140 | + #print('*** size of current_char is',current_char.size()) |
| 141 | + output_word = torch.cat((output_word,current_char),1) |
| 142 | + #Convert output_word tensor to list and each element to char and put together in decodedWords |
| 143 | + out_list = output_word.tolist() |
| 144 | + out_list = [[self.target_vocab.id2char[x] for x in ilist[1:]] for ilist in out_list] |
| 145 | + decodedWords = [] |
| 146 | + for string_list in out_list: |
| 147 | + stringer = '' |
| 148 | + for char in string_list: |
| 149 | + if char!='}': |
| 150 | + stringer = stringer+char |
| 151 | + else: |
| 152 | + break |
| 153 | + decodedWords.append(stringer) |
| 154 | + return decodedWords |
| 155 | + ### END YOUR CODE |
| 156 | + |
0 commit comments