diff --git a/comfy/model_management.py b/comfy/model_management.py index 146c00925d3c..709ebc40b256 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -345,9 +345,9 @@ def amd_min_version(device=None, min_rdna_version=0): if torch_version_numeric >= (2, 7): # works on 2.6 but doesn't actually seem to improve much if any((a in arch) for a in ["gfx90a", "gfx942", "gfx1100", "gfx1101", "gfx1151"]): # TODO: more arches, TODO: gfx950 ENABLE_PYTORCH_ATTENTION = True -# if torch_version_numeric >= (2, 8): -# if any((a in arch) for a in ["gfx1201"]): -# ENABLE_PYTORCH_ATTENTION = True + if rocm_version >= (7, 0): + if any((a in arch) for a in ["gfx1201"]): + ENABLE_PYTORCH_ATTENTION = True if torch_version_numeric >= (2, 7) and rocm_version >= (6, 4): if any((a in arch) for a in ["gfx1200", "gfx1201", "gfx942", "gfx950"]): # TODO: more arches SUPPORT_FP8_OPS = True diff --git a/comfy/ops.py b/comfy/ops.py index 2415c96bf486..b2096b40ee37 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -24,6 +24,8 @@ import comfy.rmsnorm import contextlib +def run_every_op(): + comfy.model_management.throw_exception_if_processing_interrupted() def scaled_dot_product_attention(q, k, v, *args, **kwargs): return torch.nn.functional.scaled_dot_product_attention(q, k, v, *args, **kwargs) @@ -109,6 +111,7 @@ def forward_comfy_cast_weights(self, input): return torch.nn.functional.linear(input, weight, bias) def forward(self, *args, **kwargs): + run_every_op() if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0: return self.forward_comfy_cast_weights(*args, **kwargs) else: @@ -123,6 +126,7 @@ def forward_comfy_cast_weights(self, input): return self._conv_forward(input, weight, bias) def forward(self, *args, **kwargs): + run_every_op() if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0: return self.forward_comfy_cast_weights(*args, **kwargs) else: @@ -137,6 +141,7 @@ def forward_comfy_cast_weights(self, input): return self._conv_forward(input, weight, bias) def forward(self, *args, **kwargs): + run_every_op() if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0: return self.forward_comfy_cast_weights(*args, **kwargs) else: @@ -151,6 +156,7 @@ def forward_comfy_cast_weights(self, input): return self._conv_forward(input, weight, bias) def forward(self, *args, **kwargs): + run_every_op() if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0: return self.forward_comfy_cast_weights(*args, **kwargs) else: @@ -165,6 +171,7 @@ def forward_comfy_cast_weights(self, input): return torch.nn.functional.group_norm(input, self.num_groups, weight, bias, self.eps) def forward(self, *args, **kwargs): + run_every_op() if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0: return self.forward_comfy_cast_weights(*args, **kwargs) else: @@ -183,6 +190,7 @@ def forward_comfy_cast_weights(self, input): return torch.nn.functional.layer_norm(input, self.normalized_shape, weight, bias, self.eps) def forward(self, *args, **kwargs): + run_every_op() if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0: return self.forward_comfy_cast_weights(*args, **kwargs) else: @@ -202,6 +210,7 @@ def forward_comfy_cast_weights(self, input): # return torch.nn.functional.rms_norm(input, self.normalized_shape, weight, self.eps) def forward(self, *args, **kwargs): + run_every_op() if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0: return self.forward_comfy_cast_weights(*args, **kwargs) else: @@ -223,6 +232,7 @@ def forward_comfy_cast_weights(self, input, output_size=None): output_padding, self.groups, self.dilation) def forward(self, *args, **kwargs): + run_every_op() if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0: return self.forward_comfy_cast_weights(*args, **kwargs) else: @@ -244,6 +254,7 @@ def forward_comfy_cast_weights(self, input, output_size=None): output_padding, self.groups, self.dilation) def forward(self, *args, **kwargs): + run_every_op() if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0: return self.forward_comfy_cast_weights(*args, **kwargs) else: @@ -262,6 +273,7 @@ def forward_comfy_cast_weights(self, input, out_dtype=None): return torch.nn.functional.embedding(input, weight, self.padding_idx, self.max_norm, self.norm_type, self.scale_grad_by_freq, self.sparse).to(dtype=output_dtype) def forward(self, *args, **kwargs): + run_every_op() if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0: return self.forward_comfy_cast_weights(*args, **kwargs) else: diff --git a/comfy/sd.py b/comfy/sd.py index b9c2e995e759..28bee248dae1 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -276,8 +276,13 @@ def __init__(self, sd=None, device=None, config=None, dtype=None, metadata=None) if 'decoder.up_blocks.0.resnets.0.norm1.weight' in sd.keys(): #diffusers format sd = diffusers_convert.convert_vae_state_dict(sd) - self.memory_used_encode = lambda shape, dtype: (1767 * shape[2] * shape[3]) * model_management.dtype_size(dtype) #These are for AutoencoderKL and need tweaking (should be lower) - self.memory_used_decode = lambda shape, dtype: (2178 * shape[2] * shape[3] * 64) * model_management.dtype_size(dtype) + if model_management.is_amd(): + VAE_KL_MEM_RATIO = 2.73 + else: + VAE_KL_MEM_RATIO = 1.0 + + self.memory_used_encode = lambda shape, dtype: (1767 * shape[2] * shape[3]) * model_management.dtype_size(dtype) * VAE_KL_MEM_RATIO #These are for AutoencoderKL and need tweaking (should be lower) + self.memory_used_decode = lambda shape, dtype: (2178 * shape[2] * shape[3] * 64) * model_management.dtype_size(dtype) * VAE_KL_MEM_RATIO self.downscale_ratio = 8 self.upscale_ratio = 8 self.latent_channels = 4 diff --git a/comfy/utils.py b/comfy/utils.py index fab28cf088c0..0fd03f165b7c 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -39,7 +39,11 @@ class ModelCheckpoint: pass ModelCheckpoint.__module__ = "pytorch_lightning.callbacks.model_checkpoint" - from numpy.core.multiarray import scalar + def scalar(*args, **kwargs): + from numpy.core.multiarray import scalar as sc + return sc(*args, **kwargs) + scalar.__module__ = "numpy.core.multiarray" + from numpy import dtype from numpy.dtypes import Float64DType from _codecs import encode diff --git a/comfyui_version.py b/comfyui_version.py index da5cde02d016..d39c1fdc46f5 100644 --- a/comfyui_version.py +++ b/comfyui_version.py @@ -1,3 +1,3 @@ # This file is automatically generated by the build process when version is # updated in pyproject.toml. -__version__ = "0.3.64" +__version__ = "0.3.65" diff --git a/pyproject.toml b/pyproject.toml index 5dcc49a47cff..653604e24f0e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "ComfyUI" -version = "0.3.64" +version = "0.3.65" readme = "README.md" license = { file = "LICENSE" } requires-python = ">=3.9"