Feature/dflash training#51
Draft
zhubohao911 wants to merge 121 commits intotorchspec-project:mainfrom
Draft
Conversation
- DFlash draft model with dual-source KV attention and W_proj - DFlash training wrapper with block-causal FlexAttention mask, anchor sampling, and decay-weighted CE loss - DFlash trainer with FSDP2 support - Unit tests covering config, model forward, mask, loss, and mini training loop - RunPod test script and sample config for Qwen2.5-7B Made-with: Cursor
…el, and training entry - Add DFlash-specific parameters to TrainingConfig (block_size, num_anchors, loss_decay_gamma, num_target_layers) - Generalize Eagle3TargetModel to support N auxiliary layers (was hardcoded to 3) - Add config-based trainer dispatch in TrainerActor (DFlashConfig → DFlashTrainer) - Auto-set aux_hidden_states_layers for DFlash in train_entry.py - Add get_default_dflash_aux_layer_ids() utility using uniform spacing algorithm Made-with: Cursor
…onfig - Add framework integration tests: trainer dispatch, target model generalization, Mooncake buffer sizing, aux layer IDs, train entry integration, config roundtrip - Add training quality tests: convergence, accuracy improvement, gradient health, multi-layer, padding - Add architecture comparison tests: parameter counts, context projection, CE vs KL, block-parallel, decay weights - Add DFlash YAML config for Qwen3-8B GPU training - Add implementation log to gitignore Made-with: Cursor
- configs/hf_qwen3_8b_1gpu.yaml: Eagle3 single-GPU colocate mode (HF backend) - configs/hf_qwen3_8b_dflash_1gpu.yaml: DFlash single-GPU colocate mode (HF backend) - scripts/runpod_dflash_train.sh: auto-detects GPU count (1/2/4), installs PyTorch 2.6+ for FlexAttention, runs Eagle3 then DFlash sequentially Made-with: Cursor
- Remove interactive read prompt for WandB (auto-skips when no API key) - Add runpod_ssh.sh helper for expect-based PTY workaround Made-with: Cursor
RunPod images already have PyTorch+CUDA. Inheriting system packages avoids a ~2.5GB redundant download. Only PyTorch 2.6+ upgrade needed. Made-with: Cursor
Allows running with HF backend only, without installing SGLang or vLLM. Needed for minimal single-GPU RunPod setups. Made-with: Cursor
Move SglEngine and VllmEngine imports into the functions that use them, so HF-only setups don't require these packages to be installed. Made-with: Cursor
The eval cache generation times out in single-GPU colocate mode because the AsyncInferenceManager fails to dispatch buffered samples to the HFEngine. Disabling eval for now to allow training to proceed. Made-with: Cursor
1. dflash_draft_config.json: Update dimensions to match Qwen3-8B (hidden_size 3584->4096, target_hidden_size 3584->4096, intermediate_size 18944->12288, num_attention_heads 28->32, num_key_value_heads 4->8, target_num_hidden_layers 28->36, vocab_size 152064->151936, max_position_embeddings 32768->40960). The old values caused RuntimeError during embedding load because the draft model's hidden_size must match the target for shared embeddings. 2. Add TORCHINDUCTOR_MAX_AUTOTUNE_GEMM_BACKENDS=ATEN,TRITON to train_env_vars in both 1-GPU configs. Without this, FlexAttention's torch.compile hits NoValidChoicesError during backward pass autotuning on RunPod's --no-deps PyTorch install. Made-with: Cursor
1. flex_attention.py: Set inductor max_autotune_gemm_backends to include ATEN at import time. Env vars set by Ray runtime_env may not take effect because torch._inductor.config is read at module import time. Setting it directly in code ensures ATEN is always available as a GEMM fallback. 2. dflash.py: Cast concatenated hidden states to context_proj weight dtype in extract_context_feature(). Target model hidden states are BFloat16 but the draft model's projection layer initializes as float32, causing RuntimeError on matmul. Made-with: Cursor
The inductor backend hits NoValidChoicesError during backward pass compilation on PyTorch 2.6.0+cu124 with H100. Using aot_eager avoids the inductor kernel selection issue by running in eager mode with proper autograd support. Made-with: Cursor
The draft model is created in float32 by default, but the target model's LM head weight and hidden states are in bfloat16. This caused RuntimeError in F.linear(draft_hidden, lm_head_weight) due to dtype mismatch. Converting the draft model to bfloat16 after embedding loading ensures consistent dtypes throughout the forward pass. Made-with: Cursor
- Vectorize _prepare_noise_input using torch.gather (eliminates 2760 scalar GPU-CPU syncs per step via .item() calls) - Add fallback in _sample_anchor_positions: when loss_mask has no valid positions before valid_end, sample uniformly instead of returning all-zero anchors - Add targeted debug logging (first 5 steps) to trace shapes, values, and code paths through DFlashModel.forward() and _compute_loss() - Add implementation log v7 with Session 5 covering Eagle3 success, DFlash zero-loss investigation, and Issues 8-13 Made-with: Cursor
Remove the gitignore entry so the implementation log is versioned alongside the code it documents. Made-with: Cursor
The dflash_draft_config.json was updated from Qwen2.5-7B to Qwen3-8B dimensions in Session 5, but the test assertion wasn't updated.
The controller reads 'train/avg_loss' and 'train/avg_acc' (matching Eagle3Trainer), but DFlashTrainer was returning 'train/loss' and 'train/accuracy'. This caused the progress bar to always show 0.000 even though the loss was being computed correctly internally.
Root cause: metric key mismatch (train/loss vs train/avg_loss). DFlash GPU results: 0.477 loss, 89.4% accuracy, 2.5x faster than Eagle3.
Implements speculative decoding with TorchSpec's DFlash draft model. Benchmarks target-only baseline vs DFlash spec-decode, measuring acceptance length (τ), tokens/sec, and wall-clock speedup.
…cript - Add hf_qwen3_8b_dflash_1gpu_bench.yaml for short 1-GPU training runs - extract_dflash_checkpoint.py: load FSDP .distcp with dist_cp (match convert_to_hf) - benchmark_dflash_inference: single prefill, full draft_model forward, wikitext quick-train with cached hidden states, cosine LR + warmup, higher default steps - dflash_implementation_log: Session 7 inference benchmark and analysis Made-with: Cursor
… loss_mask, RoPE order Cross-checked against SpecForge reference implementation and fixed: 1. Draft model num_hidden_layers: 1 → 5 (matching SpecForge qwen3-8b-dflash.json) 2. Added Q-norm and K-norm (RMSNorm on head_dim) to DFlashAttention 3. Added block_keep_mask to anchor sampling, attention mask, and loss computation 4. Gather loss_mask at label positions for per-position loss weighting 5. Restructured RoPE: concatenate K → K-norm → RoPE (matching SpecForge order) Also: max_seq_length 16384 → 4096 in sglang config, updated tests for all changes.
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
The previous formula produced [1, 10, 18, 26, 35] for Qwen3-8B (36 layers), but SpecForge uses [1, 9, 17, 25, 33] (end = num_hidden_layers - 3). The old layer 35 + SGLang's +1 capture offset = 36, which is out of bounds for range(36), causing only 4 of 5 hooks to fire and a Mooncake size mismatch crash. Also set explicit target_layer_ids in dflash_draft_config.json as safety net.
- Align config with SpecForge defaults: LR 6e-4 (was 1e-4), warmup 0.04 (was 0.015), max_grad_norm 1.0 (was 0.5), num_epochs 6 (was 1) - Clean up extract_context_feature: remove redundant list copy - Document Session 8: 4-GPU SGLang training validation, build_target_layer_ids bug fix, inference benchmark results (τ=1.03 at 200 steps), and full cross-check findings vs SpecForge reference
Matches SpecForge's minimum 2*block_size filtering for DFlash training. Samples with fewer loss-eligible tokens than min_loss_tokens are skipped during data preparation (default 0 preserves backward compatibility).
- prepare_perfectblend.py: Downloads mlabonne/open-perfectblend, normalizes ShareGPT format, filters invalid samples, outputs JSONL - runpod_phase_c.sh: Full 4-GPU training pipeline (pull code, prep data, verify SGLang patch, launch training with SpecForge-matched hyperparams)
Covers Docker image setup, manual installation, data prep, multi-GPU and single-GPU training, monitoring, benchmarking, and troubleshooting for any GPU environment (not just RunPod). Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
…aining Port DFlash training from RunPod to Modal with 8x H100 support. Add --extra-overrides CLI parameter for runtime config experiments without YAML changes. Document speed tuning results across 4 configurations, finding batch=4 + anchors=256 + 2 inference GPUs as the fastest at 22-25 samples/s (476s for 200 steps).
Document best-tested config (2 inference + 6 training, batch=4, anchors=256) and tuning insights from speed experiments.
…e GPU provisioning Previously, both train_hf (2x H100) and train_sglang (8x H100) were registered as @app.function in a single file, causing Modal to show train_hf in the dashboard even when only running 4+ GPU SGLang tasks. Moved HF backend (1-2 GPU) into a dedicated modal_dflash_train_hf.py so each script only provisions the GPUs it needs.
Phase 2 tests show anchors=512 matches anchors=256 speed (446-457s vs 476s) when using 2 inference GPUs, which is essential to prevent pool starvation. Updated recommendation to quality-optimized 512-D config (anchors=512, batch=1, accum=4, 2+6 GPUs) for best acceptance length.
512-E (batch=1, accum=4, 4 infer + 4 train) is fastest at 368s for 200 steps. 512-F (batch=2) slower at 394s. Fix misleading GPU count in print statements: now parses extra_overrides to show actual inference/training split instead of hardcoded "1 infer + 7 train".
output_dir was resolving to /workspace/TorchSpec/outputs/ (ephemeral
container filesystem) instead of /workspace/outputs/ (Modal volume
mount). Checkpoints were lost on container exit. Now explicitly sets
output_dir={OUTPUTS_DIR}/{run_id} so checkpoints, configs, and logs
all land on the persistent 'torchspec-outputs' volume. Added download
instructions to completion message.
Converts final FSDP checkpoint to HuggingFace format using tools/convert_to_hf.py with explicit dflash_draft_config.json, then optionally uploads to a HuggingFace Hub repo via --hf-repo flag.
Loads pre-trained draft model from HuggingFace (Xingh3/dflash-qwen3-8b-1epoch) and benchmarks speculative decoding vs target-only baseline on Modal H100. Initial results (50 prompts, 256 tokens): 1.39x speedup, τ=2.13
- Change `private=True` to `private=False` on `api.create_repo()` calls in both `_convert_and_upload_hf` and `convert_checkpoint`. Private repos on the free HF tier block large file uploads via LFS. - Add `wandb-secret` to Modal secrets for both training and conversion functions so WandB metric logging works during training runs. - Update WandB setup comment with link to key URL.
The original Transformers-backend benchmark used 10 hand-written prompts, making results incomparable to z-lab's published DFlash numbers. This overhaul adopts z-lab's exact evaluation methodology: - Add `load_benchmark_dataset()` supporting all 10 z-lab datasets (GSM8K, MATH-500, AIME24/25, HumanEval, MBPP, LiveCodeBench, SWE-Bench, MT-Bench, Alpaca) with matching prompt formatting. - Replace fixed prompt list with dataset-driven evaluation loop that handles single-turn and multi-turn (MT-Bench) conversations. - Apply chat template with `enable_thinking=False` to match z-lab. - Add z-lab reference τ and speedup values for comparison in output. - Add cross-dataset summary table in CLI entrypoint. - Increase timeout from 2h to 4h for full dataset runs. - Set deterministic seeds for reproducibility. - Update default draft model to 3-epoch checkpoint.
The Transformers backend was too slow for full evaluation (~130 tok/s). This adds an SGLang-based benchmark that achieves ~350 tok/s with KV cache, CUDA graphs, and tensor parallelism on 2x H100. Five issues were discovered and fixed during development: 1. SGLang v0.5.8.post1 lacks DFLASH speculative algorithm support. Fixed by installing from sgl-project/sglang PR #16818. 2. PR #16818 omits `set_dflash_layers_to_capture` on Qwen2/Qwen3 models (only has eagle3 variant). Patched via sed during image build to add alias method. 3. TorchSpec saves weights with different names than z-lab/SGLang expects (context_proj→fc, context_norm→hidden_norm, final_norm→norm, extra embed_tokens). Added runtime safetensors key remapping in `patch_draft_model_for_sglang()`. 4. SGLang TP workers survive `SIGTERM`, causing Modal container timeout before results print. Fixed with `start_new_session=True` and `SIGKILL` on the process group. 5. Modal containers intermittently fail `snapshot_download` with `ConnectionResetError`. Added exponential backoff retry (5 attempts). Result: τ=3.21 on GSM8K (z-lab reference: 3.38, gap: -0.17).
- fix_hf_model_config.py: Converts TorchSpec-exported model config from model_type="dflash" (unrecognized by Transformers/SGLang) to z-lab format (model_type="qwen3" + nested dflash_config + auto_map). Also uploads z-lab's dflash.py/modeling_dflash.py/utils.py for trust_remote_code loading. Required for SGLang DFlash inference. - training_metrics_3epoch.png: Loss and accuracy curves from the full 200K-sample, 3-epoch training run.
Run all 10 z-lab datasets on 1x H100 (tp=1) via SGLang backend. Average τ=3.06 across datasets, reaching 78% of z-lab reference on math benchmarks with 33x less training data. Comprehensive results with per-dataset τ distributions and domain analysis added to results doc.
Verified from Modal training logs: actual launch used --dataset-size 200000, producing 190,095 samples after filtering (188,977 usable). Training was 23,622 optimizer steps across 3 epochs with global_batch=24. Updated gap analysis: 8.5x fewer sample passes vs z-lab (not 33x).
- Expose min_lr and weight_decay in TrainingConfig (was hardcoded 0) - Plumb through BF16Optimizer → LRSchedulerWithWarmup in both trainers - Update YAML config: accum 4→2, min_lr=6e-5, weight_decay=0.01 - Fix WandB config prefix (logging.* not training.*) - Support epoch-based training (--num-epochs without --max-steps) - Document Phase G convergence analysis and Phase H training plan
…ed secrets - Add Phase H inference benchmark results (10 datasets, SGLang backend): τ improved +0.59 avg over Phase G (3.06 → 3.45), math gap to z-lab narrowed from 21% to 6.4% - Remove naive HF benchmark (modal_dflash_benchmark.py) in favor of SGLang-based benchmark with KV cache, CUDA graphs, paged attention - Update SGLang benchmark to support Modal volume paths for draft model loading (no HuggingFace upload required) - Remove hardcoded API tokens from train script docstring, add setup_modal_secrets.sh for safe secret management - Pin TorchSpec commit in Modal image for reproducible builds
load_hf_dataset now tries datasets.load_dataset() first for Hub paths, falling back to manual JSON download. Also strips columns not needed by the training pipeline (e.g. pre-tokenized input_ids, attention_mask) to reduce memory during streaming iteration.
Add speedup methodology analysis comparing our SGLang E2E measurement vs z-lab's decode-only Transformers benchmark. Add decode-only timing to benchmark script. Patch training script for Arrow/Parquet Hub datasets.
BUG 33: Remove dead gradient_checkpointing flag in DFlashModel — the parameter was stored but never used (no checkpoint wrapping). BUG 34: Fix FlexAttention recompilation from variable Q_LEN — anchor sampling now always returns num_anchors slots with block_keep_mask for invalids, keeping Q_LEN constant across all steps. BUG 35: Fix no_sync silently no-oping under FSDP2 fully_shard — use set_requires_gradient_sync(bool) for FSDP2, fall back to no_sync() for replicate (DDP) strategy. OPT 1: Per-layer FSDP sharding — each decoder layer is now an independent FSDP unit, enabling comm/compute overlap and reducing peak memory. OPT 2: Async H2D transfers — replace synchronous .cuda() with .to(device, non_blocking=True) in DFlash _forward(). OPT 3: Pre-allocate fp32 grad buffers in BF16Optimizer — eliminates per-step tensor allocation/deallocation, reducing GC pressure. Document all findings in dflash_issues.md (Session 13 section).
Training was restarting from step 0 because load_path was never set. Added --resume CLI flag that reads latest_checkpointed_iteration.txt from the output volume and passes training.load_path to the trainer. Also reduced save_interval to 1000 and max_checkpoints to 2 for more frequent checkpointing with less disk usage.
UltraChat model (208K samples, 43K steps) underperforms Phase H across all domains (τ=3.04 vs 3.45 overall). Deep code audit against SpecForge PRs #427/#472/#473 confirms no training bugs — the τ gap to z-lab is fully explained by recipe differences (data volume, seq_len, epochs).
Scripts: - Move RunPod scripts to scripts/runpod/ - Move Modal scripts to scripts/modal/ - Move utility scripts to scripts/tools/ - Remove obsolete scripts (benchmark_dflash_inference.py, fix_hf_model_config.py, modal_dflash_train_hf.py, modal_example.py) - Update all self-referencing paths in script comments/docstrings Docs (docs/inference/dflash/): - Merge TRAINING_GUIDE.md + dflash_runpod_guide.md into one guide with Generic/Modal/RunPod/Troubleshooting sections - Merge dflash_training_results.md + dflash_modal_training_results.md into training_results.md (Phases A-J, deduplicated) - Fold dflash_pending_work.md into issues.md as Future Work section - Fold dflash_training_test_plan.md success criteria into results header - Rename dflash_overview.md to README.md - Rename dflash_issues.md to issues.md - Rename specforge_dflash_training_reference.md to specforge_reference.md - Remove raw benchmark logs (already summarized in results) - Fix all stale cross-references across repo
- Add --sweep-config flag to modal_dflash_train.py for parallel HP sweeps - Add sweep configs for Phase 1 (screening) and Phase 2 (validation) - Handle WSD schedule in dflash_trainer.py without modifying shared framework - Document Phase K results: WSD and accum=1 both beat cosine baseline by ~12%
… training - Add dflash_run_id parameter to train_sglang/sweep runner so each run gets its own output directory (fixes checkpoint collision bug where all sweep runs overwrote the same dflash-qwen3-8b/ path) - Propagate --resume flag through sweep runner to train_sglang.spawn() - Add WSD schedule parameters (wsd_decay_ratio, wsd_decay_style) to TrainingConfig for DFlash trainer - Add phase2_resume.json sweep config for resumed Phase 2 runs
- Add Phase 2 results: P2-WSD vs P2-baseline vs P2-accum1 training curves from 24h Modal runs, documenting checkpoint collision bug and fix - Add P2-accum1 inference benchmark (30 samples/dataset, 10 datasets): τ=3.83 math avg, beats z-lab on GSM8K (3.79 vs 3.38) - Add per-dataset τ distribution tables showing 80-100% of math/code requests achieve τ≥3, with τ≥5 emerging as a real category (up to 37%) - Add distribution comparison vs Phase G showing rightward shift that mean τ alone understates - Update success criteria to reference Phase K P2-accum1
- P2-WSD (WSD schedule) achieves τ=3.94 math avg, our best result and only 2.7% below z-lab (vs 6.4% for Phase H) - Decode-only speedup hits 3.02x on livecodebench, meeting the 3.0x target - Add full cross-dataset comparison (P2-WSD vs P2-accum1 vs Phase H vs z-lab), per-dataset τ distributions, and distribution shift analysis - Update success criteria: speedup target now met
Replace ValueError crash in _sample_anchor_positions with a graceful fallback that returns all-False keep_mask (zero loss) when no valid anchor positions exist. This fixes a training crash on long-context models (e.g. Kimi 2.5 with seq_len=65536) where some batches have all-zero loss_mask in the anchor-eligible range. Also plumb min_loss_tokens through the online mooncake data path (data_fetcher → trainer) so samples with too few supervised tokens are filtered at loading time — previously this filter only worked in the offline preprocessing path.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
No description provided.