Skip to content

feat: roll out DiffusionNFT under the EMA shadow adapter#21

Open
Jayce-Ping wants to merge 1 commit into
Tencent-Hunyuan:mainfrom
Jayce-Ping:pr/nft-ema-rollout
Open

feat: roll out DiffusionNFT under the EMA shadow adapter#21
Jayce-Ping wants to merge 1 commit into
Tencent-Hunyuan:mainfrom
Jayce-Ping:pr/nft-ema-rollout

Conversation

@Jayce-Ping

@Jayce-Ping Jayce-Ping commented Jun 10, 2026

Copy link
Copy Markdown
Collaborator

Summary

Roll out DiffusionNFT under its EMA-smoothed old shadow adapter on a SEPARATE rollout engine (SGLang / vLLM-omni), the way requires_ema_rollout intends. The in-process apply_eval_ema swap only touches the train process, so a separate engine never saw the EMA weights.

Reworked (rebased fresh on latest main) into a single, transport-agnostic mechanism owned by the backend:

  • Backend single source of truth (unirl/train/backend/fsdp.py): FSDPBackend.rollout_adapter_name returns the EMA shadow ("old") for adapter-EMA, else "default". The in-process eval-EMA swap and the cross-process weight sync both derive from it, so they cannot disagree.
  • One unified adapter_name knob across both weight-sync families (weight_sync/full/*, weight_sync/lora/*): defaults to None -> asks the backend; an explicit value overrides. Full-sync folds the shadow into the merged base (SGLang); LoRA-sync ships the shadow adapter directly (vLLM-omni). Also fixes a latent bug where _extract_canonical_lora never forwarded adapter_name, so the LoRA path always shipped default.
  • Trainer untouched: no per-recipe config-reaching; the sync derives the adapter from the backend.
  • FSDP grad fix (unirl/train/inject.py): _activate_keep_grad re-asserts the trainable/frozen requires_grad split across rollout/restore swaps, so the trainable default is never all-gathered grad-less under reshard_after_forward=False.
  • SGLang ODE path (unirl/rollout/engine/sglang/*): DiffusionNFT's num_sde_steps=0 resolves sde_indices to []; request the deterministic ODE rollout, always collect the dit-trajectory, and keep the full trajectory.

Related Issue

N/A

Test Plan

  • python -m py_compile on all touched files: passes.
  • Linter on touched files: clean.
  • Standalone CPU logic checks (no torch in env): backend rollout_adapter_name (old with ema_lora_cfg, else default); both sync families resolve None -> backend; full-sync fail-closed guard fires for non-default + lora_merged=False; _extract_canonical_lora forwards adapter_name; SGLang empty-sde_indices -> ODE branch + full trajectory; EMA update math.
  • Not run; reason: no local GPU. Recommended e2e: examples/diffusion/sd3_nft.yaml (trainside trains past the first backward) and examples/diffusion/sd3_nft_sglang.yaml (separate, rolls out under EMA old).

Compatibility / Risk

  • Backward compatible: GRPO and plain-LoRA recipes resolve adapter_name to "default" (unchanged behavior). FullWeightSync gains an adapter_name ctor arg (defaults to None); the LoRA-sync adapter_name default changes from "default" to None (resolves to "default" for non-EMA backends), so existing recipes are unaffected.
  • Requires the sync handler's backend to expose rollout_adapter_name; FSDPBackend is the only training backend, so all paths are covered.

Reviewer Notes

  • AI-assisted (Cursor). Rebased fresh on latest main; this force-push supersedes the earlier full-sync-only lora_adapter + trainer _resolve_sync_lora_adapter approach with the backend-sourced unified mechanism.
  • Out of scope: no new vLLM-omni NFT recipe is shipped, but the unified design makes it work via either RemoteLoraWeightSync (ships old, no merge) or IPC/NCCLWeightSync + lora_merged: true (merged old base).

Checklist

  • I reviewed the changed code and removed unrelated/generated artifacts.
  • I updated tests, docs, and configs where needed, or explained why not.

@Jayce-Ping Jayce-Ping force-pushed the pr/nft-ema-rollout branch from fe74bba to 1322d6e Compare June 10, 2026 07:08
@celve

celve commented Jun 11, 2026

Copy link
Copy Markdown
Collaborator

Do we have the e2e test result?

…name

Make the backend the single source of truth for which adapter the rollout
samples under, so a SEPARATE SGLang/vLLM-omni engine rolls out DiffusionNFT
under its EMA-smoothed "old" shadow adapter — replacing the full-sync-only
knob plus trainer config-reaching.

- FSDPBackend.rollout_adapter_name: the EMA shadow ("old") for adapter-EMA,
  else "default". The in-process apply_eval_ema swap and the cross-process
  weight sync both derive from it, so they cannot disagree.
- FullWeightSync + LoRA sync: one unified adapter_name knob (None -> backend);
  fix _extract_canonical_lora to forward adapter_name into extract_lora_tensors
  (it previously always shipped "default", silently dropping the EMA shadow on
  the vLLM-omni LoRA-ship path).
- inject.py: _activate_keep_grad re-asserts the trainable/frozen requires_grad
  split across rollout/restore swaps (FSDP2 reshard_after_forward=False).
- sglang request/engine: NFT empty sde_indices -> deterministic ODE rollout,
  always collect the dit-trajectory, keep the full trajectory.
- sd3_nft_sglang.yaml: now faithful; EMA shadow auto-derived from the backend.

Made with Cursor

Co-authored-by: Cursor <cursoragent@cursor.com>
@Jayce-Ping Jayce-Ping force-pushed the pr/nft-ema-rollout branch from 988e5f9 to 8a701c6 Compare June 12, 2026 07:35
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants