diff --git a/nemo_run/run/experiment.py b/nemo_run/run/experiment.py index f7c747aa..49b9e43e 100644 --- a/nemo_run/run/experiment.py +++ b/nemo_run/run/experiment.py @@ -665,7 +665,10 @@ def run( return # Prepare experiment before running - self._prepare() + + # in case of multi-node execution with LocalExecutor+torchrun+slurm, run only on first rank + if int(os.getenv("SLURM_PROCID", 0)) == 0: + self._prepare() if direct: self.console.log( diff --git a/nemo_run/run/torchx_backend/components/torchrun.py b/nemo_run/run/torchx_backend/components/torchrun.py index f1ea4d4f..d9599b0d 100644 --- a/nemo_run/run/torchx_backend/components/torchrun.py +++ b/nemo_run/run/torchx_backend/components/torchrun.py @@ -128,8 +128,11 @@ def torchrun( num_nodes = nnodes_rep nproc_per_node = str(nproc_per_node) + # set node rank to relative node id in the current allocation if use_env and os.getenv("NODE_RANK"): node_rank = os.environ["NODE_RANK"] + elif use_env and os.getenv("SLURM_NODEID"): + node_rank = os.environ["SLURM_NODEID"] else: node_rank = torchx_dist._noquote(f"$${ExecutorMacros.NODE_RANK_VAR}")