@@ -27,6 +27,7 @@ class Llama2Config:
2727 rms_norm_add = False
2828 mlp_activation = "silu"
2929 qkv_bias = False
30+ rope_dims = None
3031
3132@dataclass
3233class Qwen25_3BConfig :
@@ -44,6 +45,7 @@ class Qwen25_3BConfig:
4445 rms_norm_add = False
4546 mlp_activation = "silu"
4647 qkv_bias = True
48+ rope_dims = None
4749
4850@dataclass
4951class Qwen25_7BVLI_Config :
@@ -61,6 +63,7 @@ class Qwen25_7BVLI_Config:
6163 rms_norm_add = False
6264 mlp_activation = "silu"
6365 qkv_bias = True
66+ rope_dims = [16 , 24 , 24 ]
6467
6568@dataclass
6669class Gemma2_2B_Config :
@@ -78,6 +81,7 @@ class Gemma2_2B_Config:
7881 rms_norm_add = True
7982 mlp_activation = "gelu_pytorch_tanh"
8083 qkv_bias = False
84+ rope_dims = None
8185
8286class RMSNorm (nn .Module ):
8387 def __init__ (self , dim : int , eps : float = 1e-5 , add = False , device = None , dtype = None ):
@@ -102,7 +106,7 @@ def rotate_half(x):
102106 return torch .cat ((- x2 , x1 ), dim = - 1 )
103107
104108
105- def precompute_freqs_cis (head_dim , position_ids , theta , device = None ):
109+ def precompute_freqs_cis (head_dim , position_ids , theta , rope_dims = None , device = None ):
106110 theta_numerator = torch .arange (0 , head_dim , 2 , device = device ).float ()
107111 inv_freq = 1.0 / (theta ** (theta_numerator / head_dim ))
108112
@@ -112,12 +116,20 @@ def precompute_freqs_cis(head_dim, position_ids, theta, device=None):
112116 emb = torch .cat ((freqs , freqs ), dim = - 1 )
113117 cos = emb .cos ()
114118 sin = emb .sin ()
119+ if rope_dims is not None and position_ids .shape [0 ] > 1 :
120+ mrope_section = rope_dims * 2
121+ cos = torch .cat ([m [i % 3 ] for i , m in enumerate (cos .split (mrope_section , dim = - 1 ))], dim = - 1 ).unsqueeze (0 )
122+ sin = torch .cat ([m [i % 3 ] for i , m in enumerate (sin .split (mrope_section , dim = - 1 ))], dim = - 1 ).unsqueeze (0 )
123+ else :
124+ cos = cos .unsqueeze (1 )
125+ sin = sin .unsqueeze (1 )
126+
115127 return (cos , sin )
116128
117129
118130def apply_rope (xq , xk , freqs_cis ):
119- cos = freqs_cis [0 ]. unsqueeze ( 1 )
120- sin = freqs_cis [1 ]. unsqueeze ( 1 )
131+ cos = freqs_cis [0 ]
132+ sin = freqs_cis [1 ]
121133 q_embed = (xq * cos ) + (rotate_half (xq ) * sin )
122134 k_embed = (xk * cos ) + (rotate_half (xk ) * sin )
123135 return q_embed , k_embed
@@ -292,6 +304,7 @@ def forward(self, x, attention_mask=None, embeds=None, num_tokens=None, intermed
292304 freqs_cis = precompute_freqs_cis (self .config .head_dim ,
293305 position_ids ,
294306 self .config .rope_theta ,
307+ self .config .rope_dims ,
295308 device = x .device )
296309
297310 mask = None
0 commit comments