Skip to content

Commit 88ca60b

Browse files
[Feature] Add MLA draft model support for Eagle3 training (#55)
Co-authored-by: Yubo Wang <yubowang2019@gmail.com>
1 parent 06a9760 commit 88ca60b

File tree

11 files changed

+1200
-11
lines changed

11 files changed

+1200
-11
lines changed
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
{
2+
"architectures": [
3+
"Eagle3DeepseekV2ForCausalLM"
4+
],
5+
"model_type": "kimi_k2",
6+
"hidden_size": 7168,
7+
"intermediate_size": 18432,
8+
"num_hidden_layers": 1,
9+
"num_attention_heads": 64,
10+
"num_key_value_heads": 64,
11+
"q_lora_rank": 1536,
12+
"kv_lora_rank": 512,
13+
"qk_nope_head_dim": 128,
14+
"qk_rope_head_dim": 64,
15+
"v_head_dim": 128,
16+
"hidden_act": "silu",
17+
"rms_norm_eps": 1e-05,
18+
"vocab_size": 163840,
19+
"draft_vocab_size": 163840,
20+
"torch_dtype": "bfloat16",
21+
"rope_theta": 50000.0,
22+
"rope_scaling": {
23+
"beta_fast": 1.0,
24+
"beta_slow": 1.0,
25+
"factor": 64.0,
26+
"mscale": 1.0,
27+
"mscale_all_dim": 1.0,
28+
"original_max_position_embeddings": 4096,
29+
"type": "yarn"
30+
},
31+
"eagle_config": {
32+
"eagle_aux_hidden_state_layer_ids": [1, 29, 57],
33+
"use_aux_hidden_state": true,
34+
"use_input_layernorm_in_first_layer": true,
35+
"use_last_layernorm": true,
36+
"use_mtp_layernorm": false
37+
},
38+
"bos_token_id": 163584,
39+
"eos_token_id": 163585,
40+
"pad_token_id": 0
41+
}
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
{
2+
"architectures": [
3+
"DeepSeekForCausalLMEagle3"
4+
],
5+
"model_type": "deepseek_v3",
6+
"hidden_size": 4096,
7+
"num_attention_heads": 32,
8+
"num_key_value_heads": 32,
9+
"num_hidden_layers": 1,
10+
"intermediate_size": 12288,
11+
"hidden_act": "silu",
12+
"rms_norm_eps": 1e-06,
13+
"q_lora_rank": 1536,
14+
"kv_lora_rank": 512,
15+
"qk_nope_head_dim": 128,
16+
"qk_rope_head_dim": 64,
17+
"v_head_dim": 128,
18+
"max_position_embeddings": 262144,
19+
"rope_theta": 1000000,
20+
"rope_scaling": {
21+
"type": "yarn",
22+
"factor": 64.0,
23+
"original_max_position_embeddings": 4096,
24+
"beta_fast": 1.0,
25+
"beta_slow": 1.0,
26+
"mscale": 1.0,
27+
"mscale_all_dim": 1.0
28+
},
29+
"vocab_size": 151936,
30+
"tie_word_embeddings": false,
31+
"pretraining_tp": 1,
32+
"bos_token_id": 151643,
33+
"eos_token_id": 151645,
34+
"pad_token_id": 0
35+
}

configs/sglang_qwen3_8b.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,8 @@ model:
1515

1616
dataset:
1717
train_data_path: ../examples/data/sample_conversations.jsonl
18-
eval_data_path: ../examples/data/eval_conversations.jsonl
19-
eval_interval: 100
18+
# eval_data_path: ../examples/data/eval_conversations.jsonl
19+
# eval_interval: 100
2020
chat_template: qwen
2121
prompt_key: conversations
2222

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
# Configuration for Qwen3-8B target model with MLA draft model
2+
#
3+
# Same as sglang_qwen3_8b.yaml but uses DeepSeek MLA attention in the draft model.
4+
#
5+
# Usage:
6+
# python -m torchspec.train_entry --config configs/sglang_qwen3_8b_mla_draft.yaml
7+
# ./examples/qwen3-8b-single-node/run.sh configs/sglang_qwen3_8b_mla_draft.yaml
8+
9+
model:
10+
target_model_path: Qwen/Qwen3-8B
11+
draft_model_config: configs/draft_models/qwen3_8b_eagle3_mla.json
12+
trust_remote_code: true
13+
14+
dataset:
15+
train_data_path: ../examples/data/sample_conversations.jsonl
16+
chat_template: qwen
17+
prompt_key: conversations
18+
19+
training:
20+
attention_backend: flex_attention
21+
micro_batch_size: 1
22+
draft_accumulation_steps: 1
23+
learning_rate: 1e-4
24+
max_concurrent_batches: 1
25+
max_grad_norm: 0.5
26+
max_seq_length: 16384
27+
num_epochs: 1
28+
seed: 42
29+
training_num_gpus_per_node: 2
30+
training_num_nodes: 1
31+
ttt_length: 7
32+
save_per_epoch: true
33+
warmup_ratio: 0.015
34+
35+
inference:
36+
inference_engine_type: sgl
37+
inference_num_gpus: 1
38+
inference_num_gpus_per_engine: 1
39+
inference_num_gpus_per_node: 1
40+
max_sample_pool_size: 64
41+
inference_buffer_threshold: 32
42+
inference_batch_size: 8
43+
sglang:
44+
tp_size: 1
45+
mem_fraction_static: 0.7
46+
47+
mooncake:
48+
master_server_address: null
49+
metadata_server: null
50+
protocol: tcp
51+
global_segment_size: 16GB
52+
local_buffer_size: 4GB
53+
54+
output_dir: ./outputs/qwen3-8b-mla-draft
55+
cache_dir: ./cache/qwen3-8b-mla-draft
56+
model_download_dir: null
57+
58+
debug:
59+
save_debug_train_data: null
60+
debug_train_only: false
61+
debug_inference_only: false

examples/qwen3-8b-single-node/run.sh

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,12 +56,14 @@ echo "Local IP: $LOCAL_IP"
5656
echo "Extra args: $*"
5757
echo "=============================================="
5858

59+
# TODO: unify tp_size config across sglang/vllm backends
5960
python3 -m torchspec.train_entry \
6061
--config "$CONFIG_FILE" \
6162
training.training_num_gpus_per_node="$TRAIN_GPUS" \
6263
inference.inference_num_gpus="$INFERENCE_GPUS" \
6364
inference.inference_num_gpus_per_engine=2 \
6465
inference.inference_num_gpus_per_node="$TOTAL_GPUS" \
66+
inference.sglang.tp_size=2 \
6567
"$@"
6668

6769
echo "=============================================="

0 commit comments

Comments
 (0)