Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit b12af9c

Browse files
committedJan 21, 2018
use epoch for tensorboard
1 parent f9d1afe commit b12af9c

File tree

2 files changed

+11
-19
lines changed

2 files changed

+11
-19
lines changed
 

‎train-cifar10.py

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ def train(epoch):
119119

120120
print("epoch %3d with lr=%.02e" % (epoch, get_lr()))
121121
phase = 'train'
122-
writer.add_scalar('%s/learning_rate' % phase, get_lr(), global_step)
122+
writer.add_scalar('%s/learning_rate' % phase, get_lr(), epoch)
123123

124124
model.train() # Set model to training mode
125125

@@ -163,10 +163,8 @@ def train(epoch):
163163

164164
accuracy = correct/total
165165
epoch_loss = running_loss / it
166-
writer.add_scalar('%s/accuracy' % phase, 100*accuracy, global_step)
167-
writer.add_scalar('%s/epoch_loss' % phase, epoch_loss, global_step)
168-
writer.add_scalar('%s/accuracy_by_epoch' % phase, 100*accuracy, epoch)
169-
writer.add_scalar('%s/epoch_loss_by_epoch' % phase, epoch_loss, epoch)
166+
writer.add_scalar('%s/accuracy' % phase, 100*accuracy, epoch)
167+
writer.add_scalar('%s/epoch_loss' % phase, epoch_loss, epoch)
170168

171169
def test(epoch):
172170
global best_accuracy, global_step
@@ -211,10 +209,8 @@ def test(epoch):
211209

212210
accuracy = correct/total
213211
epoch_loss = running_loss / it
214-
writer.add_scalar('%s/accuracy' % phase, 100*accuracy, global_step)
215-
writer.add_scalar('%s/epoch_loss' % phase, epoch_loss, global_step)
216-
writer.add_scalar('%s/accuracy_by_epoch' % phase, 100*accuracy, epoch)
217-
writer.add_scalar('%s/epoch_loss_by_epoch' % phase, epoch_loss, epoch)
212+
writer.add_scalar('%s/accuracy' % phase, 100*accuracy, epoch)
213+
writer.add_scalar('%s/epoch_loss' % phase, epoch_loss, epoch)
218214

219215
checkpoint = {
220216
'epoch': epoch,
@@ -228,7 +224,7 @@ def test(epoch):
228224
if accuracy > best_accuracy:
229225
best_accuracy = accuracy
230226
torch.save(checkpoint, 'checkpoints/best-cifar10-checkpoint-%s.pth' % full_name)
231-
torch.save(model, 'best-cifar10-model-%s.pth' % full_name)
227+
torch.save(model, '%d-best-cifar10-model-%s.pth' % (start_timestamp, full_name))
232228

233229
torch.save(checkpoint, 'checkpoints/last-cifar10-checkpoint.pth')
234230
del checkpoint # reduce memory

‎train-speech-commands.py

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,7 @@ def train(epoch):
136136

137137
print("epoch %3d with lr=%.02e" % (epoch, get_lr()))
138138
phase = 'train'
139-
writer.add_scalar('%s/learning_rate' % phase, get_lr(), global_step)
139+
writer.add_scalar('%s/learning_rate' % phase, get_lr(), epoch)
140140

141141
model.train() # Set model to training mode
142142

@@ -183,10 +183,8 @@ def train(epoch):
183183

184184
accuracy = correct/total
185185
epoch_loss = running_loss / it
186-
writer.add_scalar('%s/accuracy' % phase, 100*accuracy, global_step)
187-
writer.add_scalar('%s/epoch_loss' % phase, epoch_loss, global_step)
188-
writer.add_scalar('%s/accuracy_by_epoch' % phase, 100*accuracy, epoch)
189-
writer.add_scalar('%s/epoch_loss_by_epoch' % phase, epoch_loss, epoch)
186+
writer.add_scalar('%s/accuracy' % phase, 100*accuracy, epoch)
187+
writer.add_scalar('%s/epoch_loss' % phase, epoch_loss, epoch)
190188

191189
def valid(epoch):
192190
global best_accuracy, best_loss, global_step
@@ -234,10 +232,8 @@ def valid(epoch):
234232

235233
accuracy = correct/total
236234
epoch_loss = running_loss / it
237-
writer.add_scalar('%s/accuracy' % phase, 100*accuracy, global_step)
238-
writer.add_scalar('%s/epoch_loss' % phase, epoch_loss, global_step)
239-
writer.add_scalar('%s/accuracy_by_epoch' % phase, 100*accuracy, epoch)
240-
writer.add_scalar('%s/epoch_loss_by_epoch' % phase, epoch_loss, epoch)
235+
writer.add_scalar('%s/accuracy' % phase, 100*accuracy, epoch)
236+
writer.add_scalar('%s/epoch_loss' % phase, epoch_loss, epoch)
241237

242238
checkpoint = {
243239
'epoch': epoch,

0 commit comments

Comments
 (0)
Please sign in to comment.