Skip to content

Commit 1c166f3

Browse files
committed
create default process group with global rank
1 parent 4910cd6 commit 1c166f3

File tree

2 files changed

+18
-4
lines changed

2 files changed

+18
-4
lines changed

torchtitan/distributed/utils.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -238,7 +238,10 @@ def maybe_enable_amp(
238238

239239

240240
def init_distributed(
241-
comm_config: CommConfig, enable_cpu_backend: bool = False, base_folder: str = ""
241+
comm_config: CommConfig,
242+
enable_cpu_backend: bool = False,
243+
base_folder: str = "",
244+
replica_id: int | None = None,
242245
):
243246
def _warn_overwrite_env(env, val):
244247
if env in os.environ:
@@ -279,9 +282,17 @@ def _get_distributed_backend(enable_cpu_backend):
279282
os.makedirs(dump_dir, exist_ok=True)
280283
_warn_overwrite_env(TRACE_FILE, f"{dump_dir}/{prefix}")
281284

285+
local_rank = os.environ.get("RANK")
286+
world_size = os.environ.get("WORLD_SIZE")
287+
288+
global_rank = None
289+
if local_rank is not None and replica_id is not None and world_size is not None:
290+
global_rank = int(local_rank) + int(replica_id) * int(world_size)
291+
282292
torch.distributed.init_process_group(
283293
backend=_get_distributed_backend(enable_cpu_backend),
284294
timeout=timedelta(seconds=comm_config.init_timeout_seconds),
295+
rank=global_rank,
285296
)
286297

287298

@@ -432,9 +443,7 @@ def _clip_grad_norm_with_ep(
432443
if math.isinf(norm_type):
433444
total_norm = torch.maximum(ep_grads_total_norm, non_ep_grads_total_norm)
434445
else:
435-
total_norm = (
436-
ep_grads_total_norm**norm_type + non_ep_grads_total_norm**norm_type
437-
)
446+
total_norm = ep_grads_total_norm**norm_type + non_ep_grads_total_norm**norm_type
438447
total_norm **= 1.0 / norm_type
439448

440449
if pp_mesh is not None:

torchtitan/train.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,11 @@ def __init__(self, job_config: JobConfig):
9898
job_config.comm,
9999
enable_cpu_backend=job_config.training.enable_cpu_offload,
100100
base_folder=job_config.job.dump_folder,
101+
replica_id=(
102+
job_config.fault_tolerance.replica_id
103+
if job_config.fault_tolerance.enable
104+
else None
105+
),
101106
)
102107
world_size = int(os.environ["WORLD_SIZE"])
103108
parallelism_config = job_config.parallelism

0 commit comments

Comments
 (0)