diff --git a/megatron/core/extensions/transformer_engine.py b/megatron/core/extensions/transformer_engine.py index 32845980c14..a7259570077 100644 --- a/megatron/core/extensions/transformer_engine.py +++ b/megatron/core/extensions/transformer_engine.py @@ -535,7 +535,7 @@ def __init__( ) if is_te_min_version("0.8.0"): - if self.config.tp_comm_overlap: + if self.config.tp_comm_overlap and parallel_mode != "duplicated": if is_te_min_version("1.5.0"): # Use old overlap flags if they were supplied instead extra_kwargs["ub_overlap_ag"] = ( diff --git a/megatron/core/transformer/mlp.py b/megatron/core/transformer/mlp.py index 9602beb2f71..4a95134b8ed 100644 --- a/megatron/core/transformer/mlp.py +++ b/megatron/core/transformer/mlp.py @@ -104,9 +104,12 @@ def __init__( if self.config.gated_linear_unit: ffn_hidden_size *= 2 + # Use moe_latent_size only for routed experts. 'is_expert' is false for shared_experts + use_latent_size = (self.config.moe_latent_size is not None) and is_expert + self.linear_fc1 = build_module( submodules.linear_fc1, - self.input_size, + self.input_size if not use_latent_size else self.config.moe_latent_size, ffn_hidden_size, config=self.config, init_method=self.config.init_method, @@ -126,7 +129,7 @@ def __init__( self.linear_fc2 = build_module( submodules.linear_fc2, self.config.ffn_hidden_size, - self.config.hidden_size, + self.config.hidden_size if not use_latent_size else self.config.moe_latent_size, config=self.config, init_method=self.config.output_layer_init_method, bias=self.config.add_bias_linear, diff --git a/megatron/core/transformer/moe/experts.py b/megatron/core/transformer/moe/experts.py index 68a3d53d2be..9ea26e3e2ee 100644 --- a/megatron/core/transformer/moe/experts.py +++ b/megatron/core/transformer/moe/experts.py @@ -118,6 +118,9 @@ def __init__( assert ( config.add_bias_linear == False ), "bias not supported in Grouped GEMM yet, please set '--disable-bias-linear' instead." + assert ( + config.moe_latent_size is None + ), "MoE latent projection not supported in GroupedMLP yet." self.expert_parallel = config.expert_model_parallel_size > 1 if self.config.gated_linear_unit: @@ -778,7 +781,7 @@ def __init__( self.linear_fc1 = build_module( submodules.linear_fc1, self.num_local_experts, - self.input_size, + self.input_size if self.config.moe_latent_size is None else self.config.moe_latent_size, ffn_hidden_size, config=self.config, init_method=self.config.init_method, @@ -799,7 +802,11 @@ def __init__( submodules.linear_fc2, self.num_local_experts, self.config.moe_ffn_hidden_size, - self.config.hidden_size, + ( + self.config.hidden_size + if self.config.moe_latent_size is None + else self.config.moe_latent_size + ), config=self.config, init_method=self.config.output_layer_init_method, bias=self.config.add_bias_linear, diff --git a/megatron/core/transformer/moe/moe_layer.py b/megatron/core/transformer/moe/moe_layer.py index e3de8220a54..a9349729caa 100644 --- a/megatron/core/transformer/moe/moe_layer.py +++ b/megatron/core/transformer/moe/moe_layer.py @@ -23,7 +23,7 @@ try: import transformer_engine as te # pylint: disable=unused-import - from megatron.core.extensions.transformer_engine import te_checkpoint + from megatron.core.extensions.transformer_engine import TELinear, te_checkpoint HAVE_TE = True except ImportError: @@ -123,6 +123,32 @@ def __init__( # Initialize router self.router = TopKRouter(config=self.config, pg_collection=pg_collection) + # Initialize latent projections + if self.config.moe_latent_size: + assert HAVE_TE, "TransformerEngine is required for MoE latent projections." + self.fc1_latent_proj = TELinear( + self.config.hidden_size, + self.config.moe_latent_size, + parallel_mode="duplicated", + config=self.config, + init_method=self.config.init_method, + bias=self.config.add_bias_linear, + skip_bias_add=False, + skip_weight_param_allocation=False, + is_expert=False, + ) + self.fc2_latent_proj = TELinear( + self.config.moe_latent_size, + self.config.hidden_size, + parallel_mode="duplicated", + config=self.config, + init_method=self.config.output_layer_init_method, + bias=self.config.add_bias_linear, + skip_bias_add=False, + skip_weight_param_allocation=False, + is_expert=False, + ) + # Initialize token dispatcher if config.moe_token_dispatcher_type == "allgather": self.token_dispatcher = MoEAllGatherTokenDispatcher( @@ -176,6 +202,12 @@ def router_and_preprocess(self, hidden_states: torch.Tensor): """ residual = hidden_states probs, routing_map = self.router(hidden_states) + # Project the hidden_states from hidden dimension down to latent dimenion. + if self.config.moe_latent_size: + assert ( + not self.shared_expert_overlap + ), "Shared expert overlap not supported when MoE latent projections are used." + hidden_states, _ = self.fc1_latent_proj(hidden_states) hidden_states, probs = self.token_dispatcher.dispatch_preprocess( hidden_states, routing_map, probs ) @@ -243,6 +275,10 @@ def combine(self, output: torch.Tensor, shared_expert_output: Optional[torch.Ten """ output = self.token_dispatcher.token_combine(output) output = self.token_dispatcher.combine_postprocess(output) + # Project the output back from latent dimension to hidden dimension after combine + # in latent dimension. + if self.config.moe_latent_size: + output, _ = self.fc2_latent_proj(output) if shared_expert_output is not None: output = output + shared_expert_output return output @@ -274,7 +310,9 @@ def custom_forward(hidden_states): hidden_states, probs, residual = self.router_and_preprocess(hidden_states) dispatched_input, probs = self.dispatch(hidden_states, probs) output, mlp_bias = self.routed_experts_compute(dispatched_input, probs, residual) + assert mlp_bias is None, f"mlp_bias is not supported for {type(self.token_dispatcher)}" output = self.combine(output, shared_expert_output) + return output, mlp_bias if self.moe_layer_recompute: diff --git a/megatron/core/transformer/transformer_config.py b/megatron/core/transformer/transformer_config.py index b0df17ccf4e..ddb97472cd2 100644 --- a/megatron/core/transformer/transformer_config.py +++ b/megatron/core/transformer/transformer_config.py @@ -216,6 +216,9 @@ class TransformerConfig(ModelParallelConfig): """Number of SMs to use for HybridEP. In pure NVL scenarios, 16 SMs can generally achieve good bandwidth.""" + moe_latent_size: Optional[int] = None + """Latent projection dimension for MoE. If None, MoE latent projections are not used.""" + #################### # initialization #################### diff --git a/megatron/training/arguments.py b/megatron/training/arguments.py index 2662f0f9866..0d78ea9df54 100644 --- a/megatron/training/arguments.py +++ b/megatron/training/arguments.py @@ -1225,6 +1225,13 @@ def validate_args(args, defaults={}): args.recompute_granularity != 'full' ), 'recompute_granularity must not be full when CUDA Graphs are enabled.' + # MoE latent projections + if args.moe_latent_size is not None: + assert args.moe_latent_size > 0, "MoE latent projection dimension has to be greater than zero." + assert args.num_experts is not None, "MoE latent projections are applicable only for MoE models." + assert not args.use_legacy_models, "MoE latent projections are only supported for mcore models." + assert not args.moe_use_legacy_grouped_gemm, "MoE latent projection is not supported yet with legacy grouped GEMM." + if args.tiktoken_special_tokens and not args.tokenizer_special_tokens: warn_rank_0( "--tiktoken-special-tokens argument is deprecated and will be removed soon. " @@ -1331,6 +1338,8 @@ def core_transformer_config_from_args(args, config_class=None): kw_args['use_kitchen'] = True kw_args['quant_recipe'] = kitchen_quantization_recipe_config(args.kitchen_recipe_number) + kw_args['moe_latent_size'] = args.moe_latent_size + if args.te_precision_config_file: assert not 'quant_recipe' in kw_args, "Quantization recipe already configured." # TODO(kwyss): Prohibit fp8_params or fp4_params with this flexibility @@ -1719,6 +1728,8 @@ def _add_network_size_args(parser): 'We compute the average of the MTP losses across all depths, ' 'and multiply it the scaling factor to obtain the overall MTP loss, ' 'which serves as an additional training objective.') + group.add_argument('--moe-latent-size', type=int, default=None, + help='Latent projection dimension for MoE. If None, MoE latent projections are not used.') return parser diff --git a/megatron/training/checkpointing.py b/megatron/training/checkpointing.py index eb23e7cc092..d08aac734dc 100644 --- a/megatron/training/checkpointing.py +++ b/megatron/training/checkpointing.py @@ -1341,6 +1341,9 @@ def _set_arg(arg_name, old_arg_name=None, force=False): _set_arg('heterogeneous_layers_config_path', force=True) _set_arg('heterogeneous_layers_config_encoded_json', force=True) + # MoE latent projection + _set_arg('moe_latent_size', force=True) + # Tokenizer args. _set_arg('tokenizer_type', force=True) # Using checkpoint version might not always be safe (e.g., if running on different cluster). diff --git a/megatron/training/training.py b/megatron/training/training.py index eb7a903561b..5e03263a9dd 100644 --- a/megatron/training/training.py +++ b/megatron/training/training.py @@ -178,11 +178,19 @@ def mlp_layer_flops(batch_size, seq_len, hidden_size, expansion=4.0, swiglu=Fals return 4 * expansion * scale_factor * batch_size * seq_len * hidden_size**2 def moe_layer_flops(batch_size, seq_len, hidden_size, moe_ffn_hidden_size, - shared_expert_ffn_hidden_size, num_experts_routed_to, swiglu=False): + shared_expert_ffn_hidden_size, num_experts_routed_to, + moe_latent_size=None, swiglu=False): """Calculate FLOPs for an MoE layer.""" scale_factor = 3.0 / 2.0 if swiglu else 1.0 - routed_flops = (4 * batch_size * seq_len * hidden_size * - moe_ffn_hidden_size * num_experts_routed_to * scale_factor) + if moe_latent_size is None: + routed_flops = (4 * batch_size * seq_len * hidden_size * + moe_ffn_hidden_size * num_experts_routed_to * scale_factor) + else: + # Routed experts run on moe_latent_size. + routed_flops = (4 * batch_size * seq_len * moe_latent_size * + moe_ffn_hidden_size * num_experts_routed_to * scale_factor) + # Up proj and down proj. + routed_flops += (4 * batch_size * seq_len * hidden_size * moe_latent_size) shared_flops = 4 * batch_size * seq_len * hidden_size * shared_expert_ffn_hidden_size * scale_factor return routed_flops + shared_flops @@ -230,6 +238,7 @@ def hybrid_flops(batch_size, seq_len, hidden_size, num_attn_heads=32, gqa=True, gqa_groups=8, kv_channels=None, mlp_expansion=4.0, swiglu=False, + moe_latent_size=None, moe_ffn_hidden_size=2048, shared_expert_ffn_hidden_size=2048, num_experts_routed_to=1, vocab_size=256000): """Calculate total FLOPs for the hybrid model.""" @@ -242,7 +251,8 @@ def hybrid_flops(batch_size, seq_len, hidden_size, mamba_state_dim, mamba_head_dim, mamba_num_groups, mamba_num_heads) + num_moe_layers * moe_layer_flops(batch_size, seq_len, hidden_size, moe_ffn_hidden_size, - shared_expert_ffn_hidden_size, num_experts_routed_to, swiglu) + + shared_expert_ffn_hidden_size, num_experts_routed_to, + moe_latent_size, swiglu) + (2 * batch_size * seq_len * hidden_size * vocab_size) # logits computation ) return flops_fwd * 3 @@ -447,6 +457,7 @@ def transformer_flops(): kv_channels=args.kv_channels, mlp_expansion=args.ffn_hidden_size / args.hidden_size, swiglu=args.swiglu, + moe_latent_size=args.moe_latent_size, moe_ffn_hidden_size=(args.moe_ffn_hidden_size if args.moe_ffn_hidden_size is not None else args.ffn_hidden_size), shared_expert_ffn_hidden_size=(0 if args.moe_shared_expert_intermediate_size is None diff --git a/scripts/check_api_backwards_compatibility.py b/scripts/check_api_backwards_compatibility.py index 4977b806433..3c66f00b619 100644 --- a/scripts/check_api_backwards_compatibility.py +++ b/scripts/check_api_backwards_compatibility.py @@ -46,13 +46,22 @@ # Decorators that exempt objects from compatibility checks EXEMPT_DECORATORS = ['internal_api', 'deprecated', 'experimental_api'] -# Breakage kinds to ignore (not actual API signature changes) +# Breakage kinds to ignore globally (not actual API signature changes) # AttributeChangedValueBreakage: Changing constant values (e.g., VERSION = "1.0" -> "2.0") # is not a breaking API change - the constant still exists with the same name IGNORED_BREAKAGE_KINDS = [ 'AttributeChangedValueBreakage', ] +# Breakage kinds to ignore only for __init__ methods +# ParameterMovedBreakage: Reordering parameters in __init__ is generally safe because: +# - Config dataclasses should always be initialized with keyword arguments +# - Adding fields to parent dataclasses shifts child __init__ params (inheritance artifact) +# - Nobody should call Config(4096, 32, ...) with positional args +IGNORED_FOR_INIT_METHODS = [ + 'ParameterMovedBreakage', +] + def has_exempt_decorator(obj: Object) -> bool: """Check if a Griffe object has any exempt decorator. @@ -217,6 +226,7 @@ def should_skip_change(change, filtered_paths: set) -> bool: A change is skipped if: - The change kind is in IGNORED_BREAKAGE_KINDS (not a signature change) + - The change kind is in IGNORED_FOR_INIT_METHODS and affects an __init__ method - The changed object itself is in filtered_paths (exact match) - The changed object is a child of an exempt object (prefix match) @@ -227,7 +237,7 @@ def should_skip_change(change, filtered_paths: set) -> bool: Returns: bool: True if the change should be skipped (filtered out) """ - # Check if this breakage kind should be ignored (not a signature change) + # Check if this breakage kind should be ignored globally (not a signature change) change_kind = type(change).__name__ if change_kind in IGNORED_BREAKAGE_KINDS: return True @@ -240,6 +250,12 @@ def should_skip_change(change, filtered_paths: set) -> bool: # e.g., "Class.__init__(param)" -> "Class.__init__" clean_path = path.split('(')[0] if '(' in path else path + # Check if this is a breakage kind we ignore for __init__ methods + # Config dataclasses should use keyword args, so parameter reordering is safe + if change_kind in IGNORED_FOR_INIT_METHODS: + if '.__init__' in clean_path: + return True + # Check exact match if clean_path in filtered_paths or path in filtered_paths: return True