Skip to content

[Feature] Add MLA draft model support for Eagle3 training#55

Merged
yubofredwang merged 2 commits intomainfrom
feature/mla-draft-model
Apr 3, 2026
Merged

[Feature] Add MLA draft model support for Eagle3 training#55
yubofredwang merged 2 commits intomainfrom
feature/mla-draft-model

Conversation

@cicirori
Copy link
Copy Markdown
Collaborator

@cicirori cicirori commented Mar 27, 2026

Summary

  • Add DeepSeek MLA (Multi-head Latent Attention) draft model for Eagle3 online training
  • Support both SDPA and flex_attention backends
  • Enable using MLA-based draft models (DeepSeek-V2/V3 style) with any target model
  • Reference config based on nvidia/Kimi-K2.5-Thinking-Eagle3
  • Fix convert_to_hf.py: lm_head pruning bug, draft_vocab_size validation, train-time pruning detection

Changes

  • torchspec/models/draft/deepseek_eagle.pyDeepSeekMLAAttention, DeepSeekMLAFlexAttention, DeepSeekDecoderLayer, DeepSeekForCausalLMEagle3
  • torchspec/models/draft/auto.py — Register DeepseekV3ConfigDeepSeekForCausalLMEagle3 dispatch
  • configs/draft_models/qwen3_8b_eagle3_mla.json — Qwen3-8B target + MLA draft config (full vocab, no train-time pruning)
  • configs/sglang_qwen3_8b_mla_draft.yaml — E2E training config (flex_attention)
  • tests/test_deepseek_eagle.py — 12 unit tests (shapes, gradients, config dispatch, softmax scale, TTT loop, SDPA vs flex consistency)
  • tools/convert_to_hf.py:
    • Fix lm_head trimming bug (arange + d2t → direct d2t indexing)
    • Error when draft_vocab_size != vocab_size without --prune-vocab (model needs t2d/d2t)
    • Error when lm_head already pruned (train-time pruning incompatible with post-training re-prune)
  • E2E training loop fixes for single-node runs

Test plan

  • Unit tests: 12/12 passed (CPU + CUDA, SDPA + flex_attention)
  • E2E: ./examples/qwen3-8b-single-node/run.sh configs/sglang_qwen3_8b_mla_draft.yaml — 500 steps on B200
  • convert_to_hf.py without --prune-vocab → full vocab model, from_pretrained OK
  • convert_to_hf.py with --prune-vocab --draft-vocab-size 32000 → pruned model with t2d/d2t, from_pretrained OK
  • pre-commit (ruff + ruff-format + isort) all pass

E2E training result (500 steps, Qwen3-8B + MLA draft, flex_attention, B200)

$ ./examples/qwen3-8b-single-node/run.sh configs/sglang_qwen3_8b_mla_draft.yaml

Training: 100% | 500/500 [01:45<00:00, 4.75step/s, loss=1.274, acc=0.865, acc_len=4.15, thru=6.5, I=145.7, T=12.6, wait=0.0s, pool=0, epoch=1/1]

HF conversion verification

# Without --prune-vocab (full vocab)
$ python tools/convert_to_hf.py --input-dir outputs/.../iter_0000501 --config configs/draft_models/qwen3_8b_eagle3_mla.json -f
vocab=151936/151936, lm_head=[151936, 4096], has_vocab_pruning=False, params=1,493,714,944

# With --prune-vocab (post-training pruning)
$ python tools/convert_to_hf.py ... --prune-vocab --draft-vocab-size 32000 --dataset-path examples/data/sample_conversations.jsonl --tokenizer Qwen/Qwen3-8B --chat-template qwen --prompt-key conversations
vocab=32000/151936, lm_head=[32000, 4096], has_vocab_pruning=True, params=1,002,457,088

# Both load correctly via from_pretrained

FSDP2 compatibility note

MLA's bottleneck structure (down_proj → RMSNorm → up_proj) is compatible with TorchSpec's per-Linear FSDP2 sharding. FSDP2 shards parameters only — weights are all-gathered before forward, so activations (including the compressed latent and torch.split on kv_a_proj output) are always full-size during computation. RMSNorm layers are not nn.Linear and remain replicated. This is the same sharding pattern used by the existing Llama attention, verified end-to-end on B200 (500 steps, loss converges to 1.274).

@cicirori cicirori force-pushed the feature/mla-draft-model branch 3 times, most recently from 6a1182b to 2319f08 Compare March 27, 2026 19:20
@cicirori cicirori force-pushed the feature/mla-draft-model branch 4 times, most recently from 88a6b53 to fee3ed3 Compare March 27, 2026 19:33
@cicirori cicirori requested a review from yubofredwang March 27, 2026 19:35
@cicirori cicirori force-pushed the feature/mla-draft-model branch 6 times, most recently from df94bfc to fdbdd38 Compare March 27, 2026 19:59
cicirori and others added 2 commits April 3, 2026 06:34
Add DeepSeek MLA attention for Eagle3 draft model training, supporting
both SDPA and flex_attention backends. This enables using MLA-based draft
models (DeepSeek-V2/V3 style) with any target model.

New files:
- torchspec/models/draft/deepseek_eagle.py: MLA attention, decoder layer,
  and DeepSeekForCausalLMEagle3 draft model
- configs/draft_models/deepseek_v3_eagle3.json: DeepSeek-V3 draft config
- configs/draft_models/qwen3_8b_eagle3_mla.json: Qwen3-8B + MLA draft config
- configs/sglang_qwen3_8b_mla_draft{,_flex}.yaml: e2e training configs
- tests/test_deepseek_eagle.py: 12 unit tests (SDPA + flex, CPU + CUDA)

Also includes e2e training loop fixes for single-node runs.
…rmalization

Add Kimi K25 MLA draft model config, register DeepSeekV3 in AutoEagle3DraftModel,
and fix rope_scaling normalization to apply unconditionally in config generation.
@yubofredwang yubofredwang force-pushed the feature/mla-draft-model branch from fdbdd38 to 73943f1 Compare April 3, 2026 06:36
@yubofredwang yubofredwang merged commit 88ca60b into main Apr 3, 2026
1 check passed
@yubofredwang yubofredwang deleted the feature/mla-draft-model branch April 3, 2026 06:41
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants