Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions megatron/core/extensions/transformer_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -1729,6 +1729,13 @@ def __init__(

self.explicit_expert_comm = is_expert and (tp_size > 1 or self.expert_parallel)

# Save original parallel_mode before clearing it for explicit_expert_comm.
# When explicit_expert_comm is True, Megatron handles TP communication externally
# and passes parallel_mode=None to TE. This causes TE to set partition_dim=0 on
# all weights (its default for non-parallel mode). We need to fix this after init
# so that refit/resharding can correctly identify which dimension is TP-partitioned.
original_parallel_mode = parallel_mode

if self.explicit_expert_comm:
if parallel_mode == "column":
output_size = divide(output_size, tp_size)
Expand Down Expand Up @@ -1759,6 +1766,15 @@ def __init__(
for param in self.parameters():
setattr(param, "allreduce", not (is_expert and self.expert_parallel))

# Fix partition_dim when explicit_expert_comm cleared parallel_mode.
# TE defaults to partition_dim=0 when parallel_mode=None, but row-parallel
# weights are partitioned along dim=1 (input dimension).
if self.explicit_expert_comm and original_parallel_mode == "row":
for i in range(num_gemms):
weight = getattr(self, f"weight{i}", None)
if weight is not None and hasattr(weight, "partition_dim"):
weight.partition_dim = 1

def merge_extra_states(
self,
state_dict,
Expand Down
5 changes: 4 additions & 1 deletion megatron/core/models/gpt/gpt_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,10 @@ def __init__(

if self.mtp_process:
self.mtp = MultiTokenPredictionBlock(
config=self.config, spec=self.mtp_block_spec, vp_stage=vp_stage
config=self.config,
spec=self.mtp_block_spec,
vp_stage=vp_stage,
pg_collection=self.pg_collection,
)

# Output
Expand Down
130 changes: 130 additions & 0 deletions megatron/core/resharding/planner.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,126 @@ def dst_key(op):
return ops


def _plan_block_interleaved(
param_name: str,
src_metadata: ParameterMetadata,
dst_metadata: ParameterMetadata,
descriptors: list[ShardingDescriptor],
my_global_rank: int,
) -> list[tuple[int, tuple[slice, ...], tuple[slice, ...]]]:
"""
Block-interleaved TP planner for parameters with ``partition_sizes``.

When a parameter packs multiple independently-sharded components of
*different* sizes (e.g. Mamba in_proj packs z, x, B, C, dt), a simple
contiguous concat produces the wrong layout. This function treats each
block independently: it gathers (or scatters) each block across TP ranks
before moving to the next block.

``partition_sizes`` lists the per-TP-rank block sizes along the partition
dim. Block *i* occupies ``[sum(sizes[:i]), sum(sizes[:i+1]))`` in the
local tensor on every TP rank. In the *full* (TP-gathered) tensor, block
*i* occupies ``[sum(full_sizes[:i]), sum(full_sizes[:i+1]))`` where
``full_sizes[i] = sizes[i] * src_tp_world``.
"""
if not descriptors or descriptors[0].name != "tp":
return []
d = descriptors[0]
if my_global_rank not in d.dst_dim_ranks:
return []

dim = d.dim
src_shape = tuple(src_metadata.shape)
dst_shape = tuple(dst_metadata.shape)
src_world = len(d.src_dim_ranks)
dst_world = len(d.dst_dim_ranks)
dst_local_rank = _get_rank_in_group(my_global_rank, d.dst_dim_ranks)

# Use partition_sizes from whichever side has it (prefer src)
src_sizes = src_metadata.partition_sizes
dst_sizes = dst_metadata.partition_sizes

if src_sizes is None and dst_sizes is None:
raise RuntimeError(f"{param_name}: _plan_block_interleaved called without partition_sizes")

# Derive the full (un-sharded) block sizes
if src_sizes is not None:
num_blocks = len(src_sizes)
full_sizes = [s * src_world for s in src_sizes]
else:
num_blocks = len(dst_sizes)
full_sizes = [s * dst_world for s in dst_sizes]

# Compute per-rank block sizes for both sides
if src_sizes is None:
src_sizes = [f // src_world for f in full_sizes]
if dst_sizes is None:
dst_sizes = [f // dst_world for f in full_sizes]

# Validate conservation
for i in range(num_blocks):
if src_sizes[i] * src_world != dst_sizes[i] * dst_world:
raise RuntimeError(
f"{param_name}: block {i} size mismatch: "
f"src_sizes[{i}]={src_sizes[i]}*{src_world} != "
f"dst_sizes[{i}]={dst_sizes[i]}*{dst_world}"
)

ops: list[tuple[int, tuple[slice, ...], tuple[slice, ...]]] = []

# For each block, compute the transfer ops independently
src_block_offset = 0 # cumulative offset in source local tensor
dst_block_offset = 0 # cumulative offset in destination local tensor

for blk in range(num_blocks):
src_blk_sz = src_sizes[blk] # per-src-rank size of this block
dst_blk_sz = dst_sizes[blk] # per-dst-rank size of this block
full_blk_sz = full_sizes[blk]

# Within this block, use simple LCM tiling (stride=1)
Ns = src_world
Nd = dst_world
g = math.gcd(Ns, Nd)
L = (Ns // g) * Nd
if full_blk_sz % L != 0:
raise RuntimeError(
f"{param_name}: block {blk} full_size {full_blk_sz} not divisible by LCM {L}"
)
unit = full_blk_sz // L
cps = L // Ns
cpd = L // Nd

# This dst rank's segment within the block
g_dst_seg = dst_local_rank
for off in range(cpd):
g_micro = g_dst_seg * cpd + off
s_idx = g_micro // cps
in_seg = g_micro % cps
src_owner_in_dim = s_idx % src_world
src_global_rank = d.src_dim_ranks[src_owner_in_dim]
src_local_seg_idx = s_idx // src_world
src_start = src_block_offset + src_local_seg_idx * (cps * unit) + in_seg * unit
dst_start = dst_block_offset + off * unit

src_slice = [slice(None)] * len(src_shape)
dst_slice = [slice(None)] * len(dst_shape)
src_slice[dim] = slice(src_start, src_start + unit)
dst_slice[dim] = slice(dst_start, dst_start + unit)
ops.append((src_global_rank, tuple(src_slice), tuple(dst_slice)))

src_block_offset += src_blk_sz
dst_block_offset += dst_blk_sz

# Stable sort by destination offset
def dst_key(op):
_, _, dsl = op
s = dsl[dim]
return s.start if isinstance(s, slice) else 0

ops.sort(key=dst_key)
return ops


def _finalize_dp_transfers(
param_name: str,
src_metadata: ParameterMetadata,
Expand Down Expand Up @@ -210,6 +330,16 @@ def _determine_source_ranks_for_dst_param(
# Regular TP/DP planning with EP-resolved metadata
descriptors = _build_descriptors_for_param(src_metadata=src_metadata, dst_metadata=dst_metadata)
if descriptors:
# Use block-interleaved planner when partition_sizes is present
# (e.g. Mamba in_proj packs components of different sizes)
if src_metadata.partition_sizes is not None or dst_metadata.partition_sizes is not None:
return _plan_block_interleaved(
param_name=param_name,
src_metadata=src_metadata,
dst_metadata=dst_metadata,
descriptors=descriptors,
my_global_rank=my_global_rank,
)
return _plan_multi_dim_lcm(
param_name=param_name,
src_metadata=src_metadata,
Expand Down
9 changes: 9 additions & 0 deletions megatron/core/resharding/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,11 @@ class ParameterMetadata:
is_tp: bool = False
partition_dim: int = 0
partition_stride: int = 1
# For parameters that pack multiple independently-sharded components of
# different sizes (e.g. Mamba in_proj packs z, x, B, C, dt). When present,
# lists the per-TP-rank block sizes along partition_dim. The refit planner
# interleaves these blocks rather than doing a simple contiguous concat.
partition_sizes: list[int] | None = None

# EP sharding info (fused/grouped MoE)
is_ep: bool = False
Expand Down Expand Up @@ -258,6 +263,9 @@ def extract_param_metadata(
is_tp = bool(getattr(param, 'tensor_model_parallel', False))
partition_dim = int(getattr(param, 'partition_dim', 0))
partition_stride = int(getattr(param, 'partition_stride', 1))
partition_sizes = getattr(param, 'partition_sizes', None)
if partition_sizes is not None:
partition_sizes = list(partition_sizes)

# SwiGLU/GLU compatibility: For gated linear units, fc1 stores interleaved [gate, up] portions
# and requires partition_stride=2 for correct resharding. New models set this at construction
Expand Down Expand Up @@ -318,6 +326,7 @@ def _offset_ranks(ranks: list[int]) -> list[int]:
is_tp=is_tp,
partition_dim=partition_dim,
partition_stride=partition_stride,
partition_sizes=partition_sizes,
is_ep=is_ep,
num_experts=num_experts,
owner_rank=owner_rank,
Expand Down
27 changes: 24 additions & 3 deletions megatron/core/ssm/mamba_mixer.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,19 @@ def __init__(
tp_comm_buffer_name="fc1",
tp_group=self.pg_collection.tp,
)
# in_proj packs [z, x, B, C, dt] into one ColumnParallelLinear. Each
# component is independently TP-sharded but with different sizes. When
# resharding across different TP sizes the planner must interleave
# per-component blocks rather than doing a contiguous concat.
# partition_sizes lists the per-TP-rank block sizes along partition_dim.
in_proj_partition_sizes = [
self.d_inner_local_tp, # z
self.d_inner_local_tp, # x
self.ngroups_local_tp * self.d_state, # B
self.ngroups_local_tp * self.d_state, # C
self.nheads_local_tp, # dt
]
setattr(self.in_proj.weight, "partition_sizes", in_proj_partition_sizes)

if not self.use_mem_eff_path:
log_single_rank(
Expand Down Expand Up @@ -286,6 +299,17 @@ def __init__(
setattr(self.conv1d.weight, "partition_dim", 0)
setattr(self.conv1d.bias, "tensor_model_parallel", True)
setattr(self.conv1d.bias, "partition_dim", 0)
# partition_sizes describes the per-TP-rank block sizes along the
# partition dim. conv1d packs [x, B, C] whose local sizes differ,
# so a plain contiguous concat would produce the wrong layout when
# resharding across different TP sizes.
conv_partition_sizes = [
self.d_inner_local_tp,
self.ngroups_local_tp * self.d_state,
self.ngroups_local_tp * self.d_state,
]
setattr(self.conv1d.weight, "partition_sizes", conv_partition_sizes)
setattr(self.conv1d.bias, "partition_sizes", conv_partition_sizes)
if self.config.perform_initialization:
if self.conv_init is not None:
nn.init.uniform_(self.conv1d.weight, -self.conv_init, self.conv_init)
Expand Down Expand Up @@ -323,7 +347,6 @@ def __init__(
self.A_log = nn.Parameter(A_log)
setattr(self.A_log, "tensor_model_parallel", True)
setattr(self.A_log, "partition_dim", 0)

# D "skip" parameter
self.D = nn.Parameter(
torch.ones(
Expand All @@ -333,7 +356,6 @@ def __init__(
) # Keep in fp32
setattr(self.D, "tensor_model_parallel", True)
setattr(self.D, "partition_dim", 0)

if self.rmsnorm:
assert RMSNormGated is not None
self.norm = ExtendedRMSNorm(
Expand All @@ -346,7 +368,6 @@ def __init__(
)
setattr(self.norm.weight, "tensor_model_parallel", True)
setattr(self.norm.weight, "partition_dim", 0)

# Assume sequence parallelism: input is partitioned along d_inner and
# output is partitioned along the sequence dimension
self.out_proj = build_module(
Expand Down
8 changes: 5 additions & 3 deletions megatron/core/transformer/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -1420,8 +1420,10 @@ def get_query_key_value_tensors(
# 4. Further index into query to get only the q_heads that this rank is
# responsible for (e.g., q1).
# The block of code below performs steps 1 and 2.
mixed_qkv = all_gather_last_dim_from_tensor_parallel_region(mixed_qkv)
idx = get_tensor_model_parallel_rank() // (
mixed_qkv = all_gather_last_dim_from_tensor_parallel_region(
mixed_qkv, group=self.pg_collection.tp
)
idx = get_pg_rank(self.pg_collection.tp) // (
self.world_size // self.config.num_query_groups
)
size = mixed_qkv.size()[-1] // self.config.num_query_groups
Expand Down Expand Up @@ -1478,7 +1480,7 @@ def get_query_key_value_tensors(
# query above corresponds to (num_q_heads / num_kv_heads) q_heads.
# Index appropriately into query to get (num_q_heads / tp_size) q_heads.
# This is step 4 in the list of steps above.
idx = get_tensor_model_parallel_rank() % (
idx = get_pg_rank(self.pg_collection.tp) % (
self.world_size // self.config.num_query_groups
)
size = self.num_attention_heads_per_partition // (
Expand Down
9 changes: 6 additions & 3 deletions megatron/core/transformer/multi_token_prediction.py
Original file line number Diff line number Diff line change
Expand Up @@ -749,6 +749,7 @@ def __init__(
self.layer_number = layer_number + get_mtp_layer_offset(self.config, vp_stage)
self.vp_stage = vp_stage
self.cp_group = pg_collection.cp
self.tp_group = pg_collection.tp if pg_collection is not None else None
self.mtp_layer_pattern = mtp_layer_pattern

# Validate attention mask type if using transformer-based inner layers
Expand Down Expand Up @@ -807,6 +808,7 @@ def __init__(
skip_bias_add=False,
is_expert=False,
tp_comm_buffer_name="mtp_eh_proj",
tp_group=pg_collection.tp if pg_collection is not None else None,
)

# Build inner layers: two possible paths
Expand Down Expand Up @@ -838,6 +840,7 @@ def __init__(
vp_stage=self.vp_stage,
layer_number=self.layer_number,
is_mtp_layer=True,
pg_collection=pg_collection,
)

self.final_layernorm = self.submodules.layer_norm(
Expand Down Expand Up @@ -909,10 +912,10 @@ def _concat_embeddings(self, hidden_states: torch.Tensor, decoder_input: torch.T
# `all_gather_last_dim_from_tensor_parallel_region`, but that utility reduces
# the gradient in backward pass and was therefore incorrect in this context.
# It has been replaced with the correct `gather_from_tensor_model_parallel_region`.
hidden_states = gather_from_tensor_model_parallel_region(hidden_states)
hidden_states = gather_from_tensor_model_parallel_region(hidden_states, group=self.tp_group)
# For sequence parallel, scatter after linear_fc and before transformer layer.
if self.sequence_parallel:
hidden_states = scatter_to_sequence_parallel_region(hidden_states)
hidden_states = scatter_to_sequence_parallel_region(hidden_states, group=self.tp_group)
return hidden_states

def _proj_and_transformer_layer(
Expand Down Expand Up @@ -1295,7 +1298,7 @@ def __init__(
# to the roll_tensor function for proper boundary communication
if pg_collection is None:
# Use default MPU process groups if not provided
pg_collection = ProcessGroupCollection.use_mpu_process_groups(required_pgs=['cp'])
pg_collection = ProcessGroupCollection.use_mpu_process_groups(required_pgs=['cp', 'tp'])
else:
# Ensure the provided process groups include CP
assert hasattr(
Expand Down
Loading
Loading