Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion megatron/core/extensions/transformer_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -535,7 +535,7 @@ def __init__(
)

if is_te_min_version("0.8.0"):
if self.config.tp_comm_overlap:
if self.config.tp_comm_overlap and parallel_mode != "duplicated":
if is_te_min_version("1.5.0"):
# Use old overlap flags if they were supplied instead
extra_kwargs["ub_overlap_ag"] = (
Expand Down
7 changes: 5 additions & 2 deletions megatron/core/transformer/mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,9 +104,12 @@ def __init__(
if self.config.gated_linear_unit:
ffn_hidden_size *= 2

# Use moe_latent_size only for routed experts. 'is_expert' is false for shared_experts
use_latent_size = (self.config.moe_latent_size is not None) and is_expert

self.linear_fc1 = build_module(
submodules.linear_fc1,
self.input_size,
self.input_size if not use_latent_size else self.config.moe_latent_size,
ffn_hidden_size,
config=self.config,
init_method=self.config.init_method,
Expand All @@ -126,7 +129,7 @@ def __init__(
self.linear_fc2 = build_module(
submodules.linear_fc2,
self.config.ffn_hidden_size,
self.config.hidden_size,
self.config.hidden_size if not use_latent_size else self.config.moe_latent_size,
config=self.config,
init_method=self.config.output_layer_init_method,
bias=self.config.add_bias_linear,
Expand Down
11 changes: 9 additions & 2 deletions megatron/core/transformer/moe/experts.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,9 @@ def __init__(
assert (
config.add_bias_linear == False
), "bias not supported in Grouped GEMM yet, please set '--disable-bias-linear' instead."
assert (
config.moe_latent_size is None
), "MoE latent projection not supported in GroupedMLP yet."

self.expert_parallel = config.expert_model_parallel_size > 1
if self.config.gated_linear_unit:
Expand Down Expand Up @@ -778,7 +781,7 @@ def __init__(
self.linear_fc1 = build_module(
submodules.linear_fc1,
self.num_local_experts,
self.input_size,
self.input_size if self.config.moe_latent_size is None else self.config.moe_latent_size,
ffn_hidden_size,
config=self.config,
init_method=self.config.init_method,
Expand All @@ -799,7 +802,11 @@ def __init__(
submodules.linear_fc2,
self.num_local_experts,
self.config.moe_ffn_hidden_size,
self.config.hidden_size,
(
self.config.hidden_size
if self.config.moe_latent_size is None
else self.config.moe_latent_size
),
config=self.config,
init_method=self.config.output_layer_init_method,
bias=self.config.add_bias_linear,
Expand Down
40 changes: 39 additions & 1 deletion megatron/core/transformer/moe/moe_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
try:
import transformer_engine as te # pylint: disable=unused-import

from megatron.core.extensions.transformer_engine import te_checkpoint
from megatron.core.extensions.transformer_engine import TELinear, te_checkpoint

HAVE_TE = True
except ImportError:
Expand Down Expand Up @@ -123,6 +123,32 @@ def __init__(
# Initialize router
self.router = TopKRouter(config=self.config, pg_collection=pg_collection)

# Initialize latent projections
if self.config.moe_latent_size:
assert HAVE_TE, "TransformerEngine is required for MoE latent projections."
self.fc1_latent_proj = TELinear(
self.config.hidden_size,
self.config.moe_latent_size,
parallel_mode="duplicated",
config=self.config,
init_method=self.config.init_method,
bias=self.config.add_bias_linear,
skip_bias_add=False,
skip_weight_param_allocation=False,
is_expert=False,
)
self.fc2_latent_proj = TELinear(
self.config.moe_latent_size,
self.config.hidden_size,
parallel_mode="duplicated",
config=self.config,
init_method=self.config.output_layer_init_method,
bias=self.config.add_bias_linear,
skip_bias_add=False,
skip_weight_param_allocation=False,
is_expert=False,
)

# Initialize token dispatcher
if config.moe_token_dispatcher_type == "allgather":
self.token_dispatcher = MoEAllGatherTokenDispatcher(
Expand Down Expand Up @@ -176,6 +202,12 @@ def router_and_preprocess(self, hidden_states: torch.Tensor):
"""
residual = hidden_states
probs, routing_map = self.router(hidden_states)
# Project the hidden_states from hidden dimension down to latent dimenion.
if self.config.moe_latent_size:
assert (
not self.shared_expert_overlap
), "Shared expert overlap not supported when MoE latent projections are used."
hidden_states, _ = self.fc1_latent_proj(hidden_states)
hidden_states, probs = self.token_dispatcher.dispatch_preprocess(
hidden_states, routing_map, probs
)
Expand Down Expand Up @@ -243,6 +275,10 @@ def combine(self, output: torch.Tensor, shared_expert_output: Optional[torch.Ten
"""
output = self.token_dispatcher.token_combine(output)
output = self.token_dispatcher.combine_postprocess(output)
# Project the output back from latent dimension to hidden dimension after combine
# in latent dimension.
if self.config.moe_latent_size:
output, _ = self.fc2_latent_proj(output)
if shared_expert_output is not None:
output = output + shared_expert_output
return output
Expand Down Expand Up @@ -274,7 +310,9 @@ def custom_forward(hidden_states):
hidden_states, probs, residual = self.router_and_preprocess(hidden_states)
dispatched_input, probs = self.dispatch(hidden_states, probs)
output, mlp_bias = self.routed_experts_compute(dispatched_input, probs, residual)
assert mlp_bias is None, f"mlp_bias is not supported for {type(self.token_dispatcher)}"
output = self.combine(output, shared_expert_output)

return output, mlp_bias

if self.moe_layer_recompute:
Expand Down
3 changes: 3 additions & 0 deletions megatron/core/transformer/transformer_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,9 @@ class TransformerConfig(ModelParallelConfig):
"""Number of SMs to use for HybridEP. In pure NVL scenarios,
16 SMs can generally achieve good bandwidth."""

moe_latent_size: Optional[int] = None
"""Latent projection dimension for MoE. If None, MoE latent projections are not used."""

####################
# initialization
####################
Expand Down
11 changes: 11 additions & 0 deletions megatron/training/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -1225,6 +1225,13 @@ def validate_args(args, defaults={}):
args.recompute_granularity != 'full'
), 'recompute_granularity must not be full when CUDA Graphs are enabled.'

# MoE latent projections
if args.moe_latent_size is not None:
assert args.moe_latent_size > 0, "MoE latent projection dimension has to be greater than zero."
assert args.num_experts is not None, "MoE latent projections are applicable only for MoE models."
assert not args.use_legacy_models, "MoE latent projections are only supported for mcore models."
assert not args.moe_use_legacy_grouped_gemm, "MoE latent projection is not supported yet with legacy grouped GEMM."

if args.tiktoken_special_tokens and not args.tokenizer_special_tokens:
warn_rank_0(
"--tiktoken-special-tokens argument is deprecated and will be removed soon. "
Expand Down Expand Up @@ -1331,6 +1338,8 @@ def core_transformer_config_from_args(args, config_class=None):
kw_args['use_kitchen'] = True
kw_args['quant_recipe'] = kitchen_quantization_recipe_config(args.kitchen_recipe_number)

kw_args['moe_latent_size'] = args.moe_latent_size

if args.te_precision_config_file:
assert not 'quant_recipe' in kw_args, "Quantization recipe already configured."
# TODO(kwyss): Prohibit fp8_params or fp4_params with this flexibility
Expand Down Expand Up @@ -1719,6 +1728,8 @@ def _add_network_size_args(parser):
'We compute the average of the MTP losses across all depths, '
'and multiply it the scaling factor to obtain the overall MTP loss, '
'which serves as an additional training objective.')
group.add_argument('--moe-latent-size', type=int, default=None,
help='Latent projection dimension for MoE. If None, MoE latent projections are not used.')
return parser


Expand Down
3 changes: 3 additions & 0 deletions megatron/training/checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -1341,6 +1341,9 @@ def _set_arg(arg_name, old_arg_name=None, force=False):
_set_arg('heterogeneous_layers_config_path', force=True)
_set_arg('heterogeneous_layers_config_encoded_json', force=True)

# MoE latent projection
_set_arg('moe_latent_size', force=True)

# Tokenizer args.
_set_arg('tokenizer_type', force=True)
# Using checkpoint version might not always be safe (e.g., if running on different cluster).
Expand Down
19 changes: 15 additions & 4 deletions megatron/training/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,11 +178,19 @@ def mlp_layer_flops(batch_size, seq_len, hidden_size, expansion=4.0, swiglu=Fals
return 4 * expansion * scale_factor * batch_size * seq_len * hidden_size**2

def moe_layer_flops(batch_size, seq_len, hidden_size, moe_ffn_hidden_size,
shared_expert_ffn_hidden_size, num_experts_routed_to, swiglu=False):
shared_expert_ffn_hidden_size, num_experts_routed_to,
moe_latent_size=None, swiglu=False):
"""Calculate FLOPs for an MoE layer."""
scale_factor = 3.0 / 2.0 if swiglu else 1.0
routed_flops = (4 * batch_size * seq_len * hidden_size *
moe_ffn_hidden_size * num_experts_routed_to * scale_factor)
if moe_latent_size is None:
routed_flops = (4 * batch_size * seq_len * hidden_size *
moe_ffn_hidden_size * num_experts_routed_to * scale_factor)
else:
# Routed experts run on moe_latent_size.
routed_flops = (4 * batch_size * seq_len * moe_latent_size *
moe_ffn_hidden_size * num_experts_routed_to * scale_factor)
# Up proj and down proj.
routed_flops += (4 * batch_size * seq_len * hidden_size * moe_latent_size)
shared_flops = 4 * batch_size * seq_len * hidden_size * shared_expert_ffn_hidden_size * scale_factor
return routed_flops + shared_flops

Expand Down Expand Up @@ -230,6 +238,7 @@ def hybrid_flops(batch_size, seq_len, hidden_size,
num_attn_heads=32, gqa=True,
gqa_groups=8, kv_channels=None,
mlp_expansion=4.0, swiglu=False,
moe_latent_size=None,
moe_ffn_hidden_size=2048, shared_expert_ffn_hidden_size=2048, num_experts_routed_to=1,
vocab_size=256000):
"""Calculate total FLOPs for the hybrid model."""
Expand All @@ -242,7 +251,8 @@ def hybrid_flops(batch_size, seq_len, hidden_size,
mamba_state_dim, mamba_head_dim,
mamba_num_groups, mamba_num_heads) +
num_moe_layers * moe_layer_flops(batch_size, seq_len, hidden_size, moe_ffn_hidden_size,
shared_expert_ffn_hidden_size, num_experts_routed_to, swiglu) +
shared_expert_ffn_hidden_size, num_experts_routed_to,
moe_latent_size, swiglu) +
(2 * batch_size * seq_len * hidden_size * vocab_size) # logits computation
)
return flops_fwd * 3
Expand Down Expand Up @@ -447,6 +457,7 @@ def transformer_flops():
kv_channels=args.kv_channels,
mlp_expansion=args.ffn_hidden_size / args.hidden_size,
swiglu=args.swiglu,
moe_latent_size=args.moe_latent_size,
moe_ffn_hidden_size=(args.moe_ffn_hidden_size if args.moe_ffn_hidden_size is not None
else args.ffn_hidden_size),
shared_expert_ffn_hidden_size=(0 if args.moe_shared_expert_intermediate_size is None
Expand Down
20 changes: 18 additions & 2 deletions scripts/check_api_backwards_compatibility.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,13 +46,22 @@
# Decorators that exempt objects from compatibility checks
EXEMPT_DECORATORS = ['internal_api', 'deprecated', 'experimental_api']

# Breakage kinds to ignore (not actual API signature changes)
# Breakage kinds to ignore globally (not actual API signature changes)
# AttributeChangedValueBreakage: Changing constant values (e.g., VERSION = "1.0" -> "2.0")
# is not a breaking API change - the constant still exists with the same name
IGNORED_BREAKAGE_KINDS = [
'AttributeChangedValueBreakage',
]

# Breakage kinds to ignore only for __init__ methods
# ParameterMovedBreakage: Reordering parameters in __init__ is generally safe because:
# - Config dataclasses should always be initialized with keyword arguments
# - Adding fields to parent dataclasses shifts child __init__ params (inheritance artifact)
# - Nobody should call Config(4096, 32, ...) with positional args
IGNORED_FOR_INIT_METHODS = [
'ParameterMovedBreakage',
]


def has_exempt_decorator(obj: Object) -> bool:
"""Check if a Griffe object has any exempt decorator.
Expand Down Expand Up @@ -217,6 +226,7 @@ def should_skip_change(change, filtered_paths: set) -> bool:

A change is skipped if:
- The change kind is in IGNORED_BREAKAGE_KINDS (not a signature change)
- The change kind is in IGNORED_FOR_INIT_METHODS and affects an __init__ method
- The changed object itself is in filtered_paths (exact match)
- The changed object is a child of an exempt object (prefix match)

Expand All @@ -227,7 +237,7 @@ def should_skip_change(change, filtered_paths: set) -> bool:
Returns:
bool: True if the change should be skipped (filtered out)
"""
# Check if this breakage kind should be ignored (not a signature change)
# Check if this breakage kind should be ignored globally (not a signature change)
change_kind = type(change).__name__
if change_kind in IGNORED_BREAKAGE_KINDS:
return True
Expand All @@ -240,6 +250,12 @@ def should_skip_change(change, filtered_paths: set) -> bool:
# e.g., "Class.__init__(param)" -> "Class.__init__"
clean_path = path.split('(')[0] if '(' in path else path

# Check if this is a breakage kind we ignore for __init__ methods
# Config dataclasses should use keyword args, so parameter reordering is safe
if change_kind in IGNORED_FOR_INIT_METHODS:
if '.__init__' in clean_path:
return True

# Check exact match
if clean_path in filtered_paths or path in filtered_paths:
return True
Expand Down
Loading