@@ -52,6 +52,16 @@ def scaled_dot_product_attention(q, k, v, *args, **kwargs):
5252except (ModuleNotFoundError , TypeError ):
5353 logging .warning ("Could not set sdpa backend priority." )
5454
55+ NVIDIA_MEMORY_CONV_BUG_WORKAROUND = False
56+ try :
57+ if comfy .model_management .is_nvidia ():
58+ if torch .backends .cudnn .version () >= 91300 and comfy .model_management .torch_version_numeric >= (2 , 9 ) and comfy .model_management .torch_version_numeric <= (2 , 10 ):
59+ #TODO: change upper bound version once it's fixed'
60+ NVIDIA_MEMORY_CONV_BUG_WORKAROUND = True
61+ logging .info ("working around nvidia conv3d memory bug." )
62+ except :
63+ pass
64+
5565cast_to = comfy .model_management .cast_to #TODO: remove once no more references
5666
5767if torch .cuda .is_available () and torch .backends .cudnn .is_available () and PerformanceFeature .AutoTune in args .fast :
@@ -151,6 +161,15 @@ class Conv3d(torch.nn.Conv3d, CastWeightBiasOp):
151161 def reset_parameters (self ):
152162 return None
153163
164+ def _conv_forward (self , input , weight , bias , * args , ** kwargs ):
165+ if NVIDIA_MEMORY_CONV_BUG_WORKAROUND and weight .dtype in (torch .float16 , torch .bfloat16 ):
166+ out = torch .cudnn_convolution (input , weight , self .padding , self .stride , self .dilation , self .groups , benchmark = False , deterministic = False , allow_tf32 = True )
167+ if bias is not None :
168+ out += bias .reshape ((1 , - 1 ) + (1 ,) * (out .ndim - 2 ))
169+ return out
170+ else :
171+ return super ()._conv_forward (input , weight , bias , * args , ** kwargs )
172+
154173 def forward_comfy_cast_weights (self , input ):
155174 weight , bias = cast_bias_weight (self , input )
156175 return self ._conv_forward (input , weight , bias )
0 commit comments