diff --git a/megatron/core/tensor_parallel/mappings.py b/megatron/core/tensor_parallel/mappings.py index 9ff69c9dc31..121f7c45a8e 100644 --- a/megatron/core/tensor_parallel/mappings.py +++ b/megatron/core/tensor_parallel/mappings.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. import torch @@ -419,11 +419,12 @@ def backward(ctx, grad_output): class _AllToAll(torch.autograd.Function): @staticmethod - def forward(ctx, group, input, output_split_sizes, input_split_sizes): + def forward(ctx, group, input, output_split_sizes, input_split_sizes, use_nccl_stream=False): """Forward function.""" ctx.group = group ctx.output_split_sizes = output_split_sizes ctx.input_split_sizes = input_split_sizes + ctx.use_nccl_stream = use_nccl_stream world_size = group.size() # Bypass the function if we are using only 1 GPU. @@ -441,13 +442,24 @@ def forward(ctx, group, input, output_split_sizes, input_split_sizes): dtype=input.dtype, device=torch.cuda.current_device(), ) - torch.distributed.all_to_all_single( - output, - input, - output_split_sizes=output_split_sizes, - input_split_sizes=input_split_sizes, - group=group, - ) + if use_nccl_stream: + handle = torch.distributed.all_to_all_single( + output, + input, + output_split_sizes=output_split_sizes, + input_split_sizes=input_split_sizes, + group=group, + async_op=True, + ) + handle.wait() + else: + torch.distributed.all_to_all_single( + output, + input, + output_split_sizes=output_split_sizes, + input_split_sizes=input_split_sizes, + group=group, + ) return output @staticmethod @@ -455,7 +467,14 @@ def backward(ctx, *grad_output): """Backward function.""" return ( None, - _AllToAll.apply(ctx.group, *grad_output, ctx.input_split_sizes, ctx.output_split_sizes), + _AllToAll.apply( + ctx.group, + *grad_output, + ctx.input_split_sizes, + ctx.output_split_sizes, + ctx.use_nccl_stream, + ), + None, None, None, ) @@ -532,10 +551,12 @@ def reduce_scatter_last_dim_to_tensor_parallel_region(input_, group=None): return _ReduceScatterToTensorParallelRegion.apply(input_, group) -def all_to_all(group, input_, output_split_sizes_=None, input_split_sizes=None): +def all_to_all( + group, input_, output_split_sizes_=None, input_split_sizes=None, use_nccl_stream=False +): """Wrapper for autograd function""" assert group is not None, "group should not be None" - return _AllToAll.apply(group, input_, output_split_sizes_, input_split_sizes) + return _AllToAll.apply(group, input_, output_split_sizes_, input_split_sizes, use_nccl_stream) def all_to_all_sp2hp(input_, group=None): diff --git a/megatron/core/transformer/moe/shared_experts.py b/megatron/core/transformer/moe/shared_experts.py index 25d5db0b979..2443624919e 100644 --- a/megatron/core/transformer/moe/shared_experts.py +++ b/megatron/core/transformer/moe/shared_experts.py @@ -1,7 +1,9 @@ -# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. import warnings from copy import deepcopy +from enum import Enum +from functools import wraps from typing import Optional import torch @@ -27,6 +29,63 @@ ) +class SharedExpertState(Enum): + """State machine states for SharedExpertMLP overlapped forward pass.""" + + IDLE = 0 + PRE_FORWARD_COMM_DONE = 1 + FC1_FORWARD_DONE = 2 + FC2_FORWARD_DONE = 3 + POST_FORWARD_COMM_DONE = 4 + + +def overlap_state_check(required_state: "SharedExpertState", next_state: "SharedExpertState"): + """ + Decorator to validate overlap state and cached variables before method execution, + and update state after method execution. + + Args: + required_state: The expected SharedExpertState before this method runs. + next_state: The SharedExpertState to transition to after method execution. + """ + + def decorator(method): + @wraps(method) + def wrapper(self, *args, **kwargs): + # Check overlap is enabled + assert ( + self.config.moe_shared_expert_overlap + ), f"{method.__name__} requires --moe-shared-expert-overlap to be set" + # Check state machine + assert self._overlap_state == required_state, ( + f"{method.__name__} must be called from {required_state.name} state, " + f"but current state is {self._overlap_state.name}" + ) + # Execute method + result = method(self, *args, **kwargs) + # Update state after method execution + self._overlap_state = next_state + return result + + return wrapper + + return decorator + + +class _BackwardStreamWait(torch.autograd.Function): + @staticmethod + def forward(ctx, input, stream): + """forward""" + ctx.stream = stream + return input + + @staticmethod + def backward(ctx, grad_output): + """backward with stream wait""" + ctx.stream.wait_stream(torch.cuda.current_stream()) + return grad_output, None + + class SharedExpertMLP(MLP): """ MLP layer for Shared Experts. @@ -117,8 +176,11 @@ def __init__( self.cached_output = None self.gate_score = None - if self.stream is None: - self.stream = torch.cuda.Stream() + # State machine to ensure correct calling order of overlapped forward methods + self._overlap_state = SharedExpertState.IDLE + + if SharedExpertMLP.stream is None: + SharedExpertMLP.stream = torch.cuda.Stream() def forward(self, hidden_states): """Forward function""" @@ -145,15 +207,19 @@ def sharded_state_dict( sharded_state_dict.update(sub_sd) return sharded_state_dict - def pre_forward_comm(self, input): + def wait_current_stream(self): + """Wait for the current stream to complete.""" + self.stream.wait_stream(torch.cuda.current_stream()) + + @overlap_state_check(SharedExpertState.IDLE, SharedExpertState.PRE_FORWARD_COMM_DONE) + def pre_forward_comm(self, input, wait_current_stream=True): """ All Gather for SP before forward. This function is used to overlap shared experts with the dispatcher. It is only useful when --moe-shared-expert-overlap is set and may be changed. """ - assert self.config.moe_shared_expert_overlap - assert self.cached_output is None - self.stream.wait_stream(torch.cuda.current_stream()) + if wait_current_stream: + self.wait_current_stream() with torch.cuda.stream(self.stream): if self.use_shared_expert_gate: logits = torch.nn.functional.linear(input, self.gate_weight) @@ -166,16 +232,15 @@ def pre_forward_comm(self, input): self.cached_fc1_input = copy_to_tensor_model_parallel_region(input) set_tensor_grad_fn_sequence_sr(self.cached_fc1_input, torch.iinfo(torch.int).max) + @overlap_state_check( + SharedExpertState.PRE_FORWARD_COMM_DONE, SharedExpertState.FC1_FORWARD_DONE + ) def linear_fc1_forward_and_act(self, overlapped_comm_output=None): """ Do Linear FC1 and activation function forward. This function is used to overlap shared experts with the dispatcher. It is only useful when --moe-shared-expert-overlap is set and may be changed. """ - assert self.config.moe_shared_expert_overlap - assert self.cached_fc1_input is not None - if overlapped_comm_output is not None: - set_tensor_grad_fn_sequence_sr(overlapped_comm_output, torch.iinfo(torch.int).max) with torch.cuda.stream(self.stream): # [s, b, 4 * h/p] intermediate_parallel, bias_parallel = self.linear_fc1(self.cached_fc1_input) @@ -216,15 +281,22 @@ def glu(x): intermediate_parallel = self.activation_func(intermediate_parallel) self.cached_fc2_input = intermediate_parallel + # Tensor sequence number is used to control the backward order. + # Decrease the sequence number of the expert output to make the comm launched first + # in the backward order. + if overlapped_comm_output is not None and overlapped_comm_output.grad_fn is not None: + target_sequence_nr = overlapped_comm_output.grad_fn._sequence_nr() - 1 + set_tensor_grad_fn_sequence_sr(intermediate_parallel, target_sequence_nr) + # Make sure the shared expert fc1 backward is launched after the routed fc1 backward + self.cached_fc2_input = _BackwardStreamWait.apply(intermediate_parallel, self.stream) + @overlap_state_check(SharedExpertState.FC1_FORWARD_DONE, SharedExpertState.FC2_FORWARD_DONE) def linear_fc2_forward(self, overlapped_comm_output=None): """ Do Linear FC2 forward. This function is used to overlap shared experts with the dispatcher. It is only useful when --moe-shared-expert-overlap is set and may be changed. """ - assert self.config.moe_shared_expert_overlap - assert self.cached_fc2_input is not None if overlapped_comm_output is not None: set_tensor_grad_fn_sequence_sr(overlapped_comm_output, torch.iinfo(torch.int).max) with torch.cuda.stream(self.stream): @@ -232,14 +304,15 @@ def linear_fc2_forward(self, overlapped_comm_output=None): self.cached_fc2_output, _ = self.linear_fc2(self.cached_fc2_input) self.cached_fc2_input = None + @overlap_state_check( + SharedExpertState.FC2_FORWARD_DONE, SharedExpertState.POST_FORWARD_COMM_DONE + ) def post_forward_comm(self): """ Reduce scatter for SP after forward. This function is used to overlap shared experts with the dispatcher. It is only useful when --moe-shared-expert-overlap is set and may be changed. """ - assert self.config.moe_shared_expert_overlap - assert self.cached_fc2_output is not None with torch.cuda.stream(self.stream): if self.config.sequence_parallel: self.cached_output = reduce_scatter_to_sequence_parallel_region( @@ -252,14 +325,13 @@ def post_forward_comm(self): self.cached_fc2_output = None set_tensor_grad_fn_sequence_sr(self.cached_output, torch.iinfo(torch.int).max) + @overlap_state_check(SharedExpertState.POST_FORWARD_COMM_DONE, SharedExpertState.IDLE) def get_output(self): """ Gets the module forward output. This function is used to overlap shared experts with the dispatcher. It is only useful when --moe-shared-expert-overlap is set and may be changed. """ - assert self.config.moe_shared_expert_overlap - assert self.cached_output is not None with torch.cuda.stream(self.stream): if self.use_shared_expert_gate: assert self.gate_score is not None diff --git a/megatron/core/transformer/moe/token_dispatcher.py b/megatron/core/transformer/moe/token_dispatcher.py index c7c7ff147e5..a3d6999e9cb 100644 --- a/megatron/core/transformer/moe/token_dispatcher.py +++ b/megatron/core/transformer/moe/token_dispatcher.py @@ -66,6 +66,8 @@ def __init__( """ self.config = config self.shared_experts: Optional[SharedExpertMLP] = None + # Whether to use NCCL stream for A2A communication, otherwise default stream is used. + self.use_nccl_stream = False # Will be set to True when shared_experts is set. self.ep_group = pg_collection.ep # use pg_collection.expt_tp_group as tensor parallel group in this module. @@ -197,6 +199,7 @@ def set_shared_experts(self, shared_experts): """Set shared expert to the dispatcher.""" assert self.config.moe_shared_expert_overlap self.shared_experts = shared_experts + self.use_nccl_stream = True class MoEAllGatherTokenDispatcher(MoETokenDispatcher): @@ -624,15 +627,33 @@ def token_dispatch(self, permutated_local_input_tokens, permuted_probs): Returns: A tuple of tokens and probabilities after All-to-All. """ + # Make sure the shared experts fc1 is overlapped with dispatch A2A + # when CUDA_DEVICE_MAX_CONNECTIONS>1. + if self.shared_experts is not None: + self.shared_experts.wait_current_stream() # Perform expert parallel AlltoAll communication self.tokens_per_expert = self._maybe_dtoh_and_synchronize( "before_ep_alltoall", self.tokens_per_expert ) global_input_tokens = all_to_all( - self.ep_group, permutated_local_input_tokens, self.output_splits, self.input_splits + self.ep_group, + permutated_local_input_tokens, + self.output_splits, + self.input_splits, + use_nccl_stream=self.use_nccl_stream, ) + # Move the shared experts fc1 right after the tokens A2A, to prevent the probs A2A + # block the launch of fc1 GEMM when CUDA_DEVICE_MAX_CONNECTIONS=1. + # Forward launch order: tokens A2A -> shared experts fc1 -> probs A2A + # Backward launch order: probs A2A -> tokens A2A -> shared experts fc1 + if self.shared_experts is not None: + self.shared_experts.linear_fc1_forward_and_act(global_input_tokens) global_probs = all_to_all( - self.ep_group, permuted_probs, self.output_splits, self.input_splits + self.ep_group, + permuted_probs, + self.output_splits, + self.input_splits, + use_nccl_stream=self.use_nccl_stream, ) return global_input_tokens, global_probs @@ -650,9 +671,6 @@ def dispatch_postprocess(self, global_input_tokens, global_probs): Returns: A tuple of processed tokens, token counts per expert, and processed probabilities. """ - if self.shared_experts is not None: - self.shared_experts.linear_fc1_forward_and_act(global_input_tokens) - if self.tp_size > 1: if self.output_splits_tp is None: output_split_sizes = None @@ -770,11 +788,22 @@ def token_combine( Returns: Tokens after the All-to-All communication for combining. """ + # Make sure the shared experts fc2 is not overlapped with routed experts fc1 + # when CUDA_DEVICE_MAX_CONNECTIONS>1. + if self.shared_experts is not None: + self.shared_experts.wait_current_stream() # Perform expert parallel AlltoAll communication # hidden_states: [SEQL, H] -> [SEQL, H/TP] permutated_local_input_tokens = all_to_all( - self.ep_group, hidden_states, self.input_splits, self.output_splits + self.ep_group, + hidden_states, + self.input_splits, + self.output_splits, + use_nccl_stream=self.use_nccl_stream, ) + if self.shared_experts is not None: + self.shared_experts.linear_fc2_forward(permutated_local_input_tokens) + self.shared_experts.post_forward_comm() return permutated_local_input_tokens def combine_postprocess(self, permutated_local_input_tokens): @@ -790,9 +819,6 @@ def combine_postprocess(self, permutated_local_input_tokens): Returns: The final MoE layer output reshaped to its original dimensions. """ - if self.shared_experts is not None: - self.shared_experts.linear_fc2_forward(permutated_local_input_tokens) - self.shared_experts.post_forward_comm() # Unpermutation 1: AlltoAll output to output output = unpermute( @@ -806,6 +832,9 @@ def combine_postprocess(self, permutated_local_input_tokens): # Reshape the output tensor output = output.view(self.hidden_shape) + # Manually release the metadata to avoid memory leak. + self.probs = None + self.routing_map = None # Add shared experts output if self.shared_experts is not None: @@ -1217,6 +1246,9 @@ def combine( ) # Release the handle after combine operation self.handle = None + # Manually release the metadata to avoid memory leak. + self.dispatched_indices = None + self.dispatched_probs = None return hidden_states def _pad_routing_map( @@ -1337,11 +1369,6 @@ def __init__( "--moe-flex-dispatcher-backend=hybridep" ) - def set_shared_experts(self, shared_experts): - raise NotImplementedError( - "Shared expert overlap is not supported in Flex Token Dispatcher." - ) - def _initialize_metadata(self, routing_map: torch.Tensor, probs: torch.Tensor) -> torch.Tensor: """ Initialize the routing map and probs to a unified format covering the TPxEP group. @@ -1420,10 +1447,16 @@ def token_dispatch( Returns: A tuple of dispatched tokens and probabilities. """ - return ( - self._comm_manager.dispatch(hidden_states, async_finish, allocate_on_comm_stream), - self._comm_manager.dispatched_probs, + if self.shared_experts is not None: + self.shared_experts.wait_current_stream() + dispatched_hidden_states = self._comm_manager.dispatch( + hidden_states, async_finish, allocate_on_comm_stream ) + if self.shared_experts is not None: + self.shared_experts.pre_forward_comm(hidden_states, wait_current_stream=False) + self.shared_experts.linear_fc1_forward_and_act(dispatched_hidden_states) + + return dispatched_hidden_states, self._comm_manager.dispatched_probs def dispatch_postprocess(self, hidden_states: torch.Tensor, probs: torch.Tensor): """Converts dispatched tokens to a per-expert format for expert processing. @@ -1471,6 +1504,10 @@ def token_combine( Returns: Combined tokens after fused un-permutation and communication. """ + # Make sure the shared experts fc2 is not overlapped with routed experts GEMM + # when CUDA_DEVICE_MAX_CONNECTIONS>1. + if self.shared_experts is not None: + self.shared_experts.wait_current_stream() return self._comm_manager.combine(hidden_states, async_finish, allocate_on_comm_stream) def combine_postprocess(self, hidden_states: torch.Tensor): @@ -1486,4 +1523,8 @@ def combine_postprocess(self, hidden_states: torch.Tensor): Returns: The final MoE layer output reshaped to its original dimensions. """ + if self.shared_experts is not None: + self.shared_experts.linear_fc2_forward(hidden_states) + self.shared_experts.post_forward_comm() + hidden_states += self.shared_experts.get_output() return hidden_states.view(self.hidden_shape) diff --git a/megatron/core/transformer/transformer_config.py b/megatron/core/transformer/transformer_config.py index cf36c0fc631..d0a0e844e13 100644 --- a/megatron/core/transformer/transformer_config.py +++ b/megatron/core/transformer/transformer_config.py @@ -888,10 +888,11 @@ def __post_init__(self): f"but got {self.moe_shared_expert_intermediate_size}" ) if self.moe_shared_expert_overlap and self.moe_token_dispatcher_type not in [ - "alltoall" + "alltoall", + "flex", ]: raise ValueError( - f"moe_shared_expert_overlap only works with alltoall token dispatcher." + f"moe_shared_expert_overlap only works with alltoall or flex token dispatcher." ) if isinstance(self.moe_router_load_balancing_type, list): diff --git a/tests/unit_tests/transformer/moe/test_shared_experts.py b/tests/unit_tests/transformer/moe/test_shared_experts.py index f721c482937..8cff089f74e 100644 --- a/tests/unit_tests/transformer/moe/test_shared_experts.py +++ b/tests/unit_tests/transformer/moe/test_shared_experts.py @@ -1,126 +1,132 @@ -# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + +import dataclasses import pytest import torch from megatron.core.models.gpt.gpt_layer_specs import get_gpt_layer_local_spec +from megatron.core.parallel_state import get_tensor_model_parallel_world_size from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed from megatron.core.transformer.moe.moe_layer import MoELayer from megatron.core.transformer.transformer_config import TransformerConfig from tests.unit_tests.test_utilities import Utils -class TestSharedExperts: +def is_deep_ep_available(): + from megatron.core.transformer.moe.fused_a2a import HAVE_DEEP_EP - def setup_method(self, method): - pass + return HAVE_DEEP_EP - def teardown_method(self, method): - Utils.destroy_model_parallel() - @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") - @pytest.mark.internal - def test_gpu_forward(self): - Utils.initialize_model_parallel(1, 1) - model_parallel_cuda_manual_seed(123) - print("done intializing") - num_moe_experts = 2 - transformer_config = TransformerConfig( +if is_deep_ep_available(): + TOKEN_DISPATCHER_TYPES = ["alltoall", "flex"] +else: + TOKEN_DISPATCHER_TYPES = ["alltoall"] + + +class TestSharedExperts: + def setup_method(self, method): + self.config = TransformerConfig( num_layers=1, - hidden_size=12, + hidden_size=32, num_attention_heads=4, - num_moe_experts=num_moe_experts, + num_moe_experts=16, moe_shared_expert_intermediate_size=32, + moe_shared_expert_overlap=False, + moe_token_dispatcher_type="alltoall", use_cpu_initialization=True, activation_func=torch.nn.functional.silu, gated_linear_unit=True, bias_activation_fusion=True, - moe_router_load_balancing_type="sinkhorn", - moe_router_topk=1, + moe_router_load_balancing_type="aux_loss", + moe_router_topk=4, add_bias_linear=False, ) + + def get_moe_layer(self, **kargs) -> MoELayer: transformer_layer_spec = get_gpt_layer_local_spec( - num_experts=num_moe_experts, moe_grouped_gemm=False + num_experts=self.config.num_moe_experts, moe_grouped_gemm=False ) - self.moe_layer = MoELayer( - transformer_config, transformer_layer_spec.submodules.mlp.submodules - ) - - assert isinstance(self.moe_layer, MoELayer) - - num_weights = sum([p.numel() for p in self.moe_layer.parameters()]) - assert num_weights == 3480 + 1152 - assert self.moe_layer.shared_experts is not None - assert self.moe_layer.shared_experts.stream is None - assert self.moe_layer.token_dispatcher.shared_experts is None - - moe_layer = self.moe_layer + new_config = dataclasses.replace(self.config, **kargs) + if get_tensor_model_parallel_world_size() > 1: + new_config.sequence_parallel = True + if new_config.moe_token_dispatcher_type == "flex": + new_config.moe_enable_deepep = True + moe_layer = MoELayer(new_config, transformer_layer_spec.submodules.mlp.submodules) moe_layer.cuda() - # [sequence length, batch size, hidden size] - hidden_states = torch.ones((32, 2, moe_layer.config.hidden_size)) - hidden_states = hidden_states.cuda() - output, _ = moe_layer(hidden_states) - assert output.shape[0] == 32 - assert output.shape[1] == 2 - assert output.shape[2] == moe_layer.config.hidden_size - assert output.dtype == torch.float32 - assert output.device.type == 'cuda' - - -class TestSharedExpertsOverlap: - - def setup_method(self, method): - pass + return moe_layer def teardown_method(self, method): Utils.destroy_model_parallel() - @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") @pytest.mark.internal - def test_gpu_forward(self): - Utils.initialize_model_parallel(1, 1) - model_parallel_cuda_manual_seed(123) - print("done intializing") - num_moe_experts = 2 - transformer_config = TransformerConfig( - num_layers=1, - hidden_size=12, - num_attention_heads=4, - num_moe_experts=num_moe_experts, - moe_shared_expert_intermediate_size=32, - moe_shared_expert_overlap=True, - moe_token_dispatcher_type="alltoall", - use_cpu_initialization=True, - activation_func=torch.nn.functional.silu, - gated_linear_unit=True, - bias_activation_fusion=True, - moe_router_load_balancing_type="sinkhorn", - moe_router_topk=1, - add_bias_linear=False, - ) - transformer_layer_spec = get_gpt_layer_local_spec( - num_experts=num_moe_experts, moe_grouped_gemm=False - ) - self.moe_layer = MoELayer( - transformer_config, transformer_layer_spec.submodules.mlp.submodules + @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") + @pytest.mark.parametrize("dispatcher_type", TOKEN_DISPATCHER_TYPES) + @pytest.mark.parametrize("tp_size, ep_size", [[1, 1], [4, 1], [1, 4], [2, 4]]) + def test_shared_expert_forward_backward(self, dispatcher_type: str, tp_size, ep_size): + """ + Tests that the MoELayer with and without shared expert overlap produce + identical outputs and gradients. + """ + if tp_size == 1 and ep_size == 1 and dispatcher_type == "flex": + pytest.skip("Flex dispatcher is not supported for tp=1, ep=1") + Utils.initialize_model_parallel( + tensor_model_parallel_size=tp_size, expert_model_parallel_size=ep_size ) + # Create MoE layer with shared expert overlap enabled. + model_parallel_cuda_manual_seed(123) + moe_layer_overlap = self.get_moe_layer( + moe_shared_expert_overlap=True, moe_token_dispatcher_type=dispatcher_type + ).to(dtype=torch.bfloat16) - assert isinstance(self.moe_layer, MoELayer) - - num_weights = sum([p.numel() for p in self.moe_layer.parameters()]) - assert num_weights == 3480 + 1152 - assert self.moe_layer.shared_experts is not None - assert self.moe_layer.shared_experts.stream is not None - assert self.moe_layer.token_dispatcher.shared_experts is not None - - moe_layer = self.moe_layer - moe_layer.cuda() - # [sequence length, batch size, hidden size] - hidden_states = torch.ones((32, 2, moe_layer.config.hidden_size)) - hidden_states = hidden_states.cuda() - output, _ = moe_layer(hidden_states) - assert output.shape[0] == 32 - assert output.shape[1] == 2 - assert output.shape[2] == moe_layer.config.hidden_size - assert output.dtype == torch.float32 - assert output.device.type == 'cuda' + # Create MoE layer with shared expert overlap disabled. + model_parallel_cuda_manual_seed(123) + moe_layer_no_overlap = self.get_moe_layer( + moe_shared_expert_overlap=False, moe_token_dispatcher_type=dispatcher_type + ).to(dtype=torch.bfloat16) + moe_layer_no_overlap.load_state_dict(moe_layer_overlap.state_dict()) + + # Sanity check that the weights are identical. + for p_overlap, p_no_overlap in zip( + moe_layer_overlap.parameters(), moe_layer_no_overlap.parameters() + ): + assert torch.equal(p_overlap, p_no_overlap) + + # Verify attributes of the MoE layers. + num_weights_overlap = sum([p.numel() for p in moe_layer_overlap.parameters()]) + num_weights_no_overlap = sum([p.numel() for p in moe_layer_no_overlap.parameters()]) + assert num_weights_overlap == num_weights_no_overlap + + assert moe_layer_overlap.shared_experts is not None + assert moe_layer_overlap.shared_experts.stream is not None + assert moe_layer_overlap.token_dispatcher.shared_experts is not None + + assert moe_layer_no_overlap.shared_experts is not None + assert moe_layer_no_overlap.token_dispatcher.shared_experts is None + + # Create a dummy input tensor. + hidden_states = torch.randn( + (32, 2, self.config.hidden_size), + requires_grad=True, + device="cuda", + dtype=torch.bfloat16, + ) + hidden_states_no_overlap = hidden_states.clone().detach().requires_grad_(True) + + # Forward pass. + output_overlap, _ = moe_layer_overlap(hidden_states) + output_no_overlap, _ = moe_layer_no_overlap(hidden_states_no_overlap) + torch.testing.assert_close(output_overlap, output_no_overlap) + + # Backward pass. + output_overlap.mean().backward() + output_no_overlap.mean().backward() + + # Check gradients. + for p_overlap, p_no_overlap in zip( + moe_layer_overlap.parameters(), moe_layer_no_overlap.parameters() + ): + assert torch.allclose( + p_overlap.grad, p_no_overlap.grad + ), f"max diff: {torch.max(torch.abs(p_overlap.grad - p_no_overlap.grad))}"