Skip to content

Commit 8aa4ffd

Browse files
authored
Merge pull request #43 from sczhengyabin/sczhengyabin-patch-model_restore_speedup
Update files.py, speed up model saving and restoring process.
2 parents 400d717 + 9eff9a6 commit 8aa4ffd

File tree

1 file changed

+11
-10
lines changed

1 file changed

+11
-10
lines changed

tensorlayer/files.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -646,14 +646,14 @@ def save_npz(save_list=[], name='model.npz', sess=None):
646646
"""
647647
## save params into a list
648648
save_list_var = []
649-
for k, value in enumerate(save_list):
650-
if sess:
651-
save_list_var.append( sess.run(value) )
652-
else:
653-
try:
654-
save_list_var.append( value.eval() )
655-
except:
656-
print(" Fail to save model, Hint: pass the session into this function, save_npz(network.all_params, name='model.npz', sess=sess)")
649+
if sess:
650+
save_list_var = sess.run(save_list)
651+
else:
652+
try:
653+
for k, value in enumerate(save_list):
654+
save_list_var.append(value.eval())
655+
except:
656+
print(" Fail to save model, Hint: pass the session into this function, save_npz(network.all_params, name='model.npz', sess=sess)")
657657
np.savez(name, params=save_list_var)
658658
save_list_var = None
659659
del save_list_var
@@ -734,9 +734,10 @@ def assign_params(sess, params, network):
734734
----------
735735
- `Assign value to a TensorFlow variable <http://stackoverflow.com/questions/34220532/how-to-assign-value-to-a-tensorflow-variable>`_
736736
"""
737+
ops = []
737738
for idx, param in enumerate(params):
738-
assign_op = network.all_params[idx].assign(param)
739-
sess.run(assign_op)
739+
ops.append(network.all_params[idx].assign(param))
740+
sess.run(ops)
740741

741742

742743

0 commit comments

Comments
 (0)