1
+ import json
2
+
1
3
"""
2
4
3
5
Some simple logging functionality, inspired by rllab's logging.
14
16
"""
15
17
16
18
import os .path as osp , shutil , time , atexit , os , subprocess
19
+ import pickle
20
+ import tensorflow as tf
17
21
18
22
color2num = dict (
19
23
gray = 30 ,
@@ -47,15 +51,12 @@ def configure_output_dir(d=None):
47
51
Set output directory to d, or to /tmp/somerandomnumber if d is None
48
52
"""
49
53
G .output_dir = d or "/tmp/experiments/%i" % int (time .time ())
50
- assert not osp .exists (G .output_dir ), "Log dir %s already exists! Delete it first or use a different dir" % G .output_dir
51
- os .makedirs (G .output_dir )
54
+ if osp .exists (G .output_dir ):
55
+ print ("Log dir %s already exists! Delete it first or use a different dir" % G .output_dir )
56
+ else :
57
+ os .makedirs (G .output_dir )
52
58
G .output_file = open (osp .join (G .output_dir , "log.txt" ), 'w' )
53
59
atexit .register (G .output_file .close )
54
- try :
55
- cmd = "cd %s && git diff > %s 2>/dev/null" % (osp .dirname (__file__ ), osp .join (G .output_dir , "a.diff" ))
56
- subprocess .check_call (cmd , shell = True ) # Save git diff to experiment directory
57
- except subprocess .CalledProcessError :
58
- print ("configure_output_dir: not storing the git diff, probably because you're not in a git repo" )
59
60
print (colorize ("Logging data to %s" % G .output_file .name , 'green' , bold = True ))
60
61
61
62
def log_tabular (key , val ):
@@ -70,19 +71,38 @@ def log_tabular(key, val):
70
71
assert key not in G .log_current_row , "You already set %s this iteration. Maybe you forgot to call dump_tabular()" % key
71
72
G .log_current_row [key ] = val
72
73
74
+ def save_params (params ):
75
+ with open (osp .join (G .output_dir , "params.json" ), 'w' ) as out :
76
+ out .write (json .dumps (params , separators = (',\n ' ,'\t :\t ' ), sort_keys = True ))
77
+
78
+ def pickle_tf_vars ():
79
+ """
80
+ Saves tensorflow variables
81
+ Requires them to be initialized first, also a default session must exist
82
+ """
83
+ _dict = {v .name : v .eval () for v in tf .global_variables ()}
84
+ with open (osp .join (G .output_dir , "vars.pkl" ), 'wb' ) as f :
85
+ pickle .dump (_dict , f )
86
+
87
+
73
88
def dump_tabular ():
74
89
"""
75
90
Write all of the diagnostics from the current iteration
76
91
"""
77
92
vals = []
78
- print ("-" * 37 )
93
+ key_lens = [len (key ) for key in G .log_headers ]
94
+ max_key_len = max (15 ,max (key_lens ))
95
+ keystr = '%' + '%d' % max_key_len
96
+ fmt = "| " + keystr + "s | %15s |"
97
+ n_slashes = 22 + max_key_len
98
+ print ("-" * n_slashes )
79
99
for key in G .log_headers :
80
100
val = G .log_current_row .get (key , "" )
81
101
if hasattr (val , "__float__" ): valstr = "%8.3g" % val
82
102
else : valstr = val
83
- print ("| %15s | %15s |" % (key , valstr ))
103
+ print (fmt % (key , valstr ))
84
104
vals .append (val )
85
- print ("-" * 37 )
105
+ print ("-" * n_slashes )
86
106
if G .output_file is not None :
87
107
if G .first_row :
88
108
G .output_file .write ("\t " .join (G .log_headers ))
@@ -91,4 +111,4 @@ def dump_tabular():
91
111
G .output_file .write ("\n " )
92
112
G .output_file .flush ()
93
113
G .log_current_row .clear ()
94
- G .first_row = False
114
+ G .first_row = False
0 commit comments