Skip to content

Commit ff6ea8a

Browse files
committed
[torchtitan][replicate] experimenting new replicate integration with torchtitan
ghstack-source-id: 88bd5ed Pull Request resolved: #1714
1 parent 0f34257 commit ff6ea8a

File tree

7 files changed

+63
-67
lines changed

7 files changed

+63
-67
lines changed

torchtitan/distributed/parallel_dims.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -151,7 +151,9 @@ def _build_mesh_without_ep(self) -> DeviceMesh:
151151
[self.pp, self.dp_replicate, self.dp_shard, self.cp, self.tp],
152152
["pp", "dp_replicate", "dp_shard", "cp", "tp"],
153153
):
154-
if d > 1:
154+
# Include dp_shard dimension even if it equals 1 when replicate > 1
155+
# to make device_mesh compatible with replicate function
156+
if d > 1 or (name == "dp_shard" and self.dp_replicate > 1):
155157
dims.append(d)
156158
names.append(name)
157159

torchtitan/distributed/utils.py

Lines changed: 8 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -217,24 +217,16 @@ def context(cp_context: Generator[None, None, None] | None = None):
217217
def maybe_enable_amp(
218218
parallel_dims: ParallelDims, mixed_precision_param: str, device_type: torch.device
219219
) -> Generator[None, None, None]:
220-
if parallel_dims.fsdp_enabled:
220+
if parallel_dims.fsdp_enabled or parallel_dims.dp_replicate_enabled:
221221
# FSDP handles mixed precision internally
222-
logger.info("Mixed precision training is handled by fully_shard")
222+
logger.info("Mixed precision training is handled by fully_shard or replicate")
223223
return contextlib.nullcontext()
224224
else:
225-
if parallel_dims.tp_enabled or parallel_dims.pp_enabled:
226-
logger.warning(
227-
"Mixed precision training with TP or PP is only supported when FSDP/HSDP/CP is enabled."
228-
)
229-
logger.info("Mixed precision training is disabled")
230-
return contextlib.nullcontext()
231-
else:
232-
# the following code will only be executed for DDP or single-device training
233-
logger.info("Mixed precision training is handled by AMP")
234-
return torch.autocast(
235-
device_type,
236-
dtype=TORCH_DTYPE_MAP[mixed_precision_param],
237-
)
225+
logger.warning(
226+
"Mixed precision training with TP or PP is only supported when FSDP/HSDP/CP/replicate is enabled."
227+
)
228+
logger.info("Mixed precision training is disabled")
229+
return contextlib.nullcontext()
238230

239231

240232
def init_distributed(
@@ -432,9 +424,7 @@ def _clip_grad_norm_with_ep(
432424
if math.isinf(norm_type):
433425
total_norm = torch.maximum(ep_grads_total_norm, non_ep_grads_total_norm)
434426
else:
435-
total_norm = (
436-
ep_grads_total_norm**norm_type + non_ep_grads_total_norm**norm_type
437-
)
427+
total_norm = ep_grads_total_norm**norm_type + non_ep_grads_total_norm**norm_type
438428
total_norm **= 1.0 / norm_type
439429

440430
if pp_mesh is not None:

torchtitan/experiments/llama4/infra/parallelize.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
TensorParallel,
2929
)
3030
from torchtitan.distributed.tensor_parallel import maybe_enable_async_tp
31-
from torchtitan.models.llama3.infra.parallelize import apply_ddp
31+
from torchtitan.models.llama3.infra.parallelize import apply_replicate
3232
from torchtitan.tools.logging import logger
3333

3434

@@ -169,14 +169,12 @@ def parallelize_llama(
169169
if job_config.training.enable_cpu_offload:
170170
logger.info("Applied CPU Offloading to the model")
171171
elif parallel_dims.dp_replicate_enabled:
172-
if world_mesh.ndim > 1:
173-
raise RuntimeError("DDP has not supported > 1D parallelism")
174-
dp_mesh = world_mesh
175-
apply_ddp(
172+
dp_mesh_dim_names = ("dp_replicate", "dp_shard")
173+
apply_replicate(
176174
model,
177-
dp_mesh,
178-
enable_compile=model_compile_enabled,
179-
enable_compiled_autograd=job_config.parallelism.enable_compiled_autograd,
175+
world_mesh[tuple(dp_mesh_dim_names)],
176+
param_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_param],
177+
reduce_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_reduce],
180178
)
181179

182180
return model

torchtitan/experiments/qwen3/infra/parallelize.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
apply_fsdp,
2929
apply_moe_ep_tp,
3030
)
31-
from torchtitan.models.llama3.infra.parallelize import apply_ddp
31+
from torchtitan.models.llama3.infra.parallelize import apply_replicate
3232
from torchtitan.tools.logging import logger
3333

3434

@@ -164,13 +164,12 @@ def parallelize_qwen3(
164164
if job_config.training.enable_cpu_offload:
165165
logger.info("Applied CPU Offloading to the model")
166166
elif parallel_dims.dp_replicate_enabled:
167-
if world_mesh.ndim > 1:
168-
raise RuntimeError("DDP has not supported > 1D parallelism")
169-
apply_ddp(
167+
dp_mesh_dim_names = ("dp_replicate", "dp_shard")
168+
apply_replicate(
170169
model,
171-
world_mesh,
172-
enable_compile=model_compile_enabled,
173-
enable_compiled_autograd=job_config.parallelism.enable_compiled_autograd,
170+
world_mesh[tuple(dp_mesh_dim_names)],
171+
param_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_param],
172+
reduce_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_reduce],
174173
)
175174

176175
# Enable weight tying after applying parallelisms

torchtitan/experiments/vlm/infra/parallelize.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from torchtitan.models.llama3.infra.parallelize import (
2121
_save_list as sac_save_list,
2222
apply_compile,
23-
apply_ddp,
23+
apply_replicate,
2424
)
2525
from torchtitan.tools.logging import logger
2626

@@ -101,13 +101,12 @@ def parallelize_vlm(
101101
if job_config.training.enable_cpu_offload:
102102
logger.info("Applied CPU Offloading to the model")
103103
elif parallel_dims.dp_replicate_enabled:
104-
if world_mesh.ndim > 1:
105-
raise RuntimeError("DDP has not supported > 1D parallelism")
106-
apply_ddp(
104+
dp_mesh_dim_names = ("dp_replicate", "dp_shard")
105+
apply_replicate(
107106
model,
108-
world_mesh,
109-
enable_compile=job_config.compile.enable,
110-
enable_compiled_autograd=job_config.parallelism.enable_compiled_autograd,
107+
world_mesh[tuple(dp_mesh_dim_names)],
108+
param_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_param],
109+
reduce_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_reduce],
111110
)
112111

113112
return model

torchtitan/models/deepseek_v3/infra/parallelize.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
apply_fsdp,
2626
apply_moe_ep_tp,
2727
)
28-
from torchtitan.models.llama3.infra.parallelize import apply_ddp
28+
from torchtitan.models.llama3.infra.parallelize import apply_replicate
2929
from torchtitan.tools.logging import logger
3030

3131

@@ -162,14 +162,12 @@ def parallelize_deepseekv3(
162162
if job_config.training.enable_cpu_offload:
163163
logger.info("Applied CPU Offloading to the model")
164164
elif parallel_dims.dp_replicate_enabled:
165-
if world_mesh.ndim > 1:
166-
raise RuntimeError("DDP has not supported > 1D parallelism")
167-
dp_mesh = world_mesh
168-
apply_ddp(
165+
dp_mesh_dim_names = ("dp_replicate", "dp_shard")
166+
apply_replicate(
169167
model,
170-
dp_mesh,
171-
enable_compile=model_compile_enabled,
172-
enable_compiled_autograd=job_config.parallelism.enable_compiled_autograd,
168+
world_mesh[tuple(dp_mesh_dim_names)],
169+
param_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_param],
170+
reduce_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_reduce],
173171
)
174172

175173
return model

torchtitan/models/llama3/infra/parallelize.py

Lines changed: 28 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
import torch
1111
import torch.nn as nn
12-
from torch.distributed._composable.replicate import replicate
12+
from torch.distributed._composable.replicate_with_fsdp import replicate
1313

1414
from torch.distributed.device_mesh import DeviceMesh
1515
from torch.distributed.fsdp import CPUOffloadPolicy, fully_shard, MixedPrecisionPolicy
@@ -135,13 +135,12 @@ def parallelize_llama(
135135
if job_config.training.enable_cpu_offload:
136136
logger.info("Applied CPU Offloading to the model")
137137
elif parallel_dims.dp_replicate_enabled:
138-
if world_mesh.ndim > 1:
139-
raise RuntimeError("DDP has not supported > 1D parallelism")
140-
apply_ddp(
138+
dp_mesh_dim_names = ("dp_replicate", "dp_shard")
139+
apply_replicate(
141140
model,
142-
world_mesh,
143-
enable_compile=model_compile_enabled,
144-
enable_compiled_autograd=job_config.parallelism.enable_compiled_autograd,
141+
world_mesh[tuple(dp_mesh_dim_names)],
142+
param_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_param],
143+
reduce_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_reduce],
145144
)
146145

147146
return model
@@ -314,20 +313,31 @@ def apply_fsdp(
314313
fully_shard(model, **fsdp_config)
315314

316315

317-
def apply_ddp(
316+
def apply_replicate(
318317
model: nn.Module,
319318
dp_mesh: DeviceMesh,
320-
enable_compile: bool,
321-
enable_compiled_autograd: bool,
319+
param_dtype: torch.dtype,
320+
reduce_dtype: torch.dtype,
322321
):
323-
if enable_compile:
324-
if enable_compiled_autograd:
325-
torch._dynamo.config.optimize_ddp = (
326-
"python_reducer_without_compiled_forward"
327-
)
328-
else:
329-
torch._dynamo.config.optimize_ddp = "ddp_optimizer"
322+
mp_policy = MixedPrecisionPolicy(param_dtype=param_dtype, reduce_dtype=reduce_dtype)
323+
replicate_config = {"device_mesh": dp_mesh, "mp_policy": mp_policy}
324+
325+
if model.tok_embeddings is not None:
326+
replicate(
327+
model.tok_embeddings,
328+
**replicate_config,
329+
)
330+
for layer_id, transformer_block in model.layers.items():
331+
replicate(
332+
transformer_block,
333+
**replicate_config,
334+
)
330335

331-
replicate(model, device_mesh=dp_mesh, bucket_cap_mb=100)
336+
if model.norm is not None and model.output is not None:
337+
replicate(
338+
[model.norm, model.output],
339+
**replicate_config,
340+
)
341+
replicate(model, **replicate_config)
332342

333343
logger.info("Applied DDP to the model")

0 commit comments

Comments
 (0)