Skip to content

Add CLI workflow for diffusion UNet training#33

Draft
wyli wants to merge 1 commit into
NVIDIA-Medtech:mainfrom
wyli:codex/diff-unet-train-workflow
Draft

Add CLI workflow for diffusion UNet training#33
wyli wants to merge 1 commit into
NVIDIA-Medtech:mainfrom
wyli:codex/diff-unet-train-workflow

Conversation

@wyli

@wyli wyli commented May 28, 2026

Copy link
Copy Markdown

Summary

  • Add a noninteractive CLI extraction of train_diff_unet_tutorial.ipynb as scripts/diff_model_train_workflow.py
  • Stage model/environment/network configs, datalists, and embedding sidecar JSON files before invoking the existing create-training-data/train/infer scripts
  • Support MR-brain, MR, rflow CT, and DDPM CT generate versions with optional inference after training

Validation

  • python -m py_compile scripts/diff_model_train_workflow.py
  • python -m scripts.diff_model_train_workflow --help

@greptile-apps

greptile-apps Bot commented May 28, 2026

Copy link
Copy Markdown
Contributor

Greptile Summary

This PR adds scripts/diff_model_train_workflow.py, a noninteractive CLI extraction of train_diff_unet_tutorial.ipynb that stages model/environment/network configs, creates per-embedding conditioning sidecar JSON files, and orchestrates create-training-data → train → infer for four model variants (DDPM-CT, rflow-CT, rflow-MR, rflow-MR-brain).

  • Stages all required JSON configs into a user-specified --work-dir, then delegates to the existing diff_model_create_training_data, diff_model_train, and diff_model_infer functions.
  • Adds argument parsing with sensible defaults for epochs, batch size, learning rate, modality, inference dimensions, and region indices, and writes a workflow_summary.json on completion.
  • One functional defect: _create_embedding_sidecars is called unconditionally even when --skip-create-training-data is passed, causing a FileNotFoundError on any fresh working directory where the embedding subdirectory hasn't been created yet.

Confidence Score: 3/5

The new workflow script orchestrates existing steps correctly but has a defect where using --skip-create-training-data on a fresh working directory crashes before training begins.

The unconditional call to _create_embedding_sidecars after the skipped training-data step will raise FileNotFoundError whenever the embedding subdirectory hasn't been created yet — a common scenario for anyone resuming a partially failed run with --skip-create-training-data pointing to an empty work dir. The fix is straightforward but the current code will fail before any training happens.

scripts/diff_model_train_workflow.py — specifically the unconditional _create_embedding_sidecars call in main() and the flag-conflict handling around --train-from-scratch.

Important Files Changed

Filename Overview
scripts/diff_model_train_workflow.py New CLI workflow script that stages configs, creates embedding sidecars, and orchestrates training/inference — crashes on fresh work dirs when --skip-create-training-data is used, and has minor flag-conflict/error-message issues.

Flowchart

%%{init: {'theme': 'neutral'}}%%
flowchart TD
    A[CLI args parsed] --> B[Set default modality]
    B --> C{--download-model-data?}
    C -- yes --> D[download_model_data]
    C -- no --> E
    D --> E[_stage_configs\nwrite env/model/network JSONs\nstage datalist]
    E --> F{--skip-create-training-data?}
    F -- no --> G[diff_model_create_training_data\ncreate latent embeddings]
    F -- yes --> H
    G --> H[_create_embedding_sidecars\nwrite .nii.gz.json files]
    H --> I{--skip-train?}
    I -- no --> J[diff_model_train]
    I -- yes --> K
    J --> K{--run-inference?}
    K -- yes --> L[diff_model_infer]
    K -- no --> M
    L --> M[Write workflow_summary.json]

    style H fill:#ffcccc,stroke:#cc0000
    H -.->|crashes if embedding dir does not exist| H
Loading

Reviews (1): Last reviewed commit: "Add CLI workflow for diffusion UNet trai..." | Re-trigger Greptile

Comment on lines +309 to +315
sidecars = _create_embedding_sidecars(
staged["embedding_dir"],
args.modality,
staged["include_body_region"],
args.top_region_index,
args.bottom_region_index,
)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 _create_embedding_sidecars crashes when embedding dir doesn't exist

_create_embedding_sidecars is called unconditionally regardless of --skip-create-training-data. When that flag is used on a fresh --work-dir (i.e., the embedding directory has never been populated), sorted(embedding_base_dir.rglob("*_emb.nii.gz")) raises FileNotFoundError because the directory doesn't exist yet. The error surface is unhelpful — there's no message indicating the user should drop --skip-create-training-data or point to an existing work dir with embeddings. Consider guarding with embedding_base_dir.is_dir() or only calling _create_embedding_sidecars when embeddings are present.

Comment thread scripts/diff_model_train_workflow.py
Comment on lines +81 to +85
def _modality_mapping(repo_root: Path) -> dict[str, int]:
path = repo_root / "configs" / "modality_mapping.json"
if not path.is_file():
return {}
return {str(k): int(v) for k, v in _load_json(path).items()}

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Misleading error when modality_mapping.json is absent

_modality_mapping silently returns {} when modality_mapping.json doesn't exist. The subsequent ValueError then says the modality was "not found in configs/modality_mapping.json", implying the file exists but lacks the entry — when in fact the file itself is missing. Distinguishing the two cases helps the user understand what's wrong faster.

Suggested change
def _modality_mapping(repo_root: Path) -> dict[str, int]:
path = repo_root / "configs" / "modality_mapping.json"
if not path.is_file():
return {}
return {str(k): int(v) for k, v in _load_json(path).items()}
def _modality_mapping(repo_root: Path) -> dict[str, int]:
path = repo_root / "configs" / "modality_mapping.json"
if not path.is_file():
raise FileNotFoundError(f"modality_mapping.json not found at {path}")
return {str(k): int(v) for k, v in _load_json(path).items()}

@wyli wyli marked this pull request as draft May 28, 2026 10:40
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant