|
27 | 27 | from tensorflow.python.framework import errors
|
28 | 28 | from tensorflow.python.framework import ops
|
29 | 29 | from tensorflow.python.framework import sparse_tensor
|
| 30 | +from tensorflow.python.ops import lookup_ops |
30 | 31 | from tensorflow.python.ops import variables
|
31 | 32 | from tensorflow.python.platform import gfile
|
32 | 33 | from tensorflow.python.platform import test
|
@@ -235,8 +236,7 @@ def verify_reset_restored_iterator(self,
|
235 | 236 | ds_fn, sparse_tensors=sparse_tensors)
|
236 | 237 | with self.test_session(graph=g) as sess:
|
237 | 238 | self._restore(saver, sess)
|
238 |
| - sess.run(variables.global_variables_initializer()) |
239 |
| - sess.run(init_op) |
| 239 | + self._initialize(init_op, sess) |
240 | 240 | for _ in range(num_outputs):
|
241 | 241 | actual.append(sess.run(get_next_op))
|
242 | 242 | if verify_exhausted:
|
@@ -390,8 +390,7 @@ def verify_error_on_save(self,
|
390 | 390 | init_op, get_next_op, saver = self._build_graph(
|
391 | 391 | ds_fn, sparse_tensors=sparse_tensors)
|
392 | 392 | with self.test_session(graph=g) as sess:
|
393 |
| - sess.run(variables.global_variables_initializer()) |
394 |
| - sess.run(init_op) |
| 393 | + self._initialize(init_op, sess) |
395 | 394 | for _ in range(break_point):
|
396 | 395 | sess.run(get_next_op)
|
397 | 396 | with self.assertRaises(error):
|
@@ -493,12 +492,10 @@ def get_ops():
|
493 | 492 | with self.test_session(graph=g) as sess:
|
494 | 493 | if ckpt_saved:
|
495 | 494 | if init_before_restore:
|
496 |
| - sess.run(variables.global_variables_initializer()) |
497 |
| - sess.run(init_op) |
| 495 | + self._initialize(init_op, sess) |
498 | 496 | self._restore(saver, sess)
|
499 | 497 | else:
|
500 |
| - sess.run(variables.global_variables_initializer()) |
501 |
| - sess.run(init_op) |
| 498 | + self._initialize(init_op, sess) |
502 | 499 | start = break_points[i - 1] if i > 0 else 0
|
503 | 500 | end = break_points[i] if i < len(break_points) else num_outputs
|
504 | 501 | num_iters = end - start
|
@@ -621,8 +618,14 @@ def _save(self, sess, saver):
|
621 | 618 | saver.save(sess, self._ckpt_path())
|
622 | 619 |
|
623 | 620 | def _restore(self, saver, sess):
|
| 621 | + sess.run(lookup_ops.tables_initializer()) |
624 | 622 | saver.restore(sess, self._latest_ckpt())
|
625 | 623 |
|
| 624 | + def _initialize(self, init_op, sess): |
| 625 | + sess.run(variables.global_variables_initializer()) |
| 626 | + sess.run(lookup_ops.tables_initializer()) |
| 627 | + sess.run(init_op) |
| 628 | + |
626 | 629 | def _import_meta_graph(self):
|
627 | 630 | meta_file_path = self._ckpt_path() + ".meta"
|
628 | 631 | return saver_lib.import_meta_graph(meta_file_path)
|
|
0 commit comments