diff --git a/nemo_run/core/execution/slurm.py b/nemo_run/core/execution/slurm.py index 3539a83f..f483a61e 100644 --- a/nemo_run/core/execution/slurm.py +++ b/nemo_run/core/execution/slurm.py @@ -560,7 +560,7 @@ def get_nsys_entrypoint(self) -> str: launcher = self.get_launcher() entrypoint, postfix = "nsys", "" if launcher.nsys_gpu_metrics: - entrypoint = 'bash -c \'GPU_METRICS_FLAG=""; if [ "$SLURM_PROCID" -eq 0 ]; then GPU_METRICS_FLAG="--gpu-metrics-devices=all"; fi; nsys' + entrypoint = 'bash -c \'GPU_METRICS_FLAG=""; if echo "${GPU_METRICS_NODES}" | grep -q -w "${SLURM_NODEID}"; then GPU_METRICS_FLAG="--gpu-metrics-devices=${SLURM_LOCALID}"; fi; nsys' postfix = "'" return (entrypoint, postfix) diff --git a/test/core/execution/test_slurm.py b/test/core/execution/test_slurm.py index 4c681111..5f194df0 100644 --- a/test/core/execution/test_slurm.py +++ b/test/core/execution/test_slurm.py @@ -198,7 +198,7 @@ def test_get_nsys_entrypoint(self): with patch.object(executor, "get_launcher", return_value=launcher_mock): assert executor.get_nsys_entrypoint() == ( - 'bash -c \'GPU_METRICS_FLAG=""; if [ "$SLURM_PROCID" -eq 0 ]; then GPU_METRICS_FLAG="--gpu-metrics-devices=all"; fi; nsys', + 'bash -c \'GPU_METRICS_FLAG=""; if echo "${GPU_METRICS_NODES}" | grep -q -w "${SLURM_NODEID}"; then GPU_METRICS_FLAG="--gpu-metrics-devices=${SLURM_LOCALID}"; fi; nsys', "'", )