Skip to content

Commit df086d4

Browse files
prajjwal1facebook-github-bot
authored andcommitted
fix RecMetrics loading (make trained_batches a buffer) (#3534)
Summary: This diff addresses the following task: T209753398 Currently `trained_batches` is not stored in state_dict, requiring us to manually sync this variable upon checkpoint loading. We make this variable a buffer so that it can now be captured with model state dict. Differential Revision: D86697665
1 parent 115aaa8 commit df086d4

File tree

4 files changed

+24
-3
lines changed

4 files changed

+24
-3
lines changed

torchrec/metrics/cpu_offloaded_metric_module.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -474,6 +474,9 @@ def sync(self) -> None:
474474
)
475475
self.comms_module.load_pre_compute_states(aggregated_states)
476476

477+
# Sync _trained_batches to comms module
478+
self.comms_module._trained_batches.copy_(self._trained_batches)
479+
477480
logger.info("CPUOffloadedRecMetricModule synced.")
478481

479482
@override

torchrec/metrics/metric_module.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -202,7 +202,11 @@ def __init__(
202202
self.rec_metrics = rec_metrics if rec_metrics else RecMetricList([])
203203
self.throughput_metric = throughput_metric
204204
self.state_metrics = state_metrics if state_metrics else {}
205-
self.trained_batches: int = 0
205+
206+
self.register_buffer(
207+
"_trained_batches", torch.tensor([0], dtype=torch.int64), persistent=True
208+
)
209+
206210
self.batch_size = batch_size
207211
self.world_size = world_size
208212
self.oom_count = 0
@@ -228,6 +232,15 @@ def __init__(
228232
)
229233
self.last_compute_time = -1.0
230234

235+
@property
236+
def trained_batches(self) -> int:
237+
# .trained_batches should return an int
238+
return int(self._trained_batches.item())
239+
240+
@trained_batches.setter
241+
def trained_batches(self, value: int) -> None:
242+
self._trained_batches.fill_(int(value))
243+
231244
def _update_rec_metrics(
232245
self, model_out: Dict[str, torch.Tensor], **kwargs: Any
233246
) -> None:
@@ -260,7 +273,7 @@ def update(self, model_out: Dict[str, torch.Tensor], **kwargs: Any) -> None:
260273
self._update_rec_metrics(model_out, **kwargs)
261274
if self.throughput_metric:
262275
self.throughput_metric.update()
263-
self.trained_batches += 1
276+
self._trained_batches.add_(1)
264277

265278
def _adjust_compute_interval(self) -> None:
266279
"""

torchrec/metrics/tests/test_cpu_offloaded_metric_module.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -341,6 +341,7 @@ def test_state_dict_save_load(self) -> None:
341341
"rec_metrics.rec_metrics.0._metrics_computations.0.state_3": torch.tensor(
342342
[6.0]
343343
),
344+
"_trained_batches": torch.tensor([0], dtype=torch.int64),
344345
},
345346
)
346347

torchrec/metrics/tests/test_metric_module.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -248,7 +248,11 @@ def _run_trainer_checkpointing(rank: int, world_size: int, backend: str) -> None
248248
state_dict = metric_module.state_dict()
249249
keys = list(state_dict.keys())
250250
for k in state_dict.keys():
251-
state_dict[k] = torch.tensor(value, dtype=torch.long).detach()
251+
# _trained_batches is now a 1-D tensor, not a scalar
252+
if k == "_trained_batches":
253+
state_dict[k] = torch.tensor([value], dtype=torch.long).detach()
254+
else:
255+
state_dict[k] = torch.tensor(value, dtype=torch.long).detach()
252256
logging.info(f"Metrics state keys = {keys}")
253257
metric_module.load_state_dict(state_dict)
254258
tc = unittest.TestCase()

0 commit comments

Comments
 (0)