@@ -80,15 +80,13 @@ def forward(self, img: Tensor, txt: Tensor, pe: Tensor, vec: Tensor, attn_mask=N
8080 (img_mod1 , img_mod2 ), (txt_mod1 , txt_mod2 ) = vec
8181
8282 # prepare image for attention
83- img_modulated = self .img_norm1 (img )
84- img_modulated = (1 + img_mod1 .scale ) * img_modulated + img_mod1 .shift
83+ img_modulated = torch .addcmul (img_mod1 .shift , 1 + img_mod1 .scale , self .img_norm1 (img ))
8584 img_qkv = self .img_attn .qkv (img_modulated )
8685 img_q , img_k , img_v = img_qkv .view (img_qkv .shape [0 ], img_qkv .shape [1 ], 3 , self .num_heads , - 1 ).permute (2 , 0 , 3 , 1 , 4 )
8786 img_q , img_k = self .img_attn .norm (img_q , img_k , img_v )
8887
8988 # prepare txt for attention
90- txt_modulated = self .txt_norm1 (txt )
91- txt_modulated = (1 + txt_mod1 .scale ) * txt_modulated + txt_mod1 .shift
89+ txt_modulated = torch .addcmul (txt_mod1 .shift , 1 + txt_mod1 .scale , self .txt_norm1 (txt ))
9290 txt_qkv = self .txt_attn .qkv (txt_modulated )
9391 txt_q , txt_k , txt_v = txt_qkv .view (txt_qkv .shape [0 ], txt_qkv .shape [1 ], 3 , self .num_heads , - 1 ).permute (2 , 0 , 3 , 1 , 4 )
9492 txt_q , txt_k = self .txt_attn .norm (txt_q , txt_k , txt_v )
@@ -102,12 +100,12 @@ def forward(self, img: Tensor, txt: Tensor, pe: Tensor, vec: Tensor, attn_mask=N
102100 txt_attn , img_attn = attn [:, : txt .shape [1 ]], attn [:, txt .shape [1 ] :]
103101
104102 # calculate the img bloks
105- img = img + img_mod1 .gate * self .img_attn .proj (img_attn )
106- img = img + img_mod2 .gate * self .img_mlp (( 1 + img_mod2 .scale ) * self .img_norm2 (img ) + img_mod2 . shift )
103+ img . addcmul_ ( img_mod1 .gate , self .img_attn .proj (img_attn ) )
104+ img . addcmul_ ( img_mod2 .gate , self .img_mlp (torch . addcmul ( img_mod2 . shift , 1 + img_mod2 .scale , self .img_norm2 (img ))) )
107105
108106 # calculate the txt bloks
109- txt += txt_mod1 .gate * self .txt_attn .proj (txt_attn )
110- txt += txt_mod2 .gate * self .txt_mlp (( 1 + txt_mod2 .scale ) * self .txt_norm2 (txt ) + txt_mod2 . shift )
107+ txt . addcmul_ ( txt_mod1 .gate , self .txt_attn .proj (txt_attn ) )
108+ txt . addcmul_ ( txt_mod2 .gate , self .txt_mlp (torch . addcmul ( txt_mod2 . shift , 1 + txt_mod2 .scale , self .txt_norm2 (txt ))) )
111109
112110 if txt .dtype == torch .float16 :
113111 txt = torch .nan_to_num (txt , nan = 0.0 , posinf = 65504 , neginf = - 65504 )
@@ -152,7 +150,7 @@ def __init__(
152150
153151 def forward (self , x : Tensor , pe : Tensor , vec : Tensor , attn_mask = None ) -> Tensor :
154152 mod = vec
155- x_mod = ( 1 + mod .scale ) * self .pre_norm (x ) + mod . shift
153+ x_mod = torch . addcmul ( mod . shift , 1 + mod .scale , self .pre_norm (x ))
156154 qkv , mlp = torch .split (self .linear1 (x_mod ), [3 * self .hidden_size , self .mlp_hidden_dim ], dim = - 1 )
157155
158156 q , k , v = qkv .view (qkv .shape [0 ], qkv .shape [1 ], 3 , self .num_heads , - 1 ).permute (2 , 0 , 3 , 1 , 4 )
@@ -162,7 +160,7 @@ def forward(self, x: Tensor, pe: Tensor, vec: Tensor, attn_mask=None) -> Tensor:
162160 attn = attention (q , k , v , pe = pe , mask = attn_mask )
163161 # compute activation in mlp stream, cat again and run second linear layer
164162 output = self .linear2 (torch .cat ((attn , self .mlp_act (mlp )), 2 ))
165- x += mod .gate * output
163+ x . addcmul_ ( mod .gate , output )
166164 if x .dtype == torch .float16 :
167165 x = torch .nan_to_num (x , nan = 0.0 , posinf = 65504 , neginf = - 65504 )
168166 return x
@@ -178,6 +176,6 @@ def forward(self, x: Tensor, vec: Tensor) -> Tensor:
178176 shift , scale = vec
179177 shift = shift .squeeze (1 )
180178 scale = scale .squeeze (1 )
181- x = ( 1 + scale [:, None , :]) * self . norm_final ( x ) + shift [:, None , :]
179+ x = torch . addcmul ( shift [:, None , :], 1 + scale [:, None , :], self . norm_final ( x ))
182180 x = self .linear (x )
183181 return x
0 commit comments