File tree 1 file changed +2
-7
lines changed
1 file changed +2
-7
lines changed Original file line number Diff line number Diff line change 14
14
from .utils import init_method_normal
15
15
from .utils import scaled_init_method_normal
16
16
17
- from megatron .model import LayerNorm
17
+ from megatron .model import LayerNorm , RMSNorm
18
18
from .language_model import EmbeddingPipe
19
19
from .transformer import ParallelTransformerLayerPipe , LMHeadPipe
20
20
from deepspeed .pipe import PipelineModule , LayerSpec , TiedLayerSpec
21
21
22
- try :
23
- from apex .normalization import MixedFusedRMSNorm
24
- except ImportError :
25
- MixedFusedRMSNorm = None
26
-
27
22
try :
28
23
from deepspeed .checkpoint import (
29
24
VOCABULARY_PARAMETER_PATTERNS ,
@@ -290,7 +285,7 @@ def _to_float16(inputs):
290
285
args .hidden_size ,
291
286
eps = args .layernorm_epsilon ))
292
287
else :
293
- self .specs .append (LayerSpec (MixedFusedRMSNorm , args .hidden_size , args .layernorm_epsilon ))
288
+ self .specs .append (LayerSpec (RMSNorm , args .hidden_size , args .layernorm_epsilon ))
294
289
295
290
def _logits_helper (embedding , lm_output ):
296
291
"""A wrapper to massage inputs/outputs from pipeline. """
You can’t perform that action at this time.
0 commit comments