@@ -455,14 +455,17 @@ def _count_moe_modules(model):
455455 from torchtitan .models .moe import MoE
456456
457457 moe_count = 0
458- for name , module in model .named_modules ():
458+ for _ , module in model .named_modules ():
459459 if isinstance (module , MoE ):
460460 moe_count += 1
461461 return moe_count
462462
463463
464464def overlap_callback (action : _Action , ctx : _PipelineContext ):
465- """Custom callback for OVERLAP_F_B computation that mimics the original implementation."""
465+ """
466+ Custom callback for OVERLAP_F_B computation that allows expert parallel communication
467+ and pipeline parallel computation to overlap.
468+ """
466469 schedule = ctx .schedule_ref
467470 assert isinstance (schedule , _PipelineScheduleRuntime )
468471 stage_index_to_stage : dict [int , _PipelineStageBase ] = {
@@ -482,6 +485,7 @@ def overlap_callback(action: _Action, ctx: _PipelineContext):
482485 # Forward setup
483486 arg_mbs = ctx .arg_mbs
484487 kwarg_mbs = ctx .kwarg_mbs
488+ assert arg_mbs is not None and kwarg_mbs is not None
485489 fwd_recv_ops = schedule .fwd_recv_ops
486490 forward_stage = stage_index_to_stage [forward_stage_index ]
487491 forward_is_next_stage_on_this_rank = forward_stage_index + 1 in stage_index_to_stage
@@ -498,13 +502,6 @@ def overlap_callback(action: _Action, ctx: _PipelineContext):
498502 assert backward_mb_index is not None
499503 bwd_recv_ops = schedule .bwd_recv_ops
500504
501- # print(
502- # f"overlap_callback begin {forward_stage_index}:{forward_mb_index}, {backward_stage_index}:{backward_mb_index}",
503- # "=" * 80,
504- # torch.distributed.get_rank(),
505- # )
506- # PP communication ========================================================
507-
508505 # Fwd receives
509506 if (
510507 not forward_stage .is_first
@@ -529,85 +526,61 @@ def overlap_callback(action: _Action, ctx: _PipelineContext):
529526 ) in bwd_recv_ops , f"Attempted to run compute { action = } before receiving input"
530527 _wait_batch_p2p (bwd_recv_ops .pop ((backward_stage_index , backward_mb_index )))
531528
529+ # We count num layers in case the stage layers differ
530+ # If they differ than we only want coordination to happen for the min amount of layers
531+ min_num_layers = min (
532+ _count_moe_modules (forward_stage .submod ),
533+ _count_moe_modules (backward_stage .submod ),
534+ )
532535 # PP computation ========================================================
533- def forward_backward_overlapped ():
534- from torchtitan .distributed .pipeline_parallel import _hook_coordinator
536+ _hook_coordinator .enable_coordination (num_layers = min_num_layers )
537+ main_cuda_stream = torch .cuda .current_stream ()
538+
539+ def run_backward ():
540+ # Set the backward thread to use the same stream as forward
541+ torch .cuda .set_stream (main_cuda_stream )
542+ with record_function (
543+ f"backward_stage_{ backward_stage_index } _mb_{ backward_mb_index } "
544+ ):
545+ loss = schedule ._maybe_get_loss (backward_stage , backward_mb_index )
546+ schedule .backward_counter [backward_stage_index ] += 1
547+ last_backward = (
548+ schedule .backward_counter [backward_stage_index ]
549+ == schedule ._n_microbatches
550+ )
551+ backward_stage .backward_one_chunk (
552+ backward_mb_index ,
553+ loss = loss ,
554+ full_backward = True ,
555+ last_backward = last_backward ,
556+ )
557+ grad_scale_factor = schedule ._n_microbatches if schedule .scale_grads else 1
558+ if last_backward :
559+ backward_stage .scale_grads (grad_scale_factor )
535560
536- # TODO: Num layers is needed in case the stage layers differ, we need to ensure there is no coordination
537- min_num_layers = min (
538- _count_moe_modules (forward_stage .submod ),
539- _count_moe_modules (backward_stage .submod ),
540- )
541- _hook_coordinator .enable_coordination (num_layers = min_num_layers )
542- main_cuda_stream = torch .cuda .current_stream ()
543-
544- def run_backward ():
545- # Set the backward thread to use the same stream as forward
546- torch .cuda .set_stream (main_cuda_stream )
547- # Backward ========================================================
548- with record_function (
549- f"backward_stage_{ backward_stage_index } _mb_{ backward_mb_index } "
550- ):
551- loss = schedule ._maybe_get_loss (backward_stage , backward_mb_index )
552- schedule .backward_counter [backward_stage_index ] += 1
553- last_backward = (
554- schedule .backward_counter [backward_stage_index ]
555- == schedule ._n_microbatches
556- )
557- backward_stage .backward_one_chunk (
561+ if backward_is_prev_stage_on_this_rank :
562+ stage_index_to_stage [backward_stage_index - 1 ].set_local_bwd_input (
563+ backward_stage .get_local_bwd_output (backward_mb_index ),
558564 backward_mb_index ,
559- loss = loss ,
560- full_backward = True ,
561- last_backward = last_backward ,
562565 )
563- grad_scale_factor = (
564- schedule ._n_microbatches if schedule .scale_grads else 1
565- )
566- if last_backward :
567- backward_stage .scale_grads (grad_scale_factor )
568-
569- if backward_is_prev_stage_on_this_rank :
570- stage_index_to_stage [backward_stage_index - 1 ].set_local_bwd_input (
571- backward_stage .get_local_bwd_output (backward_mb_index ),
572- backward_mb_index ,
573- )
574-
575- # Forward ========================================================
576- def run_forward ():
577- output = forward_stage .forward_one_chunk (
578- forward_mb_index ,
579- arg_mbs [forward_mb_index ],
580- kwarg_mbs [forward_mb_index ],
581- )
582- schedule ._maybe_compute_loss (
583- forward_stage , output , ctx .target_mbs , forward_mb_index
566+
567+ def run_forward ():
568+ output = forward_stage .forward_one_chunk (
569+ forward_mb_index ,
570+ arg_mbs [forward_mb_index ],
571+ kwarg_mbs [forward_mb_index ],
572+ )
573+ schedule ._maybe_compute_loss (
574+ forward_stage , output , ctx .target_mbs , forward_mb_index
575+ )
576+ if forward_is_next_stage_on_this_rank :
577+ stage_index_to_stage [forward_stage_index + 1 ].set_local_fwd_input (
578+ output , forward_mb_index
584579 )
585- if forward_is_next_stage_on_this_rank :
586- stage_index_to_stage [forward_stage_index + 1 ].set_local_fwd_input (
587- output , forward_mb_index
588- )
589580
590- # Run forward and backward in parallel
591- # if _hook_coordinator.is_coordination_enabled():
592- thread = threading .Thread (target = run_backward , daemon = True )
593- thread .start ()
594- run_forward ()
595- thread .join ()
596- # with concurrent.futures.ThreadPoolExecutor(max_workers=2) as executor:
597- # forward_future = executor.submit(run_forward)
598- # backward_future = executor.submit(run_backward)
599-
600- # # Wait for both to complete simultaneously
601- # done, not_done = concurrent.futures.wait([forward_future, backward_future])
602- # output = forward_future.result()
603- # else:
604- # run_forward()
605- # run_backward()
606-
607- _hook_coordinator .disable_coordination ()
608-
609- forward_backward_overlapped ()
610- # print(
611- # f"overlap_callback end {forward_stage_index}:{forward_mb_index}, {backward_stage_index}:{backward_mb_index}",
612- # "=" * 80,
613- # )
581+ # Run forward and backward in parallel
582+ thread = threading .Thread (target = run_backward , daemon = True )
583+ thread .start ()
584+ run_forward ()
585+ thread .join ()
586+ _hook_coordinator .disable_coordination ()
0 commit comments