Skip to content

Commit 868b11c

Browse files
samialabedcopybara-github
authored andcommitted
Add preserving existing note functionality
Users can set notes that aren't overwritten by the report progress. TESTED:UNIT Test PiperOrigin-RevId: 745215687
1 parent a8152eb commit 868b11c

2 files changed

Lines changed: 80 additions & 33 deletions

File tree

clu/periodic_actions.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -192,6 +192,11 @@ def __init__(self,
192192
# Using max_worker=1 guarantees that the calls to _wait_jax_async_dispatch()
193193
# happen sequentially.
194194
self._executor = concurrent.futures.ThreadPoolExecutor(max_workers=1)
195+
self._persistent_notes = ""
196+
197+
def set_persistent_notes(self, message: str):
198+
"""Sets the persistent notes for this work unit (not overwritten by the periodic action)."""
199+
self._persistent_notes = message
195200

196201
def _should_trigger(self, step: int, t: float) -> bool:
197202
# Note: step == self._previous_step is only True on the first step.
@@ -211,6 +216,8 @@ def _apply(self, step: int, t: float):
211216
f"{100 * dt / total:.1f}% {name}"
212217
for name, dt in sorted(self._time_per_part.items())))
213218
# This should be relatively cheap so we can do it in the same main thread.
219+
if self._persistent_notes:
220+
message = f"{self._persistent_notes}\n{message}"
214221
platform.work_unit().set_notes(message)
215222
if self._writer is not None:
216223
self._writer.write_scalars(step, {"steps_per_sec": steps_per_sec})

clu/periodic_actions_test.py

Lines changed: 73 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,8 @@ class ReportProgressTest(parameterized.TestCase):
2727

2828
def test_every_steps(self):
2929
hook = periodic_actions.ReportProgress(
30-
every_steps=4, every_secs=None, num_train_steps=10)
30+
every_steps=4, every_secs=None, num_train_steps=10
31+
)
3132
t = time.monotonic()
3233
with self.assertLogs(level="INFO") as logs:
3334
self.assertFalse(hook(1, t))
@@ -38,13 +39,18 @@ def test_every_steps(self):
3839
t += 0.12
3940
self.assertTrue(hook(4, t))
4041
# We did 1 step every 0.12s => 8.333 steps/s.
41-
self.assertEqual(logs.output, [
42-
"INFO:absl:Setting work unit notes: 8.3 steps/s, 40.0% (4/10), ETA: 0m"
43-
])
42+
self.assertEqual(
43+
logs.output,
44+
[
45+
"INFO:absl:Setting work unit notes: 8.3 steps/s, 40.0% (4/10),"
46+
" ETA: 0m"
47+
],
48+
)
4449

4550
def test_every_secs(self):
4651
hook = periodic_actions.ReportProgress(
47-
every_steps=None, every_secs=0.3, num_train_steps=10)
52+
every_steps=None, every_secs=0.3, num_train_steps=10
53+
)
4854
t = time.monotonic()
4955
with self.assertLogs(level="INFO") as logs:
5056
self.assertFalse(hook(1, t))
@@ -55,9 +61,13 @@ def test_every_secs(self):
5561
t += 0.12
5662
self.assertTrue(hook(4, t))
5763
# We did 1 step every 0.12s => 8.333 steps/s.
58-
self.assertEqual(logs.output, [
59-
"INFO:absl:Setting work unit notes: 8.3 steps/s, 40.0% (4/10), ETA: 0m"
60-
])
64+
self.assertEqual(
65+
logs.output,
66+
[
67+
"INFO:absl:Setting work unit notes: 8.3 steps/s, 40.0% (4/10),"
68+
" ETA: 0m"
69+
],
70+
)
6171

6272
def test_without_num_train_steps(self):
6373
report = periodic_actions.ReportProgress(every_steps=2)
@@ -66,9 +76,22 @@ def test_without_num_train_steps(self):
6676
self.assertFalse(report(1, t))
6777
self.assertTrue(report(2, t + 0.12))
6878
# We did 1 step in 0.12s => 8.333 steps/s.
69-
self.assertEqual(logs.output, [
70-
"INFO:absl:Setting work unit notes: 8.3 steps/s"
71-
])
79+
self.assertEqual(
80+
logs.output, ["INFO:absl:Setting work unit notes: 8.3 steps/s"]
81+
)
82+
83+
def test_with_persistent_notes(self):
84+
report = periodic_actions.ReportProgress(every_steps=2)
85+
report.set_persistent_notes("Hello world")
86+
t = time.monotonic()
87+
with self.assertLogs(level="INFO") as logs:
88+
self.assertFalse(report(1, t))
89+
self.assertTrue(report(2, t + 0.12))
90+
# We did 1 step in 0.12s => 8.333 steps/s.
91+
self.assertEqual(
92+
logs.output,
93+
["INFO:absl:Setting work unit notes: Hello world\n8.3 steps/s"],
94+
)
7295

7396
def test_unknown_cardinality(self):
7497
report = periodic_actions.ReportProgress(every_steps=2)
@@ -77,15 +100,16 @@ def test_unknown_cardinality(self):
77100
self.assertFalse(report(1, t))
78101
self.assertTrue(report(2, t + 0.12))
79102
# We did 1 step in 0.12s => 8.333 steps/s.
80-
self.assertEqual(logs.output, [
81-
"INFO:absl:Setting work unit notes: 8.3 steps/s"
82-
])
103+
self.assertEqual(
104+
logs.output, ["INFO:absl:Setting work unit notes: 8.3 steps/s"]
105+
)
83106

84107
def test_called_every_step(self):
85108
hook = periodic_actions.ReportProgress(every_steps=3, num_train_steps=10)
86109
t = time.monotonic()
87110
with self.assertRaisesRegex(
88-
ValueError, "PeriodicAction must be called after every step"):
111+
ValueError, "PeriodicAction must be called after every step"
112+
):
89113
hook(1, t)
90114
hook(11, t) # Raises exception.
91115

@@ -97,10 +121,13 @@ def test_called_every_step(self):
97121
def test_named(self, wait_jax_async_dispatch, mock_time):
98122
mock_time.return_value = 0
99123
hook = periodic_actions.ReportProgress(
100-
every_steps=1, every_secs=None, num_train_steps=10)
124+
every_steps=1, every_secs=None, num_train_steps=10
125+
)
126+
101127
def _wait():
102128
# Here we depend on hook._executor=ThreadPoolExecutor(max_workers=1)
103129
hook._executor.submit(lambda: None).result()
130+
104131
self.assertFalse(hook(1)) # Never triggers on first execution.
105132
with hook.timed("test1", wait_jax_async_dispatch):
106133
_wait()
@@ -117,25 +144,32 @@ def _wait():
117144
mock_time.return_value = 4
118145
with self.assertLogs(level="INFO") as logs:
119146
self.assertTrue(hook(2))
120-
self.assertEqual(logs.output, [
121-
"INFO:absl:Setting work unit notes: 0.2 steps/s, 20.0% (2/10), ETA: 0m"
122-
" (0m : 50.0% test1, 25.0% test2)"
123-
])
147+
self.assertEqual(
148+
logs.output,
149+
[
150+
"INFO:absl:Setting work unit notes: 0.2 steps/s, 20.0% (2/10), ETA:"
151+
" 0m (0m : 50.0% test1, 25.0% test2)"
152+
],
153+
)
124154

125155
@mock.patch("time.monotonic")
126156
def test_write_metrics(self, time_mock):
127157
time_mock.return_value = 0
128158
writer_mock = mock.Mock()
129159
hook = periodic_actions.ReportProgress(
130-
every_steps=2, every_secs=None, writer=writer_mock)
160+
every_steps=2, every_secs=None, writer=writer_mock
161+
)
131162
time_mock.return_value = 1
132163
hook(1)
133164
time_mock.return_value = 2
134165
hook(2)
135-
self.assertEqual(writer_mock.write_scalars.mock_calls, [
136-
mock.call(2, {"steps_per_sec": 1}),
137-
mock.call(2, {"uptime": 2}),
138-
])
166+
self.assertEqual(
167+
writer_mock.write_scalars.mock_calls,
168+
[
169+
mock.call(2, {"steps_per_sec": 1}),
170+
mock.call(2, {"uptime": 2}),
171+
],
172+
)
139173

140174

141175
class DummyProfilerSession:
@@ -177,7 +211,8 @@ def add_stop_step():
177211
num_profile_steps=2,
178212
profile_duration_ms=2_000,
179213
first_profile=3,
180-
every_steps=7)
214+
every_steps=7,
215+
)
181216
for step in range(1, 18):
182217
mock_time.return_value = step - 0.5 if step == 9 else step
183218
hook(step)
@@ -202,7 +237,8 @@ def profile_collect(logdir, callback, hosts, duration_ms):
202237
logdir=tempfile.mkdtemp(),
203238
profile_duration_ms=2_000,
204239
first_profile=3,
205-
every_steps=7)
240+
every_steps=7,
241+
)
206242
for step in range(1, 18):
207243
hook(step)
208244
self.assertEqual([3, 7, 14], start_steps)
@@ -213,7 +249,8 @@ class PeriodicCallbackTest(absltest.TestCase):
213249
def test_every_steps(self):
214250
callback = mock.Mock()
215251
hook = periodic_actions.PeriodicCallback(
216-
every_steps=2, callback_fn=callback)
252+
every_steps=2, callback_fn=callback
253+
)
217254

218255
for step in range(1, 10):
219256
hook(step, 3, remainder=step % 3)
@@ -222,7 +259,7 @@ def test_every_steps(self):
222259
mock.call(remainder=2, step=2, t=3),
223260
mock.call(remainder=1, step=4, t=3),
224261
mock.call(remainder=0, step=6, t=3),
225-
mock.call(remainder=2, step=8, t=3)
262+
mock.call(remainder=2, step=8, t=3),
226263
]
227264
self.assertListEqual(expected_calls, callback.call_args_list)
228265

@@ -237,7 +274,7 @@ def test_every_secs(self, mock_time):
237274
# Note: time will be initialized at 1 so hook runs at steps 4 & 7.
238275
expected_calls = [
239276
mock.call(remainder=4, step=4, t=4.0),
240-
mock.call(remainder=2, step=7, t=7.0)
277+
mock.call(remainder=2, step=7, t=7.0),
241278
]
242279
self.assertListEqual(expected_calls, callback.call_args_list)
243280

@@ -258,7 +295,8 @@ def cb(step, t):
258295
out.append(step)
259296

260297
hook = periodic_actions.PeriodicCallback(
261-
every_steps=1, callback_fn=cb, execute_async=True)
298+
every_steps=1, callback_fn=cb, execute_async=True
299+
)
262300
hook(0)
263301
hook(1)
264302
hook(2)
@@ -276,7 +314,8 @@ def cb(step, t):
276314
raise Exception
277315

278316
hook = periodic_actions.PeriodicCallback(
279-
every_steps=1, callback_fn=cb, execute_async=True)
317+
every_steps=1, callback_fn=cb, execute_async=True
318+
)
280319

281320
hook(0)
282321

@@ -290,7 +329,8 @@ def cb():
290329
return 5
291330

292331
hook = periodic_actions.PeriodicCallback(
293-
every_steps=1, callback_fn=cb, pass_step_and_time=False)
332+
every_steps=1, callback_fn=cb, pass_step_and_time=False
333+
)
294334
hook(0)
295335
hook(1)
296336
self.assertEqual(hook.get_last_callback_result(), 5)

0 commit comments

Comments
 (0)