Skip to content

Commit 1b1853d

Browse files
aliafzalfacebook-github-bot
authored andcommitted
Log sparse hbm estimate along with max hbm (#3275)
Summary: Pull Request resolved: #3275 as title Reviewed By: micrain Differential Revision: D80079739 fbshipit-source-id: b7d7eb187201fc924e38008e5a4b39d021eaf36a
1 parent 08d1685 commit 1b1853d

File tree

1 file changed

+18
-9
lines changed

1 file changed

+18
-9
lines changed

torchrec/distributed/planner/stats.py

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -232,7 +232,7 @@ def log(
232232
stats[rank]["input_sizes"] += input_sizes[i]
233233
stats[rank]["output_sizes"] += output_sizes[i]
234234

235-
used_hbm, used_ddr, perf = _compute_mem_usage_and_perf(
235+
sparse_hbm, used_hbm, used_ddr, perf = _compute_mem_usage_and_perf(
236236
topology=topology,
237237
best_plan=best_plan,
238238
dense_storage=dense_storage,
@@ -322,7 +322,7 @@ def log(
322322
)
323323

324324
# Max perf and HBM to help root cause imbalance
325-
self._log_max_perf_and_max_hbm(perf, used_hbm, best_plan)
325+
self._log_max_perf_and_max_hbm(perf, sparse_hbm, used_hbm, best_plan)
326326
self._log_storage_reservation_stats(
327327
storage_reservation,
328328
topology,
@@ -449,7 +449,11 @@ def _log_plan_imbalance_stats(
449449
)
450450

451451
def _log_max_perf_and_max_hbm(
452-
self, perfs: List[Perf], used_hbm: List[int], best_plan: List[ShardingOption]
452+
self,
453+
perfs: List[Perf],
454+
sparse_hbm: List[int],
455+
used_hbm: List[int],
456+
best_plan: List[ShardingOption],
453457
) -> None:
454458
total_perfs = [perf.total for perf in perfs]
455459

@@ -506,6 +510,12 @@ def _log_max_perf_and_max_hbm(
506510
self._stats_table.append(
507511
f"# {'Estimated Sharding Distribution' : <{self._width-2}}#"
508512
)
513+
self._stats_table.append(
514+
f"# {'Sparse only Max HBM: '+_generate_rank_hbm_stats(sparse_hbm, max) : <{self._width-3}}#"
515+
)
516+
self._stats_table.append(
517+
f"# {'Sparse only Min HBM: '+_generate_rank_hbm_stats(sparse_hbm, min) : <{self._width-3}}#"
518+
)
509519
self._stats_table.append(
510520
f"# {'Max HBM: '+_generate_rank_hbm_stats(used_hbm, max) : <{self._width-3}}#"
511521
)
@@ -996,8 +1006,8 @@ def _compute_mem_usage_and_perf(
9961006
best_plan: List[ShardingOption],
9971007
dense_storage: Storage,
9981008
kjt_storage: Storage,
999-
) -> Tuple[List[int], List[int], List[Perf]]:
1000-
used_hbm = [0] * topology.world_size
1009+
) -> Tuple[List[int], List[int], List[int], List[Perf]]:
1010+
sparse_hbm = [0] * topology.world_size
10011011
used_ddr = [0] * topology.world_size
10021012
perf = [
10031013
Perf(fwd_compute=0, fwd_comms=0, bwd_compute=0, bwd_comms=0)
@@ -1007,13 +1017,12 @@ def _compute_mem_usage_and_perf(
10071017
for shard in sharding_option.shards:
10081018
shard_storage = cast(Storage, shard.storage)
10091019
rank = cast(int, shard.rank)
1010-
used_hbm[rank] += shard_storage.hbm
1020+
sparse_hbm[rank] += shard_storage.hbm
10111021
used_ddr[rank] += shard_storage.ddr
10121022
perf[rank] += cast(Perf, shard.perf)
1013-
1014-
used_hbm = [hbm + dense_storage.hbm + kjt_storage.hbm for hbm in used_hbm]
1023+
used_hbm = [hbm + dense_storage.hbm + kjt_storage.hbm for hbm in sparse_hbm]
10151024
used_ddr = [ddr + dense_storage.ddr + kjt_storage.ddr for ddr in used_ddr]
1016-
return used_hbm, used_ddr, perf
1025+
return sparse_hbm, used_hbm, used_ddr, perf
10171026

10181027

10191028
def _format_storage_breakdown(storage: Storage) -> str:

0 commit comments

Comments
 (0)