Skip to content

Commit 4f8e621

Browse files
committed
clean up train.py
1 parent 6e4ef27 commit 4f8e621

File tree

4 files changed

+72
-106
lines changed

4 files changed

+72
-106
lines changed

torchtitan/distributed/expert_parallel.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -159,9 +159,9 @@ def _token_combine(self, mod, routed_output, device_mesh):
159159

160160
def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module:
161161
"""
162-
hooks are called in the order they are registered:
163-
A, dispatch, B (pre hooks)
164-
C, combine, D (post hooks)
162+
Hooks are called in the order they are registered:
163+
SyncHookA, _token_dispatch, SyncHookB (pre hooks)
164+
SyncHookC, _token_combine, SyncHookD (post hooks)
165165
"""
166166
inner_wrapped_module = self._wrap_with_inner_hooks(module)
167167
distributed_module = distribute_module(

torchtitan/distributed/pipeline_parallel.py

Lines changed: 58 additions & 85 deletions
Original file line numberDiff line numberDiff line change
@@ -455,14 +455,17 @@ def _count_moe_modules(model):
455455
from torchtitan.models.moe import MoE
456456

457457
moe_count = 0
458-
for name, module in model.named_modules():
458+
for _, module in model.named_modules():
459459
if isinstance(module, MoE):
460460
moe_count += 1
461461
return moe_count
462462

463463

464464
def overlap_callback(action: _Action, ctx: _PipelineContext):
465-
"""Custom callback for OVERLAP_F_B computation that mimics the original implementation."""
465+
"""
466+
Custom callback for OVERLAP_F_B computation that allows expert parallel communication
467+
and pipeline parallel computation to overlap.
468+
"""
466469
schedule = ctx.schedule_ref
467470
assert isinstance(schedule, _PipelineScheduleRuntime)
468471
stage_index_to_stage: dict[int, _PipelineStageBase] = {
@@ -482,6 +485,7 @@ def overlap_callback(action: _Action, ctx: _PipelineContext):
482485
# Forward setup
483486
arg_mbs = ctx.arg_mbs
484487
kwarg_mbs = ctx.kwarg_mbs
488+
assert arg_mbs is not None and kwarg_mbs is not None
485489
fwd_recv_ops = schedule.fwd_recv_ops
486490
forward_stage = stage_index_to_stage[forward_stage_index]
487491
forward_is_next_stage_on_this_rank = forward_stage_index + 1 in stage_index_to_stage
@@ -498,13 +502,6 @@ def overlap_callback(action: _Action, ctx: _PipelineContext):
498502
assert backward_mb_index is not None
499503
bwd_recv_ops = schedule.bwd_recv_ops
500504

501-
# print(
502-
# f"overlap_callback begin {forward_stage_index}:{forward_mb_index}, {backward_stage_index}:{backward_mb_index}",
503-
# "=" * 80,
504-
# torch.distributed.get_rank(),
505-
# )
506-
# PP communication ========================================================
507-
508505
# Fwd receives
509506
if (
510507
not forward_stage.is_first
@@ -529,85 +526,61 @@ def overlap_callback(action: _Action, ctx: _PipelineContext):
529526
) in bwd_recv_ops, f"Attempted to run compute {action=} before receiving input"
530527
_wait_batch_p2p(bwd_recv_ops.pop((backward_stage_index, backward_mb_index)))
531528

529+
# We count num layers in case the stage layers differ
530+
# If they differ than we only want coordination to happen for the min amount of layers
531+
min_num_layers = min(
532+
_count_moe_modules(forward_stage.submod),
533+
_count_moe_modules(backward_stage.submod),
534+
)
532535
# PP computation ========================================================
533-
def forward_backward_overlapped():
534-
from torchtitan.distributed.pipeline_parallel import _hook_coordinator
536+
_hook_coordinator.enable_coordination(num_layers=min_num_layers)
537+
main_cuda_stream = torch.cuda.current_stream()
538+
539+
def run_backward():
540+
# Set the backward thread to use the same stream as forward
541+
torch.cuda.set_stream(main_cuda_stream)
542+
with record_function(
543+
f"backward_stage_{backward_stage_index}_mb_{backward_mb_index}"
544+
):
545+
loss = schedule._maybe_get_loss(backward_stage, backward_mb_index)
546+
schedule.backward_counter[backward_stage_index] += 1
547+
last_backward = (
548+
schedule.backward_counter[backward_stage_index]
549+
== schedule._n_microbatches
550+
)
551+
backward_stage.backward_one_chunk(
552+
backward_mb_index,
553+
loss=loss,
554+
full_backward=True,
555+
last_backward=last_backward,
556+
)
557+
grad_scale_factor = schedule._n_microbatches if schedule.scale_grads else 1
558+
if last_backward:
559+
backward_stage.scale_grads(grad_scale_factor)
535560

536-
# TODO: Num layers is needed in case the stage layers differ, we need to ensure there is no coordination
537-
min_num_layers = min(
538-
_count_moe_modules(forward_stage.submod),
539-
_count_moe_modules(backward_stage.submod),
540-
)
541-
_hook_coordinator.enable_coordination(num_layers=min_num_layers)
542-
main_cuda_stream = torch.cuda.current_stream()
543-
544-
def run_backward():
545-
# Set the backward thread to use the same stream as forward
546-
torch.cuda.set_stream(main_cuda_stream)
547-
# Backward ========================================================
548-
with record_function(
549-
f"backward_stage_{backward_stage_index}_mb_{backward_mb_index}"
550-
):
551-
loss = schedule._maybe_get_loss(backward_stage, backward_mb_index)
552-
schedule.backward_counter[backward_stage_index] += 1
553-
last_backward = (
554-
schedule.backward_counter[backward_stage_index]
555-
== schedule._n_microbatches
556-
)
557-
backward_stage.backward_one_chunk(
561+
if backward_is_prev_stage_on_this_rank:
562+
stage_index_to_stage[backward_stage_index - 1].set_local_bwd_input(
563+
backward_stage.get_local_bwd_output(backward_mb_index),
558564
backward_mb_index,
559-
loss=loss,
560-
full_backward=True,
561-
last_backward=last_backward,
562565
)
563-
grad_scale_factor = (
564-
schedule._n_microbatches if schedule.scale_grads else 1
565-
)
566-
if last_backward:
567-
backward_stage.scale_grads(grad_scale_factor)
568-
569-
if backward_is_prev_stage_on_this_rank:
570-
stage_index_to_stage[backward_stage_index - 1].set_local_bwd_input(
571-
backward_stage.get_local_bwd_output(backward_mb_index),
572-
backward_mb_index,
573-
)
574-
575-
# Forward ========================================================
576-
def run_forward():
577-
output = forward_stage.forward_one_chunk(
578-
forward_mb_index,
579-
arg_mbs[forward_mb_index],
580-
kwarg_mbs[forward_mb_index],
581-
)
582-
schedule._maybe_compute_loss(
583-
forward_stage, output, ctx.target_mbs, forward_mb_index
566+
567+
def run_forward():
568+
output = forward_stage.forward_one_chunk(
569+
forward_mb_index,
570+
arg_mbs[forward_mb_index],
571+
kwarg_mbs[forward_mb_index],
572+
)
573+
schedule._maybe_compute_loss(
574+
forward_stage, output, ctx.target_mbs, forward_mb_index
575+
)
576+
if forward_is_next_stage_on_this_rank:
577+
stage_index_to_stage[forward_stage_index + 1].set_local_fwd_input(
578+
output, forward_mb_index
584579
)
585-
if forward_is_next_stage_on_this_rank:
586-
stage_index_to_stage[forward_stage_index + 1].set_local_fwd_input(
587-
output, forward_mb_index
588-
)
589580

590-
# Run forward and backward in parallel
591-
# if _hook_coordinator.is_coordination_enabled():
592-
thread = threading.Thread(target=run_backward, daemon=True)
593-
thread.start()
594-
run_forward()
595-
thread.join()
596-
# with concurrent.futures.ThreadPoolExecutor(max_workers=2) as executor:
597-
# forward_future = executor.submit(run_forward)
598-
# backward_future = executor.submit(run_backward)
599-
600-
# # Wait for both to complete simultaneously
601-
# done, not_done = concurrent.futures.wait([forward_future, backward_future])
602-
# output = forward_future.result()
603-
# else:
604-
# run_forward()
605-
# run_backward()
606-
607-
_hook_coordinator.disable_coordination()
608-
609-
forward_backward_overlapped()
610-
# print(
611-
# f"overlap_callback end {forward_stage_index}:{forward_mb_index}, {backward_stage_index}:{backward_mb_index}",
612-
# "=" * 80,
613-
# )
581+
# Run forward and backward in parallel
582+
thread = threading.Thread(target=run_backward, daemon=True)
583+
thread.start()
584+
run_forward()
585+
thread.join()
586+
_hook_coordinator.disable_coordination()

torchtitan/models/deepseek_v3/train_configs/deepseek_v3_16b.toml

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ description = "DeepSeek-V3 16B model training"
44
print_args = false
55

66
[profiling]
7-
enable_profiling = false
7+
enable_profiling = true
88
save_traces_folder = "profile_trace"
99
profile_freq = 10
1010
enable_memory_snapshot = false
@@ -56,15 +56,14 @@ expert_tensor_parallel_degree = 1
5656
enable = false
5757
folder = "checkpoint"
5858
interval = 10
59-
last_save_model_only = false # This does stuff with causing compile?
59+
last_save_model_only = true
6060
export_dtype = "float32"
6161
async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem]"
6262

6363
[activation_checkpoint]
6464
mode = "none" # ["none", "selective", "full"]
6565
selective_ac_option = 'op' # 'int' = ac every positive int layer or 'op', ac based on ops policy
6666

67-
# we cannot compile model with dI-dW split
6867
[compile]
6968
enable=true
7069
components = ["loss"] # ["model", "loss"]

torchtitan/train.py

Lines changed: 9 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -493,17 +493,15 @@ def train_step(
493493
loss = self.forward_backward_step(input_dict, labels)
494494
accumulated_losses.append(loss.detach())
495495

496-
# TODO: parameters are not DTensors which im not sure why
497-
# grad_norm = dist_utils.clip_grad_norm_(
498-
# [p for m in self.model_parts for p in m.parameters()],
499-
# self.job_config.training.max_norm,
500-
# foreach=True,
501-
# pp_mesh=(
502-
# parallel_dims.world_mesh["pp"] if parallel_dims.pp_enabled else None
503-
# ),
504-
# ep_enabled=parallel_dims.ep_enabled,
505-
# )
506-
grad_norm = torch.tensor([0.0], device=self.device)
496+
grad_norm = dist_utils.clip_grad_norm_(
497+
[p for m in self.model_parts for p in m.parameters()],
498+
self.job_config.training.max_norm,
499+
foreach=True,
500+
pp_mesh=(
501+
parallel_dims.world_mesh["pp"] if parallel_dims.pp_enabled else None
502+
),
503+
ep_enabled=parallel_dims.ep_enabled,
504+
)
507505
self.checkpointer.maybe_wait_for_staging()
508506
self.optimizers.step()
509507
self.lr_schedulers.step()
@@ -648,10 +646,6 @@ def close(self) -> None:
648646
self.metrics_processor.close()
649647

650648

651-
import fbvscode
652-
653-
fbvscode.attach_debugger()
654-
655649
if __name__ == "__main__":
656650
init_logger()
657651
config_manager = ConfigManager()

0 commit comments

Comments
 (0)