diff --git a/megatron/core/transformer/moe/experts.py b/megatron/core/transformer/moe/experts.py index 62fb7a148c8..cf15faeaf90 100644 --- a/megatron/core/transformer/moe/experts.py +++ b/megatron/core/transformer/moe/experts.py @@ -222,7 +222,12 @@ def forward( permuted_probs: torch.Tensor, ): """Forward step of the GroupedMLP.""" - assert self.config.bf16, "Currently GroupedMLP for MoE only supports bf16." + assert self.config.bf16, ( + "The legacy GroupedMLP only supports bf16. " + "For FP16/FP8 support, please use TEGroupedMLP instead, which is adopted by default " + "when TransformerEngine >= 1.9 is installed and '--moe-use-legacy-grouped-gemm' is " + "*not* set." + ) if self.activation_recompute: self.activation_checkpoint = tensor_parallel.CheckpointWithoutOutput()