Skip to content

Commit 1c7eaec

Browse files
authored
qwen: reduce VRAM usage (Comfy-Org#10725)
Clean up a bunch of stacked and no-longer-needed tensors on the QWEN VRAM peak (currently FFN). With this I go from OOMing at B=37x1328x1328 to being able to succesfully run B=47 (RTX5090).
1 parent 18e7d6d commit 1c7eaec

File tree

1 file changed

+12
-8
lines changed

1 file changed

+12
-8
lines changed

comfy/ldm/qwen_image/model.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -236,10 +236,10 @@ def forward(
236236
img_mod1, img_mod2 = img_mod_params.chunk(2, dim=-1)
237237
txt_mod1, txt_mod2 = txt_mod_params.chunk(2, dim=-1)
238238

239-
img_normed = self.img_norm1(hidden_states)
240-
img_modulated, img_gate1 = self._modulate(img_normed, img_mod1)
241-
txt_normed = self.txt_norm1(encoder_hidden_states)
242-
txt_modulated, txt_gate1 = self._modulate(txt_normed, txt_mod1)
239+
img_modulated, img_gate1 = self._modulate(self.img_norm1(hidden_states), img_mod1)
240+
del img_mod1
241+
txt_modulated, txt_gate1 = self._modulate(self.txt_norm1(encoder_hidden_states), txt_mod1)
242+
del txt_mod1
243243

244244
img_attn_output, txt_attn_output = self.attn(
245245
hidden_states=img_modulated,
@@ -248,16 +248,20 @@ def forward(
248248
image_rotary_emb=image_rotary_emb,
249249
transformer_options=transformer_options,
250250
)
251+
del img_modulated
252+
del txt_modulated
251253

252254
hidden_states = hidden_states + img_gate1 * img_attn_output
253255
encoder_hidden_states = encoder_hidden_states + txt_gate1 * txt_attn_output
256+
del img_attn_output
257+
del txt_attn_output
258+
del img_gate1
259+
del txt_gate1
254260

255-
img_normed2 = self.img_norm2(hidden_states)
256-
img_modulated2, img_gate2 = self._modulate(img_normed2, img_mod2)
261+
img_modulated2, img_gate2 = self._modulate(self.img_norm2(hidden_states), img_mod2)
257262
hidden_states = torch.addcmul(hidden_states, img_gate2, self.img_mlp(img_modulated2))
258263

259-
txt_normed2 = self.txt_norm2(encoder_hidden_states)
260-
txt_modulated2, txt_gate2 = self._modulate(txt_normed2, txt_mod2)
264+
txt_modulated2, txt_gate2 = self._modulate(self.txt_norm2(encoder_hidden_states), txt_mod2)
261265
encoder_hidden_states = torch.addcmul(encoder_hidden_states, txt_gate2, self.txt_mlp(txt_modulated2))
262266

263267
return encoder_hidden_states, hidden_states

0 commit comments

Comments
 (0)