29
29
lr = 1e-3 # learning rate
30
30
momentum = 0.9 # momentum
31
31
workers = 4 # number of workers for loading data in the DataLoader
32
- epochs = 200 # number of epochs to run without early-stopping
32
+ epochs = 2 # number of epochs to run without early-stopping
33
33
grad_clip = None # clip gradients at this value
34
34
print_freq = 2000 # print training or validation status every __ batches
35
35
checkpoint = None # path to model checkpoint, None if none
36
- best_acc = 0. # assume the accuracy is 0 at first
37
36
38
37
device = torch .device ("cuda" if torch .cuda .is_available () else "cpu" )
39
38
@@ -44,7 +43,7 @@ def main():
44
43
"""
45
44
Training and validation.
46
45
"""
47
- global best_acc , epochs_since_improvement , checkpoint , start_epoch , word_map
46
+ global checkpoint , start_epoch , word_map
48
47
49
48
# Initialize model or load checkpoint
50
49
if checkpoint is not None :
@@ -53,10 +52,8 @@ def main():
53
52
optimizer = checkpoint ['optimizer' ]
54
53
word_map = checkpoint ['word_map' ]
55
54
start_epoch = checkpoint ['epoch' ] + 1
56
- best_acc = checkpoint ['best_acc' ]
57
- epochs_since_improvement = checkpoint ['epochs_since_improvement' ]
58
55
print (
59
- '\n Loaded checkpoint from epoch %d, with a previous best accuracy of %.3f. \n ' % (start_epoch - 1 , best_acc ))
56
+ '\n Loaded checkpoint from epoch %d. \n ' % (start_epoch - 1 ))
60
57
else :
61
58
embeddings , emb_size = load_word2vec_embeddings (word2vec_file , word_map ) # load pre-trained word2vec embeddings
62
59
@@ -85,8 +82,6 @@ def main():
85
82
# DataLoaders
86
83
train_loader = torch .utils .data .DataLoader (HANDataset (data_folder , 'train' ), batch_size = batch_size , shuffle = True ,
87
84
num_workers = workers , pin_memory = True )
88
- val_loader = torch .utils .data .DataLoader (HANDataset (data_folder , 'test' ), batch_size = batch_size , shuffle = True ,
89
- num_workers = workers , pin_memory = True )
90
85
91
86
# Epochs
92
87
for epoch in range (start_epoch , epochs ):
@@ -97,25 +92,11 @@ def main():
97
92
optimizer = optimizer ,
98
93
epoch = epoch )
99
94
100
- # One epoch's validation
101
- acc = validate (val_loader = val_loader ,
102
- model = model ,
103
- criterion = criterion )
104
-
105
- # Did validation accuracy improve?
106
- is_best = acc > best_acc
107
- best_acc = max (acc , best_acc )
108
- if not is_best :
109
- epochs_since_improvement += 1
110
- print ("\n Epochs since improvement: %d\n " % (epochs_since_improvement ,))
111
- else :
112
- epochs_since_improvement = 0
113
-
114
95
# Decay learning rate every epoch
115
- # adjust_learning_rate(optimizer, 0.5 )
96
+ adjust_learning_rate (optimizer , 0.1 )
116
97
117
98
# Save checkpoint
118
- save_checkpoint (epoch , model , optimizer , best_acc , word_map , epochs_since_improvement , is_best )
99
+ save_checkpoint (epoch , model , optimizer , word_map )
119
100
120
101
121
102
def train (train_loader , model , criterion , optimizer , epoch ):
@@ -190,69 +171,5 @@ def train(train_loader, model, criterion, optimizer, epoch):
190
171
acc = accs ))
191
172
192
173
193
- def validate (val_loader , model , criterion ):
194
- """
195
- Performs one epoch's validation.
196
-
197
- :param val_loader: DataLoader for validation data
198
- :param model: model
199
- :param criterion: cross entropy loss layer
200
- :return: validation accuracy score
201
- """
202
- model .eval ()
203
-
204
- batch_time = AverageMeter () # forward prop. + back prop. time per batch
205
- data_time = AverageMeter () # data loading time per batch
206
- losses = AverageMeter () # cross entropy loss
207
- accs = AverageMeter () # accuracies
208
-
209
- start = time .time ()
210
-
211
- # Batches
212
- for i , (documents , sentences_per_document , words_per_sentence , labels ) in enumerate (val_loader ):
213
-
214
- data_time .update (time .time () - start )
215
-
216
- documents = documents .to (device ) # (batch_size, sentence_limit, word_limit)
217
- sentences_per_document = sentences_per_document .squeeze (1 ).to (device ) # (batch_size)
218
- words_per_sentence = words_per_sentence .to (device ) # (batch_size, sentence_limit)
219
- labels = labels .squeeze (1 ).to (device ) # (batch_size)
220
-
221
- # Forward prop.
222
- scores , word_alphas , sentence_alphas = model (documents , sentences_per_document ,
223
- words_per_sentence ) # (n_documents, n_classes), (n_documents, max_doc_len_in_batch, max_sent_len_in_batch), (n_documents, max_doc_len_in_batch)
224
-
225
- # Loss
226
- loss = criterion (scores , labels )
227
-
228
- # Find accuracy
229
- _ , predictions = scores .max (dim = 1 ) # (n_documents)
230
- correct_predictions = torch .eq (predictions , labels ).sum ().item ()
231
- accuracy = correct_predictions / labels .size (0 )
232
-
233
- # Keep track of metrics
234
- losses .update (loss .item (), labels .size (0 ))
235
- batch_time .update (time .time () - start )
236
- accs .update (accuracy , labels .size (0 ))
237
-
238
- start = time .time ()
239
-
240
- # Print training status
241
- if i % print_freq == 0 :
242
- print ('[{0}/{1}]\t '
243
- 'Batch Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t '
244
- 'Data Load Time {data_time.val:.3f} ({data_time.avg:.3f})\t '
245
- 'Loss {loss.val:.4f} ({loss.avg:.4f})\t '
246
- 'Accuracy {acc.val:.3f} ({acc.avg:.3f})' .format (i , len (val_loader ),
247
- batch_time = batch_time ,
248
- data_time = data_time , loss = losses ,
249
- acc = accs ))
250
-
251
- print ('\n * LOSS - {loss.avg:.3f}, ACCURACY - {acc.avg:.3f}\n ' .format (loss = losses ,
252
- acc = accs ))
253
-
254
- return accs .avg
255
-
256
-
257
174
if __name__ == '__main__' :
258
175
main ()
0 commit comments