Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 41 additions & 0 deletions configs/draft_models/kimi_k25_eagle3_mla.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
{
"architectures": [
"Eagle3DeepseekV2ForCausalLM"
],
"model_type": "kimi_k2",
"hidden_size": 7168,
"intermediate_size": 18432,
"num_hidden_layers": 1,
"num_attention_heads": 64,
"num_key_value_heads": 64,
"q_lora_rank": 1536,
"kv_lora_rank": 512,
"qk_nope_head_dim": 128,
"qk_rope_head_dim": 64,
"v_head_dim": 128,
"hidden_act": "silu",
"rms_norm_eps": 1e-05,
"vocab_size": 163840,
"draft_vocab_size": 163840,
"torch_dtype": "bfloat16",
"rope_theta": 50000.0,
"rope_scaling": {
"beta_fast": 1.0,
"beta_slow": 1.0,
"factor": 64.0,
"mscale": 1.0,
"mscale_all_dim": 1.0,
"original_max_position_embeddings": 4096,
"type": "yarn"
},
"eagle_config": {
"eagle_aux_hidden_state_layer_ids": [1, 29, 57],
"use_aux_hidden_state": true,
"use_input_layernorm_in_first_layer": true,
"use_last_layernorm": true,
"use_mtp_layernorm": false
},
"bos_token_id": 163584,
"eos_token_id": 163585,
"pad_token_id": 0
}
35 changes: 35 additions & 0 deletions configs/draft_models/qwen3_8b_eagle3_mla.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
{
"architectures": [
"DeepSeekForCausalLMEagle3"
],
"model_type": "deepseek_v3",
"hidden_size": 4096,
"num_attention_heads": 32,
"num_key_value_heads": 32,
"num_hidden_layers": 1,
"intermediate_size": 12288,
"hidden_act": "silu",
"rms_norm_eps": 1e-06,
"q_lora_rank": 1536,
"kv_lora_rank": 512,
"qk_nope_head_dim": 128,
"qk_rope_head_dim": 64,
"v_head_dim": 128,
"max_position_embeddings": 262144,
"rope_theta": 1000000,
"rope_scaling": {
"type": "yarn",
"factor": 64.0,
"original_max_position_embeddings": 4096,
"beta_fast": 1.0,
"beta_slow": 1.0,
"mscale": 1.0,
"mscale_all_dim": 1.0
},
"vocab_size": 151936,
"tie_word_embeddings": false,
"pretraining_tp": 1,
"bos_token_id": 151643,
"eos_token_id": 151645,
"pad_token_id": 0
}
4 changes: 2 additions & 2 deletions configs/sglang_qwen3_8b.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@ model:

dataset:
train_data_path: ../examples/data/sample_conversations.jsonl
eval_data_path: ../examples/data/eval_conversations.jsonl
eval_interval: 100
# eval_data_path: ../examples/data/eval_conversations.jsonl
# eval_interval: 100
chat_template: qwen
prompt_key: conversations

Expand Down
61 changes: 61 additions & 0 deletions configs/sglang_qwen3_8b_mla_draft.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
# Configuration for Qwen3-8B target model with MLA draft model
#
# Same as sglang_qwen3_8b.yaml but uses DeepSeek MLA attention in the draft model.
#
# Usage:
# python -m torchspec.train_entry --config configs/sglang_qwen3_8b_mla_draft.yaml
# ./examples/qwen3-8b-single-node/run.sh configs/sglang_qwen3_8b_mla_draft.yaml

model:
target_model_path: Qwen/Qwen3-8B
draft_model_config: configs/draft_models/qwen3_8b_eagle3_mla.json
trust_remote_code: true

dataset:
train_data_path: ../examples/data/sample_conversations.jsonl
chat_template: qwen
prompt_key: conversations

training:
attention_backend: flex_attention
micro_batch_size: 1
draft_accumulation_steps: 1
learning_rate: 1e-4
max_concurrent_batches: 1
max_grad_norm: 0.5
max_seq_length: 16384
num_epochs: 1
seed: 42
training_num_gpus_per_node: 2
training_num_nodes: 1
ttt_length: 7
save_per_epoch: true
warmup_ratio: 0.015

inference:
inference_engine_type: sgl
inference_num_gpus: 1
inference_num_gpus_per_engine: 1
inference_num_gpus_per_node: 1
max_sample_pool_size: 64
inference_buffer_threshold: 32
inference_batch_size: 8
sglang:
tp_size: 1
mem_fraction_static: 0.7

mooncake:
master_server_address: null
metadata_server: null
protocol: tcp
global_segment_size: 16GB
local_buffer_size: 4GB

output_dir: ./outputs/qwen3-8b-mla-draft
cache_dir: ./cache/qwen3-8b-mla-draft
model_download_dir: null

debug:
save_debug_train_data: null
debug_train_only: false
debug_inference_only: false
2 changes: 2 additions & 0 deletions examples/qwen3-8b-single-node/run.sh
Original file line number Diff line number Diff line change
Expand Up @@ -56,12 +56,14 @@ echo "Local IP: $LOCAL_IP"
echo "Extra args: $*"
echo "=============================================="

# TODO: unify tp_size config across sglang/vllm backends
python3 -m torchspec.train_entry \
--config "$CONFIG_FILE" \
training.training_num_gpus_per_node="$TRAIN_GPUS" \
inference.inference_num_gpus="$INFERENCE_GPUS" \
inference.inference_num_gpus_per_engine=2 \
inference.inference_num_gpus_per_node="$TOTAL_GPUS" \
inference.sglang.tp_size=2 \
"$@"

echo "=============================================="
Expand Down
Loading
Loading