| 
9 | 9 | 
 
  | 
10 | 10 | import torch  | 
11 | 11 | import torch.nn as nn  | 
12 |  | -from torch.distributed._composable.replicate import replicate  | 
 | 12 | +from torch.distributed._composable.replicate_with_fsdp import replicate  | 
13 | 13 | 
 
  | 
14 | 14 | from torch.distributed.device_mesh import DeviceMesh  | 
15 | 15 | from torch.distributed.fsdp import CPUOffloadPolicy, fully_shard, MixedPrecisionPolicy  | 
@@ -135,13 +135,18 @@ def parallelize_llama(  | 
135 | 135 |         if job_config.training.enable_cpu_offload:  | 
136 | 136 |             logger.info("Applied CPU Offloading to the model")  | 
137 | 137 |     elif parallel_dims.dp_replicate_enabled:  | 
138 |  | -        if world_mesh.ndim > 1:  | 
139 |  | -            raise RuntimeError("DDP has not supported > 1D parallelism")  | 
 | 138 | +        # if world_mesh.ndim > 1:  | 
 | 139 | +        #    raise RuntimeError("DDP has not supported > 1D parallelism")  | 
 | 140 | + | 
 | 141 | +        dp_mesh_dim_names = ("dp_replicate", "dp_shard")  | 
140 | 142 |         apply_ddp(  | 
141 | 143 |             model,  | 
142 |  | -            world_mesh,  | 
 | 144 | +            world_mesh[tuple(dp_mesh_dim_names)],  | 
 | 145 | +            param_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_param],  | 
 | 146 | +            reduce_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_reduce],  | 
143 | 147 |             enable_compile=model_compile_enabled,  | 
144 | 148 |             enable_compiled_autograd=job_config.parallelism.enable_compiled_autograd,  | 
 | 149 | +            cpu_offload=job_config.training.enable_cpu_offload,  | 
145 | 150 |         )  | 
146 | 151 | 
 
  | 
147 | 152 |     return model  | 
@@ -317,17 +322,33 @@ def apply_fsdp(  | 
317 | 322 | def apply_ddp(  | 
318 | 323 |     model: nn.Module,  | 
319 | 324 |     dp_mesh: DeviceMesh,  | 
 | 325 | +    param_dtype: torch.dtype,  | 
 | 326 | +    reduce_dtype: torch.dtype,  | 
320 | 327 |     enable_compile: bool,  | 
321 | 328 |     enable_compiled_autograd: bool,  | 
 | 329 | +    cpu_offload: bool = False,  | 
322 | 330 | ):  | 
323 |  | -    if enable_compile:  | 
324 |  | -        if enable_compiled_autograd:  | 
325 |  | -            torch._dynamo.config.optimize_ddp = (  | 
326 |  | -                "python_reducer_without_compiled_forward"  | 
327 |  | -            )  | 
328 |  | -        else:  | 
329 |  | -            torch._dynamo.config.optimize_ddp = "ddp_optimizer"  | 
 | 331 | +    mp_policy = MixedPrecisionPolicy(param_dtype=param_dtype, reduce_dtype=reduce_dtype)  | 
 | 332 | +    replicate_config = {"device_mesh": dp_mesh, "mp_policy": mp_policy}  | 
 | 333 | +    if cpu_offload:  | 
 | 334 | +        replicate_config["offload_policy"] = CPUOffloadPolicy()  | 
330 | 335 | 
 
  | 
331 |  | -    replicate(model, device_mesh=dp_mesh, bucket_cap_mb=100)  | 
 | 336 | +    if model.tok_embeddings is not None:  | 
 | 337 | +        replicate(  | 
 | 338 | +            model.tok_embeddings,  | 
 | 339 | +            **replicate_config,  | 
 | 340 | +        )  | 
 | 341 | +    for layer_id, transformer_block in model.layers.items():  | 
 | 342 | +        replicate(  | 
 | 343 | +            transformer_block,  | 
 | 344 | +            **replicate_config,  | 
 | 345 | +        )  | 
 | 346 | + | 
 | 347 | +    if model.norm is not None and model.output is not None:  | 
 | 348 | +        replicate(  | 
 | 349 | +            [model.norm, model.output],  | 
 | 350 | +            **replicate_config,  | 
 | 351 | +        )  | 
 | 352 | +    replicate(model, **replicate_config)  | 
332 | 353 | 
 
  | 
333 | 354 |     logger.info("Applied DDP to the model")  | 
0 commit comments