Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add yarn option for qwen2 #1332

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 7 additions & 21 deletions llms/mlx_lm/models/qwen2.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import mlx.nn as nn

from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
from .rope_utils import initialize_rope


@dataclass
Expand All @@ -18,24 +19,13 @@ class ModelArgs(BaseModelArgs):
num_attention_heads: int
rms_norm_eps: float
vocab_size: int
num_key_value_heads: Optional[int] = None
num_key_value_heads: int
max_position_embeddings: int = 32768
rope_theta: float = 1000000
rope_traditional: bool = False
rope_scaling: Optional[Dict[str, Union[float, str]]] = None
tie_word_embeddings: bool = True

def __post_init__(self):
if self.num_key_value_heads is None:
self.num_key_value_heads = self.num_attention_heads

if self.rope_scaling:
required_keys = {"factor", "type"}
if not all(key in self.rope_scaling for key in required_keys):
raise ValueError(f"rope_scaling must contain keys {required_keys}")

if self.rope_scaling["type"] != "linear":
raise ValueError("rope_scaling 'type' currently only supports 'linear'")


class Attention(nn.Module):
def __init__(self, args: ModelArgs):
Expand All @@ -54,16 +44,12 @@ def __init__(self, args: ModelArgs):
self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=True)
self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=False)

rope_scale = (
1 / args.rope_scaling["factor"]
if args.rope_scaling is not None and args.rope_scaling["type"] == "linear"
else 1
)
self.rope = nn.RoPE(
self.rope = initialize_rope(
head_dim,
traditional=args.rope_traditional,
base=args.rope_theta,
scale=rope_scale,
traditional=args.rope_traditional,
scaling_config=args.rope_scaling,
max_position_embeddings=args.max_position_embeddings,
)

def __call__(
Expand Down
93 changes: 93 additions & 0 deletions llms/mlx_lm/models/rope_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Copyright © 2023-2024 Apple Inc.

import math
from typing import Optional

import mlx.core as mx
Expand Down Expand Up @@ -61,6 +62,78 @@ def __call__(self, x, offset: int = 0):
)


class YarnRoPE(nn.Module):
def __init__(
self,
dims,
traditional=False,
max_position_embeddings=2048,
base=10000,
scaling_factor=1.0,
original_max_position_embeddings=4096,
beta_fast=32,
beta_slow=1,
mscale=1,
mscale_all_dim=0,
):
super().__init__()

def yarn_find_correction_dim(num_rotations):
return (
dims
* math.log(
original_max_position_embeddings / (num_rotations * 2 * math.pi)
)
) / (2 * math.log(base))

def yarn_find_correction_range():
low = math.floor(yarn_find_correction_dim(beta_fast))
high = math.ceil(yarn_find_correction_dim(beta_slow))
return max(low, 0), min(high, dims - 1)

def yarn_get_mscale(scale=1, mscale=1):
if scale <= 1:
return 1.0
return 0.1 * mscale * math.log(scale) + 1.0

def yarn_linear_ramp_mask(min_val, max_val, dim):
if min_val == max_val:
max_val += 0.001 # Prevent singularity

linear_func = (mx.arange(dim, dtype=mx.float32) - min_val) / (
max_val - min_val
)
return mx.clip(linear_func, 0, 1)

self.mscale = yarn_get_mscale(scaling_factor, mscale) / yarn_get_mscale(
scaling_factor, mscale_all_dim
)
freq_extra = base ** (mx.arange(0, dims, 2, dtype=mx.float32) / dims)
freq_inter = scaling_factor * base ** (
mx.arange(0, dims, 2, dtype=mx.float32) / dims
)
low, high = yarn_find_correction_range()
freq_mask = 1.0 - yarn_linear_ramp_mask(low, high, dims // 2)
self._freqs = (freq_inter * freq_extra) / (
freq_inter * freq_mask + freq_extra * (1 - freq_mask)
)
self.dims = dims
self.traditional = traditional

def __call__(self, x, offset=0):
if self.mscale != 1.0:
x[..., : self.dims] = self.mscale * x[..., : self.dims]
return mx.fast.rope(
x,
self.dims,
traditional=self.traditional,
base=None,
scale=1.0,
offset=offset,
freqs=self._freqs,
)


def initialize_rope(
dims,
base,
Expand All @@ -87,5 +160,25 @@ def initialize_rope(
base=base,
scaling_config=scaling_config,
)
elif rope_type == "yarn":
scaling_factor = scaling_config["factor"]
rope_kwargs = {
key: scaling_config[key]
for key in [
"original_max_position_embeddings",
"beta_fast",
"beta_slow",
"mscale",
"mscale_all_dim",
]
if key in scaling_config
}
return YarnRoPE(
dims=dims,
max_position_embeddings=max_position_embeddings,
traditional=traditional,
base=base,
**rope_kwargs,
)
else:
raise ValueError(f"Unsupported RoPE type {rope_type}")
1 change: 1 addition & 0 deletions llms/tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,6 +336,7 @@ def test_qwen2(self):
num_hidden_layers=4,
intermediate_size=2048,
num_attention_heads=4,
num_key_value_heads=4,
rms_norm_eps=1e-5,
vocab_size=10_000,
)
Expand Down