@@ -238,7 +238,10 @@ def maybe_enable_amp(
238238
239239
240240def init_distributed (
241- comm_config : CommConfig , enable_cpu_backend : bool = False , base_folder : str = ""
241+ comm_config : CommConfig ,
242+ enable_cpu_backend : bool = False ,
243+ base_folder : str = "" ,
244+ replica_id : int | None = None ,
242245):
243246 def _warn_overwrite_env (env , val ):
244247 if env in os .environ :
@@ -279,9 +282,17 @@ def _get_distributed_backend(enable_cpu_backend):
279282 os .makedirs (dump_dir , exist_ok = True )
280283 _warn_overwrite_env (TRACE_FILE , f"{ dump_dir } /{ prefix } " )
281284
285+ local_rank = os .environ .get ("RANK" )
286+ world_size = os .environ .get ("WORLD_SIZE" )
287+
288+ global_rank = None
289+ if local_rank is not None and replica_id is not None and world_size is not None :
290+ global_rank = int (local_rank ) + int (replica_id ) * int (world_size )
291+
282292 torch .distributed .init_process_group (
283293 backend = _get_distributed_backend (enable_cpu_backend ),
284294 timeout = timedelta (seconds = comm_config .init_timeout_seconds ),
295+ rank = global_rank ,
285296 )
286297
287298
@@ -432,9 +443,7 @@ def _clip_grad_norm_with_ep(
432443 if math .isinf (norm_type ):
433444 total_norm = torch .maximum (ep_grads_total_norm , non_ep_grads_total_norm )
434445 else :
435- total_norm = (
436- ep_grads_total_norm ** norm_type + non_ep_grads_total_norm ** norm_type
437- )
446+ total_norm = ep_grads_total_norm ** norm_type + non_ep_grads_total_norm ** norm_type
438447 total_norm **= 1.0 / norm_type
439448
440449 if pp_mesh is not None :
0 commit comments