diff --git a/torchtitan/experiments/README.md b/torchtitan/experiments/README.md index 10b90ac1d4..53df45dd84 100644 --- a/torchtitan/experiments/README.md +++ b/torchtitan/experiments/README.md @@ -29,7 +29,8 @@ We provide this `experiments/` folder to host experiments that add significant v | [forge](./forge/) | TBA | [@allenwang28](https://github.com/allenwang28) [@ebsmothers](https://github.com/ebsmothers) [@joecummings](https://github.com/joecummings) [@pbontrager](https://github.com/pbontrager) | | [torchcomms](./torchcomms/) | [![TorchComms 8 GPU Integration Tests](https://github.com/pytorch/torchtitan/actions/workflows/integration_test_8gpu_torchcomms.yaml/badge.svg?branch=main)](https://github.com/pytorch/torchtitan/actions/workflows/integration_test_8gpu_torchcomms.yaml?query=branch%3Amain) | [@d4l3k](https://https://github.com/d4l3k) [@fduwjj](https://github.com/fduwjj) [@mori360 ](https://github.com/mori360) | | [moe_symm_mem_kernels](./moe_symm_mem_kernels/) | TBA | [@kwen2501](https://github.com/kwen2501) | -| [gpt_oss](./gpt_oss/) | TBA | [@jianiw](https://github.com/jianiw) | +| [gpt_oss](./gpt_oss/) | TBA | [@wwwjn](https://github.com/wwwjn) | | [compiler_toolkit](./compiler_toolkit/) | [![Compiler Toolkit 8 GPU Integration Tests](https://github.com/pytorch/torchtitan/actions/workflows/integration_test_8gpu_compiler_toolkit.yaml/badge.svg?branch=main)](https://github.com/pytorch/torchtitan/actions/workflows/integration_test_8gpu_compiler_toolkit.yaml?query=branch%3Amain) | [@SherlockNoMad](https://github.com/SherlockNoMad) [@yiming0416](https://github.com/yiming0416) | | [transformers_modeling_backend](./transformers_modeling_backend/) | [![Transformers modeling backend 8 GPU Integration Tests](https://github.com/pytorch/torchtitan/actions/workflows/integration_test_8gpu_transformers_modeling_backend.yaml/badge.svg?branch=main)](https://github.com/pytorch/torchtitan/actions/workflows/integration_test_8gpu_transformers_modeling_backend.yaml?query=branch%3Amain) | [@3outeille](https://github.com/3outeille) | +| [rl](./rl/) | TBA | [@bwasti](https://github.com/bwasti) [@wwwjn](https://github.com/wwwjn) | | [autoparallel](./autoparallel/) | [![Auto Parallel 8 GPU Integration Tests](https://github.com/pytorch/torchtitan/actions/workflows/integration_test_8gpu_autoparallel.yaml/badge.svg?branch=main)](https://github.com/pytorch/torchtitan/actions/workflows/integration_test_8gpu_autoparallel.yaml?query=branch%3Amain) | [@wconstab](https://github.com/wconstab) [@xmfan](https://github.com/xmfan) | diff --git a/torchtitan/experiments/rl/README.md b/torchtitan/experiments/rl/README.md new file mode 100644 index 0000000000..72b3d2ad11 --- /dev/null +++ b/torchtitan/experiments/rl/README.md @@ -0,0 +1,12 @@ +# Deterministic RL Training with vLLM + +This package provides two approaches for integrating TorchTitan models with vLLM: + +1. vllm_compat/ - vLLM-Compatible approach + - Separate model definition matching vLLM's weight format + - Support batch-invariant and bit-wise identity between train and inference + - Custom backward passes for attention gradient computation + +2. unified/ - Unified approach + - Uses canonical TorchTitan model definition for inference directly + - Replaces attention with vLLM Compatible attention for inference diff --git a/torchtitan/experiments/rl/unified/README.md b/torchtitan/experiments/rl/unified/README.md new file mode 100644 index 0000000000..5cea3918ae --- /dev/null +++ b/torchtitan/experiments/rl/unified/README.md @@ -0,0 +1,68 @@ +# Run vLLM inference with TorchTitan Qwen3 Model + +This directory contains code to run a single canonical model definition (TorchTitan model definition) with vLLM inference engine (not batch-invariant yet, working in progress). This work is actively developing and only supports inference for now. + +This work is inspired by https://github.com/vllm-project/vllm/pull/28685. + +## Overview +The integration consists of two main components: + +1. **Model Adapter** (`model/qwen3.py`): A custom model class that extends vLLM's `Qwen3ForCausalLM` to handle TorchTitan checkpoint naming conventions +2. **Inference Script** (`infer.py`): A simple script to register the model and run inference + + +## Quick Start +### Prerequisites + +1. Install PyTorch nightly for torchtitan: +``` +pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu126 --force-reinstall +``` + + +2. Install vLLM from source [vllm-use-an-existing-pytorch-installation](https://docs.vllm.ai/en/latest/getting_started/installation/gpu/index.html#use-an-existing-pytorch-installation): +```bash +# install PyTorch first, either from PyPI or from source +git clone https://github.com/vllm-project/vllm.git +cd vllm +python use_existing_torch.py +uv pip install -r requirements/build.txt +uv pip install --no-build-isolation -e . +``` + + +NOTE: If `flash_attn_varlen_func` hits error "torch.AcceleratorError: CUDA error: the provided PTX was compiled with an unsupported toolchain" during forward path, this is due to GPU driver version is not compatible with vLLM/PyTorch compiled version. Use the following command to recompile vLLM. + +``` +# Set CUDA version environment variable +export CUDA_HOME=/usr/local/cuda-12.4 +export PATH=/usr/local/cuda-12.4/bin:$PATH +export LD_LIBRARY_PATH=/usr/local/cuda-12.4/lib64:$LD_LIBRARY_PATH + +# Clean previous build +rm -rf build dist *.egg-info +uv pip uninstall -y vllm + +# Rebuild vLLM from source with CUDA 12.4 +pip install -e . + +``` + +3. Download Qwen3/Qwen3-0.6b checkpoint from HuggingFace and put into `example_checkpoint` folder. + + +4. Run inference: +``` +python torchtitan/experiments/rl/unified/infer.py --model torchtitan/experiments/deterministic_vllm_rl/example_checkpoint/qwen3-0.6B +``` + +Run with TP: (work in progress) +``` +python torchtitan/experiments/rl/unified/infer.py --model torchtitan/experiments/deterministic_vllm_rl/example_checkpoint/qwen3-0.6B --tensor-parallel-size 2 + +``` + +## TODO +1. Rewrite attention part to use vllm.Attention() with backward as the only attention path. +2. Integrate with simple_rl.py to run end-to-end RL with one canonical model definition. +3. Leverage batch-invariant kernels into model definition. diff --git a/torchtitan/experiments/rl/unified/__init__.py b/torchtitan/experiments/rl/unified/__init__.py new file mode 100644 index 0000000000..6c34556112 --- /dev/null +++ b/torchtitan/experiments/rl/unified/__init__.py @@ -0,0 +1,93 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +Unified approach for running TorchTitan models with vLLM inference. + +This module automatically registers TorchTitan models with vLLM when imported. +Uses the canonical TorchTitan model definition directly with vLLM inference engine. +""" + +from torchtitan.protocols.train_spec import get_train_spec, TrainSpec +from vllm.logger import init_logger + +from .utils import create_parallel_dims_from_vllm_config +from .vllm_wrapper import TorchTitanVLLMModelWrapper + + +logger = init_logger(__name__) + + +def register_torchtitan_model_from_train_spec( + train_spec: TrainSpec, + model_name: str, + model_flavor: str, +) -> None: + """ + Register a TorchTitan model with vLLM using a TrainSpec. + + Args: + train_spec: TorchTitan TrainSpec containing model components + model_name: Name to register in vLLM (e.g., "Qwen3TorchTitanForCausalLM") + model_flavor: Model flavor key (e.g., "0.6B") to select from qwen3_args + + """ + from vllm.model_executor.models.registry import ModelRegistry + + # Get model_args directly from TrainSpec.model_args dict using flavor key + if isinstance(train_spec.model_args, dict): + if model_flavor not in train_spec.model_args: + raise ValueError( + f"Model flavor '{model_flavor}' not found in train_spec.model_args. " + f"Available flavors: {list(train_spec.model_args.keys())}" + ) + model_args = train_spec.model_args[model_flavor] + else: + raise ValueError( + "train_spec.model_args must be a dict mapping flavor names to ModelArgs" + ) + + # Create dynamic model class directly from TrainSpec components + class TorchTitanVLLMModelFromSpec(TorchTitanVLLMModelWrapper): + def __init__(self, *, vllm_config, prefix=""): + super().__init__( + model_cls=train_spec.model_cls, + model_args=model_args, + state_dict_adapter=train_spec.state_dict_adapter, + parallelize_fn=train_spec.parallelize_fn, + vllm_config=vllm_config, + prefix=prefix, + ) + + # Set the class name + TorchTitanVLLMModelFromSpec.__name__ = model_name + TorchTitanVLLMModelFromSpec.__qualname__ = model_name + + # Register with vLLM + ModelRegistry.register_model(model_name, TorchTitanVLLMModelFromSpec) + + logger.info( + f"Successfully registered {model_name} with vLLM using TrainSpec " + f"(model_cls={train_spec.model_cls.__name__}, flavor={model_flavor})" + ) + + +# Auto-register TorchTitan models with vLLM when this module is imported +register_torchtitan_model_from_train_spec( + train_spec=get_train_spec("qwen3"), + model_name="Qwen3TorchTitanForCausalLM", + # TODO: Remove the model_flavor args when registering model, + # allow passing model flavor option from config system. Now we have to specify + # model_flavor during registration because we can not pass torchtitan job_config from LLM() Api + model_flavor="0.6B", +) + + +__all__ = [ + "TorchTitanVLLMModelWrapper", + "create_parallel_dims_from_vllm_config", + "register_torchtitan_model_from_train_spec", +] diff --git a/torchtitan/experiments/rl/unified/attention.py b/torchtitan/experiments/rl/unified/attention.py new file mode 100644 index 0000000000..1a03b882cb --- /dev/null +++ b/torchtitan/experiments/rl/unified/attention.py @@ -0,0 +1,93 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from vllm.attention.layer import Attention + + +class VLLMAttention(torch.nn.Module): + """ + Wrapper around vLLM's Attention. Compatible with TorchTitan input shape. + """ + + def __init__( + self, + hidden_size: int, + num_heads: int, + num_kv_heads: int, + head_dim: int, + layer_name: str, + scale: float | None = None, + ) -> None: + super().__init__() + + self.hidden_size = hidden_size + self.layer_name = layer_name + + from vllm.config import get_current_vllm_config + + vllm_config = get_current_vllm_config() + + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = head_dim + + if scale is None: + self.scale = head_dim**-0.5 + else: + self.scale = scale + + cache_config = ( + vllm_config.cache_config if hasattr(vllm_config, "cache_config") else None + ) + + self.vllm_attn = Attention( + num_heads=num_heads, + head_size=head_dim, + scale=self.scale, + num_kv_heads=num_kv_heads, + cache_config=cache_config, + quant_config=None, + prefix=f"model.layers.{layer_name}.attention.inner_attention", + ) + + def forward( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + *, + scale: float | None = None, + ) -> torch.Tensor: + """ + Forward pass using vLLM's Attention layer for inference. + + Args: + q: Query tensor [batch, num_heads, seq_len, head_dim] + k: Key tensor [batch, num_kv_heads, seq_len, head_dim] + v: Value tensor [batch, num_kv_heads, seq_len, head_dim] + scale: Optional attention scale override (unused, vLLM uses internal scale) + + Returns: + output: [batch, num_heads, seq_len, head_dim] + """ + # Input is (batch, num_heads, seq_len, head_dim) + batch_size, num_heads, seq_len, head_dim = q.shape + + # Transpose to (batch, seq_len, num_heads, head_dim) for vLLM + q = q.transpose(1, 2) + k = k.transpose(1, 2) + v = v.transpose(1, 2) + + output_varlen = self.vllm_attn(q, k, v) + + # Reshape back to batch format + output = output_varlen.view(batch_size, seq_len, num_heads, head_dim) + + # Transpose back to TorchTitan format: (batch, num_heads, seq_len, head_dim) + output = output.transpose(1, 2) + + return output diff --git a/torchtitan/experiments/rl/unified/infer.py b/torchtitan/experiments/rl/unified/infer.py new file mode 100755 index 0000000000..19770ecc22 --- /dev/null +++ b/torchtitan/experiments/rl/unified/infer.py @@ -0,0 +1,115 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import argparse + +# Import unified module - this automatically registers TorchTitan models with vLLM +from torchtitan.experiments.deterministic_vllm_rl import unified # noqa: F401 + +from vllm import LLM, SamplingParams +from vllm.logger import init_logger + + +logger = init_logger(__name__) + + +def parse_args(): + parser = argparse.ArgumentParser( + description="Run TorchTitan model inference with vLLM Engine", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + parser.add_argument( + "--model_ckpt_path", + type=str, + default="torchtitan/experiments/deterministic_vllm_rl/example_checkpoint", + help="Path to TorchTitan checkpoint directory", + ) + parser.add_argument( + "--prompt", + type=str, + default="Hello, my name is", + help="Prompt text for generation", + ) + parser.add_argument( + "--max-tokens", + type=int, + default=100, + help="Maximum number of tokens to generate", + ) + parser.add_argument( + "--temperature", + type=float, + default=0.8, + help="Sampling temperature", + ) + parser.add_argument( + "--tensor-parallel-size", + type=int, + default=1, + help="Number of GPUs for tensor parallelism (default: 1 for single GPU)", + ) + return parser.parse_args() + + +def main(): + args = parse_args() + + logger.info("Initializing vLLM with TorchTitan model") + logger.info(f"Model: {args.model_ckpt_path}") + logger.info(f"Tensor Parallel Size: {args.tensor_parallel_size}") + + # Initialize vLLM with custom TorchTitan model + # The LLM initialization will internally: + # 1. Load TrainSpec for Qwen3 (from models/__init__.py register()) + # 2. Create TorchTitanVLLMModel instance + # 3. Create JobConfig and ParallelDims from vLLM config + # 4. Apply parallelization using parallelize_qwen3 + # 5. Load model weights and prepare for inference + logger.info("Creating vLLM LLM engine...") + + llm = LLM( + model=args.model_ckpt_path, # Model checkpoint path + hf_overrides={ + # Override architectures to use our registered TorchTitan model class + "architectures": ["Qwen3TorchTitanForCausalLM"], + }, + dtype="bfloat16", + trust_remote_code=True, + enforce_eager=True, # Use eager mode + tensor_parallel_size=args.tensor_parallel_size, + ) + + logger.info("vLLM engine initialized successfully") + logger.info(f"Prompt: {args.prompt}") + + # Prepare prompt and sampling parameters + prompts = [args.prompt] + sampling_params = SamplingParams( + temperature=args.temperature, + top_p=0.95, + max_tokens=args.max_tokens, + ) + + # Generate text + logger.info("Generating text...") + outputs = llm.generate( + prompts=prompts, + sampling_params=sampling_params, + ) + + # Print results + logger.info("Generation complete") + for output in outputs: + prompt = output.prompt + generated_text = output.outputs[0].text + + print(f"\nPrompt: {prompt}") + print(f"Generated text: {generated_text!r}\n") + + +if __name__ == "__main__": + main() diff --git a/torchtitan/experiments/rl/unified/utils.py b/torchtitan/experiments/rl/unified/utils.py new file mode 100644 index 0000000000..e997c387d9 --- /dev/null +++ b/torchtitan/experiments/rl/unified/utils.py @@ -0,0 +1,63 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +Parallelization utilities for vLLM + TorchTitan models. + +This module provides functions for setting up device mesh and applying +tensor parallelism to TorchTitan models in vLLM using TorchTitan's ParallelDims. +""" + +import torch.distributed as dist + +from torchtitan.distributed.parallel_dims import ParallelDims +from vllm.config import VllmConfig +from vllm.logger import init_logger + + +logger = init_logger(__name__) + + +def create_parallel_dims_from_vllm_config(vllm_config: VllmConfig) -> ParallelDims: + """ + Create ParallelDims from vLLM config and maps vLLM parallelism settings to TorchTitan's ParallelDims dataclass. + + This function is needed because vLLM doesn't separate model creation and + parallelism application - it requires parallelization to be done inside + the model constructor, so we are creating parallel_dims and apply parallelism + in TorchTitanVLLMModelWrapper.__init__ function. + + Args: + vllm_config: vLLM configuration object + + Returns: + ParallelDims object with parallelism settings validated + + Note: + vLLM doesn't use FSDP sharding (dp_shard=1) or expert parallelism (ep=1, etp=1) + in inference. These are set to default values. + """ + world_size = dist.get_world_size() + + # Map vLLM config to TorchTitan ParallelDims + parallel_dims = ParallelDims( + dp_replicate=vllm_config.parallel_config.data_parallel_size, + dp_shard=1, # vLLM doesn't use FSDP sharding + cp=vllm_config.parallel_config.decode_context_parallel_size, + tp=vllm_config.parallel_config.tensor_parallel_size, + pp=vllm_config.parallel_config.pipeline_parallel_size, + ep=1, # Expert parallelism not used in vLLM inference yet + etp=1, # Expert tensor parallelism not used in vLLM inference yet + world_size=world_size, + ) + + logger.info( + f"Created ParallelDims from vLLM config: " + f"DP={parallel_dims.dp_replicate}, TP={parallel_dims.tp}, " + f"CP={parallel_dims.cp}, PP={parallel_dims.pp}" + ) + + return parallel_dims diff --git a/torchtitan/experiments/rl/unified/vllm_wrapper.py b/torchtitan/experiments/rl/unified/vllm_wrapper.py new file mode 100644 index 0000000000..e92903c744 --- /dev/null +++ b/torchtitan/experiments/rl/unified/vllm_wrapper.py @@ -0,0 +1,329 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +Base wrapper for TorchTitan models to work with vLLM V1 engine. + +This module provides TorchTitanVLLMModel: Core model class that adapts +TorchTitan models for vLLM. +""" + +from functools import partial + +import torch +import torch.nn as nn +from torch.distributed._tensor import DTensor, Replicate +from torch.distributed.checkpoint.state_dict import ( + set_model_state_dict, + StateDictOptions, +) + +from torchtitan.experiments.deterministic_vllm_rl.unified.attention import VLLMAttention +from torchtitan.models.qwen3.model.model import precompute_rope_cache +from torchtitan.protocols.model import BaseModelArgs, ModelProtocol +from torchtitan.protocols.state_dict_adapter import BaseStateDictAdapter +from torchtitan.protocols.train_spec import ParallelizeFunction + +from vllm.config import VllmConfig +from vllm.logger import init_logger + +from .utils import create_parallel_dims_from_vllm_config + + +logger = init_logger(__name__) + + +class TorchTitanVLLMModelWrapper(nn.Module): + """ + Generic vLLM-compatible model wrapper for TorchTitan models. + + The wrapper handles: + - Direct usage of TorchTitan model args (no HF config mapping needed) + - Attention replacement with vLLM paged attention + - Tensor parallelism setup + - Weight loading from HF checkpoints + - vLLM forward/compute_logits interface + """ + + is_text_generation_model = True # Required for vLLM runner validation + supports_pp = False # Pipeline parallelism not supported yet + supports_multimodal = False + + def __init__( + self, + *, + model_cls: type[ModelProtocol], + model_args: BaseModelArgs, + state_dict_adapter: type[BaseStateDictAdapter], + parallelize_fn: ParallelizeFunction, + vllm_config: VllmConfig, + prefix: str = "", + ): + super().__init__() + + assert vllm_config is not None, "vllm_config is required" + + # Store components + self.model_cls = model_cls + self.state_dict_adapter = state_dict_adapter + self.parallelize_fn = parallelize_fn + + # Use TorchTitan model args directly (no HF config mapping) + self.config = model_args + logger.info(f"Creating {self.model_cls.__name__} with config: {model_args}") + self.model = self.model_cls(model_args) + + # Setup RoPE cache extension function if provided + self.rope_cache_extension_fn = partial( + precompute_rope_cache, + dim=self.config.head_dim, + base=self.config.rope_theta, + ) + # Replace attention with vLLM paged attention + self._replace_with_vllm_attention(model_args) + + # Create ParallelDims from vLLM config and apply parallelization + # NOTE: We need to apply parallelize within model.__init__ because w + parallel_dims = create_parallel_dims_from_vllm_config(vllm_config) + if parallel_dims.tp_enabled: + self.world_mesh = parallel_dims.world_mesh + tp_mesh = self.world_mesh["tp"] + parallelize_fn( + model=self.model, + tp_mesh=tp_mesh, + loss_parallel=False, + enable_float8_tensorwise_tp=False, + enable_async_tp=False, + ) + logger.info( + f"Successfully initialized model with with TP={parallel_dims.tp}" + ) + else: + logger.info("Single GPU mode - no parallelization needed") + + def _replace_with_vllm_attention(self, model_args): + """ + Replace TorchTitan attention with vLLM paged attention. + + Assumes model has .layers dict with .attention.inner_attention structure. + Override in subclass if different structure. + """ + assert hasattr( + self.model, "layers" + ), f"Model {type(self.model).__name__} must have .layers attribute" + + for layer_name, layer in self.model.layers.items(): + assert hasattr( + layer, "attention" + ), f"Layer {layer_name} must have .attention attribute" + + vllm_attn = VLLMAttention( + hidden_size=model_args.dim, + num_heads=model_args.n_heads, + num_kv_heads=model_args.n_heads, # Use n_heads (already replicated) + head_dim=model_args.head_dim, + layer_name=layer_name, + scale=model_args.head_dim**-0.5, + ) + + # Replace inner attention + layer.attention.inner_attention = vllm_attn + + logger.info( + f"Successfully replaced TorchTitan attention with VLLMAttention " + f"({len(self.model.layers)} layers)" + ) + + def _extend_rope_cache_if_needed( + self, rope_cache: torch.Tensor, max_position: int + ) -> torch.Tensor: + """ + Extend RoPE cache if needed during vLLM profiling stage. + + Args: + rope_cache: Current RoPE cache tensor + max_position: Maximum position index needed + + Returns: + Extended RoPE cache if needed, otherwise original cache + """ + from torch.distributed._tensor import DTensor, Replicate + + required_len = max_position + 1 + + # No extension needed + if required_len <= rope_cache.shape[0]: + return rope_cache + + # If no extension function provided, return original cache + if self.rope_cache_extension_fn is None: + logger.warning( + f"RoPE cache extension needed (required_len={required_len}, " + f"current_len={rope_cache.shape[0]}) but no rope_cache_extension_fn provided. " + "Returning original cache." + ) + return rope_cache + + # Handle DTensor case + is_dtensor = isinstance(rope_cache, DTensor) + if is_dtensor: + device_mesh = rope_cache.device_mesh + local_rope_cache = rope_cache.to_local() + device = local_rope_cache.device + dtype = local_rope_cache.dtype + else: + device = rope_cache.device + dtype = rope_cache.dtype + + # Use provided extension function + try: + extended_cache = self.rope_cache_extension_fn(self.config, required_len) + extended_cache = extended_cache.to(device=device, dtype=dtype) + except Exception as e: + logger.warning( + f"Failed to extend RoPE cache using rope_cache_extension_fn: {e}. " + "Returning original cache." + ) + return rope_cache + + # Convert back to DTensor if needed + if is_dtensor: + rope_cache = DTensor.from_local( + extended_cache, + device_mesh=device_mesh, + placements=[Replicate()], + ) + else: + rope_cache = extended_cache + + return rope_cache + + def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: + """Convert input token IDs to embeddings.""" + return self.model.tok_embeddings(input_ids) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + """Convert input token IDs to embeddings (deprecated vLLM interface).""" + return self.embed_input_ids(input_ids) + + def forward( + self, + input_ids: torch.Tensor | None = None, + positions: torch.Tensor | None = None, + inputs_embeds: torch.Tensor | None = None, + **kwargs, + ) -> torch.Tensor: + """ + Forward pass with vLLM interface. + + Args: + input_ids: Token IDs [total_tokens] (1D varlen format) + positions: Position indices [total_tokens] (1D varlen format) + inputs_embeds: Pre-computed embeddings (optional) + **kwargs: Additional vLLM kwargs + + Returns: + hidden_states: Final hidden states [total_tokens, hidden_size] + """ + if inputs_embeds is not None: + raise NotImplementedError("inputs_embeds not yet supported") + + if input_ids is None: + raise ValueError("Either input_ids or inputs_embeds must be provided") + + # Convert vLLM interface to TorchTitan interface + # vLLM: [total_tokens] → TorchTitan: [batch_size, seq_len] + tokens_2d = input_ids.unsqueeze(0) + + # Get embeddings + h = self.model.tok_embeddings(tokens_2d) + + # Get RoPE cache (handle model-specific attribute names) + # Use hasattr to avoid ambiguous boolean value error with tensors + if hasattr(self.model, "rope_cache"): + rope_attr = self.model.rope_cache + elif hasattr(self.model, "freqs_cis"): + rope_attr = self.model.freqs_cis + else: + rope_attr = None + + # Extend RoPE cache if needed (vLLM profiling may use 2x max_seq_len) + if positions is not None: + max_position = positions.max().item() + else: + max_position = 0 + + rope_cache = self._extend_rope_cache_if_needed(rope_attr, max_position) + positions = positions.unsqueeze(0) + + # Pass through transformer layers + for layer in self.model.layers.values(): + h = layer(h, rope_cache, attention_masks=None, positions=positions) + + # Convert to vLLM format: [total_tokens, hidden_size] + if h.dim() == 3: + batch_size, seq_len, hidden_size = h.shape + h = h.view(batch_size * seq_len, hidden_size) + + return h + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata=None, + ) -> torch.Tensor | None: + """Compute logits from hidden states.""" + h = self.model.norm(hidden_states) + logits = self.model.output(h) + + return logits + + def load_weights(self, weights_iter): + """ + Load weights from HF checkpoint using the provided state dict adapter. + vLLM engine would call this function to load model weights. + + Args: + weights_iter: Iterator of (name, tensor) pairs from HF checkpoint + + Returns: + Set of loaded parameter names + """ + # Collect weights from iterator + hf_state_dict = {} + for name, tensor in weights_iter: + hf_state_dict[name] = tensor + + # Use adapter to convert HF → TorchTitan format + adapter = self.state_dict_adapter( + model_args=self.config, + hf_assets_path=None, + ) + + torchtitan_state_dict = adapter.from_hf(hf_state_dict) + model_state_dict = {k: v for k, v in self.model.state_dict().items()} + + # Convert to DTensor if target is DTensor + for name, tensor in torchtitan_state_dict.items(): + if name in model_state_dict and isinstance(model_state_dict[name], DTensor): + target_dtensor = model_state_dict[name] + device_mesh = target_dtensor.device_mesh + torchtitan_state_dict[name] = DTensor.from_local( + tensor.to(device_mesh.device_type), + device_mesh=device_mesh, + placements=[Replicate()], + ) + + # Load state dict + set_model_state_dict( + model=self.model, + model_state_dict=torchtitan_state_dict, + options=StateDictOptions(strict=False), + ) + + loaded_params = {f"model.{name}" for name in torchtitan_state_dict.keys()} + + return loaded_params diff --git a/torchtitan/experiments/deterministic_vllm_rl/README.md b/torchtitan/experiments/rl/vllm_compat/README.md similarity index 97% rename from torchtitan/experiments/deterministic_vllm_rl/README.md rename to torchtitan/experiments/rl/vllm_compat/README.md index d2ef719c0d..bf56f4afbe 100644 --- a/torchtitan/experiments/deterministic_vllm_rl/README.md +++ b/torchtitan/experiments/rl/vllm_compat/README.md @@ -77,7 +77,7 @@ init_batch_invariance() ```python import torch from vllm.model_executor.layers.batch_invariant import init_batch_invariance -from torchtitan.experiments.deterministic_vllm_rl import ( +from torchtitan.experiments.rl.vllm_compat import ( enable_batch_invariant_backward_mode, Qwen3VLLMCompatModel, ) @@ -111,7 +111,7 @@ loss.backward() Run the RL training loop: ```bash -VLLM_BATCH_INVARIANT=1 VLLM_FLASH_ATTN_VERSION=3 python -m torchtitan.experiments.deterministic_vllm_rl.simple_rl +VLLM_BATCH_INVARIANT=1 VLLM_FLASH_ATTN_VERSION=3 python -m torchtitan.experiments.rl.vllm_compat.simple_rl ``` This will: @@ -177,7 +177,7 @@ assert torch.equal(vllm_logprobs, titan_logprobs) Run the test suite: ```bash -cd torchtitan/experiments/deterministic_vllm_rl/tests +cd torchtitan/experiments/rl/vllm_compat/tests # Test backward passes python test_batch_invariant_backward.py @@ -214,7 +214,7 @@ This implementation uses the same kernels for both rollouts (vLLM) and training ## Project Structure ``` -deterministic_vllm_rl/ +rl/vllm_compat/ ├── README.md # Documentation ├── __init__.py # Package initialization ├── batch_invariant_backward.py # Backward passes for vLLM ops diff --git a/torchtitan/experiments/deterministic_vllm_rl/__init__.py b/torchtitan/experiments/rl/vllm_compat/__init__.py similarity index 53% rename from torchtitan/experiments/deterministic_vllm_rl/__init__.py rename to torchtitan/experiments/rl/vllm_compat/__init__.py index 067555251f..b86721fba5 100644 --- a/torchtitan/experiments/deterministic_vllm_rl/__init__.py +++ b/torchtitan/experiments/rl/vllm_compat/__init__.py @@ -5,16 +5,10 @@ # LICENSE file in the root directory of this source tree. """ -Deterministic RL training with vLLM experiment. +vLLM-Compatible approach for deterministic RL training. -This experiment provides tools for bitwise-deterministic reinforcement learning -training using vLLM for fast rollouts and TorchTitan for training. - -Key components: -- VLLMCompatibleFlashAttention: Flash attention with custom backward pass -- Qwen3VLLMCompatModel: vLLM-compatible model with merged projections -- batch_invariant_backward: Gradient support for vLLM's deterministic operations -- simple_rl: End-to-end RL training loop +This module provides models that match vLLM's weight format (e.g., merged gate_up_proj) +with custom backward passes for gradient computation during training. """ from .batch_invariant_backward import ( @@ -22,9 +16,10 @@ rms_norm_with_gradients, silu_and_mul_with_gradients, ) -from .models import VLLMCompatibleFlashAttention +from .models.attention import VLLMCompatibleFlashAttention from .models.qwen3 import Qwen3VLLMCompatModel + __all__ = [ "VLLMCompatibleFlashAttention", "Qwen3VLLMCompatModel", diff --git a/torchtitan/experiments/deterministic_vllm_rl/batch_invariant_backward.py b/torchtitan/experiments/rl/vllm_compat/batch_invariant_backward.py similarity index 100% rename from torchtitan/experiments/deterministic_vllm_rl/batch_invariant_backward.py rename to torchtitan/experiments/rl/vllm_compat/batch_invariant_backward.py diff --git a/torchtitan/experiments/deterministic_vllm_rl/models/__init__.py b/torchtitan/experiments/rl/vllm_compat/models/__init__.py similarity index 74% rename from torchtitan/experiments/deterministic_vllm_rl/models/__init__.py rename to torchtitan/experiments/rl/vllm_compat/models/__init__.py index c8c11a170a..2e7a5fa6af 100644 --- a/torchtitan/experiments/deterministic_vllm_rl/models/__init__.py +++ b/torchtitan/experiments/rl/vllm_compat/models/__init__.py @@ -6,8 +6,13 @@ """ Models for deterministic vLLM RL training. + +This module provides vLLM-compatible model components. """ from .attention import VLLMCompatibleFlashAttention -__all__ = ["VLLMCompatibleFlashAttention"] + +__all__ = [ + "VLLMCompatibleFlashAttention", +] diff --git a/torchtitan/experiments/deterministic_vllm_rl/models/attention.py b/torchtitan/experiments/rl/vllm_compat/models/attention.py similarity index 98% rename from torchtitan/experiments/deterministic_vllm_rl/models/attention.py rename to torchtitan/experiments/rl/vllm_compat/models/attention.py index 33dd5a140d..11e6d3af67 100644 --- a/torchtitan/experiments/deterministic_vllm_rl/models/attention.py +++ b/torchtitan/experiments/rl/vllm_compat/models/attention.py @@ -4,12 +4,9 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -""" -vLLM-compatible Flash Attention implementation for deterministic RL training. -""" import torch -from vllm.vllm_flash_attn import flash_attn_varlen_func +from vllm.attention.utils.fa_utils import flash_attn_varlen_func class VLLMCompatibleFlashAttention(torch.nn.Module): diff --git a/torchtitan/experiments/deterministic_vllm_rl/models/qwen3/__init__.py b/torchtitan/experiments/rl/vllm_compat/models/qwen3/__init__.py similarity index 100% rename from torchtitan/experiments/deterministic_vllm_rl/models/qwen3/__init__.py rename to torchtitan/experiments/rl/vllm_compat/models/qwen3/__init__.py diff --git a/torchtitan/experiments/deterministic_vllm_rl/models/qwen3/model_vllm_compat.py b/torchtitan/experiments/rl/vllm_compat/models/qwen3/model_vllm_compat.py similarity index 99% rename from torchtitan/experiments/deterministic_vllm_rl/models/qwen3/model_vllm_compat.py rename to torchtitan/experiments/rl/vllm_compat/models/qwen3/model_vllm_compat.py index dd84665091..2c9742b1fa 100644 --- a/torchtitan/experiments/deterministic_vllm_rl/models/qwen3/model_vllm_compat.py +++ b/torchtitan/experiments/rl/vllm_compat/models/qwen3/model_vllm_compat.py @@ -13,7 +13,7 @@ from torchtitan.components.tokenizer import BaseTokenizer # Import gradient-enabled operations from experiment utilities -from torchtitan.experiments.deterministic_vllm_rl.batch_invariant_backward import ( +from torchtitan.experiments.rl.vllm_compat.batch_invariant_backward import ( rms_norm_with_gradients, silu_and_mul_with_gradients, ) diff --git a/torchtitan/experiments/deterministic_vllm_rl/simple_rl.py b/torchtitan/experiments/rl/vllm_compat/simple_rl.py similarity index 99% rename from torchtitan/experiments/deterministic_vllm_rl/simple_rl.py rename to torchtitan/experiments/rl/vllm_compat/simple_rl.py index ffc7d52eb0..508868c0d4 100644 --- a/torchtitan/experiments/deterministic_vllm_rl/simple_rl.py +++ b/torchtitan/experiments/rl/vllm_compat/simple_rl.py @@ -25,20 +25,20 @@ from huggingface_hub import snapshot_download from safetensors.torch import load_file, save_file from torch.utils.tensorboard import SummaryWriter -from transformers import AutoConfig, AutoTokenizer - -from vllm import LLM, SamplingParams -from vllm.model_executor.layers.batch_invariant import init_batch_invariance -from torchtitan.experiments.deterministic_vllm_rl.weights.converter import ( +from torchtitan.experiments.rl.vllm_compat.weights.converter import ( torchtitan_to_vllm, vllm_to_torchtitan, ) -from torchtitan.experiments.deterministic_vllm_rl.weights_vllm_compat import ( +from torchtitan.experiments.rl.vllm_compat.weights_vllm_compat import ( torchtitan_to_vllm_compat, ) from torchtitan.models.qwen3.model.args import Qwen3ModelArgs +from transformers import AutoConfig, AutoTokenizer + +from vllm import LLM, SamplingParams +from vllm.model_executor.layers.batch_invariant import init_batch_invariance init_batch_invariance() @@ -340,7 +340,7 @@ def load_model(checkpoint_path: str, model_path: str, use_vllm_compat: bool = Tr if use_vllm_compat: # Create and load model (using vLLM-compat for bitwise determinism) - from torchtitan.experiments.deterministic_vllm_rl.models.qwen3 import ( + from torchtitan.experiments.deterministic_vllm_rl.vllm_compat.models.qwen3 import ( Qwen3VLLMCompatModel, ) @@ -1058,7 +1058,7 @@ def main(): print("✓ Batch invariance detected - using vLLM-compatible model") # Add backward pass support to vLLM's batch_invariant mode print(" Adding gradient support to vLLM's batch_invariant mode...") - from torchtitan.experiments.deterministic_vllm_rl.batch_invariant_backward import ( + from torchtitan.experiments.deterministic_vllm_rl.vllm_compat.batch_invariant_backward import ( enable_batch_invariant_backward_mode, ) diff --git a/torchtitan/experiments/deterministic_vllm_rl/tests/__init__.py b/torchtitan/experiments/rl/vllm_compat/tests/__init__.py similarity index 100% rename from torchtitan/experiments/deterministic_vllm_rl/tests/__init__.py rename to torchtitan/experiments/rl/vllm_compat/tests/__init__.py diff --git a/torchtitan/experiments/deterministic_vllm_rl/tests/test_batch_invariant_backward.py b/torchtitan/experiments/rl/vllm_compat/tests/test_batch_invariant_backward.py similarity index 97% rename from torchtitan/experiments/deterministic_vllm_rl/tests/test_batch_invariant_backward.py rename to torchtitan/experiments/rl/vllm_compat/tests/test_batch_invariant_backward.py index 3ed9604d10..ddf8b01514 100644 --- a/torchtitan/experiments/deterministic_vllm_rl/tests/test_batch_invariant_backward.py +++ b/torchtitan/experiments/rl/vllm_compat/tests/test_batch_invariant_backward.py @@ -8,9 +8,11 @@ Test batch_invariant_backward module to ensure it works correctly. """ +import sys + import torch -from torchtitan.experiments.deterministic_vllm_rl.batch_invariant_backward import ( +from torchtitan.experiments.deterministic_vllm_rl.vllm_compat.batch_invariant_backward import ( disable_batch_invariant_backward_mode, enable_batch_invariant_backward_mode, linear_batch_invariant_backward, diff --git a/torchtitan/experiments/deterministic_vllm_rl/tests/test_exact_determinism.py b/torchtitan/experiments/rl/vllm_compat/tests/test_exact_determinism.py similarity index 98% rename from torchtitan/experiments/deterministic_vllm_rl/tests/test_exact_determinism.py rename to torchtitan/experiments/rl/vllm_compat/tests/test_exact_determinism.py index 8d0ac3133e..2a9863ab2f 100644 --- a/torchtitan/experiments/deterministic_vllm_rl/tests/test_exact_determinism.py +++ b/torchtitan/experiments/rl/vllm_compat/tests/test_exact_determinism.py @@ -11,11 +11,11 @@ """ import torch -from vllm.model_executor.layers.batch_invariant import disable_batch_invariant_mode -from torchtitan.experiments.deterministic_vllm_rl.batch_invariant_backward import ( +from torchtitan.experiments.deterministic_vllm_rl.vllm_compat.batch_invariant_backward import ( enable_batch_invariant_backward_mode, ) +from vllm.model_executor.layers.batch_invariant import disable_batch_invariant_mode print("Enabling batch_invariant_backward mode...") disable_batch_invariant_mode() diff --git a/torchtitan/experiments/deterministic_vllm_rl/weights/README.md b/torchtitan/experiments/rl/vllm_compat/weights/README.md similarity index 100% rename from torchtitan/experiments/deterministic_vllm_rl/weights/README.md rename to torchtitan/experiments/rl/vllm_compat/weights/README.md diff --git a/torchtitan/experiments/deterministic_vllm_rl/weights/__init__.py b/torchtitan/experiments/rl/vllm_compat/weights/__init__.py similarity index 100% rename from torchtitan/experiments/deterministic_vllm_rl/weights/__init__.py rename to torchtitan/experiments/rl/vllm_compat/weights/__init__.py diff --git a/torchtitan/experiments/deterministic_vllm_rl/weights/converter.py b/torchtitan/experiments/rl/vllm_compat/weights/converter.py similarity index 100% rename from torchtitan/experiments/deterministic_vllm_rl/weights/converter.py rename to torchtitan/experiments/rl/vllm_compat/weights/converter.py diff --git a/torchtitan/experiments/deterministic_vllm_rl/weights_vllm_compat.py b/torchtitan/experiments/rl/vllm_compat/weights_vllm_compat.py similarity index 100% rename from torchtitan/experiments/deterministic_vllm_rl/weights_vllm_compat.py rename to torchtitan/experiments/rl/vllm_compat/weights_vllm_compat.py