15
15
from .moe import GptOssMoE
16
16
17
17
18
- # Adapted from https://github.com/DeepSeek-ai/DeepSeek-V3/blob/main/inference/model.py#L294
19
- def precompute_freqs_cis (args : GptOssModelArgs ) -> torch .Tensor :
20
- """
21
- Precomputes frequency-based complex exponential values for rotary positional embeddings.
22
-
23
- Args:
24
- args (GptOssModelArgs): Model arguments containing positional embedding parameters.
25
-
26
- Returns:
27
- torch.Tensor: Precomputed complex exponential values for positional embeddings.
28
- """
29
- dim = args .head_dim
30
- seqlen = args .max_seq_len
31
- beta_fast = args .beta_fast
32
- beta_slow = args .beta_slow
33
- base = args .rope_theta
34
- factor = args .rope_factor
35
- original_seq_len = args .original_seq_len
36
-
37
- # YaRN default m-scale (attention_factor). Matches HF when attention_factor is None.
38
- mscale = 0.1 * math .log (factor ) + 1.0
39
-
40
- def find_correction_dim (
41
- num_rotations : float , dim : int , base : float , max_seq_len : int
42
- ) -> float :
43
- """
44
- Computes the correction dimension for a given number of rotations in the rotary positional embedding.
45
-
46
- Args:
47
- num_rotations (float): Number of rotations to compute the correction for.
48
- dim (int): Dimensionality of the embedding space.
49
- base (float): Base value for the exponential computation.
50
- max_seq_len (int): Maximum sequence length.
51
-
52
- Returns:
53
- float: The correction dimension based on the input parameters.
54
- """
55
- return (
56
- dim
57
- * math .log (max_seq_len / (num_rotations * 2 * math .pi ))
58
- / (2 * math .log (base ))
59
- )
60
-
61
- def find_correction_range (
62
- low_rot : float , high_rot : float , dim : int , base : float , max_seq_len : int
63
- ) -> Tuple [int , int ]:
64
- """
65
- Computes the range of correction dimensions for rotary positional embeddings.
66
-
67
- Args:
68
- low_rot (float): Lower bound for the number of rotations.
69
- high_rot (float): Upper bound for the number of rotations.
70
- dim (int): Dimensionality of the embedding space.
71
- base (float): Base value for the exponential computation.
72
- max_seq_len (int): Maximum sequence length.
18
+ def precompute_rope_cache (
19
+ dim : int , max_seq_len : int , base : float = 1_000_000.0
20
+ ) -> torch .Tensor :
21
+ freqs = 1.0 / (base ** (torch .arange (0 , dim , 2 )[: (dim // 2 )].float () / dim ))
22
+ # Create position indexes `[0, 1, ..., max_seq_len - 1]`
23
+ t = torch .arange (max_seq_len , dtype = freqs .dtype , device = freqs .device )
73
24
74
- Returns:
75
- Tuple[int, int]: The range of correction dimensions (low, high), clamped to valid indices.
76
- """
77
- low = math .floor (find_correction_dim (low_rot , dim , base , max_seq_len ))
78
- high = math .ceil (find_correction_dim (high_rot , dim , base , max_seq_len ))
79
- return max (low , 0 ), min (high , dim - 1 )
25
+ # Outer product of theta and position index; output tensor has
26
+ # a shape of [max_seq_len, dim // 2]
27
+ idx_theta = torch .outer (t , freqs ).float ()
80
28
81
- def linear_ramp_factor (min : float , max : float , dim : int ) -> torch .Tensor :
82
- """
83
- Computes a linear ramp function used to smooth values between a minimum and maximum range.
84
-
85
- Args:
86
- min (float): Minimum value for the ramp function.
87
- max (float): Maximum value for the ramp function.
88
- dim (int): Dimensionality of the ramp tensor.
89
-
90
- Returns:
91
- torch.Tensor: A tensor of shape (dim,) with values linearly interpolated between 0 and 1,
92
- clamped to the range [0, 1].
93
- """
94
- if min == max :
95
- max += 0.001
96
- linear_func = (torch .arange (dim , dtype = torch .float32 ) - min ) / (max - min )
97
- ramp_func = torch .clamp (linear_func , 0 , 1 )
98
- return ramp_func
99
-
100
- # Basic RoPE frequency calculation
101
- freqs = 1.0 / (base ** (torch .arange (0 , dim , 2 , dtype = torch .float32 ) / dim ))
102
-
103
- # YaRN scaling for extended context. YaRN is used to extend the context length after pre-training.
104
- if seqlen > original_seq_len :
105
- low , high = find_correction_range (
106
- beta_fast , beta_slow , dim , base , original_seq_len
107
- )
108
- smooth = 1 - linear_ramp_factor (low , high , dim // 2 )
109
- freqs = freqs / factor * (1 - smooth ) + freqs * smooth
29
+ # We cache the cos and sin embeddings instead of the IDs. This helps
30
+ # ensure we have correct behavior when training with bf16
31
+ # Size: [max_seq_len, (dim * 2)]
32
+ freqs = torch .cat ([idx_theta , idx_theta ], dim = - 1 )
33
+ rope_cache = torch .cat ([freqs .cos (), freqs .sin ()], dim = - 1 )
34
+ return rope_cache
110
35
111
- # Create position indices
112
- t = torch .arange (seqlen )
113
36
114
- # Outer product: [positions] × [frequencies]
115
- freqs = torch .outer (t , freqs )
37
+ def rotate_half (x : torch .Tensor ) -> torch .Tensor :
38
+ """Rotates half the hidden dims of the input."""
39
+ x1 = x [..., : x .shape [- 1 ] // 2 ]
40
+ x2 = x [..., x .shape [- 1 ] // 2 :]
41
+ return torch .cat ((- x2 , x1 ), dim = - 1 )
116
42
117
- # Convert to complex exponentials: e^(i*freq*pos)
118
- freqs_cis = torch .polar (torch .full_like (freqs , fill_value = mscale ), freqs )
119
43
120
- return freqs_cis
44
+ def reshape_for_broadcast (rope_cache : torch .Tensor , x : torch .Tensor ) -> torch .Tensor :
45
+ """
46
+ Reshape frequency tensor (represented by cos, sin) for broadcasting it with another tensor.
121
47
48
+ This function reshapes the frequency tensor to have the same shape as the target tensor 'x'
49
+ for the purpose of broadcasting the frequency tensor during element-wise operations.
122
50
123
- def apply_rotary_emb_inner (x : torch .Tensor , freqs_cis : torch .Tensor ) -> torch .Tensor :
124
- """
125
- Applies rotary positional embeddings to the input tensor.
51
+ The input freqs_cis tensor is assumed to be of shape (max_seqlen, head_dim * 2),
52
+ and the first seqlen elements will be sliced, but dim must match x.
126
53
127
54
Args:
128
- x (torch.Tensor): Input tensor with positional embeddings to be applied .
129
- freqs_cis (torch.Tensor): Precomputed complex exponential values for positional embeddings .
55
+ rope_cache (torch.Tensor): RoPE tensor (cos and sin) to be reshaped .
56
+ x (torch.Tensor): Target tensor for broadcasting compatibility .
130
57
131
58
Returns:
132
- torch.Tensor: Tensor with rotary embeddings applied .
59
+ torch.Tensor: Reshaped frequency tensor .
133
60
"""
134
- dtype = x .dtype
135
- x = torch .view_as_complex (x .float ().view (* x .shape [:- 1 ], - 1 , 2 ))
136
- freqs_cis = freqs_cis .view (1 , x .size (1 ), 1 , x .size (- 1 ))
137
- y = torch .view_as_real (x * freqs_cis ).flatten (3 )
138
- return y .to (dtype )
61
+ ndim = x .ndim
62
+ assert ndim > 1
63
+ _ , seqlen , _ , head_dim = x .shape
64
+ rope_cache = rope_cache [0 :seqlen ]
65
+ # The shape of rope_cache is (seqlen, head_dim * 2) because we concate cos and sin
66
+ assert rope_cache .shape == (seqlen , head_dim * 2 )
67
+ shape = [- 1 , seqlen , 1 , head_dim * 2 ]
68
+ return rope_cache .view (* shape )
69
+
70
+
71
+ def apply_rotary_emb (
72
+ xq : torch .Tensor , xk : torch .Tensor , rope_cache : torch .Tensor
73
+ ) -> tuple [torch .Tensor , torch .Tensor ]:
74
+ # input tensor x has shape [bsz, seq_len, num_heads, head_dim]
75
+ head_dim = xq .shape [- 1 ]
76
+
77
+ # reshape for broadcast
78
+ rope_cache = reshape_for_broadcast (rope_cache , xq )
79
+
80
+ # [bsz, seq_len, 1, head_dim]
81
+ cos = rope_cache [..., :head_dim ].to (dtype = xq .dtype , device = xq .device )
82
+ sin = rope_cache [..., head_dim :].to (dtype = xq .dtype , device = xq .device )
83
+
84
+ # xq: [bsz, seq_len, num_heads, head_dim]
85
+ # xk: [bsz, seq_len, num_kv_heads, head_dim]
86
+ xq_out = (xq * cos ) + (rotate_half (xq ) * sin )
87
+ xk_out = (xk * cos ) + (rotate_half (xk ) * sin )
88
+ return xq_out .type_as (xq ), xk_out .type_as (xk )
139
89
140
- def apply_rotary_emb (q : torch .Tensor , k : torch .Tensor , freqs_cis : torch .Tensor ):
141
- """
142
- HF-style inputs (half-split last dim) -> interleave -> Torchtitan complex RoPE -> de-interleave.
143
- Shapes:
144
- q, k: [B, T, H, D] with D even (HF half-split: first D/2 real, last D/2 imag)
145
- freqs_cis: complex, last dim == D/2. Typically [T, D/2] or [1, T, D/2].
146
- Returns:
147
- q_out, k_out in HF half-split layout (same shape as q, k).
148
- """
149
- B , T , H , D = q .shape
150
- assert D % 2 == 0 , "head_dim must be even for RoPE"
151
- rot = D // 2
152
- assert freqs_cis .shape [- 1 ] == rot , "freqs_cis last dim must be D/2"
153
- freqs_cis = freqs_cis [:T , :]
154
-
155
- # Memory layout comparison for head_dim=8:
156
- # HF Format: [r0][r1][r2][r3][i0][i1][i2][i3]
157
- # ↑-- reals --↑ ↑-- imags --↑
158
-
159
- # Interleaved: [r0][i0][r1][i1][r2][i2][r3][i3]
160
- # ↑-pair-↑ ↑-pair-↑ ↑-pair-↑ ↑-pair-↑
161
- # --- inline: HF half-split -> interleaved (real0, imag0, real1, imag1, ...)
162
- # q_i, k_i: [B, T, H, D]
163
- q_i = torch .empty_like (q )
164
- k_i = torch .empty_like (k )
165
- q_i [..., 0 ::2 ] = q [..., :rot ]
166
- q_i [..., 1 ::2 ] = q [..., rot :]
167
- k_i [..., 0 ::2 ] = k [..., :rot ]
168
- k_i [..., 1 ::2 ] = k [..., rot :]
169
-
170
- # --- Torchtitan default complex apply (expects interleaved last dim)
171
- # freqs_cis will be reshaped inside to [1, T, 1, rot]
172
- # TODO(jianiw): I think we shoud go with sin/cos representation to simplify the conversion between paired real/imaginary <-> half-split real/imaginary
173
- q_rot_i = apply_rotary_emb_inner (q_i , freqs_cis ) # uses TT's complex path
174
- k_rot_i = apply_rotary_emb_inner (k_i , freqs_cis )
175
-
176
- # --- inline: interleaved -> HF half-split
177
- # TODO(jianiw): convert it back
178
- q_out = torch .cat ([q_rot_i [..., 0 ::2 ], q_rot_i [..., 1 ::2 ]], dim = - 1 )
179
- k_out = torch .cat ([k_rot_i [..., 0 ::2 ], k_rot_i [..., 1 ::2 ]], dim = - 1 )
180
- return q_out , k_out
181
-
182
- # Torch Attention backup implementation (for debugging and sampling) from HuggingFace
183
90
def repeat_kv (x : torch .Tensor , n_rep : int ) -> torch .Tensor :
184
91
"""torch.repeat_interleave(x, dim=2, repeats=n_rep)"""
185
92
bs , slen , n_kv_heads , head_dim = x .shape
@@ -215,7 +122,7 @@ def eager_attention_forward(
215
122
add_mask = attention_mask .to (attn_weights .dtype )
216
123
217
124
# Truncate to current key length and add (broadcasts if needed)
218
- add_mask = add_mask [..., : key_states .shape [- 2 ]]
125
+ add_mask = add_mask [..., : key .shape [- 2 ]]
219
126
attn_weights = attn_weights + add_mask
220
127
221
128
sinks = sinks .reshape (1 , - 1 , 1 , 1 ).expand (query .shape [0 ], - 1 , query .shape [- 2 ], - 1 )
@@ -275,14 +182,14 @@ def __init__(self, model_args: GptOssModelArgs, use_sliding_attention: bool = Fa
275
182
def forward (
276
183
self ,
277
184
x : torch .Tensor ,
278
- freqs_cis : torch .Tensor ,
185
+ rope_cache : torch .Tensor ,
279
186
):
280
187
"""
281
188
Forward pass for the Multi-Head Latent Attention (MLA) Layer.
282
189
283
190
Args:
284
191
x (torch.Tensor): Input tensor of shape (batch_size, seq_len, dim).
285
- freqs_cis (torch.Tensor): Precomputed complex exponential values for rotary embeddings .
192
+ rope_cache (torch.Tensor): Precomputed cosine and sine frequencies for rope embedding .
286
193
287
194
Returns:
288
195
torch.Tensor: Output tensor with the same shape as the input.
@@ -294,7 +201,7 @@ def forward(
294
201
k = self .wk (x ).view (hidden_shape )
295
202
v = self .wv (x ).view (hidden_shape )
296
203
297
- q , k = apply_rotary_emb (q , k , freqs_cis )
204
+ q , k = apply_rotary_emb (q , k , rope_cache )
298
205
299
206
# repeat k/v heads if n_kv_heads < n_heads
300
207
keys = repeat_kv (k , self .n_rep )
@@ -369,18 +276,18 @@ def __init__(self, layer_id: int, model_args: GptOssModelArgs):
369
276
self .weight_init_std = 0.02 / (2 * (layer_id + 1 )) ** 0.5
370
277
self .layer_id = layer_id
371
278
372
- def forward (self , x : torch .Tensor , freqs_cis : torch .Tensor ):
279
+ def forward (self , x : torch .Tensor , rope_cache : torch .Tensor ):
373
280
"""
374
281
Forward pass for the Transformer block.
375
282
376
283
Args:
377
284
x (torch.Tensor): Input tensor of shape (batch_size, seq_len, dim).
378
- freqs_cis (torch.Tensor): Precomputed complex exponential values for rotary embeddings .
285
+ rope_cache (torch.Tensor): Precomputed cosine and sine frequencies .
379
286
380
287
Returns:
381
288
torch.Tensor: Output tensor with the same shape as the input.
382
289
"""
383
- x = x + self .attention (self .attention_norm (x ), freqs_cis )
290
+ x = x + self .attention (self .attention_norm (x ), rope_cache )
384
291
x = x + self .moe (self .ffn_norm (x ))
385
292
return x
386
293
@@ -398,16 +305,16 @@ class GptOssModel(nn.Module, ModelProtocol):
398
305
399
306
def __init__ (self , model_args : GptOssModelArgs ):
400
307
super ().__init__ ()
308
+ self .model_args = model_args
401
309
self .max_seq_len = model_args .max_seq_len
402
310
self .tok_embeddings = nn .Embedding (model_args .vocab_size , model_args .hidden_size )
403
311
self .register_buffer (
404
- "freqs_cis " , precompute_freqs_cis ( model_args ), persistent = True
312
+ "rope_cache " , self . _precompute_rope_cache ( ), persistent = False
405
313
)
406
314
407
315
self .layers = torch .nn .ModuleDict ()
408
316
for layer_id in range (model_args .num_hidden_layers ):
409
317
self .layers [str (layer_id )] = TransformerBlock (layer_id , model_args ).to (torch .bfloat16 )
410
- # convert_submodules_to_bf16(self.layers[str(layer_id)])
411
318
412
319
self .norm = nn .RMSNorm (model_args .hidden_size , eps = model_args .norm_eps )
413
320
self .output = nn .Linear (
@@ -418,12 +325,11 @@ def __init__(self, model_args: GptOssModelArgs):
418
325
)
419
326
self .model_args = model_args
420
327
self .init_weights ()
421
- # convert_submodules_to_bf16(self)
422
328
423
329
def init_weights (self , buffer_device : torch .device | None = None ) -> None :
424
- buffer_device = buffer_device or self .freqs_cis .device
330
+ buffer_device = buffer_device or self .rope_cache .device
425
331
with torch .device (buffer_device ):
426
- self .freqs_cis = precompute_freqs_cis ( self .model_args )
332
+ self .rope_cache = self ._precompute_rope_cache ( )
427
333
if self .tok_embeddings is not None :
428
334
nn .init .normal_ (self .tok_embeddings .weight )
429
335
for layer in self .layers .values ():
@@ -442,6 +348,13 @@ def init_weights(self, buffer_device: torch.device | None = None) -> None:
442
348
b = cutoff_factor * final_out_std ,
443
349
)
444
350
351
+ def _precompute_rope_cache (self ) -> torch .Tensor :
352
+ return precompute_rope_cache (
353
+ self .model_args .head_dim ,
354
+ self .model_args .max_seq_len ,
355
+ self .model_args .rope_theta ,
356
+ )
357
+
445
358
def forward (self , tokens : torch .Tensor ):
446
359
"""
447
360
Forward pass for the Transformer model.
@@ -455,7 +368,7 @@ def forward(self, tokens: torch.Tensor):
455
368
h = self .tok_embeddings (tokens )
456
369
457
370
for layer in self .layers .values ():
458
- h = layer (h , self .freqs_cis )
371
+ h = layer (h , self .rope_cache )
459
372
h = self .norm (h )
460
373
output = self .output (h )
461
374
return output
0 commit comments