Skip to content

Commit 77694fc

Browse files
committed
chore: allow using a custom wheel on slurm
Signed-off-by: Matt Kornfield <mkornfield@nvidia.com>
1 parent ad93dca commit 77694fc

2 files changed

Lines changed: 32 additions & 7 deletions

File tree

script/slurm/slurm_nss_matrix.sh

Lines changed: 23 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,11 @@ EXP_NAME="${EXP_NAME:-multi_jobs}"
5555
BASE_LOG_DIR="${BASE_LOG_DIR:-./}"
5656

5757

58-
cd "${NSS_DIR}"
58+
# Only cd into the repo when using the repo-based install; in PyPI mode the
59+
# working directory does not matter for package resolution.
60+
if [[ -z "${NSS_VERSION:-}" ]]; then
61+
cd "${NSS_DIR}"
62+
fi
5963

6064
# Get dataset and config name from packed strings
6165
declare -a all_datasets
@@ -88,8 +92,21 @@ apt-get update && apt-get install -y --no-install-recommends \
8892

8993
# Ensure Python environment is available inside the container
9094
source "${LUSTRE_DIR}/.uv/bin/env"
91-
source "${NSS_DIR}/.venv/bin/activate"
92-
uv sync --frozen --extra cu128 --extra engine --group dev
95+
if [[ -n "${NSS_VERSION:-}" ]]; then
96+
# Install nemo-safe-synthesizer from PyPI into a versioned venv cached on
97+
# lustre so concurrent array jobs can share it without redundant downloads.
98+
PYPI_VENV="${LUSTRE_DIR}/.venv_nss_${NSS_VERSION}"
99+
uv venv --python 3.11 "${PYPI_VENV}"
100+
source "${PYPI_VENV}/bin/activate"
101+
uv pip install "nemo-safe-synthesizer[cu128,engine]==${NSS_VERSION}"
102+
NSS_RUN_CMD="${PYPI_VENV}/bin/safe-synthesizer"
103+
echo "[NSS SLURM] Using PyPI install: nemo-safe-synthesizer==${NSS_VERSION}"
104+
else
105+
source "${NSS_DIR}/.venv/bin/activate"
106+
uv sync --frozen --extra cu128 --extra engine --group dev
107+
NSS_RUN_CMD="uv run safe-synthesizer"
108+
fi
109+
echo "[NSS SLURM] nemo-safe-synthesizer version: $(python -c 'from nemo_safe_synthesizer.package_info import __version__; print(__version__)')"
93110

94111

95112
# for column classification
@@ -166,7 +183,7 @@ fi
166183
if [[ "${NSS_PHASE}" == "train" ]]; then
167184
# Stage 1: PII replacement + training
168185
# Creates new workdir at run_path with adapter
169-
uv run safe-synthesizer run train \
186+
${NSS_RUN_CMD} run train \
170187
--data-source "$dataset" \
171188
--config "$full_config_path" \
172189
--run-path "$run_path" \
@@ -182,15 +199,15 @@ elif [[ "${NSS_PHASE}" == "generate" ]]; then
182199
wandb_resume_arg="--wandb-resume-job-id $wandb_id_file"
183200
fi
184201

185-
uv run safe-synthesizer run generate \
202+
${NSS_RUN_CMD} run generate \
186203
--data-source "$dataset" \
187204
--config "$full_config_path" \
188205
--run-path "$run_path" \
189206
$dataset_registry_arg \
190207
$wandb_resume_arg
191208
else
192209
# Full end-to-end run
193-
uv run safe-synthesizer run \
210+
${NSS_RUN_CMD} run \
194211
--data-source "$dataset" \
195212
--config "$full_config_path" \
196213
--run-path "$run_path" \

script/slurm/submit_slurm_jobs.sh

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ PIPELINE_MODE="end_to_end" # values: two_stage | end_to_end
1717
CONFIGS_CSV="" # optional override for CONFIGS array (comma-separated)
1818
WANDB_PROJECT="" # optional wandb project name; uses EXP_NAME if not provided
1919
MAX_CONCURRENT_SLURM_JOBS="" # optional max number of concurrent slurm jobs to run within each array; if not provided, no restriction is applied
20+
NSS_VERSION="" # optional: install nemo-safe-synthesizer from PyPI at this version instead of syncing the repo
2021
ACCOUNT="${ACCOUNT:-llmservice_sdg_research}"
2122
TIME_LIMIT="04:00:00"
2223
TRAIN_TIME_LIMIT=""
@@ -43,6 +44,8 @@ while [ $# -gt 0 ]; do
4344
WANDB_PROJECT="${2:-}"; shift 2;;
4445
--max-concurrent-slurm-jobs)
4546
MAX_CONCURRENT_SLURM_JOBS="${2:-$MAX_CONCURRENT_SLURM_JOBS}"; shift 2;;
47+
--nss-version|-V)
48+
NSS_VERSION="${2:-}"; shift 2;;
4649
--time-limit|-t)
4750
TIME_LIMIT="${2:-$TIME_LIMIT}"; shift 2;;
4851
--train-time-limit)
@@ -52,12 +55,16 @@ while [ $# -gt 0 ]; do
5255
--dry-run)
5356
DRY_RUN="true"; shift;;
5457
--help|-h)
55-
echo "Usage: $0 [--configs c1,c2] [--dataset-urls name1,url1,path1] [--dataset-group short|long] [--runs N] [--exp-name NAME] [--pipeline-mode two_stage|end_to_end] [--partition P] [--wandb-project PROJECT] [--max-concurrent-slurm-jobs N] [--time-limit TIME] [--train-time-limit TIME] [--generate-time-limit TIME] [--dry-run]"
58+
echo "Usage: $0 [--configs c1,c2] [--dataset-urls name1,url1,path1] [--dataset-group short|long] [--runs N] [--exp-name NAME] [--pipeline-mode two_stage|end_to_end] [--partition P] [--wandb-project PROJECT] [--max-concurrent-slurm-jobs N] [--time-limit TIME] [--train-time-limit TIME] [--generate-time-limit TIME] [--nss-version VERSION] [--dry-run]"
5659
echo ""
5760
echo "Provide either --dataset-urls to specify a list of datasets by name, url, or path, or --dataset-group to use a predefined set of datasets."
5861
echo "Time limits:"
5962
echo " --time-limit is used for end_to_end mode (defaults to 4 hours)"
6063
echo " --train-time-limit and --generate-time-limit are used for two_stage mode, and will default to --time-limit if the more more specific train and generate limits are not provided"
64+
echo "Package installation:"
65+
echo " --nss-version VERSION install nemo-safe-synthesizer==VERSION from PyPI instead of syncing the repo"
66+
echo " example: --nss-version 0.2.3"
67+
echo " if omitted, the repo at NSS_DIR is used (default behavior)"
6168

6269
exit 0;;
6370
--) shift; break;;
@@ -115,6 +122,7 @@ fi
115122

116123
export ACCOUNT
117124
export EXP_NAME
125+
export NSS_VERSION
118126

119127
# Build configs list: CLI override (comma-separated) takes precedence; otherwise use CONFIGS from env
120128
declare -a CONFIGS_LIST

0 commit comments

Comments
 (0)