@@ -104,7 +104,7 @@ def _update_epochs_dynamically(self):
104
104
elif args == 2 :
105
105
new_epochs = self ._dynamic_epochs (curve )
106
106
else :
107
- warn ('Invalid \ " dynamic_epochs\ " parameter: '
107
+ warn ('Invalid "dynamic_epochs" parameter: '
108
108
'expected a function with either one or two arguments, but got %s' % args_spec )
109
109
110
110
if self ._epochs != new_epochs :
@@ -115,7 +115,7 @@ def _evaluate_train(self, batch_x, batch_y):
115
115
eval_this_step = self ._train_set .step % self ._eval_train_every == 0
116
116
if eval_this_step and is_info_logged ():
117
117
eval_ = self ._runner .evaluate (batch_x , batch_y )
118
- self ._log_iteration ('train_accuracy' , eval_ .get ('loss' , 0 ), eval_ .get ('accuracy' , 0 ), False )
118
+ self ._log_iteration ('train_accuracy' , eval_ .get ('loss' ), eval_ .get ('accuracy' ), False )
119
119
120
120
def _evaluate_validation (self ):
121
121
eval_this_step = self ._train_set .step % self ._eval_validation_every == 0
@@ -127,7 +127,7 @@ def _evaluate_validation(self):
127
127
return
128
128
129
129
eval_ = self ._evaluate (batch_x = self ._val_set .x , batch_y = self ._val_set .y )
130
- self ._log_iteration ('validation_accuracy' , eval_ .get ('loss' , 0 ), eval_ .get ('accuracy' , 0 ), True )
130
+ self ._log_iteration ('validation_accuracy' , eval_ .get ('loss' ), eval_ .get ('accuracy' ), True )
131
131
return eval_
132
132
133
133
def _evaluate_test (self ):
@@ -139,13 +139,27 @@ def _evaluate_test(self):
139
139
return
140
140
141
141
eval_ = self ._evaluate (batch_x = self ._test_set .x , batch_y = self ._test_set .y )
142
- info ('Final test_accuracy=%.4f' % eval_ .get ('accuracy' , 0 ))
142
+ test_accuracy = eval_ .get ('accuracy' )
143
+ if test_accuracy is None :
144
+ warn ('Test accuracy evaluation is not available' )
145
+ return
146
+
147
+ info ('Final test_accuracy=%.4f' % test_accuracy )
143
148
return eval_
144
149
145
150
def _log_iteration (self , name , loss , accuracy , mark_best ):
146
- marker = ' *' if mark_best and (accuracy > self ._max_val_accuracy ) else ''
147
- info ('Epoch %2d, iteration %7d: loss=%.6f, %s=%.4f%s' %
148
- (self ._train_set .epochs_completed , self ._train_set .index , loss , name , accuracy , marker ))
151
+ message = 'Epoch %2d, iteration %7d' % (self ._train_set .epochs_completed , self ._train_set .index )
152
+ if accuracy is not None :
153
+ marker = ' *' if mark_best and (accuracy > self ._max_val_accuracy ) else ''
154
+ if loss is None :
155
+ info ('%s: %s=%.4f%s' % (message , name , accuracy , marker ))
156
+ else :
157
+ info ('%s: loss=%.6f, %s=%.4f%s' % (message , loss , name , accuracy , marker ))
158
+ else :
159
+ if loss is not None :
160
+ info ('%s: loss=%.6f' % (message , loss ))
161
+ else :
162
+ info ('%s: -- no loss or accuracy defined --' % message )
149
163
150
164
def _evaluate (self , batch_x , batch_y ):
151
165
size = len (batch_x )
0 commit comments