@@ -232,7 +232,7 @@ def log(
232
232
stats [rank ]["input_sizes" ] += input_sizes [i ]
233
233
stats [rank ]["output_sizes" ] += output_sizes [i ]
234
234
235
- used_hbm , used_ddr , perf = _compute_mem_usage_and_perf (
235
+ sparse_hbm , used_hbm , used_ddr , perf = _compute_mem_usage_and_perf (
236
236
topology = topology ,
237
237
best_plan = best_plan ,
238
238
dense_storage = dense_storage ,
@@ -322,7 +322,7 @@ def log(
322
322
)
323
323
324
324
# Max perf and HBM to help root cause imbalance
325
- self ._log_max_perf_and_max_hbm (perf , used_hbm , best_plan )
325
+ self ._log_max_perf_and_max_hbm (perf , sparse_hbm , used_hbm , best_plan )
326
326
self ._log_storage_reservation_stats (
327
327
storage_reservation ,
328
328
topology ,
@@ -449,7 +449,11 @@ def _log_plan_imbalance_stats(
449
449
)
450
450
451
451
def _log_max_perf_and_max_hbm (
452
- self , perfs : List [Perf ], used_hbm : List [int ], best_plan : List [ShardingOption ]
452
+ self ,
453
+ perfs : List [Perf ],
454
+ sparse_hbm : List [int ],
455
+ used_hbm : List [int ],
456
+ best_plan : List [ShardingOption ],
453
457
) -> None :
454
458
total_perfs = [perf .total for perf in perfs ]
455
459
@@ -506,6 +510,12 @@ def _log_max_perf_and_max_hbm(
506
510
self ._stats_table .append (
507
511
f"# { 'Estimated Sharding Distribution' : <{self ._width - 2 }} #"
508
512
)
513
+ self ._stats_table .append (
514
+ f"# { 'Sparse only Max HBM: ' + _generate_rank_hbm_stats (sparse_hbm , max ) : <{self ._width - 3 }} #"
515
+ )
516
+ self ._stats_table .append (
517
+ f"# { 'Sparse only Min HBM: ' + _generate_rank_hbm_stats (sparse_hbm , min ) : <{self ._width - 3 }} #"
518
+ )
509
519
self ._stats_table .append (
510
520
f"# { 'Max HBM: ' + _generate_rank_hbm_stats (used_hbm , max ) : <{self ._width - 3 }} #"
511
521
)
@@ -996,8 +1006,8 @@ def _compute_mem_usage_and_perf(
996
1006
best_plan : List [ShardingOption ],
997
1007
dense_storage : Storage ,
998
1008
kjt_storage : Storage ,
999
- ) -> Tuple [List [int ], List [int ], List [Perf ]]:
1000
- used_hbm = [0 ] * topology .world_size
1009
+ ) -> Tuple [List [int ], List [int ], List [int ], List [ Perf ]]:
1010
+ sparse_hbm = [0 ] * topology .world_size
1001
1011
used_ddr = [0 ] * topology .world_size
1002
1012
perf = [
1003
1013
Perf (fwd_compute = 0 , fwd_comms = 0 , bwd_compute = 0 , bwd_comms = 0 )
@@ -1007,13 +1017,12 @@ def _compute_mem_usage_and_perf(
1007
1017
for shard in sharding_option .shards :
1008
1018
shard_storage = cast (Storage , shard .storage )
1009
1019
rank = cast (int , shard .rank )
1010
- used_hbm [rank ] += shard_storage .hbm
1020
+ sparse_hbm [rank ] += shard_storage .hbm
1011
1021
used_ddr [rank ] += shard_storage .ddr
1012
1022
perf [rank ] += cast (Perf , shard .perf )
1013
-
1014
- used_hbm = [hbm + dense_storage .hbm + kjt_storage .hbm for hbm in used_hbm ]
1023
+ used_hbm = [hbm + dense_storage .hbm + kjt_storage .hbm for hbm in sparse_hbm ]
1015
1024
used_ddr = [ddr + dense_storage .ddr + kjt_storage .ddr for ddr in used_ddr ]
1016
- return used_hbm , used_ddr , perf
1025
+ return sparse_hbm , used_hbm , used_ddr , perf
1017
1026
1018
1027
1019
1028
def _format_storage_breakdown (storage : Storage ) -> str :
0 commit comments