|
8 | 8 | PrepareModuleInput,
|
9 | 9 | RowwiseParallel,
|
10 | 10 | SequenceParallel,
|
| 11 | + PrepareModuleOutput, |
11 | 12 | )
|
12 | 13 |
|
13 | 14 | if torch.__version__ >= "2.9":
|
|
22 | 23 | from torchtitan.models.llama3.infra.parallelize import apply_ac, apply_ddp
|
23 | 24 | from torchtitan.experiments.llama4.infra.parallelize import (
|
24 | 25 | apply_fsdp,
|
25 |
| - apply_moe_ep_tp, |
26 | 26 | )
|
27 | 27 |
|
28 | 28 | from torchtitan.tools.logging import logger
|
@@ -212,49 +212,52 @@ def apply_non_moe_tp(
|
212 | 212 | Float8ColwiseParallel,
|
213 | 213 | Float8RowwiseParallel,
|
214 | 214 | PrepareFloat8ModuleInput,
|
| 215 | + PrepareFloat8ModuleOutput |
215 | 216 | )
|
216 | 217 |
|
217 |
| - rowwise_parallel, colwise_parallel, prepare_module_input = ( |
| 218 | + rowwise_parallel, colwise_parallel, prepare_module_input, prepare_module_output = ( |
218 | 219 | Float8RowwiseParallel,
|
219 | 220 | Float8ColwiseParallel,
|
220 | 221 | PrepareFloat8ModuleInput,
|
| 222 | + PrepareFloat8ModuleOutput, |
221 | 223 | )
|
222 | 224 | else:
|
223 |
| - rowwise_parallel, colwise_parallel, prepare_module_input = ( |
| 225 | + rowwise_parallel, colwise_parallel, prepare_module_input, prepare_module_output= ( |
224 | 226 | RowwiseParallel,
|
225 | 227 | ColwiseParallel,
|
226 | 228 | PrepareModuleInput,
|
| 229 | + PrepareModuleOutput, |
227 | 230 | )
|
228 | 231 |
|
229 | 232 | # Apply tensor + sequence parallelism to every transformer block
|
230 | 233 | for transformer_block in model.layers.values():
|
231 | 234 | layer_plan = {
|
232 | 235 | "attention_norm": SequenceParallel(),
|
233 | 236 | "attention": prepare_module_input(
|
234 |
| - input_layouts=(Shard(1), Replicate()), |
235 |
| - desired_input_layouts=(Replicate(), Replicate()), |
| 237 | + input_layouts=(Shard(1), None), |
| 238 | + desired_input_layouts=(Replicate(), None), |
236 | 239 | ),
|
237 |
| - "attention.wq": colwise_parallel(use_local_output=False), |
238 |
| - "attention.wk": colwise_parallel(use_local_output=False), |
239 |
| - "attention.wv": colwise_parallel(use_local_output=False), |
| 240 | + "attention.wq": colwise_parallel(), |
| 241 | + "attention.wk": colwise_parallel(), |
| 242 | + "attention.wv": colwise_parallel(), |
| 243 | + "attention.attn": prepare_module_output(output_layouts=(Shard(1), Shard(1)), desired_output_layouts=(Shard(1), Shard(1)), use_local_output=False), |
240 | 244 | "attention.wo": rowwise_parallel(output_layouts=Shard(1)),
|
241 | 245 | "ffn_norm": SequenceParallel(),
|
242 | 246 | }
|
243 | 247 |
|
| 248 | + # shard attention.sinks across heads |
| 249 | + attn = transformer_block.attention |
| 250 | + attn.register_parameter( |
| 251 | + "sinks", |
| 252 | + nn.Parameter(distribute_tensor(attn.sinks, tp_mesh, [Shard(0)])), |
| 253 | + ) |
| 254 | + |
244 | 255 | parallelize_module(
|
245 | 256 | module=transformer_block,
|
246 | 257 | device_mesh=tp_mesh,
|
247 | 258 | parallelize_plan=layer_plan,
|
248 | 259 | )
|
249 | 260 |
|
250 |
| - # shard attention.sinks across heads |
251 |
| - # TODO(jianiw): Fix the sink implementation |
252 |
| - # attn = transformer_block.attention |
253 |
| - # attn.register_parameter( |
254 |
| - # "sinks", |
255 |
| - # nn.Parameter(distribute_tensor(attn.sinks, tp_mesh, [Replicate()])), |
256 |
| - # ) |
257 |
| - |
258 | 261 | if enable_async_tp:
|
259 | 262 | from torch.distributed._symmetric_memory import enable_symm_mem_for_group
|
260 | 263 |
|
|
0 commit comments