feat: roll out DiffusionNFT under the EMA shadow adapter#21
Open
Jayce-Ping wants to merge 1 commit into
Open
Conversation
fe74bba to
1322d6e
Compare
Collaborator
|
Do we have the e2e test result? |
…name
Make the backend the single source of truth for which adapter the rollout
samples under, so a SEPARATE SGLang/vLLM-omni engine rolls out DiffusionNFT
under its EMA-smoothed "old" shadow adapter — replacing the full-sync-only
knob plus trainer config-reaching.
- FSDPBackend.rollout_adapter_name: the EMA shadow ("old") for adapter-EMA,
else "default". The in-process apply_eval_ema swap and the cross-process
weight sync both derive from it, so they cannot disagree.
- FullWeightSync + LoRA sync: one unified adapter_name knob (None -> backend);
fix _extract_canonical_lora to forward adapter_name into extract_lora_tensors
(it previously always shipped "default", silently dropping the EMA shadow on
the vLLM-omni LoRA-ship path).
- inject.py: _activate_keep_grad re-asserts the trainable/frozen requires_grad
split across rollout/restore swaps (FSDP2 reshard_after_forward=False).
- sglang request/engine: NFT empty sde_indices -> deterministic ODE rollout,
always collect the dit-trajectory, keep the full trajectory.
- sd3_nft_sglang.yaml: now faithful; EMA shadow auto-derived from the backend.
Made with Cursor
Co-authored-by: Cursor <cursoragent@cursor.com>
988e5f9 to
8a701c6
Compare
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Roll out DiffusionNFT under its EMA-smoothed
oldshadow adapter on a SEPARATE rollout engine (SGLang / vLLM-omni), the wayrequires_ema_rolloutintends. The in-processapply_eval_emaswap only touches the train process, so a separate engine never saw the EMA weights.Reworked (rebased fresh on latest
main) into a single, transport-agnostic mechanism owned by the backend:unirl/train/backend/fsdp.py):FSDPBackend.rollout_adapter_namereturns the EMA shadow ("old") for adapter-EMA, else"default". The in-process eval-EMA swap and the cross-process weight sync both derive from it, so they cannot disagree.adapter_nameknob across both weight-sync families (weight_sync/full/*,weight_sync/lora/*): defaults toNone-> asks the backend; an explicit value overrides. Full-sync folds the shadow into the merged base (SGLang); LoRA-sync ships the shadow adapter directly (vLLM-omni). Also fixes a latent bug where_extract_canonical_loranever forwardedadapter_name, so the LoRA path always shippeddefault.unirl/train/inject.py):_activate_keep_gradre-asserts the trainable/frozenrequires_gradsplit across rollout/restore swaps, so the trainabledefaultis never all-gathered grad-less underreshard_after_forward=False.unirl/rollout/engine/sglang/*): DiffusionNFT'snum_sde_steps=0resolvessde_indicesto[]; request the deterministic ODE rollout, always collect the dit-trajectory, and keep the full trajectory.Related Issue
N/A
Test Plan
python -m py_compileon all touched files: passes.rollout_adapter_name(oldwithema_lora_cfg, elsedefault); both sync families resolveNone -> backend; full-sync fail-closed guard fires for non-default +lora_merged=False;_extract_canonical_loraforwardsadapter_name; SGLang empty-sde_indices-> ODE branch + full trajectory; EMA update math.examples/diffusion/sd3_nft.yaml(trainside trains past the first backward) andexamples/diffusion/sd3_nft_sglang.yaml(separate, rolls out under EMAold).Compatibility / Risk
adapter_nameto"default"(unchanged behavior).FullWeightSyncgains anadapter_namector arg (defaults toNone); the LoRA-syncadapter_namedefault changes from"default"toNone(resolves to"default"for non-EMA backends), so existing recipes are unaffected.backendto exposerollout_adapter_name;FSDPBackendis the only training backend, so all paths are covered.Reviewer Notes
main; this force-push supersedes the earlier full-sync-onlylora_adapter+ trainer_resolve_sync_lora_adapterapproach with the backend-sourced unified mechanism.RemoteLoraWeightSync(shipsold, no merge) orIPC/NCCLWeightSync + lora_merged: true(mergedoldbase).Checklist