|
| 1 | +# Configuration for Qwen3-8B target model with MLA draft model |
| 2 | +# |
| 3 | +# Same as sglang_qwen3_8b.yaml but uses DeepSeek MLA attention in the draft model. |
| 4 | +# |
| 5 | +# Usage: |
| 6 | +# python -m torchspec.train_entry --config configs/sglang_qwen3_8b_mla_draft.yaml |
| 7 | +# ./examples/qwen3-8b-single-node/run.sh configs/sglang_qwen3_8b_mla_draft.yaml |
| 8 | + |
| 9 | +model: |
| 10 | + target_model_path: Qwen/Qwen3-8B |
| 11 | + draft_model_config: configs/draft_models/qwen3_8b_eagle3_mla.json |
| 12 | + trust_remote_code: true |
| 13 | + |
| 14 | +dataset: |
| 15 | + train_data_path: ../examples/data/sample_conversations.jsonl |
| 16 | + chat_template: qwen |
| 17 | + prompt_key: conversations |
| 18 | + |
| 19 | +training: |
| 20 | + attention_backend: flex_attention |
| 21 | + micro_batch_size: 1 |
| 22 | + draft_accumulation_steps: 1 |
| 23 | + learning_rate: 1e-4 |
| 24 | + max_concurrent_batches: 1 |
| 25 | + max_grad_norm: 0.5 |
| 26 | + max_seq_length: 16384 |
| 27 | + num_epochs: 1 |
| 28 | + seed: 42 |
| 29 | + training_num_gpus_per_node: 2 |
| 30 | + training_num_nodes: 1 |
| 31 | + ttt_length: 7 |
| 32 | + save_per_epoch: true |
| 33 | + warmup_ratio: 0.015 |
| 34 | + |
| 35 | +inference: |
| 36 | + inference_engine_type: sgl |
| 37 | + inference_num_gpus: 1 |
| 38 | + inference_num_gpus_per_engine: 1 |
| 39 | + inference_num_gpus_per_node: 1 |
| 40 | + max_sample_pool_size: 64 |
| 41 | + inference_buffer_threshold: 32 |
| 42 | + inference_batch_size: 8 |
| 43 | + sglang: |
| 44 | + tp_size: 1 |
| 45 | + mem_fraction_static: 0.7 |
| 46 | + |
| 47 | +mooncake: |
| 48 | + master_server_address: null |
| 49 | + metadata_server: null |
| 50 | + protocol: tcp |
| 51 | + global_segment_size: 16GB |
| 52 | + local_buffer_size: 4GB |
| 53 | + |
| 54 | +output_dir: ./outputs/qwen3-8b-mla-draft |
| 55 | +cache_dir: ./cache/qwen3-8b-mla-draft |
| 56 | +model_download_dir: null |
| 57 | + |
| 58 | +debug: |
| 59 | + save_debug_train_data: null |
| 60 | + debug_train_only: false |
| 61 | + debug_inference_only: false |
0 commit comments