Skip to content

Feature/dflash training#51

Draft
zhubohao911 wants to merge 121 commits intotorchspec-project:mainfrom
zhubohao911:feature/dflash-training
Draft

Feature/dflash training#51
zhubohao911 wants to merge 121 commits intotorchspec-project:mainfrom
zhubohao911:feature/dflash-training

Conversation

@zhubohao911
Copy link
Copy Markdown

No description provided.

Xing Han and others added 30 commits March 19, 2026 01:51
- DFlash draft model with dual-source KV attention and W_proj
- DFlash training wrapper with block-causal FlexAttention mask, anchor sampling, and decay-weighted CE loss
- DFlash trainer with FSDP2 support
- Unit tests covering config, model forward, mask, loss, and mini training loop
- RunPod test script and sample config for Qwen2.5-7B

Made-with: Cursor
…el, and training entry

- Add DFlash-specific parameters to TrainingConfig (block_size, num_anchors, loss_decay_gamma, num_target_layers)
- Generalize Eagle3TargetModel to support N auxiliary layers (was hardcoded to 3)
- Add config-based trainer dispatch in TrainerActor (DFlashConfig → DFlashTrainer)
- Auto-set aux_hidden_states_layers for DFlash in train_entry.py
- Add get_default_dflash_aux_layer_ids() utility using uniform spacing algorithm

Made-with: Cursor
…onfig

- Add framework integration tests: trainer dispatch, target model generalization, Mooncake buffer sizing, aux layer IDs, train entry integration, config roundtrip
- Add training quality tests: convergence, accuracy improvement, gradient health, multi-layer, padding
- Add architecture comparison tests: parameter counts, context projection, CE vs KL, block-parallel, decay weights
- Add DFlash YAML config for Qwen3-8B GPU training
- Add implementation log to gitignore

Made-with: Cursor
- configs/hf_qwen3_8b_1gpu.yaml: Eagle3 single-GPU colocate mode (HF backend)
- configs/hf_qwen3_8b_dflash_1gpu.yaml: DFlash single-GPU colocate mode (HF backend)
- scripts/runpod_dflash_train.sh: auto-detects GPU count (1/2/4), installs
  PyTorch 2.6+ for FlexAttention, runs Eagle3 then DFlash sequentially

Made-with: Cursor
- Remove interactive read prompt for WandB (auto-skips when no API key)
- Add runpod_ssh.sh helper for expect-based PTY workaround

Made-with: Cursor
RunPod images already have PyTorch+CUDA. Inheriting system packages
avoids a ~2.5GB redundant download. Only PyTorch 2.6+ upgrade needed.

Made-with: Cursor
Allows running with HF backend only, without installing SGLang or vLLM.
Needed for minimal single-GPU RunPod setups.

Made-with: Cursor
Move SglEngine and VllmEngine imports into the functions that use them,
so HF-only setups don't require these packages to be installed.

Made-with: Cursor
The eval cache generation times out in single-GPU colocate mode because
the AsyncInferenceManager fails to dispatch buffered samples to the
HFEngine. Disabling eval for now to allow training to proceed.

Made-with: Cursor
1. dflash_draft_config.json: Update dimensions to match Qwen3-8B
   (hidden_size 3584->4096, target_hidden_size 3584->4096,
   intermediate_size 18944->12288, num_attention_heads 28->32,
   num_key_value_heads 4->8, target_num_hidden_layers 28->36,
   vocab_size 152064->151936, max_position_embeddings 32768->40960).
   The old values caused RuntimeError during embedding load because
   the draft model's hidden_size must match the target for shared
   embeddings.

2. Add TORCHINDUCTOR_MAX_AUTOTUNE_GEMM_BACKENDS=ATEN,TRITON to
   train_env_vars in both 1-GPU configs. Without this, FlexAttention's
   torch.compile hits NoValidChoicesError during backward pass
   autotuning on RunPod's --no-deps PyTorch install.

Made-with: Cursor
1. flex_attention.py: Set inductor max_autotune_gemm_backends to
   include ATEN at import time. Env vars set by Ray runtime_env
   may not take effect because torch._inductor.config is read at
   module import time. Setting it directly in code ensures ATEN
   is always available as a GEMM fallback.

2. dflash.py: Cast concatenated hidden states to context_proj weight
   dtype in extract_context_feature(). Target model hidden states are
   BFloat16 but the draft model's projection layer initializes as
   float32, causing RuntimeError on matmul.

Made-with: Cursor
The inductor backend hits NoValidChoicesError during backward pass
compilation on PyTorch 2.6.0+cu124 with H100. Using aot_eager avoids
the inductor kernel selection issue by running in eager mode with
proper autograd support.

Made-with: Cursor
The draft model is created in float32 by default, but the target
model's LM head weight and hidden states are in bfloat16. This caused
RuntimeError in F.linear(draft_hidden, lm_head_weight) due to dtype
mismatch. Converting the draft model to bfloat16 after embedding
loading ensures consistent dtypes throughout the forward pass.

Made-with: Cursor
- Vectorize _prepare_noise_input using torch.gather (eliminates 2760 scalar
  GPU-CPU syncs per step via .item() calls)
- Add fallback in _sample_anchor_positions: when loss_mask has no valid
  positions before valid_end, sample uniformly instead of returning all-zero
  anchors
- Add targeted debug logging (first 5 steps) to trace shapes, values, and
  code paths through DFlashModel.forward() and _compute_loss()
- Add implementation log v7 with Session 5 covering Eagle3 success, DFlash
  zero-loss investigation, and Issues 8-13

Made-with: Cursor
Remove the gitignore entry so the implementation log is versioned
alongside the code it documents.

Made-with: Cursor
The dflash_draft_config.json was updated from Qwen2.5-7B to Qwen3-8B
dimensions in Session 5, but the test assertion wasn't updated.
The controller reads 'train/avg_loss' and 'train/avg_acc' (matching
Eagle3Trainer), but DFlashTrainer was returning 'train/loss' and
'train/accuracy'. This caused the progress bar to always show 0.000
even though the loss was being computed correctly internally.
Root cause: metric key mismatch (train/loss vs train/avg_loss).
DFlash GPU results: 0.477 loss, 89.4% accuracy, 2.5x faster than Eagle3.
Implements speculative decoding with TorchSpec's DFlash draft model.
Benchmarks target-only baseline vs DFlash spec-decode, measuring
acceptance length (τ), tokens/sec, and wall-clock speedup.
…cript

- Add hf_qwen3_8b_dflash_1gpu_bench.yaml for short 1-GPU training runs
- extract_dflash_checkpoint.py: load FSDP .distcp with dist_cp (match convert_to_hf)
- benchmark_dflash_inference: single prefill, full draft_model forward, wikitext
  quick-train with cached hidden states, cosine LR + warmup, higher default steps
- dflash_implementation_log: Session 7 inference benchmark and analysis

Made-with: Cursor
… loss_mask, RoPE order

Cross-checked against SpecForge reference implementation and fixed:
1. Draft model num_hidden_layers: 1 → 5 (matching SpecForge qwen3-8b-dflash.json)
2. Added Q-norm and K-norm (RMSNorm on head_dim) to DFlashAttention
3. Added block_keep_mask to anchor sampling, attention mask, and loss computation
4. Gather loss_mask at label positions for per-position loss weighting
5. Restructured RoPE: concatenate K → K-norm → RoPE (matching SpecForge order)

Also: max_seq_length 16384 → 4096 in sglang config, updated tests for all changes.
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
The previous formula produced [1, 10, 18, 26, 35] for Qwen3-8B (36 layers),
but SpecForge uses [1, 9, 17, 25, 33] (end = num_hidden_layers - 3). The old
layer 35 + SGLang's +1 capture offset = 36, which is out of bounds for
range(36), causing only 4 of 5 hooks to fire and a Mooncake size mismatch crash.

Also set explicit target_layer_ids in dflash_draft_config.json as safety net.
- Align config with SpecForge defaults: LR 6e-4 (was 1e-4), warmup 0.04
  (was 0.015), max_grad_norm 1.0 (was 0.5), num_epochs 6 (was 1)
- Clean up extract_context_feature: remove redundant list copy
- Document Session 8: 4-GPU SGLang training validation, build_target_layer_ids
  bug fix, inference benchmark results (τ=1.03 at 200 steps), and full
  cross-check findings vs SpecForge reference
Matches SpecForge's minimum 2*block_size filtering for DFlash training.
Samples with fewer loss-eligible tokens than min_loss_tokens are skipped
during data preparation (default 0 preserves backward compatibility).
- prepare_perfectblend.py: Downloads mlabonne/open-perfectblend,
  normalizes ShareGPT format, filters invalid samples, outputs JSONL
- runpod_phase_c.sh: Full 4-GPU training pipeline (pull code, prep data,
  verify SGLang patch, launch training with SpecForge-matched hyperparams)
Xing Han and others added 30 commits March 22, 2026 21:12
Covers Docker image setup, manual installation, data prep, multi-GPU
and single-GPU training, monitoring, benchmarking, and troubleshooting
for any GPU environment (not just RunPod).

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
…aining

Port DFlash training from RunPod to Modal with 8x H100 support. Add
--extra-overrides CLI parameter for runtime config experiments without
YAML changes. Document speed tuning results across 4 configurations,
finding batch=4 + anchors=256 + 2 inference GPUs as the fastest at
22-25 samples/s (476s for 200 steps).
Document best-tested config (2 inference + 6 training, batch=4,
anchors=256) and tuning insights from speed experiments.
…e GPU provisioning

Previously, both train_hf (2x H100) and train_sglang (8x H100) were registered
as @app.function in a single file, causing Modal to show train_hf in the dashboard
even when only running 4+ GPU SGLang tasks. Moved HF backend (1-2 GPU) into a
dedicated modal_dflash_train_hf.py so each script only provisions the GPUs it needs.
Phase 2 tests show anchors=512 matches anchors=256 speed (446-457s vs
476s) when using 2 inference GPUs, which is essential to prevent pool
starvation. Updated recommendation to quality-optimized 512-D config
(anchors=512, batch=1, accum=4, 2+6 GPUs) for best acceptance length.
512-E (batch=1, accum=4, 4 infer + 4 train) is fastest at 368s for
200 steps. 512-F (batch=2) slower at 394s. Fix misleading GPU count
in print statements: now parses extra_overrides to show actual
inference/training split instead of hardcoded "1 infer + 7 train".
output_dir was resolving to /workspace/TorchSpec/outputs/ (ephemeral
container filesystem) instead of /workspace/outputs/ (Modal volume
mount). Checkpoints were lost on container exit. Now explicitly sets
output_dir={OUTPUTS_DIR}/{run_id} so checkpoints, configs, and logs
all land on the persistent 'torchspec-outputs' volume. Added download
instructions to completion message.
Converts final FSDP checkpoint to HuggingFace format using
tools/convert_to_hf.py with explicit dflash_draft_config.json,
then optionally uploads to a HuggingFace Hub repo via --hf-repo flag.
Loads pre-trained draft model from HuggingFace (Xingh3/dflash-qwen3-8b-1epoch)
and benchmarks speculative decoding vs target-only baseline on Modal H100.

Initial results (50 prompts, 256 tokens): 1.39x speedup, τ=2.13
- Change `private=True` to `private=False` on `api.create_repo()` calls
  in both `_convert_and_upload_hf` and `convert_checkpoint`. Private repos
  on the free HF tier block large file uploads via LFS.
- Add `wandb-secret` to Modal secrets for both training and conversion
  functions so WandB metric logging works during training runs.
- Update WandB setup comment with link to key URL.
The original Transformers-backend benchmark used 10 hand-written prompts,
making results incomparable to z-lab's published DFlash numbers. This
overhaul adopts z-lab's exact evaluation methodology:

- Add `load_benchmark_dataset()` supporting all 10 z-lab datasets
  (GSM8K, MATH-500, AIME24/25, HumanEval, MBPP, LiveCodeBench,
  SWE-Bench, MT-Bench, Alpaca) with matching prompt formatting.
- Replace fixed prompt list with dataset-driven evaluation loop that
  handles single-turn and multi-turn (MT-Bench) conversations.
- Apply chat template with `enable_thinking=False` to match z-lab.
- Add z-lab reference τ and speedup values for comparison in output.
- Add cross-dataset summary table in CLI entrypoint.
- Increase timeout from 2h to 4h for full dataset runs.
- Set deterministic seeds for reproducibility.
- Update default draft model to 3-epoch checkpoint.
The Transformers backend was too slow for full evaluation (~130 tok/s).
This adds an SGLang-based benchmark that achieves ~350 tok/s with KV
cache, CUDA graphs, and tensor parallelism on 2x H100.

Five issues were discovered and fixed during development:

1. SGLang v0.5.8.post1 lacks DFLASH speculative algorithm support.
   Fixed by installing from sgl-project/sglang PR #16818.

2. PR #16818 omits `set_dflash_layers_to_capture` on Qwen2/Qwen3
   models (only has eagle3 variant). Patched via sed during image
   build to add alias method.

3. TorchSpec saves weights with different names than z-lab/SGLang
   expects (context_proj→fc, context_norm→hidden_norm,
   final_norm→norm, extra embed_tokens). Added runtime safetensors
   key remapping in `patch_draft_model_for_sglang()`.

4. SGLang TP workers survive `SIGTERM`, causing Modal container
   timeout before results print. Fixed with `start_new_session=True`
   and `SIGKILL` on the process group.

5. Modal containers intermittently fail `snapshot_download` with
   `ConnectionResetError`. Added exponential backoff retry (5
   attempts).

Result: τ=3.21 on GSM8K (z-lab reference: 3.38, gap: -0.17).
- fix_hf_model_config.py: Converts TorchSpec-exported model config from
  model_type="dflash" (unrecognized by Transformers/SGLang) to z-lab
  format (model_type="qwen3" + nested dflash_config + auto_map).
  Also uploads z-lab's dflash.py/modeling_dflash.py/utils.py for
  trust_remote_code loading. Required for SGLang DFlash inference.

- training_metrics_3epoch.png: Loss and accuracy curves from the full
  200K-sample, 3-epoch training run.
Run all 10 z-lab datasets on 1x H100 (tp=1) via SGLang backend.
Average τ=3.06 across datasets, reaching 78% of z-lab reference on math
benchmarks with 33x less training data. Comprehensive results with
per-dataset τ distributions and domain analysis added to results doc.
Verified from Modal training logs: actual launch used --dataset-size
200000, producing 190,095 samples after filtering (188,977 usable).
Training was 23,622 optimizer steps across 3 epochs with global_batch=24.
Updated gap analysis: 8.5x fewer sample passes vs z-lab (not 33x).
- Expose min_lr and weight_decay in TrainingConfig (was hardcoded 0)
- Plumb through BF16Optimizer → LRSchedulerWithWarmup in both trainers
- Update YAML config: accum 4→2, min_lr=6e-5, weight_decay=0.01
- Fix WandB config prefix (logging.* not training.*)
- Support epoch-based training (--num-epochs without --max-steps)
- Document Phase G convergence analysis and Phase H training plan
…ed secrets

- Add Phase H inference benchmark results (10 datasets, SGLang backend):
  τ improved +0.59 avg over Phase G (3.06 → 3.45), math gap to z-lab
  narrowed from 21% to 6.4%
- Remove naive HF benchmark (modal_dflash_benchmark.py) in favor of
  SGLang-based benchmark with KV cache, CUDA graphs, paged attention
- Update SGLang benchmark to support Modal volume paths for draft model
  loading (no HuggingFace upload required)
- Remove hardcoded API tokens from train script docstring, add
  setup_modal_secrets.sh for safe secret management
- Pin TorchSpec commit in Modal image for reproducible builds
load_hf_dataset now tries datasets.load_dataset() first for Hub paths,
falling back to manual JSON download. Also strips columns not needed by
the training pipeline (e.g. pre-tokenized input_ids, attention_mask)
to reduce memory during streaming iteration.
Add speedup methodology analysis comparing our SGLang E2E measurement
vs z-lab's decode-only Transformers benchmark. Add decode-only timing
to benchmark script. Patch training script for Arrow/Parquet Hub datasets.
BUG 33: Remove dead gradient_checkpointing flag in DFlashModel — the
  parameter was stored but never used (no checkpoint wrapping).

BUG 34: Fix FlexAttention recompilation from variable Q_LEN — anchor
  sampling now always returns num_anchors slots with block_keep_mask
  for invalids, keeping Q_LEN constant across all steps.

BUG 35: Fix no_sync silently no-oping under FSDP2 fully_shard — use
  set_requires_gradient_sync(bool) for FSDP2, fall back to no_sync()
  for replicate (DDP) strategy.

OPT 1: Per-layer FSDP sharding — each decoder layer is now an
  independent FSDP unit, enabling comm/compute overlap and reducing
  peak memory.

OPT 2: Async H2D transfers — replace synchronous .cuda() with
  .to(device, non_blocking=True) in DFlash _forward().

OPT 3: Pre-allocate fp32 grad buffers in BF16Optimizer — eliminates
  per-step tensor allocation/deallocation, reducing GC pressure.

Document all findings in dflash_issues.md (Session 13 section).
Training was restarting from step 0 because load_path was never set.
Added --resume CLI flag that reads latest_checkpointed_iteration.txt
from the output volume and passes training.load_path to the trainer.
Also reduced save_interval to 1000 and max_checkpoints to 2 for
more frequent checkpointing with less disk usage.
UltraChat model (208K samples, 43K steps) underperforms Phase H across
all domains (τ=3.04 vs 3.45 overall). Deep code audit against SpecForge
PRs #427/#472/#473 confirms no training bugs — the τ gap to z-lab is
fully explained by recipe differences (data volume, seq_len, epochs).
Scripts:
- Move RunPod scripts to scripts/runpod/
- Move Modal scripts to scripts/modal/
- Move utility scripts to scripts/tools/
- Remove obsolete scripts (benchmark_dflash_inference.py,
  fix_hf_model_config.py, modal_dflash_train_hf.py, modal_example.py)
- Update all self-referencing paths in script comments/docstrings

Docs (docs/inference/dflash/):
- Merge TRAINING_GUIDE.md + dflash_runpod_guide.md into one guide
  with Generic/Modal/RunPod/Troubleshooting sections
- Merge dflash_training_results.md + dflash_modal_training_results.md
  into training_results.md (Phases A-J, deduplicated)
- Fold dflash_pending_work.md into issues.md as Future Work section
- Fold dflash_training_test_plan.md success criteria into results header
- Rename dflash_overview.md to README.md
- Rename dflash_issues.md to issues.md
- Rename specforge_dflash_training_reference.md to specforge_reference.md
- Remove raw benchmark logs (already summarized in results)
- Fix all stale cross-references across repo
- Add --sweep-config flag to modal_dflash_train.py for parallel HP sweeps
- Add sweep configs for Phase 1 (screening) and Phase 2 (validation)
- Handle WSD schedule in dflash_trainer.py without modifying shared framework
- Document Phase K results: WSD and accum=1 both beat cosine baseline by ~12%
… training

- Add dflash_run_id parameter to train_sglang/sweep runner so each run
  gets its own output directory (fixes checkpoint collision bug where all
  sweep runs overwrote the same dflash-qwen3-8b/ path)
- Propagate --resume flag through sweep runner to train_sglang.spawn()
- Add WSD schedule parameters (wsd_decay_ratio, wsd_decay_style) to
  TrainingConfig for DFlash trainer
- Add phase2_resume.json sweep config for resumed Phase 2 runs
- Add Phase 2 results: P2-WSD vs P2-baseline vs P2-accum1 training curves
  from 24h Modal runs, documenting checkpoint collision bug and fix
- Add P2-accum1 inference benchmark (30 samples/dataset, 10 datasets):
  τ=3.83 math avg, beats z-lab on GSM8K (3.79 vs 3.38)
- Add per-dataset τ distribution tables showing 80-100% of math/code
  requests achieve τ≥3, with τ≥5 emerging as a real category (up to 37%)
- Add distribution comparison vs Phase G showing rightward shift that
  mean τ alone understates
- Update success criteria to reference Phase K P2-accum1
- P2-WSD (WSD schedule) achieves τ=3.94 math avg, our best result
  and only 2.7% below z-lab (vs 6.4% for Phase H)
- Decode-only speedup hits 3.02x on livecodebench, meeting the 3.0x target
- Add full cross-dataset comparison (P2-WSD vs P2-accum1 vs Phase H vs z-lab),
  per-dataset τ distributions, and distribution shift analysis
- Update success criteria: speedup target now met
Replace ValueError crash in _sample_anchor_positions with a graceful
fallback that returns all-False keep_mask (zero loss) when no valid
anchor positions exist. This fixes a training crash on long-context
models (e.g. Kimi 2.5 with seq_len=65536) where some batches have
all-zero loss_mask in the anchor-eligible range.

Also plumb min_loss_tokens through the online mooncake data path
(data_fetcher → trainer) so samples with too few supervised tokens
are filtered at loading time — previously this filter only worked
in the offline preprocessing path.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants