Skip to content

Commit 06a9760

Browse files
authored
Fix kimi yarn settings for draft model (#54)
1 parent ebf869e commit 06a9760

5 files changed

Lines changed: 211 additions & 2 deletions

File tree

configs/draft_models/kimi_k25_eagle3.json

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,16 @@
1818
"num_hidden_layers": 1,
1919
"num_key_value_heads": 64,
2020
"rms_norm_eps": 1e-06,
21-
"rope_scaling": null,
22-
"rope_theta": 1000000,
21+
"rope_scaling": {
22+
"beta_fast": 32.0,
23+
"beta_slow": 1.0,
24+
"factor": 64.0,
25+
"mscale": 1.0,
26+
"mscale_all_dim": 1.0,
27+
"original_max_position_embeddings": 4096,
28+
"type": "yarn"
29+
},
30+
"rope_theta": 50000.0,
2331
"sliding_window": null,
2432
"tie_word_embeddings": false,
2533
"torch_dtype": "bfloat16",

tests/test_draft_model_config.py

Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
from types import SimpleNamespace
2+
3+
from torchspec.config.utils import generate_draft_model_config
4+
5+
6+
def test_generate_draft_model_config_preserves_rope_fields(monkeypatch):
7+
rope_scaling = {
8+
"type": "yarn",
9+
"factor": 64.0,
10+
"original_max_position_embeddings": 4096,
11+
"beta_fast": 32.0,
12+
"beta_slow": 1.0,
13+
"mscale": 1.0,
14+
"mscale_all_dim": 1.0,
15+
}
16+
text_config = SimpleNamespace(
17+
vocab_size=32000,
18+
hidden_size=4096,
19+
num_attention_heads=32,
20+
num_key_value_heads=8,
21+
intermediate_size=14336,
22+
max_position_embeddings=262144,
23+
rope_theta=5000000,
24+
rope_scaling=rope_scaling,
25+
rms_norm_eps=1e-6,
26+
hidden_act="silu",
27+
bos_token_id=1,
28+
eos_token_id=2,
29+
torch_dtype="bfloat16",
30+
)
31+
target_config = SimpleNamespace(text_config=text_config)
32+
33+
class DummyTokenizer:
34+
def __len__(self):
35+
return 32000
36+
37+
monkeypatch.setattr(
38+
"torchspec.config.utils.AutoConfig.from_pretrained",
39+
lambda *args, **kwargs: target_config,
40+
)
41+
monkeypatch.setattr(
42+
"torchspec.config.utils.AutoTokenizer.from_pretrained",
43+
lambda *args, **kwargs: DummyTokenizer(),
44+
)
45+
46+
draft_config = generate_draft_model_config("dummy-model")
47+
48+
assert draft_config["max_position_embeddings"] == 262144
49+
assert draft_config["rope_theta"] == 5000000
50+
assert draft_config["rope_scaling"] == rope_scaling
51+
52+
53+
def test_generate_draft_model_config_copies_rope_scaling(monkeypatch):
54+
rope_scaling = {"type": "yarn", "factor": 8.0, "original_max_position_embeddings": 8192}
55+
text_config = SimpleNamespace(
56+
vocab_size=32000,
57+
hidden_size=2048,
58+
num_attention_heads=16,
59+
num_key_value_heads=4,
60+
intermediate_size=8192,
61+
max_position_embeddings=65536,
62+
rope_theta=1000000,
63+
rope_scaling=rope_scaling,
64+
rms_norm_eps=1e-6,
65+
hidden_act="silu",
66+
bos_token_id=1,
67+
eos_token_id=2,
68+
torch_dtype="bfloat16",
69+
)
70+
target_config = SimpleNamespace(text_config=text_config)
71+
72+
class DummyTokenizer:
73+
def __len__(self):
74+
return 32000
75+
76+
monkeypatch.setattr(
77+
"torchspec.config.utils.AutoConfig.from_pretrained",
78+
lambda *args, **kwargs: target_config,
79+
)
80+
monkeypatch.setattr(
81+
"torchspec.config.utils.AutoTokenizer.from_pretrained",
82+
lambda *args, **kwargs: DummyTokenizer(),
83+
)
84+
85+
draft_config = generate_draft_model_config("dummy-model")
86+
rope_scaling["factor"] = 999.0
87+
88+
assert draft_config["rope_scaling"]["factor"] == 8.0
89+
90+
91+
def test_generate_draft_model_config_fills_yarn_defaults(monkeypatch):
92+
rope_scaling = {"type": "yarn", "factor": 8.0, "original_max_position_embeddings": 8192}
93+
text_config = SimpleNamespace(
94+
vocab_size=32000,
95+
hidden_size=2048,
96+
num_attention_heads=16,
97+
num_key_value_heads=4,
98+
intermediate_size=8192,
99+
max_position_embeddings=65536,
100+
rope_theta=1000000,
101+
rope_scaling=rope_scaling,
102+
rms_norm_eps=1e-6,
103+
hidden_act="silu",
104+
bos_token_id=1,
105+
eos_token_id=2,
106+
torch_dtype="bfloat16",
107+
)
108+
target_config = SimpleNamespace(text_config=text_config)
109+
110+
class DummyTokenizer:
111+
def __len__(self):
112+
return 32000
113+
114+
monkeypatch.setattr(
115+
"torchspec.config.utils.AutoConfig.from_pretrained",
116+
lambda *args, **kwargs: target_config,
117+
)
118+
monkeypatch.setattr(
119+
"torchspec.config.utils.AutoTokenizer.from_pretrained",
120+
lambda *args, **kwargs: DummyTokenizer(),
121+
)
122+
123+
draft_config = generate_draft_model_config("dummy-model")
124+
125+
assert draft_config["rope_scaling"]["beta_fast"] == 32.0
126+
assert draft_config["rope_scaling"]["beta_slow"] == 1.0
127+
assert draft_config["rope_scaling"]["mscale"] == 1.0
128+
assert draft_config["rope_scaling"]["mscale_all_dim"] == 0.0

tests/test_eagle3_loss.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -294,6 +294,42 @@ def test_losses_match_cuda(self):
294294
)
295295

296296

297+
class TestRotaryConfigWiring(unittest.TestCase):
298+
"""Model config should fully wire RoPE settings into rotary embeddings."""
299+
300+
def test_yarn_uses_rope_theta_as_base(self):
301+
config = LlamaConfig(
302+
hidden_size=128,
303+
num_attention_heads=4,
304+
num_key_value_heads=4,
305+
intermediate_size=512,
306+
max_position_embeddings=262144,
307+
vocab_size=256,
308+
hidden_act="silu",
309+
rms_norm_eps=1e-6,
310+
rope_theta=50000.0,
311+
rope_scaling={
312+
"type": "yarn",
313+
"factor": 64.0,
314+
"original_max_position_embeddings": 4096,
315+
"beta_fast": 32.0,
316+
"beta_slow": 1.0,
317+
"mscale": 1.0,
318+
"mscale_all_dim": 1.0,
319+
},
320+
pretraining_tp=1,
321+
pad_token_id=0,
322+
)
323+
config.draft_vocab_size = 256
324+
325+
model = LlamaForCausalLMEagle3(config, attention_backend="sdpa")
326+
rotary = model.midlayer.self_attn.rotary_emb
327+
328+
self.assertEqual(rotary.base, 50000.0)
329+
self.assertEqual(rotary.original_max_position_embeddings, 4096)
330+
self.assertEqual(rotary.scaling_factor, 64.0)
331+
332+
297333
def _make_mask_patterns(BT):
298334
"""Return (name, valid_idx) pairs covering diverse masking patterns."""
299335
patterns = []

torchspec/config/utils.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
1919
# SOFTWARE.
2020

21+
import copy
2122
import json
2223
import logging
2324
import warnings
@@ -28,6 +29,35 @@
2829
logger = logging.getLogger(__name__)
2930

3031

32+
def _copy_config_value(value):
33+
if hasattr(value, "to_dict"):
34+
return value.to_dict()
35+
return copy.deepcopy(value)
36+
37+
38+
def _normalize_rope_scaling(rope_scaling):
39+
if rope_scaling is None:
40+
return None
41+
42+
normalized = _copy_config_value(rope_scaling)
43+
if not isinstance(normalized, dict):
44+
return normalized
45+
46+
scaling_type = normalized.get("rope_type", normalized.get("type"))
47+
if scaling_type == "yarn":
48+
yarn_defaults = {
49+
"beta_fast": 32.0,
50+
"beta_slow": 1.0,
51+
"mscale": 1.0,
52+
"mscale_all_dim": 0.0,
53+
}
54+
for key, default in yarn_defaults.items():
55+
if normalized.get(key) is None:
56+
normalized[key] = default
57+
58+
return normalized
59+
60+
3161
def generate_draft_model_config(
3262
target_model_path: str, template_config_path: str = None, cache_dir: str = None
3363
):
@@ -85,6 +115,8 @@ def generate_draft_model_config(
85115
"num_key_value_heads": "num_key_value_heads",
86116
"intermediate_size": "intermediate_size",
87117
"max_position_embeddings": "max_position_embeddings",
118+
"rope_theta": "rope_theta",
119+
"rope_scaling": "rope_scaling",
88120
"rms_norm_eps": "rms_norm_eps",
89121
"hidden_act": "hidden_act",
90122
"bos_token_id": "bos_token_id",
@@ -101,6 +133,10 @@ def generate_draft_model_config(
101133
continue
102134
if target_param == "torch_dtype" and isinstance(value, torch.dtype):
103135
value = str(value).replace("torch.", "")
136+
else:
137+
value = _copy_config_value(value)
138+
if target_param == "rope_scaling":
139+
value = _normalize_rope_scaling(value)
104140
draft_config[draft_param] = value
105141

106142
draft_config["num_hidden_layers"] = 1

torchspec/models/draft/llama3_eagle.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1071,6 +1071,7 @@ def rope_get(key, default=None):
10711071
self.rotary_emb = LlamaYarnRotaryEmbedding(
10721072
self.head_dim,
10731073
max_position_embeddings=self.max_position_embeddings,
1074+
base=getattr(self.config, "rope_theta", 10000),
10741075
original_max_position_embeddings=rope_get("original_max_position_embeddings"),
10751076
scaling_factor=scaling_factor,
10761077
beta_fast=rope_get("beta_fast"),

0 commit comments

Comments
 (0)