diff --git a/torchtitan/experiments/compat/README.md b/torchtitan/experiments/compat/README.md new file mode 100644 index 000000000..07b9bafbd --- /dev/null +++ b/torchtitan/experiments/compat/README.md @@ -0,0 +1,139 @@ +# PyTorch Compatibility Shim System (Experimental) + +This document describes the experimental compatibility shim system that allows TorchTitan to run on both PyTorch nightly and stable releases (e.g., PyTorch 2.8.0). + +## Overview + +The shim system is implemented in `torchtitan/experiments/compat/compat.py` and automatically patches missing PyTorch APIs when the package is imported. This allows developers using stable PyTorch releases to use TorchTitan without requiring PyTorch nightly. + +## How It Works + +The compatibility system uses two approaches: + +### 1. Import Hook for Missing Modules +For completely missing modules (like `torch.distributed.checkpoint._consolidate_hf_safetensors`), a custom meta path finder intercepts imports and provides shim modules with stub implementations. + +### 2. Runtime Patching for Missing Classes +For existing modules that are missing specific classes (like `DefaultStager` in `torch.distributed.checkpoint.staging`), the shim system directly adds the missing classes to the existing module at import time. + +## Automatic Activation + +The shim system is automatically activated when you import the `torchtitan` package: + +```python +import torchtitan # Shims are installed automatically +``` + +This happens in `torchtitan/__init__.py`, which imports `torchtitan.experiments.compat` before anything else. + +## Currently Shimmed APIs + +### 1. Checkpoint Consolidation (`torch.distributed.checkpoint._consolidate_hf_safetensors`) +- `consolidate_safetensor_files` - Raises NotImplementedError +- `consolidate_safetensors_files_on_every_rank` - Raises NotImplementedError + +**Note:** HuggingFace checkpoint export requires PyTorch nightly. + +### 2. Checkpoint Staging (`torch.distributed.checkpoint.staging`) +- `StagingOptions` - Simple placeholder for staging configuration +- `DefaultStager` - Falls back to `BlockingAsyncStager` if available + +### 3. Pipeline Schedules (`torch.distributed.pipelining.schedules`) +- `ScheduleDualPipeV` - Raises NotImplementedError if instantiated + +**Note:** Use a different pipeline schedule if you hit this error. + +### 4. Flex Attention (`torch.nn.attention.flex_attention`) +- `AuxOutput` - NamedTuple for auxiliary flex_attention outputs + +### 5. Checkpoint Wrapper (`torch.distributed.algorithms._checkpoint.checkpoint_wrapper`) +- Wraps `checkpoint_wrapper` function to filter out the `early_stop` parameter which is not available in PyTorch 2.8.0 +- The `early_stop` parameter is silently ignored in stable PyTorch + +## Adding New Shims + +If you encounter a new missing API when using stable PyTorch, you can add a shim by: + +1. **For missing modules:** Add a factory function to `torchtitan/experiments/compat/compat.py` and register it with `register_shim()` + +```python +def _shim_new_module(): + module = ModuleType('torch.some.missing.module') + # Add functions/classes to the module + return module + +# In install_shims(): +register_shim('torch.some.missing.module', _shim_new_module) +``` + +2. **For missing classes in existing modules:** Add a function that patches the existing module + +```python +def _shim_existing_module(): + from torch.some import existing_module + + class MissingClass: + # Implementation or stub + pass + + existing_module.MissingClass = MissingClass + return existing_module + +# In install_shims(): +_shim_existing_module() +``` + +## Testing + +To verify the shim system works: + +```bash +# Should succeed with PyTorch 2.8.0 +python -c "import torchtitan; print('Shims loaded successfully')" + +# Try importing a shimmed module +python -c "from torch.distributed.checkpoint._consolidate_hf_safetensors import consolidate_safetensors_files_on_every_rank" + +# Run the test suite +python -m torchtitan.experiments.compat.test_compat +``` + +## Known Limitations + +1. **HuggingFace Checkpoint Export:** Not supported in stable PyTorch. Set `checkpoint.last_save_in_hf = false` in your config. + +2. **ScheduleDualPipeV:** Not available in stable PyTorch. Use a different pipeline schedule. + +3. **Async Checkpoint Staging:** Limited functionality with the shim. Some advanced features may not work. + +## Version Compatibility + +- **PyTorch Nightly:** All features work natively, shims are harmless +- **PyTorch 2.8.0:** Tested and working with limitations noted above +- **Older versions:** May require additional shims + +## Philosophy + +The shim system follows these principles: + +1. **Simple and Transparent:** Easy to understand and extend +2. **Fail-Fast:** Unsupported features raise clear errors explaining limitations +3. **Non-Intrusive:** Works automatically without code changes +4. **Compatible:** Harmless when used with PyTorch nightly + +## Troubleshooting + +If you encounter an import error: + +1. Check if it's a PyTorch API that's missing in your version +2. Add a shim following the patterns in `torchtitan/experiments/compat/compat.py` +3. Test that both stable and nightly PyTorch work with your shim + +For feature limitations, the error messages will guide you to either: +- Upgrade to PyTorch nightly +- Use an alternative feature +- Disable the feature in your configuration + +## Experimental Status + +This compatibility system is experimental and may change in future releases. It is designed to help users who cannot use PyTorch nightly for various reasons (e.g., stability requirements, deployment constraints). diff --git a/torchtitan/experiments/compat/__init__.py b/torchtitan/experiments/compat/__init__.py new file mode 100644 index 000000000..b141b4ebb --- /dev/null +++ b/torchtitan/experiments/compat/__init__.py @@ -0,0 +1,20 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +PyTorch compatibility shims for non-nightly versions. + +This experimental module provides compatibility between PyTorch nightly and stable releases +by shimming missing modules and functions. + +Usage: + import torchtitan.experiments.compat # noqa: F401 + +The shims are automatically installed when this module is imported. +""" + +# Import compat to auto-install shims +from . import compat # noqa: F401 diff --git a/torchtitan/experiments/compat/compat.py b/torchtitan/experiments/compat/compat.py new file mode 100644 index 000000000..83575e283 --- /dev/null +++ b/torchtitan/experiments/compat/compat.py @@ -0,0 +1,256 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +PyTorch compatibility shims for non-nightly versions (Experimental). + +This experimental module provides compatibility between PyTorch nightly and stable releases +by shimming missing modules and functions. Import this module early in your +application to enable automatic shimming. + +The shims are automatically installed when torchtitan is imported. + +Usage: + import torchtitan.experiments.compat # noqa: F401 +""" + +import sys +from importlib.abc import Loader, MetaPathFinder +from importlib.machinery import ModuleSpec +from types import ModuleType + + +class CompatShimLoader(Loader): + """Loader that provides shim modules for missing PyTorch features.""" + + def __init__(self, module_name: str, shim_factory): + self.module_name = module_name + self.shim_factory = shim_factory + + def create_module(self, spec): + """Create the shim module.""" + return self.shim_factory() + + def exec_module(self, module): + """Module is already populated by create_module.""" + pass + + +class CompatMetaPathFinder(MetaPathFinder): + """Meta path finder that intercepts imports and provides compatibility shims.""" + + # Registry of shims: module_name -> factory function + SHIMS = {} + + def find_spec(self, fullname, path, target=None): + """Find module spec for shimmed modules.""" + if fullname in self.SHIMS: + return ModuleSpec( + fullname, + CompatShimLoader(fullname, self.SHIMS[fullname]), + origin="torchtitan-compat-shim", + ) + return None + + +def register_shim(module_name: str, factory): + """Register a shim factory for a module. + + Args: + module_name: Full module name to shim (e.g., 'torch.foo.bar') + factory: Callable that returns a module object with the shimmed functionality + """ + CompatMetaPathFinder.SHIMS[module_name] = factory + + +# ============================================================================ +# Shim Definitions +# ============================================================================ + + +def _shim_checkpoint_staging(): + """Shim for torch.distributed.checkpoint.staging missing classes""" + from torch.distributed.checkpoint import staging + + # Create wrapper for StagingOptions + class StagingOptions: + """Shim for StagingOptions from PyTorch nightly.""" + + __slots__ = ("args", "kwargs") + + def __init__(self, *args, **kwargs): + # Store the arguments for potential future use + self.args = args + self.kwargs = kwargs + + # Create wrapper for DefaultStager + class DefaultStager: + """Shim for DefaultStager from PyTorch nightly.""" + + def __init__(self, options=None): + # In PyTorch 2.8, we can use BlockingAsyncStager as a fallback + if hasattr(staging, "BlockingAsyncStager"): + self._stager = staging.BlockingAsyncStager() + else: + self._stager = None + self.options = options + + def __getattr__(self, name): + # Delegate to the underlying stager if it exists + if self._stager: + return getattr(self._stager, name) + raise AttributeError( + f"'{type(self).__name__}' object has no attribute '{name}'" + ) + + def close(self): + """Close the stager.""" + if self._stager and hasattr(self._stager, "close"): + self._stager.close() + + # Add the classes to the staging module + staging.StagingOptions = StagingOptions + staging.DefaultStager = DefaultStager + + return staging + + +def _shim_pipelining_schedules(): + """Shim for torch.distributed.pipelining.schedules missing classes""" + from torch.distributed.pipelining import schedules + + # ScheduleDualPipeV is a nightly-only schedule class + # For compatibility, we create a placeholder that raises an error if used + # but allows the import to succeed + class ScheduleDualPipeV: + """Shim for ScheduleDualPipeV from PyTorch nightly. + + This is a placeholder to allow imports to succeed. If this schedule is + actually used at runtime, it will raise an error. + """ + + def __init__(self, *args, **kwargs): + raise NotImplementedError( + "ScheduleDualPipeV requires PyTorch nightly. " + "This schedule is not available in PyTorch 2.8.0. " + "Please use a different pipeline schedule or upgrade to PyTorch nightly." + ) + + # Add the class to the schedules module + schedules.ScheduleDualPipeV = ScheduleDualPipeV + + return schedules + + +def _shim_flex_attention(): + """Shim for torch.nn.attention.flex_attention missing classes""" + from typing import NamedTuple + + import torch + from torch.nn.attention import flex_attention + + # AuxOutput is used for auxiliary outputs from flex_attention + # It's a NamedTuple that contains logsumexp and per_sample_seed + class AuxOutput(NamedTuple): + """Shim for AuxOutput from PyTorch nightly. + + This is a simple NamedTuple to match the structure expected by flex_attention. + """ + + logsumexp: torch.Tensor + per_sample_seed: torch.Tensor | None = None + + # Add the class to the flex_attention module + flex_attention.AuxOutput = AuxOutput + + return flex_attention + + +def _shim_checkpoint_wrapper(): + """Shim for torch.distributed.algorithms._checkpoint.checkpoint_wrapper early_stop parameter""" + import functools + + from torch.distributed.algorithms._checkpoint import ( + checkpoint_wrapper as checkpoint_module, + ) + + # Save the original checkpoint_wrapper + _original_checkpoint_wrapper = checkpoint_module.checkpoint_wrapper + + @functools.wraps(_original_checkpoint_wrapper) + def checkpoint_wrapper_shim(module, *args, early_stop=None, **kwargs): + """Wrapper that filters out the early_stop parameter not supported in PyTorch 2.8.0. + + The early_stop parameter is a nightly-only feature for activation checkpointing. + In PyTorch 2.8.0, we simply ignore it. + """ + # Filter out early_stop parameter and pass everything else through + return _original_checkpoint_wrapper(module, *args, **kwargs) + + # Replace the checkpoint_wrapper function with our shim + checkpoint_module.checkpoint_wrapper = checkpoint_wrapper_shim + + return checkpoint_module + + +def _shim_consolidate_hf_safetensors(): + """Shim for torch.distributed.checkpoint._consolidate_hf_safetensors""" + module = ModuleType("torch.distributed.checkpoint._consolidate_hf_safetensors") + + def consolidate_safetensor_files(checkpoint_id, save_path, *args, **kwargs): + """Stub implementation that raises a helpful error.""" + raise NotImplementedError( + "consolidate_safetensor_files requires PyTorch nightly. " + "This feature is not available in PyTorch 2.8.0. " + "Please either upgrade to PyTorch nightly or disable this feature." + ) + + def consolidate_safetensors_files_on_every_rank( + input_dir, output_dir, fqn_to_index_mapping, num_threads=5, *args, **kwargs + ): + """Stub implementation that raises a helpful error.""" + raise NotImplementedError( + "consolidate_safetensors_files_on_every_rank requires PyTorch nightly. " + "This feature is not available in PyTorch 2.8.0. " + "Please either upgrade to PyTorch nightly or disable HuggingFace checkpoint export." + ) + + module.consolidate_safetensor_files = consolidate_safetensor_files + module.consolidate_safetensors_files_on_every_rank = ( + consolidate_safetensors_files_on_every_rank + ) + + return module + + +# ============================================================================ +# Registration +# ============================================================================ + + +def install_shims(): + """Install all compatibility shims.""" + # Register shims for missing modules + register_shim( + "torch.distributed.checkpoint._consolidate_hf_safetensors", + _shim_consolidate_hf_safetensors, + ) + + # Install shims for existing modules with missing classes/parameters + _shim_checkpoint_staging() + _shim_pipelining_schedules() + _shim_flex_attention() + _shim_checkpoint_wrapper() + + # Install the meta path finder (only once) + finder = CompatMetaPathFinder() + if finder not in sys.meta_path: + # Insert at the beginning so we can intercept before the import fails + sys.meta_path.insert(0, finder) + + +# Auto-install shims when this module is imported +install_shims() diff --git a/torchtitan/experiments/compat/test_compat.py b/torchtitan/experiments/compat/test_compat.py new file mode 100644 index 000000000..f8c9eeb71 --- /dev/null +++ b/torchtitan/experiments/compat/test_compat.py @@ -0,0 +1,101 @@ +#!/usr/bin/env python +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +Test script to verify PyTorch compatibility shims are working correctly. +""" + +import sys + + +def test_shims(): + """Test all compatibility shims.""" + print("Testing PyTorch compatibility shims...\n") + + # Test 1: Import torchtitan (which auto-installs shims) + print("1. Testing torchtitan import...") + try: + import torchtitan # noqa: F401 + + print(" ✓ torchtitan imported successfully\n") + except Exception as e: + print(f" ✗ Failed to import torchtitan: {e}\n") + return False + + # Test 2: Consolidate HF safetensors module + print("2. Testing torch.distributed.checkpoint._consolidate_hf_safetensors...") + try: + from torch.distributed.checkpoint._consolidate_hf_safetensors import ( # noqa: F401 + consolidate_safetensor_files, + consolidate_safetensors_files_on_every_rank, + ) + + print(" ✓ Module and functions imported successfully\n") + except Exception as e: + print(f" ✗ Failed: {e}\n") + return False + + # Test 3: Checkpoint staging classes + print("3. Testing torch.distributed.checkpoint.staging classes...") + try: + from torch.distributed.checkpoint.staging import ( # noqa: F401 + DefaultStager, + StagingOptions, + ) + + print(" ✓ DefaultStager and StagingOptions imported successfully\n") + except Exception as e: + print(f" ✗ Failed: {e}\n") + return False + + # Test 4: Pipeline schedule class + print("4. Testing torch.distributed.pipelining.schedules.ScheduleDualPipeV...") + try: + from torch.distributed.pipelining.schedules import ( # noqa: F401 + ScheduleDualPipeV, + ) + + print(" ✓ ScheduleDualPipeV imported successfully\n") + except Exception as e: + print(f" ✗ Failed: {e}\n") + return False + + # Test 5: Flex attention AuxOutput + print("5. Testing torch.nn.attention.flex_attention.AuxOutput...") + try: + from torch.nn.attention.flex_attention import AuxOutput # noqa: F401 + + print(" ✓ AuxOutput imported successfully\n") + except Exception as e: + print(f" ✗ Failed: {e}\n") + return False + + # Test 6: Checkpoint wrapper with early_stop parameter + print("6. Testing checkpoint_wrapper with early_stop parameter...") + try: + import torch.nn as nn + from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( + checkpoint_wrapper, + ) + + # Create a simple module and wrap it with early_stop parameter + module = nn.Linear(10, 10) + _ = checkpoint_wrapper(module, preserve_rng_state=False, early_stop=True) + print(" ✓ checkpoint_wrapper accepts early_stop parameter\n") + except Exception as e: + print(f" ✗ Failed: {e}\n") + return False + + print("=" * 60) + print("All compatibility shims are working correctly! ✓") + print("=" * 60) + return True + + +if __name__ == "__main__": + success = test_shims() + sys.exit(0 if success else 1) diff --git a/torchtitan/models/attention.py b/torchtitan/models/attention.py index f941430e3..2bc803e1e 100644 --- a/torchtitan/models/attention.py +++ b/torchtitan/models/attention.py @@ -20,6 +20,7 @@ create_block_mask, flex_attention, ) +from vllm.vllm_flash_attn import flash_attn_varlen_func __all__ = [ @@ -103,6 +104,72 @@ def forward( with sdpa_kernel(self.sdpa_backends, set_priority=True): return F.scaled_dot_product_attention(q, k, v, scale=scale, is_causal=True) +class VLLMCompatibleFlashAttention(torch.nn.Module): + """Wrapper around FlashAttention as used by VLLM""" + def __init__(self) -> None: + super().__init__() + self.flash_attn_varlen_func = flash_attn_varlen_func + from vllm.model_executor.layers.batch_invariant import vllm_is_batch_invariant + from vllm.attention.utils.fa_utils import get_flash_attn_version + self.vllm_is_batch_invariant = vllm_is_batch_invariant + self.fa_version = get_flash_attn_version() + + def forward( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + *, + scale: float | None = None, + ) -> torch.Tensor | tuple[torch.Tensor, AuxOutput]: + # Flash Attention varlen expects: (batch, seqlen, nheads, headdim) + # The input from TorchTitan is always (batch, num_heads, seq_len, head_dim) + # We need to transpose to (batch, seq_len, num_heads, head_dim) + + # Input is (batch, num_heads, seq_len, head_dim) - need to transpose + q = q.transpose(1, 2) # -> (batch, seq_len, num_heads, head_dim) + k = k.transpose(1, 2) + v = v.transpose(1, 2) + + # Get dimensions + batch_size, seq_len, num_heads, head_dim = q.shape + + # Convert to varlen format: flatten batch and sequence dimensions + # (batch, seqlen, nheads, headdim) -> (total_tokens, nheads, headdim) + q_varlen = q.reshape(-1, num_heads, head_dim) + k_varlen = k.reshape(-1, k.shape[2], head_dim) + v_varlen = v.reshape(-1, v.shape[2], head_dim) + + # Create cumulative sequence lengths + # cu_seqlens: [0, seq_len, 2*seq_len, ..., batch_size*seq_len] + cu_seqlens = torch.arange( + 0, (batch_size + 1) * seq_len, seq_len, + dtype=torch.int32, device=q.device + ) + + # Call Flash Attention varlen (works with both standard flash-attn and vLLM's wrapper) + output_varlen = self.flash_attn_varlen_func( + q_varlen, k_varlen, v_varlen, + cu_seqlens_q=cu_seqlens, + cu_seqlens_k=cu_seqlens, + max_seqlen_q=seq_len, + max_seqlen_k=seq_len, + softmax_scale=scale, + causal=True, + num_splits=1 if self.vllm_is_batch_invariant() else 0, + fa_version=self.fa_version, + ) + + # Convert back to batch format + # (total_tokens, nheads, headdim) -> (batch, seqlen, nheads, headdim) + output = output_varlen.reshape(batch_size, seq_len, num_heads, head_dim) + + # Transpose back to (batch, num_heads, seq_len, head_dim) to match input format + output = output.transpose(1, 2) + + return output + + # We cannot do inner function/closure because we won't be able to cache it -- # if we an inner function, a new closure will be created every time diff --git a/torchtitan/models/qwen3/model/model_vllm_compat.py b/torchtitan/models/qwen3/model/model_vllm_compat.py new file mode 100644 index 000000000..95e13e78b --- /dev/null +++ b/torchtitan/models/qwen3/model/model_vllm_compat.py @@ -0,0 +1,446 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +# +# Qwen3 model compatible with vLLM's implementation +# Uses merged gate_up projections and vLLM Flash Attention + +import torch +import torch.nn.functional as F +from torch import nn +from torch.nn.attention.flex_attention import and_masks, BlockMask + +from torchtitan.components.tokenizer import BaseTokenizer +from torchtitan.models.attention import ( + create_attention_mask, + get_causal_mask_mod, + get_document_mask_mod, + VLLMCompatibleFlashAttention, +) +from torchtitan.protocols.model import AttentionMasksType +from torchtitan.protocols.train_spec import ModelProtocol + +from .args import Qwen3ModelArgs + +# Import vLLM's exact operations for bitwise determinism +from vllm.model_executor.layers.activation import SiluAndMul as VLLMSiluAndMul +from vllm.model_executor.layers.batch_invariant import rms_norm as vllm_rms_norm + + +# RoPE functions (same as original) +def precompute_rope_cache( + dim: int, max_seq_len: int, base: float = 1_000_000.0 +) -> torch.Tensor: + freqs = 1.0 / (base ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) + t = torch.arange(max_seq_len, dtype=freqs.dtype, device=freqs.device) + idx_theta = torch.outer(t, freqs).float() + freqs = torch.cat([idx_theta, idx_theta], dim=-1) + rope_cache = torch.cat([freqs.cos(), freqs.sin()], dim=-1) + return rope_cache + + +def rotate_half(x: torch.Tensor) -> torch.Tensor: + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def reshape_for_broadcast(rope_cache: torch.Tensor, x: torch.Tensor) -> torch.Tensor: + """Reshape frequency tensor for broadcasting.""" + ndim = x.ndim + assert ndim > 1 + _, seqlen, _, head_dim = x.shape + rope_cache = rope_cache[0:seqlen] + assert rope_cache.shape == (seqlen, head_dim * 2) + shape = [-1, seqlen, 1, head_dim * 2] + return rope_cache.view(*shape) + + +def apply_rotary_emb( + xq: torch.Tensor, xk: torch.Tensor, rope_cache: torch.Tensor +) -> tuple[torch.Tensor, torch.Tensor]: + head_dim = xq.shape[-1] + rope_cache = reshape_for_broadcast(rope_cache, xq) + cos = rope_cache[..., :head_dim].to(dtype=xq.dtype, device=xq.device) + sin = rope_cache[..., head_dim:].to(dtype=xq.dtype, device=xq.device) + xq_out = (xq * cos) + (rotate_half(xq) * sin) + xk_out = (xk * cos) + (rotate_half(xk) * sin) + return xq_out.type_as(xq), xk_out.type_as(xk) + + +def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor: + """torch.repeat_interleave(x, dim=2, repeats=n_rep)""" + bs, slen, n_kv_heads, head_dim = x.shape + if n_rep == 1: + return x + return ( + torch.unsqueeze(x, dim=3) + .expand(bs, slen, n_kv_heads, n_rep, head_dim) + .reshape(bs, slen, n_kv_heads * n_rep, head_dim) + ) + + +# Use vLLM's exact SiluAndMul kernel for bitwise determinism +SiluAndMul = VLLMSiluAndMul + + +class RMSNormFunction(torch.autograd.Function): + """ + Autograd function for RMS normalization using vLLM's Triton kernel in forward + and batch-invariant operations in backward. + """ + + @staticmethod + def forward(ctx, input, weight, eps): + """ + Forward pass using vLLM's rms_norm Triton kernel. + + Args: + input: Input tensor [*, hidden_size] + weight: Weight tensor [hidden_size] + eps: Epsilon for numerical stability + + Returns: + output: Normalized and scaled tensor [*, hidden_size] + """ + # Use vLLM's Triton kernel for forward (deterministic) + output = vllm_rms_norm(input, weight, eps) + + # Save for backward + ctx.save_for_backward(input, weight) + ctx.eps = eps + + return output + + @staticmethod + def backward(ctx, grad_output): + """ + Backward pass using batch-invariant PyTorch operations. + + Returns: + (grad_input, grad_weight, None) + """ + input, weight = ctx.saved_tensors + eps = ctx.eps + + # Compute forward pass values needed for backward + # variance = mean(x^2) along last dim + variance = (input * input).mean(dim=-1, keepdim=True) + rms = torch.sqrt(variance + eps) + x_norm = input / rms + + # Gradient w.r.t. weight + # grad_weight = sum(grad_output * x_norm) over all dims except last + grad_weight = (grad_output * x_norm).sum(dim=tuple(range(grad_output.ndim - 1))) + + # Gradient w.r.t. input + # grad_x_norm = grad_output * weight + grad_x_norm = grad_output * weight + + # grad_x = (grad_x_norm - mean(grad_x_norm * x_norm) * x_norm) / rms + mean_term = (grad_x_norm * x_norm).mean(dim=-1, keepdim=True) + grad_input = (grad_x_norm - mean_term * x_norm) / rms + + return grad_input, grad_weight, None + + +def rms_norm_with_gradients(input: torch.Tensor, weight: torch.Tensor, eps: float = 1e-6) -> torch.Tensor: + """ + RMS normalization with gradient support. + + Uses vLLM's Triton kernel for forward pass (deterministic) and + batch-invariant PyTorch operations for backward pass. + + Args: + input: Input tensor [*, hidden_size] + weight: Weight tensor [hidden_size] + eps: Epsilon for numerical stability + + Returns: + output: Normalized and scaled tensor [*, hidden_size] + """ + return RMSNormFunction.apply(input, weight, eps) + + +class VLLMRMSNorm(nn.Module): + """ + RMSNorm using vLLM's exact Triton kernel for bitwise determinism. + Compatible with PyTorch's nn.RMSNorm interface but uses vLLM's implementation. + """ + + def __init__(self, hidden_size: int, eps: float = 1e-6): + super().__init__() + self.eps = eps + self.weight = nn.Parameter(torch.ones(hidden_size)) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + # Use vLLM's RMSNorm with gradient support for training + return rms_norm_with_gradients(x, self.weight, self.eps) + + def reset_parameters(self): + nn.init.ones_(self.weight) + + +class FeedForwardVLLMCompat(nn.Module): + """ + FeedForward module compatible with vLLM implementation. + Uses merged gate_up projection like vLLM. + """ + + def __init__( + self, + dim: int, + hidden_dim: int, + ): + super().__init__() + + # Merged gate and up projections (like vLLM's gate_up_proj) + self.gate_up_proj = nn.Linear(dim, hidden_dim * 2, bias=False) + + # Down projection (like vLLM's down_proj) + self.down_proj = nn.Linear(hidden_dim, dim, bias=False) + + # vLLM's activation + self.act_fn = SiluAndMul() + + def forward(self, x): + # Project to gate and up in one go + gate_up = self.gate_up_proj(x) + # Apply SiluAndMul activation + activated = self.act_fn(gate_up) + # Project down + output = self.down_proj(activated) + return output + + def init_weights(self, init_std: float): + # Initialize like vLLM + nn.init.trunc_normal_(self.gate_up_proj.weight, mean=0.0, std=0.02) + nn.init.trunc_normal_(self.down_proj.weight, mean=0.0, std=init_std) + + +class Attention(nn.Module): + """ + Multi-head attention module compatible with vLLM. + """ + + def __init__(self, model_args: Qwen3ModelArgs): + super().__init__() + self.n_heads = model_args.n_heads + self.n_kv_heads = ( + model_args.n_heads + if model_args.n_kv_heads is None + else model_args.n_kv_heads + ) + self.n_rep = self.n_heads // self.n_kv_heads + self.head_dim = model_args.head_dim + self.scaling = self.head_dim**-0.5 + + # QK norm (Qwen3 specific) - use vLLM's RMSNorm + if model_args.qk_norm: + self.q_norm = VLLMRMSNorm(self.head_dim, eps=model_args.norm_eps) + self.k_norm = VLLMRMSNorm(self.head_dim, eps=model_args.norm_eps) + else: + self.q_norm = None + self.k_norm = None + + # QKV projections + self.wq = nn.Linear(model_args.dim, model_args.n_heads * self.head_dim, bias=False) + self.wk = nn.Linear(model_args.dim, self.n_kv_heads * self.head_dim, bias=False) + self.wv = nn.Linear(model_args.dim, self.n_kv_heads * self.head_dim, bias=False) + self.wo = nn.Linear(model_args.n_heads * self.head_dim, model_args.dim, bias=False) + + # Always use vLLM compatible flash attention + self.inner_attention = VLLMCompatibleFlashAttention() + + def init_weights(self, init_std: float): + for linear in (self.wq, self.wk, self.wv): + nn.init.trunc_normal_(linear.weight, mean=0.0, std=0.02) + nn.init.trunc_normal_(self.wo.weight, mean=0.0, std=init_std) + if self.q_norm is not None: + self.q_norm.reset_parameters() + if self.k_norm is not None: + self.k_norm.reset_parameters() + + def forward( + self, + x: torch.Tensor, + rope_cache: torch.Tensor, + attention_masks: AttentionMasksType | None, + ): + bs, seqlen, _ = x.shape + xq, xk, xv = self.wq(x), self.wk(x), self.wv(x) + + # Reshape to heads + xq = xq.view(bs, seqlen, -1, self.head_dim) + xk = xk.view(bs, seqlen, -1, self.head_dim) + xv = xv.view(bs, seqlen, -1, self.head_dim) + + # Apply QK norm + if self.q_norm: + xq = self.q_norm(xq) + if self.k_norm: + xk = self.k_norm(xk) + + # Apply rotary embedding + xq, xk = apply_rotary_emb(xq, xk, rope_cache) + + # Repeat k/v heads if needed + keys = repeat_kv(xk, self.n_rep) + values = repeat_kv(xv, self.n_rep) + + # Transpose for attention + xq = xq.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim) + xk = keys.transpose(1, 2) + xv = values.transpose(1, 2) + + # Apply flash attention (vLLM compatible, no flex attention) + assert attention_masks is None, "vLLM compat mode doesn't use flex attention masks" + output = self.inner_attention(xq, xk, xv, scale=self.scaling) + + # Transpose back + output = output.transpose(1, 2).contiguous() + output = output.view(bs, seqlen, -1) + + return self.wo(output) + + +class TransformerBlock(nn.Module): + """ + TransformerBlock with vLLM-compatible FFN. + """ + + def __init__(self, layer_id: int, model_args: Qwen3ModelArgs): + super().__init__() + self.n_heads = model_args.n_heads + self.dim = model_args.dim + + self.attention = Attention(model_args) + + # Use vLLM-compatible FFN with merged projections + self.feed_forward = FeedForwardVLLMCompat( + dim=model_args.dim, hidden_dim=model_args.hidden_dim + ) + + self.attention_norm = VLLMRMSNorm(model_args.dim, eps=model_args.norm_eps) + self.ffn_norm = VLLMRMSNorm(model_args.dim, eps=model_args.norm_eps) + + if model_args.depth_init: + self.weight_init_std = 0.02 / (2 * (layer_id + 1)) ** 0.5 + else: + self.weight_init_std = 0.02 / (2 * model_args.n_layers) ** 0.5 + + def forward( + self, + x: torch.Tensor, + rope_cache: torch.Tensor, + attention_masks: AttentionMasksType | None, + ): + # Self attention with residual + attn_norm_out = self.attention_norm(x) + x = x + self.attention(attn_norm_out, rope_cache, attention_masks) + + # FFN with residual + ffn_norm_out = self.ffn_norm(x) + x = x + self.feed_forward(ffn_norm_out) + + return x + + def init_weights(self, buffer_device: torch.device): + for norm in (self.attention_norm, self.ffn_norm): + norm.reset_parameters() + self.attention.init_weights(self.weight_init_std) + self.feed_forward.init_weights(self.weight_init_std) + + +class Qwen3VLLMCompatModel(nn.Module, ModelProtocol): + """ + Qwen3 model with vLLM-compatible implementation. + Uses merged gate_up projections and vLLM Flash Attention. + """ + + def __init__(self, model_args: Qwen3ModelArgs): + super().__init__() + self.model_args = model_args + self.vocab_size = model_args.vocab_size + self.n_layers = model_args.n_layers + self.eos_id = model_args.eos_id + self.head_dim = model_args.head_dim + + self.tok_embeddings = nn.Embedding(model_args.vocab_size, model_args.dim) + + self.register_buffer( + "rope_cache", self._precompute_rope_cache(), persistent=False + ) + + self.layers = torch.nn.ModuleDict() + for layer_id in range(model_args.n_layers): + self.layers[str(layer_id)] = TransformerBlock(layer_id, model_args) + + self.norm = VLLMRMSNorm(model_args.dim, eps=model_args.norm_eps) + self.output = nn.Linear(model_args.dim, model_args.vocab_size, bias=False) + + # IMPORTANT: To match vLLM's behavior and Qwen3's config + # (tie_word_embeddings: true), tie output layer weights to + # embedding weights. When either weight updates during training, + # both update together + self.output.weight = self.tok_embeddings.weight + + def init_weights( + self, + buffer_device: torch.device | None = None, + ): + buffer_device = buffer_device or self.rope_cache.device + with torch.device(buffer_device): + self.rope_cache = self._precompute_rope_cache() + if self.tok_embeddings is not None: + nn.init.normal_(self.tok_embeddings.weight) + for layer in self.layers.values(): + if layer is not None: + layer.init_weights(buffer_device) + if self.norm is not None: + self.norm.reset_parameters() + final_out_std = self.model_args.dim**-0.5 + cutoff_factor = 3 + + if self.output is not None: + nn.init.trunc_normal_( + self.output.weight, + mean=0.0, + std=final_out_std, + a=-cutoff_factor * final_out_std, + b=cutoff_factor * final_out_std, + ) + + def _precompute_rope_cache(self) -> torch.Tensor: + return precompute_rope_cache( + self.model_args.head_dim, + self.model_args.max_seq_len, + self.model_args.rope_theta, + ) + + def get_attention_masks( + self, + input_batch: torch.Tensor, + tokenizer: BaseTokenizer, + extra_inputs: dict[str, torch.Tensor] | None = None, + ) -> AttentionMasksType: + # vLLM compat mode: no flex attention masks + return None + + def forward( + self, + tokens: torch.Tensor, + attention_masks: AttentionMasksType | None = None, + ): + h = self.tok_embeddings(tokens) if self.tok_embeddings else tokens + + for layer in self.layers.values(): + h = layer(h, self.rope_cache, attention_masks) + + h = self.norm(h) if self.norm else h + output = self.output(h) if self.output else h + + return output