Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
121 commits
Select commit Hold shift + click to select a range
367846d
Add DFlash training implementation
Mar 19, 2026
f29ded5
Add torch dynamo config fallback for flex attention
Mar 19, 2026
fdf2822
Add DFlash framework integration: config dispatch, N-layer target mod…
Mar 19, 2026
25dc9ac
Add DFlash training quality tests, architecture validation, and GPU c…
Mar 19, 2026
b28f06c
Add 1-GPU training configs and RunPod training script
Mar 19, 2026
fb8a4e5
Fix non-interactive WandB prompt and add SSH helper
Mar 19, 2026
506298e
Use --system-site-packages venv to avoid re-downloading PyTorch
Mar 19, 2026
7e5958b
Make SGLang and vLLM imports optional in engine __init__
Mar 19, 2026
ac4f3ad
Lazy-import SGLang and vLLM engines in factory.py
Mar 19, 2026
0b9f631
Remove eval_data_path from 1-GPU configs to avoid colocate timeout
Mar 19, 2026
d2836ed
Fix DFlash config for Qwen3-8B and add inductor GEMM fallback
Mar 19, 2026
447705a
Fix inductor GEMM backend and DFlash dtype mismatch
Mar 19, 2026
2a03302
Use aot_eager backend for flex_attention compilation
Mar 19, 2026
5033d17
Cast DFlash draft model to bfloat16 before FSDP wrapping
Mar 19, 2026
d427b4d
Debug DFlash zero-loss bug: vectorize noise input, add anchor fallback
Mar 19, 2026
340837f
Track dflash_implementation_log.md in git
Mar 19, 2026
d127c38
[Feature] add Minimax support and Kimi parser updates (#45)
yubofredwang Mar 19, 2026
8a33620
Fix stale test assertion: update expected config values for Qwen3-8B
Mar 19, 2026
c59097e
Fix DFlash metric key names to match controller expectations
Mar 19, 2026
fbdd8ff
Update implementation log: Session 6 — zero-loss bug fix and GPU results
Mar 19, 2026
7ca08bc
Add DFlash inference benchmark script
Mar 19, 2026
dc75fcc
DFlash inference: checkpoint extract, 1-GPU bench config, benchmark s…
Mar 19, 2026
e4603bc
Fix 5 critical DFlash bugs: 5-layer model, Q/K-norm, block_keep_mask,…
Mar 19, 2026
18a1519
[Docs] Add 5-node MiniMax example files (#47)
yubofredwang Mar 19, 2026
c796ce8
[Docs] refresh README structure and links (#46)
yubofredwang Mar 19, 2026
004b6c4
[Docs] Add PyTorch Blog Link (#48)
yubofredwang Mar 19, 2026
947d3f5
Fix build_target_layer_ids to match SpecForge formula exactly
Mar 19, 2026
345598c
Phase B validation: fix training hyperparams and document 4-GPU results
Mar 19, 2026
636d4ef
Add min_loss_tokens filter to preprocess_conversations
Mar 19, 2026
d63001b
Add Phase C scripts: PerfectBlend data prep and full training launcher
Mar 19, 2026
07356c1
Fix collator crash when loss_mask length differs from input_ids
Mar 20, 2026
10119ac
Switch FlexAttention from aot_eager to inductor backend for 2-5x speedup
Mar 20, 2026
5ea37da
Document aot_eager vs inductor analysis and 3x speedup results
Mar 20, 2026
5e50e38
Fix minimax type mismatch (#49)
yubofredwang Mar 20, 2026
a85dde3
docker: add RunPod Dockerfile for sglang v0.5.8.post1
Mar 20, 2026
407bb7a
Add checkpoint rotation to prevent disk quota exceeded during training
Mar 20, 2026
5a0fb00
Document Phase C crash debugging: disk quota, eval timeout, checkpoin…
Mar 20, 2026
fe48866
Document training speed optimization: 6.6hr → 3.1hr with batch/anchor…
Mar 20, 2026
e3dfd80
Document checkpoint backup, resume instructions, and full attempt sum…
Mar 20, 2026
89de88e
Document complete pod restart recovery procedure for DFlash training …
Mar 20, 2026
c1b9548
Consolidate DFlash implementation log: 1697 → 588 lines
Mar 21, 2026
c42448d
Document Session 10: PyTorch 2.9.1 migration and Phase C resume from …
Mar 21, 2026
83c455b
Optimize DFlash training speed: remove compile overhead and enable GQA
Mar 21, 2026
04fc7da
Document Session 11: torch 2.9.1 speed regression investigation
Mar 21, 2026
9bac294
Document DFlash benchmark results (τ=1.86) and SpecForge comparison
Mar 21, 2026
19532a5
Merge branch 'torchspec-project:main' into feature/dflash-training
zhubohao911 Mar 21, 2026
0f7c860
Clean up DFlash branch: remove dead eval code, add comments to shared…
Mar 21, 2026
17eaecc
Fix DFlash training bugs and add missing test coverage
Mar 22, 2026
7dc18cc
dflash test plan and implementation log updates
Mar 22, 2026
15fb371
remove old implementation log
Mar 22, 2026
3f9184a
Update DFlash docs: Phase 1 complete, Issue 26 resolved, Phase D results
Mar 22, 2026
c4f6102
Document Phase 2 speed benchmark results: pipeline bottleneck identified
Mar 22, 2026
dde8b21
Update speed benchmark with instrumented timing breakdown
Mar 22, 2026
db9c556
Add compute sub-breakdown profiling instrumentation
Mar 22, 2026
407bb44
Document compute sub-breakdown profiling results
Mar 22, 2026
5f5a7cc
Document speed optimization test results (Section 2.5)
Mar 22, 2026
6581975
Add three compute speed optimizations
Mar 22, 2026
48867f3
Document compute optimization results: no_sync +8%, compile not viable
Mar 22, 2026
d941fab
Document Phase F speed optimization session results
Mar 22, 2026
89820a1
Document FSDP strategy and GPU scaling benchmark results (Tests 1-3)
Mar 22, 2026
f60f172
Increase find_free_port timeout from 30s to 120s for PyTorch 2.9+ com…
Mar 22, 2026
55d5558
Document Test 4 results: 2-inference GPU is strictly worse (1.1 step/s)
Mar 22, 2026
4a8616e
Document hardware topology analysis: NVLink unused, Mooncake TCP is b…
Mar 22, 2026
35ffb18
Add async data pre-fetch to overlap Mooncake transfer with GPU compute
Mar 22, 2026
2e8710a
Fix PrefetchedDataFetcher: use persistent thread across training steps
Mar 22, 2026
5a946cb
Fix prefetch GPU contention: stage data on CPU, move to GPU synchrono…
Mar 22, 2026
d71d71d
Document Test 5a/5b/6 results: CPU prefetch 6.8 step/s, NVLink attempt
Mar 22, 2026
39e20ea
Document Test 6 NVLink investigation (not viable) and Test 7 (3 GPU +…
Mar 22, 2026
3fe01a2
Update config to best training configuration and add Phase 3 τ benchm…
Mar 22, 2026
3b3dc5e
Update Phase 3 plan: incremental approach, train 5k steps first
Mar 22, 2026
5c73a1e
Add SpecForge DFlash training reference doc
Mar 22, 2026
a1cab4e
Document Phase 3 status, environment issues, and known workarounds
Mar 22, 2026
822d8e9
Phase 3: 2-epoch training with resume, prefetch_depth=8, auto GEMM fix
Mar 22, 2026
0d2ab60
Phase 3: switch to batch=2 accum=2 seq=2048 for 5.3 step/s throughput
Mar 22, 2026
5f02478
Debug: add bounds check for noise_ids in _create_noise_embed
Mar 22, 2026
78d8938
Fix: clone mooncake tensors before buffer cleanup to prevent use-afte…
Mar 22, 2026
eee9d52
Fix FlexAttention recompile overflow: bucketed padding + higher recom…
Mar 22, 2026
c355f86
Switch to batch=1 accum=4 for better Mooncake TCP overlap
Mar 22, 2026
8c4f2b1
Fix 5x speed regression: skip clone for batch=1 to preserve pinned me…
Mar 22, 2026
1247cf6
Add RunPod setup script consolidating all env fixes from issues/commits
Mar 22, 2026
0a45cd4
Update Dockerfile and setup script for faster pod provisioning
Mar 22, 2026
c5f9588
Add slim/full Docker image variants via INCLUDE_MODEL build arg
Mar 22, 2026
6ca24e9
Rewrite RunPod guide: Docker image + script-based setup
Mar 22, 2026
db9ad51
Add GitHub Actions workflow for Docker image build
Mar 22, 2026
276948a
Fix Docker build: remove [fa] extra (flash-attention-cute metadata mi…
Mar 22, 2026
41f31bf
Expand benchmark to 50 diverse prompts with τ distribution analysis
Mar 23, 2026
c7995ed
Document Phase 3 training results: 2-epoch run, τ=1.78 benchmark
Mar 23, 2026
0b55cbc
Extend training to 3 epochs, keep 2 checkpoints for comparison
Mar 23, 2026
7f7efc2
Add --prompt_file argument to benchmark script
Mar 23, 2026
9ffe525
Update RunPod guide: Docker image, tmux, monitoring instructions
Mar 23, 2026
df045bc
Update training results: epoch 3 confirms data saturation at τ=1.85
Mar 23, 2026
43bae6b
Add general-purpose DFlash training guide
Mar 23, 2026
e8b5e1d
Add Modal GPU platform support and speed tuning results for DFlash tr…
xinghandd Mar 24, 2026
764bd08
Add recommended parameter settings to modal_dflash_train.py docstring
xinghandd Mar 24, 2026
23e497c
Split Modal training into separate SGLang and HF scripts to avoid idl…
xinghandd Mar 24, 2026
b48adfe
Add anchors=512 speed tuning results and update recommended config
xinghandd Mar 24, 2026
5094f5e
Add 512-E/F (4+4 GPU split) results and fix GPU count display
xinghandd Mar 24, 2026
44a4f3d
Fix checkpoint saving to persist on Modal volume
xinghandd Mar 24, 2026
89e8ff6
Add post-training HF conversion and upload to Modal training script
xinghandd Mar 24, 2026
5c1c77e
Add Modal DFlash inference benchmark script
xinghandd Mar 24, 2026
3156a1a
fix: HF repo visibility and missing WandB secret in training script
xinghandd Mar 25, 2026
e00b5f3
refactor: replace hardcoded prompts with z-lab standard benchmarks
xinghandd Mar 25, 2026
9ddb506
feat: add SGLang-backend DFlash benchmark with compat fixes
xinghandd Mar 25, 2026
dc90f9d
add: HF config fix utility and 3-epoch training metrics plot
xinghandd Mar 25, 2026
036d10e
docs: add full z-lab benchmark results (Phase G) and switch to 1x H100
xinghandd Mar 25, 2026
d047c95
fix: correct training data size from 47K to 190K in benchmark results
xinghandd Mar 25, 2026
c375d72
Add min_lr + weight_decay for better convergence, prep 800K training
xinghandd Mar 25, 2026
96f0f5b
Merge branch 'feature/dd-test' into feature/dflash-training
Mar 26, 2026
2b0fe91
Phase H benchmark results, switch to SGLang inference, remove hardcod…
xinghandd Mar 27, 2026
dd4bcb0
Support HF Hub datasets in Arrow/Parquet format and drop unused columns
xinghandd Mar 27, 2026
f3005e7
docs: add decode-only speedup analysis and HF dataset patch for training
xinghandd Mar 28, 2026
c927567
fix: 3 training bugs + 3 perf optimizations from pipeline deep-dive
xinghandd Mar 28, 2026
3e15c11
Add --resume flag for checkpoint resumption and tune save settings
xinghandd Mar 29, 2026
7b2bd5e
docs: add Phase J UltraChat results and training code audit
xinghandd Mar 29, 2026
848b541
Reorganize scripts/ and consolidate dflash docs
xinghandd Mar 29, 2026
7ce2637
restructure dflash docs folder
xinghandd Mar 29, 2026
79e0f64
Add hyperparameter sweep infra and Phase K convergence results
xinghandd Mar 30, 2026
b2bf4ad
fix: isolated run_id per sweep run + resume support for multi-session…
xinghandd Apr 1, 2026
77f992d
docs: Phase K P2-accum1 benchmark results with τ distribution analysis
xinghandd Apr 1, 2026
6023d9b
docs: P2-WSD benchmark results — best model, 2.7% gap to z-lab
xinghandd Apr 2, 2026
fda638b
fix: graceful fallback for empty loss_mask in DFlash anchor sampling
xinghandd Apr 2, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 68 additions & 0 deletions .github/workflows/docker-build.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
name: Build Docker Image

on:
workflow_dispatch:
inputs:
include_model:
description: 'Include Qwen3-8B model in image (0=slim ~4GB, 1=full ~20GB)'
required: true
default: '0'
type: choice
options:
- '0'
- '1'
tag:
description: 'Image tag (default: slim or latest based on include_model)'
required: false
type: string

jobs:
build:
runs-on: ubuntu-latest
permissions:
contents: read
packages: write

steps:
- name: Checkout
uses: actions/checkout@v4

- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3

- name: Login to GitHub Container Registry
uses: docker/login-action@v3
with:
registry: ghcr.io
username: ${{ github.actor }}
password: ${{ secrets.GITHUB_TOKEN }}

- name: Determine tag
id: tag
run: |
if [ -n "${{ inputs.tag }}" ]; then
echo "tag=${{ inputs.tag }}" >> $GITHUB_OUTPUT
elif [ "${{ inputs.include_model }}" = "1" ]; then
echo "tag=latest" >> $GITHUB_OUTPUT
else
echo "tag=slim" >> $GITHUB_OUTPUT
fi

- name: Free up disk space
run: |
# GitHub runners have ~14GB free, we need more for the build
sudo rm -rf /usr/share/dotnet /usr/local/lib/android /opt/ghc
sudo apt-get clean
df -h /

- name: Build and push
uses: docker/build-push-action@v5
with:
context: .
file: docker/sglang/v0.5.8.post1/Dockerfile.runpod
push: true
tags: ghcr.io/${{ github.repository_owner }}/torchspec-dflash:${{ steps.tag.outputs.tag }}
build-args: |
INCLUDE_MODEL=${{ inputs.include_model }}
cache-from: type=gha
cache-to: type=gha,mode=max
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -88,4 +88,4 @@ _sglang/
wandb/log.txt

.claude/
wandb/
wandb/
110 changes: 69 additions & 41 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,22 +1,59 @@
# TorchSpec

TorchSpec is a torch-native speculative decoding training framework. We introduce a disaggregated way of training speculative decoding draft models where inference and training are fully decoupled and stream hidden states directly from inference engine groups to distributed training workers via Mooncake (RDMA/TCP) store, allowing each side to scale independently.
TorchSpec is a torch-native speculative decoding training framework. We introduce a disaggregated way of training speculative decoding draft models where inference and training are fully decoupled and stream hidden states directly from inference engine groups to distributed training workers via [Mooncake](https://github.com/kvcache-ai/Mooncake) store, allowing each side to scale independently.

TorchSpec currently includes training flows and examples for:

- Kimi-K2.5
- MiniMax-M2.5
- Qwen3-Coder-Next

## 🚀 Blogs

- PyTorch blog: [TorchSpec: Speculative Decoding Training at Scale](https://pytorch.org/blog/torchspec-speculative-decoding-training-at-scale/)
- Release blog: [TorchSpec: Speculative Decoding Training at Scale](https://lightseek.org/blog/torchspec-speculative-decoding-training-at-scale.html)
- Released draft model: [lightseekorg/kimi-k2.5-eagle3](https://huggingface.co/lightseekorg/kimi-k2.5-eagle3)

## Table of Contents

- [Architecture Overview](#architecture-overview)
- [Quick Start](#quick-start)
- [Setup](#setup)
- [Examples](#examples)
- [Training Modes](#training-modes)
- [Checkpoint Conversion](#checkpoint-conversion)
- [Metrics Reporting](#metrics-reporting)
- [Troubleshooting](#troubleshooting)

## Architecture Overview

<p align="center">
<img src="docs/torchspec_architecture.png" alt="TorchSpec Architecture" width="100%">
</p>

## Setup
TorchSpec is built around a disaggregated training pipeline:

### Choose Your Backend
- **Inference engines** generate target-model hidden states with either vLLM or SGLang.
- **Mooncake store** transfers tensors between inference and training without materializing them on disk.
- **Training workers** consume streamed hidden states to train speculative decoding draft models.

TorchSpec supports two inference backends:
This separation keeps the training side focused on optimization while letting the inference side scale for hidden-state generation throughput.

## Quick Start

| Backend | Best For | Installation |
|---------|----------|--------------|
| **vLLM** | Flexibility, easier deployment | `./tools/build_conda.sh 1 vllm` |
| **SGLang** | Production workloads, high throughput | `./tools/build_conda.sh 1 sglang` |
| **Both** | Development, comparison testing | `./tools/build_conda.sh 1 both` |
Train an Eagle3 draft model for Qwen3-8B on a single node with 4 GPUs (2 for training and 2 for inference):

```bash
./examples/qwen3-8b-single-node/run.sh
```

Override config values directly from the CLI:

```bash
./examples/qwen3-8b-single-node/run.sh training.learning_rate=5e-5 training.num_train_steps=500
```

## Setup

### Quick Setup

Expand All @@ -31,58 +68,50 @@ micromamba activate torchspec
```

To install into your current environment instead:

```bash
./tools/build_conda.sh current sglang # or 'vllm' or 'both'
```

Optionalinstall Flash Attention:
Optional: install Flash Attention support:

```bash
pip install -e ".[fa]"
```

### Backend-Specific Usage

**vLLM:**
```bash
./examples/qwen3-8b-single-node/run.sh --config configs/vllm_qwen3_8b.yaml
```
**vLLM**

**SGLang:**
```bash
./examples/qwen3-8b-single-node/run.sh
./examples/qwen3-8b-single-node/run.sh --config configs/vllm_qwen3_8b.yaml
```

TorchSpec uses vLLM's **Worker Extension** mechanism to hook into the model's forward pass and capture hidden states directly in the worker processes. This avoids RPC serialization issues and enables reliable hidden states extraction.

## Quick Start

Train an Eagle3 draft model for Qwen3-8B using inference engine (4 GPUs: 2 training + 2 inference):
**SGLang**

```bash
./examples/qwen3-8b-single-node/run.sh
```

Override any config value via CLI:

```bash
./examples/qwen3-8b-single-node/run.sh training.learning_rate=5e-5 training.num_train_steps=500
```
TorchSpec uses vLLM's **Worker Extension** mechanism to hook into the model forward pass and capture hidden states directly inside worker processes, which avoids RPC serialization overhead during extraction. For SGLang, TorchSpec applies a patch to the existing codebase to enable hidden-state extraction.

## Examples

| Example | Backend | Model |
|---------|---------|-------|
| [hf-quickstart](examples/hf-quickstart/) | HuggingFace | Qwen3-8B |
| [qwen3-8b-single-node](examples/qwen3-8b-single-node/) | Inference Engine | Qwen3-8B |
| [kimi-k25-2node-h200](examples/kimi-k25-2node-h200/) | Inference Engine | Kimi-K2.5 |
| [kimi-k25-3node-h100](examples/kimi-k25-3node-h100/) | Inference Engine | Kimi-K2.5 |
| [qwen3-8b-single-node](examples/qwen3-8b-single-node/) | Inference engine | Qwen3-8B |
| [kimi-k25-2node-h200](examples/kimi-k25-2node-h200/) | Inference engine | Kimi-K2.5 |
| [kimi-k25-3node-h100](examples/kimi-k25-3node-h100/) | Inference engine | Kimi-K2.5 |
| [minimax-m25-5node-h200](examples/minimax-m25-5node-h200/) | Inference engine | MiniMax-M2.5 |

See [examples/README.md](examples/README.md) for more details about each example.

See [examples/README.md](examples/README.md) for details.
## Training Modes

## Resume Vs Continual Training
### Resume vs. Continual Training

Both modes use `training.load_path`, but they restore different state:
Both modes use `training.load_path`, but they restore different states:

| Goal | `training.load_path` | `training.continual_training` | What gets restored |
|------|----------------------|-------------------------------|--------------------|
Expand Down Expand Up @@ -119,11 +148,10 @@ Convert an FSDP checkpoint to HuggingFace format:
python tools/convert_to_hf.py --input-dir ./outputs/my_experiment/iter_0010000/
```

Vocabulary pruning — reducing the draft model's `lm_head` to a smaller token set and emitting `d2t`/`t2d` mappingscan be applied either **during training** (pre-pruning) or **at conversion time** (post-pruning):
Vocabulary pruning, which reduces the draft model `lm_head` to a smaller token set and emits `d2t` and `t2d` mappings, can be applied either during training or at conversion time.

- **Pre-pruning**: set `draft_vocab_size` in your training config. The checkpoint already contains the pruned `lm_head` and `d2t`/`t2d` buffers. Use the basic conversion command above — no extra flags needed.

- **Post-pruning**: train with the full vocabulary, then prune at conversion time by passing `--prune-vocab` along with a representative dataset to compute token frequencies:
- **Pre-pruning**: set `draft_vocab_size` in your training config. The checkpoint already contains the pruned `lm_head` and `d2t`/`t2d` buffers, so the basic conversion command is enough.
- **Post-pruning**: train with the full vocabulary, then pass `--prune-vocab` at conversion time together with a representative dataset to compute token frequencies.

```bash
python tools/convert_to_hf.py \
Expand All @@ -140,27 +168,27 @@ Pass `--cache-dir ./cache` to reuse the tokenized dataset cache from training.

## Metrics Reporting

W&B logging is disabled by default (`report_to: none`). To enable it, set `report_to: wandb` in your config and supply your API key.
W&B logging is disabled by default with `report_to: none`. To enable it, set `report_to: wandb` in your config and provide your API key.

## Troubleshooting

Set `TORCHSPEC_LOG_LEVEL=DEBUG` for verbose logging when diagnosing issues:
Set `TORCHSPEC_LOG_LEVEL=DEBUG` for more verbose logging when diagnosing issues:

```bash
TORCHSPEC_LOG_LEVEL=DEBUG ./examples/qwen3-8b-single-node/run.sh
```

### Per-Rank File Logging

Set `TORCHSPEC_LOG_DIR` to an <b> absolute path </b> on a shared filesystem (NFS) to enable per-rank log files for every Ray actor (training and inference):
Set `TORCHSPEC_LOG_DIR` to an absolute path on a shared filesystem (NFS) to enable per-rank log files for every Ray actor on both training and inference:

```bash
export TORCHSPEC_LOG_DIR=/my_project/running_logs
```

This creates a structured directory with one file per actor, organized by role and node:

```
```text
running_logs/
training/
10.0.0.1/
Expand All @@ -175,7 +203,7 @@ running_logs/
inference_g0_rank1_20260301_080015.log
```

The path must be an absolute path on a shared filesystem (NFS) accessible from all nodes. If `TORCHSPEC_LOG_DIR` is not set or the path is not writable, per-rank file logging is disabled and only Ray's default stdout/stderr capture is used.
The path must be absolute and writable from all nodes. If `TORCHSPEC_LOG_DIR` is unset or not writable, per-rank file logging stays disabled and Ray falls back to stdout/stderr capture.

| Issue | Reference |
|-------|-----------|
Expand Down
2 changes: 1 addition & 1 deletion configs/draft_models/kimi_k25_eagle3.json
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
"model_type": "llama",
"num_attention_heads": 64,
"num_hidden_layers": 1,
"num_key_value_heads": 16,
"num_key_value_heads": 64,
"rms_norm_eps": 1e-06,
"rope_scaling": null,
"rope_theta": 1000000,
Expand Down
27 changes: 27 additions & 0 deletions configs/draft_models/minimax_m25_eagle3.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
{
"architectures": [
"LlamaForCausalLMEagle3"
],
"attention_bias": false,
"attention_dropout": 0.0,
"head_dim": 128,
"hidden_act": "silu",
"hidden_size": 3072,
"initializer_range": 0.02,
"intermediate_size": 9216,
"max_position_embeddings": 196608,
"model_type": "llama",
"num_attention_heads": 48,
"num_hidden_layers": 1,
"num_key_value_heads": 48,
"rms_norm_eps": 1e-06,
"rope_scaling": null,
"rope_theta": 5000000,
"sliding_window": null,
"tie_word_embeddings": false,
"torch_dtype": "bfloat16",
"transformers_version": "4.51.0",
"use_cache": true,
"use_sliding_window": false,
"vocab_size": 200064
}
60 changes: 60 additions & 0 deletions configs/hf_qwen3_8b_1gpu.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
# Eagle3 training config — single GPU (colocate mode)
#
# GPU allocation:
# - 1 GPU shared between inference (HF) and training (no FSDP)
# - Requires H100 80GB or A100 80GB
#
# Usage:
# python -m torchspec.train_entry --config configs/hf_qwen3_8b_1gpu.yaml

model:
target_model_path: Qwen/Qwen3-8B
trust_remote_code: true

dataset:
train_data_path: ../examples/data/sample_conversations.jsonl
chat_template: qwen
prompt_key: conversations

training:
attention_backend: flex_attention
colocate: true
micro_batch_size: 1
draft_accumulation_steps: 1
learning_rate: 1e-4
max_concurrent_batches: 1
max_grad_norm: 0.5
max_seq_length: 4096
num_epochs: 1
seed: 42
training_num_gpus_per_node: 1
training_num_nodes: 1
ttt_length: 7
save_per_epoch: true
warmup_ratio: 0.015
train_env_vars: '{"TORCHINDUCTOR_MAX_AUTOTUNE_GEMM_BACKENDS": "ATEN,TRITON"}'

inference:
inference_engine_type: hf
inference_num_gpus: 1
inference_num_gpus_per_engine: 1
inference_num_gpus_per_node: 1
max_sample_pool_size: 32
inference_buffer_threshold: 16
inference_batch_size: 4

mooncake:
master_server_address: null
metadata_server: null
protocol: tcp
global_segment_size: 16GB
local_buffer_size: 4GB

output_dir: ./outputs/qwen3-8b-eagle3-1gpu
cache_dir: ./cache/qwen3-8b-1gpu
model_download_dir: null

debug:
save_debug_train_data: null
debug_train_only: false
debug_inference_only: false
Loading