diff --git a/configs/draft_models/kimi_k25_eagle3.json b/configs/draft_models/kimi_k25_eagle3.json index b61fd04..d46e79f 100644 --- a/configs/draft_models/kimi_k25_eagle3.json +++ b/configs/draft_models/kimi_k25_eagle3.json @@ -18,8 +18,16 @@ "num_hidden_layers": 1, "num_key_value_heads": 64, "rms_norm_eps": 1e-06, - "rope_scaling": null, - "rope_theta": 1000000, + "rope_scaling": { + "beta_fast": 32.0, + "beta_slow": 1.0, + "factor": 64.0, + "mscale": 1.0, + "mscale_all_dim": 1.0, + "original_max_position_embeddings": 4096, + "type": "yarn" + }, + "rope_theta": 50000.0, "sliding_window": null, "tie_word_embeddings": false, "torch_dtype": "bfloat16", diff --git a/tests/test_draft_model_config.py b/tests/test_draft_model_config.py new file mode 100644 index 0000000..7504829 --- /dev/null +++ b/tests/test_draft_model_config.py @@ -0,0 +1,128 @@ +from types import SimpleNamespace + +from torchspec.config.utils import generate_draft_model_config + + +def test_generate_draft_model_config_preserves_rope_fields(monkeypatch): + rope_scaling = { + "type": "yarn", + "factor": 64.0, + "original_max_position_embeddings": 4096, + "beta_fast": 32.0, + "beta_slow": 1.0, + "mscale": 1.0, + "mscale_all_dim": 1.0, + } + text_config = SimpleNamespace( + vocab_size=32000, + hidden_size=4096, + num_attention_heads=32, + num_key_value_heads=8, + intermediate_size=14336, + max_position_embeddings=262144, + rope_theta=5000000, + rope_scaling=rope_scaling, + rms_norm_eps=1e-6, + hidden_act="silu", + bos_token_id=1, + eos_token_id=2, + torch_dtype="bfloat16", + ) + target_config = SimpleNamespace(text_config=text_config) + + class DummyTokenizer: + def __len__(self): + return 32000 + + monkeypatch.setattr( + "torchspec.config.utils.AutoConfig.from_pretrained", + lambda *args, **kwargs: target_config, + ) + monkeypatch.setattr( + "torchspec.config.utils.AutoTokenizer.from_pretrained", + lambda *args, **kwargs: DummyTokenizer(), + ) + + draft_config = generate_draft_model_config("dummy-model") + + assert draft_config["max_position_embeddings"] == 262144 + assert draft_config["rope_theta"] == 5000000 + assert draft_config["rope_scaling"] == rope_scaling + + +def test_generate_draft_model_config_copies_rope_scaling(monkeypatch): + rope_scaling = {"type": "yarn", "factor": 8.0, "original_max_position_embeddings": 8192} + text_config = SimpleNamespace( + vocab_size=32000, + hidden_size=2048, + num_attention_heads=16, + num_key_value_heads=4, + intermediate_size=8192, + max_position_embeddings=65536, + rope_theta=1000000, + rope_scaling=rope_scaling, + rms_norm_eps=1e-6, + hidden_act="silu", + bos_token_id=1, + eos_token_id=2, + torch_dtype="bfloat16", + ) + target_config = SimpleNamespace(text_config=text_config) + + class DummyTokenizer: + def __len__(self): + return 32000 + + monkeypatch.setattr( + "torchspec.config.utils.AutoConfig.from_pretrained", + lambda *args, **kwargs: target_config, + ) + monkeypatch.setattr( + "torchspec.config.utils.AutoTokenizer.from_pretrained", + lambda *args, **kwargs: DummyTokenizer(), + ) + + draft_config = generate_draft_model_config("dummy-model") + rope_scaling["factor"] = 999.0 + + assert draft_config["rope_scaling"]["factor"] == 8.0 + + +def test_generate_draft_model_config_fills_yarn_defaults(monkeypatch): + rope_scaling = {"type": "yarn", "factor": 8.0, "original_max_position_embeddings": 8192} + text_config = SimpleNamespace( + vocab_size=32000, + hidden_size=2048, + num_attention_heads=16, + num_key_value_heads=4, + intermediate_size=8192, + max_position_embeddings=65536, + rope_theta=1000000, + rope_scaling=rope_scaling, + rms_norm_eps=1e-6, + hidden_act="silu", + bos_token_id=1, + eos_token_id=2, + torch_dtype="bfloat16", + ) + target_config = SimpleNamespace(text_config=text_config) + + class DummyTokenizer: + def __len__(self): + return 32000 + + monkeypatch.setattr( + "torchspec.config.utils.AutoConfig.from_pretrained", + lambda *args, **kwargs: target_config, + ) + monkeypatch.setattr( + "torchspec.config.utils.AutoTokenizer.from_pretrained", + lambda *args, **kwargs: DummyTokenizer(), + ) + + draft_config = generate_draft_model_config("dummy-model") + + assert draft_config["rope_scaling"]["beta_fast"] == 32.0 + assert draft_config["rope_scaling"]["beta_slow"] == 1.0 + assert draft_config["rope_scaling"]["mscale"] == 1.0 + assert draft_config["rope_scaling"]["mscale_all_dim"] == 0.0 diff --git a/tests/test_eagle3_loss.py b/tests/test_eagle3_loss.py index ab0feb0..aa148b6 100644 --- a/tests/test_eagle3_loss.py +++ b/tests/test_eagle3_loss.py @@ -294,6 +294,42 @@ def test_losses_match_cuda(self): ) +class TestRotaryConfigWiring(unittest.TestCase): + """Model config should fully wire RoPE settings into rotary embeddings.""" + + def test_yarn_uses_rope_theta_as_base(self): + config = LlamaConfig( + hidden_size=128, + num_attention_heads=4, + num_key_value_heads=4, + intermediate_size=512, + max_position_embeddings=262144, + vocab_size=256, + hidden_act="silu", + rms_norm_eps=1e-6, + rope_theta=50000.0, + rope_scaling={ + "type": "yarn", + "factor": 64.0, + "original_max_position_embeddings": 4096, + "beta_fast": 32.0, + "beta_slow": 1.0, + "mscale": 1.0, + "mscale_all_dim": 1.0, + }, + pretraining_tp=1, + pad_token_id=0, + ) + config.draft_vocab_size = 256 + + model = LlamaForCausalLMEagle3(config, attention_backend="sdpa") + rotary = model.midlayer.self_attn.rotary_emb + + self.assertEqual(rotary.base, 50000.0) + self.assertEqual(rotary.original_max_position_embeddings, 4096) + self.assertEqual(rotary.scaling_factor, 64.0) + + def _make_mask_patterns(BT): """Return (name, valid_idx) pairs covering diverse masking patterns.""" patterns = [] diff --git a/torchspec/config/utils.py b/torchspec/config/utils.py index 2a4d030..1c9b618 100644 --- a/torchspec/config/utils.py +++ b/torchspec/config/utils.py @@ -18,6 +18,7 @@ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. +import copy import json import logging import warnings @@ -28,6 +29,35 @@ logger = logging.getLogger(__name__) +def _copy_config_value(value): + if hasattr(value, "to_dict"): + return value.to_dict() + return copy.deepcopy(value) + + +def _normalize_rope_scaling(rope_scaling): + if rope_scaling is None: + return None + + normalized = _copy_config_value(rope_scaling) + if not isinstance(normalized, dict): + return normalized + + scaling_type = normalized.get("rope_type", normalized.get("type")) + if scaling_type == "yarn": + yarn_defaults = { + "beta_fast": 32.0, + "beta_slow": 1.0, + "mscale": 1.0, + "mscale_all_dim": 0.0, + } + for key, default in yarn_defaults.items(): + if normalized.get(key) is None: + normalized[key] = default + + return normalized + + def generate_draft_model_config( target_model_path: str, template_config_path: str = None, cache_dir: str = None ): @@ -85,6 +115,8 @@ def generate_draft_model_config( "num_key_value_heads": "num_key_value_heads", "intermediate_size": "intermediate_size", "max_position_embeddings": "max_position_embeddings", + "rope_theta": "rope_theta", + "rope_scaling": "rope_scaling", "rms_norm_eps": "rms_norm_eps", "hidden_act": "hidden_act", "bos_token_id": "bos_token_id", @@ -101,6 +133,10 @@ def generate_draft_model_config( continue if target_param == "torch_dtype" and isinstance(value, torch.dtype): value = str(value).replace("torch.", "") + else: + value = _copy_config_value(value) + if target_param == "rope_scaling": + value = _normalize_rope_scaling(value) draft_config[draft_param] = value draft_config["num_hidden_layers"] = 1 diff --git a/torchspec/models/draft/llama3_eagle.py b/torchspec/models/draft/llama3_eagle.py index e8b73f2..adad71a 100644 --- a/torchspec/models/draft/llama3_eagle.py +++ b/torchspec/models/draft/llama3_eagle.py @@ -1071,6 +1071,7 @@ def rope_get(key, default=None): self.rotary_emb = LlamaYarnRotaryEmbedding( self.head_dim, max_position_embeddings=self.max_position_embeddings, + base=getattr(self.config, "rope_theta", 10000), original_max_position_embeddings=rope_get("original_max_position_embeddings"), scaling_factor=scaling_factor, beta_fast=rope_get("beta_fast"),