diff --git a/llms/mlx_lm/models/qwen2.py b/llms/mlx_lm/models/qwen2.py index 381767c42..bb31b4a86 100644 --- a/llms/mlx_lm/models/qwen2.py +++ b/llms/mlx_lm/models/qwen2.py @@ -7,6 +7,7 @@ import mlx.nn as nn from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention +from .rope_utils import initialize_rope @dataclass @@ -18,24 +19,13 @@ class ModelArgs(BaseModelArgs): num_attention_heads: int rms_norm_eps: float vocab_size: int - num_key_value_heads: Optional[int] = None + num_key_value_heads: int + max_position_embeddings: int = 32768 rope_theta: float = 1000000 rope_traditional: bool = False rope_scaling: Optional[Dict[str, Union[float, str]]] = None tie_word_embeddings: bool = True - def __post_init__(self): - if self.num_key_value_heads is None: - self.num_key_value_heads = self.num_attention_heads - - if self.rope_scaling: - required_keys = {"factor", "type"} - if not all(key in self.rope_scaling for key in required_keys): - raise ValueError(f"rope_scaling must contain keys {required_keys}") - - if self.rope_scaling["type"] != "linear": - raise ValueError("rope_scaling 'type' currently only supports 'linear'") - class Attention(nn.Module): def __init__(self, args: ModelArgs): @@ -54,16 +44,12 @@ def __init__(self, args: ModelArgs): self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=True) self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=False) - rope_scale = ( - 1 / args.rope_scaling["factor"] - if args.rope_scaling is not None and args.rope_scaling["type"] == "linear" - else 1 - ) - self.rope = nn.RoPE( + self.rope = initialize_rope( head_dim, - traditional=args.rope_traditional, base=args.rope_theta, - scale=rope_scale, + traditional=args.rope_traditional, + scaling_config=args.rope_scaling, + max_position_embeddings=args.max_position_embeddings, ) def __call__( diff --git a/llms/mlx_lm/models/rope_utils.py b/llms/mlx_lm/models/rope_utils.py index d30b432df..5be172623 100644 --- a/llms/mlx_lm/models/rope_utils.py +++ b/llms/mlx_lm/models/rope_utils.py @@ -1,5 +1,6 @@ # Copyright © 2023-2024 Apple Inc. +import math from typing import Optional import mlx.core as mx @@ -61,6 +62,78 @@ def __call__(self, x, offset: int = 0): ) +class YarnRoPE(nn.Module): + def __init__( + self, + dims, + traditional=False, + max_position_embeddings=2048, + base=10000, + scaling_factor=1.0, + original_max_position_embeddings=4096, + beta_fast=32, + beta_slow=1, + mscale=1, + mscale_all_dim=0, + ): + super().__init__() + + def yarn_find_correction_dim(num_rotations): + return ( + dims + * math.log( + original_max_position_embeddings / (num_rotations * 2 * math.pi) + ) + ) / (2 * math.log(base)) + + def yarn_find_correction_range(): + low = math.floor(yarn_find_correction_dim(beta_fast)) + high = math.ceil(yarn_find_correction_dim(beta_slow)) + return max(low, 0), min(high, dims - 1) + + def yarn_get_mscale(scale=1, mscale=1): + if scale <= 1: + return 1.0 + return 0.1 * mscale * math.log(scale) + 1.0 + + def yarn_linear_ramp_mask(min_val, max_val, dim): + if min_val == max_val: + max_val += 0.001 # Prevent singularity + + linear_func = (mx.arange(dim, dtype=mx.float32) - min_val) / ( + max_val - min_val + ) + return mx.clip(linear_func, 0, 1) + + self.mscale = yarn_get_mscale(scaling_factor, mscale) / yarn_get_mscale( + scaling_factor, mscale_all_dim + ) + freq_extra = base ** (mx.arange(0, dims, 2, dtype=mx.float32) / dims) + freq_inter = scaling_factor * base ** ( + mx.arange(0, dims, 2, dtype=mx.float32) / dims + ) + low, high = yarn_find_correction_range() + freq_mask = 1.0 - yarn_linear_ramp_mask(low, high, dims // 2) + self._freqs = (freq_inter * freq_extra) / ( + freq_inter * freq_mask + freq_extra * (1 - freq_mask) + ) + self.dims = dims + self.traditional = traditional + + def __call__(self, x, offset=0): + if self.mscale != 1.0: + x[..., : self.dims] = self.mscale * x[..., : self.dims] + return mx.fast.rope( + x, + self.dims, + traditional=self.traditional, + base=None, + scale=1.0, + offset=offset, + freqs=self._freqs, + ) + + def initialize_rope( dims, base, @@ -87,5 +160,25 @@ def initialize_rope( base=base, scaling_config=scaling_config, ) + elif rope_type == "yarn": + scaling_factor = scaling_config["factor"] + rope_kwargs = { + key: scaling_config[key] + for key in [ + "original_max_position_embeddings", + "beta_fast", + "beta_slow", + "mscale", + "mscale_all_dim", + ] + if key in scaling_config + } + return YarnRoPE( + dims=dims, + max_position_embeddings=max_position_embeddings, + traditional=traditional, + base=base, + **rope_kwargs, + ) else: raise ValueError(f"Unsupported RoPE type {rope_type}") diff --git a/llms/tests/test_models.py b/llms/tests/test_models.py index 0c0fc6018..34e584647 100644 --- a/llms/tests/test_models.py +++ b/llms/tests/test_models.py @@ -336,6 +336,7 @@ def test_qwen2(self): num_hidden_layers=4, intermediate_size=2048, num_attention_heads=4, + num_key_value_heads=4, rms_norm_eps=1e-5, vocab_size=10_000, )