Skip to content

Commit e955377

Browse files
shivaniagtensorflower-gardener
authored andcommitted
[tf.data] Support for initializing all the tables of the given graph.
PiperOrigin-RevId: 183466905
1 parent a977a77 commit e955377

File tree

2 files changed

+12
-8
lines changed

2 files changed

+12
-8
lines changed

tensorflow/contrib/data/python/kernel_tests/BUILD

+1
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,7 @@ py_library(
126126
"//tensorflow/python:client_testlib",
127127
"//tensorflow/python:errors",
128128
"//tensorflow/python:framework_ops",
129+
"//tensorflow/python:lookup_ops",
129130
"//tensorflow/python:platform",
130131
"//tensorflow/python:sparse_tensor",
131132
"//tensorflow/python:training",

tensorflow/contrib/data/python/kernel_tests/dataset_serialization_test_base.py

+11-8
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from tensorflow.python.framework import errors
2828
from tensorflow.python.framework import ops
2929
from tensorflow.python.framework import sparse_tensor
30+
from tensorflow.python.ops import lookup_ops
3031
from tensorflow.python.ops import variables
3132
from tensorflow.python.platform import gfile
3233
from tensorflow.python.platform import test
@@ -235,8 +236,7 @@ def verify_reset_restored_iterator(self,
235236
ds_fn, sparse_tensors=sparse_tensors)
236237
with self.test_session(graph=g) as sess:
237238
self._restore(saver, sess)
238-
sess.run(variables.global_variables_initializer())
239-
sess.run(init_op)
239+
self._initialize(init_op, sess)
240240
for _ in range(num_outputs):
241241
actual.append(sess.run(get_next_op))
242242
if verify_exhausted:
@@ -390,8 +390,7 @@ def verify_error_on_save(self,
390390
init_op, get_next_op, saver = self._build_graph(
391391
ds_fn, sparse_tensors=sparse_tensors)
392392
with self.test_session(graph=g) as sess:
393-
sess.run(variables.global_variables_initializer())
394-
sess.run(init_op)
393+
self._initialize(init_op, sess)
395394
for _ in range(break_point):
396395
sess.run(get_next_op)
397396
with self.assertRaises(error):
@@ -493,12 +492,10 @@ def get_ops():
493492
with self.test_session(graph=g) as sess:
494493
if ckpt_saved:
495494
if init_before_restore:
496-
sess.run(variables.global_variables_initializer())
497-
sess.run(init_op)
495+
self._initialize(init_op, sess)
498496
self._restore(saver, sess)
499497
else:
500-
sess.run(variables.global_variables_initializer())
501-
sess.run(init_op)
498+
self._initialize(init_op, sess)
502499
start = break_points[i - 1] if i > 0 else 0
503500
end = break_points[i] if i < len(break_points) else num_outputs
504501
num_iters = end - start
@@ -621,8 +618,14 @@ def _save(self, sess, saver):
621618
saver.save(sess, self._ckpt_path())
622619

623620
def _restore(self, saver, sess):
621+
sess.run(lookup_ops.tables_initializer())
624622
saver.restore(sess, self._latest_ckpt())
625623

624+
def _initialize(self, init_op, sess):
625+
sess.run(variables.global_variables_initializer())
626+
sess.run(lookup_ops.tables_initializer())
627+
sess.run(init_op)
628+
626629
def _import_meta_graph(self):
627630
meta_file_path = self._ckpt_path() + ".meta"
628631
return saver_lib.import_meta_graph(meta_file_path)

0 commit comments

Comments
 (0)