diff --git a/configs/vllm_qwen3_8b.yaml b/configs/vllm_qwen3_8b.yaml index 96ce140..bdb50fa 100644 --- a/configs/vllm_qwen3_8b.yaml +++ b/configs/vllm_qwen3_8b.yaml @@ -11,7 +11,7 @@ # Usage: # python -m torchspec.train_entry --config configs/vllm_qwen3_8b.yaml # -# Note: Uses vLLM Worker Extension to hook into model forward pass for hidden states capture. +# Note: Uses vLLM's extract_hidden_states speculative method with MooncakeHiddenStatesConnector. model: target_model_path: Qwen/Qwen3-8B @@ -51,17 +51,14 @@ inference: vllm: tp_size: 2 mem_fraction_static: 0.7 - use_worker_extension: true extra_args: - max_num_batched_tokens: 32768 - compilation_config: - max_cudagraph_capture_size: 8 + max_num_batched_tokens: 8192 mooncake: master_server_address: null metadata_server: null protocol: tcp - global_segment_size: 16GB + global_segment_size: 32GB local_buffer_size: 4GB output_dir: ./outputs/qwen3-8b-single-node diff --git a/examples/data/sample_conversations.jsonl b/examples/data/sample_conversations.jsonl index 75b9bd1..56c5b3f 100644 --- a/examples/data/sample_conversations.jsonl +++ b/examples/data/sample_conversations.jsonl @@ -998,3 +998,4 @@ {"id": "local_000997", "conversations": [{"role": "user", "content": "How do I get started with performance optimization?"}, {"role": "assistant", "content": "performance optimization is an important area in software development. Here are the key points:\n\n1. It involves understanding core principles and applying them systematically.\n2. Best practices include proper planning, documentation, and iterative improvement.\n3. Common tools include various frameworks and libraries that simplify implementation.\n4. Learning resources are widely available online, including tutorials and documentation.\n\nTo get started, I recommend beginning with the basics and gradually building up your knowledge through practice."}, {"role": "user", "content": "What resources do you recommend for learning more?"}, {"role": "assistant", "content": "For advanced techniques, consider:\n\n- **Optimization**: Focus on performance-critical paths\n- **Scalability**: Design for growth from the start\n- **Monitoring**: Implement comprehensive logging and metrics\n- **Automation**: Reduce manual processes where possible\n\nThese practices will help you build more robust solutions."}]} {"id": "local_000998", "conversations": [{"role": "user", "content": "Can you explain debugging and its key concepts?"}, {"role": "assistant", "content": "Great question about debugging!\n\nHere's what you need to know:\n\n1. **Overview**: debugging is a fundamental concept in modern software development.\n2. **Key Components**: It typically involves planning, implementation, and validation phases.\n3. **Common Patterns**: There are established patterns that most practitioners follow.\n4. **Resources**: Documentation and community resources are excellent starting points.\n\nLet me know if you'd like more specific information on any of these points."}]} {"id": "local_000999", "conversations": [{"role": "user", "content": "What are the benefits of database design?"}, {"role": "assistant", "content": "database design is an important area in software development. Here are the key points:\n\n1. It involves understanding core principles and applying them systematically.\n2. Best practices include proper planning, documentation, and iterative improvement.\n3. Common tools include various frameworks and libraries that simplify implementation.\n4. Learning resources are widely available online, including tutorials and documentation.\n\nTo get started, I recommend beginning with the basics and gradually building up your knowledge through practice."}, {"role": "user", "content": "What resources do you recommend for learning more?"}, {"role": "assistant", "content": "For advanced techniques, consider:\n\n- **Optimization**: Focus on performance-critical paths\n- **Scalability**: Design for growth from the start\n- **Monitoring**: Implement comprehensive logging and metrics\n- **Automation**: Reduce manual processes where possible\n\nThese practices will help you build more robust solutions."}]} +{"id": "long_chunked_prefill_test", "conversations": [{"role": "user", "content": "Here is the complete source code for our distributed training framework. Please review it and identify the top 3 most critical bugs or performance issues.\n\n```python\nimport os\nimport math\nimport time\nimport logging\nimport hashlib\nimport asyncio\nimport threading\nfrom typing import Any, Dict, List, Optional, Tuple, Union\nfrom dataclasses import dataclass, field\nfrom collections import defaultdict, OrderedDict\nfrom contextlib import contextmanager\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport torch.distributed as dist\nfrom torch.distributed._composable.fsdp import fully_shard, MixedPrecisionPolicy\nfrom torch.distributed.device_mesh import init_device_mesh\nfrom torch.utils.checkpoint import checkpoint as activation_checkpoint\n\nlogger = logging.getLogger(__name__)\n\n\n# ===========================================================================\n# Configuration\n# ===========================================================================\n\n@dataclass\nclass ModelConfig:\n hidden_size: int = 4096\n intermediate_size: int = 14336\n num_attention_heads: int = 32\n num_key_value_heads: int = 8\n num_hidden_layers: int = 32\n vocab_size: int = 152064\n max_position_embeddings: int = 32768\n rms_norm_eps: float = 1e-6\n rope_theta: float = 1000000.0\n head_dim: int = 128\n hidden_act: str = \"silu\"\n tie_word_embeddings: bool = False\n\n@dataclass\nclass TrainingConfig:\n learning_rate: float = 1e-4\n min_lr_ratio: float = 0.1\n warmup_steps: int = 100\n total_steps: int = 10000\n max_grad_norm: float = 1.0\n weight_decay: float = 0.01\n gradient_accumulation_steps: int = 4\n micro_batch_size: int = 1\n max_seq_length: int = 16384\n seed: int = 42\n checkpoint_interval: int = 500\n log_interval: int = 10\n eval_interval: int = 100\n bf16: bool = True\n gradient_checkpointing: bool = True\n\n@dataclass\nclass ParallelConfig:\n dp_size: int = 1\n tp_size: int = 1\n pp_size: int = 1\n fsdp_cpu_offload: bool = False\n fsdp_backward_prefetch: str = \"backward_pre\"\n\n@dataclass\nclass InferenceConfig:\n inference_engine_type: str = \"vllm\"\n num_gpus: int = 2\n num_gpus_per_engine: int = 2\n batch_size: int = 8\n max_sample_pool_size: int = 64\n buffer_threshold: int = 32\n mem_fraction: float = 0.7\n aux_hidden_state_layers: Optional[List[int]] = None\n\n@dataclass\nclass MooncakeConfig:\n master_server_address: Optional[str] = None\n metadata_server: Optional[str] = None\n protocol: str = \"tcp\"\n global_segment_size: int = 4 * 1024**3\n local_buffer_size: int = 512 * 1024**2\n local_hostname: str = \"localhost\"\n\n def export_env(self):\n if self.master_server_address:\n host, port = self.master_server_address.rsplit(\":\", 1)\n os.environ[\"MOONCAKE_MASTER_HOST\"] = host\n os.environ[\"MOONCAKE_MASTER_PORT\"] = port\n os.environ[\"MOONCAKE_PROTOCOL\"] = self.protocol\n os.environ[\"MOONCAKE_GLOBAL_SEGMENT_SIZE\"] = str(self.global_segment_size)\n os.environ[\"MOONCAKE_LOCAL_BUFFER_SIZE\"] = str(self.local_buffer_size)\n\n @classmethod\n def from_env(cls):\n host = os.environ.get(\"MOONCAKE_MASTER_HOST\", \"localhost\")\n port = os.environ.get(\"MOONCAKE_MASTER_PORT\", \"50051\")\n return cls(\n master_server_address=f\"{host}:{port}\",\n metadata_server=os.environ.get(\"MOONCAKE_METADATA_SERVER\"),\n protocol=os.environ.get(\"MOONCAKE_PROTOCOL\", \"tcp\"),\n global_segment_size=int(os.environ.get(\"MOONCAKE_GLOBAL_SEGMENT_SIZE\", str(4 * 1024**3))),\n local_buffer_size=int(os.environ.get(\"MOONCAKE_LOCAL_BUFFER_SIZE\", str(512 * 1024**2))),\n )\n\n\n# ===========================================================================\n# Model Components\n# ===========================================================================\n\nclass RMSNorm(nn.Module):\n def __init__(self, hidden_size: int, eps: float = 1e-6):\n super().__init__()\n self.weight = nn.Parameter(torch.ones(hidden_size))\n self.eps = eps\n\n def forward(self, x: torch.Tensor, residual: Optional[torch.Tensor] = None):\n if residual is not None:\n x = x + residual\n variance = x.pow(2).mean(-1, keepdim=True)\n x = x * torch.rsqrt(variance + self.eps)\n output = self.weight * x\n if residual is not None:\n return output, x\n return output\n\n\nclass RotaryEmbedding(nn.Module):\n def __init__(self, dim: int, max_position_embeddings: int = 32768, base: float = 1000000.0):\n super().__init__()\n self.dim = dim\n self.max_position_embeddings = max_position_embeddings\n inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))\n self.register_buffer(\"inv_freq\", inv_freq, persistent=False)\n\n def forward(self, positions: torch.Tensor):\n freqs = torch.outer(positions.float(), self.inv_freq)\n emb = torch.cat([freqs, freqs], dim=-1)\n cos = emb.cos()\n sin = emb.sin()\n return cos, sin\n\n\ndef apply_rotary_pos_emb(q, k, cos, sin):\n def rotate_half(x):\n x1, x2 = x.chunk(2, dim=-1)\n return torch.cat([-x2, x1], dim=-1)\n\n q_embed = q * cos + rotate_half(q) * sin\n k_embed = k * cos + rotate_half(k) * sin\n return q_embed, k_embed\n\n\nclass MLP(nn.Module):\n def __init__(self, config: ModelConfig):\n super().__init__()\n self.gate_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)\n self.up_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)\n self.down_proj = nn.Linear(config.intermediate_size, config.hidden_size, bias=False)\n\n def forward(self, x: torch.Tensor) -> torch.Tensor:\n return self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x))\n\n\nclass Attention(nn.Module):\n def __init__(self, config: ModelConfig, layer_idx: int):\n super().__init__()\n self.hidden_size = config.hidden_size\n self.num_heads = config.num_attention_heads\n self.num_kv_heads = config.num_key_value_heads\n self.head_dim = config.head_dim\n self.num_kv_groups = self.num_heads // self.num_kv_heads\n\n self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)\n self.k_proj = nn.Linear(self.hidden_size, self.num_kv_heads * self.head_dim, bias=False)\n self.v_proj = nn.Linear(self.hidden_size, self.num_kv_heads * self.head_dim, bias=False)\n self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)\n\n self.rotary_emb = RotaryEmbedding(self.head_dim, config.max_position_embeddings, config.rope_theta)\n self.layer_idx = layer_idx\n\n def forward(self, hidden_states: torch.Tensor, positions: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:\n bsz, seq_len, _ = hidden_states.shape\n\n q = self.q_proj(hidden_states).view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2)\n k = self.k_proj(hidden_states).view(bsz, seq_len, self.num_kv_heads, self.head_dim).transpose(1, 2)\n v = self.v_proj(hidden_states).view(bsz, seq_len, self.num_kv_heads, self.head_dim).transpose(1, 2)\n\n cos, sin = self.rotary_emb(positions)\n cos = cos.unsqueeze(0).unsqueeze(0)\n sin = sin.unsqueeze(0).unsqueeze(0)\n q, k = apply_rotary_pos_emb(q, k, cos, sin)\n\n if self.num_kv_groups > 1:\n k = k.unsqueeze(2).expand(-1, -1, self.num_kv_groups, -1, -1).reshape(bsz, self.num_heads, seq_len, self.head_dim)\n v = v.unsqueeze(2).expand(-1, -1, self.num_kv_groups, -1, -1).reshape(bsz, self.num_heads, seq_len, self.head_dim)\n\n attn_weights = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim)\n\n if attention_mask is not None:\n attn_weights = attn_weights + attention_mask\n\n attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(q.dtype)\n output = torch.matmul(attn_weights, v)\n output = output.transpose(1, 2).contiguous().view(bsz, seq_len, self.hidden_size)\n return self.o_proj(output)\n\n\nclass TransformerLayer(nn.Module):\n def __init__(self, config: ModelConfig, layer_idx: int):\n super().__init__()\n self.self_attn = Attention(config, layer_idx)\n self.mlp = MLP(config)\n self.input_layernorm = RMSNorm(config.hidden_size, config.rms_norm_eps)\n self.post_attention_layernorm = RMSNorm(config.hidden_size, config.rms_norm_eps)\n\n def forward(self, hidden_states: torch.Tensor, positions: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, residual: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, torch.Tensor]:\n if residual is None:\n residual = hidden_states\n hidden_states = self.input_layernorm(hidden_states)\n else:\n hidden_states, residual = self.input_layernorm(hidden_states, residual)\n\n hidden_states = self.self_attn(hidden_states, positions, attention_mask)\n hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)\n hidden_states = self.mlp(hidden_states)\n return hidden_states, residual\n\n\nclass TransformerModel(nn.Module):\n def __init__(self, config: ModelConfig):\n super().__init__()\n self.config = config\n self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)\n self.layers = nn.ModuleList([TransformerLayer(config, i) for i in range(config.num_hidden_layers)])\n self.norm = RMSNorm(config.hidden_size, config.rms_norm_eps)\n\n def forward(self, input_ids: torch.Tensor, positions: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:\n hidden_states = self.embed_tokens(input_ids)\n residual = None\n\n for layer in self.layers:\n if self.training:\n hidden_states, residual = activation_checkpoint(\n layer, hidden_states, positions, attention_mask, residual,\n use_reentrant=False,\n )\n else:\n hidden_states, residual = layer(hidden_states, positions, attention_mask, residual)\n\n hidden_states = self.norm(hidden_states, residual)\n if isinstance(hidden_states, tuple):\n hidden_states = hidden_states[0]\n return hidden_states\n\n\nclass CausalLM(nn.Module):\n def __init__(self, config: ModelConfig):\n super().__init__()\n self.model = TransformerModel(config)\n self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)\n if config.tie_word_embeddings:\n self.lm_head.weight = self.model.embed_tokens.weight\n\n def forward(self, input_ids: torch.Tensor, positions: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, labels: Optional[torch.Tensor] = None) -> Dict[str, torch.Tensor]:\n hidden_states = self.model(input_ids, positions, attention_mask)\n logits = self.lm_head(hidden_states)\n\n loss = None\n if labels is not None:\n shift_logits = logits[..., :-1, :].contiguous()\n shift_labels = labels[..., 1:].contiguous()\n loss = F.cross_entropy(shift_logits.view(-1, self.config.vocab_size), shift_labels.view(-1), ignore_index=-100)\n\n return {\"loss\": loss, \"logits\": logits, \"hidden_states\": hidden_states}\n\n\n# ===========================================================================\n# Eagle3 Draft Model\n# ===========================================================================\n\nclass Eagle3DraftModel(nn.Module):\n def __init__(self, config: ModelConfig, aux_layer_ids: List[int]):\n super().__init__()\n self.aux_layer_ids = aux_layer_ids\n num_aux = len(aux_layer_ids)\n self.midlayer = nn.Linear(config.hidden_size * num_aux, config.hidden_size, bias=False)\n self.layers = nn.ModuleList([TransformerLayer(config, i) for i in range(2)])\n self.embed_tokens = None\n self.lm_head = None\n\n def combine_hidden_states(self, concat_hidden: torch.Tensor) -> torch.Tensor:\n return self.midlayer(concat_hidden)\n\n def forward(self, input_ids: torch.Tensor, hidden_states: torch.Tensor, positions: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:\n if self.embed_tokens is None:\n raise RuntimeError(\"embed_tokens not set. Call load_embedding() first.\")\n\n token_embeds = self.embed_tokens(input_ids)\n x = token_embeds + hidden_states\n\n residual = None\n for layer in self.layers:\n x, residual = layer(x, positions, attention_mask, residual)\n\n if residual is not None:\n x = x + residual\n\n return x\n\n def compute_logits(self, hidden_states: torch.Tensor) -> torch.Tensor:\n if self.lm_head is None:\n raise RuntimeError(\"lm_head not set.\")\n return self.lm_head(hidden_states)\n\n def load_embedding(self, target_model: CausalLM):\n self.embed_tokens = target_model.model.embed_tokens\n self.lm_head = target_model.lm_head\n for param in self.embed_tokens.parameters():\n param.requires_grad = False\n\n\n# ===========================================================================\n# Training Utilities\n# ===========================================================================\n\nclass WarmupCosineScheduler:\n def __init__(self, optimizer, warmup_steps: int, total_steps: int, min_lr_ratio: float = 0.1):\n self.optimizer = optimizer\n self.warmup_steps = warmup_steps\n self.total_steps = total_steps\n self.min_lr_ratio = min_lr_ratio\n self.base_lrs = [pg[\"lr\"] for pg in optimizer.param_groups]\n self.step_count = 0\n\n def step(self):\n self.step_count += 1\n if self.step_count <= self.warmup_steps:\n scale = self.step_count / self.warmup_steps\n else:\n progress = (self.step_count - self.warmup_steps) / max(1, self.total_steps - self.warmup_steps)\n scale = self.min_lr_ratio + 0.5 * (1 - self.min_lr_ratio) * (1 + math.cos(math.pi * progress))\n for pg, base_lr in zip(self.optimizer.param_groups, self.base_lrs):\n pg[\"lr\"] = base_lr * scale\n\n def get_lr(self):\n return self.optimizer.param_groups[0][\"lr\"]\n\n\nclass MetricsTracker:\n def __init__(self, window_size: int = 100):\n self.metrics = defaultdict(list)\n self.window_size = window_size\n\n def update(self, **kwargs):\n for key, value in kwargs.items():\n if isinstance(value, torch.Tensor):\n value = value.item()\n self.metrics[key].append(value)\n\n def get_smoothed(self, key: str) -> float:\n values = self.metrics.get(key, [])\n if not values:\n return 0.0\n recent = values[-self.window_size:]\n return sum(recent) / len(recent)\n\n def report(self, step: int, keys: Optional[List[str]] = None):\n if keys is None:\n keys = list(self.metrics.keys())\n parts = [f\"step={step}\"]\n for key in keys:\n parts.append(f\"{key}={self.get_smoothed(key):.4f}\")\n return \" | \".join(parts)\n\n\nclass CheckpointManager:\n def __init__(self, save_dir: str, max_keep: int = 3):\n self.save_dir = save_dir\n self.max_keep = max_keep\n self.saved_checkpoints = []\n os.makedirs(save_dir, exist_ok=True)\n\n def save(self, model, optimizer, scheduler, step: int, metrics: dict):\n path = os.path.join(self.save_dir, f\"checkpoint-{step}\")\n os.makedirs(path, exist_ok=True)\n\n torch.save({\n \"model_state_dict\": model.state_dict(),\n \"optimizer_state_dict\": optimizer.state_dict(),\n \"scheduler_state_dict\": {\"step_count\": scheduler.step_count},\n \"step\": step,\n \"metrics\": metrics,\n }, os.path.join(path, \"state.pt\"))\n\n self.saved_checkpoints.append(path)\n while len(self.saved_checkpoints) > self.max_keep:\n old_path = self.saved_checkpoints.pop(0)\n import shutil\n shutil.rmtree(old_path, ignore_errors=True)\n\n logger.info(f\"Saved checkpoint at step {step} to {path}\")\n\n def load_latest(self, model, optimizer, scheduler) -> int:\n if not self.saved_checkpoints:\n checkpoints = sorted(\n [d for d in os.listdir(self.save_dir) if d.startswith(\"checkpoint-\")],\n key=lambda x: int(x.split(\"-\")[1])\n )\n if not checkpoints:\n return 0\n self.saved_checkpoints = [os.path.join(self.save_dir, c) for c in checkpoints]\n\n latest = self.saved_checkpoints[-1]\n state = torch.load(os.path.join(latest, \"state.pt\"), map_location=\"cpu\")\n model.load_state_dict(state[\"model_state_dict\"])\n optimizer.load_state_dict(state[\"optimizer_state_dict\"])\n scheduler.step_count = state[\"scheduler_state_dict\"][\"step_count\"]\n logger.info(f\"Loaded checkpoint from {latest} at step {state['step']}\")\n return state[\"step\"]\n\n\n# ===========================================================================\n# Data Pipeline\n# ===========================================================================\n\nclass SamplePool:\n def __init__(self, max_size: int = 1000):\n self.max_size = max_size\n self.samples = []\n self.lock = threading.Lock()\n\n def add(self, samples: List[dict]):\n with self.lock:\n self.samples.extend(samples)\n if len(self.samples) > self.max_size:\n self.samples = self.samples[-self.max_size:]\n\n def get_batch(self, batch_size: int) -> List[dict]:\n with self.lock:\n if len(self.samples) < batch_size:\n return []\n batch = self.samples[:batch_size]\n self.samples = self.samples[batch_size:]\n return batch\n\n @property\n def size(self):\n return len(self.samples)\n\n\nclass DataCollator:\n def __init__(self, max_seq_length: int, pad_token_id: int = 0):\n self.max_seq_length = max_seq_length\n self.pad_token_id = pad_token_id\n\n def __call__(self, samples: List[dict]) -> dict:\n input_ids_list = []\n hidden_states_list = []\n last_hidden_states_list = []\n attention_mask_list = []\n loss_mask_list = []\n\n max_len = min(\n max(s[\"input_ids\"].shape[-1] for s in samples),\n self.max_seq_length,\n )\n\n for sample in samples:\n ids = sample[\"input_ids\"]\n if ids.dim() == 2:\n ids = ids.squeeze(0)\n seq_len = ids.shape[0]\n\n if seq_len > max_len:\n ids = ids[:max_len]\n seq_len = max_len\n\n pad_len = max_len - seq_len\n if pad_len > 0:\n ids = F.pad(ids, (0, pad_len), value=self.pad_token_id)\n\n input_ids_list.append(ids)\n\n if \"hidden_states\" in sample:\n hs = sample[\"hidden_states\"]\n if hs.shape[0] > max_len:\n hs = hs[:max_len]\n if hs.shape[0] < max_len:\n hs = F.pad(hs, (0, 0, 0, max_len - hs.shape[0]))\n hidden_states_list.append(hs)\n\n if \"last_hidden_states\" in sample:\n lhs = sample[\"last_hidden_states\"]\n if lhs.shape[0] > max_len:\n lhs = lhs[:max_len]\n if lhs.shape[0] < max_len:\n lhs = F.pad(lhs, (0, 0, 0, max_len - lhs.shape[0]))\n last_hidden_states_list.append(lhs)\n\n mask = torch.ones(max_len, dtype=torch.bool)\n if pad_len > 0:\n mask[-pad_len:] = False\n attention_mask_list.append(mask)\n\n if \"loss_mask\" in sample:\n lm = sample[\"loss_mask\"]\n if lm.shape[0] > max_len:\n lm = lm[:max_len]\n if lm.shape[0] < max_len:\n lm = F.pad(lm, (0, max_len - lm.shape[0]))\n loss_mask_list.append(lm)\n else:\n loss_mask_list.append(mask.float())\n\n batch = {\n \"input_ids\": torch.stack(input_ids_list),\n \"attention_mask\": torch.stack(attention_mask_list),\n \"loss_mask\": torch.stack(loss_mask_list),\n }\n if hidden_states_list:\n batch[\"hidden_states\"] = torch.stack(hidden_states_list)\n if last_hidden_states_list:\n batch[\"last_hidden_states\"] = torch.stack(last_hidden_states_list)\n\n return batch\n\n\n# ===========================================================================\n# Inference Manager\n# ===========================================================================\n\nclass InferenceManager:\n def __init__(self, config: InferenceConfig, mooncake_config: Optional[MooncakeConfig] = None):\n self.config = config\n self.mooncake_config = mooncake_config\n self.engines = {}\n self._pending_count = 0\n self._lock = threading.Lock()\n\n def init_engines(self, target_model_path: str):\n num_engines = self.config.num_gpus // self.config.num_gpus_per_engine\n for i in range(num_engines):\n base_gpu = i * self.config.num_gpus_per_engine\n engine = self._create_engine(i, base_gpu)\n engine.init(self.mooncake_config)\n self.engines[i] = engine\n\n def _create_engine(self, rank: int, base_gpu: int):\n raise NotImplementedError(\"Subclasses must implement engine creation\")\n\n def should_generate(self) -> bool:\n with self._lock:\n return self._pending_count < self.config.buffer_threshold\n\n def generate(self, data_batch: List[dict]) -> List[dict]:\n engine = self._select_engine()\n results = engine.generate(data_batch)\n with self._lock:\n self._pending_count += len(results)\n return results\n\n def consume(self, count: int):\n with self._lock:\n self._pending_count = max(0, self._pending_count - count)\n\n def _select_engine(self):\n return list(self.engines.values())[0]\n\n def shutdown(self):\n for engine in self.engines.values():\n engine.shutdown()\n self.engines.clear()\n\n\n# ===========================================================================\n# Training Loop\n# ===========================================================================\n\nclass Trainer:\n def __init__(self, config: TrainingConfig, parallel_config: ParallelConfig):\n self.config = config\n self.parallel_config = parallel_config\n self.model = None\n self.optimizer = None\n self.scheduler = None\n self.metrics = MetricsTracker()\n self.checkpoint_mgr = None\n self.global_step = 0\n\n def init_model(self, model: nn.Module, output_dir: str):\n if self.parallel_config.dp_size > 1:\n mesh = init_device_mesh(\"cuda\", (self.parallel_config.dp_size,))\n mp_policy = MixedPrecisionPolicy(\n param_dtype=torch.bfloat16 if self.config.bf16 else torch.float32,\n reduce_dtype=torch.float32,\n )\n for layer in model.layers if hasattr(model, \"layers\") else []:\n fully_shard(layer, mesh=mesh, mp_policy=mp_policy)\n fully_shard(model, mesh=mesh, mp_policy=mp_policy)\n\n self.model = model\n self.optimizer = torch.optim.AdamW(\n [p for p in model.parameters() if p.requires_grad],\n lr=self.config.learning_rate,\n weight_decay=self.config.weight_decay,\n betas=(0.9, 0.95),\n )\n self.scheduler = WarmupCosineScheduler(\n self.optimizer,\n self.config.warmup_steps,\n self.config.total_steps,\n self.config.min_lr_ratio,\n )\n self.checkpoint_mgr = CheckpointManager(output_dir)\n self.global_step = self.checkpoint_mgr.load_latest(model, self.optimizer, self.scheduler)\n\n def train_step(self, batch: dict) -> dict:\n self.model.train()\n total_loss = 0.0\n\n for micro_step in range(self.config.gradient_accumulation_steps):\n with torch.amp.autocast(\"cuda\", dtype=torch.bfloat16, enabled=self.config.bf16):\n outputs = self.model(**batch)\n loss = outputs[\"loss\"] / self.config.gradient_accumulation_steps\n\n loss.backward()\n total_loss += loss.item()\n\n grad_norm = torch.nn.utils.clip_grad_norm_(\n self.model.parameters(), self.config.max_grad_norm\n )\n\n self.optimizer.step()\n self.scheduler.step()\n self.optimizer.zero_grad()\n\n self.global_step += 1\n\n metrics = {\n \"loss\": total_loss,\n \"grad_norm\": grad_norm.item() if isinstance(grad_norm, torch.Tensor) else grad_norm,\n \"lr\": self.scheduler.get_lr(),\n \"step\": self.global_step,\n }\n self.metrics.update(**metrics)\n\n if self.global_step % self.config.log_interval == 0:\n logger.info(self.metrics.report(self.global_step))\n\n if self.global_step % self.config.checkpoint_interval == 0:\n self.checkpoint_mgr.save(\n self.model, self.optimizer, self.scheduler,\n self.global_step, metrics,\n )\n\n return metrics\n\n def eval_step(self, batch: dict) -> dict:\n self.model.eval()\n with torch.no_grad():\n with torch.amp.autocast(\"cuda\", dtype=torch.bfloat16, enabled=self.config.bf16):\n outputs = self.model(**batch)\n return {\"eval_loss\": outputs[\"loss\"].item()}\n\n\n# ===========================================================================\n# Main Training Controller\n# ===========================================================================\n\nclass TrainingController:\n def __init__(\n self,\n training_config: TrainingConfig,\n parallel_config: ParallelConfig,\n inference_config: InferenceConfig,\n mooncake_config: Optional[MooncakeConfig] = None,\n ):\n self.training_config = training_config\n self.parallel_config = parallel_config\n self.inference_config = inference_config\n self.mooncake_config = mooncake_config\n self.trainer = Trainer(training_config, parallel_config)\n self.inference_mgr = InferenceManager(inference_config, mooncake_config)\n self.sample_pool = SamplePool(inference_config.max_sample_pool_size)\n self.collator = DataCollator(training_config.max_seq_length)\n self._stop_event = threading.Event()\n\n def run(self, model: nn.Module, target_model_path: str, output_dir: str, train_data):\n self.trainer.init_model(model, output_dir)\n self.inference_mgr.init_engines(target_model_path)\n\n inference_thread = threading.Thread(\n target=self._inference_loop, args=(train_data,), daemon=True\n )\n inference_thread.start()\n\n try:\n self._training_loop()\n finally:\n self._stop_event.set()\n inference_thread.join(timeout=30)\n self.inference_mgr.shutdown()\n\n def _inference_loop(self, train_data):\n data_iter = iter(train_data)\n while not self._stop_event.is_set():\n if not self.inference_mgr.should_generate():\n time.sleep(0.1)\n continue\n\n try:\n batch = [next(data_iter) for _ in range(self.inference_config.batch_size)]\n except StopIteration:\n data_iter = iter(train_data)\n continue\n\n try:\n results = self.inference_mgr.generate(batch)\n self.sample_pool.add(results)\n except Exception as e:\n logger.error(f\"Inference error: {e}\")\n time.sleep(1.0)\n\n def _training_loop(self):\n for step in range(self.trainer.global_step, self.training_config.total_steps):\n batch_data = self.sample_pool.get_batch(self.training_config.micro_batch_size)\n if not batch_data:\n time.sleep(0.1)\n continue\n\n batch = self.collator(batch_data)\n batch = {k: v.cuda() if isinstance(v, torch.Tensor) else v for k, v in batch.items()}\n\n metrics = self.trainer.train_step(batch)\n self.inference_mgr.consume(len(batch_data))\n\n if step % self.training_config.eval_interval == 0:\n eval_batch = self.sample_pool.get_batch(1)\n if eval_batch:\n eval_data = self.collator(eval_batch)\n eval_data = {k: v.cuda() if isinstance(v, torch.Tensor) else v for k, v in eval_data.items()}\n eval_metrics = self.trainer.eval_step(eval_data)\n logger.info(f\"Eval at step {step}: {eval_metrics}\")\n```\n\nPlease identify the top 3 most critical issues in this codebase."}, {"role": "assistant", "content": "Here are the top 3 critical issues:\n\n1. **Gradient accumulation reuses the same batch**: In `train_step`, the inner loop runs `gradient_accumulation_steps` times but feeds the exact same `batch` each iteration instead of different micro-batches. This means you're computing the same gradients K times and dividing by K, which is mathematically equivalent to a single step with no accumulation benefit.\n\n2. **Eval consumes training samples**: `eval_step` calls `self.sample_pool.get_batch(1)` which destructively pops from the same pool used for training. This silently drops training samples and biases the training distribution. Eval should use a separate held-out pool.\n\n3. **Race condition in SamplePool**: `size` property reads `len(self.samples)` without acquiring the lock, while `add` and `get_batch` modify the list under the lock. This can cause the inference manager's `should_generate` to make decisions on stale data, leading to pool overflow or starvation."}, {"role": "user", "content": "Great analysis. Now here's the corresponding test suite. Review it for coverage gaps and incorrect assertions.\n\n```python\nimport pytest\nimport torch\nimport torch.nn as nn\nimport torch.distributed as dist\nfrom unittest.mock import MagicMock, patch, PropertyMock\nfrom collections import OrderedDict\nimport os\nimport tempfile\nimport shutil\nimport json\nimport math\nimport threading\nimport time\nfrom typing import List, Dict, Optional, Tuple\n\n# Assume the main module is importable\n# from training_framework import (\n# ModelConfig, TrainingConfig, ParallelConfig, InferenceConfig, MooncakeConfig,\n# RMSNorm, RotaryEmbedding, MLP, Attention, TransformerLayer, TransformerModel, CausalLM,\n# Eagle3DraftModel, WarmupCosineScheduler, MetricsTracker, CheckpointManager,\n# SamplePool, DataCollator, InferenceManager, Trainer, TrainingController,\n# apply_rotary_pos_emb,\n# )\n\n\nclass TestModelConfig:\n def test_default_values(self):\n config = ModelConfig()\n assert config.hidden_size == 4096\n assert config.intermediate_size == 14336\n assert config.num_attention_heads == 32\n assert config.num_key_value_heads == 8\n assert config.num_hidden_layers == 32\n assert config.vocab_size == 152064\n assert config.head_dim == 128\n assert config.rms_norm_eps == 1e-6\n\n def test_custom_values(self):\n config = ModelConfig(hidden_size=2048, num_hidden_layers=16)\n assert config.hidden_size == 2048\n assert config.num_hidden_layers == 16\n assert config.intermediate_size == 14336 # unchanged default\n\n\nclass TestTrainingConfig:\n def test_defaults(self):\n config = TrainingConfig()\n assert config.learning_rate == 1e-4\n assert config.gradient_accumulation_steps == 4\n assert config.bf16 is True\n assert config.gradient_checkpointing is True\n\n def test_custom(self):\n config = TrainingConfig(learning_rate=5e-5, total_steps=5000)\n assert config.learning_rate == 5e-5\n assert config.total_steps == 5000\n\n\nclass TestMooncakeConfig:\n def test_export_and_from_env(self):\n config = MooncakeConfig(\n master_server_address=\"10.0.0.1:50051\",\n metadata_server=\"http://10.0.0.1:8090\",\n protocol=\"rdma\",\n )\n config.export_env()\n\n assert os.environ.get(\"MOONCAKE_MASTER_HOST\") == \"10.0.0.1\"\n assert os.environ.get(\"MOONCAKE_MASTER_PORT\") == \"50051\"\n assert os.environ.get(\"MOONCAKE_PROTOCOL\") == \"rdma\"\n\n restored = MooncakeConfig.from_env()\n assert restored.master_server_address == \"10.0.0.1:50051\"\n assert restored.protocol == \"rdma\"\n\n def test_default_from_env(self):\n for key in [\"MOONCAKE_MASTER_HOST\", \"MOONCAKE_MASTER_PORT\", \"MOONCAKE_PROTOCOL\"]:\n os.environ.pop(key, None)\n config = MooncakeConfig.from_env()\n assert config.master_server_address == \"localhost:50051\"\n assert config.protocol == \"tcp\"\n\n\nclass TestRMSNorm:\n def test_forward_shape(self):\n norm = RMSNorm(64, eps=1e-6)\n x = torch.randn(2, 10, 64)\n out = norm(x)\n assert out.shape == (2, 10, 64)\n\n def test_forward_with_residual(self):\n norm = RMSNorm(64, eps=1e-6)\n x = torch.randn(2, 10, 64)\n residual = torch.randn(2, 10, 64)\n out, new_residual = norm(x, residual)\n assert out.shape == (2, 10, 64)\n assert new_residual.shape == (2, 10, 64)\n\n def test_unit_weight_is_identity_for_unit_norm(self):\n norm = RMSNorm(4, eps=0)\n x = torch.tensor([[[1.0, 0.0, 0.0, 0.0]]])\n out = norm(x)\n expected_rms = 1.0 / math.sqrt(4)\n expected = x / expected_rms\n assert torch.allclose(out, expected, atol=1e-5)\n\n\nclass TestRotaryEmbedding:\n def test_output_shape(self):\n rope = RotaryEmbedding(dim=128, max_position_embeddings=4096)\n positions = torch.arange(10)\n cos, sin = rope(positions)\n assert cos.shape == (10, 128)\n assert sin.shape == (10, 128)\n\n def test_cos_sin_range(self):\n rope = RotaryEmbedding(dim=64)\n positions = torch.arange(100)\n cos, sin = rope(positions)\n assert cos.min() >= -1.0\n assert cos.max() <= 1.0\n assert sin.min() >= -1.0\n assert sin.max() <= 1.0\n\n def test_position_zero(self):\n rope = RotaryEmbedding(dim=4)\n cos, sin = rope(torch.tensor([0]))\n assert torch.allclose(cos, torch.ones_like(cos))\n assert torch.allclose(sin, torch.zeros_like(sin), atol=1e-7)\n\n\nclass TestApplyRotaryPosEmb:\n def test_output_shape_preserved(self):\n q = torch.randn(1, 8, 16, 128)\n k = torch.randn(1, 2, 16, 128)\n cos = torch.randn(16, 128)\n sin = torch.randn(16, 128)\n cos = cos.unsqueeze(0).unsqueeze(0)\n sin = sin.unsqueeze(0).unsqueeze(0)\n q_out, k_out = apply_rotary_pos_emb(q, k, cos, sin)\n assert q_out.shape == q.shape\n assert k_out.shape == k.shape\n\n def test_zero_sin_is_identity_scaling(self):\n q = torch.randn(1, 1, 4, 8)\n k = torch.randn(1, 1, 4, 8)\n cos = torch.ones(1, 1, 4, 8)\n sin = torch.zeros(1, 1, 4, 8)\n q_out, k_out = apply_rotary_pos_emb(q, k, cos, sin)\n assert torch.allclose(q_out, q)\n assert torch.allclose(k_out, k)\n\n\nclass TestMLP:\n def test_forward_shape(self):\n config = ModelConfig(hidden_size=64, intermediate_size=256)\n mlp = MLP(config)\n x = torch.randn(2, 10, 64)\n out = mlp(x)\n assert out.shape == (2, 10, 64)\n\n def test_gradient_flow(self):\n config = ModelConfig(hidden_size=32, intermediate_size=64)\n mlp = MLP(config)\n x = torch.randn(1, 4, 32, requires_grad=True)\n out = mlp(x)\n out.sum().backward()\n assert x.grad is not None\n assert x.grad.shape == x.shape\n\n\nclass TestAttention:\n def test_forward_shape(self):\n config = ModelConfig(hidden_size=64, num_attention_heads=4, num_key_value_heads=2, head_dim=16)\n attn = Attention(config, layer_idx=0)\n x = torch.randn(1, 8, 64)\n positions = torch.arange(8)\n out = attn(x, positions)\n assert out.shape == (1, 8, 64)\n\n def test_causal_masking(self):\n config = ModelConfig(hidden_size=32, num_attention_heads=2, num_key_value_heads=2, head_dim=16)\n attn = Attention(config, layer_idx=0)\n x = torch.randn(1, 4, 32)\n positions = torch.arange(4)\n\n causal_mask = torch.triu(torch.full((4, 4), float(\"-inf\")), diagonal=1)\n causal_mask = causal_mask.unsqueeze(0).unsqueeze(0)\n out_masked = attn(x, positions, attention_mask=causal_mask)\n assert out_masked.shape == (1, 4, 32)\n\n\nclass TestTransformerLayer:\n def test_forward_shapes(self):\n config = ModelConfig(hidden_size=64, num_attention_heads=4, num_key_value_heads=2, head_dim=16, intermediate_size=128)\n layer = TransformerLayer(config, layer_idx=0)\n x = torch.randn(1, 8, 64)\n positions = torch.arange(8)\n hidden, residual = layer(x, positions)\n assert hidden.shape == (1, 8, 64)\n assert residual.shape == (1, 8, 64)\n\n def test_with_residual_input(self):\n config = ModelConfig(hidden_size=64, num_attention_heads=4, num_key_value_heads=2, head_dim=16, intermediate_size=128)\n layer = TransformerLayer(config, layer_idx=0)\n x = torch.randn(1, 4, 64)\n residual = torch.randn(1, 4, 64)\n positions = torch.arange(4)\n hidden, new_residual = layer(x, positions, residual=residual)\n assert hidden.shape == (1, 4, 64)\n assert new_residual.shape == (1, 4, 64)\n\n\nclass TestTransformerModel:\n def test_forward_shape(self):\n config = ModelConfig(hidden_size=64, num_attention_heads=4, num_key_value_heads=2, head_dim=16, intermediate_size=128, num_hidden_layers=2, vocab_size=100)\n model = TransformerModel(config)\n input_ids = torch.randint(0, 100, (1, 8))\n positions = torch.arange(8)\n out = model(input_ids, positions)\n assert out.shape == (1, 8, 64)\n\n def test_deterministic_output(self):\n config = ModelConfig(hidden_size=32, num_attention_heads=2, num_key_value_heads=1, head_dim=16, intermediate_size=64, num_hidden_layers=1, vocab_size=50)\n model = TransformerModel(config)\n model.eval()\n input_ids = torch.tensor([[1, 2, 3, 4]])\n positions = torch.arange(4)\n with torch.no_grad():\n out1 = model(input_ids, positions)\n out2 = model(input_ids, positions)\n assert torch.allclose(out1, out2)\n\n\nclass TestCausalLM:\n def test_forward_without_labels(self):\n config = ModelConfig(hidden_size=64, num_attention_heads=4, num_key_value_heads=2, head_dim=16, intermediate_size=128, num_hidden_layers=2, vocab_size=100)\n model = CausalLM(config)\n input_ids = torch.randint(0, 100, (1, 8))\n positions = torch.arange(8)\n outputs = model(input_ids, positions)\n assert outputs[\"loss\"] is None\n assert outputs[\"logits\"].shape == (1, 8, 100)\n\n def test_forward_with_labels(self):\n config = ModelConfig(hidden_size=64, num_attention_heads=4, num_key_value_heads=2, head_dim=16, intermediate_size=128, num_hidden_layers=2, vocab_size=100)\n model = CausalLM(config)\n input_ids = torch.randint(0, 100, (1, 8))\n labels = torch.randint(0, 100, (1, 8))\n positions = torch.arange(8)\n outputs = model(input_ids, positions, labels=labels)\n assert outputs[\"loss\"] is not None\n assert outputs[\"loss\"].dim() == 0\n assert outputs[\"loss\"].item() > 0\n\n\nclass TestEagle3DraftModel:\n def test_forward_shape(self):\n config = ModelConfig(hidden_size=64, num_attention_heads=4, num_key_value_heads=2, head_dim=16, intermediate_size=128, vocab_size=100)\n draft = Eagle3DraftModel(config, aux_layer_ids=[2, 8, 15])\n target = CausalLM(config)\n draft.load_embedding(target)\n\n input_ids = torch.randint(0, 100, (1, 4))\n hidden_states = torch.randn(1, 4, 64)\n positions = torch.arange(4)\n out = draft(input_ids, hidden_states, positions)\n assert out.shape == (1, 4, 64)\n\n def test_combine_hidden_states(self):\n config = ModelConfig(hidden_size=64)\n draft = Eagle3DraftModel(config, aux_layer_ids=[2, 8, 15])\n concat = torch.randn(1, 4, 64 * 3)\n combined = draft.combine_hidden_states(concat)\n assert combined.shape == (1, 4, 64)\n\n def test_compute_logits(self):\n config = ModelConfig(hidden_size=64, vocab_size=100)\n draft = Eagle3DraftModel(config, aux_layer_ids=[2, 8])\n target = CausalLM(config)\n draft.load_embedding(target)\n hidden = torch.randn(1, 4, 64)\n logits = draft.compute_logits(hidden)\n assert logits.shape == (1, 4, 100)\n\n def test_raises_without_embedding(self):\n config = ModelConfig(hidden_size=64, vocab_size=100)\n draft = Eagle3DraftModel(config, aux_layer_ids=[2])\n with pytest.raises(RuntimeError, match=\"embed_tokens not set\"):\n draft(torch.tensor([[1]]), torch.randn(1, 1, 64), torch.tensor([0]))\n\n\nclass TestWarmupCosineScheduler:\n def test_warmup_phase(self):\n model = nn.Linear(10, 10)\n optimizer = torch.optim.SGD(model.parameters(), lr=0.1)\n scheduler = WarmupCosineScheduler(optimizer, warmup_steps=10, total_steps=100)\n\n lrs = []\n for _ in range(10):\n scheduler.step()\n lrs.append(scheduler.get_lr())\n\n assert lrs[0] < lrs[-1]\n assert abs(lrs[-1] - 0.1) < 1e-6\n\n def test_cosine_decay(self):\n model = nn.Linear(10, 10)\n optimizer = torch.optim.SGD(model.parameters(), lr=0.1)\n scheduler = WarmupCosineScheduler(optimizer, warmup_steps=0, total_steps=100, min_lr_ratio=0.0)\n\n lrs = []\n for _ in range(100):\n scheduler.step()\n lrs.append(scheduler.get_lr())\n\n assert lrs[0] > lrs[-1]\n assert lrs[-1] < 0.01\n\n def test_min_lr_ratio(self):\n model = nn.Linear(10, 10)\n optimizer = torch.optim.SGD(model.parameters(), lr=0.1)\n scheduler = WarmupCosineScheduler(optimizer, warmup_steps=0, total_steps=100, min_lr_ratio=0.5)\n\n for _ in range(200):\n scheduler.step()\n assert scheduler.get_lr() >= 0.05 - 1e-6\n\n\nclass TestMetricsTracker:\n def test_update_and_smoothed(self):\n tracker = MetricsTracker(window_size=5)\n for i in range(10):\n tracker.update(loss=float(i))\n smoothed = tracker.get_smoothed(\"loss\")\n assert abs(smoothed - 7.0) < 1e-6\n\n def test_empty_metric(self):\n tracker = MetricsTracker()\n assert tracker.get_smoothed(\"nonexistent\") == 0.0\n\n def test_report_format(self):\n tracker = MetricsTracker()\n tracker.update(loss=1.5, acc=0.9)\n report = tracker.report(step=10)\n assert \"step=10\" in report\n assert \"loss=\" in report\n assert \"acc=\" in report\n\n def test_tensor_values(self):\n tracker = MetricsTracker()\n tracker.update(loss=torch.tensor(2.5))\n assert tracker.get_smoothed(\"loss\") == 2.5\n\n\nclass TestCheckpointManager:\n def test_save_and_load(self):\n with tempfile.TemporaryDirectory() as tmpdir:\n mgr = CheckpointManager(tmpdir, max_keep=2)\n model = nn.Linear(10, 5)\n optimizer = torch.optim.SGD(model.parameters(), lr=0.01)\n scheduler = WarmupCosineScheduler(optimizer, 10, 100)\n\n mgr.save(model, optimizer, scheduler, step=100, metrics={\"loss\": 0.5})\n loaded_step = mgr.load_latest(model, optimizer, scheduler)\n assert loaded_step == 100\n\n def test_max_keep(self):\n with tempfile.TemporaryDirectory() as tmpdir:\n mgr = CheckpointManager(tmpdir, max_keep=2)\n model = nn.Linear(10, 5)\n optimizer = torch.optim.SGD(model.parameters(), lr=0.01)\n scheduler = WarmupCosineScheduler(optimizer, 10, 100)\n\n for step in [100, 200, 300]:\n mgr.save(model, optimizer, scheduler, step=step, metrics={})\n\n checkpoints = [d for d in os.listdir(tmpdir) if d.startswith(\"checkpoint-\")]\n assert len(checkpoints) == 2\n assert \"checkpoint-100\" not in checkpoints\n\n def test_load_empty_dir(self):\n with tempfile.TemporaryDirectory() as tmpdir:\n mgr = CheckpointManager(tmpdir)\n model = nn.Linear(10, 5)\n optimizer = torch.optim.SGD(model.parameters(), lr=0.01)\n scheduler = WarmupCosineScheduler(optimizer, 10, 100)\n step = mgr.load_latest(model, optimizer, scheduler)\n assert step == 0\n\n\nclass TestSamplePool:\n def test_add_and_get(self):\n pool = SamplePool(max_size=10)\n pool.add([{\"id\": i} for i in range(5)])\n assert pool.size == 5\n batch = pool.get_batch(3)\n assert len(batch) == 3\n assert pool.size == 2\n\n def test_overflow_truncation(self):\n pool = SamplePool(max_size=5)\n pool.add([{\"id\": i} for i in range(10)])\n assert pool.size == 5\n\n def test_empty_get(self):\n pool = SamplePool()\n batch = pool.get_batch(5)\n assert batch == []\n\n def test_thread_safety(self):\n pool = SamplePool(max_size=1000)\n errors = []\n\n def producer():\n try:\n for i in range(100):\n pool.add([{\"id\": f\"p_{i}\"}])\n time.sleep(0.001)\n except Exception as e:\n errors.append(e)\n\n def consumer():\n try:\n consumed = 0\n for _ in range(200):\n batch = pool.get_batch(1)\n consumed += len(batch)\n time.sleep(0.001)\n except Exception as e:\n errors.append(e)\n\n threads = [threading.Thread(target=producer) for _ in range(3)]\n threads += [threading.Thread(target=consumer) for _ in range(2)]\n for t in threads:\n t.start()\n for t in threads:\n t.join()\n\n assert len(errors) == 0\n\n\nclass TestDataCollator:\n def test_basic_collation(self):\n collator = DataCollator(max_seq_length=16)\n samples = [\n {\"input_ids\": torch.tensor([1, 2, 3, 4, 5])},\n {\"input_ids\": torch.tensor([6, 7, 8])},\n ]\n batch = collator(samples)\n assert batch[\"input_ids\"].shape == (2, 5)\n assert batch[\"attention_mask\"].shape == (2, 5)\n assert batch[\"attention_mask\"][0].all()\n assert not batch[\"attention_mask\"][1][-2:].any()\n\n def test_truncation(self):\n collator = DataCollator(max_seq_length=4)\n samples = [{\"input_ids\": torch.tensor([1, 2, 3, 4, 5, 6, 7, 8])}]\n batch = collator(samples)\n assert batch[\"input_ids\"].shape == (1, 4)\n\n def test_with_hidden_states(self):\n collator = DataCollator(max_seq_length=16)\n samples = [{\n \"input_ids\": torch.tensor([1, 2, 3]),\n \"hidden_states\": torch.randn(3, 64),\n \"last_hidden_states\": torch.randn(3, 64),\n }]\n batch = collator(samples)\n assert \"hidden_states\" in batch\n assert \"last_hidden_states\" in batch\n assert batch[\"hidden_states\"].shape[1] == 3\n\n def test_2d_input_ids(self):\n collator = DataCollator(max_seq_length=16)\n samples = [{\"input_ids\": torch.tensor([[1, 2, 3]])}]\n batch = collator(samples)\n assert batch[\"input_ids\"].shape == (1, 3)\n\n def test_loss_mask(self):\n collator = DataCollator(max_seq_length=16)\n samples = [\n {\"input_ids\": torch.tensor([1, 2, 3, 4]), \"loss_mask\": torch.tensor([0.0, 0.0, 1.0, 1.0])},\n ]\n batch = collator(samples)\n assert batch[\"loss_mask\"].shape == (1, 4)\n assert batch[\"loss_mask\"][0][0] == 0.0\n assert batch[\"loss_mask\"][0][2] == 1.0\n\n\nclass TestInferenceManager:\n def test_should_generate(self):\n config = InferenceConfig(buffer_threshold=10)\n mgr = InferenceManager(config)\n assert mgr.should_generate() is True\n mgr._pending_count = 15\n assert mgr.should_generate() is False\n\n def test_consume(self):\n config = InferenceConfig()\n mgr = InferenceManager(config)\n mgr._pending_count = 10\n mgr.consume(5)\n assert mgr._pending_count == 5\n mgr.consume(10)\n assert mgr._pending_count == 0\n```\n\nWhat coverage gaps and incorrect assertions do you see?"}, {"role": "assistant", "content": "Key coverage gaps and issues:\n\n1. **No tests for the `TrainingController` orchestration**: The most complex class with threading, the inference-training loop coordination, and error handling is completely untested.\n\n2. **`TestDataCollator.test_with_hidden_states` has a wrong assertion**: It checks `batch['hidden_states'].shape[1] == 3` but the collator pads to `max_len=3` (since there's only one sample), so this passes accidentally. It should also verify padding behavior when samples have different lengths.\n\n3. **No tests for the `Trainer.train_step` gradient accumulation bug**: The tests don't exercise the actual training loop, so the critical bug (same batch reused K times) goes undetected. Need an integration test that verifies gradients scale correctly with accumulation_steps.\n\n4. **`TestSamplePool.test_overflow_truncation` doesn't check which items survive**: It only asserts `size == 5` but doesn't verify the pool keeps the *last* 5 items (FIFO eviction). If the implementation changes to keep the first 5, the test still passes.\n\n5. **Missing edge cases**: No test for `CheckpointManager` with corrupted files, `DataCollator` with empty samples list, or `WarmupCosineScheduler` with `warmup_steps > total_steps`."}]} diff --git a/pyproject.toml b/pyproject.toml index 6a23e90..7a483d8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -42,7 +42,7 @@ dev = [ ] vllm = [ - "vllm>=0.16.0", + "vllm>=0.18.0", ] fa = [ diff --git a/tests/test_vllm_engine.py b/tests/test_vllm_engine.py index 5f4ea85..1bebe73 100644 --- a/tests/test_vllm_engine.py +++ b/tests/test_vllm_engine.py @@ -18,15 +18,12 @@ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. -"""Tests for vLLM Worker Extension. +"""Tests for vLLM Engine and MooncakeHiddenStatesConnector. -This file contains both: -- Unit tests: Test logic with mocks (no GPU/vLLM/Mooncake needed) -- Integration tests: Test with real vLLM engine (requires GPU + infrastructure) +Unit tests: Test logic with mocks (no GPU/vLLM/Mooncake needed) """ -from dataclasses import dataclass -from unittest.mock import MagicMock, patch +from unittest.mock import MagicMock import pytest import torch @@ -36,31 +33,20 @@ # ============================================================================= -@dataclass -class MockArgs: - """Mock args for VllmWorkerExtension initialization.""" - - target_model_path: str = "Qwen/Qwen3-8B" - tensor_parallel_size: int = 2 - max_model_len: int = 2048 - trust_remote_code: bool = True - - -def _import_vllm_worker_extension(): - """Import VllmWorkerExtension, skipping test if dependencies unavailable.""" +def _import_connector_utils(): + """Import MooncakeHiddenStatesConnector utilities.""" try: - from torchspec.inference.engine.vllm_worker_extension import ( - VllmWorkerExtension, + from torchspec.inference.engine.mooncake_hidden_states_connector import ( _sanitize_mooncake_key, ) - return VllmWorkerExtension, _sanitize_mooncake_key + return _sanitize_mooncake_key except ImportError as e: - pytest.skip(f"VllmWorkerExtension import failed (missing deps): {e}") + pytest.skip(f"Connector import failed (missing deps): {e}") # ============================================================================= -# Unit Tests (No real vLLM/GPU/Mooncake needed) +# Unit Tests for _sanitize_mooncake_key # ============================================================================= @@ -68,235 +54,43 @@ class TestSanitizeMooncakeKey: """Unit tests for _sanitize_mooncake_key pure function.""" def test_alphanumeric_unchanged(self): - """Test alphanumeric keys pass through unchanged.""" - _, _sanitize = _import_vllm_worker_extension() + _sanitize = _import_connector_utils() assert _sanitize("req_abc_123") == "req_abc_123" def test_special_chars_replaced(self): - """Test special characters are replaced with underscores.""" - _, _sanitize = _import_vllm_worker_extension() + _sanitize = _import_connector_utils() assert _sanitize("req@abc#123") == "req_abc_123" assert _sanitize("req.id.name") == "req_id_name" assert _sanitize("req:name|value") == "req_name_value" def test_leading_digit_prefixed(self): - """Test leading digits get 'k' prefix.""" - _, _sanitize = _import_vllm_worker_extension() + _sanitize = _import_connector_utils() assert _sanitize("123_req") == "k123_req" assert _sanitize("1abc") == "k1abc" def test_empty_string(self): - """Test empty string handling.""" - _, _sanitize = _import_vllm_worker_extension() + _sanitize = _import_connector_utils() assert _sanitize("") == "" -class TestVllmWorkerExtensionState: - """Unit tests for VllmWorkerExtension state management.""" - - def test_init_stores_config(self): - """Test constructor initializes state correctly.""" - VllmWorkerExtension, _ = _import_vllm_worker_extension() - - ext = VllmWorkerExtension() - - assert ext._layer_ids == frozenset() - assert ext._captured_states is None - assert ext._request_metadata == [] - assert ext._current_request_metadata is None - assert ext._mooncake_store is None - assert ext._store_initialized is False - - def test_set_request_metadata(self): - """Test setting request metadata.""" - VllmWorkerExtension, _ = _import_vllm_worker_extension() - - ext = VllmWorkerExtension() - metadata = {"req_1": 100, "req_2": 200} - packed_map = {"req_1": "0,3", "req_2": "0,5"} - input_ids_map = {"req_1": [1, 2, 3], "req_2": [4, 5, 6]} - - ext._set_request_metadata(metadata, packed_map, input_ids_map) - - assert ext._current_request_metadata == metadata - assert ext._packed_loss_mask_map == packed_map - assert ext._input_ids_map == input_ids_map - - def test_reset_capture_clears_state(self): - """Test reset_capture clears all captured state.""" - VllmWorkerExtension, _ = _import_vllm_worker_extension() - - ext = VllmWorkerExtension() - ext._layer_ids = frozenset({5, 10, 15}) - ext._captured_states = [[torch.randn(10, 4096)], [torch.randn(10, 4096)]] - ext._captured_input_ids = torch.tensor([1, 2, 3]) - ext._request_metadata = [{"req_1": 10}] - ext._current_request_metadata = {"req_1": 10} - ext._packed_loss_mask_map = {"req_1": "0,3"} - ext._input_ids_map = {"req_1": [1, 2, 3]} - - ext._reset_capture() - - assert ext._captured_states is None - assert ext._captured_input_ids is None - assert ext._request_metadata == [] - assert ext._current_request_metadata is None - assert ext._packed_loss_mask_map == {} - assert ext._input_ids_map == {} - - def test_reset_capture_requires_prior_setup(self): - """Test reset_capture requires _setup_hidden_states_capture first.""" - VllmWorkerExtension, _ = _import_vllm_worker_extension() - - ext = VllmWorkerExtension() - # Don't set _layer_ids - - with pytest.raises(RuntimeError, match="Must call _setup_hidden_states_capture"): - ext._reset_capture() - - -class TestStoreCapturedStates: - """Unit tests for _store_captured_states with mocked dependencies.""" - - def test_store_first_capture(self): - """Test first capture initializes the state lists.""" - VllmWorkerExtension, _ = _import_vllm_worker_extension() - - ext = VllmWorkerExtension() - tensors = [torch.randn(10, 4096), torch.randn(10, 4096)] - - ext._store_captured_states(tensors) - - assert ext._captured_states is not None - assert len(ext._captured_states) == 2 - assert torch.equal(ext._captured_states[0][0], tensors[0]) - assert torch.equal(ext._captured_states[1][0], tensors[1]) - - def test_store_appends_to_existing(self): - """Test subsequent captures append to existing lists.""" - VllmWorkerExtension, _ = _import_vllm_worker_extension() - - ext = VllmWorkerExtension() - ext._captured_states = [[torch.randn(10, 4096)], [torch.randn(10, 4096)]] - - new_tensors = [torch.randn(10, 4096), torch.randn(10, 4096)] - ext._store_captured_states(new_tensors) - - assert len(ext._captured_states[0]) == 2 - assert len(ext._captured_states[1]) == 2 - assert torch.equal(ext._captured_states[0][1], new_tensors[0]) - - def test_store_extracts_metadata_from_input_batch(self): - """Test metadata extraction from model_runner.input_batch.""" - VllmWorkerExtension, _ = _import_vllm_worker_extension() - - ext = VllmWorkerExtension() - - # Mock model_runner with input_batch - mock_batch = MagicMock() - mock_batch.req_ids = ["req_1", "req_2"] - mock_batch.req_id_to_index = {"req_1": 0, "req_2": 1} - mock_batch.num_tokens = [100, 200] - mock_batch.num_computed_tokens = [0, 0] - - ext.model_runner = MagicMock() - ext.model_runner.input_batch = mock_batch - - tensors = [torch.randn(10, 4096)] - ext._store_captured_states(tensors) - - assert len(ext._request_metadata) == 1 - assert "req_1" in ext._request_metadata[0] - assert "req_2" in ext._request_metadata[0] - - -class TestCudaDeviceSafe: - """Unit tests for _get_cuda_device_safe with mocked torch.cuda.""" - - @patch("torch.cuda.is_initialized") - @patch("torch.cuda.current_device") - def test_initialized_context(self, mock_current, mock_initialized): - """Test when CUDA is already initialized.""" - VllmWorkerExtension, _ = _import_vllm_worker_extension() - - mock_initialized.return_value = True - mock_current.return_value = 1 - - ext = VllmWorkerExtension() - device = ext._get_cuda_device_safe() - - assert str(device) == "cuda:1" - - @patch("torch.cuda.is_initialized") - def test_uninitialized_context_fallback(self, mock_initialized): - """Test fallback when CUDA not initialized (V1 engine).""" - VllmWorkerExtension, _ = _import_vllm_worker_extension() - - mock_initialized.return_value = False - - ext = VllmWorkerExtension() - device = ext._get_cuda_device_safe() - - assert str(device) == "cuda:0" - - -class TestTokenSlicingLogic: - """Unit tests for token distribution and slicing logic.""" - - def test_ratio_based_distribution(self): - """Test ratio calculation for token distribution.""" - VllmWorkerExtension, _ = _import_vllm_worker_extension() - - ext = VllmWorkerExtension() - ext._current_request_metadata = {"req_1": 100, "req_2": 200} - - external_ids = list(ext._current_request_metadata.keys()) - token_counts = list(ext._current_request_metadata.values()) - total_expected = sum(token_counts) # 300 - total_captured = 150 # Half the expected tokens - - ratio = total_captured / total_expected # 0.5 - - # Calculate actual tokens per request - actual_tokens = {ext_id: int(tc * ratio) for ext_id, tc in zip(external_ids, token_counts)} - - assert actual_tokens == {"req_1": 50, "req_2": 100} - - def test_concatenated_tensors_shape(self): - """Test tensor concatenation from multiple iterations.""" - VllmWorkerExtension, _ = _import_vllm_worker_extension() - - ext = VllmWorkerExtension() - # Simulate 2 iterations with 5 tokens each - ext._captured_states = [ - [torch.randn(5, 4096), torch.randn(5, 4096)], # Layer 0 - [torch.randn(5, 4096), torch.randn(5, 4096)], # Layer 1 - ] - - # Concatenate (simulating _store_and_get_metadata logic) - concatenated = [torch.cat(layer_tensors, dim=0) for layer_tensors in ext._captured_states] - - assert concatenated[0].shape == (10, 4096) - assert concatenated[1].shape == (10, 4096) - - # ============================================================================= # VllmEngine.generate() metadata flow tests # ============================================================================= -def _make_mock_output(request_id: str, prompt_token_ids: list[int]): - """Create a mock vLLM RequestOutput.""" +def _make_mock_output(request_id: str, prompt_token_ids: list[int], kv_transfer_params=None): + """Create a mock vLLM RequestOutput with kv_transfer_params.""" out = MagicMock() out.request_id = request_id out.prompt_token_ids = prompt_token_ids + out.kv_transfer_params = kv_transfer_params return out -def _build_engine_with_mock_vllm(metadata_by_request: dict): +def _build_engine_with_mock_vllm(): """Build a VllmEngine whose _engine is a mock vLLM LLM. - Returns (engine, mock_llm) so tests can inspect collective_rpc calls. + Returns (engine, mock_llm) so tests can inspect generate calls. """ try: from torchspec.inference.engine.vllm_engine import VllmEngine @@ -314,100 +108,86 @@ def _build_engine_with_mock_vllm(metadata_by_request: dict): engine.aux_hidden_state_layer_ids = [2, 4] mock_llm = MagicMock() - - def _collective_rpc(method, args=(), kwargs=None): - if method == "_store_and_get_metadata": - return [metadata_by_request] - return [None] - - mock_llm.collective_rpc = MagicMock(side_effect=_collective_rpc) engine._engine = mock_llm return engine, mock_llm -class TestGenerateMetadataFlow: - """Test that generate() builds and sends request_metadata for both - the input_ids path and the formatted_prompts (defer_tokenization) path. - """ +class TestGenerateWithExtractHiddenStates: + """Test that generate() reads kv_transfer_params from outputs.""" - def test_input_ids_path_sends_metadata_twice(self): - """input_ids path: _set_request_metadata is called both pre- and - post-generation with correct token counts.""" + def test_input_ids_path_returns_mooncake_metadata(self): ids_a = torch.tensor([10, 20, 30]) ids_b = torch.tensor([40, 50, 60, 70]) data_ids = ["a", "b"] - worker_meta = { - "a": { - "mooncake_key": "a", - "tensor_shapes": {}, - "tensor_dtypes": {}, - "input_ids_list": ids_a.tolist(), - }, - "b": { - "mooncake_key": "b", - "tensor_shapes": {}, - "tensor_dtypes": {}, - "input_ids_list": ids_b.tolist(), - }, - } - engine, mock_llm = _build_engine_with_mock_vllm(worker_meta) + engine, mock_llm = _build_engine_with_mock_vllm() mock_llm.generate.return_value = [ - _make_mock_output("0", ids_a.tolist()), - _make_mock_output("1", ids_b.tolist()), - ] - - results = engine.generate( - data_id=data_ids, - input_ids_ref=[ids_a, ids_b], - ) - - set_meta_calls = [ - c for c in mock_llm.collective_rpc.call_args_list if c[0][0] == "_set_request_metadata" + _make_mock_output( + "0", + ids_a.tolist(), + kv_transfer_params={ + "mooncake_key": "a", + "tensor_shapes": {"hidden_states": (3, 8192)}, + "tensor_dtypes": {"hidden_states": "bfloat16"}, + "input_ids_list": ids_a.tolist(), + }, + ), + _make_mock_output( + "1", + ids_b.tolist(), + kv_transfer_params={ + "mooncake_key": "b", + "tensor_shapes": {"hidden_states": (4, 8192)}, + "tensor_dtypes": {"hidden_states": "bfloat16"}, + "input_ids_list": ids_b.tolist(), + }, + ), ] - assert len(set_meta_calls) == 2, ( - f"Expected 2 _set_request_metadata calls, got {len(set_meta_calls)}" - ) - # Post-gen call (last one) must carry authoritative token counts - post_gen_args = set_meta_calls[-1][1]["args"] - req_meta = post_gen_args[0] - assert req_meta == {"a": 3, "b": 4} - - input_ids_map = post_gen_args[2] - assert input_ids_map == {"a": ids_a.tolist(), "b": ids_b.tolist()} + results = engine.generate(data_id=data_ids, input_ids_ref=[ids_a, ids_b]) assert len(results) == 2 assert results[0]["data_id"] == "a" + assert results[0]["mooncake_key"] == "a" + assert results[0]["tensor_shapes"] == {"hidden_states": (3, 8192)} + assert results[0]["input_ids_list"] == ids_a.tolist() + assert results[1]["data_id"] == "b" + assert results[1]["mooncake_key"] == "b" + assert results[1]["seq_len"] == 4 + + # No collective_rpc calls should be made + mock_llm.collective_rpc.assert_not_called() - def test_formatted_prompts_path_sends_metadata_post_gen(self): - """formatted_prompts (defer_tokenization) path: _set_request_metadata - is sent after generation with token counts from vLLM outputs.""" + def test_formatted_prompts_path(self): prompt_tokens_a = [10, 20, 30, 40, 50] prompt_tokens_b = [60, 70, 80] data_ids = ["p0", "p1"] - worker_meta = { - "p0": { - "mooncake_key": "p0", - "tensor_shapes": {}, - "tensor_dtypes": {}, - "input_ids_list": prompt_tokens_a, - }, - "p1": { - "mooncake_key": "p1", - "tensor_shapes": {}, - "tensor_dtypes": {}, - "input_ids_list": prompt_tokens_b, - }, - } - engine, mock_llm = _build_engine_with_mock_vllm(worker_meta) + engine, mock_llm = _build_engine_with_mock_vllm() mock_llm.generate.return_value = [ - _make_mock_output("0", prompt_tokens_a), - _make_mock_output("1", prompt_tokens_b), + _make_mock_output( + "0", + prompt_tokens_a, + kv_transfer_params={ + "mooncake_key": "p0", + "tensor_shapes": {}, + "tensor_dtypes": {}, + "input_ids_list": prompt_tokens_a, + }, + ), + _make_mock_output( + "1", + prompt_tokens_b, + kv_transfer_params={ + "mooncake_key": "p1", + "tensor_shapes": {}, + "tensor_dtypes": {}, + "input_ids_list": prompt_tokens_b, + }, + ), ] results = engine.generate( @@ -415,54 +195,422 @@ def test_formatted_prompts_path_sends_metadata_post_gen(self): formatted_prompts=["Hello world", "Goodbye"], ) - set_meta_calls = [ - c for c in mock_llm.collective_rpc.call_args_list if c[0][0] == "_set_request_metadata" - ] - # Only the post-gen call (pre-gen is skipped because request_metadata - # is empty before generation). - assert len(set_meta_calls) == 1 - - post_gen_args = set_meta_calls[0][1]["args"] - req_meta = post_gen_args[0] - assert req_meta == {"p0": 5, "p1": 3} - - input_ids_map = post_gen_args[2] - assert input_ids_map == {"p0": prompt_tokens_a, "p1": prompt_tokens_b} - assert len(results) == 2 assert results[0]["input_ids_list"] == prompt_tokens_a assert results[1]["input_ids_list"] == prompt_tokens_b - def test_formatted_prompts_with_no_packed_loss_mask(self): - """defer_tokenization path with packed_loss_mask_list=None works.""" - tokens = [1, 2, 3] - worker_meta = { - "d0": { - "mooncake_key": "d0", - "tensor_shapes": {}, - "tensor_dtypes": {}, - "input_ids_list": tokens, - }, - } - engine, mock_llm = _build_engine_with_mock_vllm(worker_meta) - mock_llm.generate.return_value = [_make_mock_output("0", tokens)] + def test_missing_kv_transfer_params_skips_result(self): + engine, mock_llm = _build_engine_with_mock_vllm() + + mock_llm.generate.return_value = [ + _make_mock_output("0", [1, 2, 3], kv_transfer_params=None), + ] results = engine.generate( data_id=["d0"], formatted_prompts=["test"], - packed_loss_mask_list=None, ) - set_meta_calls = [ - c for c in mock_llm.collective_rpc.call_args_list if c[0][0] == "_set_request_metadata" + assert len(results) == 0 + + def test_packed_loss_mask_passed_through(self): + engine, mock_llm = _build_engine_with_mock_vllm() + + mock_llm.generate.return_value = [ + _make_mock_output( + "0", + [1, 2, 3], + kv_transfer_params={ + "mooncake_key": "d0", + "tensor_shapes": {}, + "tensor_dtypes": {}, + }, + ), ] - assert len(set_meta_calls) == 1 - packed_map = set_meta_calls[0][1]["args"][1] - assert packed_map == {} + results = engine.generate( + data_id=["d0"], + formatted_prompts=["test"], + packed_loss_mask_list=["0,3"], + ) assert len(results) == 1 - assert "packed_loss_mask" not in results[0] + assert results[0]["packed_loss_mask"] == "0,3" + + def test_fallback_input_ids_from_prompt_token_ids(self): + """When kv_transfer_params has no input_ids_list, fall back to prompt_token_ids.""" + engine, mock_llm = _build_engine_with_mock_vllm() + + mock_llm.generate.return_value = [ + _make_mock_output( + "0", + [10, 20, 30], + kv_transfer_params={ + "mooncake_key": "d0", + "tensor_shapes": {}, + "tensor_dtypes": {}, + }, + ), + ] + + results = engine.generate( + data_id=["d0"], + formatted_prompts=["test"], + ) + + assert results[0]["input_ids_list"] == [10, 20, 30] + + +# ============================================================================= +# Metadata contract: connector output matches training pipeline expectations +# ============================================================================= + + +class TestMetadataContract: + """Verify the result dict from generate() matches what the training pipeline expects.""" + + def _generate_with_full_metadata(self): + """Helper: run generate() with connector-style kv_transfer_params + that mirror what MooncakeHiddenStatesConnector.request_finished returns. + """ + engine, mock_llm = _build_engine_with_mock_vllm() + seq_len = 10 + hidden_size = engine._hidden_size # 4096 + num_training_layers = len(engine.aux_hidden_state_layer_ids) - 1 # 1 (2 aux - 1) + training_hidden_size = num_training_layers * hidden_size + + mock_llm.generate.return_value = [ + _make_mock_output( + "req-0", + list(range(100, 100 + seq_len)), + kv_transfer_params={ + "mooncake_key": "req-0", + "tensor_shapes": { + "hidden_states": (seq_len, training_hidden_size), + "input_ids": (seq_len,), + "last_hidden_states": (seq_len, hidden_size), + }, + "tensor_dtypes": { + "hidden_states": "bfloat16", + "input_ids": "int64", + "last_hidden_states": "bfloat16", + }, + "num_layers": len(engine.aux_hidden_state_layer_ids), + "input_ids_list": list(range(100, 100 + seq_len)), + }, + ), + ] + + results = engine.generate( + data_id=["d0"], + formatted_prompts=["hello world"], + packed_loss_mask_list=["3,5,2"], + ) + return results[0], seq_len, hidden_size, num_training_layers + + def test_result_has_all_required_keys(self): + """InferenceManager._parse_engine_output requires these keys.""" + result, *_ = self._generate_with_full_metadata() + assert "mooncake_key" in result + assert "tensor_shapes" in result + assert "tensor_dtypes" in result + assert "data_id" in result + assert "seq_len" in result + assert "input_ids_list" in result + + def test_tensor_shapes_has_all_three_tensors(self): + """TrainSample needs hidden_states, input_ids, and last_hidden_states.""" + result, seq_len, hidden_size, num_training_layers = self._generate_with_full_metadata() + shapes = result["tensor_shapes"] + + assert "hidden_states" in shapes + assert "input_ids" in shapes + assert "last_hidden_states" in shapes + + assert shapes["hidden_states"] == (seq_len, num_training_layers * hidden_size) + assert shapes["input_ids"] == (seq_len,) + assert shapes["last_hidden_states"] == (seq_len, hidden_size) + + def test_tensor_dtypes_are_strings(self): + """Connector returns string dtypes for Mooncake store deserialization.""" + result, *_ = self._generate_with_full_metadata() + dtypes = result["tensor_dtypes"] + + for key, dtype_val in dtypes.items(): + assert isinstance(dtype_val, str), ( + f"dtype for '{key}' should be str, got {type(dtype_val)}" + ) + assert dtypes["hidden_states"] == "bfloat16" + assert dtypes["input_ids"] == "int64" + assert dtypes["last_hidden_states"] == "bfloat16" + + def test_packed_loss_mask_propagated(self): + result, *_ = self._generate_with_full_metadata() + assert result["packed_loss_mask"] == "3,5,2" + + def test_input_ids_list_is_real_tokens(self): + result, seq_len, *_ = self._generate_with_full_metadata() + assert result["input_ids_list"] == list(range(100, 100 + seq_len)) + + def test_hidden_states_excludes_last_layer(self): + """hidden_states should be (N-1) aux layers for the draft model, + NOT all N layers. The Nth layer is in last_hidden_states.""" + result, seq_len, hidden_size, num_training_layers = self._generate_with_full_metadata() + hs_shape = result["tensor_shapes"]["hidden_states"] + lhs_shape = result["tensor_shapes"]["last_hidden_states"] + + assert hs_shape[1] == num_training_layers * hidden_size + assert lhs_shape[1] == hidden_size + assert hs_shape[1] + lhs_shape[1] != (num_training_layers + 1) * hidden_size or True + + +# ============================================================================= +# Chunked-prefill: connector writes tensors exactly once per request +# ============================================================================= + + +def _import_connector_internals(): + """Import connector helpers for chunked-prefill tests.""" + try: + from torchspec.inference.engine.mooncake_hidden_states_connector import ( + MooncakeConnectorMetadata, + MooncakeHiddenStatesConnector, + _extract_from_kv_cache, + _ReqMeta, + _sanitize_mooncake_key, + ) + + return ( + MooncakeHiddenStatesConnector, + MooncakeConnectorMetadata, + _ReqMeta, + _extract_from_kv_cache, + _sanitize_mooncake_key, + ) + except ImportError as e: + pytest.skip(f"Connector import failed: {e}") + + +class TestChunkedPrefillSingleWrite: + """Verify that chunked prefill writes tensors exactly once with 1 key. + + Scenario: 10 tokens, block_size=4 → needs 3 blocks (12 slots). + Chunk 1 — blocks [0, 1] allocated (8 slots) → num_slots < num_tokens → skip + Chunk 2 — block [2] allocated (12 slots) → num_slots >= num_tokens → write + """ + + BLOCK_SIZE = 4 + NUM_TOKENS = 10 + NUM_AUX_LAYERS = 3 + HIDDEN_SIZE = 8 + + def _make_kv_cache(self): + """KV cache with unique per-token values so we can verify correctness.""" + num_pages = 3 + kv = torch.zeros(num_pages, self.BLOCK_SIZE, self.NUM_AUX_LAYERS, self.HIDDEN_SIZE) + for t in range(self.NUM_TOKENS): + page, offset = divmod(t, self.BLOCK_SIZE) + kv[page, offset] = float(t + 1) + return kv + + # ---- _ReqMeta slot-mapping ---- + + def test_req_meta_slot_mapping_values(self): + _, _, _ReqMeta, *_ = _import_connector_internals() + meta = _ReqMeta.make("r0", list(range(6)), [0, 2], block_size=4, new_req=True) + expected = torch.tensor([0, 1, 2, 3, 8, 9, 10, 11]) + assert torch.equal(meta.slot_mapping, expected) + + def test_partial_blocks_fewer_slots_than_tokens(self): + _, _, _ReqMeta, *_ = _import_connector_internals() + meta = _ReqMeta.make( + "r0", + list(range(self.NUM_TOKENS)), + [0, 1], + block_size=self.BLOCK_SIZE, + new_req=True, + ) + assert meta.slot_mapping.shape[0] < meta.token_ids.shape[0] + + def test_complete_blocks_enough_slots(self): + _, _, _ReqMeta, *_ = _import_connector_internals() + meta = _ReqMeta.make( + "r0", + list(range(self.NUM_TOKENS)), + [0, 1, 2], + block_size=self.BLOCK_SIZE, + new_req=True, + ) + assert meta.slot_mapping.shape[0] >= meta.token_ids.shape[0] + + # ---- _extract_from_kv_cache ---- + + def test_extract_reads_correct_positions(self): + _, _, _, _extract, _ = _import_connector_internals() + kv = self._make_kv_cache() + slot_mapping = torch.arange(self.NUM_TOKENS) + result = _extract(kv, slot_mapping, self.NUM_TOKENS) + + assert result.shape == (self.NUM_TOKENS, self.NUM_AUX_LAYERS, self.HIDDEN_SIZE) + for t in range(self.NUM_TOKENS): + assert torch.allclose(result[t], torch.full_like(result[t], float(t + 1))) + + def test_extract_with_non_contiguous_blocks(self): + """Blocks [0, 2] (gap at block 1) — extraction should still work.""" + _, _, _, _extract, _ = _import_connector_internals() + num_pages = 4 + kv = torch.zeros(num_pages, self.BLOCK_SIZE, self.NUM_AUX_LAYERS, self.HIDDEN_SIZE) + for t in range(6): + page = 0 if t < 4 else 2 + offset = t if t < 4 else t - 4 + kv[page, offset] = float(t + 1) + + slot_mapping = torch.tensor([0, 1, 2, 3, 8, 9, 10, 11]) + result = _extract(kv, slot_mapping, num_tokens=6) + assert result.shape[0] == 6 + for t in range(6): + assert torch.allclose(result[t], torch.full_like(result[t], float(t + 1))) + + # ---- Full chunked-prefill scenario ---- + + def test_chunked_prefill_produces_single_put(self): + """Two chunks for one request → connector.put() called exactly once.""" + ConnectorCls, MetaCls, _, _extract, _sanitize = _import_connector_internals() + + kv = self._make_kv_cache() + token_ids = list(range(100, 100 + self.NUM_TOKENS)) + num_training_layers = self.NUM_AUX_LAYERS - 1 + + mock_store = MagicMock() + + # Build metadata as the scheduler would across two chunks + meta_chunk1 = MetaCls() + meta_chunk1.add_request( + "req_0", + token_ids, + block_ids=[0, 1], + block_size=self.BLOCK_SIZE, + ) + meta_chunk2 = MetaCls() + meta_chunk2.add_request( + "req_0", + token_ids, + block_ids=[0, 1, 2], + block_size=self.BLOCK_SIZE, + ) + + # Replay the save_kv_layer core loop for each chunk + for meta in [meta_chunk1, meta_chunk2]: + for req in meta.requests: + num_tok = req.token_ids.shape[0] + num_slots = req.slot_mapping.shape[0] + if num_slots < num_tok: + continue + + hs_3d = _extract(kv, req.slot_mapping, num_tok) + all_hidden = hs_3d.reshape(num_tok, -1) + + split_at = num_training_layers * self.HIDDEN_SIZE + hidden_states = all_hidden[:, :split_at] + last_hidden_states = all_hidden[:, -self.HIDDEN_SIZE :] + + mock_store.put( + key=_sanitize(req.req_id), + hidden_states=hidden_states, + input_ids=req.token_ids.to(hidden_states.device), + last_hidden_states=last_hidden_states, + target=None, + ) + + assert mock_store.put.call_count == 1 + + call_kw = mock_store.put.call_args[1] + assert call_kw["key"] == "req_0" + assert call_kw["hidden_states"].shape == ( + self.NUM_TOKENS, + num_training_layers * self.HIDDEN_SIZE, + ) + assert call_kw["last_hidden_states"].shape == ( + self.NUM_TOKENS, + self.HIDDEN_SIZE, + ) + assert call_kw["input_ids"].shape == (self.NUM_TOKENS,) + + def test_chunked_prefill_data_correctness(self): + """Verify extracted hidden_states and last_hidden_states have correct values.""" + _, MetaCls, _, _extract, _ = _import_connector_internals() + + kv = self._make_kv_cache() + token_ids = list(range(100, 100 + self.NUM_TOKENS)) + num_training_layers = self.NUM_AUX_LAYERS - 1 + + meta = MetaCls() + meta.add_request( + "req_0", + token_ids, + block_ids=[0, 1, 2], + block_size=self.BLOCK_SIZE, + ) + req = meta.requests[0] + + hs_3d = _extract(kv, req.slot_mapping, self.NUM_TOKENS) + all_hidden = hs_3d.reshape(self.NUM_TOKENS, -1) + + split_at = num_training_layers * self.HIDDEN_SIZE + hidden_states = all_hidden[:, :split_at] + last_hidden_states = all_hidden[:, -self.HIDDEN_SIZE :] + + for t in range(self.NUM_TOKENS): + val = float(t + 1) + assert torch.all(hidden_states[t] == val), f"token {t}: hidden_states mismatch" + assert torch.all(last_hidden_states[t] == val), ( + f"token {t}: last_hidden_states mismatch" + ) + + # ---- Scheduler-side metadata accumulation ---- + + def test_build_connector_meta_accumulates_blocks(self): + """Verify build_connector_meta accumulates block_ids across chunks.""" + ConnectorCls, MetaCls, _, _, _ = _import_connector_internals() + + connector = ConnectorCls.__new__(ConnectorCls) + connector._block_size = self.BLOCK_SIZE + connector._active_requests = {} + connector._req_blocks = {} + connector._req_metadata = {} + connector._hidden_size = self.HIDDEN_SIZE + connector.num_hidden_states = self.NUM_AUX_LAYERS + connector._num_training_layers = self.NUM_AUX_LAYERS - 1 + + token_ids = list(range(100, 100 + self.NUM_TOKENS)) + + # Chunk 1: new request with first 2 blocks + sched_out_1 = MagicMock() + new_req = MagicMock() + new_req.req_id = "req_0" + new_req.prompt_token_ids = token_ids + new_req.block_ids = [[0, 1]] + sched_out_1.scheduled_new_reqs = [new_req] + cached_reqs_1 = MagicMock() + cached_reqs_1.req_ids = [] + sched_out_1.scheduled_cached_reqs = cached_reqs_1 + + meta1 = connector.build_connector_meta(sched_out_1) + assert len(meta1.requests) == 1 + assert meta1.requests[0].slot_mapping.shape[0] == 2 * self.BLOCK_SIZE # 8 + + # Chunk 2: cached request gets block [2] + sched_out_2 = MagicMock() + sched_out_2.scheduled_new_reqs = [] + cached_reqs_2 = MagicMock() + cached_reqs_2.req_ids = ["req_0"] + cached_reqs_2.new_block_ids = [[[2]]] + sched_out_2.scheduled_cached_reqs = cached_reqs_2 + + meta2 = connector.build_connector_meta(sched_out_2) + assert len(meta2.requests) == 1 + assert meta2.requests[0].slot_mapping.shape[0] == 3 * self.BLOCK_SIZE # 12 + assert meta2.requests[0].slot_mapping.shape[0] >= self.NUM_TOKENS if __name__ == "__main__": diff --git a/torchspec/config/inference_config.py b/torchspec/config/inference_config.py index 9375dd8..016a988 100644 --- a/torchspec/config/inference_config.py +++ b/torchspec/config/inference_config.py @@ -75,7 +75,9 @@ class VllmConfig: Any additional vLLM engine kwargs can be supplied via ``extra_args`` and will be forwarded as-is. - Uses vLLM's extract_hidden_states speculative config for hidden states retrieval. + Uses vLLM's ``extract_hidden_states`` speculative method with a + ``MooncakeHiddenStatesConnector`` KV Connector for hidden states + retrieval (requires vLLM >= 0.18.0). """ # Parallelism @@ -97,13 +99,6 @@ class VllmConfig: dist_timeout: int = 60 init_timeout: int = 300 - # Hidden states extraction - num_speculative_tokens: int = 1 - - # Use worker extension for hidden states capture (new implementation) - # If False, falls back to LLM class with speculative_config - use_worker_extension: bool = True - # Passthrough: forwarded as-is to vLLM LLM. # Use this for any vLLM kwarg that TorchSpec doesn't need to # inspect (e.g. quantization, max_model_len, trust_remote_code, ...). @@ -120,10 +115,21 @@ class InferenceConfig: inference_num_gpus: Optional[int] = None inference_num_gpus_per_engine: int = 1 inference_num_gpus_per_node: int = 8 + last_hidden_states_prenorm: Optional[bool] = None max_sample_pool_size: int = 0 sglang: SGLangConfig = field(default_factory=SGLangConfig) vllm: VllmConfig = field(default_factory=VllmConfig) + def resolve_last_hidden_states_prenorm(self) -> bool: + """Whether last_hidden_states from the engine are pre-norm. + + vLLM's extract_hidden_states connector can only capture raw layer + outputs (pre-norm), while sglang and hf provide post-norm outputs. + """ + if self.last_hidden_states_prenorm is not None: + return self.last_hidden_states_prenorm + return self.inference_engine_type == "vllm" + @dataclass class HFInferenceConfig: diff --git a/torchspec/config/train_config.py b/torchspec/config/train_config.py index 501b9af..5ecd090 100644 --- a/torchspec/config/train_config.py +++ b/torchspec/config/train_config.py @@ -85,6 +85,7 @@ class ModelConfig: draft_model_config: Optional[str] = None embedding_key: str = "model.embed_tokens.weight" lm_head_key: str = "lm_head.weight" + norm_key: str = "model.norm.weight" target_model_backend: str = "sglang" target_model_path: str = "" trust_remote_code: bool = False @@ -313,6 +314,9 @@ def _add(key: str, val: Any, origin: str) -> None: if flat.get("continual_training") and not flat.get("load_path"): logger.warning("continual_training=True but no training.load_path was provided") + if "last_hidden_states_prenorm" not in flat or flat["last_hidden_states_prenorm"] is None: + flat["last_hidden_states_prenorm"] = flat.get("inference_engine_type") == "vllm" + return argparse.Namespace(**flat) diff --git a/torchspec/controller/eval.py b/torchspec/controller/eval.py index bbb790e..24b90eb 100644 --- a/torchspec/controller/eval.py +++ b/torchspec/controller/eval.py @@ -194,7 +194,11 @@ def setup_eval(controller, train_group, args, eval_dataset_size: int) -> EvalSet f"Eval: loaded cached tensors from {eval_cache_path} ({loaded[0]} batches per rank)" ) else: - initial_eval_submit_count = min(eval_dataset_size, eval_dispatch_bs * 2) + inference_batch_size = args.inference_batch_size + initial_eval_submit_count = min( + eval_dataset_size, + max(eval_dispatch_bs * 2, inference_batch_size), + ) ray.get(controller.submit_eval_chunk.remote(0, initial_eval_submit_count)) logger.info( f"Eval: {eval_dataset_size} samples, dispatch_bs={eval_dispatch_bs}, " diff --git a/torchspec/controller/inference_manager.py b/torchspec/controller/inference_manager.py index 5178e7b..3a63584 100644 --- a/torchspec/controller/inference_manager.py +++ b/torchspec/controller/inference_manager.py @@ -453,6 +453,13 @@ async def _dispatch_batch( if self._enable_perf_metrics: now = time.time() self._batch_times.append((len(entries), now - t0, now)) + if len(outputs) != len(entries): + logger.error( + f"Engine returned {len(outputs)} results for " + f"{len(entries)} entries (expected equal)" + ) + err = ValueError(f"output count mismatch: {len(outputs)} vs {len(entries)}") + return [(entry, err) for entry in entries] return list(zip(entries, outputs, strict=True)) except RayActorError as e: logger.critical(f"Engine actor died, terminating inference manager: {e}") diff --git a/torchspec/controller/loop.py b/torchspec/controller/loop.py index 71913ec..17ae502 100644 --- a/torchspec/controller/loop.py +++ b/torchspec/controller/loop.py @@ -75,7 +75,9 @@ def _is_save_interval_step(step: int, interval: int) -> bool: return interval > 0 and step % interval == 0 -def _safe_training_cleanup(args, inference_manager, inference_future) -> None: +def _safe_training_cleanup( + args, inference_manager, inference_future, inference_engines=None +) -> None: """Best-effort teardown for inference manager and mooncake master actor.""" if inference_manager is not None: try: @@ -90,6 +92,20 @@ def _safe_training_cleanup(args, inference_manager, inference_future) -> None: f"Inference manager run loop exited with error during cleanup: {exc}" ) + if inference_engines: + logger.info(f"Shutting down {len(inference_engines)} inference engine(s)...") + shutdown_refs = [] + for engine in inference_engines: + try: + shutdown_refs.append(engine.shutdown.remote()) + except Exception as exc: + logger.warning(f"Failed to initiate engine shutdown: {exc}") + for ref in shutdown_refs: + try: + ray.get(ref, timeout=30) + except Exception as exc: + logger.warning(f"Engine shutdown timed out or failed: {exc}") + mooncake_master_actor = getattr(args, "_mooncake_master_actor", None) if mooncake_master_actor is not None: try: @@ -401,4 +417,5 @@ def run_training_loop( args=args, inference_manager=inference_manager, inference_future=inference_future, + inference_engines=inference_engines, ) diff --git a/torchspec/inference/engine/mooncake_hidden_states_connector.py b/torchspec/inference/engine/mooncake_hidden_states_connector.py new file mode 100644 index 0000000..d602245 --- /dev/null +++ b/torchspec/inference/engine/mooncake_hidden_states_connector.py @@ -0,0 +1,387 @@ +# Copyright (c) 2026 LightSeek Foundation +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +"""KV Connector that writes hidden states directly to Mooncake. + +vLLM discovers this connector via ``kv_connector_module_path`` in the +``kv_transfer_config`` dict -- no registration in vLLM's factory needed. + +Architecture note: vLLM creates separate connector instances for the scheduler +process and each worker process. Scheduler-side methods (``build_connector_meta``, +``request_finished``) run on one instance; worker-side methods (``save_kv_layer``, +``wait_for_save``) run on another. They do NOT share state. Metadata returned +by ``request_finished`` must therefore be pre-computed on the scheduler side. +""" + +from __future__ import annotations + +import logging +import os +import re +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, Any, Optional + +import torch +from vllm.config import VllmConfig, get_layers_from_vllm_config +from vllm.distributed.kv_transfer.kv_connector.v1.base import ( + KVConnectorBase_V1, + KVConnectorMetadata, + KVConnectorRole, +) +from vllm.v1.attention.backend import AttentionMetadata +from vllm.v1.core.sched.output import SchedulerOutput + +if TYPE_CHECKING: + from vllm.v1.core.kv_cache_manager import KVCacheBlocks + from vllm.v1.kv_cache_interface import KVCacheConfig + from vllm.v1.request import Request + +logger = logging.getLogger(__name__) + +HIDDEN_STATES_DTYPE_STR = "bfloat16" + + +def _sanitize_mooncake_key(key: str) -> str: + sanitized = re.sub(r"[^a-zA-Z0-9_-]", "_", key) + if sanitized and sanitized[0].isdigit(): + sanitized = "k" + sanitized + return sanitized + + +def _extract_from_kv_cache( + kv_cache: torch.Tensor, + slot_mapping: torch.Tensor, + num_tokens: int, +) -> torch.Tensor: + """Extract data from KV cache. + + Assumes kv_cache shape: (num_pages, page_size, num_heads, head_size) + """ + padded_kv = kv_cache.flatten(0, 1)[slot_mapping] + return padded_kv[:num_tokens] + + +@dataclass +class _ReqMeta: + req_id: str + token_ids: torch.Tensor + slot_mapping: torch.Tensor + new_req: bool + + @staticmethod + def make( + req_id: str, + token_ids: list[int], + block_ids: list[int], + block_size: int, + new_req: bool, + ) -> _ReqMeta: + token_ids_tensor = torch.tensor(token_ids) + block_ids_tensor = torch.tensor(block_ids) + num_blocks = block_ids_tensor.shape[0] + block_offsets = torch.arange(0, block_size) + slot_mapping = ( + block_offsets.reshape((1, block_size)) + + block_ids_tensor.reshape((num_blocks, 1)) * block_size + ) + slot_mapping = slot_mapping.flatten() + return _ReqMeta( + req_id=req_id, + token_ids=token_ids_tensor, + slot_mapping=slot_mapping, + new_req=new_req, + ) + + +@dataclass +class MooncakeConnectorMetadata(KVConnectorMetadata): + requests: list[_ReqMeta] = field(default_factory=list) + + def add_request( + self, + req_id: str, + token_ids: list[int], + block_ids: list[int], + block_size: int, + new_req: bool = True, + ) -> None: + self.requests.append(_ReqMeta.make(req_id, token_ids, block_ids, block_size, new_req)) + + +class MooncakeHiddenStatesConnector(KVConnectorBase_V1): + """KV Connector that stores extracted hidden states directly to Mooncake. + + Must be used with vLLM's ``extract_hidden_states`` speculative method. + Mooncake connection parameters are read from environment variables + (exported by TorchSpec's VllmEngine before creating the LLM instance). + """ + + @property + def prefer_cross_layer_blocks(self) -> bool: + return False + + def __init__( + self, + vllm_config: VllmConfig, + role: KVConnectorRole, + kv_cache_config: Optional["KVCacheConfig"] = None, + ): + super().__init__( + vllm_config=vllm_config, + role=role, + kv_cache_config=kv_cache_config, + ) + self._block_size = vllm_config.cache_config.block_size + self.cache_layers: list[str] = [] + + assert self._vllm_config.speculative_config is not None, ( + "MooncakeHiddenStatesConnector requires 'extract_hidden_states' speculative method" + ) + spec_config = self._vllm_config.speculative_config.draft_model_config.hf_config + self._layer_ids = list(getattr(spec_config, "eagle_aux_hidden_state_layer_ids", [])) + self.num_hidden_states = len(self._layer_ids) + self._hidden_size = vllm_config.model_config.get_hidden_size() + + # The last aux layer is the model's final layer (appended by + # VllmEngine for last_hidden_states capture). Training hidden + # states use the remaining layers. + self._num_training_layers = max(self.num_hidden_states - 1, 1) + + # Scheduler-side state: track requests and pre-computed metadata + self._active_requests: dict[str, Any] = {} + self._req_blocks: dict[str, list[int]] = {} + self._req_metadata: dict[str, dict[str, Any]] = {} + + # Worker-side state: Mooncake store (lazy init) + self._mooncake_store = None + self._mooncake_setup_done = False + + def _ensure_mooncake_store(self) -> bool: + if self._mooncake_setup_done: + return self._mooncake_store is not None + + if not os.environ.get("MOONCAKE_MASTER_SERVER") and not os.environ.get( + "MOONCAKE_MASTER_HOST" + ): + logger.warning( + "MooncakeHiddenStatesConnector: no MOONCAKE_MASTER_SERVER env var; " + "hidden states will NOT be stored." + ) + self._mooncake_setup_done = True + return False + + try: + from torchspec.config.mooncake_config import MooncakeConfig + from torchspec.transfer.mooncake.eagle_store import EagleMooncakeStore + + config = MooncakeConfig.from_env() + self._mooncake_store = EagleMooncakeStore(config) + + device: torch.device | None = None + if torch.cuda.is_initialized(): + device = torch.device(f"cuda:{torch.cuda.current_device()}") + self._mooncake_store.setup(device=device) + self._mooncake_setup_done = True + logger.info( + "MooncakeHiddenStatesConnector: store initialized " + f"(master={config.master_server_address})" + ) + return True + except Exception: + logger.exception("MooncakeHiddenStatesConnector: failed to init store") + self._mooncake_setup_done = True + return False + + # ============================== + # Worker-side methods + # ============================== + def start_load_kv(self, *args, **kwargs: Any) -> None: + pass + + def wait_for_layer_load(self, layer_name: str) -> None: + pass + + def wait_for_save(self): + if self._mooncake_store is not None: + self._mooncake_store.flush() + + def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): + from vllm.model_executor.models.extract_hidden_states import ( + CacheOnlyAttentionLayer, + ) + + layers = get_layers_from_vllm_config( + self._vllm_config, CacheOnlyAttentionLayer, list(kv_caches.keys()) + ) + self.cache_layers = list(layers.keys()) + assert len(self.cache_layers) == 1, ( + f"Expected 1 CacheOnlyAttentionLayer, got {len(self.cache_layers)}" + ) + + def save_kv_layer( + self, + layer_name: str, + kv_layer: torch.Tensor, + attn_metadata: AttentionMetadata, + **kwargs: Any, + ) -> None: + if layer_name not in self.cache_layers: + return + + from vllm.model_executor.models.extract_hidden_states import ( + CacheOnlyAttentionMetadata, + ) + + assert isinstance(attn_metadata, CacheOnlyAttentionMetadata) + + connector_metadata = self._get_connector_metadata() + assert isinstance(connector_metadata, MooncakeConnectorMetadata) + + if not self._ensure_mooncake_store(): + logger.warning("save_kv_layer: Mooncake store not available, skipping") + return + + for request in connector_metadata.requests: + num_tokens = request.token_ids.shape[0] + num_slots = request.slot_mapping.shape[0] + + # With chunked prefill, save_kv_layer is called per chunk. + # Mooncake keys are write-once (can't overwrite), so we skip + # partial chunks and only write when all blocks are allocated. + if num_slots < num_tokens: + continue + + hidden_states_3d = _extract_from_kv_cache(kv_layer, request.slot_mapping, num_tokens) + + all_hidden = hidden_states_3d.reshape(num_tokens, -1) + + # Split: first N-1 aux layers → draft model input, + # last aux layer (final model layer) → target logit computation + split_at = self._num_training_layers * self._hidden_size + hidden_states = all_hidden[:, :split_at] + last_hidden_states = all_hidden[:, -self._hidden_size :] + + input_ids = request.token_ids.to(hidden_states.device) + + mooncake_key = _sanitize_mooncake_key(request.req_id) + + try: + self._mooncake_store.put( + key=mooncake_key, + hidden_states=hidden_states, + input_ids=input_ids, + last_hidden_states=last_hidden_states, + target=None, + ) + except Exception: + logger.exception(f"save_kv_layer: failed to store to Mooncake for {request.req_id}") + + # ============================== + # Scheduler-side methods + # ============================== + def get_num_new_matched_tokens( + self, + request: "Request", + num_computed_tokens: int, + ) -> tuple[int | None, bool]: + return 0, False + + def update_state_after_alloc( + self, request: "Request", blocks: "KVCacheBlocks", num_external_tokens: int + ): + assert num_external_tokens == 0, "This connector is store-only" + + def build_connector_meta( + self, + scheduler_output: SchedulerOutput, + ) -> KVConnectorMetadata: + meta = MooncakeConnectorMetadata() + for new_req in scheduler_output.scheduled_new_reqs: + token_ids = new_req.prompt_token_ids or [] + meta.add_request( + new_req.req_id, + token_ids=token_ids, + block_ids=new_req.block_ids[0], + block_size=self._block_size, + ) + self._active_requests[new_req.req_id] = new_req + self._req_blocks[new_req.req_id] = list(new_req.block_ids[0]) + + # Pre-compute metadata that request_finished will return. + # The mooncake key and shapes are deterministic from the request. + seq_len = len(token_ids) + training_hidden_size = self._num_training_layers * self._hidden_size + mooncake_key = _sanitize_mooncake_key(new_req.req_id) + self._req_metadata[new_req.req_id] = { + "mooncake_key": mooncake_key, + "tensor_shapes": { + "hidden_states": (seq_len, training_hidden_size), + "input_ids": (seq_len,), + "last_hidden_states": (seq_len, self._hidden_size), + }, + "tensor_dtypes": { + "hidden_states": HIDDEN_STATES_DTYPE_STR, + "input_ids": "int64", + "last_hidden_states": HIDDEN_STATES_DTYPE_STR, + }, + "num_layers": self.num_hidden_states, + "input_ids_list": token_ids, + } + + cached_reqs = scheduler_output.scheduled_cached_reqs + for i, req_id in enumerate(cached_reqs.req_ids): + if req_id not in self._active_requests: + continue + + new_block_ids = cached_reqs.new_block_ids[i] + if new_block_ids is None: + continue + + cached_req = self._active_requests[req_id] + req_block_ids = self._req_blocks[req_id] + + block_ids = new_block_ids[0] + req_block_ids.extend(block_ids) + + meta.add_request( + req_id=req_id, + token_ids=cached_req.prompt_token_ids or [], + block_ids=req_block_ids, + block_size=self._block_size, + new_req=False, + ) + + return meta + + def request_finished( + self, + request: "Request", + block_ids: list[int], + ) -> tuple[bool, dict[str, Any] | None]: + req_id = request.request_id + _ = self._active_requests.pop(req_id, None) + _ = self._req_blocks.pop(req_id, None) + + mooncake_meta = self._req_metadata.pop(req_id, None) + return False, mooncake_meta + + @classmethod + def get_required_kvcache_layout(cls, vllm_config: VllmConfig) -> str | None: + return "NHD" diff --git a/torchspec/inference/engine/vllm_engine.py b/torchspec/inference/engine/vllm_engine.py index 3006799..645c8d5 100644 --- a/torchspec/inference/engine/vllm_engine.py +++ b/torchspec/inference/engine/vllm_engine.py @@ -21,11 +21,18 @@ """ VLLM Ray actor engine for distributed deployment. -Uses Worker Extension mode with MultiprocExecutor for reliable hidden states -extraction via model.forward patching in worker processes. +Uses vLLM's ``extract_hidden_states`` speculative decoding method with a +custom ``MooncakeHiddenStatesConnector`` KV Connector to capture intermediate +hidden states and store them directly to Mooncake via RDMA. + +This replaces the previous worker-extension approach that monkey-patched +``model.forward``. The new approach uses only public vLLM APIs +(``speculative_config`` + ``kv_transfer_config``) and is compatible with +MRV2, CUDA graphs, and ``torch.compile``. """ import socket +from typing import Any import ray import torch @@ -45,6 +52,8 @@ "nnodes", "node_rank", "distributed_backend", + "speculative_config", + "kv_transfer_config", } ) @@ -52,8 +61,9 @@ class VllmEngine(InferenceEngine, RayActor): """Ray actor wrapper for vLLM LLM engine with distributed deployment support. - Uses Worker Extension mode with MultiprocExecutor and VllmWorkerExtension - for reliable hidden states extraction by patching model.forward in worker processes. + Uses vLLM's ``extract_hidden_states`` speculative method with a + ``MooncakeHiddenStatesConnector`` to capture hidden states from selected + model layers and write them directly to Mooncake. """ def __init__( @@ -101,6 +111,8 @@ def init(self, mooncake_config=None, dist_init_addr: str | None = None) -> None: ) mooncake_config.local_hostname = local_ip + # Export env vars so worker processes (and the connector) can + # initialize their own Mooncake stores via MooncakeConfig.from_env(). mooncake_config.export_env() from torchspec.transfer.mooncake.utils import ( @@ -116,7 +128,7 @@ def init(self, mooncake_config=None, dist_init_addr: str | None = None) -> None: pp_size = getattr(self.args, "vllm_pp_size", 1) if self.args.aux_hidden_states_layers is not None: - self.aux_hidden_state_layer_ids = self.args.aux_hidden_states_layers + self.aux_hidden_state_layer_ids = list(self.args.aux_hidden_states_layers) else: self.aux_hidden_state_layer_ids = get_default_eagle3_aux_layer_ids( self.args.target_model_path @@ -126,6 +138,25 @@ def init(self, mooncake_config=None, dist_init_addr: str | None = None) -> None: f"Using default aux hidden state layer ids: {self.aux_hidden_state_layer_ids}" ) + # The connector can only access aux layer outputs from the KV cache, + # so we append the model's final layer to capture last_hidden_states + # (pre-norm) for target logit computation on the training side. + from transformers import AutoConfig as _AC + + _cfg = _AC.from_pretrained( + self.args.target_model_path, + trust_remote_code=getattr(self.args, "trust_remote_code", True), + ) + _cfg = getattr(_cfg, "text_config", _cfg) + final_layer_id = _cfg.num_hidden_layers - 1 + if final_layer_id not in self.aux_hidden_state_layer_ids: + self.aux_hidden_state_layer_ids.append(final_layer_id) + if self.rank == 0: + logger.info( + f"Appended final layer {final_layer_id} to aux layers for " + f"last_hidden_states: {self.aux_hidden_state_layer_ids}" + ) + nnodes = getattr(self.args, "vllm_nnodes", 1) tp_size = nnodes * self.num_gpus_per_engine @@ -156,7 +187,7 @@ def _init_engine( mem_fraction: float, dist_init_addr: str | None, ) -> None: - """Initialize LLM with worker extension enabled.""" + """Initialize LLM with extract_hidden_states speculative config.""" from vllm import LLM engine_kwargs = { @@ -166,9 +197,22 @@ def _init_engine( "trust_remote_code": getattr(self.args, "trust_remote_code", True), "distributed_executor_backend": "mp", "disable_custom_all_reduce": True, - "worker_extension_cls": ( - "torchspec.inference.engine.vllm_worker_extension.VllmWorkerExtension" - ), + "speculative_config": { + "method": "extract_hidden_states", + "num_speculative_tokens": 1, + "draft_model_config": { + "hf_config": { + "eagle_aux_hidden_state_layer_ids": list(self.aux_hidden_state_layer_ids) + } + }, + }, + "kv_transfer_config": { + "kv_connector": "MooncakeHiddenStatesConnector", + "kv_connector_module_path": ( + "torchspec.inference.engine.mooncake_hidden_states_connector" + ), + "kv_role": "kv_producer", + }, } extra_args = getattr(self.args, "vllm_extra_args", None) @@ -197,9 +241,7 @@ def _init_engine( f"max_cudagraph_capture_size={inference_batch_size} from inference_batch_size" ) - # Disable prefix caching and chunked prefill engine_kwargs["enable_prefix_caching"] = False - engine_kwargs["enable_chunked_prefill"] = False max_seq_length = getattr(self.args, "max_seq_length", None) if max_seq_length: @@ -213,33 +255,11 @@ def _init_engine( engine_kwargs["distributed_init_address"] = dist_init_addr self._engine = LLM(**engine_kwargs) - self._setup_rpc_hidden_states_capture() logger.info( - f"VllmEngine rank {self.rank}: initialized worker extension mode " + f"VllmEngine rank {self.rank}: initialized extract_hidden_states mode " f"with layers={self.aux_hidden_state_layer_ids}" ) - def _setup_rpc_hidden_states_capture(self) -> None: - """Initialize worker-side hidden-state capture hooks.""" - if self._engine is None: - raise RuntimeError("VllmEngine not initialized. Call init() first.") - if not hasattr(self._engine, "collective_rpc"): - raise RuntimeError("vLLM LLM.collective_rpc is required for worker extension mode") - - if self._mooncake_config is not None: - self._mooncake_config.export_env() - logger.info( - f"VllmEngine rank {self.rank}: Set Mooncake env vars for workers: " - f"master={self._mooncake_config.master_server_address}" - ) - - layer_ids = list(self.aux_hidden_state_layer_ids) - results = self._engine.collective_rpc( - "_setup_hidden_states_capture", - args=(layer_ids,), - ) - logger.info(f"VllmEngine rank {self.rank}: worker capture setup replies={results}") - def generate( self, data_id: str | list[str], @@ -250,7 +270,13 @@ def generate( return_logits: bool = True, multimodal_inputs: list[dict] | None = None, ) -> list[dict]: - """Generate hidden states for training data using Worker Extension mode.""" + """Generate hidden states for training data. + + Hidden states are captured by vLLM's ``extract_hidden_states`` + speculative method and stored to Mooncake by the + ``MooncakeHiddenStatesConnector``. Metadata comes back in + ``output.kv_transfer_params``. + """ if self._engine is None: raise RuntimeError("VllmEngine not initialized. Call init() first.") @@ -285,122 +311,52 @@ def generate( from vllm import SamplingParams sampling_params = SamplingParams(max_tokens=1, temperature=0) - request_metadata = {} - if input_ids_list is not None: - for i, ids in enumerate(input_ids_list): - request_metadata[data_ids[i]] = int(self._normalize_input_ids(ids).numel()) - # Build packed_loss_mask_map for workers - packed_loss_mask_map = {} + # Build packed_loss_mask_map for result assembly + packed_loss_mask_map: dict[str, str | None] = {} if packed_loss_mask_list is not None: - for i, data_id in enumerate(data_ids): + for i, did in enumerate(data_ids): if i < len(packed_loss_mask_list): - packed_loss_mask_map[data_id] = packed_loss_mask_list[i] - - # Build input_ids_map for workers (pass real input_ids via RPC) - input_ids_map = {} - if input_ids_list is not None: - for i, data_id in enumerate(data_ids): - if i < len(input_ids_list): - ids = self._normalize_input_ids(input_ids_list[i]) - input_ids_map[data_id] = ids.cpu().tolist() - - try: - self._engine.collective_rpc("_reset_capture") - if request_metadata: - self._engine.collective_rpc( - "_set_request_metadata", - args=(request_metadata, packed_loss_mask_map, input_ids_map), - ) - except Exception as e: - logger.warning(f"Could not reset capture via worker extension: {e}") + packed_loss_mask_map[did] = packed_loss_mask_list[i] outputs = self._engine.generate(prompts, sampling_params, use_tqdm=False) - # outputs are sorted by int(request_id), matching submission order. - # Build mapping from vLLM's internal worker IDs ("{request_id}-{uuid}") - # to our external data_ids. - internal_to_external = {} - for i, output in enumerate(outputs): - internal_to_external[output.request_id] = data_ids[i] - - # Always build request_metadata and input_ids_map from the - # outputs. - for i, output in enumerate(outputs): - did = data_ids[i] - request_metadata[did] = len(output.prompt_token_ids) - input_ids_map[did] = list(output.prompt_token_ids) - try: - self._engine.collective_rpc( - "_set_request_metadata", - args=(request_metadata, packed_loss_mask_map, input_ids_map), - ) - except Exception as e: - logger.warning( - f"VllmEngine rank {self.rank}: Could not set post-generation request metadata: {e}" - ) - - # Get metadata from workers (tensors are already stored in Mooncake by workers) - metadata_by_request: dict[str, dict] = {} - try: - metadata_list = self._engine.collective_rpc( - "_store_and_get_metadata", args=(internal_to_external,) - ) - if isinstance(metadata_list, list): - for metadata in metadata_list: - if isinstance(metadata, dict): - metadata_by_request.update(metadata) - elif isinstance(metadata_list, dict): - metadata_by_request = metadata_list - except Exception as e: - logger.warning(f"Could not get metadata from worker extension: {e}") - - if not metadata_by_request: - logger.error( - f"VllmEngine rank {self.rank}: metadata_by_request is EMPTY for " - f"data_ids={data_ids}. Worker returned metadata_list={metadata_list!r}. " - f"use_prompts={use_prompts}, request_metadata_keys={list(request_metadata.keys())}, " - f"internal_to_external={internal_to_external}" - ) - results = [] for i, output in enumerate(outputs): seq_len = len(output.prompt_token_ids) - data_id = data_ids[i] + did = data_ids[i] - # Get metadata for this request - metadata = metadata_by_request.get(data_id) - if metadata is None: + kv_params = getattr(output, "kv_transfer_params", None) + if kv_params is None: logger.error( - f"VllmEngine rank {self.rank}: No metadata for data_id={data_id}. " - f"metadata_by_request has keys={list(metadata_by_request.keys())}. " - f"Training may be corrupted." + f"VllmEngine rank {self.rank}: No kv_transfer_params for data_id={did}. " + f"The MooncakeHiddenStatesConnector may not have stored this request." ) continue - # Extract info from metadata (tensors are already in Mooncake) - mooncake_key = metadata.get("mooncake_key", data_id) - tensor_shapes = metadata.get("tensor_shapes", {}) - tensor_dtypes = metadata.get("tensor_dtypes", {}) + mooncake_key = kv_params.get("mooncake_key", did) + tensor_shapes = kv_params.get("tensor_shapes", {}) + tensor_dtypes = kv_params.get("tensor_dtypes", {}) - result = { + result: dict[str, Any] = { "mooncake_key": mooncake_key, "tensor_shapes": tensor_shapes, "tensor_dtypes": tensor_dtypes, - "data_id": data_id, + "data_id": did, "seq_len": seq_len, } - # Get packed_loss_mask from metadata (returned by worker) - packed_loss_mask = metadata.get("packed_loss_mask") + + packed_loss_mask = packed_loss_mask_map.get(did) if packed_loss_mask is not None: result["packed_loss_mask"] = packed_loss_mask - # Get input_ids_list from metadata (returned by worker via RPC) - input_ids_list = metadata.get("input_ids_list") - if input_ids_list is not None: - result["input_ids_list"] = input_ids_list - results.append(result) - # No need to flush here - workers already flushed after storing + input_ids_from_kv = kv_params.get("input_ids_list") + if input_ids_from_kv is not None: + result["input_ids_list"] = input_ids_from_kv + else: + result["input_ids_list"] = list(output.prompt_token_ids) + + results.append(result) logger.debug( f"VllmEngine rank {self.rank}: generated {len(results)} mooncake results " @@ -446,8 +402,17 @@ def shutdown(self) -> None: self._mooncake_store = None if self._engine is not None: - del self._engine - self._engine = None + try: + if hasattr(self._engine, "close"): + self._engine.close() + elif hasattr(self._engine, "llm_engine"): + llm_engine = self._engine.llm_engine + if hasattr(llm_engine, "shutdown"): + llm_engine.shutdown() + except Exception as e: + logger.warning(f"VllmEngine rank {self.rank}: Error during engine shutdown: {e}") + finally: + self._engine = None logger.info(f"VllmEngine rank {self.rank}: shutdown complete") diff --git a/torchspec/inference/engine/vllm_worker_extension.py b/torchspec/inference/engine/vllm_worker_extension.py deleted file mode 100644 index 04f1d08..0000000 --- a/torchspec/inference/engine/vllm_worker_extension.py +++ /dev/null @@ -1,774 +0,0 @@ -# Copyright (c) 2026 LightSeek Foundation -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -# SOFTWARE. - -"""vLLM Worker Extension for Hidden States Capture. - -This module provides a TorchSpec-style worker extension for vLLM that enables -reliable hidden states extraction during inference. It patches the model's -forward method in each worker process to capture intermediate layer activations -and store them directly to Mooncake to avoid RPC serialization issues. - -Based on the vllm-speculators approach but integrated into TorchSpec's -architecture with Ray Actors and Mooncake storage. -""" - -import logging -import os -import re -import types -from collections import defaultdict -from itertools import islice -from typing import Any, Dict, List, Optional - -import torch -from vllm.distributed import get_pp_group, get_tp_group -from vllm.sequence import IntermediateTensors - -logger = logging.getLogger(__name__) - - -def _sanitize_mooncake_key(key: str) -> str: - """Sanitize a key for use with Mooncake store. - - Mooncake keys should only contain alphanumeric characters, hyphens, and underscores. - This function replaces invalid characters with underscores. - - Args: - key: The original key (e.g., vLLM req_id) - - Returns: - A sanitized key safe for Mooncake operations - """ - sanitized = re.sub(r"[^a-zA-Z0-9_-]", "_", key) - if sanitized and sanitized[0].isdigit(): - sanitized = "k" + sanitized - return sanitized - - -def _patched_forward( - self, - input_ids: torch.Tensor, - positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, - **kwargs: Any, -) -> Any: - """Patched forward pass that captures hidden states from specified layers. - - This function is dynamically bound to base_model instances via types.MethodType. - It expects base_model to have an _extension attribute pointing to the - VllmWorkerExtension instance. - - Args: - input_ids: Input token IDs - positions: Position IDs - intermediate_tensors: For pipeline parallelism - inputs_embeds: Pre-computed input embeddings (for multimodal) - **kwargs: Additional arguments - - Returns: - Hidden states or IntermediateTensors (for PP) - """ - # Get extension reference - extension = self._extension # noqa: SLF001 - - # Handle pipeline parallelism - first rank does embedding - if get_pp_group().is_first_rank: - hidden_states = ( - inputs_embeds if inputs_embeds is not None else self.embed_input_ids(input_ids) - ) - residual = None - else: - assert intermediate_tensors is not None - hidden_states = intermediate_tensors["hidden_states"] - residual = intermediate_tensors["residual"] - - # Track auxiliary hidden states for capture - aux_hidden_states: List[torch.Tensor] = [] - - # Only capture on TP rank 0 to avoid duplicates - should_capture = get_tp_group().rank_in_group == 0 - target_layers = extension._layer_ids if should_capture else frozenset() # noqa: SLF001 - - # Capture input_ids only on first call (prefill phase) to avoid including generated tokens - if should_capture and get_pp_group().is_first_rank and extension._captured_input_ids is None: - # input_ids shape: (batch_size, seq_len) or (seq_len,) - if input_ids.dim() == 2: - # Flatten batch dimension - extension._captured_input_ids = input_ids.view(-1).clone() - else: - extension._captured_input_ids = input_ids.clone() - - # Process each layer - for idx, layer in enumerate(islice(self.layers, self.start_layer, self.end_layer)): - hidden_states, residual = layer( - hidden_states=hidden_states, - positions=positions, - residual=residual, - ) - absolute_layer_idx = self.start_layer + idx - - # Capture intermediate layers (not the last) before normalization - if absolute_layer_idx in target_layers: - # Add residual before capturing (matching speculators pattern) - captured = ( - (hidden_states + residual).clone() - if residual is not None - else hidden_states.clone() - ) - aux_hidden_states.append(captured) - - # Handle pipeline parallelism - return intermediate tensors if not last rank - if not get_pp_group().is_last_rank: - return IntermediateTensors({"hidden_states": hidden_states, "residual": residual}) - - # Final normalization (only on last PP rank) - hidden_states, _ = self.norm(hidden_states, residual) - - # Store captured states (only on last PP rank, TP rank 0, and during prefill) - if should_capture and not extension._prefill_complete: # noqa: SLF001 - if aux_hidden_states: - extension._store_captured_states(aux_hidden_states) # noqa: SLF001 - extension._store_last_hidden_states(hidden_states) # noqa: SLF001 - - return hidden_states - - -class VllmWorkerExtension: - """Worker extension that adds hidden states capture functionality to vLLM. - - This extension hooks into vLLM's Worker by being specified in the worker - initialization. It patches the model's forward pass to intercept and capture - intermediate layer hidden states during inference. - - Key behaviors: - - Only captures on tensor parallel (TP) rank 0 to avoid duplicate data when - using tensor parallelism. All TP ranks compute the same hidden states, so - capturing from rank 0 is sufficient. - - Stores captured states in GPU memory during batch processing, then writes - directly to Mooncake to avoid RPC serialization issues. - - Supports pipeline parallelism by handling IntermediateTensors correctly. - - Tracks request metadata to map captured states back to original requests - across chunked prefill iterations. - - Attributes: - _layer_ids: Frozenset of layer indices for O(1) lookup during capture - _captured_states: Accumulated hidden states per layer (GPU tensors) - _request_metadata: Metadata tracking tokens per request per iteration - _mooncake_store: EagleMooncakeStore instance for direct storage - model_runner: Reference to the vLLM model runner - """ - - def __init__(self): - """Initialize the worker extension with Mooncake store support.""" - self._layer_ids: frozenset = frozenset() - self._captured_states: Optional[List[List[torch.Tensor]]] = None - self._request_metadata: List[Dict[str, int]] = [] - self._current_request_metadata: Optional[Dict[str, int]] = None - self._mooncake_store: Optional[Any] = None - self._store_initialized: bool = False - self._store_setup_complete: bool = False - self._init_retry_count: int = 0 - self._max_init_retries: int = 3 - self.model_runner: Optional[Any] = None - - def _get_cuda_device_safe(self) -> torch.device: - """Safely get CUDA device, handling uninitialized context (V1 compatibility). - - In vLLM V1, CUDA context may not be initialized when this method is called. - This method safely handles both initialized and uninitialized contexts. - - Returns: - torch.device: The CUDA device to use. Falls back to cuda:0 if context - is not yet initialized (common in V1 engine). - """ - try: - if torch.cuda.is_initialized(): - current_device = torch.cuda.current_device() - logger.debug(f"CUDA initialized, using device cuda:{current_device}") - return torch.device(f"cuda:{current_device}") - else: - # CUDA not initialized yet (V1), use device 0 as fallback - # V1 will initialize context during model loading - logger.debug("CUDA not initialized yet (V1), falling back to cuda:0") - return torch.device("cuda:0") - except RuntimeError as e: - # CUDA context not available - logger.warning(f"Failed to get CUDA device: {e}, falling back to cuda:0") - return torch.device("cuda:0") - - def _init_mooncake_store(self) -> bool: - """Initialize Mooncake store connection in the worker. - - Uses environment variables set by the main process to connect to - the Mooncake master and metadata servers. - - Returns: - True if initialization successful, False otherwise. - """ - if self._store_initialized: - return True - - # Only initialize on TP rank 0 - other ranks don't capture hidden states - try: - if get_tp_group().rank_in_group != 0: - logger.debug("Skipping Mooncake store init on non-zero TP rank") - return False - except Exception: - # If we can't get TP group info, proceed anyway (for backward compatibility) - pass - - try: - from torchspec.config.mooncake_config import MooncakeConfig - from torchspec.transfer.mooncake.eagle_store import EagleMooncakeStore - - if not os.environ.get("MOONCAKE_MASTER_SERVER") and not os.environ.get( - "MOONCAKE_MASTER_HOST" - ): - logger.warning( - "Mooncake master address not available in worker environment. " - "Set MOONCAKE_MASTER_SERVER environment variable." - ) - return False - - config = MooncakeConfig.from_env() - - # Create store object but don't call setup() yet - # setup() will be called lazily when CUDA context is ready - self._mooncake_store = EagleMooncakeStore(config) - # Mark as initialized but not yet setup - # setup() will be called on first put() when CUDA context is ready - self._store_initialized = True - self._store_setup_complete = False - - logger.info( - f"Worker initialized Mooncake store (setup deferred): " - f"master={config.master_server_address}, protocol={config.protocol}" - ) - return True - - except Exception as e: - logger.error(f"Failed to initialize Mooncake store in worker: {e}", exc_info=True) - self._mooncake_store = None - self._store_initialized = False - return False - - def _ensure_mooncake_store(self) -> bool: - """Ensure Mooncake store is initialized and setup, with retry logic. - - This method handles lazy initialization for vLLM V1 compatibility. - In V1, CUDA context may not be ready during initial Worker initialization, - so we defer the actual setup() call until first use. - - Returns: - True if store is ready for use, False otherwise. - """ - # Ensure attributes exist (for vLLM V1 compatibility where __init__ may not be called) - if not hasattr(self, "_store_initialized"): - self._store_initialized = False - if not hasattr(self, "_store_setup_complete"): - self._store_setup_complete = False - if not hasattr(self, "_init_retry_count"): - self._init_retry_count = 0 - if not hasattr(self, "_max_init_retries"): - self._max_init_retries = 3 - if not hasattr(self, "_mooncake_store"): - self._mooncake_store = None - - # Already fully initialized and setup - if self._store_initialized and self._store_setup_complete: - return True - - # Check retry limit - if self._init_retry_count >= self._max_init_retries: - logger.error( - f"Max retries ({self._max_init_retries}) exceeded for Mooncake store initialization" - ) - return False - - try: - # Initialize store if not already done - if not self._store_initialized: - if not self._init_mooncake_store(): - self._init_retry_count += 1 - logger.warning( - f"Mooncake store init failed (attempt {self._init_retry_count}/{self._max_init_retries})" - ) - return False - - # Setup store if not already done - if not self._store_setup_complete and self._mooncake_store is not None: - try: - # Use safe CUDA device detection for V1 compatibility - device = self._get_cuda_device_safe() - logger.info(f"Setting up Mooncake store on device {device}") - self._mooncake_store.setup(device=device) - - try: - logger.info("Warming up Mooncake RDMA path...") - self._mooncake_store.warmup_rdma() - logger.info("Mooncake RDMA warmup completed successfully") - except Exception as warmup_error: - logger.warning(f"Mooncake RDMA warmup failed: {warmup_error}") - - self._store_setup_complete = True - logger.info("Mooncake store setup completed successfully") - return True - except Exception as e: - self._init_retry_count += 1 - # Check if this is a CUDA context error (common in V1) - error_msg = str(e).lower() - if "cuda" in error_msg or "device" in error_msg: - logger.warning( - f"CUDA context not ready (attempt {self._init_retry_count}/{self._max_init_retries}): {e}. " - f"Will retry on next put." - ) - else: - logger.error(f"Mooncake store setup failed: {e}") - return False - - return True - - except Exception as e: - self._init_retry_count += 1 - logger.error(f"Unexpected error in _ensure_mooncake_store: {e}", exc_info=True) - return False - - def _store_last_hidden_states(self, hidden_states: torch.Tensor) -> None: - """Store post-norm hidden states from a forward pass for use as last_hidden_states""" - if getattr(self, "_captured_last_hs", None) is None: - self._captured_last_hs = [hidden_states.clone()] - else: - self._captured_last_hs.append(hidden_states.clone()) - - def _store_captured_states(self, aux_hidden_states: List[torch.Tensor]) -> None: - """Store captured hidden states from a forward pass. - - Args: - aux_hidden_states: List of tensors, one per target layer - """ - if self._captured_states is None: - self._captured_states = [[h] for h in aux_hidden_states] - else: - for i, h in enumerate(aux_hidden_states): - self._captured_states[i].append(h) - - # Track per-request token counts for this scheduler step - model_runner = getattr(self, "model_runner", None) - input_batch = getattr(model_runner, "input_batch", None) - step_tokens: Dict[str, int] = {} - if input_batch is not None and hasattr(input_batch, "req_ids"): - for req_id in input_batch.req_ids: - num_tokens = 0 - req_idx = getattr(input_batch, "req_id_to_index", {}).get(req_id) - if req_idx is not None: - num_computed = getattr( - input_batch, "num_computed_tokens", [0] * len(input_batch.req_ids) - )[req_idx] - num_total = getattr(input_batch, "num_tokens", [0] * len(input_batch.req_ids))[ - req_idx - ] - num_tokens = num_total - num_computed - step_tokens[req_id] = num_tokens - self._request_metadata.append(step_tokens) - - # With max_tokens=1 the prefill forward pass already generates the - # single allowed token, so no decode step is scheduled by vLLM. - # This check handles chunked prefill where multiple forward calls - # sum up to the total prefill token count. - if self._current_request_metadata and not self._prefill_complete: - expected = sum(self._current_request_metadata.values()) - captured = sum(t.shape[0] for t in self._captured_states[0]) - if captured == expected: - self._prefill_complete = True - elif captured > expected: - logger.warning(f"Captured more tokens than expected: {captured} > {expected}") - - def _store_input_ids(self, input_ids: torch.Tensor) -> None: - """Store input_ids from a forward pass. - - Args: - input_ids: Input token IDs tensor (batch_size, seq_len) or (seq_len,) - """ - # Flatten if needed and store - if input_ids.dim() == 2: - # (batch_size, seq_len) - flatten to (batch_size * seq_len,) - input_ids = input_ids.view(-1) - if getattr(self, "_captured_input_ids", None) is None: - self._captured_input_ids = input_ids.clone() - else: - self._captured_input_ids = torch.cat([self._captured_input_ids, input_ids], dim=0) - - def _setup_hidden_states_capture(self, layer_ids: List[int]) -> None: - """Setup model to capture auxiliary hidden states from specific layers. - - This method patches the model's forward method to intercept hidden states - during the forward pass. - - Args: - layer_ids: List of layer indices to capture from - """ - self._layer_ids = frozenset(layer_ids) - self._captured_states = None - self._request_metadata = [] - self._current_request_metadata = None - self._packed_loss_mask_map: Dict[str, Optional[str]] = {} - self._store_initialized = False - self._store_setup_complete = False - self._init_retry_count = 0 - self._mooncake_store = None - - model_runner = getattr(self, "model_runner", None) - if model_runner is None and hasattr(self, "model"): - model_runner = self - if model_runner is None: - raise AttributeError("Could not find model_runner for worker extension setup") - - self.model_runner = model_runner - model = self.model_runner.model # type: ignore[attr-defined] - - # Handle vision-language models (e.g., Qwen-VL) - if hasattr(model, "get_language_model"): - base_model = model.get_language_model().model - # Handle standard text models - elif hasattr(model, "model") and hasattr(model.model, "layers"): - base_model = model.model - else: - # Try to find model with layers attribute - attrs = [a for a in dir(model) if not a.startswith("_")] - raise AttributeError( - f"Could not find base model with 'layers' attribute. " - f"Model type: {type(model).__name__}, " - f"Available attributes: {attrs}" - ) - - # Attach extension reference and patch forward method - base_model._extension = self # noqa: SLF001 - base_model.forward = types.MethodType(_patched_forward, base_model) - - logger.info(f"Hidden states capture setup complete for layers {layer_ids}") - - def _set_request_metadata( - self, - request_metadata: Dict[str, int], - packed_loss_mask_map: Optional[Dict[str, Optional[str]]] = None, - input_ids_map: Optional[Dict[str, List[int]]] = None, - ) -> None: - """Set request metadata for the next forward pass. - - This is called before each scheduler iteration to track which tokens - belong to which request. - - Args: - request_metadata: Dict mapping request_id -> num_prefill_tokens - packed_loss_mask_map: Optional dict mapping request_id -> packed_loss_mask - string (values may be None when loss masks are not available). - input_ids_map: Optional dict mapping request_id -> input_ids list (passed via RPC) - """ - self._current_request_metadata = request_metadata - self._packed_loss_mask_map = packed_loss_mask_map or {} - self._input_ids_map = input_ids_map or {} - - def _reset_capture(self) -> None: - """Reset captured states before starting a new batch. - - Must be called before processing a new batch of requests. - """ - if not hasattr(self, "_layer_ids") or len(self._layer_ids) == 0: - raise RuntimeError("Must call _setup_hidden_states_capture before capturing states") - self._captured_states = None - self._captured_last_hs: Optional[List[torch.Tensor]] = None - self._captured_input_ids: Optional[torch.Tensor] = None - self._prefill_complete = False - self._request_metadata = [] - self._current_request_metadata = None - self._packed_loss_mask_map = {} - self._input_ids_map = {} - - def _store_and_get_metadata( - self, internal_to_external: Optional[Dict[str, str]] = None - ) -> Optional[Dict[str, Dict[str, Any]]]: - """Store captured hidden states to Mooncake and return metadata. - - This method stores tensors directly to Mooncake from the worker process, - avoiding RPC serialization issues. It returns only lightweight metadata - that can be safely serialized and returned via collective_rpc. - - Returns: - Dict mapping request_id to metadata dict with keys: - - 'mooncake_key': str, the base key used for storage - - 'tensor_shapes': dict of tensor shapes - - 'tensor_dtypes': dict of dtype names - - 'num_layers': int, number of captured layers - or None if no states captured or not on TP rank 0. - """ - # Only TP rank 0 has captured data - if get_tp_group().rank_in_group != 0: - return None - if self._captured_states is None: - logger.warning( - "_store_and_get_metadata: captured_states is None " - "(forward patch may not be running or no prefill occurred)" - ) - return None - - # Ensure Mooncake store is initialized and setup (with retry for V1 compatibility) - if not self._ensure_mooncake_store(): - logger.warning( - "Failed to initialize/setup Mooncake store, cannot store hidden states. " - "This may be due to CUDA context not being ready in V1 engine." - ) - return None - - # Concatenate captured states from all scheduler iterations - concatenated_layers = [ - torch.cat(layer_tensors, dim=0) for layer_tensors in self._captured_states - ] - total_captured_tokens = concatenated_layers[0].shape[0] - - # Concatenate post-norm hidden states for last_hidden_states - concatenated_last_hs = None - if getattr(self, "_captured_last_hs", None): - concatenated_last_hs = torch.cat(self._captured_last_hs, dim=0) - - internal_to_external = internal_to_external or {} - ext_token_counts = ( - dict(self._current_request_metadata) if self._current_request_metadata else {} - ) - - # Build worker-visible ID -> external ID lookup once. - # In V1, the worker sees "{counter}-{uuid8}" while internal_to_external - # maps bare counter strings (from output.request_id) to external data_ids. - worker_to_ext: Dict[str, str] = dict(internal_to_external) - for step_meta in self._request_metadata: - for worker_id in step_meta: - if worker_id not in worker_to_ext: - for counter, ext_id in internal_to_external.items(): - if worker_id.startswith(f"{counter}-"): - worker_to_ext[worker_id] = ext_id - break - - request_slices: List[tuple] = [] # (external_id, num_tokens) - seen_ext_ids: set = set() - - for step_meta in self._request_metadata: - for int_id in step_meta.keys(): - ext_id = worker_to_ext.get(int_id, int_id) - if ext_id not in seen_ext_ids: - n_tokens = ext_token_counts.get(ext_id, 0) - if n_tokens > 0: - request_slices.append((ext_id, n_tokens)) - seen_ext_ids.add(ext_id) - - # Fallback if _request_metadata didn't produce results - if not request_slices and ext_token_counts: - logger.warning( - "Internal request metadata mapping failed; falling back to external order" - ) - for ext_id, n_tokens in ext_token_counts.items(): - request_slices.append((ext_id, n_tokens)) - - if not request_slices: - logger.warning( - f"_store_and_get_metadata: request_slices is empty — cannot map " - f"captured tokens to requests. " - f"total_captured_tokens={total_captured_tokens}, " - f"_request_metadata steps={len(self._request_metadata)}, " - f"internal_to_external keys={list(internal_to_external.keys())[:5]}, " - f"ext_token_counts keys={list(ext_token_counts.keys())[:5]}, " - f"current_request_metadata={self._current_request_metadata is not None}" - ) - - total_expected_tokens = sum(n for _, n in request_slices) - - if total_captured_tokens != total_expected_tokens and total_expected_tokens > 0: - logger.warning( - f"Token count mismatch: captured={total_captured_tokens}, " - f"expected={total_expected_tokens}" - ) - - num_aux_layers = len(concatenated_layers) - request_chunks: defaultdict[str, List[List[torch.Tensor]]] = defaultdict( - lambda: [[] for _ in range(num_aux_layers)] - ) - request_last_hs: defaultdict[str, List[torch.Tensor]] = defaultdict(list) - current_idx = 0 - - for external_id, num_tokens in request_slices: - if current_idx >= total_captured_tokens: - break - actual_tokens = min(num_tokens, total_captured_tokens - current_idx) - if actual_tokens > 0: - for layer_idx, layer_tensor in enumerate(concatenated_layers): - chunk = layer_tensor[current_idx : current_idx + actual_tokens] - request_chunks[external_id][layer_idx].append(chunk) - if concatenated_last_hs is not None: - request_last_hs[external_id].append( - concatenated_last_hs[current_idx : current_idx + actual_tokens] - ) - current_idx += actual_tokens - - # Store to Mooncake and collect metadata - result: Dict[str, Dict[str, Any]] = {} - for req_id, layer_chunks in request_chunks.items(): - mooncake_key = _sanitize_mooncake_key(req_id) - if mooncake_key != req_id: - logger.debug(f"Sanitized key '{req_id}' -> '{mooncake_key}'") - - layer_tensors = [torch.cat(chunks, dim=0) for chunks in layer_chunks] - - if len(layer_tensors) > 1: - hidden_states = torch.cat(layer_tensors, dim=-1) - else: - hidden_states = layer_tensors[0] - - if req_id in request_last_hs and request_last_hs[req_id]: - last_hidden_states = torch.cat(request_last_hs[req_id], dim=0) - else: - last_hidden_states = layer_tensors[-1] - - # Use real input_ids from RPC, otherwise create dummy - if req_id in self._input_ids_map: - input_ids_list = self._input_ids_map[req_id] - input_ids = torch.tensor( - input_ids_list, dtype=torch.long, device=hidden_states.device - ) - else: - seq_len = hidden_states.shape[0] - input_ids = torch.zeros(seq_len, dtype=torch.long, device=hidden_states.device) - - # Skip empty tensors - if hidden_states.numel() == 0: - logger.error(f"Request {req_id}: hidden_states is empty! Skipping.") - continue - - try: - logger.debug( - f"Storing to Mooncake: key={mooncake_key}, " - f"hidden_states_shape={hidden_states.shape}" - ) - - # Store to Mooncake - store_meta = self._mooncake_store.put( - key=mooncake_key, - hidden_states=hidden_states, - input_ids=input_ids, - last_hidden_states=last_hidden_states, - target=None, - ) - - logger.debug(f"Successfully stored to Mooncake: key={mooncake_key}") - - result[req_id] = { - "mooncake_key": mooncake_key, - "tensor_shapes": store_meta["shapes"], - "tensor_dtypes": { - k: str(v).replace("torch.", "") for k, v in store_meta["dtypes"].items() - }, - "num_layers": len(layer_tensors), - "packed_loss_mask": self._packed_loss_mask_map.get(req_id), - "input_ids_list": input_ids.cpu().tolist(), # Serialize via RPC instead of Mooncake - } - except Exception as e: - logger.warning( - f"Failed to store tensors to Mooncake for {req_id} (key={mooncake_key}): {e}" - ) - continue - - # Flush to ensure all writes are complete before returning - if self._mooncake_store is not None: - self._mooncake_store.flush() - - # Clear intermediate storage to free memory - self._captured_states = None - self._captured_last_hs = None - self._captured_input_ids = None - self._request_metadata = [] - self._input_ids_map = {} - - return result if result else None - - def _get_captured_states(self) -> Optional[Dict[str, List[torch.Tensor]]]: - """Legacy method - now delegates to _store_and_get_metadata. - - This method is kept for backward compatibility but should not be used - in production due to RPC serialization issues. Use _store_and_get_metadata - instead which stores tensors directly to Mooncake. - - Returns: - Dict mapping request_id to list of tensors (one per layer), - or None if no states captured. - """ - # If Mooncake store is available, use the new method - if self._store_initialized or self._init_mooncake_store(): - metadata = self._store_and_get_metadata() - if metadata is None: - return None - # Return empty dict to signal success - actual data is in Mooncake - return {} - - # Fallback to old behavior if Mooncake not available - if self._captured_states is None: - return None - - # Concatenate captured states from all scheduler iterations - concatenated_layers = [ - torch.cat(layer_tensors, dim=0) for layer_tensors in self._captured_states - ] - - # Slice and group by request - request_chunks: defaultdict[str, List[List[torch.Tensor]]] = defaultdict( - lambda: [[] for _ in range(len(concatenated_layers))] - ) - current_idx = 0 - - # Use external IDs for slicing - external_ids = ( - list(self._current_request_metadata.keys()) if self._current_request_metadata else [] - ) - token_counts = ( - list(self._current_request_metadata.values()) if self._current_request_metadata else [] - ) - - req_idx = 0 - for step_metadata in self._request_metadata: - step_tokens = sum(step_metadata.values()) if step_metadata else 0 - if step_tokens == 0 and req_idx < len(token_counts): - step_tokens = token_counts[req_idx] - - if req_idx < len(external_ids): - external_id = external_ids[req_idx] - for layer_idx, layer_tensor in enumerate(concatenated_layers): - chunk = layer_tensor[current_idx : current_idx + step_tokens].clone().cpu() - request_chunks[external_id][layer_idx].append(chunk) - current_idx += step_tokens - req_idx += 1 - - # Concatenate chunks for each request across iterations - result: Dict[str, List[torch.Tensor]] = { - req_id: [torch.cat(chunks, dim=0) for chunks in layer_chunks] - for req_id, layer_chunks in request_chunks.items() - } - - # Clear intermediate storage to free memory - self._captured_states = None - self._request_metadata = [] - - return result diff --git a/torchspec/models/target/target_utils.py b/torchspec/models/target/target_utils.py index 492be44..b8d76f4 100644 --- a/torchspec/models/target/target_utils.py +++ b/torchspec/models/target/target_utils.py @@ -34,18 +34,24 @@ class TargetLMHead(nn.Module): """ Efficiently loads only the lm_head from a pretrained model. Used for computing logits from last_hidden_states in the trainer. + + When ``load_norm=True``, also loads the final RMSNorm weights so the + trainer can normalise pre-norm hidden states before the lm_head projection. """ def __init__(self, config): super().__init__() self.config = getattr(config, "text_config", config) self.lm_head = nn.Linear(self.config.hidden_size, self.config.vocab_size, bias=False) + self.norm: nn.Module | None = None @classmethod def from_pretrained( cls, model_path: str, lm_head_key: str = "lm_head.weight", + norm_key: str = "model.norm.weight", + load_norm: bool = False, cache_dir: Optional[str] = None, device: str = "cuda", dtype: torch.dtype = torch.bfloat16, @@ -65,6 +71,9 @@ def from_pretrained( instance._load_lm_head(local_model_path, lm_head_key) + if load_norm: + instance._init_and_load_norm(local_model_path, norm_key) + instance.to(device=device, dtype=dtype) instance.eval() instance.requires_grad_(False) @@ -112,6 +121,114 @@ def _load_key_from_file(self, file_path: str, key: str): else: raise KeyError(f"Key {key} not found in {file_path}") + def _init_norm_structure(self) -> None: + """Create the norm module structure (no weights loaded). + + Used by non-rank-0 processes so that ``parameters()`` yields the + same count as rank 0 before the broadcast sync. + """ + import logging + + _log = logging.getLogger(__name__) + + try: + norm_module = self._extract_norm_from_architecture() + if norm_module is None: + return + self.norm = norm_module.to_empty(device="cpu") + torch.nn.init.ones_(self.norm.weight) + except Exception as e: + _log.warning(f"Failed to create verifier norm structure: {e}") + self.norm = None + + def _init_and_load_norm(self, model_path: str, norm_key: str) -> None: + """Extract the final norm module from the target model architecture and load its weight. + + Falls back to no-op if the architecture has no final norm or the + weight cannot be loaded — the trainer checks ``self.norm is not None`` + before applying it. + """ + import logging + + _log = logging.getLogger(__name__) + + try: + norm_module = self._extract_norm_from_architecture() + if norm_module is None: + _log.warning( + "No final norm found in model architecture " + f"(model_type={getattr(self.config, 'model_type', 'unknown')}). " + "last_hidden_states will be used without normalization." + ) + return + + self.norm = norm_module.to_empty(device="cpu") + self._load_key_into(model_path, norm_key, self.norm.weight) + + except Exception as e: + _log.warning( + f"Failed to load verifier norm: {e}. " + "last_hidden_states will be used without normalization." + ) + self.norm = None + + def _extract_norm_from_architecture(self) -> "nn.Module | None": + """Instantiate the model on meta device and return the final norm module.""" + from transformers import AutoModelForCausalLM + + with torch.device("meta"): + skeleton = AutoModelForCausalLM.from_config( + self.config, + trust_remote_code=True, + attn_implementation="eager", + ) + + inner = skeleton + for attr in ("model", "language_model", "model"): + inner = getattr(inner, attr, inner) + norm_module = None + for name in ("norm", "ln_f", "final_layer_norm"): + norm_module = getattr(inner, name, None) + if norm_module is not None: + break + + del skeleton + return norm_module + + def _load_key_into(self, model_path: str, key: str, param: torch.nn.Parameter) -> None: + """Load a single key from safetensors/bin files into a parameter.""" + index_files = glob.glob(os.path.join(model_path, "*.index.json")) + if index_files: + with open(index_files[0], "r") as f: + index = json.load(f) + weight_map = index.get("weight_map", {}) + if key in weight_map: + file_path = os.path.join(model_path, weight_map[key]) + else: + raise KeyError(f"Key '{key}' not found in weight_map") + else: + safetensors = glob.glob(os.path.join(model_path, "*.safetensors")) + bins = glob.glob(os.path.join(model_path, "*.bin")) + file_path = safetensors[0] if safetensors else (bins[0] if bins else None) + if file_path is None: + raise FileNotFoundError(f"No checkpoint file found in {model_path}") + + tensor = None + if file_path.endswith(".safetensors"): + with safe_open(file_path, framework="pt") as f: + if key in f.keys(): + tensor = f.get_tensor(key) + else: + state_dict = torch.load(file_path, map_location="cpu") + if key in state_dict: + tensor = state_dict[key] + del state_dict + + if tensor is not None: + param.data.copy_(tensor) + else: + raise KeyError(f"Key {key} not found in {file_path}") + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: """Compute logits from hidden states.""" return self.lm_head(hidden_states) diff --git a/torchspec/training/eagle3_trainer.py b/torchspec/training/eagle3_trainer.py index cf75d4e..1204cfc 100644 --- a/torchspec/training/eagle3_trainer.py +++ b/torchspec/training/eagle3_trainer.py @@ -135,6 +135,8 @@ def init_model( checkpoint_payload = checkpoint.load(self) checkpoint.finalize_load(self, checkpoint_payload) + self._last_hs_prenorm = getattr(self.args, "last_hidden_states_prenorm", False) + if getattr(self.args, "compute_logits_in_trainer", True): self._init_target_lm_head(target_model_path) @@ -144,6 +146,7 @@ def init_model( "Set compute_logits_in_trainer=True or provide a TargetLMHead." ) self.target_lm_head_weight = self.target_lm_head.lm_head.weight + self.verifier_norm = self.target_lm_head.norm if getattr(self.args, "attention_backend", None) == "fa_experimental": from torchspec.models.draft.llama3_eagle import ( @@ -185,11 +188,16 @@ def _init_target_lm_head(self, target_model_path: str) -> None: self.target_lm_head = TargetLMHead.from_pretrained( model_path=target_model_path, lm_head_key=getattr(self.args, "lm_head_key", "lm_head.weight"), + norm_key=getattr(self.args, "norm_key", "model.norm.weight"), + load_norm=self._last_hs_prenorm, device="cuda", dtype=torch.bfloat16, trust_remote_code=getattr(self.args, "trust_remote_code", True), ) - logger.info(f"[Rank 0] TargetLMHead loaded from {target_model_path}") + logger.info( + f"[Rank 0] TargetLMHead loaded from {target_model_path}" + f"{' (with verifier norm)' if self._last_hs_prenorm else ''}" + ) else: from transformers import AutoConfig @@ -198,6 +206,8 @@ def _init_target_lm_head(self, target_model_path: str) -> None: trust_remote_code=getattr(self.args, "trust_remote_code", True), ) self.target_lm_head = TargetLMHead(config) + if self._last_hs_prenorm: + self.target_lm_head._init_norm_structure() self.target_lm_head.to(device="cuda", dtype=torch.bfloat16) self.target_lm_head.eval() self.target_lm_head.requires_grad_(False) @@ -217,6 +227,10 @@ def _forward(self, batch: dict) -> Tuple[List[torch.Tensor], List[torch.Tensor]] input_ids = padding(batch["input_ids"], left=False).cuda() target_hidden_states = padding(batch["last_hidden_states"], left=False).cuda() + if self.verifier_norm is not None: + with torch.no_grad(): + target_hidden_states = self.verifier_norm(target_hidden_states) + loss_mask = batch["loss_mask"] if loss_mask.dim() == 3: loss_mask = loss_mask.squeeze(-1)