@@ -236,10 +236,10 @@ def forward(
236236 img_mod1 , img_mod2 = img_mod_params .chunk (2 , dim = - 1 )
237237 txt_mod1 , txt_mod2 = txt_mod_params .chunk (2 , dim = - 1 )
238238
239- img_normed = self .img_norm1 (hidden_states )
240- img_modulated , img_gate1 = self . _modulate ( img_normed , img_mod1 )
241- txt_normed = self .txt_norm1 (encoder_hidden_states )
242- txt_modulated , txt_gate1 = self . _modulate ( txt_normed , txt_mod1 )
239+ img_modulated , img_gate1 = self ._modulate ( self . img_norm1 (hidden_states ), img_mod1 )
240+ del img_mod1
241+ txt_modulated , txt_gate1 = self ._modulate ( self . txt_norm1 (encoder_hidden_states ), txt_mod1 )
242+ del txt_mod1
243243
244244 img_attn_output , txt_attn_output = self .attn (
245245 hidden_states = img_modulated ,
@@ -248,16 +248,20 @@ def forward(
248248 image_rotary_emb = image_rotary_emb ,
249249 transformer_options = transformer_options ,
250250 )
251+ del img_modulated
252+ del txt_modulated
251253
252254 hidden_states = hidden_states + img_gate1 * img_attn_output
253255 encoder_hidden_states = encoder_hidden_states + txt_gate1 * txt_attn_output
256+ del img_attn_output
257+ del txt_attn_output
258+ del img_gate1
259+ del txt_gate1
254260
255- img_normed2 = self .img_norm2 (hidden_states )
256- img_modulated2 , img_gate2 = self ._modulate (img_normed2 , img_mod2 )
261+ img_modulated2 , img_gate2 = self ._modulate (self .img_norm2 (hidden_states ), img_mod2 )
257262 hidden_states = torch .addcmul (hidden_states , img_gate2 , self .img_mlp (img_modulated2 ))
258263
259- txt_normed2 = self .txt_norm2 (encoder_hidden_states )
260- txt_modulated2 , txt_gate2 = self ._modulate (txt_normed2 , txt_mod2 )
264+ txt_modulated2 , txt_gate2 = self ._modulate (self .txt_norm2 (encoder_hidden_states ), txt_mod2 )
261265 encoder_hidden_states = torch .addcmul (encoder_hidden_states , txt_gate2 , self .txt_mlp (txt_modulated2 ))
262266
263267 return encoder_hidden_states , hidden_states
0 commit comments