From a6274e0c6b95de1b0fea29a8537d6bf2055d524d Mon Sep 17 00:00:00 2001 From: Hemil Desai Date: Thu, 30 Jan 2025 03:36:50 -0800 Subject: [PATCH 1/3] Support torchrun multi node on local executor Signed-off-by: Hemil Desai --- nemo_run/core/execution/local.py | 3 ++- .../torchx_backend/components/ft_launcher.py | 2 ++ .../run/torchx_backend/components/torchrun.py | 24 +++++++++++-------- nemo_run/run/torchx_backend/packaging.py | 2 ++ test/run/torchx_backend/test_packaging.py | 2 +- 5 files changed, 21 insertions(+), 12 deletions(-) diff --git a/nemo_run/core/execution/local.py b/nemo_run/core/execution/local.py index 31f96387..e8954bae 100644 --- a/nemo_run/core/execution/local.py +++ b/nemo_run/core/execution/local.py @@ -37,6 +37,7 @@ class LocalExecutor(Executor): #: Used by components like torchrun to deduce the number of tasks to launch. ntasks_per_node: int = 1 + nodes: int = 1 def assign( self, @@ -50,7 +51,7 @@ def assign( self.job_dir = os.path.join(exp_dir, task_dir) def nnodes(self) -> int: - return 1 + return self.nodes def nproc_per_node(self) -> int: return self.ntasks_per_node diff --git a/nemo_run/run/torchx_backend/components/ft_launcher.py b/nemo_run/run/torchx_backend/components/ft_launcher.py index 8ebd7c07..2f880f88 100644 --- a/nemo_run/run/torchx_backend/components/ft_launcher.py +++ b/nemo_run/run/torchx_backend/components/ft_launcher.py @@ -48,6 +48,7 @@ def ft_launcher( rank_termination_signal: Optional[str] = None, log_level: Optional[str] = None, max_restarts: Optional[int] = None, + use_env: bool = False, ) -> specs.AppDef: torchrun_component = torchrun.torchrun( *script_args, @@ -67,6 +68,7 @@ def ft_launcher( mounts=mounts, debug=debug, max_retries=max_retries, + use_env=use_env, ) ft_args = [] diff --git a/nemo_run/run/torchx_backend/components/torchrun.py b/nemo_run/run/torchx_backend/components/torchrun.py index 6a5c4d3d..ccbe83cc 100644 --- a/nemo_run/run/torchx_backend/components/torchrun.py +++ b/nemo_run/run/torchx_backend/components/torchrun.py @@ -60,6 +60,7 @@ def torchrun( mounts: Optional[list[str]] = None, debug: bool = False, dgxc: bool = False, + use_env: bool = False, ) -> specs.AppDef: """ Distributed data parallel style application (one role, multi-replica). @@ -113,17 +114,20 @@ def torchrun( nproc_per_node = str(nproc_per_node) node_rank = "0" else: - # for multi-node, rely on the rank0_env environment variable set by - # the schedulers (see scheduler implementation for the actual env var this maps to) - # some schedulers (e.g. aws batch) make the rank0's ip-addr available on all BUT on rank0 - # so default to "localhost" if the env var is not set or is empty - # rdzv_endpoint bash resolves to something to the effect of - # ${TORCHX_RANK0_HOST:=localhost}:29500 - # use $$ in the prefix to escape the '$' literal (rather than a string Template substitution argument) - rdzv_endpoint = torchx_dist._noquote(f"$${ExecutorMacros.HEAD_NODE_IP_VAR}:{rdzv_port}") - num_nodes = torchx_dist._noquote(f"$${ExecutorMacros.NUM_NODES_VAR}") + if use_env and os.getenv("MASTER_ADDR") and os.getenv("MASTER_PORT"): + master_addr = os.environ["MASTER_ADDR"] + master_port = os.environ["MASTER_PORT"] + rdzv_endpoint = torchx_dist._noquote(master_addr + ":" + master_port) + else: + rdzv_endpoint = torchx_dist._noquote(f"$${ExecutorMacros.HEAD_NODE_IP_VAR}:{rdzv_port}") + + num_nodes = nnodes_rep nproc_per_node = str(nproc_per_node) - node_rank = torchx_dist._noquote(f"$${ExecutorMacros.NODE_RANK_VAR}") + + if use_env and os.getenv("NODE_RANK"): + node_rank = os.environ["NODE_RANK"] + else: + node_rank = torchx_dist._noquote(f"$${ExecutorMacros.NODE_RANK_VAR}") if env is None: env = {} diff --git a/nemo_run/run/torchx_backend/packaging.py b/nemo_run/run/torchx_backend/packaging.py index 2c0fdade..168a877f 100644 --- a/nemo_run/run/torchx_backend/packaging.py +++ b/nemo_run/run/torchx_backend/packaging.py @@ -165,6 +165,7 @@ def _get_details_from_script(fn_or_script: Script, serialize_configs: bool): debug=executor.packager.debug, max_retries=executor.retries, dgxc=isinstance(executor, DGXCloudExecutor), + use_env=use_env, ) elif launcher and isinstance(launcher, FaultTolerance): app_def = ft_launcher.ft_launcher( @@ -191,6 +192,7 @@ def _get_details_from_script(fn_or_script: Script, serialize_configs: bool): log_level=launcher.log_level, max_retries=executor.retries, max_restarts=launcher.max_restarts, + use_env=use_env, ) else: app_def = specs.AppDef( diff --git a/test/run/torchx_backend/test_packaging.py b/test/run/torchx_backend/test_packaging.py index 9e2b4f54..9343c637 100644 --- a/test/run/torchx_backend/test_packaging.py +++ b/test/run/torchx_backend/test_packaging.py @@ -207,7 +207,7 @@ def test_package_torchrun(mock_executor): "--rdzv-id", "1", "--nnodes", - "$$${num_nodes_var}", + "2", "--nproc-per-node", "1", "--node-rank", From 75dbc7e1d5494a6d0967d8eae2b25d50aa7a98c4 Mon Sep 17 00:00:00 2001 From: Hemil Desai Date: Fri, 11 Apr 2025 11:01:20 -0700 Subject: [PATCH 2/3] fix Signed-off-by: Hemil Desai --- nemo_run/core/execution/launcher.py | 2 ++ nemo_run/run/torchx_backend/components/ft_launcher.py | 4 ++++ nemo_run/run/torchx_backend/components/torchrun.py | 6 +++++- nemo_run/run/torchx_backend/packaging.py | 3 +++ 4 files changed, 14 insertions(+), 1 deletion(-) diff --git a/nemo_run/core/execution/launcher.py b/nemo_run/core/execution/launcher.py index cad0ebd3..f24a1d2b 100644 --- a/nemo_run/core/execution/launcher.py +++ b/nemo_run/core/execution/launcher.py @@ -47,6 +47,7 @@ def transform(self, cmd: list[str]) -> Optional[Script]: ... class Torchrun(Launcher): rdzv_backend: str = "c10d" rdzv_port: int = 29500 + rdzv_id: Optional[int] = None @dataclass(kw_only=True) @@ -56,6 +57,7 @@ class FaultTolerance(Launcher): job_results_file: str = "" rdzv_backend: str = "c10d" rdzv_port: int = 29500 + rdzv_id: Optional[int] = None workload_check_interval: Optional[float] = None initial_rank_heartbeat_timeout: Optional[float] = None rank_heartbeat_timeout: Optional[float] = None diff --git a/nemo_run/run/torchx_backend/components/ft_launcher.py b/nemo_run/run/torchx_backend/components/ft_launcher.py index 2f880f88..3920041f 100644 --- a/nemo_run/run/torchx_backend/components/ft_launcher.py +++ b/nemo_run/run/torchx_backend/components/ft_launcher.py @@ -40,6 +40,7 @@ def ft_launcher( max_retries: int = 0, rdzv_port: int = 49450, rdzv_backend: str = "c10d", + rdzv_id: Optional[int] = None, mounts: Optional[list[str]] = None, debug: bool = False, workload_check_interval: Optional[float] = None, @@ -48,6 +49,7 @@ def ft_launcher( rank_termination_signal: Optional[str] = None, log_level: Optional[str] = None, max_restarts: Optional[int] = None, + dgxc: bool = False, use_env: bool = False, ) -> specs.AppDef: torchrun_component = torchrun.torchrun( @@ -64,10 +66,12 @@ def ft_launcher( j=j, rdzv_backend=rdzv_backend, rdzv_port=rdzv_port, + rdzv_id=rdzv_id, env=env, mounts=mounts, debug=debug, max_retries=max_retries, + dgxc=dgxc, use_env=use_env, ) diff --git a/nemo_run/run/torchx_backend/components/torchrun.py b/nemo_run/run/torchx_backend/components/torchrun.py index ccbe83cc..08d1eab2 100644 --- a/nemo_run/run/torchx_backend/components/torchrun.py +++ b/nemo_run/run/torchx_backend/components/torchrun.py @@ -39,6 +39,8 @@ * ``TORCH_SHOW_CPP_STACKTRACES``: Read more `here `__. """ +DEFAULT_SEED = 11111 + # Adapted from https://github.com/pytorch/torchx/blob/main/torchx/components/dist.py def torchrun( @@ -57,6 +59,7 @@ def torchrun( max_retries: int = 0, rdzv_port: int = 49450, rdzv_backend: str = "c10d", + rdzv_id: Optional[int] = None, mounts: Optional[list[str]] = None, debug: bool = False, dgxc: bool = False, @@ -118,6 +121,7 @@ def torchrun( master_addr = os.environ["MASTER_ADDR"] master_port = os.environ["MASTER_PORT"] rdzv_endpoint = torchx_dist._noquote(master_addr + ":" + master_port) + random.seed(DEFAULT_SEED) else: rdzv_endpoint = torchx_dist._noquote(f"$${ExecutorMacros.HEAD_NODE_IP_VAR}:{rdzv_port}") @@ -145,7 +149,7 @@ def torchrun( "--rdzv-endpoint", rdzv_endpoint, "--rdzv-id", - f"{random.randint(1, 10000)}", + f"{rdzv_id or random.randint(1, 10000)}", "--nnodes", num_nodes, "--nproc-per-node", diff --git a/nemo_run/run/torchx_backend/packaging.py b/nemo_run/run/torchx_backend/packaging.py index 168a877f..49857c90 100644 --- a/nemo_run/run/torchx_backend/packaging.py +++ b/nemo_run/run/torchx_backend/packaging.py @@ -145,6 +145,7 @@ def _get_details_from_script(fn_or_script: Script, serialize_configs: bool): transformed_script, serialize_configs=False ) + use_env = isinstance(executor, LocalExecutor) if launcher and isinstance(launcher, Torchrun): app_def = torchrun.torchrun( *args, @@ -160,6 +161,7 @@ def _get_details_from_script(fn_or_script: Script, serialize_configs: bool): j=f"{executor.nnodes()}x{executor.nproc_per_node()}", rdzv_backend=launcher.rdzv_backend, rdzv_port=launcher.rdzv_port, + rdzv_id=launcher.rdzv_id, env=env, mounts=mounts, debug=executor.packager.debug, @@ -182,6 +184,7 @@ def _get_details_from_script(fn_or_script: Script, serialize_configs: bool): j=f"{executor.nnodes()}x{executor.nproc_per_node()}", rdzv_backend=launcher.rdzv_backend, rdzv_port=launcher.rdzv_port, + rdzv_id=launcher.rdzv_id, env=env, mounts=mounts, debug=executor.packager.debug, From 32bd62354a3d5f7f14aab381896631b9e3bc3204 Mon Sep 17 00:00:00 2001 From: Hemil Desai Date: Fri, 11 Apr 2025 11:15:40 -0700 Subject: [PATCH 3/3] fix Signed-off-by: Hemil Desai --- nemo_run/run/torchx_backend/components/torchrun.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/nemo_run/run/torchx_backend/components/torchrun.py b/nemo_run/run/torchx_backend/components/torchrun.py index 08d1eab2..05ed96e0 100644 --- a/nemo_run/run/torchx_backend/components/torchrun.py +++ b/nemo_run/run/torchx_backend/components/torchrun.py @@ -39,8 +39,6 @@ * ``TORCH_SHOW_CPP_STACKTRACES``: Read more `here `__. """ -DEFAULT_SEED = 11111 - # Adapted from https://github.com/pytorch/torchx/blob/main/torchx/components/dist.py def torchrun( @@ -121,7 +119,7 @@ def torchrun( master_addr = os.environ["MASTER_ADDR"] master_port = os.environ["MASTER_PORT"] rdzv_endpoint = torchx_dist._noquote(master_addr + ":" + master_port) - random.seed(DEFAULT_SEED) + random.seed(rdzv_id) else: rdzv_endpoint = torchx_dist._noquote(f"$${ExecutorMacros.HEAD_NODE_IP_VAR}:{rdzv_port}")