|  | 
| 7 | 7 | 
 | 
| 8 | 8 | import math | 
| 9 | 9 | import os | 
| 10 |  | -from typing import Callable | 
|  | 10 | +import threading | 
|  | 11 | +from typing import Callable, Optional | 
| 11 | 12 | 
 | 
| 12 | 13 | import torch | 
| 13 | 14 | import torch.nn as nn | 
| 14 | 15 | from torch.distributed.device_mesh import DeviceMesh | 
| 15 | 16 | from torch.distributed.pipelining import PipelineStage | 
| 16 | 17 | 
 | 
| 17 | 18 | from torch.distributed.pipelining.schedules import ( | 
|  | 19 | +    _Action, | 
|  | 20 | +    _PipelineContext, | 
| 18 | 21 |     _PipelineSchedule, | 
| 19 | 22 |     _PipelineScheduleRuntime, | 
|  | 23 | +    _wait_batch_p2p, | 
| 20 | 24 |     get_schedule_class, | 
|  | 25 | +    OVERLAP_F_B, | 
| 21 | 26 |     PipelineScheduleMulti, | 
| 22 | 27 |     PipelineScheduleSingle, | 
| 23 | 28 |     ScheduleDualPipeV, | 
| 24 | 29 |     ScheduleZBVZeroBubble, | 
| 25 | 30 | ) | 
|  | 31 | +from torch.distributed.pipelining.stage import _PipelineStageBase | 
|  | 32 | +from torch.profiler import record_function | 
| 26 | 33 | 
 | 
| 27 | 34 | from torchtitan.components.loss import LossFunction, rescale_accumulated_loss | 
| 28 | 35 | from torchtitan.config import JobConfig | 
|  | 
| 37 | 44 |     "pipeline_module_split", | 
| 38 | 45 | ] | 
| 39 | 46 | 
 | 
| 40 |  | - | 
|  | 47 | +import fbvscode | 
|  | 48 | +fbvscode.attach_debugger() | 
| 41 | 49 | def pipeline_llm( | 
| 42 | 50 |     model: nn.Module, | 
| 43 | 51 |     parallel_dims: ParallelDims, | 
| @@ -209,6 +217,11 @@ def build_pipeline_schedule( | 
| 209 | 217 |         f"with {n_microbatches} microbatches and {num_total_stages} stages." | 
| 210 | 218 |     ) | 
| 211 | 219 | 
 | 
|  | 220 | +    if job_config.parallelism.pipeline_parallel_expert_parallel_overlap and isinstance( | 
|  | 221 | +        schedule, ScheduleDualPipeV | 
|  | 222 | +    ): | 
|  | 223 | +        schedule.register_custom_function(OVERLAP_F_B, overlap_callback) | 
|  | 224 | + | 
| 212 | 225 |     if pp_schedule_csv: | 
| 213 | 226 |         assert schedule_class in [ | 
| 214 | 227 |             PipelineScheduleSingle, | 
| @@ -473,3 +486,222 @@ def _get_stage_indices() -> tuple[int]: | 
| 473 | 486 |         models.append(model_chunk) | 
| 474 | 487 | 
 | 
| 475 | 488 |     return stages, models | 
|  | 489 | + | 
|  | 490 | + | 
|  | 491 | +# TODO: is there a better place to put this? | 
|  | 492 | +# Below are optimizations related to pipeline parallelism with expert parallelism | 
|  | 493 | + | 
|  | 494 | + | 
|  | 495 | +class HookCoordinator: | 
|  | 496 | +    def __init__(self): | 
|  | 497 | +        # Barrier for 2 threads (forward and backward) to synchronize | 
|  | 498 | +        # This ensures that we always alternate at executing one compute and one comm op together | 
|  | 499 | +        self._execution_barrier = threading.Barrier(2) | 
|  | 500 | + | 
|  | 501 | +        self._coordination_enabled = False | 
|  | 502 | +        self._cycle_count = 0 | 
|  | 503 | +        self._num_layers = None | 
|  | 504 | + | 
|  | 505 | +    def barrier(self): | 
|  | 506 | +        """Barrier for 2 threads to synchronize""" | 
|  | 507 | +        if not self.is_coordination_enabled(): | 
|  | 508 | +            return | 
|  | 509 | + | 
|  | 510 | +        try: | 
|  | 511 | +            self._execution_barrier.wait() | 
|  | 512 | +        except threading.BrokenBarrierError: | 
|  | 513 | +            pass | 
|  | 514 | + | 
|  | 515 | +    def enable_coordination(self, num_layers: Optional[int] = None): | 
|  | 516 | +        if num_layers is not None and num_layers > 0: | 
|  | 517 | +            self._coordination_enabled = True | 
|  | 518 | +            self._cycle_count = 0 | 
|  | 519 | + | 
|  | 520 | +            # Reset barrier | 
|  | 521 | +            self._execution_barrier = threading.Barrier(2) | 
|  | 522 | +            self._num_layers = num_layers | 
|  | 523 | + | 
|  | 524 | +    def disable_coordination(self): | 
|  | 525 | +        self._coordination_enabled = False | 
|  | 526 | +        self._cycle_count = 0 | 
|  | 527 | +        self._execution_barrier.abort()  # Break barrier to unblock threads | 
|  | 528 | + | 
|  | 529 | +    def check_should_continue_coordination(self): | 
|  | 530 | +        if self._num_layers is not None and self._cycle_count >= self._num_layers: | 
|  | 531 | +            return False | 
|  | 532 | +        return True | 
|  | 533 | + | 
|  | 534 | +    def is_coordination_enabled(self): | 
|  | 535 | +        return self._coordination_enabled | 
|  | 536 | + | 
|  | 537 | + | 
|  | 538 | +# Global coordinator | 
|  | 539 | +_hook_coordinator = HookCoordinator() | 
|  | 540 | + | 
|  | 541 | + | 
|  | 542 | +class SyncHook(torch.autograd.Function): | 
|  | 543 | +    @staticmethod | 
|  | 544 | +    def forward(ctx, x, hook_name=""): | 
|  | 545 | +        ctx.hook_name = hook_name | 
|  | 546 | +        # handle edge case for transformer level boundary | 
|  | 547 | +        if _hook_coordinator._coordination_enabled and hook_name == "D": | 
|  | 548 | +            _hook_coordinator._cycle_count += 1 | 
|  | 549 | +            if not _hook_coordinator.check_should_continue_coordination(): | 
|  | 550 | +                _hook_coordinator.disable_coordination() | 
|  | 551 | +                return x | 
|  | 552 | + | 
|  | 553 | +        # print(f"forward {hook_name=} calling barrier") | 
|  | 554 | +        _hook_coordinator.barrier() | 
|  | 555 | +        return x | 
|  | 556 | + | 
|  | 557 | +    @staticmethod | 
|  | 558 | +    def backward(ctx, grad_output): | 
|  | 559 | +        hook_name = ctx.hook_name | 
|  | 560 | + | 
|  | 561 | +        # Edge case, skip initial barrier, all subsequent backward hooks will acquire | 
|  | 562 | +        if hook_name == "D" and _hook_coordinator._cycle_count == 0: | 
|  | 563 | +            return grad_output, None | 
|  | 564 | + | 
|  | 565 | +        # print(f"backward {hook_name=} calling barrier") | 
|  | 566 | +        _hook_coordinator.barrier() | 
|  | 567 | +        return grad_output, None | 
|  | 568 | + | 
|  | 569 | + | 
|  | 570 | +def _count_moe_modules(model): | 
|  | 571 | +    """Count MoE modules directly""" | 
|  | 572 | +    from torchtitan.models.moe import MoE | 
|  | 573 | + | 
|  | 574 | +    moe_count = 0 | 
|  | 575 | +    for _, module in model.named_modules(): | 
|  | 576 | +        if isinstance(module, MoE): | 
|  | 577 | +            moe_count += 1 | 
|  | 578 | +    return moe_count | 
|  | 579 | + | 
|  | 580 | + | 
|  | 581 | +def overlap_callback(action: _Action, ctx: _PipelineContext): | 
|  | 582 | +    """ | 
|  | 583 | +    Custom callback for OVERLAP_F_B computation that allows expert parallel communication | 
|  | 584 | +    and pipeline parallel computation to overlap. | 
|  | 585 | +    """ | 
|  | 586 | +    schedule = ctx.schedule_ref | 
|  | 587 | +    assert isinstance(schedule, _PipelineScheduleRuntime) | 
|  | 588 | +    stage_index_to_stage: dict[int, _PipelineStageBase] = { | 
|  | 589 | +        stage.stage_index: stage for stage in schedule._stages | 
|  | 590 | +    } | 
|  | 591 | +    assert action.sub_actions is not None | 
|  | 592 | +    fwd_action = action.sub_actions[0] | 
|  | 593 | +    bwd_action = action.sub_actions[1] | 
|  | 594 | + | 
|  | 595 | +    # Get stages | 
|  | 596 | +    forward_stage_index = fwd_action.stage_index | 
|  | 597 | +    forward_mb_index = fwd_action.microbatch_index | 
|  | 598 | +    assert forward_mb_index is not None | 
|  | 599 | +    backward_stage_index = bwd_action.stage_index | 
|  | 600 | +    backward_stage = stage_index_to_stage[backward_stage_index] | 
|  | 601 | + | 
|  | 602 | +    # Forward setup | 
|  | 603 | +    arg_mbs = ctx.arg_mbs | 
|  | 604 | +    kwarg_mbs = ctx.kwarg_mbs | 
|  | 605 | +    assert arg_mbs is not None and kwarg_mbs is not None | 
|  | 606 | +    fwd_recv_ops = schedule.fwd_recv_ops | 
|  | 607 | +    forward_stage = stage_index_to_stage[forward_stage_index] | 
|  | 608 | +    forward_is_next_stage_on_this_rank = forward_stage_index + 1 in stage_index_to_stage | 
|  | 609 | +    forward_is_prev_stage_on_this_rank = forward_stage_index - 1 in stage_index_to_stage | 
|  | 610 | + | 
|  | 611 | +    # Backward setup | 
|  | 612 | +    backward_is_next_stage_on_this_rank = ( | 
|  | 613 | +        backward_stage.stage_index + 1 in stage_index_to_stage | 
|  | 614 | +    ) | 
|  | 615 | +    backward_is_prev_stage_on_this_rank = ( | 
|  | 616 | +        backward_stage.stage_index - 1 in stage_index_to_stage | 
|  | 617 | +    ) | 
|  | 618 | +    backward_mb_index = bwd_action.microbatch_index | 
|  | 619 | +    assert backward_mb_index is not None | 
|  | 620 | +    bwd_recv_ops = schedule.bwd_recv_ops | 
|  | 621 | + | 
|  | 622 | +    # Fwd receives | 
|  | 623 | +    if ( | 
|  | 624 | +        not forward_stage.is_first | 
|  | 625 | +        # no recv op expected for V-schedule special case (see [Note: V-schedule special case]) | 
|  | 626 | +        and not forward_is_prev_stage_on_this_rank | 
|  | 627 | +    ): | 
|  | 628 | +        assert ( | 
|  | 629 | +            forward_stage_index, | 
|  | 630 | +            forward_mb_index, | 
|  | 631 | +        ) in fwd_recv_ops, f"Computing {action=} before receiving input" | 
|  | 632 | +        _wait_batch_p2p(fwd_recv_ops.pop((forward_stage_index, forward_mb_index))) | 
|  | 633 | + | 
|  | 634 | +    # Bwd receives | 
|  | 635 | +    if ( | 
|  | 636 | +        not backward_stage.is_last | 
|  | 637 | +        # no recv op expected for V-schedule special case (see [Note: V-schedule special case]) | 
|  | 638 | +        and not backward_is_next_stage_on_this_rank | 
|  | 639 | +    ): | 
|  | 640 | +        assert ( | 
|  | 641 | +            backward_stage_index, | 
|  | 642 | +            backward_mb_index, | 
|  | 643 | +        ) in bwd_recv_ops, f"Attempted to run compute {action=} before receiving input" | 
|  | 644 | +        _wait_batch_p2p(bwd_recv_ops.pop((backward_stage_index, backward_mb_index))) | 
|  | 645 | + | 
|  | 646 | +    # We count num layers in case the stage layers differ | 
|  | 647 | +    # If they differ than we only want coordination to happen for the min amount of layers | 
|  | 648 | +    min_num_layers = min( | 
|  | 649 | +        _count_moe_modules(forward_stage.submod), | 
|  | 650 | +        _count_moe_modules(backward_stage.submod), | 
|  | 651 | +    ) | 
|  | 652 | +    # PP computation ======================================================== | 
|  | 653 | +    _hook_coordinator.enable_coordination(num_layers=min_num_layers) | 
|  | 654 | +    main_cuda_stream = torch.cuda.current_stream() | 
|  | 655 | + | 
|  | 656 | +    # Shared container for exception from backward thread | 
|  | 657 | +    def run_backward(): | 
|  | 658 | +        schedule._assert_unsharded(backward_stage) | 
|  | 659 | +        # Set the backward thread to use the same stream as forward | 
|  | 660 | +        torch.cuda.set_stream(main_cuda_stream) | 
|  | 661 | +        with record_function( | 
|  | 662 | +            f"backward_stage_{backward_stage_index}_mb_{backward_mb_index}" | 
|  | 663 | +        ): | 
|  | 664 | +            loss = schedule._maybe_get_loss(backward_stage, backward_mb_index) | 
|  | 665 | +            schedule.backward_counter[backward_stage_index] += 1 | 
|  | 666 | +            last_backward = ( | 
|  | 667 | +                schedule.backward_counter[backward_stage_index] | 
|  | 668 | +                == schedule._n_microbatches | 
|  | 669 | +            ) | 
|  | 670 | +            backward_stage.backward_one_chunk( | 
|  | 671 | +                backward_mb_index, | 
|  | 672 | +                loss=loss, | 
|  | 673 | +                full_backward=True, | 
|  | 674 | +                last_backward=last_backward, | 
|  | 675 | +            ) | 
|  | 676 | +            grad_scale_factor = schedule._n_microbatches if schedule.scale_grads else 1 | 
|  | 677 | +            if last_backward: | 
|  | 678 | +                backward_stage.scale_grads(grad_scale_factor) | 
|  | 679 | + | 
|  | 680 | +            if backward_is_prev_stage_on_this_rank: | 
|  | 681 | +                stage_index_to_stage[backward_stage_index - 1].set_local_bwd_input( | 
|  | 682 | +                    backward_stage.get_local_bwd_output(backward_mb_index), | 
|  | 683 | +                    backward_mb_index, | 
|  | 684 | +                ) | 
|  | 685 | + | 
|  | 686 | +    def run_forward(): | 
|  | 687 | +        schedule._assert_unsharded(forward_stage) | 
|  | 688 | +        output = forward_stage.forward_one_chunk( | 
|  | 689 | +            forward_mb_index, | 
|  | 690 | +            arg_mbs[forward_mb_index], | 
|  | 691 | +            kwarg_mbs[forward_mb_index], | 
|  | 692 | +        ) | 
|  | 693 | +        schedule._maybe_compute_loss( | 
|  | 694 | +            forward_stage, output, ctx.target_mbs, forward_mb_index | 
|  | 695 | +        ) | 
|  | 696 | +        if forward_is_next_stage_on_this_rank: | 
|  | 697 | +            stage_index_to_stage[forward_stage_index + 1].set_local_fwd_input( | 
|  | 698 | +                output, forward_mb_index | 
|  | 699 | +            ) | 
|  | 700 | + | 
|  | 701 | +    # Run forward and backward in parallel | 
|  | 702 | +    thread = threading.Thread(target=run_backward, daemon=True) | 
|  | 703 | +    thread.start() | 
|  | 704 | +    run_forward() | 
|  | 705 | +    thread.join() | 
|  | 706 | + | 
|  | 707 | +    _hook_coordinator.disable_coordination() | 
0 commit comments