Skip to content

Commit 19b4661

Browse files
Workaround for nvidia issue where VAE uses 3x more memory on torch 2.9 (Comfy-Org#10373)
1 parent bc0ad9b commit 19b4661

File tree

1 file changed

+19
-0
lines changed

1 file changed

+19
-0
lines changed

comfy/ops.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,16 @@ def scaled_dot_product_attention(q, k, v, *args, **kwargs):
5252
except (ModuleNotFoundError, TypeError):
5353
logging.warning("Could not set sdpa backend priority.")
5454

55+
NVIDIA_MEMORY_CONV_BUG_WORKAROUND = False
56+
try:
57+
if comfy.model_management.is_nvidia():
58+
if torch.backends.cudnn.version() >= 91300 and comfy.model_management.torch_version_numeric >= (2, 9) and comfy.model_management.torch_version_numeric <= (2, 10):
59+
#TODO: change upper bound version once it's fixed'
60+
NVIDIA_MEMORY_CONV_BUG_WORKAROUND = True
61+
logging.info("working around nvidia conv3d memory bug.")
62+
except:
63+
pass
64+
5565
cast_to = comfy.model_management.cast_to #TODO: remove once no more references
5666

5767
if torch.cuda.is_available() and torch.backends.cudnn.is_available() and PerformanceFeature.AutoTune in args.fast:
@@ -151,6 +161,15 @@ class Conv3d(torch.nn.Conv3d, CastWeightBiasOp):
151161
def reset_parameters(self):
152162
return None
153163

164+
def _conv_forward(self, input, weight, bias, *args, **kwargs):
165+
if NVIDIA_MEMORY_CONV_BUG_WORKAROUND and weight.dtype in (torch.float16, torch.bfloat16):
166+
out = torch.cudnn_convolution(input, weight, self.padding, self.stride, self.dilation, self.groups, benchmark=False, deterministic=False, allow_tf32=True)
167+
if bias is not None:
168+
out += bias.reshape((1, -1) + (1,) * (out.ndim - 2))
169+
return out
170+
else:
171+
return super()._conv_forward(input, weight, bias, *args, **kwargs)
172+
154173
def forward_comfy_cast_weights(self, input):
155174
weight, bias = cast_bias_weight(self, input)
156175
return self._conv_forward(input, weight, bias)

0 commit comments

Comments
 (0)