Skip to content

Commit 07c0ff4

Browse files
committed
fix sink
1 parent bb8ee6f commit 07c0ff4

File tree

3 files changed

+30
-21
lines changed

3 files changed

+30
-21
lines changed

torchtitan/experiments/gpt_oss/infra/parallelize.py

Lines changed: 19 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
PrepareModuleInput,
99
RowwiseParallel,
1010
SequenceParallel,
11+
PrepareModuleOutput,
1112
)
1213

1314
if torch.__version__ >= "2.9":
@@ -22,7 +23,6 @@
2223
from torchtitan.models.llama3.infra.parallelize import apply_ac, apply_ddp
2324
from torchtitan.experiments.llama4.infra.parallelize import (
2425
apply_fsdp,
25-
apply_moe_ep_tp,
2626
)
2727

2828
from torchtitan.tools.logging import logger
@@ -212,49 +212,52 @@ def apply_non_moe_tp(
212212
Float8ColwiseParallel,
213213
Float8RowwiseParallel,
214214
PrepareFloat8ModuleInput,
215+
PrepareFloat8ModuleOutput
215216
)
216217

217-
rowwise_parallel, colwise_parallel, prepare_module_input = (
218+
rowwise_parallel, colwise_parallel, prepare_module_input, prepare_module_output = (
218219
Float8RowwiseParallel,
219220
Float8ColwiseParallel,
220221
PrepareFloat8ModuleInput,
222+
PrepareFloat8ModuleOutput,
221223
)
222224
else:
223-
rowwise_parallel, colwise_parallel, prepare_module_input = (
225+
rowwise_parallel, colwise_parallel, prepare_module_input, prepare_module_output= (
224226
RowwiseParallel,
225227
ColwiseParallel,
226228
PrepareModuleInput,
229+
PrepareModuleOutput,
227230
)
228231

229232
# Apply tensor + sequence parallelism to every transformer block
230233
for transformer_block in model.layers.values():
231234
layer_plan = {
232235
"attention_norm": SequenceParallel(),
233236
"attention": prepare_module_input(
234-
input_layouts=(Shard(1), Replicate()),
235-
desired_input_layouts=(Replicate(), Replicate()),
237+
input_layouts=(Shard(1), None),
238+
desired_input_layouts=(Replicate(), None),
236239
),
237-
"attention.wq": colwise_parallel(use_local_output=False),
238-
"attention.wk": colwise_parallel(use_local_output=False),
239-
"attention.wv": colwise_parallel(use_local_output=False),
240+
"attention.wq": colwise_parallel(),
241+
"attention.wk": colwise_parallel(),
242+
"attention.wv": colwise_parallel(),
243+
"attention.attn": prepare_module_output(output_layouts=(Shard(1), Shard(1)), desired_output_layouts=(Shard(1), Shard(1)), use_local_output=False),
240244
"attention.wo": rowwise_parallel(output_layouts=Shard(1)),
241245
"ffn_norm": SequenceParallel(),
242246
}
243247

248+
# shard attention.sinks across heads
249+
attn = transformer_block.attention
250+
attn.register_parameter(
251+
"sinks",
252+
nn.Parameter(distribute_tensor(attn.sinks, tp_mesh, [Shard(0)])),
253+
)
254+
244255
parallelize_module(
245256
module=transformer_block,
246257
device_mesh=tp_mesh,
247258
parallelize_plan=layer_plan,
248259
)
249260

250-
# shard attention.sinks across heads
251-
# TODO(jianiw): Fix the sink implementation
252-
# attn = transformer_block.attention
253-
# attn.register_parameter(
254-
# "sinks",
255-
# nn.Parameter(distribute_tensor(attn.sinks, tp_mesh, [Replicate()])),
256-
# )
257-
258261
if enable_async_tp:
259262
from torch.distributed._symmetric_memory import enable_symm_mem_for_group
260263

torchtitan/experiments/gpt_oss/model/model.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -245,14 +245,18 @@ def forward(
245245
k,
246246
v,
247247
scale=None,
248-
return_lse=False,
248+
return_lse=True,
249249
)
250250

251251
# Apply attention sink rescaling: rescale by σ(lse - w[h])
252252
# This is mathematically equivalent to concatenating learnable sink weights
253+
# TODO: If attention part is, but self.sinks are registered as a DTensor, while lse is a plain tensor
254+
# q, k, v are already sharded by TP: [batch, local_heads, seq_len, head_dim] (plain tensor)
255+
# sinks shape needs to match: [local_heads],
256+
# [rank0]:lse.shape torch.Size([8, 32, 2048]), <class 'torch.Tensor'>
253257
sink_scale = torch.sigmoid(lse - self.sinks.view(1, -1, 1)).unsqueeze(
254258
-1
255-
) # [B,H,S,1]
259+
)
256260
output = output * sink_scale.to(output.dtype)
257261

258262
else:

torchtitan/experiments/gpt_oss/train_configs/debug_model.toml

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -46,10 +46,12 @@ dataset = "c4_test" # supported datasets: c4_test (2K), c4 (177M)
4646
data_parallel_replicate_degree = 1
4747
data_parallel_shard_degree = -1
4848
fsdp_reshard_after_forward = "default" # default / never / always
49-
tensor_parallel_degree = 2
49+
tensor_parallel_degree = 1
5050
enable_async_tensor_parallel = false
51-
pipeline_parallel_degree = 1
52-
context_parallel_degree = 1
51+
expert_parallel_degree = 4
52+
53+
54+
5355

5456
[checkpoint]
5557
enable = false

0 commit comments

Comments
 (0)