Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions clu/periodic_actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,11 @@ def __init__(self,
# Using max_worker=1 guarantees that the calls to _wait_jax_async_dispatch()
# happen sequentially.
self._executor = concurrent.futures.ThreadPoolExecutor(max_workers=1)
self._persistent_notes = ""

def set_persistent_notes(self, message: str):
"""Sets the persistent notes for this work unit (not overwritten by the periodic action)."""
self._persistent_notes = message

def _should_trigger(self, step: int, t: float) -> bool:
# Note: step == self._previous_step is only True on the first step.
Expand All @@ -211,6 +216,8 @@ def _apply(self, step: int, t: float):
f"{100 * dt / total:.1f}% {name}"
for name, dt in sorted(self._time_per_part.items())))
# This should be relatively cheap so we can do it in the same main thread.
if self._persistent_notes:
message = f"{self._persistent_notes}\n{message}"
platform.work_unit().set_notes(message)
if self._writer is not None:
self._writer.write_scalars(step, {"steps_per_sec": steps_per_sec})
Expand Down
106 changes: 73 additions & 33 deletions clu/periodic_actions_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@ class ReportProgressTest(parameterized.TestCase):

def test_every_steps(self):
hook = periodic_actions.ReportProgress(
every_steps=4, every_secs=None, num_train_steps=10)
every_steps=4, every_secs=None, num_train_steps=10
)
t = time.monotonic()
with self.assertLogs(level="INFO") as logs:
self.assertFalse(hook(1, t))
Expand All @@ -38,13 +39,18 @@ def test_every_steps(self):
t += 0.12
self.assertTrue(hook(4, t))
# We did 1 step every 0.12s => 8.333 steps/s.
self.assertEqual(logs.output, [
"INFO:absl:Setting work unit notes: 8.3 steps/s, 40.0% (4/10), ETA: 0m"
])
self.assertEqual(
logs.output,
[
"INFO:absl:Setting work unit notes: 8.3 steps/s, 40.0% (4/10),"
" ETA: 0m"
],
)

def test_every_secs(self):
hook = periodic_actions.ReportProgress(
every_steps=None, every_secs=0.3, num_train_steps=10)
every_steps=None, every_secs=0.3, num_train_steps=10
)
t = time.monotonic()
with self.assertLogs(level="INFO") as logs:
self.assertFalse(hook(1, t))
Expand All @@ -55,9 +61,13 @@ def test_every_secs(self):
t += 0.12
self.assertTrue(hook(4, t))
# We did 1 step every 0.12s => 8.333 steps/s.
self.assertEqual(logs.output, [
"INFO:absl:Setting work unit notes: 8.3 steps/s, 40.0% (4/10), ETA: 0m"
])
self.assertEqual(
logs.output,
[
"INFO:absl:Setting work unit notes: 8.3 steps/s, 40.0% (4/10),"
" ETA: 0m"
],
)

def test_without_num_train_steps(self):
report = periodic_actions.ReportProgress(every_steps=2)
Expand All @@ -66,9 +76,22 @@ def test_without_num_train_steps(self):
self.assertFalse(report(1, t))
self.assertTrue(report(2, t + 0.12))
# We did 1 step in 0.12s => 8.333 steps/s.
self.assertEqual(logs.output, [
"INFO:absl:Setting work unit notes: 8.3 steps/s"
])
self.assertEqual(
logs.output, ["INFO:absl:Setting work unit notes: 8.3 steps/s"]
)

def test_with_persistent_notes(self):
report = periodic_actions.ReportProgress(every_steps=2)
report.set_persistent_notes("Hello world")
t = time.monotonic()
with self.assertLogs(level="INFO") as logs:
self.assertFalse(report(1, t))
self.assertTrue(report(2, t + 0.12))
# We did 1 step in 0.12s => 8.333 steps/s.
self.assertEqual(
logs.output,
["INFO:absl:Setting work unit notes: Hello world\n8.3 steps/s"],
)

def test_unknown_cardinality(self):
report = periodic_actions.ReportProgress(every_steps=2)
Expand All @@ -77,15 +100,16 @@ def test_unknown_cardinality(self):
self.assertFalse(report(1, t))
self.assertTrue(report(2, t + 0.12))
# We did 1 step in 0.12s => 8.333 steps/s.
self.assertEqual(logs.output, [
"INFO:absl:Setting work unit notes: 8.3 steps/s"
])
self.assertEqual(
logs.output, ["INFO:absl:Setting work unit notes: 8.3 steps/s"]
)

def test_called_every_step(self):
hook = periodic_actions.ReportProgress(every_steps=3, num_train_steps=10)
t = time.monotonic()
with self.assertRaisesRegex(
ValueError, "PeriodicAction must be called after every step"):
ValueError, "PeriodicAction must be called after every step"
):
hook(1, t)
hook(11, t) # Raises exception.

Expand All @@ -97,10 +121,13 @@ def test_called_every_step(self):
def test_named(self, wait_jax_async_dispatch, mock_time):
mock_time.return_value = 0
hook = periodic_actions.ReportProgress(
every_steps=1, every_secs=None, num_train_steps=10)
every_steps=1, every_secs=None, num_train_steps=10
)

def _wait():
# Here we depend on hook._executor=ThreadPoolExecutor(max_workers=1)
hook._executor.submit(lambda: None).result()

self.assertFalse(hook(1)) # Never triggers on first execution.
with hook.timed("test1", wait_jax_async_dispatch):
_wait()
Expand All @@ -117,25 +144,32 @@ def _wait():
mock_time.return_value = 4
with self.assertLogs(level="INFO") as logs:
self.assertTrue(hook(2))
self.assertEqual(logs.output, [
"INFO:absl:Setting work unit notes: 0.2 steps/s, 20.0% (2/10), ETA: 0m"
" (0m : 50.0% test1, 25.0% test2)"
])
self.assertEqual(
logs.output,
[
"INFO:absl:Setting work unit notes: 0.2 steps/s, 20.0% (2/10), ETA:"
" 0m (0m : 50.0% test1, 25.0% test2)"
],
)

@mock.patch("time.monotonic")
def test_write_metrics(self, time_mock):
time_mock.return_value = 0
writer_mock = mock.Mock()
hook = periodic_actions.ReportProgress(
every_steps=2, every_secs=None, writer=writer_mock)
every_steps=2, every_secs=None, writer=writer_mock
)
time_mock.return_value = 1
hook(1)
time_mock.return_value = 2
hook(2)
self.assertEqual(writer_mock.write_scalars.mock_calls, [
mock.call(2, {"steps_per_sec": 1}),
mock.call(2, {"uptime": 2}),
])
self.assertEqual(
writer_mock.write_scalars.mock_calls,
[
mock.call(2, {"steps_per_sec": 1}),
mock.call(2, {"uptime": 2}),
],
)


class DummyProfilerSession:
Expand Down Expand Up @@ -177,7 +211,8 @@ def add_stop_step():
num_profile_steps=2,
profile_duration_ms=2_000,
first_profile=3,
every_steps=7)
every_steps=7,
)
for step in range(1, 18):
mock_time.return_value = step - 0.5 if step == 9 else step
hook(step)
Expand All @@ -202,7 +237,8 @@ def profile_collect(logdir, callback, hosts, duration_ms):
logdir=tempfile.mkdtemp(),
profile_duration_ms=2_000,
first_profile=3,
every_steps=7)
every_steps=7,
)
for step in range(1, 18):
hook(step)
self.assertEqual([3, 7, 14], start_steps)
Expand All @@ -213,7 +249,8 @@ class PeriodicCallbackTest(absltest.TestCase):
def test_every_steps(self):
callback = mock.Mock()
hook = periodic_actions.PeriodicCallback(
every_steps=2, callback_fn=callback)
every_steps=2, callback_fn=callback
)

for step in range(1, 10):
hook(step, 3, remainder=step % 3)
Expand All @@ -222,7 +259,7 @@ def test_every_steps(self):
mock.call(remainder=2, step=2, t=3),
mock.call(remainder=1, step=4, t=3),
mock.call(remainder=0, step=6, t=3),
mock.call(remainder=2, step=8, t=3)
mock.call(remainder=2, step=8, t=3),
]
self.assertListEqual(expected_calls, callback.call_args_list)

Expand All @@ -237,7 +274,7 @@ def test_every_secs(self, mock_time):
# Note: time will be initialized at 1 so hook runs at steps 4 & 7.
expected_calls = [
mock.call(remainder=4, step=4, t=4.0),
mock.call(remainder=2, step=7, t=7.0)
mock.call(remainder=2, step=7, t=7.0),
]
self.assertListEqual(expected_calls, callback.call_args_list)

Expand All @@ -258,7 +295,8 @@ def cb(step, t):
out.append(step)

hook = periodic_actions.PeriodicCallback(
every_steps=1, callback_fn=cb, execute_async=True)
every_steps=1, callback_fn=cb, execute_async=True
)
hook(0)
hook(1)
hook(2)
Expand All @@ -276,7 +314,8 @@ def cb(step, t):
raise Exception

hook = periodic_actions.PeriodicCallback(
every_steps=1, callback_fn=cb, execute_async=True)
every_steps=1, callback_fn=cb, execute_async=True
)

hook(0)

Expand All @@ -290,7 +329,8 @@ def cb():
return 5

hook = periodic_actions.PeriodicCallback(
every_steps=1, callback_fn=cb, pass_step_and_time=False)
every_steps=1, callback_fn=cb, pass_step_and_time=False
)
hook(0)
hook(1)
self.assertEqual(hook.get_last_callback_result(), 5)
Expand Down