[Refactor][vLLM] Replace worker extension with extract_hidden_states KV connector#57
Conversation
8987560 to
7c4fb89
Compare
…MooncakeHiddenStatesConnector Switch hidden states capture from monkey-patching model.forward via VllmWorkerExtension to vLLM's public extract_hidden_states speculative method paired with a custom MooncakeHiddenStatesConnector KV connector. Key changes: - Add MooncakeHiddenStatesConnector that writes hidden states to Mooncake via RDMA directly from vLLM worker processes - Rewrite VllmEngine to use speculative_config + kv_transfer_config - Delete VllmWorkerExtension (774 lines) — no longer needed - Add verifier norm support to TargetLMHead and Eagle3Trainer for pre-norm hidden states (vLLM extract_hidden_states captures pre-norm outputs) - Add last_hidden_states_prenorm config with auto-detection per engine type - Add proper engine shutdown in training loop cleanup - Add output count mismatch guard in inference manager - Bump vLLM minimum to >=0.18.0 - Update tests for new connector/engine
There was a problem hiding this comment.
Pull request overview
Refactors TorchSpec’s vLLM integration to capture hidden states via vLLM’s public extract_hidden_states speculative method and a custom KV connector, replacing the prior worker-extension monkey-patching approach.
Changes:
- Introduces
MooncakeHiddenStatesConnectorto write hidden states directly to Mooncake from vLLM worker processes via KV transfer. - Updates
VllmEngineto usespeculative_config+kv_transfer_config, and removes the legacyVllmWorkerExtension. - Adjusts training/config/test plumbing for pre-norm
last_hidden_states(verifier norm), adds safer shutdown/guards, and bumps vLLM minimum to>=0.18.0.
Reviewed changes
Copilot reviewed 14 out of 14 changed files in this pull request and generated 6 comments.
Show a summary per file
| File | Description |
|---|---|
torchspec/inference/engine/vllm_engine.py |
Switches vLLM engine wiring to extract_hidden_states + KV transfer metadata flow. |
torchspec/inference/engine/mooncake_hidden_states_connector.py |
Adds a KV connector that extracts hidden states from KV cache and stores to Mooncake. |
torchspec/inference/engine/vllm_worker_extension.py |
Deletes the previous monkey-patching worker extension implementation. |
torchspec/training/eagle3_trainer.py |
Applies optional verifier norm to pre-norm last_hidden_states before target projection. |
torchspec/models/target/target_utils.py |
Adds optional loading/initialization of final norm alongside lm_head for pre-norm handling. |
torchspec/controller/loop.py |
Ensures inference engines are shut down during training cleanup. |
torchspec/controller/inference_manager.py |
Adds an output-count mismatch guard to avoid incorrect zipping/dispatch. |
torchspec/controller/eval.py |
Adjusts initial eval submission sizing with inference_batch_size. |
torchspec/config/train_config.py |
Auto-defaults last_hidden_states_prenorm based on engine type. |
torchspec/config/inference_config.py |
Documents new vLLM behavior and adds last_hidden_states_prenorm resolution helper. |
tests/test_vllm_engine.py |
Updates unit tests for new connector/metadata flow and adds chunked-prefill coverage. |
pyproject.toml |
Bumps vLLM extra requirement to >=0.18.0. |
configs/vllm_qwen3_8b.yaml |
Updates example config for the new connector-based capture path. |
examples/data/sample_conversations.jsonl |
Adds an additional sample conversation entry used for long/chunked scenarios. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
|
|
||
| assert hs_shape[1] == num_training_layers * hidden_size | ||
| assert lhs_shape[1] == hidden_size | ||
| assert hs_shape[1] + lhs_shape[1] != (num_training_layers + 1) * hidden_size or True |
There was a problem hiding this comment.
This assertion is a no-op: ... != ... or True always evaluates to True, so it can’t catch regressions in the hidden_states/last_hidden_states split. Replace it with a meaningful check (e.g., assert equality to the expected combined width, or assert the combined width equals num_hidden_states * hidden_size).
| assert hs_shape[1] + lhs_shape[1] != (num_training_layers + 1) * hidden_size or True | |
| assert hs_shape[1] + lhs_shape[1] == (num_training_layers + 1) * hidden_size |
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: 7c4fb89136
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".
…errides When pre-norm last hidden states are enabled, TargetLMHead always used the default norm_key (model.norm.weight) because only lm_head_key was forwarded from config. For models with custom key prefixes, norm loading silently failed, leaving self.norm=None and corrupting target logits.
Switch hidden states capture from monkey-patching model.forward via VllmWorkerExtension to vLLM's public extract_hidden_states speculative method paired with a custom MooncakeHiddenStatesConnector KV connector.
Key changes:
fixes issue: #53