@@ -2582,6 +2582,8 @@ class Config(BaseLayer.Config):
25822582 # TODO (bwzhang@) Adding a unittest for the hybridnorm.
25832583 # v2: see comments on NormPosition for details.
25842584 structure : str = "prenorm"
2585+ # outputs = inputs + residual_weight * x.
2586+ residual_weight : Optional [float ] = None
25852587
25862588 def __init__ (self , cfg : Config , * , parent : Module ):
25872589 super ().__init__ (cfg , parent = parent )
@@ -2724,25 +2726,35 @@ def attention_thunk(target: Tensor) -> tuple[Optional[NestedTensor], Tensor]:
27242726 skip_input = target # pre-norm: where normalization happens within the residual part.
27252727 norm_target = self .norm (target )
27262728 atten_state , atten_output = attention_thunk (norm_target )
2727- data = skip_input + self .stochastic_depth (self .dropout (atten_output .data ))
2729+ data = self .stochastic_depth (self .dropout (atten_output .data ))
2730+ if cfg .residual_weight is not None and cfg .residual_weight != 1 :
2731+ data *= cfg .residual_weight
2732+ data = skip_input + data
27282733 elif cfg .structure == "postnorm" :
27292734 # This is the structure used by the original Transformer, BERT, and RoBERTa.
27302735 atten_state , atten_output = attention_thunk (target )
27312736 # Post-norm: norm applied on the sum of input and attention output.
2732- data = self .norm (target + self .stochastic_depth (self .dropout (atten_output .data )))
2737+ data = self .stochastic_depth (self .dropout (atten_output .data ))
2738+ if cfg .residual_weight is not None and cfg .residual_weight != 1 :
2739+ data *= cfg .residual_weight
2740+ data = self .norm (target + data )
27332741 elif cfg .structure == "hybridnorm" :
27342742 skip_input = target # pre-norm: where normalization happens within the residual part.
27352743 norm_target = self .prenorm (target )
27362744 atten_state , atten_output = attention_thunk (norm_target )
2737- data = skip_input + self .stochastic_depth (
2738- self .dropout (self .postnorm (atten_output .data ))
2739- )
2745+ data = self .stochastic_depth (self .dropout (self .postnorm (atten_output .data )))
2746+ if cfg .residual_weight is not None and cfg .residual_weight != 1 :
2747+ data *= cfg .residual_weight
2748+ data = skip_input + data
27402749 elif cfg .structure == "v2" :
27412750 norm_target = self .in_norm (target ) if NormPosition .IN_NORM in cfg .norm else target
27422751 atten_state , atten_output = attention_thunk (norm_target )
27432752 data = atten_output .data
27442753 data = self .res_norm (data ) if NormPosition .RES_NORM in cfg .norm else data
2745- data = target + self .stochastic_depth (self .dropout (data ))
2754+ data = self .stochastic_depth (self .dropout (data ))
2755+ if cfg .residual_weight is not None and cfg .residual_weight != 1 :
2756+ data *= cfg .residual_weight
2757+ data = target + data
27462758 data = self .out_norm (data ) if NormPosition .OUT_NORM in cfg .norm else data
27472759 else :
27482760 raise NotImplementedError (cfg .structure )
0 commit comments