diff --git a/nemo_run/core/execution/launcher.py b/nemo_run/core/execution/launcher.py index cad0ebd3..f24a1d2b 100644 --- a/nemo_run/core/execution/launcher.py +++ b/nemo_run/core/execution/launcher.py @@ -47,6 +47,7 @@ def transform(self, cmd: list[str]) -> Optional[Script]: ... class Torchrun(Launcher): rdzv_backend: str = "c10d" rdzv_port: int = 29500 + rdzv_id: Optional[int] = None @dataclass(kw_only=True) @@ -56,6 +57,7 @@ class FaultTolerance(Launcher): job_results_file: str = "" rdzv_backend: str = "c10d" rdzv_port: int = 29500 + rdzv_id: Optional[int] = None workload_check_interval: Optional[float] = None initial_rank_heartbeat_timeout: Optional[float] = None rank_heartbeat_timeout: Optional[float] = None diff --git a/nemo_run/core/execution/local.py b/nemo_run/core/execution/local.py index 31f96387..e8954bae 100644 --- a/nemo_run/core/execution/local.py +++ b/nemo_run/core/execution/local.py @@ -37,6 +37,7 @@ class LocalExecutor(Executor): #: Used by components like torchrun to deduce the number of tasks to launch. ntasks_per_node: int = 1 + nodes: int = 1 def assign( self, @@ -50,7 +51,7 @@ def assign( self.job_dir = os.path.join(exp_dir, task_dir) def nnodes(self) -> int: - return 1 + return self.nodes def nproc_per_node(self) -> int: return self.ntasks_per_node diff --git a/nemo_run/run/torchx_backend/components/ft_launcher.py b/nemo_run/run/torchx_backend/components/ft_launcher.py index 8ebd7c07..3920041f 100644 --- a/nemo_run/run/torchx_backend/components/ft_launcher.py +++ b/nemo_run/run/torchx_backend/components/ft_launcher.py @@ -40,6 +40,7 @@ def ft_launcher( max_retries: int = 0, rdzv_port: int = 49450, rdzv_backend: str = "c10d", + rdzv_id: Optional[int] = None, mounts: Optional[list[str]] = None, debug: bool = False, workload_check_interval: Optional[float] = None, @@ -48,6 +49,8 @@ def ft_launcher( rank_termination_signal: Optional[str] = None, log_level: Optional[str] = None, max_restarts: Optional[int] = None, + dgxc: bool = False, + use_env: bool = False, ) -> specs.AppDef: torchrun_component = torchrun.torchrun( *script_args, @@ -63,10 +66,13 @@ def ft_launcher( j=j, rdzv_backend=rdzv_backend, rdzv_port=rdzv_port, + rdzv_id=rdzv_id, env=env, mounts=mounts, debug=debug, max_retries=max_retries, + dgxc=dgxc, + use_env=use_env, ) ft_args = [] diff --git a/nemo_run/run/torchx_backend/components/torchrun.py b/nemo_run/run/torchx_backend/components/torchrun.py index 6a5c4d3d..05ed96e0 100644 --- a/nemo_run/run/torchx_backend/components/torchrun.py +++ b/nemo_run/run/torchx_backend/components/torchrun.py @@ -57,9 +57,11 @@ def torchrun( max_retries: int = 0, rdzv_port: int = 49450, rdzv_backend: str = "c10d", + rdzv_id: Optional[int] = None, mounts: Optional[list[str]] = None, debug: bool = False, dgxc: bool = False, + use_env: bool = False, ) -> specs.AppDef: """ Distributed data parallel style application (one role, multi-replica). @@ -113,17 +115,21 @@ def torchrun( nproc_per_node = str(nproc_per_node) node_rank = "0" else: - # for multi-node, rely on the rank0_env environment variable set by - # the schedulers (see scheduler implementation for the actual env var this maps to) - # some schedulers (e.g. aws batch) make the rank0's ip-addr available on all BUT on rank0 - # so default to "localhost" if the env var is not set or is empty - # rdzv_endpoint bash resolves to something to the effect of - # ${TORCHX_RANK0_HOST:=localhost}:29500 - # use $$ in the prefix to escape the '$' literal (rather than a string Template substitution argument) - rdzv_endpoint = torchx_dist._noquote(f"$${ExecutorMacros.HEAD_NODE_IP_VAR}:{rdzv_port}") - num_nodes = torchx_dist._noquote(f"$${ExecutorMacros.NUM_NODES_VAR}") + if use_env and os.getenv("MASTER_ADDR") and os.getenv("MASTER_PORT"): + master_addr = os.environ["MASTER_ADDR"] + master_port = os.environ["MASTER_PORT"] + rdzv_endpoint = torchx_dist._noquote(master_addr + ":" + master_port) + random.seed(rdzv_id) + else: + rdzv_endpoint = torchx_dist._noquote(f"$${ExecutorMacros.HEAD_NODE_IP_VAR}:{rdzv_port}") + + num_nodes = nnodes_rep nproc_per_node = str(nproc_per_node) - node_rank = torchx_dist._noquote(f"$${ExecutorMacros.NODE_RANK_VAR}") + + if use_env and os.getenv("NODE_RANK"): + node_rank = os.environ["NODE_RANK"] + else: + node_rank = torchx_dist._noquote(f"$${ExecutorMacros.NODE_RANK_VAR}") if env is None: env = {} @@ -141,7 +147,7 @@ def torchrun( "--rdzv-endpoint", rdzv_endpoint, "--rdzv-id", - f"{random.randint(1, 10000)}", + f"{rdzv_id or random.randint(1, 10000)}", "--nnodes", num_nodes, "--nproc-per-node", diff --git a/nemo_run/run/torchx_backend/packaging.py b/nemo_run/run/torchx_backend/packaging.py index 2c0fdade..49857c90 100644 --- a/nemo_run/run/torchx_backend/packaging.py +++ b/nemo_run/run/torchx_backend/packaging.py @@ -145,6 +145,7 @@ def _get_details_from_script(fn_or_script: Script, serialize_configs: bool): transformed_script, serialize_configs=False ) + use_env = isinstance(executor, LocalExecutor) if launcher and isinstance(launcher, Torchrun): app_def = torchrun.torchrun( *args, @@ -160,11 +161,13 @@ def _get_details_from_script(fn_or_script: Script, serialize_configs: bool): j=f"{executor.nnodes()}x{executor.nproc_per_node()}", rdzv_backend=launcher.rdzv_backend, rdzv_port=launcher.rdzv_port, + rdzv_id=launcher.rdzv_id, env=env, mounts=mounts, debug=executor.packager.debug, max_retries=executor.retries, dgxc=isinstance(executor, DGXCloudExecutor), + use_env=use_env, ) elif launcher and isinstance(launcher, FaultTolerance): app_def = ft_launcher.ft_launcher( @@ -181,6 +184,7 @@ def _get_details_from_script(fn_or_script: Script, serialize_configs: bool): j=f"{executor.nnodes()}x{executor.nproc_per_node()}", rdzv_backend=launcher.rdzv_backend, rdzv_port=launcher.rdzv_port, + rdzv_id=launcher.rdzv_id, env=env, mounts=mounts, debug=executor.packager.debug, @@ -191,6 +195,7 @@ def _get_details_from_script(fn_or_script: Script, serialize_configs: bool): log_level=launcher.log_level, max_retries=executor.retries, max_restarts=launcher.max_restarts, + use_env=use_env, ) else: app_def = specs.AppDef( diff --git a/test/run/torchx_backend/test_packaging.py b/test/run/torchx_backend/test_packaging.py index 9e2b4f54..9343c637 100644 --- a/test/run/torchx_backend/test_packaging.py +++ b/test/run/torchx_backend/test_packaging.py @@ -207,7 +207,7 @@ def test_package_torchrun(mock_executor): "--rdzv-id", "1", "--nnodes", - "$$${num_nodes_var}", + "2", "--nproc-per-node", "1", "--node-rank",