Skip to content

Commit 3fe0830

Browse files
committed
changes
1 parent dc4095b commit 3fe0830

File tree

4 files changed

+60
-95
lines changed

4 files changed

+60
-95
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,4 +80,4 @@ I am still writing this tutorial.
8080

8181
In the meantime, you could take a look at the code – it works!
8282

83-
We achieve an accuracy of **74.8%** (against **75.8%** in the paper) on the Yahoo Answer dataset.
83+
We achieve an accuracy of **75.1%** (against **75.8%** in the paper) on the Yahoo Answer dataset.

eval.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
import time
2+
from utils import *
3+
from datasets import HANDataset
4+
5+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
6+
7+
# Data parameters
8+
data_folder = '/media/ssd/han data'
9+
10+
# Evaluation parameters
11+
batch_size = 64 # batch size
12+
workers = 4 # number of workers for loading data in the DataLoader
13+
print_freq = 2000 # print training or validation status every __ batches
14+
checkpoint = 'checkpoint_han.pth.tar'
15+
16+
# Load model
17+
checkpoint = torch.load(checkpoint)
18+
model = checkpoint['model']
19+
model = model.to(device)
20+
model.eval()
21+
22+
# Load test data
23+
test_loader = torch.utils.data.DataLoader(HANDataset(data_folder, 'test'), batch_size=batch_size, shuffle=False,
24+
num_workers=workers, pin_memory=True)
25+
26+
# Track metrics
27+
accs = AverageMeter() # accuracies
28+
29+
# Evaluate in batches
30+
for i, (documents, sentences_per_document, words_per_sentence, labels) in enumerate(
31+
tqdm(test_loader, desc='Evaluating')):
32+
33+
documents = documents.to(device) # (batch_size, sentence_limit, word_limit)
34+
sentences_per_document = sentences_per_document.squeeze(1).to(device) # (batch_size)
35+
words_per_sentence = words_per_sentence.to(device) # (batch_size, sentence_limit)
36+
labels = labels.squeeze(1).to(device) # (batch_size)
37+
38+
# Forward prop.
39+
scores, word_alphas, sentence_alphas = model(documents, sentences_per_document,
40+
words_per_sentence) # (n_documents, n_classes), (n_documents, max_doc_len_in_batch, max_sent_len_in_batch), (n_documents, max_doc_len_in_batch)
41+
42+
# Find accuracy
43+
_, predictions = scores.max(dim=1) # (n_documents)
44+
correct_predictions = torch.eq(predictions, labels).sum().item()
45+
accuracy = correct_predictions / labels.size(0)
46+
47+
# Keep track of metrics
48+
accs.update(accuracy, labels.size(0))
49+
50+
start = time.time()
51+
52+
# Print final result
53+
print('\n * TEST ACCURACY - %.1f per cent\n' % (accs.avg * 100))

train.py

Lines changed: 5 additions & 88 deletions
Original file line numberDiff line numberDiff line change
@@ -29,11 +29,10 @@
2929
lr = 1e-3 # learning rate
3030
momentum = 0.9 # momentum
3131
workers = 4 # number of workers for loading data in the DataLoader
32-
epochs = 200 # number of epochs to run without early-stopping
32+
epochs = 2 # number of epochs to run without early-stopping
3333
grad_clip = None # clip gradients at this value
3434
print_freq = 2000 # print training or validation status every __ batches
3535
checkpoint = None # path to model checkpoint, None if none
36-
best_acc = 0. # assume the accuracy is 0 at first
3736

3837
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
3938

@@ -44,7 +43,7 @@ def main():
4443
"""
4544
Training and validation.
4645
"""
47-
global best_acc, epochs_since_improvement, checkpoint, start_epoch, word_map
46+
global checkpoint, start_epoch, word_map
4847

4948
# Initialize model or load checkpoint
5049
if checkpoint is not None:
@@ -53,10 +52,8 @@ def main():
5352
optimizer = checkpoint['optimizer']
5453
word_map = checkpoint['word_map']
5554
start_epoch = checkpoint['epoch'] + 1
56-
best_acc = checkpoint['best_acc']
57-
epochs_since_improvement = checkpoint['epochs_since_improvement']
5855
print(
59-
'\nLoaded checkpoint from epoch %d, with a previous best accuracy of %.3f.\n' % (start_epoch - 1, best_acc))
56+
'\nLoaded checkpoint from epoch %d.\n' % (start_epoch - 1))
6057
else:
6158
embeddings, emb_size = load_word2vec_embeddings(word2vec_file, word_map) # load pre-trained word2vec embeddings
6259

@@ -85,8 +82,6 @@ def main():
8582
# DataLoaders
8683
train_loader = torch.utils.data.DataLoader(HANDataset(data_folder, 'train'), batch_size=batch_size, shuffle=True,
8784
num_workers=workers, pin_memory=True)
88-
val_loader = torch.utils.data.DataLoader(HANDataset(data_folder, 'test'), batch_size=batch_size, shuffle=True,
89-
num_workers=workers, pin_memory=True)
9085

9186
# Epochs
9287
for epoch in range(start_epoch, epochs):
@@ -97,25 +92,11 @@ def main():
9792
optimizer=optimizer,
9893
epoch=epoch)
9994

100-
# One epoch's validation
101-
acc = validate(val_loader=val_loader,
102-
model=model,
103-
criterion=criterion)
104-
105-
# Did validation accuracy improve?
106-
is_best = acc > best_acc
107-
best_acc = max(acc, best_acc)
108-
if not is_best:
109-
epochs_since_improvement += 1
110-
print("\nEpochs since improvement: %d\n" % (epochs_since_improvement,))
111-
else:
112-
epochs_since_improvement = 0
113-
11495
# Decay learning rate every epoch
115-
# adjust_learning_rate(optimizer, 0.5)
96+
adjust_learning_rate(optimizer, 0.1)
11697

11798
# Save checkpoint
118-
save_checkpoint(epoch, model, optimizer, best_acc, word_map, epochs_since_improvement, is_best)
99+
save_checkpoint(epoch, model, optimizer, word_map)
119100

120101

121102
def train(train_loader, model, criterion, optimizer, epoch):
@@ -190,69 +171,5 @@ def train(train_loader, model, criterion, optimizer, epoch):
190171
acc=accs))
191172

192173

193-
def validate(val_loader, model, criterion):
194-
"""
195-
Performs one epoch's validation.
196-
197-
:param val_loader: DataLoader for validation data
198-
:param model: model
199-
:param criterion: cross entropy loss layer
200-
:return: validation accuracy score
201-
"""
202-
model.eval()
203-
204-
batch_time = AverageMeter() # forward prop. + back prop. time per batch
205-
data_time = AverageMeter() # data loading time per batch
206-
losses = AverageMeter() # cross entropy loss
207-
accs = AverageMeter() # accuracies
208-
209-
start = time.time()
210-
211-
# Batches
212-
for i, (documents, sentences_per_document, words_per_sentence, labels) in enumerate(val_loader):
213-
214-
data_time.update(time.time() - start)
215-
216-
documents = documents.to(device) # (batch_size, sentence_limit, word_limit)
217-
sentences_per_document = sentences_per_document.squeeze(1).to(device) # (batch_size)
218-
words_per_sentence = words_per_sentence.to(device) # (batch_size, sentence_limit)
219-
labels = labels.squeeze(1).to(device) # (batch_size)
220-
221-
# Forward prop.
222-
scores, word_alphas, sentence_alphas = model(documents, sentences_per_document,
223-
words_per_sentence) # (n_documents, n_classes), (n_documents, max_doc_len_in_batch, max_sent_len_in_batch), (n_documents, max_doc_len_in_batch)
224-
225-
# Loss
226-
loss = criterion(scores, labels)
227-
228-
# Find accuracy
229-
_, predictions = scores.max(dim=1) # (n_documents)
230-
correct_predictions = torch.eq(predictions, labels).sum().item()
231-
accuracy = correct_predictions / labels.size(0)
232-
233-
# Keep track of metrics
234-
losses.update(loss.item(), labels.size(0))
235-
batch_time.update(time.time() - start)
236-
accs.update(accuracy, labels.size(0))
237-
238-
start = time.time()
239-
240-
# Print training status
241-
if i % print_freq == 0:
242-
print('[{0}/{1}]\t'
243-
'Batch Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
244-
'Data Load Time {data_time.val:.3f} ({data_time.avg:.3f})\t'
245-
'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
246-
'Accuracy {acc.val:.3f} ({acc.avg:.3f})'.format(i, len(val_loader),
247-
batch_time=batch_time,
248-
data_time=data_time, loss=losses,
249-
acc=accs))
250-
251-
print('\n * LOSS - {loss.avg:.3f}, ACCURACY - {acc.avg:.3f}\n'.format(loss=losses,
252-
acc=accs))
253-
254-
return accs.avg
255-
256-
257174
if __name__ == '__main__':
258175
main()

utils.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -250,7 +250,7 @@ def clip_gradient(optimizer, grad_clip):
250250
param.grad.data.clamp_(-grad_clip, grad_clip)
251251

252252

253-
def save_checkpoint(epoch, model, optimizer, best_acc, word_map, epochs_since_improvement, is_best):
253+
def save_checkpoint(epoch, model, optimizer, word_map):
254254
"""
255255
Save model checkpoint.
256256
@@ -263,16 +263,11 @@ def save_checkpoint(epoch, model, optimizer, best_acc, word_map, epochs_since_im
263263
:param is_best: is this checkpoint the best so far?
264264
"""
265265
state = {'epoch': epoch,
266-
'best_acc': best_acc,
267266
'model': model,
268267
'optimizer': optimizer,
269-
'epochs_since_improvement': epochs_since_improvement,
270268
'word_map': word_map}
271269
filename = 'checkpoint_han.pth.tar'
272270
torch.save(state, filename)
273-
# If checkpoint is the best so far, create a copy to avoid being overwritten by a subsequent worse checkpoint
274-
if is_best:
275-
torch.save(state, 'BEST_' + filename)
276271

277272

278273
class AverageMeter(object):

0 commit comments

Comments
 (0)