diff --git a/gpt_builders.py b/gpt_builders.py index dfe41f7b88e..c8b4efa3075 100644 --- a/gpt_builders.py +++ b/gpt_builders.py @@ -136,6 +136,7 @@ def _get_transformer_layer_spec(use_te, config): moe_use_legacy_grouped_gemm=args.moe_use_legacy_grouped_gemm, qk_l2_norm=args.qk_l2_norm, use_kitchen=config.use_kitchen, + use_te_activation_func=config.use_te_activation_func, use_kitchen_attention=config.use_kitchen_attention, kitchen_attention_backend=config.kitchen_attention_backend, ) diff --git a/megatron/core/models/gpt/moe_module_specs.py b/megatron/core/models/gpt/moe_module_specs.py index 62ee4537cfc..b558d5ae886 100755 --- a/megatron/core/models/gpt/moe_module_specs.py +++ b/megatron/core/models/gpt/moe_module_specs.py @@ -41,7 +41,7 @@ def get_moe_module_spec_for_backend( linear_fc1 = backend.column_parallel_linear() linear_fc2 = backend.row_parallel_linear() - activation_func = backend.activation_func() + activation_func = backend.activation_func() if use_te_activation_func else None mlp = MLPSubmodules( linear_fc1=linear_fc1, linear_fc2=linear_fc2, activation_func=activation_func