|
9 | 9 |
|
10 | 10 | import torch
|
11 | 11 | import torch.nn as nn
|
12 |
| -from torch.distributed._composable.replicate import replicate |
| 12 | +from torch.distributed._composable.replicate_with_fsdp import replicate |
13 | 13 |
|
14 | 14 | from torch.distributed.device_mesh import DeviceMesh
|
15 | 15 | from torch.distributed.fsdp import CPUOffloadPolicy, fully_shard, MixedPrecisionPolicy
|
@@ -135,13 +135,12 @@ def parallelize_llama(
|
135 | 135 | if job_config.training.enable_cpu_offload:
|
136 | 136 | logger.info("Applied CPU Offloading to the model")
|
137 | 137 | elif parallel_dims.dp_replicate_enabled:
|
138 |
| - if world_mesh.ndim > 1: |
139 |
| - raise RuntimeError("DDP has not supported > 1D parallelism") |
140 |
| - apply_ddp( |
| 138 | + dp_mesh_dim_names = ("dp_replicate", "dp_shard") |
| 139 | + apply_replicate( |
141 | 140 | model,
|
142 |
| - world_mesh, |
143 |
| - enable_compile=model_compile_enabled, |
144 |
| - enable_compiled_autograd=job_config.parallelism.enable_compiled_autograd, |
| 141 | + world_mesh[tuple(dp_mesh_dim_names)], |
| 142 | + param_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_param], |
| 143 | + reduce_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_reduce], |
145 | 144 | )
|
146 | 145 |
|
147 | 146 | return model
|
@@ -314,20 +313,31 @@ def apply_fsdp(
|
314 | 313 | fully_shard(model, **fsdp_config)
|
315 | 314 |
|
316 | 315 |
|
317 |
| -def apply_ddp( |
| 316 | +def apply_replicate( |
318 | 317 | model: nn.Module,
|
319 | 318 | dp_mesh: DeviceMesh,
|
320 |
| - enable_compile: bool, |
321 |
| - enable_compiled_autograd: bool, |
| 319 | + param_dtype: torch.dtype, |
| 320 | + reduce_dtype: torch.dtype, |
322 | 321 | ):
|
323 |
| - if enable_compile: |
324 |
| - if enable_compiled_autograd: |
325 |
| - torch._dynamo.config.optimize_ddp = ( |
326 |
| - "python_reducer_without_compiled_forward" |
327 |
| - ) |
328 |
| - else: |
329 |
| - torch._dynamo.config.optimize_ddp = "ddp_optimizer" |
| 322 | + mp_policy = MixedPrecisionPolicy(param_dtype=param_dtype, reduce_dtype=reduce_dtype) |
| 323 | + replicate_config = {"device_mesh": dp_mesh, "mp_policy": mp_policy} |
| 324 | + |
| 325 | + if model.tok_embeddings is not None: |
| 326 | + replicate( |
| 327 | + model.tok_embeddings, |
| 328 | + **replicate_config, |
| 329 | + ) |
| 330 | + for layer_id, transformer_block in model.layers.items(): |
| 331 | + replicate( |
| 332 | + transformer_block, |
| 333 | + **replicate_config, |
| 334 | + ) |
330 | 335 |
|
331 |
| - replicate(model, device_mesh=dp_mesh, bucket_cap_mb=100) |
| 336 | + if model.norm is not None and model.output is not None: |
| 337 | + replicate( |
| 338 | + [model.norm, model.output], |
| 339 | + **replicate_config, |
| 340 | + ) |
| 341 | + replicate(model, **replicate_config) |
332 | 342 |
|
333 | 343 | logger.info("Applied DDP to the model")
|
0 commit comments