Skip to content

Commit 38bb077

Browse files
committed
make both accuracy and loss non-mandatory; improve logging
1 parent 88ebdd1 commit 38bb077

File tree

2 files changed

+32
-12
lines changed

2 files changed

+32
-12
lines changed

hyperengine/impl/tensorflow/tensorflow_runner.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ def __init__(self, model=None, extra_feed={},
2424
self._x = self._find_tensor(input)
2525
self._y = self._find_tensor(label)
2626
self._mode = self._find_tensor(mode, mandatory=False)
27-
self._loss = self._find_tensor(loss)
27+
self._loss = self._find_tensor(loss, mandatory=False)
2828
self._accuracy = self._find_tensor(accuracy, mandatory=False)
2929
self._minimize = self._find_op(train)
3030
self._model_size = self._calc_model_size()
@@ -43,12 +43,16 @@ def run_batch(self, batch_x, batch_y):
4343

4444
def evaluate(self, batch_x, batch_y):
4545
feed_dict = self._get_feed_dict(batch_x, batch_y, 'test')
46-
if self._accuracy:
47-
loss, accuracy = self._session.run([self._loss, self._accuracy], feed_dict=feed_dict)
48-
return {'loss': loss, 'accuracy': accuracy}
49-
else:
46+
if self._loss is None and self._accuracy is None:
47+
return {}
48+
if self._accuracy is None:
5049
loss = self._session.run(self._loss, feed_dict=feed_dict)
5150
return {'loss': loss}
51+
if self._loss is None:
52+
accuracy = self._session.run(self._accuracy, feed_dict=feed_dict)
53+
return {'accuracy': accuracy}
54+
loss, accuracy = self._session.run([self._loss, self._accuracy], feed_dict=feed_dict)
55+
return {'loss': loss, 'accuracy': accuracy}
5256

5357
def terminate(self):
5458
tf.reset_default_graph()
@@ -74,6 +78,7 @@ def _find_tensor(self, name, mandatory=True):
7478
return self._graph.get_tensor_by_name(name + ':0')
7579
except KeyError:
7680
if not mandatory:
81+
debug('Tensor not found in Tensorflow graph:', name)
7782
return None
7883
warn('Failed to infer a tensor "%s" in Tensorflow graph. '
7984
'Most likely, you should add "name=\'%s\'" in the placeholder/tensor definition' % (name, name))
@@ -84,6 +89,7 @@ def _find_op(self, name, default=None):
8489
return self._graph.get_operation_by_name(name)
8590
except KeyError:
8691
if default is not None:
92+
debug('Op not found in Tensorflow graph:', name)
8793
return default
8894
warn('Failed to infer an op "%s" in Tensorflow graph. '
8995
'Most likely, you should add "name=\'%s\'" in the op definition' % (name, name))

hyperengine/model/base_solver.py

Lines changed: 21 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ def _update_epochs_dynamically(self):
104104
elif args == 2:
105105
new_epochs = self._dynamic_epochs(curve)
106106
else:
107-
warn('Invalid \"dynamic_epochs\" parameter: '
107+
warn('Invalid "dynamic_epochs" parameter: '
108108
'expected a function with either one or two arguments, but got %s' % args_spec)
109109

110110
if self._epochs != new_epochs:
@@ -115,7 +115,7 @@ def _evaluate_train(self, batch_x, batch_y):
115115
eval_this_step = self._train_set.step % self._eval_train_every == 0
116116
if eval_this_step and is_info_logged():
117117
eval_ = self._runner.evaluate(batch_x, batch_y)
118-
self._log_iteration('train_accuracy', eval_.get('loss', 0), eval_.get('accuracy', 0), False)
118+
self._log_iteration('train_accuracy', eval_.get('loss'), eval_.get('accuracy'), False)
119119

120120
def _evaluate_validation(self):
121121
eval_this_step = self._train_set.step % self._eval_validation_every == 0
@@ -127,7 +127,7 @@ def _evaluate_validation(self):
127127
return
128128

129129
eval_ = self._evaluate(batch_x=self._val_set.x, batch_y=self._val_set.y)
130-
self._log_iteration('validation_accuracy', eval_.get('loss', 0), eval_.get('accuracy', 0), True)
130+
self._log_iteration('validation_accuracy', eval_.get('loss'), eval_.get('accuracy'), True)
131131
return eval_
132132

133133
def _evaluate_test(self):
@@ -139,13 +139,27 @@ def _evaluate_test(self):
139139
return
140140

141141
eval_ = self._evaluate(batch_x=self._test_set.x, batch_y=self._test_set.y)
142-
info('Final test_accuracy=%.4f' % eval_.get('accuracy', 0))
142+
test_accuracy = eval_.get('accuracy')
143+
if test_accuracy is None:
144+
warn('Test accuracy evaluation is not available')
145+
return
146+
147+
info('Final test_accuracy=%.4f' % test_accuracy)
143148
return eval_
144149

145150
def _log_iteration(self, name, loss, accuracy, mark_best):
146-
marker = ' *' if mark_best and (accuracy > self._max_val_accuracy) else ''
147-
info('Epoch %2d, iteration %7d: loss=%.6f, %s=%.4f%s' %
148-
(self._train_set.epochs_completed, self._train_set.index, loss, name, accuracy, marker))
151+
message = 'Epoch %2d, iteration %7d' % (self._train_set.epochs_completed, self._train_set.index)
152+
if accuracy is not None:
153+
marker = ' *' if mark_best and (accuracy > self._max_val_accuracy) else ''
154+
if loss is None:
155+
info('%s: %s=%.4f%s' % (message, name, accuracy, marker))
156+
else:
157+
info('%s: loss=%.6f, %s=%.4f%s' % (message, loss, name, accuracy, marker))
158+
else:
159+
if loss is not None:
160+
info('%s: loss=%.6f' % (message, loss))
161+
else:
162+
info('%s: -- no loss or accuracy defined --' % message)
149163

150164
def _evaluate(self, batch_x, batch_y):
151165
size = len(batch_x)

0 commit comments

Comments
 (0)