diff --git a/megatron/core/extensions/transformer_engine.py b/megatron/core/extensions/transformer_engine.py index 9238d3aa0ec..ea09f3de6df 100644 --- a/megatron/core/extensions/transformer_engine.py +++ b/megatron/core/extensions/transformer_engine.py @@ -1729,6 +1729,13 @@ def __init__( self.explicit_expert_comm = is_expert and (tp_size > 1 or self.expert_parallel) + # Save original parallel_mode before clearing it for explicit_expert_comm. + # When explicit_expert_comm is True, Megatron handles TP communication externally + # and passes parallel_mode=None to TE. This causes TE to set partition_dim=0 on + # all weights (its default for non-parallel mode). We need to fix this after init + # so that refit/resharding can correctly identify which dimension is TP-partitioned. + original_parallel_mode = parallel_mode + if self.explicit_expert_comm: if parallel_mode == "column": output_size = divide(output_size, tp_size) @@ -1759,6 +1766,21 @@ def __init__( for param in self.parameters(): setattr(param, "allreduce", not (is_expert and self.expert_parallel)) + # Explicitly stamp partition_dim and partition_stride on expert weight + # tensors when explicit_expert_comm cleared parallel_mode. TE ≤2.12 + # set these internally; TE ≥2.13 no longer does (parallel_mode=None + # is passed due to explicit_expert_comm). The resharding/refit planner + # relies on partition_dim to correctly plan TP gather/scatter operations. + # NOTE: we intentionally do NOT stamp tensor_model_parallel here — + # doing so would change num-zeros gradient counting. + if self.explicit_expert_comm and original_parallel_mode in ("column", "row"): + part_dim = 0 if original_parallel_mode == "column" else 1 + for i in range(num_gemms): + weight = getattr(self, f"weight{i}", None) + if weight is not None: + setattr(weight, "partition_dim", part_dim) + setattr(weight, "partition_stride", 1) + def merge_extra_states( self, state_dict, diff --git a/megatron/core/models/gpt/gpt_model.py b/megatron/core/models/gpt/gpt_model.py index 10371726f1c..d63b2c1ddfa 100644 --- a/megatron/core/models/gpt/gpt_model.py +++ b/megatron/core/models/gpt/gpt_model.py @@ -217,7 +217,10 @@ def __init__( if self.mtp_process: self.mtp = MultiTokenPredictionBlock( - config=self.config, spec=self.mtp_block_spec, vp_stage=vp_stage + config=self.config, + spec=self.mtp_block_spec, + vp_stage=vp_stage, + pg_collection=self.pg_collection, ) # Output diff --git a/megatron/core/resharding/planner.py b/megatron/core/resharding/planner.py index 444ea5a673a..1921a0290ec 100644 --- a/megatron/core/resharding/planner.py +++ b/megatron/core/resharding/planner.py @@ -152,6 +152,126 @@ def dst_key(op): return ops +def _plan_block_interleaved( + param_name: str, + src_metadata: ParameterMetadata, + dst_metadata: ParameterMetadata, + descriptors: list[ShardingDescriptor], + my_global_rank: int, +) -> list[tuple[int, tuple[slice, ...], tuple[slice, ...]]]: + """ + Block-interleaved TP planner for parameters with ``partition_sizes``. + + When a parameter packs multiple independently-sharded components of + *different* sizes (e.g. Mamba in_proj packs z, x, B, C, dt), a simple + contiguous concat produces the wrong layout. This function treats each + block independently: it gathers (or scatters) each block across TP ranks + before moving to the next block. + + ``partition_sizes`` lists the per-TP-rank block sizes along the partition + dim. Block *i* occupies ``[sum(sizes[:i]), sum(sizes[:i+1]))`` in the + local tensor on every TP rank. In the *full* (TP-gathered) tensor, block + *i* occupies ``[sum(full_sizes[:i]), sum(full_sizes[:i+1]))`` where + ``full_sizes[i] = sizes[i] * src_tp_world``. + """ + if not descriptors or descriptors[0].name != "tp": + return [] + d = descriptors[0] + if my_global_rank not in d.dst_dim_ranks: + return [] + + dim = d.dim + src_shape = tuple(src_metadata.shape) + dst_shape = tuple(dst_metadata.shape) + src_world = len(d.src_dim_ranks) + dst_world = len(d.dst_dim_ranks) + dst_local_rank = _get_rank_in_group(my_global_rank, d.dst_dim_ranks) + + # Use partition_sizes from whichever side has it (prefer src) + src_sizes = src_metadata.partition_sizes + dst_sizes = dst_metadata.partition_sizes + + if src_sizes is None and dst_sizes is None: + raise RuntimeError(f"{param_name}: _plan_block_interleaved called without partition_sizes") + + # Derive the full (un-sharded) block sizes + if src_sizes is not None: + num_blocks = len(src_sizes) + full_sizes = [s * src_world for s in src_sizes] + else: + num_blocks = len(dst_sizes) + full_sizes = [s * dst_world for s in dst_sizes] + + # Compute per-rank block sizes for both sides + if src_sizes is None: + src_sizes = [f // src_world for f in full_sizes] + if dst_sizes is None: + dst_sizes = [f // dst_world for f in full_sizes] + + # Validate conservation + for i in range(num_blocks): + if src_sizes[i] * src_world != dst_sizes[i] * dst_world: + raise RuntimeError( + f"{param_name}: block {i} size mismatch: " + f"src_sizes[{i}]={src_sizes[i]}*{src_world} != " + f"dst_sizes[{i}]={dst_sizes[i]}*{dst_world}" + ) + + ops: list[tuple[int, tuple[slice, ...], tuple[slice, ...]]] = [] + + # For each block, compute the transfer ops independently + src_block_offset = 0 # cumulative offset in source local tensor + dst_block_offset = 0 # cumulative offset in destination local tensor + + for blk in range(num_blocks): + src_blk_sz = src_sizes[blk] # per-src-rank size of this block + dst_blk_sz = dst_sizes[blk] # per-dst-rank size of this block + full_blk_sz = full_sizes[blk] + + # Within this block, use simple LCM tiling (stride=1) + Ns = src_world + Nd = dst_world + g = math.gcd(Ns, Nd) + L = (Ns // g) * Nd + if full_blk_sz % L != 0: + raise RuntimeError( + f"{param_name}: block {blk} full_size {full_blk_sz} not divisible by LCM {L}" + ) + unit = full_blk_sz // L + cps = L // Ns + cpd = L // Nd + + # This dst rank's segment within the block + g_dst_seg = dst_local_rank + for off in range(cpd): + g_micro = g_dst_seg * cpd + off + s_idx = g_micro // cps + in_seg = g_micro % cps + src_owner_in_dim = s_idx % src_world + src_global_rank = d.src_dim_ranks[src_owner_in_dim] + src_local_seg_idx = s_idx // src_world + src_start = src_block_offset + src_local_seg_idx * (cps * unit) + in_seg * unit + dst_start = dst_block_offset + off * unit + + src_slice = [slice(None)] * len(src_shape) + dst_slice = [slice(None)] * len(dst_shape) + src_slice[dim] = slice(src_start, src_start + unit) + dst_slice[dim] = slice(dst_start, dst_start + unit) + ops.append((src_global_rank, tuple(src_slice), tuple(dst_slice))) + + src_block_offset += src_blk_sz + dst_block_offset += dst_blk_sz + + # Stable sort by destination offset + def dst_key(op): + _, _, dsl = op + s = dsl[dim] + return s.start if isinstance(s, slice) else 0 + + ops.sort(key=dst_key) + return ops + + def _finalize_dp_transfers( param_name: str, src_metadata: ParameterMetadata, @@ -210,6 +330,16 @@ def _determine_source_ranks_for_dst_param( # Regular TP/DP planning with EP-resolved metadata descriptors = _build_descriptors_for_param(src_metadata=src_metadata, dst_metadata=dst_metadata) if descriptors: + # Use block-interleaved planner when partition_sizes is present + # (e.g. Mamba in_proj packs components of different sizes) + if src_metadata.partition_sizes is not None or dst_metadata.partition_sizes is not None: + return _plan_block_interleaved( + param_name=param_name, + src_metadata=src_metadata, + dst_metadata=dst_metadata, + descriptors=descriptors, + my_global_rank=my_global_rank, + ) return _plan_multi_dim_lcm( param_name=param_name, src_metadata=src_metadata, diff --git a/megatron/core/resharding/utils.py b/megatron/core/resharding/utils.py index 9edcad9d51b..5836ca608f7 100644 --- a/megatron/core/resharding/utils.py +++ b/megatron/core/resharding/utils.py @@ -44,6 +44,11 @@ class ParameterMetadata: is_tp: bool = False partition_dim: int = 0 partition_stride: int = 1 + # For parameters that pack multiple independently-sharded components of + # different sizes (e.g. Mamba in_proj packs z, x, B, C, dt). When present, + # lists the per-TP-rank block sizes along partition_dim. The refit planner + # interleaves these blocks rather than doing a simple contiguous concat. + partition_sizes: list[int] | None = None # EP sharding info (fused/grouped MoE) is_ep: bool = False @@ -258,6 +263,9 @@ def extract_param_metadata( is_tp = bool(getattr(param, 'tensor_model_parallel', False)) partition_dim = int(getattr(param, 'partition_dim', 0)) partition_stride = int(getattr(param, 'partition_stride', 1)) + partition_sizes = getattr(param, 'partition_sizes', None) + if partition_sizes is not None: + partition_sizes = list(partition_sizes) # SwiGLU/GLU compatibility: For gated linear units, fc1 stores interleaved [gate, up] portions # and requires partition_stride=2 for correct resharding. New models set this at construction @@ -271,6 +279,15 @@ def extract_param_metadata( # EP detection: Megatron convention - expert params are not allreduced is_ep = not bool(getattr(param, 'allreduce', True)) + # Expert-param detection for TP inference. When explicit_expert_comm is + # active (is_expert and (tp_size>1 or ep)), TE clears parallel_mode so + # tensor_model_parallel is never stamped — yet the weight IS TP-sharded + # when tp_size > 1. We detect expert params via num_experts + the + # per-expert naming convention (weightK / biasK in TEGroupedLinear). + is_expert_param = ( + num_experts is not None and _detect_expert_index_from_param_name(param_name) is not None + ) + tensor_parallel_group_ranks: list[int] | None = None expert_parallel_group_ranks: list[int] | None = None data_parallel_group_ranks: list[int] | None = None @@ -279,17 +296,22 @@ def extract_param_metadata( def _offset_ranks(ranks: list[int]) -> list[int]: return [r + rank_offset for r in ranks] if rank_offset else ranks - if is_ep: - expert_parallel_group_ranks = _offset_ranks(dist.get_process_group_ranks(pg_collection.ep)) - # For MoE params, prefer expert TP group when available, else regular TP - if is_tp and hasattr(pg_collection, 'expt_tp') and pg_collection.expt_tp is not None: - tensor_parallel_group_ranks = _offset_ranks( - dist.get_process_group_ranks(pg_collection.expt_tp) - ) - elif is_tp and hasattr(pg_collection, 'tp') and pg_collection.tp is not None: - tensor_parallel_group_ranks = _offset_ranks( - dist.get_process_group_ranks(pg_collection.tp) + if is_ep or is_expert_param: + if is_ep: + expert_parallel_group_ranks = _offset_ranks( + dist.get_process_group_ranks(pg_collection.ep) ) + # For expert params, always provide TP group ranks so the planner can + # handle TP size transitions (e.g., TP2→TP1). When explicit_expert_comm + # clears TE's parallel_mode, tensor_model_parallel may not be set even + # though the weight IS TP-sharded. Detect TP via group size instead. + expt_tp = getattr(pg_collection, 'expt_tp', None) + tp_grp = expt_tp if expt_tp is not None else getattr(pg_collection, 'tp', None) + if tp_grp is not None: + tp_ranks = _offset_ranks(dist.get_process_group_ranks(tp_grp)) + tensor_parallel_group_ranks = tp_ranks + if not is_tp and len(tp_ranks) > 1: + is_tp = True data_parallel_group_ranks = _offset_ranks(dist.get_process_group_ranks(pg_collection.dp)) elif is_tp: # Non-EP: use regular TP group @@ -301,6 +323,17 @@ def _offset_ranks(ranks: list[int]) -> list[int]: else: data_parallel_group_ranks = _offset_ranks(dist.get_process_group_ranks(pg_collection.dp)) + # Always provide TP group ranks so the planner can handle TP size transitions + # (e.g., TP2→TP1). When is_tp=False the param is replicated across the TP group, + # but the planner still needs to know the TP topology to plan gather/scatter ops + # when the *other* side of the reshard IS TP-sharded. + if ( + tensor_parallel_group_ranks is None + and hasattr(pg_collection, 'tp') + and pg_collection.tp is not None + ): + tensor_parallel_group_ranks = _offset_ranks(dist.get_process_group_ranks(pg_collection.tp)) + if hasattr(pg_collection, 'pp') and pg_collection.pp is not None: pipeline_parallel_group_ranks = _offset_ranks( dist.get_process_group_ranks(pg_collection.pp) @@ -318,6 +351,7 @@ def _offset_ranks(ranks: list[int]) -> list[int]: is_tp=is_tp, partition_dim=partition_dim, partition_stride=partition_stride, + partition_sizes=partition_sizes, is_ep=is_ep, num_experts=num_experts, owner_rank=owner_rank, diff --git a/megatron/core/ssm/mamba_mixer.py b/megatron/core/ssm/mamba_mixer.py index 04805a0087e..7ea61b38b1f 100644 --- a/megatron/core/ssm/mamba_mixer.py +++ b/megatron/core/ssm/mamba_mixer.py @@ -257,6 +257,19 @@ def __init__( tp_comm_buffer_name="fc1", tp_group=self.pg_collection.tp, ) + # in_proj packs [z, x, B, C, dt] into one ColumnParallelLinear. Each + # component is independently TP-sharded but with different sizes. When + # resharding across different TP sizes the planner must interleave + # per-component blocks rather than doing a contiguous concat. + # partition_sizes lists the per-TP-rank block sizes along partition_dim. + in_proj_partition_sizes = [ + self.d_inner_local_tp, # z + self.d_inner_local_tp, # x + self.ngroups_local_tp * self.d_state, # B + self.ngroups_local_tp * self.d_state, # C + self.nheads_local_tp, # dt + ] + setattr(self.in_proj.weight, "partition_sizes", in_proj_partition_sizes) if not self.use_mem_eff_path: log_single_rank( @@ -286,6 +299,17 @@ def __init__( setattr(self.conv1d.weight, "partition_dim", 0) setattr(self.conv1d.bias, "tensor_model_parallel", True) setattr(self.conv1d.bias, "partition_dim", 0) + # partition_sizes describes the per-TP-rank block sizes along the + # partition dim. conv1d packs [x, B, C] whose local sizes differ, + # so a plain contiguous concat would produce the wrong layout when + # resharding across different TP sizes. + conv_partition_sizes = [ + self.d_inner_local_tp, + self.ngroups_local_tp * self.d_state, + self.ngroups_local_tp * self.d_state, + ] + setattr(self.conv1d.weight, "partition_sizes", conv_partition_sizes) + setattr(self.conv1d.bias, "partition_sizes", conv_partition_sizes) if self.config.perform_initialization: if self.conv_init is not None: nn.init.uniform_(self.conv1d.weight, -self.conv_init, self.conv_init) @@ -323,7 +347,6 @@ def __init__( self.A_log = nn.Parameter(A_log) setattr(self.A_log, "tensor_model_parallel", True) setattr(self.A_log, "partition_dim", 0) - # D "skip" parameter self.D = nn.Parameter( torch.ones( @@ -333,7 +356,6 @@ def __init__( ) # Keep in fp32 setattr(self.D, "tensor_model_parallel", True) setattr(self.D, "partition_dim", 0) - if self.rmsnorm: assert RMSNormGated is not None self.norm = ExtendedRMSNorm( @@ -346,7 +368,6 @@ def __init__( ) setattr(self.norm.weight, "tensor_model_parallel", True) setattr(self.norm.weight, "partition_dim", 0) - # Assume sequence parallelism: input is partitioned along d_inner and # output is partitioned along the sequence dimension self.out_proj = build_module( diff --git a/megatron/core/transformer/attention.py b/megatron/core/transformer/attention.py index 6d409632c98..24fdc219772 100644 --- a/megatron/core/transformer/attention.py +++ b/megatron/core/transformer/attention.py @@ -1420,8 +1420,10 @@ def get_query_key_value_tensors( # 4. Further index into query to get only the q_heads that this rank is # responsible for (e.g., q1). # The block of code below performs steps 1 and 2. - mixed_qkv = all_gather_last_dim_from_tensor_parallel_region(mixed_qkv) - idx = get_tensor_model_parallel_rank() // ( + mixed_qkv = all_gather_last_dim_from_tensor_parallel_region( + mixed_qkv, group=self.pg_collection.tp + ) + idx = get_pg_rank(self.pg_collection.tp) // ( self.world_size // self.config.num_query_groups ) size = mixed_qkv.size()[-1] // self.config.num_query_groups @@ -1478,7 +1480,7 @@ def get_query_key_value_tensors( # query above corresponds to (num_q_heads / num_kv_heads) q_heads. # Index appropriately into query to get (num_q_heads / tp_size) q_heads. # This is step 4 in the list of steps above. - idx = get_tensor_model_parallel_rank() % ( + idx = get_pg_rank(self.pg_collection.tp) % ( self.world_size // self.config.num_query_groups ) size = self.num_attention_heads_per_partition // ( diff --git a/megatron/core/transformer/multi_token_prediction.py b/megatron/core/transformer/multi_token_prediction.py index c82c161a746..3648b338b3f 100755 --- a/megatron/core/transformer/multi_token_prediction.py +++ b/megatron/core/transformer/multi_token_prediction.py @@ -749,6 +749,7 @@ def __init__( self.layer_number = layer_number + get_mtp_layer_offset(self.config, vp_stage) self.vp_stage = vp_stage self.cp_group = pg_collection.cp + self.tp_group = pg_collection.tp if pg_collection is not None else None self.mtp_layer_pattern = mtp_layer_pattern # Validate attention mask type if using transformer-based inner layers @@ -807,6 +808,7 @@ def __init__( skip_bias_add=False, is_expert=False, tp_comm_buffer_name="mtp_eh_proj", + tp_group=pg_collection.tp if pg_collection is not None else None, ) # Build inner layers: two possible paths @@ -838,6 +840,7 @@ def __init__( vp_stage=self.vp_stage, layer_number=self.layer_number, is_mtp_layer=True, + pg_collection=pg_collection, ) self.final_layernorm = self.submodules.layer_norm( @@ -909,10 +912,10 @@ def _concat_embeddings(self, hidden_states: torch.Tensor, decoder_input: torch.T # `all_gather_last_dim_from_tensor_parallel_region`, but that utility reduces # the gradient in backward pass and was therefore incorrect in this context. # It has been replaced with the correct `gather_from_tensor_model_parallel_region`. - hidden_states = gather_from_tensor_model_parallel_region(hidden_states) + hidden_states = gather_from_tensor_model_parallel_region(hidden_states, group=self.tp_group) # For sequence parallel, scatter after linear_fc and before transformer layer. if self.sequence_parallel: - hidden_states = scatter_to_sequence_parallel_region(hidden_states) + hidden_states = scatter_to_sequence_parallel_region(hidden_states, group=self.tp_group) return hidden_states def _proj_and_transformer_layer( @@ -1295,7 +1298,7 @@ def __init__( # to the roll_tensor function for proper boundary communication if pg_collection is None: # Use default MPU process groups if not provided - pg_collection = ProcessGroupCollection.use_mpu_process_groups(required_pgs=['cp']) + pg_collection = ProcessGroupCollection.use_mpu_process_groups(required_pgs=['cp', 'tp']) else: # Ensure the provided process groups include CP assert hasattr( diff --git a/tests/unit_tests/resharding/test_model_swap.py b/tests/unit_tests/resharding/test_model_swap.py index 19cb2306bf7..70d81d97829 100644 --- a/tests/unit_tests/resharding/test_model_swap.py +++ b/tests/unit_tests/resharding/test_model_swap.py @@ -1,7 +1,9 @@ # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. import copy +import gc import os import types +from dataclasses import fields from typing import List, Optional, Tuple import pytest @@ -14,6 +16,7 @@ from megatron.core.models.gpt.gpt_layer_specs import ( get_gpt_layer_local_spec, get_gpt_layer_with_transformer_engine_spec, + get_gpt_mtp_block_spec, ) from megatron.core.models.gpt.gpt_model import GPTModel from megatron.core.process_groups_config import ProcessGroupCollection @@ -31,6 +34,16 @@ except Exception: has_nvshmem = False +try: + import mamba_ssm # noqa: F401 + + from megatron.core.models.mamba.mamba_layer_specs import mamba_stack_spec + from megatron.core.models.mamba.mamba_model import MambaModel + + has_mamba_deps = True +except Exception: + has_mamba_deps = False + def _build_pg_collection( tp_size: int, pp_size: int = None, ep_size: int = 1 @@ -82,6 +95,70 @@ def _build_pg_collection( ) +def _destroy_pg_collection(pgc: ProcessGroupCollection): + """Destroy all process groups in a ProcessGroupCollection to free NCCL communicator memory.""" + destroyed = set() + for f in fields(pgc): + pg = getattr(pgc, f.name, None) + if pg is not None and id(pg) not in destroyed: + destroyed.add(id(pg)) + dist.destroy_process_group(pg) + + +def _pp_flags(pg_collection) -> Tuple[bool, bool]: + """Return (pre_process, post_process) based on pipeline-parallel rank.""" + pp_group = pg_collection.pp + pp_rank = dist.get_rank(pp_group) + pp_size = dist.get_world_size(pp_group) + return pp_rank == 0, pp_rank == pp_size - 1 + + +def _run_forward(model, tokens, position_ids, attention_mask, pg_collection): + """Run a forward pass using Megatron's pipeline schedule. + + For PP=1 this is a simple forward call. For PP>1 this delegates to the + Megatron pipeline schedule which handles P2P communication between stages. + + Returns logits on the last PP stage, None on other stages. + """ + from megatron.core.pipeline_parallel import get_forward_backward_func + from megatron.core.pipeline_parallel.p2p_communication import P2PCommunicator + + pp_group = pg_collection.pp + pp_size = dist.get_world_size(pp_group) + batch, seq_len = tokens.shape + + def forward_step_func(data_iterator, model): + output = model(tokens, position_ids, attention_mask) + + def loss_func(output_tensor, non_loss_data=False): + if non_loss_data: + return output_tensor + return output_tensor.sum(), {"logits": output_tensor} + + return output, loss_func + + forward_backward_func = get_forward_backward_func(pp_size=pp_size, vp_size=None) + kwargs = dict( + forward_step_func=forward_step_func, + data_iterator=iter([None]), + model=[model], + num_microbatches=1, + seq_length=seq_len, + micro_batch_size=batch, + forward_only=True, + collect_non_loss_data=True, + pg_collection=pg_collection, + ) + if pp_size > 1: + kwargs["p2p_communicator"] = P2PCommunicator(pp_group, model.config) + result = forward_backward_func(**kwargs) + # result is a list of per-microbatch outputs; only populated on last PP stage + if result and result[0] is not None: + return result[0] + return None + + def _build_gpt( config: TransformerConfig, vocab_size: int, @@ -90,25 +167,67 @@ def _build_gpt( parallel_output: bool = True, num_moe_experts: Optional[int] = None, ) -> GPTModel: + layer_spec = get_gpt_layer_with_transformer_engine_spec( + num_experts=num_moe_experts, moe_grouped_gemm=(num_moe_experts is not None) + ) + mtp_block_spec = None + if config.mtp_num_layers: + mtp_block_spec = get_gpt_mtp_block_spec( + config=config, spec=layer_spec, use_transformer_engine=True + ) + pre_process, post_process = _pp_flags(pg_collection) model = GPTModel( config=config, - transformer_layer_spec=get_gpt_layer_with_transformer_engine_spec( - num_experts=num_moe_experts, moe_grouped_gemm=(num_moe_experts is not None) - ), + transformer_layer_spec=layer_spec, vocab_size=vocab_size, max_sequence_length=seq_len, - pre_process=True, - post_process=True, + pre_process=pre_process, + post_process=post_process, fp16_lm_cross_entropy=False, parallel_output=parallel_output, - share_embeddings_and_output_weights=True, + share_embeddings_and_output_weights=False, position_embedding_type="rope", rotary_percent=1.0, pg_collection=pg_collection, + mtp_block_spec=mtp_block_spec, ) return model +def _build_mamba( + config: TransformerConfig, + vocab_size: int, + seq_len: int, + pg_collection, + hybrid_layer_pattern: str, + parallel_output: bool = True, +): + pre_process, post_process = _pp_flags(pg_collection) + model = MambaModel( + config=config, + mamba_stack_spec=mamba_stack_spec, + vocab_size=vocab_size, + max_sequence_length=seq_len, + hybrid_layer_pattern=hybrid_layer_pattern, + pre_process=pre_process, + post_process=post_process, + fp16_lm_cross_entropy=False, + parallel_output=parallel_output, + share_embeddings_and_output_weights=False, + pg_collection=pg_collection, + ) + return model + + +def _mamba_layer_pattern(base: str, num_layers: int, pp_size: int) -> str: + """Build hybrid_layer_pattern with '|' pipeline stage boundaries.""" + layers_per_stage = num_layers // pp_size + unit_len = len(base) + repeats_per_stage = layers_per_stage // unit_len + stage = base * repeats_per_stage + return "|".join([stage] * pp_size) + + def _mp_config() -> ModelParallelConfig: return ModelParallelConfig( params_dtype=torch.float32, @@ -138,25 +257,41 @@ def _set_pg_collection(module, tp_group, dp_group): ], ) @pytest.mark.parametrize( - "src_tp,src_pp,src_ep,dst_tp,dst_pp,dst_ep,num_experts", + "src_tp,src_pp,src_ep,dst_tp,dst_pp,dst_ep,num_experts,moe_mode", [ - # TP only changes - (2, 1, 1, 1, 1, 1, None), # TP2 -> TP1 - (1, 1, 1, 2, 1, 1, None), # TP1 -> TP2 - (2, 1, 1, 4, 1, 1, None), # TP2 -> TP4 - # # PP only changes - (1, 2, 1, 1, 1, 1, None), # PP2 -> PP1 - (1, 1, 1, 1, 2, 1, None), # PP1 -> PP2 - # # Both TP and PP change - (2, 2, 1, 1, 1, 1, None), # TP2,PP2 -> TP1,PP1 - (1, 1, 1, 2, 2, 1, None), # TP1,PP1 -> TP2,PP2 - (2, 1, 1, 1, 2, 1, None), # TP2,PP1 -> TP1,PP2 - (1, 2, 1, 2, 1, 1, None), # TP1,PP2 -> TP2,PP1 - (1, 2, 1, 2, 4, 1, None), # TP1,PP2 -> TP2,PP4 - (1, 1, 2, 1, 1, 4, 4), # EP2 -> EP4 - (1, 1, 2, 1, 1, 1, 4), # EP2 -> EP1 - (1, 1, 1, 1, 1, 2, 4), - (1, 1, 2, 1, 2, 2, 4), + # ---- Non-MoE: TP only changes ---- + (2, 1, 1, 1, 1, 1, None, None), # TP2 -> TP1 + (1, 1, 1, 2, 1, 1, None, None), # TP1 -> TP2 + (2, 1, 1, 4, 1, 1, None, None), # TP2 -> TP4 + # ---- Non-MoE: PP only changes ---- + (1, 2, 1, 1, 1, 1, None, None), # PP2 -> PP1 + (1, 1, 1, 1, 2, 1, None, None), # PP1 -> PP2 + # ---- Non-MoE: Both TP and PP change ---- + (2, 2, 1, 1, 1, 1, None, None), # TP2,PP2 -> TP1,PP1 + (1, 1, 1, 2, 2, 1, None, None), # TP1,PP1 -> TP2,PP2 + (2, 1, 1, 1, 2, 1, None, None), # TP2,PP1 -> TP1,PP2 + (1, 2, 1, 2, 1, 1, None, None), # TP1,PP2 -> TP2,PP1 + (1, 2, 1, 2, 4, 1, None, None), # TP1,PP2 -> TP2,PP4 + # ---- MoE: EP changes (standard) ---- + (1, 1, 2, 1, 1, 4, 4, None), # EP2 -> EP4 + (1, 1, 2, 1, 1, 1, 4, None), # EP2 -> EP1 + (1, 1, 1, 1, 1, 2, 4, None), # EP1 -> EP2 + (1, 1, 2, 1, 2, 2, 4, None), # EP2 -> PP2,EP2 + # ---- MoE: mixed TP + EP (standard) ---- + (2, 1, 2, 1, 1, 1, 4, None), # TP2,EP2 -> TP1,EP1 + (1, 1, 1, 2, 1, 2, 4, None), # TP1,EP1 -> TP2,EP2 + (4, 1, 1, 2, 1, 2, 4, None), # TP4,EP1 -> TP2,EP2 + (2, 1, 2, 4, 1, 1, 4, None), # TP2,EP2 -> TP4,EP1 + (4, 1, 1, 1, 1, 4, 4, None), # TP4,EP1 -> TP1,EP4 + (1, 1, 4, 4, 1, 1, 4, None), # EP4 -> TP4,EP1 + # ---- MoE latent: representative configs ---- + (1, 1, 2, 1, 1, 1, 4, "latent"), # EP2 -> EP1 + (2, 1, 2, 1, 1, 1, 4, "latent"), # TP2,EP2 -> TP1,EP1 + (1, 1, 1, 2, 1, 2, 4, "latent"), # TP1,EP1 -> TP2,EP2 + # ---- MoE latent + MTP: representative configs ---- + (1, 1, 1, 1, 1, 2, 4, "latent_mtp"), # EP1 -> EP2 + (2, 1, 2, 1, 1, 1, 4, "latent_mtp"), # TP2,EP2 -> TP1,EP1 + (1, 1, 1, 2, 1, 2, 4, "latent_mtp"), # TP1,EP1 -> TP2,EP2 ], ) def test_swap_gpt_parametrized( @@ -168,20 +303,20 @@ def test_swap_gpt_parametrized( dst_pp: int, dst_ep: int, num_experts: Optional[int], + moe_mode: Optional[str], ): - # Initialize environment with source MP sizing + Utils.initialize_model_parallel( tensor_model_parallel_size=src_tp, pipeline_model_parallel_size=src_pp ) - # Validate divisibility post-init using the default PG safely world = dist.get_world_size() if (world % (src_tp * src_pp * src_ep) != 0) or (world % (dst_tp * dst_pp * dst_ep) != 0): Utils.destroy_model_parallel() pytest.skip( "WORLD_SIZE must be divisible by both src_tp*src_pp*src_ep and dst_tp*dst_pp*dst_ep" ) - model_parallel_cuda_manual_seed(1234) + model_parallel_cuda_manual_seed(1234) torch.manual_seed(1234) device = torch.device(f"cuda:{torch.cuda.current_device()}") @@ -205,29 +340,32 @@ def test_swap_gpt_parametrized( # Build PGs and models (always use unified PG builder so we can set EP) src_pgs = _build_pg_collection(tp_size=src_tp, pp_size=src_pp, ep_size=src_ep) dst_pgs = _build_pg_collection(tp_size=dst_tp, pp_size=dst_pp, ep_size=dst_ep) - # Apply EP configuration to TransformerConfigs when MoE is requested + # Apply PP/EP configuration to TransformerConfigs src_cfg = copy.deepcopy(cfg) dst_cfg = copy.deepcopy(cfg) + src_cfg.pipeline_model_parallel_size = src_pp + dst_cfg.pipeline_model_parallel_size = dst_pp + if num_experts is not None: - src_cfg.num_moe_experts = num_experts - dst_cfg.num_moe_experts = num_experts - # Ensure MoE MLP has an intermediate size; __post_init__ won't rerun after manual mutation - src_cfg.moe_ffn_hidden_size = src_cfg.ffn_hidden_size - dst_cfg.moe_ffn_hidden_size = dst_cfg.ffn_hidden_size - src_cfg.expert_model_parallel_size = src_ep - dst_cfg.expert_model_parallel_size = dst_ep - # Force grouped MLP path under Transformer Engine and satisfy requirements - src_cfg.moe_grouped_gemm = True - dst_cfg.moe_grouped_gemm = True - src_cfg.add_bias_linear = False - dst_cfg.add_bias_linear = False - # Require Transformer Engine for TEGroupedMLP; skip if unavailable + for c, ep in [(src_cfg, src_ep), (dst_cfg, dst_ep)]: + c.num_moe_experts = num_experts + c.moe_ffn_hidden_size = c.ffn_hidden_size + c.expert_model_parallel_size = ep + c.moe_grouped_gemm = True + c.add_bias_linear = False + if moe_mode in ("latent", "latent_mtp"): + c.moe_latent_size = 16 + c.moe_shared_expert_intermediate_size = 64 + c.activation_func = torch.nn.functional.silu + c.gated_linear_unit = True + if moe_mode == "latent_mtp": + c.mtp_num_layers = 1 try: import transformer_engine except Exception: Utils.destroy_model_parallel() - pytest.skip("Transformer Engine not available; skipping TE-grouped MoE test") - # Use parallel_output=False to gather TP logits inside model and emit only on last PP stage + pytest.skip("Transformer Engine not available; skipping MoE refit test") + src_model = ( _build_gpt( src_cfg, @@ -263,39 +401,184 @@ def test_swap_gpt_parametrized( ) attention_mask = torch.ones((batch, 1, seq_len, seq_len), device=device, dtype=torch.bool) - # Collect source reference logits (parallel_output=False ensures full vocab on last PP stage) + # Collect source reference logits ref_logits = torch.empty(batch, seq_len, vocab_size, device=device, dtype=torch.float32) src_pp_ranks = dist.get_process_group_ranks(src_pgs.pp) src_last_pp_rank = src_pp_ranks[-1] with torch.no_grad(): - src_out = src_model(tokens, position_ids, attention_mask) + src_out = _run_forward(src_model, tokens, position_ids, attention_mask, src_pgs) if dist.get_rank() == src_last_pp_rank: - ref = src_out # [b, s, vocab] - ref_logits.copy_(ref) + ref_logits.copy_(src_out) dist.broadcast(ref_logits, src=src_last_pp_rank, group=src_pgs.pp) # Swap weights swap_model_weights([src_model], [dst_model], refit_method=refit_backend) - # Collect destination logits (parallel_output=False ensures full vocab on last PP stage) + # Collect destination logits dst_logits = torch.empty(batch, seq_len, vocab_size, device=device, dtype=torch.float32) dst_pp_ranks = dist.get_process_group_ranks(dst_pgs.pp) dst_last_pp_rank = dst_pp_ranks[-1] with torch.no_grad(): - dst_out = dst_model( - tokens, position_ids, attention_mask - ) # last stage returns tensor, others return None + dst_out = _run_forward(dst_model, tokens, position_ids, attention_mask, dst_pgs) if dist.get_rank() == dst_last_pp_rank: - dst_logits.copy_(dst_out) # [b, s, vocab] + dst_logits.copy_(dst_out) dist.broadcast(dst_logits, src=dst_last_pp_rank, group=dst_pgs.pp) # Compare assert ref_logits.shape == dst_logits.shape - assert torch.allclose( - dst_logits, ref_logits, atol=1e-4, rtol=1e-4 - ), f"Refit src(TP={src_tp},PP={src_pp})->dst(TP={dst_tp},PP={dst_pp}) GPT outputs differ" + max_diff = (dst_logits - ref_logits).abs().max().item() + assert torch.allclose(dst_logits, ref_logits, atol=5e-4, rtol=5e-4), ( + f"Refit src(TP={src_tp},PP={src_pp},EP={src_ep})" + f"->dst(TP={dst_tp},PP={dst_pp},EP={dst_ep}) " + f"moe_mode={moe_mode} outputs differ (max_diff={max_diff:.6f})" + ) dist.barrier() + # Free GPU memory to prevent OOM across the many parametrized test cases + del src_model, dst_model # Clear refit caches before destroying model parallel to avoid stale plans clear_all_caches() + _destroy_pg_collection(src_pgs) + _destroy_pg_collection(dst_pgs) + Utils.destroy_model_parallel() + gc.collect() + torch.cuda.empty_cache() + + +@pytest.mark.parametrize( + "refit_backend", + [ + pytest.param( + "nvshmem", + marks=pytest.mark.skipif( + not has_nvshmem, + reason="nvshmem.core is not available (NVSHMEM Python bindings not installed)", + ), + ), + "nccl", + "gloo", + ], +) +@pytest.mark.parametrize( + "src_tp,src_pp,dst_tp,dst_pp", + [ + # TP only changes (exercises block-interleaved planner for Mamba in_proj) + (2, 1, 1, 1), # TP2 -> TP1 + (1, 1, 2, 1), # TP1 -> TP2 + (2, 1, 4, 1), # TP2 -> TP4 + # TP + PP change together + (1, 1, 2, 2), # TP1,PP1 -> TP2,PP2 + (2, 1, 1, 2), # TP2,PP1 -> TP1,PP2 + ], +) +def test_swap_mamba_parametrized( + refit_backend: str, src_tp: int, src_pp: int, dst_tp: int, dst_pp: int +): + if not has_mamba_deps: + pytest.skip("Mamba dependencies (mamba_ssm, einops) not available") + + Utils.initialize_model_parallel( + tensor_model_parallel_size=src_tp, pipeline_model_parallel_size=src_pp + ) + world = dist.get_world_size() + if (world % (src_tp * src_pp) != 0) or (world % (dst_tp * dst_pp) != 0): + Utils.destroy_model_parallel() + pytest.skip("WORLD_SIZE must be divisible by both src_tp*src_pp and dst_tp*dst_pp") + + model_parallel_cuda_manual_seed(1234) + torch.manual_seed(1234) + device = torch.device(f"cuda:{torch.cuda.current_device()}") + + # Small Mamba config — use "M*" hybrid pattern to test both Mamba layers + # (block-interleaved in_proj resharding) and attention layers together. + seq_len = 8 + vocab_size = 128 + base_pattern = "M*" + # Ensure enough layers for both PP configs (at least len(base_pattern) per stage) + min_layers = max(src_pp, dst_pp) * len(base_pattern) + num_layers = max(min_layers, 4 if (src_pp > 1 or dst_pp > 1) else 2) + # Round up to be divisible by both pp_size * unit_len + from math import lcm + + factor = lcm(src_pp, dst_pp) * len(base_pattern) + num_layers = ((num_layers + factor - 1) // factor) * factor + + cfg = TransformerConfig( + num_layers=num_layers, + hidden_size=256, + num_attention_heads=8, + num_query_groups=4, + use_cpu_initialization=True, + pipeline_dtype=torch.float32, + hidden_dropout=0.0, + attention_dropout=0.0, + ) + + src_pgs = _build_pg_collection(tp_size=src_tp, pp_size=src_pp) + dst_pgs = _build_pg_collection(tp_size=dst_tp, pp_size=dst_pp) + + src_pattern = _mamba_layer_pattern(base_pattern, num_layers, src_pp) + dst_pattern = _mamba_layer_pattern(base_pattern, num_layers, dst_pp) + + src_model = ( + _build_mamba(cfg, vocab_size, seq_len, src_pgs, src_pattern, parallel_output=False) + .to(device) + .eval() + ) + dst_model = ( + _build_mamba(cfg, vocab_size, seq_len, dst_pgs, dst_pattern, parallel_output=False) + .to(device) + .eval() + ) + + # Inputs + batch = 2 + tokens = torch.randint( + low=0, high=vocab_size, size=(batch, seq_len), device=device, dtype=torch.long + ) + position_ids = ( + torch.arange(seq_len, device=device, dtype=torch.long).unsqueeze(0).expand(batch, -1) + ) + attention_mask = torch.ones((batch, 1, seq_len, seq_len), device=device, dtype=torch.bool) + + # Collect source reference logits + ref_logits = torch.empty(batch, seq_len, vocab_size, device=device, dtype=torch.float32) + src_pp_ranks = dist.get_process_group_ranks(src_pgs.pp) + src_last_pp_rank = src_pp_ranks[-1] + with torch.no_grad(): + src_out = _run_forward(src_model, tokens, position_ids, attention_mask, src_pgs) + if dist.get_rank() == src_last_pp_rank: + ref_logits.copy_(src_out) + dist.broadcast(ref_logits, src=src_last_pp_rank, group=src_pgs.pp) + + # Swap weights + swap_model_weights([src_model], [dst_model], refit_method=refit_backend) + + # Collect destination logits + dst_logits = torch.empty(batch, seq_len, vocab_size, device=device, dtype=torch.float32) + dst_pp_ranks = dist.get_process_group_ranks(dst_pgs.pp) + dst_last_pp_rank = dst_pp_ranks[-1] + with torch.no_grad(): + dst_out = _run_forward(dst_model, tokens, position_ids, attention_mask, dst_pgs) + if dist.get_rank() == dst_last_pp_rank: + dst_logits.copy_(dst_out) + dist.broadcast(dst_logits, src=dst_last_pp_rank, group=dst_pgs.pp) + + # Compare + assert ref_logits.shape == dst_logits.shape + max_diff = (dst_logits - ref_logits).abs().max().item() + assert torch.allclose(dst_logits, ref_logits, atol=1e-3, rtol=1e-3), ( + f"Mamba refit src(TP={src_tp},PP={src_pp})" + f"->dst(TP={dst_tp},PP={dst_pp}) " + f"outputs differ (max_diff={max_diff:.6f})" + ) + dist.barrier() + + # Free GPU memory to prevent OOM across the many parametrized test cases + del src_model, dst_model + clear_all_caches() + _destroy_pg_collection(src_pgs) + _destroy_pg_collection(dst_pgs) Utils.destroy_model_parallel() + gc.collect() + torch.cuda.empty_cache()