@@ -27,7 +27,8 @@ class ReportProgressTest(parameterized.TestCase):
2727
2828 def test_every_steps (self ):
2929 hook = periodic_actions .ReportProgress (
30- every_steps = 4 , every_secs = None , num_train_steps = 10 )
30+ every_steps = 4 , every_secs = None , num_train_steps = 10
31+ )
3132 t = time .monotonic ()
3233 with self .assertLogs (level = "INFO" ) as logs :
3334 self .assertFalse (hook (1 , t ))
@@ -38,13 +39,18 @@ def test_every_steps(self):
3839 t += 0.12
3940 self .assertTrue (hook (4 , t ))
4041 # We did 1 step every 0.12s => 8.333 steps/s.
41- self .assertEqual (logs .output , [
42- "INFO:absl:Setting work unit notes: 8.3 steps/s, 40.0% (4/10), ETA: 0m"
43- ])
42+ self .assertEqual (
43+ logs .output ,
44+ [
45+ "INFO:absl:Setting work unit notes: 8.3 steps/s, 40.0% (4/10),"
46+ " ETA: 0m"
47+ ],
48+ )
4449
4550 def test_every_secs (self ):
4651 hook = periodic_actions .ReportProgress (
47- every_steps = None , every_secs = 0.3 , num_train_steps = 10 )
52+ every_steps = None , every_secs = 0.3 , num_train_steps = 10
53+ )
4854 t = time .monotonic ()
4955 with self .assertLogs (level = "INFO" ) as logs :
5056 self .assertFalse (hook (1 , t ))
@@ -55,9 +61,13 @@ def test_every_secs(self):
5561 t += 0.12
5662 self .assertTrue (hook (4 , t ))
5763 # We did 1 step every 0.12s => 8.333 steps/s.
58- self .assertEqual (logs .output , [
59- "INFO:absl:Setting work unit notes: 8.3 steps/s, 40.0% (4/10), ETA: 0m"
60- ])
64+ self .assertEqual (
65+ logs .output ,
66+ [
67+ "INFO:absl:Setting work unit notes: 8.3 steps/s, 40.0% (4/10),"
68+ " ETA: 0m"
69+ ],
70+ )
6171
6272 def test_without_num_train_steps (self ):
6373 report = periodic_actions .ReportProgress (every_steps = 2 )
@@ -66,9 +76,22 @@ def test_without_num_train_steps(self):
6676 self .assertFalse (report (1 , t ))
6777 self .assertTrue (report (2 , t + 0.12 ))
6878 # We did 1 step in 0.12s => 8.333 steps/s.
69- self .assertEqual (logs .output , [
70- "INFO:absl:Setting work unit notes: 8.3 steps/s"
71- ])
79+ self .assertEqual (
80+ logs .output , ["INFO:absl:Setting work unit notes: 8.3 steps/s" ]
81+ )
82+
83+ def test_with_persistent_notes (self ):
84+ report = periodic_actions .ReportProgress (every_steps = 2 )
85+ report .set_persistent_notes ("Hello world" )
86+ t = time .monotonic ()
87+ with self .assertLogs (level = "INFO" ) as logs :
88+ self .assertFalse (report (1 , t ))
89+ self .assertTrue (report (2 , t + 0.12 ))
90+ # We did 1 step in 0.12s => 8.333 steps/s.
91+ self .assertEqual (
92+ logs .output ,
93+ ["INFO:absl:Setting work unit notes: Hello world\n 8.3 steps/s" ],
94+ )
7295
7396 def test_unknown_cardinality (self ):
7497 report = periodic_actions .ReportProgress (every_steps = 2 )
@@ -77,15 +100,16 @@ def test_unknown_cardinality(self):
77100 self .assertFalse (report (1 , t ))
78101 self .assertTrue (report (2 , t + 0.12 ))
79102 # We did 1 step in 0.12s => 8.333 steps/s.
80- self .assertEqual (logs . output , [
81- "INFO:absl:Setting work unit notes: 8.3 steps/s"
82- ] )
103+ self .assertEqual (
104+ logs . output , [ "INFO:absl:Setting work unit notes: 8.3 steps/s" ]
105+ )
83106
84107 def test_called_every_step (self ):
85108 hook = periodic_actions .ReportProgress (every_steps = 3 , num_train_steps = 10 )
86109 t = time .monotonic ()
87110 with self .assertRaisesRegex (
88- ValueError , "PeriodicAction must be called after every step" ):
111+ ValueError , "PeriodicAction must be called after every step"
112+ ):
89113 hook (1 , t )
90114 hook (11 , t ) # Raises exception.
91115
@@ -97,10 +121,13 @@ def test_called_every_step(self):
97121 def test_named (self , wait_jax_async_dispatch , mock_time ):
98122 mock_time .return_value = 0
99123 hook = periodic_actions .ReportProgress (
100- every_steps = 1 , every_secs = None , num_train_steps = 10 )
124+ every_steps = 1 , every_secs = None , num_train_steps = 10
125+ )
126+
101127 def _wait ():
102128 # Here we depend on hook._executor=ThreadPoolExecutor(max_workers=1)
103129 hook ._executor .submit (lambda : None ).result ()
130+
104131 self .assertFalse (hook (1 )) # Never triggers on first execution.
105132 with hook .timed ("test1" , wait_jax_async_dispatch ):
106133 _wait ()
@@ -117,25 +144,32 @@ def _wait():
117144 mock_time .return_value = 4
118145 with self .assertLogs (level = "INFO" ) as logs :
119146 self .assertTrue (hook (2 ))
120- self .assertEqual (logs .output , [
121- "INFO:absl:Setting work unit notes: 0.2 steps/s, 20.0% (2/10), ETA: 0m"
122- " (0m : 50.0% test1, 25.0% test2)"
123- ])
147+ self .assertEqual (
148+ logs .output ,
149+ [
150+ "INFO:absl:Setting work unit notes: 0.2 steps/s, 20.0% (2/10), ETA:"
151+ " 0m (0m : 50.0% test1, 25.0% test2)"
152+ ],
153+ )
124154
125155 @mock .patch ("time.monotonic" )
126156 def test_write_metrics (self , time_mock ):
127157 time_mock .return_value = 0
128158 writer_mock = mock .Mock ()
129159 hook = periodic_actions .ReportProgress (
130- every_steps = 2 , every_secs = None , writer = writer_mock )
160+ every_steps = 2 , every_secs = None , writer = writer_mock
161+ )
131162 time_mock .return_value = 1
132163 hook (1 )
133164 time_mock .return_value = 2
134165 hook (2 )
135- self .assertEqual (writer_mock .write_scalars .mock_calls , [
136- mock .call (2 , {"steps_per_sec" : 1 }),
137- mock .call (2 , {"uptime" : 2 }),
138- ])
166+ self .assertEqual (
167+ writer_mock .write_scalars .mock_calls ,
168+ [
169+ mock .call (2 , {"steps_per_sec" : 1 }),
170+ mock .call (2 , {"uptime" : 2 }),
171+ ],
172+ )
139173
140174
141175class DummyProfilerSession :
@@ -177,7 +211,8 @@ def add_stop_step():
177211 num_profile_steps = 2 ,
178212 profile_duration_ms = 2_000 ,
179213 first_profile = 3 ,
180- every_steps = 7 )
214+ every_steps = 7 ,
215+ )
181216 for step in range (1 , 18 ):
182217 mock_time .return_value = step - 0.5 if step == 9 else step
183218 hook (step )
@@ -202,7 +237,8 @@ def profile_collect(logdir, callback, hosts, duration_ms):
202237 logdir = tempfile .mkdtemp (),
203238 profile_duration_ms = 2_000 ,
204239 first_profile = 3 ,
205- every_steps = 7 )
240+ every_steps = 7 ,
241+ )
206242 for step in range (1 , 18 ):
207243 hook (step )
208244 self .assertEqual ([3 , 7 , 14 ], start_steps )
@@ -213,7 +249,8 @@ class PeriodicCallbackTest(absltest.TestCase):
213249 def test_every_steps (self ):
214250 callback = mock .Mock ()
215251 hook = periodic_actions .PeriodicCallback (
216- every_steps = 2 , callback_fn = callback )
252+ every_steps = 2 , callback_fn = callback
253+ )
217254
218255 for step in range (1 , 10 ):
219256 hook (step , 3 , remainder = step % 3 )
@@ -222,7 +259,7 @@ def test_every_steps(self):
222259 mock .call (remainder = 2 , step = 2 , t = 3 ),
223260 mock .call (remainder = 1 , step = 4 , t = 3 ),
224261 mock .call (remainder = 0 , step = 6 , t = 3 ),
225- mock .call (remainder = 2 , step = 8 , t = 3 )
262+ mock .call (remainder = 2 , step = 8 , t = 3 ),
226263 ]
227264 self .assertListEqual (expected_calls , callback .call_args_list )
228265
@@ -237,7 +274,7 @@ def test_every_secs(self, mock_time):
237274 # Note: time will be initialized at 1 so hook runs at steps 4 & 7.
238275 expected_calls = [
239276 mock .call (remainder = 4 , step = 4 , t = 4.0 ),
240- mock .call (remainder = 2 , step = 7 , t = 7.0 )
277+ mock .call (remainder = 2 , step = 7 , t = 7.0 ),
241278 ]
242279 self .assertListEqual (expected_calls , callback .call_args_list )
243280
@@ -258,7 +295,8 @@ def cb(step, t):
258295 out .append (step )
259296
260297 hook = periodic_actions .PeriodicCallback (
261- every_steps = 1 , callback_fn = cb , execute_async = True )
298+ every_steps = 1 , callback_fn = cb , execute_async = True
299+ )
262300 hook (0 )
263301 hook (1 )
264302 hook (2 )
@@ -276,7 +314,8 @@ def cb(step, t):
276314 raise Exception
277315
278316 hook = periodic_actions .PeriodicCallback (
279- every_steps = 1 , callback_fn = cb , execute_async = True )
317+ every_steps = 1 , callback_fn = cb , execute_async = True
318+ )
280319
281320 hook (0 )
282321
@@ -290,7 +329,8 @@ def cb():
290329 return 5
291330
292331 hook = periodic_actions .PeriodicCallback (
293- every_steps = 1 , callback_fn = cb , pass_step_and_time = False )
332+ every_steps = 1 , callback_fn = cb , pass_step_and_time = False
333+ )
294334 hook (0 )
295335 hook (1 )
296336 self .assertEqual (hook .get_last_callback_result (), 5 )
0 commit comments