Skip to content

Commit 4fc78a3

Browse files
committed
check and replace rope
1 parent 589ce62 commit 4fc78a3

File tree

1 file changed

+77
-164
lines changed
  • torchtitan/experiments/gpt_oss/model

1 file changed

+77
-164
lines changed

torchtitan/experiments/gpt_oss/model/model.py

Lines changed: 77 additions & 164 deletions
Original file line numberDiff line numberDiff line change
@@ -15,171 +15,78 @@
1515
from .moe import GptOssMoE
1616

1717

18-
# Adapted from https://github.com/DeepSeek-ai/DeepSeek-V3/blob/main/inference/model.py#L294
19-
def precompute_freqs_cis(args: GptOssModelArgs) -> torch.Tensor:
20-
"""
21-
Precomputes frequency-based complex exponential values for rotary positional embeddings.
22-
23-
Args:
24-
args (GptOssModelArgs): Model arguments containing positional embedding parameters.
25-
26-
Returns:
27-
torch.Tensor: Precomputed complex exponential values for positional embeddings.
28-
"""
29-
dim = args.head_dim
30-
seqlen = args.max_seq_len
31-
beta_fast = args.beta_fast
32-
beta_slow = args.beta_slow
33-
base = args.rope_theta
34-
factor = args.rope_factor
35-
original_seq_len = args.original_seq_len
36-
37-
# YaRN default m-scale (attention_factor). Matches HF when attention_factor is None.
38-
mscale = 0.1 * math.log(factor) + 1.0
39-
40-
def find_correction_dim(
41-
num_rotations: float, dim: int, base: float, max_seq_len: int
42-
) -> float:
43-
"""
44-
Computes the correction dimension for a given number of rotations in the rotary positional embedding.
45-
46-
Args:
47-
num_rotations (float): Number of rotations to compute the correction for.
48-
dim (int): Dimensionality of the embedding space.
49-
base (float): Base value for the exponential computation.
50-
max_seq_len (int): Maximum sequence length.
51-
52-
Returns:
53-
float: The correction dimension based on the input parameters.
54-
"""
55-
return (
56-
dim
57-
* math.log(max_seq_len / (num_rotations * 2 * math.pi))
58-
/ (2 * math.log(base))
59-
)
60-
61-
def find_correction_range(
62-
low_rot: float, high_rot: float, dim: int, base: float, max_seq_len: int
63-
) -> Tuple[int, int]:
64-
"""
65-
Computes the range of correction dimensions for rotary positional embeddings.
66-
67-
Args:
68-
low_rot (float): Lower bound for the number of rotations.
69-
high_rot (float): Upper bound for the number of rotations.
70-
dim (int): Dimensionality of the embedding space.
71-
base (float): Base value for the exponential computation.
72-
max_seq_len (int): Maximum sequence length.
18+
def precompute_rope_cache(
19+
dim: int, max_seq_len: int, base: float = 1_000_000.0
20+
) -> torch.Tensor:
21+
freqs = 1.0 / (base ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
22+
# Create position indexes `[0, 1, ..., max_seq_len - 1]`
23+
t = torch.arange(max_seq_len, dtype=freqs.dtype, device=freqs.device)
7324

74-
Returns:
75-
Tuple[int, int]: The range of correction dimensions (low, high), clamped to valid indices.
76-
"""
77-
low = math.floor(find_correction_dim(low_rot, dim, base, max_seq_len))
78-
high = math.ceil(find_correction_dim(high_rot, dim, base, max_seq_len))
79-
return max(low, 0), min(high, dim - 1)
25+
# Outer product of theta and position index; output tensor has
26+
# a shape of [max_seq_len, dim // 2]
27+
idx_theta = torch.outer(t, freqs).float()
8028

81-
def linear_ramp_factor(min: float, max: float, dim: int) -> torch.Tensor:
82-
"""
83-
Computes a linear ramp function used to smooth values between a minimum and maximum range.
84-
85-
Args:
86-
min (float): Minimum value for the ramp function.
87-
max (float): Maximum value for the ramp function.
88-
dim (int): Dimensionality of the ramp tensor.
89-
90-
Returns:
91-
torch.Tensor: A tensor of shape (dim,) with values linearly interpolated between 0 and 1,
92-
clamped to the range [0, 1].
93-
"""
94-
if min == max:
95-
max += 0.001
96-
linear_func = (torch.arange(dim, dtype=torch.float32) - min) / (max - min)
97-
ramp_func = torch.clamp(linear_func, 0, 1)
98-
return ramp_func
99-
100-
# Basic RoPE frequency calculation
101-
freqs = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim))
102-
103-
# YaRN scaling for extended context. YaRN is used to extend the context length after pre-training.
104-
if seqlen > original_seq_len:
105-
low, high = find_correction_range(
106-
beta_fast, beta_slow, dim, base, original_seq_len
107-
)
108-
smooth = 1 - linear_ramp_factor(low, high, dim // 2)
109-
freqs = freqs / factor * (1 - smooth) + freqs * smooth
29+
# We cache the cos and sin embeddings instead of the IDs. This helps
30+
# ensure we have correct behavior when training with bf16
31+
# Size: [max_seq_len, (dim * 2)]
32+
freqs = torch.cat([idx_theta, idx_theta], dim=-1)
33+
rope_cache = torch.cat([freqs.cos(), freqs.sin()], dim=-1)
34+
return rope_cache
11035

111-
# Create position indices
112-
t = torch.arange(seqlen)
11336

114-
# Outer product: [positions] × [frequencies]
115-
freqs = torch.outer(t, freqs)
37+
def rotate_half(x: torch.Tensor) -> torch.Tensor:
38+
"""Rotates half the hidden dims of the input."""
39+
x1 = x[..., : x.shape[-1] // 2]
40+
x2 = x[..., x.shape[-1] // 2 :]
41+
return torch.cat((-x2, x1), dim=-1)
11642

117-
# Convert to complex exponentials: e^(i*freq*pos)
118-
freqs_cis = torch.polar(torch.full_like(freqs, fill_value=mscale), freqs)
11943

120-
return freqs_cis
44+
def reshape_for_broadcast(rope_cache: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
45+
"""
46+
Reshape frequency tensor (represented by cos, sin) for broadcasting it with another tensor.
12147
48+
This function reshapes the frequency tensor to have the same shape as the target tensor 'x'
49+
for the purpose of broadcasting the frequency tensor during element-wise operations.
12250
123-
def apply_rotary_emb_inner(x: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor:
124-
"""
125-
Applies rotary positional embeddings to the input tensor.
51+
The input freqs_cis tensor is assumed to be of shape (max_seqlen, head_dim * 2),
52+
and the first seqlen elements will be sliced, but dim must match x.
12653
12754
Args:
128-
x (torch.Tensor): Input tensor with positional embeddings to be applied.
129-
freqs_cis (torch.Tensor): Precomputed complex exponential values for positional embeddings.
55+
rope_cache (torch.Tensor): RoPE tensor (cos and sin) to be reshaped.
56+
x (torch.Tensor): Target tensor for broadcasting compatibility.
13057
13158
Returns:
132-
torch.Tensor: Tensor with rotary embeddings applied.
59+
torch.Tensor: Reshaped frequency tensor.
13360
"""
134-
dtype = x.dtype
135-
x = torch.view_as_complex(x.float().view(*x.shape[:-1], -1, 2))
136-
freqs_cis = freqs_cis.view(1, x.size(1), 1, x.size(-1))
137-
y = torch.view_as_real(x * freqs_cis).flatten(3)
138-
return y.to(dtype)
61+
ndim = x.ndim
62+
assert ndim > 1
63+
_, seqlen, _, head_dim = x.shape
64+
rope_cache = rope_cache[0:seqlen]
65+
# The shape of rope_cache is (seqlen, head_dim * 2) because we concate cos and sin
66+
assert rope_cache.shape == (seqlen, head_dim * 2)
67+
shape = [-1, seqlen, 1, head_dim * 2]
68+
return rope_cache.view(*shape)
69+
70+
71+
def apply_rotary_emb(
72+
xq: torch.Tensor, xk: torch.Tensor, rope_cache: torch.Tensor
73+
) -> tuple[torch.Tensor, torch.Tensor]:
74+
# input tensor x has shape [bsz, seq_len, num_heads, head_dim]
75+
head_dim = xq.shape[-1]
76+
77+
# reshape for broadcast
78+
rope_cache = reshape_for_broadcast(rope_cache, xq)
79+
80+
# [bsz, seq_len, 1, head_dim]
81+
cos = rope_cache[..., :head_dim].to(dtype=xq.dtype, device=xq.device)
82+
sin = rope_cache[..., head_dim:].to(dtype=xq.dtype, device=xq.device)
83+
84+
# xq: [bsz, seq_len, num_heads, head_dim]
85+
# xk: [bsz, seq_len, num_kv_heads, head_dim]
86+
xq_out = (xq * cos) + (rotate_half(xq) * sin)
87+
xk_out = (xk * cos) + (rotate_half(xk) * sin)
88+
return xq_out.type_as(xq), xk_out.type_as(xk)
13989

140-
def apply_rotary_emb(q: torch.Tensor, k: torch.Tensor, freqs_cis: torch.Tensor):
141-
"""
142-
HF-style inputs (half-split last dim) -> interleave -> Torchtitan complex RoPE -> de-interleave.
143-
Shapes:
144-
q, k: [B, T, H, D] with D even (HF half-split: first D/2 real, last D/2 imag)
145-
freqs_cis: complex, last dim == D/2. Typically [T, D/2] or [1, T, D/2].
146-
Returns:
147-
q_out, k_out in HF half-split layout (same shape as q, k).
148-
"""
149-
B, T, H, D = q.shape
150-
assert D % 2 == 0, "head_dim must be even for RoPE"
151-
rot = D // 2
152-
assert freqs_cis.shape[-1] == rot, "freqs_cis last dim must be D/2"
153-
freqs_cis = freqs_cis[:T, :]
154-
155-
# Memory layout comparison for head_dim=8:
156-
# HF Format: [r0][r1][r2][r3][i0][i1][i2][i3]
157-
# ↑-- reals --↑ ↑-- imags --↑
158-
159-
# Interleaved: [r0][i0][r1][i1][r2][i2][r3][i3]
160-
# ↑-pair-↑ ↑-pair-↑ ↑-pair-↑ ↑-pair-↑
161-
# --- inline: HF half-split -> interleaved (real0, imag0, real1, imag1, ...)
162-
# q_i, k_i: [B, T, H, D]
163-
q_i = torch.empty_like(q)
164-
k_i = torch.empty_like(k)
165-
q_i[..., 0::2] = q[..., :rot]
166-
q_i[..., 1::2] = q[..., rot:]
167-
k_i[..., 0::2] = k[..., :rot]
168-
k_i[..., 1::2] = k[..., rot:]
169-
170-
# --- Torchtitan default complex apply (expects interleaved last dim)
171-
# freqs_cis will be reshaped inside to [1, T, 1, rot]
172-
# TODO(jianiw): I think we shoud go with sin/cos representation to simplify the conversion between paired real/imaginary <-> half-split real/imaginary
173-
q_rot_i = apply_rotary_emb_inner(q_i, freqs_cis) # uses TT's complex path
174-
k_rot_i = apply_rotary_emb_inner(k_i, freqs_cis)
175-
176-
# --- inline: interleaved -> HF half-split
177-
# TODO(jianiw): convert it back
178-
q_out = torch.cat([q_rot_i[..., 0::2], q_rot_i[..., 1::2]], dim=-1)
179-
k_out = torch.cat([k_rot_i[..., 0::2], k_rot_i[..., 1::2]], dim=-1)
180-
return q_out, k_out
181-
182-
# Torch Attention backup implementation (for debugging and sampling) from HuggingFace
18390
def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
18491
"""torch.repeat_interleave(x, dim=2, repeats=n_rep)"""
18592
bs, slen, n_kv_heads, head_dim = x.shape
@@ -215,7 +122,7 @@ def eager_attention_forward(
215122
add_mask = attention_mask.to(attn_weights.dtype)
216123

217124
# Truncate to current key length and add (broadcasts if needed)
218-
add_mask = add_mask[..., : key_states.shape[-2]]
125+
add_mask = add_mask[..., : key.shape[-2]]
219126
attn_weights = attn_weights + add_mask
220127

221128
sinks = sinks.reshape(1, -1, 1, 1).expand(query.shape[0], -1, query.shape[-2], -1)
@@ -275,14 +182,14 @@ def __init__(self, model_args: GptOssModelArgs, use_sliding_attention: bool = Fa
275182
def forward(
276183
self,
277184
x: torch.Tensor,
278-
freqs_cis: torch.Tensor,
185+
rope_cache: torch.Tensor,
279186
):
280187
"""
281188
Forward pass for the Multi-Head Latent Attention (MLA) Layer.
282189
283190
Args:
284191
x (torch.Tensor): Input tensor of shape (batch_size, seq_len, dim).
285-
freqs_cis (torch.Tensor): Precomputed complex exponential values for rotary embeddings.
192+
rope_cache (torch.Tensor): Precomputed cosine and sine frequencies for rope embedding.
286193
287194
Returns:
288195
torch.Tensor: Output tensor with the same shape as the input.
@@ -294,7 +201,7 @@ def forward(
294201
k = self.wk(x).view(hidden_shape)
295202
v = self.wv(x).view(hidden_shape)
296203

297-
q, k = apply_rotary_emb(q, k, freqs_cis)
204+
q, k = apply_rotary_emb(q, k, rope_cache)
298205

299206
# repeat k/v heads if n_kv_heads < n_heads
300207
keys = repeat_kv(k, self.n_rep)
@@ -369,18 +276,18 @@ def __init__(self, layer_id: int, model_args: GptOssModelArgs):
369276
self.weight_init_std = 0.02 / (2 * (layer_id + 1)) ** 0.5
370277
self.layer_id = layer_id
371278

372-
def forward(self, x: torch.Tensor, freqs_cis: torch.Tensor):
279+
def forward(self, x: torch.Tensor, rope_cache: torch.Tensor):
373280
"""
374281
Forward pass for the Transformer block.
375282
376283
Args:
377284
x (torch.Tensor): Input tensor of shape (batch_size, seq_len, dim).
378-
freqs_cis (torch.Tensor): Precomputed complex exponential values for rotary embeddings.
285+
rope_cache (torch.Tensor): Precomputed cosine and sine frequencies.
379286
380287
Returns:
381288
torch.Tensor: Output tensor with the same shape as the input.
382289
"""
383-
x = x + self.attention(self.attention_norm(x), freqs_cis)
290+
x = x + self.attention(self.attention_norm(x), rope_cache)
384291
x = x + self.moe(self.ffn_norm(x))
385292
return x
386293

@@ -398,16 +305,16 @@ class GptOssModel(nn.Module, ModelProtocol):
398305

399306
def __init__(self, model_args: GptOssModelArgs):
400307
super().__init__()
308+
self.model_args = model_args
401309
self.max_seq_len = model_args.max_seq_len
402310
self.tok_embeddings = nn.Embedding(model_args.vocab_size, model_args.hidden_size)
403311
self.register_buffer(
404-
"freqs_cis", precompute_freqs_cis(model_args), persistent=True
312+
"rope_cache", self._precompute_rope_cache(), persistent=False
405313
)
406314

407315
self.layers = torch.nn.ModuleDict()
408316
for layer_id in range(model_args.num_hidden_layers):
409317
self.layers[str(layer_id)] = TransformerBlock(layer_id, model_args).to(torch.bfloat16)
410-
# convert_submodules_to_bf16(self.layers[str(layer_id)])
411318

412319
self.norm = nn.RMSNorm(model_args.hidden_size, eps=model_args.norm_eps)
413320
self.output = nn.Linear(
@@ -418,12 +325,11 @@ def __init__(self, model_args: GptOssModelArgs):
418325
)
419326
self.model_args = model_args
420327
self.init_weights()
421-
# convert_submodules_to_bf16(self)
422328

423329
def init_weights(self, buffer_device: torch.device | None = None) -> None:
424-
buffer_device = buffer_device or self.freqs_cis.device
330+
buffer_device = buffer_device or self.rope_cache.device
425331
with torch.device(buffer_device):
426-
self.freqs_cis = precompute_freqs_cis(self.model_args)
332+
self.rope_cache = self._precompute_rope_cache()
427333
if self.tok_embeddings is not None:
428334
nn.init.normal_(self.tok_embeddings.weight)
429335
for layer in self.layers.values():
@@ -442,6 +348,13 @@ def init_weights(self, buffer_device: torch.device | None = None) -> None:
442348
b=cutoff_factor * final_out_std,
443349
)
444350

351+
def _precompute_rope_cache(self) -> torch.Tensor:
352+
return precompute_rope_cache(
353+
self.model_args.head_dim,
354+
self.model_args.max_seq_len,
355+
self.model_args.rope_theta,
356+
)
357+
445358
def forward(self, tokens: torch.Tensor):
446359
"""
447360
Forward pass for the Transformer model.
@@ -455,7 +368,7 @@ def forward(self, tokens: torch.Tensor):
455368
h = self.tok_embeddings(tokens)
456369

457370
for layer in self.layers.values():
458-
h = layer(h, self.freqs_cis)
371+
h = layer(h, self.rope_cache)
459372
h = self.norm(h)
460373
output = self.output(h)
461374
return output

0 commit comments

Comments
 (0)