Skip to content

Commit 4e6a1b6

Browse files
authored
speed up and reduce VRAM of QWEN VAE and WAN (less so) (Comfy-Org#12036)
* ops: introduce autopad for conv3d This works around pytorch missing ability to causal pad as part of the kernel and avoids massive weight duplications for padding. * wan-vae: rework causal padding This currently uses F.pad which takes a full deep copy and is liable to be the VRAM peak. Instead, kick spatial padding back to the op and consolidate the temporal padding with the cat for the cache. * wan-vae: implement zero pad fast path The WAN VAE is also QWEN where it is used single-image. These convolutions are however zero padded 3d convolutions, which means the VAE is actually just 2D down the last element of the conv weight in the temporal dimension. Fast path this, to avoid adding zeros that then just evaporate in convoluton math but cost computation.
1 parent 9cf299a commit 4e6a1b6

File tree

2 files changed

+23
-14
lines changed

2 files changed

+23
-14
lines changed

comfy/ldm/wan/vae.py

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import torch.nn as nn
66
import torch.nn.functional as F
77
from einops import rearrange
8-
from comfy.ldm.modules.diffusionmodules.model import vae_attention
8+
from comfy.ldm.modules.diffusionmodules.model import vae_attention, torch_cat_if_needed
99

1010
import comfy.ops
1111
ops = comfy.ops.disable_weight_init
@@ -20,22 +20,29 @@ class CausalConv3d(ops.Conv3d):
2020

2121
def __init__(self, *args, **kwargs):
2222
super().__init__(*args, **kwargs)
23-
self._padding = (self.padding[2], self.padding[2], self.padding[1],
24-
self.padding[1], 2 * self.padding[0], 0)
25-
self.padding = (0, 0, 0)
23+
self._padding = 2 * self.padding[0]
24+
self.padding = (0, self.padding[1], self.padding[2])
2625

2726
def forward(self, x, cache_x=None, cache_list=None, cache_idx=None):
2827
if cache_list is not None:
2928
cache_x = cache_list[cache_idx]
3029
cache_list[cache_idx] = None
3130

32-
padding = list(self._padding)
33-
if cache_x is not None and self._padding[4] > 0:
34-
cache_x = cache_x.to(x.device)
35-
x = torch.cat([cache_x, x], dim=2)
36-
padding[4] -= cache_x.shape[2]
31+
if cache_x is None and x.shape[2] == 1:
32+
#Fast path - the op will pad for use by truncating the weight
33+
#and save math on a pile of zeros.
34+
return super().forward(x, autopad="causal_zero")
35+
36+
if self._padding > 0:
37+
padding_needed = self._padding
38+
if cache_x is not None:
39+
cache_x = cache_x.to(x.device)
40+
padding_needed = max(0, padding_needed - cache_x.shape[2])
41+
padding_shape = list(x.shape)
42+
padding_shape[2] = padding_needed
43+
padding = torch.zeros(padding_shape, device=x.device, dtype=x.dtype)
44+
x = torch_cat_if_needed([padding, cache_x, x], dim=2)
3745
del cache_x
38-
x = F.pad(x, padding)
3946

4047
return super().forward(x)
4148

comfy/ops.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -203,7 +203,9 @@ class Conv3d(torch.nn.Conv3d, CastWeightBiasOp):
203203
def reset_parameters(self):
204204
return None
205205

206-
def _conv_forward(self, input, weight, bias, *args, **kwargs):
206+
def _conv_forward(self, input, weight, bias, autopad=None, *args, **kwargs):
207+
if autopad == "causal_zero":
208+
weight = weight[:, :, -input.shape[2]:, :, :]
207209
if NVIDIA_MEMORY_CONV_BUG_WORKAROUND and weight.dtype in (torch.float16, torch.bfloat16):
208210
out = torch.cudnn_convolution(input, weight, self.padding, self.stride, self.dilation, self.groups, benchmark=False, deterministic=False, allow_tf32=True)
209211
if bias is not None:
@@ -212,15 +214,15 @@ def _conv_forward(self, input, weight, bias, *args, **kwargs):
212214
else:
213215
return super()._conv_forward(input, weight, bias, *args, **kwargs)
214216

215-
def forward_comfy_cast_weights(self, input):
217+
def forward_comfy_cast_weights(self, input, autopad=None):
216218
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
217-
x = self._conv_forward(input, weight, bias)
219+
x = self._conv_forward(input, weight, bias, autopad=autopad)
218220
uncast_bias_weight(self, weight, bias, offload_stream)
219221
return x
220222

221223
def forward(self, *args, **kwargs):
222224
run_every_op()
223-
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
225+
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0 or "autopad" in kwargs:
224226
return self.forward_comfy_cast_weights(*args, **kwargs)
225227
else:
226228
return super().forward(*args, **kwargs)

0 commit comments

Comments
 (0)