Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 26 additions & 6 deletions script/slurm/slurm_nss_matrix.sh
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,11 @@ EXP_NAME="${EXP_NAME:-multi_jobs}"
BASE_LOG_DIR="${BASE_LOG_DIR:-./}"


cd "${NSS_DIR}"
# Only cd into the repo when using the repo-based install; in PyPI mode the
# working directory does not matter for package resolution.
if [[ -z "${NSS_VERSION:-}" ]]; then
cd "${NSS_DIR}"
fi

# Get dataset and config name from packed strings
declare -a all_datasets
Expand Down Expand Up @@ -88,8 +92,24 @@ apt-get update && apt-get install -y --no-install-recommends \

# Ensure Python environment is available inside the container
source "${LUSTRE_DIR}/.uv/bin/env"
source "${NSS_DIR}/.venv/bin/activate"
uv sync --frozen --extra cu128 --extra engine --group dev
if [[ -n "${NSS_VERSION:-}" ]]; then
# Install nemo-safe-synthesizer from PyPI into a versioned venv cached on
# lustre so concurrent array jobs can share it without redundant downloads.
PYPI_VENV="${LUSTRE_DIR}/.venv_nss_${NSS_VERSION}"
uv venv --python 3.11 "${PYPI_VENV}"
source "${PYPI_VENV}/bin/activate"
uv pip install "nemo-safe-synthesizer[cu128,engine]==${NSS_VERSION}" \
--extra-index-url https://download.pytorch.org/whl/cu128 \
--extra-index-url https://flashinfer.ai/whl/cu128 \
--index-strategy unsafe-best-match
NSS_RUN_CMD="${PYPI_VENV}/bin/safe-synthesizer"
echo "[NSS SLURM] Using PyPI install: nemo-safe-synthesizer==${NSS_VERSION}"
else
source "${NSS_DIR}/.venv/bin/activate"
uv sync --frozen --extra cu128 --extra engine --group dev
NSS_RUN_CMD="uv run safe-synthesizer"
fi
echo "[NSS SLURM] nemo-safe-synthesizer version: $(python -c 'from nemo_safe_synthesizer.package_info import __version__; print(__version__)')"


# for column classification
Expand Down Expand Up @@ -166,7 +186,7 @@ fi
if [[ "${NSS_PHASE}" == "train" ]]; then
# Stage 1: PII replacement + training
# Creates new workdir at run_path with adapter
uv run safe-synthesizer run train \
${NSS_RUN_CMD} run train \
--data-source "$dataset" \
--config "$full_config_path" \
--run-path "$run_path" \
Expand All @@ -182,15 +202,15 @@ elif [[ "${NSS_PHASE}" == "generate" ]]; then
wandb_resume_arg="--wandb-resume-job-id $wandb_id_file"
fi

uv run safe-synthesizer run generate \
${NSS_RUN_CMD} run generate \
--data-source "$dataset" \
--config "$full_config_path" \
--run-path "$run_path" \
$dataset_registry_arg \
$wandb_resume_arg
else
# Full end-to-end run
uv run safe-synthesizer run \
${NSS_RUN_CMD} run \
--data-source "$dataset" \
--config "$full_config_path" \
--run-path "$run_path" \
Expand Down
10 changes: 9 additions & 1 deletion script/slurm/submit_slurm_jobs.sh
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ PIPELINE_MODE="end_to_end" # values: two_stage | end_to_end
CONFIGS_CSV="" # optional override for CONFIGS array (comma-separated)
WANDB_PROJECT="" # optional wandb project name; uses EXP_NAME if not provided
MAX_CONCURRENT_SLURM_JOBS="" # optional max number of concurrent slurm jobs to run within each array; if not provided, no restriction is applied
NSS_VERSION="" # optional: install nemo-safe-synthesizer from PyPI at this version instead of syncing the repo
ACCOUNT="${ACCOUNT:-llmservice_sdg_research}"
TIME_LIMIT="04:00:00"
TRAIN_TIME_LIMIT=""
Expand All @@ -43,6 +44,8 @@ while [ $# -gt 0 ]; do
WANDB_PROJECT="${2:-}"; shift 2;;
--max-concurrent-slurm-jobs)
MAX_CONCURRENT_SLURM_JOBS="${2:-$MAX_CONCURRENT_SLURM_JOBS}"; shift 2;;
--nss-version|-V)
NSS_VERSION="${2:-}"; shift 2;;
--time-limit|-t)
TIME_LIMIT="${2:-$TIME_LIMIT}"; shift 2;;
--train-time-limit)
Expand All @@ -52,12 +55,16 @@ while [ $# -gt 0 ]; do
--dry-run)
DRY_RUN="true"; shift;;
--help|-h)
echo "Usage: $0 [--configs c1,c2] [--dataset-urls name1,url1,path1] [--dataset-group short|long] [--runs N] [--exp-name NAME] [--pipeline-mode two_stage|end_to_end] [--partition P] [--wandb-project PROJECT] [--max-concurrent-slurm-jobs N] [--time-limit TIME] [--train-time-limit TIME] [--generate-time-limit TIME] [--dry-run]"
echo "Usage: $0 [--configs c1,c2] [--dataset-urls name1,url1,path1] [--dataset-group short|long] [--runs N] [--exp-name NAME] [--pipeline-mode two_stage|end_to_end] [--partition P] [--wandb-project PROJECT] [--max-concurrent-slurm-jobs N] [--time-limit TIME] [--train-time-limit TIME] [--generate-time-limit TIME] [--nss-version VERSION] [--dry-run]"
echo ""
echo "Provide either --dataset-urls to specify a list of datasets by name, url, or path, or --dataset-group to use a predefined set of datasets."
echo "Time limits:"
echo " --time-limit is used for end_to_end mode (defaults to 4 hours)"
echo " --train-time-limit and --generate-time-limit are used for two_stage mode, and will default to --time-limit if the more more specific train and generate limits are not provided"
echo "Package installation:"
echo " --nss-version VERSION install nemo-safe-synthesizer==VERSION from PyPI instead of syncing the repo"
echo " example: --nss-version 0.2.3"
echo " if omitted, the repo at NSS_DIR is used (default behavior)"

exit 0;;
--) shift; break;;
Expand Down Expand Up @@ -115,6 +122,7 @@ fi

export ACCOUNT
export EXP_NAME
export NSS_VERSION

# Build configs list: CLI override (comma-separated) takes precedence; otherwise use CONFIGS from env
declare -a CONFIGS_LIST
Expand Down
Loading