Skip to content

Commit c3bf79b

Browse files
committed
feat: add DGX Spark (aarch64) support
- containers/Dockerfile.spark: container-based install using nvcr.io/nvidia/vllm:26.02-py3 - docs/DGX_SPARK.md: quick start guide (build + run in 2 steps) - pyproject.toml: platform markers for aarch64-incompatible packages (faiss-gpu-cu12, torchvision+cu128, torchao, xformers) - config/training.py: auto-fallback Flash Attention 3 to sdpa on aarch64 - vllm_backend.py: handle vllm versions without attention_config kwarg
1 parent 9a4a150 commit c3bf79b

5 files changed

Lines changed: 177 additions & 15 deletions

File tree

containers/Dockerfile.spark

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
# Dockerfile for NeMo Safe Synthesizer on DGX Spark (aarch64)
2+
#
3+
# Base: NVIDIA vLLM container with torch 2.11 + vLLM 0.15.1 + Triton 3.6
4+
#
5+
# Build:
6+
# docker build -f containers/Dockerfile.spark -t nss-spark .
7+
#
8+
# Run:
9+
# docker run --gpus all --ipc=host --ulimit memlock=-1 --ulimit stack=67108864 \
10+
# -it nss-spark
11+
#
12+
FROM nvcr.io/nvidia/vllm:26.02-py3
13+
14+
ENV TRITON_CACHE_DIR=/workspace/.triton_cache
15+
ENV BNB_CUDA_VERSION=130
16+
17+
WORKDIR /workspace/Safe-Synthesizer
18+
COPY . .
19+
20+
# Torch-dependent packages — install with --no-deps to preserve container's torch/CUDA
21+
RUN pip install --no-deps \
22+
peft accelerate bitsandbytes datasets==4.3.0 trl==0.26.1 \
23+
hf_transfer unsloth unsloth_zoo \
24+
opacus sentence-transformers gliner kernels
25+
26+
# Safe Synthesizer + remaining deps
27+
RUN pip install --no-deps -e . && \
28+
pip install \
29+
faker 'pydantic[email]>=2.12.5' pydantic-settings pyyaml jsonschema rich structlog \
30+
colorama 'huggingface-hub>=0.34.4,<1' anyascii pycountry betterproto flashtext \
31+
cached-property category-encoders dython dateparser langchain-core json-repair \
32+
matplotlib 'outlines>=1.0.0' plotly prv-accountant 'smart-open==7.0.5' python-stdnum \
33+
'pandas>=2.1.3' ratelimit 'sqlfluff==3.2.0' 'range_regex>=0.1.0' 'tenacity==9.1.2' \
34+
'tiktoken>=0.7.0' tldextract 'wandb==0.23.1' python-dotenv patsy \
35+
pyarrow multiprocess onnxruntime opt_einsum dill==0.3.8 faiss-cpu
36+
37+
ENTRYPOINT ["/usr/bin/bash"]

docs/DGX_SPARK.md

Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
# NeMo Safe Synthesizer on DGX Spark
2+
3+
Generate synthetic tabular data with quality and privacy guarantees — train, generate, and evaluate in one command.
4+
5+
## Quick Start
6+
7+
### 1. Build and launch the container
8+
9+
```bash
10+
git clone https://github.com/NVIDIA-NeMo/Safe-Synthesizer.git && cd Safe-Synthesizer
11+
docker build -f containers/Dockerfile.spark -t nss-spark .
12+
docker run --gpus all --ipc=host --ulimit memlock=-1 -it --ulimit stack=67108864 nss-spark
13+
```
14+
15+
### 2. Run
16+
17+
```python
18+
python -c "
19+
import pandas as pd, numpy as np
20+
from nemo_safe_synthesizer.sdk.library_builder import SafeSynthesizer
21+
22+
# Sample data — replace with your own CSV or DataFrame
23+
np.random.seed(42)
24+
df = pd.DataFrame({
25+
'age': np.random.randint(18, 85, 500),
26+
'income': np.random.lognormal(10.5, 0.8, 500).astype(int),
27+
'credit_score': np.random.randint(300, 850, 500),
28+
'default': np.random.choice(['yes', 'no'], 500, p=[0.15, 0.85]),
29+
})
30+
31+
builder = (
32+
SafeSynthesizer()
33+
.with_data_source(df)
34+
.with_replace_pii()
35+
.with_generate(num_records=500)
36+
.with_evaluate()
37+
)
38+
builder.run()
39+
40+
s = builder.results.summary
41+
print(f'Quality (SQS): {s.synthetic_data_quality_score}/10')
42+
print(f'Privacy (DPS): {s.data_privacy_score}/10')
43+
builder.save_results()
44+
"
45+
```
46+
47+
Expected: SQS ~8-9, DPS ~9-10.
48+
49+
> **First run is slower.** Model weights (~6 GB) download from HuggingFace and Triton
50+
> JIT-compiles LoRA kernels for the GB10. Subsequent runs reuse cached weights and kernels.
51+
52+
## Use Your Own Data
53+
54+
```python
55+
from nemo_safe_synthesizer.sdk.library_builder import SafeSynthesizer
56+
57+
builder = (
58+
SafeSynthesizer()
59+
.with_data_source("your_data.csv") # or pass a DataFrame
60+
.with_replace_pii()
61+
.with_generate(num_records=1000)
62+
.with_evaluate()
63+
)
64+
builder.run()
65+
builder.save_results()
66+
```
67+
68+
Outputs are saved to `safe-synthesizer-artifacts/` — synthetic CSV and an HTML evaluation report.
69+
70+
## Optional: Improve PII Detection
71+
72+
Set a NIM API key for LLM-based column classification (more accurate than NER-only):
73+
74+
```bash
75+
export NIM_ENDPOINT_URL="https://integrate.api.nvidia.com/v1"
76+
export NIM_API_KEY="your-key" # get one at build.nvidia.com/settings/api-keys
77+
```
78+
79+
## Optional: Differential Privacy
80+
81+
```python
82+
builder = (
83+
SafeSynthesizer()
84+
.with_data_source(df)
85+
.with_replace_pii()
86+
.with_generate(num_records=1000)
87+
.with_differential_privacy(dp_enabled=True, epsilon=8.0)
88+
.with_evaluate()
89+
)
90+
```
91+
92+
## Troubleshooting
93+
94+
**Slow first generation batch?** Triton JIT-compiles LoRA kernels for the GB10 on first use. This is normal and only happens once per container session.
95+
96+
**Memory issues between runs?** Flush the cache:
97+
```bash
98+
sudo sh -c 'sync; echo 3 > /proc/sys/vm/drop_caches'
99+
```
100+
101+
**Why a container?** DGX Spark's CUDA 13 + aarch64 requires specific Triton, vLLM, and PyTorch versions. The container (`nvcr.io/nvidia/vllm:26.02-py3`) provides a tested stack where Unsloth training and vLLM generation work natively.
102+
103+
**Full documentation:** [Safe Synthesizer User Guide](https://github.com/NVIDIA-NeMo/Safe-Synthesizer/blob/main/docs/user-guide/getting-started.md)

pyproject.toml

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,8 @@ cpu = [
141141
cu128 = [
142142
"accelerate",
143143
"bitsandbytes==0.49.1",
144-
"faiss-gpu-cu12==1.13.2; sys_platform == 'linux'",
144+
"faiss-gpu-cu12==1.13.2; sys_platform == 'linux' and platform_machine == 'x86_64'",
145+
"faiss-cpu==1.13.2; sys_platform == 'linux' and platform_machine == 'aarch64'",
145146
"flashinfer-python==0.6.1; sys_platform == 'linux'",
146147
"flashinfer-cubin==0.6.1; sys_platform == 'linux'",
147148
"flashinfer-jit-cache==0.6.1+cu128; sys_platform == 'linux'",
@@ -154,15 +155,16 @@ cu128 = [
154155
"sentence-transformers",
155156
"torch==2.9.1+cu128; sys_platform == 'linux'",
156157
"torch-c-dlpack-ext",
157-
"torchvision==0.24.1+cu128; sys_platform == 'linux'",
158-
"torchao==0.15.0; sys_platform == 'linux'",
158+
"torchvision==0.24.1+cu128; sys_platform == 'linux' and platform_machine == 'x86_64'",
159+
"torchvision==0.24.1; sys_platform == 'linux' and platform_machine == 'aarch64'",
160+
"torchao==0.15.0; sys_platform == 'linux' and platform_machine == 'x86_64'",
159161
"transformers==4.57.3",
160162
"triton>=2.0.0",
161163
"trl>=0.23.0",
162164
"unsloth[cu128-torch291]==2025.12.4",
163165
"unsloth_zoo==2025.12.4",
164166
"vllm==0.15.0",
165-
"xformers==v0.0.33.post2; sys_platform == 'linux'",
167+
"xformers==v0.0.33.post2; sys_platform == 'linux' and platform_machine == 'x86_64'",
166168
]
167169

168170
# at some point, do per-subpackage dependencies
@@ -188,8 +190,8 @@ dependency-metadata = [
188190

189191

190192
override-dependencies = [
191-
"flashinfer-python==0.6.1; sys_platform != 'darwin'", # uv locking won't find the matching versions of flashinfer-python and -cubin without overriding
192-
"flashinfer-cubin==0.6.1; sys_platform != 'darwin'", # perhaps because the published wheels have some wrong metadata
193+
"flashinfer-python==0.6.1; sys_platform != 'darwin' and platform_machine != 'aarch64'", # uv locking won't find the matching versions of flashinfer-python and -cubin without overriding
194+
"flashinfer-cubin==0.6.1; sys_platform != 'darwin' and platform_machine != 'aarch64'", # perhaps because the published wheels have some wrong metadata
193195
"xgrammar>=0.1.32,<1.0.0", # CVE-2026-25048: override vllm's pin on 0.1.29
194196
]
195197

src/nemo_safe_synthesizer/config/training.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,15 @@
33

44
from __future__ import annotations
55

6+
import platform
67
from typing import (
78
Annotated,
89
Literal,
910
)
1011

1112
from pydantic import (
1213
Field,
14+
model_validator,
1315
)
1416

1517
from ..configurator.parameters import (
@@ -265,3 +267,11 @@ class TrainingHyperparams(Parameters):
265267
),
266268
),
267269
] = "kernels-community/vllm-flash-attn3"
270+
271+
@model_validator(mode="after")
272+
def _resolve_platform_defaults(self) -> "TrainingHyperparams":
273+
"""Override defaults that are incompatible with the current platform."""
274+
if platform.machine() == "aarch64":
275+
if self.attn_implementation == "kernels-community/vllm-flash-attn3":
276+
self.attn_implementation = "sdpa"
277+
return self

src/nemo_safe_synthesizer/generation/vllm_backend.py

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -203,16 +203,26 @@ def initialize(self, **kwargs) -> None:
203203
# check this when updating unsloth in the future.
204204
enforce_eager = self.config.training.use_unsloth is True
205205

206+
vllm_kwargs = dict(
207+
model=self.config.training.pretrained_model,
208+
gpu_memory_utilization=max_vram,
209+
enable_lora=True,
210+
max_lora_rank=self.config.training.lora_r,
211+
structured_outputs_config=structured_outputs_config,
212+
enforce_eager=enforce_eager,
213+
)
214+
# attention_config was added in vLLM 0.12+ but removed in some builds.
215+
# Fall back to VLLM_ATTENTION_BACKEND env var if the kwarg is not accepted.
206216
with heartbeat("Model loading", logger_name=__name__, model=self.config.training.pretrained_model):
207-
self.llm = vLLM(
208-
model=self.config.training.pretrained_model,
209-
gpu_memory_utilization=max_vram,
210-
enable_lora=True,
211-
max_lora_rank=self.config.training.lora_r,
212-
structured_outputs_config=structured_outputs_config,
213-
enforce_eager=enforce_eager,
214-
attention_config=attention_config,
215-
)
217+
if attention_config is not None:
218+
try:
219+
self.llm = vLLM(**vllm_kwargs, attention_config=attention_config)
220+
except TypeError:
221+
if attn_backend not in (None, "auto"):
222+
os.environ["VLLM_ATTENTION_BACKEND"] = attn_backend
223+
self.llm = vLLM(**vllm_kwargs)
224+
else:
225+
self.llm = vLLM(**vllm_kwargs)
216226

217227
def _build_structured_output_params(self) -> StructuredOutputsParams | None:
218228
"""Build structured output parameters based on generation config.

0 commit comments

Comments
 (0)