[Feature] Add MLA draft model support for Eagle3 training#55
Merged
yubofredwang merged 2 commits intomainfrom Apr 3, 2026
Merged
[Feature] Add MLA draft model support for Eagle3 training#55yubofredwang merged 2 commits intomainfrom
yubofredwang merged 2 commits intomainfrom
Conversation
6a1182b to
2319f08
Compare
cicirori
commented
Mar 27, 2026
88a6b53 to
fee3ed3
Compare
df94bfc to
fdbdd38
Compare
Add DeepSeek MLA attention for Eagle3 draft model training, supporting
both SDPA and flex_attention backends. This enables using MLA-based draft
models (DeepSeek-V2/V3 style) with any target model.
New files:
- torchspec/models/draft/deepseek_eagle.py: MLA attention, decoder layer,
and DeepSeekForCausalLMEagle3 draft model
- configs/draft_models/deepseek_v3_eagle3.json: DeepSeek-V3 draft config
- configs/draft_models/qwen3_8b_eagle3_mla.json: Qwen3-8B + MLA draft config
- configs/sglang_qwen3_8b_mla_draft{,_flex}.yaml: e2e training configs
- tests/test_deepseek_eagle.py: 12 unit tests (SDPA + flex, CPU + CUDA)
Also includes e2e training loop fixes for single-node runs.
…rmalization Add Kimi K25 MLA draft model config, register DeepSeekV3 in AutoEagle3DraftModel, and fix rope_scaling normalization to apply unconditionally in config generation.
fdbdd38 to
73943f1
Compare
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
convert_to_hf.py: lm_head pruning bug, draft_vocab_size validation, train-time pruning detectionChanges
torchspec/models/draft/deepseek_eagle.py—DeepSeekMLAAttention,DeepSeekMLAFlexAttention,DeepSeekDecoderLayer,DeepSeekForCausalLMEagle3torchspec/models/draft/auto.py— RegisterDeepseekV3Config→DeepSeekForCausalLMEagle3dispatchconfigs/draft_models/qwen3_8b_eagle3_mla.json— Qwen3-8B target + MLA draft config (full vocab, no train-time pruning)configs/sglang_qwen3_8b_mla_draft.yaml— E2E training config (flex_attention)tests/test_deepseek_eagle.py— 12 unit tests (shapes, gradients, config dispatch, softmax scale, TTT loop, SDPA vs flex consistency)tools/convert_to_hf.py:arange + d2t→ directd2tindexing)draft_vocab_size != vocab_sizewithout--prune-vocab(model needs t2d/d2t)Test plan
./examples/qwen3-8b-single-node/run.sh configs/sglang_qwen3_8b_mla_draft.yaml— 500 steps on B200convert_to_hf.pywithout--prune-vocab→ full vocab model,from_pretrainedOKconvert_to_hf.pywith--prune-vocab --draft-vocab-size 32000→ pruned model with t2d/d2t,from_pretrainedOKE2E training result (500 steps, Qwen3-8B + MLA draft, flex_attention, B200)
HF conversion verification
FSDP2 compatibility note
MLA's bottleneck structure (down_proj → RMSNorm → up_proj) is compatible with TorchSpec's per-Linear FSDP2 sharding. FSDP2 shards parameters only — weights are all-gathered before forward, so activations (including the compressed latent and
torch.splitonkv_a_projoutput) are always full-size during computation. RMSNorm layers are notnn.Linearand remain replicated. This is the same sharding pattern used by the existing Llama attention, verified end-to-end on B200 (500 steps, loss converges to 1.274).