8
8
# pyre-strict
9
9
10
10
import unittest
11
- from unittest .mock import ANY , call , MagicMock
11
+ from unittest .mock import ANY , call , MagicMock , patch
12
12
13
13
import torch
14
14
from pyre_extensions import none_throws
17
17
from torchtnt .framework ._test_utils import (
18
18
DummyAutoUnit ,
19
19
DummyPredictUnit ,
20
+ DummyTrainUnit ,
20
21
generate_random_dataloader ,
21
22
)
22
23
from torchtnt .framework .callbacks .throughput_logger import ThroughputLogger
23
24
from torchtnt .framework .predict import predict
24
25
25
- from torchtnt .framework .state import EntryPoint , PhaseState , State
26
- from torchtnt .framework .train import _train_impl
26
+ from torchtnt .framework .state import ActivePhase , EntryPoint , PhaseState , State
27
+ from torchtnt .framework .train import _train_impl , train
28
+ from torchtnt .framework .unit import TrainUnit
27
29
from torchtnt .utils .loggers .logger import MetricLogger
28
30
29
31
@@ -121,21 +123,18 @@ def test_with_comparing_time(self) -> None:
121
123
evaluate_every_n_epochs = 2 ,
122
124
),
123
125
)
126
+ throughput_logger = ThroughputLogger (
127
+ logger = logger ,
128
+ throughput_per_batch = {"Batches" : 1 , "Queries" : 8 },
129
+ log_every_n_steps = 1 ,
130
+ )
124
131
125
132
# we want to be able to compare the logging value to the state, so we need to create state manually and
126
133
# call _train_impl. This would have been similar to calling fit() and getting the state as a ret value
127
134
_train_impl (
128
135
state ,
129
136
DummyAutoUnit (module = torch .nn .Linear (2 , 2 )),
130
- CallbackHandler (
131
- [
132
- ThroughputLogger (
133
- logger = logger ,
134
- throughput_per_batch = {"Batches" : 1 , "Queries" : 8 },
135
- log_every_n_steps = 1 ,
136
- )
137
- ],
138
- ),
137
+ CallbackHandler ([throughput_logger ]),
139
138
)
140
139
141
140
train_iteration_times = none_throws (
@@ -163,8 +162,8 @@ def test_with_comparing_time(self) -> None:
163
162
eval_iteration_times [i ] + eval_twfb_times [i ] for i in range (2 )
164
163
]
165
164
self .assertEqual (
166
- logger .log .call_count , 12
167
- ) # 8 train (2epochs x 2steps x 2items), 4 eval (1x2x2)
165
+ logger .log .call_count , 18
166
+ ) # steps: 8 train (2epochs x 2steps x 2items), 4 eval (1x2x2). epochs: 4 train (2epoch x 2items). 2 eval (1x2 )
168
167
train_batches_step_logs = [
169
168
call (
170
169
"Train: Batches per second (step granularity)" ,
@@ -197,11 +196,36 @@ def test_with_comparing_time(self) -> None:
197
196
)
198
197
for i in range (2 )
199
198
]
199
+ # for epoch, we test the logged value separately
200
+ train_batches_epoch_logs = [
201
+ call ("Train: Batches per second (epoch granularity)" , ANY , i )
202
+ for i in range (1 , 3 )
203
+ ]
204
+ train_queries_epoch_logs = [
205
+ call ("Train: Queries per second (epoch granularity)" , ANY , i )
206
+ for i in range (1 , 3 )
207
+ ]
208
+ eval_epoch_logs = [
209
+ call (
210
+ "Eval: Queries per second (epoch granularity)" ,
211
+ ANY ,
212
+ 1 ,
213
+ ),
214
+ call (
215
+ "Eval: Batches per second (epoch granularity)" ,
216
+ ANY ,
217
+ 1 ,
218
+ ),
219
+ ]
220
+
200
221
logger .log .assert_has_calls (
201
222
train_batches_step_logs
202
223
+ train_queries_step_logs
203
224
+ eval_batches_step_logs
204
- + eval_queries_step_logs ,
225
+ + eval_queries_step_logs
226
+ + train_batches_epoch_logs
227
+ + train_queries_epoch_logs
228
+ + eval_epoch_logs ,
205
229
any_order = True ,
206
230
)
207
231
@@ -227,6 +251,79 @@ def test_with_predict(self) -> None:
227
251
1 ,
228
252
)
229
253
],
254
+ [
255
+ call (
256
+ "Predict: Batches per second (epoch granularity)" ,
257
+ ANY ,
258
+ 1 ,
259
+ )
260
+ ],
261
+ )
262
+
263
+ def test_log_for_epoch (self ) -> None :
264
+ logger = MagicMock (spec = MetricLogger )
265
+ unit = DummyTrainUnit (input_dim = 2 )
266
+ throughput_logger = ThroughputLogger (logger , {"Batches" : 1 , "Queries" : 8 })
267
+ state = State (entry_point = EntryPoint .TRAIN )
268
+
269
+ self .assertIsNone (throughput_logger ._epoch_start_times .get (ActivePhase .TRAIN ))
270
+ self .assertEqual (throughput_logger ._steps_in_epoch [ActivePhase .TRAIN ], 0 )
271
+ with patch .object (throughput_logger , "_maybe_log_for_step" ):
272
+ throughput_logger .on_train_step_end (state , unit )
273
+ self .assertEqual (throughput_logger ._steps_in_epoch [ActivePhase .TRAIN ], 1 )
274
+
275
+ with patch ("time.perf_counter" , return_value = 0.5 ):
276
+ throughput_logger .on_train_epoch_start (state , MagicMock (spec = TrainUnit ))
277
+ self .assertEqual (throughput_logger ._epoch_start_times [ActivePhase .TRAIN ], 0.5 )
278
+
279
+ throughput_logger ._steps_in_epoch [ActivePhase .TRAIN ] = (
280
+ 2 # to assume there were two steps in the epoch
281
+ )
282
+ logger .log .reset_mock ()
283
+ with patch ("time.perf_counter" , return_value = 0.6 ):
284
+ throughput_logger ._log_for_epoch (state , epoch_logging_for = 15 )
285
+
286
+ logger .log .assert_has_calls (
287
+ [
288
+ call (
289
+ "Train: Batches per second (epoch granularity)" ,
290
+ (1 * 2 ) / (0.6 - 0.5 ),
291
+ 15 ,
292
+ ),
293
+ call (
294
+ "Train: Queries per second (epoch granularity)" ,
295
+ (8 * 2 ) / (0.6 - 0.5 ),
296
+ 15 ,
297
+ ),
298
+ ]
299
+ )
300
+
301
+ def test_epoch_logging_time (self ) -> None :
302
+ logger = MagicMock (spec = MetricLogger )
303
+ throughput_logger = ThroughputLogger (logger , {"Queries" : 4 })
304
+ with patch ("time.perf_counter" , side_effect = [0.1 , 0.5 , 0.8 , 1.5 ]):
305
+ train (
306
+ DummyTrainUnit (input_dim = 2 ),
307
+ generate_random_dataloader (num_samples = 16 , input_dim = 2 , batch_size = 4 ),
308
+ max_epochs = 2 ,
309
+ max_steps_per_epoch = 2 ,
310
+ callbacks = [throughput_logger ],
311
+ )
312
+
313
+ logger .log .assert_has_calls (
314
+ [
315
+ call (
316
+ "Train: Queries per second (epoch granularity)" ,
317
+ (4 * 2 ) / (0.5 - 0.1 ),
318
+ 1 ,
319
+ ),
320
+ call (
321
+ "Train: Queries per second (epoch granularity)" ,
322
+ (4 * 2 ) / (1.5 - 0.8 ),
323
+ 2 ,
324
+ ),
325
+ ],
326
+ any_order = True ,
230
327
)
231
328
232
329
def test_input_validation (self ) -> None :
0 commit comments