|  | 
|  | 1 | +# Copyright (c) Meta Platforms, Inc. and affiliates. | 
|  | 2 | +# All rights reserved. | 
|  | 3 | +# | 
|  | 4 | +# This source code is licensed under the BSD-style license found in the | 
|  | 5 | +# LICENSE file in the root directory of this source tree. | 
|  | 6 | +import threading | 
|  | 7 | +from typing import Optional | 
|  | 8 | + | 
|  | 9 | +import torch | 
|  | 10 | +import torch.nn as nn | 
|  | 11 | + | 
|  | 12 | +from torch.distributed.pipelining.schedules import ( | 
|  | 13 | +    _Action, | 
|  | 14 | +    _PipelineContext, | 
|  | 15 | +    _PipelineScheduleRuntime, | 
|  | 16 | +    _wait_batch_p2p, | 
|  | 17 | +) | 
|  | 18 | +from torch.distributed.pipelining.stage import _PipelineStageBase | 
|  | 19 | +from torch.distributed.tensor import DeviceMesh, distribute_module | 
|  | 20 | +from torch.profiler import record_function | 
|  | 21 | + | 
|  | 22 | +from torchtitan.distributed.expert_parallel import ExpertParallel | 
|  | 23 | + | 
|  | 24 | +from torchtitan.tools.utils import get_device_info | 
|  | 25 | + | 
|  | 26 | +""" | 
|  | 27 | +Below are optimizations related to pipeline parallelism with expert parallelism | 
|  | 28 | +""" | 
|  | 29 | + | 
|  | 30 | + | 
|  | 31 | +class DualPipeExpertParallel(ExpertParallel): | 
|  | 32 | +    def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module: | 
|  | 33 | +        """ | 
|  | 34 | +        The execution order is: | 
|  | 35 | +        A -> dispatch -> B -> module -> C -> combine -> D | 
|  | 36 | +
 | 
|  | 37 | +        Hooks are called in the order they are registered: | 
|  | 38 | +        SyncHookA, _token_dispatch, SyncHookB (pre hooks) | 
|  | 39 | +        SyncHookC, _token_combine, SyncHookD (post hooks) | 
|  | 40 | +        """ | 
|  | 41 | +        inner_wrapped_module = self._wrap_with_pre_comm_hooks(module) | 
|  | 42 | +        distributed_module = distribute_module( | 
|  | 43 | +            inner_wrapped_module, | 
|  | 44 | +            device_mesh, | 
|  | 45 | +            partition_fn=ExpertParallel._partition_fn, | 
|  | 46 | +            input_fn=self._token_dispatch, | 
|  | 47 | +            output_fn=self._token_combine, | 
|  | 48 | +        ) | 
|  | 49 | +        final_module = self._wrap_with_post_comm_hooks(distributed_module) | 
|  | 50 | +        return final_module | 
|  | 51 | + | 
|  | 52 | +    def _wrap_with_pre_comm_hooks(self, module): | 
|  | 53 | +        def inner_pre_hook(module, input): | 
|  | 54 | +            return (SyncHook.apply(input[0], "A"),) + input[1:] | 
|  | 55 | + | 
|  | 56 | +        def inner_post_hook(module, input, output): | 
|  | 57 | +            return SyncHook.apply(output, "C") | 
|  | 58 | + | 
|  | 59 | +        module.register_forward_pre_hook(inner_pre_hook) | 
|  | 60 | +        module.register_forward_hook(inner_post_hook) | 
|  | 61 | +        return module | 
|  | 62 | + | 
|  | 63 | +    def _wrap_with_post_comm_hooks(self, module): | 
|  | 64 | +        def outer_pre_hook(module, input): | 
|  | 65 | +            return (SyncHook.apply(input[0], "B"),) + input[1:] | 
|  | 66 | + | 
|  | 67 | +        def outer_post_hook(module, input, output): | 
|  | 68 | +            return SyncHook.apply(output, "D") | 
|  | 69 | + | 
|  | 70 | +        module.register_forward_pre_hook(outer_pre_hook) | 
|  | 71 | +        module.register_forward_hook(outer_post_hook) | 
|  | 72 | +        return module | 
|  | 73 | + | 
|  | 74 | + | 
|  | 75 | +class HookCoordinator: | 
|  | 76 | +    def __init__(self): | 
|  | 77 | +        # Barrier for 2 threads (forward and backward) to synchronize | 
|  | 78 | +        # This ensures that we always alternate at executing one compute and one comm op together | 
|  | 79 | +        self._execution_barrier = threading.Barrier(2) | 
|  | 80 | + | 
|  | 81 | +        self._coordination_enabled = False | 
|  | 82 | +        self._cycle_count = 0 | 
|  | 83 | +        self._num_layers = None | 
|  | 84 | + | 
|  | 85 | +    def barrier(self): | 
|  | 86 | +        """Barrier for 2 threads to synchronize""" | 
|  | 87 | +        if not self.is_coordination_enabled(): | 
|  | 88 | +            return | 
|  | 89 | + | 
|  | 90 | +        try: | 
|  | 91 | +            self._execution_barrier.wait() | 
|  | 92 | +        except threading.BrokenBarrierError: | 
|  | 93 | +            pass | 
|  | 94 | + | 
|  | 95 | +    def enable_coordination(self, num_layers: Optional[int] = None): | 
|  | 96 | +        if num_layers is not None and num_layers > 0: | 
|  | 97 | +            self._coordination_enabled = True | 
|  | 98 | +            self._cycle_count = 0 | 
|  | 99 | + | 
|  | 100 | +            # Reset barrier | 
|  | 101 | +            self._execution_barrier = threading.Barrier(2) | 
|  | 102 | +            self._num_layers = num_layers | 
|  | 103 | + | 
|  | 104 | +    def disable_coordination(self): | 
|  | 105 | +        self._coordination_enabled = False | 
|  | 106 | +        self._cycle_count = 0 | 
|  | 107 | +        self._execution_barrier.abort()  # Break barrier to unblock threads | 
|  | 108 | + | 
|  | 109 | +    def check_should_continue_coordination(self): | 
|  | 110 | +        if self._num_layers is not None and self._cycle_count >= self._num_layers: | 
|  | 111 | +            return False | 
|  | 112 | +        return True | 
|  | 113 | + | 
|  | 114 | +    def is_coordination_enabled(self): | 
|  | 115 | +        return self._coordination_enabled | 
|  | 116 | + | 
|  | 117 | + | 
|  | 118 | +# Global coordinator | 
|  | 119 | +_hook_coordinator = HookCoordinator() | 
|  | 120 | + | 
|  | 121 | + | 
|  | 122 | +class SyncHook(torch.autograd.Function): | 
|  | 123 | +    @staticmethod | 
|  | 124 | +    def forward(ctx, x, hook_name=""): | 
|  | 125 | +        ctx.hook_name = hook_name | 
|  | 126 | +        # handle edge case for transformer level boundary | 
|  | 127 | +        if _hook_coordinator._coordination_enabled and hook_name == "D": | 
|  | 128 | +            _hook_coordinator._cycle_count += 1 | 
|  | 129 | +            if not _hook_coordinator.check_should_continue_coordination(): | 
|  | 130 | +                _hook_coordinator.disable_coordination() | 
|  | 131 | +                return x | 
|  | 132 | + | 
|  | 133 | +        _hook_coordinator.barrier() | 
|  | 134 | +        return x | 
|  | 135 | + | 
|  | 136 | +    @staticmethod | 
|  | 137 | +    def backward(ctx, grad_output): | 
|  | 138 | +        hook_name = ctx.hook_name | 
|  | 139 | + | 
|  | 140 | +        # Edge case, skip initial barrier, all subsequent backward hooks will acquire | 
|  | 141 | +        if hook_name == "D" and _hook_coordinator._cycle_count == 0: | 
|  | 142 | +            return grad_output, None | 
|  | 143 | + | 
|  | 144 | +        _hook_coordinator.barrier() | 
|  | 145 | +        return grad_output, None | 
|  | 146 | + | 
|  | 147 | + | 
|  | 148 | +def _count_moe_modules(model): | 
|  | 149 | +    """Count MoE modules directly""" | 
|  | 150 | +    from torchtitan.models.moe import MoE | 
|  | 151 | + | 
|  | 152 | +    moe_count = 0 | 
|  | 153 | +    for _, module in model.named_modules(): | 
|  | 154 | +        if isinstance(module, MoE): | 
|  | 155 | +            moe_count += 1 | 
|  | 156 | +    return moe_count | 
|  | 157 | + | 
|  | 158 | + | 
|  | 159 | +# import fbvscode | 
|  | 160 | +# fbvscode.attach_debugger() | 
|  | 161 | + | 
|  | 162 | +device_type, device_module = get_device_info() | 
|  | 163 | + | 
|  | 164 | + | 
|  | 165 | +def overlap_callback(action: _Action, ctx: _PipelineContext): | 
|  | 166 | +    """ | 
|  | 167 | +    Custom callback for OVERLAP_F_B computation that allows expert parallel communication | 
|  | 168 | +    and pipeline parallel computation to overlap. | 
|  | 169 | +    """ | 
|  | 170 | +    schedule = ctx.schedule_ref | 
|  | 171 | +    assert isinstance(schedule, _PipelineScheduleRuntime) | 
|  | 172 | +    stage_index_to_stage: dict[int, _PipelineStageBase] = { | 
|  | 173 | +        stage.stage_index: stage for stage in schedule._stages | 
|  | 174 | +    } | 
|  | 175 | +    assert action.sub_actions is not None | 
|  | 176 | +    fwd_action = action.sub_actions[0] | 
|  | 177 | +    bwd_action = action.sub_actions[1] | 
|  | 178 | + | 
|  | 179 | +    # Get stages | 
|  | 180 | +    forward_stage_index = fwd_action.stage_index | 
|  | 181 | +    forward_mb_index = fwd_action.microbatch_index | 
|  | 182 | +    assert forward_mb_index is not None | 
|  | 183 | +    backward_stage_index = bwd_action.stage_index | 
|  | 184 | +    backward_stage = stage_index_to_stage[backward_stage_index] | 
|  | 185 | + | 
|  | 186 | +    # Forward setup | 
|  | 187 | +    arg_mbs = ctx.arg_mbs | 
|  | 188 | +    kwarg_mbs = ctx.kwarg_mbs | 
|  | 189 | +    assert arg_mbs is not None and kwarg_mbs is not None | 
|  | 190 | +    fwd_recv_ops = schedule.fwd_recv_ops | 
|  | 191 | +    forward_stage = stage_index_to_stage[forward_stage_index] | 
|  | 192 | +    forward_is_next_stage_on_this_rank = forward_stage_index + 1 in stage_index_to_stage | 
|  | 193 | +    forward_is_prev_stage_on_this_rank = forward_stage_index - 1 in stage_index_to_stage | 
|  | 194 | + | 
|  | 195 | +    # Backward setup | 
|  | 196 | +    backward_is_next_stage_on_this_rank = ( | 
|  | 197 | +        backward_stage.stage_index + 1 in stage_index_to_stage | 
|  | 198 | +    ) | 
|  | 199 | +    backward_is_prev_stage_on_this_rank = ( | 
|  | 200 | +        backward_stage.stage_index - 1 in stage_index_to_stage | 
|  | 201 | +    ) | 
|  | 202 | +    backward_mb_index = bwd_action.microbatch_index | 
|  | 203 | +    assert backward_mb_index is not None | 
|  | 204 | +    bwd_recv_ops = schedule.bwd_recv_ops | 
|  | 205 | + | 
|  | 206 | +    # Fwd receives | 
|  | 207 | +    if ( | 
|  | 208 | +        not forward_stage.is_first | 
|  | 209 | +        # no recv op expected for V-schedule special case | 
|  | 210 | +        and not forward_is_prev_stage_on_this_rank | 
|  | 211 | +    ): | 
|  | 212 | +        assert ( | 
|  | 213 | +            forward_stage_index, | 
|  | 214 | +            forward_mb_index, | 
|  | 215 | +        ) in fwd_recv_ops, f"Computing {action=} before receiving input" | 
|  | 216 | +        _wait_batch_p2p(fwd_recv_ops.pop((forward_stage_index, forward_mb_index))) | 
|  | 217 | + | 
|  | 218 | +    # Bwd receives | 
|  | 219 | +    if ( | 
|  | 220 | +        not backward_stage.is_last | 
|  | 221 | +        # no recv op expected for V-schedule special case | 
|  | 222 | +        and not backward_is_next_stage_on_this_rank | 
|  | 223 | +    ): | 
|  | 224 | +        assert ( | 
|  | 225 | +            backward_stage_index, | 
|  | 226 | +            backward_mb_index, | 
|  | 227 | +        ) in bwd_recv_ops, f"Attempted to run compute {action=} before receiving input" | 
|  | 228 | +        _wait_batch_p2p(bwd_recv_ops.pop((backward_stage_index, backward_mb_index))) | 
|  | 229 | + | 
|  | 230 | +    # We count num layers in case the stage layers differ | 
|  | 231 | +    # If they differ than we only want coordination to happen for the min amount of layers | 
|  | 232 | +    min_num_layers = min( | 
|  | 233 | +        _count_moe_modules(forward_stage.submod), | 
|  | 234 | +        _count_moe_modules(backward_stage.submod), | 
|  | 235 | +    ) | 
|  | 236 | +    # PP computation ======================================================== | 
|  | 237 | +    _hook_coordinator.enable_coordination(num_layers=min_num_layers) | 
|  | 238 | +    main_stream = torch.accelerator.current_stream(device_module) | 
|  | 239 | + | 
|  | 240 | +    # Shared container for exception from backward thread | 
|  | 241 | +    def run_backward(): | 
|  | 242 | +        schedule._assert_unsharded(backward_stage) | 
|  | 243 | +        # Set the backward thread to use the same stream as forward | 
|  | 244 | +        device_module.set_stream(main_stream) | 
|  | 245 | +        with record_function( | 
|  | 246 | +            f"backward_stage_{backward_stage_index}_mb_{backward_mb_index}" | 
|  | 247 | +        ): | 
|  | 248 | +            loss = schedule._maybe_get_loss(backward_stage, backward_mb_index) | 
|  | 249 | +            schedule.backward_counter[backward_stage_index] += 1 | 
|  | 250 | +            last_backward = ( | 
|  | 251 | +                schedule.backward_counter[backward_stage_index] | 
|  | 252 | +                == schedule._n_microbatches | 
|  | 253 | +            ) | 
|  | 254 | +            backward_stage.backward_one_chunk( | 
|  | 255 | +                backward_mb_index, | 
|  | 256 | +                loss=loss, | 
|  | 257 | +                full_backward=True, | 
|  | 258 | +                last_backward=last_backward, | 
|  | 259 | +            ) | 
|  | 260 | + | 
|  | 261 | +            if backward_is_prev_stage_on_this_rank: | 
|  | 262 | +                stage_index_to_stage[backward_stage_index - 1].set_local_bwd_input( | 
|  | 263 | +                    backward_stage.get_local_bwd_output(backward_mb_index), | 
|  | 264 | +                    backward_mb_index, | 
|  | 265 | +                ) | 
|  | 266 | + | 
|  | 267 | +    def run_forward(): | 
|  | 268 | +        schedule._assert_unsharded(forward_stage) | 
|  | 269 | +        output = forward_stage.forward_one_chunk( | 
|  | 270 | +            forward_mb_index, | 
|  | 271 | +            arg_mbs[forward_mb_index], | 
|  | 272 | +            kwarg_mbs[forward_mb_index], | 
|  | 273 | +        ) | 
|  | 274 | +        schedule._maybe_compute_loss( | 
|  | 275 | +            forward_stage, output, ctx.target_mbs, forward_mb_index | 
|  | 276 | +        ) | 
|  | 277 | +        if forward_is_next_stage_on_this_rank: | 
|  | 278 | +            stage_index_to_stage[forward_stage_index + 1].set_local_fwd_input( | 
|  | 279 | +                output, forward_mb_index | 
|  | 280 | +            ) | 
|  | 281 | + | 
|  | 282 | +    # Run forward and backward in parallel | 
|  | 283 | +    thread = threading.Thread(target=run_backward, daemon=True) | 
|  | 284 | +    thread.start() | 
|  | 285 | +    run_forward() | 
|  | 286 | +    thread.join() | 
|  | 287 | + | 
|  | 288 | +    _hook_coordinator.disable_coordination() | 
0 commit comments