Skip to content

Commit f1d41a1

Browse files
Support rope cache indexing using positions (#2112)
Add support to indexing rope cache using `position_ids`, this might be needed during 1. inference, where we passed in `position_ids` into transformer forward 2. CP load balancing where we need to index rope cache given positions ids Test: running dpskv3 16b base <img width="489" height="286" alt="image" src="https://github.com/user-attachments/assets/6f463d65-a0de-413d-ab19-770db9983dbb" /> also tested in https://github.com/wwwjn/torchtitan/pull/1/files when passing position_ids <img width="665" height="269" alt="image" src="https://github.com/user-attachments/assets/70e4bddc-0334-4dbf-b00d-6e4b49a94655" /> --------- Co-authored-by: JessicaZhong <[email protected]>
1 parent 1ebd914 commit f1d41a1

File tree

8 files changed

+226
-47
lines changed

8 files changed

+226
-47
lines changed

torchtitan/models/deepseek_v3/infra/parallelize.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -224,9 +224,11 @@ def apply_non_moe_tp(
224224
for transformer_block in model.layers.values():
225225
layer_plan = {
226226
"attention_norm": SequenceParallel(),
227+
# NOTE: when the fourth argument (positions) is not None, its input layout
228+
# and desired input layout should be Replicate()
227229
"attention": prepare_module_input(
228-
input_layouts=(Shard(1), Replicate(), None),
229-
desired_input_layouts=(Replicate(), Replicate(), None),
230+
input_layouts=(Shard(1), Replicate(), None, None),
231+
desired_input_layouts=(Replicate(), Replicate(), None, None),
230232
),
231233
# NOTE: use_local_output=False make the output to be a DTensor instead of a plain Tensor
232234
# so that the intermedidate results k is generated as a DTensor and its gradient is

torchtitan/models/deepseek_v3/model/model.py

Lines changed: 68 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -126,20 +126,71 @@ def linear_ramp_factor(min: float, max: float, dim: int) -> torch.Tensor:
126126
return freqs_cis
127127

128128

129-
def apply_rotary_emb(x: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor:
129+
def reshape_for_broadcast(
130+
freqs_cis: torch.Tensor, x: torch.Tensor, positions: torch.Tensor | None = None
131+
) -> torch.Tensor:
132+
"""
133+
Reshape frequency tensor for broadcasting it with another tensor.
134+
135+
This function reshapes the frequency tensor to have the same shape as the target tensor 'x'
136+
for the purpose of broadcasting the frequency tensor during element-wise operations.
137+
138+
The input freqs_cis tensor is assumed to be of shape (max_seqlen, dim // 2),
139+
and the first seqlen elements will be sliced, but dim must match x.
140+
141+
Args:
142+
freqs_cis (torch.Tensor): Frequency tensor to be reshaped.
143+
x (torch.Tensor): Target tensor for broadcasting compatibility.
144+
positions (torch.Tensor | None): Position indices used to access/shuffle RoPE cache.
145+
Shape is (1, seqlen) or (bz, seqlen). Defaults to None.
146+
147+
Returns:
148+
torch.Tensor: Reshaped frequency tensor.
149+
"""
150+
ndim = x.ndim
151+
assert ndim > 1
152+
seqlen = x.shape[1]
153+
if positions is None:
154+
freqs_cis = freqs_cis[0:seqlen]
155+
assert freqs_cis.shape == (seqlen, x.shape[-1])
156+
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
157+
return freqs_cis.view(*shape)
158+
elif positions.size(0) == 1:
159+
assert positions.shape == (1, seqlen)
160+
freqs_cis = freqs_cis[positions.squeeze(0)]
161+
assert freqs_cis.shape == (seqlen, x.shape[-1])
162+
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
163+
return freqs_cis.view(*shape)
164+
else:
165+
assert positions.shape == (x.shape[0], seqlen)
166+
freqs_cis_expanded = freqs_cis[None, :, None, :].expand(x.shape[0], -1, -1, -1)
167+
freqs_cis = torch.gather(
168+
freqs_cis_expanded,
169+
dim=1,
170+
index=positions.view(x.shape[0], seqlen, 1, 1).expand(
171+
x.shape[0], seqlen, 1, freqs_cis_expanded.shape[-1]
172+
),
173+
)
174+
return freqs_cis
175+
176+
177+
def apply_rotary_emb(
178+
x: torch.Tensor, freqs_cis: torch.Tensor, positions: torch.Tensor | None = None
179+
) -> torch.Tensor:
130180
"""
131181
Applies rotary positional embeddings to the input tensor.
132182
133183
Args:
134184
x (torch.Tensor): Input tensor with positional embeddings to be applied.
135185
freqs_cis (torch.Tensor): Precomputed complex exponential values for positional embeddings.
186+
positions (torch.Tensor | None): Position indices used to access/shuffle RoPE cache. Defaults to None.
136187
137188
Returns:
138189
torch.Tensor: Tensor with rotary embeddings applied.
139190
"""
140191
dtype = x.dtype
141192
x = torch.view_as_complex(x.float().view(*x.shape[:-1], -1, 2))
142-
freqs_cis = freqs_cis.view(1, x.size(1), 1, x.size(-1))
193+
freqs_cis = reshape_for_broadcast(freqs_cis, x, positions)
143194
y = torch.view_as_real(x * freqs_cis).flatten(3)
144195
return y.to(dtype)
145196

@@ -196,13 +247,16 @@ def forward(
196247
x: torch.Tensor,
197248
freqs_cis: torch.Tensor,
198249
attention_masks: AttentionMasksType | None,
250+
positions: torch.Tensor | None = None,
199251
):
200252
"""
201253
Forward pass for the Multi-Head Latent Attention (MLA) Layer.
202254
203255
Args:
204256
x (torch.Tensor): Input tensor of shape (batch_size, seq_len, dim).
205257
freqs_cis (torch.Tensor): Precomputed complex exponential values for rotary embeddings.
258+
attention_masks (AttentionMasksType | None): Masks used when calculating attention scores.
259+
positions (torch.Tensor | None): Position indices used to access/shuffle RoPE cache. Defaults to None.
206260
207261
Returns:
208262
torch.Tensor: Output tensor with the same shape as the input.
@@ -222,15 +276,15 @@ def forward(
222276
q_nope, q_pe = torch.split(
223277
q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1
224278
)
225-
q_pe = apply_rotary_emb(q_pe, freqs_cis)
279+
q_pe = apply_rotary_emb(q_pe, freqs_cis, positions)
226280
q = torch.cat([q_nope, q_pe], dim=-1) # (bsz, seqlen, n_heads, qk_head_dim)
227281

228282
# Key-value projection
229283
kv = self.wkv_a(x) # (bsz, seqlen, kv_lora_rank + qk_rope_head_dim)
230284
kv, k_pe = torch.split(kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
231285

232286
k_pe = apply_rotary_emb(
233-
k_pe.unsqueeze(2), freqs_cis
287+
k_pe.unsqueeze(2), freqs_cis, positions
234288
) # (bsz, seqlen, 1, qk_rope_head_dim)
235289

236290
kv = self.wkv_b(
@@ -312,18 +366,23 @@ def forward(
312366
x: torch.Tensor,
313367
freqs_cis: torch.Tensor,
314368
attention_masks: AttentionMasksType | None,
369+
positions: torch.Tensor | None = None,
315370
):
316371
"""
317372
Forward pass for the Transformer block.
318373
319374
Args:
320375
x (torch.Tensor): Input tensor of shape (batch_size, seq_len, dim).
321376
freqs_cis (torch.Tensor): Precomputed complex exponential values for rotary embeddings.
377+
attention_masks (AttentionMasksType | None): Masks used when calculating attention scores.
378+
positions (torch.Tensor | None): Position indices used to access/shuffle RoPE cache. Defaults to None.
322379
323380
Returns:
324381
torch.Tensor: Output tensor with the same shape as the input.
325382
"""
326-
x = x + self.attention(self.attention_norm(x), freqs_cis, attention_masks)
383+
x = x + self.attention(
384+
self.attention_norm(x), freqs_cis, attention_masks, positions
385+
)
327386
if self.moe_enabled:
328387
x = x + self.moe(self.ffn_norm(x))
329388
else:
@@ -413,6 +472,7 @@ def forward(
413472
self,
414473
tokens: torch.Tensor,
415474
attention_masks: AttentionMasksType | None = None,
475+
positions: torch.Tensor | None = None,
416476
):
417477
"""
418478
Forward pass for the Transformer model.
@@ -422,6 +482,8 @@ def forward(
422482
If pipeline parallelism is enabled, this will be the input token indices
423483
for the ranks on the first pipeline stage. This will be the activation of the
424484
previous pipeline stage if the current rank is not on the first stage.
485+
attention_masks (AttentionMasksType | None): Masks used when calculating attention scores.
486+
positions (torch.Tensor | None): Position indices used to access/shuffle RoPE cache. Defaults to None.
425487
426488
Returns:
427489
torch.Tensor: Logits tensor of shape (batch_size, vocab_size).
@@ -430,7 +492,7 @@ def forward(
430492
h = self.tok_embeddings(tokens) if self.tok_embeddings is not None else tokens
431493

432494
for layer in self.layers.values():
433-
h = layer(h, self.freqs_cis, attention_masks)
495+
h = layer(h, self.freqs_cis, attention_masks, positions)
434496
h = self.norm(h) if self.norm is not None else h
435497
output = self.output(h) if self.output is not None else h
436498
return output

torchtitan/models/llama3/infra/parallelize.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -205,9 +205,11 @@ def apply_tp(
205205
for transformer_block in model.layers.values():
206206
layer_plan = {
207207
"attention_norm": SequenceParallel(),
208+
# NOTE: when the fourth argument (positions) is not None, its input layout
209+
# and desired input layout should be Replicate()
208210
"attention": prepare_module_input(
209-
input_layouts=(Shard(1), None, None),
210-
desired_input_layouts=(Replicate(), None, None),
211+
input_layouts=(Shard(1), None, None, None),
212+
desired_input_layouts=(Replicate(), None, None, None),
211213
),
212214
"attention.wq": colwise_parallel(),
213215
"attention.wk": colwise_parallel(),

torchtitan/models/llama3/model/model.py

Lines changed: 47 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -88,36 +88,59 @@ def precompute_freqs_cis(
8888
return freqs_cis
8989

9090

91-
def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
91+
def reshape_for_broadcast(
92+
freqs_cis: torch.Tensor, x: torch.Tensor, positions: torch.Tensor | None = None
93+
) -> torch.Tensor:
9294
"""
9395
Reshape frequency tensor for broadcasting it with another tensor.
9496
9597
This function reshapes the frequency tensor to have the same shape as the target tensor 'x'
9698
for the purpose of broadcasting the frequency tensor during element-wise operations.
9799
98-
The input freqs_cis tensor is assumed to be of shape (max_seqlen, dim),
100+
The input freqs_cis tensor is assumed to be of shape (max_seqlen, dim // 2),
99101
and the first seqlen elements will be sliced, but dim must match x.
100102
101103
Args:
102104
freqs_cis (torch.Tensor): Frequency tensor to be reshaped.
103105
x (torch.Tensor): Target tensor for broadcasting compatibility.
106+
positions (torch.Tensor | None): Position indices used to access/shuffle RoPE cache.
107+
Shape is (1, seqlen) or (bz, seqlen). Defaults to None.
104108
105109
Returns:
106110
torch.Tensor: Reshaped frequency tensor.
107111
"""
108112
ndim = x.ndim
109113
assert ndim > 1
110114
seqlen = x.shape[1]
111-
freqs_cis = freqs_cis[0:seqlen]
112-
assert freqs_cis.shape == (seqlen, x.shape[-1])
113-
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
114-
return freqs_cis.view(*shape)
115+
if positions is None:
116+
freqs_cis = freqs_cis[0:seqlen]
117+
assert freqs_cis.shape == (seqlen, x.shape[-1])
118+
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
119+
return freqs_cis.view(*shape)
120+
elif positions.size(0) == 1:
121+
assert positions.shape == (1, seqlen)
122+
freqs_cis = freqs_cis[positions.squeeze(0)]
123+
assert freqs_cis.shape == (seqlen, x.shape[-1])
124+
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
125+
return freqs_cis.view(*shape)
126+
else:
127+
assert positions.shape == (x.shape[0], seqlen)
128+
freqs_cis_expanded = freqs_cis[None, :, None, :].expand(x.shape[0], -1, -1, -1)
129+
freqs_cis = torch.gather(
130+
freqs_cis_expanded,
131+
dim=1,
132+
index=positions.view(x.shape[0], seqlen, 1, 1).expand(
133+
x.shape[0], seqlen, 1, freqs_cis_expanded.shape[-1]
134+
),
135+
)
136+
return freqs_cis
115137

116138

117139
def apply_rotary_emb(
118140
xq: torch.Tensor,
119141
xk: torch.Tensor,
120142
freqs_cis: torch.Tensor,
143+
positions: torch.Tensor | None = None,
121144
) -> tuple[torch.Tensor, torch.Tensor]:
122145
"""
123146
Apply rotary embeddings to input tensors using the given frequency tensor.
@@ -131,13 +154,14 @@ def apply_rotary_emb(
131154
xq (torch.Tensor): Query tensor to apply rotary embeddings.
132155
xk (torch.Tensor): Key tensor to apply rotary embeddings.
133156
freqs_cis (torch.Tensor): Precomputed frequency tensor for complex exponentials.
157+
positions (torch.Tensor | None): Position indices used to access/shuffle RoPE cache. Defaults to None.
134158
135159
Returns:
136160
tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings.
137161
"""
138162
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
139163
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
140-
freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
164+
freqs_cis = reshape_for_broadcast(freqs_cis, xq_, positions)
141165
xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
142166
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
143167
return xq_out.type_as(xq), xk_out.type_as(xk)
@@ -213,13 +237,16 @@ def forward(
213237
x: torch.Tensor,
214238
freqs_cis: torch.Tensor,
215239
attention_masks: AttentionMasksType | None,
240+
positions: torch.Tensor | None = None,
216241
):
217242
"""
218243
Forward pass of the attention module.
219244
220245
Args:
221246
x (torch.Tensor): Input tensor.
222247
freqs_cis (torch.Tensor): Precomputed frequency tensor.
248+
attention_masks (AttentionMasksType | None): Masks used when calculating attention scores.
249+
positions (torch.Tensor | None): Position indices used to access/shuffle RoPE cache. Defaults to None.
223250
224251
Returns:
225252
torch.Tensor: Output tensor after attention.
@@ -236,7 +263,7 @@ def forward(
236263
xk = xk.view(bs, seqlen, -1, self.head_dim)
237264
xv = xv.view(bs, seqlen, -1, self.head_dim)
238265

239-
xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis)
266+
xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis, positions=positions)
240267

241268
# repeat k/v heads if n_kv_heads < n_heads
242269
keys = repeat_kv(xk, self.n_rep) # (bs, seqlen, n_local_heads, head_dim)
@@ -360,19 +387,24 @@ def forward(
360387
x: torch.Tensor,
361388
freqs_cis: torch.Tensor,
362389
attention_masks: AttentionMasksType | None,
390+
positions: torch.Tensor | None = None,
363391
):
364392
"""
365393
Perform a forward pass through the TransformerBlock.
366394
367395
Args:
368396
x (torch.Tensor): Input tensor.
369397
freqs_cis (torch.Tensor): Precomputed cosine and sine frequencies.
398+
attention_masks (AttentionMasksType | None): Masks used when calculating attention scores.
399+
positions (torch.Tensor | None): Position indices used to access/shuffle RoPE cache. Defaults to None.
370400
371401
Returns:
372402
torch.Tensor: Output tensor after applying attention and feedforward layers.
373403
374404
"""
375-
h = x + self.attention(self.attention_norm(x), freqs_cis, attention_masks)
405+
h = x + self.attention(
406+
self.attention_norm(x), freqs_cis, attention_masks, positions
407+
)
376408
out = h + self.feed_forward(self.ffn_norm(h))
377409
return out
378410

@@ -519,6 +551,7 @@ def forward(
519551
self,
520552
tokens: torch.Tensor,
521553
attention_masks: AttentionMasksType | None = None,
554+
positions: torch.Tensor | None = None,
522555
):
523556
"""
524557
Perform a forward pass through the Transformer model.
@@ -528,6 +561,8 @@ def forward(
528561
If pipeline parallelism is enabled, this will be the input token indices
529562
for the ranks on the first pipeline stage. This will be the activation of the
530563
previous pipeline stage if the current rank is not on the first stage.
564+
attention_masks (AttentionMasksType | None): Masks used when calculating attention scores.
565+
positions (torch.Tensor | None): Position indices used to access/shuffle RoPE cache. Defaults to None.
531566
532567
Returns:
533568
torch.Tensor: Output logits after applying the Transformer model.
@@ -537,7 +572,9 @@ def forward(
537572
h = self.tok_embeddings(tokens) if self.tok_embeddings else tokens
538573

539574
for layer in self.layers.values():
540-
h = layer(h, self.freqs_cis, attention_masks=attention_masks)
575+
h = layer(
576+
h, self.freqs_cis, attention_masks=attention_masks, positions=positions
577+
)
541578
h = self.norm(h) if self.norm else h
542579
output = self.output(h) if self.output else h
543580
return output

torchtitan/models/llama4/infra/parallelize.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -240,9 +240,11 @@ def apply_non_moe_tp(
240240
for transformer_block in model.layers.values():
241241
layer_plan = {
242242
"attention_norm": SequenceParallel(),
243+
# NOTE: when the fourth argument (positions) is not None, its input layout
244+
# and desired input layout should be Replicate()
243245
"attention": prepare_module_input(
244-
input_layouts=(Shard(1), None, None),
245-
desired_input_layouts=(Replicate(), None, None),
246+
input_layouts=(Shard(1), None, None, None),
247+
desired_input_layouts=(Replicate(), None, None, None),
246248
),
247249
"attention.wq": colwise_parallel(),
248250
"attention.wk": colwise_parallel(),

0 commit comments

Comments
 (0)