diff --git a/clu/periodic_actions.py b/clu/periodic_actions.py index c804128..7fe87a7 100644 --- a/clu/periodic_actions.py +++ b/clu/periodic_actions.py @@ -192,6 +192,11 @@ def __init__(self, # Using max_worker=1 guarantees that the calls to _wait_jax_async_dispatch() # happen sequentially. self._executor = concurrent.futures.ThreadPoolExecutor(max_workers=1) + self._persistent_notes = "" + + def set_persistent_notes(self, message: str): + """Sets the persistent notes for this work unit (not overwritten by the periodic action).""" + self._persistent_notes = message def _should_trigger(self, step: int, t: float) -> bool: # Note: step == self._previous_step is only True on the first step. @@ -211,6 +216,8 @@ def _apply(self, step: int, t: float): f"{100 * dt / total:.1f}% {name}" for name, dt in sorted(self._time_per_part.items()))) # This should be relatively cheap so we can do it in the same main thread. + if self._persistent_notes: + message = f"{self._persistent_notes}\n{message}" platform.work_unit().set_notes(message) if self._writer is not None: self._writer.write_scalars(step, {"steps_per_sec": steps_per_sec}) diff --git a/clu/periodic_actions_test.py b/clu/periodic_actions_test.py index 2f80dd7..bf70e97 100644 --- a/clu/periodic_actions_test.py +++ b/clu/periodic_actions_test.py @@ -27,7 +27,8 @@ class ReportProgressTest(parameterized.TestCase): def test_every_steps(self): hook = periodic_actions.ReportProgress( - every_steps=4, every_secs=None, num_train_steps=10) + every_steps=4, every_secs=None, num_train_steps=10 + ) t = time.monotonic() with self.assertLogs(level="INFO") as logs: self.assertFalse(hook(1, t)) @@ -38,13 +39,18 @@ def test_every_steps(self): t += 0.12 self.assertTrue(hook(4, t)) # We did 1 step every 0.12s => 8.333 steps/s. - self.assertEqual(logs.output, [ - "INFO:absl:Setting work unit notes: 8.3 steps/s, 40.0% (4/10), ETA: 0m" - ]) + self.assertEqual( + logs.output, + [ + "INFO:absl:Setting work unit notes: 8.3 steps/s, 40.0% (4/10)," + " ETA: 0m" + ], + ) def test_every_secs(self): hook = periodic_actions.ReportProgress( - every_steps=None, every_secs=0.3, num_train_steps=10) + every_steps=None, every_secs=0.3, num_train_steps=10 + ) t = time.monotonic() with self.assertLogs(level="INFO") as logs: self.assertFalse(hook(1, t)) @@ -55,9 +61,13 @@ def test_every_secs(self): t += 0.12 self.assertTrue(hook(4, t)) # We did 1 step every 0.12s => 8.333 steps/s. - self.assertEqual(logs.output, [ - "INFO:absl:Setting work unit notes: 8.3 steps/s, 40.0% (4/10), ETA: 0m" - ]) + self.assertEqual( + logs.output, + [ + "INFO:absl:Setting work unit notes: 8.3 steps/s, 40.0% (4/10)," + " ETA: 0m" + ], + ) def test_without_num_train_steps(self): report = periodic_actions.ReportProgress(every_steps=2) @@ -66,9 +76,22 @@ def test_without_num_train_steps(self): self.assertFalse(report(1, t)) self.assertTrue(report(2, t + 0.12)) # We did 1 step in 0.12s => 8.333 steps/s. - self.assertEqual(logs.output, [ - "INFO:absl:Setting work unit notes: 8.3 steps/s" - ]) + self.assertEqual( + logs.output, ["INFO:absl:Setting work unit notes: 8.3 steps/s"] + ) + + def test_with_persistent_notes(self): + report = periodic_actions.ReportProgress(every_steps=2) + report.set_persistent_notes("Hello world") + t = time.monotonic() + with self.assertLogs(level="INFO") as logs: + self.assertFalse(report(1, t)) + self.assertTrue(report(2, t + 0.12)) + # We did 1 step in 0.12s => 8.333 steps/s. + self.assertEqual( + logs.output, + ["INFO:absl:Setting work unit notes: Hello world\n8.3 steps/s"], + ) def test_unknown_cardinality(self): report = periodic_actions.ReportProgress(every_steps=2) @@ -77,15 +100,16 @@ def test_unknown_cardinality(self): self.assertFalse(report(1, t)) self.assertTrue(report(2, t + 0.12)) # We did 1 step in 0.12s => 8.333 steps/s. - self.assertEqual(logs.output, [ - "INFO:absl:Setting work unit notes: 8.3 steps/s" - ]) + self.assertEqual( + logs.output, ["INFO:absl:Setting work unit notes: 8.3 steps/s"] + ) def test_called_every_step(self): hook = periodic_actions.ReportProgress(every_steps=3, num_train_steps=10) t = time.monotonic() with self.assertRaisesRegex( - ValueError, "PeriodicAction must be called after every step"): + ValueError, "PeriodicAction must be called after every step" + ): hook(1, t) hook(11, t) # Raises exception. @@ -97,10 +121,13 @@ def test_called_every_step(self): def test_named(self, wait_jax_async_dispatch, mock_time): mock_time.return_value = 0 hook = periodic_actions.ReportProgress( - every_steps=1, every_secs=None, num_train_steps=10) + every_steps=1, every_secs=None, num_train_steps=10 + ) + def _wait(): # Here we depend on hook._executor=ThreadPoolExecutor(max_workers=1) hook._executor.submit(lambda: None).result() + self.assertFalse(hook(1)) # Never triggers on first execution. with hook.timed("test1", wait_jax_async_dispatch): _wait() @@ -117,25 +144,32 @@ def _wait(): mock_time.return_value = 4 with self.assertLogs(level="INFO") as logs: self.assertTrue(hook(2)) - self.assertEqual(logs.output, [ - "INFO:absl:Setting work unit notes: 0.2 steps/s, 20.0% (2/10), ETA: 0m" - " (0m : 50.0% test1, 25.0% test2)" - ]) + self.assertEqual( + logs.output, + [ + "INFO:absl:Setting work unit notes: 0.2 steps/s, 20.0% (2/10), ETA:" + " 0m (0m : 50.0% test1, 25.0% test2)" + ], + ) @mock.patch("time.monotonic") def test_write_metrics(self, time_mock): time_mock.return_value = 0 writer_mock = mock.Mock() hook = periodic_actions.ReportProgress( - every_steps=2, every_secs=None, writer=writer_mock) + every_steps=2, every_secs=None, writer=writer_mock + ) time_mock.return_value = 1 hook(1) time_mock.return_value = 2 hook(2) - self.assertEqual(writer_mock.write_scalars.mock_calls, [ - mock.call(2, {"steps_per_sec": 1}), - mock.call(2, {"uptime": 2}), - ]) + self.assertEqual( + writer_mock.write_scalars.mock_calls, + [ + mock.call(2, {"steps_per_sec": 1}), + mock.call(2, {"uptime": 2}), + ], + ) class DummyProfilerSession: @@ -177,7 +211,8 @@ def add_stop_step(): num_profile_steps=2, profile_duration_ms=2_000, first_profile=3, - every_steps=7) + every_steps=7, + ) for step in range(1, 18): mock_time.return_value = step - 0.5 if step == 9 else step hook(step) @@ -202,7 +237,8 @@ def profile_collect(logdir, callback, hosts, duration_ms): logdir=tempfile.mkdtemp(), profile_duration_ms=2_000, first_profile=3, - every_steps=7) + every_steps=7, + ) for step in range(1, 18): hook(step) self.assertEqual([3, 7, 14], start_steps) @@ -213,7 +249,8 @@ class PeriodicCallbackTest(absltest.TestCase): def test_every_steps(self): callback = mock.Mock() hook = periodic_actions.PeriodicCallback( - every_steps=2, callback_fn=callback) + every_steps=2, callback_fn=callback + ) for step in range(1, 10): hook(step, 3, remainder=step % 3) @@ -222,7 +259,7 @@ def test_every_steps(self): mock.call(remainder=2, step=2, t=3), mock.call(remainder=1, step=4, t=3), mock.call(remainder=0, step=6, t=3), - mock.call(remainder=2, step=8, t=3) + mock.call(remainder=2, step=8, t=3), ] self.assertListEqual(expected_calls, callback.call_args_list) @@ -237,7 +274,7 @@ def test_every_secs(self, mock_time): # Note: time will be initialized at 1 so hook runs at steps 4 & 7. expected_calls = [ mock.call(remainder=4, step=4, t=4.0), - mock.call(remainder=2, step=7, t=7.0) + mock.call(remainder=2, step=7, t=7.0), ] self.assertListEqual(expected_calls, callback.call_args_list) @@ -258,7 +295,8 @@ def cb(step, t): out.append(step) hook = periodic_actions.PeriodicCallback( - every_steps=1, callback_fn=cb, execute_async=True) + every_steps=1, callback_fn=cb, execute_async=True + ) hook(0) hook(1) hook(2) @@ -276,7 +314,8 @@ def cb(step, t): raise Exception hook = periodic_actions.PeriodicCallback( - every_steps=1, callback_fn=cb, execute_async=True) + every_steps=1, callback_fn=cb, execute_async=True + ) hook(0) @@ -290,7 +329,8 @@ def cb(): return 5 hook = periodic_actions.PeriodicCallback( - every_steps=1, callback_fn=cb, pass_step_and_time=False) + every_steps=1, callback_fn=cb, pass_step_and_time=False + ) hook(0) hook(1) self.assertEqual(hook.get_last_callback_result(), 5)