@@ -119,7 +119,7 @@ def train(epoch):
119
119
120
120
print ("epoch %3d with lr=%.02e" % (epoch , get_lr ()))
121
121
phase = 'train'
122
- writer .add_scalar ('%s/learning_rate' % phase , get_lr (), global_step )
122
+ writer .add_scalar ('%s/learning_rate' % phase , get_lr (), epoch )
123
123
124
124
model .train () # Set model to training mode
125
125
@@ -163,10 +163,8 @@ def train(epoch):
163
163
164
164
accuracy = correct / total
165
165
epoch_loss = running_loss / it
166
- writer .add_scalar ('%s/accuracy' % phase , 100 * accuracy , global_step )
167
- writer .add_scalar ('%s/epoch_loss' % phase , epoch_loss , global_step )
168
- writer .add_scalar ('%s/accuracy_by_epoch' % phase , 100 * accuracy , epoch )
169
- writer .add_scalar ('%s/epoch_loss_by_epoch' % phase , epoch_loss , epoch )
166
+ writer .add_scalar ('%s/accuracy' % phase , 100 * accuracy , epoch )
167
+ writer .add_scalar ('%s/epoch_loss' % phase , epoch_loss , epoch )
170
168
171
169
def test (epoch ):
172
170
global best_accuracy , global_step
@@ -211,10 +209,8 @@ def test(epoch):
211
209
212
210
accuracy = correct / total
213
211
epoch_loss = running_loss / it
214
- writer .add_scalar ('%s/accuracy' % phase , 100 * accuracy , global_step )
215
- writer .add_scalar ('%s/epoch_loss' % phase , epoch_loss , global_step )
216
- writer .add_scalar ('%s/accuracy_by_epoch' % phase , 100 * accuracy , epoch )
217
- writer .add_scalar ('%s/epoch_loss_by_epoch' % phase , epoch_loss , epoch )
212
+ writer .add_scalar ('%s/accuracy' % phase , 100 * accuracy , epoch )
213
+ writer .add_scalar ('%s/epoch_loss' % phase , epoch_loss , epoch )
218
214
219
215
checkpoint = {
220
216
'epoch' : epoch ,
@@ -228,7 +224,7 @@ def test(epoch):
228
224
if accuracy > best_accuracy :
229
225
best_accuracy = accuracy
230
226
torch .save (checkpoint , 'checkpoints/best-cifar10-checkpoint-%s.pth' % full_name )
231
- torch .save (model , 'best-cifar10-model-%s.pth' % full_name )
227
+ torch .save (model , '%d- best-cifar10-model-%s.pth' % ( start_timestamp , full_name ) )
232
228
233
229
torch .save (checkpoint , 'checkpoints/last-cifar10-checkpoint.pth' )
234
230
del checkpoint # reduce memory
0 commit comments