From c29531843160f413a47e39bca0c7e94f777a2afe Mon Sep 17 00:00:00 2001 From: Jiani Wang Date: Tue, 4 Nov 2025 18:11:45 -0800 Subject: [PATCH 01/19] run vllm engine --- torchtitan/experiments/vllm/README.md | 185 ++++++++++++++++++ .../experiments/vllm/checkpoint/config.json | 22 +++ torchtitan/experiments/vllm/infer.py | 177 +++++++++++++++++ torchtitan/experiments/vllm/model/qwen3.py | 40 ++++ torchtitan/experiments/vllm/test_infer.sh | 54 +++++ 5 files changed, 478 insertions(+) create mode 100644 torchtitan/experiments/vllm/README.md create mode 100644 torchtitan/experiments/vllm/checkpoint/config.json create mode 100644 torchtitan/experiments/vllm/infer.py create mode 100644 torchtitan/experiments/vllm/model/qwen3.py create mode 100755 torchtitan/experiments/vllm/test_infer.sh diff --git a/torchtitan/experiments/vllm/README.md b/torchtitan/experiments/vllm/README.md new file mode 100644 index 0000000000..18a027c500 --- /dev/null +++ b/torchtitan/experiments/vllm/README.md @@ -0,0 +1,185 @@ +# TorchTitan Qwen3 Model with vLLM Inference + +This directory contains code to run vLLM inference on models trained with TorchTitan. + +## Overview + +The integration consists of two main components: + +1. **Model Adapter** (`model/qwen3.py`): A custom model class that extends vLLM's `Qwen3ForCausalLM` to handle TorchTitan checkpoint naming conventions +2. **Inference Script** (`infer.py`): A simple script to register the model and run inference + + +## Quick Start + +### Prerequisites + +1. Install vLLM: +```bash +pip install vllm +``` + +### Running Inference + +#### Single Prompt + +```bash +python torchtitan/experiments/vllm/infer.py \ + --model-path /path/to/torchtitan/checkpoint \ + --prompt "Explain quantum computing in simple terms" +``` + +#### Multiple Prompts from File + +```bash +# Create a file with prompts (one per line) +cat > prompts.txt << EOF +What is the meaning of life? +Explain how transformers work +Write a poem about AI +EOF + +# Run inference +python torchtitan/experiments/vllm/infer.py \ + --model-path /path/to/torchtitan/checkpoint \ + --prompts-file prompts.txt +``` + +#### With Tensor Parallelism + +```bash +python torchtitan/experiments/vllm/infer.py \ + --model-path /path/to/torchtitan/checkpoint \ + --prompt "Explain deep learning" \ + --tensor-parallel-size 4 \ + --max-tokens 200 +``` + +## Model Configuration + +Your checkpoint directory should contain: + +1. **`config.json`**: HuggingFace-style model configuration +2. **Model weights**: Either PyTorch checkpoint files or safetensors + +Example `config.json` for a Qwen3-7B model: + +```json +{ + "architectures": ["TorchTitanQwen3ForCausalLM"], + "model_type": "qwen3", + "hidden_size": 3584, + "intermediate_size": 18944, + "num_attention_heads": 28, + "num_hidden_layers": 28, + "num_key_value_heads": 4, + "vocab_size": 151936, + "max_position_embeddings": 32768, + "rope_theta": 1000000.0, + "rms_norm_eps": 1e-06, + "tie_word_embeddings": false, + "head_dim": 128 +} +``` + +## Command-line Arguments + +| Argument | Type | Default | Description | +|----------|------|---------|-------------| +| `--model-path` | str | Required | Path to TorchTitan checkpoint directory | +| `--prompt` | str | "Hello, how are you?" | Single prompt to generate from | +| `--prompts-file` | str | None | Path to file with prompts (one per line) | +| `--max-tokens` | int | 100 | Maximum tokens to generate | +| `--temperature` | float | 0.8 | Sampling temperature | +| `--top-p` | float | 0.95 | Nucleus sampling parameter | +| `--tensor-parallel-size` | int | 1 | Number of GPUs for tensor parallelism | + +## Implementation Details + +### Model Registration + +The inference script registers the custom model with vLLM's model registry: + +```python +from vllm.model_executor.models import ModelRegistry +from torchtitan.experiments.vllm.model.qwen3 import TorchTitanQwen3ForCausalLM + +ModelRegistry.register_model( + "TorchTitanQwen3ForCausalLM", + TorchTitanQwen3ForCausalLM, +) +``` + +### Weight Mapping + +The `WeightsMapper` class handles automatic name translation: + +```python +mapper = WeightsMapper( + orig_to_new_substr={ + ".attention.wq": ".self_attn.q_proj", + # ... other mappings + }, + orig_to_new_prefix={ + "tok_embeddings.weight": "model.embed_tokens.weight", + # ... other mappings + }, +) +``` + +### vLLM Engine Initialization + +The script uses vLLM's high-level `LLM` class: + +```python +llm = LLM( + model=model_path, + tensor_parallel_size=tensor_parallel_size, + trust_remote_code=True, +) +``` + +## Troubleshooting + +### "Model not found" Error + +Ensure `config.json` exists in your checkpoint directory and specifies the correct architecture: +```json +{ + "architectures": ["TorchTitanQwen3ForCausalLM"], + "model_type": "qwen3" +} +``` + +### Weight Loading Errors + +Check that your checkpoint contains weights with TorchTitan naming conventions. You can inspect checkpoint keys: + +```python +import torch +checkpoint = torch.load("path/to/checkpoint.pt") +print(checkpoint.keys()) +``` + +### Memory Issues + +- Reduce `--tensor-parallel-size` if you have limited GPU memory +- Use quantization (see vLLM documentation for quantization options) + +## Performance Notes + +- **Batch Processing**: The script processes multiple prompts in a single batch for efficiency +- **KV Caching**: vLLM automatically uses KV caching for fast autoregressive generation +- **Tensor Parallelism**: Use `--tensor-parallel-size` to distribute the model across multiple GPUs + +## Next Steps + +- See vLLM documentation for advanced features: https://docs.vllm.ai/ +- Explore vLLM's serving capabilities for production deployments +- Configure quantization for reduced memory usage + +## References + +- [TorchTitan Qwen3 Model](../../../models/qwen3/model/model.py) +- [vLLM Qwen3 Model](https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/qwen3.py) +- [vLLM Documentation](https://docs.vllm.ai/) diff --git a/torchtitan/experiments/vllm/checkpoint/config.json b/torchtitan/experiments/vllm/checkpoint/config.json new file mode 100644 index 0000000000..8e23aa1fa5 --- /dev/null +++ b/torchtitan/experiments/vllm/checkpoint/config.json @@ -0,0 +1,22 @@ +{ + "architectures": [ + "TorchTitanQwen3ForCausalLM" + ], + "model_type": "qwen3", + "hidden_size": 3584, + "intermediate_size": 18944, + "num_attention_heads": 28, + "num_hidden_layers": 28, + "num_key_value_heads": 4, + "vocab_size": 151936, + "max_position_embeddings": 32768, + "rope_theta": 1000000.0, + "rms_norm_eps": 1e-06, + "tie_word_embeddings": false, + "head_dim": 128, + "attention_bias": false, + "hidden_act": "silu", + "qk_norm": true, + "torch_dtype": "bfloat16", + "skip_tokenizer_init": true +} diff --git a/torchtitan/experiments/vllm/infer.py b/torchtitan/experiments/vllm/infer.py new file mode 100644 index 0000000000..77b727e601 --- /dev/null +++ b/torchtitan/experiments/vllm/infer.py @@ -0,0 +1,177 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +Simple inference script for TorchTitan-trained Qwen3 model using vLLM. + +This script demonstrates how to: +1. Register a custom TorchTitan Qwen3 model with vLLM +2. Load a TorchTitan checkpoint into vLLM +3. Run inference using vLLM's optimized engine + +Usage: + python infer.py --model-path /path/to/torchtitan/checkpoint --prompt "Hello, world!" +""" + +import argparse +import logging +from pathlib import Path + +from vllm import LLM, SamplingParams +from vllm.model_executor.models import ModelRegistry + +logging.basicConfig( + level=logging.INFO, + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", +) +logger = logging.getLogger(__name__) + + +def register_torchtitan_qwen3_model(): + """Register the TorchTitan Qwen3 model with vLLM's model registry.""" + from torchtitan.experiments.vllm.model.qwen3 import TorchTitanQwen3ForCausalLM + + logger.info("Registering TorchTitan Qwen3 model with vLLM") + + # Register the model using Qwen3's architecture but with custom weight loading + ModelRegistry.register_model( + "TorchTitanQwen3ForCausalLM", + TorchTitanQwen3ForCausalLM, + ) + + print("Successfully registered TorchTitanQwen3ForCausalLM") + + +def run_inference( + model: str, + prompts: list[str], + max_tokens: int = 100, + temperature: float = 0.8, + top_p: float = 0.95, + tensor_parallel_size: int = 1, +): + """ + Run inference using vLLM with a TorchTitan-trained Qwen3 model. + + Args: + model: Model name + prompts: List of prompts to generate from + max_tokens: Maximum number of tokens to generate + temperature: Sampling temperature + top_p: Top-p sampling parameter + tensor_parallel_size: Number of GPUs for tensor parallelism + """ + # Create sampling parameters + sampling_params = SamplingParams( + temperature=temperature, + top_p=top_p, + max_tokens=max_tokens, + ) + + # entry point: + try: + llm = LLM( + model=model, + model_impl="vllm", + skip_tokenizer_init=True, + ) + except Exception as e: + logger.error( + "Failed to initialize vLLM engine with TorchTitanQwen3ForCausalLM model\n" + ) + raise + + logger.info("Model loaded successfully, starting generation...") + + # Generate outputs + outputs = llm.generate(prompts, sampling_params) + + # Print results + for output in outputs: + prompt = output.prompt + generated_text = output.outputs[0].text + logger.info("-" * 80) + logger.info(f"Prompt: {prompt}") + logger.info(f"Generated: {generated_text}") + + logger.info("-" * 80) + logger.info(f"Generated {len(outputs)} outputs successfully") + + +def main(): + parser = argparse.ArgumentParser( + description="Run inference with TorchTitan Qwen3 model using vLLM" + ) + parser.add_argument( + "--model", + type=str, + default="torchtitan/experiments/vllm/checkpoint/", + help="Path to the TorchTitan checkpoint or HuggingFace model directory", + ) + parser.add_argument( + "--prompt", + type=str, + default="Hello, how are you?", + help="Single prompt to generate from", + ) + parser.add_argument( + "--prompts-file", + type=str, + help="Path to file containing prompts (one per line)", + ) + parser.add_argument( + "--max-tokens", + type=int, + default=100, + help="Maximum number of tokens to generate", + ) + parser.add_argument( + "--temperature", + type=float, + default=0.8, + help="Sampling temperature", + ) + parser.add_argument( + "--top-p", + type=float, + default=0.95, + help="Top-p (nucleus) sampling parameter", + ) + parser.add_argument( + "--tensor-parallel-size", + type=int, + default=1, + help="Number of GPUs for tensor parallelism", + ) + + args = parser.parse_args() + + # Register the custom model + register_torchtitan_qwen3_model() + + # Prepare prompts + if args.prompts_file: + prompts_path = Path(args.prompts_file) + if not prompts_path.exists(): + raise FileNotFoundError(f"Prompts file not found: {prompts_path}") + prompts = prompts_path.read_text().strip().split("\n") + logger.info(f"Loaded {len(prompts)} prompts from {prompts_path}") + else: + prompts = [args.prompt] + + # Run inference + run_inference( + model=args.model, + prompts=prompts, + max_tokens=args.max_tokens, + temperature=args.temperature, + top_p=args.top_p, + tensor_parallel_size=args.tensor_parallel_size, + ) + + +if __name__ == "__main__": + main() diff --git a/torchtitan/experiments/vllm/model/qwen3.py b/torchtitan/experiments/vllm/model/qwen3.py new file mode 100644 index 0000000000..131b2a8c7a --- /dev/null +++ b/torchtitan/experiments/vllm/model/qwen3.py @@ -0,0 +1,40 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from collections.abc import Iterable + +import torch + +from vllm.config import VllmConfig +from vllm.logger import init_logger +from vllm.model_executor.models.qwen3 import Qwen3ForCausalLM + +logger = init_logger(__name__) + + +class TorchTitanQwen3ForCausalLM(Qwen3ForCausalLM): + """ + TorchTitan-trained Qwen3 dense model adapter for vLLM. + + This class extends the standard Qwen3ForCausalLM to support loading + weights from TorchTitan checkpoints with different naming conventions. + The architecture is identical to standard Qwen3 - only weight names differ. + """ + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__(vllm_config=vllm_config, prefix=prefix) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + pass + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + pass + + def compute_logits( + self, + hidden_states: torch.Tensor, + ) -> torch.Tensor | None: + pass diff --git a/torchtitan/experiments/vllm/test_infer.sh b/torchtitan/experiments/vllm/test_infer.sh new file mode 100755 index 0000000000..a13810613c --- /dev/null +++ b/torchtitan/experiments/vllm/test_infer.sh @@ -0,0 +1,54 @@ +#!/bin/bash +# Simple test script to demonstrate vLLM inference with TorchTitan Qwen3 model + +set -e + +echo "========================================" +echo "TorchTitan Qwen3 + vLLM Inference Test" +echo "========================================" + +# Check if model path is provided +if [ -z "$1" ]; then + echo "Usage: $0 " + echo "" + echo "Example:" + echo " $0 /path/to/torchtitan/checkpoint" + echo "" + echo "The checkpoint directory should contain:" + echo " - config.json (HuggingFace-style model config)" + echo " - Model weights (PyTorch checkpoint or safetensors)" + exit 1 +fi + +MODEL_PATH="$1" + +# Verify model path exists +if [ ! -d "$MODEL_PATH" ]; then + echo "Error: Model path not found: $MODEL_PATH" + exit 1 +fi + +# Verify config.json exists +if [ ! -f "$MODEL_PATH/config.json" ]; then + echo "Error: config.json not found in $MODEL_PATH" + echo "" + echo "Please ensure your checkpoint contains a HuggingFace-style config.json" + echo "See example_config.json for reference" + exit 1 +fi + +echo "Model path: $MODEL_PATH" +echo "" + +# Run inference with a simple prompt +echo "Running inference with single prompt..." +python torchtitan/experiments/vllm/infer.py \ + --model-path "$MODEL_PATH" \ + --prompt "What is the meaning of life?" \ + --max-tokens 50 \ + --temperature 0.7 + +echo "" +echo "========================================" +echo "Test completed successfully!" +echo "========================================" From b5f9e59694698181769aa87faf9abd4dcbc6509f Mon Sep 17 00:00:00 2001 From: Jiani Wang Date: Sun, 23 Nov 2025 19:05:50 -0800 Subject: [PATCH 02/19] add 1st version --- .../experiments/vllm/checkpoint/config.json | 22 - .../vllm/custom_models/__init__.py | 35 ++ .../custom_models/attention_replacement.py | 170 ++++++ .../custom_models/custom_model_wrapper.py | 133 +++++ .../vllm/custom_models/trainable_attention.py | 501 ++++++++++++++++ .../custom_models/trainable_mla_attention.py | 562 ++++++++++++++++++ .../experiments/vllm/custom_models/utils.py | 245 ++++++++ .../vllm/example_checkpoint/config.json | 0 torchtitan/experiments/vllm/infer.py | 291 +++++---- torchtitan/experiments/vllm/model/qwen3.py | 192 +++++- torchtitan/experiments/vllm/test_infer.sh | 54 -- 11 files changed, 2002 insertions(+), 203 deletions(-) delete mode 100644 torchtitan/experiments/vllm/checkpoint/config.json create mode 100644 torchtitan/experiments/vllm/custom_models/__init__.py create mode 100644 torchtitan/experiments/vllm/custom_models/attention_replacement.py create mode 100644 torchtitan/experiments/vllm/custom_models/custom_model_wrapper.py create mode 100644 torchtitan/experiments/vllm/custom_models/trainable_attention.py create mode 100644 torchtitan/experiments/vllm/custom_models/trainable_mla_attention.py create mode 100644 torchtitan/experiments/vllm/custom_models/utils.py create mode 100644 torchtitan/experiments/vllm/example_checkpoint/config.json mode change 100644 => 100755 torchtitan/experiments/vllm/infer.py delete mode 100755 torchtitan/experiments/vllm/test_infer.sh diff --git a/torchtitan/experiments/vllm/checkpoint/config.json b/torchtitan/experiments/vllm/checkpoint/config.json deleted file mode 100644 index 8e23aa1fa5..0000000000 --- a/torchtitan/experiments/vllm/checkpoint/config.json +++ /dev/null @@ -1,22 +0,0 @@ -{ - "architectures": [ - "TorchTitanQwen3ForCausalLM" - ], - "model_type": "qwen3", - "hidden_size": 3584, - "intermediate_size": 18944, - "num_attention_heads": 28, - "num_hidden_layers": 28, - "num_key_value_heads": 4, - "vocab_size": 151936, - "max_position_embeddings": 32768, - "rope_theta": 1000000.0, - "rms_norm_eps": 1e-06, - "tie_word_embeddings": false, - "head_dim": 128, - "attention_bias": false, - "hidden_act": "silu", - "qk_norm": true, - "torch_dtype": "bfloat16", - "skip_tokenizer_init": true -} diff --git a/torchtitan/experiments/vllm/custom_models/__init__.py b/torchtitan/experiments/vllm/custom_models/__init__.py new file mode 100644 index 0000000000..ffe2a498ce --- /dev/null +++ b/torchtitan/experiments/vllm/custom_models/__init__.py @@ -0,0 +1,35 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Utilities for integrating custom model implementations with vLLM.""" + +from .attention_replacement import replace_with_trainable_attention +from .custom_model_wrapper import VLLMModelForCausalLM +from .trainable_attention import TrainableFlashAttention +from .trainable_mla_attention import MLAConfig, TrainableMLA +from .utils import ( + convert_freqs_cis_to_real, + create_mla_kv_cache_spec, + load_external_weights, + store_positions_in_context, +) + +__all__ = [ + # Attention modules + "TrainableFlashAttention", + "TrainableMLA", + "MLAConfig", + "replace_with_trainable_attention", + # Base wrapper + "VLLMModelForCausalLM", + # Utilities + "convert_freqs_cis_to_real", + "create_mla_kv_cache_spec", + "load_external_weights", + "store_positions_in_context", +] diff --git a/torchtitan/experiments/vllm/custom_models/attention_replacement.py b/torchtitan/experiments/vllm/custom_models/attention_replacement.py new file mode 100644 index 0000000000..a3d51dbe65 --- /dev/null +++ b/torchtitan/experiments/vllm/custom_models/attention_replacement.py @@ -0,0 +1,170 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +Utilities for replacing TorchTitan attention layers with vLLM trainable attention. + +This module provides functions to automatically replace TorchTitan's attention +layers with vLLM's optimized trainable attention layers (TrainableFlashAttention +or TrainableMLA) while preserving weights. + +Example usage: + ```python + from torchtitan.models.qwen3.model import Qwen3Model + from . import replace_with_trainable_attention + + # Create TorchTitan model + model = Qwen3Model(model_args) + + # Replace attention layers with vLLM trainable attention + replace_with_trainable_attention(model, use_mla=False) + ``` +""" + +import torch.nn as nn + +from .trainable_attention import TrainableFlashAttention +from .trainable_mla_attention import MLAConfig, TrainableMLA + + +def replace_with_trainable_attention( + model: nn.Module, + use_mla: bool = False, +) -> None: + """ + Replace TorchTitan attention layers with vLLM trainable attention. + + This function performs in-place module surgery, replacing all attention + layers in model.layers with either TrainableFlashAttention or TrainableMLA + while preserving the original weights. + + Args: + model: TorchTitan model with .layers attribute (dict or nn.ModuleDict) + use_mla: If True, use TrainableMLA; otherwise use TrainableFlashAttention + + Raises: + AttributeError: If model doesn't have .layers attribute + ValueError: If attention layer structure is not recognized + """ + if not hasattr(model, "layers"): + raise AttributeError( + f"Model {type(model).__name__} must have .layers attribute" + ) + + for layer_name, layer in model.layers.items(): + if not hasattr(layer, "attention"): + raise ValueError(f"Layer {layer_name} must have .attention attribute") + + old_attn = layer.attention + + if use_mla: + # Create TrainableMLA and transfer weights + new_attn = _create_trainable_mla_from_torchtitan(old_attn) + else: + # Create TrainableFlashAttention and transfer weights + new_attn = _create_trainable_flash_attention_from_torchtitan(old_attn) + + # Replace attention module + layer.attention = new_attn + + +def _create_trainable_flash_attention_from_torchtitan( + torchtitan_attn: nn.Module, +) -> TrainableFlashAttention: + """ + Create TrainableFlashAttention from TorchTitan attention and transfer weights. + + Args: + torchtitan_attn: TorchTitan Attention module + + Returns: + TrainableFlashAttention with transferred weights + """ + # Extract config from TorchTitan attention + hidden_size = torchtitan_attn.wq.weight.shape[1] + num_heads = torchtitan_attn.n_heads + num_kv_heads = getattr(torchtitan_attn, "n_kv_heads", num_heads) + head_dim = torchtitan_attn.head_dim + use_qk_norm = ( + hasattr(torchtitan_attn, "q_norm") and torchtitan_attn.q_norm is not None + ) + + # Create vLLM attention + vllm_attn = TrainableFlashAttention( + hidden_size=hidden_size, + num_heads=num_heads, + num_kv_heads=num_kv_heads, + head_dim=head_dim, + use_fused_qkv=False, # TorchTitan uses separate wq/wk/wv + use_qk_norm=use_qk_norm, + ) + + # Transfer weights (TorchTitan and vLLM use same naming: wq, wk, wv, wo) + vllm_attn.wq.weight.data.copy_(torchtitan_attn.wq.weight.data) + vllm_attn.wk.weight.data.copy_(torchtitan_attn.wk.weight.data) + vllm_attn.wv.weight.data.copy_(torchtitan_attn.wv.weight.data) + vllm_attn.wo.weight.data.copy_(torchtitan_attn.wo.weight.data) + + # Transfer QK norm weights if present + if use_qk_norm: + vllm_attn.q_norm.weight.data.copy_(torchtitan_attn.q_norm.weight.data) + vllm_attn.k_norm.weight.data.copy_(torchtitan_attn.k_norm.weight.data) + + return vllm_attn + + +def _create_trainable_mla_from_torchtitan( + torchtitan_attn: nn.Module, +) -> TrainableMLA: + """ + Create TrainableMLA from TorchTitan MLA attention and transfer weights. + + Args: + torchtitan_attn: TorchTitan MLA Attention module + + Returns: + TrainableMLA with transferred weights + """ + # Extract MLA config from TorchTitan attention + config = MLAConfig( + hidden_size=torchtitan_attn.dim, + num_heads=torchtitan_attn.n_heads, + q_lora_rank=torchtitan_attn.q_lora_rank, + kv_lora_rank=torchtitan_attn.kv_lora_rank, + qk_nope_head_dim=torchtitan_attn.qk_nope_head_dim, + qk_rope_head_dim=torchtitan_attn.qk_rope_head_dim, + v_head_dim=torchtitan_attn.v_head_dim, + norm_eps=1e-5, # Standard value for DeepSeek + dropout=0.0, + scale=torchtitan_attn.softmax_scale, + causal=True, + ) + + # Create vLLM MLA + vllm_mla = TrainableMLA(config) + + # Transfer weights + if vllm_mla.q_lora_rank == 0: + # Direct Q projection + vllm_mla.wq.weight.data.copy_(torchtitan_attn.wq.weight.data) + else: + # LoRA Q projection + assert vllm_mla.q_norm is not None # q_norm exists when q_lora_rank > 0 + vllm_mla.wq_a.weight.data.copy_(torchtitan_attn.wq_a.weight.data) + vllm_mla.wq_b.weight.data.copy_(torchtitan_attn.wq_b.weight.data) + vllm_mla.q_norm.weight.data.copy_(torchtitan_attn.q_norm.weight.data) + + # KV projection (always LoRA) + vllm_mla.wkv_a.weight.data.copy_(torchtitan_attn.wkv_a.weight.data) + vllm_mla.wkv_b.weight.data.copy_(torchtitan_attn.wkv_b.weight.data) + vllm_mla.kv_norm.weight.data.copy_(torchtitan_attn.kv_norm.weight.data) + + # Output projection + vllm_mla.wo.weight.data.copy_(torchtitan_attn.wo.weight.data) + + return vllm_mla diff --git a/torchtitan/experiments/vllm/custom_models/custom_model_wrapper.py b/torchtitan/experiments/vllm/custom_models/custom_model_wrapper.py new file mode 100644 index 0000000000..d30c711def --- /dev/null +++ b/torchtitan/experiments/vllm/custom_models/custom_model_wrapper.py @@ -0,0 +1,133 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +Base class for integrating custom model implementations with vLLM. + +This module provides an abstract base class that enforces the vLLM model interface, +making it easy to integrate external model implementations (e.g., from TorchTitan, +NanoGPT, etc.) with vLLM. + +Example usage: + ```python + from some_external_lib.models import ExternalModel + from . import VLLMModelForCausalLM + + + class MyCustomModelForCausalLM(VLLMModelForCausalLM): + def __init__(self, vllm_config, parallel_context=None, **kwargs): + super().__init__() + self.model = ExternalModel(...) # Create external model + # Replace attention layers with vLLM's trainable attention + + def get_input_embeddings(self, input_ids): + return self.model.tok_embeddings(input_ids) + + def forward(self, input_ids, positions=None, **kwargs): + # Forward pass + return hidden_states + + def compute_logits(self, hidden_states, sampling_metadata=None): + return self.model.output(hidden_states) + + def load_weights(self, weights_iter): + # Load weights from HuggingFace checkpoint + pass + ``` +""" + +from abc import ABC, abstractmethod +from collections.abc import Iterator +from typing import Any + +import torch +import torch.nn as nn + + +class VLLMModelForCausalLM(nn.Module, ABC): + """ + Abstract base class for integrating custom model implementations with vLLM. + + This class enforces the vLLM model interface that all text generation models + must implement. Subclasses should: + 1. Import and instantiate the external model in __init__ + 2. Replace attention layers with vLLM's trainable attention + 3. Implement the abstract methods below + + Class attributes: + supports_pp: Whether pipeline parallelism is supported + supports_multimodal: Whether multimodal inputs are supported + """ + + supports_pp: bool = False + supports_multimodal: bool = False + + @abstractmethod + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + """ + Convert input token IDs to embeddings. + + Args: + input_ids: Token IDs [batch, seq_len] or [total_tokens] + + Returns: + Embeddings [batch, seq_len, hidden_size] or [total_tokens, hidden_size] + """ + pass + + @abstractmethod + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor | None = None, + **kwargs: Any, + ) -> torch.Tensor: + """ + Forward pass through the model. + + Args: + input_ids: Token IDs [batch, seq_len] or [total_tokens] + positions: Position indices from vLLM for RoPE indexing + **kwargs: Additional vLLM-specific arguments + + Returns: + Hidden states before final projection [batch, seq_len, hidden_size] + or [total_tokens, hidden_size] + """ + pass + + @abstractmethod + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: Any = None, + ) -> torch.Tensor: + """ + Compute logits from hidden states. + + Args: + hidden_states: Output from forward() [batch, seq_len, hidden_size] + sampling_metadata: vLLM sampling metadata (optional) + + Returns: + Logits [batch, seq_len, vocab_size] + """ + pass + + @abstractmethod + def load_weights(self, weights_iter: Iterator[tuple[str, torch.Tensor]]) -> None: + """ + Load weights from HuggingFace checkpoint. + + This method should map HuggingFace weight names to model parameter names + and load them into the model. + + Args: + weights_iter: Iterator yielding (name, tensor) tuples from HF checkpoint + """ + pass diff --git a/torchtitan/experiments/vllm/custom_models/trainable_attention.py b/torchtitan/experiments/vllm/custom_models/trainable_attention.py new file mode 100644 index 0000000000..aa8d074d85 --- /dev/null +++ b/torchtitan/experiments/vllm/custom_models/trainable_attention.py @@ -0,0 +1,501 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +Trainable Flash Attention module for research and fine-tuning. + +This module provides a training-compatible wrapper around vLLM's optimized +flash attention implementation, enabling backpropagation for RL and fine-tuning +use cases. +""" + +import itertools +from typing import Any + +import torch +import torch.nn as nn + +from vllm.attention import Attention +from vllm.attention.utils.fa_utils import is_flash_attn_varlen_func_available + +if is_flash_attn_varlen_func_available(): + from vllm.attention.utils.fa_utils import flash_attn_varlen_func + +from vllm.config import VllmConfig +from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase +from vllm.v1.kv_cache_interface import FullAttentionSpec + + +class TrainableFlashAttention(nn.Module, AttentionLayerBase): + """ + Training-compatible flash attention module using vLLM's optimized kernels. + + This module wraps vLLM's flash attention forward pass and adds backward + support for training scenarios like reinforcement learning and fine-tuning. + + Supports both fused QKV projections (efficient) and separate projections + (for compatibility with TorchTitan models during module surgery). + + Example: + ```python + # Create attention module (fused, efficient) + attn = TrainableFlashAttention(hidden_size=768, num_heads=12, dropout=0.1) + + # Create TorchTitan-compatible module (separate projections) + attn = TrainableFlashAttention( + hidden_size=768, + num_heads=12, + use_fused_qkv=False, # Separate wq/wk/wv for compatibility + use_qk_norm=True, # QK normalization like Qwen3 + ) + + # Use in training + attn.train() + hidden_states = torch.randn(2, 16, 768, requires_grad=True) + output = attn(hidden_states) + + # Backward pass works + loss = output.sum() + loss.backward() + ``` + + Args: + hidden_size: Hidden dimension of the model + num_heads: Number of attention heads + num_kv_heads: Number of key/value heads (for GQA). Defaults to num_heads + head_dim: Dimension per head. Defaults to hidden_size // num_heads + dropout: Dropout probability during training. Defaults to 0.0 + scale: Attention scale factor. Defaults to 1/sqrt(head_dim) + causal: Whether to use causal masking. Defaults to True + use_fused_qkv: Use fused QKV projection (efficient). Set False for + TorchTitan compatibility. Defaults to True. + use_qk_norm: Apply RMSNorm to Q and K after projection (Qwen3 style). + Defaults to False. + norm_eps: Epsilon for RMSNorm if use_qk_norm=True. Defaults to 1e-6. + """ + + # Class variable for auto-generating unique layer names (thread-safe) + _layer_counter = itertools.count() + + def __init__( + self, + hidden_size: int, + num_heads: int, + num_kv_heads: int | None = None, + head_dim: int | None = None, + dropout: float = 0.0, + scale: float | None = None, + causal: bool = True, + use_fused_qkv: bool = True, + use_qk_norm: bool = False, + norm_eps: float = 1e-6, + ): + super().__init__() + + if not is_flash_attn_varlen_func_available(): + raise RuntimeError( + "Flash attention is not available. " + "Please install flash-attn: pip install flash-attn" + ) + + self.hidden_size = hidden_size + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads or num_heads + self.head_dim = head_dim or (hidden_size // num_heads) + self.dropout = dropout + self.causal = causal + self.use_fused_qkv = use_fused_qkv + self.use_qk_norm = use_qk_norm + + if scale is None: + self.scale = self.head_dim**-0.5 + else: + self.scale = scale + + # TODO(future optimization): Always use fused QKV for efficiency + # Currently supporting separate projections for TorchTitan compatibility + # during module surgery. Once we have weight conversion utilities, + # we should always initialize with fused weights and convert TorchTitan + # weights (wq, wk, wv) -> fused (qkv) during load_weights(). + # This will give us the best of both worlds: compatibility + efficiency. + + if use_fused_qkv: + # Fused QKV projection (efficient - single matmul) + self.qkv = nn.Linear( + hidden_size, + (num_heads + 2 * self.num_kv_heads) * self.head_dim, + bias=False, + ) + else: + # Separate projections (TorchTitan compatibility) + self.wq = nn.Linear(hidden_size, num_heads * self.head_dim, bias=False) + self.wk = nn.Linear( + hidden_size, self.num_kv_heads * self.head_dim, bias=False + ) + self.wv = nn.Linear( + hidden_size, self.num_kv_heads * self.head_dim, bias=False + ) + + # Output projection (naming convention follows use_fused_qkv) + if use_fused_qkv: + self.o_proj = nn.Linear(num_heads * self.head_dim, hidden_size, bias=False) + else: + # TorchTitan uses 'wo' naming + self.wo = nn.Linear(num_heads * self.head_dim, hidden_size, bias=False) + + # Optional QK normalization (for Qwen3 and similar models) + if use_qk_norm: + self.q_norm = nn.RMSNorm(self.head_dim, eps=norm_eps) + self.k_norm = nn.RMSNorm(self.head_dim, eps=norm_eps) + else: + self.q_norm = None + self.k_norm = None + + # Create vLLM Attention layer to handle KV cache automatically + # This delegates all the complex KV cache logic to vLLM + try: + from vllm.config import get_current_vllm_config + + config = get_current_vllm_config() + cache_config = ( + config.cache_config if hasattr(config, "cache_config") else None + ) + + # Generate unique prefix for this attention layer + # vLLM expects format "layers.X" for layer index extraction + layer_idx = next(TrainableFlashAttention._layer_counter) + prefix = f"layers.{layer_idx}" + + self.vllm_attn = Attention( + num_heads=num_heads, + head_size=self.head_dim, + scale=self.scale, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=None, + prefix=prefix, + ) + except (ImportError, RuntimeError, AttributeError): + # Not in vLLM context - attention layer not needed + self.vllm_attn = None + + # KV cache - will be populated by vLLM during model loading + # For V1 engine, this is a list[torch.Tensor] indexed by virtual_engine + self.kv_cache: list[torch.Tensor] | None = None + + # Auto-register for vLLM KV cache if in vLLM context + self._auto_register_for_kv_cache() + + def _auto_register_for_kv_cache(self): + """Automatically register this layer for vLLM KV cache allocation. + + This is called during __init__ and will register the layer if we're in + a vLLM context. If not in vLLM context (e.g., pure PyTorch training), + this silently does nothing. + """ + # Initialize layer_name attribute + self.layer_name: str | None = None + + try: + from vllm.config import get_current_vllm_config + + config = get_current_vllm_config() + compilation_config = config.compilation_config + + # Generate unique layer name using class counter + # Format: "layers.{index}" for compatibility with extract_layer_index() + layer_name = f"layers.{next(TrainableFlashAttention._layer_counter)}" + + # Register this layer in static forward context + if layer_name in compilation_config.static_forward_context: + raise ValueError(f"Duplicate layer name: {layer_name}") + compilation_config.static_forward_context[layer_name] = self + self.layer_name = layer_name + + except (ImportError, RuntimeError, AttributeError): + # Not in vLLM context - this is fine! + # Layer will work normally for training/inference without vLLM + pass + + def forward( + self, + hidden_states: torch.Tensor, + freqs_cis: torch.Tensor + | None = None, # RoPE frequencies (TorchTitan compatibility) + attention_mask: torch.Tensor | None = None, + **kwargs: Any, # Accept any additional vLLM-specific kwargs + ) -> torch.Tensor: + """ + Forward pass with flash attention. + + Supports both training (full sequences) and vLLM inference (with KV cache). + + Args: + hidden_states: Input tensor of shape [total_tokens, hidden_size] + or [batch, seq_len, hidden_size] + freqs_cis: RoPE frequencies (for TorchTitan compatibility, currently unused) + attention_mask: Optional attention mask (not yet fully supported) + **kwargs: Additional vLLM-specific kwargs (intermediate_tensors, etc.) + + Returns: + output: Attention output of same shape as hidden_states + """ + # Handle both batched [batch, seq, hidden] and flattened [total_tokens, hidden] + input_is_batched = hidden_states.dim() == 3 + if input_is_batched: + original_batch_size, original_seq_len, _ = hidden_states.shape + hidden_states = hidden_states.view(-1, self.hidden_size) + else: + original_batch_size = None + original_seq_len = None + + total_tokens = hidden_states.shape[0] + + # Project to Q, K, V (supports both fused and separate modes) + if self.use_fused_qkv: + # Fused projection path (efficient) + qkv = self.qkv(hidden_states) + + # Split into Q, K, V + # qkv shape: [total_tokens, (num_heads + 2*num_kv_heads) * head_dim] + q_size = self.num_heads * self.head_dim + k_size = self.num_kv_heads * self.head_dim + v_size = self.num_kv_heads * self.head_dim + + q = qkv[:, :q_size] + k = qkv[:, q_size : q_size + k_size] + v = qkv[:, q_size + k_size : q_size + k_size + v_size] + else: + # Separate projections (TorchTitan compatibility) + q = self.wq(hidden_states) + k = self.wk(hidden_states) + v = self.wv(hidden_states) + + # Reshape for attention: [total_tokens, num_heads, head_dim] + q = q.view(total_tokens, self.num_heads, self.head_dim) + k = k.view(total_tokens, self.num_kv_heads, self.head_dim) + v = v.view(total_tokens, self.num_kv_heads, self.head_dim) + + # Optional QK normalization (Qwen3 style) + if self.q_norm is not None: + q = self.q_norm(q) + if self.k_norm is not None: + k = self.k_norm(k) + + # DEBUG: Log layer 0 values to compare with TorchTitan + is_layer_0 = not hasattr(self, "_debug_logged") + if is_layer_0 and total_tokens > 1 and total_tokens < 100: # Skip warmup + self._debug_logged = True + print("\n[VLLM ATT DEBUG] Layer 0 - Input") + print(f" hidden_states.shape: {hidden_states.shape}") + print(f" total_tokens: {total_tokens}") + print(f" q (before RoPE)[0,0,:5]: {q[0, 0, :5]}") + print(f" k (before RoPE)[0,0,:5]: {k[0, 0, :5]}") + + # Apply RoPE if freqs_cis is provided (TorchTitan integration) + if freqs_cis is not None: + # Get positions from vLLM forward context + try: + from vllm.forward_context import get_forward_context + + forward_ctx = get_forward_context() + + # Try to get positions from custom attribute set by wrapper + positions = None + if hasattr(forward_ctx, "_torchtitan_positions"): + positions = forward_ctx._torchtitan_positions + # Debug: Log positions during generation, not just warmup + unique_pos = torch.unique(positions[: min(100, len(positions))]) + # Skip warmup with all zeros + if (len(unique_pos) > 1 or unique_pos[0] != 0) and not hasattr( + self, "_rope_gen_debug" + ): + self._rope_gen_debug = True + print(f"\n[ROPE GEN] Got real positions: {unique_pos[:20]}") + print( + f"[ROPE GEN] total_tokens: {total_tokens}, " + f"freqs_cis.shape: {freqs_cis.shape}" + ) + else: + # Fallback to sequential positions + positions = torch.arange(total_tokens, device=q.device) + + # Index rope_cache by positions + # freqs_cis shape after convert_freqs_cis_to_real: + # [max_seq_len, head_dim] (cos and sin concatenated) + positions_flat = positions.flatten() + + # Ensure positions are within bounds + max_pos = freqs_cis.shape[0] - 1 + positions_flat = torch.clamp(positions_flat[:total_tokens], 0, max_pos) + + cos_sin = freqs_cis.index_select(0, positions_flat) + + # Split into cos and sin at half of last dimension + # (works for both [*, head_dim] and [*, 2*head_dim] formats) + half_dim = cos_sin.shape[-1] // 2 + cos = cos_sin[..., :half_dim] + sin = cos_sin[..., half_dim:] + + # Expand cos/sin to full head_dim by repeating each element + # [total_tokens, head_dim//2] -> [total_tokens, head_dim] + cos = cos.repeat_interleave(2, dim=-1) + sin = sin.repeat_interleave(2, dim=-1) + + # Apply rotary embedding (same as TorchTitan's apply_rotary_emb) + def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + # Reshape cos/sin for broadcast: [total_tokens, 1, head_dim] + cos = cos.unsqueeze(1).to(dtype=q.dtype, device=q.device) + sin = sin.unsqueeze(1).to(dtype=q.dtype, device=q.device) + + # Apply rotation + q = (q * cos) + (rotate_half(q) * sin) + k = (k * cos) + (rotate_half(k) * sin) + + # DEBUG: Log after RoPE + if is_layer_0 and total_tokens > 1 and total_tokens < 100: + print(f" RoPE applied with positions: {unique_pos[:10]}") + print(f" freqs_cis.shape: {freqs_cis.shape}") + print(f" q (after RoPE)[0,0,:5]: {q[0, 0, :5]}") + print(f" k (after RoPE)[0,0,:5]: {k[0, 0, :5]}") + + except (ImportError, AttributeError, IndexError, AssertionError) as e: + # If we can't get positions, fall through without RoPE + # This will happen in pure training mode + if not hasattr(self, "_rope_error"): + self._rope_error = True + print(f"\n[ROPE DEBUG] Error applying RoPE: {e}") + pass + + # Delegate to vLLM's Attention layer if available + # (handles KV cache automatically) + if self.vllm_attn is not None and not self.training: + # Let vLLM handle all KV cache logic + # vllm_attn expects q,k,v in shape [total_tokens, num_heads*head_dim] + # or [total_tokens, num_heads, head_dim] + attn_output = self.vllm_attn(q, k, v) + # vllm_attn returns [total_tokens, num_heads * head_dim] + else: + # Training mode or fallback: use regular flash attention (no KV cache) + if not self.training and hidden_states.is_cuda: + # Inference without KV cache: use flash attention varlen + # Create simple cu_seqlens for single sequence + cu_seqlens_q = torch.tensor( + [0, total_tokens], + dtype=torch.int32, + device=hidden_states.device, + ) + + attn_output = flash_attn_varlen_func( + q=q, + k=k, + v=v, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_q, + max_seqlen_q=total_tokens, + max_seqlen_k=total_tokens, + softmax_scale=self.scale, + causal=self.causal, + dropout_p=0.0, + fa_version=3, + ) + else: + # Training mode with CPU: use PyTorch SDPA + batch_size = 1 + seq_len = total_tokens + + q = q.view(batch_size, seq_len, self.num_heads, self.head_dim) + k = k.view(batch_size, seq_len, self.num_kv_heads, self.head_dim) + v = v.view(batch_size, seq_len, self.num_kv_heads, self.head_dim) + + q = q.transpose(1, 2) # [batch, num_heads, seq_len, head_dim] + k = k.transpose(1, 2) + v = v.transpose(1, 2) + + # Handle GQA by repeating k, v if needed + if self.num_kv_heads != self.num_heads: + num_repeats = self.num_heads // self.num_kv_heads + k = k.repeat_interleave(num_repeats, dim=1) + v = v.repeat_interleave(num_repeats, dim=1) + + attn_output = torch.nn.functional.scaled_dot_product_attention( + q, + k, + v, + attn_mask=attention_mask, + dropout_p=self.dropout if self.training else 0.0, + is_causal=self.causal and attention_mask is None, + ) + + attn_output = attn_output.transpose( + 1, 2 + ) # [batch, seq_len, heads, dim] + attn_output = attn_output.reshape( + total_tokens, self.num_heads, self.head_dim + ) + + # Flatten heads and project output + attn_output = attn_output.reshape(total_tokens, -1) + if self.use_fused_qkv: + output = self.o_proj(attn_output) + else: + output = self.wo(attn_output) + + # DEBUG: Log attention output for layer 0 + if is_layer_0 and total_tokens > 1 and total_tokens < 100: + print(f" attn_output (before o_proj)[0,:5]: {attn_output[0, :5]}") + print(f" output (after o_proj)[0,:5]: {output[0, :5]}") + + # Restore original shape if input was batched + if input_is_batched: + output = output.view( + original_batch_size, original_seq_len, self.hidden_size + ) + + return output + + def get_attn_backend(self): + """ + Get the attention backend for this layer. + + For TrainableFlashAttention, we don't use a specific vLLM backend + since we implement attention directly. Return None to indicate + this layer manages its own attention computation. + """ + # Import here to avoid circular dependency + from vllm.v1.attention.backends.flash_attn import FlashAttentionBackend + + return FlashAttentionBackend + + def get_kv_cache_spec(self, vllm_config: VllmConfig) -> FullAttentionSpec: + """ + Return KV cache specification for V1 engine integration. + + This allows TrainableFlashAttention to work with vLLM's V1 engine + by providing the necessary KV cache metadata. + """ + block_size = vllm_config.cache_config.block_size + # Determine the dtype for KV cache + kv_cache_dtype = vllm_config.cache_config.cache_dtype + if kv_cache_dtype == "auto": + kv_cache_dtype = vllm_config.model_config.dtype + + return FullAttentionSpec( + block_size=block_size, + num_kv_heads=self.num_kv_heads, + head_size=self.head_dim, + dtype=kv_cache_dtype, + ) + + +__all__ = ["TrainableFlashAttention"] diff --git a/torchtitan/experiments/vllm/custom_models/trainable_mla_attention.py b/torchtitan/experiments/vllm/custom_models/trainable_mla_attention.py new file mode 100644 index 0000000000..8e16812249 --- /dev/null +++ b/torchtitan/experiments/vllm/custom_models/trainable_mla_attention.py @@ -0,0 +1,562 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +Trainable Multi-Head Latent Attention (MLA) for DeepSeek V3. + +This module implements the MLA architecture used in DeepSeek V3, which uses: +1. Low-rank compression for Q and KV projections +2. Split Q/K into RoPE and non-RoPE parts +3. Shared K_PE (RoPE-encoded key) across all heads + +Reference: https://github.com/deepseek-ai/DeepSeek-V3 +""" + +from dataclasses import dataclass +from typing import Any + +import torch +import torch.nn as nn + + +@dataclass +class MLAConfig: + """Configuration for Multi-Head Latent Attention. + + Args: + hidden_size: Hidden dimension of the model + num_heads: Number of attention heads + q_lora_rank: LoRA rank for query projection. If 0, use direct projection. + kv_lora_rank: LoRA rank for key-value projection + qk_nope_head_dim: Dimension of Q/K without positional encoding + qk_rope_head_dim: Dimension of Q/K with RoPE + v_head_dim: Dimension of value projection per head + norm_eps: Epsilon for RMSNorm layers + dropout: Dropout probability during training + scale: Attention scale factor. If None, defaults to 1/sqrt(qk_head_dim) + causal: Whether to use causal masking + """ + + hidden_size: int + num_heads: int + q_lora_rank: int = 0 + kv_lora_rank: int = 512 + qk_nope_head_dim: int = 128 + qk_rope_head_dim: int = 64 + v_head_dim: int = 128 + norm_eps: float = 1e-5 + dropout: float = 0.0 + scale: float | None = None + causal: bool = True + + @property + def qk_head_dim(self) -> int: + """Total Q/K head dimension.""" + return self.qk_nope_head_dim + self.qk_rope_head_dim + + +class TrainableMLA(nn.Module): + """ + Training-compatible Multi-Head Latent Attention (MLA). + + This implements DeepSeek V3's MLA architecture: + - Low-rank compression with intermediate RMSNorm + - Split Q/K into RoPE and non-RoPE parts + - Shared K_PE across all attention heads (memory efficient!) + + Example: + ```python + config = MLAConfig( + hidden_size=2048, + num_heads=16, + q_lora_rank=512, + kv_lora_rank=512, + qk_nope_head_dim=128, + qk_rope_head_dim=64, + v_head_dim=128, + ) + mla = TrainableMLA(config) + + # Forward pass + hidden_states = torch.randn(2, 16, 2048) + freqs_cis = torch.randn(16, config.qk_rope_head_dim // 2) + output = mla(hidden_states, freqs_cis=freqs_cis) + ``` + """ + + def __init__(self, config: MLAConfig): + super().__init__() + + self.config = config + self.hidden_size = config.hidden_size + self.num_heads = config.num_heads + self.q_lora_rank = config.q_lora_rank + self.kv_lora_rank = config.kv_lora_rank + self.qk_nope_head_dim = config.qk_nope_head_dim + self.qk_rope_head_dim = config.qk_rope_head_dim + self.qk_head_dim = config.qk_head_dim + self.v_head_dim = config.v_head_dim + self.dropout = config.dropout + self.causal = config.causal + + # Attention scale + if config.scale is None: + self.scale = self.qk_head_dim**-0.5 + else: + self.scale = config.scale + + # Query projection + if self.q_lora_rank == 0: + # Direct projection without LoRA + self.wq = nn.Linear( + self.hidden_size, self.num_heads * self.qk_head_dim, bias=False + ) + self.q_norm = None + else: + # Low-rank projection with intermediate norm + self.wq_a = nn.Linear(self.hidden_size, self.q_lora_rank, bias=False) + self.q_norm = nn.RMSNorm(self.q_lora_rank, eps=config.norm_eps) + self.wq_b = nn.Linear( + self.q_lora_rank, self.num_heads * self.qk_head_dim, bias=False + ) + + # Key-value projection (always uses LoRA) + # Note: wkv_a outputs kv_lora_rank + qk_rope_head_dim + # The extra qk_rope_head_dim is for the shared K_PE + self.wkv_a = nn.Linear( + self.hidden_size, self.kv_lora_rank + self.qk_rope_head_dim, bias=False + ) + self.kv_norm = nn.RMSNorm(self.kv_lora_rank, eps=config.norm_eps) + self.wkv_b = nn.Linear( + self.kv_lora_rank, + self.num_heads * (self.qk_nope_head_dim + self.v_head_dim), + bias=False, + ) + + # Output projection + self.wo = nn.Linear( + self.num_heads * self.v_head_dim, self.hidden_size, bias=False + ) + + # Create vLLM MLAAttention for KV cache + optimized attention + # We'll initialize it in a lazy way since we need vLLM config + self.vllm_mla_attn = None + self._init_vllm_mla_attention() + + def _init_vllm_mla_attention(self): + """Initialize vLLM's MLAAttention for KV cache and optimized attention.""" + try: + from vllm.attention.layer import MLAAttention + from vllm.config import get_current_vllm_config + from vllm.model_executor.layers.linear import ColumnParallelLinear + + # Get vLLM config if available + try: + vllm_config = get_current_vllm_config() + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + except (RuntimeError, AttributeError): + # Not in vLLM context - skip MLAAttention initialization + return + + # Generate unique layer name for KV cache registration + import itertools + + if not hasattr(TrainableMLA, "_layer_counter"): + TrainableMLA._layer_counter = itertools.count() + + layer_name = f"layers.{next(TrainableMLA._layer_counter)}.attention" + + # Wrap wkv_b in ColumnParallelLinear (vLLM's parallel layer) + # This allows vLLM to handle TP sharding properly + kv_b_proj = ColumnParallelLinear( + input_size=self.kv_lora_rank, + output_size=self.num_heads * (self.qk_nope_head_dim + self.v_head_dim), + bias=False, + quant_config=quant_config, + ) + # Copy weights from our regular Linear layer + kv_b_proj.weight.data.copy_(self.wkv_b.weight.data) + + # Create vLLM's MLAAttention + self.vllm_mla_attn = MLAAttention( + num_heads=self.num_heads, + scale=self.scale, + qk_nope_head_dim=self.qk_nope_head_dim, + qk_rope_head_dim=self.qk_rope_head_dim, + v_head_dim=self.v_head_dim, + q_lora_rank=self.q_lora_rank if self.q_lora_rank > 0 else None, + kv_lora_rank=self.kv_lora_rank, + kv_b_proj=kv_b_proj, + cache_config=cache_config, + quant_config=quant_config, + prefix=layer_name, + ) + + print(f" ✓ Created vLLM MLAAttention for {layer_name}") + + except (ImportError, RuntimeError, AttributeError, AssertionError) as e: + # vLLM not available or not in vLLM context - use manual implementation + print(f" ⚠️ Could not create vLLM MLAAttention: {e}") + pass + + def _auto_register_for_kv_cache(self): + """Automatically register this layer for vLLM KV cache allocation. + + This is called during __init__ and will register the layer if we're in + a vLLM context. If not in vLLM context, this silently does nothing. + """ + self.layer_name: str | None = None + + try: + from vllm.config import get_current_vllm_config + + config = get_current_vllm_config() + compilation_config = config.compilation_config + + # Generate unique layer name + import itertools + + if not hasattr(TrainableMLA, "_layer_counter"): + TrainableMLA._layer_counter = itertools.count() + + layer_name = f"layers.{next(TrainableMLA._layer_counter)}" + + # Register this layer in static forward context + if layer_name in compilation_config.static_forward_context: + raise ValueError(f"Duplicate layer name: {layer_name}") + compilation_config.static_forward_context[layer_name] = self + self.layer_name = layer_name + + except (ImportError, RuntimeError, AttributeError): + # Not in vLLM context - this is fine! + pass + + def apply_rotary_emb( + self, x: torch.Tensor, freqs_cis: torch.Tensor + ) -> torch.Tensor: + """ + Apply rotary positional embeddings to the input tensor. + + Args: + x: Input tensor [total_tokens, heads, qk_rope_head_dim] + or [bsz, seq_len, heads, qk_rope_head_dim] + freqs_cis: Precomputed complex exponentials + [max_seq_len, qk_rope_head_dim//2] (complex64/complex128) + + Returns: + Tensor with rotary embeddings applied + """ + # Determine if batched or flattened + # [bsz, seq_len, heads, dim] or [total_tokens, heads, dim] + seq_len = x.size(1) if x.dim() == 4 else x.size(0) + + # Slice freqs_cis to actual sequence length + # freqs_cis is complex: [max_seq_len, qk_rope_head_dim//2] + freqs = freqs_cis[:seq_len] # [seq_len, qk_rope_head_dim//2] + + # Convert x to complex for rotation + # x: [..., qk_rope_head_dim] -> [..., qk_rope_head_dim//2] complex + x_complex = torch.view_as_complex( + x.float().reshape(*x.shape[:-1], -1, 2) + ) # [..., qk_rope_head_dim//2] + + # Reshape freqs for broadcasting + # Batched: [bsz, seq_len, heads, dim] -> freqs [1, seq_len, 1, dim//2] + # Flattened: [total_tokens, heads, dim] -> freqs [seq_len, 1, dim//2] + freqs = freqs.unsqueeze(0).unsqueeze(2) if x.dim() == 4 else freqs.unsqueeze(1) + + # Apply rotation: multiply by complex exponential + x_rotated = x_complex * freqs + + # Convert back to real + x_out = torch.view_as_real(x_rotated).flatten(-2) # [..., qk_rope_head_dim] + + return x_out.to(x.dtype) + + def apply_rotary_emb_with_cos_sin( + self, x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor + ) -> torch.Tensor: + """ + Apply rotary positional embeddings using cos and sin directly. + + Args: + x: Input tensor [total_tokens, heads, qk_rope_head_dim] + cos: Cosine values [total_tokens, qk_rope_head_dim//2] + sin: Sine values [total_tokens, qk_rope_head_dim//2] + + Returns: + Tensor with rotary embeddings applied + [total_tokens, heads, qk_rope_head_dim] + """ + # Expand cos/sin to match x's head dimension + # cos/sin: [total_tokens, qk_rope_head_dim//2] + # -> [total_tokens, 1, qk_rope_head_dim//2] + cos = cos.unsqueeze(1) + sin = sin.unsqueeze(1) + + # Repeat to full dimension: [total_tokens, 1, qk_rope_head_dim] + cos = cos.repeat_interleave(2, dim=-1) + sin = sin.repeat_interleave(2, dim=-1) + + # Apply rotation using rotate_half (avoids complex operations) + def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + # Apply RoPE: x_rotated = x * cos + rotate_half(x) * sin + x_out = (x * cos) + (rotate_half(x) * sin) + + return x_out.to(x.dtype) + + def apply_rotary_emb_indexed( + self, x: torch.Tensor, freqs_for_tokens: torch.Tensor + ) -> torch.Tensor: + """ + DEPRECATED: Use apply_rotary_emb_with_cos_sin instead. + + Apply rotary positional embeddings using pre-indexed frequencies. + + Args: + x: Input tensor [total_tokens, heads, qk_rope_head_dim] + freqs_for_tokens: Pre-indexed frequencies - complex or real format + + Returns: + Tensor with rotary embeddings applied + [total_tokens, heads, qk_rope_head_dim] + """ + # Check if freqs_for_tokens is complex or already split into cos/sin + if freqs_for_tokens.is_complex(): + # Extract cos and sin from complex frequencies + # freqs_for_tokens is complex exponentials: e^(i*theta) + # = cos(theta) + i*sin(theta) + cos = freqs_for_tokens.real # [total_tokens, qk_rope_head_dim//2] + sin = freqs_for_tokens.imag # [total_tokens, qk_rope_head_dim//2] + elif freqs_for_tokens.shape[-1] == x.shape[-1] // 2: + # Format: [total_tokens, qk_rope_head_dim//2] + # complex stored as real + # This happens after index_select on complex tensor + # The tensor is complex data stored in real format + # We need to extract real and imaginary parts + # Actually this shouldn't happen, but handle it anyway + print( + f"[DEBUG] freqs_for_tokens shape: {freqs_for_tokens.shape}, " + f"dtype: {freqs_for_tokens.dtype}" + ) + print(f"[DEBUG] x shape: {x.shape}") + # This format is ambiguous - assume it needs to be duplicated + cos = freqs_for_tokens + sin = freqs_for_tokens + else: + # freqs_for_tokens is already real, split it into cos and sin + # Assume format: [total_tokens, qk_rope_head_dim] + # where first half is cos, second is sin + half_dim = freqs_for_tokens.shape[-1] // 2 + cos = freqs_for_tokens[ + ..., :half_dim + ] # [total_tokens, qk_rope_head_dim//2] + sin = freqs_for_tokens[ + ..., half_dim: + ] # [total_tokens, qk_rope_head_dim//2] + + return self.apply_rotary_emb_with_cos_sin(x, cos, sin) + + def forward( + self, + hidden_states: torch.Tensor, + freqs_cis: torch.Tensor, + attention_mask: torch.Tensor | None = None, + positions: torch.Tensor | None = None, # vLLM provides positions + **kwargs: Any, + ) -> torch.Tensor: + """ + Forward pass for Multi-Head Latent Attention. + + Args: + hidden_states: Input tensor of shape [batch, seq_len, hidden_size] + or [total_tokens, hidden_size] + freqs_cis: Precomputed RoPE frequencies + [max_seq_len, qk_rope_head_dim//2] + attention_mask: Optional attention mask (not fully supported yet) + positions: Per-token positions for RoPE indexing (from vLLM) + **kwargs: Additional vLLM-specific kwargs + + Returns: + Output tensor of same shape as hidden_states + """ + # Handle both batched [batch, seq, hidden] and flattened [total_tokens, hidden] + input_is_batched = hidden_states.dim() == 3 + if input_is_batched: + bsz, seqlen, _ = hidden_states.shape + hidden_states_flat = hidden_states.view(-1, self.hidden_size) + else: + # Flattened format (vLLM inference) + hidden_states_flat = hidden_states + bsz = 1 + seqlen = hidden_states.shape[0] + + total_tokens = hidden_states_flat.shape[0] + + # Get positions for RoPE indexing + if positions is None: + # Try to get from vLLM forward context + try: + from vllm.forward_context import get_forward_context + + forward_ctx = get_forward_context() + if hasattr(forward_ctx, "_torchtitan_positions"): + positions = forward_ctx._torchtitan_positions + else: + # Fallback: sequential positions + positions = torch.arange( + total_tokens, device=hidden_states_flat.device + ) + except (ImportError, AttributeError, AssertionError): + # Training mode: sequential positions + positions = torch.arange(total_tokens, device=hidden_states_flat.device) + + # Query projection + if self.q_lora_rank == 0: + q = self.wq(hidden_states_flat) # [total_tokens, n_heads * qk_head_dim] + else: + q = self.wq_a(hidden_states_flat) # [total_tokens, q_lora_rank] + assert self.q_norm is not None # q_norm exists when q_lora_rank > 0 + q = self.wq_b(self.q_norm(q)) # [total_tokens, n_heads * qk_head_dim] + + # Reshape: [total_tokens, n_heads, qk_head_dim] + q = q.view(total_tokens, self.num_heads, self.qk_head_dim) + + # Split Q into non-RoPE and RoPE parts + q_nope, q_pe = torch.split( + q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1 + ) + + # Apply RoPE to q_pe using positions to index freqs_cis + # Convert freqs_cis from complex to cos/sin BEFORE indexing + # to avoid dtype issues + if freqs_cis.is_complex(): + # Extract cos and sin from complex freqs_cis + freqs_cos = freqs_cis.real # [max_seq_len, qk_rope_head_dim//2] + freqs_sin = freqs_cis.imag # [max_seq_len, qk_rope_head_dim//2] + # Concatenate for easier indexing + freqs_real = torch.cat( + [freqs_cos, freqs_sin], dim=-1 + ) # [max_seq_len, qk_rope_head_dim] + else: + freqs_real = freqs_cis + + # Index by positions + positions_flat = positions.flatten()[:total_tokens] + max_pos = freqs_real.shape[0] - 1 + positions_clamped = torch.clamp(positions_flat, 0, max_pos) + freqs_for_tokens = freqs_real.index_select( + 0, positions_clamped + ) # [total_tokens, qk_rope_head_dim] + + # Split into cos and sin + half_dim = self.qk_rope_head_dim // 2 + cos_for_tokens = freqs_for_tokens[ + ..., :half_dim + ] # [total_tokens, qk_rope_head_dim//2] + sin_for_tokens = freqs_for_tokens[ + ..., half_dim: + ] # [total_tokens, qk_rope_head_dim//2] + + # Apply RoPE to q_pe: [total_tokens, num_heads, qk_rope_head_dim] + q_pe = self.apply_rotary_emb_with_cos_sin(q_pe, cos_for_tokens, sin_for_tokens) + + # Concatenate back: [total_tokens, n_heads, qk_head_dim] + q = torch.cat([q_nope, q_pe], dim=-1) + + # Key-value projection + kv = self.wkv_a( + hidden_states_flat + ) # [total_tokens, kv_lora_rank + qk_rope_head_dim] + + # Split into compressed KV and K_PE + kv_c, k_pe = torch.split(kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) + + # Apply RoPE to k_pe: [total_tokens, qk_rope_head_dim] + # Reshape to [total_tokens, 1, qk_rope_head_dim] for apply_rotary_emb + k_pe = k_pe.unsqueeze(1) # [total_tokens, 1, qk_rope_head_dim] + k_pe = self.apply_rotary_emb_with_cos_sin( + k_pe, cos_for_tokens, sin_for_tokens + ) # [total_tokens, 1, qk_rope_head_dim] + + # Normalize compressed KV + kv_c_normed = self.kv_norm(kv_c) # [total_tokens, kv_lora_rank] + + # Delegate to vLLM's MLAAttention if available (handles KV cache automatically) + if self.vllm_mla_attn is not None and not self.training: + # Let vLLM handle all KV cache logic + attn_output = self.vllm_mla_attn( + q, + kv_c_normed, + k_pe, + output_shape=(total_tokens, self.num_heads * self.v_head_dim), + ) + else: + # Training mode or fallback: manual implementation + # Decompress KV + kv = self.wkv_b( + kv_c_normed + ) # [total_tokens, n_heads * (qk_nope_head_dim + v_head_dim)] + kv = kv.view( + total_tokens, self.num_heads, self.qk_nope_head_dim + self.v_head_dim + ) + + # Split into K_nope and V + k_nope, v = torch.split( + kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1 + ) + + # Concatenate K_nope with broadcasted K_PE + # k_pe shape: [total_tokens, 1, qk_rope_head_dim] + # Expand to: [total_tokens, n_heads, qk_rope_head_dim] + k = torch.cat( + [k_nope, k_pe.expand(-1, self.num_heads, -1)], dim=-1 + ) # [total_tokens, n_heads, qk_head_dim] + + # Reshape for batched attention: [bsz, seqlen, n_heads, head_dim] + q = q.view(bsz, seqlen, self.num_heads, self.qk_head_dim) + k = k.view(bsz, seqlen, self.num_heads, self.qk_head_dim) + v = v.view(bsz, seqlen, self.num_heads, self.v_head_dim) + + # Transpose for attention: [bsz, n_heads, seqlen, head_dim] + q = q.transpose(1, 2) + k = k.transpose(1, 2) + v = v.transpose(1, 2) + + # Use PyTorch SDPA (supports different head dims for Q/K vs V) + # Flash attention doesn't support qk_head_dim != v_head_dim, so we use SDPA + attn_output = torch.nn.functional.scaled_dot_product_attention( + q, + k, + v, + attn_mask=attention_mask, + dropout_p=self.dropout if self.training else 0.0, + is_causal=self.causal and attention_mask is None, + scale=self.scale, + ) + + # Transpose back and reshape: [total_tokens, n_heads * v_head_dim] + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape( + total_tokens, self.num_heads * self.v_head_dim + ) + + # Output projection: [total_tokens, hidden_size] + output = self.wo(attn_output) + + # Restore original shape if input was batched + if input_is_batched: + output = output.view(bsz, seqlen, self.hidden_size) + + return output diff --git a/torchtitan/experiments/vllm/custom_models/utils.py b/torchtitan/experiments/vllm/custom_models/utils.py new file mode 100644 index 0000000000..5ea546d563 --- /dev/null +++ b/torchtitan/experiments/vllm/custom_models/utils.py @@ -0,0 +1,245 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +Utility functions for TorchTitan model integration with vLLM. + +This module provides helper functions for common TorchTitan integration tasks: +- Converting RoPE frequency tensors from complex to real format +- Managing forward context for position indices +- Creating KV cache specifications for MLA +- Loading weights from HuggingFace checkpoints with name mapping + +Example usage: + ```python + from vllm.model_executor.utils.torchtitan_utils import ( + convert_freqs_cis_to_real, + create_mla_kv_cache_spec, + ) + + # Convert TorchTitan's complex freqs_cis to vLLM-compatible format + model.freqs_cis = convert_freqs_cis_to_real(model.freqs_cis) + + # Create KV cache spec for MLA attention + kv_cache_spec = create_mla_kv_cache_spec( + kv_lora_rank=512, + qk_rope_head_dim=64, + vllm_config=vllm_config, + ) + ``` +""" + +from collections.abc import Iterator +from typing import Any + +import torch +import torch.nn as nn + + +def convert_freqs_cis_to_real(freqs_cis: torch.Tensor) -> torch.Tensor: + """ + Convert complex RoPE frequencies to real format (cos, sin concatenated). + + TorchTitan uses complex exponentials e^(i*theta) for RoPE, but vLLM and + dtype conversion to bfloat16 require real tensors. This converts: + complex[max_seq_len, dim//2] -> real[max_seq_len, dim] + + Args: + freqs_cis: Complex frequency tensor [max_seq_len, dim//2] + + Returns: + Real tensor [max_seq_len, dim] with cos and sin concatenated + """ + if not freqs_cis.is_complex(): + # Already in real format + return freqs_cis + + # Extract cos and sin from complex exponentials + # e^(i*theta) = cos(theta) + i*sin(theta) + freqs_cos = freqs_cis.real # [max_seq_len, dim//2] + freqs_sin = freqs_cis.imag # [max_seq_len, dim//2] + + # Concatenate: [max_seq_len, dim] + freqs_real = torch.cat([freqs_cos, freqs_sin], dim=-1) + + return freqs_real + + +def store_positions_in_context(positions: torch.Tensor | None) -> None: + """ + Store position indices in vLLM's forward context. + + This allows attention layers to access per-token positions for RoPE indexing + during inference without explicitly passing them through every layer. + + Args: + positions: Position indices from vLLM [total_tokens] or None + """ + if positions is None: + return + + try: + from vllm.forward_context import get_forward_context + + forward_ctx = get_forward_context() + # Store positions in a custom attribute accessible to attention layers + forward_ctx._torchtitan_positions = positions + except (ImportError, RuntimeError, AttributeError): + # Not in vLLM context - this is fine (e.g., during testing) + pass + + +def create_mla_kv_cache_spec( + kv_lora_rank: int, + qk_rope_head_dim: int, + block_size: int, + dtype: torch.dtype, +) -> Any: + """ + Create KV cache specification for Multi-Head Latent Attention (MLA). + + MLA uses compressed KV cache with layout: + [kv_lora_rank + qk_rope_head_dim] per token + + Args: + kv_lora_rank: LoRA rank for compressed KV (e.g., 512) + qk_rope_head_dim: Dimension of RoPE-encoded keys (e.g., 64) + block_size: KV cache block size from vLLM config + dtype: Data type for KV cache + + Returns: + MLAAttentionSpec instance + """ + from vllm.v1.kv_cache_interface import MLAAttentionSpec + + # MLA cache layout: compressed KV + shared K_PE + head_size = kv_lora_rank + qk_rope_head_dim + + return MLAAttentionSpec( + block_size=block_size, + num_kv_heads=1, # MLA shares K_PE across all heads + head_size=head_size, + dtype=dtype, + ) + + +def load_external_weights( + model: nn.Module, + weights_iter: Iterator[tuple[str, torch.Tensor]], + name_mapping: dict[str, str], + verbose: bool = False, +) -> tuple[int, int]: + """ + Load weights from HuggingFace checkpoint into external model. + + Maps HuggingFace parameter names to model parameter names and loads + them into the model. Supports layer-specific patterns with {} placeholders. + + Args: + model: Model instance to load weights into + weights_iter: Iterator yielding (name, tensor) from HF checkpoint + name_mapping: Dict mapping HF names to model parameter names. + Use {} as placeholder for layer numbers, e.g.: + {"model.layers.{}.attn.weight": "layers.{}.attention.weight"} + verbose: Whether to print detailed loading progress + + Returns: + Tuple of (loaded_count, skipped_count) + + Example: + ```python + name_mapping = { + "model.embed_tokens.weight": "tok_embeddings.weight", + "model.layers.{}.self_attn.q_proj.weight": "layers.{}.attention.wq.weight", + "lm_head.weight": "output.weight", + } + loaded, skipped = load_external_weights(model, weights_iter, name_mapping) + ``` + """ + from vllm.model_executor.model_loader.weight_utils import default_weight_loader + + # Get all parameter names in the model + params_dict = dict(model.named_parameters()) + + loaded_count = 0 + skipped_count = 0 + + # Convert iterator to list to check if empty + weights_list = list(weights_iter) + if len(weights_list) == 0: + if verbose: + print(" ⚠️ No weight files found - using random initialization") + return 0, 0 + + for hf_name, loaded_weight in weights_list: + # Try to find matching pattern in name_mapping + target_name = None + + # Check if it's a layer-specific weight + if "layers" in hf_name: + # Extract layer number + import regex as re + + layer_match = re.search(r"layers\.(\d+)\.", hf_name) + if layer_match: + layer_num = layer_match.group(1) + + # Try to find matching pattern + for hf_pattern, target_pattern in name_mapping.items(): + if "{}" in hf_pattern: + hf_concrete = hf_pattern.format(layer_num) + if hf_name == hf_concrete: + target_name = target_pattern.format(layer_num) + break + else: + # Non-layer weight (embeddings, norms, output) + target_name = name_mapping.get(hf_name) + + if target_name is None: + # Skip MoE weights and other unmapped weights + if ( + "mlp.experts" in hf_name + or "mlp.gate" in hf_name + or "mlp.shared_experts" in hf_name + ): + # MoE weights - skip silently + skipped_count += 1 + continue + else: + if verbose: + print(f" ⚠️ No mapping for: {hf_name}") + skipped_count += 1 + continue + + # Check if parameter exists in model + if target_name not in params_dict: + if verbose: + print(f" ⚠️ Parameter not found in model: {target_name}") + skipped_count += 1 + continue + + # Load the weight + param = params_dict[target_name] + + # Verify shapes match + if param.shape != loaded_weight.shape: + if verbose: + print(f" ⚠️ Shape mismatch for {target_name}:") + print(f" Model: {param.shape}, Checkpoint: {loaded_weight.shape}") + skipped_count += 1 + continue + + # Load the weight + default_weight_loader(param, loaded_weight) + loaded_count += 1 + + # Log first few loads for verification + if verbose and loaded_count <= 5: + print(f" ✓ Loaded {target_name}: {loaded_weight.shape}") + + return loaded_count, skipped_count diff --git a/torchtitan/experiments/vllm/example_checkpoint/config.json b/torchtitan/experiments/vllm/example_checkpoint/config.json new file mode 100644 index 0000000000..e69de29bb2 diff --git a/torchtitan/experiments/vllm/infer.py b/torchtitan/experiments/vllm/infer.py old mode 100644 new mode 100755 index 77b727e601..976c34e451 --- a/torchtitan/experiments/vllm/infer.py +++ b/torchtitan/experiments/vllm/infer.py @@ -1,3 +1,4 @@ +#!/usr/bin/env python3 # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # @@ -5,122 +6,128 @@ # LICENSE file in the root directory of this source tree. """ -Simple inference script for TorchTitan-trained Qwen3 model using vLLM. +Example CLI to run TorchTitan Qwen3 model inference with vLLM: -This script demonstrates how to: -1. Register a custom TorchTitan Qwen3 model with vLLM -2. Load a TorchTitan checkpoint into vLLM -3. Run inference using vLLM's optimized engine - -Usage: - python infer.py --model-path /path/to/torchtitan/checkpoint --prompt "Hello, world!" +# Run inference +python torchtitan/experiments/vllm/infer.py """ import argparse -import logging -from pathlib import Path +import json +import os +import shutil +import tempfile +import torch.nn as nn from vllm import LLM, SamplingParams -from vllm.model_executor.models import ModelRegistry +from vllm.model_executor.parallel_context import ParallelContext + -logging.basicConfig( - level=logging.INFO, - format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", -) -logger = logging.getLogger(__name__) +def build_qwen3_torchtitan(vllm_config, parallel_context: ParallelContext) -> nn.Module: + """ + Factory function to build Qwen3 with TorchTitan + vLLM. + This is registered with vLLM's ModelRegistry to enable: + LLM(model="Qwen/Qwen3-0.6B", ...) -def register_torchtitan_qwen3_model(): - """Register the TorchTitan Qwen3 model with vLLM's model registry.""" - from torchtitan.experiments.vllm.model.qwen3 import TorchTitanQwen3ForCausalLM + Args: + vllm_config: vLLM configuration object + parallel_context: Parallelism context with TP/PP info - logger.info("Registering TorchTitan Qwen3 model with vLLM") + Returns: + TorchTitanQwen3ForCausalLM instance + """ + from torchtitan.experiments.vllm.model.qwen3 import TorchTitanQwen3ForCausalLM - # Register the model using Qwen3's architecture but with custom weight loading - ModelRegistry.register_model( - "TorchTitanQwen3ForCausalLM", - TorchTitanQwen3ForCausalLM, + # Create model + model = TorchTitanQwen3ForCausalLM( + vllm_config=vllm_config, parallel_context=parallel_context ) - print("Successfully registered TorchTitanQwen3ForCausalLM") + # Apply tensor parallelism if TP > 1 + # This must happen AFTER model creation and attention replacement + # but BEFORE dtype conversion (to avoid dtype issues with DTensors) + if parallel_context is not None: + tp_size = parallel_context.get_tensor_parallel_world_size() + if tp_size > 1: + from torch.distributed.device_mesh import init_device_mesh + from torchtitan.models.qwen3.infra.parallelize import apply_non_moe_tp + print(f"🔧 Applying Tensor Parallelism (TP={tp_size})...") -def run_inference( - model: str, - prompts: list[str], - max_tokens: int = 100, - temperature: float = 0.8, - top_p: float = 0.95, - tensor_parallel_size: int = 1, -): - """ - Run inference using vLLM with a TorchTitan-trained Qwen3 model. + # Create DeviceMesh for TorchTitan + tp_mesh = init_device_mesh( + "cuda", + (tp_size,), + mesh_dim_names=("tp",), + ) - Args: - model: Model name - prompts: List of prompts to generate from - max_tokens: Maximum number of tokens to generate - temperature: Sampling temperature - top_p: Top-p sampling parameter - tensor_parallel_size: Number of GPUs for tensor parallelism + # Apply TorchTitan's tensor parallelism to shard weights + apply_non_moe_tp( + model.model, + tp_mesh=tp_mesh, + loss_parallel=False, # Don't shard the output for loss computation + enable_float8_tensorwise_tp=False, + enable_async_tp=False, + ) + + print(f"✅ Applied Tensor Parallelism (TP={tp_size})") + + # Convert to dtype if specified (happens after TP) + if hasattr(vllm_config, "model_config") and hasattr( + vllm_config.model_config, "dtype" + ): + model = model.to(dtype=vllm_config.model_config.dtype) + + return model + + +# Register with vLLM's ModelRegistry +from vllm import ModelRegistry + +ModelRegistry.register_model("Qwen3TorchTitan", build_qwen3_torchtitan) + + +def register_torchtitan_model(): """ - # Create sampling parameters - sampling_params = SamplingParams( - temperature=temperature, - top_p=top_p, - max_tokens=max_tokens, - ) + Register the TorchTitan Qwen3 custom model with vLLM using factory function pattern. - # entry point: + This registers a factory function that vLLM will call to create the model, + allowing us to apply tensor parallelism and other transformations. + """ try: - llm = LLM( - model=model, - model_impl="vllm", - skip_tokenizer_init=True, - ) - except Exception as e: - logger.error( - "Failed to initialize vLLM engine with TorchTitanQwen3ForCausalLM model\n" - ) - raise - - logger.info("Model loaded successfully, starting generation...") + from vllm import ModelRegistry - # Generate outputs - outputs = llm.generate(prompts, sampling_params) + # Register the factory function with vLLM + # vLLM will call build_qwen3_torchtitan(vllm_config, parallel_context) + ModelRegistry.register_model( + "Qwen3TorchTitanForCausalLM", build_qwen3_torchtitan + ) - # Print results - for output in outputs: - prompt = output.prompt - generated_text = output.outputs[0].text - logger.info("-" * 80) - logger.info(f"Prompt: {prompt}") - logger.info(f"Generated: {generated_text}") + print("✅ Successfully registered TorchTitan Qwen3 custom model with vLLM") + return True - logger.info("-" * 80) - logger.info(f"Generated {len(outputs)} outputs successfully") + except Exception as e: + print(f"⚠️ Warning: Failed to register custom model: {e}") + return False -def main(): +def parse_args(): parser = argparse.ArgumentParser( - description="Run inference with TorchTitan Qwen3 model using vLLM" + description="Run TorchTitan Qwen3 model inference with vLLM", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, ) parser.add_argument( "--model", type=str, default="torchtitan/experiments/vllm/checkpoint/", - help="Path to the TorchTitan checkpoint or HuggingFace model directory", + help="Path to TorchTitan checkpoint directory", ) parser.add_argument( "--prompt", type=str, - default="Hello, how are you?", - help="Single prompt to generate from", - ) - parser.add_argument( - "--prompts-file", - type=str, - help="Path to file containing prompts (one per line)", + default="Hello, my name is", + help="Prompt text for generation", ) parser.add_argument( "--max-tokens", @@ -134,44 +141,102 @@ def main(): default=0.8, help="Sampling temperature", ) - parser.add_argument( - "--top-p", - type=float, - default=0.95, - help="Top-p (nucleus) sampling parameter", - ) - parser.add_argument( - "--tensor-parallel-size", - type=int, - default=1, - help="Number of GPUs for tensor parallelism", - ) + return parser.parse_args() - args = parser.parse_args() - # Register the custom model - register_torchtitan_qwen3_model() +def main(): + args = parse_args() + + print("=" * 80) + print("REGISTERING TORCHTITAN QWEN3 CUSTOM MODEL") + print("=" * 80) + + # Register the custom model with vLLM + register_torchtitan_model() + + # Create a temporary directory with minimal config.json for vLLM + temp_dir = tempfile.mkdtemp(prefix="vllm_torchtitan_qwen_") + minimal_config = { + "architectures": ["Qwen3TorchTitanForCausalLM"], + # Why `model_type`: Tells HuggingFace Transformers to use `Qwen3Config` class (known type) + "model_type": "qwen3", # Use known HF model type + # The following parameter is Qwen3-0.6B + "hidden_size": 3584, + "intermediate_size": 18944, + "num_attention_heads": 28, + "num_hidden_layers": 28, + "num_key_value_heads": 4, + "vocab_size": 151936, + "max_position_embeddings": 32768, + "rope_theta": 1000000.0, + "rms_norm_eps": 1e-06, + "tie_word_embeddings": False, + "head_dim": 128, + "attention_bias": False, + "hidden_act": "silu", + "qk_norm": True, + "torch_dtype": "bfloat16", + "skip_tokenizer_init": True, + } + + config_path = os.path.join(temp_dir, "config.json") + with open(config_path, "w") as f: + json.dump(minimal_config, f, indent=2) + + print(f"Created temporary model config at: {temp_dir}") + print(f"Using checkpoint: {args.model}") + + print("=" * 80) + print("INITIALIZING vLLM WITH TORCHTITAN QWEN3 MODEL") + print("=" * 80) + + # Build hf_overrides with checkpoint path + hf_overrides = { + "checkpoint_dir": args.model, + } + + # Initialize vLLM with custom TorchTitan Qwen3 model + llm = LLM( + model=temp_dir, # Use temporary directory with config.json + hf_overrides=hf_overrides, + dtype="bfloat16", + trust_remote_code=True, + enforce_eager=True, # Use eager mode for debugging + ) - # Prepare prompts - if args.prompts_file: - prompts_path = Path(args.prompts_file) - if not prompts_path.exists(): - raise FileNotFoundError(f"Prompts file not found: {prompts_path}") - prompts = prompts_path.read_text().strip().split("\n") - logger.info(f"Loaded {len(prompts)} prompts from {prompts_path}") - else: - prompts = [args.prompt] + print("=" * 80) + print("vLLM ENGINE INITIALIZED - STARTING GENERATION") + print("=" * 80) - # Run inference - run_inference( - model=args.model, - prompts=prompts, - max_tokens=args.max_tokens, + # Prepare prompt + prompts = [args.prompt] + sampling_params = SamplingParams( temperature=args.temperature, - top_p=args.top_p, - tensor_parallel_size=args.tensor_parallel_size, + top_p=0.95, + max_tokens=args.max_tokens, ) + # Generate + outputs = llm.generate( + prompts=prompts, + sampling_params=sampling_params, + ) + + # Print results + for output in outputs: + prompt = output.prompt + generated_text = output.outputs[0].text + + print(f"\nPrompt: {prompt}") + print(f"Generated text: {generated_text!r}") + + # Clean up temporary directory + try: + shutil.rmtree(temp_dir) + print(f"\nCleaned up temporary directory: {temp_dir}") + except Exception as e: + print(f"Warning: Could not clean up temporary directory {temp_dir}: {e}") + if __name__ == "__main__": main() diff --git a/torchtitan/experiments/vllm/model/qwen3.py b/torchtitan/experiments/vllm/model/qwen3.py index 131b2a8c7a..148b4ade9f 100644 --- a/torchtitan/experiments/vllm/model/qwen3.py +++ b/torchtitan/experiments/vllm/model/qwen3.py @@ -3,38 +3,202 @@ # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +""" +TorchTitan Qwen3 Model for vLLM Integration. -from collections.abc import Iterable +This module provides a vLLM-compatible wrapper around TorchTitan's Qwen3 model, +enabling models trained with TorchTitan to be served through vLLM for inference. + +Example: + ```python + from vllm import LLM + + llm = LLM( + model="Qwen/Qwen3-0.6B", + trust_remote_code=True, + ) + ``` + +IMPORTANT: TorchTitan imports are deferred to avoid CUDA initialization +before vLLM's multiprocessing fork. +""" import torch +# Import from local custom_models directory +from torchtitan.experiments.vllm.custom_models import ( + load_external_weights, + replace_with_trainable_attention, + store_positions_in_context, + VLLMModelForCausalLM, +) + from vllm.config import VllmConfig from vllm.logger import init_logger -from vllm.model_executor.models.qwen3 import Qwen3ForCausalLM +from vllm.model_executor.parallel_context import ParallelContext logger = init_logger(__name__) -class TorchTitanQwen3ForCausalLM(Qwen3ForCausalLM): +class TorchTitanQwen3ForCausalLM(VLLMModelForCausalLM): """ - TorchTitan-trained Qwen3 dense model adapter for vLLM. + vLLM-compatible wrapper for TorchTitan's Qwen3 model. - This class extends the standard Qwen3ForCausalLM to support loading - weights from TorchTitan checkpoints with different naming conventions. - The architecture is identical to standard Qwen3 - only weight names differ. + This class integrates TorchTitan's Qwen3Model with vLLM by: + 1. Importing TorchTitan's model architecture + 2. Replacing attention with vLLM's TrainableFlashAttention + 3. Implementing the vLLM model interface + + The architecture uses standard multi-head attention (not MLA), + with RoPE positional embeddings and optional QK normalization. """ - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): - super().__init__(vllm_config=vllm_config, prefix=prefix) + supports_pp = False # Pipeline parallelism not supported yet + supports_multimodal = False + + def __init__( + self, + *, + vllm_config: VllmConfig, + prefix: str = "", + parallel_context: ParallelContext | None = None, + ): + super().__init__() + + # vLLM config is required + assert vllm_config is not None, "vllm_config is required" + + # Import TorchTitan's Qwen3 model (deferred import to avoid CUDA init issues) + from torchtitan.models.qwen3.model.args import Qwen3ModelArgs + from torchtitan.models.qwen3.model.model import Qwen3Model + + # Map HuggingFace config to TorchTitan ModelArgs + hf_config = vllm_config.hf_config + print("hf_config: ", hf_config) + model_args = Qwen3ModelArgs( + vocab_size=getattr(hf_config, "vocab_size", 151936), + dim=getattr(hf_config, "hidden_size", 2048), + n_layers=getattr(hf_config, "num_hidden_layers", 4), + n_heads=getattr(hf_config, "num_attention_heads", 16), + n_kv_heads=getattr(hf_config, "num_key_value_heads", 2), + head_dim=getattr(hf_config, "head_dim", 128), + hidden_dim=getattr(hf_config, "intermediate_size", 11008), + norm_eps=getattr(hf_config, "rms_norm_eps", 1e-6), + max_seq_len=getattr(hf_config, "max_position_embeddings", 8192), + rope_theta=getattr(hf_config, "rope_theta", 1000000.0), + qk_norm=getattr(hf_config, "qk_norm", True), + ) + + # Create TorchTitan model + self.model = Qwen3Model(model_args) + self.config = model_args + self.parallel_context = parallel_context + + # Replace attention with vLLM's TrainableFlashAttention + # (This happens before TP so TP can shard the attention weights) + replace_with_trainable_attention(self.model, use_mla=False) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: - pass + """Convert input token IDs to embeddings.""" + return self.model.tok_embeddings(input_ids) + + def forward( + self, + input_ids: torch.Tensor | None = None, + positions: torch.Tensor | None = None, + inputs_embeds: torch.Tensor | None = None, + **kwargs, + ) -> torch.Tensor: + """ + Forward pass with vLLM interface. + + Args: + input_ids: Token IDs [batch, seq_len] (optional if inputs_embeds provided) + positions: Position indices from vLLM for RoPE + inputs_embeds: Pre-computed embeddings (optional, used by vLLM) + **kwargs: Additional vLLM kwargs + + Returns: + hidden_states: Final hidden states before LM head + """ + # Store positions in forward context for attention layers + store_positions_in_context(positions) + + # Get embeddings + h = ( + inputs_embeds + if inputs_embeds is not None + else self.model.tok_embeddings(input_ids) + ) - def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: - pass + # Get RoPE cache + seqlen = h.shape[1] if h.dim() == 3 else h.shape[0] + rope_cache = self.model.rope_cache[:seqlen] + + # Pass through transformer layers + for layer in self.model.layers.values(): + h = layer(h, rope_cache, attention_masks=None) + + # Final norm + return self.model.norm(h) def compute_logits( self, hidden_states: torch.Tensor, - ) -> torch.Tensor | None: - pass + sampling_metadata=None, + ) -> torch.Tensor: + """Compute logits from hidden states.""" + return self.model.output(hidden_states) + + def load_weights(self, weights_iter): + """ + Load weights from HuggingFace checkpoint. + + Maps HF Qwen weight names → TorchTitan naming convention. + This uses the same mapping as TorchTitan's Qwen3StateDictAdapter. + + Args: + weights_iter: Iterator of (name, tensor) pairs from HF checkpoint + + Returns: + Number of loaded and skipped parameters + """ + # HF → TorchTitan name mapping (from Qwen3StateDictAdapter) + hf_to_tt = { + "model.embed_tokens.weight": "tok_embeddings.weight", + "lm_head.weight": "output.weight", + "model.norm.weight": "norm.weight", + # Attention weights + "model.layers.{}.self_attn.q_proj.weight": "layers.{}.attention.wq.weight", + "model.layers.{}.self_attn.k_proj.weight": "layers.{}.attention.wk.weight", + "model.layers.{}.self_attn.v_proj.weight": "layers.{}.attention.wv.weight", + "model.layers.{}.self_attn.o_proj.weight": "layers.{}.attention.wo.weight", + "model.layers.{}.self_attn.q_norm.weight": ( + "layers.{}.attention.q_norm.weight" + ), + "model.layers.{}.self_attn.k_norm.weight": ( + "layers.{}.attention.k_norm.weight" + ), + # Skip rotary_emb.inv_freq (not used in TorchTitan) + "model.layers.{}.self_attn.rotary_emb.inv_freq": None, + # MLP weights + "model.layers.{}.mlp.gate_proj.weight": "layers.{}.feed_forward.w1.weight", + "model.layers.{}.mlp.up_proj.weight": "layers.{}.feed_forward.w3.weight", + "model.layers.{}.mlp.down_proj.weight": "layers.{}.feed_forward.w2.weight", + # Layer norms + "model.layers.{}.input_layernorm.weight": ( + "layers.{}.attention_norm.weight" + ), + "model.layers.{}.post_attention_layernorm.weight": ( + "layers.{}.ffn_norm.weight" + ), + } + + # Load weights using utility function + loaded, skipped = load_external_weights( + model=self.model, weights_iter=weights_iter, name_mapping=hf_to_tt + ) + + logger.info(f"✅ Loaded {loaded} parameters, skipped {skipped}") + + return loaded, skipped diff --git a/torchtitan/experiments/vllm/test_infer.sh b/torchtitan/experiments/vllm/test_infer.sh deleted file mode 100755 index a13810613c..0000000000 --- a/torchtitan/experiments/vllm/test_infer.sh +++ /dev/null @@ -1,54 +0,0 @@ -#!/bin/bash -# Simple test script to demonstrate vLLM inference with TorchTitan Qwen3 model - -set -e - -echo "========================================" -echo "TorchTitan Qwen3 + vLLM Inference Test" -echo "========================================" - -# Check if model path is provided -if [ -z "$1" ]; then - echo "Usage: $0 " - echo "" - echo "Example:" - echo " $0 /path/to/torchtitan/checkpoint" - echo "" - echo "The checkpoint directory should contain:" - echo " - config.json (HuggingFace-style model config)" - echo " - Model weights (PyTorch checkpoint or safetensors)" - exit 1 -fi - -MODEL_PATH="$1" - -# Verify model path exists -if [ ! -d "$MODEL_PATH" ]; then - echo "Error: Model path not found: $MODEL_PATH" - exit 1 -fi - -# Verify config.json exists -if [ ! -f "$MODEL_PATH/config.json" ]; then - echo "Error: config.json not found in $MODEL_PATH" - echo "" - echo "Please ensure your checkpoint contains a HuggingFace-style config.json" - echo "See example_config.json for reference" - exit 1 -fi - -echo "Model path: $MODEL_PATH" -echo "" - -# Run inference with a simple prompt -echo "Running inference with single prompt..." -python torchtitan/experiments/vllm/infer.py \ - --model-path "$MODEL_PATH" \ - --prompt "What is the meaning of life?" \ - --max-tokens 50 \ - --temperature 0.7 - -echo "" -echo "========================================" -echo "Test completed successfully!" -echo "========================================" From b86b8a3e8db177ff4d4c09653c41384f8297fbea Mon Sep 17 00:00:00 2001 From: Jiani Wang Date: Mon, 24 Nov 2025 13:01:46 -0800 Subject: [PATCH 03/19] hit TP issue --- torchtitan/experiments/vllm/README.md | 178 +----------------- .../vllm/example_checkpoint/config.json | 0 torchtitan/experiments/vllm/infer.py | 37 +--- torchtitan/experiments/vllm/model/qwen3.py | 80 +++++++- 4 files changed, 83 insertions(+), 212 deletions(-) delete mode 100644 torchtitan/experiments/vllm/example_checkpoint/config.json diff --git a/torchtitan/experiments/vllm/README.md b/torchtitan/experiments/vllm/README.md index 18a027c500..3eec52c74d 100644 --- a/torchtitan/experiments/vllm/README.md +++ b/torchtitan/experiments/vllm/README.md @@ -14,172 +14,14 @@ The integration consists of two main components: ### Prerequisites -1. Install vLLM: +1. Install vLLM from source: ```bash -pip install vllm -``` - -### Running Inference - -#### Single Prompt - -```bash -python torchtitan/experiments/vllm/infer.py \ - --model-path /path/to/torchtitan/checkpoint \ - --prompt "Explain quantum computing in simple terms" -``` - -#### Multiple Prompts from File - -```bash -# Create a file with prompts (one per line) -cat > prompts.txt << EOF -What is the meaning of life? -Explain how transformers work -Write a poem about AI -EOF - -# Run inference -python torchtitan/experiments/vllm/infer.py \ - --model-path /path/to/torchtitan/checkpoint \ - --prompts-file prompts.txt -``` - -#### With Tensor Parallelism - -```bash -python torchtitan/experiments/vllm/infer.py \ - --model-path /path/to/torchtitan/checkpoint \ - --prompt "Explain deep learning" \ - --tensor-parallel-size 4 \ - --max-tokens 200 -``` - -## Model Configuration - -Your checkpoint directory should contain: - -1. **`config.json`**: HuggingFace-style model configuration -2. **Model weights**: Either PyTorch checkpoint files or safetensors - -Example `config.json` for a Qwen3-7B model: - -```json -{ - "architectures": ["TorchTitanQwen3ForCausalLM"], - "model_type": "qwen3", - "hidden_size": 3584, - "intermediate_size": 18944, - "num_attention_heads": 28, - "num_hidden_layers": 28, - "num_key_value_heads": 4, - "vocab_size": 151936, - "max_position_embeddings": 32768, - "rope_theta": 1000000.0, - "rms_norm_eps": 1e-06, - "tie_word_embeddings": false, - "head_dim": 128 -} -``` - -## Command-line Arguments - -| Argument | Type | Default | Description | -|----------|------|---------|-------------| -| `--model-path` | str | Required | Path to TorchTitan checkpoint directory | -| `--prompt` | str | "Hello, how are you?" | Single prompt to generate from | -| `--prompts-file` | str | None | Path to file with prompts (one per line) | -| `--max-tokens` | int | 100 | Maximum tokens to generate | -| `--temperature` | float | 0.8 | Sampling temperature | -| `--top-p` | float | 0.95 | Nucleus sampling parameter | -| `--tensor-parallel-size` | int | 1 | Number of GPUs for tensor parallelism | - -## Implementation Details - -### Model Registration - -The inference script registers the custom model with vLLM's model registry: - -```python -from vllm.model_executor.models import ModelRegistry -from torchtitan.experiments.vllm.model.qwen3 import TorchTitanQwen3ForCausalLM - -ModelRegistry.register_model( - "TorchTitanQwen3ForCausalLM", - TorchTitanQwen3ForCausalLM, -) -``` - -### Weight Mapping - -The `WeightsMapper` class handles automatic name translation: - -```python -mapper = WeightsMapper( - orig_to_new_substr={ - ".attention.wq": ".self_attn.q_proj", - # ... other mappings - }, - orig_to_new_prefix={ - "tok_embeddings.weight": "model.embed_tokens.weight", - # ... other mappings - }, -) -``` - -### vLLM Engine Initialization - -The script uses vLLM's high-level `LLM` class: - -```python -llm = LLM( - model=model_path, - tensor_parallel_size=tensor_parallel_size, - trust_remote_code=True, -) -``` - -## Troubleshooting - -### "Model not found" Error - -Ensure `config.json` exists in your checkpoint directory and specifies the correct architecture: -```json -{ - "architectures": ["TorchTitanQwen3ForCausalLM"], - "model_type": "qwen3" -} -``` - -### Weight Loading Errors - -Check that your checkpoint contains weights with TorchTitan naming conventions. You can inspect checkpoint keys: - -```python -import torch -checkpoint = torch.load("path/to/checkpoint.pt") -print(checkpoint.keys()) -``` - -### Memory Issues - -- Reduce `--tensor-parallel-size` if you have limited GPU memory -- Use quantization (see vLLM documentation for quantization options) - -## Performance Notes - -- **Batch Processing**: The script processes multiple prompts in a single batch for efficiency -- **KV Caching**: vLLM automatically uses KV caching for fast autoregressive generation -- **Tensor Parallelism**: Use `--tensor-parallel-size` to distribute the model across multiple GPUs - -## Next Steps - -- See vLLM documentation for advanced features: https://docs.vllm.ai/ -- Explore vLLM's serving capabilities for production deployments -- Configure quantization for reduced memory usage - -## References - -- [TorchTitan Qwen3 Model](../../../models/qwen3/model/model.py) -- [vLLM Qwen3 Model](https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/qwen3.py) -- [vLLM Documentation](https://docs.vllm.ai/) +# install PyTorch first, either from PyPI or from source +git clone https://github.com/vllm-project/vllm.git +cd vllm +python use_existing_torch.py +uv pip install -r requirements/build.txt +uv pip install --no-build-isolation -e . +``` +Using following command +https://docs.vllm.ai/en/latest/getting_started/installation/gpu/index.html#use-an-existing-pytorch-installation diff --git a/torchtitan/experiments/vllm/example_checkpoint/config.json b/torchtitan/experiments/vllm/example_checkpoint/config.json deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/torchtitan/experiments/vllm/infer.py b/torchtitan/experiments/vllm/infer.py index 976c34e451..bd39029835 100755 --- a/torchtitan/experiments/vllm/infer.py +++ b/torchtitan/experiments/vllm/infer.py @@ -13,10 +13,7 @@ """ import argparse -import json -import os import shutil -import tempfile import torch.nn as nn from vllm import LLM, SamplingParams @@ -120,7 +117,7 @@ def parse_args(): parser.add_argument( "--model", type=str, - default="torchtitan/experiments/vllm/checkpoint/", + default="torchtitan/experiments/vllm/example_checkpoint/qwen3-0.6B", help="Path to TorchTitan checkpoint directory", ) parser.add_argument( @@ -155,36 +152,8 @@ def main(): register_torchtitan_model() # Create a temporary directory with minimal config.json for vLLM - temp_dir = tempfile.mkdtemp(prefix="vllm_torchtitan_qwen_") - minimal_config = { - "architectures": ["Qwen3TorchTitanForCausalLM"], - # Why `model_type`: Tells HuggingFace Transformers to use `Qwen3Config` class (known type) - "model_type": "qwen3", # Use known HF model type - # The following parameter is Qwen3-0.6B - "hidden_size": 3584, - "intermediate_size": 18944, - "num_attention_heads": 28, - "num_hidden_layers": 28, - "num_key_value_heads": 4, - "vocab_size": 151936, - "max_position_embeddings": 32768, - "rope_theta": 1000000.0, - "rms_norm_eps": 1e-06, - "tie_word_embeddings": False, - "head_dim": 128, - "attention_bias": False, - "hidden_act": "silu", - "qk_norm": True, - "torch_dtype": "bfloat16", - "skip_tokenizer_init": True, - } - - config_path = os.path.join(temp_dir, "config.json") - with open(config_path, "w") as f: - json.dump(minimal_config, f, indent=2) - print(f"Created temporary model config at: {temp_dir}") - print(f"Using checkpoint: {args.model}") + print(f"Using checkpoint and config.json from: {args.model}") print("=" * 80) print("INITIALIZING vLLM WITH TORCHTITAN QWEN3 MODEL") @@ -197,7 +166,7 @@ def main(): # Initialize vLLM with custom TorchTitan Qwen3 model llm = LLM( - model=temp_dir, # Use temporary directory with config.json + model=args.model, # Use temporary directory with config.json hf_overrides=hf_overrides, dtype="bfloat16", trust_remote_code=True, diff --git a/torchtitan/experiments/vllm/model/qwen3.py b/torchtitan/experiments/vllm/model/qwen3.py index 148b4ade9f..63d3121897 100644 --- a/torchtitan/experiments/vllm/model/qwen3.py +++ b/torchtitan/experiments/vllm/model/qwen3.py @@ -27,7 +27,7 @@ # Import from local custom_models directory from torchtitan.experiments.vllm.custom_models import ( - load_external_weights, + # load_external_weights, replace_with_trainable_attention, store_positions_in_context, VLLMModelForCausalLM, @@ -73,8 +73,9 @@ def __init__( from torchtitan.models.qwen3.model.model import Qwen3Model # Map HuggingFace config to TorchTitan ModelArgs - hf_config = vllm_config.hf_config - print("hf_config: ", hf_config) + logger.info("vllm config: ", vllm_config.__class__) + hf_config = vllm_config.model_config.hf_config + logger.info("hf_config: ", hf_config) model_args = Qwen3ModelArgs( vocab_size=getattr(hf_config, "vocab_size", 151936), dim=getattr(hf_config, "hidden_size", 2048), @@ -161,7 +162,7 @@ def load_weights(self, weights_iter): weights_iter: Iterator of (name, tensor) pairs from HF checkpoint Returns: - Number of loaded and skipped parameters + Set of loaded parameter names (for vLLM compatibility) """ # HF → TorchTitan name mapping (from Qwen3StateDictAdapter) hf_to_tt = { @@ -194,11 +195,70 @@ def load_weights(self, weights_iter): ), } - # Load weights using utility function - loaded, skipped = load_external_weights( - model=self.model, weights_iter=weights_iter, name_mapping=hf_to_tt + # Track loaded parameter names + loaded_params = set() + + # Convert iterator to list for processing + from vllm.model_executor.model_loader.weight_utils import default_weight_loader + + # Get parameters from model + params_dict = dict(self.model.named_parameters()) + + weights_list = list(weights_iter) + + for hf_name, loaded_weight in weights_list: + # Try to find matching pattern in name_mapping + target_name = None + + # Check if it's a layer-specific weight + if "layers" in hf_name: + # Extract layer number + import regex as re + + layer_match = re.search(r"layers\.(\d+)\.", hf_name) + if layer_match: + layer_num = layer_match.group(1) + + # Try to find matching pattern + for hf_pattern, target_pattern in hf_to_tt.items(): + if "{}" in hf_pattern and target_pattern is not None: + hf_concrete = hf_pattern.format(layer_num) + if hf_name == hf_concrete: + target_name = target_pattern.format(layer_num) + break + else: + # Non-layer weight (embeddings, norms, output) + target_name = hf_to_tt.get(hf_name) + + # Skip if no mapping or explicitly marked as None + if target_name is None: + continue + + # Check if parameter exists in model + if target_name not in params_dict: + continue + + # Load the weight into model parameter + param = params_dict[target_name] + + # Verify shapes match + if param.shape != loaded_weight.shape: + logger.warning( + f"Shape mismatch for {target_name}: " + f"Model: {param.shape}, Checkpoint: {loaded_weight.shape}" + ) + continue + + # Load the weight + default_weight_loader(param, loaded_weight) + + # Add the parameter name to loaded set + # Since CallableModelWrapper overrides named_parameters(), + # the names returned here already match what vLLM expects + loaded_params.add(target_name) + + logger.info( + f"✅ Loaded {len(loaded_params)} parameters, loaded weights are: {loaded_params}" ) - logger.info(f"✅ Loaded {loaded} parameters, skipped {skipped}") - - return loaded, skipped + return loaded_params From b6cf72bcf4028a5c95b3b2887a962b2191c6b52f Mon Sep 17 00:00:00 2001 From: Jiani Wang Date: Fri, 28 Nov 2025 20:20:15 -0800 Subject: [PATCH 04/19] switch to attention --- torchtitan/experiments/vllm/README.md | 2 +- .../vllm/custom_models/trainable_attention.py | 5 +- torchtitan/experiments/vllm/infer.py | 13 +- .../experiments/vllm/model/attention.py | 176 +++++++++++++++++ torchtitan/experiments/vllm/model/qwen3.py | 80 ++++++-- .../vllm/tests/test_flash_attn_varlen.py | 183 ++++++++++++++++++ 6 files changed, 434 insertions(+), 25 deletions(-) create mode 100644 torchtitan/experiments/vllm/model/attention.py create mode 100644 torchtitan/experiments/vllm/tests/test_flash_attn_varlen.py diff --git a/torchtitan/experiments/vllm/README.md b/torchtitan/experiments/vllm/README.md index 3eec52c74d..8acad8e6b2 100644 --- a/torchtitan/experiments/vllm/README.md +++ b/torchtitan/experiments/vllm/README.md @@ -1,6 +1,6 @@ # TorchTitan Qwen3 Model with vLLM Inference -This directory contains code to run vLLM inference on models trained with TorchTitan. +This directory contains code to run TorchTitan model definition with vLLM inference engine. ## Overview diff --git a/torchtitan/experiments/vllm/custom_models/trainable_attention.py b/torchtitan/experiments/vllm/custom_models/trainable_attention.py index aa8d074d85..80a304c3fc 100644 --- a/torchtitan/experiments/vllm/custom_models/trainable_attention.py +++ b/torchtitan/experiments/vllm/custom_models/trainable_attention.py @@ -225,8 +225,9 @@ def _auto_register_for_kv_cache(self): def forward( self, hidden_states: torch.Tensor, - freqs_cis: torch.Tensor - | None = None, # RoPE frequencies (TorchTitan compatibility) + freqs_cis: ( + torch.Tensor | None + ) = None, # RoPE frequencies (TorchTitan compatibility) attention_mask: torch.Tensor | None = None, **kwargs: Any, # Accept any additional vLLM-specific kwargs ) -> torch.Tensor: diff --git a/torchtitan/experiments/vllm/infer.py b/torchtitan/experiments/vllm/infer.py index bd39029835..b94abd883d 100755 --- a/torchtitan/experiments/vllm/infer.py +++ b/torchtitan/experiments/vllm/infer.py @@ -79,12 +79,6 @@ def build_qwen3_torchtitan(vllm_config, parallel_context: ParallelContext) -> nn return model -# Register with vLLM's ModelRegistry -from vllm import ModelRegistry - -ModelRegistry.register_model("Qwen3TorchTitan", build_qwen3_torchtitan) - - def register_torchtitan_model(): """ Register the TorchTitan Qwen3 custom model with vLLM using factory function pattern. @@ -138,6 +132,12 @@ def parse_args(): default=0.8, help="Sampling temperature", ) + parser.add_argument( + "--tensor-parallel-size", + type=int, + default=1, + help="Number of GPUs for tensor parallelism (default: 1 for single GPU)", + ) return parser.parse_args() @@ -171,6 +171,7 @@ def main(): dtype="bfloat16", trust_remote_code=True, enforce_eager=True, # Use eager mode for debugging + tensor_parallel_size=args.tensor_parallel_size, # Multi-GPU support ) print("=" * 80) diff --git a/torchtitan/experiments/vllm/model/attention.py b/torchtitan/experiments/vllm/model/attention.py new file mode 100644 index 0000000000..ea6f677c75 --- /dev/null +++ b/torchtitan/experiments/vllm/model/attention.py @@ -0,0 +1,176 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +vLLM-compatible Flash Attention implementation for deterministic RL training. +""" + +import itertools + +import torch + +from vllm.attention import Attention +from vllm.attention.utils.fa_utils import is_flash_attn_varlen_func_available + + +class VLLMCompatibleFlashAttention(torch.nn.Module): + """ + Wrapper around vLLM's Attention layer for deterministic RL training. + + This uses vLLM's high-level Attention module which handles: + - KV cache management + - Multiple attention backend selection (FlashAttention, xFormers, SDPA, etc.) + - Quantization support + - Optimized inference + """ + + # Class variable for auto-generating unique layer names (thread-safe) + _layer_counter = itertools.count() + + def __init__( + self, + hidden_size: int, + num_heads: int, + num_kv_heads: int | None = None, + head_dim: int | None = None, + dropout: float = 0.0, + scale: float | None = None, + causal: bool = True, + use_fused_qkv: bool = True, + use_qk_norm: bool = False, + norm_eps: float = 1e-6, + ): + super().__init__() + + if not is_flash_attn_varlen_func_available(): + raise RuntimeError( + "Flash attention is not available. " + "Please install flash-attn or use XPU platform with IPEX." + ) + + self.hidden_size = hidden_size + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads or num_heads + self.head_dim = head_dim or (hidden_size // num_heads) + self.dropout = dropout + self.causal = causal + + if scale is None: + self.scale = self.head_dim**-0.5 + else: + self.scale = scale + + # Create vLLM Attention layer to handle KV cache and backend selection + try: + from vllm.config import get_current_vllm_config + + config = get_current_vllm_config() + cache_config = ( + config.cache_config if hasattr(config, "cache_config") else None + ) + + # Generate unique prefix for this attention layer + # vLLM expects format "layers.X" for layer index extraction + layer_idx = next(VLLMCompatibleFlashAttention._layer_counter) + prefix = f"layers.{layer_idx}" + + self.vllm_attn = Attention( + num_heads=num_heads, + head_size=self.head_dim, + scale=self.scale, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=None, + prefix=prefix, + ) + except (ImportError, RuntimeError, AttributeError): + # Not in vLLM context - will fall back to direct flash attention + self.vllm_attn = None + + def forward( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + *, + scale: float | None = None, + ) -> torch.Tensor: + """ + Forward pass using vLLM's Attention layer. + + Args: + q: Query tensor [batch, num_heads, seq_len, head_dim] + k: Key tensor [batch, num_kv_heads, seq_len, head_dim] + v: Value tensor [batch, num_kv_heads, seq_len, head_dim] + scale: Optional scale override (uses self.scale if None) + + Returns: + output: Attention output [batch, num_heads, seq_len, head_dim] + """ + # Input is (batch, num_heads, seq_len, head_dim) from TorchTitan + # Need to transpose to (batch, seq_len, num_heads, head_dim) + q = q.transpose(1, 2) # -> (batch, seq_len, num_heads, head_dim) + k = k.transpose(1, 2) # -> (batch, seq_len, num_kv_heads, head_dim) + v = v.transpose(1, 2) # -> (batch, seq_len, num_kv_heads, head_dim) + + # Get dimensions + batch_size, seq_len, num_heads, head_dim = q.shape + + # Flatten to vLLM format: (total_tokens, num_heads, head_dim) + total_tokens = batch_size * seq_len + q = q.reshape(total_tokens, num_heads, head_dim) + k = k.reshape(total_tokens, self.num_kv_heads, head_dim) + v = v.reshape(total_tokens, self.num_kv_heads, head_dim) + + # Use vLLM's Attention layer if available (handles KV cache, backend selection) + if self.vllm_attn is not None and not self.training: + # vLLM Attention expects and returns [total_tokens, num_heads * head_dim] + # But it can also accept [total_tokens, num_heads, head_dim] + attn_output = self.vllm_attn(q, k, v) + # Output is [total_tokens, num_heads * head_dim] or [total_tokens, num_heads, head_dim] + if attn_output.dim() == 2: + # Reshape to [total_tokens, num_heads, head_dim] + attn_output = attn_output.reshape(total_tokens, num_heads, head_dim) + else: + # Training mode or fallback: use PyTorch SDPA + # Reshape for SDPA: [batch, num_heads, seq_len, head_dim] + q = q.reshape(batch_size, seq_len, num_heads, head_dim).transpose(1, 2) + k = k.reshape(batch_size, seq_len, self.num_kv_heads, head_dim).transpose( + 1, 2 + ) + v = v.reshape(batch_size, seq_len, self.num_kv_heads, head_dim).transpose( + 1, 2 + ) + + # Handle GQA by repeating k, v if needed + if self.num_kv_heads != self.num_heads: + num_repeats = self.num_heads // self.num_kv_heads + k = k.repeat_interleave(num_repeats, dim=1) + v = v.repeat_interleave(num_repeats, dim=1) + + # Use PyTorch SDPA + attn_output = torch.nn.functional.scaled_dot_product_attention( + q, + k, + v, + attn_mask=None, + dropout_p=self.dropout if self.training else 0.0, + is_causal=self.causal, + scale=scale or self.scale, + ) + + # Transpose back and flatten: [batch, num_heads, seq_len, head_dim] -> [total_tokens, num_heads, head_dim] + attn_output = attn_output.transpose(1, 2).reshape( + total_tokens, num_heads, head_dim + ) + + # Reshape back to batch format and transpose to TorchTitan format + # [total_tokens, num_heads, head_dim] -> [batch, seq_len, num_heads, head_dim] -> [batch, num_heads, seq_len, head_dim] + output = attn_output.reshape( + batch_size, seq_len, num_heads, head_dim + ).transpose(1, 2) + + return output diff --git a/torchtitan/experiments/vllm/model/qwen3.py b/torchtitan/experiments/vllm/model/qwen3.py index 63d3121897..70d8333e26 100644 --- a/torchtitan/experiments/vllm/model/qwen3.py +++ b/torchtitan/experiments/vllm/model/qwen3.py @@ -28,10 +28,10 @@ # Import from local custom_models directory from torchtitan.experiments.vllm.custom_models import ( # load_external_weights, - replace_with_trainable_attention, store_positions_in_context, VLLMModelForCausalLM, ) +from torchtitan.experiments.vllm.model.attention import VLLMCompatibleFlashAttention from vllm.config import VllmConfig from vllm.logger import init_logger @@ -95,9 +95,32 @@ def __init__( self.config = model_args self.parallel_context = parallel_context - # Replace attention with vLLM's TrainableFlashAttention - # (This happens before TP so TP can shard the attention weights) - replace_with_trainable_attention(self.model, use_mla=False) + # Replace inner_attention with vLLM compatible Flash Attention + # NOTE: We replace `inner_attention` (the attention kernel), NOT the whole `Attention` module + # The `Attention` module handles QKV projection, RoPE, etc., and calls `inner_attention` + if not hasattr(self.model, "layers"): + raise AttributeError( + f"Model {type(self.model).__name__} must have .layers attribute" + ) + + for layer_name, layer in self.model.layers.items(): + if not hasattr(layer, "attention"): + raise ValueError(f"Layer {layer_name} must have .attention attribute") + + if not hasattr(layer.attention, "inner_attention"): + raise ValueError( + f"Layer {layer_name}.attention must have .inner_attention attribute" + ) + + # NOTE(jianiw): Attention implementation 1: Add backward for vllm FlashAttn + # Replace only the inner attention kernel, not the whole Attention module + layer.attention.inner_attention = VLLMCompatibleFlashAttention( + hidden_size=model_args.dim, + num_heads=model_args.n_heads, + num_kv_heads=model_args.n_kv_heads, + head_dim=model_args.head_dim, + causal=True, + ) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: """Convert input token IDs to embeddings.""" @@ -114,26 +137,42 @@ def forward( Forward pass with vLLM interface. Args: - input_ids: Token IDs [batch, seq_len] (optional if inputs_embeds provided) - positions: Position indices from vLLM for RoPE + input_ids: Token IDs from vLLM [total_tokens] (1D varlen format) + positions: Position indices from vLLM [total_tokens] (1D varlen format) inputs_embeds: Pre-computed embeddings (optional, used by vLLM) **kwargs: Additional vLLM kwargs Returns: - hidden_states: Final hidden states before LM head + hidden_states: Final hidden states [total_tokens, hidden_size] """ + # Handle inputs_embeds vs input_ids properly + if inputs_embeds is not None: + raise NotImplementedError( + "inputs_embeds is not yet supported by TorchTitan Qwen3. " + "The model expects token IDs and computes embeddings internally. " + "Please provide input_ids instead." + ) + + if input_ids is None: + raise ValueError("Either input_ids or inputs_embeds must be provided") + + # Convert vLLM interface to TorchTitan interface + # vLLM passes input_ids as [total_tokens] but TorchTitan expects [batch_size, seq_len] + # For now, reshape to [1, total_tokens] as a simple batch of 1 + # TODO: In future, use attn_metadata.seq_lens to properly reconstruct batch structure + tokens_2d = input_ids.unsqueeze(0) # [total_tokens] -> [1, total_tokens] + # Store positions in forward context for attention layers - store_positions_in_context(positions) + # Also convert positions to 2D format + if positions is not None: + positions_2d = positions.unsqueeze(0) # [total_tokens] -> [1, total_tokens] + store_positions_in_context(positions_2d) - # Get embeddings - h = ( - inputs_embeds - if inputs_embeds is not None - else self.model.tok_embeddings(input_ids) - ) + # Get embeddings from 2D tokens + h = self.model.tok_embeddings(tokens_2d) # [1, total_tokens, hidden_size] # Get RoPE cache - seqlen = h.shape[1] if h.dim() == 3 else h.shape[0] + seqlen = h.shape[1] # seq_len dimension rope_cache = self.model.rope_cache[:seqlen] # Pass through transformer layers @@ -141,7 +180,16 @@ def forward( h = layer(h, rope_cache, attention_masks=None) # Final norm - return self.model.norm(h) + h = self.model.norm(h) # [1, total_tokens, hidden_size] + + # Convert output format back to vLLM expectations + # vLLM expects hidden_states in [total_tokens, hidden_size] format + # TorchTitan returns [batch_size, seq_len, hidden_size], so we need to flatten + if h.dim() == 3: # [batch_size, seq_len, hidden_size] + batch_size, seq_len, hidden_size = h.shape + h = h.view(batch_size * seq_len, hidden_size) # [total_tokens, hidden_size] + + return h def compute_logits( self, diff --git a/torchtitan/experiments/vllm/tests/test_flash_attn_varlen.py b/torchtitan/experiments/vllm/tests/test_flash_attn_varlen.py new file mode 100644 index 0000000000..195adc5a07 --- /dev/null +++ b/torchtitan/experiments/vllm/tests/test_flash_attn_varlen.py @@ -0,0 +1,183 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +Simple test script for flash_attn_varlen_func forward pass. +Tests the variable-length flash attention function from vLLM. +""" + +import torch +from vllm.attention.utils.fa_utils import flash_attn_varlen_func + + +def test_flash_attn_varlen_func(): + """Test flash_attn_varlen_func with simple dummy inputs.""" + + # Set device + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + print(f"Using device: {device}") + + # Parameters + batch_size = 2 + num_heads = 8 + head_dim = 64 + seq_len_q = 128 + seq_len_k = 128 + + # Total tokens (for variable length) + total_tokens_q = batch_size * seq_len_q + total_tokens_k = batch_size * seq_len_k + + # Create input tensors + # Shape: (total_tokens, num_heads, head_dim) + q = torch.randn( + total_tokens_q, num_heads, head_dim, dtype=torch.float16, device=device + ) + k = torch.randn( + total_tokens_k, num_heads, head_dim, dtype=torch.float16, device=device + ) + v = torch.randn( + total_tokens_k, num_heads, head_dim, dtype=torch.float16, device=device + ) + + # Create cumulative sequence lengths + # cu_seqlens_q and cu_seqlens_k indicate the start position of each sequence + # For uniform sequences: [0, seq_len, 2*seq_len, ...] + cu_seqlens_q = torch.tensor( + [0, seq_len_q, 2 * seq_len_q], dtype=torch.int32, device=device + ) + cu_seqlens_k = torch.tensor( + [0, seq_len_k, 2 * seq_len_k], dtype=torch.int32, device=device + ) + + # Maximum sequence lengths + max_seqlen_q = seq_len_q + max_seqlen_k = seq_len_k + + # Softmax scale (typically 1/sqrt(head_dim)) + softmax_scale = 1.0 / (head_dim**0.5) + + print("\nInput shapes:") + print(f" q: {q.shape}") + print(f" k: {k.shape}") + print(f" v: {v.shape}") + print(f" cu_seqlens_q: {cu_seqlens_q}") + print(f" cu_seqlens_k: {cu_seqlens_k}") + print(f" max_seqlen_q: {max_seqlen_q}") + print(f" max_seqlen_k: {max_seqlen_k}") + print(f" softmax_scale: {softmax_scale}") + + try: + # Call flash_attn_varlen_func + print("\nCalling flash_attn_varlen_func...") + output = flash_attn_varlen_func( + q=q, + k=k, + v=v, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=max_seqlen_q, + max_seqlen_k=max_seqlen_k, + softmax_scale=softmax_scale, + causal=False, # Set to True for causal attention + ) + + print(f"\nOutput shape: {output.shape}") + print(f"Output dtype: {output.dtype}") + print(f"Output device: {output.device}") + print("\nOutput statistics:") + print(f" Mean: {output.mean().item():.6f}") + print(f" Std: {output.std().item():.6f}") + print(f" Min: {output.min().item():.6f}") + print(f" Max: {output.max().item():.6f}") + + print("\n✓ Test passed successfully!") + return output + + except Exception as e: + print(f"\n✗ Test failed with error: {e}") + import traceback + + traceback.print_exc() + raise + + +def test_flash_attn_varlen_func_causal(): + """Test flash_attn_varlen_func with causal attention.""" + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + print(f"\n{'=' * 60}") + print("Testing with causal attention") + print(f"{'=' * 60}") + print(f"Using device: {device}") + + # Smaller test for causal + batch_size = 1 + num_heads = 4 + head_dim = 32 + seq_len = 64 + + total_tokens = batch_size * seq_len + + q = torch.randn( + total_tokens, num_heads, head_dim, dtype=torch.float16, device=device + ) + k = torch.randn( + total_tokens, num_heads, head_dim, dtype=torch.float16, device=device + ) + v = torch.randn( + total_tokens, num_heads, head_dim, dtype=torch.float16, device=device + ) + + cu_seqlens = torch.tensor([0, seq_len], dtype=torch.int32, device=device) + + print("\nInput shapes:") + print(f" q, k, v: {q.shape}") + print(f" cu_seqlens: {cu_seqlens}") + + try: + output = flash_attn_varlen_func( + q=q, + k=k, + v=v, + cu_seqlens_q=cu_seqlens, + cu_seqlens_k=cu_seqlens, + max_seqlen_q=seq_len, + max_seqlen_k=seq_len, + causal=True, + ) + + print(f"\nCausal output shape: {output.shape}") + print("Output statistics:") + print(f" Mean: {output.mean().item():.6f}") + print(f" Std: {output.std().item():.6f}") + + print("\n✓ Causal test passed successfully!") + return output + + except Exception as e: + print(f"\n✗ Causal test failed with error: {e}") + import traceback + + traceback.print_exc() + raise + + +if __name__ == "__main__": + print("=" * 60) + print("Testing flash_attn_varlen_func") + print("=" * 60) + + # Test 1: Basic forward pass + test_flash_attn_varlen_func() + + # Test 2: Causal attention + test_flash_attn_varlen_func_causal() + + print("\n" + "=" * 60) + print("All tests completed successfully! ✓") + print("=" * 60) From 410f3b3b0e3677081c1bc2b1a46916e495dcecab Mon Sep 17 00:00:00 2001 From: Jiani Wang Date: Mon, 1 Dec 2025 11:13:05 -0800 Subject: [PATCH 05/19] single GPU works --- torchtitan/experiments/vllm/README.md | 24 +- torchtitan/experiments/vllm/infer.py | 8 - .../experiments/vllm/model/attention.py | 324 +++++++++++++----- torchtitan/experiments/vllm/model/qwen3.py | 2 - 4 files changed, 255 insertions(+), 103 deletions(-) diff --git a/torchtitan/experiments/vllm/README.md b/torchtitan/experiments/vllm/README.md index 8acad8e6b2..9f4836b2e9 100644 --- a/torchtitan/experiments/vllm/README.md +++ b/torchtitan/experiments/vllm/README.md @@ -14,7 +14,7 @@ The integration consists of two main components: ### Prerequisites -1. Install vLLM from source: +1. Install vLLM from source [vllm-use-an-existing-pytorch-installation](https://docs.vllm.ai/en/latest/getting_started/installation/gpu/index.html#use-an-existing-pytorch-installation): ```bash # install PyTorch first, either from PyPI or from source git clone https://github.com/vllm-project/vllm.git @@ -23,5 +23,23 @@ python use_existing_torch.py uv pip install -r requirements/build.txt uv pip install --no-build-isolation -e . ``` -Using following command -https://docs.vllm.ai/en/latest/getting_started/installation/gpu/index.html#use-an-existing-pytorch-installation + + +NOTE: If `flash_attn_varlen_func` hits error "torch.AcceleratorError: CUDA error: the provided PTX was compiled with an unsupported toolchain" during forward path, this is due to GPU driver version is not compatible with vLLM/PyTorch compiled version. Use the following command to recompile vLLM. + +``` +# Set CUDA version environment variable +export CUDA_HOME=/usr/local/cuda-12.4 +export PATH=/usr/local/cuda-12.4/bin:$PATH +export LD_LIBRARY_PATH=/usr/local/cuda-12.4/lib64:$LD_LIBRARY_PATH + +# Clean previous build +rm -rf build dist *.egg-info +pip uninstall -y vllm + +# Rebuild vLLM from source with CUDA 12.4 +pip install -e . + +``` + +2. Download Qwen3/Qwen3-0.6b checkpoint from HuggingFace and put into `example_checkpoint` folder. diff --git a/torchtitan/experiments/vllm/infer.py b/torchtitan/experiments/vllm/infer.py index b94abd883d..34a739c2df 100755 --- a/torchtitan/experiments/vllm/infer.py +++ b/torchtitan/experiments/vllm/infer.py @@ -13,7 +13,6 @@ """ import argparse -import shutil import torch.nn as nn from vllm import LLM, SamplingParams @@ -200,13 +199,6 @@ def main(): print(f"\nPrompt: {prompt}") print(f"Generated text: {generated_text!r}") - # Clean up temporary directory - try: - shutil.rmtree(temp_dir) - print(f"\nCleaned up temporary directory: {temp_dir}") - except Exception as e: - print(f"Warning: Could not clean up temporary directory {temp_dir}: {e}") - if __name__ == "__main__": main() diff --git a/torchtitan/experiments/vllm/model/attention.py b/torchtitan/experiments/vllm/model/attention.py index ea6f677c75..72116d5052 100644 --- a/torchtitan/experiments/vllm/model/attention.py +++ b/torchtitan/experiments/vllm/model/attention.py @@ -6,6 +6,8 @@ """ vLLM-compatible Flash Attention implementation for deterministic RL training. + +Uses vLLM's flash_attn_varlen_func for forward pass (fast) with custom backward pass. """ import itertools @@ -13,18 +15,16 @@ import torch from vllm.attention import Attention -from vllm.attention.utils.fa_utils import is_flash_attn_varlen_func_available +from vllm.attention.utils.fa_utils import flash_attn_varlen_func, get_flash_attn_version +from vllm.model_executor.layers.batch_invariant import vllm_is_batch_invariant class VLLMCompatibleFlashAttention(torch.nn.Module): """ - Wrapper around vLLM's Attention layer for deterministic RL training. + Wrapper around vLLM's Attention with custom backward pass. - This uses vLLM's high-level Attention module which handles: - - KV cache management - - Multiple attention backend selection (FlashAttention, xFormers, SDPA, etc.) - - Quantization support - - Optimized inference + Forward: Uses vLLM's optimized Attention layer (flash attention kernels) + Backward: Custom implementation for deterministic gradients """ # Class variable for auto-generating unique layer names (thread-safe) @@ -34,36 +34,25 @@ def __init__( self, hidden_size: int, num_heads: int, - num_kv_heads: int | None = None, - head_dim: int | None = None, - dropout: float = 0.0, + num_kv_heads: int, + head_dim: int, scale: float | None = None, causal: bool = True, - use_fused_qkv: bool = True, - use_qk_norm: bool = False, - norm_eps: float = 1e-6, - ): + ) -> None: super().__init__() - if not is_flash_attn_varlen_func_available(): - raise RuntimeError( - "Flash attention is not available. " - "Please install flash-attn or use XPU platform with IPEX." - ) - self.hidden_size = hidden_size self.num_heads = num_heads - self.num_kv_heads = num_kv_heads or num_heads - self.head_dim = head_dim or (hidden_size // num_heads) - self.dropout = dropout + self.num_kv_heads = num_kv_heads + self.head_dim = head_dim self.causal = causal if scale is None: - self.scale = self.head_dim**-0.5 + self.scale = head_dim**-0.5 else: self.scale = scale - # Create vLLM Attention layer to handle KV cache and backend selection + # Create vLLM Attention layer try: from vllm.config import get_current_vllm_config @@ -79,17 +68,54 @@ def __init__( self.vllm_attn = Attention( num_heads=num_heads, - head_size=self.head_dim, + head_size=head_dim, scale=self.scale, - num_kv_heads=self.num_kv_heads, + num_kv_heads=num_kv_heads, cache_config=cache_config, quant_config=None, prefix=prefix, ) except (ImportError, RuntimeError, AttributeError): - # Not in vLLM context - will fall back to direct flash attention + # Not in vLLM context - will need to set up manually self.vllm_attn = None + # KV cache - will be populated by vLLM during model loading + self.kv_cache: list[torch.Tensor] | None = None + + # Auto-register for vLLM KV cache if in vLLM context + self._auto_register_for_kv_cache() + + def _auto_register_for_kv_cache(self): + """Automatically register this layer for vLLM KV cache allocation. + + This is called during __init__ and will register the layer if we're in + a vLLM context. If not in vLLM context (e.g., pure PyTorch training), + this silently does nothing. + """ + # Initialize layer_name attribute + self.layer_name: str | None = None + + try: + from vllm.config import get_current_vllm_config + + config = get_current_vllm_config() + compilation_config = config.compilation_config + + # Generate unique layer name using class counter + # Format: "layers.{index}" for compatibility with extract_layer_index() + layer_name = f"layers.{next(VLLMCompatibleFlashAttention._layer_counter)}" + + # Register this layer in static forward context + if layer_name in compilation_config.static_forward_context: + raise ValueError(f"Duplicate layer name: {layer_name}") + compilation_config.static_forward_context[layer_name] = self + self.layer_name = layer_name + + except (ImportError, RuntimeError, AttributeError): + # Not in vLLM context - this is fine! + # Layer will work normally for training/inference without vLLM + pass + def forward( self, q: torch.Tensor, @@ -99,78 +125,196 @@ def forward( scale: float | None = None, ) -> torch.Tensor: """ - Forward pass using vLLM's Attention layer. + Forward with dual-mode behavior: + - Inference (model.training=False): Use vLLM's Attention layer (KV cache, etc.) + - Training (model.training=True): Use flash_attn_varlen_func with custom backward + - vLLM's Attention used flash_attn_varlen_func kernel by default. Args: q: Query tensor [batch, num_heads, seq_len, head_dim] k: Key tensor [batch, num_kv_heads, seq_len, head_dim] v: Value tensor [batch, num_kv_heads, seq_len, head_dim] - scale: Optional scale override (uses self.scale if None) + scale: Optional attention scale override Returns: - output: Attention output [batch, num_heads, seq_len, head_dim] + output: [batch, num_heads, seq_len, head_dim] """ - # Input is (batch, num_heads, seq_len, head_dim) from TorchTitan - # Need to transpose to (batch, seq_len, num_heads, head_dim) - q = q.transpose(1, 2) # -> (batch, seq_len, num_heads, head_dim) - k = k.transpose(1, 2) # -> (batch, seq_len, num_kv_heads, head_dim) - v = v.transpose(1, 2) # -> (batch, seq_len, num_kv_heads, head_dim) - - # Get dimensions - batch_size, seq_len, num_heads, head_dim = q.shape - - # Flatten to vLLM format: (total_tokens, num_heads, head_dim) - total_tokens = batch_size * seq_len - q = q.reshape(total_tokens, num_heads, head_dim) - k = k.reshape(total_tokens, self.num_kv_heads, head_dim) - v = v.reshape(total_tokens, self.num_kv_heads, head_dim) - - # Use vLLM's Attention layer if available (handles KV cache, backend selection) - if self.vllm_attn is not None and not self.training: - # vLLM Attention expects and returns [total_tokens, num_heads * head_dim] - # But it can also accept [total_tokens, num_heads, head_dim] - attn_output = self.vllm_attn(q, k, v) - # Output is [total_tokens, num_heads * head_dim] or [total_tokens, num_heads, head_dim] - if attn_output.dim() == 2: - # Reshape to [total_tokens, num_heads, head_dim] - attn_output = attn_output.reshape(total_tokens, num_heads, head_dim) - else: - # Training mode or fallback: use PyTorch SDPA - # Reshape for SDPA: [batch, num_heads, seq_len, head_dim] - q = q.reshape(batch_size, seq_len, num_heads, head_dim).transpose(1, 2) - k = k.reshape(batch_size, seq_len, self.num_kv_heads, head_dim).transpose( - 1, 2 - ) - v = v.reshape(batch_size, seq_len, self.num_kv_heads, head_dim).transpose( - 1, 2 - ) + # Input is (batch, num_heads, seq_len, head_dim) + batch_size, num_heads, seq_len, head_dim = q.shape - # Handle GQA by repeating k, v if needed - if self.num_kv_heads != self.num_heads: - num_repeats = self.num_heads // self.num_kv_heads - k = k.repeat_interleave(num_repeats, dim=1) - v = v.repeat_interleave(num_repeats, dim=1) - - # Use PyTorch SDPA - attn_output = torch.nn.functional.scaled_dot_product_attention( - q, - k, - v, - attn_mask=None, - dropout_p=self.dropout if self.training else 0.0, - is_causal=self.causal, - scale=scale or self.scale, - ) + # INFERENCE MODE: Use vLLM's Attention layer + if not self.training and self.vllm_attn is not None: + # Transpose to (batch, seq_len, num_heads, head_dim) for vLLM + q = q.transpose(1, 2) + k = k.transpose(1, 2) + v = v.transpose(1, 2) - # Transpose back and flatten: [batch, num_heads, seq_len, head_dim] -> [total_tokens, num_heads, head_dim] - attn_output = attn_output.transpose(1, 2).reshape( - total_tokens, num_heads, head_dim - ) + # Flatten to (total_tokens, num_heads, head_dim) + q_varlen = q.reshape(-1, num_heads, head_dim) + k_varlen = k.reshape(-1, k.shape[2], head_dim) + v_varlen = v.reshape(-1, v.shape[2], head_dim) + + try: + # Use vLLM's Attention layer (requires forward context) + output_varlen = self.vllm_attn(q_varlen, k_varlen, v_varlen) + + # Reshape back to batch format + output = output_varlen.reshape(batch_size, seq_len, num_heads, head_dim) + + # Transpose back to TorchTitan format + output = output.transpose(1, 2) + + return output + except (AssertionError, RuntimeError): + # Forward context not available, fall through to training mode + pass + + # TRAINING MODE: Use flash_attn_varlen_func with custom backward + # Transpose to (batch, seq_len, num_heads, head_dim) for vLLM + q = q.transpose(1, 2) + k = k.transpose(1, 2) + v = v.transpose(1, 2) + + # Convert to varlen format for vLLM: flatten batch and sequence + # (batch, seq_len, num_heads, head_dim) -> (total_tokens, num_heads, head_dim) + q_varlen = q.reshape(-1, num_heads, head_dim) + k_varlen = k.reshape(-1, k.shape[2], head_dim) + v_varlen = v.reshape(-1, v.shape[2], head_dim) + + # Use custom autograd function with flash_attn_varlen_func forward and manual backward + class VLLMForwardCustomBackward(torch.autograd.Function): + @staticmethod + def forward(ctx, q, k, v, scale, batch_size, seq_len, causal, fa_version): + # Use flash_attn_varlen_func directly for fast forward pass + # This is the SAME kernel vLLM uses internally! + cu_seqlens_q = torch.arange( + 0, + (batch_size + 1) * seq_len, + seq_len, + dtype=torch.int32, + device=q.device, + ) + + output = flash_attn_varlen_func( + q=q, + k=k, + v=v, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_q, + max_seqlen_q=seq_len, + max_seqlen_k=seq_len, + softmax_scale=scale, + causal=causal, + num_splits=1 if vllm_is_batch_invariant() else 0, + fa_version=fa_version, + ) + + # Save for backward + ctx.save_for_backward(q, k, v, output) + ctx.scale = scale + ctx.seq_len = seq_len + ctx.batch_size = batch_size + ctx.causal = causal + + return output + + @staticmethod + def backward(ctx, grad_output): + q, k, v, output = ctx.saved_tensors + scale = ctx.scale + seq_len = ctx.seq_len + batch_size = ctx.batch_size + causal = ctx.causal + + # Reshape from varlen to batch format + total_tokens = q.shape[0] + num_heads = q.shape[1] + head_dim = q.shape[2] + + q_batch = q.reshape(batch_size, seq_len, num_heads, head_dim) + k_batch = k.reshape(batch_size, seq_len, num_heads, head_dim) + v_batch = v.reshape(batch_size, seq_len, num_heads, head_dim) + grad_out_batch = grad_output.reshape( + batch_size, seq_len, num_heads, head_dim + ) + + # Transpose to (batch, num_heads, seq_len, head_dim) + q_t = q_batch.transpose(1, 2) + k_t = k_batch.transpose(1, 2) + v_t = v_batch.transpose(1, 2) + grad_out_t = grad_out_batch.transpose(1, 2) + + # Compute attention scores: QK^T + scores = torch.matmul(q_t, k_t.transpose(-2, -1)) * scale + + # Apply causal mask if needed + if causal: + causal_mask = torch.triu( + torch.ones(seq_len, seq_len, device=q.device, dtype=torch.bool), + diagonal=1, + ) + scores = scores.masked_fill(causal_mask, float("-inf")) + + # Softmax + attn_weights = torch.nn.functional.softmax(scores, dim=-1) + + # Backward through attention + # grad_v = attn_weights^T @ grad_out + grad_v_t = torch.matmul(attn_weights.transpose(-2, -1), grad_out_t) + + # grad_attn_weights = grad_out @ v^T + grad_attn_weights = torch.matmul(grad_out_t, v_t.transpose(-2, -1)) + + # Backward through softmax + sum_term = (grad_attn_weights * attn_weights).sum(dim=-1, keepdim=True) + grad_scores = attn_weights * (grad_attn_weights - sum_term) + + # Apply causal mask to gradients + if causal: + grad_scores = grad_scores.masked_fill(causal_mask, 0.0) + + # Backward through scale + grad_scores = grad_scores * scale + + # grad_q = grad_scores @ K + grad_q_t = torch.matmul(grad_scores, k_t) + + # grad_k = grad_scores^T @ Q + grad_k_t = torch.matmul(grad_scores.transpose(-2, -1), q_t) + + # Transpose back and reshape to varlen format + grad_q = grad_q_t.transpose(1, 2).reshape( + total_tokens, num_heads, head_dim + ) + grad_k = grad_k_t.transpose(1, 2).reshape( + total_tokens, num_heads, head_dim + ) + grad_v = grad_v_t.transpose(1, 2).reshape( + total_tokens, num_heads, head_dim + ) + + return grad_q, grad_k, grad_v, None, None, None, None, None + + # Get flash attention version + fa_version = get_flash_attn_version() + + # Apply custom autograd function + output_varlen = VLLMForwardCustomBackward.apply( + q_varlen, + k_varlen, + v_varlen, + scale or self.scale, + batch_size, + seq_len, + self.causal, + fa_version, + ) + + # Convert back to batch format + # (total_tokens, num_heads, head_dim) -> (batch, seq_len, num_heads, head_dim) + output = output_varlen.reshape(batch_size, seq_len, num_heads, head_dim) - # Reshape back to batch format and transpose to TorchTitan format - # [total_tokens, num_heads, head_dim] -> [batch, seq_len, num_heads, head_dim] -> [batch, num_heads, seq_len, head_dim] - output = attn_output.reshape( - batch_size, seq_len, num_heads, head_dim - ).transpose(1, 2) + # Transpose back to TorchTitan format: (batch, num_heads, seq_len, head_dim) + output = output.transpose(1, 2) return output diff --git a/torchtitan/experiments/vllm/model/qwen3.py b/torchtitan/experiments/vllm/model/qwen3.py index 70d8333e26..4e182ff47c 100644 --- a/torchtitan/experiments/vllm/model/qwen3.py +++ b/torchtitan/experiments/vllm/model/qwen3.py @@ -112,8 +112,6 @@ def __init__( f"Layer {layer_name}.attention must have .inner_attention attribute" ) - # NOTE(jianiw): Attention implementation 1: Add backward for vllm FlashAttn - # Replace only the inner attention kernel, not the whole Attention module layer.attention.inner_attention = VLLMCompatibleFlashAttention( hidden_size=model_args.dim, num_heads=model_args.n_heads, From 4376528cfb6dea73ebe0778bf83d241d68913a5b Mon Sep 17 00:00:00 2001 From: Jiani Wang Date: Mon, 1 Dec 2025 11:43:02 -0800 Subject: [PATCH 06/19] single GPU works --- torchtitan/experiments/vllm/README.md | 8 +- .../vllm/custom_models/__init__.py | 18 +- .../custom_models/attention_replacement.py | 170 ------ .../vllm/custom_models/trainable_attention.py | 502 ---------------- .../custom_models/trainable_mla_attention.py | 562 ------------------ .../experiments/vllm/model/attention.py | 33 + .../tests/test_fwd_bwd_with_vllm_attention.py | 537 +++++++++++++++++ .../vllm/tests/test_vllm_attention.py | 360 +++++++++++ 8 files changed, 938 insertions(+), 1252 deletions(-) delete mode 100644 torchtitan/experiments/vllm/custom_models/attention_replacement.py delete mode 100644 torchtitan/experiments/vllm/custom_models/trainable_attention.py delete mode 100644 torchtitan/experiments/vllm/custom_models/trainable_mla_attention.py create mode 100644 torchtitan/experiments/vllm/tests/test_fwd_bwd_with_vllm_attention.py create mode 100644 torchtitan/experiments/vllm/tests/test_vllm_attention.py diff --git a/torchtitan/experiments/vllm/README.md b/torchtitan/experiments/vllm/README.md index 9f4836b2e9..33725c1f94 100644 --- a/torchtitan/experiments/vllm/README.md +++ b/torchtitan/experiments/vllm/README.md @@ -42,4 +42,10 @@ pip install -e . ``` -2. Download Qwen3/Qwen3-0.6b checkpoint from HuggingFace and put into `example_checkpoint` folder. +2. Download Qwen3/Qwen3-0.6b checkpoint from HuggingFace and put into `example_checkpoint` folder. Make sure to change the "architecture" field in `config.json` to be `Qwen3TorchTitanForCausalLM` so vllm engine could use torchtitan model. + + +3. Run inference: +``` +python torchtitan/experiments/vllm/infer.py --model torchtitan/experiments/vllm/example_checkpoint/qwen3-0.6B +``` diff --git a/torchtitan/experiments/vllm/custom_models/__init__.py b/torchtitan/experiments/vllm/custom_models/__init__.py index ffe2a498ce..f036e77c62 100644 --- a/torchtitan/experiments/vllm/custom_models/__init__.py +++ b/torchtitan/experiments/vllm/custom_models/__init__.py @@ -8,28 +8,12 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Utilities for integrating custom model implementations with vLLM.""" -from .attention_replacement import replace_with_trainable_attention from .custom_model_wrapper import VLLMModelForCausalLM -from .trainable_attention import TrainableFlashAttention -from .trainable_mla_attention import MLAConfig, TrainableMLA -from .utils import ( - convert_freqs_cis_to_real, - create_mla_kv_cache_spec, - load_external_weights, - store_positions_in_context, -) +from .utils import store_positions_in_context __all__ = [ - # Attention modules - "TrainableFlashAttention", - "TrainableMLA", - "MLAConfig", - "replace_with_trainable_attention", # Base wrapper "VLLMModelForCausalLM", # Utilities - "convert_freqs_cis_to_real", - "create_mla_kv_cache_spec", - "load_external_weights", "store_positions_in_context", ] diff --git a/torchtitan/experiments/vllm/custom_models/attention_replacement.py b/torchtitan/experiments/vllm/custom_models/attention_replacement.py deleted file mode 100644 index a3d51dbe65..0000000000 --- a/torchtitan/experiments/vllm/custom_models/attention_replacement.py +++ /dev/null @@ -1,170 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -""" -Utilities for replacing TorchTitan attention layers with vLLM trainable attention. - -This module provides functions to automatically replace TorchTitan's attention -layers with vLLM's optimized trainable attention layers (TrainableFlashAttention -or TrainableMLA) while preserving weights. - -Example usage: - ```python - from torchtitan.models.qwen3.model import Qwen3Model - from . import replace_with_trainable_attention - - # Create TorchTitan model - model = Qwen3Model(model_args) - - # Replace attention layers with vLLM trainable attention - replace_with_trainable_attention(model, use_mla=False) - ``` -""" - -import torch.nn as nn - -from .trainable_attention import TrainableFlashAttention -from .trainable_mla_attention import MLAConfig, TrainableMLA - - -def replace_with_trainable_attention( - model: nn.Module, - use_mla: bool = False, -) -> None: - """ - Replace TorchTitan attention layers with vLLM trainable attention. - - This function performs in-place module surgery, replacing all attention - layers in model.layers with either TrainableFlashAttention or TrainableMLA - while preserving the original weights. - - Args: - model: TorchTitan model with .layers attribute (dict or nn.ModuleDict) - use_mla: If True, use TrainableMLA; otherwise use TrainableFlashAttention - - Raises: - AttributeError: If model doesn't have .layers attribute - ValueError: If attention layer structure is not recognized - """ - if not hasattr(model, "layers"): - raise AttributeError( - f"Model {type(model).__name__} must have .layers attribute" - ) - - for layer_name, layer in model.layers.items(): - if not hasattr(layer, "attention"): - raise ValueError(f"Layer {layer_name} must have .attention attribute") - - old_attn = layer.attention - - if use_mla: - # Create TrainableMLA and transfer weights - new_attn = _create_trainable_mla_from_torchtitan(old_attn) - else: - # Create TrainableFlashAttention and transfer weights - new_attn = _create_trainable_flash_attention_from_torchtitan(old_attn) - - # Replace attention module - layer.attention = new_attn - - -def _create_trainable_flash_attention_from_torchtitan( - torchtitan_attn: nn.Module, -) -> TrainableFlashAttention: - """ - Create TrainableFlashAttention from TorchTitan attention and transfer weights. - - Args: - torchtitan_attn: TorchTitan Attention module - - Returns: - TrainableFlashAttention with transferred weights - """ - # Extract config from TorchTitan attention - hidden_size = torchtitan_attn.wq.weight.shape[1] - num_heads = torchtitan_attn.n_heads - num_kv_heads = getattr(torchtitan_attn, "n_kv_heads", num_heads) - head_dim = torchtitan_attn.head_dim - use_qk_norm = ( - hasattr(torchtitan_attn, "q_norm") and torchtitan_attn.q_norm is not None - ) - - # Create vLLM attention - vllm_attn = TrainableFlashAttention( - hidden_size=hidden_size, - num_heads=num_heads, - num_kv_heads=num_kv_heads, - head_dim=head_dim, - use_fused_qkv=False, # TorchTitan uses separate wq/wk/wv - use_qk_norm=use_qk_norm, - ) - - # Transfer weights (TorchTitan and vLLM use same naming: wq, wk, wv, wo) - vllm_attn.wq.weight.data.copy_(torchtitan_attn.wq.weight.data) - vllm_attn.wk.weight.data.copy_(torchtitan_attn.wk.weight.data) - vllm_attn.wv.weight.data.copy_(torchtitan_attn.wv.weight.data) - vllm_attn.wo.weight.data.copy_(torchtitan_attn.wo.weight.data) - - # Transfer QK norm weights if present - if use_qk_norm: - vllm_attn.q_norm.weight.data.copy_(torchtitan_attn.q_norm.weight.data) - vllm_attn.k_norm.weight.data.copy_(torchtitan_attn.k_norm.weight.data) - - return vllm_attn - - -def _create_trainable_mla_from_torchtitan( - torchtitan_attn: nn.Module, -) -> TrainableMLA: - """ - Create TrainableMLA from TorchTitan MLA attention and transfer weights. - - Args: - torchtitan_attn: TorchTitan MLA Attention module - - Returns: - TrainableMLA with transferred weights - """ - # Extract MLA config from TorchTitan attention - config = MLAConfig( - hidden_size=torchtitan_attn.dim, - num_heads=torchtitan_attn.n_heads, - q_lora_rank=torchtitan_attn.q_lora_rank, - kv_lora_rank=torchtitan_attn.kv_lora_rank, - qk_nope_head_dim=torchtitan_attn.qk_nope_head_dim, - qk_rope_head_dim=torchtitan_attn.qk_rope_head_dim, - v_head_dim=torchtitan_attn.v_head_dim, - norm_eps=1e-5, # Standard value for DeepSeek - dropout=0.0, - scale=torchtitan_attn.softmax_scale, - causal=True, - ) - - # Create vLLM MLA - vllm_mla = TrainableMLA(config) - - # Transfer weights - if vllm_mla.q_lora_rank == 0: - # Direct Q projection - vllm_mla.wq.weight.data.copy_(torchtitan_attn.wq.weight.data) - else: - # LoRA Q projection - assert vllm_mla.q_norm is not None # q_norm exists when q_lora_rank > 0 - vllm_mla.wq_a.weight.data.copy_(torchtitan_attn.wq_a.weight.data) - vllm_mla.wq_b.weight.data.copy_(torchtitan_attn.wq_b.weight.data) - vllm_mla.q_norm.weight.data.copy_(torchtitan_attn.q_norm.weight.data) - - # KV projection (always LoRA) - vllm_mla.wkv_a.weight.data.copy_(torchtitan_attn.wkv_a.weight.data) - vllm_mla.wkv_b.weight.data.copy_(torchtitan_attn.wkv_b.weight.data) - vllm_mla.kv_norm.weight.data.copy_(torchtitan_attn.kv_norm.weight.data) - - # Output projection - vllm_mla.wo.weight.data.copy_(torchtitan_attn.wo.weight.data) - - return vllm_mla diff --git a/torchtitan/experiments/vllm/custom_models/trainable_attention.py b/torchtitan/experiments/vllm/custom_models/trainable_attention.py deleted file mode 100644 index 80a304c3fc..0000000000 --- a/torchtitan/experiments/vllm/custom_models/trainable_attention.py +++ /dev/null @@ -1,502 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -""" -Trainable Flash Attention module for research and fine-tuning. - -This module provides a training-compatible wrapper around vLLM's optimized -flash attention implementation, enabling backpropagation for RL and fine-tuning -use cases. -""" - -import itertools -from typing import Any - -import torch -import torch.nn as nn - -from vllm.attention import Attention -from vllm.attention.utils.fa_utils import is_flash_attn_varlen_func_available - -if is_flash_attn_varlen_func_available(): - from vllm.attention.utils.fa_utils import flash_attn_varlen_func - -from vllm.config import VllmConfig -from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase -from vllm.v1.kv_cache_interface import FullAttentionSpec - - -class TrainableFlashAttention(nn.Module, AttentionLayerBase): - """ - Training-compatible flash attention module using vLLM's optimized kernels. - - This module wraps vLLM's flash attention forward pass and adds backward - support for training scenarios like reinforcement learning and fine-tuning. - - Supports both fused QKV projections (efficient) and separate projections - (for compatibility with TorchTitan models during module surgery). - - Example: - ```python - # Create attention module (fused, efficient) - attn = TrainableFlashAttention(hidden_size=768, num_heads=12, dropout=0.1) - - # Create TorchTitan-compatible module (separate projections) - attn = TrainableFlashAttention( - hidden_size=768, - num_heads=12, - use_fused_qkv=False, # Separate wq/wk/wv for compatibility - use_qk_norm=True, # QK normalization like Qwen3 - ) - - # Use in training - attn.train() - hidden_states = torch.randn(2, 16, 768, requires_grad=True) - output = attn(hidden_states) - - # Backward pass works - loss = output.sum() - loss.backward() - ``` - - Args: - hidden_size: Hidden dimension of the model - num_heads: Number of attention heads - num_kv_heads: Number of key/value heads (for GQA). Defaults to num_heads - head_dim: Dimension per head. Defaults to hidden_size // num_heads - dropout: Dropout probability during training. Defaults to 0.0 - scale: Attention scale factor. Defaults to 1/sqrt(head_dim) - causal: Whether to use causal masking. Defaults to True - use_fused_qkv: Use fused QKV projection (efficient). Set False for - TorchTitan compatibility. Defaults to True. - use_qk_norm: Apply RMSNorm to Q and K after projection (Qwen3 style). - Defaults to False. - norm_eps: Epsilon for RMSNorm if use_qk_norm=True. Defaults to 1e-6. - """ - - # Class variable for auto-generating unique layer names (thread-safe) - _layer_counter = itertools.count() - - def __init__( - self, - hidden_size: int, - num_heads: int, - num_kv_heads: int | None = None, - head_dim: int | None = None, - dropout: float = 0.0, - scale: float | None = None, - causal: bool = True, - use_fused_qkv: bool = True, - use_qk_norm: bool = False, - norm_eps: float = 1e-6, - ): - super().__init__() - - if not is_flash_attn_varlen_func_available(): - raise RuntimeError( - "Flash attention is not available. " - "Please install flash-attn: pip install flash-attn" - ) - - self.hidden_size = hidden_size - self.num_heads = num_heads - self.num_kv_heads = num_kv_heads or num_heads - self.head_dim = head_dim or (hidden_size // num_heads) - self.dropout = dropout - self.causal = causal - self.use_fused_qkv = use_fused_qkv - self.use_qk_norm = use_qk_norm - - if scale is None: - self.scale = self.head_dim**-0.5 - else: - self.scale = scale - - # TODO(future optimization): Always use fused QKV for efficiency - # Currently supporting separate projections for TorchTitan compatibility - # during module surgery. Once we have weight conversion utilities, - # we should always initialize with fused weights and convert TorchTitan - # weights (wq, wk, wv) -> fused (qkv) during load_weights(). - # This will give us the best of both worlds: compatibility + efficiency. - - if use_fused_qkv: - # Fused QKV projection (efficient - single matmul) - self.qkv = nn.Linear( - hidden_size, - (num_heads + 2 * self.num_kv_heads) * self.head_dim, - bias=False, - ) - else: - # Separate projections (TorchTitan compatibility) - self.wq = nn.Linear(hidden_size, num_heads * self.head_dim, bias=False) - self.wk = nn.Linear( - hidden_size, self.num_kv_heads * self.head_dim, bias=False - ) - self.wv = nn.Linear( - hidden_size, self.num_kv_heads * self.head_dim, bias=False - ) - - # Output projection (naming convention follows use_fused_qkv) - if use_fused_qkv: - self.o_proj = nn.Linear(num_heads * self.head_dim, hidden_size, bias=False) - else: - # TorchTitan uses 'wo' naming - self.wo = nn.Linear(num_heads * self.head_dim, hidden_size, bias=False) - - # Optional QK normalization (for Qwen3 and similar models) - if use_qk_norm: - self.q_norm = nn.RMSNorm(self.head_dim, eps=norm_eps) - self.k_norm = nn.RMSNorm(self.head_dim, eps=norm_eps) - else: - self.q_norm = None - self.k_norm = None - - # Create vLLM Attention layer to handle KV cache automatically - # This delegates all the complex KV cache logic to vLLM - try: - from vllm.config import get_current_vllm_config - - config = get_current_vllm_config() - cache_config = ( - config.cache_config if hasattr(config, "cache_config") else None - ) - - # Generate unique prefix for this attention layer - # vLLM expects format "layers.X" for layer index extraction - layer_idx = next(TrainableFlashAttention._layer_counter) - prefix = f"layers.{layer_idx}" - - self.vllm_attn = Attention( - num_heads=num_heads, - head_size=self.head_dim, - scale=self.scale, - num_kv_heads=self.num_kv_heads, - cache_config=cache_config, - quant_config=None, - prefix=prefix, - ) - except (ImportError, RuntimeError, AttributeError): - # Not in vLLM context - attention layer not needed - self.vllm_attn = None - - # KV cache - will be populated by vLLM during model loading - # For V1 engine, this is a list[torch.Tensor] indexed by virtual_engine - self.kv_cache: list[torch.Tensor] | None = None - - # Auto-register for vLLM KV cache if in vLLM context - self._auto_register_for_kv_cache() - - def _auto_register_for_kv_cache(self): - """Automatically register this layer for vLLM KV cache allocation. - - This is called during __init__ and will register the layer if we're in - a vLLM context. If not in vLLM context (e.g., pure PyTorch training), - this silently does nothing. - """ - # Initialize layer_name attribute - self.layer_name: str | None = None - - try: - from vllm.config import get_current_vllm_config - - config = get_current_vllm_config() - compilation_config = config.compilation_config - - # Generate unique layer name using class counter - # Format: "layers.{index}" for compatibility with extract_layer_index() - layer_name = f"layers.{next(TrainableFlashAttention._layer_counter)}" - - # Register this layer in static forward context - if layer_name in compilation_config.static_forward_context: - raise ValueError(f"Duplicate layer name: {layer_name}") - compilation_config.static_forward_context[layer_name] = self - self.layer_name = layer_name - - except (ImportError, RuntimeError, AttributeError): - # Not in vLLM context - this is fine! - # Layer will work normally for training/inference without vLLM - pass - - def forward( - self, - hidden_states: torch.Tensor, - freqs_cis: ( - torch.Tensor | None - ) = None, # RoPE frequencies (TorchTitan compatibility) - attention_mask: torch.Tensor | None = None, - **kwargs: Any, # Accept any additional vLLM-specific kwargs - ) -> torch.Tensor: - """ - Forward pass with flash attention. - - Supports both training (full sequences) and vLLM inference (with KV cache). - - Args: - hidden_states: Input tensor of shape [total_tokens, hidden_size] - or [batch, seq_len, hidden_size] - freqs_cis: RoPE frequencies (for TorchTitan compatibility, currently unused) - attention_mask: Optional attention mask (not yet fully supported) - **kwargs: Additional vLLM-specific kwargs (intermediate_tensors, etc.) - - Returns: - output: Attention output of same shape as hidden_states - """ - # Handle both batched [batch, seq, hidden] and flattened [total_tokens, hidden] - input_is_batched = hidden_states.dim() == 3 - if input_is_batched: - original_batch_size, original_seq_len, _ = hidden_states.shape - hidden_states = hidden_states.view(-1, self.hidden_size) - else: - original_batch_size = None - original_seq_len = None - - total_tokens = hidden_states.shape[0] - - # Project to Q, K, V (supports both fused and separate modes) - if self.use_fused_qkv: - # Fused projection path (efficient) - qkv = self.qkv(hidden_states) - - # Split into Q, K, V - # qkv shape: [total_tokens, (num_heads + 2*num_kv_heads) * head_dim] - q_size = self.num_heads * self.head_dim - k_size = self.num_kv_heads * self.head_dim - v_size = self.num_kv_heads * self.head_dim - - q = qkv[:, :q_size] - k = qkv[:, q_size : q_size + k_size] - v = qkv[:, q_size + k_size : q_size + k_size + v_size] - else: - # Separate projections (TorchTitan compatibility) - q = self.wq(hidden_states) - k = self.wk(hidden_states) - v = self.wv(hidden_states) - - # Reshape for attention: [total_tokens, num_heads, head_dim] - q = q.view(total_tokens, self.num_heads, self.head_dim) - k = k.view(total_tokens, self.num_kv_heads, self.head_dim) - v = v.view(total_tokens, self.num_kv_heads, self.head_dim) - - # Optional QK normalization (Qwen3 style) - if self.q_norm is not None: - q = self.q_norm(q) - if self.k_norm is not None: - k = self.k_norm(k) - - # DEBUG: Log layer 0 values to compare with TorchTitan - is_layer_0 = not hasattr(self, "_debug_logged") - if is_layer_0 and total_tokens > 1 and total_tokens < 100: # Skip warmup - self._debug_logged = True - print("\n[VLLM ATT DEBUG] Layer 0 - Input") - print(f" hidden_states.shape: {hidden_states.shape}") - print(f" total_tokens: {total_tokens}") - print(f" q (before RoPE)[0,0,:5]: {q[0, 0, :5]}") - print(f" k (before RoPE)[0,0,:5]: {k[0, 0, :5]}") - - # Apply RoPE if freqs_cis is provided (TorchTitan integration) - if freqs_cis is not None: - # Get positions from vLLM forward context - try: - from vllm.forward_context import get_forward_context - - forward_ctx = get_forward_context() - - # Try to get positions from custom attribute set by wrapper - positions = None - if hasattr(forward_ctx, "_torchtitan_positions"): - positions = forward_ctx._torchtitan_positions - # Debug: Log positions during generation, not just warmup - unique_pos = torch.unique(positions[: min(100, len(positions))]) - # Skip warmup with all zeros - if (len(unique_pos) > 1 or unique_pos[0] != 0) and not hasattr( - self, "_rope_gen_debug" - ): - self._rope_gen_debug = True - print(f"\n[ROPE GEN] Got real positions: {unique_pos[:20]}") - print( - f"[ROPE GEN] total_tokens: {total_tokens}, " - f"freqs_cis.shape: {freqs_cis.shape}" - ) - else: - # Fallback to sequential positions - positions = torch.arange(total_tokens, device=q.device) - - # Index rope_cache by positions - # freqs_cis shape after convert_freqs_cis_to_real: - # [max_seq_len, head_dim] (cos and sin concatenated) - positions_flat = positions.flatten() - - # Ensure positions are within bounds - max_pos = freqs_cis.shape[0] - 1 - positions_flat = torch.clamp(positions_flat[:total_tokens], 0, max_pos) - - cos_sin = freqs_cis.index_select(0, positions_flat) - - # Split into cos and sin at half of last dimension - # (works for both [*, head_dim] and [*, 2*head_dim] formats) - half_dim = cos_sin.shape[-1] // 2 - cos = cos_sin[..., :half_dim] - sin = cos_sin[..., half_dim:] - - # Expand cos/sin to full head_dim by repeating each element - # [total_tokens, head_dim//2] -> [total_tokens, head_dim] - cos = cos.repeat_interleave(2, dim=-1) - sin = sin.repeat_interleave(2, dim=-1) - - # Apply rotary embedding (same as TorchTitan's apply_rotary_emb) - def rotate_half(x): - """Rotates half the hidden dims of the input.""" - x1 = x[..., : x.shape[-1] // 2] - x2 = x[..., x.shape[-1] // 2 :] - return torch.cat((-x2, x1), dim=-1) - - # Reshape cos/sin for broadcast: [total_tokens, 1, head_dim] - cos = cos.unsqueeze(1).to(dtype=q.dtype, device=q.device) - sin = sin.unsqueeze(1).to(dtype=q.dtype, device=q.device) - - # Apply rotation - q = (q * cos) + (rotate_half(q) * sin) - k = (k * cos) + (rotate_half(k) * sin) - - # DEBUG: Log after RoPE - if is_layer_0 and total_tokens > 1 and total_tokens < 100: - print(f" RoPE applied with positions: {unique_pos[:10]}") - print(f" freqs_cis.shape: {freqs_cis.shape}") - print(f" q (after RoPE)[0,0,:5]: {q[0, 0, :5]}") - print(f" k (after RoPE)[0,0,:5]: {k[0, 0, :5]}") - - except (ImportError, AttributeError, IndexError, AssertionError) as e: - # If we can't get positions, fall through without RoPE - # This will happen in pure training mode - if not hasattr(self, "_rope_error"): - self._rope_error = True - print(f"\n[ROPE DEBUG] Error applying RoPE: {e}") - pass - - # Delegate to vLLM's Attention layer if available - # (handles KV cache automatically) - if self.vllm_attn is not None and not self.training: - # Let vLLM handle all KV cache logic - # vllm_attn expects q,k,v in shape [total_tokens, num_heads*head_dim] - # or [total_tokens, num_heads, head_dim] - attn_output = self.vllm_attn(q, k, v) - # vllm_attn returns [total_tokens, num_heads * head_dim] - else: - # Training mode or fallback: use regular flash attention (no KV cache) - if not self.training and hidden_states.is_cuda: - # Inference without KV cache: use flash attention varlen - # Create simple cu_seqlens for single sequence - cu_seqlens_q = torch.tensor( - [0, total_tokens], - dtype=torch.int32, - device=hidden_states.device, - ) - - attn_output = flash_attn_varlen_func( - q=q, - k=k, - v=v, - cu_seqlens_q=cu_seqlens_q, - cu_seqlens_k=cu_seqlens_q, - max_seqlen_q=total_tokens, - max_seqlen_k=total_tokens, - softmax_scale=self.scale, - causal=self.causal, - dropout_p=0.0, - fa_version=3, - ) - else: - # Training mode with CPU: use PyTorch SDPA - batch_size = 1 - seq_len = total_tokens - - q = q.view(batch_size, seq_len, self.num_heads, self.head_dim) - k = k.view(batch_size, seq_len, self.num_kv_heads, self.head_dim) - v = v.view(batch_size, seq_len, self.num_kv_heads, self.head_dim) - - q = q.transpose(1, 2) # [batch, num_heads, seq_len, head_dim] - k = k.transpose(1, 2) - v = v.transpose(1, 2) - - # Handle GQA by repeating k, v if needed - if self.num_kv_heads != self.num_heads: - num_repeats = self.num_heads // self.num_kv_heads - k = k.repeat_interleave(num_repeats, dim=1) - v = v.repeat_interleave(num_repeats, dim=1) - - attn_output = torch.nn.functional.scaled_dot_product_attention( - q, - k, - v, - attn_mask=attention_mask, - dropout_p=self.dropout if self.training else 0.0, - is_causal=self.causal and attention_mask is None, - ) - - attn_output = attn_output.transpose( - 1, 2 - ) # [batch, seq_len, heads, dim] - attn_output = attn_output.reshape( - total_tokens, self.num_heads, self.head_dim - ) - - # Flatten heads and project output - attn_output = attn_output.reshape(total_tokens, -1) - if self.use_fused_qkv: - output = self.o_proj(attn_output) - else: - output = self.wo(attn_output) - - # DEBUG: Log attention output for layer 0 - if is_layer_0 and total_tokens > 1 and total_tokens < 100: - print(f" attn_output (before o_proj)[0,:5]: {attn_output[0, :5]}") - print(f" output (after o_proj)[0,:5]: {output[0, :5]}") - - # Restore original shape if input was batched - if input_is_batched: - output = output.view( - original_batch_size, original_seq_len, self.hidden_size - ) - - return output - - def get_attn_backend(self): - """ - Get the attention backend for this layer. - - For TrainableFlashAttention, we don't use a specific vLLM backend - since we implement attention directly. Return None to indicate - this layer manages its own attention computation. - """ - # Import here to avoid circular dependency - from vllm.v1.attention.backends.flash_attn import FlashAttentionBackend - - return FlashAttentionBackend - - def get_kv_cache_spec(self, vllm_config: VllmConfig) -> FullAttentionSpec: - """ - Return KV cache specification for V1 engine integration. - - This allows TrainableFlashAttention to work with vLLM's V1 engine - by providing the necessary KV cache metadata. - """ - block_size = vllm_config.cache_config.block_size - # Determine the dtype for KV cache - kv_cache_dtype = vllm_config.cache_config.cache_dtype - if kv_cache_dtype == "auto": - kv_cache_dtype = vllm_config.model_config.dtype - - return FullAttentionSpec( - block_size=block_size, - num_kv_heads=self.num_kv_heads, - head_size=self.head_dim, - dtype=kv_cache_dtype, - ) - - -__all__ = ["TrainableFlashAttention"] diff --git a/torchtitan/experiments/vllm/custom_models/trainable_mla_attention.py b/torchtitan/experiments/vllm/custom_models/trainable_mla_attention.py deleted file mode 100644 index 8e16812249..0000000000 --- a/torchtitan/experiments/vllm/custom_models/trainable_mla_attention.py +++ /dev/null @@ -1,562 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -""" -Trainable Multi-Head Latent Attention (MLA) for DeepSeek V3. - -This module implements the MLA architecture used in DeepSeek V3, which uses: -1. Low-rank compression for Q and KV projections -2. Split Q/K into RoPE and non-RoPE parts -3. Shared K_PE (RoPE-encoded key) across all heads - -Reference: https://github.com/deepseek-ai/DeepSeek-V3 -""" - -from dataclasses import dataclass -from typing import Any - -import torch -import torch.nn as nn - - -@dataclass -class MLAConfig: - """Configuration for Multi-Head Latent Attention. - - Args: - hidden_size: Hidden dimension of the model - num_heads: Number of attention heads - q_lora_rank: LoRA rank for query projection. If 0, use direct projection. - kv_lora_rank: LoRA rank for key-value projection - qk_nope_head_dim: Dimension of Q/K without positional encoding - qk_rope_head_dim: Dimension of Q/K with RoPE - v_head_dim: Dimension of value projection per head - norm_eps: Epsilon for RMSNorm layers - dropout: Dropout probability during training - scale: Attention scale factor. If None, defaults to 1/sqrt(qk_head_dim) - causal: Whether to use causal masking - """ - - hidden_size: int - num_heads: int - q_lora_rank: int = 0 - kv_lora_rank: int = 512 - qk_nope_head_dim: int = 128 - qk_rope_head_dim: int = 64 - v_head_dim: int = 128 - norm_eps: float = 1e-5 - dropout: float = 0.0 - scale: float | None = None - causal: bool = True - - @property - def qk_head_dim(self) -> int: - """Total Q/K head dimension.""" - return self.qk_nope_head_dim + self.qk_rope_head_dim - - -class TrainableMLA(nn.Module): - """ - Training-compatible Multi-Head Latent Attention (MLA). - - This implements DeepSeek V3's MLA architecture: - - Low-rank compression with intermediate RMSNorm - - Split Q/K into RoPE and non-RoPE parts - - Shared K_PE across all attention heads (memory efficient!) - - Example: - ```python - config = MLAConfig( - hidden_size=2048, - num_heads=16, - q_lora_rank=512, - kv_lora_rank=512, - qk_nope_head_dim=128, - qk_rope_head_dim=64, - v_head_dim=128, - ) - mla = TrainableMLA(config) - - # Forward pass - hidden_states = torch.randn(2, 16, 2048) - freqs_cis = torch.randn(16, config.qk_rope_head_dim // 2) - output = mla(hidden_states, freqs_cis=freqs_cis) - ``` - """ - - def __init__(self, config: MLAConfig): - super().__init__() - - self.config = config - self.hidden_size = config.hidden_size - self.num_heads = config.num_heads - self.q_lora_rank = config.q_lora_rank - self.kv_lora_rank = config.kv_lora_rank - self.qk_nope_head_dim = config.qk_nope_head_dim - self.qk_rope_head_dim = config.qk_rope_head_dim - self.qk_head_dim = config.qk_head_dim - self.v_head_dim = config.v_head_dim - self.dropout = config.dropout - self.causal = config.causal - - # Attention scale - if config.scale is None: - self.scale = self.qk_head_dim**-0.5 - else: - self.scale = config.scale - - # Query projection - if self.q_lora_rank == 0: - # Direct projection without LoRA - self.wq = nn.Linear( - self.hidden_size, self.num_heads * self.qk_head_dim, bias=False - ) - self.q_norm = None - else: - # Low-rank projection with intermediate norm - self.wq_a = nn.Linear(self.hidden_size, self.q_lora_rank, bias=False) - self.q_norm = nn.RMSNorm(self.q_lora_rank, eps=config.norm_eps) - self.wq_b = nn.Linear( - self.q_lora_rank, self.num_heads * self.qk_head_dim, bias=False - ) - - # Key-value projection (always uses LoRA) - # Note: wkv_a outputs kv_lora_rank + qk_rope_head_dim - # The extra qk_rope_head_dim is for the shared K_PE - self.wkv_a = nn.Linear( - self.hidden_size, self.kv_lora_rank + self.qk_rope_head_dim, bias=False - ) - self.kv_norm = nn.RMSNorm(self.kv_lora_rank, eps=config.norm_eps) - self.wkv_b = nn.Linear( - self.kv_lora_rank, - self.num_heads * (self.qk_nope_head_dim + self.v_head_dim), - bias=False, - ) - - # Output projection - self.wo = nn.Linear( - self.num_heads * self.v_head_dim, self.hidden_size, bias=False - ) - - # Create vLLM MLAAttention for KV cache + optimized attention - # We'll initialize it in a lazy way since we need vLLM config - self.vllm_mla_attn = None - self._init_vllm_mla_attention() - - def _init_vllm_mla_attention(self): - """Initialize vLLM's MLAAttention for KV cache and optimized attention.""" - try: - from vllm.attention.layer import MLAAttention - from vllm.config import get_current_vllm_config - from vllm.model_executor.layers.linear import ColumnParallelLinear - - # Get vLLM config if available - try: - vllm_config = get_current_vllm_config() - cache_config = vllm_config.cache_config - quant_config = vllm_config.quant_config - except (RuntimeError, AttributeError): - # Not in vLLM context - skip MLAAttention initialization - return - - # Generate unique layer name for KV cache registration - import itertools - - if not hasattr(TrainableMLA, "_layer_counter"): - TrainableMLA._layer_counter = itertools.count() - - layer_name = f"layers.{next(TrainableMLA._layer_counter)}.attention" - - # Wrap wkv_b in ColumnParallelLinear (vLLM's parallel layer) - # This allows vLLM to handle TP sharding properly - kv_b_proj = ColumnParallelLinear( - input_size=self.kv_lora_rank, - output_size=self.num_heads * (self.qk_nope_head_dim + self.v_head_dim), - bias=False, - quant_config=quant_config, - ) - # Copy weights from our regular Linear layer - kv_b_proj.weight.data.copy_(self.wkv_b.weight.data) - - # Create vLLM's MLAAttention - self.vllm_mla_attn = MLAAttention( - num_heads=self.num_heads, - scale=self.scale, - qk_nope_head_dim=self.qk_nope_head_dim, - qk_rope_head_dim=self.qk_rope_head_dim, - v_head_dim=self.v_head_dim, - q_lora_rank=self.q_lora_rank if self.q_lora_rank > 0 else None, - kv_lora_rank=self.kv_lora_rank, - kv_b_proj=kv_b_proj, - cache_config=cache_config, - quant_config=quant_config, - prefix=layer_name, - ) - - print(f" ✓ Created vLLM MLAAttention for {layer_name}") - - except (ImportError, RuntimeError, AttributeError, AssertionError) as e: - # vLLM not available or not in vLLM context - use manual implementation - print(f" ⚠️ Could not create vLLM MLAAttention: {e}") - pass - - def _auto_register_for_kv_cache(self): - """Automatically register this layer for vLLM KV cache allocation. - - This is called during __init__ and will register the layer if we're in - a vLLM context. If not in vLLM context, this silently does nothing. - """ - self.layer_name: str | None = None - - try: - from vllm.config import get_current_vllm_config - - config = get_current_vllm_config() - compilation_config = config.compilation_config - - # Generate unique layer name - import itertools - - if not hasattr(TrainableMLA, "_layer_counter"): - TrainableMLA._layer_counter = itertools.count() - - layer_name = f"layers.{next(TrainableMLA._layer_counter)}" - - # Register this layer in static forward context - if layer_name in compilation_config.static_forward_context: - raise ValueError(f"Duplicate layer name: {layer_name}") - compilation_config.static_forward_context[layer_name] = self - self.layer_name = layer_name - - except (ImportError, RuntimeError, AttributeError): - # Not in vLLM context - this is fine! - pass - - def apply_rotary_emb( - self, x: torch.Tensor, freqs_cis: torch.Tensor - ) -> torch.Tensor: - """ - Apply rotary positional embeddings to the input tensor. - - Args: - x: Input tensor [total_tokens, heads, qk_rope_head_dim] - or [bsz, seq_len, heads, qk_rope_head_dim] - freqs_cis: Precomputed complex exponentials - [max_seq_len, qk_rope_head_dim//2] (complex64/complex128) - - Returns: - Tensor with rotary embeddings applied - """ - # Determine if batched or flattened - # [bsz, seq_len, heads, dim] or [total_tokens, heads, dim] - seq_len = x.size(1) if x.dim() == 4 else x.size(0) - - # Slice freqs_cis to actual sequence length - # freqs_cis is complex: [max_seq_len, qk_rope_head_dim//2] - freqs = freqs_cis[:seq_len] # [seq_len, qk_rope_head_dim//2] - - # Convert x to complex for rotation - # x: [..., qk_rope_head_dim] -> [..., qk_rope_head_dim//2] complex - x_complex = torch.view_as_complex( - x.float().reshape(*x.shape[:-1], -1, 2) - ) # [..., qk_rope_head_dim//2] - - # Reshape freqs for broadcasting - # Batched: [bsz, seq_len, heads, dim] -> freqs [1, seq_len, 1, dim//2] - # Flattened: [total_tokens, heads, dim] -> freqs [seq_len, 1, dim//2] - freqs = freqs.unsqueeze(0).unsqueeze(2) if x.dim() == 4 else freqs.unsqueeze(1) - - # Apply rotation: multiply by complex exponential - x_rotated = x_complex * freqs - - # Convert back to real - x_out = torch.view_as_real(x_rotated).flatten(-2) # [..., qk_rope_head_dim] - - return x_out.to(x.dtype) - - def apply_rotary_emb_with_cos_sin( - self, x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor - ) -> torch.Tensor: - """ - Apply rotary positional embeddings using cos and sin directly. - - Args: - x: Input tensor [total_tokens, heads, qk_rope_head_dim] - cos: Cosine values [total_tokens, qk_rope_head_dim//2] - sin: Sine values [total_tokens, qk_rope_head_dim//2] - - Returns: - Tensor with rotary embeddings applied - [total_tokens, heads, qk_rope_head_dim] - """ - # Expand cos/sin to match x's head dimension - # cos/sin: [total_tokens, qk_rope_head_dim//2] - # -> [total_tokens, 1, qk_rope_head_dim//2] - cos = cos.unsqueeze(1) - sin = sin.unsqueeze(1) - - # Repeat to full dimension: [total_tokens, 1, qk_rope_head_dim] - cos = cos.repeat_interleave(2, dim=-1) - sin = sin.repeat_interleave(2, dim=-1) - - # Apply rotation using rotate_half (avoids complex operations) - def rotate_half(x): - """Rotates half the hidden dims of the input.""" - x1 = x[..., : x.shape[-1] // 2] - x2 = x[..., x.shape[-1] // 2 :] - return torch.cat((-x2, x1), dim=-1) - - # Apply RoPE: x_rotated = x * cos + rotate_half(x) * sin - x_out = (x * cos) + (rotate_half(x) * sin) - - return x_out.to(x.dtype) - - def apply_rotary_emb_indexed( - self, x: torch.Tensor, freqs_for_tokens: torch.Tensor - ) -> torch.Tensor: - """ - DEPRECATED: Use apply_rotary_emb_with_cos_sin instead. - - Apply rotary positional embeddings using pre-indexed frequencies. - - Args: - x: Input tensor [total_tokens, heads, qk_rope_head_dim] - freqs_for_tokens: Pre-indexed frequencies - complex or real format - - Returns: - Tensor with rotary embeddings applied - [total_tokens, heads, qk_rope_head_dim] - """ - # Check if freqs_for_tokens is complex or already split into cos/sin - if freqs_for_tokens.is_complex(): - # Extract cos and sin from complex frequencies - # freqs_for_tokens is complex exponentials: e^(i*theta) - # = cos(theta) + i*sin(theta) - cos = freqs_for_tokens.real # [total_tokens, qk_rope_head_dim//2] - sin = freqs_for_tokens.imag # [total_tokens, qk_rope_head_dim//2] - elif freqs_for_tokens.shape[-1] == x.shape[-1] // 2: - # Format: [total_tokens, qk_rope_head_dim//2] - # complex stored as real - # This happens after index_select on complex tensor - # The tensor is complex data stored in real format - # We need to extract real and imaginary parts - # Actually this shouldn't happen, but handle it anyway - print( - f"[DEBUG] freqs_for_tokens shape: {freqs_for_tokens.shape}, " - f"dtype: {freqs_for_tokens.dtype}" - ) - print(f"[DEBUG] x shape: {x.shape}") - # This format is ambiguous - assume it needs to be duplicated - cos = freqs_for_tokens - sin = freqs_for_tokens - else: - # freqs_for_tokens is already real, split it into cos and sin - # Assume format: [total_tokens, qk_rope_head_dim] - # where first half is cos, second is sin - half_dim = freqs_for_tokens.shape[-1] // 2 - cos = freqs_for_tokens[ - ..., :half_dim - ] # [total_tokens, qk_rope_head_dim//2] - sin = freqs_for_tokens[ - ..., half_dim: - ] # [total_tokens, qk_rope_head_dim//2] - - return self.apply_rotary_emb_with_cos_sin(x, cos, sin) - - def forward( - self, - hidden_states: torch.Tensor, - freqs_cis: torch.Tensor, - attention_mask: torch.Tensor | None = None, - positions: torch.Tensor | None = None, # vLLM provides positions - **kwargs: Any, - ) -> torch.Tensor: - """ - Forward pass for Multi-Head Latent Attention. - - Args: - hidden_states: Input tensor of shape [batch, seq_len, hidden_size] - or [total_tokens, hidden_size] - freqs_cis: Precomputed RoPE frequencies - [max_seq_len, qk_rope_head_dim//2] - attention_mask: Optional attention mask (not fully supported yet) - positions: Per-token positions for RoPE indexing (from vLLM) - **kwargs: Additional vLLM-specific kwargs - - Returns: - Output tensor of same shape as hidden_states - """ - # Handle both batched [batch, seq, hidden] and flattened [total_tokens, hidden] - input_is_batched = hidden_states.dim() == 3 - if input_is_batched: - bsz, seqlen, _ = hidden_states.shape - hidden_states_flat = hidden_states.view(-1, self.hidden_size) - else: - # Flattened format (vLLM inference) - hidden_states_flat = hidden_states - bsz = 1 - seqlen = hidden_states.shape[0] - - total_tokens = hidden_states_flat.shape[0] - - # Get positions for RoPE indexing - if positions is None: - # Try to get from vLLM forward context - try: - from vllm.forward_context import get_forward_context - - forward_ctx = get_forward_context() - if hasattr(forward_ctx, "_torchtitan_positions"): - positions = forward_ctx._torchtitan_positions - else: - # Fallback: sequential positions - positions = torch.arange( - total_tokens, device=hidden_states_flat.device - ) - except (ImportError, AttributeError, AssertionError): - # Training mode: sequential positions - positions = torch.arange(total_tokens, device=hidden_states_flat.device) - - # Query projection - if self.q_lora_rank == 0: - q = self.wq(hidden_states_flat) # [total_tokens, n_heads * qk_head_dim] - else: - q = self.wq_a(hidden_states_flat) # [total_tokens, q_lora_rank] - assert self.q_norm is not None # q_norm exists when q_lora_rank > 0 - q = self.wq_b(self.q_norm(q)) # [total_tokens, n_heads * qk_head_dim] - - # Reshape: [total_tokens, n_heads, qk_head_dim] - q = q.view(total_tokens, self.num_heads, self.qk_head_dim) - - # Split Q into non-RoPE and RoPE parts - q_nope, q_pe = torch.split( - q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1 - ) - - # Apply RoPE to q_pe using positions to index freqs_cis - # Convert freqs_cis from complex to cos/sin BEFORE indexing - # to avoid dtype issues - if freqs_cis.is_complex(): - # Extract cos and sin from complex freqs_cis - freqs_cos = freqs_cis.real # [max_seq_len, qk_rope_head_dim//2] - freqs_sin = freqs_cis.imag # [max_seq_len, qk_rope_head_dim//2] - # Concatenate for easier indexing - freqs_real = torch.cat( - [freqs_cos, freqs_sin], dim=-1 - ) # [max_seq_len, qk_rope_head_dim] - else: - freqs_real = freqs_cis - - # Index by positions - positions_flat = positions.flatten()[:total_tokens] - max_pos = freqs_real.shape[0] - 1 - positions_clamped = torch.clamp(positions_flat, 0, max_pos) - freqs_for_tokens = freqs_real.index_select( - 0, positions_clamped - ) # [total_tokens, qk_rope_head_dim] - - # Split into cos and sin - half_dim = self.qk_rope_head_dim // 2 - cos_for_tokens = freqs_for_tokens[ - ..., :half_dim - ] # [total_tokens, qk_rope_head_dim//2] - sin_for_tokens = freqs_for_tokens[ - ..., half_dim: - ] # [total_tokens, qk_rope_head_dim//2] - - # Apply RoPE to q_pe: [total_tokens, num_heads, qk_rope_head_dim] - q_pe = self.apply_rotary_emb_with_cos_sin(q_pe, cos_for_tokens, sin_for_tokens) - - # Concatenate back: [total_tokens, n_heads, qk_head_dim] - q = torch.cat([q_nope, q_pe], dim=-1) - - # Key-value projection - kv = self.wkv_a( - hidden_states_flat - ) # [total_tokens, kv_lora_rank + qk_rope_head_dim] - - # Split into compressed KV and K_PE - kv_c, k_pe = torch.split(kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) - - # Apply RoPE to k_pe: [total_tokens, qk_rope_head_dim] - # Reshape to [total_tokens, 1, qk_rope_head_dim] for apply_rotary_emb - k_pe = k_pe.unsqueeze(1) # [total_tokens, 1, qk_rope_head_dim] - k_pe = self.apply_rotary_emb_with_cos_sin( - k_pe, cos_for_tokens, sin_for_tokens - ) # [total_tokens, 1, qk_rope_head_dim] - - # Normalize compressed KV - kv_c_normed = self.kv_norm(kv_c) # [total_tokens, kv_lora_rank] - - # Delegate to vLLM's MLAAttention if available (handles KV cache automatically) - if self.vllm_mla_attn is not None and not self.training: - # Let vLLM handle all KV cache logic - attn_output = self.vllm_mla_attn( - q, - kv_c_normed, - k_pe, - output_shape=(total_tokens, self.num_heads * self.v_head_dim), - ) - else: - # Training mode or fallback: manual implementation - # Decompress KV - kv = self.wkv_b( - kv_c_normed - ) # [total_tokens, n_heads * (qk_nope_head_dim + v_head_dim)] - kv = kv.view( - total_tokens, self.num_heads, self.qk_nope_head_dim + self.v_head_dim - ) - - # Split into K_nope and V - k_nope, v = torch.split( - kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1 - ) - - # Concatenate K_nope with broadcasted K_PE - # k_pe shape: [total_tokens, 1, qk_rope_head_dim] - # Expand to: [total_tokens, n_heads, qk_rope_head_dim] - k = torch.cat( - [k_nope, k_pe.expand(-1, self.num_heads, -1)], dim=-1 - ) # [total_tokens, n_heads, qk_head_dim] - - # Reshape for batched attention: [bsz, seqlen, n_heads, head_dim] - q = q.view(bsz, seqlen, self.num_heads, self.qk_head_dim) - k = k.view(bsz, seqlen, self.num_heads, self.qk_head_dim) - v = v.view(bsz, seqlen, self.num_heads, self.v_head_dim) - - # Transpose for attention: [bsz, n_heads, seqlen, head_dim] - q = q.transpose(1, 2) - k = k.transpose(1, 2) - v = v.transpose(1, 2) - - # Use PyTorch SDPA (supports different head dims for Q/K vs V) - # Flash attention doesn't support qk_head_dim != v_head_dim, so we use SDPA - attn_output = torch.nn.functional.scaled_dot_product_attention( - q, - k, - v, - attn_mask=attention_mask, - dropout_p=self.dropout if self.training else 0.0, - is_causal=self.causal and attention_mask is None, - scale=self.scale, - ) - - # Transpose back and reshape: [total_tokens, n_heads * v_head_dim] - attn_output = attn_output.transpose(1, 2).contiguous() - attn_output = attn_output.reshape( - total_tokens, self.num_heads * self.v_head_dim - ) - - # Output projection: [total_tokens, hidden_size] - output = self.wo(attn_output) - - # Restore original shape if input was batched - if input_is_batched: - output = output.view(bsz, seqlen, self.hidden_size) - - return output diff --git a/torchtitan/experiments/vllm/model/attention.py b/torchtitan/experiments/vllm/model/attention.py index 72116d5052..0a3c048234 100644 --- a/torchtitan/experiments/vllm/model/attention.py +++ b/torchtitan/experiments/vllm/model/attention.py @@ -185,6 +185,23 @@ def forward( class VLLMForwardCustomBackward(torch.autograd.Function): @staticmethod def forward(ctx, q, k, v, scale, batch_size, seq_len, causal, fa_version): + # Flash Attention only supports fp16 and bf16 + # Store original dtype for conversion back + original_dtype = q.dtype + + # Convert to bf16 if not already fp16/bf16 + if original_dtype not in [torch.float16, torch.bfloat16]: + target_dtype = ( + torch.bfloat16 + if torch.cuda.is_bf16_supported() + else torch.float16 + ) + q = q.to(target_dtype) + k = k.to(target_dtype) + v = v.to(target_dtype) + else: + target_dtype = original_dtype + # Use flash_attn_varlen_func directly for fast forward pass # This is the SAME kernel vLLM uses internally! cu_seqlens_q = torch.arange( @@ -209,12 +226,17 @@ def forward(ctx, q, k, v, scale, batch_size, seq_len, causal, fa_version): fa_version=fa_version, ) + # Convert output back to original dtype if needed + if original_dtype not in [torch.float16, torch.bfloat16]: + output = output.to(original_dtype) + # Save for backward ctx.save_for_backward(q, k, v, output) ctx.scale = scale ctx.seq_len = seq_len ctx.batch_size = batch_size ctx.causal = causal + ctx.original_dtype = original_dtype return output @@ -225,6 +247,11 @@ def backward(ctx, grad_output): seq_len = ctx.seq_len batch_size = ctx.batch_size causal = ctx.causal + original_dtype = ctx.original_dtype + + # Convert grad_output to match saved tensor dtype + if grad_output.dtype != q.dtype: + grad_output = grad_output.to(q.dtype) # Reshape from varlen to batch format total_tokens = q.shape[0] @@ -293,6 +320,12 @@ def backward(ctx, grad_output): total_tokens, num_heads, head_dim ) + # Convert gradients back to original dtype if needed + if original_dtype not in [torch.float16, torch.bfloat16]: + grad_q = grad_q.to(original_dtype) + grad_k = grad_k.to(original_dtype) + grad_v = grad_v.to(original_dtype) + return grad_q, grad_k, grad_v, None, None, None, None, None # Get flash attention version diff --git a/torchtitan/experiments/vllm/tests/test_fwd_bwd_with_vllm_attention.py b/torchtitan/experiments/vllm/tests/test_fwd_bwd_with_vllm_attention.py new file mode 100644 index 0000000000..3669b54280 --- /dev/null +++ b/torchtitan/experiments/vllm/tests/test_fwd_bwd_with_vllm_attention.py @@ -0,0 +1,537 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +Test script to verify that patching TorchTitan Qwen3 model with VLLMCompatibleFlashAttention +still allows it to run with TorchTitan's training loop. + +This tests: +1. Model creation with patched attention +2. Forward pass with dummy data +3. Backward pass and gradient computation +4. Training step execution +5. Compatibility with TorchTitan's model protocol +""" + +import sys +from pathlib import Path + +import torch +import torch.nn as nn + +# Add TorchTitan to path +torchtitan_root = Path(__file__).parent.parent.parent.parent.parent +sys.path.insert(0, str(torchtitan_root)) + +from torchtitan.experiments.vllm.model.attention import VLLMCompatibleFlashAttention +from torchtitan.models.qwen3.model.args import Qwen3ModelArgs +from torchtitan.models.qwen3.model.model import Qwen3Model + + +def print_section(title): + """Print a formatted section header.""" + print("\n" + "=" * 70) + print(f" {title}") + print("=" * 70) + + +def patch_qwen3_attention(model: Qwen3Model, model_args: Qwen3ModelArgs) -> int: + """ + Patch all attention modules in Qwen3Model with VLLMCompatibleFlashAttention. + + Args: + model: Qwen3Model instance + model_args: Model configuration + + Returns: + Number of attention modules patched + """ + num_patched = 0 + + for layer_name, layer in model.layers.items(): + # Replace inner_attention with VLLMCompatibleFlashAttention + layer.attention.inner_attention = VLLMCompatibleFlashAttention( + hidden_size=model_args.dim, + num_heads=model_args.n_heads, + num_kv_heads=model_args.n_kv_heads, + head_dim=model_args.head_dim, + causal=True, + ) + num_patched += 1 + + return num_patched + + +def test_model_creation(): + """Test 1: Create Qwen3 model and patch with VLLMCompatibleFlashAttention.""" + print_section("Test 1: Model Creation and Patching") + + try: + # Create small test model + model_args = Qwen3ModelArgs( + dim=512, + n_layers=4, + n_heads=8, + n_kv_heads=4, # GQA + vocab_size=1000, + max_seq_len=512, + rope_theta=1000000.0, + hidden_dim=1024, + norm_eps=1e-6, + qk_norm=True, + ) + + print("Creating Qwen3Model with config:") + print(f" dim: {model_args.dim}") + print(f" n_layers: {model_args.n_layers}") + print(f" n_heads: {model_args.n_heads}") + print(f" n_kv_heads: {model_args.n_kv_heads}") + print(f" vocab_size: {model_args.vocab_size}") + + model = Qwen3Model(model_args) + print("✅ Model created successfully") + + # Patch attention modules + print("\nPatching attention modules...") + num_patched = patch_qwen3_attention(model, model_args) + print( + f"✅ Patched {num_patched} attention modules with VLLMCompatibleFlashAttention" + ) + + # Verify patch + first_layer = model.layers["0"] + assert isinstance( + first_layer.attention.inner_attention, VLLMCompatibleFlashAttention + ), "Attention module not patched correctly" + print("✅ Verified attention module type") + + return model, model_args + + except Exception as e: + print(f"❌ Test 1 failed: {e}") + import traceback + + traceback.print_exc() + return None, None + + +def test_forward_pass(model: Qwen3Model, model_args: Qwen3ModelArgs): + """Test 2: Run forward pass with dummy data.""" + print_section("Test 2: Forward Pass") + + if model is None: + print("⚠️ Skipping (model creation failed)") + return None + + try: + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + model = model.to(device) + model.eval() + + print(f"Using device: {device}") + + # Create dummy input + batch_size = 2 + seq_len = 64 + tokens = torch.randint( + 0, model_args.vocab_size, (batch_size, seq_len), device=device + ) + + print(f"\nInput shape: {tokens.shape}") + + # Forward pass + with torch.no_grad(): + output = model(tokens) + + print(f"Output shape: {output.shape}") + print(f"Output dtype: {output.dtype}") + + # Verify output shape + expected_shape = (batch_size, seq_len, model_args.vocab_size) + assert ( + output.shape == expected_shape + ), f"Expected {expected_shape}, got {output.shape}" + + print("\nOutput statistics:") + print(f" Mean: {output.mean().item():.6f}") + print(f" Std: {output.std().item():.6f}") + print(f" Min: {output.min().item():.6f}") + print(f" Max: {output.max().item():.6f}") + + print("\n✅ Forward pass successful") + return output + + except Exception as e: + print(f"❌ Test 2 failed: {e}") + import traceback + + traceback.print_exc() + return None + + +def test_backward_pass(model: Qwen3Model, model_args: Qwen3ModelArgs): + """Test 3: Run backward pass and verify gradients.""" + print_section("Test 3: Backward Pass and Gradient Computation") + + if model is None: + print("⚠️ Skipping (model creation failed)") + return False + + try: + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + model = model.to(device) + model.train() # Enable training mode + + # Create dummy input and target + batch_size = 2 + seq_len = 64 + tokens = torch.randint( + 0, model_args.vocab_size, (batch_size, seq_len), device=device + ) + targets = torch.randint( + 0, model_args.vocab_size, (batch_size, seq_len), device=device + ) + + print(f"Input shape: {tokens.shape}") + print(f"Target shape: {targets.shape}") + + # Forward pass + output = model(tokens) + + # Compute loss + loss_fn = nn.CrossEntropyLoss() + loss = loss_fn(output.view(-1, model_args.vocab_size), targets.view(-1)) + + print(f"\nLoss: {loss.item():.6f}") + + # Backward pass + print("\nRunning backward pass...") + loss.backward() + + # Check gradients + grad_count = 0 + grad_norms = {} + + for name, param in model.named_parameters(): + if param.grad is not None: + grad_count += 1 + grad_norm = param.grad.norm().item() + if "attention" in name: + grad_norms[name] = grad_norm + + print(f"✅ Gradients computed for {grad_count} parameters") + + # Show some attention gradients + if grad_norms: + print("\nSample attention gradient norms:") + for name, norm in list(grad_norms.items())[:5]: + print(f" {name}: {norm:.6f}") + + # Verify gradients are non-zero + assert grad_count > 0, "No gradients computed" + + print("\n✅ Backward pass successful") + return True + + except Exception as e: + print(f"❌ Test 3 failed: {e}") + import traceback + + traceback.print_exc() + return False + + +def test_training_step(model: Qwen3Model, model_args: Qwen3ModelArgs): + """Test 4: Run a full training step with optimizer.""" + print_section("Test 4: Training Step with Optimizer") + + if model is None: + print("⚠️ Skipping (model creation failed)") + return False + + try: + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + model = model.to(device) + model.train() + + # Create optimizer + optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4) + loss_fn = nn.CrossEntropyLoss() + + print(f"Optimizer: {type(optimizer).__name__}") + print(f"Learning rate: {optimizer.param_groups[0]['lr']}") + + # Run multiple training steps + num_steps = 3 + losses = [] + + for step in range(num_steps): + # Create dummy data + batch_size = 2 + seq_len = 64 + tokens = torch.randint( + 0, model_args.vocab_size, (batch_size, seq_len), device=device + ) + targets = torch.randint( + 0, model_args.vocab_size, (batch_size, seq_len), device=device + ) + + # Zero gradients + optimizer.zero_grad() + + # Forward pass + output = model(tokens) + loss = loss_fn(output.view(-1, model_args.vocab_size), targets.view(-1)) + + # Backward pass + loss.backward() + + # Optimizer step + optimizer.step() + + losses.append(loss.item()) + print(f" Step {step + 1}/{num_steps}: loss = {loss.item():.6f}") + + print(f"\n✅ Completed {num_steps} training steps") + print(f"Loss values: {losses}") + + return True + + except Exception as e: + print(f"❌ Test 4 failed: {e}") + import traceback + + traceback.print_exc() + return False + + +def test_attention_shapes(model: Qwen3Model, model_args: Qwen3ModelArgs): + """Test 5: Verify attention input/output shapes in detail.""" + print_section("Test 5: Attention Shape Verification") + + if model is None: + print("⚠️ Skipping (model creation failed)") + return False + + try: + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + model = model.to(device) + model.eval() + + # Hook to capture attention inputs/outputs + attention_info = {} + + def attention_hook(module, input_args, output): + """Capture attention layer inputs and outputs.""" + # input_args is a tuple: (x, rope_cache, attention_masks) + x = input_args[0] + attention_info["input_shape"] = x.shape + attention_info["output_shape"] = output.shape + + # Register hook on first layer's attention + first_layer = model.layers["0"] + hook = first_layer.attention.register_forward_hook(attention_hook) + + # Run forward pass + batch_size = 2 + seq_len = 64 + tokens = torch.randint( + 0, model_args.vocab_size, (batch_size, seq_len), device=device + ) + + with torch.no_grad(): + _ = model(tokens) + + # Remove hook + hook.remove() + + # Verify shapes + print(f"Attention input shape: {attention_info['input_shape']}") + print(f"Attention output shape: {attention_info['output_shape']}") + + expected_input = (batch_size, seq_len, model_args.dim) + expected_output = (batch_size, seq_len, model_args.dim) + + assert ( + attention_info["input_shape"] == expected_input + ), f"Expected input {expected_input}, got {attention_info['input_shape']}" + assert ( + attention_info["output_shape"] == expected_output + ), f"Expected output {expected_output}, got {attention_info['output_shape']}" + + print("\n✅ Attention shapes verified") + return True + + except Exception as e: + print(f"❌ Test 5 failed: {e}") + import traceback + + traceback.print_exc() + return False + + +def test_gqa_correctness(model_args: Qwen3ModelArgs): + """Test 6: Verify GQA expansion works correctly.""" + print_section("Test 6: GQA (Grouped Query Attention) Verification") + + try: + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + # Create attention module directly + attn = VLLMCompatibleFlashAttention( + hidden_size=model_args.dim, + num_heads=model_args.n_heads, + num_kv_heads=model_args.n_kv_heads, + head_dim=model_args.head_dim, + causal=True, + ).to(device) + + print("Attention config:") + print(f" num_heads (Q): {model_args.n_heads}") + print(f" num_kv_heads (K/V): {model_args.n_kv_heads}") + print(f" head_dim: {model_args.head_dim}") + print(f" n_rep: {model_args.n_heads // model_args.n_kv_heads}") + + batch_size = 2 + seq_len = 32 + + # Test with unexpanded K/V (num_kv_heads) + print(f"\nTest 6a: K/V with num_kv_heads ({model_args.n_kv_heads})") + q = torch.randn( + batch_size, + model_args.n_heads, + seq_len, + model_args.head_dim, + device=device, + dtype=torch.float32, + ) + k_small = torch.randn( + batch_size, + model_args.n_kv_heads, + seq_len, + model_args.head_dim, + device=device, + dtype=torch.float32, + ) + v_small = torch.randn( + batch_size, + model_args.n_kv_heads, + seq_len, + model_args.head_dim, + device=device, + dtype=torch.float32, + ) + + with torch.no_grad(): + output_small = attn(q, k_small, v_small) + + print(f" Output shape: {output_small.shape}") + assert ( + output_small.shape == q.shape + ), f"Shape mismatch: {output_small.shape} vs {q.shape}" + print(" ✅ Unexpanded K/V works") + + # Test with expanded K/V (num_heads) + print(f"\nTest 6b: K/V already expanded to num_heads ({model_args.n_heads})") + k_large = k_small.repeat_interleave( + model_args.n_heads // model_args.n_kv_heads, dim=1 + ) + v_large = v_small.repeat_interleave( + model_args.n_heads // model_args.n_kv_heads, dim=1 + ) + + print(f" k_large shape: {k_large.shape}") + print(f" v_large shape: {v_large.shape}") + + with torch.no_grad(): + output_large = attn(q, k_large, v_large) + + print(f" Output shape: {output_large.shape}") + assert ( + output_large.shape == q.shape + ), f"Shape mismatch: {output_large.shape} vs {q.shape}" + print(" ✅ Expanded K/V works") + + print("\n✅ GQA verification successful") + return True + + except Exception as e: + print(f"❌ Test 6 failed: {e}") + import traceback + + traceback.print_exc() + return False + + +def main(): + """Run all tests.""" + print("=" * 70) + print(" TorchTitan + VLLMCompatibleFlashAttention Integration Test") + print("=" * 70) + + # Set random seed for reproducibility + torch.manual_seed(42) + if torch.cuda.is_available(): + torch.cuda.manual_seed(42) + + results = {} + + # Test 1: Model creation and patching + model, model_args = test_model_creation() + results["model_creation"] = model is not None + + if model is None: + print("\n❌ Cannot continue - model creation failed") + return 1 + + # Test 2: Forward pass + output = test_forward_pass(model, model_args) + results["forward_pass"] = output is not None + + # Test 3: Backward pass + results["backward_pass"] = test_backward_pass(model, model_args) + + # Test 4: Training step + results["training_step"] = test_training_step(model, model_args) + + # Test 5: Attention shapes + results["attention_shapes"] = test_attention_shapes(model, model_args) + + # Test 6: GQA verification + results["gqa_verification"] = test_gqa_correctness(model_args) + + # Summary + print_section("FINAL SUMMARY") + + print("\nTest Results:") + for test_name, passed in results.items(): + status = "✅ PASS" if passed else "❌ FAIL" + print(f" {test_name:20s}: {status}") + + all_passed = all(results.values()) + + if all_passed: + print("\n🎉 ALL TESTS PASSED!") + print("\n✅ VLLMCompatibleFlashAttention is compatible with TorchTitan!") + print("✅ Model can be trained with patched attention modules.") + print( + "\nYou can safely use this attention implementation in TorchTitan training." + ) + return 0 + else: + failed_tests = [name for name, passed in results.items() if not passed] + print(f"\n❌ {len(failed_tests)} TEST(S) FAILED:") + for test in failed_tests: + print(f" - {test}") + print( + "\nPlease fix the issues before using VLLMCompatibleFlashAttention in production." + ) + return 1 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/torchtitan/experiments/vllm/tests/test_vllm_attention.py b/torchtitan/experiments/vllm/tests/test_vllm_attention.py new file mode 100644 index 0000000000..aad6229b76 --- /dev/null +++ b/torchtitan/experiments/vllm/tests/test_vllm_attention.py @@ -0,0 +1,360 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +Simple test script for vLLM's Attention layer. +Tests the high-level Attention module with KV cache management. +""" + +import torch +from vllm.attention import Attention +from vllm.config import CacheConfig + + +def test_vllm_attention_basic(): + """Test vLLM Attention layer with basic inputs (no KV cache).""" + print("=" * 70) + print("Test 1: Basic Attention (No KV Cache)") + print("=" * 70) + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + print(f"Using device: {device}") + + # Model parameters + num_heads = 16 + num_kv_heads = 8 # GQA + head_dim = 128 + hidden_size = num_heads * head_dim + + # Create minimal vLLM config + # Note: In production, this comes from get_current_vllm_config() + cache_config = CacheConfig( + block_size=16, + gpu_memory_utilization=0.9, + swap_space=4, + cache_dtype="auto", + ) + + # Create Attention layer + print("\nCreating Attention layer:") + print(f" num_heads: {num_heads}") + print(f" num_kv_heads: {num_kv_heads}") + print(f" head_dim: {head_dim}") + print(f" scale: {1.0 / (head_dim ** 0.5)}") + + attn = Attention( + num_heads=num_heads, + head_size=head_dim, + scale=1.0 / (head_dim**0.5), + num_kv_heads=num_kv_heads, + cache_config=cache_config, + quant_config=None, + prefix="layers.0", + ) + attn = attn.to(device) + attn.eval() + + # Create dummy inputs + batch_size = 2 + seq_len = 128 + total_tokens = batch_size * seq_len + + # Format: [total_tokens, num_heads, head_dim] + q = torch.randn( + total_tokens, num_heads, head_dim, dtype=torch.float16, device=device + ) + k = torch.randn( + total_tokens, num_kv_heads, head_dim, dtype=torch.float16, device=device + ) + v = torch.randn( + total_tokens, num_kv_heads, head_dim, dtype=torch.float16, device=device + ) + + print("\nInput shapes:") + print(f" q: {q.shape}") + print(f" k: {k.shape}") + print(f" v: {v.shape}") + + # Forward pass + try: + with torch.no_grad(): + output = attn(q, k, v) + + print(f"\nOutput shape: {output.shape}") + print(f"Output dtype: {output.dtype}") + print(f"Output device: {output.device}") + print("\nOutput statistics:") + print(f" Mean: {output.mean().item():.6f}") + print(f" Std: {output.std().item():.6f}") + print(f" Min: {output.min().item():.6f}") + print(f" Max: {output.max().item():.6f}") + + print("\n✅ Test 1 passed!") + return True + + except Exception as e: + print(f"\n❌ Test 1 failed with error: {e}") + import traceback + + traceback.print_exc() + return False + + +def test_vllm_attention_gqa_expansion(): + """Test that GQA expansion works correctly.""" + print("\n" + "=" * 70) + print("Test 2: GQA Expansion Test") + print("=" * 70) + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + num_heads = 16 + num_kv_heads = 8 + head_dim = 128 + + cache_config = CacheConfig( + block_size=16, + gpu_memory_utilization=0.9, + swap_space=4, + cache_dtype="auto", + ) + + attn = Attention( + num_heads=num_heads, + head_size=head_dim, + scale=1.0 / (head_dim**0.5), + num_kv_heads=num_kv_heads, + cache_config=cache_config, + quant_config=None, + prefix="layers.1", + ) + attn = attn.to(device) + attn.eval() + + # Test with both unexpanded and expanded k/v + total_tokens = 64 + + q = torch.randn( + total_tokens, num_heads, head_dim, device=device, dtype=torch.float16 + ) + + print(f"\nTest 2a: K/V with num_kv_heads ({num_kv_heads})") + k_small = torch.randn( + total_tokens, num_kv_heads, head_dim, device=device, dtype=torch.float16 + ) + v_small = torch.randn( + total_tokens, num_kv_heads, head_dim, device=device, dtype=torch.float16 + ) + + try: + with torch.no_grad(): + output_small = attn(q, k_small, v_small) + print(f" Output shape: {output_small.shape}") + print(" ✅ GQA with num_kv_heads works!") + except Exception as e: + print(f" ❌ Failed: {e}") + return False + + print(f"\nTest 2b: K/V already expanded to num_heads ({num_heads})") + # Simulate what TorchTitan does - expand k/v to num_heads + k_large = k_small.repeat_interleave(num_heads // num_kv_heads, dim=1) + v_large = v_small.repeat_interleave(num_heads // num_kv_heads, dim=1) + print(f" k_large shape: {k_large.shape}") + print(f" v_large shape: {v_large.shape}") + + try: + with torch.no_grad(): + output_large = attn(q, k_large, v_large) + print(f" Output shape: {output_large.shape}") + print(" ✅ Already-expanded K/V works!") + except Exception as e: + print(f" ❌ Failed: {e}") + return False + + print("\n✅ Test 2 passed!") + return True + + +def test_vllm_attention_shapes(): + """Test various input shapes.""" + print("\n" + "=" * 70) + print("Test 3: Various Input Shapes") + print("=" * 70) + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + num_heads = 8 + num_kv_heads = 8 # MHA + head_dim = 64 + + cache_config = CacheConfig( + block_size=16, + gpu_memory_utilization=0.9, + swap_space=4, + cache_dtype="auto", + ) + + attn = Attention( + num_heads=num_heads, + head_size=head_dim, + scale=1.0 / (head_dim**0.5), + num_kv_heads=num_kv_heads, + cache_config=cache_config, + quant_config=None, + prefix="layers.2", + ) + attn = attn.to(device) + attn.eval() + + test_cases = [ + (1, "Single token"), + (32, "Small batch"), + (256, "Medium batch"), + (1024, "Large batch"), + ] + + for total_tokens, description in test_cases: + print( + f"\nTest 3.{test_cases.index((total_tokens, description)) + 1}: {description} ({total_tokens} tokens)" + ) + q = torch.randn( + total_tokens, num_heads, head_dim, device=device, dtype=torch.float16 + ) + k = torch.randn( + total_tokens, num_kv_heads, head_dim, device=device, dtype=torch.float16 + ) + v = torch.randn( + total_tokens, num_kv_heads, head_dim, device=device, dtype=torch.float16 + ) + + try: + with torch.no_grad(): + output = attn(q, k, v) + assert ( + output.shape[0] == total_tokens + ), f"Expected {total_tokens} tokens, got {output.shape[0]}" + print(f" ✅ Shape: {output.shape}") + except Exception as e: + print(f" ❌ Failed: {e}") + return False + + print("\n✅ Test 3 passed!") + return True + + +def test_integration_with_torchtitan_format(): + """Test integration with TorchTitan's tensor format.""" + print("\n" + "=" * 70) + print("Test 4: TorchTitan Format Integration") + print("=" * 70) + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + num_heads = 16 + num_kv_heads = 8 + head_dim = 128 + + cache_config = CacheConfig( + block_size=16, + gpu_memory_utilization=0.9, + swap_space=4, + cache_dtype="auto", + ) + + attn = Attention( + num_heads=num_heads, + head_size=head_dim, + scale=1.0 / (head_dim**0.5), + num_kv_heads=num_kv_heads, + cache_config=cache_config, + quant_config=None, + prefix="layers.3", + ) + attn = attn.to(device) + attn.eval() + + # Simulate TorchTitan format: [batch, num_heads, seq_len, head_dim] + batch_size = 2 + seq_len = 64 + + print( + f"\nTorchTitan input format: [batch={batch_size}, num_heads={num_heads}, seq_len={seq_len}, head_dim={head_dim}]" + ) + + q_tt = torch.randn( + batch_size, num_heads, seq_len, head_dim, device=device, dtype=torch.float16 + ) + k_tt = torch.randn( + batch_size, num_heads, seq_len, head_dim, device=device, dtype=torch.float16 + ) + v_tt = torch.randn( + batch_size, num_heads, seq_len, head_dim, device=device, dtype=torch.float16 + ) + + print(f" q_tt: {q_tt.shape}") + print(f" k_tt: {k_tt.shape}") + print(f" v_tt: {v_tt.shape}") + + # Convert to vLLM format: [total_tokens, num_heads, head_dim] + total_tokens = batch_size * seq_len + q_vllm = q_tt.transpose(1, 2).reshape(total_tokens, num_heads, head_dim) + k_vllm = k_tt.transpose(1, 2).reshape(total_tokens, num_heads, head_dim) + v_vllm = v_tt.transpose(1, 2).reshape(total_tokens, num_heads, head_dim) + + print( + f"\nvLLM input format: [total_tokens={total_tokens}, num_heads={num_heads}, head_dim={head_dim}]" + ) + print(f" q_vllm: {q_vllm.shape}") + print(f" k_vllm: {k_vllm.shape}") + print(f" v_vllm: {v_vllm.shape}") + + try: + with torch.no_grad(): + output_vllm = attn(q_vllm, k_vllm, v_vllm) + + print(f"\nvLLM output: {output_vllm.shape}") + + # Convert back to TorchTitan format + output_tt = output_vllm.reshape( + batch_size, seq_len, num_heads, head_dim + ).transpose(1, 2) + print(f"TorchTitan output: {output_tt.shape}") + + assert ( + output_tt.shape == q_tt.shape + ), f"Output shape mismatch: {output_tt.shape} vs {q_tt.shape}" + print("\n✅ Test 4 passed!") + return True + + except Exception as e: + print(f"\n❌ Test 4 failed: {e}") + import traceback + + traceback.print_exc() + return False + + +if __name__ == "__main__": + print("=" * 70) + print("Testing vLLM Attention Layer") + print("=" * 70) + + all_passed = True + + # Run all tests + all_passed &= test_vllm_attention_basic() + all_passed &= test_vllm_attention_gqa_expansion() + all_passed &= test_vllm_attention_shapes() + all_passed &= test_integration_with_torchtitan_format() + + # Summary + print("\n" + "=" * 70) + if all_passed: + print("🎉 All tests passed successfully!") + else: + print("❌ Some tests failed!") + print("=" * 70) From a1640103bf835cbcb7bbe401519212557f5df6e3 Mon Sep 17 00:00:00 2001 From: Jiani Wang Date: Wed, 3 Dec 2025 10:29:10 -0800 Subject: [PATCH 07/19] add TP v1 --- .../experiments/vllm/model/attention.py | 77 ++++++-- torchtitan/experiments/vllm/model/qwen3.py | 186 +++++++----------- torchtitan/models/qwen3/infra/parallelize.py | 17 +- torchtitan/models/qwen3/model/model.py | 10 +- 4 files changed, 154 insertions(+), 136 deletions(-) diff --git a/torchtitan/experiments/vllm/model/attention.py b/torchtitan/experiments/vllm/model/attention.py index 0a3c048234..bde41fd38d 100644 --- a/torchtitan/experiments/vllm/model/attention.py +++ b/torchtitan/experiments/vllm/model/attention.py @@ -19,7 +19,7 @@ from vllm.model_executor.layers.batch_invariant import vllm_is_batch_invariant -class VLLMCompatibleFlashAttention(torch.nn.Module): +class VLLMPagedFlashAttention(torch.nn.Module): """ Wrapper around vLLM's Attention with custom backward pass. @@ -42,6 +42,44 @@ def __init__( super().__init__() self.hidden_size = hidden_size + + # Handle tensor parallelism: adjust num_heads and num_kv_heads for TP + # NOTE(jianiw): As we use local tensor for this region, we need to manually + try: + from vllm.config import get_current_vllm_config + from vllm.logger import init_logger + + logger = init_logger(__name__) + vllm_config = get_current_vllm_config() + tp_size = vllm_config.parallel_config.tensor_parallel_size + + if tp_size > 1: + if num_kv_heads % tp_size != 0: + # Pad num_kv_heads and num_heads to be divisible by tp_size + assert num_heads % num_kv_heads == 0 + padded_size = tp_size - num_kv_heads % tp_size + padded_num_kv_heads = num_kv_heads + padded_size + padded_num_heads = ( + num_heads + padded_size * num_heads // num_kv_heads + ) + assert padded_num_heads % tp_size == 0 + assert padded_num_kv_heads % tp_size == 0 + + logger.info( + f"Padding attention heads for tensor parallelism: " + f"{num_heads=}, {padded_num_heads=}, " + f"{num_kv_heads=}, {padded_num_kv_heads=}" + ) + + num_heads = padded_num_heads // tp_size + num_kv_heads = padded_num_kv_heads // tp_size + else: + num_heads //= tp_size + num_kv_heads //= tp_size + except (ImportError, RuntimeError, AttributeError): + # Not in vLLM context - use original values + pass + self.num_heads = num_heads self.num_kv_heads = num_kv_heads self.head_dim = head_dim @@ -63,7 +101,7 @@ def __init__( # Generate unique prefix for this attention layer # vLLM expects format "layers.X" for layer index extraction - layer_idx = next(VLLMCompatibleFlashAttention._layer_counter) + layer_idx = next(VLLMPagedFlashAttention._layer_counter) prefix = f"layers.{layer_idx}" self.vllm_attn = Attention( @@ -75,6 +113,7 @@ def __init__( quant_config=None, prefix=prefix, ) + except (ImportError, RuntimeError, AttributeError): # Not in vLLM context - will need to set up manually self.vllm_attn = None @@ -103,7 +142,7 @@ def _auto_register_for_kv_cache(self): # Generate unique layer name using class counter # Format: "layers.{index}" for compatibility with extract_layer_index() - layer_name = f"layers.{next(VLLMCompatibleFlashAttention._layer_counter)}" + layer_name = f"layers.{next(VLLMPagedFlashAttention._layer_counter)}" # Register this layer in static forward context if layer_name in compilation_config.static_forward_context: @@ -149,25 +188,29 @@ def forward( k = k.transpose(1, 2) v = v.transpose(1, 2) - # Flatten to (total_tokens, num_heads, head_dim) - q_varlen = q.reshape(-1, num_heads, head_dim) - k_varlen = k.reshape(-1, k.shape[2], head_dim) - v_varlen = v.reshape(-1, v.shape[2], head_dim) + # # Flatten to (total_tokens, num_heads, head_dim) + # NOTE(jianiw): vllm_attention can also take input as shape (batch, seq_len, num_heads, head_dim) and do internally + + # q_varlen = q.reshape(-1, num_heads, head_dim) + # k_varlen = k.reshape(-1, k.shape[2], head_dim) # k.shape[2] = num_kv_head + # v_varlen = v.reshape(-1, v.shape[2], head_dim) try: # Use vLLM's Attention layer (requires forward context) - output_varlen = self.vllm_attn(q_varlen, k_varlen, v_varlen) + output_varlen = self.vllm_attn(q, k, v) + print(f"[jianiw] vllm_attn output is: {output_varlen}") # Reshape back to batch format - output = output_varlen.reshape(batch_size, seq_len, num_heads, head_dim) + output = output_varlen.view(batch_size, seq_len, num_heads, head_dim) # Transpose back to TorchTitan format output = output.transpose(1, 2) return output - except (AssertionError, RuntimeError): + except (AssertionError, RuntimeError) as e: # Forward context not available, fall through to training mode - pass + print(f"Error when calling self.vllm_attn during Inference, {str(e)}") + raise # TRAINING MODE: Use flash_attn_varlen_func with custom backward # Transpose to (batch, seq_len, num_heads, head_dim) for vLLM @@ -175,11 +218,16 @@ def forward( k = k.transpose(1, 2) v = v.transpose(1, 2) + # After to_local(), use actual tensor shapes (TP may have sharded heads) + # Shape: (batch, seq_len, num_heads_local, head_dim) + _, _, num_heads_local, _ = q.shape + _, _, num_kv_heads_local, _ = k.shape + # Convert to varlen format for vLLM: flatten batch and sequence # (batch, seq_len, num_heads, head_dim) -> (total_tokens, num_heads, head_dim) - q_varlen = q.reshape(-1, num_heads, head_dim) - k_varlen = k.reshape(-1, k.shape[2], head_dim) - v_varlen = v.reshape(-1, v.shape[2], head_dim) + q_varlen = q.reshape(-1, num_heads_local, head_dim) + k_varlen = k.reshape(-1, num_kv_heads_local, head_dim) + v_varlen = v.reshape(-1, num_kv_heads_local, head_dim) # Use custom autograd function with flash_attn_varlen_func forward and manual backward class VLLMForwardCustomBackward(torch.autograd.Function): @@ -204,6 +252,7 @@ def forward(ctx, q, k, v, scale, batch_size, seq_len, causal, fa_version): # Use flash_attn_varlen_func directly for fast forward pass # This is the SAME kernel vLLM uses internally! + # TODO(jianiw): Need to double-check cu_seqlens_q = torch.arange( 0, (batch_size + 1) * seq_len, diff --git a/torchtitan/experiments/vllm/model/qwen3.py b/torchtitan/experiments/vllm/model/qwen3.py index 4e182ff47c..97256c4e74 100644 --- a/torchtitan/experiments/vllm/model/qwen3.py +++ b/torchtitan/experiments/vllm/model/qwen3.py @@ -24,6 +24,7 @@ """ import torch +from torch.distributed.tensor import DTensor # Import from local custom_models directory from torchtitan.experiments.vllm.custom_models import ( @@ -31,7 +32,7 @@ store_positions_in_context, VLLMModelForCausalLM, ) -from torchtitan.experiments.vllm.model.attention import VLLMCompatibleFlashAttention +from torchtitan.experiments.vllm.model.attention import VLLMPagedFlashAttention from vllm.config import VllmConfig from vllm.logger import init_logger @@ -73,9 +74,9 @@ def __init__( from torchtitan.models.qwen3.model.model import Qwen3Model # Map HuggingFace config to TorchTitan ModelArgs - logger.info("vllm config: ", vllm_config.__class__) + logger.info("vllm config: " + str(vllm_config.__class__)) hf_config = vllm_config.model_config.hf_config - logger.info("hf_config: ", hf_config) + logger.info("hf_config: " + str(hf_config)) model_args = Qwen3ModelArgs( vocab_size=getattr(hf_config, "vocab_size", 151936), dim=getattr(hf_config, "hidden_size", 2048), @@ -90,14 +91,14 @@ def __init__( qk_norm=getattr(hf_config, "qk_norm", True), ) + print(f"In attention initialization, model args are : {model_args}") + # Create TorchTitan model self.model = Qwen3Model(model_args) self.config = model_args self.parallel_context = parallel_context - # Replace inner_attention with vLLM compatible Flash Attention - # NOTE: We replace `inner_attention` (the attention kernel), NOT the whole `Attention` module - # The `Attention` module handles QKV projection, RoPE, etc., and calls `inner_attention` + # The `vllm.Attention` module handles QKV projection, RoPE, etc., and calls `inner_attention` if not hasattr(self.model, "layers"): raise AttributeError( f"Model {type(self.model).__name__} must have .layers attribute" @@ -107,19 +108,17 @@ def __init__( if not hasattr(layer, "attention"): raise ValueError(f"Layer {layer_name} must have .attention attribute") - if not hasattr(layer.attention, "inner_attention"): - raise ValueError( - f"Layer {layer_name}.attention must have .inner_attention attribute" - ) - - layer.attention.inner_attention = VLLMCompatibleFlashAttention( + vllm_attn = VLLMPagedFlashAttention( hidden_size=model_args.dim, - num_heads=model_args.n_heads, - num_kv_heads=model_args.n_kv_heads, + num_heads=model_args.n_heads, # 16 (8 when TP =2) + # NOTE(jianiw): Before feeding into inner_attention, the n_kv_heads has been replicated -> num_heads + num_kv_heads=model_args.n_heads, # 16 (8 When TP=2) head_dim=model_args.head_dim, causal=True, ) + layer.attention.inner_attention = vllm_attn + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: """Convert input token IDs to embeddings.""" return self.model.tok_embeddings(input_ids) @@ -177,9 +176,6 @@ def forward( for layer in self.model.layers.values(): h = layer(h, rope_cache, attention_masks=None) - # Final norm - h = self.model.norm(h) # [1, total_tokens, hidden_size] - # Convert output format back to vLLM expectations # vLLM expects hidden_states in [total_tokens, hidden_size] format # TorchTitan returns [batch_size, seq_len, hidden_size], so we need to flatten @@ -187,6 +183,9 @@ def forward( batch_size, seq_len, hidden_size = h.shape h = h.view(batch_size * seq_len, hidden_size) # [total_tokens, hidden_size] + # NOTE(jianiw): explicitly insert communication and return full tensor to vLLM Engine + if isinstance(h, DTensor): + h = h.full_tensor() return h def compute_logits( @@ -195,14 +194,20 @@ def compute_logits( sampling_metadata=None, ) -> torch.Tensor: """Compute logits from hidden states.""" + # hidden states is the output from + # NOTE(jianiw): When TP is enabled, self.model.norm isself.model.output is ColumnParallel() annotated + + # Current hidden_states is replicate() + + h = self.model.norm(hidden_states) # [1, total_tokens, hidden_size] + output = self.model.output(h) + return self.model.output(hidden_states) def load_weights(self, weights_iter): """ - Load weights from HuggingFace checkpoint. - - Maps HF Qwen weight names → TorchTitan naming convention. - This uses the same mapping as TorchTitan's Qwen3StateDictAdapter. + Uses TorchTitan's Qwen3StateDictAdapter to map HF → TorchTitan naming, + then uses set_model_state_dict for proper distributed tensor handling. Args: weights_iter: Iterator of (name, tensor) pairs from HF checkpoint @@ -210,101 +215,60 @@ def load_weights(self, weights_iter): Returns: Set of loaded parameter names (for vLLM compatibility) """ - # HF → TorchTitan name mapping (from Qwen3StateDictAdapter) - hf_to_tt = { - "model.embed_tokens.weight": "tok_embeddings.weight", - "lm_head.weight": "output.weight", - "model.norm.weight": "norm.weight", - # Attention weights - "model.layers.{}.self_attn.q_proj.weight": "layers.{}.attention.wq.weight", - "model.layers.{}.self_attn.k_proj.weight": "layers.{}.attention.wk.weight", - "model.layers.{}.self_attn.v_proj.weight": "layers.{}.attention.wv.weight", - "model.layers.{}.self_attn.o_proj.weight": "layers.{}.attention.wo.weight", - "model.layers.{}.self_attn.q_norm.weight": ( - "layers.{}.attention.q_norm.weight" - ), - "model.layers.{}.self_attn.k_norm.weight": ( - "layers.{}.attention.k_norm.weight" - ), - # Skip rotary_emb.inv_freq (not used in TorchTitan) - "model.layers.{}.self_attn.rotary_emb.inv_freq": None, - # MLP weights - "model.layers.{}.mlp.gate_proj.weight": "layers.{}.feed_forward.w1.weight", - "model.layers.{}.mlp.up_proj.weight": "layers.{}.feed_forward.w3.weight", - "model.layers.{}.mlp.down_proj.weight": "layers.{}.feed_forward.w2.weight", - # Layer norms - "model.layers.{}.input_layernorm.weight": ( - "layers.{}.attention_norm.weight" - ), - "model.layers.{}.post_attention_layernorm.weight": ( - "layers.{}.ffn_norm.weight" - ), - } - - # Track loaded parameter names - loaded_params = set() - - # Convert iterator to list for processing - from vllm.model_executor.model_loader.weight_utils import default_weight_loader - - # Get parameters from model - params_dict = dict(self.model.named_parameters()) - - weights_list = list(weights_iter) - - for hf_name, loaded_weight in weights_list: - # Try to find matching pattern in name_mapping - target_name = None - - # Check if it's a layer-specific weight - if "layers" in hf_name: - # Extract layer number - import regex as re - - layer_match = re.search(r"layers\.(\d+)\.", hf_name) - if layer_match: - layer_num = layer_match.group(1) - - # Try to find matching pattern - for hf_pattern, target_pattern in hf_to_tt.items(): - if "{}" in hf_pattern and target_pattern is not None: - hf_concrete = hf_pattern.format(layer_num) - if hf_name == hf_concrete: - target_name = target_pattern.format(layer_num) - break - else: - # Non-layer weight (embeddings, norms, output) - target_name = hf_to_tt.get(hf_name) - - # Skip if no mapping or explicitly marked as None - if target_name is None: - continue - - # Check if parameter exists in model - if target_name not in params_dict: - continue - - # Load the weight into model parameter - param = params_dict[target_name] - - # Verify shapes match - if param.shape != loaded_weight.shape: - logger.warning( - f"Shape mismatch for {target_name}: " - f"Model: {param.shape}, Checkpoint: {loaded_weight.shape}" + from torch.distributed._tensor import DTensor, Replicate + from torch.distributed.checkpoint.state_dict import ( + set_model_state_dict, + StateDictOptions, + ) + from torchtitan.models.qwen3.model.state_dict_adapter import ( + Qwen3StateDictAdapter, + ) + + # Collect weights from iterator into a dict + hf_state_dict = {} + for name, tensor in weights_iter: + hf_state_dict[name] = tensor + + # Use TorchTitan's adapter to convert HF → TorchTitan format + adapter = Qwen3StateDictAdapter( + model_args=self.config, + hf_assets_path=None, # Not needed for from_hf conversion + ) + + # Convert the entire state dict at once + torchtitan_state_dict = adapter.from_hf(hf_state_dict) + + # Get model state dict to check which tensors are DTensors + model_state_dict = {k: v for k, v in self.model.state_dict().items()} + + # Convert HF tensors to replicate DTensor if target is DTensor + for name, tensor in torchtitan_state_dict.items(): + if name in model_state_dict and isinstance(model_state_dict[name], DTensor): + # Get the device mesh from the target DTensor + target_dtensor = model_state_dict[name] + device_mesh = target_dtensor.device_mesh + # Convert to replicate DTensor + torchtitan_state_dict[name] = DTensor.from_local( + tensor.to(device_mesh.device_type), + device_mesh=device_mesh, + placements=[Replicate()], ) - continue - # Load the weight - default_weight_loader(param, loaded_weight) + # Use TorchTitan's distributed state dict loading + # This handles TP/PP sharding automatically + set_model_state_dict( + model=self.model, + model_state_dict=torchtitan_state_dict, + options=StateDictOptions( + strict=False, # Allow missing keys + ), + ) - # Add the parameter name to loaded set - # Since CallableModelWrapper overrides named_parameters(), - # the names returned here already match what vLLM expects - loaded_params.add(target_name) + # Get loaded parameter names for vLLM compatibility + loaded_params = set(torchtitan_state_dict.keys()) logger.info( - f"✅ Loaded {len(loaded_params)} parameters, loaded weights are: {loaded_params}" + f"Loaded {len(loaded_params)} parameters from checkpoint using distributed-aware loading" ) return loaded_params diff --git a/torchtitan/models/qwen3/infra/parallelize.py b/torchtitan/models/qwen3/infra/parallelize.py index c2eaed8de6..1dd54657c2 100644 --- a/torchtitan/models/qwen3/infra/parallelize.py +++ b/torchtitan/models/qwen3/infra/parallelize.py @@ -17,6 +17,7 @@ ColwiseParallel, parallelize_module, PrepareModuleInput, + PrepareModuleOutput, RowwiseParallel, SequenceParallel, ) @@ -256,11 +257,19 @@ def apply_non_moe_tp( input_layouts=(Shard(1), Replicate(), None, None), desired_input_layouts=(Replicate(), Replicate(), None, None), ), - "attention.wq": colwise_parallel(use_local_output=False), + "attention.wq": colwise_parallel( + use_local_output=False + ), # NOTE(jianiw): Try using Tensor instead of DTensor "attention.wk": colwise_parallel(use_local_output=False), - "attention.wv": colwise_parallel(use_local_output=False), - "attention.q_norm": SequenceParallel(sequence_dim=2), - "attention.k_norm": SequenceParallel(sequence_dim=2), + "attention.wv": colwise_parallel(use_local_output=True), + "attention.q_norm": SequenceParallel(sequence_dim=2, use_local_output=True), + "attention.k_norm": SequenceParallel(sequence_dim=2, use_local_output=True), + # NOTE(jianiw): manually convert the output to be Shard(1) + "attention.inner_attention": PrepareModuleOutput( + output_layouts=(Shard(1),), + desired_output_layouts=(Shard(1),), + use_local_output=False, + ), # NOTE(jianiw): inner_attention output shape (batch, num_heads, seq_len, head_dim) "attention.wo": rowwise_parallel(output_layouts=Shard(1)), "ffn_norm": SequenceParallel(), } diff --git a/torchtitan/models/qwen3/model/model.py b/torchtitan/models/qwen3/model/model.py index 0683b4c42d..378f203bf6 100644 --- a/torchtitan/models/qwen3/model/model.py +++ b/torchtitan/models/qwen3/model/model.py @@ -61,13 +61,7 @@ def reshape_for_broadcast( rope_cache: torch.Tensor, x: torch.Tensor, positions: torch.Tensor | None = None ) -> torch.Tensor: """ - Reshape frequency tensor (represented by cos, sin) for broadcasting it with another tensor. - - This function reshapes the frequency tensor to have the same shape as the target tensor 'x' - for the purpose of broadcasting the frequency tensor during element-wise operations. - - The input freqs_cis tensor is assumed to be of shape (max_seqlen, head_dim * 2), - and the first seqlen elements will be sliced, but dim must match x. + Reshapes the RoPE frequency tensor to be broadcastable with the input tensor. Args: rope_cache (torch.Tensor): RoPE tensor (cos and sin) to be reshaped. @@ -264,6 +258,8 @@ def forward( keys = repeat_kv(xk, self.n_rep) # (bs, seqlen, n_local_heads, head_dim) values = repeat_kv(xv, self.n_rep) # (bs, seqlen, n_local_heads, head_dim) + # NOTE(jianiw) + xq = xq.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim) xk = keys.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim) xv = values.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim) From 9c8cfe88d8ae531659f1e01a63aa4a4c9de7cb6f Mon Sep 17 00:00:00 2001 From: Jiani Wang Date: Thu, 4 Dec 2025 16:36:59 -0800 Subject: [PATCH 08/19] remove dependency on bram's PR --- torchtitan/experiments/vllm/README.md | 19 +- torchtitan/experiments/vllm/infer.py | 102 +---------- .../experiments/vllm/model/attention.py | 4 +- torchtitan/experiments/vllm/model/qwen3.py | 166 +++++++++++++++--- 4 files changed, 170 insertions(+), 121 deletions(-) diff --git a/torchtitan/experiments/vllm/README.md b/torchtitan/experiments/vllm/README.md index 33725c1f94..7b23b9ced7 100644 --- a/torchtitan/experiments/vllm/README.md +++ b/torchtitan/experiments/vllm/README.md @@ -14,7 +14,14 @@ The integration consists of two main components: ### Prerequisites -1. Install vLLM from source [vllm-use-an-existing-pytorch-installation](https://docs.vllm.ai/en/latest/getting_started/installation/gpu/index.html#use-an-existing-pytorch-installation): + +1. Install PyTorch nightly for torchtitan: +``` +pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu126 --force-reinstall +``` + + +2. Install vLLM from source [vllm-use-an-existing-pytorch-installation](https://docs.vllm.ai/en/latest/getting_started/installation/gpu/index.html#use-an-existing-pytorch-installation): ```bash # install PyTorch first, either from PyPI or from source git clone https://github.com/vllm-project/vllm.git @@ -42,10 +49,16 @@ pip install -e . ``` -2. Download Qwen3/Qwen3-0.6b checkpoint from HuggingFace and put into `example_checkpoint` folder. Make sure to change the "architecture" field in `config.json` to be `Qwen3TorchTitanForCausalLM` so vllm engine could use torchtitan model. +3. Download Qwen3/Qwen3-0.6b checkpoint from HuggingFace and put into `example_checkpoint` folder. Make sure to change the "architecture" field in `config.json` to be `Qwen3TorchTitanForCausalLM` so vllm engine could use torchtitan model. -3. Run inference: +4. Run inference: ``` python torchtitan/experiments/vllm/infer.py --model torchtitan/experiments/vllm/example_checkpoint/qwen3-0.6B ``` + +Run with TP: (work in progress) +``` +python torchtitan/experiments/vllm/infer.py --model torchtitan/experiments/vllm/example_checkpoint/qwen3-0.6B --tensor-parallel-size 2 + +``` diff --git a/torchtitan/experiments/vllm/infer.py b/torchtitan/experiments/vllm/infer.py index 34a739c2df..d178b428ec 100755 --- a/torchtitan/experiments/vllm/infer.py +++ b/torchtitan/experiments/vllm/infer.py @@ -14,92 +14,13 @@ import argparse -import torch.nn as nn +# Import and register the TorchTitan vLLM plugin +from torchtitan.experiments.vllm.register import register from vllm import LLM, SamplingParams -from vllm.model_executor.parallel_context import ParallelContext - -def build_qwen3_torchtitan(vllm_config, parallel_context: ParallelContext) -> nn.Module: - """ - Factory function to build Qwen3 with TorchTitan + vLLM. - - This is registered with vLLM's ModelRegistry to enable: - LLM(model="Qwen/Qwen3-0.6B", ...) - - Args: - vllm_config: vLLM configuration object - parallel_context: Parallelism context with TP/PP info - - Returns: - TorchTitanQwen3ForCausalLM instance - """ - from torchtitan.experiments.vllm.model.qwen3 import TorchTitanQwen3ForCausalLM - - # Create model - model = TorchTitanQwen3ForCausalLM( - vllm_config=vllm_config, parallel_context=parallel_context - ) - - # Apply tensor parallelism if TP > 1 - # This must happen AFTER model creation and attention replacement - # but BEFORE dtype conversion (to avoid dtype issues with DTensors) - if parallel_context is not None: - tp_size = parallel_context.get_tensor_parallel_world_size() - if tp_size > 1: - from torch.distributed.device_mesh import init_device_mesh - from torchtitan.models.qwen3.infra.parallelize import apply_non_moe_tp - - print(f"🔧 Applying Tensor Parallelism (TP={tp_size})...") - - # Create DeviceMesh for TorchTitan - tp_mesh = init_device_mesh( - "cuda", - (tp_size,), - mesh_dim_names=("tp",), - ) - - # Apply TorchTitan's tensor parallelism to shard weights - apply_non_moe_tp( - model.model, - tp_mesh=tp_mesh, - loss_parallel=False, # Don't shard the output for loss computation - enable_float8_tensorwise_tp=False, - enable_async_tp=False, - ) - - print(f"✅ Applied Tensor Parallelism (TP={tp_size})") - - # Convert to dtype if specified (happens after TP) - if hasattr(vllm_config, "model_config") and hasattr( - vllm_config.model_config, "dtype" - ): - model = model.to(dtype=vllm_config.model_config.dtype) - - return model - - -def register_torchtitan_model(): - """ - Register the TorchTitan Qwen3 custom model with vLLM using factory function pattern. - - This registers a factory function that vLLM will call to create the model, - allowing us to apply tensor parallelism and other transformations. - """ - try: - from vllm import ModelRegistry - - # Register the factory function with vLLM - # vLLM will call build_qwen3_torchtitan(vllm_config, parallel_context) - ModelRegistry.register_model( - "Qwen3TorchTitanForCausalLM", build_qwen3_torchtitan - ) - - print("✅ Successfully registered TorchTitan Qwen3 custom model with vLLM") - return True - - except Exception as e: - print(f"⚠️ Warning: Failed to register custom model: {e}") - return False +# Register TorchTitan models with vLLM. +# NOTE(jianiw): We could use plug-in system instead: https://docs.vllm.ai/en/latest/design/plugin_system/ +register() def parse_args(): @@ -143,17 +64,6 @@ def parse_args(): def main(): args = parse_args() - print("=" * 80) - print("REGISTERING TORCHTITAN QWEN3 CUSTOM MODEL") - print("=" * 80) - - # Register the custom model with vLLM - register_torchtitan_model() - - # Create a temporary directory with minimal config.json for vLLM - - print(f"Using checkpoint and config.json from: {args.model}") - print("=" * 80) print("INITIALIZING vLLM WITH TORCHTITAN QWEN3 MODEL") print("=" * 80) @@ -170,6 +80,8 @@ def main(): dtype="bfloat16", trust_remote_code=True, enforce_eager=True, # Use eager mode for debugging + # Disable kv cache, required for now + enable_prefix_caching=False, tensor_parallel_size=args.tensor_parallel_size, # Multi-GPU support ) diff --git a/torchtitan/experiments/vllm/model/attention.py b/torchtitan/experiments/vllm/model/attention.py index bde41fd38d..aebc579b12 100644 --- a/torchtitan/experiments/vllm/model/attention.py +++ b/torchtitan/experiments/vllm/model/attention.py @@ -14,7 +14,7 @@ import torch -from vllm.attention import Attention +from vllm.attention.layer import Attention from vllm.attention.utils.fa_utils import flash_attn_varlen_func, get_flash_attn_version from vllm.model_executor.layers.batch_invariant import vllm_is_batch_invariant @@ -199,7 +199,7 @@ def forward( # Use vLLM's Attention layer (requires forward context) output_varlen = self.vllm_attn(q, k, v) - print(f"[jianiw] vllm_attn output is: {output_varlen}") + # print(f"[jianiw] vllm_attn output is: {output_varlen}") # Reshape back to batch format output = output_varlen.view(batch_size, seq_len, num_heads, head_dim) diff --git a/torchtitan/experiments/vllm/model/qwen3.py b/torchtitan/experiments/vllm/model/qwen3.py index 97256c4e74..6abaef480d 100644 --- a/torchtitan/experiments/vllm/model/qwen3.py +++ b/torchtitan/experiments/vllm/model/qwen3.py @@ -24,6 +24,7 @@ """ import torch +from torch.distributed.device_mesh import init_device_mesh from torch.distributed.tensor import DTensor # Import from local custom_models directory @@ -33,10 +34,13 @@ VLLMModelForCausalLM, ) from torchtitan.experiments.vllm.model.attention import VLLMPagedFlashAttention +from torchtitan.models.qwen3.infra.parallelize import apply_non_moe_tp +from torchtitan.tools.utils import device_type from vllm.config import VllmConfig from vllm.logger import init_logger -from vllm.model_executor.parallel_context import ParallelContext + +# from vllm.model_executor.parallel_context import ParallelContext logger = init_logger(__name__) @@ -54,6 +58,7 @@ class TorchTitanQwen3ForCausalLM(VLLMModelForCausalLM): with RoPE positional embeddings and optional QK normalization. """ + is_text_generation_model = True # Required for vLLM runner validation supports_pp = False # Pipeline parallelism not supported yet supports_multimodal = False @@ -61,8 +66,7 @@ def __init__( self, *, vllm_config: VllmConfig, - prefix: str = "", - parallel_context: ParallelContext | None = None, + prefix: str = "", # This is required for vLLM interface ): super().__init__() @@ -96,8 +100,26 @@ def __init__( # Create TorchTitan model self.model = Qwen3Model(model_args) self.config = model_args - self.parallel_context = parallel_context + # self.parallel_context = parallel_context + + self._replice_with_vllm_paged_attention(model_args) + + ( + dp_size, + mp_size, + cp_size, + pp_size, + ep_size, + etp_size, + ) = self._process_parallelism_settings(vllm_config) + + # Build device mesh and apply parallelization + if mp_size > 1 or ep_size > 1: + self._build_device_mesh_and_parallelize( + dp_size, mp_size, cp_size, pp_size, ep_size, etp_size + ) + def _replice_with_vllm_paged_attention(self, model_args): # The `vllm.Attention` module handles QKV projection, RoPE, etc., and calls `inner_attention` if not hasattr(self.model, "layers"): raise AttributeError( @@ -118,11 +140,109 @@ def __init__( ) layer.attention.inner_attention = vllm_attn + logger.info( + "Successfully replaced TorchTitan attention with vLLM PagedFlashAttention" + ) - def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: - """Convert input token IDs to embeddings.""" + def _process_parallelism_settings( + self, vllm_config: VllmConfig, use_token_shuffling_moe: bool = False + ): + """ + Parse parallel config from vllm config + """ + world_size = ( + vllm_config.parallel_config.data_parallel_size + * vllm_config.parallel_config.tensor_parallel_size + ) + ep_size = ( + world_size if vllm_config.parallel_config.enable_expert_parallel else 1 + ) + etp_size = ( + 1 if vllm_config.parallel_config.enable_expert_parallel else world_size + ) + dp_size = vllm_config.parallel_config.data_parallel_size + + mp_size = vllm_config.parallel_config.tensor_parallel_size + cp_size = vllm_config.parallel_config.decode_context_parallel_size + pp_size = vllm_config.parallel_config.pipeline_parallel_size + self.pp_size = pp_size + + return dp_size, mp_size, cp_size, pp_size, ep_size, etp_size + + def _build_device_mesh_and_parallelize( + self, + dp_size: int, + mp_size: int, + cp_size: int, + pp_size: int, + ep_size: int, + etp_size: int, + ): + """ + Build device mesh in TorchTitan style and apply parallelization to the model. + + This follows the same approach as TorchTitan's ParallelDims.build_mesh() + and parallelize_qwen3() functions. + """ + import torch.distributed as dist + + # Get world size and validate + world_size = dist.get_world_size() + + # For now, assume dp_shard=1 (no data parallel sharding) + # In full implementation, you may need to calculate dp_replicate and dp_shard + dp_replicate = dp_size + dp_shard = 1 + + # Validate parallelism settings + assert dp_replicate * dp_shard * cp_size * mp_size * pp_size == world_size, ( + f"Invalid parallel dims: dp_replicate({dp_replicate}) * dp_shard({dp_shard}) * " + f"cp({cp_size}) * tp({mp_size}) * pp({pp_size}) != WORLD_SIZE({world_size})" + ) + + # Build device mesh following TorchTitan's _build_mesh_without_ep pattern + # (assuming no EP for now) + dims = [] + names = [] + for d, name in zip( + [pp_size, dp_replicate, dp_shard, cp_size, mp_size], + ["pp", "dp_replicate", "dp_shard", "cp", "tp"], + ): + if d > 1: + dims.append(d) + names.append(name) + + logger.info(f"Building {len(dims)}-D device mesh with {names}, {dims}") + world_mesh = init_device_mesh(device_type, dims, mesh_dim_names=names) + + logger.info(f"Build torchtitan device mesh: {world_mesh}") + + # Apply tensor parallelism if enabled + if mp_size > 1: + tp_mesh = world_mesh["tp"] + apply_non_moe_tp( + model=self.model, + tp_mesh=tp_mesh, + loss_parallel=False, # vLLM handles loss computation separately + enable_float8_tensorwise_tp=False, # Can be enabled if needed + enable_async_tp=False, # Can be enabled if needed + ) + logger.info(f"Applied Tensor Parallelism with TP={mp_size}") + + # Store the mesh for future use + self.world_mesh = world_mesh + + def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: + """Convert input token IDs to embeddings. + + This is the vLLM-standard method name for embedding tokens. + """ return self.model.tok_embeddings(input_ids) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + """Convert input token IDs to embeddings (deprecated vLLM interface).""" + return self.embed_input_ids(input_ids) + def forward( self, input_ids: torch.Tensor | None = None, @@ -183,7 +303,7 @@ def forward( batch_size, seq_len, hidden_size = h.shape h = h.view(batch_size * seq_len, hidden_size) # [total_tokens, hidden_size] - # NOTE(jianiw): explicitly insert communication and return full tensor to vLLM Engine + # TODO(jianiw): explicitly insert communication and return full tensor to vLLM Engine. To be checked. if isinstance(h, DTensor): h = h.full_tensor() return h @@ -192,17 +312,25 @@ def compute_logits( self, hidden_states: torch.Tensor, sampling_metadata=None, - ) -> torch.Tensor: - """Compute logits from hidden states.""" - # hidden states is the output from - # NOTE(jianiw): When TP is enabled, self.model.norm isself.model.output is ColumnParallel() annotated + ) -> torch.Tensor | None: + """Compute logits from hidden states. - # Current hidden_states is replicate() + Returns: + Logits tensor, or None if TP rank > 0 + """ + # Apply final layer norm + h = self.model.norm(hidden_states) + + # Apply output projection to get logits + logits = self.model.output(h) - h = self.model.norm(hidden_states) # [1, total_tokens, hidden_size] - output = self.model.output(h) + # When using TP, only rank 0 returns logits + # vLLM expects None from other ranks + if isinstance(logits, DTensor): + # Convert DTensor to local tensor for vLLM + logits = logits.full_tensor() - return self.model.output(hidden_states) + return logits def load_weights(self, weights_iter): """ @@ -235,10 +363,7 @@ def load_weights(self, weights_iter): hf_assets_path=None, # Not needed for from_hf conversion ) - # Convert the entire state dict at once torchtitan_state_dict = adapter.from_hf(hf_state_dict) - - # Get model state dict to check which tensors are DTensors model_state_dict = {k: v for k, v in self.model.state_dict().items()} # Convert HF tensors to replicate DTensor if target is DTensor @@ -264,9 +389,8 @@ def load_weights(self, weights_iter): ), ) - # Get loaded parameter names for vLLM compatibility - loaded_params = set(torchtitan_state_dict.keys()) - + # manually patch the loaded + loaded_params = {f"model.{name}" for name in torchtitan_state_dict.keys()} logger.info( f"Loaded {len(loaded_params)} parameters from checkpoint using distributed-aware loading" ) From 58e1cdb5552a18957e75a14d29e5999379f01c31 Mon Sep 17 00:00:00 2001 From: Jiani Wang Date: Thu, 4 Dec 2025 17:32:46 -0800 Subject: [PATCH 09/19] add missing file --- torchtitan/experiments/vllm/register.py | 26 +++++++++++++++++++++++++ 1 file changed, 26 insertions(+) create mode 100644 torchtitan/experiments/vllm/register.py diff --git a/torchtitan/experiments/vllm/register.py b/torchtitan/experiments/vllm/register.py new file mode 100644 index 0000000000..e4803c492d --- /dev/null +++ b/torchtitan/experiments/vllm/register.py @@ -0,0 +1,26 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + + +def register(): + """ + Register TorchTitan models with vLLM. + + This function is called to register TorchTitan-trained models with vLLM. + It sets up the necessary model registry entries for TorchTitan models. + + Currently supports: + - Qwen3TorchTitanForCausalLM: Qwen3 models trained with TorchTitan + + """ + from torchtitan.experiments.vllm.model.qwen3 import TorchTitanQwen3ForCausalLM + from vllm.model_executor.models.registry import ModelRegistry + + # Register Qwen3TorchTitanForCausalLM with vLLM's ModelRegistry + # This maps the architecture name from config.json to the model class + ModelRegistry.register_model( + "Qwen3TorchTitanForCausalLM", TorchTitanQwen3ForCausalLM + ) From d70ff99bae10de9e2c5d65ac32ef16232d850497 Mon Sep 17 00:00:00 2001 From: Jiani Wang Date: Fri, 5 Dec 2025 13:57:22 -0800 Subject: [PATCH 10/19] merge to deterministic_rl --- .../deterministic_vllm_rl/README.md | 71 +- .../{vllm => deterministic_vllm_rl}/infer.py | 9 +- .../deterministic_vllm_rl/models/__init__.py | 4 +- .../deterministic_vllm_rl/models/attention.py | 392 ++++++++++- .../models/qwen3/__init__.py | 5 +- .../models/qwen3/model_batch_invariant.py | 368 +++++++++++ .../models/qwen3/model_vllm_compat.py | 618 +++++++++--------- .../register.py | 5 +- torchtitan/experiments/vllm/README.md | 64 -- .../vllm/custom_models/__init__.py | 19 - .../custom_models/custom_model_wrapper.py | 133 ---- .../experiments/vllm/custom_models/utils.py | 245 ------- .../experiments/vllm/model/attention.py | 402 ------------ torchtitan/experiments/vllm/model/qwen3.py | 398 ----------- .../vllm/tests/test_flash_attn_varlen.py | 183 ------ .../tests/test_fwd_bwd_with_vllm_attention.py | 537 --------------- .../vllm/tests/test_vllm_attention.py | 360 ---------- 17 files changed, 1152 insertions(+), 2661 deletions(-) rename torchtitan/experiments/{vllm => deterministic_vllm_rl}/infer.py (92%) create mode 100644 torchtitan/experiments/deterministic_vllm_rl/models/qwen3/model_batch_invariant.py rename torchtitan/experiments/{vllm => deterministic_vllm_rl}/register.py (87%) delete mode 100644 torchtitan/experiments/vllm/README.md delete mode 100644 torchtitan/experiments/vllm/custom_models/__init__.py delete mode 100644 torchtitan/experiments/vllm/custom_models/custom_model_wrapper.py delete mode 100644 torchtitan/experiments/vllm/custom_models/utils.py delete mode 100644 torchtitan/experiments/vllm/model/attention.py delete mode 100644 torchtitan/experiments/vllm/model/qwen3.py delete mode 100644 torchtitan/experiments/vllm/tests/test_flash_attn_varlen.py delete mode 100644 torchtitan/experiments/vllm/tests/test_fwd_bwd_with_vllm_attention.py delete mode 100644 torchtitan/experiments/vllm/tests/test_vllm_attention.py diff --git a/torchtitan/experiments/deterministic_vllm_rl/README.md b/torchtitan/experiments/deterministic_vllm_rl/README.md index d2ef719c0d..e667d9c147 100644 --- a/torchtitan/experiments/deterministic_vllm_rl/README.md +++ b/torchtitan/experiments/deterministic_vllm_rl/README.md @@ -28,7 +28,7 @@ Note: Currently supports single-device training only. - Implements custom backward pass for gradient computation - Uses `num_splits=1` for deterministic behavior -2. `models/qwen3/model_vllm_compat.py`: Qwen3VLLMCompatModel +2. `models/qwen3/model_batch_invariant.py`: Qwen3VLLMCompatModel - Qwen3 model with merged gate/up projections matching vLLM format - Uses VLLMRMSNorm with gradient support @@ -211,7 +211,72 @@ This implementation uses the same kernels for both rollouts (vLLM) and training 2. Only causal attention is supported 3. Requires NVIDIA GPUs with Flash Attention support -## Project Structure + +# Run vLLM inference with TorchTitan Qwen3 Model + +This directory contains code to run TorchTitan model definition with vLLM inference engine (not batch-invariant yet, working in progress). +This work is inspired by https://github.com/vllm-project/vllm/pull/28685. + +## Overview +The integration consists of two main components: + +1. **Model Adapter** (`model/qwen3.py`): A custom model class that extends vLLM's `Qwen3ForCausalLM` to handle TorchTitan checkpoint naming conventions +2. **Inference Script** (`infer.py`): A simple script to register the model and run inference + + +## Quick Start +### Prerequisites + +1. Install PyTorch nightly for torchtitan: +``` +pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu126 --force-reinstall +``` + + +2. Install vLLM from source [vllm-use-an-existing-pytorch-installation](https://docs.vllm.ai/en/latest/getting_started/installation/gpu/index.html#use-an-existing-pytorch-installation): +```bash +# install PyTorch first, either from PyPI or from source +git clone https://github.com/vllm-project/vllm.git +cd vllm +python use_existing_torch.py +uv pip install -r requirements/build.txt +uv pip install --no-build-isolation -e . +``` + + +NOTE: If `flash_attn_varlen_func` hits error "torch.AcceleratorError: CUDA error: the provided PTX was compiled with an unsupported toolchain" during forward path, this is due to GPU driver version is not compatible with vLLM/PyTorch compiled version. Use the following command to recompile vLLM. + +``` +# Set CUDA version environment variable +export CUDA_HOME=/usr/local/cuda-12.4 +export PATH=/usr/local/cuda-12.4/bin:$PATH +export LD_LIBRARY_PATH=/usr/local/cuda-12.4/lib64:$LD_LIBRARY_PATH + +# Clean previous build +rm -rf build dist *.egg-info +pip uninstall -y vllm + +# Rebuild vLLM from source with CUDA 12.4 +pip install -e . + +``` + +3. Download Qwen3/Qwen3-0.6b checkpoint from HuggingFace and put into `example_checkpoint` folder. Make sure to change the "architecture" field in `config.json` to be `Qwen3TorchTitanForCausalLM` so vllm engine could use torchtitan model. + + +4. Run inference: +``` +python torchtitan/experiments/deterministic_vllm_rl/infer.py --model torchtitan/experiments/deterministic_vllm_rl/example_checkpoint/qwen3-0.6B +``` + +Run with TP: (work in progress) +``` +python torchtitan/experiments/deterministic_vllm_rl/infer.py --model torchtitan/experiments/deterministic_vllm_rl/example_checkpoint/qwen3-0.6B --tensor-parallel-size 2 + +``` + + +# Project Structure ``` deterministic_vllm_rl/ @@ -225,7 +290,7 @@ deterministic_vllm_rl/ │ ├── attention.py # VLLMCompatibleFlashAttention │ └── qwen3/ │ ├── __init__.py -│ └── model_vllm_compat.py # vLLM-compatible Qwen3 model +│ └── model_batch_invariant.py # Batch-invariant Qwen3 model ├── weights/ │ ├── __init__.py │ ├── converter.py # Weight conversion script diff --git a/torchtitan/experiments/vllm/infer.py b/torchtitan/experiments/deterministic_vllm_rl/infer.py similarity index 92% rename from torchtitan/experiments/vllm/infer.py rename to torchtitan/experiments/deterministic_vllm_rl/infer.py index d178b428ec..e9517a456d 100755 --- a/torchtitan/experiments/vllm/infer.py +++ b/torchtitan/experiments/deterministic_vllm_rl/infer.py @@ -9,15 +9,16 @@ Example CLI to run TorchTitan Qwen3 model inference with vLLM: # Run inference -python torchtitan/experiments/vllm/infer.py +python torchtitan/experiments/deterministic_vllm_rl/infer.py """ import argparse -# Import and register the TorchTitan vLLM plugin -from torchtitan.experiments.vllm.register import register from vllm import LLM, SamplingParams +# Import and register the TorchTitan vLLM plugin +from torchtitan.experiments.deterministic_vllm_rl.register import register + # Register TorchTitan models with vLLM. # NOTE(jianiw): We could use plug-in system instead: https://docs.vllm.ai/en/latest/design/plugin_system/ register() @@ -31,7 +32,7 @@ def parse_args(): parser.add_argument( "--model", type=str, - default="torchtitan/experiments/vllm/example_checkpoint/qwen3-0.6B", + default="torchtitan/experiments/deterministic_vllm_rl/example_checkpoint/qwen3-0.6B", help="Path to TorchTitan checkpoint directory", ) parser.add_argument( diff --git a/torchtitan/experiments/deterministic_vllm_rl/models/__init__.py b/torchtitan/experiments/deterministic_vllm_rl/models/__init__.py index c8c11a170a..823ec82ad7 100644 --- a/torchtitan/experiments/deterministic_vllm_rl/models/__init__.py +++ b/torchtitan/experiments/deterministic_vllm_rl/models/__init__.py @@ -8,6 +8,6 @@ Models for deterministic vLLM RL training. """ -from .attention import VLLMCompatibleFlashAttention +from .attention import VLLMCompatibleFlashAttention, VLLMPagedFlashAttention -__all__ = ["VLLMCompatibleFlashAttention"] +__all__ = ["VLLMCompatibleFlashAttention", "VLLMPagedFlashAttention"] diff --git a/torchtitan/experiments/deterministic_vllm_rl/models/attention.py b/torchtitan/experiments/deterministic_vllm_rl/models/attention.py index 33dd5a140d..00f1245230 100644 --- a/torchtitan/experiments/deterministic_vllm_rl/models/attention.py +++ b/torchtitan/experiments/deterministic_vllm_rl/models/attention.py @@ -8,8 +8,13 @@ vLLM-compatible Flash Attention implementation for deterministic RL training. """ +import itertools + import torch -from vllm.vllm_flash_attn import flash_attn_varlen_func + +from vllm.attention.layer import Attention +from vllm.attention.utils.fa_utils import flash_attn_varlen_func, get_flash_attn_version +from vllm.model_executor.layers.batch_invariant import vllm_is_batch_invariant class VLLMCompatibleFlashAttention(torch.nn.Module): @@ -194,3 +199,388 @@ def backward(ctx, grad_output): output = output.transpose(1, 2) return output + + +class VLLMPagedFlashAttention(torch.nn.Module): + """ + Wrapper around vLLM's Attention with custom backward pass. + + Forward: Uses vLLM's optimized Attention layer (flash attention kernels) + Backward: Custom implementation for deterministic gradients + + TODO: This class need to be future refined. + """ + + # Class variable for auto-generating unique layer names (thread-safe) + _layer_counter = itertools.count() + + def __init__( + self, + hidden_size: int, + num_heads: int, + num_kv_heads: int, + head_dim: int, + scale: float | None = None, + causal: bool = True, + ) -> None: + super().__init__() + + self.hidden_size = hidden_size + + # Handle tensor parallelism: adjust num_heads and num_kv_heads for TP + # NOTE(jianiw): As we use local tensor for this region, we need to manually + try: + from vllm.config import get_current_vllm_config + from vllm.logger import init_logger + + logger = init_logger(__name__) + vllm_config = get_current_vllm_config() + tp_size = vllm_config.parallel_config.tensor_parallel_size + + if tp_size > 1: + if num_kv_heads % tp_size != 0: + # Pad num_kv_heads and num_heads to be divisible by tp_size + assert num_heads % num_kv_heads == 0 + padded_size = tp_size - num_kv_heads % tp_size + padded_num_kv_heads = num_kv_heads + padded_size + padded_num_heads = ( + num_heads + padded_size * num_heads // num_kv_heads + ) + assert padded_num_heads % tp_size == 0 + assert padded_num_kv_heads % tp_size == 0 + + logger.info( + f"Padding attention heads for tensor parallelism: " + f"{num_heads=}, {padded_num_heads=}, " + f"{num_kv_heads=}, {padded_num_kv_heads=}" + ) + + num_heads = padded_num_heads // tp_size + num_kv_heads = padded_num_kv_heads // tp_size + else: + num_heads //= tp_size + num_kv_heads //= tp_size + except (ImportError, RuntimeError, AttributeError): + # Not in vLLM context - use original values + pass + + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = head_dim + self.causal = causal + + if scale is None: + self.scale = head_dim**-0.5 + else: + self.scale = scale + + # Create vLLM Attention layer + try: + from vllm.config import get_current_vllm_config + + config = get_current_vllm_config() + cache_config = ( + config.cache_config if hasattr(config, "cache_config") else None + ) + + # Generate unique prefix for this attention layer + # vLLM expects format "layers.X" for layer index extraction + layer_idx = next(VLLMPagedFlashAttention._layer_counter) + prefix = f"layers.{layer_idx}" + + self.vllm_attn = Attention( + num_heads=num_heads, + head_size=head_dim, + scale=self.scale, + num_kv_heads=num_kv_heads, + cache_config=cache_config, + quant_config=None, + prefix=prefix, + ) + + except (ImportError, RuntimeError, AttributeError): + # Not in vLLM context - will need to set up manually + self.vllm_attn = None + + # KV cache - will be populated by vLLM during model loading + self.kv_cache: list[torch.Tensor] | None = None + + # Auto-register for vLLM KV cache if in vLLM context + self._auto_register_for_kv_cache() + + def _auto_register_for_kv_cache(self): + """Automatically register this layer for vLLM KV cache allocation. + + This is called during __init__ and will register the layer if we're in + a vLLM context. If not in vLLM context (e.g., pure PyTorch training), + this silently does nothing. + """ + # Initialize layer_name attribute + self.layer_name: str | None = None + + try: + from vllm.config import get_current_vllm_config + + config = get_current_vllm_config() + compilation_config = config.compilation_config + + # Generate unique layer name using class counter + # Format: "layers.{index}" for compatibility with extract_layer_index() + layer_name = f"layers.{next(VLLMPagedFlashAttention._layer_counter)}" + + # Register this layer in static forward context + if layer_name in compilation_config.static_forward_context: + raise ValueError(f"Duplicate layer name: {layer_name}") + compilation_config.static_forward_context[layer_name] = self + self.layer_name = layer_name + + except (ImportError, RuntimeError, AttributeError): + # Not in vLLM context - this is fine! + # Layer will work normally for training/inference without vLLM + pass + + def forward( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + *, + scale: float | None = None, + ) -> torch.Tensor: + """ + Forward with dual-mode behavior: + - Inference (model.training=False): Use vLLM's Attention layer (KV cache, etc.) + - Training (model.training=True): Use flash_attn_varlen_func with custom backward + - vLLM's Attention used flash_attn_varlen_func kernel by default. + + Args: + q: Query tensor [batch, num_heads, seq_len, head_dim] + k: Key tensor [batch, num_kv_heads, seq_len, head_dim] + v: Value tensor [batch, num_kv_heads, seq_len, head_dim] + scale: Optional attention scale override + + Returns: + output: [batch, num_heads, seq_len, head_dim] + """ + # Input is (batch, num_heads, seq_len, head_dim) + batch_size, num_heads, seq_len, head_dim = q.shape + + # INFERENCE MODE: Use vLLM's Attention layer + if not self.training and self.vllm_attn is not None: + # Transpose to (batch, seq_len, num_heads, head_dim) for vLLM + q = q.transpose(1, 2) + k = k.transpose(1, 2) + v = v.transpose(1, 2) + + # # Flatten to (total_tokens, num_heads, head_dim) + # NOTE(jianiw): vllm_attention can also take input as shape (batch, seq_len, num_heads, head_dim) and do internally + + # q_varlen = q.reshape(-1, num_heads, head_dim) + # k_varlen = k.reshape(-1, k.shape[2], head_dim) # k.shape[2] = num_kv_head + # v_varlen = v.reshape(-1, v.shape[2], head_dim) + + try: + # Use vLLM's Attention layer (requires forward context) + output_varlen = self.vllm_attn(q, k, v) + + # print(f"[jianiw] vllm_attn output is: {output_varlen}") + # Reshape back to batch format + output = output_varlen.view(batch_size, seq_len, num_heads, head_dim) + + # Transpose back to TorchTitan format + output = output.transpose(1, 2) + + return output + except (AssertionError, RuntimeError) as e: + # Forward context not available, fall through to training mode + print(f"Error when calling self.vllm_attn during Inference, {str(e)}") + raise + + # TRAINING MODE: Use flash_attn_varlen_func with custom backward + # Transpose to (batch, seq_len, num_heads, head_dim) for vLLM + q = q.transpose(1, 2) + k = k.transpose(1, 2) + v = v.transpose(1, 2) + + # After to_local(), use actual tensor shapes (TP may have sharded heads) + # Shape: (batch, seq_len, num_heads_local, head_dim) + _, _, num_heads_local, _ = q.shape + _, _, num_kv_heads_local, _ = k.shape + + # Convert to varlen format for vLLM: flatten batch and sequence + # (batch, seq_len, num_heads, head_dim) -> (total_tokens, num_heads, head_dim) + q_varlen = q.reshape(-1, num_heads_local, head_dim) + k_varlen = k.reshape(-1, num_kv_heads_local, head_dim) + v_varlen = v.reshape(-1, num_kv_heads_local, head_dim) + + # Use custom autograd function with flash_attn_varlen_func forward and manual backward + class VLLMForwardCustomBackward(torch.autograd.Function): + @staticmethod + def forward(ctx, q, k, v, scale, batch_size, seq_len, causal, fa_version): + # Flash Attention only supports fp16 and bf16 + # Store original dtype for conversion back + original_dtype = q.dtype + + # Convert to bf16 if not already fp16/bf16 + if original_dtype not in [torch.float16, torch.bfloat16]: + target_dtype = ( + torch.bfloat16 + if torch.cuda.is_bf16_supported() + else torch.float16 + ) + q = q.to(target_dtype) + k = k.to(target_dtype) + v = v.to(target_dtype) + else: + target_dtype = original_dtype + + # Use flash_attn_varlen_func directly for fast forward pass + # This is the SAME kernel vLLM uses internally! + # TODO(jianiw): Need to double-check + cu_seqlens_q = torch.arange( + 0, + (batch_size + 1) * seq_len, + seq_len, + dtype=torch.int32, + device=q.device, + ) + + output = flash_attn_varlen_func( + q=q, + k=k, + v=v, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_q, + max_seqlen_q=seq_len, + max_seqlen_k=seq_len, + softmax_scale=scale, + causal=causal, + num_splits=1 if vllm_is_batch_invariant() else 0, + fa_version=fa_version, + ) + + # Convert output back to original dtype if needed + if original_dtype not in [torch.float16, torch.bfloat16]: + output = output.to(original_dtype) + + # Save for backward + ctx.save_for_backward(q, k, v, output) + ctx.scale = scale + ctx.seq_len = seq_len + ctx.batch_size = batch_size + ctx.causal = causal + ctx.original_dtype = original_dtype + + return output + + @staticmethod + def backward(ctx, grad_output): + q, k, v, output = ctx.saved_tensors + scale = ctx.scale + seq_len = ctx.seq_len + batch_size = ctx.batch_size + causal = ctx.causal + original_dtype = ctx.original_dtype + + # Convert grad_output to match saved tensor dtype + if grad_output.dtype != q.dtype: + grad_output = grad_output.to(q.dtype) + + # Reshape from varlen to batch format + total_tokens = q.shape[0] + num_heads = q.shape[1] + head_dim = q.shape[2] + + q_batch = q.reshape(batch_size, seq_len, num_heads, head_dim) + k_batch = k.reshape(batch_size, seq_len, num_heads, head_dim) + v_batch = v.reshape(batch_size, seq_len, num_heads, head_dim) + grad_out_batch = grad_output.reshape( + batch_size, seq_len, num_heads, head_dim + ) + + # Transpose to (batch, num_heads, seq_len, head_dim) + q_t = q_batch.transpose(1, 2) + k_t = k_batch.transpose(1, 2) + v_t = v_batch.transpose(1, 2) + grad_out_t = grad_out_batch.transpose(1, 2) + + # Compute attention scores: QK^T + scores = torch.matmul(q_t, k_t.transpose(-2, -1)) * scale + + # Apply causal mask if needed + if causal: + causal_mask = torch.triu( + torch.ones(seq_len, seq_len, device=q.device, dtype=torch.bool), + diagonal=1, + ) + scores = scores.masked_fill(causal_mask, float("-inf")) + + # Softmax + attn_weights = torch.nn.functional.softmax(scores, dim=-1) + + # Backward through attention + # grad_v = attn_weights^T @ grad_out + grad_v_t = torch.matmul(attn_weights.transpose(-2, -1), grad_out_t) + + # grad_attn_weights = grad_out @ v^T + grad_attn_weights = torch.matmul(grad_out_t, v_t.transpose(-2, -1)) + + # Backward through softmax + sum_term = (grad_attn_weights * attn_weights).sum(dim=-1, keepdim=True) + grad_scores = attn_weights * (grad_attn_weights - sum_term) + + # Apply causal mask to gradients + if causal: + grad_scores = grad_scores.masked_fill(causal_mask, 0.0) + + # Backward through scale + grad_scores = grad_scores * scale + + # grad_q = grad_scores @ K + grad_q_t = torch.matmul(grad_scores, k_t) + + # grad_k = grad_scores^T @ Q + grad_k_t = torch.matmul(grad_scores.transpose(-2, -1), q_t) + + # Transpose back and reshape to varlen format + grad_q = grad_q_t.transpose(1, 2).reshape( + total_tokens, num_heads, head_dim + ) + grad_k = grad_k_t.transpose(1, 2).reshape( + total_tokens, num_heads, head_dim + ) + grad_v = grad_v_t.transpose(1, 2).reshape( + total_tokens, num_heads, head_dim + ) + + # Convert gradients back to original dtype if needed + if original_dtype not in [torch.float16, torch.bfloat16]: + grad_q = grad_q.to(original_dtype) + grad_k = grad_k.to(original_dtype) + grad_v = grad_v.to(original_dtype) + + return grad_q, grad_k, grad_v, None, None, None, None, None + + # Get flash attention version + fa_version = get_flash_attn_version() + + # Apply custom autograd function + output_varlen = VLLMForwardCustomBackward.apply( + q_varlen, + k_varlen, + v_varlen, + scale or self.scale, + batch_size, + seq_len, + self.causal, + fa_version, + ) + + # Convert back to batch format + # (total_tokens, num_heads, head_dim) -> (batch, seq_len, num_heads, head_dim) + output = output_varlen.reshape(batch_size, seq_len, num_heads, head_dim) + + # Transpose back to TorchTitan format: (batch, num_heads, seq_len, head_dim) + output = output.transpose(1, 2) + + return output diff --git a/torchtitan/experiments/deterministic_vllm_rl/models/qwen3/__init__.py b/torchtitan/experiments/deterministic_vllm_rl/models/qwen3/__init__.py index 10f49db8b5..c721a01b3a 100644 --- a/torchtitan/experiments/deterministic_vllm_rl/models/qwen3/__init__.py +++ b/torchtitan/experiments/deterministic_vllm_rl/models/qwen3/__init__.py @@ -8,6 +8,7 @@ Qwen3 model with vLLM compatibility for deterministic RL training. """ -from .model_vllm_compat import Qwen3VLLMCompatModel +from .model_batch_invariant import Qwen3VLLMCompatModel +from .model_vllm_compat import TorchTitanQwen3ForCausalLM -__all__ = ["Qwen3VLLMCompatModel"] +__all__ = ["Qwen3VLLMCompatModel", "TorchTitanQwen3ForCausalLM"] diff --git a/torchtitan/experiments/deterministic_vllm_rl/models/qwen3/model_batch_invariant.py b/torchtitan/experiments/deterministic_vllm_rl/models/qwen3/model_batch_invariant.py new file mode 100644 index 0000000000..dd84665091 --- /dev/null +++ b/torchtitan/experiments/deterministic_vllm_rl/models/qwen3/model_batch_invariant.py @@ -0,0 +1,368 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +# +# Qwen3 model compatible with vLLM's implementation +# Uses merged gate_up projections and vLLM Flash Attention + +import torch +from torch import nn + +from torchtitan.components.tokenizer import BaseTokenizer + +# Import gradient-enabled operations from experiment utilities +from torchtitan.experiments.deterministic_vllm_rl.batch_invariant_backward import ( + rms_norm_with_gradients, + silu_and_mul_with_gradients, +) + +# Import from main torchtitan +from torchtitan.models.qwen3.model.args import Qwen3ModelArgs +from torchtitan.protocols.model import AttentionMasksType +from torchtitan.protocols.train_spec import ModelProtocol + +# Import from local experiment's models +from ..attention import VLLMCompatibleFlashAttention + + +# RoPE functions (same as original) +def precompute_rope_cache( + dim: int, max_seq_len: int, base: float = 1_000_000.0 +) -> torch.Tensor: + freqs = 1.0 / (base ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) + t = torch.arange(max_seq_len, dtype=freqs.dtype, device=freqs.device) + idx_theta = torch.outer(t, freqs).float() + freqs = torch.cat([idx_theta, idx_theta], dim=-1) + rope_cache = torch.cat([freqs.cos(), freqs.sin()], dim=-1) + return rope_cache + + +def rotate_half(x: torch.Tensor) -> torch.Tensor: + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def reshape_for_broadcast(rope_cache: torch.Tensor, x: torch.Tensor) -> torch.Tensor: + """Reshape frequency tensor for broadcasting.""" + ndim = x.ndim + assert ndim > 1 + _, seqlen, _, head_dim = x.shape + rope_cache = rope_cache[0:seqlen] + assert rope_cache.shape == (seqlen, head_dim * 2) + shape = [-1, seqlen, 1, head_dim * 2] + return rope_cache.view(*shape) + + +def apply_rotary_emb( + xq: torch.Tensor, xk: torch.Tensor, rope_cache: torch.Tensor +) -> tuple[torch.Tensor, torch.Tensor]: + head_dim = xq.shape[-1] + rope_cache = reshape_for_broadcast(rope_cache, xq) + cos = rope_cache[..., :head_dim].to(dtype=xq.dtype, device=xq.device) + sin = rope_cache[..., head_dim:].to(dtype=xq.dtype, device=xq.device) + xq_out = (xq * cos) + (rotate_half(xq) * sin) + xk_out = (xk * cos) + (rotate_half(xk) * sin) + return xq_out.type_as(xq), xk_out.type_as(xk) + + +def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor: + """torch.repeat_interleave(x, dim=2, repeats=n_rep)""" + bs, slen, n_kv_heads, head_dim = x.shape + if n_rep == 1: + return x + return ( + torch.unsqueeze(x, dim=3) + .expand(bs, slen, n_kv_heads, n_rep, head_dim) + .reshape(bs, slen, n_kv_heads * n_rep, head_dim) + ) + + +class VLLMRMSNorm(nn.Module): + """ + RMSNorm using vLLM's exact Triton kernel for bitwise determinism. + Compatible with PyTorch's nn.RMSNorm interface but uses vLLM's implementation. + + Supports gradients through a custom autograd function that uses vLLM's + kernel for forward and batch-invariant PyTorch ops for backward. + """ + + def __init__(self, hidden_size: int, eps: float = 1e-6): + super().__init__() + self.eps = eps + self.weight = nn.Parameter(torch.ones(hidden_size)) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + # Use vLLM's RMSNorm with gradient support for training + return rms_norm_with_gradients(x, self.weight, self.eps) + + def reset_parameters(self): + nn.init.ones_(self.weight) + + +class FeedForwardVLLMCompat(nn.Module): + """ + FeedForward module compatible with vLLM implementation. + Uses merged gate_up projection like vLLM. + """ + + def __init__( + self, + dim: int, + hidden_dim: int, + ): + super().__init__() + + # Merged gate and up projections (like vLLM's gate_up_proj) + self.gate_up_proj = nn.Linear(dim, hidden_dim * 2, bias=False) + + # Down projection (like vLLM's down_proj) + self.down_proj = nn.Linear(hidden_dim, dim, bias=False) + + def forward(self, x): + # Project to gate and up in one go + gate_up = self.gate_up_proj(x) + # Apply SiluAndMul activation with gradient support + activated = silu_and_mul_with_gradients(gate_up) + # Project down + output = self.down_proj(activated) + return output + + def init_weights(self, init_std: float): + # Initialize like vLLM + nn.init.trunc_normal_(self.gate_up_proj.weight, mean=0.0, std=0.02) + nn.init.trunc_normal_(self.down_proj.weight, mean=0.0, std=init_std) + + +class Attention(nn.Module): + """ + Multi-head attention module compatible with vLLM. + """ + + def __init__(self, model_args: Qwen3ModelArgs): + super().__init__() + self.n_heads = model_args.n_heads + self.n_kv_heads = ( + model_args.n_heads + if model_args.n_kv_heads is None + else model_args.n_kv_heads + ) + self.n_rep = self.n_heads // self.n_kv_heads + self.head_dim = model_args.head_dim + self.scaling = self.head_dim**-0.5 + + # QK norm (Qwen3 specific) - use vLLM's RMSNorm + if model_args.qk_norm: + self.q_norm = VLLMRMSNorm(self.head_dim, eps=model_args.norm_eps) + self.k_norm = VLLMRMSNorm(self.head_dim, eps=model_args.norm_eps) + else: + self.q_norm = None + self.k_norm = None + + # QKV projections + self.wq = nn.Linear( + model_args.dim, model_args.n_heads * self.head_dim, bias=False + ) + self.wk = nn.Linear(model_args.dim, self.n_kv_heads * self.head_dim, bias=False) + self.wv = nn.Linear(model_args.dim, self.n_kv_heads * self.head_dim, bias=False) + self.wo = nn.Linear( + model_args.n_heads * self.head_dim, model_args.dim, bias=False + ) + + # Always use vLLM compatible flash attention + self.inner_attention = VLLMCompatibleFlashAttention() + + def init_weights(self, init_std: float): + for linear in (self.wq, self.wk, self.wv): + nn.init.trunc_normal_(linear.weight, mean=0.0, std=0.02) + nn.init.trunc_normal_(self.wo.weight, mean=0.0, std=init_std) + if self.q_norm is not None: + self.q_norm.reset_parameters() + if self.k_norm is not None: + self.k_norm.reset_parameters() + + def forward( + self, + x: torch.Tensor, + rope_cache: torch.Tensor, + attention_masks: AttentionMasksType | None, + ): + bs, seqlen, _ = x.shape + xq, xk, xv = self.wq(x), self.wk(x), self.wv(x) + + # Reshape to heads + xq = xq.view(bs, seqlen, -1, self.head_dim) + xk = xk.view(bs, seqlen, -1, self.head_dim) + xv = xv.view(bs, seqlen, -1, self.head_dim) + + # Apply QK norm + if self.q_norm: + xq = self.q_norm(xq) + if self.k_norm: + xk = self.k_norm(xk) + + # Apply rotary embedding + xq, xk = apply_rotary_emb(xq, xk, rope_cache) + + # Repeat k/v heads if needed + keys = repeat_kv(xk, self.n_rep) + values = repeat_kv(xv, self.n_rep) + + # Transpose for attention + xq = xq.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim) + xk = keys.transpose(1, 2) + xv = values.transpose(1, 2) + + # Apply flash attention (vLLM compatible, no flex attention) + assert ( + attention_masks is None + ), "vLLM compat mode doesn't use flex attention masks" + output = self.inner_attention(xq, xk, xv, scale=self.scaling) + + # Transpose back + output = output.transpose(1, 2).contiguous() + output = output.view(bs, seqlen, -1) + + return self.wo(output) + + +class TransformerBlock(nn.Module): + """ + TransformerBlock with vLLM-compatible FFN. + """ + + def __init__(self, layer_id: int, model_args: Qwen3ModelArgs): + super().__init__() + self.n_heads = model_args.n_heads + self.dim = model_args.dim + + self.attention = Attention(model_args) + + # Use vLLM-compatible FFN with merged projections + self.feed_forward = FeedForwardVLLMCompat( + dim=model_args.dim, hidden_dim=model_args.hidden_dim + ) + + self.attention_norm = VLLMRMSNorm(model_args.dim, eps=model_args.norm_eps) + self.ffn_norm = VLLMRMSNorm(model_args.dim, eps=model_args.norm_eps) + + if model_args.depth_init: + self.weight_init_std = 0.02 / (2 * (layer_id + 1)) ** 0.5 + else: + self.weight_init_std = 0.02 / (2 * model_args.n_layers) ** 0.5 + + def forward( + self, + x: torch.Tensor, + rope_cache: torch.Tensor, + attention_masks: AttentionMasksType | None, + ): + # Self attention with residual + attn_norm_out = self.attention_norm(x) + x = x + self.attention(attn_norm_out, rope_cache, attention_masks) + + # FFN with residual + ffn_norm_out = self.ffn_norm(x) + x = x + self.feed_forward(ffn_norm_out) + + return x + + def init_weights(self, buffer_device: torch.device): + for norm in (self.attention_norm, self.ffn_norm): + norm.reset_parameters() + self.attention.init_weights(self.weight_init_std) + self.feed_forward.init_weights(self.weight_init_std) + + +class Qwen3VLLMCompatModel(nn.Module, ModelProtocol): + """ + Qwen3 model with vLLM-compatible implementation. + Uses merged gate_up projections and vLLM Flash Attention. + """ + + def __init__(self, model_args: Qwen3ModelArgs): + super().__init__() + self.model_args = model_args + self.vocab_size = model_args.vocab_size + self.n_layers = model_args.n_layers + self.eos_id = model_args.eos_id + self.head_dim = model_args.head_dim + + self.tok_embeddings = nn.Embedding(model_args.vocab_size, model_args.dim) + + self.register_buffer( + "rope_cache", self._precompute_rope_cache(), persistent=False + ) + + self.layers = torch.nn.ModuleDict() + for layer_id in range(model_args.n_layers): + self.layers[str(layer_id)] = TransformerBlock(layer_id, model_args) + + self.norm = VLLMRMSNorm(model_args.dim, eps=model_args.norm_eps) + self.output = nn.Linear(model_args.dim, model_args.vocab_size, bias=False) + + # IMPORTANT: To match vLLM's behavior and Qwen3's config + # (tie_word_embeddings: true), tie output layer weights to + # embedding weights. When either weight updates during training, + # both update together + self.output.weight = self.tok_embeddings.weight + + def init_weights( + self, + buffer_device: torch.device | None = None, + ): + buffer_device = buffer_device or self.rope_cache.device + with torch.device(buffer_device): + self.rope_cache = self._precompute_rope_cache() + if self.tok_embeddings is not None: + nn.init.normal_(self.tok_embeddings.weight) + for layer in self.layers.values(): + if layer is not None: + layer.init_weights(buffer_device) + if self.norm is not None: + self.norm.reset_parameters() + final_out_std = self.model_args.dim**-0.5 + cutoff_factor = 3 + + if self.output is not None: + nn.init.trunc_normal_( + self.output.weight, + mean=0.0, + std=final_out_std, + a=-cutoff_factor * final_out_std, + b=cutoff_factor * final_out_std, + ) + + def _precompute_rope_cache(self) -> torch.Tensor: + return precompute_rope_cache( + self.model_args.head_dim, + self.model_args.max_seq_len, + self.model_args.rope_theta, + ) + + def get_attention_masks( + self, + input_batch: torch.Tensor, + tokenizer: BaseTokenizer, + extra_inputs: dict[str, torch.Tensor] | None = None, + ) -> AttentionMasksType | None: + # vLLM compat mode: no flex attention masks + return None + + def forward( + self, + tokens: torch.Tensor, + attention_masks: AttentionMasksType | None = None, + ): + h = self.tok_embeddings(tokens) if self.tok_embeddings else tokens + + for layer in self.layers.values(): + h = layer(h, self.rope_cache, attention_masks) + + h = self.norm(h) if self.norm else h + output = self.output(h) if self.output else h + + return output diff --git a/torchtitan/experiments/deterministic_vllm_rl/models/qwen3/model_vllm_compat.py b/torchtitan/experiments/deterministic_vllm_rl/models/qwen3/model_vllm_compat.py index dd84665091..10b4841c31 100644 --- a/torchtitan/experiments/deterministic_vllm_rl/models/qwen3/model_vllm_compat.py +++ b/torchtitan/experiments/deterministic_vllm_rl/models/qwen3/model_vllm_compat.py @@ -3,366 +3,370 @@ # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -# -# Qwen3 model compatible with vLLM's implementation -# Uses merged gate_up projections and vLLM Flash Attention import torch -from torch import nn - -from torchtitan.components.tokenizer import BaseTokenizer - -# Import gradient-enabled operations from experiment utilities -from torchtitan.experiments.deterministic_vllm_rl.batch_invariant_backward import ( - rms_norm_with_gradients, - silu_and_mul_with_gradients, -) -# Import from main torchtitan -from torchtitan.models.qwen3.model.args import Qwen3ModelArgs -from torchtitan.protocols.model import AttentionMasksType -from torchtitan.protocols.train_spec import ModelProtocol - -# Import from local experiment's models -from ..attention import VLLMCompatibleFlashAttention - - -# RoPE functions (same as original) -def precompute_rope_cache( - dim: int, max_seq_len: int, base: float = 1_000_000.0 -) -> torch.Tensor: - freqs = 1.0 / (base ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) - t = torch.arange(max_seq_len, dtype=freqs.dtype, device=freqs.device) - idx_theta = torch.outer(t, freqs).float() - freqs = torch.cat([idx_theta, idx_theta], dim=-1) - rope_cache = torch.cat([freqs.cos(), freqs.sin()], dim=-1) - return rope_cache - - -def rotate_half(x: torch.Tensor) -> torch.Tensor: - """Rotates half the hidden dims of the input.""" - x1 = x[..., : x.shape[-1] // 2] - x2 = x[..., x.shape[-1] // 2 :] - return torch.cat((-x2, x1), dim=-1) - - -def reshape_for_broadcast(rope_cache: torch.Tensor, x: torch.Tensor) -> torch.Tensor: - """Reshape frequency tensor for broadcasting.""" - ndim = x.ndim - assert ndim > 1 - _, seqlen, _, head_dim = x.shape - rope_cache = rope_cache[0:seqlen] - assert rope_cache.shape == (seqlen, head_dim * 2) - shape = [-1, seqlen, 1, head_dim * 2] - return rope_cache.view(*shape) - - -def apply_rotary_emb( - xq: torch.Tensor, xk: torch.Tensor, rope_cache: torch.Tensor -) -> tuple[torch.Tensor, torch.Tensor]: - head_dim = xq.shape[-1] - rope_cache = reshape_for_broadcast(rope_cache, xq) - cos = rope_cache[..., :head_dim].to(dtype=xq.dtype, device=xq.device) - sin = rope_cache[..., head_dim:].to(dtype=xq.dtype, device=xq.device) - xq_out = (xq * cos) + (rotate_half(xq) * sin) - xk_out = (xk * cos) + (rotate_half(xk) * sin) - return xq_out.type_as(xq), xk_out.type_as(xk) - - -def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor: - """torch.repeat_interleave(x, dim=2, repeats=n_rep)""" - bs, slen, n_kv_heads, head_dim = x.shape - if n_rep == 1: - return x - return ( - torch.unsqueeze(x, dim=3) - .expand(bs, slen, n_kv_heads, n_rep, head_dim) - .reshape(bs, slen, n_kv_heads * n_rep, head_dim) - ) - - -class VLLMRMSNorm(nn.Module): - """ - RMSNorm using vLLM's exact Triton kernel for bitwise determinism. - Compatible with PyTorch's nn.RMSNorm interface but uses vLLM's implementation. +import torch.nn as nn +from torch.distributed.device_mesh import init_device_mesh +from torch.distributed.tensor import DTensor - Supports gradients through a custom autograd function that uses vLLM's - kernel for forward and batch-invariant PyTorch ops for backward. - """ +from vllm.config import VllmConfig +from vllm.logger import init_logger - def __init__(self, hidden_size: int, eps: float = 1e-6): - super().__init__() - self.eps = eps - self.weight = nn.Parameter(torch.ones(hidden_size)) +from torchtitan.experiments.deterministic_vllm_rl.models.attention import ( + VLLMPagedFlashAttention, +) +from torchtitan.models.qwen3.infra.parallelize import apply_non_moe_tp +from torchtitan.tools.utils import device_type - def forward(self, x: torch.Tensor) -> torch.Tensor: - # Use vLLM's RMSNorm with gradient support for training - return rms_norm_with_gradients(x, self.weight, self.eps) - def reset_parameters(self): - nn.init.ones_(self.weight) +logger = init_logger(__name__) -class FeedForwardVLLMCompat(nn.Module): +class TorchTitanQwen3ForCausalLM(nn.Module): """ - FeedForward module compatible with vLLM implementation. - Uses merged gate_up projection like vLLM. + vLLM-compatible wrapper for TorchTitan's Qwen3 model. + + This class integrates TorchTitan's Qwen3Model with vLLM by: + 1. Importing TorchTitan's model architecture + 2. Replacing attention with vLLM's Attention with PagedAttention and kv cache capability. + 3. Implementing the vLLM model interface """ + is_text_generation_model = True # Required for vLLM runner validation + supports_pp = False # Pipeline parallelism not supported yet + supports_multimodal = False + def __init__( self, - dim: int, - hidden_dim: int, + *, + vllm_config: VllmConfig, + prefix: str = "", # This is required for vLLM interface ): super().__init__() - # Merged gate and up projections (like vLLM's gate_up_proj) - self.gate_up_proj = nn.Linear(dim, hidden_dim * 2, bias=False) + # vLLM config is required + assert vllm_config is not None, "vllm_config is required" + + # Import TorchTitan's Qwen3 model (deferred import to avoid CUDA init issues) + from torchtitan.models.qwen3.model.args import Qwen3ModelArgs + from torchtitan.models.qwen3.model.model import Qwen3Model + + # Map HuggingFace config to TorchTitan ModelArgs + logger.info("vllm config: " + str(vllm_config.__class__)) + hf_config = vllm_config.model_config.hf_config + logger.info("hf_config: " + str(hf_config)) + model_args = Qwen3ModelArgs( + vocab_size=getattr(hf_config, "vocab_size", 151936), + dim=getattr(hf_config, "hidden_size", 2048), + n_layers=getattr(hf_config, "num_hidden_layers", 4), + n_heads=getattr(hf_config, "num_attention_heads", 16), + n_kv_heads=getattr(hf_config, "num_key_value_heads", 2), + head_dim=getattr(hf_config, "head_dim", 128), + hidden_dim=getattr(hf_config, "intermediate_size", 11008), + norm_eps=getattr(hf_config, "rms_norm_eps", 1e-6), + max_seq_len=getattr(hf_config, "max_position_embeddings", 8192), + rope_theta=getattr(hf_config, "rope_theta", 1000000.0), + qk_norm=getattr(hf_config, "qk_norm", True), + ) - # Down projection (like vLLM's down_proj) - self.down_proj = nn.Linear(hidden_dim, dim, bias=False) + print(f"In attention initialization, model args are : {model_args}") - def forward(self, x): - # Project to gate and up in one go - gate_up = self.gate_up_proj(x) - # Apply SiluAndMul activation with gradient support - activated = silu_and_mul_with_gradients(gate_up) - # Project down - output = self.down_proj(activated) - return output + # Create TorchTitan model + self.model = Qwen3Model(model_args) + self.config = model_args - def init_weights(self, init_std: float): - # Initialize like vLLM - nn.init.trunc_normal_(self.gate_up_proj.weight, mean=0.0, std=0.02) - nn.init.trunc_normal_(self.down_proj.weight, mean=0.0, std=init_std) + self._replice_with_vllm_paged_attention(model_args) + ( + dp_size, + mp_size, + cp_size, + pp_size, + ep_size, + etp_size, + ) = self._process_parallelism_settings(vllm_config) -class Attention(nn.Module): - """ - Multi-head attention module compatible with vLLM. - """ + # Build device mesh and apply parallelization + if mp_size > 1 or ep_size > 1: + self._build_device_mesh_and_parallelize( + dp_size, mp_size, cp_size, pp_size, ep_size, etp_size + ) - def __init__(self, model_args: Qwen3ModelArgs): - super().__init__() - self.n_heads = model_args.n_heads - self.n_kv_heads = ( - model_args.n_heads - if model_args.n_kv_heads is None - else model_args.n_kv_heads + def _replice_with_vllm_paged_attention(self, model_args): + # The `vllm.Attention` module handles QKV projection, RoPE, etc., and calls `inner_attention` + if not hasattr(self.model, "layers"): + raise AttributeError( + f"Model {type(self.model).__name__} must have .layers attribute" + ) + + for layer_name, layer in self.model.layers.items(): + if not hasattr(layer, "attention"): + raise ValueError(f"Layer {layer_name} must have .attention attribute") + + vllm_attn = VLLMPagedFlashAttention( + hidden_size=model_args.dim, + num_heads=model_args.n_heads, # 16 (8 when TP =2) + # NOTE(jianiw): Before feeding into inner_attention, the n_kv_heads has been replicated -> num_heads + num_kv_heads=model_args.n_heads, # 16 (8 When TP=2) + head_dim=model_args.head_dim, + causal=True, + ) + + layer.attention.inner_attention = vllm_attn + logger.info( + "Successfully replaced TorchTitan attention with vLLM PagedFlashAttention" + ) + + def _process_parallelism_settings( + self, vllm_config: VllmConfig, use_token_shuffling_moe: bool = False + ): + """ + Parse parallel config from vllm config + """ + world_size = ( + vllm_config.parallel_config.data_parallel_size + * vllm_config.parallel_config.tensor_parallel_size ) - self.n_rep = self.n_heads // self.n_kv_heads - self.head_dim = model_args.head_dim - self.scaling = self.head_dim**-0.5 - - # QK norm (Qwen3 specific) - use vLLM's RMSNorm - if model_args.qk_norm: - self.q_norm = VLLMRMSNorm(self.head_dim, eps=model_args.norm_eps) - self.k_norm = VLLMRMSNorm(self.head_dim, eps=model_args.norm_eps) - else: - self.q_norm = None - self.k_norm = None - - # QKV projections - self.wq = nn.Linear( - model_args.dim, model_args.n_heads * self.head_dim, bias=False + ep_size = ( + world_size if vllm_config.parallel_config.enable_expert_parallel else 1 ) - self.wk = nn.Linear(model_args.dim, self.n_kv_heads * self.head_dim, bias=False) - self.wv = nn.Linear(model_args.dim, self.n_kv_heads * self.head_dim, bias=False) - self.wo = nn.Linear( - model_args.n_heads * self.head_dim, model_args.dim, bias=False + etp_size = ( + 1 if vllm_config.parallel_config.enable_expert_parallel else world_size ) + dp_size = vllm_config.parallel_config.data_parallel_size - # Always use vLLM compatible flash attention - self.inner_attention = VLLMCompatibleFlashAttention() + mp_size = vllm_config.parallel_config.tensor_parallel_size + cp_size = vllm_config.parallel_config.decode_context_parallel_size + pp_size = vllm_config.parallel_config.pipeline_parallel_size + self.pp_size = pp_size - def init_weights(self, init_std: float): - for linear in (self.wq, self.wk, self.wv): - nn.init.trunc_normal_(linear.weight, mean=0.0, std=0.02) - nn.init.trunc_normal_(self.wo.weight, mean=0.0, std=init_std) - if self.q_norm is not None: - self.q_norm.reset_parameters() - if self.k_norm is not None: - self.k_norm.reset_parameters() + return dp_size, mp_size, cp_size, pp_size, ep_size, etp_size - def forward( + def _build_device_mesh_and_parallelize( self, - x: torch.Tensor, - rope_cache: torch.Tensor, - attention_masks: AttentionMasksType | None, + dp_size: int, + mp_size: int, + cp_size: int, + pp_size: int, + ep_size: int, + etp_size: int, ): - bs, seqlen, _ = x.shape - xq, xk, xv = self.wq(x), self.wk(x), self.wv(x) - - # Reshape to heads - xq = xq.view(bs, seqlen, -1, self.head_dim) - xk = xk.view(bs, seqlen, -1, self.head_dim) - xv = xv.view(bs, seqlen, -1, self.head_dim) - - # Apply QK norm - if self.q_norm: - xq = self.q_norm(xq) - if self.k_norm: - xk = self.k_norm(xk) - - # Apply rotary embedding - xq, xk = apply_rotary_emb(xq, xk, rope_cache) - - # Repeat k/v heads if needed - keys = repeat_kv(xk, self.n_rep) - values = repeat_kv(xv, self.n_rep) - - # Transpose for attention - xq = xq.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim) - xk = keys.transpose(1, 2) - xv = values.transpose(1, 2) - - # Apply flash attention (vLLM compatible, no flex attention) - assert ( - attention_masks is None - ), "vLLM compat mode doesn't use flex attention masks" - output = self.inner_attention(xq, xk, xv, scale=self.scaling) - - # Transpose back - output = output.transpose(1, 2).contiguous() - output = output.view(bs, seqlen, -1) - - return self.wo(output) - - -class TransformerBlock(nn.Module): - """ - TransformerBlock with vLLM-compatible FFN. - """ + """ + Build device mesh in TorchTitan style and apply parallelization to the model. + + This follows the same approach as TorchTitan's ParallelDims.build_mesh() + and parallelize_qwen3() functions. + """ + import torch.distributed as dist + + # Get world size and validate + world_size = dist.get_world_size() + + # For now, assume dp_shard=1 (no data parallel sharding) + # In full implementation, you may need to calculate dp_replicate and dp_shard + dp_replicate = dp_size + dp_shard = 1 + + # Validate parallelism settings + assert dp_replicate * dp_shard * cp_size * mp_size * pp_size == world_size, ( + f"Invalid parallel dims: dp_replicate({dp_replicate}) * dp_shard({dp_shard}) * " + f"cp({cp_size}) * tp({mp_size}) * pp({pp_size}) != WORLD_SIZE({world_size})" + ) - def __init__(self, layer_id: int, model_args: Qwen3ModelArgs): - super().__init__() - self.n_heads = model_args.n_heads - self.dim = model_args.dim + # Build device mesh following TorchTitan's _build_mesh_without_ep pattern + # (assuming no EP for now) + dims = [] + names = [] + for d, name in zip( + [pp_size, dp_replicate, dp_shard, cp_size, mp_size], + ["pp", "dp_replicate", "dp_shard", "cp", "tp"], + ): + if d > 1: + dims.append(d) + names.append(name) + + logger.info(f"Building {len(dims)}-D device mesh with {names}, {dims}") + world_mesh = init_device_mesh(device_type, dims, mesh_dim_names=names) + + logger.info(f"Build torchtitan device mesh: {world_mesh}") + + # Apply tensor parallelism if enabled + if mp_size > 1: + tp_mesh = world_mesh["tp"] + apply_non_moe_tp( + model=self.model, + tp_mesh=tp_mesh, + loss_parallel=False, # vLLM handles loss computation separately + enable_float8_tensorwise_tp=False, # Can be enabled if needed + enable_async_tp=False, # Can be enabled if needed + ) + logger.info(f"Applied Tensor Parallelism with TP={mp_size}") - self.attention = Attention(model_args) + # Store the mesh for future use + self.world_mesh = world_mesh - # Use vLLM-compatible FFN with merged projections - self.feed_forward = FeedForwardVLLMCompat( - dim=model_args.dim, hidden_dim=model_args.hidden_dim - ) + def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: + """Convert input token IDs to embeddings. - self.attention_norm = VLLMRMSNorm(model_args.dim, eps=model_args.norm_eps) - self.ffn_norm = VLLMRMSNorm(model_args.dim, eps=model_args.norm_eps) + This is the vLLM-standard method name for embedding tokens. + """ + return self.model.tok_embeddings(input_ids) - if model_args.depth_init: - self.weight_init_std = 0.02 / (2 * (layer_id + 1)) ** 0.5 - else: - self.weight_init_std = 0.02 / (2 * model_args.n_layers) ** 0.5 + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + """Convert input token IDs to embeddings (deprecated vLLM interface).""" + return self.embed_input_ids(input_ids) def forward( self, - x: torch.Tensor, - rope_cache: torch.Tensor, - attention_masks: AttentionMasksType | None, - ): - # Self attention with residual - attn_norm_out = self.attention_norm(x) - x = x + self.attention(attn_norm_out, rope_cache, attention_masks) + input_ids: torch.Tensor | None = None, + positions: torch.Tensor | None = None, + inputs_embeds: torch.Tensor | None = None, + **kwargs, + ) -> torch.Tensor: + """ + Forward pass with vLLM interface. + + Args: + input_ids: Token IDs from vLLM [total_tokens] (1D varlen format) + positions: Position indices from vLLM [total_tokens] (1D varlen format) + inputs_embeds: Pre-computed embeddings (optional, used by vLLM) + **kwargs: Additional vLLM kwargs + + Returns: + hidden_states: Final hidden states [total_tokens, hidden_size] + """ + # Handle inputs_embeds vs input_ids properly + if inputs_embeds is not None: + raise NotImplementedError( + "inputs_embeds is not yet supported by TorchTitan Qwen3. " + "The model expects token IDs and computes embeddings internally. " + "Please provide input_ids instead." + ) - # FFN with residual - ffn_norm_out = self.ffn_norm(x) - x = x + self.feed_forward(ffn_norm_out) + if input_ids is None: + raise ValueError("Either input_ids or inputs_embeds must be provided") - return x + # Convert vLLM interface to TorchTitan interface + # vLLM passes input_ids as [total_tokens] but TorchTitan expects [batch_size, seq_len] + # For now, reshape to [1, total_tokens] as a simple batch of 1 + # TODO: In future, use attn_metadata.seq_lens to properly reconstruct batch structure + tokens_2d = input_ids.unsqueeze(0) # [total_tokens] -> [1, total_tokens] - def init_weights(self, buffer_device: torch.device): - for norm in (self.attention_norm, self.ffn_norm): - norm.reset_parameters() - self.attention.init_weights(self.weight_init_std) - self.feed_forward.init_weights(self.weight_init_std) + # Store positions in forward context for attention layers + # Also convert positions to 2D format + # TODO: The position id information is not properly used yet + if positions is not None: + positions_2d = positions.unsqueeze(0) # [total_tokens] -> [1, total_tokens] + # Get embeddings from 2D tokens + h = self.model.tok_embeddings(tokens_2d) # [1, total_tokens, hidden_size] -class Qwen3VLLMCompatModel(nn.Module, ModelProtocol): - """ - Qwen3 model with vLLM-compatible implementation. - Uses merged gate_up projections and vLLM Flash Attention. - """ + # Get RoPE cache indexed by positions + rope_cache = self.model.rope_cache[positions] - def __init__(self, model_args: Qwen3ModelArgs): - super().__init__() - self.model_args = model_args - self.vocab_size = model_args.vocab_size - self.n_layers = model_args.n_layers - self.eos_id = model_args.eos_id - self.head_dim = model_args.head_dim + # Pass through transformer layers + for layer in self.model.layers.values(): + h = layer(h, rope_cache, attention_masks=None) - self.tok_embeddings = nn.Embedding(model_args.vocab_size, model_args.dim) + # Convert output format back to vLLM expectations + # vLLM expects hidden_states in [total_tokens, hidden_size] format + # TorchTitan returns [batch_size, seq_len, hidden_size], so we need to flatten + if h.dim() == 3: # [batch_size, seq_len, hidden_size] + batch_size, seq_len, hidden_size = h.shape + h = h.view(batch_size * seq_len, hidden_size) # [total_tokens, hidden_size] - self.register_buffer( - "rope_cache", self._precompute_rope_cache(), persistent=False - ) - - self.layers = torch.nn.ModuleDict() - for layer_id in range(model_args.n_layers): - self.layers[str(layer_id)] = TransformerBlock(layer_id, model_args) - - self.norm = VLLMRMSNorm(model_args.dim, eps=model_args.norm_eps) - self.output = nn.Linear(model_args.dim, model_args.vocab_size, bias=False) - - # IMPORTANT: To match vLLM's behavior and Qwen3's config - # (tie_word_embeddings: true), tie output layer weights to - # embedding weights. When either weight updates during training, - # both update together - self.output.weight = self.tok_embeddings.weight + # TODO(jianiw): explicitly insert communication and return full tensor to vLLM Engine. To be checked. + if isinstance(h, DTensor): + h = h.full_tensor() + return h - def init_weights( + def compute_logits( self, - buffer_device: torch.device | None = None, - ): - buffer_device = buffer_device or self.rope_cache.device - with torch.device(buffer_device): - self.rope_cache = self._precompute_rope_cache() - if self.tok_embeddings is not None: - nn.init.normal_(self.tok_embeddings.weight) - for layer in self.layers.values(): - if layer is not None: - layer.init_weights(buffer_device) - if self.norm is not None: - self.norm.reset_parameters() - final_out_std = self.model_args.dim**-0.5 - cutoff_factor = 3 - - if self.output is not None: - nn.init.trunc_normal_( - self.output.weight, - mean=0.0, - std=final_out_std, - a=-cutoff_factor * final_out_std, - b=cutoff_factor * final_out_std, - ) + hidden_states: torch.Tensor, + sampling_metadata=None, + ) -> torch.Tensor | None: + """Compute logits from hidden states. + + Returns: + Logits tensor, or None if TP rank > 0 + """ + # Apply final layer norm + h = self.model.norm(hidden_states) + + # Apply output projection to get logits + logits = self.model.output(h) + + # When using TP, only rank 0 returns logits + # vLLM expects None from other ranks + if isinstance(logits, DTensor): + # Convert DTensor to local tensor for vLLM + logits = logits.full_tensor() + + return logits + + def load_weights(self, weights_iter): + """ + Uses TorchTitan's Qwen3StateDictAdapter to map HF → TorchTitan naming, + then uses set_model_state_dict for proper distributed tensor handling. + + Args: + weights_iter: Iterator of (name, tensor) pairs from HF checkpoint + + Returns: + Set of loaded parameter names (for vLLM compatibility) + """ + from torch.distributed._tensor import DTensor, Replicate + from torch.distributed.checkpoint.state_dict import ( + set_model_state_dict, + StateDictOptions, + ) - def _precompute_rope_cache(self) -> torch.Tensor: - return precompute_rope_cache( - self.model_args.head_dim, - self.model_args.max_seq_len, - self.model_args.rope_theta, + from torchtitan.models.qwen3.model.state_dict_adapter import ( + Qwen3StateDictAdapter, ) - def get_attention_masks( - self, - input_batch: torch.Tensor, - tokenizer: BaseTokenizer, - extra_inputs: dict[str, torch.Tensor] | None = None, - ) -> AttentionMasksType | None: - # vLLM compat mode: no flex attention masks - return None + # Collect weights from iterator into a dict + hf_state_dict = {} + for name, tensor in weights_iter: + hf_state_dict[name] = tensor - def forward( - self, - tokens: torch.Tensor, - attention_masks: AttentionMasksType | None = None, - ): - h = self.tok_embeddings(tokens) if self.tok_embeddings else tokens + # Use TorchTitan's adapter to convert HF → TorchTitan format + adapter = Qwen3StateDictAdapter( + model_args=self.config, + hf_assets_path=None, # Not needed for from_hf conversion + ) - for layer in self.layers.values(): - h = layer(h, self.rope_cache, attention_masks) + torchtitan_state_dict = adapter.from_hf(hf_state_dict) + model_state_dict = {k: v for k, v in self.model.state_dict().items()} + + # Convert HF tensors to replicate DTensor if target is DTensor + for name, tensor in torchtitan_state_dict.items(): + if name in model_state_dict and isinstance(model_state_dict[name], DTensor): + # Get the device mesh from the target DTensor + target_dtensor = model_state_dict[name] + device_mesh = target_dtensor.device_mesh + # Convert to replicate DTensor + torchtitan_state_dict[name] = DTensor.from_local( + tensor.to(device_mesh.device_type), + device_mesh=device_mesh, + placements=[Replicate()], + ) + + # Use TorchTitan's distributed state dict loading + # This handles TP/PP sharding automatically + set_model_state_dict( + model=self.model, + model_state_dict=torchtitan_state_dict, + options=StateDictOptions( + strict=False, # Allow missing keys + ), + ) - h = self.norm(h) if self.norm else h - output = self.output(h) if self.output else h + # manually patch the loaded + loaded_params = {f"model.{name}" for name in torchtitan_state_dict.keys()} + logger.info( + f"Loaded {len(loaded_params)} parameters from checkpoint using distributed-aware loading" + ) - return output + return loaded_params diff --git a/torchtitan/experiments/vllm/register.py b/torchtitan/experiments/deterministic_vllm_rl/register.py similarity index 87% rename from torchtitan/experiments/vllm/register.py rename to torchtitan/experiments/deterministic_vllm_rl/register.py index e4803c492d..d7ee8c3b12 100644 --- a/torchtitan/experiments/vllm/register.py +++ b/torchtitan/experiments/deterministic_vllm_rl/register.py @@ -16,9 +16,12 @@ def register(): - Qwen3TorchTitanForCausalLM: Qwen3 models trained with TorchTitan """ - from torchtitan.experiments.vllm.model.qwen3 import TorchTitanQwen3ForCausalLM from vllm.model_executor.models.registry import ModelRegistry + from torchtitan.experiments.deterministic_vllm_rl.models.qwen3 import ( + TorchTitanQwen3ForCausalLM, + ) + # Register Qwen3TorchTitanForCausalLM with vLLM's ModelRegistry # This maps the architecture name from config.json to the model class ModelRegistry.register_model( diff --git a/torchtitan/experiments/vllm/README.md b/torchtitan/experiments/vllm/README.md deleted file mode 100644 index 7b23b9ced7..0000000000 --- a/torchtitan/experiments/vllm/README.md +++ /dev/null @@ -1,64 +0,0 @@ -# TorchTitan Qwen3 Model with vLLM Inference - -This directory contains code to run TorchTitan model definition with vLLM inference engine. - -## Overview - -The integration consists of two main components: - -1. **Model Adapter** (`model/qwen3.py`): A custom model class that extends vLLM's `Qwen3ForCausalLM` to handle TorchTitan checkpoint naming conventions -2. **Inference Script** (`infer.py`): A simple script to register the model and run inference - - -## Quick Start - -### Prerequisites - - -1. Install PyTorch nightly for torchtitan: -``` -pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu126 --force-reinstall -``` - - -2. Install vLLM from source [vllm-use-an-existing-pytorch-installation](https://docs.vllm.ai/en/latest/getting_started/installation/gpu/index.html#use-an-existing-pytorch-installation): -```bash -# install PyTorch first, either from PyPI or from source -git clone https://github.com/vllm-project/vllm.git -cd vllm -python use_existing_torch.py -uv pip install -r requirements/build.txt -uv pip install --no-build-isolation -e . -``` - - -NOTE: If `flash_attn_varlen_func` hits error "torch.AcceleratorError: CUDA error: the provided PTX was compiled with an unsupported toolchain" during forward path, this is due to GPU driver version is not compatible with vLLM/PyTorch compiled version. Use the following command to recompile vLLM. - -``` -# Set CUDA version environment variable -export CUDA_HOME=/usr/local/cuda-12.4 -export PATH=/usr/local/cuda-12.4/bin:$PATH -export LD_LIBRARY_PATH=/usr/local/cuda-12.4/lib64:$LD_LIBRARY_PATH - -# Clean previous build -rm -rf build dist *.egg-info -pip uninstall -y vllm - -# Rebuild vLLM from source with CUDA 12.4 -pip install -e . - -``` - -3. Download Qwen3/Qwen3-0.6b checkpoint from HuggingFace and put into `example_checkpoint` folder. Make sure to change the "architecture" field in `config.json` to be `Qwen3TorchTitanForCausalLM` so vllm engine could use torchtitan model. - - -4. Run inference: -``` -python torchtitan/experiments/vllm/infer.py --model torchtitan/experiments/vllm/example_checkpoint/qwen3-0.6B -``` - -Run with TP: (work in progress) -``` -python torchtitan/experiments/vllm/infer.py --model torchtitan/experiments/vllm/example_checkpoint/qwen3-0.6B --tensor-parallel-size 2 - -``` diff --git a/torchtitan/experiments/vllm/custom_models/__init__.py b/torchtitan/experiments/vllm/custom_models/__init__.py deleted file mode 100644 index f036e77c62..0000000000 --- a/torchtitan/experiments/vllm/custom_models/__init__.py +++ /dev/null @@ -1,19 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -"""Utilities for integrating custom model implementations with vLLM.""" - -from .custom_model_wrapper import VLLMModelForCausalLM -from .utils import store_positions_in_context - -__all__ = [ - # Base wrapper - "VLLMModelForCausalLM", - # Utilities - "store_positions_in_context", -] diff --git a/torchtitan/experiments/vllm/custom_models/custom_model_wrapper.py b/torchtitan/experiments/vllm/custom_models/custom_model_wrapper.py deleted file mode 100644 index d30c711def..0000000000 --- a/torchtitan/experiments/vllm/custom_models/custom_model_wrapper.py +++ /dev/null @@ -1,133 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -""" -Base class for integrating custom model implementations with vLLM. - -This module provides an abstract base class that enforces the vLLM model interface, -making it easy to integrate external model implementations (e.g., from TorchTitan, -NanoGPT, etc.) with vLLM. - -Example usage: - ```python - from some_external_lib.models import ExternalModel - from . import VLLMModelForCausalLM - - - class MyCustomModelForCausalLM(VLLMModelForCausalLM): - def __init__(self, vllm_config, parallel_context=None, **kwargs): - super().__init__() - self.model = ExternalModel(...) # Create external model - # Replace attention layers with vLLM's trainable attention - - def get_input_embeddings(self, input_ids): - return self.model.tok_embeddings(input_ids) - - def forward(self, input_ids, positions=None, **kwargs): - # Forward pass - return hidden_states - - def compute_logits(self, hidden_states, sampling_metadata=None): - return self.model.output(hidden_states) - - def load_weights(self, weights_iter): - # Load weights from HuggingFace checkpoint - pass - ``` -""" - -from abc import ABC, abstractmethod -from collections.abc import Iterator -from typing import Any - -import torch -import torch.nn as nn - - -class VLLMModelForCausalLM(nn.Module, ABC): - """ - Abstract base class for integrating custom model implementations with vLLM. - - This class enforces the vLLM model interface that all text generation models - must implement. Subclasses should: - 1. Import and instantiate the external model in __init__ - 2. Replace attention layers with vLLM's trainable attention - 3. Implement the abstract methods below - - Class attributes: - supports_pp: Whether pipeline parallelism is supported - supports_multimodal: Whether multimodal inputs are supported - """ - - supports_pp: bool = False - supports_multimodal: bool = False - - @abstractmethod - def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: - """ - Convert input token IDs to embeddings. - - Args: - input_ids: Token IDs [batch, seq_len] or [total_tokens] - - Returns: - Embeddings [batch, seq_len, hidden_size] or [total_tokens, hidden_size] - """ - pass - - @abstractmethod - def forward( - self, - input_ids: torch.Tensor, - positions: torch.Tensor | None = None, - **kwargs: Any, - ) -> torch.Tensor: - """ - Forward pass through the model. - - Args: - input_ids: Token IDs [batch, seq_len] or [total_tokens] - positions: Position indices from vLLM for RoPE indexing - **kwargs: Additional vLLM-specific arguments - - Returns: - Hidden states before final projection [batch, seq_len, hidden_size] - or [total_tokens, hidden_size] - """ - pass - - @abstractmethod - def compute_logits( - self, - hidden_states: torch.Tensor, - sampling_metadata: Any = None, - ) -> torch.Tensor: - """ - Compute logits from hidden states. - - Args: - hidden_states: Output from forward() [batch, seq_len, hidden_size] - sampling_metadata: vLLM sampling metadata (optional) - - Returns: - Logits [batch, seq_len, vocab_size] - """ - pass - - @abstractmethod - def load_weights(self, weights_iter: Iterator[tuple[str, torch.Tensor]]) -> None: - """ - Load weights from HuggingFace checkpoint. - - This method should map HuggingFace weight names to model parameter names - and load them into the model. - - Args: - weights_iter: Iterator yielding (name, tensor) tuples from HF checkpoint - """ - pass diff --git a/torchtitan/experiments/vllm/custom_models/utils.py b/torchtitan/experiments/vllm/custom_models/utils.py deleted file mode 100644 index 5ea546d563..0000000000 --- a/torchtitan/experiments/vllm/custom_models/utils.py +++ /dev/null @@ -1,245 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -""" -Utility functions for TorchTitan model integration with vLLM. - -This module provides helper functions for common TorchTitan integration tasks: -- Converting RoPE frequency tensors from complex to real format -- Managing forward context for position indices -- Creating KV cache specifications for MLA -- Loading weights from HuggingFace checkpoints with name mapping - -Example usage: - ```python - from vllm.model_executor.utils.torchtitan_utils import ( - convert_freqs_cis_to_real, - create_mla_kv_cache_spec, - ) - - # Convert TorchTitan's complex freqs_cis to vLLM-compatible format - model.freqs_cis = convert_freqs_cis_to_real(model.freqs_cis) - - # Create KV cache spec for MLA attention - kv_cache_spec = create_mla_kv_cache_spec( - kv_lora_rank=512, - qk_rope_head_dim=64, - vllm_config=vllm_config, - ) - ``` -""" - -from collections.abc import Iterator -from typing import Any - -import torch -import torch.nn as nn - - -def convert_freqs_cis_to_real(freqs_cis: torch.Tensor) -> torch.Tensor: - """ - Convert complex RoPE frequencies to real format (cos, sin concatenated). - - TorchTitan uses complex exponentials e^(i*theta) for RoPE, but vLLM and - dtype conversion to bfloat16 require real tensors. This converts: - complex[max_seq_len, dim//2] -> real[max_seq_len, dim] - - Args: - freqs_cis: Complex frequency tensor [max_seq_len, dim//2] - - Returns: - Real tensor [max_seq_len, dim] with cos and sin concatenated - """ - if not freqs_cis.is_complex(): - # Already in real format - return freqs_cis - - # Extract cos and sin from complex exponentials - # e^(i*theta) = cos(theta) + i*sin(theta) - freqs_cos = freqs_cis.real # [max_seq_len, dim//2] - freqs_sin = freqs_cis.imag # [max_seq_len, dim//2] - - # Concatenate: [max_seq_len, dim] - freqs_real = torch.cat([freqs_cos, freqs_sin], dim=-1) - - return freqs_real - - -def store_positions_in_context(positions: torch.Tensor | None) -> None: - """ - Store position indices in vLLM's forward context. - - This allows attention layers to access per-token positions for RoPE indexing - during inference without explicitly passing them through every layer. - - Args: - positions: Position indices from vLLM [total_tokens] or None - """ - if positions is None: - return - - try: - from vllm.forward_context import get_forward_context - - forward_ctx = get_forward_context() - # Store positions in a custom attribute accessible to attention layers - forward_ctx._torchtitan_positions = positions - except (ImportError, RuntimeError, AttributeError): - # Not in vLLM context - this is fine (e.g., during testing) - pass - - -def create_mla_kv_cache_spec( - kv_lora_rank: int, - qk_rope_head_dim: int, - block_size: int, - dtype: torch.dtype, -) -> Any: - """ - Create KV cache specification for Multi-Head Latent Attention (MLA). - - MLA uses compressed KV cache with layout: - [kv_lora_rank + qk_rope_head_dim] per token - - Args: - kv_lora_rank: LoRA rank for compressed KV (e.g., 512) - qk_rope_head_dim: Dimension of RoPE-encoded keys (e.g., 64) - block_size: KV cache block size from vLLM config - dtype: Data type for KV cache - - Returns: - MLAAttentionSpec instance - """ - from vllm.v1.kv_cache_interface import MLAAttentionSpec - - # MLA cache layout: compressed KV + shared K_PE - head_size = kv_lora_rank + qk_rope_head_dim - - return MLAAttentionSpec( - block_size=block_size, - num_kv_heads=1, # MLA shares K_PE across all heads - head_size=head_size, - dtype=dtype, - ) - - -def load_external_weights( - model: nn.Module, - weights_iter: Iterator[tuple[str, torch.Tensor]], - name_mapping: dict[str, str], - verbose: bool = False, -) -> tuple[int, int]: - """ - Load weights from HuggingFace checkpoint into external model. - - Maps HuggingFace parameter names to model parameter names and loads - them into the model. Supports layer-specific patterns with {} placeholders. - - Args: - model: Model instance to load weights into - weights_iter: Iterator yielding (name, tensor) from HF checkpoint - name_mapping: Dict mapping HF names to model parameter names. - Use {} as placeholder for layer numbers, e.g.: - {"model.layers.{}.attn.weight": "layers.{}.attention.weight"} - verbose: Whether to print detailed loading progress - - Returns: - Tuple of (loaded_count, skipped_count) - - Example: - ```python - name_mapping = { - "model.embed_tokens.weight": "tok_embeddings.weight", - "model.layers.{}.self_attn.q_proj.weight": "layers.{}.attention.wq.weight", - "lm_head.weight": "output.weight", - } - loaded, skipped = load_external_weights(model, weights_iter, name_mapping) - ``` - """ - from vllm.model_executor.model_loader.weight_utils import default_weight_loader - - # Get all parameter names in the model - params_dict = dict(model.named_parameters()) - - loaded_count = 0 - skipped_count = 0 - - # Convert iterator to list to check if empty - weights_list = list(weights_iter) - if len(weights_list) == 0: - if verbose: - print(" ⚠️ No weight files found - using random initialization") - return 0, 0 - - for hf_name, loaded_weight in weights_list: - # Try to find matching pattern in name_mapping - target_name = None - - # Check if it's a layer-specific weight - if "layers" in hf_name: - # Extract layer number - import regex as re - - layer_match = re.search(r"layers\.(\d+)\.", hf_name) - if layer_match: - layer_num = layer_match.group(1) - - # Try to find matching pattern - for hf_pattern, target_pattern in name_mapping.items(): - if "{}" in hf_pattern: - hf_concrete = hf_pattern.format(layer_num) - if hf_name == hf_concrete: - target_name = target_pattern.format(layer_num) - break - else: - # Non-layer weight (embeddings, norms, output) - target_name = name_mapping.get(hf_name) - - if target_name is None: - # Skip MoE weights and other unmapped weights - if ( - "mlp.experts" in hf_name - or "mlp.gate" in hf_name - or "mlp.shared_experts" in hf_name - ): - # MoE weights - skip silently - skipped_count += 1 - continue - else: - if verbose: - print(f" ⚠️ No mapping for: {hf_name}") - skipped_count += 1 - continue - - # Check if parameter exists in model - if target_name not in params_dict: - if verbose: - print(f" ⚠️ Parameter not found in model: {target_name}") - skipped_count += 1 - continue - - # Load the weight - param = params_dict[target_name] - - # Verify shapes match - if param.shape != loaded_weight.shape: - if verbose: - print(f" ⚠️ Shape mismatch for {target_name}:") - print(f" Model: {param.shape}, Checkpoint: {loaded_weight.shape}") - skipped_count += 1 - continue - - # Load the weight - default_weight_loader(param, loaded_weight) - loaded_count += 1 - - # Log first few loads for verification - if verbose and loaded_count <= 5: - print(f" ✓ Loaded {target_name}: {loaded_weight.shape}") - - return loaded_count, skipped_count diff --git a/torchtitan/experiments/vllm/model/attention.py b/torchtitan/experiments/vllm/model/attention.py deleted file mode 100644 index aebc579b12..0000000000 --- a/torchtitan/experiments/vllm/model/attention.py +++ /dev/null @@ -1,402 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -""" -vLLM-compatible Flash Attention implementation for deterministic RL training. - -Uses vLLM's flash_attn_varlen_func for forward pass (fast) with custom backward pass. -""" - -import itertools - -import torch - -from vllm.attention.layer import Attention -from vllm.attention.utils.fa_utils import flash_attn_varlen_func, get_flash_attn_version -from vllm.model_executor.layers.batch_invariant import vllm_is_batch_invariant - - -class VLLMPagedFlashAttention(torch.nn.Module): - """ - Wrapper around vLLM's Attention with custom backward pass. - - Forward: Uses vLLM's optimized Attention layer (flash attention kernels) - Backward: Custom implementation for deterministic gradients - """ - - # Class variable for auto-generating unique layer names (thread-safe) - _layer_counter = itertools.count() - - def __init__( - self, - hidden_size: int, - num_heads: int, - num_kv_heads: int, - head_dim: int, - scale: float | None = None, - causal: bool = True, - ) -> None: - super().__init__() - - self.hidden_size = hidden_size - - # Handle tensor parallelism: adjust num_heads and num_kv_heads for TP - # NOTE(jianiw): As we use local tensor for this region, we need to manually - try: - from vllm.config import get_current_vllm_config - from vllm.logger import init_logger - - logger = init_logger(__name__) - vllm_config = get_current_vllm_config() - tp_size = vllm_config.parallel_config.tensor_parallel_size - - if tp_size > 1: - if num_kv_heads % tp_size != 0: - # Pad num_kv_heads and num_heads to be divisible by tp_size - assert num_heads % num_kv_heads == 0 - padded_size = tp_size - num_kv_heads % tp_size - padded_num_kv_heads = num_kv_heads + padded_size - padded_num_heads = ( - num_heads + padded_size * num_heads // num_kv_heads - ) - assert padded_num_heads % tp_size == 0 - assert padded_num_kv_heads % tp_size == 0 - - logger.info( - f"Padding attention heads for tensor parallelism: " - f"{num_heads=}, {padded_num_heads=}, " - f"{num_kv_heads=}, {padded_num_kv_heads=}" - ) - - num_heads = padded_num_heads // tp_size - num_kv_heads = padded_num_kv_heads // tp_size - else: - num_heads //= tp_size - num_kv_heads //= tp_size - except (ImportError, RuntimeError, AttributeError): - # Not in vLLM context - use original values - pass - - self.num_heads = num_heads - self.num_kv_heads = num_kv_heads - self.head_dim = head_dim - self.causal = causal - - if scale is None: - self.scale = head_dim**-0.5 - else: - self.scale = scale - - # Create vLLM Attention layer - try: - from vllm.config import get_current_vllm_config - - config = get_current_vllm_config() - cache_config = ( - config.cache_config if hasattr(config, "cache_config") else None - ) - - # Generate unique prefix for this attention layer - # vLLM expects format "layers.X" for layer index extraction - layer_idx = next(VLLMPagedFlashAttention._layer_counter) - prefix = f"layers.{layer_idx}" - - self.vllm_attn = Attention( - num_heads=num_heads, - head_size=head_dim, - scale=self.scale, - num_kv_heads=num_kv_heads, - cache_config=cache_config, - quant_config=None, - prefix=prefix, - ) - - except (ImportError, RuntimeError, AttributeError): - # Not in vLLM context - will need to set up manually - self.vllm_attn = None - - # KV cache - will be populated by vLLM during model loading - self.kv_cache: list[torch.Tensor] | None = None - - # Auto-register for vLLM KV cache if in vLLM context - self._auto_register_for_kv_cache() - - def _auto_register_for_kv_cache(self): - """Automatically register this layer for vLLM KV cache allocation. - - This is called during __init__ and will register the layer if we're in - a vLLM context. If not in vLLM context (e.g., pure PyTorch training), - this silently does nothing. - """ - # Initialize layer_name attribute - self.layer_name: str | None = None - - try: - from vllm.config import get_current_vllm_config - - config = get_current_vllm_config() - compilation_config = config.compilation_config - - # Generate unique layer name using class counter - # Format: "layers.{index}" for compatibility with extract_layer_index() - layer_name = f"layers.{next(VLLMPagedFlashAttention._layer_counter)}" - - # Register this layer in static forward context - if layer_name in compilation_config.static_forward_context: - raise ValueError(f"Duplicate layer name: {layer_name}") - compilation_config.static_forward_context[layer_name] = self - self.layer_name = layer_name - - except (ImportError, RuntimeError, AttributeError): - # Not in vLLM context - this is fine! - # Layer will work normally for training/inference without vLLM - pass - - def forward( - self, - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - *, - scale: float | None = None, - ) -> torch.Tensor: - """ - Forward with dual-mode behavior: - - Inference (model.training=False): Use vLLM's Attention layer (KV cache, etc.) - - Training (model.training=True): Use flash_attn_varlen_func with custom backward - - vLLM's Attention used flash_attn_varlen_func kernel by default. - - Args: - q: Query tensor [batch, num_heads, seq_len, head_dim] - k: Key tensor [batch, num_kv_heads, seq_len, head_dim] - v: Value tensor [batch, num_kv_heads, seq_len, head_dim] - scale: Optional attention scale override - - Returns: - output: [batch, num_heads, seq_len, head_dim] - """ - # Input is (batch, num_heads, seq_len, head_dim) - batch_size, num_heads, seq_len, head_dim = q.shape - - # INFERENCE MODE: Use vLLM's Attention layer - if not self.training and self.vllm_attn is not None: - # Transpose to (batch, seq_len, num_heads, head_dim) for vLLM - q = q.transpose(1, 2) - k = k.transpose(1, 2) - v = v.transpose(1, 2) - - # # Flatten to (total_tokens, num_heads, head_dim) - # NOTE(jianiw): vllm_attention can also take input as shape (batch, seq_len, num_heads, head_dim) and do internally - - # q_varlen = q.reshape(-1, num_heads, head_dim) - # k_varlen = k.reshape(-1, k.shape[2], head_dim) # k.shape[2] = num_kv_head - # v_varlen = v.reshape(-1, v.shape[2], head_dim) - - try: - # Use vLLM's Attention layer (requires forward context) - output_varlen = self.vllm_attn(q, k, v) - - # print(f"[jianiw] vllm_attn output is: {output_varlen}") - # Reshape back to batch format - output = output_varlen.view(batch_size, seq_len, num_heads, head_dim) - - # Transpose back to TorchTitan format - output = output.transpose(1, 2) - - return output - except (AssertionError, RuntimeError) as e: - # Forward context not available, fall through to training mode - print(f"Error when calling self.vllm_attn during Inference, {str(e)}") - raise - - # TRAINING MODE: Use flash_attn_varlen_func with custom backward - # Transpose to (batch, seq_len, num_heads, head_dim) for vLLM - q = q.transpose(1, 2) - k = k.transpose(1, 2) - v = v.transpose(1, 2) - - # After to_local(), use actual tensor shapes (TP may have sharded heads) - # Shape: (batch, seq_len, num_heads_local, head_dim) - _, _, num_heads_local, _ = q.shape - _, _, num_kv_heads_local, _ = k.shape - - # Convert to varlen format for vLLM: flatten batch and sequence - # (batch, seq_len, num_heads, head_dim) -> (total_tokens, num_heads, head_dim) - q_varlen = q.reshape(-1, num_heads_local, head_dim) - k_varlen = k.reshape(-1, num_kv_heads_local, head_dim) - v_varlen = v.reshape(-1, num_kv_heads_local, head_dim) - - # Use custom autograd function with flash_attn_varlen_func forward and manual backward - class VLLMForwardCustomBackward(torch.autograd.Function): - @staticmethod - def forward(ctx, q, k, v, scale, batch_size, seq_len, causal, fa_version): - # Flash Attention only supports fp16 and bf16 - # Store original dtype for conversion back - original_dtype = q.dtype - - # Convert to bf16 if not already fp16/bf16 - if original_dtype not in [torch.float16, torch.bfloat16]: - target_dtype = ( - torch.bfloat16 - if torch.cuda.is_bf16_supported() - else torch.float16 - ) - q = q.to(target_dtype) - k = k.to(target_dtype) - v = v.to(target_dtype) - else: - target_dtype = original_dtype - - # Use flash_attn_varlen_func directly for fast forward pass - # This is the SAME kernel vLLM uses internally! - # TODO(jianiw): Need to double-check - cu_seqlens_q = torch.arange( - 0, - (batch_size + 1) * seq_len, - seq_len, - dtype=torch.int32, - device=q.device, - ) - - output = flash_attn_varlen_func( - q=q, - k=k, - v=v, - cu_seqlens_q=cu_seqlens_q, - cu_seqlens_k=cu_seqlens_q, - max_seqlen_q=seq_len, - max_seqlen_k=seq_len, - softmax_scale=scale, - causal=causal, - num_splits=1 if vllm_is_batch_invariant() else 0, - fa_version=fa_version, - ) - - # Convert output back to original dtype if needed - if original_dtype not in [torch.float16, torch.bfloat16]: - output = output.to(original_dtype) - - # Save for backward - ctx.save_for_backward(q, k, v, output) - ctx.scale = scale - ctx.seq_len = seq_len - ctx.batch_size = batch_size - ctx.causal = causal - ctx.original_dtype = original_dtype - - return output - - @staticmethod - def backward(ctx, grad_output): - q, k, v, output = ctx.saved_tensors - scale = ctx.scale - seq_len = ctx.seq_len - batch_size = ctx.batch_size - causal = ctx.causal - original_dtype = ctx.original_dtype - - # Convert grad_output to match saved tensor dtype - if grad_output.dtype != q.dtype: - grad_output = grad_output.to(q.dtype) - - # Reshape from varlen to batch format - total_tokens = q.shape[0] - num_heads = q.shape[1] - head_dim = q.shape[2] - - q_batch = q.reshape(batch_size, seq_len, num_heads, head_dim) - k_batch = k.reshape(batch_size, seq_len, num_heads, head_dim) - v_batch = v.reshape(batch_size, seq_len, num_heads, head_dim) - grad_out_batch = grad_output.reshape( - batch_size, seq_len, num_heads, head_dim - ) - - # Transpose to (batch, num_heads, seq_len, head_dim) - q_t = q_batch.transpose(1, 2) - k_t = k_batch.transpose(1, 2) - v_t = v_batch.transpose(1, 2) - grad_out_t = grad_out_batch.transpose(1, 2) - - # Compute attention scores: QK^T - scores = torch.matmul(q_t, k_t.transpose(-2, -1)) * scale - - # Apply causal mask if needed - if causal: - causal_mask = torch.triu( - torch.ones(seq_len, seq_len, device=q.device, dtype=torch.bool), - diagonal=1, - ) - scores = scores.masked_fill(causal_mask, float("-inf")) - - # Softmax - attn_weights = torch.nn.functional.softmax(scores, dim=-1) - - # Backward through attention - # grad_v = attn_weights^T @ grad_out - grad_v_t = torch.matmul(attn_weights.transpose(-2, -1), grad_out_t) - - # grad_attn_weights = grad_out @ v^T - grad_attn_weights = torch.matmul(grad_out_t, v_t.transpose(-2, -1)) - - # Backward through softmax - sum_term = (grad_attn_weights * attn_weights).sum(dim=-1, keepdim=True) - grad_scores = attn_weights * (grad_attn_weights - sum_term) - - # Apply causal mask to gradients - if causal: - grad_scores = grad_scores.masked_fill(causal_mask, 0.0) - - # Backward through scale - grad_scores = grad_scores * scale - - # grad_q = grad_scores @ K - grad_q_t = torch.matmul(grad_scores, k_t) - - # grad_k = grad_scores^T @ Q - grad_k_t = torch.matmul(grad_scores.transpose(-2, -1), q_t) - - # Transpose back and reshape to varlen format - grad_q = grad_q_t.transpose(1, 2).reshape( - total_tokens, num_heads, head_dim - ) - grad_k = grad_k_t.transpose(1, 2).reshape( - total_tokens, num_heads, head_dim - ) - grad_v = grad_v_t.transpose(1, 2).reshape( - total_tokens, num_heads, head_dim - ) - - # Convert gradients back to original dtype if needed - if original_dtype not in [torch.float16, torch.bfloat16]: - grad_q = grad_q.to(original_dtype) - grad_k = grad_k.to(original_dtype) - grad_v = grad_v.to(original_dtype) - - return grad_q, grad_k, grad_v, None, None, None, None, None - - # Get flash attention version - fa_version = get_flash_attn_version() - - # Apply custom autograd function - output_varlen = VLLMForwardCustomBackward.apply( - q_varlen, - k_varlen, - v_varlen, - scale or self.scale, - batch_size, - seq_len, - self.causal, - fa_version, - ) - - # Convert back to batch format - # (total_tokens, num_heads, head_dim) -> (batch, seq_len, num_heads, head_dim) - output = output_varlen.reshape(batch_size, seq_len, num_heads, head_dim) - - # Transpose back to TorchTitan format: (batch, num_heads, seq_len, head_dim) - output = output.transpose(1, 2) - - return output diff --git a/torchtitan/experiments/vllm/model/qwen3.py b/torchtitan/experiments/vllm/model/qwen3.py deleted file mode 100644 index 6abaef480d..0000000000 --- a/torchtitan/experiments/vllm/model/qwen3.py +++ /dev/null @@ -1,398 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. -""" -TorchTitan Qwen3 Model for vLLM Integration. - -This module provides a vLLM-compatible wrapper around TorchTitan's Qwen3 model, -enabling models trained with TorchTitan to be served through vLLM for inference. - -Example: - ```python - from vllm import LLM - - llm = LLM( - model="Qwen/Qwen3-0.6B", - trust_remote_code=True, - ) - ``` - -IMPORTANT: TorchTitan imports are deferred to avoid CUDA initialization -before vLLM's multiprocessing fork. -""" - -import torch -from torch.distributed.device_mesh import init_device_mesh -from torch.distributed.tensor import DTensor - -# Import from local custom_models directory -from torchtitan.experiments.vllm.custom_models import ( - # load_external_weights, - store_positions_in_context, - VLLMModelForCausalLM, -) -from torchtitan.experiments.vllm.model.attention import VLLMPagedFlashAttention -from torchtitan.models.qwen3.infra.parallelize import apply_non_moe_tp -from torchtitan.tools.utils import device_type - -from vllm.config import VllmConfig -from vllm.logger import init_logger - -# from vllm.model_executor.parallel_context import ParallelContext - -logger = init_logger(__name__) - - -class TorchTitanQwen3ForCausalLM(VLLMModelForCausalLM): - """ - vLLM-compatible wrapper for TorchTitan's Qwen3 model. - - This class integrates TorchTitan's Qwen3Model with vLLM by: - 1. Importing TorchTitan's model architecture - 2. Replacing attention with vLLM's TrainableFlashAttention - 3. Implementing the vLLM model interface - - The architecture uses standard multi-head attention (not MLA), - with RoPE positional embeddings and optional QK normalization. - """ - - is_text_generation_model = True # Required for vLLM runner validation - supports_pp = False # Pipeline parallelism not supported yet - supports_multimodal = False - - def __init__( - self, - *, - vllm_config: VllmConfig, - prefix: str = "", # This is required for vLLM interface - ): - super().__init__() - - # vLLM config is required - assert vllm_config is not None, "vllm_config is required" - - # Import TorchTitan's Qwen3 model (deferred import to avoid CUDA init issues) - from torchtitan.models.qwen3.model.args import Qwen3ModelArgs - from torchtitan.models.qwen3.model.model import Qwen3Model - - # Map HuggingFace config to TorchTitan ModelArgs - logger.info("vllm config: " + str(vllm_config.__class__)) - hf_config = vllm_config.model_config.hf_config - logger.info("hf_config: " + str(hf_config)) - model_args = Qwen3ModelArgs( - vocab_size=getattr(hf_config, "vocab_size", 151936), - dim=getattr(hf_config, "hidden_size", 2048), - n_layers=getattr(hf_config, "num_hidden_layers", 4), - n_heads=getattr(hf_config, "num_attention_heads", 16), - n_kv_heads=getattr(hf_config, "num_key_value_heads", 2), - head_dim=getattr(hf_config, "head_dim", 128), - hidden_dim=getattr(hf_config, "intermediate_size", 11008), - norm_eps=getattr(hf_config, "rms_norm_eps", 1e-6), - max_seq_len=getattr(hf_config, "max_position_embeddings", 8192), - rope_theta=getattr(hf_config, "rope_theta", 1000000.0), - qk_norm=getattr(hf_config, "qk_norm", True), - ) - - print(f"In attention initialization, model args are : {model_args}") - - # Create TorchTitan model - self.model = Qwen3Model(model_args) - self.config = model_args - # self.parallel_context = parallel_context - - self._replice_with_vllm_paged_attention(model_args) - - ( - dp_size, - mp_size, - cp_size, - pp_size, - ep_size, - etp_size, - ) = self._process_parallelism_settings(vllm_config) - - # Build device mesh and apply parallelization - if mp_size > 1 or ep_size > 1: - self._build_device_mesh_and_parallelize( - dp_size, mp_size, cp_size, pp_size, ep_size, etp_size - ) - - def _replice_with_vllm_paged_attention(self, model_args): - # The `vllm.Attention` module handles QKV projection, RoPE, etc., and calls `inner_attention` - if not hasattr(self.model, "layers"): - raise AttributeError( - f"Model {type(self.model).__name__} must have .layers attribute" - ) - - for layer_name, layer in self.model.layers.items(): - if not hasattr(layer, "attention"): - raise ValueError(f"Layer {layer_name} must have .attention attribute") - - vllm_attn = VLLMPagedFlashAttention( - hidden_size=model_args.dim, - num_heads=model_args.n_heads, # 16 (8 when TP =2) - # NOTE(jianiw): Before feeding into inner_attention, the n_kv_heads has been replicated -> num_heads - num_kv_heads=model_args.n_heads, # 16 (8 When TP=2) - head_dim=model_args.head_dim, - causal=True, - ) - - layer.attention.inner_attention = vllm_attn - logger.info( - "Successfully replaced TorchTitan attention with vLLM PagedFlashAttention" - ) - - def _process_parallelism_settings( - self, vllm_config: VllmConfig, use_token_shuffling_moe: bool = False - ): - """ - Parse parallel config from vllm config - """ - world_size = ( - vllm_config.parallel_config.data_parallel_size - * vllm_config.parallel_config.tensor_parallel_size - ) - ep_size = ( - world_size if vllm_config.parallel_config.enable_expert_parallel else 1 - ) - etp_size = ( - 1 if vllm_config.parallel_config.enable_expert_parallel else world_size - ) - dp_size = vllm_config.parallel_config.data_parallel_size - - mp_size = vllm_config.parallel_config.tensor_parallel_size - cp_size = vllm_config.parallel_config.decode_context_parallel_size - pp_size = vllm_config.parallel_config.pipeline_parallel_size - self.pp_size = pp_size - - return dp_size, mp_size, cp_size, pp_size, ep_size, etp_size - - def _build_device_mesh_and_parallelize( - self, - dp_size: int, - mp_size: int, - cp_size: int, - pp_size: int, - ep_size: int, - etp_size: int, - ): - """ - Build device mesh in TorchTitan style and apply parallelization to the model. - - This follows the same approach as TorchTitan's ParallelDims.build_mesh() - and parallelize_qwen3() functions. - """ - import torch.distributed as dist - - # Get world size and validate - world_size = dist.get_world_size() - - # For now, assume dp_shard=1 (no data parallel sharding) - # In full implementation, you may need to calculate dp_replicate and dp_shard - dp_replicate = dp_size - dp_shard = 1 - - # Validate parallelism settings - assert dp_replicate * dp_shard * cp_size * mp_size * pp_size == world_size, ( - f"Invalid parallel dims: dp_replicate({dp_replicate}) * dp_shard({dp_shard}) * " - f"cp({cp_size}) * tp({mp_size}) * pp({pp_size}) != WORLD_SIZE({world_size})" - ) - - # Build device mesh following TorchTitan's _build_mesh_without_ep pattern - # (assuming no EP for now) - dims = [] - names = [] - for d, name in zip( - [pp_size, dp_replicate, dp_shard, cp_size, mp_size], - ["pp", "dp_replicate", "dp_shard", "cp", "tp"], - ): - if d > 1: - dims.append(d) - names.append(name) - - logger.info(f"Building {len(dims)}-D device mesh with {names}, {dims}") - world_mesh = init_device_mesh(device_type, dims, mesh_dim_names=names) - - logger.info(f"Build torchtitan device mesh: {world_mesh}") - - # Apply tensor parallelism if enabled - if mp_size > 1: - tp_mesh = world_mesh["tp"] - apply_non_moe_tp( - model=self.model, - tp_mesh=tp_mesh, - loss_parallel=False, # vLLM handles loss computation separately - enable_float8_tensorwise_tp=False, # Can be enabled if needed - enable_async_tp=False, # Can be enabled if needed - ) - logger.info(f"Applied Tensor Parallelism with TP={mp_size}") - - # Store the mesh for future use - self.world_mesh = world_mesh - - def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: - """Convert input token IDs to embeddings. - - This is the vLLM-standard method name for embedding tokens. - """ - return self.model.tok_embeddings(input_ids) - - def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: - """Convert input token IDs to embeddings (deprecated vLLM interface).""" - return self.embed_input_ids(input_ids) - - def forward( - self, - input_ids: torch.Tensor | None = None, - positions: torch.Tensor | None = None, - inputs_embeds: torch.Tensor | None = None, - **kwargs, - ) -> torch.Tensor: - """ - Forward pass with vLLM interface. - - Args: - input_ids: Token IDs from vLLM [total_tokens] (1D varlen format) - positions: Position indices from vLLM [total_tokens] (1D varlen format) - inputs_embeds: Pre-computed embeddings (optional, used by vLLM) - **kwargs: Additional vLLM kwargs - - Returns: - hidden_states: Final hidden states [total_tokens, hidden_size] - """ - # Handle inputs_embeds vs input_ids properly - if inputs_embeds is not None: - raise NotImplementedError( - "inputs_embeds is not yet supported by TorchTitan Qwen3. " - "The model expects token IDs and computes embeddings internally. " - "Please provide input_ids instead." - ) - - if input_ids is None: - raise ValueError("Either input_ids or inputs_embeds must be provided") - - # Convert vLLM interface to TorchTitan interface - # vLLM passes input_ids as [total_tokens] but TorchTitan expects [batch_size, seq_len] - # For now, reshape to [1, total_tokens] as a simple batch of 1 - # TODO: In future, use attn_metadata.seq_lens to properly reconstruct batch structure - tokens_2d = input_ids.unsqueeze(0) # [total_tokens] -> [1, total_tokens] - - # Store positions in forward context for attention layers - # Also convert positions to 2D format - if positions is not None: - positions_2d = positions.unsqueeze(0) # [total_tokens] -> [1, total_tokens] - store_positions_in_context(positions_2d) - - # Get embeddings from 2D tokens - h = self.model.tok_embeddings(tokens_2d) # [1, total_tokens, hidden_size] - - # Get RoPE cache - seqlen = h.shape[1] # seq_len dimension - rope_cache = self.model.rope_cache[:seqlen] - - # Pass through transformer layers - for layer in self.model.layers.values(): - h = layer(h, rope_cache, attention_masks=None) - - # Convert output format back to vLLM expectations - # vLLM expects hidden_states in [total_tokens, hidden_size] format - # TorchTitan returns [batch_size, seq_len, hidden_size], so we need to flatten - if h.dim() == 3: # [batch_size, seq_len, hidden_size] - batch_size, seq_len, hidden_size = h.shape - h = h.view(batch_size * seq_len, hidden_size) # [total_tokens, hidden_size] - - # TODO(jianiw): explicitly insert communication and return full tensor to vLLM Engine. To be checked. - if isinstance(h, DTensor): - h = h.full_tensor() - return h - - def compute_logits( - self, - hidden_states: torch.Tensor, - sampling_metadata=None, - ) -> torch.Tensor | None: - """Compute logits from hidden states. - - Returns: - Logits tensor, or None if TP rank > 0 - """ - # Apply final layer norm - h = self.model.norm(hidden_states) - - # Apply output projection to get logits - logits = self.model.output(h) - - # When using TP, only rank 0 returns logits - # vLLM expects None from other ranks - if isinstance(logits, DTensor): - # Convert DTensor to local tensor for vLLM - logits = logits.full_tensor() - - return logits - - def load_weights(self, weights_iter): - """ - Uses TorchTitan's Qwen3StateDictAdapter to map HF → TorchTitan naming, - then uses set_model_state_dict for proper distributed tensor handling. - - Args: - weights_iter: Iterator of (name, tensor) pairs from HF checkpoint - - Returns: - Set of loaded parameter names (for vLLM compatibility) - """ - from torch.distributed._tensor import DTensor, Replicate - from torch.distributed.checkpoint.state_dict import ( - set_model_state_dict, - StateDictOptions, - ) - from torchtitan.models.qwen3.model.state_dict_adapter import ( - Qwen3StateDictAdapter, - ) - - # Collect weights from iterator into a dict - hf_state_dict = {} - for name, tensor in weights_iter: - hf_state_dict[name] = tensor - - # Use TorchTitan's adapter to convert HF → TorchTitan format - adapter = Qwen3StateDictAdapter( - model_args=self.config, - hf_assets_path=None, # Not needed for from_hf conversion - ) - - torchtitan_state_dict = adapter.from_hf(hf_state_dict) - model_state_dict = {k: v for k, v in self.model.state_dict().items()} - - # Convert HF tensors to replicate DTensor if target is DTensor - for name, tensor in torchtitan_state_dict.items(): - if name in model_state_dict and isinstance(model_state_dict[name], DTensor): - # Get the device mesh from the target DTensor - target_dtensor = model_state_dict[name] - device_mesh = target_dtensor.device_mesh - # Convert to replicate DTensor - torchtitan_state_dict[name] = DTensor.from_local( - tensor.to(device_mesh.device_type), - device_mesh=device_mesh, - placements=[Replicate()], - ) - - # Use TorchTitan's distributed state dict loading - # This handles TP/PP sharding automatically - set_model_state_dict( - model=self.model, - model_state_dict=torchtitan_state_dict, - options=StateDictOptions( - strict=False, # Allow missing keys - ), - ) - - # manually patch the loaded - loaded_params = {f"model.{name}" for name in torchtitan_state_dict.keys()} - logger.info( - f"Loaded {len(loaded_params)} parameters from checkpoint using distributed-aware loading" - ) - - return loaded_params diff --git a/torchtitan/experiments/vllm/tests/test_flash_attn_varlen.py b/torchtitan/experiments/vllm/tests/test_flash_attn_varlen.py deleted file mode 100644 index 195adc5a07..0000000000 --- a/torchtitan/experiments/vllm/tests/test_flash_attn_varlen.py +++ /dev/null @@ -1,183 +0,0 @@ -#!/usr/bin/env python3 -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -""" -Simple test script for flash_attn_varlen_func forward pass. -Tests the variable-length flash attention function from vLLM. -""" - -import torch -from vllm.attention.utils.fa_utils import flash_attn_varlen_func - - -def test_flash_attn_varlen_func(): - """Test flash_attn_varlen_func with simple dummy inputs.""" - - # Set device - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - print(f"Using device: {device}") - - # Parameters - batch_size = 2 - num_heads = 8 - head_dim = 64 - seq_len_q = 128 - seq_len_k = 128 - - # Total tokens (for variable length) - total_tokens_q = batch_size * seq_len_q - total_tokens_k = batch_size * seq_len_k - - # Create input tensors - # Shape: (total_tokens, num_heads, head_dim) - q = torch.randn( - total_tokens_q, num_heads, head_dim, dtype=torch.float16, device=device - ) - k = torch.randn( - total_tokens_k, num_heads, head_dim, dtype=torch.float16, device=device - ) - v = torch.randn( - total_tokens_k, num_heads, head_dim, dtype=torch.float16, device=device - ) - - # Create cumulative sequence lengths - # cu_seqlens_q and cu_seqlens_k indicate the start position of each sequence - # For uniform sequences: [0, seq_len, 2*seq_len, ...] - cu_seqlens_q = torch.tensor( - [0, seq_len_q, 2 * seq_len_q], dtype=torch.int32, device=device - ) - cu_seqlens_k = torch.tensor( - [0, seq_len_k, 2 * seq_len_k], dtype=torch.int32, device=device - ) - - # Maximum sequence lengths - max_seqlen_q = seq_len_q - max_seqlen_k = seq_len_k - - # Softmax scale (typically 1/sqrt(head_dim)) - softmax_scale = 1.0 / (head_dim**0.5) - - print("\nInput shapes:") - print(f" q: {q.shape}") - print(f" k: {k.shape}") - print(f" v: {v.shape}") - print(f" cu_seqlens_q: {cu_seqlens_q}") - print(f" cu_seqlens_k: {cu_seqlens_k}") - print(f" max_seqlen_q: {max_seqlen_q}") - print(f" max_seqlen_k: {max_seqlen_k}") - print(f" softmax_scale: {softmax_scale}") - - try: - # Call flash_attn_varlen_func - print("\nCalling flash_attn_varlen_func...") - output = flash_attn_varlen_func( - q=q, - k=k, - v=v, - cu_seqlens_q=cu_seqlens_q, - cu_seqlens_k=cu_seqlens_k, - max_seqlen_q=max_seqlen_q, - max_seqlen_k=max_seqlen_k, - softmax_scale=softmax_scale, - causal=False, # Set to True for causal attention - ) - - print(f"\nOutput shape: {output.shape}") - print(f"Output dtype: {output.dtype}") - print(f"Output device: {output.device}") - print("\nOutput statistics:") - print(f" Mean: {output.mean().item():.6f}") - print(f" Std: {output.std().item():.6f}") - print(f" Min: {output.min().item():.6f}") - print(f" Max: {output.max().item():.6f}") - - print("\n✓ Test passed successfully!") - return output - - except Exception as e: - print(f"\n✗ Test failed with error: {e}") - import traceback - - traceback.print_exc() - raise - - -def test_flash_attn_varlen_func_causal(): - """Test flash_attn_varlen_func with causal attention.""" - - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - print(f"\n{'=' * 60}") - print("Testing with causal attention") - print(f"{'=' * 60}") - print(f"Using device: {device}") - - # Smaller test for causal - batch_size = 1 - num_heads = 4 - head_dim = 32 - seq_len = 64 - - total_tokens = batch_size * seq_len - - q = torch.randn( - total_tokens, num_heads, head_dim, dtype=torch.float16, device=device - ) - k = torch.randn( - total_tokens, num_heads, head_dim, dtype=torch.float16, device=device - ) - v = torch.randn( - total_tokens, num_heads, head_dim, dtype=torch.float16, device=device - ) - - cu_seqlens = torch.tensor([0, seq_len], dtype=torch.int32, device=device) - - print("\nInput shapes:") - print(f" q, k, v: {q.shape}") - print(f" cu_seqlens: {cu_seqlens}") - - try: - output = flash_attn_varlen_func( - q=q, - k=k, - v=v, - cu_seqlens_q=cu_seqlens, - cu_seqlens_k=cu_seqlens, - max_seqlen_q=seq_len, - max_seqlen_k=seq_len, - causal=True, - ) - - print(f"\nCausal output shape: {output.shape}") - print("Output statistics:") - print(f" Mean: {output.mean().item():.6f}") - print(f" Std: {output.std().item():.6f}") - - print("\n✓ Causal test passed successfully!") - return output - - except Exception as e: - print(f"\n✗ Causal test failed with error: {e}") - import traceback - - traceback.print_exc() - raise - - -if __name__ == "__main__": - print("=" * 60) - print("Testing flash_attn_varlen_func") - print("=" * 60) - - # Test 1: Basic forward pass - test_flash_attn_varlen_func() - - # Test 2: Causal attention - test_flash_attn_varlen_func_causal() - - print("\n" + "=" * 60) - print("All tests completed successfully! ✓") - print("=" * 60) diff --git a/torchtitan/experiments/vllm/tests/test_fwd_bwd_with_vllm_attention.py b/torchtitan/experiments/vllm/tests/test_fwd_bwd_with_vllm_attention.py deleted file mode 100644 index 3669b54280..0000000000 --- a/torchtitan/experiments/vllm/tests/test_fwd_bwd_with_vllm_attention.py +++ /dev/null @@ -1,537 +0,0 @@ -#!/usr/bin/env python3 -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -""" -Test script to verify that patching TorchTitan Qwen3 model with VLLMCompatibleFlashAttention -still allows it to run with TorchTitan's training loop. - -This tests: -1. Model creation with patched attention -2. Forward pass with dummy data -3. Backward pass and gradient computation -4. Training step execution -5. Compatibility with TorchTitan's model protocol -""" - -import sys -from pathlib import Path - -import torch -import torch.nn as nn - -# Add TorchTitan to path -torchtitan_root = Path(__file__).parent.parent.parent.parent.parent -sys.path.insert(0, str(torchtitan_root)) - -from torchtitan.experiments.vllm.model.attention import VLLMCompatibleFlashAttention -from torchtitan.models.qwen3.model.args import Qwen3ModelArgs -from torchtitan.models.qwen3.model.model import Qwen3Model - - -def print_section(title): - """Print a formatted section header.""" - print("\n" + "=" * 70) - print(f" {title}") - print("=" * 70) - - -def patch_qwen3_attention(model: Qwen3Model, model_args: Qwen3ModelArgs) -> int: - """ - Patch all attention modules in Qwen3Model with VLLMCompatibleFlashAttention. - - Args: - model: Qwen3Model instance - model_args: Model configuration - - Returns: - Number of attention modules patched - """ - num_patched = 0 - - for layer_name, layer in model.layers.items(): - # Replace inner_attention with VLLMCompatibleFlashAttention - layer.attention.inner_attention = VLLMCompatibleFlashAttention( - hidden_size=model_args.dim, - num_heads=model_args.n_heads, - num_kv_heads=model_args.n_kv_heads, - head_dim=model_args.head_dim, - causal=True, - ) - num_patched += 1 - - return num_patched - - -def test_model_creation(): - """Test 1: Create Qwen3 model and patch with VLLMCompatibleFlashAttention.""" - print_section("Test 1: Model Creation and Patching") - - try: - # Create small test model - model_args = Qwen3ModelArgs( - dim=512, - n_layers=4, - n_heads=8, - n_kv_heads=4, # GQA - vocab_size=1000, - max_seq_len=512, - rope_theta=1000000.0, - hidden_dim=1024, - norm_eps=1e-6, - qk_norm=True, - ) - - print("Creating Qwen3Model with config:") - print(f" dim: {model_args.dim}") - print(f" n_layers: {model_args.n_layers}") - print(f" n_heads: {model_args.n_heads}") - print(f" n_kv_heads: {model_args.n_kv_heads}") - print(f" vocab_size: {model_args.vocab_size}") - - model = Qwen3Model(model_args) - print("✅ Model created successfully") - - # Patch attention modules - print("\nPatching attention modules...") - num_patched = patch_qwen3_attention(model, model_args) - print( - f"✅ Patched {num_patched} attention modules with VLLMCompatibleFlashAttention" - ) - - # Verify patch - first_layer = model.layers["0"] - assert isinstance( - first_layer.attention.inner_attention, VLLMCompatibleFlashAttention - ), "Attention module not patched correctly" - print("✅ Verified attention module type") - - return model, model_args - - except Exception as e: - print(f"❌ Test 1 failed: {e}") - import traceback - - traceback.print_exc() - return None, None - - -def test_forward_pass(model: Qwen3Model, model_args: Qwen3ModelArgs): - """Test 2: Run forward pass with dummy data.""" - print_section("Test 2: Forward Pass") - - if model is None: - print("⚠️ Skipping (model creation failed)") - return None - - try: - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - model = model.to(device) - model.eval() - - print(f"Using device: {device}") - - # Create dummy input - batch_size = 2 - seq_len = 64 - tokens = torch.randint( - 0, model_args.vocab_size, (batch_size, seq_len), device=device - ) - - print(f"\nInput shape: {tokens.shape}") - - # Forward pass - with torch.no_grad(): - output = model(tokens) - - print(f"Output shape: {output.shape}") - print(f"Output dtype: {output.dtype}") - - # Verify output shape - expected_shape = (batch_size, seq_len, model_args.vocab_size) - assert ( - output.shape == expected_shape - ), f"Expected {expected_shape}, got {output.shape}" - - print("\nOutput statistics:") - print(f" Mean: {output.mean().item():.6f}") - print(f" Std: {output.std().item():.6f}") - print(f" Min: {output.min().item():.6f}") - print(f" Max: {output.max().item():.6f}") - - print("\n✅ Forward pass successful") - return output - - except Exception as e: - print(f"❌ Test 2 failed: {e}") - import traceback - - traceback.print_exc() - return None - - -def test_backward_pass(model: Qwen3Model, model_args: Qwen3ModelArgs): - """Test 3: Run backward pass and verify gradients.""" - print_section("Test 3: Backward Pass and Gradient Computation") - - if model is None: - print("⚠️ Skipping (model creation failed)") - return False - - try: - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - model = model.to(device) - model.train() # Enable training mode - - # Create dummy input and target - batch_size = 2 - seq_len = 64 - tokens = torch.randint( - 0, model_args.vocab_size, (batch_size, seq_len), device=device - ) - targets = torch.randint( - 0, model_args.vocab_size, (batch_size, seq_len), device=device - ) - - print(f"Input shape: {tokens.shape}") - print(f"Target shape: {targets.shape}") - - # Forward pass - output = model(tokens) - - # Compute loss - loss_fn = nn.CrossEntropyLoss() - loss = loss_fn(output.view(-1, model_args.vocab_size), targets.view(-1)) - - print(f"\nLoss: {loss.item():.6f}") - - # Backward pass - print("\nRunning backward pass...") - loss.backward() - - # Check gradients - grad_count = 0 - grad_norms = {} - - for name, param in model.named_parameters(): - if param.grad is not None: - grad_count += 1 - grad_norm = param.grad.norm().item() - if "attention" in name: - grad_norms[name] = grad_norm - - print(f"✅ Gradients computed for {grad_count} parameters") - - # Show some attention gradients - if grad_norms: - print("\nSample attention gradient norms:") - for name, norm in list(grad_norms.items())[:5]: - print(f" {name}: {norm:.6f}") - - # Verify gradients are non-zero - assert grad_count > 0, "No gradients computed" - - print("\n✅ Backward pass successful") - return True - - except Exception as e: - print(f"❌ Test 3 failed: {e}") - import traceback - - traceback.print_exc() - return False - - -def test_training_step(model: Qwen3Model, model_args: Qwen3ModelArgs): - """Test 4: Run a full training step with optimizer.""" - print_section("Test 4: Training Step with Optimizer") - - if model is None: - print("⚠️ Skipping (model creation failed)") - return False - - try: - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - model = model.to(device) - model.train() - - # Create optimizer - optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4) - loss_fn = nn.CrossEntropyLoss() - - print(f"Optimizer: {type(optimizer).__name__}") - print(f"Learning rate: {optimizer.param_groups[0]['lr']}") - - # Run multiple training steps - num_steps = 3 - losses = [] - - for step in range(num_steps): - # Create dummy data - batch_size = 2 - seq_len = 64 - tokens = torch.randint( - 0, model_args.vocab_size, (batch_size, seq_len), device=device - ) - targets = torch.randint( - 0, model_args.vocab_size, (batch_size, seq_len), device=device - ) - - # Zero gradients - optimizer.zero_grad() - - # Forward pass - output = model(tokens) - loss = loss_fn(output.view(-1, model_args.vocab_size), targets.view(-1)) - - # Backward pass - loss.backward() - - # Optimizer step - optimizer.step() - - losses.append(loss.item()) - print(f" Step {step + 1}/{num_steps}: loss = {loss.item():.6f}") - - print(f"\n✅ Completed {num_steps} training steps") - print(f"Loss values: {losses}") - - return True - - except Exception as e: - print(f"❌ Test 4 failed: {e}") - import traceback - - traceback.print_exc() - return False - - -def test_attention_shapes(model: Qwen3Model, model_args: Qwen3ModelArgs): - """Test 5: Verify attention input/output shapes in detail.""" - print_section("Test 5: Attention Shape Verification") - - if model is None: - print("⚠️ Skipping (model creation failed)") - return False - - try: - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - model = model.to(device) - model.eval() - - # Hook to capture attention inputs/outputs - attention_info = {} - - def attention_hook(module, input_args, output): - """Capture attention layer inputs and outputs.""" - # input_args is a tuple: (x, rope_cache, attention_masks) - x = input_args[0] - attention_info["input_shape"] = x.shape - attention_info["output_shape"] = output.shape - - # Register hook on first layer's attention - first_layer = model.layers["0"] - hook = first_layer.attention.register_forward_hook(attention_hook) - - # Run forward pass - batch_size = 2 - seq_len = 64 - tokens = torch.randint( - 0, model_args.vocab_size, (batch_size, seq_len), device=device - ) - - with torch.no_grad(): - _ = model(tokens) - - # Remove hook - hook.remove() - - # Verify shapes - print(f"Attention input shape: {attention_info['input_shape']}") - print(f"Attention output shape: {attention_info['output_shape']}") - - expected_input = (batch_size, seq_len, model_args.dim) - expected_output = (batch_size, seq_len, model_args.dim) - - assert ( - attention_info["input_shape"] == expected_input - ), f"Expected input {expected_input}, got {attention_info['input_shape']}" - assert ( - attention_info["output_shape"] == expected_output - ), f"Expected output {expected_output}, got {attention_info['output_shape']}" - - print("\n✅ Attention shapes verified") - return True - - except Exception as e: - print(f"❌ Test 5 failed: {e}") - import traceback - - traceback.print_exc() - return False - - -def test_gqa_correctness(model_args: Qwen3ModelArgs): - """Test 6: Verify GQA expansion works correctly.""" - print_section("Test 6: GQA (Grouped Query Attention) Verification") - - try: - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - - # Create attention module directly - attn = VLLMCompatibleFlashAttention( - hidden_size=model_args.dim, - num_heads=model_args.n_heads, - num_kv_heads=model_args.n_kv_heads, - head_dim=model_args.head_dim, - causal=True, - ).to(device) - - print("Attention config:") - print(f" num_heads (Q): {model_args.n_heads}") - print(f" num_kv_heads (K/V): {model_args.n_kv_heads}") - print(f" head_dim: {model_args.head_dim}") - print(f" n_rep: {model_args.n_heads // model_args.n_kv_heads}") - - batch_size = 2 - seq_len = 32 - - # Test with unexpanded K/V (num_kv_heads) - print(f"\nTest 6a: K/V with num_kv_heads ({model_args.n_kv_heads})") - q = torch.randn( - batch_size, - model_args.n_heads, - seq_len, - model_args.head_dim, - device=device, - dtype=torch.float32, - ) - k_small = torch.randn( - batch_size, - model_args.n_kv_heads, - seq_len, - model_args.head_dim, - device=device, - dtype=torch.float32, - ) - v_small = torch.randn( - batch_size, - model_args.n_kv_heads, - seq_len, - model_args.head_dim, - device=device, - dtype=torch.float32, - ) - - with torch.no_grad(): - output_small = attn(q, k_small, v_small) - - print(f" Output shape: {output_small.shape}") - assert ( - output_small.shape == q.shape - ), f"Shape mismatch: {output_small.shape} vs {q.shape}" - print(" ✅ Unexpanded K/V works") - - # Test with expanded K/V (num_heads) - print(f"\nTest 6b: K/V already expanded to num_heads ({model_args.n_heads})") - k_large = k_small.repeat_interleave( - model_args.n_heads // model_args.n_kv_heads, dim=1 - ) - v_large = v_small.repeat_interleave( - model_args.n_heads // model_args.n_kv_heads, dim=1 - ) - - print(f" k_large shape: {k_large.shape}") - print(f" v_large shape: {v_large.shape}") - - with torch.no_grad(): - output_large = attn(q, k_large, v_large) - - print(f" Output shape: {output_large.shape}") - assert ( - output_large.shape == q.shape - ), f"Shape mismatch: {output_large.shape} vs {q.shape}" - print(" ✅ Expanded K/V works") - - print("\n✅ GQA verification successful") - return True - - except Exception as e: - print(f"❌ Test 6 failed: {e}") - import traceback - - traceback.print_exc() - return False - - -def main(): - """Run all tests.""" - print("=" * 70) - print(" TorchTitan + VLLMCompatibleFlashAttention Integration Test") - print("=" * 70) - - # Set random seed for reproducibility - torch.manual_seed(42) - if torch.cuda.is_available(): - torch.cuda.manual_seed(42) - - results = {} - - # Test 1: Model creation and patching - model, model_args = test_model_creation() - results["model_creation"] = model is not None - - if model is None: - print("\n❌ Cannot continue - model creation failed") - return 1 - - # Test 2: Forward pass - output = test_forward_pass(model, model_args) - results["forward_pass"] = output is not None - - # Test 3: Backward pass - results["backward_pass"] = test_backward_pass(model, model_args) - - # Test 4: Training step - results["training_step"] = test_training_step(model, model_args) - - # Test 5: Attention shapes - results["attention_shapes"] = test_attention_shapes(model, model_args) - - # Test 6: GQA verification - results["gqa_verification"] = test_gqa_correctness(model_args) - - # Summary - print_section("FINAL SUMMARY") - - print("\nTest Results:") - for test_name, passed in results.items(): - status = "✅ PASS" if passed else "❌ FAIL" - print(f" {test_name:20s}: {status}") - - all_passed = all(results.values()) - - if all_passed: - print("\n🎉 ALL TESTS PASSED!") - print("\n✅ VLLMCompatibleFlashAttention is compatible with TorchTitan!") - print("✅ Model can be trained with patched attention modules.") - print( - "\nYou can safely use this attention implementation in TorchTitan training." - ) - return 0 - else: - failed_tests = [name for name, passed in results.items() if not passed] - print(f"\n❌ {len(failed_tests)} TEST(S) FAILED:") - for test in failed_tests: - print(f" - {test}") - print( - "\nPlease fix the issues before using VLLMCompatibleFlashAttention in production." - ) - return 1 - - -if __name__ == "__main__": - sys.exit(main()) diff --git a/torchtitan/experiments/vllm/tests/test_vllm_attention.py b/torchtitan/experiments/vllm/tests/test_vllm_attention.py deleted file mode 100644 index aad6229b76..0000000000 --- a/torchtitan/experiments/vllm/tests/test_vllm_attention.py +++ /dev/null @@ -1,360 +0,0 @@ -#!/usr/bin/env python3 -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -""" -Simple test script for vLLM's Attention layer. -Tests the high-level Attention module with KV cache management. -""" - -import torch -from vllm.attention import Attention -from vllm.config import CacheConfig - - -def test_vllm_attention_basic(): - """Test vLLM Attention layer with basic inputs (no KV cache).""" - print("=" * 70) - print("Test 1: Basic Attention (No KV Cache)") - print("=" * 70) - - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - print(f"Using device: {device}") - - # Model parameters - num_heads = 16 - num_kv_heads = 8 # GQA - head_dim = 128 - hidden_size = num_heads * head_dim - - # Create minimal vLLM config - # Note: In production, this comes from get_current_vllm_config() - cache_config = CacheConfig( - block_size=16, - gpu_memory_utilization=0.9, - swap_space=4, - cache_dtype="auto", - ) - - # Create Attention layer - print("\nCreating Attention layer:") - print(f" num_heads: {num_heads}") - print(f" num_kv_heads: {num_kv_heads}") - print(f" head_dim: {head_dim}") - print(f" scale: {1.0 / (head_dim ** 0.5)}") - - attn = Attention( - num_heads=num_heads, - head_size=head_dim, - scale=1.0 / (head_dim**0.5), - num_kv_heads=num_kv_heads, - cache_config=cache_config, - quant_config=None, - prefix="layers.0", - ) - attn = attn.to(device) - attn.eval() - - # Create dummy inputs - batch_size = 2 - seq_len = 128 - total_tokens = batch_size * seq_len - - # Format: [total_tokens, num_heads, head_dim] - q = torch.randn( - total_tokens, num_heads, head_dim, dtype=torch.float16, device=device - ) - k = torch.randn( - total_tokens, num_kv_heads, head_dim, dtype=torch.float16, device=device - ) - v = torch.randn( - total_tokens, num_kv_heads, head_dim, dtype=torch.float16, device=device - ) - - print("\nInput shapes:") - print(f" q: {q.shape}") - print(f" k: {k.shape}") - print(f" v: {v.shape}") - - # Forward pass - try: - with torch.no_grad(): - output = attn(q, k, v) - - print(f"\nOutput shape: {output.shape}") - print(f"Output dtype: {output.dtype}") - print(f"Output device: {output.device}") - print("\nOutput statistics:") - print(f" Mean: {output.mean().item():.6f}") - print(f" Std: {output.std().item():.6f}") - print(f" Min: {output.min().item():.6f}") - print(f" Max: {output.max().item():.6f}") - - print("\n✅ Test 1 passed!") - return True - - except Exception as e: - print(f"\n❌ Test 1 failed with error: {e}") - import traceback - - traceback.print_exc() - return False - - -def test_vllm_attention_gqa_expansion(): - """Test that GQA expansion works correctly.""" - print("\n" + "=" * 70) - print("Test 2: GQA Expansion Test") - print("=" * 70) - - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - - num_heads = 16 - num_kv_heads = 8 - head_dim = 128 - - cache_config = CacheConfig( - block_size=16, - gpu_memory_utilization=0.9, - swap_space=4, - cache_dtype="auto", - ) - - attn = Attention( - num_heads=num_heads, - head_size=head_dim, - scale=1.0 / (head_dim**0.5), - num_kv_heads=num_kv_heads, - cache_config=cache_config, - quant_config=None, - prefix="layers.1", - ) - attn = attn.to(device) - attn.eval() - - # Test with both unexpanded and expanded k/v - total_tokens = 64 - - q = torch.randn( - total_tokens, num_heads, head_dim, device=device, dtype=torch.float16 - ) - - print(f"\nTest 2a: K/V with num_kv_heads ({num_kv_heads})") - k_small = torch.randn( - total_tokens, num_kv_heads, head_dim, device=device, dtype=torch.float16 - ) - v_small = torch.randn( - total_tokens, num_kv_heads, head_dim, device=device, dtype=torch.float16 - ) - - try: - with torch.no_grad(): - output_small = attn(q, k_small, v_small) - print(f" Output shape: {output_small.shape}") - print(" ✅ GQA with num_kv_heads works!") - except Exception as e: - print(f" ❌ Failed: {e}") - return False - - print(f"\nTest 2b: K/V already expanded to num_heads ({num_heads})") - # Simulate what TorchTitan does - expand k/v to num_heads - k_large = k_small.repeat_interleave(num_heads // num_kv_heads, dim=1) - v_large = v_small.repeat_interleave(num_heads // num_kv_heads, dim=1) - print(f" k_large shape: {k_large.shape}") - print(f" v_large shape: {v_large.shape}") - - try: - with torch.no_grad(): - output_large = attn(q, k_large, v_large) - print(f" Output shape: {output_large.shape}") - print(" ✅ Already-expanded K/V works!") - except Exception as e: - print(f" ❌ Failed: {e}") - return False - - print("\n✅ Test 2 passed!") - return True - - -def test_vllm_attention_shapes(): - """Test various input shapes.""" - print("\n" + "=" * 70) - print("Test 3: Various Input Shapes") - print("=" * 70) - - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - - num_heads = 8 - num_kv_heads = 8 # MHA - head_dim = 64 - - cache_config = CacheConfig( - block_size=16, - gpu_memory_utilization=0.9, - swap_space=4, - cache_dtype="auto", - ) - - attn = Attention( - num_heads=num_heads, - head_size=head_dim, - scale=1.0 / (head_dim**0.5), - num_kv_heads=num_kv_heads, - cache_config=cache_config, - quant_config=None, - prefix="layers.2", - ) - attn = attn.to(device) - attn.eval() - - test_cases = [ - (1, "Single token"), - (32, "Small batch"), - (256, "Medium batch"), - (1024, "Large batch"), - ] - - for total_tokens, description in test_cases: - print( - f"\nTest 3.{test_cases.index((total_tokens, description)) + 1}: {description} ({total_tokens} tokens)" - ) - q = torch.randn( - total_tokens, num_heads, head_dim, device=device, dtype=torch.float16 - ) - k = torch.randn( - total_tokens, num_kv_heads, head_dim, device=device, dtype=torch.float16 - ) - v = torch.randn( - total_tokens, num_kv_heads, head_dim, device=device, dtype=torch.float16 - ) - - try: - with torch.no_grad(): - output = attn(q, k, v) - assert ( - output.shape[0] == total_tokens - ), f"Expected {total_tokens} tokens, got {output.shape[0]}" - print(f" ✅ Shape: {output.shape}") - except Exception as e: - print(f" ❌ Failed: {e}") - return False - - print("\n✅ Test 3 passed!") - return True - - -def test_integration_with_torchtitan_format(): - """Test integration with TorchTitan's tensor format.""" - print("\n" + "=" * 70) - print("Test 4: TorchTitan Format Integration") - print("=" * 70) - - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - - num_heads = 16 - num_kv_heads = 8 - head_dim = 128 - - cache_config = CacheConfig( - block_size=16, - gpu_memory_utilization=0.9, - swap_space=4, - cache_dtype="auto", - ) - - attn = Attention( - num_heads=num_heads, - head_size=head_dim, - scale=1.0 / (head_dim**0.5), - num_kv_heads=num_kv_heads, - cache_config=cache_config, - quant_config=None, - prefix="layers.3", - ) - attn = attn.to(device) - attn.eval() - - # Simulate TorchTitan format: [batch, num_heads, seq_len, head_dim] - batch_size = 2 - seq_len = 64 - - print( - f"\nTorchTitan input format: [batch={batch_size}, num_heads={num_heads}, seq_len={seq_len}, head_dim={head_dim}]" - ) - - q_tt = torch.randn( - batch_size, num_heads, seq_len, head_dim, device=device, dtype=torch.float16 - ) - k_tt = torch.randn( - batch_size, num_heads, seq_len, head_dim, device=device, dtype=torch.float16 - ) - v_tt = torch.randn( - batch_size, num_heads, seq_len, head_dim, device=device, dtype=torch.float16 - ) - - print(f" q_tt: {q_tt.shape}") - print(f" k_tt: {k_tt.shape}") - print(f" v_tt: {v_tt.shape}") - - # Convert to vLLM format: [total_tokens, num_heads, head_dim] - total_tokens = batch_size * seq_len - q_vllm = q_tt.transpose(1, 2).reshape(total_tokens, num_heads, head_dim) - k_vllm = k_tt.transpose(1, 2).reshape(total_tokens, num_heads, head_dim) - v_vllm = v_tt.transpose(1, 2).reshape(total_tokens, num_heads, head_dim) - - print( - f"\nvLLM input format: [total_tokens={total_tokens}, num_heads={num_heads}, head_dim={head_dim}]" - ) - print(f" q_vllm: {q_vllm.shape}") - print(f" k_vllm: {k_vllm.shape}") - print(f" v_vllm: {v_vllm.shape}") - - try: - with torch.no_grad(): - output_vllm = attn(q_vllm, k_vllm, v_vllm) - - print(f"\nvLLM output: {output_vllm.shape}") - - # Convert back to TorchTitan format - output_tt = output_vllm.reshape( - batch_size, seq_len, num_heads, head_dim - ).transpose(1, 2) - print(f"TorchTitan output: {output_tt.shape}") - - assert ( - output_tt.shape == q_tt.shape - ), f"Output shape mismatch: {output_tt.shape} vs {q_tt.shape}" - print("\n✅ Test 4 passed!") - return True - - except Exception as e: - print(f"\n❌ Test 4 failed: {e}") - import traceback - - traceback.print_exc() - return False - - -if __name__ == "__main__": - print("=" * 70) - print("Testing vLLM Attention Layer") - print("=" * 70) - - all_passed = True - - # Run all tests - all_passed &= test_vllm_attention_basic() - all_passed &= test_vllm_attention_gqa_expansion() - all_passed &= test_vllm_attention_shapes() - all_passed &= test_integration_with_torchtitan_format() - - # Summary - print("\n" + "=" * 70) - if all_passed: - print("🎉 All tests passed successfully!") - else: - print("❌ Some tests failed!") - print("=" * 70) From 4da69a510f586e77ff7e89eff989724d63015385 Mon Sep 17 00:00:00 2001 From: Jiani Wang Date: Fri, 5 Dec 2025 14:14:34 -0800 Subject: [PATCH 11/19] add readme --- .../deterministic_vllm_rl/README.md | 30 ++++++++++++++----- 1 file changed, 23 insertions(+), 7 deletions(-) diff --git a/torchtitan/experiments/deterministic_vllm_rl/README.md b/torchtitan/experiments/deterministic_vllm_rl/README.md index e667d9c147..0387528853 100644 --- a/torchtitan/experiments/deterministic_vllm_rl/README.md +++ b/torchtitan/experiments/deterministic_vllm_rl/README.md @@ -214,7 +214,7 @@ This implementation uses the same kernels for both rollouts (vLLM) and training # Run vLLM inference with TorchTitan Qwen3 Model -This directory contains code to run TorchTitan model definition with vLLM inference engine (not batch-invariant yet, working in progress). +This directory contains code to run a single canonical model definition (TorchTitan model definition) with vLLM inference engine (not batch-invariant yet, working in progress). This work is inspired by https://github.com/vllm-project/vllm/pull/28685. ## Overview @@ -275,30 +275,46 @@ python torchtitan/experiments/deterministic_vllm_rl/infer.py --model torchtitan/ ``` +## TODOs +1. Rewrite attention part to use vllm.Attention() with backward as the only attention path. +2. Integrate with simple_rl.py to run end-to-end RL with one canonical model definition. +3. Leverage batch-invariant kernels into model definition. + + # Project Structure ``` deterministic_vllm_rl/ -├── README.md # Documentation +├── README.md # This documentation ├── __init__.py # Package initialization ├── batch_invariant_backward.py # Backward passes for vLLM ops ├── weights_vllm_compat.py # Weight conversion utilities ├── simple_rl.py # RL training loop +├── infer.py # vLLM Inference script to run with torchtitan model definition +├── register.py # Register torchtitan model to vLLM Engine ├── models/ │ ├── __init__.py │ ├── attention.py # VLLMCompatibleFlashAttention │ └── qwen3/ │ ├── __init__.py -│ └── model_batch_invariant.py # Batch-invariant Qwen3 model +│ ├── model_batch_invariant.py # Batch-invariant Qwen3 model +│ └── model_vllm_compat.py # vLLM-compatible Qwen3 model ├── weights/ │ ├── __init__.py │ ├── converter.py # Weight conversion script │ └── README.md # Weight conversion documentation -└── tests/ - ├── __init__.py - ├── test_batch_invariant_backward.py # Test backward passes - └── test_exact_determinism.py # Test determinism +├── tests/ +│ ├── __init__.py +│ ├── test_batch_invariant_backward.py # Test backward passes +│ └── test_exact_determinism.py # Test determinism +└── example_checkpoint/ # Example checkpoint directory (to be moved) + └── qwen3-0.6B/ + ├── config.json # Model configuration + ├── model.safetensors # Model weights + ├── tokenizer.json # Tokenizer + ├── vocab.json # Vocabulary + └── merges.txt # BPE merges ``` ## TODO From 3d5ad6b213ed57d7ee09ec5f28e2dd8ad2705b0f Mon Sep 17 00:00:00 2001 From: Jiani Wang Date: Sat, 6 Dec 2025 14:37:08 -0800 Subject: [PATCH 12/19] address comments --- .../deterministic_vllm_rl/README.md | 49 +--- .../deterministic_vllm_rl/models/attention.py | 227 +----------------- .../models/qwen3/model_vllm_compat.py | 68 +++++- 3 files changed, 87 insertions(+), 257 deletions(-) diff --git a/torchtitan/experiments/deterministic_vllm_rl/README.md b/torchtitan/experiments/deterministic_vllm_rl/README.md index 0387528853..741a47127b 100644 --- a/torchtitan/experiments/deterministic_vllm_rl/README.md +++ b/torchtitan/experiments/deterministic_vllm_rl/README.md @@ -212,6 +212,12 @@ This implementation uses the same kernels for both rollouts (vLLM) and training 3. Requires NVIDIA GPUs with Flash Attention support +## TODO + +- `FlashAttnWithBackward` will need to become more composable and should not live exclusively within this directory. +- vLLM integration will need to become more generic with a provided Attention operator that is KV-cache compatible. +- vLLM parallelism will need to add generic parallelism initialization to support Monarch managed TP/DP. + # Run vLLM inference with TorchTitan Qwen3 Model This directory contains code to run a single canonical model definition (TorchTitan model definition) with vLLM inference engine (not batch-invariant yet, working in progress). @@ -275,54 +281,13 @@ python torchtitan/experiments/deterministic_vllm_rl/infer.py --model torchtitan/ ``` -## TODOs +## TODO 1. Rewrite attention part to use vllm.Attention() with backward as the only attention path. 2. Integrate with simple_rl.py to run end-to-end RL with one canonical model definition. 3. Leverage batch-invariant kernels into model definition. -# Project Structure - -``` -deterministic_vllm_rl/ -├── README.md # This documentation -├── __init__.py # Package initialization -├── batch_invariant_backward.py # Backward passes for vLLM ops -├── weights_vllm_compat.py # Weight conversion utilities -├── simple_rl.py # RL training loop -├── infer.py # vLLM Inference script to run with torchtitan model definition -├── register.py # Register torchtitan model to vLLM Engine -├── models/ -│ ├── __init__.py -│ ├── attention.py # VLLMCompatibleFlashAttention -│ └── qwen3/ -│ ├── __init__.py -│ ├── model_batch_invariant.py # Batch-invariant Qwen3 model -│ └── model_vllm_compat.py # vLLM-compatible Qwen3 model -├── weights/ -│ ├── __init__.py -│ ├── converter.py # Weight conversion script -│ └── README.md # Weight conversion documentation -├── tests/ -│ ├── __init__.py -│ ├── test_batch_invariant_backward.py # Test backward passes -│ └── test_exact_determinism.py # Test determinism -└── example_checkpoint/ # Example checkpoint directory (to be moved) - └── qwen3-0.6B/ - ├── config.json # Model configuration - ├── model.safetensors # Model weights - ├── tokenizer.json # Tokenizer - ├── vocab.json # Vocabulary - └── merges.txt # BPE merges -``` - -## TODO - -- `FlashAttnWithBackward` will need to become more composable and should not live exclusively within this directory. -- vLLM integration will need to become more generic with a provided Attention operator that is KV-cache compatible. -- vLLM parallelism will need to add generic parallelism initialization to support Monarch managed TP/DP. - ## Contributing This experiment is part of TorchTitan. To contribute: diff --git a/torchtitan/experiments/deterministic_vllm_rl/models/attention.py b/torchtitan/experiments/deterministic_vllm_rl/models/attention.py index 00f1245230..49ce1ccd4a 100644 --- a/torchtitan/experiments/deterministic_vllm_rl/models/attention.py +++ b/torchtitan/experiments/deterministic_vllm_rl/models/attention.py @@ -13,8 +13,7 @@ import torch from vllm.attention.layer import Attention -from vllm.attention.utils.fa_utils import flash_attn_varlen_func, get_flash_attn_version -from vllm.model_executor.layers.batch_invariant import vllm_is_batch_invariant +from vllm.attention.utils.fa_utils import flash_attn_varlen_func class VLLMCompatibleFlashAttention(torch.nn.Module): @@ -348,16 +347,13 @@ def forward( scale: float | None = None, ) -> torch.Tensor: """ - Forward with dual-mode behavior: - - Inference (model.training=False): Use vLLM's Attention layer (KV cache, etc.) - - Training (model.training=True): Use flash_attn_varlen_func with custom backward - - vLLM's Attention used flash_attn_varlen_func kernel by default. + Forward pass using vLLM's Attention layer for inference. Args: q: Query tensor [batch, num_heads, seq_len, head_dim] k: Key tensor [batch, num_kv_heads, seq_len, head_dim] v: Value tensor [batch, num_kv_heads, seq_len, head_dim] - scale: Optional attention scale override + scale: Optional attention scale override (unused, vLLM uses internal scale) Returns: output: [batch, num_heads, seq_len, head_dim] @@ -365,220 +361,23 @@ def forward( # Input is (batch, num_heads, seq_len, head_dim) batch_size, num_heads, seq_len, head_dim = q.shape - # INFERENCE MODE: Use vLLM's Attention layer - if not self.training and self.vllm_attn is not None: - # Transpose to (batch, seq_len, num_heads, head_dim) for vLLM - q = q.transpose(1, 2) - k = k.transpose(1, 2) - v = v.transpose(1, 2) - - # # Flatten to (total_tokens, num_heads, head_dim) - # NOTE(jianiw): vllm_attention can also take input as shape (batch, seq_len, num_heads, head_dim) and do internally - - # q_varlen = q.reshape(-1, num_heads, head_dim) - # k_varlen = k.reshape(-1, k.shape[2], head_dim) # k.shape[2] = num_kv_head - # v_varlen = v.reshape(-1, v.shape[2], head_dim) - - try: - # Use vLLM's Attention layer (requires forward context) - output_varlen = self.vllm_attn(q, k, v) - - # print(f"[jianiw] vllm_attn output is: {output_varlen}") - # Reshape back to batch format - output = output_varlen.view(batch_size, seq_len, num_heads, head_dim) - - # Transpose back to TorchTitan format - output = output.transpose(1, 2) - - return output - except (AssertionError, RuntimeError) as e: - # Forward context not available, fall through to training mode - print(f"Error when calling self.vllm_attn during Inference, {str(e)}") - raise + if self.vllm_attn is None: + raise RuntimeError( + "vLLM attention not initialized. This module requires vLLM context." + ) - # TRAINING MODE: Use flash_attn_varlen_func with custom backward # Transpose to (batch, seq_len, num_heads, head_dim) for vLLM q = q.transpose(1, 2) k = k.transpose(1, 2) v = v.transpose(1, 2) - # After to_local(), use actual tensor shapes (TP may have sharded heads) - # Shape: (batch, seq_len, num_heads_local, head_dim) - _, _, num_heads_local, _ = q.shape - _, _, num_kv_heads_local, _ = k.shape - - # Convert to varlen format for vLLM: flatten batch and sequence - # (batch, seq_len, num_heads, head_dim) -> (total_tokens, num_heads, head_dim) - q_varlen = q.reshape(-1, num_heads_local, head_dim) - k_varlen = k.reshape(-1, num_kv_heads_local, head_dim) - v_varlen = v.reshape(-1, num_kv_heads_local, head_dim) - - # Use custom autograd function with flash_attn_varlen_func forward and manual backward - class VLLMForwardCustomBackward(torch.autograd.Function): - @staticmethod - def forward(ctx, q, k, v, scale, batch_size, seq_len, causal, fa_version): - # Flash Attention only supports fp16 and bf16 - # Store original dtype for conversion back - original_dtype = q.dtype - - # Convert to bf16 if not already fp16/bf16 - if original_dtype not in [torch.float16, torch.bfloat16]: - target_dtype = ( - torch.bfloat16 - if torch.cuda.is_bf16_supported() - else torch.float16 - ) - q = q.to(target_dtype) - k = k.to(target_dtype) - v = v.to(target_dtype) - else: - target_dtype = original_dtype - - # Use flash_attn_varlen_func directly for fast forward pass - # This is the SAME kernel vLLM uses internally! - # TODO(jianiw): Need to double-check - cu_seqlens_q = torch.arange( - 0, - (batch_size + 1) * seq_len, - seq_len, - dtype=torch.int32, - device=q.device, - ) - - output = flash_attn_varlen_func( - q=q, - k=k, - v=v, - cu_seqlens_q=cu_seqlens_q, - cu_seqlens_k=cu_seqlens_q, - max_seqlen_q=seq_len, - max_seqlen_k=seq_len, - softmax_scale=scale, - causal=causal, - num_splits=1 if vllm_is_batch_invariant() else 0, - fa_version=fa_version, - ) - - # Convert output back to original dtype if needed - if original_dtype not in [torch.float16, torch.bfloat16]: - output = output.to(original_dtype) - - # Save for backward - ctx.save_for_backward(q, k, v, output) - ctx.scale = scale - ctx.seq_len = seq_len - ctx.batch_size = batch_size - ctx.causal = causal - ctx.original_dtype = original_dtype - - return output - - @staticmethod - def backward(ctx, grad_output): - q, k, v, output = ctx.saved_tensors - scale = ctx.scale - seq_len = ctx.seq_len - batch_size = ctx.batch_size - causal = ctx.causal - original_dtype = ctx.original_dtype - - # Convert grad_output to match saved tensor dtype - if grad_output.dtype != q.dtype: - grad_output = grad_output.to(q.dtype) - - # Reshape from varlen to batch format - total_tokens = q.shape[0] - num_heads = q.shape[1] - head_dim = q.shape[2] - - q_batch = q.reshape(batch_size, seq_len, num_heads, head_dim) - k_batch = k.reshape(batch_size, seq_len, num_heads, head_dim) - v_batch = v.reshape(batch_size, seq_len, num_heads, head_dim) - grad_out_batch = grad_output.reshape( - batch_size, seq_len, num_heads, head_dim - ) - - # Transpose to (batch, num_heads, seq_len, head_dim) - q_t = q_batch.transpose(1, 2) - k_t = k_batch.transpose(1, 2) - v_t = v_batch.transpose(1, 2) - grad_out_t = grad_out_batch.transpose(1, 2) - - # Compute attention scores: QK^T - scores = torch.matmul(q_t, k_t.transpose(-2, -1)) * scale - - # Apply causal mask if needed - if causal: - causal_mask = torch.triu( - torch.ones(seq_len, seq_len, device=q.device, dtype=torch.bool), - diagonal=1, - ) - scores = scores.masked_fill(causal_mask, float("-inf")) - - # Softmax - attn_weights = torch.nn.functional.softmax(scores, dim=-1) - - # Backward through attention - # grad_v = attn_weights^T @ grad_out - grad_v_t = torch.matmul(attn_weights.transpose(-2, -1), grad_out_t) - - # grad_attn_weights = grad_out @ v^T - grad_attn_weights = torch.matmul(grad_out_t, v_t.transpose(-2, -1)) - - # Backward through softmax - sum_term = (grad_attn_weights * attn_weights).sum(dim=-1, keepdim=True) - grad_scores = attn_weights * (grad_attn_weights - sum_term) - - # Apply causal mask to gradients - if causal: - grad_scores = grad_scores.masked_fill(causal_mask, 0.0) - - # Backward through scale - grad_scores = grad_scores * scale - - # grad_q = grad_scores @ K - grad_q_t = torch.matmul(grad_scores, k_t) + # NOTE(jianiw): vllm_attention can take input as shape (batch, seq_len, num_heads, head_dim) + # and handle the reshaping internally + # Use vLLM's Attention layer for paged attention and + output_varlen = self.vllm_attn(q, k, v) - # grad_k = grad_scores^T @ Q - grad_k_t = torch.matmul(grad_scores.transpose(-2, -1), q_t) - - # Transpose back and reshape to varlen format - grad_q = grad_q_t.transpose(1, 2).reshape( - total_tokens, num_heads, head_dim - ) - grad_k = grad_k_t.transpose(1, 2).reshape( - total_tokens, num_heads, head_dim - ) - grad_v = grad_v_t.transpose(1, 2).reshape( - total_tokens, num_heads, head_dim - ) - - # Convert gradients back to original dtype if needed - if original_dtype not in [torch.float16, torch.bfloat16]: - grad_q = grad_q.to(original_dtype) - grad_k = grad_k.to(original_dtype) - grad_v = grad_v.to(original_dtype) - - return grad_q, grad_k, grad_v, None, None, None, None, None - - # Get flash attention version - fa_version = get_flash_attn_version() - - # Apply custom autograd function - output_varlen = VLLMForwardCustomBackward.apply( - q_varlen, - k_varlen, - v_varlen, - scale or self.scale, - batch_size, - seq_len, - self.causal, - fa_version, - ) - - # Convert back to batch format - # (total_tokens, num_heads, head_dim) -> (batch, seq_len, num_heads, head_dim) - output = output_varlen.reshape(batch_size, seq_len, num_heads, head_dim) + # Reshape back to batch format + output = output_varlen.view(batch_size, seq_len, num_heads, head_dim) # Transpose back to TorchTitan format: (batch, num_heads, seq_len, head_dim) output = output.transpose(1, 2) diff --git a/torchtitan/experiments/deterministic_vllm_rl/models/qwen3/model_vllm_compat.py b/torchtitan/experiments/deterministic_vllm_rl/models/qwen3/model_vllm_compat.py index 10b4841c31..d77dcf4695 100644 --- a/torchtitan/experiments/deterministic_vllm_rl/models/qwen3/model_vllm_compat.py +++ b/torchtitan/experiments/deterministic_vllm_rl/models/qwen3/model_vllm_compat.py @@ -206,6 +206,65 @@ def _build_device_mesh_and_parallelize( # Store the mesh for future use self.world_mesh = world_mesh + def _extend_rope_cache_if_needed( + self, rope_cache: torch.Tensor, max_position: int + ) -> torch.Tensor: + """ + Extend rope_cache if needed (e.g., during vLLM profiling with 2x max_seq_len). + + Args: + rope_cache: Current RoPE cache tensor + max_position: Maximum position index needed + + Returns: + Extended rope_cache if needed, otherwise original rope_cache + """ + required_len = max_position + 1 + + # No extension needed + if required_len <= rope_cache.shape[0]: + return rope_cache + + # Handle DTensor case - convert to local tensor first + from torch.distributed._tensor import DTensor, Replicate + + is_dtensor = isinstance(rope_cache, DTensor) + if is_dtensor: + # Get the local tensor and device mesh + device_mesh = rope_cache.device_mesh + local_rope_cache = rope_cache.to_local() + device = local_rope_cache.device + dtype = local_rope_cache.dtype + else: + device = rope_cache.device + dtype = rope_cache.dtype + + # Import precompute_rope_cache from the model module + from torchtitan.models.qwen3.model.model import precompute_rope_cache + + # Precompute additional RoPE frequencies on-the-fly + # TODO: This also has strong assumptions on config names + rope_theta = self.config.rope_theta + head_dim = self.config.head_dim + extended_cache = precompute_rope_cache( + dim=head_dim, + max_seq_len=required_len, + base=rope_theta, + ) + extended_cache = extended_cache.to(device=device, dtype=dtype) + + # If original was DTensor, convert extended cache to DTensor too + if is_dtensor: + rope_cache = DTensor.from_local( + extended_cache, + device_mesh=device_mesh, + placements=[Replicate()], + ) + else: + rope_cache = extended_cache + + return rope_cache + def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: """Convert input token IDs to embeddings. @@ -256,14 +315,21 @@ def forward( # Store positions in forward context for attention layers # Also convert positions to 2D format # TODO: The position id information is not properly used yet + # TODO: The batch_size might not be 1, need double check if positions is not None: positions_2d = positions.unsqueeze(0) # [total_tokens] -> [1, total_tokens] # Get embeddings from 2D tokens h = self.model.tok_embeddings(tokens_2d) # [1, total_tokens, hidden_size] + # Extend RoPE cache if needed (vLLM profiling may use 2x max_seq_len) + max_position = positions.max().item() if positions is not None else 0 + rope_cache = self._extend_rope_cache_if_needed( + self.model.rope_cache, max_position + ) + # Get RoPE cache indexed by positions - rope_cache = self.model.rope_cache[positions] + rope_cache = rope_cache[positions] # Pass through transformer layers for layer in self.model.layers.values(): From d3e93a0a5e86bb996b6a97a3d63f4c0b625ea19f Mon Sep 17 00:00:00 2001 From: Jiani Wang Date: Mon, 8 Dec 2025 16:11:40 -0800 Subject: [PATCH 13/19] restore unnecessary changes --- torchtitan/models/qwen3/infra/parallelize.py | 17 ++++------------- torchtitan/models/qwen3/model/model.py | 10 +++++++--- 2 files changed, 11 insertions(+), 16 deletions(-) diff --git a/torchtitan/models/qwen3/infra/parallelize.py b/torchtitan/models/qwen3/infra/parallelize.py index 1dd54657c2..c2eaed8de6 100644 --- a/torchtitan/models/qwen3/infra/parallelize.py +++ b/torchtitan/models/qwen3/infra/parallelize.py @@ -17,7 +17,6 @@ ColwiseParallel, parallelize_module, PrepareModuleInput, - PrepareModuleOutput, RowwiseParallel, SequenceParallel, ) @@ -257,19 +256,11 @@ def apply_non_moe_tp( input_layouts=(Shard(1), Replicate(), None, None), desired_input_layouts=(Replicate(), Replicate(), None, None), ), - "attention.wq": colwise_parallel( - use_local_output=False - ), # NOTE(jianiw): Try using Tensor instead of DTensor + "attention.wq": colwise_parallel(use_local_output=False), "attention.wk": colwise_parallel(use_local_output=False), - "attention.wv": colwise_parallel(use_local_output=True), - "attention.q_norm": SequenceParallel(sequence_dim=2, use_local_output=True), - "attention.k_norm": SequenceParallel(sequence_dim=2, use_local_output=True), - # NOTE(jianiw): manually convert the output to be Shard(1) - "attention.inner_attention": PrepareModuleOutput( - output_layouts=(Shard(1),), - desired_output_layouts=(Shard(1),), - use_local_output=False, - ), # NOTE(jianiw): inner_attention output shape (batch, num_heads, seq_len, head_dim) + "attention.wv": colwise_parallel(use_local_output=False), + "attention.q_norm": SequenceParallel(sequence_dim=2), + "attention.k_norm": SequenceParallel(sequence_dim=2), "attention.wo": rowwise_parallel(output_layouts=Shard(1)), "ffn_norm": SequenceParallel(), } diff --git a/torchtitan/models/qwen3/model/model.py b/torchtitan/models/qwen3/model/model.py index 378f203bf6..0683b4c42d 100644 --- a/torchtitan/models/qwen3/model/model.py +++ b/torchtitan/models/qwen3/model/model.py @@ -61,7 +61,13 @@ def reshape_for_broadcast( rope_cache: torch.Tensor, x: torch.Tensor, positions: torch.Tensor | None = None ) -> torch.Tensor: """ - Reshapes the RoPE frequency tensor to be broadcastable with the input tensor. + Reshape frequency tensor (represented by cos, sin) for broadcasting it with another tensor. + + This function reshapes the frequency tensor to have the same shape as the target tensor 'x' + for the purpose of broadcasting the frequency tensor during element-wise operations. + + The input freqs_cis tensor is assumed to be of shape (max_seqlen, head_dim * 2), + and the first seqlen elements will be sliced, but dim must match x. Args: rope_cache (torch.Tensor): RoPE tensor (cos and sin) to be reshaped. @@ -258,8 +264,6 @@ def forward( keys = repeat_kv(xk, self.n_rep) # (bs, seqlen, n_local_heads, head_dim) values = repeat_kv(xv, self.n_rep) # (bs, seqlen, n_local_heads, head_dim) - # NOTE(jianiw) - xq = xq.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim) xk = keys.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim) xv = values.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim) From 8154b74d506c86ca72b7253c217ee678a6037477 Mon Sep 17 00:00:00 2001 From: Jiani Wang Date: Mon, 8 Dec 2025 16:15:04 -0800 Subject: [PATCH 14/19] refactor --- .../model_vllm_compat.py => base_wrapper.py} | 331 +++++++++--------- .../models/qwen3/__init__.py | 3 +- .../deterministic_vllm_rl/register.py | 48 ++- 3 files changed, 213 insertions(+), 169 deletions(-) rename torchtitan/experiments/deterministic_vllm_rl/models/{qwen3/model_vllm_compat.py => base_wrapper.py} (51%) diff --git a/torchtitan/experiments/deterministic_vllm_rl/models/qwen3/model_vllm_compat.py b/torchtitan/experiments/deterministic_vllm_rl/models/base_wrapper.py similarity index 51% rename from torchtitan/experiments/deterministic_vllm_rl/models/qwen3/model_vllm_compat.py rename to torchtitan/experiments/deterministic_vllm_rl/models/base_wrapper.py index d77dcf4695..79e6c1fc47 100644 --- a/torchtitan/experiments/deterministic_vllm_rl/models/qwen3/model_vllm_compat.py +++ b/torchtitan/experiments/deterministic_vllm_rl/models/base_wrapper.py @@ -4,8 +4,21 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -import torch +""" +Generic vLLM wrapper for TorchTitan models. + +This module provides TorchTitanVLLMWrapper, a generic base class that makes +any TorchTitan model compatible with vLLM by: +1. Accepting 4 pluggable model-specific components +2. Replacing attention with vLLM's paged attention +3. Setting up parallelization (TP/EP) +4. Loading weights from HuggingFace checkpoints +""" +import functools +from typing import Callable + +import torch import torch.nn as nn from torch.distributed.device_mesh import init_device_mesh from torch.distributed.tensor import DTensor @@ -16,21 +29,32 @@ from torchtitan.experiments.deterministic_vllm_rl.models.attention import ( VLLMPagedFlashAttention, ) -from torchtitan.models.qwen3.infra.parallelize import apply_non_moe_tp +from torchtitan.protocols.model import BaseModelArgs +from torchtitan.protocols.state_dict_adapter import BaseStateDictAdapter from torchtitan.tools.utils import device_type logger = init_logger(__name__) -class TorchTitanQwen3ForCausalLM(nn.Module): +class TorchTitanVLLMWrapper(nn.Module): """ - vLLM-compatible wrapper for TorchTitan's Qwen3 model. - - This class integrates TorchTitan's Qwen3Model with vLLM by: - 1. Importing TorchTitan's model architecture - 2. Replacing attention with vLLM's Attention with PagedAttention and kv cache capability. - 3. Implementing the vLLM model interface + Generic vLLM-compatible wrapper for TorchTitan models. + + This base class integrates any TorchTitan model with vLLM by accepting + 4 pluggable model-specific components: + + 1. model_cls: The TorchTitan model class (e.g., Qwen3Model) + 2. model_args_cls: The model args class (e.g., Qwen3ModelArgs) + 3. state_dict_adapter: State dict adapter for loading HF weights + 4. parallelize_fn: Function to apply tensor parallelism + + The wrapper handles: + - HF config → TorchTitan model args mapping + - Attention replacement with vLLM paged attention + - Tensor parallelism setup + - Weight loading from HF checkpoints + - vLLM forward/compute_logits interface """ is_text_generation_model = True # Required for vLLM runner validation @@ -40,44 +64,43 @@ class TorchTitanQwen3ForCausalLM(nn.Module): def __init__( self, *, + model_cls: type, + model_args_cls: BaseModelArgs, + state_dict_adapter: BaseStateDictAdapter, + parallelize_fn: Callable, + rope_cache_compute_fn: Callable | None = None, vllm_config: VllmConfig, - prefix: str = "", # This is required for vLLM interface + prefix: str = "", ): super().__init__() - # vLLM config is required assert vllm_config is not None, "vllm_config is required" - # Import TorchTitan's Qwen3 model (deferred import to avoid CUDA init issues) - from torchtitan.models.qwen3.model.args import Qwen3ModelArgs - from torchtitan.models.qwen3.model.model import Qwen3Model + # Store components + self.model_cls = model_cls + self.model_args_cls = model_args_cls + self.state_dict_adapter = state_dict_adapter + self.parallelize_fn = parallelize_fn - # Map HuggingFace config to TorchTitan ModelArgs - logger.info("vllm config: " + str(vllm_config.__class__)) + # Map HF config to TorchTitan ModelArgs hf_config = vllm_config.model_config.hf_config - logger.info("hf_config: " + str(hf_config)) - model_args = Qwen3ModelArgs( - vocab_size=getattr(hf_config, "vocab_size", 151936), - dim=getattr(hf_config, "hidden_size", 2048), - n_layers=getattr(hf_config, "num_hidden_layers", 4), - n_heads=getattr(hf_config, "num_attention_heads", 16), - n_kv_heads=getattr(hf_config, "num_key_value_heads", 2), - head_dim=getattr(hf_config, "head_dim", 128), - hidden_dim=getattr(hf_config, "intermediate_size", 11008), - norm_eps=getattr(hf_config, "rms_norm_eps", 1e-6), - max_seq_len=getattr(hf_config, "max_position_embeddings", 8192), - rope_theta=getattr(hf_config, "rope_theta", 1000000.0), - qk_norm=getattr(hf_config, "qk_norm", True), - ) - - print(f"In attention initialization, model args are : {model_args}") + logger.info(f"Mapping HF config to {model_args_cls.__name__}") + model_args = self._map_hf_config_to_model_args(hf_config, model_args_cls) # Create TorchTitan model - self.model = Qwen3Model(model_args) + logger.info(f"Creating {model_cls.__name__} with config: {model_args}") + self.model = model_cls(model_args) self.config = model_args - self._replice_with_vllm_paged_attention(model_args) + # NOTE: Here's assumptions of rope_cache_compute_fn function signature + self.rope_cache_extension_fn = functools.partial( + rope_cache_compute_fn, dim=self.config.head_dim, base=self.config.rope_theta + ) + + # Replace attention with vLLM paged attention + self._replace_with_vllm_paged_attention(model_args) + # Setup parallelization ( dp_size, mp_size, @@ -87,14 +110,49 @@ def __init__( etp_size, ) = self._process_parallelism_settings(vllm_config) - # Build device mesh and apply parallelization if mp_size > 1 or ep_size > 1: self._build_device_mesh_and_parallelize( dp_size, mp_size, cp_size, pp_size, ep_size, etp_size ) - def _replice_with_vllm_paged_attention(self, model_args): - # The `vllm.Attention` module handles QKV projection, RoPE, etc., and calls `inner_attention` + def _map_hf_config_to_model_args(self, hf_config, model_args_cls): + """ + Map HuggingFace config to TorchTitan ModelArgs. + + Default implementation that handles common model args fields. + Override in subclass if custom mapping is needed. + """ + # Maps TorchTitan parameter name to HF config attribute name + mapping = { + "vocab_size": "vocab_size", + "dim": "hidden_size", + "n_layers": "num_hidden_layers", + "n_heads": "num_attention_heads", + "n_kv_heads": "num_key_value_heads", + "head_dim": "head_dim", + "hidden_dim": "intermediate_size", + "norm_eps": "rms_norm_eps", + "max_seq_len": "max_position_embeddings", + "rope_theta": "rope_theta", + "qk_norm": "qk_norm", + } + + # Build kwargs for model args from mapping + kwargs = {} + for torchtitan_param, hf_attr in mapping.items(): + # Try to get value from HF config + if hasattr(hf_config, hf_attr): + kwargs[torchtitan_param] = getattr(hf_config, hf_attr) + + return model_args_cls(**kwargs) + + def _replace_with_vllm_paged_attention(self, model_args): + """ + Replace TorchTitan attention with vLLM paged attention. + + Assumes model has .layers dict with .attention.inner_attention structure. + Override in subclass if different structure. + """ if not hasattr(self.model, "layers"): raise AttributeError( f"Model {type(self.model).__name__} must have .layers attribute" @@ -104,26 +162,24 @@ def _replice_with_vllm_paged_attention(self, model_args): if not hasattr(layer, "attention"): raise ValueError(f"Layer {layer_name} must have .attention attribute") + # Create vLLM paged attention vllm_attn = VLLMPagedFlashAttention( hidden_size=model_args.dim, - num_heads=model_args.n_heads, # 16 (8 when TP =2) - # NOTE(jianiw): Before feeding into inner_attention, the n_kv_heads has been replicated -> num_heads - num_kv_heads=model_args.n_heads, # 16 (8 When TP=2) + num_heads=model_args.n_heads, + num_kv_heads=model_args.n_heads, # Use n_heads (already replicated) head_dim=model_args.head_dim, causal=True, ) + # Replace inner attention layer.attention.inner_attention = vllm_attn + logger.info( "Successfully replaced TorchTitan attention with vLLM PagedFlashAttention" ) - def _process_parallelism_settings( - self, vllm_config: VllmConfig, use_token_shuffling_moe: bool = False - ): - """ - Parse parallel config from vllm config - """ + def _process_parallelism_settings(self, vllm_config: VllmConfig): + """Parse parallel config from vLLM config.""" world_size = ( vllm_config.parallel_config.data_parallel_size * vllm_config.parallel_config.tensor_parallel_size @@ -135,7 +191,6 @@ def _process_parallelism_settings( 1 if vllm_config.parallel_config.enable_expert_parallel else world_size ) dp_size = vllm_config.parallel_config.data_parallel_size - mp_size = vllm_config.parallel_config.tensor_parallel_size cp_size = vllm_config.parallel_config.decode_context_parallel_size pp_size = vllm_config.parallel_config.pipeline_parallel_size @@ -153,29 +208,20 @@ def _build_device_mesh_and_parallelize( etp_size: int, ): """ - Build device mesh in TorchTitan style and apply parallelization to the model. - - This follows the same approach as TorchTitan's ParallelDims.build_mesh() - and parallelize_qwen3() functions. + Build device mesh and apply parallelization using the provided parallelize_fn. """ import torch.distributed as dist - # Get world size and validate world_size = dist.get_world_size() - - # For now, assume dp_shard=1 (no data parallel sharding) - # In full implementation, you may need to calculate dp_replicate and dp_shard dp_replicate = dp_size dp_shard = 1 - # Validate parallelism settings assert dp_replicate * dp_shard * cp_size * mp_size * pp_size == world_size, ( f"Invalid parallel dims: dp_replicate({dp_replicate}) * dp_shard({dp_shard}) * " f"cp({cp_size}) * tp({mp_size}) * pp({pp_size}) != WORLD_SIZE({world_size})" ) - # Build device mesh following TorchTitan's _build_mesh_without_ep pattern - # (assuming no EP for now) + # Build device mesh dims = [] names = [] for d, name in zip( @@ -189,48 +235,55 @@ def _build_device_mesh_and_parallelize( logger.info(f"Building {len(dims)}-D device mesh with {names}, {dims}") world_mesh = init_device_mesh(device_type, dims, mesh_dim_names=names) - logger.info(f"Build torchtitan device mesh: {world_mesh}") - - # Apply tensor parallelism if enabled + # Apply tensor parallelism using provided function if mp_size > 1: tp_mesh = world_mesh["tp"] - apply_non_moe_tp( + self.parallelize_fn( model=self.model, tp_mesh=tp_mesh, - loss_parallel=False, # vLLM handles loss computation separately - enable_float8_tensorwise_tp=False, # Can be enabled if needed - enable_async_tp=False, # Can be enabled if needed + loss_parallel=False, + enable_float8_tensorwise_tp=False, + enable_async_tp=False, ) logger.info(f"Applied Tensor Parallelism with TP={mp_size}") - # Store the mesh for future use self.world_mesh = world_mesh def _extend_rope_cache_if_needed( self, rope_cache: torch.Tensor, max_position: int ) -> torch.Tensor: """ - Extend rope_cache if needed (e.g., during vLLM profiling with 2x max_seq_len). + Extend RoPE cache if needed during vLLM profiling. + + Uses the rope_cache_extension_fn provided during initialization if available. Args: rope_cache: Current RoPE cache tensor max_position: Maximum position index needed Returns: - Extended rope_cache if needed, otherwise original rope_cache + Extended RoPE cache if needed, otherwise original cache """ + from torch.distributed._tensor import DTensor, Replicate + required_len = max_position + 1 # No extension needed if required_len <= rope_cache.shape[0]: return rope_cache - # Handle DTensor case - convert to local tensor first - from torch.distributed._tensor import DTensor, Replicate + # If no extension function provided, return original cache + if self.rope_cache_extension_fn is None: + logger.warning( + f"RoPE cache extension needed (required_len={required_len}, " + f"current_len={rope_cache.shape[0]}) but no rope_cache_extension_fn provided. " + "Returning original cache." + ) + return rope_cache + # Handle DTensor case is_dtensor = isinstance(rope_cache, DTensor) if is_dtensor: - # Get the local tensor and device mesh device_mesh = rope_cache.device_mesh local_rope_cache = rope_cache.to_local() device = local_rope_cache.device @@ -239,21 +292,18 @@ def _extend_rope_cache_if_needed( device = rope_cache.device dtype = rope_cache.dtype - # Import precompute_rope_cache from the model module - from torchtitan.models.qwen3.model.model import precompute_rope_cache - - # Precompute additional RoPE frequencies on-the-fly - # TODO: This also has strong assumptions on config names - rope_theta = self.config.rope_theta - head_dim = self.config.head_dim - extended_cache = precompute_rope_cache( - dim=head_dim, - max_seq_len=required_len, - base=rope_theta, - ) - extended_cache = extended_cache.to(device=device, dtype=dtype) + # Use provided extension function + try: + extended_cache = self.rope_cache_extension_fn(self.config, required_len) + extended_cache = extended_cache.to(device=device, dtype=dtype) + except Exception as e: + logger.warning( + f"Failed to extend RoPE cache using rope_cache_extension_fn: {e}. " + "Returning original cache." + ) + return rope_cache - # If original was DTensor, convert extended cache to DTensor too + # Convert back to DTensor if needed if is_dtensor: rope_cache = DTensor.from_local( extended_cache, @@ -266,10 +316,7 @@ def _extend_rope_cache_if_needed( return rope_cache def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: - """Convert input token IDs to embeddings. - - This is the vLLM-standard method name for embedding tokens. - """ + """Convert input token IDs to embeddings.""" return self.model.tok_embeddings(input_ids) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: @@ -287,64 +334,58 @@ def forward( Forward pass with vLLM interface. Args: - input_ids: Token IDs from vLLM [total_tokens] (1D varlen format) - positions: Position indices from vLLM [total_tokens] (1D varlen format) - inputs_embeds: Pre-computed embeddings (optional, used by vLLM) + input_ids: Token IDs [total_tokens] (1D varlen format) + positions: Position indices [total_tokens] (1D varlen format) + inputs_embeds: Pre-computed embeddings (optional) **kwargs: Additional vLLM kwargs Returns: hidden_states: Final hidden states [total_tokens, hidden_size] """ - # Handle inputs_embeds vs input_ids properly if inputs_embeds is not None: - raise NotImplementedError( - "inputs_embeds is not yet supported by TorchTitan Qwen3. " - "The model expects token IDs and computes embeddings internally. " - "Please provide input_ids instead." - ) + raise NotImplementedError("inputs_embeds not yet supported") if input_ids is None: raise ValueError("Either input_ids or inputs_embeds must be provided") # Convert vLLM interface to TorchTitan interface - # vLLM passes input_ids as [total_tokens] but TorchTitan expects [batch_size, seq_len] - # For now, reshape to [1, total_tokens] as a simple batch of 1 - # TODO: In future, use attn_metadata.seq_lens to properly reconstruct batch structure - tokens_2d = input_ids.unsqueeze(0) # [total_tokens] -> [1, total_tokens] - - # Store positions in forward context for attention layers - # Also convert positions to 2D format - # TODO: The position id information is not properly used yet - # TODO: The batch_size might not be 1, need double check - if positions is not None: - positions_2d = positions.unsqueeze(0) # [total_tokens] -> [1, total_tokens] - - # Get embeddings from 2D tokens - h = self.model.tok_embeddings(tokens_2d) # [1, total_tokens, hidden_size] + # vLLM: [total_tokens] → TorchTitan: [batch_size, seq_len] + tokens_2d = input_ids.unsqueeze(0) + + # Get embeddings + h = self.model.tok_embeddings(tokens_2d) + + # Get RoPE cache (handle model-specific attribute names) + # Use hasattr to avoid ambiguous boolean value error with tensors + if hasattr(self.model, "rope_cache"): + rope_attr = self.model.rope_cache + elif hasattr(self.model, "freqs_cis"): + rope_attr = self.model.freqs_cis + else: + rope_attr = None # Extend RoPE cache if needed (vLLM profiling may use 2x max_seq_len) - max_position = positions.max().item() if positions is not None else 0 - rope_cache = self._extend_rope_cache_if_needed( - self.model.rope_cache, max_position - ) + if positions is not None: + max_position = positions.max().item() + else: + max_position = 0 - # Get RoPE cache indexed by positions + rope_cache = self._extend_rope_cache_if_needed(rope_attr, max_position) rope_cache = rope_cache[positions] # Pass through transformer layers for layer in self.model.layers.values(): h = layer(h, rope_cache, attention_masks=None) - # Convert output format back to vLLM expectations - # vLLM expects hidden_states in [total_tokens, hidden_size] format - # TorchTitan returns [batch_size, seq_len, hidden_size], so we need to flatten - if h.dim() == 3: # [batch_size, seq_len, hidden_size] + # Convert to vLLM format: [total_tokens, hidden_size] + if h.dim() == 3: batch_size, seq_len, hidden_size = h.shape - h = h.view(batch_size * seq_len, hidden_size) # [total_tokens, hidden_size] + h = h.view(batch_size * seq_len, hidden_size) - # TODO(jianiw): explicitly insert communication and return full tensor to vLLM Engine. To be checked. + # Convert DTensor to regular tensor if isinstance(h, DTensor): h = h.full_tensor() + return h def compute_logits( @@ -352,35 +393,24 @@ def compute_logits( hidden_states: torch.Tensor, sampling_metadata=None, ) -> torch.Tensor | None: - """Compute logits from hidden states. - - Returns: - Logits tensor, or None if TP rank > 0 - """ - # Apply final layer norm + """Compute logits from hidden states.""" h = self.model.norm(hidden_states) - - # Apply output projection to get logits logits = self.model.output(h) - # When using TP, only rank 0 returns logits - # vLLM expects None from other ranks if isinstance(logits, DTensor): - # Convert DTensor to local tensor for vLLM logits = logits.full_tensor() return logits def load_weights(self, weights_iter): """ - Uses TorchTitan's Qwen3StateDictAdapter to map HF → TorchTitan naming, - then uses set_model_state_dict for proper distributed tensor handling. + Load weights from HF checkpoint using the provided state dict adapter. Args: weights_iter: Iterator of (name, tensor) pairs from HF checkpoint Returns: - Set of loaded parameter names (for vLLM compatibility) + Set of loaded parameter names """ from torch.distributed._tensor import DTensor, Replicate from torch.distributed.checkpoint.state_dict import ( @@ -388,51 +418,38 @@ def load_weights(self, weights_iter): StateDictOptions, ) - from torchtitan.models.qwen3.model.state_dict_adapter import ( - Qwen3StateDictAdapter, - ) - - # Collect weights from iterator into a dict + # Collect weights from iterator hf_state_dict = {} for name, tensor in weights_iter: hf_state_dict[name] = tensor - # Use TorchTitan's adapter to convert HF → TorchTitan format - adapter = Qwen3StateDictAdapter( + # Use adapter to convert HF → TorchTitan format + adapter = self.state_dict_adapter( model_args=self.config, - hf_assets_path=None, # Not needed for from_hf conversion + hf_assets_path=None, ) torchtitan_state_dict = adapter.from_hf(hf_state_dict) model_state_dict = {k: v for k, v in self.model.state_dict().items()} - # Convert HF tensors to replicate DTensor if target is DTensor + # Convert to DTensor if target is DTensor for name, tensor in torchtitan_state_dict.items(): if name in model_state_dict and isinstance(model_state_dict[name], DTensor): - # Get the device mesh from the target DTensor target_dtensor = model_state_dict[name] device_mesh = target_dtensor.device_mesh - # Convert to replicate DTensor torchtitan_state_dict[name] = DTensor.from_local( tensor.to(device_mesh.device_type), device_mesh=device_mesh, placements=[Replicate()], ) - # Use TorchTitan's distributed state dict loading - # This handles TP/PP sharding automatically + # Load state dict set_model_state_dict( model=self.model, model_state_dict=torchtitan_state_dict, - options=StateDictOptions( - strict=False, # Allow missing keys - ), + options=StateDictOptions(strict=False), ) - # manually patch the loaded loaded_params = {f"model.{name}" for name in torchtitan_state_dict.keys()} - logger.info( - f"Loaded {len(loaded_params)} parameters from checkpoint using distributed-aware loading" - ) return loaded_params diff --git a/torchtitan/experiments/deterministic_vllm_rl/models/qwen3/__init__.py b/torchtitan/experiments/deterministic_vllm_rl/models/qwen3/__init__.py index c721a01b3a..4ac63df232 100644 --- a/torchtitan/experiments/deterministic_vllm_rl/models/qwen3/__init__.py +++ b/torchtitan/experiments/deterministic_vllm_rl/models/qwen3/__init__.py @@ -9,6 +9,5 @@ """ from .model_batch_invariant import Qwen3VLLMCompatModel -from .model_vllm_compat import TorchTitanQwen3ForCausalLM -__all__ = ["Qwen3VLLMCompatModel", "TorchTitanQwen3ForCausalLM"] +__all__ = ["Qwen3VLLMCompatModel"] diff --git a/torchtitan/experiments/deterministic_vllm_rl/register.py b/torchtitan/experiments/deterministic_vllm_rl/register.py index d7ee8c3b12..17a94e8821 100644 --- a/torchtitan/experiments/deterministic_vllm_rl/register.py +++ b/torchtitan/experiments/deterministic_vllm_rl/register.py @@ -7,23 +7,51 @@ def register(): """ - Register TorchTitan models with vLLM. + Register TorchTitan models with vLLM using class inheritance pattern. - This function is called to register TorchTitan-trained models with vLLM. - It sets up the necessary model registry entries for TorchTitan models. + This function registers TorchTitan-trained models with vLLM's model registry + by creating subclasses of TorchTitanVLLMWrapper and passing in the 5 core + model-specific components: - Currently supports: - - Qwen3TorchTitanForCausalLM: Qwen3 models trained with TorchTitan + 1. model_cls - The TorchTitan model class (e.g., Qwen3Model, Transformer) + 2. model_args_cls - The model args class (e.g., Qwen3ModelArgs) + 3. state_dict_adapter - State dict adapter for loading HF weights + 4. parallelize_fn - Function to apply tensor parallelism + 5. rope_cache_extension_fn - Optional function to extend RoPE cache """ from vllm.model_executor.models.registry import ModelRegistry - from torchtitan.experiments.deterministic_vllm_rl.models.qwen3 import ( - TorchTitanQwen3ForCausalLM, + from torchtitan.experiments.deterministic_vllm_rl.models.base_wrapper import ( + TorchTitanVLLMWrapper, ) - # Register Qwen3TorchTitanForCausalLM with vLLM's ModelRegistry - # This maps the architecture name from config.json to the model class + from torchtitan.models.qwen3 import Qwen3Model, Qwen3ModelArgs + from torchtitan.models.qwen3.infra.parallelize import ( + apply_non_moe_tp as apply_qwen3_tp, + ) + from torchtitan.models.qwen3.model.model import precompute_rope_cache + from torchtitan.models.qwen3.model.state_dict_adapter import Qwen3StateDictAdapter + + class Qwen3TorchTitanForCausalLM(TorchTitanVLLMWrapper): + """ + vLLM wrapper for TorchTitan-trained Qwen3 models. + + This class plugs in the 5 Qwen3-specific components into the + generic TorchTitanVLLMWrapper. + """ + + def __init__(self, *, vllm_config, prefix=""): + super().__init__( + model_cls=Qwen3Model, + model_args_cls=Qwen3ModelArgs, + state_dict_adapter=Qwen3StateDictAdapter, + parallelize_fn=apply_qwen3_tp, + rope_cache_compute_fn=precompute_rope_cache, + vllm_config=vllm_config, + prefix=prefix, + ) + ModelRegistry.register_model( - "Qwen3TorchTitanForCausalLM", TorchTitanQwen3ForCausalLM + "Qwen3TorchTitanForCausalLM", Qwen3TorchTitanForCausalLM ) From 4ea07eef36dafcf22fb563f440f6c82a18446bfe Mon Sep 17 00:00:00 2001 From: Jiani Wang Date: Mon, 8 Dec 2025 23:34:18 -0800 Subject: [PATCH 15/19] refactor v2 --- .../deterministic_vllm_rl/__init__.py | 9 + .../deterministic_vllm_rl/infer.py | 45 ++--- .../deterministic_vllm_rl/models/__init__.py | 71 ++++++- .../models/base_wrapper.py | 181 +++++------------- .../deterministic_vllm_rl/models/utils.py | 104 ++++++++++ .../deterministic_vllm_rl/register.py | 57 ------ 6 files changed, 257 insertions(+), 210 deletions(-) create mode 100644 torchtitan/experiments/deterministic_vllm_rl/models/utils.py delete mode 100644 torchtitan/experiments/deterministic_vllm_rl/register.py diff --git a/torchtitan/experiments/deterministic_vllm_rl/__init__.py b/torchtitan/experiments/deterministic_vllm_rl/__init__.py index 067555251f..0e6c76c6e5 100644 --- a/torchtitan/experiments/deterministic_vllm_rl/__init__.py +++ b/torchtitan/experiments/deterministic_vllm_rl/__init__.py @@ -15,6 +15,12 @@ - Qwen3VLLMCompatModel: vLLM-compatible model with merged projections - batch_invariant_backward: Gradient support for vLLM's deterministic operations - simple_rl: End-to-end RL training loop +- TorchTitanVLLMModel: Generic wrapper for TorchTitan models with vLLM + +For vLLM inference with TorchTitan models, see: +- models/base_wrapper.py: Core vLLM wrapper +- models/__init__.py: Auto-registration with vLLM +- infer.py: Example inference script """ from .batch_invariant_backward import ( @@ -23,12 +29,15 @@ silu_and_mul_with_gradients, ) from .models import VLLMCompatibleFlashAttention +from .models.base_wrapper import TorchTitanVLLMModel from .models.qwen3 import Qwen3VLLMCompatModel + __all__ = [ "VLLMCompatibleFlashAttention", "Qwen3VLLMCompatModel", "enable_batch_invariant_backward_mode", "rms_norm_with_gradients", "silu_and_mul_with_gradients", + "TorchTitanVLLMModel", ] diff --git a/torchtitan/experiments/deterministic_vllm_rl/infer.py b/torchtitan/experiments/deterministic_vllm_rl/infer.py index e9517a456d..fa311b8b6e 100755 --- a/torchtitan/experiments/deterministic_vllm_rl/infer.py +++ b/torchtitan/experiments/deterministic_vllm_rl/infer.py @@ -5,23 +5,12 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -""" -Example CLI to run TorchTitan Qwen3 model inference with vLLM: - -# Run inference -python torchtitan/experiments/deterministic_vllm_rl/infer.py -""" - import argparse from vllm import LLM, SamplingParams -# Import and register the TorchTitan vLLM plugin -from torchtitan.experiments.deterministic_vllm_rl.register import register - -# Register TorchTitan models with vLLM. -# NOTE(jianiw): We could use plug-in system instead: https://docs.vllm.ai/en/latest/design/plugin_system/ -register() +# Import models module - this automatically registers TorchTitan models with vLLM +from torchtitan.experiments.deterministic_vllm_rl import models # noqa: F401 def parse_args(): @@ -66,8 +55,11 @@ def main(): args = parse_args() print("=" * 80) - print("INITIALIZING vLLM WITH TORCHTITAN QWEN3 MODEL") + print("INITIALIZING vLLM WITH TORCHTITAN QWEN3 MODEL ") print("=" * 80) + print(f"Model: {args.model}") + print(f"Tensor Parallel Size: {args.tensor_parallel_size}") + print() # Build hf_overrides with checkpoint path hf_overrides = { @@ -75,22 +67,30 @@ def main(): } # Initialize vLLM with custom TorchTitan Qwen3 model + # The LLM initialization will internally: + # 1. Load TrainSpec for Qwen3 (from register()) + # 2. Create TorchTitanVLLMModel instance + # 3. Process parallelism settings via process_parallelism_settings() + # 4. Build device mesh and apply parallelization via build_device_mesh_and_parallelize() + # 5. Load model weights and prepare for inference + print("Initializing vLLM engine...") llm = LLM( - model=args.model, # Use temporary directory with config.json + model=args.model, # Model checkpoint path hf_overrides=hf_overrides, dtype="bfloat16", trust_remote_code=True, - enforce_eager=True, # Use eager mode for debugging - # Disable kv cache, required for now - enable_prefix_caching=False, + enforce_eager=True, # Use eager mode + enable_prefix_caching=False, # Disable kv cache for now tensor_parallel_size=args.tensor_parallel_size, # Multi-GPU support ) print("=" * 80) - print("vLLM ENGINE INITIALIZED - STARTING GENERATION") + print("vLLM ENGINE INITIALIZED - CONFIGURATION DETAILS") print("=" * 80) + print(f"Prompt: {args.prompt}") + print() - # Prepare prompt + # Prepare prompt and sampling parameters prompts = [args.prompt] sampling_params = SamplingParams( temperature=args.temperature, @@ -98,7 +98,7 @@ def main(): max_tokens=args.max_tokens, ) - # Generate + # Generate text outputs = llm.generate( prompts=prompts, sampling_params=sampling_params, @@ -109,8 +109,9 @@ def main(): prompt = output.prompt generated_text = output.outputs[0].text - print(f"\nPrompt: {prompt}") + print(f"Prompt: {prompt}") print(f"Generated text: {generated_text!r}") + print() if __name__ == "__main__": diff --git a/torchtitan/experiments/deterministic_vllm_rl/models/__init__.py b/torchtitan/experiments/deterministic_vllm_rl/models/__init__.py index 823ec82ad7..7c84d7efb3 100644 --- a/torchtitan/experiments/deterministic_vllm_rl/models/__init__.py +++ b/torchtitan/experiments/deterministic_vllm_rl/models/__init__.py @@ -6,8 +6,77 @@ """ Models for deterministic vLLM RL training. + +This module automatically registers TorchTitan models with vLLM when imported. """ +from vllm.logger import init_logger + +from torchtitan.protocols.train_spec import get_train_spec, TrainSpec from .attention import VLLMCompatibleFlashAttention, VLLMPagedFlashAttention +from .base_wrapper import TorchTitanVLLMModel + + +logger = init_logger(__name__) + + +def register_torchtitan_model_from_train_spec( + train_spec: TrainSpec, + model_name: str, +) -> None: + """ + Register a TorchTitan model with vLLM using a TrainSpec. + + Args: + train_spec: TorchTitan TrainSpec containing model components + model_name: Name to register in vLLM (e.g., "Qwen3TorchTitanForCausalLM") + + """ + from vllm.model_executor.models.registry import ModelRegistry + + # Extract model_args from TrainSpec + # TrainSpec has model_args as a Mapping, get the first value + if isinstance(train_spec.model_args, dict): + model_args_cls = type(next(iter(train_spec.model_args.values()))) + else: + model_args_cls = train_spec.model_args + + # Create dynamic model class directly from TrainSpec components + class TorchTitanVLLMModelFromSpec(TorchTitanVLLMModel): + """Dynamically created vLLM model from TrainSpec.""" + + def __init__(self, *, vllm_config, prefix=""): + super().__init__( + model_cls=train_spec.model_cls, + model_args_cls=model_args_cls, + state_dict_adapter=train_spec.state_dict_adapter, + parallelize_fn=train_spec.parallelize_fn, + vllm_config=vllm_config, + prefix=prefix, + ) + + # Set the class name + TorchTitanVLLMModelFromSpec.__name__ = model_name + TorchTitanVLLMModelFromSpec.__qualname__ = model_name + + # Register with vLLM + ModelRegistry.register_model(model_name, TorchTitanVLLMModelFromSpec) + + logger.info( + f"Successfully registered {model_name} with vLLM using TrainSpec " + f"(model_cls={train_spec.model_cls.__name__})" + ) + + +# Auto-register TorchTitan models with vLLM when this module is imported +register_torchtitan_model_from_train_spec( + train_spec=get_train_spec("qwen3"), + model_name="Qwen3TorchTitanForCausalLM", +) + -__all__ = ["VLLMCompatibleFlashAttention", "VLLMPagedFlashAttention"] +__all__ = [ + "VLLMCompatibleFlashAttention", + "VLLMPagedFlashAttention", + "TorchTitanVLLMModel", +] diff --git a/torchtitan/experiments/deterministic_vllm_rl/models/base_wrapper.py b/torchtitan/experiments/deterministic_vllm_rl/models/base_wrapper.py index 79e6c1fc47..aabde774f2 100644 --- a/torchtitan/experiments/deterministic_vllm_rl/models/base_wrapper.py +++ b/torchtitan/experiments/deterministic_vllm_rl/models/base_wrapper.py @@ -5,23 +5,22 @@ # LICENSE file in the root directory of this source tree. """ -Generic vLLM wrapper for TorchTitan models. - -This module provides TorchTitanVLLMWrapper, a generic base class that makes -any TorchTitan model compatible with vLLM by: -1. Accepting 4 pluggable model-specific components -2. Replacing attention with vLLM's paged attention -3. Setting up parallelization (TP/EP) -4. Loading weights from HuggingFace checkpoints +Base wrapper for TorchTitan models to work with vLLM V1 engine. + +This module provides TorchTitanVLLMModel: Core model class that adapts +TorchTitan models for vLLM. """ -import functools -from typing import Callable +from functools import partial +from typing import Callable, TypeAlias import torch import torch.nn as nn -from torch.distributed.device_mesh import init_device_mesh -from torch.distributed.tensor import DTensor +from torch.distributed._tensor import DTensor, Replicate +from torch.distributed.checkpoint.state_dict import ( + set_model_state_dict, + StateDictOptions, +) from vllm.config import VllmConfig from vllm.logger import init_logger @@ -29,28 +28,25 @@ from torchtitan.experiments.deterministic_vllm_rl.models.attention import ( VLLMPagedFlashAttention, ) -from torchtitan.protocols.model import BaseModelArgs +from torchtitan.experiments.deterministic_vllm_rl.models.utils import ( + create_parallel_dims_from_vllm_config, +) +from torchtitan.models.qwen3.model.model import precompute_rope_cache +from torchtitan.protocols.model import BaseModelArgs, ModelProtocol from torchtitan.protocols.state_dict_adapter import BaseStateDictAdapter -from torchtitan.tools.utils import device_type logger = init_logger(__name__) +ParallelizeFunction: TypeAlias = Callable[..., nn.Module] -class TorchTitanVLLMWrapper(nn.Module): - """ - Generic vLLM-compatible wrapper for TorchTitan models. - - This base class integrates any TorchTitan model with vLLM by accepting - 4 pluggable model-specific components: - 1. model_cls: The TorchTitan model class (e.g., Qwen3Model) - 2. model_args_cls: The model args class (e.g., Qwen3ModelArgs) - 3. state_dict_adapter: State dict adapter for loading HF weights - 4. parallelize_fn: Function to apply tensor parallelism +class TorchTitanVLLMModel(nn.Module): + """ + Generic vLLM-compatible model wrapper for TorchTitan models. The wrapper handles: - - HF config → TorchTitan model args mapping + - HF config to TorchTitan model args mapping - Attention replacement with vLLM paged attention - Tensor parallelism setup - Weight loading from HF checkpoints @@ -64,11 +60,10 @@ class TorchTitanVLLMWrapper(nn.Module): def __init__( self, *, - model_cls: type, - model_args_cls: BaseModelArgs, - state_dict_adapter: BaseStateDictAdapter, - parallelize_fn: Callable, - rope_cache_compute_fn: Callable | None = None, + model_cls: type[ModelProtocol], # passing types that is not instantiated + model_args_cls: type[BaseModelArgs], + state_dict_adapter: type[BaseStateDictAdapter], + parallelize_fn: ParallelizeFunction, vllm_config: VllmConfig, prefix: str = "", ): @@ -84,36 +79,41 @@ def __init__( # Map HF config to TorchTitan ModelArgs hf_config = vllm_config.model_config.hf_config - logger.info(f"Mapping HF config to {model_args_cls.__name__}") - model_args = self._map_hf_config_to_model_args(hf_config, model_args_cls) + logger.info(f"Mapping HF config to {self.model_args_cls.__name__}") + model_args = self._map_hf_config_to_model_args(hf_config, self.model_args_cls) # Create TorchTitan model - logger.info(f"Creating {model_cls.__name__} with config: {model_args}") - self.model = model_cls(model_args) + logger.info(f"Creating {self.model_cls.__name__} with config: {model_args}") + self.model = self.model_cls(model_args) self.config = model_args - # NOTE: Here's assumptions of rope_cache_compute_fn function signature - self.rope_cache_extension_fn = functools.partial( - rope_cache_compute_fn, dim=self.config.head_dim, base=self.config.rope_theta + # Setup RoPE cache extension function if provided + self.rope_cache_extension_fn = partial( + precompute_rope_cache, + dim=self.config.head_dim, + base=self.config.rope_theta, ) - # Replace attention with vLLM paged attention self._replace_with_vllm_paged_attention(model_args) - # Setup parallelization - ( - dp_size, - mp_size, - cp_size, - pp_size, - ep_size, - etp_size, - ) = self._process_parallelism_settings(vllm_config) - - if mp_size > 1 or ep_size > 1: - self._build_device_mesh_and_parallelize( - dp_size, mp_size, cp_size, pp_size, ep_size, etp_size + # Create ParallelDims from vLLM config and apply parallelization + # NOTE: We need to apply parallelize within model.__init__ because w + parallel_dims = create_parallel_dims_from_vllm_config(vllm_config) + if parallel_dims.tp_enabled: + self.world_mesh = parallel_dims.world_mesh + tp_mesh = self.world_mesh["tp"] + parallelize_fn( + model=self.model, + tp_mesh=tp_mesh, + loss_parallel=False, + enable_float8_tensorwise_tp=False, + enable_async_tp=False, + ) + logger.info( + f"Successfully initialized model with with TP={parallel_dims.tp}" ) + else: + logger.info("Single GPU mode - no parallelization needed") def _map_hf_config_to_model_args(self, hf_config, model_args_cls): """ @@ -178,85 +178,12 @@ def _replace_with_vllm_paged_attention(self, model_args): "Successfully replaced TorchTitan attention with vLLM PagedFlashAttention" ) - def _process_parallelism_settings(self, vllm_config: VllmConfig): - """Parse parallel config from vLLM config.""" - world_size = ( - vllm_config.parallel_config.data_parallel_size - * vllm_config.parallel_config.tensor_parallel_size - ) - ep_size = ( - world_size if vllm_config.parallel_config.enable_expert_parallel else 1 - ) - etp_size = ( - 1 if vllm_config.parallel_config.enable_expert_parallel else world_size - ) - dp_size = vllm_config.parallel_config.data_parallel_size - mp_size = vllm_config.parallel_config.tensor_parallel_size - cp_size = vllm_config.parallel_config.decode_context_parallel_size - pp_size = vllm_config.parallel_config.pipeline_parallel_size - self.pp_size = pp_size - - return dp_size, mp_size, cp_size, pp_size, ep_size, etp_size - - def _build_device_mesh_and_parallelize( - self, - dp_size: int, - mp_size: int, - cp_size: int, - pp_size: int, - ep_size: int, - etp_size: int, - ): - """ - Build device mesh and apply parallelization using the provided parallelize_fn. - """ - import torch.distributed as dist - - world_size = dist.get_world_size() - dp_replicate = dp_size - dp_shard = 1 - - assert dp_replicate * dp_shard * cp_size * mp_size * pp_size == world_size, ( - f"Invalid parallel dims: dp_replicate({dp_replicate}) * dp_shard({dp_shard}) * " - f"cp({cp_size}) * tp({mp_size}) * pp({pp_size}) != WORLD_SIZE({world_size})" - ) - - # Build device mesh - dims = [] - names = [] - for d, name in zip( - [pp_size, dp_replicate, dp_shard, cp_size, mp_size], - ["pp", "dp_replicate", "dp_shard", "cp", "tp"], - ): - if d > 1: - dims.append(d) - names.append(name) - - logger.info(f"Building {len(dims)}-D device mesh with {names}, {dims}") - world_mesh = init_device_mesh(device_type, dims, mesh_dim_names=names) - - # Apply tensor parallelism using provided function - if mp_size > 1: - tp_mesh = world_mesh["tp"] - self.parallelize_fn( - model=self.model, - tp_mesh=tp_mesh, - loss_parallel=False, - enable_float8_tensorwise_tp=False, - enable_async_tp=False, - ) - logger.info(f"Applied Tensor Parallelism with TP={mp_size}") - - self.world_mesh = world_mesh - def _extend_rope_cache_if_needed( self, rope_cache: torch.Tensor, max_position: int ) -> torch.Tensor: """ Extend RoPE cache if needed during vLLM profiling. - Uses the rope_cache_extension_fn provided during initialization if available. - Args: rope_cache: Current RoPE cache tensor max_position: Maximum position index needed @@ -412,12 +339,6 @@ def load_weights(self, weights_iter): Returns: Set of loaded parameter names """ - from torch.distributed._tensor import DTensor, Replicate - from torch.distributed.checkpoint.state_dict import ( - set_model_state_dict, - StateDictOptions, - ) - # Collect weights from iterator hf_state_dict = {} for name, tensor in weights_iter: diff --git a/torchtitan/experiments/deterministic_vllm_rl/models/utils.py b/torchtitan/experiments/deterministic_vllm_rl/models/utils.py new file mode 100644 index 0000000000..e3657a9e92 --- /dev/null +++ b/torchtitan/experiments/deterministic_vllm_rl/models/utils.py @@ -0,0 +1,104 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +Parallelization utilities for vLLM + TorchTitan models. + +This module provides functions for setting up device mesh and applying +tensor parallelism to TorchTitan models in vLLM using TorchTitan's ParallelDims. +""" + +import torch.distributed as dist +from vllm.config import VllmConfig +from vllm.logger import init_logger + +from torchtitan.distributed.parallel_dims import ParallelDims + + +logger = init_logger(__name__) + + +def create_parallel_dims_from_vllm_config(vllm_config: VllmConfig) -> ParallelDims: + """ + Create ParallelDims from vLLM config. + + Maps vLLM parallelism settings to TorchTitan's ParallelDims dataclass. + + Args: + vllm_config: vLLM configuration object + + Returns: + ParallelDims object with parallelism settings validated + + Note: + vLLM doesn't use FSDP sharding (dp_shard=1) or expert parallelism (ep=1, etp=1) + in inference. These are set to default values. + """ + world_size = dist.get_world_size() + + # Map vLLM config to TorchTitan ParallelDims + parallel_dims = ParallelDims( + dp_replicate=vllm_config.parallel_config.data_parallel_size, + dp_shard=1, # vLLM doesn't use FSDP sharding + cp=vllm_config.parallel_config.decode_context_parallel_size, + tp=vllm_config.parallel_config.tensor_parallel_size, + pp=vllm_config.parallel_config.pipeline_parallel_size, + ep=1, # Expert parallelism not used in vLLM inference yet + etp=1, # Expert tensor parallelism not used in vLLM inference yet + world_size=world_size, + ) + + logger.info( + f"Created ParallelDims from vLLM config: " + f"DP={parallel_dims.dp_replicate}, TP={parallel_dims.tp}, " + f"CP={parallel_dims.cp}, PP={parallel_dims.pp}" + ) + + return parallel_dims + + +def build_device_mesh_and_parallelize( + model, + parallelize_fn, + parallel_dims: ParallelDims, +): + """ + Build device mesh and apply parallelization to the model. + + Uses TorchTitan's ParallelDims to build the device mesh with proper validation + and submesh creation, then applies tensor parallelism to the model using the + provided parallelize function. + + Args: + model: The TorchTitan model to parallelize + parallelize_fn: Function to apply tensor parallelism (e.g., apply_qwen3_tp) + parallel_dims: ParallelDims object with validated parallelism settings + + Returns: + The device mesh object + """ + # Use ParallelDims to build the device mesh + # This handles all the complexity of: + # - Validation of parallel dimensions + # - Building multi-dimensional device mesh + # - Creating all required submeshes (dp, dp_shard_cp, dp_cp, etc.) + world_mesh = parallel_dims.world_mesh + + logger.info(f"Built device mesh using ParallelDims: {world_mesh}") + + # Apply tensor parallelism using provided function + if parallel_dims.tp_enabled: + tp_mesh = world_mesh["tp"] + parallelize_fn( + model=model, + tp_mesh=tp_mesh, + loss_parallel=False, + enable_float8_tensorwise_tp=False, + enable_async_tp=False, + ) + logger.info(f"Applied Tensor Parallelism with TP={parallel_dims.tp}") + + return world_mesh diff --git a/torchtitan/experiments/deterministic_vllm_rl/register.py b/torchtitan/experiments/deterministic_vllm_rl/register.py deleted file mode 100644 index 17a94e8821..0000000000 --- a/torchtitan/experiments/deterministic_vllm_rl/register.py +++ /dev/null @@ -1,57 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - - -def register(): - """ - Register TorchTitan models with vLLM using class inheritance pattern. - - This function registers TorchTitan-trained models with vLLM's model registry - by creating subclasses of TorchTitanVLLMWrapper and passing in the 5 core - model-specific components: - - 1. model_cls - The TorchTitan model class (e.g., Qwen3Model, Transformer) - 2. model_args_cls - The model args class (e.g., Qwen3ModelArgs) - 3. state_dict_adapter - State dict adapter for loading HF weights - 4. parallelize_fn - Function to apply tensor parallelism - 5. rope_cache_extension_fn - Optional function to extend RoPE cache - - """ - from vllm.model_executor.models.registry import ModelRegistry - - from torchtitan.experiments.deterministic_vllm_rl.models.base_wrapper import ( - TorchTitanVLLMWrapper, - ) - - from torchtitan.models.qwen3 import Qwen3Model, Qwen3ModelArgs - from torchtitan.models.qwen3.infra.parallelize import ( - apply_non_moe_tp as apply_qwen3_tp, - ) - from torchtitan.models.qwen3.model.model import precompute_rope_cache - from torchtitan.models.qwen3.model.state_dict_adapter import Qwen3StateDictAdapter - - class Qwen3TorchTitanForCausalLM(TorchTitanVLLMWrapper): - """ - vLLM wrapper for TorchTitan-trained Qwen3 models. - - This class plugs in the 5 Qwen3-specific components into the - generic TorchTitanVLLMWrapper. - """ - - def __init__(self, *, vllm_config, prefix=""): - super().__init__( - model_cls=Qwen3Model, - model_args_cls=Qwen3ModelArgs, - state_dict_adapter=Qwen3StateDictAdapter, - parallelize_fn=apply_qwen3_tp, - rope_cache_compute_fn=precompute_rope_cache, - vllm_config=vllm_config, - prefix=prefix, - ) - - ModelRegistry.register_model( - "Qwen3TorchTitanForCausalLM", Qwen3TorchTitanForCausalLM - ) From c8f12a11b6a719b00cbb82f66c91bd888af314f1 Mon Sep 17 00:00:00 2001 From: Jiani Wang Date: Tue, 9 Dec 2025 20:14:15 -0800 Subject: [PATCH 16/19] comments --- .../deterministic_vllm_rl/README.md | 2 +- .../deterministic_vllm_rl/__init__.py | 4 +- .../deterministic_vllm_rl/infer.py | 56 ++++---- .../deterministic_vllm_rl/models/__init__.py | 2 +- .../deterministic_vllm_rl/models/attention.py | 136 ++++++++---------- .../{base_wrapper.py => vllm_wrapper.py} | 9 +- 6 files changed, 94 insertions(+), 115 deletions(-) rename torchtitan/experiments/deterministic_vllm_rl/models/{base_wrapper.py => vllm_wrapper.py} (97%) diff --git a/torchtitan/experiments/deterministic_vllm_rl/README.md b/torchtitan/experiments/deterministic_vllm_rl/README.md index 741a47127b..02107db5df 100644 --- a/torchtitan/experiments/deterministic_vllm_rl/README.md +++ b/torchtitan/experiments/deterministic_vllm_rl/README.md @@ -260,7 +260,7 @@ export LD_LIBRARY_PATH=/usr/local/cuda-12.4/lib64:$LD_LIBRARY_PATH # Clean previous build rm -rf build dist *.egg-info -pip uninstall -y vllm +uv pip uninstall -y vllm # Rebuild vLLM from source with CUDA 12.4 pip install -e . diff --git a/torchtitan/experiments/deterministic_vllm_rl/__init__.py b/torchtitan/experiments/deterministic_vllm_rl/__init__.py index 0e6c76c6e5..8525e2a387 100644 --- a/torchtitan/experiments/deterministic_vllm_rl/__init__.py +++ b/torchtitan/experiments/deterministic_vllm_rl/__init__.py @@ -18,7 +18,7 @@ - TorchTitanVLLMModel: Generic wrapper for TorchTitan models with vLLM For vLLM inference with TorchTitan models, see: -- models/base_wrapper.py: Core vLLM wrapper +- models/vllm_wrapper.py: Core vLLM wrapper - models/__init__.py: Auto-registration with vLLM - infer.py: Example inference script """ @@ -29,8 +29,8 @@ silu_and_mul_with_gradients, ) from .models import VLLMCompatibleFlashAttention -from .models.base_wrapper import TorchTitanVLLMModel from .models.qwen3 import Qwen3VLLMCompatModel +from .models.vllm_wrapper import TorchTitanVLLMModel __all__ = [ diff --git a/torchtitan/experiments/deterministic_vllm_rl/infer.py b/torchtitan/experiments/deterministic_vllm_rl/infer.py index fa311b8b6e..3863b228bd 100755 --- a/torchtitan/experiments/deterministic_vllm_rl/infer.py +++ b/torchtitan/experiments/deterministic_vllm_rl/infer.py @@ -8,18 +8,22 @@ import argparse from vllm import LLM, SamplingParams +from vllm.logger import init_logger # Import models module - this automatically registers TorchTitan models with vLLM from torchtitan.experiments.deterministic_vllm_rl import models # noqa: F401 +logger = init_logger(__name__) + + def parse_args(): parser = argparse.ArgumentParser( - description="Run TorchTitan Qwen3 model inference with vLLM", + description="Run TorchTitan model inference with vLLM Engine", formatter_class=argparse.ArgumentDefaultsHelpFormatter, ) parser.add_argument( - "--model", + "--model_ckpt_path", type=str, default="torchtitan/experiments/deterministic_vllm_rl/example_checkpoint/qwen3-0.6B", help="Path to TorchTitan checkpoint directory", @@ -54,41 +58,32 @@ def parse_args(): def main(): args = parse_args() - print("=" * 80) - print("INITIALIZING vLLM WITH TORCHTITAN QWEN3 MODEL ") - print("=" * 80) - print(f"Model: {args.model}") - print(f"Tensor Parallel Size: {args.tensor_parallel_size}") - print() + logger.info("Initializing vLLM with TorchTitan model") + logger.info(f"Model: {args.model_ckpt_path}") + logger.info(f"Tensor Parallel Size: {args.tensor_parallel_size}") - # Build hf_overrides with checkpoint path - hf_overrides = { - "checkpoint_dir": args.model, - } - - # Initialize vLLM with custom TorchTitan Qwen3 model + # Initialize vLLM with custom TorchTitan model # The LLM initialization will internally: - # 1. Load TrainSpec for Qwen3 (from register()) + # 1. Load TrainSpec for Qwen3 (from models/__init__.py register()) # 2. Create TorchTitanVLLMModel instance - # 3. Process parallelism settings via process_parallelism_settings() - # 4. Build device mesh and apply parallelization via build_device_mesh_and_parallelize() + # 3. Create JobConfig and ParallelDims from vLLM config + # 4. Apply parallelization using parallelize_qwen3 # 5. Load model weights and prepare for inference - print("Initializing vLLM engine...") + logger.info("Creating vLLM LLM engine...") + llm = LLM( - model=args.model, # Model checkpoint path - hf_overrides=hf_overrides, + model=args.model_ckpt_path, # Model checkpoint path + hf_overrides={ + "checkpoint_dir": args.model_ckpt_path, + }, dtype="bfloat16", trust_remote_code=True, enforce_eager=True, # Use eager mode - enable_prefix_caching=False, # Disable kv cache for now - tensor_parallel_size=args.tensor_parallel_size, # Multi-GPU support + tensor_parallel_size=args.tensor_parallel_size, ) - print("=" * 80) - print("vLLM ENGINE INITIALIZED - CONFIGURATION DETAILS") - print("=" * 80) - print(f"Prompt: {args.prompt}") - print() + logger.info("vLLM engine initialized successfully") + logger.info(f"Prompt: {args.prompt}") # Prepare prompt and sampling parameters prompts = [args.prompt] @@ -99,19 +94,20 @@ def main(): ) # Generate text + logger.info("Generating text...") outputs = llm.generate( prompts=prompts, sampling_params=sampling_params, ) # Print results + logger.info("Generation complete") for output in outputs: prompt = output.prompt generated_text = output.outputs[0].text - print(f"Prompt: {prompt}") - print(f"Generated text: {generated_text!r}") - print() + print(f"\nPrompt: {prompt}") + print(f"Generated text: {generated_text!r}\n") if __name__ == "__main__": diff --git a/torchtitan/experiments/deterministic_vllm_rl/models/__init__.py b/torchtitan/experiments/deterministic_vllm_rl/models/__init__.py index 7c84d7efb3..080c028036 100644 --- a/torchtitan/experiments/deterministic_vllm_rl/models/__init__.py +++ b/torchtitan/experiments/deterministic_vllm_rl/models/__init__.py @@ -14,7 +14,7 @@ from torchtitan.protocols.train_spec import get_train_spec, TrainSpec from .attention import VLLMCompatibleFlashAttention, VLLMPagedFlashAttention -from .base_wrapper import TorchTitanVLLMModel +from .vllm_wrapper import TorchTitanVLLMModel logger = init_logger(__name__) diff --git a/torchtitan/experiments/deterministic_vllm_rl/models/attention.py b/torchtitan/experiments/deterministic_vllm_rl/models/attention.py index 49ce1ccd4a..478884f9b7 100644 --- a/torchtitan/experiments/deterministic_vllm_rl/models/attention.py +++ b/torchtitan/experiments/deterministic_vllm_rl/models/attention.py @@ -228,40 +228,35 @@ def __init__( # Handle tensor parallelism: adjust num_heads and num_kv_heads for TP # NOTE(jianiw): As we use local tensor for this region, we need to manually - try: - from vllm.config import get_current_vllm_config - from vllm.logger import init_logger - - logger = init_logger(__name__) - vllm_config = get_current_vllm_config() - tp_size = vllm_config.parallel_config.tensor_parallel_size - - if tp_size > 1: - if num_kv_heads % tp_size != 0: - # Pad num_kv_heads and num_heads to be divisible by tp_size - assert num_heads % num_kv_heads == 0 - padded_size = tp_size - num_kv_heads % tp_size - padded_num_kv_heads = num_kv_heads + padded_size - padded_num_heads = ( - num_heads + padded_size * num_heads // num_kv_heads - ) - assert padded_num_heads % tp_size == 0 - assert padded_num_kv_heads % tp_size == 0 - - logger.info( - f"Padding attention heads for tensor parallelism: " - f"{num_heads=}, {padded_num_heads=}, " - f"{num_kv_heads=}, {padded_num_kv_heads=}" - ) - - num_heads = padded_num_heads // tp_size - num_kv_heads = padded_num_kv_heads // tp_size - else: - num_heads //= tp_size - num_kv_heads //= tp_size - except (ImportError, RuntimeError, AttributeError): - # Not in vLLM context - use original values - pass + + from vllm.config import get_current_vllm_config + from vllm.logger import init_logger + + logger = init_logger(__name__) + vllm_config = get_current_vllm_config() + tp_size = vllm_config.parallel_config.tensor_parallel_size + + if tp_size > 1: + if num_kv_heads % tp_size != 0: + # Pad num_kv_heads and num_heads to be divisible by tp_size + assert num_heads % num_kv_heads == 0 + padded_size = tp_size - num_kv_heads % tp_size + padded_num_kv_heads = num_kv_heads + padded_size + padded_num_heads = num_heads + padded_size * num_heads // num_kv_heads + assert padded_num_heads % tp_size == 0 + assert padded_num_kv_heads % tp_size == 0 + + logger.info( + f"Padding attention heads for tensor parallelism: " + f"{num_heads=}, {padded_num_heads=}, " + f"{num_kv_heads=}, {padded_num_kv_heads=}" + ) + + num_heads = padded_num_heads // tp_size + num_kv_heads = padded_num_kv_heads // tp_size + else: + num_heads //= tp_size + num_kv_heads //= tp_size self.num_heads = num_heads self.num_kv_heads = num_kv_heads @@ -274,32 +269,25 @@ def __init__( self.scale = scale # Create vLLM Attention layer - try: - from vllm.config import get_current_vllm_config - - config = get_current_vllm_config() - cache_config = ( - config.cache_config if hasattr(config, "cache_config") else None - ) - - # Generate unique prefix for this attention layer - # vLLM expects format "layers.X" for layer index extraction - layer_idx = next(VLLMPagedFlashAttention._layer_counter) - prefix = f"layers.{layer_idx}" - - self.vllm_attn = Attention( - num_heads=num_heads, - head_size=head_dim, - scale=self.scale, - num_kv_heads=num_kv_heads, - cache_config=cache_config, - quant_config=None, - prefix=prefix, - ) - - except (ImportError, RuntimeError, AttributeError): - # Not in vLLM context - will need to set up manually - self.vllm_attn = None + from vllm.config import get_current_vllm_config + + config = get_current_vllm_config() + cache_config = config.cache_config if hasattr(config, "cache_config") else None + + # Generate unique prefix for this attention layer + # vLLM expects format "layers.X" for layer index extraction + layer_idx = next(VLLMPagedFlashAttention._layer_counter) + prefix = f"layers.{layer_idx}" + + self.vllm_attn = Attention( + num_heads=num_heads, + head_size=head_dim, + scale=self.scale, + num_kv_heads=num_kv_heads, + cache_config=cache_config, + quant_config=None, + prefix=prefix, + ) # KV cache - will be populated by vLLM during model loading self.kv_cache: list[torch.Tensor] | None = None @@ -317,26 +305,20 @@ def _auto_register_for_kv_cache(self): # Initialize layer_name attribute self.layer_name: str | None = None - try: - from vllm.config import get_current_vllm_config - - config = get_current_vllm_config() - compilation_config = config.compilation_config + from vllm.config import get_current_vllm_config - # Generate unique layer name using class counter - # Format: "layers.{index}" for compatibility with extract_layer_index() - layer_name = f"layers.{next(VLLMPagedFlashAttention._layer_counter)}" + config = get_current_vllm_config() + compilation_config = config.compilation_config - # Register this layer in static forward context - if layer_name in compilation_config.static_forward_context: - raise ValueError(f"Duplicate layer name: {layer_name}") - compilation_config.static_forward_context[layer_name] = self - self.layer_name = layer_name + # Generate unique layer name using class counter + # Format: "layers.{index}" for compatibility with extract_layer_index() + layer_name = f"layers.{next(VLLMPagedFlashAttention._layer_counter)}" - except (ImportError, RuntimeError, AttributeError): - # Not in vLLM context - this is fine! - # Layer will work normally for training/inference without vLLM - pass + # Register this layer in static forward context + if layer_name in compilation_config.static_forward_context: + raise ValueError(f"Duplicate layer name: {layer_name}") + compilation_config.static_forward_context[layer_name] = self + self.layer_name = layer_name def forward( self, diff --git a/torchtitan/experiments/deterministic_vllm_rl/models/base_wrapper.py b/torchtitan/experiments/deterministic_vllm_rl/models/vllm_wrapper.py similarity index 97% rename from torchtitan/experiments/deterministic_vllm_rl/models/base_wrapper.py rename to torchtitan/experiments/deterministic_vllm_rl/models/vllm_wrapper.py index aabde774f2..1fcf81e402 100644 --- a/torchtitan/experiments/deterministic_vllm_rl/models/base_wrapper.py +++ b/torchtitan/experiments/deterministic_vllm_rl/models/vllm_wrapper.py @@ -298,18 +298,18 @@ def forward( max_position = 0 rope_cache = self._extend_rope_cache_if_needed(rope_attr, max_position) - rope_cache = rope_cache[positions] + positions = positions.unsqueeze(0) # Pass through transformer layers for layer in self.model.layers.values(): - h = layer(h, rope_cache, attention_masks=None) + h = layer(h, rope_cache, attention_masks=None, positions=positions) # Convert to vLLM format: [total_tokens, hidden_size] if h.dim() == 3: batch_size, seq_len, hidden_size = h.shape h = h.view(batch_size * seq_len, hidden_size) - # Convert DTensor to regular tensor + # TODO(jianiw): finish this conversion when TP is applied if isinstance(h, DTensor): h = h.full_tensor() @@ -324,8 +324,9 @@ def compute_logits( h = self.model.norm(hidden_states) logits = self.model.output(h) + # TODO: This part is to work with TP integration if isinstance(logits, DTensor): - logits = logits.full_tensor() + logits = logits.to_local() return logits From 810de594fa579a1e83296fb54a1dc262b6178891 Mon Sep 17 00:00:00 2001 From: Jiani Wang Date: Mon, 15 Dec 2025 00:41:57 -0800 Subject: [PATCH 17/19] refactor v3 --- .../deterministic_vllm_rl/README.md | 314 +----------------- .../deterministic_vllm_rl/__init__.py | 36 +- .../deterministic_vllm_rl/unified/README.md | 67 ++++ .../{models => unified}/__init__.py | 42 ++- .../unified/attention.py | 93 ++++++ .../{ => unified}/infer.py | 9 +- .../{models => unified}/utils.py | 51 +-- .../{models => unified}/vllm_wrapper.py | 94 ++---- .../vllm_compat/README.md | 262 +++++++++++++++ .../vllm_compat/__init__.py | 29 ++ .../batch_invariant_backward.py | 0 .../vllm_compat/models/__init__.py | 18 + .../{ => vllm_compat}/models/attention.py | 174 ---------- .../models/qwen3/__init__.py | 2 +- .../models/qwen3/model_vllm_compat.py} | 2 +- .../{ => vllm_compat}/simple_rl.py | 8 +- .../{ => vllm_compat}/tests/__init__.py | 0 .../tests/test_batch_invariant_backward.py | 4 +- .../tests/test_exact_determinism.py | 2 +- .../{ => vllm_compat}/weights/README.md | 0 .../{ => vllm_compat}/weights/__init__.py | 0 .../{ => vllm_compat}/weights/converter.py | 0 .../{ => vllm_compat}/weights_vllm_compat.py | 0 23 files changed, 562 insertions(+), 645 deletions(-) create mode 100644 torchtitan/experiments/deterministic_vllm_rl/unified/README.md rename torchtitan/experiments/deterministic_vllm_rl/{models => unified}/__init__.py (56%) create mode 100644 torchtitan/experiments/deterministic_vllm_rl/unified/attention.py rename torchtitan/experiments/deterministic_vllm_rl/{ => unified}/infer.py (90%) rename torchtitan/experiments/deterministic_vllm_rl/{models => unified}/utils.py (54%) rename torchtitan/experiments/deterministic_vllm_rl/{models => unified}/vllm_wrapper.py (78%) create mode 100644 torchtitan/experiments/deterministic_vllm_rl/vllm_compat/README.md create mode 100644 torchtitan/experiments/deterministic_vllm_rl/vllm_compat/__init__.py rename torchtitan/experiments/deterministic_vllm_rl/{ => vllm_compat}/batch_invariant_backward.py (100%) create mode 100644 torchtitan/experiments/deterministic_vllm_rl/vllm_compat/models/__init__.py rename torchtitan/experiments/deterministic_vllm_rl/{ => vllm_compat}/models/attention.py (54%) rename torchtitan/experiments/deterministic_vllm_rl/{ => vllm_compat}/models/qwen3/__init__.py (85%) rename torchtitan/experiments/deterministic_vllm_rl/{models/qwen3/model_batch_invariant.py => vllm_compat/models/qwen3/model_vllm_compat.py} (99%) rename torchtitan/experiments/deterministic_vllm_rl/{ => vllm_compat}/simple_rl.py (99%) rename torchtitan/experiments/deterministic_vllm_rl/{ => vllm_compat}/tests/__init__.py (100%) rename torchtitan/experiments/deterministic_vllm_rl/{ => vllm_compat}/tests/test_batch_invariant_backward.py (97%) rename torchtitan/experiments/deterministic_vllm_rl/{ => vllm_compat}/tests/test_exact_determinism.py (98%) rename torchtitan/experiments/deterministic_vllm_rl/{ => vllm_compat}/weights/README.md (100%) rename torchtitan/experiments/deterministic_vllm_rl/{ => vllm_compat}/weights/__init__.py (100%) rename torchtitan/experiments/deterministic_vllm_rl/{ => vllm_compat}/weights/converter.py (100%) rename torchtitan/experiments/deterministic_vllm_rl/{ => vllm_compat}/weights_vllm_compat.py (100%) diff --git a/torchtitan/experiments/deterministic_vllm_rl/README.md b/torchtitan/experiments/deterministic_vllm_rl/README.md index 02107db5df..c44bb45cfb 100644 --- a/torchtitan/experiments/deterministic_vllm_rl/README.md +++ b/torchtitan/experiments/deterministic_vllm_rl/README.md @@ -1,308 +1,12 @@ -# Deterministic RL Training with vLLM +# Deterministic vLLM RL Training -This experiment combines vLLM's deterministic kernels with PyTorch autograd to enable reinforcement learning training where forward passes produce bitwise-identical results across runs. +This package provides two approaches for integrating TorchTitan models with vLLM: -## Overview +1. vllm_compat/ - vLLM-Compatible approach + - Separate model definition matching vLLM's weight format + - Support batch-invariant and bit-wise identity between train and inference + - Custom backward passes for attention gradient computation -RL training requires both fast inference for generating rollouts and gradient computation for policy updates. vLLM provides deterministic forward passes but does not support gradients. This experiment adds backward passes to vLLM's operations. - -The implementation: -1. Uses vLLM's batch-invariant kernels for forward passes -2. Implements custom backward passes for gradient computation -3. Provides weight conversion utilities between TorchTitan and vLLM formats - -### Features - -- Bitwise determinism: Same inputs produce identical outputs across runs -- Gradient support: Backward passes through vLLM operations -- Weight conversion: Utilities to convert between model formats - -Note: Currently supports single-device training only. - -## Architecture - -### Components - -1. `models/attention.py`: VLLMCompatibleFlashAttention - - Uses vLLM's Flash Attention for forward pass - - Implements custom backward pass for gradient computation - - Uses `num_splits=1` for deterministic behavior - -2. `models/qwen3/model_batch_invariant.py`: Qwen3VLLMCompatModel - - Qwen3 model with merged gate/up projections matching vLLM format - - Uses VLLMRMSNorm with gradient support - -3. `batch_invariant_backward.py`: Backward passes for vLLM operations - - Registers gradients for vLLM's batch-invariant operations - - Supports matmul, linear, and RMSNorm - - Patches Flash Attention for autograd - -4. `weights_vllm_compat.py`: Weight conversion utilities - - Converts between TorchTitan format (separate w1, w2, w3) and vLLM format (merged gate_up_proj) - - Provides bidirectional conversion functions - -5. `simple_rl.py`: RL training loop - - Generates rollouts using vLLM engine - - Computes advantages using GRPO-style ranking - - Updates policy using PPO - -## Installation - -### Prerequisites - -```bash -# Install vLLM with deterministic support -pip install vllm - -# Install TorchTitan (from the repository root) -pip install -e . - -# Install additional dependencies -pip install transformers safetensors huggingface_hub tensorboard -``` - -### Enable Batch Invariance - -Initialize vLLM's batch-invariant mode before training: - -```python -from vllm.model_executor.layers.batch_invariant import init_batch_invariance -init_batch_invariance() -``` - -## Usage - -### Quick Start - -```python -import torch -from vllm.model_executor.layers.batch_invariant import init_batch_invariance -from torchtitan.experiments.deterministic_vllm_rl import ( - enable_batch_invariant_backward_mode, - Qwen3VLLMCompatModel, -) - -# 1. Enable deterministic mode -init_batch_invariance() -enable_batch_invariant_backward_mode() - -# 2. Load model -from torchtitan.models.qwen3.model.args import Qwen3ModelArgs -model_args = Qwen3ModelArgs( - dim=2048, - n_layers=24, - n_heads=16, - n_kv_heads=2, - vocab_size=151936, -) -model = Qwen3VLLMCompatModel(model_args) - -# 3. Forward pass (deterministic) -input_ids = torch.randint(0, 151936, (2, 128), device='cuda') -logits = model(input_ids) - -# 4. Backward pass -loss = logits.sum() -loss.backward() -``` - -### Full RL Training - -Run the RL training loop: - -```bash -VLLM_BATCH_INVARIANT=1 VLLM_FLASH_ATTN_VERSION=3 python -m torchtitan.experiments.deterministic_vllm_rl.simple_rl -``` - -This will: -1. Download Qwen3-1.7B from HuggingFace -2. Initialize vLLM engine for rollouts -3. Generate samples for training prompts -4. Compute rewards and advantages -5. Update the policy using PPO -6. Log metrics to TensorBoard - -View training progress: -```bash -tensorboard --logdir=./outputs/rl_training -``` - -## How It Works - -### Deterministic Forward Pass - -vLLM's batch-invariant mode makes operations deterministic: - -```python -# These operations are deterministic when batch_invariance is enabled -y = torch.matmul(a, b) # Uses vLLM's deterministic matmul -output = flash_attn_varlen_func(q, k, v, num_splits=1) # Deterministic FA -``` - -### Backward Pass with Gradients - -Custom backward passes: -1. Re-compute attention weights deterministically -2. Use standard chain rule for gradients -3. Apply gradients through vLLM's deterministic operations - -```python -class FlashAttnWithBackward(torch.autograd.Function): - @staticmethod - def forward(ctx, q, k, v, ...): - # Use vLLM's forward implementation - return flash_attn_varlen_func(q, k, v, num_splits=1, ...) - - @staticmethod - def backward(ctx, grad_output): - # Compute gradients deterministically - # (re-compute attention weights and apply chain rule) - return grad_q, grad_k, grad_v, ... -``` - -### Bitwise Determinism Verification - -The training loop compares logprobs from vLLM and TorchTitan: - -```python -# During training, compare logprobs -vllm_logprobs = [from vLLM rollout] -titan_logprobs = [from TorchTitan forward pass] - -assert torch.equal(vllm_logprobs, titan_logprobs) -``` - -## Testing - -Run the test suite: - -```bash -cd torchtitan/experiments/deterministic_vllm_rl/tests - -# Test backward passes -python test_batch_invariant_backward.py - -# Test determinism -python test_exact_determinism.py -``` - -## Technical Details - -### Why Determinism Matters for RL - -RL training steps: -1. Generate rollouts by sampling from the policy -2. Compute rewards based on the samples -3. Update the policy using gradients - -If the forward pass during training differs from the forward pass during rollout, policy gradients may be incorrect. This matters for algorithms like PPO that compare old and new policy probabilities. - -This implementation uses the same kernels for both rollouts (vLLM) and training (TorchTitan) to ensure `logprobs_rollout == logprobs_training` bitwise. - -### Performance - -- Rollout speed: Uses vLLM's optimized kernels -- Training speed: Similar to standard TorchTitan -- Memory: Saves activations for custom backward passes - -### Limitations - -1. Custom backward requires uniform sequence lengths -2. Only causal attention is supported -3. Requires NVIDIA GPUs with Flash Attention support - - -## TODO - -- `FlashAttnWithBackward` will need to become more composable and should not live exclusively within this directory. -- vLLM integration will need to become more generic with a provided Attention operator that is KV-cache compatible. -- vLLM parallelism will need to add generic parallelism initialization to support Monarch managed TP/DP. - -# Run vLLM inference with TorchTitan Qwen3 Model - -This directory contains code to run a single canonical model definition (TorchTitan model definition) with vLLM inference engine (not batch-invariant yet, working in progress). -This work is inspired by https://github.com/vllm-project/vllm/pull/28685. - -## Overview -The integration consists of two main components: - -1. **Model Adapter** (`model/qwen3.py`): A custom model class that extends vLLM's `Qwen3ForCausalLM` to handle TorchTitan checkpoint naming conventions -2. **Inference Script** (`infer.py`): A simple script to register the model and run inference - - -## Quick Start -### Prerequisites - -1. Install PyTorch nightly for torchtitan: -``` -pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu126 --force-reinstall -``` - - -2. Install vLLM from source [vllm-use-an-existing-pytorch-installation](https://docs.vllm.ai/en/latest/getting_started/installation/gpu/index.html#use-an-existing-pytorch-installation): -```bash -# install PyTorch first, either from PyPI or from source -git clone https://github.com/vllm-project/vllm.git -cd vllm -python use_existing_torch.py -uv pip install -r requirements/build.txt -uv pip install --no-build-isolation -e . -``` - - -NOTE: If `flash_attn_varlen_func` hits error "torch.AcceleratorError: CUDA error: the provided PTX was compiled with an unsupported toolchain" during forward path, this is due to GPU driver version is not compatible with vLLM/PyTorch compiled version. Use the following command to recompile vLLM. - -``` -# Set CUDA version environment variable -export CUDA_HOME=/usr/local/cuda-12.4 -export PATH=/usr/local/cuda-12.4/bin:$PATH -export LD_LIBRARY_PATH=/usr/local/cuda-12.4/lib64:$LD_LIBRARY_PATH - -# Clean previous build -rm -rf build dist *.egg-info -uv pip uninstall -y vllm - -# Rebuild vLLM from source with CUDA 12.4 -pip install -e . - -``` - -3. Download Qwen3/Qwen3-0.6b checkpoint from HuggingFace and put into `example_checkpoint` folder. Make sure to change the "architecture" field in `config.json` to be `Qwen3TorchTitanForCausalLM` so vllm engine could use torchtitan model. - - -4. Run inference: -``` -python torchtitan/experiments/deterministic_vllm_rl/infer.py --model torchtitan/experiments/deterministic_vllm_rl/example_checkpoint/qwen3-0.6B -``` - -Run with TP: (work in progress) -``` -python torchtitan/experiments/deterministic_vllm_rl/infer.py --model torchtitan/experiments/deterministic_vllm_rl/example_checkpoint/qwen3-0.6B --tensor-parallel-size 2 - -``` - -## TODO -1. Rewrite attention part to use vllm.Attention() with backward as the only attention path. -2. Integrate with simple_rl.py to run end-to-end RL with one canonical model definition. -3. Leverage batch-invariant kernels into model definition. - - - -## Contributing - -This experiment is part of TorchTitan. To contribute: - -1. Test your changes with `pytest tests/` -2. Verify bitwise determinism is maintained -3. Update this README if adding new features - -## References - -- [vLLM Documentation](https://docs.vllm.ai/) -- [Flash Attention Paper](https://arxiv.org/abs/2205.14135) -- [PPO Algorithm](https://arxiv.org/abs/1707.06347) -- [GRPO: Group Relative Policy Optimization](https://arxiv.org/abs/2402.03300) - -## License - -This code is licensed under the BSD-style license found in the LICENSE file in the TorchTitan repository root directory. +2. unified/ - Unified approach + - Uses canonical TorchTitan model definition for inference directly + - Replaces attention with vLLM Compatible attention for inference diff --git a/torchtitan/experiments/deterministic_vllm_rl/__init__.py b/torchtitan/experiments/deterministic_vllm_rl/__init__.py index 8525e2a387..66c1de78a5 100644 --- a/torchtitan/experiments/deterministic_vllm_rl/__init__.py +++ b/torchtitan/experiments/deterministic_vllm_rl/__init__.py @@ -4,40 +4,30 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -""" -Deterministic RL training with vLLM experiment. -This experiment provides tools for bitwise-deterministic reinforcement learning -training using vLLM for fast rollouts and TorchTitan for training. - -Key components: -- VLLMCompatibleFlashAttention: Flash attention with custom backward pass -- Qwen3VLLMCompatModel: vLLM-compatible model with merged projections -- batch_invariant_backward: Gradient support for vLLM's deterministic operations -- simple_rl: End-to-end RL training loop -- TorchTitanVLLMModel: Generic wrapper for TorchTitan models with vLLM - -For vLLM inference with TorchTitan models, see: -- models/vllm_wrapper.py: Core vLLM wrapper -- models/__init__.py: Auto-registration with vLLM -- infer.py: Example inference script -""" - -from .batch_invariant_backward import ( +from .unified import ( + create_parallel_dims_from_vllm_config, + register_torchtitan_model_from_train_spec, + TorchTitanVLLMModelWrapper, +) +from .vllm_compat import ( enable_batch_invariant_backward_mode, + Qwen3VLLMCompatModel, rms_norm_with_gradients, silu_and_mul_with_gradients, + VLLMCompatibleFlashAttention, ) -from .models import VLLMCompatibleFlashAttention -from .models.qwen3 import Qwen3VLLMCompatModel -from .models.vllm_wrapper import TorchTitanVLLMModel __all__ = [ + # vllm_compat exports "VLLMCompatibleFlashAttention", "Qwen3VLLMCompatModel", "enable_batch_invariant_backward_mode", "rms_norm_with_gradients", "silu_and_mul_with_gradients", - "TorchTitanVLLMModel", + # unified exports + "TorchTitanVLLMModelWrapper", + "create_parallel_dims_from_vllm_config", + "register_torchtitan_model_from_train_spec", ] diff --git a/torchtitan/experiments/deterministic_vllm_rl/unified/README.md b/torchtitan/experiments/deterministic_vllm_rl/unified/README.md new file mode 100644 index 0000000000..68572927ab --- /dev/null +++ b/torchtitan/experiments/deterministic_vllm_rl/unified/README.md @@ -0,0 +1,67 @@ +# Run vLLM inference with TorchTitan Qwen3 Model + +This directory contains code to run a single canonical model definition (TorchTitan model definition) with vLLM inference engine (not batch-invariant yet, working in progress). +This work is inspired by https://github.com/vllm-project/vllm/pull/28685. + +## Overview +The integration consists of two main components: + +1. **Model Adapter** (`model/qwen3.py`): A custom model class that extends vLLM's `Qwen3ForCausalLM` to handle TorchTitan checkpoint naming conventions +2. **Inference Script** (`infer.py`): A simple script to register the model and run inference + + +## Quick Start +### Prerequisites + +1. Install PyTorch nightly for torchtitan: +``` +pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu126 --force-reinstall +``` + + +2. Install vLLM from source [vllm-use-an-existing-pytorch-installation](https://docs.vllm.ai/en/latest/getting_started/installation/gpu/index.html#use-an-existing-pytorch-installation): +```bash +# install PyTorch first, either from PyPI or from source +git clone https://github.com/vllm-project/vllm.git +cd vllm +python use_existing_torch.py +uv pip install -r requirements/build.txt +uv pip install --no-build-isolation -e . +``` + + +NOTE: If `flash_attn_varlen_func` hits error "torch.AcceleratorError: CUDA error: the provided PTX was compiled with an unsupported toolchain" during forward path, this is due to GPU driver version is not compatible with vLLM/PyTorch compiled version. Use the following command to recompile vLLM. + +``` +# Set CUDA version environment variable +export CUDA_HOME=/usr/local/cuda-12.4 +export PATH=/usr/local/cuda-12.4/bin:$PATH +export LD_LIBRARY_PATH=/usr/local/cuda-12.4/lib64:$LD_LIBRARY_PATH + +# Clean previous build +rm -rf build dist *.egg-info +uv pip uninstall -y vllm + +# Rebuild vLLM from source with CUDA 12.4 +pip install -e . + +``` + +3. Download Qwen3/Qwen3-0.6b checkpoint from HuggingFace and put into `example_checkpoint` folder. + + +4. Run inference: +``` +python torchtitan/experiments/deterministic_vllm_rl/infer.py --model torchtitan/experiments/deterministic_vllm_rl/example_checkpoint/qwen3-0.6B +``` + +Run with TP: (work in progress) +``` +python torchtitan/experiments/deterministic_vllm_rl/infer.py --model torchtitan/experiments/deterministic_vllm_rl/example_checkpoint/qwen3-0.6B --tensor-parallel-size 2 + +``` + +## TODO +1. Rewrite attention part to use vllm.Attention() with backward as the only attention path. +2. Integrate with simple_rl.py to run end-to-end RL with one canonical model definition. +3. Leverage batch-invariant kernels into model definition. diff --git a/torchtitan/experiments/deterministic_vllm_rl/models/__init__.py b/torchtitan/experiments/deterministic_vllm_rl/unified/__init__.py similarity index 56% rename from torchtitan/experiments/deterministic_vllm_rl/models/__init__.py rename to torchtitan/experiments/deterministic_vllm_rl/unified/__init__.py index 080c028036..f9862db824 100644 --- a/torchtitan/experiments/deterministic_vllm_rl/models/__init__.py +++ b/torchtitan/experiments/deterministic_vllm_rl/unified/__init__.py @@ -5,16 +5,18 @@ # LICENSE file in the root directory of this source tree. """ -Models for deterministic vLLM RL training. +Unified approach for running TorchTitan models with vLLM inference. This module automatically registers TorchTitan models with vLLM when imported. +Uses the canonical TorchTitan model definition directly with vLLM inference engine. """ from vllm.logger import init_logger from torchtitan.protocols.train_spec import get_train_spec, TrainSpec -from .attention import VLLMCompatibleFlashAttention, VLLMPagedFlashAttention -from .vllm_wrapper import TorchTitanVLLMModel + +from .utils import create_parallel_dims_from_vllm_config +from .vllm_wrapper import TorchTitanVLLMModelWrapper logger = init_logger(__name__) @@ -23,6 +25,7 @@ def register_torchtitan_model_from_train_spec( train_spec: TrainSpec, model_name: str, + model_flavor: str, ) -> None: """ Register a TorchTitan model with vLLM using a TrainSpec. @@ -30,25 +33,30 @@ def register_torchtitan_model_from_train_spec( Args: train_spec: TorchTitan TrainSpec containing model components model_name: Name to register in vLLM (e.g., "Qwen3TorchTitanForCausalLM") + model_flavor: Model flavor key (e.g., "0.6B") to select from qwen3_args """ from vllm.model_executor.models.registry import ModelRegistry - # Extract model_args from TrainSpec - # TrainSpec has model_args as a Mapping, get the first value + # Get model_args directly from TrainSpec.model_args dict using flavor key if isinstance(train_spec.model_args, dict): - model_args_cls = type(next(iter(train_spec.model_args.values()))) + if model_flavor not in train_spec.model_args: + raise ValueError( + f"Model flavor '{model_flavor}' not found in train_spec.model_args. " + f"Available flavors: {list(train_spec.model_args.keys())}" + ) + model_args = train_spec.model_args[model_flavor] else: - model_args_cls = train_spec.model_args + raise ValueError( + "train_spec.model_args must be a dict mapping flavor names to ModelArgs" + ) # Create dynamic model class directly from TrainSpec components - class TorchTitanVLLMModelFromSpec(TorchTitanVLLMModel): - """Dynamically created vLLM model from TrainSpec.""" - + class TorchTitanVLLMModelFromSpec(TorchTitanVLLMModelWrapper): def __init__(self, *, vllm_config, prefix=""): super().__init__( model_cls=train_spec.model_cls, - model_args_cls=model_args_cls, + model_args=model_args, state_dict_adapter=train_spec.state_dict_adapter, parallelize_fn=train_spec.parallelize_fn, vllm_config=vllm_config, @@ -64,7 +72,7 @@ def __init__(self, *, vllm_config, prefix=""): logger.info( f"Successfully registered {model_name} with vLLM using TrainSpec " - f"(model_cls={train_spec.model_cls.__name__})" + f"(model_cls={train_spec.model_cls.__name__}, flavor={model_flavor})" ) @@ -72,11 +80,15 @@ def __init__(self, *, vllm_config, prefix=""): register_torchtitan_model_from_train_spec( train_spec=get_train_spec("qwen3"), model_name="Qwen3TorchTitanForCausalLM", + # TODO: Remove the model_flavor args when registering model, + # allow passing model flavor option from config system. Now we have to specify + # model_flavor during registration because we can not pass torchtitan job_config from LLM() Api + model_flavor="0.6B", ) __all__ = [ - "VLLMCompatibleFlashAttention", - "VLLMPagedFlashAttention", - "TorchTitanVLLMModel", + "TorchTitanVLLMModelWrapper", + "create_parallel_dims_from_vllm_config", + "register_torchtitan_model_from_train_spec", ] diff --git a/torchtitan/experiments/deterministic_vllm_rl/unified/attention.py b/torchtitan/experiments/deterministic_vllm_rl/unified/attention.py new file mode 100644 index 0000000000..1a03b882cb --- /dev/null +++ b/torchtitan/experiments/deterministic_vllm_rl/unified/attention.py @@ -0,0 +1,93 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from vllm.attention.layer import Attention + + +class VLLMAttention(torch.nn.Module): + """ + Wrapper around vLLM's Attention. Compatible with TorchTitan input shape. + """ + + def __init__( + self, + hidden_size: int, + num_heads: int, + num_kv_heads: int, + head_dim: int, + layer_name: str, + scale: float | None = None, + ) -> None: + super().__init__() + + self.hidden_size = hidden_size + self.layer_name = layer_name + + from vllm.config import get_current_vllm_config + + vllm_config = get_current_vllm_config() + + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = head_dim + + if scale is None: + self.scale = head_dim**-0.5 + else: + self.scale = scale + + cache_config = ( + vllm_config.cache_config if hasattr(vllm_config, "cache_config") else None + ) + + self.vllm_attn = Attention( + num_heads=num_heads, + head_size=head_dim, + scale=self.scale, + num_kv_heads=num_kv_heads, + cache_config=cache_config, + quant_config=None, + prefix=f"model.layers.{layer_name}.attention.inner_attention", + ) + + def forward( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + *, + scale: float | None = None, + ) -> torch.Tensor: + """ + Forward pass using vLLM's Attention layer for inference. + + Args: + q: Query tensor [batch, num_heads, seq_len, head_dim] + k: Key tensor [batch, num_kv_heads, seq_len, head_dim] + v: Value tensor [batch, num_kv_heads, seq_len, head_dim] + scale: Optional attention scale override (unused, vLLM uses internal scale) + + Returns: + output: [batch, num_heads, seq_len, head_dim] + """ + # Input is (batch, num_heads, seq_len, head_dim) + batch_size, num_heads, seq_len, head_dim = q.shape + + # Transpose to (batch, seq_len, num_heads, head_dim) for vLLM + q = q.transpose(1, 2) + k = k.transpose(1, 2) + v = v.transpose(1, 2) + + output_varlen = self.vllm_attn(q, k, v) + + # Reshape back to batch format + output = output_varlen.view(batch_size, seq_len, num_heads, head_dim) + + # Transpose back to TorchTitan format: (batch, num_heads, seq_len, head_dim) + output = output.transpose(1, 2) + + return output diff --git a/torchtitan/experiments/deterministic_vllm_rl/infer.py b/torchtitan/experiments/deterministic_vllm_rl/unified/infer.py similarity index 90% rename from torchtitan/experiments/deterministic_vllm_rl/infer.py rename to torchtitan/experiments/deterministic_vllm_rl/unified/infer.py index 3863b228bd..e52e7d2651 100755 --- a/torchtitan/experiments/deterministic_vllm_rl/infer.py +++ b/torchtitan/experiments/deterministic_vllm_rl/unified/infer.py @@ -10,8 +10,8 @@ from vllm import LLM, SamplingParams from vllm.logger import init_logger -# Import models module - this automatically registers TorchTitan models with vLLM -from torchtitan.experiments.deterministic_vllm_rl import models # noqa: F401 +# Import unified module - this automatically registers TorchTitan models with vLLM +from torchtitan.experiments.deterministic_vllm_rl import unified # noqa: F401 logger = init_logger(__name__) @@ -25,7 +25,7 @@ def parse_args(): parser.add_argument( "--model_ckpt_path", type=str, - default="torchtitan/experiments/deterministic_vllm_rl/example_checkpoint/qwen3-0.6B", + default="torchtitan/experiments/deterministic_vllm_rl/example_checkpoint", help="Path to TorchTitan checkpoint directory", ) parser.add_argument( @@ -74,7 +74,8 @@ def main(): llm = LLM( model=args.model_ckpt_path, # Model checkpoint path hf_overrides={ - "checkpoint_dir": args.model_ckpt_path, + # Override architectures to use our registered TorchTitan model class + "architectures": ["Qwen3TorchTitanForCausalLM"], }, dtype="bfloat16", trust_remote_code=True, diff --git a/torchtitan/experiments/deterministic_vllm_rl/models/utils.py b/torchtitan/experiments/deterministic_vllm_rl/unified/utils.py similarity index 54% rename from torchtitan/experiments/deterministic_vllm_rl/models/utils.py rename to torchtitan/experiments/deterministic_vllm_rl/unified/utils.py index e3657a9e92..0b56d81be4 100644 --- a/torchtitan/experiments/deterministic_vllm_rl/models/utils.py +++ b/torchtitan/experiments/deterministic_vllm_rl/unified/utils.py @@ -23,9 +23,12 @@ def create_parallel_dims_from_vllm_config(vllm_config: VllmConfig) -> ParallelDims: """ - Create ParallelDims from vLLM config. + Create ParallelDims from vLLM config and maps vLLM parallelism settings to TorchTitan's ParallelDims dataclass. - Maps vLLM parallelism settings to TorchTitan's ParallelDims dataclass. + This function is needed because vLLM doesn't separate model creation and + parallelism application - it requires parallelization to be done inside + the model constructor, so we are creating parallel_dims and apply parallelism + in TorchTitanVLLMModelWrapper.__init__ function. Args: vllm_config: vLLM configuration object @@ -58,47 +61,3 @@ def create_parallel_dims_from_vllm_config(vllm_config: VllmConfig) -> ParallelDi ) return parallel_dims - - -def build_device_mesh_and_parallelize( - model, - parallelize_fn, - parallel_dims: ParallelDims, -): - """ - Build device mesh and apply parallelization to the model. - - Uses TorchTitan's ParallelDims to build the device mesh with proper validation - and submesh creation, then applies tensor parallelism to the model using the - provided parallelize function. - - Args: - model: The TorchTitan model to parallelize - parallelize_fn: Function to apply tensor parallelism (e.g., apply_qwen3_tp) - parallel_dims: ParallelDims object with validated parallelism settings - - Returns: - The device mesh object - """ - # Use ParallelDims to build the device mesh - # This handles all the complexity of: - # - Validation of parallel dimensions - # - Building multi-dimensional device mesh - # - Creating all required submeshes (dp, dp_shard_cp, dp_cp, etc.) - world_mesh = parallel_dims.world_mesh - - logger.info(f"Built device mesh using ParallelDims: {world_mesh}") - - # Apply tensor parallelism using provided function - if parallel_dims.tp_enabled: - tp_mesh = world_mesh["tp"] - parallelize_fn( - model=model, - tp_mesh=tp_mesh, - loss_parallel=False, - enable_float8_tensorwise_tp=False, - enable_async_tp=False, - ) - logger.info(f"Applied Tensor Parallelism with TP={parallel_dims.tp}") - - return world_mesh diff --git a/torchtitan/experiments/deterministic_vllm_rl/models/vllm_wrapper.py b/torchtitan/experiments/deterministic_vllm_rl/unified/vllm_wrapper.py similarity index 78% rename from torchtitan/experiments/deterministic_vllm_rl/models/vllm_wrapper.py rename to torchtitan/experiments/deterministic_vllm_rl/unified/vllm_wrapper.py index 1fcf81e402..9e7253c9f9 100644 --- a/torchtitan/experiments/deterministic_vllm_rl/models/vllm_wrapper.py +++ b/torchtitan/experiments/deterministic_vllm_rl/unified/vllm_wrapper.py @@ -25,28 +25,25 @@ from vllm.config import VllmConfig from vllm.logger import init_logger -from torchtitan.experiments.deterministic_vllm_rl.models.attention import ( - VLLMPagedFlashAttention, -) -from torchtitan.experiments.deterministic_vllm_rl.models.utils import ( - create_parallel_dims_from_vllm_config, -) +from torchtitan.experiments.deterministic_vllm_rl.unified.attention import VLLMAttention from torchtitan.models.qwen3.model.model import precompute_rope_cache from torchtitan.protocols.model import BaseModelArgs, ModelProtocol from torchtitan.protocols.state_dict_adapter import BaseStateDictAdapter +from .utils import create_parallel_dims_from_vllm_config + logger = init_logger(__name__) ParallelizeFunction: TypeAlias = Callable[..., nn.Module] -class TorchTitanVLLMModel(nn.Module): +class TorchTitanVLLMModelWrapper(nn.Module): """ Generic vLLM-compatible model wrapper for TorchTitan models. The wrapper handles: - - HF config to TorchTitan model args mapping + - Direct usage of TorchTitan model args (no HF config mapping needed) - Attention replacement with vLLM paged attention - Tensor parallelism setup - Weight loading from HF checkpoints @@ -60,8 +57,8 @@ class TorchTitanVLLMModel(nn.Module): def __init__( self, *, - model_cls: type[ModelProtocol], # passing types that is not instantiated - model_args_cls: type[BaseModelArgs], + model_cls: type[ModelProtocol], + model_args: BaseModelArgs, state_dict_adapter: type[BaseStateDictAdapter], parallelize_fn: ParallelizeFunction, vllm_config: VllmConfig, @@ -73,19 +70,13 @@ def __init__( # Store components self.model_cls = model_cls - self.model_args_cls = model_args_cls self.state_dict_adapter = state_dict_adapter self.parallelize_fn = parallelize_fn - # Map HF config to TorchTitan ModelArgs - hf_config = vllm_config.model_config.hf_config - logger.info(f"Mapping HF config to {self.model_args_cls.__name__}") - model_args = self._map_hf_config_to_model_args(hf_config, self.model_args_cls) - - # Create TorchTitan model + # Use TorchTitan model args directly (no HF config mapping) + self.config = model_args logger.info(f"Creating {self.model_cls.__name__} with config: {model_args}") self.model = self.model_cls(model_args) - self.config = model_args # Setup RoPE cache extension function if provided self.rope_cache_extension_fn = partial( @@ -94,7 +85,7 @@ def __init__( base=self.config.rope_theta, ) # Replace attention with vLLM paged attention - self._replace_with_vllm_paged_attention(model_args) + self._replace_with_vllm_attention(model_args) # Create ParallelDims from vLLM config and apply parallelization # NOTE: We need to apply parallelize within model.__init__ because w @@ -115,74 +106,44 @@ def __init__( else: logger.info("Single GPU mode - no parallelization needed") - def _map_hf_config_to_model_args(self, hf_config, model_args_cls): - """ - Map HuggingFace config to TorchTitan ModelArgs. - - Default implementation that handles common model args fields. - Override in subclass if custom mapping is needed. - """ - # Maps TorchTitan parameter name to HF config attribute name - mapping = { - "vocab_size": "vocab_size", - "dim": "hidden_size", - "n_layers": "num_hidden_layers", - "n_heads": "num_attention_heads", - "n_kv_heads": "num_key_value_heads", - "head_dim": "head_dim", - "hidden_dim": "intermediate_size", - "norm_eps": "rms_norm_eps", - "max_seq_len": "max_position_embeddings", - "rope_theta": "rope_theta", - "qk_norm": "qk_norm", - } - - # Build kwargs for model args from mapping - kwargs = {} - for torchtitan_param, hf_attr in mapping.items(): - # Try to get value from HF config - if hasattr(hf_config, hf_attr): - kwargs[torchtitan_param] = getattr(hf_config, hf_attr) - - return model_args_cls(**kwargs) - - def _replace_with_vllm_paged_attention(self, model_args): + def _replace_with_vllm_attention(self, model_args): """ Replace TorchTitan attention with vLLM paged attention. Assumes model has .layers dict with .attention.inner_attention structure. Override in subclass if different structure. """ - if not hasattr(self.model, "layers"): - raise AttributeError( - f"Model {type(self.model).__name__} must have .layers attribute" - ) + assert hasattr( + self.model, "layers" + ), f"Model {type(self.model).__name__} must have .layers attribute" for layer_name, layer in self.model.layers.items(): - if not hasattr(layer, "attention"): - raise ValueError(f"Layer {layer_name} must have .attention attribute") + assert hasattr( + layer, "attention" + ), f"Layer {layer_name} must have .attention attribute" - # Create vLLM paged attention - vllm_attn = VLLMPagedFlashAttention( + vllm_attn = VLLMAttention( hidden_size=model_args.dim, num_heads=model_args.n_heads, num_kv_heads=model_args.n_heads, # Use n_heads (already replicated) head_dim=model_args.head_dim, - causal=True, + layer_name=layer_name, + scale=model_args.head_dim**-0.5, ) # Replace inner attention layer.attention.inner_attention = vllm_attn logger.info( - "Successfully replaced TorchTitan attention with vLLM PagedFlashAttention" + f"Successfully replaced TorchTitan attention with VLLMAttention " + f"({len(self.model.layers)} layers)" ) def _extend_rope_cache_if_needed( self, rope_cache: torch.Tensor, max_position: int ) -> torch.Tensor: """ - Extend RoPE cache if needed during vLLM profiling. + Extend RoPE cache if needed during vLLM profiling stage. Args: rope_cache: Current RoPE cache tensor @@ -309,10 +270,6 @@ def forward( batch_size, seq_len, hidden_size = h.shape h = h.view(batch_size * seq_len, hidden_size) - # TODO(jianiw): finish this conversion when TP is applied - if isinstance(h, DTensor): - h = h.full_tensor() - return h def compute_logits( @@ -324,15 +281,12 @@ def compute_logits( h = self.model.norm(hidden_states) logits = self.model.output(h) - # TODO: This part is to work with TP integration - if isinstance(logits, DTensor): - logits = logits.to_local() - return logits def load_weights(self, weights_iter): """ Load weights from HF checkpoint using the provided state dict adapter. + vLLM engine would call this function to load model weights. Args: weights_iter: Iterator of (name, tensor) pairs from HF checkpoint diff --git a/torchtitan/experiments/deterministic_vllm_rl/vllm_compat/README.md b/torchtitan/experiments/deterministic_vllm_rl/vllm_compat/README.md new file mode 100644 index 0000000000..d2ef719c0d --- /dev/null +++ b/torchtitan/experiments/deterministic_vllm_rl/vllm_compat/README.md @@ -0,0 +1,262 @@ +# Deterministic RL Training with vLLM + +This experiment combines vLLM's deterministic kernels with PyTorch autograd to enable reinforcement learning training where forward passes produce bitwise-identical results across runs. + +## Overview + +RL training requires both fast inference for generating rollouts and gradient computation for policy updates. vLLM provides deterministic forward passes but does not support gradients. This experiment adds backward passes to vLLM's operations. + +The implementation: +1. Uses vLLM's batch-invariant kernels for forward passes +2. Implements custom backward passes for gradient computation +3. Provides weight conversion utilities between TorchTitan and vLLM formats + +### Features + +- Bitwise determinism: Same inputs produce identical outputs across runs +- Gradient support: Backward passes through vLLM operations +- Weight conversion: Utilities to convert between model formats + +Note: Currently supports single-device training only. + +## Architecture + +### Components + +1. `models/attention.py`: VLLMCompatibleFlashAttention + - Uses vLLM's Flash Attention for forward pass + - Implements custom backward pass for gradient computation + - Uses `num_splits=1` for deterministic behavior + +2. `models/qwen3/model_vllm_compat.py`: Qwen3VLLMCompatModel + - Qwen3 model with merged gate/up projections matching vLLM format + - Uses VLLMRMSNorm with gradient support + +3. `batch_invariant_backward.py`: Backward passes for vLLM operations + - Registers gradients for vLLM's batch-invariant operations + - Supports matmul, linear, and RMSNorm + - Patches Flash Attention for autograd + +4. `weights_vllm_compat.py`: Weight conversion utilities + - Converts between TorchTitan format (separate w1, w2, w3) and vLLM format (merged gate_up_proj) + - Provides bidirectional conversion functions + +5. `simple_rl.py`: RL training loop + - Generates rollouts using vLLM engine + - Computes advantages using GRPO-style ranking + - Updates policy using PPO + +## Installation + +### Prerequisites + +```bash +# Install vLLM with deterministic support +pip install vllm + +# Install TorchTitan (from the repository root) +pip install -e . + +# Install additional dependencies +pip install transformers safetensors huggingface_hub tensorboard +``` + +### Enable Batch Invariance + +Initialize vLLM's batch-invariant mode before training: + +```python +from vllm.model_executor.layers.batch_invariant import init_batch_invariance +init_batch_invariance() +``` + +## Usage + +### Quick Start + +```python +import torch +from vllm.model_executor.layers.batch_invariant import init_batch_invariance +from torchtitan.experiments.deterministic_vllm_rl import ( + enable_batch_invariant_backward_mode, + Qwen3VLLMCompatModel, +) + +# 1. Enable deterministic mode +init_batch_invariance() +enable_batch_invariant_backward_mode() + +# 2. Load model +from torchtitan.models.qwen3.model.args import Qwen3ModelArgs +model_args = Qwen3ModelArgs( + dim=2048, + n_layers=24, + n_heads=16, + n_kv_heads=2, + vocab_size=151936, +) +model = Qwen3VLLMCompatModel(model_args) + +# 3. Forward pass (deterministic) +input_ids = torch.randint(0, 151936, (2, 128), device='cuda') +logits = model(input_ids) + +# 4. Backward pass +loss = logits.sum() +loss.backward() +``` + +### Full RL Training + +Run the RL training loop: + +```bash +VLLM_BATCH_INVARIANT=1 VLLM_FLASH_ATTN_VERSION=3 python -m torchtitan.experiments.deterministic_vllm_rl.simple_rl +``` + +This will: +1. Download Qwen3-1.7B from HuggingFace +2. Initialize vLLM engine for rollouts +3. Generate samples for training prompts +4. Compute rewards and advantages +5. Update the policy using PPO +6. Log metrics to TensorBoard + +View training progress: +```bash +tensorboard --logdir=./outputs/rl_training +``` + +## How It Works + +### Deterministic Forward Pass + +vLLM's batch-invariant mode makes operations deterministic: + +```python +# These operations are deterministic when batch_invariance is enabled +y = torch.matmul(a, b) # Uses vLLM's deterministic matmul +output = flash_attn_varlen_func(q, k, v, num_splits=1) # Deterministic FA +``` + +### Backward Pass with Gradients + +Custom backward passes: +1. Re-compute attention weights deterministically +2. Use standard chain rule for gradients +3. Apply gradients through vLLM's deterministic operations + +```python +class FlashAttnWithBackward(torch.autograd.Function): + @staticmethod + def forward(ctx, q, k, v, ...): + # Use vLLM's forward implementation + return flash_attn_varlen_func(q, k, v, num_splits=1, ...) + + @staticmethod + def backward(ctx, grad_output): + # Compute gradients deterministically + # (re-compute attention weights and apply chain rule) + return grad_q, grad_k, grad_v, ... +``` + +### Bitwise Determinism Verification + +The training loop compares logprobs from vLLM and TorchTitan: + +```python +# During training, compare logprobs +vllm_logprobs = [from vLLM rollout] +titan_logprobs = [from TorchTitan forward pass] + +assert torch.equal(vllm_logprobs, titan_logprobs) +``` + +## Testing + +Run the test suite: + +```bash +cd torchtitan/experiments/deterministic_vllm_rl/tests + +# Test backward passes +python test_batch_invariant_backward.py + +# Test determinism +python test_exact_determinism.py +``` + +## Technical Details + +### Why Determinism Matters for RL + +RL training steps: +1. Generate rollouts by sampling from the policy +2. Compute rewards based on the samples +3. Update the policy using gradients + +If the forward pass during training differs from the forward pass during rollout, policy gradients may be incorrect. This matters for algorithms like PPO that compare old and new policy probabilities. + +This implementation uses the same kernels for both rollouts (vLLM) and training (TorchTitan) to ensure `logprobs_rollout == logprobs_training` bitwise. + +### Performance + +- Rollout speed: Uses vLLM's optimized kernels +- Training speed: Similar to standard TorchTitan +- Memory: Saves activations for custom backward passes + +### Limitations + +1. Custom backward requires uniform sequence lengths +2. Only causal attention is supported +3. Requires NVIDIA GPUs with Flash Attention support + +## Project Structure + +``` +deterministic_vllm_rl/ +├── README.md # Documentation +├── __init__.py # Package initialization +├── batch_invariant_backward.py # Backward passes for vLLM ops +├── weights_vllm_compat.py # Weight conversion utilities +├── simple_rl.py # RL training loop +├── models/ +│ ├── __init__.py +│ ├── attention.py # VLLMCompatibleFlashAttention +│ └── qwen3/ +│ ├── __init__.py +│ └── model_vllm_compat.py # vLLM-compatible Qwen3 model +├── weights/ +│ ├── __init__.py +│ ├── converter.py # Weight conversion script +│ └── README.md # Weight conversion documentation +└── tests/ + ├── __init__.py + ├── test_batch_invariant_backward.py # Test backward passes + └── test_exact_determinism.py # Test determinism +``` + +## TODO + +- `FlashAttnWithBackward` will need to become more composable and should not live exclusively within this directory. +- vLLM integration will need to become more generic with a provided Attention operator that is KV-cache compatible. +- vLLM parallelism will need to add generic parallelism initialization to support Monarch managed TP/DP. + +## Contributing + +This experiment is part of TorchTitan. To contribute: + +1. Test your changes with `pytest tests/` +2. Verify bitwise determinism is maintained +3. Update this README if adding new features + +## References + +- [vLLM Documentation](https://docs.vllm.ai/) +- [Flash Attention Paper](https://arxiv.org/abs/2205.14135) +- [PPO Algorithm](https://arxiv.org/abs/1707.06347) +- [GRPO: Group Relative Policy Optimization](https://arxiv.org/abs/2402.03300) + +## License + +This code is licensed under the BSD-style license found in the LICENSE file in the TorchTitan repository root directory. diff --git a/torchtitan/experiments/deterministic_vllm_rl/vllm_compat/__init__.py b/torchtitan/experiments/deterministic_vllm_rl/vllm_compat/__init__.py new file mode 100644 index 0000000000..b86721fba5 --- /dev/null +++ b/torchtitan/experiments/deterministic_vllm_rl/vllm_compat/__init__.py @@ -0,0 +1,29 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +vLLM-Compatible approach for deterministic RL training. + +This module provides models that match vLLM's weight format (e.g., merged gate_up_proj) +with custom backward passes for gradient computation during training. +""" + +from .batch_invariant_backward import ( + enable_batch_invariant_backward_mode, + rms_norm_with_gradients, + silu_and_mul_with_gradients, +) +from .models.attention import VLLMCompatibleFlashAttention +from .models.qwen3 import Qwen3VLLMCompatModel + + +__all__ = [ + "VLLMCompatibleFlashAttention", + "Qwen3VLLMCompatModel", + "enable_batch_invariant_backward_mode", + "rms_norm_with_gradients", + "silu_and_mul_with_gradients", +] diff --git a/torchtitan/experiments/deterministic_vllm_rl/batch_invariant_backward.py b/torchtitan/experiments/deterministic_vllm_rl/vllm_compat/batch_invariant_backward.py similarity index 100% rename from torchtitan/experiments/deterministic_vllm_rl/batch_invariant_backward.py rename to torchtitan/experiments/deterministic_vllm_rl/vllm_compat/batch_invariant_backward.py diff --git a/torchtitan/experiments/deterministic_vllm_rl/vllm_compat/models/__init__.py b/torchtitan/experiments/deterministic_vllm_rl/vllm_compat/models/__init__.py new file mode 100644 index 0000000000..2e7a5fa6af --- /dev/null +++ b/torchtitan/experiments/deterministic_vllm_rl/vllm_compat/models/__init__.py @@ -0,0 +1,18 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +Models for deterministic vLLM RL training. + +This module provides vLLM-compatible model components. +""" + +from .attention import VLLMCompatibleFlashAttention + + +__all__ = [ + "VLLMCompatibleFlashAttention", +] diff --git a/torchtitan/experiments/deterministic_vllm_rl/models/attention.py b/torchtitan/experiments/deterministic_vllm_rl/vllm_compat/models/attention.py similarity index 54% rename from torchtitan/experiments/deterministic_vllm_rl/models/attention.py rename to torchtitan/experiments/deterministic_vllm_rl/vllm_compat/models/attention.py index 478884f9b7..11e6d3af67 100644 --- a/torchtitan/experiments/deterministic_vllm_rl/models/attention.py +++ b/torchtitan/experiments/deterministic_vllm_rl/vllm_compat/models/attention.py @@ -4,15 +4,8 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -""" -vLLM-compatible Flash Attention implementation for deterministic RL training. -""" - -import itertools import torch - -from vllm.attention.layer import Attention from vllm.attention.utils.fa_utils import flash_attn_varlen_func @@ -198,170 +191,3 @@ def backward(ctx, grad_output): output = output.transpose(1, 2) return output - - -class VLLMPagedFlashAttention(torch.nn.Module): - """ - Wrapper around vLLM's Attention with custom backward pass. - - Forward: Uses vLLM's optimized Attention layer (flash attention kernels) - Backward: Custom implementation for deterministic gradients - - TODO: This class need to be future refined. - """ - - # Class variable for auto-generating unique layer names (thread-safe) - _layer_counter = itertools.count() - - def __init__( - self, - hidden_size: int, - num_heads: int, - num_kv_heads: int, - head_dim: int, - scale: float | None = None, - causal: bool = True, - ) -> None: - super().__init__() - - self.hidden_size = hidden_size - - # Handle tensor parallelism: adjust num_heads and num_kv_heads for TP - # NOTE(jianiw): As we use local tensor for this region, we need to manually - - from vllm.config import get_current_vllm_config - from vllm.logger import init_logger - - logger = init_logger(__name__) - vllm_config = get_current_vllm_config() - tp_size = vllm_config.parallel_config.tensor_parallel_size - - if tp_size > 1: - if num_kv_heads % tp_size != 0: - # Pad num_kv_heads and num_heads to be divisible by tp_size - assert num_heads % num_kv_heads == 0 - padded_size = tp_size - num_kv_heads % tp_size - padded_num_kv_heads = num_kv_heads + padded_size - padded_num_heads = num_heads + padded_size * num_heads // num_kv_heads - assert padded_num_heads % tp_size == 0 - assert padded_num_kv_heads % tp_size == 0 - - logger.info( - f"Padding attention heads for tensor parallelism: " - f"{num_heads=}, {padded_num_heads=}, " - f"{num_kv_heads=}, {padded_num_kv_heads=}" - ) - - num_heads = padded_num_heads // tp_size - num_kv_heads = padded_num_kv_heads // tp_size - else: - num_heads //= tp_size - num_kv_heads //= tp_size - - self.num_heads = num_heads - self.num_kv_heads = num_kv_heads - self.head_dim = head_dim - self.causal = causal - - if scale is None: - self.scale = head_dim**-0.5 - else: - self.scale = scale - - # Create vLLM Attention layer - from vllm.config import get_current_vllm_config - - config = get_current_vllm_config() - cache_config = config.cache_config if hasattr(config, "cache_config") else None - - # Generate unique prefix for this attention layer - # vLLM expects format "layers.X" for layer index extraction - layer_idx = next(VLLMPagedFlashAttention._layer_counter) - prefix = f"layers.{layer_idx}" - - self.vllm_attn = Attention( - num_heads=num_heads, - head_size=head_dim, - scale=self.scale, - num_kv_heads=num_kv_heads, - cache_config=cache_config, - quant_config=None, - prefix=prefix, - ) - - # KV cache - will be populated by vLLM during model loading - self.kv_cache: list[torch.Tensor] | None = None - - # Auto-register for vLLM KV cache if in vLLM context - self._auto_register_for_kv_cache() - - def _auto_register_for_kv_cache(self): - """Automatically register this layer for vLLM KV cache allocation. - - This is called during __init__ and will register the layer if we're in - a vLLM context. If not in vLLM context (e.g., pure PyTorch training), - this silently does nothing. - """ - # Initialize layer_name attribute - self.layer_name: str | None = None - - from vllm.config import get_current_vllm_config - - config = get_current_vllm_config() - compilation_config = config.compilation_config - - # Generate unique layer name using class counter - # Format: "layers.{index}" for compatibility with extract_layer_index() - layer_name = f"layers.{next(VLLMPagedFlashAttention._layer_counter)}" - - # Register this layer in static forward context - if layer_name in compilation_config.static_forward_context: - raise ValueError(f"Duplicate layer name: {layer_name}") - compilation_config.static_forward_context[layer_name] = self - self.layer_name = layer_name - - def forward( - self, - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - *, - scale: float | None = None, - ) -> torch.Tensor: - """ - Forward pass using vLLM's Attention layer for inference. - - Args: - q: Query tensor [batch, num_heads, seq_len, head_dim] - k: Key tensor [batch, num_kv_heads, seq_len, head_dim] - v: Value tensor [batch, num_kv_heads, seq_len, head_dim] - scale: Optional attention scale override (unused, vLLM uses internal scale) - - Returns: - output: [batch, num_heads, seq_len, head_dim] - """ - # Input is (batch, num_heads, seq_len, head_dim) - batch_size, num_heads, seq_len, head_dim = q.shape - - if self.vllm_attn is None: - raise RuntimeError( - "vLLM attention not initialized. This module requires vLLM context." - ) - - # Transpose to (batch, seq_len, num_heads, head_dim) for vLLM - q = q.transpose(1, 2) - k = k.transpose(1, 2) - v = v.transpose(1, 2) - - # NOTE(jianiw): vllm_attention can take input as shape (batch, seq_len, num_heads, head_dim) - # and handle the reshaping internally - # Use vLLM's Attention layer for paged attention and - output_varlen = self.vllm_attn(q, k, v) - - # Reshape back to batch format - output = output_varlen.view(batch_size, seq_len, num_heads, head_dim) - - # Transpose back to TorchTitan format: (batch, num_heads, seq_len, head_dim) - output = output.transpose(1, 2) - - return output diff --git a/torchtitan/experiments/deterministic_vllm_rl/models/qwen3/__init__.py b/torchtitan/experiments/deterministic_vllm_rl/vllm_compat/models/qwen3/__init__.py similarity index 85% rename from torchtitan/experiments/deterministic_vllm_rl/models/qwen3/__init__.py rename to torchtitan/experiments/deterministic_vllm_rl/vllm_compat/models/qwen3/__init__.py index 4ac63df232..10f49db8b5 100644 --- a/torchtitan/experiments/deterministic_vllm_rl/models/qwen3/__init__.py +++ b/torchtitan/experiments/deterministic_vllm_rl/vllm_compat/models/qwen3/__init__.py @@ -8,6 +8,6 @@ Qwen3 model with vLLM compatibility for deterministic RL training. """ -from .model_batch_invariant import Qwen3VLLMCompatModel +from .model_vllm_compat import Qwen3VLLMCompatModel __all__ = ["Qwen3VLLMCompatModel"] diff --git a/torchtitan/experiments/deterministic_vllm_rl/models/qwen3/model_batch_invariant.py b/torchtitan/experiments/deterministic_vllm_rl/vllm_compat/models/qwen3/model_vllm_compat.py similarity index 99% rename from torchtitan/experiments/deterministic_vllm_rl/models/qwen3/model_batch_invariant.py rename to torchtitan/experiments/deterministic_vllm_rl/vllm_compat/models/qwen3/model_vllm_compat.py index dd84665091..b4967fbbd9 100644 --- a/torchtitan/experiments/deterministic_vllm_rl/models/qwen3/model_batch_invariant.py +++ b/torchtitan/experiments/deterministic_vllm_rl/vllm_compat/models/qwen3/model_vllm_compat.py @@ -13,7 +13,7 @@ from torchtitan.components.tokenizer import BaseTokenizer # Import gradient-enabled operations from experiment utilities -from torchtitan.experiments.deterministic_vllm_rl.batch_invariant_backward import ( +from torchtitan.experiments.deterministic_vllm_rl.vllm_compat.batch_invariant_backward import ( rms_norm_with_gradients, silu_and_mul_with_gradients, ) diff --git a/torchtitan/experiments/deterministic_vllm_rl/simple_rl.py b/torchtitan/experiments/deterministic_vllm_rl/vllm_compat/simple_rl.py similarity index 99% rename from torchtitan/experiments/deterministic_vllm_rl/simple_rl.py rename to torchtitan/experiments/deterministic_vllm_rl/vllm_compat/simple_rl.py index ffc7d52eb0..3f938eba85 100644 --- a/torchtitan/experiments/deterministic_vllm_rl/simple_rl.py +++ b/torchtitan/experiments/deterministic_vllm_rl/vllm_compat/simple_rl.py @@ -30,11 +30,11 @@ from vllm import LLM, SamplingParams from vllm.model_executor.layers.batch_invariant import init_batch_invariance -from torchtitan.experiments.deterministic_vllm_rl.weights.converter import ( +from torchtitan.experiments.deterministic_vllm_rl.vllm_compat.weights.converter import ( torchtitan_to_vllm, vllm_to_torchtitan, ) -from torchtitan.experiments.deterministic_vllm_rl.weights_vllm_compat import ( +from torchtitan.experiments.deterministic_vllm_rl.vllm_compat.weights_vllm_compat import ( torchtitan_to_vllm_compat, ) @@ -340,7 +340,7 @@ def load_model(checkpoint_path: str, model_path: str, use_vllm_compat: bool = Tr if use_vllm_compat: # Create and load model (using vLLM-compat for bitwise determinism) - from torchtitan.experiments.deterministic_vllm_rl.models.qwen3 import ( + from torchtitan.experiments.deterministic_vllm_rl.vllm_compat.models.qwen3 import ( Qwen3VLLMCompatModel, ) @@ -1058,7 +1058,7 @@ def main(): print("✓ Batch invariance detected - using vLLM-compatible model") # Add backward pass support to vLLM's batch_invariant mode print(" Adding gradient support to vLLM's batch_invariant mode...") - from torchtitan.experiments.deterministic_vllm_rl.batch_invariant_backward import ( + from torchtitan.experiments.deterministic_vllm_rl.vllm_compat.batch_invariant_backward import ( enable_batch_invariant_backward_mode, ) diff --git a/torchtitan/experiments/deterministic_vllm_rl/tests/__init__.py b/torchtitan/experiments/deterministic_vllm_rl/vllm_compat/tests/__init__.py similarity index 100% rename from torchtitan/experiments/deterministic_vllm_rl/tests/__init__.py rename to torchtitan/experiments/deterministic_vllm_rl/vllm_compat/tests/__init__.py diff --git a/torchtitan/experiments/deterministic_vllm_rl/tests/test_batch_invariant_backward.py b/torchtitan/experiments/deterministic_vllm_rl/vllm_compat/tests/test_batch_invariant_backward.py similarity index 97% rename from torchtitan/experiments/deterministic_vllm_rl/tests/test_batch_invariant_backward.py rename to torchtitan/experiments/deterministic_vllm_rl/vllm_compat/tests/test_batch_invariant_backward.py index 3ed9604d10..ddf8b01514 100644 --- a/torchtitan/experiments/deterministic_vllm_rl/tests/test_batch_invariant_backward.py +++ b/torchtitan/experiments/deterministic_vllm_rl/vllm_compat/tests/test_batch_invariant_backward.py @@ -8,9 +8,11 @@ Test batch_invariant_backward module to ensure it works correctly. """ +import sys + import torch -from torchtitan.experiments.deterministic_vllm_rl.batch_invariant_backward import ( +from torchtitan.experiments.deterministic_vllm_rl.vllm_compat.batch_invariant_backward import ( disable_batch_invariant_backward_mode, enable_batch_invariant_backward_mode, linear_batch_invariant_backward, diff --git a/torchtitan/experiments/deterministic_vllm_rl/tests/test_exact_determinism.py b/torchtitan/experiments/deterministic_vllm_rl/vllm_compat/tests/test_exact_determinism.py similarity index 98% rename from torchtitan/experiments/deterministic_vllm_rl/tests/test_exact_determinism.py rename to torchtitan/experiments/deterministic_vllm_rl/vllm_compat/tests/test_exact_determinism.py index 8d0ac3133e..bfb7954a2a 100644 --- a/torchtitan/experiments/deterministic_vllm_rl/tests/test_exact_determinism.py +++ b/torchtitan/experiments/deterministic_vllm_rl/vllm_compat/tests/test_exact_determinism.py @@ -13,7 +13,7 @@ import torch from vllm.model_executor.layers.batch_invariant import disable_batch_invariant_mode -from torchtitan.experiments.deterministic_vllm_rl.batch_invariant_backward import ( +from torchtitan.experiments.deterministic_vllm_rl.vllm_compat.batch_invariant_backward import ( enable_batch_invariant_backward_mode, ) diff --git a/torchtitan/experiments/deterministic_vllm_rl/weights/README.md b/torchtitan/experiments/deterministic_vllm_rl/vllm_compat/weights/README.md similarity index 100% rename from torchtitan/experiments/deterministic_vllm_rl/weights/README.md rename to torchtitan/experiments/deterministic_vllm_rl/vllm_compat/weights/README.md diff --git a/torchtitan/experiments/deterministic_vllm_rl/weights/__init__.py b/torchtitan/experiments/deterministic_vllm_rl/vllm_compat/weights/__init__.py similarity index 100% rename from torchtitan/experiments/deterministic_vllm_rl/weights/__init__.py rename to torchtitan/experiments/deterministic_vllm_rl/vllm_compat/weights/__init__.py diff --git a/torchtitan/experiments/deterministic_vllm_rl/weights/converter.py b/torchtitan/experiments/deterministic_vllm_rl/vllm_compat/weights/converter.py similarity index 100% rename from torchtitan/experiments/deterministic_vllm_rl/weights/converter.py rename to torchtitan/experiments/deterministic_vllm_rl/vllm_compat/weights/converter.py diff --git a/torchtitan/experiments/deterministic_vllm_rl/weights_vllm_compat.py b/torchtitan/experiments/deterministic_vllm_rl/vllm_compat/weights_vllm_compat.py similarity index 100% rename from torchtitan/experiments/deterministic_vllm_rl/weights_vllm_compat.py rename to torchtitan/experiments/deterministic_vllm_rl/vllm_compat/weights_vllm_compat.py From e9c74e8ce29b27134169aad50d90edfed1707d77 Mon Sep 17 00:00:00 2001 From: Jiani Wang Date: Mon, 15 Dec 2025 20:48:04 -0800 Subject: [PATCH 18/19] your commit message --- torchtitan/experiments/README.md | 1 + .../deterministic_vllm_rl/__init__.py | 33 ------------------- .../{deterministic_vllm_rl => rl}/README.md | 2 +- .../unified/README.md | 7 ++-- .../unified/__init__.py | 3 +- .../unified/attention.py | 0 .../unified/infer.py | 6 ++-- .../unified/utils.py | 4 +-- .../unified/vllm_wrapper.py | 10 +++--- .../vllm_compat/README.md | 8 ++--- .../vllm_compat/__init__.py | 0 .../vllm_compat/batch_invariant_backward.py | 0 .../vllm_compat/models/__init__.py | 0 .../vllm_compat/models/attention.py | 0 .../vllm_compat/models/qwen3/__init__.py | 0 .../models/qwen3/model_vllm_compat.py | 2 +- .../vllm_compat/simple_rl.py | 12 +++---- .../vllm_compat/tests/__init__.py | 0 .../tests/test_batch_invariant_backward.py | 0 .../tests/test_exact_determinism.py | 2 +- .../vllm_compat/weights/README.md | 0 .../vllm_compat/weights/__init__.py | 0 .../vllm_compat/weights/converter.py | 0 .../vllm_compat/weights_vllm_compat.py | 0 24 files changed, 28 insertions(+), 62 deletions(-) delete mode 100644 torchtitan/experiments/deterministic_vllm_rl/__init__.py rename torchtitan/experiments/{deterministic_vllm_rl => rl}/README.md (92%) rename torchtitan/experiments/{deterministic_vllm_rl => rl}/unified/README.md (84%) rename torchtitan/experiments/{deterministic_vllm_rl => rl}/unified/__init__.py (99%) rename torchtitan/experiments/{deterministic_vllm_rl => rl}/unified/attention.py (100%) rename torchtitan/experiments/{deterministic_vllm_rl => rl}/unified/infer.py (100%) rename torchtitan/experiments/{deterministic_vllm_rl => rl}/unified/utils.py (100%) rename torchtitan/experiments/{deterministic_vllm_rl => rl}/unified/vllm_wrapper.py (99%) rename torchtitan/experiments/{deterministic_vllm_rl => rl}/vllm_compat/README.md (97%) rename torchtitan/experiments/{deterministic_vllm_rl => rl}/vllm_compat/__init__.py (100%) rename torchtitan/experiments/{deterministic_vllm_rl => rl}/vllm_compat/batch_invariant_backward.py (100%) rename torchtitan/experiments/{deterministic_vllm_rl => rl}/vllm_compat/models/__init__.py (100%) rename torchtitan/experiments/{deterministic_vllm_rl => rl}/vllm_compat/models/attention.py (100%) rename torchtitan/experiments/{deterministic_vllm_rl => rl}/vllm_compat/models/qwen3/__init__.py (100%) rename torchtitan/experiments/{deterministic_vllm_rl => rl}/vllm_compat/models/qwen3/model_vllm_compat.py (99%) rename torchtitan/experiments/{deterministic_vllm_rl => rl}/vllm_compat/simple_rl.py (99%) rename torchtitan/experiments/{deterministic_vllm_rl => rl}/vllm_compat/tests/__init__.py (100%) rename torchtitan/experiments/{deterministic_vllm_rl => rl}/vllm_compat/tests/test_batch_invariant_backward.py (100%) rename torchtitan/experiments/{deterministic_vllm_rl => rl}/vllm_compat/tests/test_exact_determinism.py (100%) rename torchtitan/experiments/{deterministic_vllm_rl => rl}/vllm_compat/weights/README.md (100%) rename torchtitan/experiments/{deterministic_vllm_rl => rl}/vllm_compat/weights/__init__.py (100%) rename torchtitan/experiments/{deterministic_vllm_rl => rl}/vllm_compat/weights/converter.py (100%) rename torchtitan/experiments/{deterministic_vllm_rl => rl}/vllm_compat/weights_vllm_compat.py (100%) diff --git a/torchtitan/experiments/README.md b/torchtitan/experiments/README.md index 10b90ac1d4..ff3b7b61d5 100644 --- a/torchtitan/experiments/README.md +++ b/torchtitan/experiments/README.md @@ -32,4 +32,5 @@ We provide this `experiments/` folder to host experiments that add significant v | [gpt_oss](./gpt_oss/) | TBA | [@jianiw](https://github.com/jianiw) | | [compiler_toolkit](./compiler_toolkit/) | [![Compiler Toolkit 8 GPU Integration Tests](https://github.com/pytorch/torchtitan/actions/workflows/integration_test_8gpu_compiler_toolkit.yaml/badge.svg?branch=main)](https://github.com/pytorch/torchtitan/actions/workflows/integration_test_8gpu_compiler_toolkit.yaml?query=branch%3Amain) | [@SherlockNoMad](https://github.com/SherlockNoMad) [@yiming0416](https://github.com/yiming0416) | | [transformers_modeling_backend](./transformers_modeling_backend/) | [![Transformers modeling backend 8 GPU Integration Tests](https://github.com/pytorch/torchtitan/actions/workflows/integration_test_8gpu_transformers_modeling_backend.yaml/badge.svg?branch=main)](https://github.com/pytorch/torchtitan/actions/workflows/integration_test_8gpu_transformers_modeling_backend.yaml?query=branch%3Amain) | [@3outeille](https://github.com/3outeille) | +| [rl](./rl/) | TBA | [@bwasti](https://github.com/bwasti)[@wwwjn](https://github.com/wwwjn) | | [autoparallel](./autoparallel/) | [![Auto Parallel 8 GPU Integration Tests](https://github.com/pytorch/torchtitan/actions/workflows/integration_test_8gpu_autoparallel.yaml/badge.svg?branch=main)](https://github.com/pytorch/torchtitan/actions/workflows/integration_test_8gpu_autoparallel.yaml?query=branch%3Amain) | [@wconstab](https://github.com/wconstab) [@xmfan](https://github.com/xmfan) | diff --git a/torchtitan/experiments/deterministic_vllm_rl/__init__.py b/torchtitan/experiments/deterministic_vllm_rl/__init__.py deleted file mode 100644 index 66c1de78a5..0000000000 --- a/torchtitan/experiments/deterministic_vllm_rl/__init__.py +++ /dev/null @@ -1,33 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - - -from .unified import ( - create_parallel_dims_from_vllm_config, - register_torchtitan_model_from_train_spec, - TorchTitanVLLMModelWrapper, -) -from .vllm_compat import ( - enable_batch_invariant_backward_mode, - Qwen3VLLMCompatModel, - rms_norm_with_gradients, - silu_and_mul_with_gradients, - VLLMCompatibleFlashAttention, -) - - -__all__ = [ - # vllm_compat exports - "VLLMCompatibleFlashAttention", - "Qwen3VLLMCompatModel", - "enable_batch_invariant_backward_mode", - "rms_norm_with_gradients", - "silu_and_mul_with_gradients", - # unified exports - "TorchTitanVLLMModelWrapper", - "create_parallel_dims_from_vllm_config", - "register_torchtitan_model_from_train_spec", -] diff --git a/torchtitan/experiments/deterministic_vllm_rl/README.md b/torchtitan/experiments/rl/README.md similarity index 92% rename from torchtitan/experiments/deterministic_vllm_rl/README.md rename to torchtitan/experiments/rl/README.md index c44bb45cfb..72b3d2ad11 100644 --- a/torchtitan/experiments/deterministic_vllm_rl/README.md +++ b/torchtitan/experiments/rl/README.md @@ -1,4 +1,4 @@ -# Deterministic vLLM RL Training +# Deterministic RL Training with vLLM This package provides two approaches for integrating TorchTitan models with vLLM: diff --git a/torchtitan/experiments/deterministic_vllm_rl/unified/README.md b/torchtitan/experiments/rl/unified/README.md similarity index 84% rename from torchtitan/experiments/deterministic_vllm_rl/unified/README.md rename to torchtitan/experiments/rl/unified/README.md index 68572927ab..5cea3918ae 100644 --- a/torchtitan/experiments/deterministic_vllm_rl/unified/README.md +++ b/torchtitan/experiments/rl/unified/README.md @@ -1,6 +1,7 @@ # Run vLLM inference with TorchTitan Qwen3 Model -This directory contains code to run a single canonical model definition (TorchTitan model definition) with vLLM inference engine (not batch-invariant yet, working in progress). +This directory contains code to run a single canonical model definition (TorchTitan model definition) with vLLM inference engine (not batch-invariant yet, working in progress). This work is actively developing and only supports inference for now. + This work is inspired by https://github.com/vllm-project/vllm/pull/28685. ## Overview @@ -52,12 +53,12 @@ pip install -e . 4. Run inference: ``` -python torchtitan/experiments/deterministic_vllm_rl/infer.py --model torchtitan/experiments/deterministic_vllm_rl/example_checkpoint/qwen3-0.6B +python torchtitan/experiments/rl/unified/infer.py --model torchtitan/experiments/deterministic_vllm_rl/example_checkpoint/qwen3-0.6B ``` Run with TP: (work in progress) ``` -python torchtitan/experiments/deterministic_vllm_rl/infer.py --model torchtitan/experiments/deterministic_vllm_rl/example_checkpoint/qwen3-0.6B --tensor-parallel-size 2 +python torchtitan/experiments/rl/unified/infer.py --model torchtitan/experiments/deterministic_vllm_rl/example_checkpoint/qwen3-0.6B --tensor-parallel-size 2 ``` diff --git a/torchtitan/experiments/deterministic_vllm_rl/unified/__init__.py b/torchtitan/experiments/rl/unified/__init__.py similarity index 99% rename from torchtitan/experiments/deterministic_vllm_rl/unified/__init__.py rename to torchtitan/experiments/rl/unified/__init__.py index f9862db824..6c34556112 100644 --- a/torchtitan/experiments/deterministic_vllm_rl/unified/__init__.py +++ b/torchtitan/experiments/rl/unified/__init__.py @@ -11,9 +11,8 @@ Uses the canonical TorchTitan model definition directly with vLLM inference engine. """ -from vllm.logger import init_logger - from torchtitan.protocols.train_spec import get_train_spec, TrainSpec +from vllm.logger import init_logger from .utils import create_parallel_dims_from_vllm_config from .vllm_wrapper import TorchTitanVLLMModelWrapper diff --git a/torchtitan/experiments/deterministic_vllm_rl/unified/attention.py b/torchtitan/experiments/rl/unified/attention.py similarity index 100% rename from torchtitan/experiments/deterministic_vllm_rl/unified/attention.py rename to torchtitan/experiments/rl/unified/attention.py diff --git a/torchtitan/experiments/deterministic_vllm_rl/unified/infer.py b/torchtitan/experiments/rl/unified/infer.py similarity index 100% rename from torchtitan/experiments/deterministic_vllm_rl/unified/infer.py rename to torchtitan/experiments/rl/unified/infer.py index e52e7d2651..19770ecc22 100755 --- a/torchtitan/experiments/deterministic_vllm_rl/unified/infer.py +++ b/torchtitan/experiments/rl/unified/infer.py @@ -7,12 +7,12 @@ import argparse -from vllm import LLM, SamplingParams -from vllm.logger import init_logger - # Import unified module - this automatically registers TorchTitan models with vLLM from torchtitan.experiments.deterministic_vllm_rl import unified # noqa: F401 +from vllm import LLM, SamplingParams +from vllm.logger import init_logger + logger = init_logger(__name__) diff --git a/torchtitan/experiments/deterministic_vllm_rl/unified/utils.py b/torchtitan/experiments/rl/unified/utils.py similarity index 100% rename from torchtitan/experiments/deterministic_vllm_rl/unified/utils.py rename to torchtitan/experiments/rl/unified/utils.py index 0b56d81be4..e997c387d9 100644 --- a/torchtitan/experiments/deterministic_vllm_rl/unified/utils.py +++ b/torchtitan/experiments/rl/unified/utils.py @@ -12,10 +12,10 @@ """ import torch.distributed as dist -from vllm.config import VllmConfig -from vllm.logger import init_logger from torchtitan.distributed.parallel_dims import ParallelDims +from vllm.config import VllmConfig +from vllm.logger import init_logger logger = init_logger(__name__) diff --git a/torchtitan/experiments/deterministic_vllm_rl/unified/vllm_wrapper.py b/torchtitan/experiments/rl/unified/vllm_wrapper.py similarity index 99% rename from torchtitan/experiments/deterministic_vllm_rl/unified/vllm_wrapper.py rename to torchtitan/experiments/rl/unified/vllm_wrapper.py index 9e7253c9f9..e92903c744 100644 --- a/torchtitan/experiments/deterministic_vllm_rl/unified/vllm_wrapper.py +++ b/torchtitan/experiments/rl/unified/vllm_wrapper.py @@ -12,7 +12,6 @@ """ from functools import partial -from typing import Callable, TypeAlias import torch import torch.nn as nn @@ -22,21 +21,20 @@ StateDictOptions, ) -from vllm.config import VllmConfig -from vllm.logger import init_logger - from torchtitan.experiments.deterministic_vllm_rl.unified.attention import VLLMAttention from torchtitan.models.qwen3.model.model import precompute_rope_cache from torchtitan.protocols.model import BaseModelArgs, ModelProtocol from torchtitan.protocols.state_dict_adapter import BaseStateDictAdapter +from torchtitan.protocols.train_spec import ParallelizeFunction + +from vllm.config import VllmConfig +from vllm.logger import init_logger from .utils import create_parallel_dims_from_vllm_config logger = init_logger(__name__) -ParallelizeFunction: TypeAlias = Callable[..., nn.Module] - class TorchTitanVLLMModelWrapper(nn.Module): """ diff --git a/torchtitan/experiments/deterministic_vllm_rl/vllm_compat/README.md b/torchtitan/experiments/rl/vllm_compat/README.md similarity index 97% rename from torchtitan/experiments/deterministic_vllm_rl/vllm_compat/README.md rename to torchtitan/experiments/rl/vllm_compat/README.md index d2ef719c0d..bf56f4afbe 100644 --- a/torchtitan/experiments/deterministic_vllm_rl/vllm_compat/README.md +++ b/torchtitan/experiments/rl/vllm_compat/README.md @@ -77,7 +77,7 @@ init_batch_invariance() ```python import torch from vllm.model_executor.layers.batch_invariant import init_batch_invariance -from torchtitan.experiments.deterministic_vllm_rl import ( +from torchtitan.experiments.rl.vllm_compat import ( enable_batch_invariant_backward_mode, Qwen3VLLMCompatModel, ) @@ -111,7 +111,7 @@ loss.backward() Run the RL training loop: ```bash -VLLM_BATCH_INVARIANT=1 VLLM_FLASH_ATTN_VERSION=3 python -m torchtitan.experiments.deterministic_vllm_rl.simple_rl +VLLM_BATCH_INVARIANT=1 VLLM_FLASH_ATTN_VERSION=3 python -m torchtitan.experiments.rl.vllm_compat.simple_rl ``` This will: @@ -177,7 +177,7 @@ assert torch.equal(vllm_logprobs, titan_logprobs) Run the test suite: ```bash -cd torchtitan/experiments/deterministic_vllm_rl/tests +cd torchtitan/experiments/rl/vllm_compat/tests # Test backward passes python test_batch_invariant_backward.py @@ -214,7 +214,7 @@ This implementation uses the same kernels for both rollouts (vLLM) and training ## Project Structure ``` -deterministic_vllm_rl/ +rl/vllm_compat/ ├── README.md # Documentation ├── __init__.py # Package initialization ├── batch_invariant_backward.py # Backward passes for vLLM ops diff --git a/torchtitan/experiments/deterministic_vllm_rl/vllm_compat/__init__.py b/torchtitan/experiments/rl/vllm_compat/__init__.py similarity index 100% rename from torchtitan/experiments/deterministic_vllm_rl/vllm_compat/__init__.py rename to torchtitan/experiments/rl/vllm_compat/__init__.py diff --git a/torchtitan/experiments/deterministic_vllm_rl/vllm_compat/batch_invariant_backward.py b/torchtitan/experiments/rl/vllm_compat/batch_invariant_backward.py similarity index 100% rename from torchtitan/experiments/deterministic_vllm_rl/vllm_compat/batch_invariant_backward.py rename to torchtitan/experiments/rl/vllm_compat/batch_invariant_backward.py diff --git a/torchtitan/experiments/deterministic_vllm_rl/vllm_compat/models/__init__.py b/torchtitan/experiments/rl/vllm_compat/models/__init__.py similarity index 100% rename from torchtitan/experiments/deterministic_vllm_rl/vllm_compat/models/__init__.py rename to torchtitan/experiments/rl/vllm_compat/models/__init__.py diff --git a/torchtitan/experiments/deterministic_vllm_rl/vllm_compat/models/attention.py b/torchtitan/experiments/rl/vllm_compat/models/attention.py similarity index 100% rename from torchtitan/experiments/deterministic_vllm_rl/vllm_compat/models/attention.py rename to torchtitan/experiments/rl/vllm_compat/models/attention.py diff --git a/torchtitan/experiments/deterministic_vllm_rl/vllm_compat/models/qwen3/__init__.py b/torchtitan/experiments/rl/vllm_compat/models/qwen3/__init__.py similarity index 100% rename from torchtitan/experiments/deterministic_vllm_rl/vllm_compat/models/qwen3/__init__.py rename to torchtitan/experiments/rl/vllm_compat/models/qwen3/__init__.py diff --git a/torchtitan/experiments/deterministic_vllm_rl/vllm_compat/models/qwen3/model_vllm_compat.py b/torchtitan/experiments/rl/vllm_compat/models/qwen3/model_vllm_compat.py similarity index 99% rename from torchtitan/experiments/deterministic_vllm_rl/vllm_compat/models/qwen3/model_vllm_compat.py rename to torchtitan/experiments/rl/vllm_compat/models/qwen3/model_vllm_compat.py index b4967fbbd9..2c9742b1fa 100644 --- a/torchtitan/experiments/deterministic_vllm_rl/vllm_compat/models/qwen3/model_vllm_compat.py +++ b/torchtitan/experiments/rl/vllm_compat/models/qwen3/model_vllm_compat.py @@ -13,7 +13,7 @@ from torchtitan.components.tokenizer import BaseTokenizer # Import gradient-enabled operations from experiment utilities -from torchtitan.experiments.deterministic_vllm_rl.vllm_compat.batch_invariant_backward import ( +from torchtitan.experiments.rl.vllm_compat.batch_invariant_backward import ( rms_norm_with_gradients, silu_and_mul_with_gradients, ) diff --git a/torchtitan/experiments/deterministic_vllm_rl/vllm_compat/simple_rl.py b/torchtitan/experiments/rl/vllm_compat/simple_rl.py similarity index 99% rename from torchtitan/experiments/deterministic_vllm_rl/vllm_compat/simple_rl.py rename to torchtitan/experiments/rl/vllm_compat/simple_rl.py index 3f938eba85..508868c0d4 100644 --- a/torchtitan/experiments/deterministic_vllm_rl/vllm_compat/simple_rl.py +++ b/torchtitan/experiments/rl/vllm_compat/simple_rl.py @@ -25,20 +25,20 @@ from huggingface_hub import snapshot_download from safetensors.torch import load_file, save_file from torch.utils.tensorboard import SummaryWriter -from transformers import AutoConfig, AutoTokenizer - -from vllm import LLM, SamplingParams -from vllm.model_executor.layers.batch_invariant import init_batch_invariance -from torchtitan.experiments.deterministic_vllm_rl.vllm_compat.weights.converter import ( +from torchtitan.experiments.rl.vllm_compat.weights.converter import ( torchtitan_to_vllm, vllm_to_torchtitan, ) -from torchtitan.experiments.deterministic_vllm_rl.vllm_compat.weights_vllm_compat import ( +from torchtitan.experiments.rl.vllm_compat.weights_vllm_compat import ( torchtitan_to_vllm_compat, ) from torchtitan.models.qwen3.model.args import Qwen3ModelArgs +from transformers import AutoConfig, AutoTokenizer + +from vllm import LLM, SamplingParams +from vllm.model_executor.layers.batch_invariant import init_batch_invariance init_batch_invariance() diff --git a/torchtitan/experiments/deterministic_vllm_rl/vllm_compat/tests/__init__.py b/torchtitan/experiments/rl/vllm_compat/tests/__init__.py similarity index 100% rename from torchtitan/experiments/deterministic_vllm_rl/vllm_compat/tests/__init__.py rename to torchtitan/experiments/rl/vllm_compat/tests/__init__.py diff --git a/torchtitan/experiments/deterministic_vllm_rl/vllm_compat/tests/test_batch_invariant_backward.py b/torchtitan/experiments/rl/vllm_compat/tests/test_batch_invariant_backward.py similarity index 100% rename from torchtitan/experiments/deterministic_vllm_rl/vllm_compat/tests/test_batch_invariant_backward.py rename to torchtitan/experiments/rl/vllm_compat/tests/test_batch_invariant_backward.py diff --git a/torchtitan/experiments/deterministic_vllm_rl/vllm_compat/tests/test_exact_determinism.py b/torchtitan/experiments/rl/vllm_compat/tests/test_exact_determinism.py similarity index 100% rename from torchtitan/experiments/deterministic_vllm_rl/vllm_compat/tests/test_exact_determinism.py rename to torchtitan/experiments/rl/vllm_compat/tests/test_exact_determinism.py index bfb7954a2a..2a9863ab2f 100644 --- a/torchtitan/experiments/deterministic_vllm_rl/vllm_compat/tests/test_exact_determinism.py +++ b/torchtitan/experiments/rl/vllm_compat/tests/test_exact_determinism.py @@ -11,11 +11,11 @@ """ import torch -from vllm.model_executor.layers.batch_invariant import disable_batch_invariant_mode from torchtitan.experiments.deterministic_vllm_rl.vllm_compat.batch_invariant_backward import ( enable_batch_invariant_backward_mode, ) +from vllm.model_executor.layers.batch_invariant import disable_batch_invariant_mode print("Enabling batch_invariant_backward mode...") disable_batch_invariant_mode() diff --git a/torchtitan/experiments/deterministic_vllm_rl/vllm_compat/weights/README.md b/torchtitan/experiments/rl/vllm_compat/weights/README.md similarity index 100% rename from torchtitan/experiments/deterministic_vllm_rl/vllm_compat/weights/README.md rename to torchtitan/experiments/rl/vllm_compat/weights/README.md diff --git a/torchtitan/experiments/deterministic_vllm_rl/vllm_compat/weights/__init__.py b/torchtitan/experiments/rl/vllm_compat/weights/__init__.py similarity index 100% rename from torchtitan/experiments/deterministic_vllm_rl/vllm_compat/weights/__init__.py rename to torchtitan/experiments/rl/vllm_compat/weights/__init__.py diff --git a/torchtitan/experiments/deterministic_vllm_rl/vllm_compat/weights/converter.py b/torchtitan/experiments/rl/vllm_compat/weights/converter.py similarity index 100% rename from torchtitan/experiments/deterministic_vllm_rl/vllm_compat/weights/converter.py rename to torchtitan/experiments/rl/vllm_compat/weights/converter.py diff --git a/torchtitan/experiments/deterministic_vllm_rl/vllm_compat/weights_vllm_compat.py b/torchtitan/experiments/rl/vllm_compat/weights_vllm_compat.py similarity index 100% rename from torchtitan/experiments/deterministic_vllm_rl/vllm_compat/weights_vllm_compat.py rename to torchtitan/experiments/rl/vllm_compat/weights_vllm_compat.py From 7e1b1b83c19ef7f35c67bf8e8075808a8bef5cc0 Mon Sep 17 00:00:00 2001 From: Jiani Wang Date: Mon, 15 Dec 2025 20:58:34 -0800 Subject: [PATCH 19/19] fix readme --- torchtitan/experiments/README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchtitan/experiments/README.md b/torchtitan/experiments/README.md index ff3b7b61d5..53df45dd84 100644 --- a/torchtitan/experiments/README.md +++ b/torchtitan/experiments/README.md @@ -29,8 +29,8 @@ We provide this `experiments/` folder to host experiments that add significant v | [forge](./forge/) | TBA | [@allenwang28](https://github.com/allenwang28) [@ebsmothers](https://github.com/ebsmothers) [@joecummings](https://github.com/joecummings) [@pbontrager](https://github.com/pbontrager) | | [torchcomms](./torchcomms/) | [![TorchComms 8 GPU Integration Tests](https://github.com/pytorch/torchtitan/actions/workflows/integration_test_8gpu_torchcomms.yaml/badge.svg?branch=main)](https://github.com/pytorch/torchtitan/actions/workflows/integration_test_8gpu_torchcomms.yaml?query=branch%3Amain) | [@d4l3k](https://https://github.com/d4l3k) [@fduwjj](https://github.com/fduwjj) [@mori360 ](https://github.com/mori360) | | [moe_symm_mem_kernels](./moe_symm_mem_kernels/) | TBA | [@kwen2501](https://github.com/kwen2501) | -| [gpt_oss](./gpt_oss/) | TBA | [@jianiw](https://github.com/jianiw) | +| [gpt_oss](./gpt_oss/) | TBA | [@wwwjn](https://github.com/wwwjn) | | [compiler_toolkit](./compiler_toolkit/) | [![Compiler Toolkit 8 GPU Integration Tests](https://github.com/pytorch/torchtitan/actions/workflows/integration_test_8gpu_compiler_toolkit.yaml/badge.svg?branch=main)](https://github.com/pytorch/torchtitan/actions/workflows/integration_test_8gpu_compiler_toolkit.yaml?query=branch%3Amain) | [@SherlockNoMad](https://github.com/SherlockNoMad) [@yiming0416](https://github.com/yiming0416) | | [transformers_modeling_backend](./transformers_modeling_backend/) | [![Transformers modeling backend 8 GPU Integration Tests](https://github.com/pytorch/torchtitan/actions/workflows/integration_test_8gpu_transformers_modeling_backend.yaml/badge.svg?branch=main)](https://github.com/pytorch/torchtitan/actions/workflows/integration_test_8gpu_transformers_modeling_backend.yaml?query=branch%3Amain) | [@3outeille](https://github.com/3outeille) | -| [rl](./rl/) | TBA | [@bwasti](https://github.com/bwasti)[@wwwjn](https://github.com/wwwjn) | +| [rl](./rl/) | TBA | [@bwasti](https://github.com/bwasti) [@wwwjn](https://github.com/wwwjn) | | [autoparallel](./autoparallel/) | [![Auto Parallel 8 GPU Integration Tests](https://github.com/pytorch/torchtitan/actions/workflows/integration_test_8gpu_autoparallel.yaml/badge.svg?branch=main)](https://github.com/pytorch/torchtitan/actions/workflows/integration_test_8gpu_autoparallel.yaml?query=branch%3Amain) | [@wconstab](https://github.com/wconstab) [@xmfan](https://github.com/xmfan) |