Skip to content

Commit b93080b

Browse files
committed
[torchtitan][replicate] experimenting new replicate integration with torchtitan
ghstack-source-id: bfe9ee3 Pull Request resolved: #1714
1 parent d240be0 commit b93080b

File tree

3 files changed

+37
-14
lines changed

3 files changed

+37
-14
lines changed

torchtitan/distributed/parallel_dims.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -151,7 +151,9 @@ def _build_mesh_without_ep(self) -> DeviceMesh:
151151
[self.pp, self.dp_replicate, self.dp_shard, self.cp, self.tp],
152152
["pp", "dp_replicate", "dp_shard", "cp", "tp"],
153153
):
154-
if d > 1:
154+
# Include dp_shard dimension even if it equals 1 when replicate > 1
155+
# to make device_mesh compatible with replicate function
156+
if d > 1 or (name == "dp_shard" and self.dp_replicate > 1):
155157
dims.append(d)
156158
names.append(name)
157159

torchtitan/models/llama3/infra/parallelize.py

Lines changed: 33 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
import torch
1111
import torch.nn as nn
12-
from torch.distributed._composable.replicate import replicate
12+
from torch.distributed._composable.replicate_with_fsdp import replicate
1313

1414
from torch.distributed.device_mesh import DeviceMesh
1515
from torch.distributed.fsdp import CPUOffloadPolicy, fully_shard, MixedPrecisionPolicy
@@ -135,13 +135,18 @@ def parallelize_llama(
135135
if job_config.training.enable_cpu_offload:
136136
logger.info("Applied CPU Offloading to the model")
137137
elif parallel_dims.dp_replicate_enabled:
138-
if world_mesh.ndim > 1:
139-
raise RuntimeError("DDP has not supported > 1D parallelism")
138+
# if world_mesh.ndim > 1:
139+
# raise RuntimeError("DDP has not supported > 1D parallelism")
140+
141+
dp_mesh_dim_names = ("dp_replicate", "dp_shard")
140142
apply_ddp(
141143
model,
142-
world_mesh,
144+
world_mesh[tuple(dp_mesh_dim_names)],
145+
param_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_param],
146+
reduce_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_reduce],
143147
enable_compile=model_compile_enabled,
144148
enable_compiled_autograd=job_config.parallelism.enable_compiled_autograd,
149+
cpu_offload=job_config.training.enable_cpu_offload,
145150
)
146151

147152
return model
@@ -317,17 +322,33 @@ def apply_fsdp(
317322
def apply_ddp(
318323
model: nn.Module,
319324
dp_mesh: DeviceMesh,
325+
param_dtype: torch.dtype,
326+
reduce_dtype: torch.dtype,
320327
enable_compile: bool,
321328
enable_compiled_autograd: bool,
329+
cpu_offload: bool = False,
322330
):
323-
if enable_compile:
324-
if enable_compiled_autograd:
325-
torch._dynamo.config.optimize_ddp = (
326-
"python_reducer_without_compiled_forward"
327-
)
328-
else:
329-
torch._dynamo.config.optimize_ddp = "ddp_optimizer"
331+
mp_policy = MixedPrecisionPolicy(param_dtype=param_dtype, reduce_dtype=reduce_dtype)
332+
replicate_config = {"device_mesh": dp_mesh, "mp_policy": mp_policy}
333+
if cpu_offload:
334+
replicate_config["offload_policy"] = CPUOffloadPolicy()
330335

331-
replicate(model, device_mesh=dp_mesh, bucket_cap_mb=100)
336+
if model.tok_embeddings is not None:
337+
replicate(
338+
model.tok_embeddings,
339+
**replicate_config,
340+
)
341+
for layer_id, transformer_block in model.layers.items():
342+
replicate(
343+
transformer_block,
344+
**replicate_config,
345+
)
346+
347+
if model.norm is not None and model.output is not None:
348+
replicate(
349+
[model.norm, model.output],
350+
**replicate_config,
351+
)
352+
replicate(model, **replicate_config)
332353

333354
logger.info("Applied DDP to the model")

torchtitan/models/llama3/train_configs/debug_model.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ steps = 10
4343
dataset = "c4_test" # supported datasets: c4_test (2K), c4 (177M)
4444

4545
[parallelism]
46-
data_parallel_replicate_degree = 1
46+
data_parallel_replicate_degree = 8
4747
data_parallel_shard_degree = -1
4848
fsdp_reshard_after_forward = "default" # default / never / always
4949
tensor_parallel_degree = 1

0 commit comments

Comments
 (0)