Skip to content

Commit 0159a07

Browse files
galrotemfacebook-github-bot
authored andcommittedApr 26, 2024·
throughput logger - log per epoch
Reviewed By: JKSenthil Differential Revision: D56498952 fbshipit-source-id: 3a8acb74d688c5c609807855c078be5eb1c6046e
1 parent e3ffa1f commit 0159a07

File tree

2 files changed

+168
-17
lines changed

2 files changed

+168
-17
lines changed
 

‎tests/framework/callbacks/test_throughput_logger.py

+112-15
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
# pyre-strict
99

1010
import unittest
11-
from unittest.mock import ANY, call, MagicMock
11+
from unittest.mock import ANY, call, MagicMock, patch
1212

1313
import torch
1414
from pyre_extensions import none_throws
@@ -17,13 +17,15 @@
1717
from torchtnt.framework._test_utils import (
1818
DummyAutoUnit,
1919
DummyPredictUnit,
20+
DummyTrainUnit,
2021
generate_random_dataloader,
2122
)
2223
from torchtnt.framework.callbacks.throughput_logger import ThroughputLogger
2324
from torchtnt.framework.predict import predict
2425

25-
from torchtnt.framework.state import EntryPoint, PhaseState, State
26-
from torchtnt.framework.train import _train_impl
26+
from torchtnt.framework.state import ActivePhase, EntryPoint, PhaseState, State
27+
from torchtnt.framework.train import _train_impl, train
28+
from torchtnt.framework.unit import TrainUnit
2729
from torchtnt.utils.loggers.logger import MetricLogger
2830

2931

@@ -121,21 +123,18 @@ def test_with_comparing_time(self) -> None:
121123
evaluate_every_n_epochs=2,
122124
),
123125
)
126+
throughput_logger = ThroughputLogger(
127+
logger=logger,
128+
throughput_per_batch={"Batches": 1, "Queries": 8},
129+
log_every_n_steps=1,
130+
)
124131

125132
# we want to be able to compare the logging value to the state, so we need to create state manually and
126133
# call _train_impl. This would have been similar to calling fit() and getting the state as a ret value
127134
_train_impl(
128135
state,
129136
DummyAutoUnit(module=torch.nn.Linear(2, 2)),
130-
CallbackHandler(
131-
[
132-
ThroughputLogger(
133-
logger=logger,
134-
throughput_per_batch={"Batches": 1, "Queries": 8},
135-
log_every_n_steps=1,
136-
)
137-
],
138-
),
137+
CallbackHandler([throughput_logger]),
139138
)
140139

141140
train_iteration_times = none_throws(
@@ -163,8 +162,8 @@ def test_with_comparing_time(self) -> None:
163162
eval_iteration_times[i] + eval_twfb_times[i] for i in range(2)
164163
]
165164
self.assertEqual(
166-
logger.log.call_count, 12
167-
) # 8 train (2epochs x 2steps x 2items), 4 eval (1x2x2)
165+
logger.log.call_count, 18
166+
) # steps: 8 train (2epochs x 2steps x 2items), 4 eval (1x2x2). epochs: 4 train (2epoch x 2items). 2 eval (1x2)
168167
train_batches_step_logs = [
169168
call(
170169
"Train: Batches per second (step granularity)",
@@ -197,11 +196,36 @@ def test_with_comparing_time(self) -> None:
197196
)
198197
for i in range(2)
199198
]
199+
# for epoch, we test the logged value separately
200+
train_batches_epoch_logs = [
201+
call("Train: Batches per second (epoch granularity)", ANY, i)
202+
for i in range(1, 3)
203+
]
204+
train_queries_epoch_logs = [
205+
call("Train: Queries per second (epoch granularity)", ANY, i)
206+
for i in range(1, 3)
207+
]
208+
eval_epoch_logs = [
209+
call(
210+
"Eval: Queries per second (epoch granularity)",
211+
ANY,
212+
1,
213+
),
214+
call(
215+
"Eval: Batches per second (epoch granularity)",
216+
ANY,
217+
1,
218+
),
219+
]
220+
200221
logger.log.assert_has_calls(
201222
train_batches_step_logs
202223
+ train_queries_step_logs
203224
+ eval_batches_step_logs
204-
+ eval_queries_step_logs,
225+
+ eval_queries_step_logs
226+
+ train_batches_epoch_logs
227+
+ train_queries_epoch_logs
228+
+ eval_epoch_logs,
205229
any_order=True,
206230
)
207231

@@ -227,6 +251,79 @@ def test_with_predict(self) -> None:
227251
1,
228252
)
229253
],
254+
[
255+
call(
256+
"Predict: Batches per second (epoch granularity)",
257+
ANY,
258+
1,
259+
)
260+
],
261+
)
262+
263+
def test_log_for_epoch(self) -> None:
264+
logger = MagicMock(spec=MetricLogger)
265+
unit = DummyTrainUnit(input_dim=2)
266+
throughput_logger = ThroughputLogger(logger, {"Batches": 1, "Queries": 8})
267+
state = State(entry_point=EntryPoint.TRAIN)
268+
269+
self.assertIsNone(throughput_logger._epoch_start_times.get(ActivePhase.TRAIN))
270+
self.assertEqual(throughput_logger._steps_in_epoch[ActivePhase.TRAIN], 0)
271+
with patch.object(throughput_logger, "_maybe_log_for_step"):
272+
throughput_logger.on_train_step_end(state, unit)
273+
self.assertEqual(throughput_logger._steps_in_epoch[ActivePhase.TRAIN], 1)
274+
275+
with patch("time.perf_counter", return_value=0.5):
276+
throughput_logger.on_train_epoch_start(state, MagicMock(spec=TrainUnit))
277+
self.assertEqual(throughput_logger._epoch_start_times[ActivePhase.TRAIN], 0.5)
278+
279+
throughput_logger._steps_in_epoch[ActivePhase.TRAIN] = (
280+
2 # to assume there were two steps in the epoch
281+
)
282+
logger.log.reset_mock()
283+
with patch("time.perf_counter", return_value=0.6):
284+
throughput_logger._log_for_epoch(state, epoch_logging_for=15)
285+
286+
logger.log.assert_has_calls(
287+
[
288+
call(
289+
"Train: Batches per second (epoch granularity)",
290+
(1 * 2) / (0.6 - 0.5),
291+
15,
292+
),
293+
call(
294+
"Train: Queries per second (epoch granularity)",
295+
(8 * 2) / (0.6 - 0.5),
296+
15,
297+
),
298+
]
299+
)
300+
301+
def test_epoch_logging_time(self) -> None:
302+
logger = MagicMock(spec=MetricLogger)
303+
throughput_logger = ThroughputLogger(logger, {"Queries": 4})
304+
with patch("time.perf_counter", side_effect=[0.1, 0.5, 0.8, 1.5]):
305+
train(
306+
DummyTrainUnit(input_dim=2),
307+
generate_random_dataloader(num_samples=16, input_dim=2, batch_size=4),
308+
max_epochs=2,
309+
max_steps_per_epoch=2,
310+
callbacks=[throughput_logger],
311+
)
312+
313+
logger.log.assert_has_calls(
314+
[
315+
call(
316+
"Train: Queries per second (epoch granularity)",
317+
(4 * 2) / (0.5 - 0.1),
318+
1,
319+
),
320+
call(
321+
"Train: Queries per second (epoch granularity)",
322+
(4 * 2) / (1.5 - 0.8),
323+
2,
324+
),
325+
],
326+
any_order=True,
230327
)
231328

232329
def test_input_validation(self) -> None:

‎torchtnt/framework/callbacks/throughput_logger.py

+56-2
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,9 @@
77
# pyre-strict
88

99

10-
from typing import Mapping
10+
import time
11+
from collections import defaultdict
12+
from typing import Dict, Mapping
1113

1214
from pyre_extensions import none_throws
1315

@@ -32,9 +34,10 @@
3234
class ThroughputLogger(Callback):
3335
"""
3436
A callback which logs the train/eval/predict/fit throughput. For instance, it can be used to log QPS and number of batches processed per second.
35-
The callback logs the throughput on a step basis.
37+
The callback logs the throughput on a step basis and on an epoch basis.
3638
We measure the throughput by dividing the number of batches processed (times the number of items in batch) by the time it took to process the batch:
3739
On a step granularity, we do this by leveraging the already collected timers for the iteration time and data wait time.
40+
On an epoch granularity, we measure the time between on_train_epoch_start and on_train_epoch_end on this callback to calculate the throughput.
3841
3942
Args:
4043
logger: A a subclass of :class:`torchtnt.utils.loggers.logger.MetricLogger`.
@@ -46,6 +49,7 @@ class ThroughputLogger(Callback):
4649
4750
Note:
4851
The values reported are only for rank 0.
52+
For more accurate measurement of epoch throughput, it is recommended to place this callback at the end of the callback list.
4953
"""
5054

5155
def __init__(
@@ -73,12 +77,15 @@ def __init__(
7377
)
7478

7579
self._log_every_n_steps = log_every_n_steps
80+
self._epoch_start_times: Dict[ActivePhase, float] = {}
81+
self._steps_in_epoch: Dict[ActivePhase, int] = defaultdict(int)
7682

7783
def on_train_step_end(self, state: State, unit: TTrainUnit) -> None:
7884
self._maybe_log_for_step(
7985
state,
8086
unit.train_progress.num_steps_completed - 1,
8187
)
88+
self._steps_in_epoch[ActivePhase.TRAIN] += 1
8289

8390
def on_train_end(self, state: State, unit: TTrainUnit) -> None:
8491
self._maybe_log_for_step(
@@ -92,13 +99,18 @@ def on_eval_step_end(self, state: State, unit: TEvalUnit) -> None:
9299
state,
93100
unit.eval_progress.num_steps_completed - 1,
94101
)
102+
self._steps_in_epoch[ActivePhase.EVALUATE] += 1
95103

96104
def on_eval_end(self, state: State, unit: TEvalUnit) -> None:
97105
self._maybe_log_for_step(
98106
state,
99107
unit.eval_progress.num_steps_completed,
100108
is_step_end_hook=False,
101109
)
110+
self._log_for_epoch(
111+
state,
112+
epoch_logging_for=unit.eval_progress.num_epochs_completed,
113+
)
102114

103115
def on_predict_step_end(self, state: State, unit: TPredictUnit) -> None:
104116
self._maybe_log_for_step(
@@ -112,6 +124,25 @@ def on_predict_end(self, state: State, unit: TPredictUnit) -> None:
112124
unit.predict_progress.num_steps_completed,
113125
is_step_end_hook=False,
114126
)
127+
self._log_for_epoch(
128+
state,
129+
epoch_logging_for=unit.predict_progress.num_epochs_completed,
130+
)
131+
132+
def on_train_epoch_start(self, state: State, unit: TTrainUnit) -> None:
133+
self._epoch_start_times[ActivePhase.TRAIN] = time.perf_counter()
134+
135+
def on_train_epoch_end(self, state: State, unit: TTrainUnit) -> None:
136+
self._log_for_epoch(
137+
state,
138+
epoch_logging_for=unit.train_progress.num_epochs_completed,
139+
)
140+
141+
def on_eval_start(self, state: State, unit: TEvalUnit) -> None:
142+
self._epoch_start_times[ActivePhase.EVALUATE] = time.perf_counter()
143+
144+
def on_predict_start(self, state: State, unit: TPredictUnit) -> None:
145+
self._epoch_start_times[ActivePhase.PREDICT] = time.perf_counter()
115146

116147
def _maybe_log_for_step(
117148
self,
@@ -154,3 +185,26 @@ def _maybe_log_for_step(
154185
num_items / total_time,
155186
step_logging_for,
156187
)
188+
189+
def _log_for_epoch(
190+
self,
191+
state: State,
192+
*,
193+
epoch_logging_for: int,
194+
) -> None:
195+
time_since_epoch_start = (
196+
time.perf_counter() - self._epoch_start_times[state.active_phase]
197+
)
198+
199+
steps_in_epoch = self._steps_in_epoch[state.active_phase]
200+
if steps_in_epoch <= 0:
201+
return
202+
203+
for item, num_items in self._throughput_per_batch.items():
204+
self._logger.log(
205+
f"{ACTIVE_PHASE_TO_LABEL_PREFIX[state.active_phase]}: {item} per second (epoch granularity)",
206+
(num_items * steps_in_epoch) / time_since_epoch_start,
207+
epoch_logging_for,
208+
)
209+
210+
self._steps_in_epoch[state.active_phase] = 0

0 commit comments

Comments
 (0)
Please sign in to comment.