diff --git a/nemo_rl/models/megatron/setup.py b/nemo_rl/models/megatron/setup.py index 24bfdb0605..061767660b 100644 --- a/nemo_rl/models/megatron/setup.py +++ b/nemo_rl/models/megatron/setup.py @@ -527,15 +527,6 @@ def _validate_training_config(config: PolicyConfig, model_cfg: Any) -> None: model_cfg.calculate_per_token_loss = True model_cfg.perform_initialization = True - # MoE aux loss validation - assert ( - "aux_loss" not in model_cfg.moe_router_load_balancing_type - or model_cfg.moe_aux_loss_coeff == 0 - ), ( - "MoE aux loss is currently not supported due to a known bug in Megatron-LM. " - "See https://github.com/NVIDIA/Megatron-LM/issues/1984 for more details." - ) - def _validate_dtype_config( dtype: torch.dtype, model_cfg: Any, optimizer_cfg: Any