diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 40a0d5c16..b8193d909 100755 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -10,6 +10,7 @@ Model Optimizer Changelog (Linux) **New Features** +- Add MoE (e.g. Qwen3-30B-A3B) pruning support for ``num_moe_experts``, ``moe_ffn_hidden_size`` and ``moe_shared_expert_intermediate_size`` parameters in Minitron pruning (``mcore_minitron``). - Add ``specdec_bench`` example to benchmark speculative decoding performance. See `examples/specdec_bench/README.md `_ for more details. 0.39 (2025-11-14) diff --git a/docs/source/guides/7_nas.rst b/docs/source/guides/7_nas.rst index 98d2b9729..8cd22d214 100644 --- a/docs/source/guides/7_nas.rst +++ b/docs/source/guides/7_nas.rst @@ -361,9 +361,12 @@ can be converted into searchable units: # search over the number of layers (depth) in the sequential layer. nn.Sequential - # We convert Megatron-core / NeMo GPT or Mamba style models (e.g. Llama3.1, NeMo Mistral, NeMotron-H, etc.) - # to automatically search over the MLP hidden size, number of attention heads, number of GQA groups, - # number of mamba heads, mamba head dimension, and depth of the model. + # We convert Megatron-core / NeMo GPT or MoE or Mamba Hybrid style models (e.g. Llama3, Nemotron-H, Qwen3-30B-A3B) + # to automatically search over the + # MLP hidden size, number of attention heads, number of GQA groups, + # number of mamba heads, mamba head dimension, + # number of moe experts, moe ffn hidden size, moe shared expert intermediate size, + # and depth of the model. megatron.core.models.gpt.GPTModel megatron.core.models.mamba.MambaModel nemo.collections.llm.gpt.model.base.GPTModel @@ -640,7 +643,7 @@ The difference between NAS and pruning is summarized below. [Advanced] Adding a new NAS/Prune Algorithm =========================================== -* Please refer to this `template `_ +* Please refer to this `template `_ for adding a new NAS algorithm. * Please refer to `mcore_minitron.py `_ for an actual example of adding Minitron Pruning algorithm. \ No newline at end of file diff --git a/examples/megatron-lm/README.md b/examples/megatron-lm/README.md index 6f88fad9d..254073675 100644 --- a/examples/megatron-lm/README.md +++ b/examples/megatron-lm/README.md @@ -20,7 +20,7 @@ | Model | Quantization | EAGLE3 | Q-LoRA | Pruning (PP only) | Distillation | | :---: | :---: | :---: | :---: | :---: | :---: | | `moonshotai/Kimi-K2-Instruct` | ✅ | **Online** | | | | -| `Qwen/Qwen3-{30B-A3B, 235B-A22B}` | **WAR** | **Online** | | | | +| `Qwen/Qwen3-{30B-A3B, 235B-A22B}` | **WAR** | **Online** | | ✅ | ✅ | | `Qwen/Qwen3-{0.6B, 8B}` | ✅ | **Online** | | ✅ | ✅ | | `deepseek-ai/DeepSeek-R1` | ✅ | **Online** | | | | | `meta-llama/Llama-{3.1-8B, 3.1-405B, 3.2-1B}-Instruct` | ✅ | **Online** | | ✅ | ✅ | @@ -112,7 +112,7 @@ Coming soon ... Checkout pruning [getting started section](../pruning/README.md#getting-started) and [guidelines](../pruning/README.md#pruning-guidelines) for configuring pruning parameters in the pruning README. -Pruning is supported for GPT and Mamba models in Pipeline Parallel mode. Available pruning options are: +Pruning is supported for GPT and Mamba models in Pipeline Parallel mode. Available pruning dimensions are: - `TARGET_FFN_HIDDEN_SIZE` - `TARGET_HIDDEN_SIZE` @@ -120,6 +120,9 @@ Pruning is supported for GPT and Mamba models in Pipeline Parallel mode. Availab - `TARGET_NUM_QUERY_GROUPS` - `TARGET_MAMBA_NUM_HEADS` - `TARGET_MAMBA_HEAD_DIM` +- `TARGET_NUM_MOE_EXPERTS` +- `TARGET_MOE_FFN_HIDDEN_SIZE` +- `TARGET_MOE_SHARED_EXPERT_INTERMEDIATE_SIZE` - `TARGET_NUM_LAYERS` - `LAYERS_TO_DROP` (comma separated, 1-indexed list of layer numbers to directly drop) @@ -137,6 +140,10 @@ bash Megatron-LM/examples/post_training/modelopt/prune.sh qwen/Qwen3-8B > If number of layers in the model is not divisible by pipeline parallel size (PP), you can configure uneven > PP by setting `MLM_EXTRA_ARGS="--decoder-first-pipeline-num-layers --decoder-last-pipeline-num-layers "` +> [!TIP] +> You can reuse pruning scores for pruning same model again to different architectures by setting +> `PRUNE_ARGS="--pruning-scores-path "` + ## Learn More About Configuration For simplicity, we use `shell` scripts and variables as arguments. Each script has at least 1 positional diff --git a/examples/pruning/README.md b/examples/pruning/README.md index 3efa9eb79..7ae42850b 100644 --- a/examples/pruning/README.md +++ b/examples/pruning/README.md @@ -6,7 +6,7 @@ Pruning can involve removal (prune) of Linear and Conv layers, and Transformer a This section focuses on applying Model Optimizer's state-of-the-art complementary pruning modes to enable you to search for the best subnet architecture from your provided base model: -1. [Minitron](https://arxiv.org/pdf/2408.11796): A pruning method developed by NVIDIA Research for pruning GPT, Mamba and Hybrid Transformer Mamba models in NVIDIA NeMo or Megatron-LM framework. It uses the activation magnitudes to prune the embedding hidden size, mlp ffn hidden size, transformer attention heads, GQA query groups, mamba heads and head dimension, and number of layers of the model. +1. [Minitron](https://arxiv.org/pdf/2408.11796): A pruning method developed by NVIDIA Research for pruning GPT, Mamba and Hybrid Transformer Mamba models in NVIDIA NeMo or Megatron-LM framework. It uses the activation magnitudes to prune the embedding hidden size; mlp ffn hidden size; transformer attention heads and GQA query groups; mamba heads and head dimension; MoE number of experts, ffn hidden size, and shared expert intermediate size; and number of layers of the model. 1. FastNAS: A pruning method recommended for Computer Vision models. Given a pretrained model, FastNAS finds the subnet which maximizes the score function while meeting the given constraints. 1. GradNAS: A light-weight pruning method recommended for language models like Hugging Face BERT, GPT-J. It uses the gradient information to prune the model's linear layers and attention heads to meet the given constraints. @@ -89,11 +89,11 @@ If your model parameters are already sorted, you can skip the sorting step by se | **Algorithm** | **Model** | **Pruning Constraints** | | :---: | :---: | :---: | -| Minitron | Megatron-core / NeMo based GPT / Mamba / Hybrid Models1 | Export config with width (`hidden_size`, `ffn_hidden_size`, `num_attention_heads`, `num_query_groups`, `mamba_num_heads`, `mamba_head_dim`) and/or depth (`num_layers`) values | +| Minitron | Megatron-core / NeMo based GPT / Mamba / MoE / Hybrid Models1 | Export config with width (`hidden_size`, `ffn_hidden_size`, `num_attention_heads`, `num_query_groups`, `mamba_num_heads`, `mamba_head_dim`, `num_moe_experts`, `moe_ffn_hidden_size`, `moe_shared_expert_intermediate_size`) and/or depth (`num_layers`) values | | FastNAS | Computer Vision models | flops, parameters | | GradNAS | HuggingFace BERT, GPT-J | flops, parameters | -> *1.Only Pipeline Parallel models are supported. Hugging Face models can be converted to NeMo format and used subsequently.* +> *1.Only Pipeline Parallel models are supported. Hugging Face models can be converted to Megatron-LM/NeMo format and used subsequently.* ## Pruning Guidelines @@ -122,7 +122,7 @@ Depth pruning reduces the number of layers (`num_layers`) in the model. #### Width Pruning -Width pruning reduces model dimensions per layer such as `hidden_size`, `ffn_hidden_size`, `num_attention_heads`, `num_query_groups`, `mamba_num_heads`, and `mamba_head_dim`. +Width pruning reduces model dimensions per layer such as `hidden_size`, `ffn_hidden_size`, `num_attention_heads`, `num_query_groups`, `mamba_num_heads`, `mamba_head_dim`, `num_moe_experts`, `moe_ffn_hidden_size`, and `moe_shared_expert_intermediate_size`. **Advantages:** diff --git a/modelopt/torch/nas/modules/container.py b/modelopt/torch/nas/modules/container.py index 7615f36f1..f001e16af 100644 --- a/modelopt/torch/nas/modules/container.py +++ b/modelopt/torch/nas/modules/container.py @@ -26,7 +26,7 @@ from ..registry import DMRegistry from ..traced_hp import TracedHp -__all__ = ["_DynamicSequential"] +__all__ = ["DynamicModuleList", "_DynamicSequential"] def _activate_depth(func: Callable) -> Callable: @@ -97,3 +97,35 @@ def modify(self, *, min_depth: int = 0): """ hp = self.get_hparam("depth") hp.choices = [d for d in hp.choices if d >= min_depth] + + +# NOTE: We provide a parent class since we do not register to DMRegistry and explicitly convert a module if needed. +class DynamicModuleList(DynamicModule, nn.ModuleList): + """An ``nn.ModuleList`` container with dynamic hyperparams and variable ``depth``. + + Unlike _DynamicSequential, this module supports sorting/reordering of modules based on + importance in addition to variable depth. + """ + + def _setup(self): + # register hyperparameters + self._register_hparam("depth", TracedHp(list(range(1, len(self) + 1)))) + + # register _modules as a dynamic attribute + self._register_dynamic_attribute("_modules", self._get_modules) + + @staticmethod + def _get_modules(mod: "DynamicModuleList", modules: dict) -> dict: + """Get modules with dynamic depth and ordering applied based on active_slice.""" + hp = mod.get_hparam("depth") + active_slice = hp.active_slice + + items = list(modules.items()) + + if isinstance(active_slice, slice): + active_items = items[active_slice] + else: + active_items = [items[idx] for idx in active_slice.tolist()] + + # Re-create dict with keys as str(index) from 0 to len(active_items) + return {str(i): module for i, (_, module) in enumerate(active_items)} diff --git a/modelopt/torch/nas/plugins/megatron.py b/modelopt/torch/nas/plugins/megatron.py index a58ef2140..988bbad9c 100644 --- a/modelopt/torch/nas/plugins/megatron.py +++ b/modelopt/torch/nas/plugins/megatron.py @@ -16,6 +16,7 @@ """Plugin to add NAS/Pruning support for megatron-core Language models like GPT and Mamba.""" import types +from abc import ABC from collections.abc import Callable, Sequence from typing import Any @@ -46,8 +47,14 @@ from megatron.core.transformer.attention import SelfAttention from megatron.core.transformer.dot_product_attention import DotProductAttention from megatron.core.transformer.mlp import MLP +from megatron.core.transformer.moe import moe_utils +from megatron.core.transformer.moe.experts import SequentialMLP +from megatron.core.transformer.moe.moe_layer import MoELayer +from megatron.core.transformer.moe.router import TopKRouter +from megatron.core.transformer.moe.shared_experts import SharedExpertMLP from megatron.core.transformer.transformer_layer import TransformerLayer +from modelopt.torch.nas.modules import DynamicModuleList from modelopt.torch.opt.dynamic import DynamicModule from modelopt.torch.opt.hparam import HPType from modelopt.torch.opt.searcher import ConstraintsDict @@ -99,6 +106,7 @@ __all__ = ["drop_mcore_language_model_layers"] +# TODO: Allow passing setup_kwargs to DM.convert so we can reuse hparams directly during setup class _DynamicParallelLinear(DynamicModule): """A parallel linear layer with dynamic hyperparams.""" @@ -126,7 +134,7 @@ def _get_bias(mod: "_DynamicParallelLinear", bias: torch.Tensor | None) -> torch {ColumnParallelLinear: "megatron.core.tensor_parallel.layers.ColumnParallelLinear"} ) class _DynamicColumnParallelLinear(_DynamicParallelLinear): - """A ``megatron.core.tensor_parallel.layers.ColumnParallelLinear`` layer with dynamic hyperparams.""" + """A ColumnParallelLinear layer with dynamic hyperparams.""" def _setup(self): super()._setup() @@ -137,7 +145,7 @@ def _setup(self): @DMRegistry.register({RowParallelLinear: "megatron.core.tensor_parallel.layers.RowParallelLinear"}) class _DynamicRowParallelLinear(_DynamicParallelLinear): - """A ``megatron.core.tensor_parallel.layers.RowParallelLinear`` layer with dynamic hyperparams.""" + """A RowParallelLinear layer with dynamic hyperparams.""" def _setup(self): super()._setup() @@ -150,7 +158,7 @@ def _setup(self): {VocabParallelEmbedding: "megatron.core.tensor_parallel.layers.VocabParallelEmbedding"} ) class _DynamicVocabParallelEmbedding(DynamicModule): - """A ``megatron.core.tensor_parallel.layers.VocabParallelEmbedding`` layer with dynamic hyperparams.""" + """A VocabParallelEmbedding layer with dynamic hyperparams.""" def _setup(self): self._register_hparam("embedding_dim", TracedHp(list(range(1, self.embedding_dim + 1)))) @@ -164,7 +172,7 @@ def _get_weight(mod: "_DynamicVocabParallelEmbedding", weight: torch.Tensor) -> @DMRegistry.register({FusedLayerNorm: "megatron.core.fusions.fused_layer_norm.FusedLayerNorm"}) class _DynamicFusedLayerNorm(_DynamicLayerNorm): - """A ``megatron.core.fusions.fused_layer_norm.FusedLayerNorm`` layer with dynamic hyperparams.""" + """A FusedLayerNorm layer with dynamic hyperparams.""" def _setup(self): # construct hidden_size with Hparam as last dimension @@ -180,15 +188,29 @@ def _setup(self): self._register_dynamic_attribute("hidden_size", self._get_normalized_shape) -@DMRegistry.register({MLP: "megatron.core.transformer.mlp.MLP"}) +# MLP DynamicModule ################################################################################ +@DMRegistry.register( + { + MLP: "megatron.core.transformer.mlp.MLP", + SharedExpertMLP: "megatron.core.transformer.moe.shared_experts.SharedExpertMLP", + } +) class _DynamicMLP(DynamicModule): - """A ``megatron.core.transformer.mlp.MLP`` layer with dynamic hyperparams.""" + """An MLP layer with dynamic hyperparams. + + Use for standard MLP and inside MoE layers (SequentialMLP and SharedExpertMLP). + """ def _setup(self): assert self.input_size == self.config.hidden_size, ( "MLP input_size must be equal to hidden_size" ) - + if isinstance(self, SharedExpertMLP): + self.hparam_name = "moe_shared_expert_intermediate_size" + elif self.config.num_moe_experts is not None: + self.hparam_name = "moe_ffn_hidden_size" + else: + self.hparam_name = "ffn_hidden_size" self.linear_fc1 = DMRegistry.convert(self.linear_fc1) self.linear_fc2 = DMRegistry.convert(self.linear_fc2) @@ -199,7 +221,7 @@ def _setup(self): else ffn_hidden_size ) - self._register_hparam("ffn_hidden_size", ffn_hidden_size) + self._register_hparam(self.hparam_name, ffn_hidden_size) self.linear_fc1.output_size = fc1_output_size self.linear_fc2.input_size = ffn_hidden_size @@ -223,9 +245,11 @@ def _linear_fc2_forward_hook(self, module, input, output): # Gather input [seq_len, batch_size, ffn_hidden_size] over all TP regions # NOTE: This is not used at the moment since we restrict to TP=1 input = gather_from_tensor_model_parallel_region(input[0]).detach() - + if input.dim() == 2: + # For sparse experts, there is no batch dimension. + input = input[:, None, :] # Dont aggregate activations from non-max subnets (e.g. from profiling) - if input.shape[-1] != self.get_hparam("ffn_hidden_size").max: + if input.shape[-1] != self.get_hparam(self.hparam_name).max: return input = input.to(torch.float32) # use full precision to avoid overflow @@ -242,15 +266,26 @@ def _estimate_importance(self) -> TracedHp.Importance: # Convert squared sum to L2 norm return self._activations.pow(0.5) + def set_hidden_size_hp(self, hidden_size: TracedHp) -> None: + """Set hidden size for shared expert.""" + self.linear_fc1.input_size = hidden_size + self.linear_fc2.output_size = hidden_size + + def modify(self, ffn_hidden_size_divisor: int, **kwargs) -> None: + """Modify the ffn_hidden_size hparam choices based on search space config.""" + hp_mlp = self.get_hparam(self.hparam_name) + choices = {int(make_divisible(c, ffn_hidden_size_divisor)) for c in hp_mlp.choices} # type: ignore[arg-type] + hp_mlp.choices = list(set(hp_mlp.choices) & choices | {hp_mlp.original}) + def export(self) -> torch.nn.Module: """Export the dynamic module to a torch.nn.Module.""" self.hook_handle.remove() self.linear_fc1.export() self.linear_fc2.export() - super().export() - return self + return super().export() +# SelfAttention DynamicModules ##################################################################### def expand_head_indices(heads: torch.LongTensor, hidden_size_per_head: int) -> torch.LongTensor: """Expand each head index to hidden_size_per_head indices and offset by head * hidden_size_per_head.""" return ( @@ -464,7 +499,7 @@ def _get_bias( @DMRegistry.register({SelfAttention: "megatron.core.transformer.attention.SelfAttention"}) class _DynamicSelfAttention(DynamicModule): - """A ``megatron.core.transformer.attention.SelfAttention`` layer with dynamic hyperparams. + """A SelfAttention layer with dynamic hyperparams. NOTE: Layernorms apply on hidden_size_per_attention_head hence no need to convert to dynamic """ @@ -617,10 +652,204 @@ def export(self) -> torch.nn.Module: self.core_attention.export() self.linear_qkv.export() self.linear_proj.export() - super().export() - return self + return super().export() + + +# MoE DynamicModules ############################################################################### +# Add ABC to avoid TypeError: object layout differs (because parent if TopKRouter inherits from ABC) +@DMRegistry.register({TopKRouter: "megatron.core.transformer.moe.router.TopKRouter"}) +class _DynamicTopKRouter(ABC, DynamicModule): + """A TopKRouter with dynamic hyperparams.""" + + def _setup(self): + # Register hparams for router weight dimensions (will be overridden by _DynamicSequentialMLP's hp) + # Router weight shape: [num_moe_experts, hidden_size] + self._register_hparam("num_experts", TracedHp(list(range(1, self.weight.shape[0] + 1)))) + # Register hidden_size reference (will be overridden by _DynamicMoELayer's hidden_size) + self._register_hparam("hidden_size", TracedHp(list(range(1, self.weight.shape[1] + 1)))) + + # Register dynamic attributes + self._register_dynamic_attribute("weight", self._get_router_weight) + if self.enable_expert_bias: + self._register_dynamic_attribute( + "local_tokens_per_expert", self._get_slice_by_num_experts + ) + self._register_dynamic_attribute("expert_bias", self._get_slice_by_num_experts) + @staticmethod + def _get_router_weight(mod: "_DynamicTopKRouter", weight: torch.Tensor) -> torch.Tensor: + return get_sliced_tensor(mod, weight, "num_experts", "hidden_size") + + @staticmethod + def _get_slice_by_num_experts(mod: "_DynamicTopKRouter", bias: torch.Tensor) -> torch.Tensor: + return get_sliced_tensor(mod, bias, "num_experts") + + def set_hidden_size_hp(self, hidden_size: TracedHp) -> None: + """Set hidden_size hparam for router weights from global hidden_size hparam.""" + self.hidden_size = hidden_size + + +@DMRegistry.register({SequentialMLP: "megatron.core.transformer.moe.experts.SequentialMLP"}) +class _DynamicSequentialMLP(DynamicModule): + """A SequentialMLP with dynamic hyperparams.""" + + def _setup(self): + # Register hparam for number of active experts (will be shared with _DynamicTopKRouter's hp) + num_moe_experts = TracedHp(list(range(1, self.num_local_experts + 1))) + self._register_hparam("num_local_experts", num_moe_experts) + + # Convert local_experts list and each individual expert MLP to dynamic modules + self.local_experts = DynamicModuleList.convert(self.local_experts) + self.local_experts.depth = num_moe_experts # Reuse same hparam for depth + for i in range(len(self.local_experts)): + self.local_experts[i] = DMRegistry.convert(self.local_experts[i]) + + # Track forward activations for importance estimation. + # _activations name is needed for get_activations_and_layer_scores to save scores for re-running pruning. + self._register_temp_attribute( + "_activations", + { + "expert_l2_scores": torch.zeros(self.num_local_experts), + "expert_sample_counts": torch.zeros(self.num_local_experts), + }, + ) + self.hook_handle = self.register_forward_hook(self._expert_l2_imp_forward_hook) + num_moe_experts.register_importance(self._estimate_expert_importance) + + def set_hidden_size_hp(self, hidden_size: TracedHp) -> None: + """Set hidden_size hparam for all expert MLPs from global hidden_size hparam.""" + for expert in self.local_experts: + expert.set_hidden_size_hp(hidden_size) + def _expert_l2_imp_forward_hook(self, module, input, output): + """Track expert importance based on L2 norms of expert outputs.""" + # Dont aggregate activations from non-max subnets (e.g. from profiling) + num_moe_experts = self.get_hparam("num_local_experts") + if num_moe_experts.active != num_moe_experts.max: + return + + # Split output back to per-expert outputs using torch.split + tokens_per_expert_list = input[1].tolist() + # use full precision to avoid overflow + output_local = output[0].to(torch.float32).detach() + + output_local_list = torch.split(output_local, tokens_per_expert_list) + + # Compute L2 norm for each expert's output + for expert_idx, expert_output in enumerate(output_local_list): + # Guard: if expert_output is empty tensor, add zero score + if expert_output.numel() == 0: + l2_norm = 0.0 + else: + # Compute L2 norm of expert output (router_prob * expert_output) + l2_norm = torch.linalg.vector_norm(expert_output, ord=2, dim=-1).sum().item() + + # Accumulate L2 scores and sample counts + self._activations["expert_l2_scores"][expert_idx] += l2_norm + self._activations["expert_sample_counts"][expert_idx] += tokens_per_expert_list[ + expert_idx + ] + + def _estimate_expert_importance(self) -> TracedHp.Importance: + """Estimate expert importance based on accumulated L2 norms.""" + assert self._activations["expert_sample_counts"].sum() > 0, ( + "No activations collected for importance estimation." + ) + # Average L2 scores across samples (avoid division by zero if some experts have no samples) + return self._activations["expert_l2_scores"] / ( + self._activations["expert_sample_counts"] + 1e-8 + ) + + def export(self) -> torch.nn.Module: + """Export the dynamic module to a standard SequentialMLP.""" + self.hook_handle.remove() + for expert in self.local_experts: + expert.export() + self.local_experts.export() + return super().export() + + +@DMRegistry.register({MoELayer: "megatron.core.transformer.moe.moe_layer.MoELayer"}) +class _DynamicMoELayer(DynamicModule): + """A MoELayer with dynamic hyperparams.""" + + def _setup(self): + # Convert to dynamic modules + # Reuse _DynamicSequentialMLP's num_moe_experts hparam for _DynamicTopKRouter's hparam so + # importance estimator and depth hparam is retained. + self.router = DMRegistry.convert(self.router) + self.experts = DMRegistry.convert(self.experts) + num_moe_experts_hp = self.experts.get_hparam("num_local_experts") + + # NOTE: Use num_moe_experts hparam name in top-level module to match TransformerConfig's name + self._register_hparam("num_moe_experts", num_moe_experts_hp) + self._register_dynamic_attribute( + "num_local_experts", + lambda mod, val: num_moe_experts_hp.active, # EP = 1 + ) + self.router.num_experts = num_moe_experts_hp + if self.use_shared_expert: + self.shared_experts = DMRegistry.convert(self.shared_experts) + + def forward(self, *args, **kwargs): + """Forward pass for the MoE layer.""" + # Dont allow forward if model is sorted / trimmed unless exported (reinitializing token dispatcher correctly) + if isinstance(self, DynamicModule) and ( + self.get_hparam("num_moe_experts")._slice_order is not None + or self.get_hparam("num_moe_experts").active != self.get_hparam("num_moe_experts").max + ): + raise RuntimeError("Only run forward after exporting the pruned model") + return super().forward(*args, **kwargs) + + def set_hidden_size_hp(self, hidden_size: TracedHp) -> None: + """Set hidden size for all MoE components from global hidden_size hparam.""" + self.router.set_hidden_size_hp(hidden_size) + self.experts.set_hidden_size_hp(hidden_size) + if self.use_shared_expert: + self.shared_experts.set_hidden_size_hp(hidden_size) + + def modify( + self, *, num_moe_experts_divisor: int = 1, ffn_hidden_size_divisor: int = 1, **kwargs + ): + """Modify MoE hparam choices based on search space config.""" + # Modify num_moe_experts hparam choices (applies to both router and experts) + expert_hp = self.get_hparam("num_moe_experts") + choices = {int(make_divisible(c, num_moe_experts_divisor)) for c in expert_hp.choices} # type: ignore[arg-type] + expert_hp.choices = list(set(expert_hp.choices) & choices | {expert_hp.original}) + + # Modify expert FFN hparam choices + for expert in self.experts.local_experts: + expert.modify(ffn_hidden_size_divisor=ffn_hidden_size_divisor) + if self.use_shared_expert: + self.shared_experts.modify(ffn_hidden_size_divisor) + + def _export_reinit_token_dispatcher(self) -> None: + """Reinitialize the token dispatcher after pruning.""" + print_rank_0("Reinitializing token dispatcher after pruning") + if hasattr(moe_utils, "get_default_model_comm_pgs"): + model_comm_pgs = moe_utils.get_default_model_comm_pgs() + else: + model_comm_pgs = moe_utils.get_default_pg_collection() + # NOTE: Update config.num_moe_experts for correct router initialization. + self.config.num_moe_experts = self.num_moe_experts + self.token_dispatcher = type(self.token_dispatcher)( + self.num_local_experts, list(range(self.num_local_experts)), self.config, model_comm_pgs + ) + + if self.use_shared_expert and self.shared_expert_overlap: + self.token_dispatcher.set_shared_experts(self.shared_experts) + + def export(self) -> torch.nn.Module: + """Export the dynamic module to a standard MoELayer.""" + self.router.export() + self.experts.export() + if self.use_shared_expert: + self.shared_experts.export() + self._export_reinit_token_dispatcher() + return super().export() + + +# TransformerLayer DynamicModule ################################################################### class MambaTransformerLayerMixin(nn.Module): """A mixin for MambaLayer and TransformerLayer to share the same logic.""" @@ -663,15 +892,16 @@ def _layer_imp_forward_hook(self, module, args, kwargs, output) -> None: {TransformerLayer: "megatron.core.transformer.transformer_layer.TransformerLayer"} ) class _DynamicTransformerLayer(DynamicModule, MambaTransformerLayerMixin): - """A ``megatron.core.transformer.transformer_layer.TransformerLayer`` layer with dynamic hyperparams.""" + """A TransformerLayer layer with dynamic hyperparams.""" def _setup(self): - # Convert the layernorms, self-attention, and mlp layers to dynamic modules + # Convert the layernorms, self-attention, and mlp/moe layers to dynamic modules # NOTE: Mamba stack layers have either Attention or MLP, not both unlike GPT models if isinstance(self.self_attention, SelfAttention): self.input_layernorm = DMRegistry.convert(self.input_layernorm) self.self_attention = DMRegistry.convert(self.self_attention) - if isinstance(self.mlp, MLP): + + if isinstance(self.mlp, (MLP, MoELayer)): self.pre_mlp_layernorm = DMRegistry.convert(self.pre_mlp_layernorm) self.mlp = DMRegistry.convert(self.mlp) @@ -683,10 +913,10 @@ def set_hidden_size_hp(self, hidden_size: TracedHp) -> None: self.input_layernorm.num_features = hidden_size self.self_attention.linear_qkv.input_size = hidden_size self.self_attention.linear_proj.output_size = hidden_size - if isinstance(self.mlp, MLP): + + if isinstance(self.mlp, (MLP, MoELayer)): self.pre_mlp_layernorm.num_features = hidden_size - self.mlp.linear_fc1.input_size = hidden_size - self.mlp.linear_fc2.output_size = hidden_size + self.mlp.set_hidden_size_hp(hidden_size) self._register_temp_attribute("max_hidden_size", hidden_size.max) @@ -696,9 +926,11 @@ def modify( num_heads_per_group_divisor: int = 1, num_query_groups_divisor: int = 1, ffn_hidden_size_divisor: int = 1, + num_moe_experts_divisor: int = 1, **kwargs, # Unused hparams ) -> None: - # Modify SelfAttention hparams + """Modify TransformerLayer hparam choices based on search space config.""" + # Modify SelfAttention hparam if isinstance(self.self_attention, SelfAttention): for hp_name, divisor in [ ("num_heads_per_group", num_heads_per_group_divisor), @@ -708,11 +940,12 @@ def modify( choices = {int(make_divisible(c, divisor)) for c in hp.choices} hp.choices = list(set(hp.choices) & choices | {hp.original}) - # Modify MLP hparams - if isinstance(self.mlp, MLP): - hp_mlp = self.mlp.get_hparam("ffn_hidden_size") - choices = {int(make_divisible(c, ffn_hidden_size_divisor)) for c in hp_mlp.choices} - hp_mlp.choices = list(set(hp_mlp.choices) & choices | {hp_mlp.original}) + # Modify MLP hparam (regular or MoE) + if isinstance(self.mlp, (MLP, MoELayer)): + self.mlp.modify( + ffn_hidden_size_divisor=ffn_hidden_size_divisor, + num_moe_experts_divisor=num_moe_experts_divisor, + ) def export(self): """Export the dynamic module to a torch.nn.Module.""" @@ -720,23 +953,13 @@ def export(self): if isinstance(self.self_attention, SelfAttention): self.input_layernorm.export() self.self_attention.export() - if isinstance(self.mlp, MLP): + if isinstance(self.mlp, (MLP, MoELayer)): self.pre_mlp_layernorm.export() self.mlp.export() - super().export() - return self - - def freeze(self): - """Freeze the dynamic module.""" - super().freeze() - if isinstance(self.self_attention, SelfAttention): - self.input_layernorm.freeze() - self.self_attention.freeze() - if isinstance(self.mlp, MLP): - self.pre_mlp_layernorm.freeze() - self.mlp.freeze() + return super().export() +# Mamba DynamicModules ############################################################################# class MambaNumHeadsHp(TracedHp): """An hparam for Mamba's num_heads. @@ -829,7 +1052,7 @@ def _resolve_dependencies( class _DynamicExtendedRMSNorm(DynamicModule): - """A ``megatron.core.ssm.mamba_mixer.ExtendedRMSNorm`` (GroupNorm) layer with dynamic hyperparams. + """An ExtendedRMSNorm (GroupNorm) layer with dynamic hyperparams. Very similar to _DynamicGroupNorm but with group_size dynamic attribute instead of num_groups. Will be registered to DMRegistry if Mamba is available. @@ -915,7 +1138,7 @@ def __setattr__(self, name, value): class _DynamicMambaMixer(DynamicModule): - """A ``megatron.core.ssm.mamba_mixer.MambaMixer`` layer with dynamic hyperparams. + """A MambaMixer layer with dynamic hyperparams. Will be registered to DMRegistry if Mamba is available. """ @@ -1076,12 +1299,11 @@ def export(self) -> torch.nn.Module: self.conv1d.export() if self.rmsnorm: self.norm.export() - super().export() - return self + return super().export() class _DynamicMambaLayer(DynamicModule, MambaTransformerLayerMixin): - """A ``megatron.core.ssm.mamba_layer.MambaLayer`` layer with dynamic hyperparams. + """A MambaLayer layer with dynamic hyperparams. Will be registered to DMRegistry if Mamba is available. """ @@ -1122,13 +1344,7 @@ def export(self): self._export_mixin() self.mixer.export() self.norm.export() - super().export() - return self - - def freeze(self): - """Freeze the hyperparameters.""" - self.mixer.freeze() - super().freeze() + return super().export() if HAS_MAMBA: @@ -1145,9 +1361,10 @@ def freeze(self): ) +# GPTModel / MambaModel DynamicModule ############################################################## @DMRegistry.register(SUPPORTED_MODELS) class _DynamicMCoreLanguageModel(DynamicModule): - """A ``megatron.core.models.gpt.GPTModel`` model with dynamic hyperparams.""" + """A GPTModel / MambaModel with dynamic hyperparams.""" def _setup(self): assert self.config.tensor_model_parallel_size == 1, "Only TP=1 is supported." @@ -1207,7 +1424,10 @@ def _setup(self): self._emb_layernorm_forward_hook ) ) - if isinstance(layer.mlp, MLP): + + # Handle both regular MLP and MoE layers + if isinstance(layer.mlp, (MLP, MoELayer)): + # MoE layer - register hook on pre_mlp_layernorm self.hook_handles.append( layer.pre_mlp_layernorm.register_forward_hook( self._emb_layernorm_forward_hook @@ -1262,6 +1482,7 @@ def modify( ffn_hidden_size_divisor: int = 1, mamba_num_heads_divisor: int = 1, mamba_head_dim_divisor: int = 1, + num_moe_experts_divisor: int = 1, ): """Modify the dynamic choices of the module according to provided keyword arguments. @@ -1272,6 +1493,7 @@ def modify( ffn_hidden_size_divisor: The divisor of the mlp ffn_hidden_size. mamba_num_heads_divisor: The divisor of the mamba num_heads. mamba_head_dim_divisor: The divisor of the mamba head_dim. + num_moe_experts_divisor: The divisor of the number of MoE experts. """ hp = self.get_hparam("hidden_size") choices = {int(make_divisible(c, hidden_size_divisor)) for c in hp.choices} # type: ignore[arg-type] @@ -1284,6 +1506,7 @@ def modify( ffn_hidden_size_divisor=ffn_hidden_size_divisor, mamba_num_heads_divisor=mamba_num_heads_divisor, mamba_head_dim_divisor=mamba_head_dim_divisor, + num_moe_experts_divisor=num_moe_experts_divisor, ) def _get_layer_scores(self) -> dict[int, torch.Tensor]: @@ -1331,14 +1554,7 @@ def export(self) -> torch.nn.Module: if is_pipeline_last_stage(): getattr(self.decoder, self.final_norm_attr_name).export() self.output_layer.export() - super().export() - return self - - def freeze(self) -> None: - """Freeze the dynamic module.""" - super().freeze() - for layer in self.decoder.layers: - layer.freeze() + return super().export() def get_activations_and_layer_scores( self, diff --git a/modelopt/torch/prune/plugins/mcore_minitron.py b/modelopt/torch/prune/plugins/mcore_minitron.py index 5f94c3175..887fb6ec4 100644 --- a/modelopt/torch/prune/plugins/mcore_minitron.py +++ b/modelopt/torch/prune/plugins/mcore_minitron.py @@ -60,14 +60,21 @@ from ..pruning import PruneModeRegistry SUPPORTED_HPARAMS = { - # Width pruning + # 1. Width pruning + "hidden_size", + # MLP "ffn_hidden_size", + # Attention "num_attention_heads", "num_query_groups", - "hidden_size", + # Mamba "mamba_num_heads", "mamba_head_dim", - # Depth pruning + # MoE + "moe_ffn_hidden_size", + "moe_shared_expert_intermediate_size", + "num_moe_experts", + # 2. Depth pruning "num_layers", } @@ -144,12 +151,18 @@ def before_search(self) -> None: ) self.hps_to_sort.add("num_heads_per_group") + configurable_hp_names = (SUPPORTED_HPARAMS | {"num_heads_per_group"}) - { + "num_attention_heads" + } for n, hp in named_hparams(self.model, unique=True): hp_name = n.split(".")[-1] - if hp.is_configurable and hp_name in export_config: - assert export_config[hp_name] in hp.choices, ( - f"Invalid choice {export_config[hp_name]} for {n}! Available choices: {hp.choices}" - ) + if hp.is_configurable: + # Make sure configurable hparams are the ones with right names else implementation needs to be fixed! + assert hp_name in configurable_hp_names, f"[ImplError] Invalid hparam {hp_name}!" + if hp_name in export_config: + assert export_config[hp_name] in hp.choices, ( + f"Invalid choice {export_config[hp_name]} for {n}! Available choices: {hp.choices}" + ) hp.reset_choices() # Make sure ConcatHparam choices are updated after modify() def run_search(self) -> None: @@ -218,6 +231,7 @@ def run_search(self) -> None: "num_heads_per_group_divisor": 1, "num_query_groups_divisor": 1, "ffn_hidden_size_divisor": 64, + "num_moe_experts_divisor": 1, }, **( { @@ -228,6 +242,7 @@ def run_search(self) -> None: "ffn_hidden_size_divisor": 64, "mamba_num_heads_divisor": 4, "mamba_head_dim_divisor": 4, + "num_moe_experts_divisor": 1, } } if HAS_MAMBA diff --git a/tests/_test_utils/torch/megatron/models.py b/tests/_test_utils/torch/megatron/models.py index 9607d7cc2..2457bb8b3 100644 --- a/tests/_test_utils/torch/megatron/models.py +++ b/tests/_test_utils/torch/megatron/models.py @@ -12,6 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import textwrap from warnings import warn import torch @@ -30,6 +31,7 @@ ) from megatron.core.models.mamba import MambaModel from megatron.core.parallel_state import is_pipeline_first_stage, is_pipeline_last_stage +from megatron.core.ssm.mamba_hybrid_layer_allocation import Symbols from megatron.core.tensor_parallel.layers import ColumnParallelLinear, RowParallelLinear from megatron.core.transformer.module import MegatronModule from megatron.core.transformer.transformer_config import TransformerConfig @@ -140,10 +142,13 @@ def get_mcore_gpt_model( normalization: str = "LayerNorm", transformer_impl: str = "modelopt" if HAS_TE else "local", use_cpu_initialization: bool = False, - num_moe_experts: int | None = None, - moe_grouped_gemm: bool = False, bf16: bool = True, use_te: bool = False, + # MoE-specific parameters + moe_grouped_gemm: bool = False, + moe_ffn_hidden_size: int | None = None, + moe_shared_expert_intermediate_size: int | None = None, + num_moe_experts: int | None = None, ) -> GPTModel: assert activation_func in ["swiglu", "squared_relu"] assert normalization in ["LayerNorm", "RMSNorm"] @@ -167,7 +172,6 @@ def squared_relu(x): expert_model_parallel_size=expert_model_parallel_size, expert_tensor_parallel_size=expert_tensor_parallel_size, sequence_parallel=False, - moe_grouped_gemm=moe_grouped_gemm, num_layers=num_layers, num_layers_in_first_pipeline_stage=num_layers_in_first_pipeline_stage, num_layers_in_last_pipeline_stage=num_layers_in_last_pipeline_stage, @@ -175,7 +179,6 @@ def squared_relu(x): num_attention_heads=num_attention_heads, num_query_groups=num_query_groups, ffn_hidden_size=ffn_hidden_size, - num_moe_experts=num_moe_experts, activation_func=squared_relu if activation_func == "squared_relu" else F.silu, normalization=normalization, gated_linear_unit=(activation_func == "swiglu"), @@ -183,6 +186,12 @@ def squared_relu(x): use_cpu_initialization=use_cpu_initialization, pipeline_dtype=torch.bfloat16 if bf16 else torch.float32, bf16=bf16, + # MoE-specific parameters + moe_grouped_gemm=moe_grouped_gemm, + moe_router_dtype="fp32", + moe_ffn_hidden_size=moe_ffn_hidden_size, + moe_shared_expert_intermediate_size=moe_shared_expert_intermediate_size, + num_moe_experts=num_moe_experts, ) if transformer_impl == "local": @@ -217,11 +226,7 @@ def squared_relu(x): share_embeddings_and_output_weights=False, position_embedding_type="rope", ) - - if bf16: - model = model.to(torch.bfloat16) - - return model + return model.to(torch.bfloat16) if bf16 else model def get_mcore_qwen3_600m( @@ -275,7 +280,7 @@ def get_mcore_qwen3_600m( return model -def get_mcore_mamba_model( +def get_mcore_mamba_hybrid_model( tensor_model_parallel_size: int = 1, pipeline_model_parallel_size: int = 1, initialize_megatron: bool = False, @@ -295,7 +300,19 @@ def get_mcore_mamba_model( mamba_state_dim: int = 32, mamba_head_dim: int = 16, mamba_num_groups: int = 2, + # MoE-specific parameters + skip_moe: bool = False, + moe_ffn_hidden_size: int | None = 64, + moe_shared_expert_intermediate_size: int | None = 32, + num_moe_experts: int | None = 8, ) -> MambaModel: + """Builds a Mamba model with hybrid layer allocation (Mamba, MoE, Attention, MLP blocks). + + Notable Args: + hybrid_override_pattern: The hybrid layer pattern to override with. + If None, a default pattern will be generated. + skip_moe: Whether to skip MoE blocks in default hybrid pattern. + """ assert HAS_MAMBA, "Mamba not installed" if initialize_megatron: @@ -315,18 +332,46 @@ def get_mcore_mamba_model( mamba_state_dim=mamba_state_dim, mamba_head_dim=mamba_head_dim, mamba_num_groups=mamba_num_groups, + num_moe_experts=num_moe_experts, + moe_ffn_hidden_size=moe_ffn_hidden_size, + moe_shared_expert_intermediate_size=moe_shared_expert_intermediate_size, + moe_router_enable_expert_bias=True, + moe_router_score_function="sigmoid", + add_bias_linear=False, pipeline_dtype=torch.bfloat16 if bf16 else torch.float32, bf16=bf16, ) + if not (skip_moe or "E" in Symbols.VALID): + warn("MoE blocks are not supported in current MambaModel. Skipping MoE blocks.") + skip_moe = True + if hybrid_override_pattern is None: - # Generate pattern by repeating "M*-" and trimming to match num_layers - # For num_layers=3, return "M*-" (Mamba -> Attention -> MLP) - # For num_layers=5, return "M*-M*" (Mamba -> Attention -> MLP -> Mamba -> Attention) - hybrid_override_pattern = ("M*-" * num_layers)[:num_layers] - else: - assert len(hybrid_override_pattern) == num_layers - print(f"Using `{hybrid_override_pattern=}` for building Mamba Model.") + # Generate pattern by repeating base_pattern and trimming to match num_layers + # E.g. for num_layers=3, return "MEM" (Mamba -> MoE -> Mamba) + # E.g. for num_layers=6, return "MEM*M-" (Mamba -> MoE -> Attention -> MoE -> MLP) + base_pattern = "M*M-" if skip_moe else "MEM*M-" + hybrid_override_pattern = (base_pattern * num_layers)[:num_layers] + + # Add | symbols for Pipeline parallelism (required for MCore 0.16+) + # E.g. MEM* with PP2 becomes ME|M* and MEM*M-ME with PP2 becomes MEM*|M-ME + if pipeline_model_parallel_size > 1 and "|" in Symbols.VALID: + if "|" not in hybrid_override_pattern: + assert ( + num_layers_in_first_pipeline_stage is None + and num_layers_in_last_pipeline_stage is None + ), "hybrid_override_pattern with `|` must be provided for uneven PP" + hybrid_override_pattern = "|".join( + textwrap.wrap( + hybrid_override_pattern, + width=num_layers // pipeline_model_parallel_size, + break_long_words=True, + break_on_hyphens=False, + ) + ) + assert hybrid_override_pattern.count("|") == pipeline_model_parallel_size - 1 + assert len(hybrid_override_pattern.replace("|", "")) == num_layers + print(f"Using `{hybrid_override_pattern=}` for building MambaModel") model = MambaModel( config=config, @@ -339,6 +384,4 @@ def get_mcore_mamba_model( share_embeddings_and_output_weights=False, position_embedding_type="none", ) - if bf16: - model = model.to(torch.bfloat16) - return model + return model.to(torch.bfloat16) if bf16 else model diff --git a/tests/_test_utils/torch/misc.py b/tests/_test_utils/torch/misc.py index 594d73897..a4ebe2f18 100644 --- a/tests/_test_utils/torch/misc.py +++ b/tests/_test_utils/torch/misc.py @@ -21,13 +21,21 @@ from modelopt.torch.utils import flatten_tree -def compare_outputs(out1, out2, rtol=1e-5, atol=1e-8): +def compare_outputs(out1, out2, rtol=1e-5, atol=1e-8, debug=False): out1, _ = flatten_tree(out1) out2, _ = flatten_tree(out2) - assert all( - torch.allclose(t1.to(torch.float32), t2.to(torch.float32), rtol, atol) - for t1, t2 in zip(out1, out2) - ) + for i, (t1, t2) in enumerate(zip(out1, out2)): + if debug: + diff = torch.abs(t1 - t2) + print(f"\n{i=}") + print(f"{t1=}") + print(f"{t2=}") + print(f"{diff=}") + print(f"{diff.shape=}") + print(f"{diff.min()=}") + print(f"{diff.max()=}") + print(f"{diff.mean()=}") + assert torch.allclose(t1.to(torch.float32), t2.to(torch.float32), rtol, atol) def set_seed(seed_value=42): diff --git a/tests/gpu/torch/export/test_unified_export_megatron.py b/tests/gpu/torch/export/test_unified_export_megatron.py index 2c63fae44..c07c2b565 100644 --- a/tests/gpu/torch/export/test_unified_export_megatron.py +++ b/tests/gpu/torch/export/test_unified_export_megatron.py @@ -61,7 +61,7 @@ def _test_unified_export_megatron(tmp_path, model_type, arch, algo, rank, size): activation_func=activation_func, normalization=normalization, transformer_impl="modelopt", - ) + ).cuda() if algo == "medusa": config = { @@ -150,7 +150,7 @@ def _test_unified_import_megatron(tiny_llama_dir, rank, size): vocab_size=vocab_size, activation_func=activation_func, normalization=normalization, - ) + ).cuda() import_mcore_gpt_from_hf(model, tiny_llama_dir) diff --git a/tests/gpu/torch/nas/plugins/test_megatron_gpt_dynamic_modules.py b/tests/gpu/torch/nas/plugins/test_megatron_gpt_dynamic_modules.py index 7fdabb22d..67158eac0 100644 --- a/tests/gpu/torch/nas/plugins/test_megatron_gpt_dynamic_modules.py +++ b/tests/gpu/torch/nas/plugins/test_megatron_gpt_dynamic_modules.py @@ -18,6 +18,7 @@ import pytest import torch from _test_utils.import_helper import skip_if_no_megatron +from _test_utils.torch.misc import compare_outputs skip_if_no_megatron(apex_or_te_required=True) @@ -36,22 +37,25 @@ from megatron.core.transformer.transformer_layer import TransformerLayer import modelopt.torch.nas as mtn +from modelopt.torch.nas.conversion import export_searchspace +from modelopt.torch.nas.modules import DynamicModuleList from modelopt.torch.nas.plugins.megatron import ( _DynamicColumnParallelLinear, _DynamicMCoreLanguageModel, _DynamicMLP, + _DynamicMoELayer, _DynamicProjRowParallelLinear, _DynamicQKVColumnParallelLinear, _DynamicRowParallelLinear, _DynamicSelfAttention, + _DynamicSequentialMLP, + _DynamicTopKRouter, _DynamicTransformerLayer, _DynamicVocabParallelEmbedding, expand_head_indices, ) -from modelopt.torch.nas.registry import DMRegistry from modelopt.torch.opt.utils import named_dynamic_modules, search_space_size from modelopt.torch.prune.plugins.mcore_minitron import _convert_model_to_dynamic_space -from modelopt.torch.utils import flatten_tree from modelopt.torch.utils.random import centroid SEED = 1234 @@ -82,7 +86,7 @@ def _test_gpt_search_space( vocab_size=vocab_size, activation_func=activation_func, normalization=normalization, - ) + ).cuda() model = mtn.convert(model, "mcore_minitron") @@ -171,7 +175,7 @@ def _test_gpt_parameter_sorting(activation_func, rank, size): vocab_size=vocab_size, activation_func=activation_func, bf16=False, - ) + ).cuda() # Randomize layernorm weights instead of all zeros or ones for n, m in model.named_modules(): @@ -198,18 +202,11 @@ def _test_gpt_parameter_sorting(activation_func, rank, size): # 3 hps per layer + 1 for hidden_size (num_layers is not sorted!) assert len(sortable_per_pp) == 3 * num_layers // size + 1 - # Export since sorting force reassigns SelfAttention weights which we dont want to re-sort! - # TODO: ideally we shouldn't need this - dynamic_space.export(DMRegistry) - # sanity check if the model functionality is preserved after sorting y2 = run_mcore_inference(model, prompt_tokens) # check if the inference results after sorting is the same - assert all( - torch.allclose(t1, t2, rtol=1e-5, atol=1e-3) - for t1, t2 in zip(flatten_tree(y1)[0], flatten_tree(y2)[0]) - ) + compare_outputs(y1, y2, rtol=1e-5, atol=1e-3) @pytest.mark.parametrize("activation_func", ["swiglu"]) @@ -228,7 +225,7 @@ def test_expand_head_indices(): assert expand_head_indices(heads, hidden_size_per_head).tolist() == [2, 3, 6, 7, 4, 5, 0, 1] -def test_megatron_self_attention_head_sorting(distributed_setup_size_1): +def test_self_attention_head_sorting(distributed_setup_size_1): model = get_mcore_gpt_model( tensor_model_parallel_size=1, pipeline_model_parallel_size=1, @@ -239,7 +236,7 @@ def test_megatron_self_attention_head_sorting(distributed_setup_size_1): num_query_groups=2, ffn_hidden_size=16, activation_func="squared_relu", - ) + ).cuda() model = mtn.convert(model, "mcore_minitron") @@ -289,3 +286,152 @@ def test_megatron_self_attention_head_sorting(distributed_setup_size_1): # Clean up since this is not a spawned process destroy_model_parallel() + + +def _test_gpt_moe_search_space(rank, size): + channel_divisor = 64 + + num_layers = min(size * 2, 8) + hidden_size = 256 + num_attention_heads = 8 + num_query_groups = 4 + moe_ffn_hidden_size = 128 + num_moe_experts = 4 + moe_shared_expert_intermediate_size = 256 + max_sequence_length = 16 + vocab_size = 64 + batch_size = 2 + + model = get_mcore_gpt_model( + tensor_model_parallel_size=1, + pipeline_model_parallel_size=size, + initialize_megatron=True, + num_layers=num_layers, + hidden_size=hidden_size, + num_attention_heads=num_attention_heads, + num_query_groups=num_query_groups, + max_sequence_length=max_sequence_length, + vocab_size=vocab_size, + activation_func="squared_relu", + num_moe_experts=num_moe_experts, + moe_ffn_hidden_size=moe_ffn_hidden_size, + moe_shared_expert_intermediate_size=moe_shared_expert_intermediate_size, + ).cuda() + + model = mtn.convert(model, "mcore_minitron") + + moe = model.decoder.layers[0].mlp + assert isinstance(moe, _DynamicMoELayer) + assert isinstance(moe.router, _DynamicTopKRouter) + assert isinstance(moe.experts, _DynamicSequentialMLP) + assert isinstance(moe.experts.local_experts, DynamicModuleList) + for expert in moe.experts.local_experts: + assert isinstance(expert, _DynamicMLP) + assert isinstance(moe.shared_experts, _DynamicMLP) + + # NOTE: `search_space_size` does not reduce across TP/PP groups + ss_size_per_pp = search_space_size(model) + moe_ffn_choices = moe_ffn_hidden_size // channel_divisor + moe_shared_ffn_choices = moe_shared_expert_intermediate_size // channel_divisor + hidden_size_choices = hidden_size // channel_divisor + num_layers_per_pp = num_layers // size + assert ( + ss_size_per_pp + == ( + num_attention_heads + * num_moe_experts + * moe_ffn_choices**num_moe_experts + * moe_shared_ffn_choices + ) + ** num_layers_per_pp + * num_layers + * hidden_size_choices + ) + + # Make sure forward pass works on min and centroid subnets + prompt_tokens = torch.randint(0, vocab_size, (batch_size, max_sequence_length)).cuda() + output = run_mcore_inference(model, prompt_tokens) + assert output.shape == (batch_size, max_sequence_length, vocab_size) + + # Make sure export and forward pass works on centroid model + mtn.sample(model, centroid) + mtn.export(model) + _ = run_mcore_inference(model, prompt_tokens, model.hidden_size) + assert not any(named_dynamic_modules(model)) + + +def test_gpt_moe_search_space(): + spawn_multiprocess_job( + size=torch.cuda.device_count(), job=_test_gpt_moe_search_space, backend="nccl" + ) + + +def _test_gpt_moe_parameter_sorting(rank, size): + num_layers = min(size * 2, 8) + hidden_size = 256 + num_attention_heads = 8 + num_query_groups = 4 + moe_ffn_hidden_size = 128 + num_moe_experts = 4 + moe_shared_expert_intermediate_size = 256 + max_sequence_length = 16 + vocab_size = 64 + batch_size = 2 + + model = get_mcore_gpt_model( + tensor_model_parallel_size=1, + pipeline_model_parallel_size=size, + initialize_megatron=True, + num_layers=num_layers, + hidden_size=hidden_size, + num_attention_heads=num_attention_heads, + num_query_groups=num_query_groups, + max_sequence_length=max_sequence_length, + vocab_size=vocab_size, + activation_func="squared_relu", + num_moe_experts=num_moe_experts, + moe_ffn_hidden_size=moe_ffn_hidden_size, + moe_shared_expert_intermediate_size=moe_shared_expert_intermediate_size, + ).cuda() + + # Randomize layernorm weights instead of all zeros or ones + for n, m in model.named_modules(): + if "layernorm" in n and not isinstance(m, IdentityOp): + m.weight.data = torch.randn_like(m.weight) + + model.eval() + dynamic_space = _convert_model_to_dynamic_space(model) + + # Compute activations for sorting + for _ in range(10): + run_mcore_inference_with_dummy_input(model, batch_size) + + # Get the output of the original model + prompt_tokens = torch.randint(0, vocab_size, (batch_size, max_sequence_length)).cuda() + y1 = run_mcore_inference(model, prompt_tokens) + + mtn.utils.sort_parameters(model) + + # check if all num_moe_experts, moe_ffn, moe_shared_ffn, num_heads_per_group, num_query_groups, hidden_size + # have been sorted + sortable_per_pp = [ + n for n, hp in dynamic_space.named_hparams(configurable=True) if hp.importance is not None + ] + # (num_moe_experts + 4) hps per layer + 1 for hidden_size (num_layers is not sorted!) + assert len(sortable_per_pp) == (num_moe_experts + 4) * num_layers // size + 1 + + # sanity check if the model functionality is preserved after sorting + export_searchspace(model, mtn.get_subnet_config(model)) + y2 = run_mcore_inference(model, prompt_tokens) + + # check if the inference results after sorting is the same + compare_outputs(y1, y2, rtol=1e-5, atol=1e-2) + + +def test_gpt_moe_parameter_sorting(need_2_gpus): + set_seed(SEED) + spawn_multiprocess_job( + size=torch.cuda.device_count(), + job=_test_gpt_moe_parameter_sorting, + backend="nccl", + ) diff --git a/tests/gpu/torch/nas/plugins/test_megatron_mamba_dynamic_modules.py b/tests/gpu/torch/nas/plugins/test_megatron_mamba_dynamic_modules.py index 0ac82db18..4f5b0b852 100644 --- a/tests/gpu/torch/nas/plugins/test_megatron_mamba_dynamic_modules.py +++ b/tests/gpu/torch/nas/plugins/test_megatron_mamba_dynamic_modules.py @@ -16,11 +16,12 @@ import torch from _test_utils.import_helper import skip_if_no_megatron +from _test_utils.torch.misc import compare_outputs skip_if_no_megatron(apex_or_te_required=True, mamba_required=True) from _test_utils.torch.distributed.utils import spawn_multiprocess_job -from _test_utils.torch.megatron.models import get_mcore_mamba_model +from _test_utils.torch.megatron.models import get_mcore_mamba_hybrid_model from _test_utils.torch.megatron.utils import ( run_mcore_inference, run_mcore_inference_with_dummy_input, @@ -46,7 +47,6 @@ from modelopt.torch.nas.traced_hp import TracedHp from modelopt.torch.opt.utils import named_dynamic_modules, search_space_size from modelopt.torch.prune.plugins.mcore_minitron import _convert_model_to_dynamic_space -from modelopt.torch.utils import flatten_tree from modelopt.torch.utils.random import centroid SEED = 1234 @@ -67,7 +67,7 @@ def _test_mamba_search_space(rank, size): vocab_size = 32 batch_size = 2 - model = get_mcore_mamba_model( + model = get_mcore_mamba_hybrid_model( tensor_model_parallel_size=1, pipeline_model_parallel_size=size, initialize_megatron=True, @@ -79,7 +79,7 @@ def _test_mamba_search_space(rank, size): mamba_num_groups=mamba_num_groups, max_sequence_length=max_sequence_length, vocab_size=vocab_size, - ) + ).cuda() mamba_num_heads = model.decoder.layers[0].mixer.nheads model = mtn.convert(model, "mcore_minitron") @@ -142,7 +142,7 @@ def _test_mamba_parameter_sorting(rank, size): vocab_size = 64 batch_size = 2 - model = get_mcore_mamba_model( + model = get_mcore_mamba_hybrid_model( tensor_model_parallel_size=1, pipeline_model_parallel_size=size, initialize_megatron=True, @@ -155,7 +155,7 @@ def _test_mamba_parameter_sorting(rank, size): max_sequence_length=max_sequence_length, vocab_size=vocab_size, bf16=False, - ) + ).cuda() # Randomize norm weights instead of all zeros or ones for n, m in model.named_modules(): @@ -186,10 +186,7 @@ def _test_mamba_parameter_sorting(rank, size): y2 = run_mcore_inference(model, prompt_tokens) # check if the inference results after sorting is the same - assert all( - torch.allclose(t1, t2, rtol=1e-5, atol=1e-3) - for t1, t2 in zip(flatten_tree(y1)[0], flatten_tree(y2)[0]) - ) + compare_outputs(y1, y2, rtol=1e-5, atol=1e-3) def test_mamba_parameter_sorting(need_2_gpus): diff --git a/tests/gpu/torch/prune/plugins/test_mcore_gpt_minitron_pruning.py b/tests/gpu/torch/prune/plugins/test_mcore_gpt_minitron_pruning.py index 81656506d..2efea7e3d 100644 --- a/tests/gpu/torch/prune/plugins/test_mcore_gpt_minitron_pruning.py +++ b/tests/gpu/torch/prune/plugins/test_mcore_gpt_minitron_pruning.py @@ -85,7 +85,7 @@ def _get_model(initialize_megatron=True): normalization=normalization, num_layers_in_first_pipeline_stage=num_layers_in_first_pipeline_stage, num_layers_in_last_pipeline_stage=num_layers_in_last_pipeline_stage, - ) + ).cuda() return model model = _get_model() @@ -233,7 +233,112 @@ def test_mcore_gpt_pruning( num_layers_div, uneven_pp, skip_sorting, - tmp_path / "modelopt_minitron_scores.pth" if test_ckpt else None, + tmp_path / "minitron_scores.pth" if test_ckpt else None, ), backend="nccl", ) + + +def _test_mcore_gpt_pruning_moe(ckpt_path, rank, size): + num_layers = size + hidden_size = 128 + moe_ffn_hidden_size = 128 + num_moe_experts = 4 + moe_shared_expert_intermediate_size = 256 + max_sequence_length = 16 + vocab_size = 64 + batch_size = 2 + + def _get_model(initialize_megatron=True): + model = get_mcore_gpt_model( + tensor_model_parallel_size=1, + pipeline_model_parallel_size=size, + initialize_megatron=initialize_megatron, + num_layers=num_layers, + hidden_size=hidden_size, + max_sequence_length=max_sequence_length, + vocab_size=vocab_size, + activation_func="squared_relu", + num_moe_experts=num_moe_experts, + moe_ffn_hidden_size=moe_ffn_hidden_size, + moe_shared_expert_intermediate_size=moe_shared_expert_intermediate_size, + ).cuda() + return model + + model = _get_model() + sd = model.state_dict() + + def forward_loop(m): + for _ in range(5): + run_mcore_inference_with_dummy_input(m, batch_size, hidden_size) + + pruned_hidden_size = hidden_size // 2 + pruned_moe_ffn = moe_ffn_hidden_size // 2 + pruned_moe_shared_ffn = moe_shared_expert_intermediate_size // 2 + pruned_num_moe_experts = num_moe_experts // 2 + + export_config = { + "hidden_size": pruned_hidden_size, + "moe_ffn_hidden_size": pruned_moe_ffn, + "moe_shared_expert_intermediate_size": pruned_moe_shared_ffn, + "num_moe_experts": pruned_num_moe_experts, + } + + mtp.prune( + model, + mode="mcore_minitron", + constraints={"export_config": export_config}, + dummy_input=None, # Not used + config={"scores_path": ckpt_path, "forward_loop": forward_loop}, + ) + + # Assert weights are pruned correctly + for layer in model.decoder.layers: + moe = layer.mlp + assert moe.router.num_experts == pruned_num_moe_experts + assert moe.router.weight.shape == (pruned_num_moe_experts, pruned_hidden_size) + assert moe.experts.num_local_experts == pruned_num_moe_experts + assert len(moe.experts.local_experts) == pruned_num_moe_experts + for expert in moe.experts.local_experts: + assert expert.linear_fc1.weight.shape == (pruned_moe_ffn, pruned_hidden_size) + assert expert.linear_fc2.weight.shape == (pruned_hidden_size, pruned_moe_ffn) + assert moe.shared_experts.linear_fc1.weight.shape == ( + pruned_moe_shared_ffn, + pruned_hidden_size, + ) + assert moe.shared_experts.linear_fc2.weight.shape == ( + pruned_hidden_size, + pruned_moe_shared_ffn, + ) + + # Assert model.config is updated for correct save/restoring + assert model.config.hidden_size == pruned_hidden_size + assert model.config.moe_ffn_hidden_size == pruned_moe_ffn + assert model.config.num_moe_experts == pruned_num_moe_experts + assert model.config.moe_shared_expert_intermediate_size == pruned_moe_shared_ffn + + # Assert forward pass works on the pruned model + prompt_tokens = torch.randint(0, vocab_size, (batch_size, max_sequence_length)).cuda() + output = run_mcore_inference(model, prompt_tokens, pruned_hidden_size) + + # Assert re-pruning from scores_path works without running the forward loop again + model_rerun = _get_model(initialize_megatron=False) + model_rerun.load_state_dict(sd) + mtp.prune( + model_rerun, + mode="mcore_minitron", + constraints={"export_config": export_config}, + dummy_input=None, # Not used + config={"scores_path": ckpt_path}, + ) + + output_rerun = run_mcore_inference(model_rerun, prompt_tokens, pruned_hidden_size) + assert torch.allclose(output, output_rerun, atol=1e-5) + + +def test_mcore_gpt_pruning_moe(tmp_path): + spawn_multiprocess_job( + size=torch.cuda.device_count(), + job=partial(_test_mcore_gpt_pruning_moe, tmp_path / "minitron_scores.pth"), + backend="nccl", + ) diff --git a/tests/gpu/torch/prune/plugins/test_mcore_mamba_minitron_pruning.py b/tests/gpu/torch/prune/plugins/test_mcore_mamba_minitron_pruning.py index 96ff93537..ae6e5fcbf 100644 --- a/tests/gpu/torch/prune/plugins/test_mcore_mamba_minitron_pruning.py +++ b/tests/gpu/torch/prune/plugins/test_mcore_mamba_minitron_pruning.py @@ -22,14 +22,14 @@ skip_if_no_megatron(apex_or_te_required=True, mamba_required=True) from _test_utils.torch.distributed.utils import spawn_multiprocess_job -from _test_utils.torch.megatron.models import get_mcore_mamba_model +from _test_utils.torch.megatron.models import get_mcore_mamba_hybrid_model from _test_utils.torch.megatron.utils import run_mcore_inference_with_dummy_input from megatron.core.ssm.mamba_layer import MambaLayer import modelopt.torch.prune as mtp -def _test_mcore_mamba_pruning(ckpt_path, rank, size): +def _test_mcore_mamba_hybrid_pruning(ckpt_path, rank, size): num_layers = min(size * 2, 8) hidden_size = 256 ffn_hidden_size = 128 @@ -38,10 +38,11 @@ def _test_mcore_mamba_pruning(ckpt_path, rank, size): mamba_state_dim = 64 mamba_head_dim = 16 mamba_num_groups = 2 + num_moe_experts = 8 batch_size = 2 def _get_model(initialize_megatron=True): - model = get_mcore_mamba_model( + model = get_mcore_mamba_hybrid_model( tensor_model_parallel_size=1, pipeline_model_parallel_size=size, initialize_megatron=initialize_megatron, @@ -49,10 +50,14 @@ def _get_model(initialize_megatron=True): hidden_size=hidden_size, num_attention_heads=num_attention_heads, num_query_groups=num_query_groups, + ffn_hidden_size=ffn_hidden_size, mamba_state_dim=mamba_state_dim, mamba_head_dim=mamba_head_dim, mamba_num_groups=mamba_num_groups, - ) + moe_ffn_hidden_size=ffn_hidden_size, + moe_shared_expert_intermediate_size=ffn_hidden_size, + num_moe_experts=num_moe_experts, + ).cuda() return model model = _get_model() @@ -74,6 +79,7 @@ def forward_loop(m): pruned_num_attention_heads = num_attention_heads // 2 pruned_num_query_groups = num_query_groups // 2 pruned_hidden_size = hidden_size // 2 + pruned_num_moe_experts = num_moe_experts // 2 # Mamba-specific pruning parameters pruned_mamba_num_heads = mamba_num_heads // 2 @@ -87,6 +93,9 @@ def forward_loop(m): "hidden_size": pruned_hidden_size, "mamba_num_heads": pruned_mamba_num_heads, "mamba_head_dim": pruned_mamba_head_dim, + "moe_ffn_hidden_size": pruned_ffn_hidden_size, + "moe_shared_expert_intermediate_size": pruned_ffn_hidden_size, + "num_moe_experts": pruned_num_moe_experts, } mtp.prune( model, @@ -115,6 +124,9 @@ def forward_loop(m): assert model.config.hidden_size == pruned_hidden_size assert model.config.mamba_num_heads == pruned_mamba_num_heads assert model.config.mamba_head_dim == pruned_mamba_head_dim + assert model.config.moe_ffn_hidden_size == pruned_ffn_hidden_size + assert model.config.moe_shared_expert_intermediate_size == pruned_ffn_hidden_size + assert model.config.num_moe_experts == pruned_num_moe_experts # Assert forward pass works on the pruned model run_mcore_inference_with_dummy_input(model, batch_size, pruned_hidden_size) @@ -130,9 +142,9 @@ def forward_loop(m): ) -def test_mcore_mamba_pruning(tmp_path): +def test_mcore_mamba_hybrid_pruning(tmp_path): spawn_multiprocess_job( size=torch.cuda.device_count(), - job=partial(_test_mcore_mamba_pruning, tmp_path / "modelopt_minitron_scores.pth"), + job=partial(_test_mcore_mamba_hybrid_pruning, tmp_path / "modelopt_minitron_scores.pth"), backend="nccl", ) diff --git a/tests/unit/torch/nas/modules/test_sequential.py b/tests/unit/torch/nas/modules/test_container.py similarity index 73% rename from tests/unit/torch/nas/modules/test_sequential.py rename to tests/unit/torch/nas/modules/test_container.py index 58b55265b..1f235ab78 100644 --- a/tests/unit/torch/nas/modules/test_sequential.py +++ b/tests/unit/torch/nas/modules/test_container.py @@ -19,7 +19,9 @@ from torch import nn from torchvision.models.mobilenetv2 import InvertedResidual +from modelopt.torch.nas.modules import DynamicModuleList from modelopt.torch.nas.registry import DMRegistry +from modelopt.torch.nas.utils import sort_parameters class ModuleContainerWrapper: @@ -104,3 +106,37 @@ def test_contained_dynamic_module(): seq = dynamic_seq.export() out = seq(input) assert torch.allclose(out, out_dynamic) + + +def test_dynamic_module_list(): + m0 = nn.Conv2d(3, 8, 3, bias=False) + m1 = nn.Linear(4, 8) + m2 = nn.ReLU() + m = nn.ModuleList([m0, m1, m2]) + assert list(m.state_dict().keys()) == ["0.weight", "1.weight", "1.bias"] + + # Test convert + DynamicModuleList.convert(m) + assert isinstance(m, DynamicModuleList) + assert m.get_hparam("depth").choices == [1, 2, 3] + + # Test trimming depth + m.depth = 1 + assert len(m) == 1 and m[0] == m0 + assert list(m.state_dict().keys()) == ["0.weight"] + + # Test sorting by importance + hp_depth = m.get_hparam("depth") + hp_depth.register_importance(lambda: torch.tensor([0.8, 0.5, 1.0])) + sort_parameters(m) + + m.depth = 3 + assert m[0] == m2 and m[1] == m0 and m[2] == m1 + assert list(m.state_dict().keys()) == ["1.weight", "2.weight", "2.bias"] + + # Test export + hp_depth.active = 2 + m.export() + assert not isinstance(m, DynamicModuleList) and isinstance(m, nn.ModuleList) + assert len(m) == 2 and m[0] == m2 and m[1] == m0 + assert list(m.state_dict().keys()) == ["1.weight"]