diff --git a/open_lm/model.py b/open_lm/model.py index 0c979c40..c89aa9a0 100644 --- a/open_lm/model.py +++ b/open_lm/model.py @@ -97,15 +97,16 @@ class Params: moe_top_k: int = 2 moe_freq: int = 0 positional_embedding_type: str = "rotary" + rotary_freq: float = 10000 ffn_type: str = "swiglu" def get_pos_embed(args: Params): head_dim = args.dim // args.n_heads if args.positional_embedding_type == "rotary": - return RotaryWithCast(head_dim, args.seq_len) + return RotaryWithCast(head_dim, args.seq_len, args.rotary_freq) elif args.positional_embedding_type == "llama_rotary": - return LLaMARotaryWithCast(head_dim, args.n_heads, args.seq_len) + return LLaMARotaryWithCast(head_dim, args.n_heads, args.seq_len, args.rotary_freq) elif args.positional_embedding_type == "head_rotary": return HeadRotaryWithCast(head_dim, args.seq_len) elif args.positional_embedding_type == "none": @@ -461,6 +462,7 @@ def create_params(args): ), apply_qk_norm=cfg.get("qk_norm", args.qk_norm), positional_embedding_type=cfg.get("positional_embedding_type", args.positional_embedding_type), + rotary_freq=cfg.get("rotary_freq", args.rotary_freq), ffn_type=cfg.get("ffn_type", args.ffn_type), moe_num_experts=cfg.get("moe_num_experts", args.moe_num_experts), moe_loss_weight=cfg.get("moe_loss_weight", args.moe_loss_weight), diff --git a/open_lm/params.py b/open_lm/params.py index 0a7a3f64..2a0640fc 100644 --- a/open_lm/params.py +++ b/open_lm/params.py @@ -58,6 +58,12 @@ def add_model_args(parser): default="rotary", help="Type of positional embedding to use. This might be overridden by the model config.", ) + parser.add_argument( + "--rotary-freq", + type=float, + default=10000, + help="Frequency for rotary positional embeddings. This might be overridden by the model config.", + ) parser.add_argument( "--moe-freq", type=int, diff --git a/open_lm/positional_embedding/llama_rotary.py b/open_lm/positional_embedding/llama_rotary.py index 776b3cc5..0fd27579 100644 --- a/open_lm/positional_embedding/llama_rotary.py +++ b/open_lm/positional_embedding/llama_rotary.py @@ -112,7 +112,7 @@ class LLaMARotaryEmbedding(torch.nn.Module): (it does not create the embedding dimension) and will likely be picked up (imported) on a ad-hoc basis """ - def __init__(self, head_dim: int, num_heads: int, seq_len: int, *_, **__): + def __init__(self, head_dim: int, num_heads: int, seq_len: int, frequency: float = 10000, *_, **__): super().__init__() # Generate and save the inverse frequency buffer (non trainable) self.freqs_cis = precompute_freqs_cis( @@ -120,6 +120,7 @@ def __init__(self, head_dim: int, num_heads: int, seq_len: int, *_, **__): # Adding this multiplier instead of using 4096 directly allows for dynamism of token lengths while training or fine-tuning. head_dim, seq_len * 2, + theta=frequency, ) def reset_parameters(self): diff --git a/open_lm/positional_embedding/rotary.py b/open_lm/positional_embedding/rotary.py index b48ed890..38ab89a9 100644 --- a/open_lm/positional_embedding/rotary.py +++ b/open_lm/positional_embedding/rotary.py @@ -44,22 +44,24 @@ class RotaryEmbedding(torch.nn.Module): (it does not create the embedding dimension) and will likely be picked up (imported) on a ad-hoc basis """ - def __init__(self, dim_model: int, seq_len: int, *_, **__): + def __init__(self, dim_model: int, seq_len: int, frequency: float = 10000, *_, **__): super().__init__() # Generate and save the inverse frequency buffer (non trainable) self.dim_model = dim_model - self.register_buffer("inv_freq", torch.zeros(self.dim_model // 2)) self._cos_cached = None self._sin_cached = None self._seq_len_cached = 0 self.seq_len = seq_len - self.reset_parameters() - - def reset_parameters(self): - self.inv_freq = 1.0 / (10000 ** (torch.arange(0, self.dim_model, 2).float() / self.dim_model)) + self.frequency = frequency self._update_cos_sin_tables(self.seq_len) + def load_state_dict(self, state_dict, strict=True): + # The state dict is not used, as the parameters are not trainable + # Previous versions had an inv_freq buffer, we don't need to load it + # This is kept for compatibility with the previous version + pass + def _update_cos_sin_tables(self, seq_len: int = None, device: torch.device = None, dtype: torch.dtype = None): # If no seq_len is provided, use the cached one # If the seq_len is smaller than the cached one it is included in the cached one so no need to update @@ -70,8 +72,9 @@ def _update_cos_sin_tables(self, seq_len: int = None, device: torch.device = Non # or if we're on a new device (possibly due to tracing for instance) if seq_len > self._seq_len_cached or self._cos_cached.device != device or self._cos_cached.dtype != dtype: self._seq_len_cached = seq_len + inv_freq = 1.0 / (self.frequency ** (torch.arange(0, self.dim_model, 2).float() / self.dim_model)) t = torch.arange(seq_len, device=device, dtype=torch.float32) - freqs = torch.einsum("i,j->ij", t, self.inv_freq.to(dtype)) + freqs = torch.einsum("i,j->ij", t, inv_freq.to(device=device, dtype=dtype)) emb = torch.cat((freqs, freqs), dim=-1).to(device) self._cos_cached = emb.cos()[None, :, None, :].to(dtype) diff --git a/tests/assets/rotary1_old.pt b/tests/assets/rotary1_old.pt new file mode 100644 index 00000000..4761a092 Binary files /dev/null and b/tests/assets/rotary1_old.pt differ diff --git a/tests/assets/rotary2_old.pt b/tests/assets/rotary2_old.pt new file mode 100644 index 00000000..6c7da4f8 Binary files /dev/null and b/tests/assets/rotary2_old.pt differ diff --git a/tests/test_rotary_freq.py b/tests/test_rotary_freq.py new file mode 100644 index 00000000..e579587d --- /dev/null +++ b/tests/test_rotary_freq.py @@ -0,0 +1,87 @@ +import torch +import pytest +from open_lm.positional_embedding.rotary import RotaryEmbedding # replace 'your_module' with the actual module name + + +@pytest.fixture +def create_rotary_embedding(): + def _create_rotary_embedding(dim_model, seq_len, frequency): + return RotaryEmbedding(dim_model, seq_len, frequency) + + return _create_rotary_embedding + + +def test_frequency_input(create_rotary_embedding): + dim_model = 32 + seq_len = 64 + + # Create two rotary embeddings with different frequencies + freq1 = 10000 + freq2 = 20000 + rotary1 = create_rotary_embedding(dim_model, seq_len, freq1) + rotary2 = create_rotary_embedding(dim_model, seq_len, freq2) + + # Generate some dummy data + q = torch.randn(1, seq_len, dim_model) + k = torch.randn(1, seq_len, dim_model) + + # Forward pass with different frequencies + q1, k1 = rotary1(q, k) + q2, k2 = rotary2(q, k) + + # Ensure the outputs are different + assert not torch.allclose(q1, q), "The outputs should not be close" + assert not torch.allclose(k1, k), "The outputs should not be close" + assert not torch.allclose(q1, q2), "The outputs for different frequencies should not be close" + assert not torch.allclose(k1, k2), "The outputs for different frequencies should not be close" + + # load the state dicts + state_dict1 = torch.load("tests/assets/rotary1_old.pt") + state_dict2 = torch.load("tests/assets/rotary2_old.pt") + + # Build new rotary embeddings with exchanged frequencies + rotary1_loaded = create_rotary_embedding(dim_model, seq_len, freq2) + rotary2_loaded = create_rotary_embedding(dim_model, seq_len, freq1) + + # Forward pass with loaded models + q1_loaded, k1_loaded = rotary1_loaded(q, k) + q2_loaded, k2_loaded = rotary2_loaded(q, k) + + # Ensure the outputs are the same + assert torch.allclose( + q1, q2_loaded + ), "The outputs should be the same for the same fequencies before loading the state dict" + assert torch.allclose( + k2, k1_loaded + ), "The outputs should be the same for the same fequencies before loading the state dict" + + # Assert old state dict is in the old format + assert "inv_freq" in state_dict1, "The old state dict should contain the inv_freq buffer" + + # Load the state dicts + rotary1_loaded.load_state_dict(state_dict1, strict=True) + rotary2_loaded.load_state_dict(state_dict2, strict=True) + + # Ensure the frequencies are not overwritten + assert rotary1_loaded.frequency == freq2, "Frequency should not be overwritten by load_state_dict" + assert rotary2_loaded.frequency == freq1, "Frequency should not be overwritten by load_state_dict" + + # Forward pass with loaded models + q1_loaded, k1_loaded = rotary1_loaded(q, k) + q2_loaded, k2_loaded = rotary2_loaded(q, k) + + # Ensure the outputs are the same + assert torch.allclose( + q1, q2_loaded + ), "The outputs should be the same for the same fequencies after loading the state dict" + assert torch.allclose( + k2, k1_loaded + ), "The outputs should be the same for the same fequencies after loading the state dict" + + # Ensure the outputs are still different + assert not torch.allclose(q1_loaded, q2_loaded), "The outputs for different frequencies should not be close" + assert not torch.allclose(k1_loaded, k2_loaded), "The outputs for different frequencies should not be close" + + +if __name__ == "__main__": + pytest.main([__file__])