diff --git a/configs/draft_models/kimi_k25_eagle3_mla.json b/configs/draft_models/kimi_k25_eagle3_mla.json new file mode 100644 index 0000000..6f2440c --- /dev/null +++ b/configs/draft_models/kimi_k25_eagle3_mla.json @@ -0,0 +1,41 @@ +{ + "architectures": [ + "Eagle3DeepseekV2ForCausalLM" + ], + "model_type": "kimi_k2", + "hidden_size": 7168, + "intermediate_size": 18432, + "num_hidden_layers": 1, + "num_attention_heads": 64, + "num_key_value_heads": 64, + "q_lora_rank": 1536, + "kv_lora_rank": 512, + "qk_nope_head_dim": 128, + "qk_rope_head_dim": 64, + "v_head_dim": 128, + "hidden_act": "silu", + "rms_norm_eps": 1e-05, + "vocab_size": 163840, + "draft_vocab_size": 163840, + "torch_dtype": "bfloat16", + "rope_theta": 50000.0, + "rope_scaling": { + "beta_fast": 1.0, + "beta_slow": 1.0, + "factor": 64.0, + "mscale": 1.0, + "mscale_all_dim": 1.0, + "original_max_position_embeddings": 4096, + "type": "yarn" + }, + "eagle_config": { + "eagle_aux_hidden_state_layer_ids": [1, 29, 57], + "use_aux_hidden_state": true, + "use_input_layernorm_in_first_layer": true, + "use_last_layernorm": true, + "use_mtp_layernorm": false + }, + "bos_token_id": 163584, + "eos_token_id": 163585, + "pad_token_id": 0 +} diff --git a/configs/draft_models/qwen3_8b_eagle3_mla.json b/configs/draft_models/qwen3_8b_eagle3_mla.json new file mode 100644 index 0000000..ba29701 --- /dev/null +++ b/configs/draft_models/qwen3_8b_eagle3_mla.json @@ -0,0 +1,35 @@ +{ + "architectures": [ + "DeepSeekForCausalLMEagle3" + ], + "model_type": "deepseek_v3", + "hidden_size": 4096, + "num_attention_heads": 32, + "num_key_value_heads": 32, + "num_hidden_layers": 1, + "intermediate_size": 12288, + "hidden_act": "silu", + "rms_norm_eps": 1e-06, + "q_lora_rank": 1536, + "kv_lora_rank": 512, + "qk_nope_head_dim": 128, + "qk_rope_head_dim": 64, + "v_head_dim": 128, + "max_position_embeddings": 262144, + "rope_theta": 1000000, + "rope_scaling": { + "type": "yarn", + "factor": 64.0, + "original_max_position_embeddings": 4096, + "beta_fast": 1.0, + "beta_slow": 1.0, + "mscale": 1.0, + "mscale_all_dim": 1.0 + }, + "vocab_size": 151936, + "tie_word_embeddings": false, + "pretraining_tp": 1, + "bos_token_id": 151643, + "eos_token_id": 151645, + "pad_token_id": 0 +} diff --git a/configs/sglang_qwen3_8b.yaml b/configs/sglang_qwen3_8b.yaml index feb1652..8e0975b 100644 --- a/configs/sglang_qwen3_8b.yaml +++ b/configs/sglang_qwen3_8b.yaml @@ -15,8 +15,8 @@ model: dataset: train_data_path: ../examples/data/sample_conversations.jsonl - eval_data_path: ../examples/data/eval_conversations.jsonl - eval_interval: 100 + # eval_data_path: ../examples/data/eval_conversations.jsonl + # eval_interval: 100 chat_template: qwen prompt_key: conversations diff --git a/configs/sglang_qwen3_8b_mla_draft.yaml b/configs/sglang_qwen3_8b_mla_draft.yaml new file mode 100644 index 0000000..023bae4 --- /dev/null +++ b/configs/sglang_qwen3_8b_mla_draft.yaml @@ -0,0 +1,61 @@ +# Configuration for Qwen3-8B target model with MLA draft model +# +# Same as sglang_qwen3_8b.yaml but uses DeepSeek MLA attention in the draft model. +# +# Usage: +# python -m torchspec.train_entry --config configs/sglang_qwen3_8b_mla_draft.yaml +# ./examples/qwen3-8b-single-node/run.sh configs/sglang_qwen3_8b_mla_draft.yaml + +model: + target_model_path: Qwen/Qwen3-8B + draft_model_config: configs/draft_models/qwen3_8b_eagle3_mla.json + trust_remote_code: true + +dataset: + train_data_path: ../examples/data/sample_conversations.jsonl + chat_template: qwen + prompt_key: conversations + +training: + attention_backend: flex_attention + micro_batch_size: 1 + draft_accumulation_steps: 1 + learning_rate: 1e-4 + max_concurrent_batches: 1 + max_grad_norm: 0.5 + max_seq_length: 16384 + num_epochs: 1 + seed: 42 + training_num_gpus_per_node: 2 + training_num_nodes: 1 + ttt_length: 7 + save_per_epoch: true + warmup_ratio: 0.015 + +inference: + inference_engine_type: sgl + inference_num_gpus: 1 + inference_num_gpus_per_engine: 1 + inference_num_gpus_per_node: 1 + max_sample_pool_size: 64 + inference_buffer_threshold: 32 + inference_batch_size: 8 + sglang: + tp_size: 1 + mem_fraction_static: 0.7 + +mooncake: + master_server_address: null + metadata_server: null + protocol: tcp + global_segment_size: 16GB + local_buffer_size: 4GB + +output_dir: ./outputs/qwen3-8b-mla-draft +cache_dir: ./cache/qwen3-8b-mla-draft +model_download_dir: null + +debug: + save_debug_train_data: null + debug_train_only: false + debug_inference_only: false diff --git a/examples/qwen3-8b-single-node/run.sh b/examples/qwen3-8b-single-node/run.sh index 14675bb..6649200 100755 --- a/examples/qwen3-8b-single-node/run.sh +++ b/examples/qwen3-8b-single-node/run.sh @@ -56,12 +56,14 @@ echo "Local IP: $LOCAL_IP" echo "Extra args: $*" echo "==============================================" +# TODO: unify tp_size config across sglang/vllm backends python3 -m torchspec.train_entry \ --config "$CONFIG_FILE" \ training.training_num_gpus_per_node="$TRAIN_GPUS" \ inference.inference_num_gpus="$INFERENCE_GPUS" \ inference.inference_num_gpus_per_engine=2 \ inference.inference_num_gpus_per_node="$TOTAL_GPUS" \ + inference.sglang.tp_size=2 \ "$@" echo "==============================================" diff --git a/tests/test_deepseek_eagle.py b/tests/test_deepseek_eagle.py new file mode 100644 index 0000000..df1aa13 --- /dev/null +++ b/tests/test_deepseek_eagle.py @@ -0,0 +1,464 @@ +"""Tests for DeepSeek MLA Eagle3 draft model. + +Verifies that: +1. Forward pass produces correct output shapes (with and without cache). +2. Backward pass computes gradients for all trainable parameters. +3. q_lora_rank=None path works correctly. +4. Config dispatch correctly routes to DeepSeekForCausalLMEagle3. +5. Softmax scale is computed correctly with YaRN mscale. +6. Eagle3Model TTT loop works end-to-end with MLA draft model. +""" + +import math +import unittest + +import torch +import torch.nn.functional as F +from transformers.models.deepseek_v3.configuration_deepseek_v3 import DeepseekV3Config + +from torchspec.models.draft.auto import AutoDraftModelConfig, AutoEagle3DraftModel +from torchspec.models.draft.deepseek_eagle import ( + DeepSeekForCausalLMEagle3, + DeepSeekMLAAttention, +) +from torchspec.models.draft.llama3_eagle import yarn_get_mscale +from torchspec.models.eagle3 import ( + Eagle3Model, + PrecomputedTarget, + compute_lazy_target_padded, +) + + +def _make_config( + H=64, + V=256, + draft_V=None, + num_heads=4, + qk_nope=16, + qk_rope=8, + v_head=16, + kv_lora=32, + q_lora=48, + rope_scaling=None, +): + config = DeepseekV3Config( + hidden_size=H, + num_attention_heads=num_heads, + num_key_value_heads=num_heads, + intermediate_size=H * 4, + max_position_embeddings=1024, + vocab_size=V, + hidden_act="silu", + rms_norm_eps=1e-6, + rope_scaling=rope_scaling, + pretraining_tp=1, + pad_token_id=0, + q_lora_rank=q_lora, + kv_lora_rank=kv_lora, + qk_nope_head_dim=qk_nope, + qk_rope_head_dim=qk_rope, + v_head_dim=v_head, + num_hidden_layers=1, + # MoE fields (unused by draft model, use small defaults) + n_routed_experts=1, + n_shared_experts=0, + first_k_dense_replace=0, + num_experts_per_tok=1, + ) + config.draft_vocab_size = draft_V or V + config.target_hidden_size = H + return config + + +def _make_model(config, length=3, device="cpu", attention_backend="sdpa"): + draft_model = DeepSeekForCausalLMEagle3(config, attention_backend=attention_backend) + draft_model = draft_model.to(device=device, dtype=torch.bfloat16) + model = Eagle3Model( + draft_model, + length=length, + attention_backend=attention_backend, + ) + model.eval() + return model + + +def _make_batch(B, T, H, V, device="cpu"): + input_ids = torch.randint(0, V, (B, T), device=device) + attention_mask = torch.ones(B, T, dtype=torch.long, device=device) + loss_mask = torch.zeros(B, T, device=device) + loss_mask[:, T // 4 :] = 1.0 + hidden_states = torch.randn(B, T, H * 3, device=device, dtype=torch.bfloat16) + target_hidden_states = torch.randn(B, T, H, device=device, dtype=torch.bfloat16) + return { + "input_ids": input_ids, + "attention_mask": attention_mask, + "loss_mask": loss_mask, + "hidden_states": hidden_states, + "target_hidden_states": target_hidden_states, + } + + +class TestForwardShapeNoCache(unittest.TestCase): + """backbone() without cache returns correct shapes.""" + + def test_output_shape(self): + torch.manual_seed(42) + H, V, B, T = 64, 256, 2, 16 + config = _make_config(H=H, V=V) + draft = DeepSeekForCausalLMEagle3(config).to(torch.bfloat16) + + input_emb = torch.randn(B, T, H, dtype=torch.bfloat16) + hidden = torch.randn(B, T, H, dtype=torch.bfloat16) + pos_ids = torch.arange(T).unsqueeze(0).expand(B, -1) + + out, ck, cv = draft.backbone( + input_embeds=input_emb, + hidden_states=hidden, + attention_mask=None, + position_ids=pos_ids, + cache_keys=None, + cache_values=None, + use_cache=False, + ) + self.assertEqual(out.shape, (B, T, H)) + self.assertIsNone(ck) + self.assertIsNone(cv) + + +class TestForwardShapeWithCache(unittest.TestCase): + """backbone() with cache returns correct 5D cache shapes.""" + + def test_cache_shapes_across_steps(self): + torch.manual_seed(42) + H, V, B, T = 64, 256, 2, 8 + num_heads, qk_nope, qk_rope, v_head = 4, 16, 8, 16 + config = _make_config( + H=H, + V=V, + num_heads=num_heads, + qk_nope=qk_nope, + qk_rope=qk_rope, + v_head=v_head, + ) + draft = DeepSeekForCausalLMEagle3(config).to(torch.bfloat16) + + from torchspec.models.draft.base import prepare_decoder_attention_mask + + input_emb = torch.randn(B, T, H, dtype=torch.bfloat16) + hidden = torch.randn(B, T, H, dtype=torch.bfloat16) + pos_ids = torch.arange(T).unsqueeze(0).expand(B, -1) + attn_mask_base = torch.ones(B, T, dtype=torch.long) + + cache_keys = None + cache_values = None + + for step in range(3): + attn_mask = prepare_decoder_attention_mask(attn_mask_base, (B, T), hidden, 0) + _, cache_keys, cache_values = draft.backbone( + input_embeds=input_emb, + hidden_states=hidden, + attention_mask=attn_mask, + position_ids=pos_ids, + cache_keys=cache_keys, + cache_values=cache_values, + use_cache=True, + ) + + expected_k_shape = (B, num_heads, step + 1, T, qk_nope + qk_rope) + expected_v_shape = (B, num_heads, step + 1, T, v_head) + self.assertEqual(cache_keys.shape, expected_k_shape, f"step {step}") + self.assertEqual(cache_values.shape, expected_v_shape, f"step {step}") + + +class TestBackward(unittest.TestCase): + """All trainable parameters receive gradients.""" + + def test_gradients(self): + torch.manual_seed(42) + H, V, B, T = 64, 256, 2, 8 + config = _make_config(H=H, V=V) + draft = DeepSeekForCausalLMEagle3(config).to(torch.bfloat16) + draft.freeze_embedding() + draft.train() + + input_emb = torch.randn(B, T, H, dtype=torch.bfloat16) + hidden = torch.randn(B, T, H, dtype=torch.bfloat16) + pos_ids = torch.arange(T).unsqueeze(0).expand(B, -1) + + from torchspec.models.draft.base import prepare_decoder_attention_mask + + attn_mask = prepare_decoder_attention_mask( + torch.ones(B, T, dtype=torch.long), (B, T), hidden, 0 + ) + + out, _, _ = draft.backbone( + input_embeds=input_emb, + hidden_states=hidden, + attention_mask=attn_mask, + position_ids=pos_ids, + use_cache=True, + ) + loss = out.sum() + loss.backward() + + # Embedding should be frozen + self.assertIsNone(draft.embed_tokens.weight.grad) + + # All params in backbone's forward path (midlayer.*) should have gradients. + # embed_tokens, fc, norm, lm_head are used outside backbone(). + for name, param in draft.named_parameters(): + if "midlayer" not in name: + continue + if param.requires_grad: + self.assertIsNotNone(param.grad, f"Parameter {name} has no gradient") + + +class TestQLoraNone(unittest.TestCase): + """q_lora_rank=None uses direct Q projection.""" + + def test_forward_works(self): + torch.manual_seed(42) + H, V, B, T = 64, 256, 2, 8 + config = _make_config(H=H, V=V, q_lora=None) + draft = DeepSeekForCausalLMEagle3(config).to(torch.bfloat16) + + self.assertFalse(hasattr(draft.midlayer.self_attn, "q_a_proj")) + self.assertTrue(hasattr(draft.midlayer.self_attn, "q_proj")) + + input_emb = torch.randn(B, T, H, dtype=torch.bfloat16) + hidden = torch.randn(B, T, H, dtype=torch.bfloat16) + pos_ids = torch.arange(T).unsqueeze(0).expand(B, -1) + + out, _, _ = draft.backbone( + input_embeds=input_emb, + hidden_states=hidden, + attention_mask=None, + position_ids=pos_ids, + use_cache=False, + ) + self.assertEqual(out.shape, (B, T, H)) + + +class TestConfigDispatch(unittest.TestCase): + """AutoDraftModelConfig + AutoEagle3DraftModel correctly dispatch to DeepSeek.""" + + def test_dispatch(self): + config_dict = { + "architectures": ["DeepSeekForCausalLMEagle3"], + "model_type": "deepseek_v3", + "hidden_size": 64, + "num_attention_heads": 4, + "num_key_value_heads": 4, + "num_hidden_layers": 1, + "intermediate_size": 256, + "hidden_act": "silu", + "rms_norm_eps": 1e-6, + "q_lora_rank": 48, + "kv_lora_rank": 32, + "qk_nope_head_dim": 16, + "qk_rope_head_dim": 8, + "v_head_dim": 16, + "vocab_size": 256, + "n_routed_experts": 1, + "n_shared_experts": 0, + "first_k_dense_replace": 0, + "num_experts_per_tok": 1, + } + config = AutoDraftModelConfig.from_dict(config_dict) + self.assertIsInstance(config, DeepseekV3Config) + + model = AutoEagle3DraftModel.from_config(config, torch_dtype=torch.bfloat16) + self.assertIsInstance(model, DeepSeekForCausalLMEagle3) + + +class TestSoftmaxScale(unittest.TestCase): + """Softmax scale computation with YaRN mscale.""" + + def test_yarn_mscale(self): + config = _make_config( + qk_nope=128, + qk_rope=64, + rope_scaling={ + "type": "yarn", + "factor": 40, + "original_max_position_embeddings": 4096, + "beta_fast": 32, + "beta_slow": 1, + "mscale": 1, + "mscale_all_dim": 0.1, + }, + ) + attn = DeepSeekMLAAttention(config) + + mscale = yarn_get_mscale(40, 0.1) + expected = (mscale * mscale) / math.sqrt(128 + 64) + self.assertAlmostEqual(attn.softmax_scale, expected, places=6) + + def test_no_rope_scaling(self): + config = _make_config(qk_nope=16, qk_rope=8, rope_scaling=None) + attn = DeepSeekMLAAttention(config) + expected = 1.0 / math.sqrt(16 + 8) + self.assertAlmostEqual(attn.softmax_scale, expected, places=6) + + +class TestEagle3ModelTTT(unittest.TestCase): + """Eagle3Model TTT loop with MLA draft model: Lazy vs Precomputed target.""" + + def _run_both_paths(self, device="cpu"): + torch.manual_seed(42) + H, V, B, T, length = 64, 256, 1, 32, 3 + + config = _make_config(H=H, V=V) + model = _make_model(config, length=length, device=device) + batch = _make_batch(B, T, H, V, device=device) + + draft_model = model.draft_model + _, lm_head_weight, _ = draft_model.get_lm_head_params() + + with torch.no_grad(): + target_logits = F.linear(batch["target_hidden_states"], lm_head_weight.detach()) + target_p = F.softmax(target_logits.float(), dim=-1) + target_p_padded = F.pad(target_p, (0, 0, 0, length), value=0.0) + + precomputed = PrecomputedTarget(target_p_padded) + with torch.no_grad(): + plosses_pre, _, acces_pre = model( + input_ids=batch["input_ids"], + attention_mask=batch["attention_mask"], + target=precomputed, + loss_mask=batch["loss_mask"], + hidden_states=batch["hidden_states"], + ) + + lazy = compute_lazy_target_padded( + batch["target_hidden_states"], + lm_head_weight, + length, + ) + with torch.no_grad(): + plosses_lazy, _, acces_lazy = model( + input_ids=batch["input_ids"], + attention_mask=batch["attention_mask"], + target=lazy, + loss_mask=batch["loss_mask"], + hidden_states=batch["hidden_states"], + ) + + return plosses_pre, acces_pre, plosses_lazy, acces_lazy + + def test_losses_finite_and_non_negative_cpu(self): + plosses_pre, _, _, _ = self._run_both_paths("cpu") + for i, loss in enumerate(plosses_pre): + self.assertTrue(torch.isfinite(loss), f"Loss {i} is not finite: {loss}") + self.assertGreaterEqual(loss.item(), 0.0, f"Loss {i} is negative: {loss}") + + def test_lazy_matches_precomputed_cpu(self): + plosses_pre, acces_pre, plosses_lazy, acces_lazy = self._run_both_paths("cpu") + for i, (pre, lazy) in enumerate(zip(plosses_pre, plosses_lazy)): + torch.testing.assert_close( + pre, + lazy, + atol=1e-4, + rtol=1e-4, + msg=f"Loss mismatch at position {i}", + ) + for i, (pre, lazy) in enumerate(zip(acces_pre, acces_lazy)): + torch.testing.assert_close( + pre, + lazy, + atol=1e-4, + rtol=1e-4, + msg=f"Accuracy mismatch at position {i}", + ) + + @unittest.skipUnless(torch.cuda.is_available(), "CUDA not available") + def test_losses_match_cuda(self): + plosses_pre, acces_pre, plosses_lazy, acces_lazy = self._run_both_paths("cuda") + for i, (pre, lazy) in enumerate(zip(plosses_pre, plosses_lazy)): + torch.testing.assert_close( + pre, + lazy, + atol=1e-3, + rtol=1e-3, + msg=f"Loss mismatch at position {i}", + ) + + +class TestFlexAttentionTTT(unittest.TestCase): + """Eagle3Model TTT loop with MLA + flex_attention backend.""" + + @unittest.skipUnless(torch.cuda.is_available(), "CUDA not available") + def test_flex_losses_finite_cuda(self): + torch.manual_seed(42) + H, V, B, T, length = 64, 256, 1, 32, 3 + config = _make_config(H=H, V=V) + model = _make_model( + config, length=length, device="cuda", attention_backend="flex_attention" + ) + batch = _make_batch(B, T, H, V, device="cuda") + + draft_model = model.draft_model + _, lm_head_weight, _ = draft_model.get_lm_head_params() + + with torch.no_grad(): + target_logits = F.linear(batch["target_hidden_states"], lm_head_weight.detach()) + target_p = F.softmax(target_logits.float(), dim=-1) + target_p_padded = F.pad(target_p, (0, 0, 0, length), value=0.0) + + precomputed = PrecomputedTarget(target_p_padded) + with torch.no_grad(): + plosses, _, acces = model( + input_ids=batch["input_ids"], + attention_mask=batch["attention_mask"], + target=precomputed, + loss_mask=batch["loss_mask"], + hidden_states=batch["hidden_states"], + ) + + for i, loss in enumerate(plosses): + self.assertTrue(torch.isfinite(loss), f"Loss {i} is not finite: {loss}") + self.assertGreaterEqual(loss.item(), 0.0, f"Loss {i} is negative: {loss}") + + @unittest.skipUnless(torch.cuda.is_available(), "CUDA not available") + def test_flex_matches_sdpa_cuda(self): + """Flex attention and SDPA should produce similar losses.""" + torch.manual_seed(42) + H, V, B, T, length = 64, 256, 1, 32, 3 + config = _make_config(H=H, V=V) + batch = _make_batch(B, T, H, V, device="cuda") + + results = {} + for backend in ("sdpa", "flex_attention"): + torch.manual_seed(42) + model = _make_model(config, length=length, device="cuda", attention_backend=backend) + draft_model = model.draft_model + _, lm_head_weight, _ = draft_model.get_lm_head_params() + + with torch.no_grad(): + target_logits = F.linear(batch["target_hidden_states"], lm_head_weight.detach()) + target_p = F.softmax(target_logits.float(), dim=-1) + target_p_padded = F.pad(target_p, (0, 0, 0, length), value=0.0) + + precomputed = PrecomputedTarget(target_p_padded) + with torch.no_grad(): + plosses, _, _ = model( + input_ids=batch["input_ids"], + attention_mask=batch["attention_mask"], + target=precomputed, + loss_mask=batch["loss_mask"], + hidden_states=batch["hidden_states"], + ) + results[backend] = plosses + + for i, (sdpa_loss, flex_loss) in enumerate(zip(results["sdpa"], results["flex_attention"])): + torch.testing.assert_close( + sdpa_loss, + flex_loss, + atol=1e-2, + rtol=1e-2, + msg=f"SDPA vs Flex loss mismatch at step {i}", + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/tools/convert_to_hf.py b/tools/convert_to_hf.py index 7942b6e..ce10758 100644 --- a/tools/convert_to_hf.py +++ b/tools/convert_to_hf.py @@ -235,13 +235,30 @@ def _save_with_vocab_pruning( lm_head_key = "lm_head.weight" if lm_head_key in tensors: - target_ids = torch.arange(draft_vocab_size) + d2t - logger.info( - "Trimming lm_head from %d to %d", - tensors[lm_head_key].shape[0], - draft_vocab_size, - ) - tensors[lm_head_key] = tensors[lm_head_key][target_ids] + current_size = tensors[lm_head_key].shape[0] + if current_size == vocab_size and current_size != draft_vocab_size: + # lm_head is full vocab — prune to draft vocab using d2t mapping + logger.info( + "Trimming lm_head from %d to %d using d2t mapping", + current_size, + draft_vocab_size, + ) + tensors[lm_head_key] = tensors[lm_head_key][d2t] + elif current_size == draft_vocab_size: + raise ValueError( + f"lm_head is already pruned to draft_vocab_size ({current_size}). " + f"This model was trained with vocabulary pruning, so the lm_head weight " + f"ordering is tied to the training data's token mapping. " + f"Post-training re-pruning with a different dataset is not supported. " + f"Retrain with draft_vocab_size == vocab_size (full vocab) for post-training pruning." + ) + else: + logger.warning( + "lm_head size (%d) matches neither vocab_size (%d) nor draft_vocab_size (%d), skipping trim", + current_size, + vocab_size, + draft_vocab_size, + ) save_file(tensors, os.path.join(output_dir, "model.safetensors")) @@ -434,8 +451,15 @@ def _convert_fsdp_to_hf( with open(config_path) as f: raw_config = json.load(f) vocab_size = raw_config["vocab_size"] + config_draft_vocab_size = raw_config.get("draft_vocab_size") or vocab_size if not prune_vocab: + if config_draft_vocab_size != vocab_size: + raise ValueError( + f"draft_vocab_size ({config_draft_vocab_size}) != vocab_size ({vocab_size}) " + f"in {config_path}. This model was trained with vocabulary pruning. " + f"Use --prune-vocab to generate t2d/d2t token mappings during conversion." + ) _save_without_vocab_pruning(hf_model, output_dir, config_path, vocab_size) return diff --git a/torchspec/config/utils.py b/torchspec/config/utils.py index 1c9b618..630cd4f 100644 --- a/torchspec/config/utils.py +++ b/torchspec/config/utils.py @@ -135,8 +135,8 @@ def generate_draft_model_config( value = str(value).replace("torch.", "") else: value = _copy_config_value(value) - if target_param == "rope_scaling": - value = _normalize_rope_scaling(value) + if target_param == "rope_scaling": + value = _normalize_rope_scaling(value) draft_config[draft_param] = value draft_config["num_hidden_layers"] = 1 diff --git a/torchspec/models/draft/__init__.py b/torchspec/models/draft/__init__.py index 0b5afbe..631fb8f 100644 --- a/torchspec/models/draft/__init__.py +++ b/torchspec/models/draft/__init__.py @@ -20,11 +20,13 @@ from torchspec.models.draft.auto import AutoDraftModelConfig, AutoEagle3DraftModel from torchspec.models.draft.base import Eagle3DraftModel +from torchspec.models.draft.deepseek_eagle import DeepSeekForCausalLMEagle3 from torchspec.models.draft.llama3_eagle import LlamaForCausalLMEagle3 __all__ = [ "AutoDraftModelConfig", "AutoEagle3DraftModel", + "DeepSeekForCausalLMEagle3", "Eagle3DraftModel", "LlamaForCausalLMEagle3", ] diff --git a/torchspec/models/draft/auto.py b/torchspec/models/draft/auto.py index 6bb0836..5f20623 100644 --- a/torchspec/models/draft/auto.py +++ b/torchspec/models/draft/auto.py @@ -24,7 +24,9 @@ from transformers import AutoModelForCausalLM as AutoModelForCausalLMBase from transformers import LlamaConfig, PretrainedConfig, modeling_utils +from transformers.models.deepseek_v3.configuration_deepseek_v3 import DeepseekV3Config +from torchspec.models.draft.deepseek_eagle import DeepSeekForCausalLMEagle3 from torchspec.models.draft.llama3_eagle import LlamaForCausalLMEagle3 from torchspec.utils.logging import logger @@ -32,6 +34,7 @@ class AutoEagle3DraftModel(AutoModelForCausalLMBase): _model_mapping = { LlamaConfig: LlamaForCausalLMEagle3, + DeepseekV3Config: DeepSeekForCausalLMEagle3, } @classmethod @@ -70,6 +73,8 @@ def filtered_warning(msg): class AutoDraftModelConfig: _config_mapping = { "LlamaForCausalLMEagle3": LlamaConfig, + "DeepSeekForCausalLMEagle3": DeepseekV3Config, + "Eagle3DeepseekV2ForCausalLM": DeepseekV3Config, } @classmethod diff --git a/torchspec/models/draft/deepseek_eagle.py b/torchspec/models/draft/deepseek_eagle.py new file mode 100644 index 0000000..d40dc41 --- /dev/null +++ b/torchspec/models/draft/deepseek_eagle.py @@ -0,0 +1,555 @@ +# Copyright (c) 2026 LightSeek Foundation +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +"""DeepSeek MLA (Multi-head Latent Attention) Eagle3 draft model for training.""" + +import math +from functools import partial +from typing import Optional, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.nn.attention.flex_attention import create_block_mask, flex_attention +from transformers.models.deepseek_v3.configuration_deepseek_v3 import DeepseekV3Config + +from torchspec.models.draft.base import Eagle3DraftModel + +# TODO: Extract shared components into a common module to reduce duplication: +# - LlamaMLP, LlamaRMSNorm, RoPE classes → torchspec/models/draft/modules.py +# - _init_rope() is near-identical to LlamaAttention._init_rope (~60 lines) +# - DecoderLayer.forward() is line-for-line identical to LlamaDecoderLayer.forward() +# - embed_input_ids/project_hidden_states/compute_logits/backbone are identical +# to LlamaForCausalLMEagle3 and could live in Eagle3DraftModel base class +# - Suffix attention loop could be batched (einsum instead of Python loop) for +# both this file and llama3_eagle.py +from torchspec.models.draft.llama3_eagle import ( + LlamaDynamicNTKScalingRotaryEmbedding, + LlamaLinearScalingRotaryEmbedding, + LlamaMLP, + LlamaRMSNorm, + LlamaRotaryEmbedding, + LlamaYarnRotaryEmbedding, + apply_rotary_pos_emb, + yarn_get_mscale, +) +from torchspec.models.ops.flex_attention import ( + compile_friendly_create_block_mask, + compile_friendly_flex_attention, + generate_eagle3_mask, +) +from torchspec.utils.logging import logger, print_with_rank + + +def _rope_config_get(rope_scaling, key, default=None): + """Get a value from rope_scaling config (dict or object).""" + if isinstance(rope_scaling, dict): + return rope_scaling.get(key, default) + return getattr(rope_scaling, key, default) + + +class DeepSeekMLAAttention(nn.Module): + """MLA attention (SDPA backend) for Eagle3 draft model training. + + Implements the MLA forward path from DeepSeek-V2/V3: + Q: down_proj -> layernorm -> up_proj (optional LoRA compression) + KV: down_proj -> split(compressed, k_rope) -> layernorm -> up_proj -> split(k_nope, value) + RoPE applied only to qk_rope_head_dim dimensions. + + Supports both cached (EAGLE3 suffix pattern) and non-cached attention paths. + """ + + def __init__(self, config: DeepseekV3Config): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.q_lora_rank = config.q_lora_rank + self.kv_lora_rank = config.kv_lora_rank + self.qk_nope_head_dim = config.qk_nope_head_dim + self.qk_rope_head_dim = config.qk_rope_head_dim + self.v_head_dim = config.v_head_dim + self.qk_head_dim = self.qk_nope_head_dim + self.qk_rope_head_dim + self.max_position_embeddings = config.max_position_embeddings + + # Eagle3: attention input is cat(input_emb, hidden_states) = hidden_size * 2 + input_dim = self.hidden_size * 2 + + # Q path + if self.q_lora_rank is not None: + self.q_a_proj = nn.Linear(input_dim, self.q_lora_rank, bias=False) + self.q_a_layernorm = LlamaRMSNorm(self.q_lora_rank, eps=config.rms_norm_eps) + self.q_b_proj = nn.Linear( + self.q_lora_rank, self.num_heads * self.qk_head_dim, bias=False + ) + else: + self.q_proj = nn.Linear(input_dim, self.num_heads * self.qk_head_dim, bias=False) + + # KV path + self.kv_a_proj_with_mqa = nn.Linear( + input_dim, self.kv_lora_rank + self.qk_rope_head_dim, bias=False + ) + self.kv_a_layernorm = LlamaRMSNorm(self.kv_lora_rank, eps=config.rms_norm_eps) + self.kv_b_proj = nn.Linear( + self.kv_lora_rank, + self.num_heads * (self.qk_nope_head_dim + self.v_head_dim), + bias=False, + ) + + # Output projection: v_head_dim per head (NOT qk_head_dim) + self.o_proj = nn.Linear(self.num_heads * self.v_head_dim, self.hidden_size, bias=False) + + # RoPE on qk_rope_head_dim only + self._init_rope() + + # Softmax scale with optional YaRN mscale + self.softmax_scale = self._compute_softmax_scale() + + def _init_rope(self): + """Initialize rotary embeddings with qk_rope_head_dim as the dimension.""" + rope_dim = self.qk_rope_head_dim + rope_scaling = self.config.rope_scaling + rope_theta = getattr(self.config, "rope_theta", 10000) + + if rope_scaling is None: + self.rotary_emb = LlamaRotaryEmbedding( + rope_dim, + max_position_embeddings=self.max_position_embeddings, + base=rope_theta, + ) + else: + rget = partial(_rope_config_get, rope_scaling) + scaling_type = rget("rope_type", rget("type")) + scaling_factor = rget("factor") + + if scaling_type in (None, "default"): + # DeepseekV3Config may set rope_scaling={"rope_type": "default"} + # when no explicit scaling is provided. Treat as standard RoPE. + self.rotary_emb = LlamaRotaryEmbedding( + rope_dim, + max_position_embeddings=self.max_position_embeddings, + base=rope_theta, + ) + elif scaling_type == "linear": + self.rotary_emb = LlamaLinearScalingRotaryEmbedding( + rope_dim, + max_position_embeddings=self.max_position_embeddings, + scaling_factor=scaling_factor, + ) + elif scaling_type == "dynamic": + self.rotary_emb = LlamaDynamicNTKScalingRotaryEmbedding( + rope_dim, + max_position_embeddings=self.max_position_embeddings, + scaling_factor=scaling_factor, + ) + elif scaling_type == "llama3": + self.rotary_emb = LlamaRotaryEmbedding( + rope_dim, + max_position_embeddings=self.max_position_embeddings, + base=rope_theta, + scaling_factor=scaling_factor if scaling_factor is not None else 1.0, + low_freq_factor=rget("low_freq_factor"), + high_freq_factor=rget("high_freq_factor"), + orig_max_position=rget("original_max_position_embeddings"), + ) + elif scaling_type == "yarn": + self.rotary_emb = LlamaYarnRotaryEmbedding( + rope_dim, + max_position_embeddings=self.max_position_embeddings, + original_max_position_embeddings=rget("original_max_position_embeddings"), + scaling_factor=scaling_factor, + beta_fast=rget("beta_fast"), + beta_slow=rget("beta_slow"), + mscale=rget("mscale"), + mscale_all_dim=rget("mscale_all_dim"), + ) + else: + raise ValueError(f"Unknown RoPE scaling type {scaling_type}") + + def _compute_softmax_scale(self) -> float: + """Compute softmax scale, incorporating YaRN mscale if applicable.""" + rope_scaling = self.config.rope_scaling + if rope_scaling is not None: + rget = partial(_rope_config_get, rope_scaling) + scaling_type = rget("rope_type", rget("type")) + if scaling_type == "yarn": + factor = rget("factor", 1.0) + mscale_all_dim = rget("mscale_all_dim", 0) + mscale = yarn_get_mscale(factor, mscale_all_dim) + return (mscale * mscale) / math.sqrt(self.qk_head_dim) + return 1.0 / math.sqrt(self.qk_head_dim) + + def _project_qkv( + self, hidden_states: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """Project hidden_states to Q, K_nope, K_rope_raw, V. + + Returns: + query: [B, H, S, qk_head_dim] + k_nope: [B, H, S, qk_nope_head_dim] + k_rope_raw: [B, 1, S, qk_rope_head_dim] (single head, before RoPE) + value: [B, H, S, v_head_dim] + """ + bsz, q_len, _ = hidden_states.size() + + # Q projection + if self.q_lora_rank is not None: + q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states))) + else: + q = self.q_proj(hidden_states) + q = q.view(bsz, q_len, self.num_heads, self.qk_head_dim).transpose(1, 2) + + # KV down projection + split + kv_combined = self.kv_a_proj_with_mqa(hidden_states) + kv_compressed, k_rope_raw = torch.split( + kv_combined, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1 + ) + + # KV up projection + kv = self.kv_b_proj(self.kv_a_layernorm(kv_compressed)) + kv = kv.view(bsz, q_len, self.num_heads, self.qk_nope_head_dim + self.v_head_dim) + k_nope, value = torch.split(kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1) + + # Reshape: k_nope [B,S,H,D] -> [B,H,S,D], value same + k_nope = k_nope.transpose(1, 2) + value = value.transpose(1, 2) + + # k_rope: [B, S, rope_dim] -> [B, 1, S, rope_dim] + k_rope_raw = k_rope_raw.unsqueeze(1) + + return q, k_nope, k_rope_raw, value + + def _apply_rope_and_assemble( + self, + query_states: torch.Tensor, + k_nope: torch.Tensor, + k_rope_raw: torch.Tensor, + position_ids: torch.Tensor, + lck: int, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Apply RoPE to rope dims and assemble full Q, K. + + Returns: + query_states: [B, H, S, qk_head_dim] + key_states: [B, H, S, qk_head_dim] + """ + bsz, num_heads, q_len, _ = query_states.shape + + q_nope = query_states[..., : self.qk_nope_head_dim] + q_rope = query_states[..., self.qk_nope_head_dim :] + + cos, sin = self.rotary_emb(q_rope, seq_len=q_len + lck) + cos, sin = cos.to(q_rope.device), sin.to(q_rope.device) + q_rope, k_rope = apply_rotary_pos_emb(q_rope, k_rope_raw, cos, sin, position_ids + lck) + + # Expand k_rope from [B, 1, S, rope_dim] to [B, H, S, rope_dim] + k_rope = k_rope.expand(-1, self.num_heads, -1, -1) + + query_states = torch.cat([q_nope, q_rope], dim=-1) + key_states = torch.cat([k_nope, k_rope], dim=-1) + + return query_states, key_states + + def forward( + self, + hidden_states: torch.Tensor, + cache_keys: Optional[torch.Tensor] = None, + cache_values: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + use_cache: bool = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]: + bsz, q_len, _ = hidden_states.size() + + query_states, k_nope, k_rope_raw, value_states = self._project_qkv(hidden_states) + + if not use_cache: + query_states, key_states = self._apply_rope_and_assemble( + query_states, k_nope, k_rope_raw, position_ids, lck=0 + ) + + attn_output = F.scaled_dot_product_attention( + query_states, + key_states, + value_states, + attn_mask=attention_mask, + is_causal=attention_mask is None, + dropout_p=0.0, + scale=self.softmax_scale, + ) + + else: + # Cached path with EAGLE3 suffix attention pattern + # cache_keys shape: [bsz, num_heads, num_cached, seq_len, qk_head_dim] + # cache_values shape: [bsz, num_heads, num_cached, seq_len, v_head_dim] + lck = 0 if cache_keys is None else cache_keys.shape[2] + + query_states, key_states = self._apply_rope_and_assemble( + query_states, k_nope, k_rope_raw, position_ids, lck=lck + ) + + # Append to 5D tensor cache (K and V have different last dims) + if cache_keys is None: + cache_keys = key_states.unsqueeze(2) + cache_values = value_states.unsqueeze(2) + else: + cache_keys = torch.cat([cache_keys, key_states.unsqueeze(2)], dim=2) + cache_values = torch.cat([cache_values, value_states.unsqueeze(2)], dim=2) + + lck = cache_keys.shape[2] + k0 = cache_keys[:, :, 0] + v0 = cache_values[:, :, 0] + + # Causal attention on k0 + attn_weights = torch.matmul(query_states, k0.transpose(2, 3)) * self.softmax_scale + attn_weights = attn_weights + attention_mask + + # Suffix diagonal attention on k1..kN + for i in range(1, lck): + ki = cache_keys[:, :, i] + attn_weightsi = (query_states * ki).sum(-1) * self.softmax_scale + attn_weights = torch.cat((attn_weights, attn_weightsi[..., None]), dim=-1) + + # Upcast to fp32 for softmax + attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to( + query_states.dtype + ) + attn_weights0 = attn_weights[..., :q_len] + attn_output = torch.matmul(attn_weights0, v0) + + for i in range(1, lck): + vi = cache_values[:, :, i] + attn_weightsi = attn_weights[..., q_len + i - 1] + attn_outputi = attn_weightsi[..., None] * vi + attn_output = attn_output + attn_outputi + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.v_head_dim) + attn_output = self.o_proj(attn_output) + + return attn_output, cache_keys, cache_values + + +class DeepSeekMLAFlexAttention(DeepSeekMLAAttention): + """MLA attention with flex_attention backend. + + Cache is concatenated along the seq dimension (not 5D): + cache_keys: [B, H, total_seq, qk_head_dim] + cache_values: [B, H, total_seq, v_head_dim] + + EAGLE3 mask pattern is handled by generate_eagle3_mask + create_block_mask. + """ + + def forward( + self, + hidden_states: torch.Tensor, + cache_keys: Optional[torch.Tensor] = None, + cache_values: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + use_cache: bool = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]: + bsz, q_len, _ = hidden_states.size() + + query_states, k_nope, k_rope_raw, value_states = self._project_qkv(hidden_states) + + # cache_keys shape: [bsz, num_heads, past_seq_len, qk_head_dim] (concatenated along seq) + past_seen_tokens = cache_keys.shape[2] if cache_keys is not None else 0 + lck = past_seen_tokens // q_len + + query_states, key_states = self._apply_rope_and_assemble( + query_states, k_nope, k_rope_raw, position_ids, lck=lck + ) + + # Concatenate along seq dimension + if cache_keys is not None: + key_cache = torch.cat([cache_keys, key_states], dim=2) + value_cache = torch.cat([cache_values, value_states], dim=2) + else: + key_cache = key_states + value_cache = value_states + + # Build EAGLE3 block mask from attention_mask (seq_lengths) + seq_lengths = attention_mask.sum(dim=-1) + seq_lengths -= lck + + if q_len <= 128: + create_block_mask_func = create_block_mask + flex_attention_func = flex_attention + else: + create_block_mask_func = compile_friendly_create_block_mask + flex_attention_func = compile_friendly_flex_attention + + block_mask = create_block_mask_func( + mask_mod=generate_eagle3_mask( + seq_lengths=seq_lengths, + Q_LEN=q_len, + KV_LEN=key_cache.shape[-2], + lck=lck, + ), + B=bsz, + H=1, # Rely on broadcast + Q_LEN=q_len, + KV_LEN=key_cache.shape[-2], + device=query_states.device, + ) + + attn_output = flex_attention_func( + query=query_states, + key=key_cache.contiguous(), + value=value_cache.contiguous(), + block_mask=block_mask, + scale=self.softmax_scale, + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.v_head_dim) + attn_output = self.o_proj(attn_output) + + return attn_output, key_cache, value_cache + + +class DeepSeekDecoderLayer(nn.Module): + def __init__(self, config: DeepseekV3Config, attention_backend: str = "sdpa"): + super().__init__() + self.hidden_size = config.hidden_size + + if attention_backend == "sdpa": + self.self_attn = DeepSeekMLAAttention(config=config) + elif attention_backend == "flex_attention": + print_with_rank("Using flex attention on MLA draft model training!") + self.self_attn = DeepSeekMLAFlexAttention(config=config) + else: + raise ValueError( + f"DeepSeekDecoderLayer supports 'sdpa' and 'flex_attention' backends, " + f"got '{attention_backend}'" + ) + + self.attention_backend = attention_backend + self.mlp = LlamaMLP(config) + self.hidden_norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + input_emb: torch.Tensor, + hidden_states: torch.Tensor, + cache_keys: Optional[torch.Tensor] = None, + cache_values: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + use_cache: bool = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]: + residual = hidden_states + + hidden_states = self.hidden_norm(hidden_states) + input_emb = self.input_layernorm(input_emb) + + # Eagle3: concatenate input embedding and hidden states + hidden_states = torch.cat((input_emb, hidden_states), dim=-1) + + # Self Attention + hidden_states, cache_keys, cache_values = self.self_attn( + hidden_states=hidden_states, + cache_keys=cache_keys, + cache_values=cache_values, + attention_mask=attention_mask, + position_ids=position_ids, + use_cache=use_cache, + ) + hidden_states = residual + hidden_states + + # MLP + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + return hidden_states, cache_keys, cache_values + + +class DeepSeekForCausalLMEagle3(Eagle3DraftModel): + """Eagle3 draft model using DeepSeek MLA attention.""" + + config_class = DeepseekV3Config + + def __init__(self, config: DeepseekV3Config, attention_backend: str = "sdpa") -> None: + super().__init__(config) + + self.target_vocab_size = config.vocab_size + self.vocab_size = getattr(config, "draft_vocab_size", None) or config.vocab_size + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, config.pad_token_id) + self.midlayer = DeepSeekDecoderLayer(config, attention_backend=attention_backend) + + target_hidden_size = getattr(config, "target_hidden_size", config.hidden_size) + self.fc = nn.Linear(target_hidden_size * 3, config.hidden_size, bias=False) + + self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.lm_head = nn.Linear(config.hidden_size, self.vocab_size, bias=False) + + if self.vocab_size != self.target_vocab_size: + self.register_buffer("t2d", torch.ones(self.target_vocab_size, dtype=torch.bool)) + self.register_buffer("d2t", torch.zeros(self.vocab_size, dtype=torch.int64)) + + logger.info( + f"DeepSeekForCausalLMEagle3: hidden_size={config.hidden_size}, " + f"num_heads={config.num_attention_heads}, " + f"kv_lora_rank={config.kv_lora_rank}, " + f"q_lora_rank={config.q_lora_rank}, " + f"qk_nope={config.qk_nope_head_dim}, qk_rope={config.qk_rope_head_dim}, " + f"v_head={config.v_head_dim}, " + f"vocab={self.vocab_size}/{self.target_vocab_size}" + ) + + def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_tokens(input_ids) + + def project_hidden_states(self, hidden_states: torch.Tensor) -> torch.Tensor: + expected_size = self.fc.in_features + if hidden_states.size(-1) != expected_size: + raise ValueError( + f"Target hidden states size mismatch: {hidden_states.size(-1)} != expected: {expected_size}" + ) + return self.fc(hidden_states) + + def compute_logits(self, hidden_states: torch.Tensor) -> torch.Tensor: + norm_hidden_states = self.norm(hidden_states) + return self.lm_head(norm_hidden_states) + + def backbone( + self, + input_embeds: torch.Tensor, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor, + position_ids: torch.Tensor, + cache_keys: Optional[torch.Tensor] = None, + cache_values: Optional[torch.Tensor] = None, + use_cache: bool = True, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]: + return self.midlayer( + input_emb=input_embeds, + hidden_states=hidden_states, + cache_keys=cache_keys, + cache_values=cache_values, + attention_mask=attention_mask, + position_ids=position_ids, + use_cache=use_cache, + )