Skip to content

Commit 671ca2b

Browse files
committed
feat: add residual_weight for attn
1 parent 343102a commit 671ca2b

1 file changed

Lines changed: 18 additions & 6 deletions

File tree

axlearn/common/attention.py

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2582,6 +2582,8 @@ class Config(BaseLayer.Config):
25822582
# TODO (bwzhang@) Adding a unittest for the hybridnorm.
25832583
# v2: see comments on NormPosition for details.
25842584
structure: str = "prenorm"
2585+
# outputs = inputs + residual_weight * x.
2586+
residual_weight: Optional[float] = None
25852587

25862588
def __init__(self, cfg: Config, *, parent: Module):
25872589
super().__init__(cfg, parent=parent)
@@ -2724,25 +2726,35 @@ def attention_thunk(target: Tensor) -> tuple[Optional[NestedTensor], Tensor]:
27242726
skip_input = target # pre-norm: where normalization happens within the residual part.
27252727
norm_target = self.norm(target)
27262728
atten_state, atten_output = attention_thunk(norm_target)
2727-
data = skip_input + self.stochastic_depth(self.dropout(atten_output.data))
2729+
data = self.stochastic_depth(self.dropout(atten_output.data))
2730+
if cfg.residual_weight is not None and cfg.residual_weight != 1:
2731+
data *= cfg.residual_weight
2732+
data = skip_input + data
27282733
elif cfg.structure == "postnorm":
27292734
# This is the structure used by the original Transformer, BERT, and RoBERTa.
27302735
atten_state, atten_output = attention_thunk(target)
27312736
# Post-norm: norm applied on the sum of input and attention output.
2732-
data = self.norm(target + self.stochastic_depth(self.dropout(atten_output.data)))
2737+
data = self.stochastic_depth(self.dropout(atten_output.data))
2738+
if cfg.residual_weight is not None and cfg.residual_weight != 1:
2739+
data *= cfg.residual_weight
2740+
data = self.norm(target + data)
27332741
elif cfg.structure == "hybridnorm":
27342742
skip_input = target # pre-norm: where normalization happens within the residual part.
27352743
norm_target = self.prenorm(target)
27362744
atten_state, atten_output = attention_thunk(norm_target)
2737-
data = skip_input + self.stochastic_depth(
2738-
self.dropout(self.postnorm(atten_output.data))
2739-
)
2745+
data = self.stochastic_depth(self.dropout(self.postnorm(atten_output.data)))
2746+
if cfg.residual_weight is not None and cfg.residual_weight != 1:
2747+
data *= cfg.residual_weight
2748+
data = skip_input + data
27402749
elif cfg.structure == "v2":
27412750
norm_target = self.in_norm(target) if NormPosition.IN_NORM in cfg.norm else target
27422751
atten_state, atten_output = attention_thunk(norm_target)
27432752
data = atten_output.data
27442753
data = self.res_norm(data) if NormPosition.RES_NORM in cfg.norm else data
2745-
data = target + self.stochastic_depth(self.dropout(data))
2754+
data = self.stochastic_depth(self.dropout(data))
2755+
if cfg.residual_weight is not None and cfg.residual_weight != 1:
2756+
data *= cfg.residual_weight
2757+
data = target + data
27462758
data = self.out_norm(data) if NormPosition.OUT_NORM in cfg.norm else data
27472759
else:
27482760
raise NotImplementedError(cfg.structure)

0 commit comments

Comments
 (0)