diff --git a/comfy/ldm/hunyuan_video/upsampler.py b/comfy/ldm/hunyuan_video/upsampler.py index d9e76922f527..51b6d1da8325 100644 --- a/comfy/ldm/hunyuan_video/upsampler.py +++ b/comfy/ldm/hunyuan_video/upsampler.py @@ -3,8 +3,8 @@ import torch.nn.functional as F from comfy.ldm.modules.diffusionmodules.model import ResnetBlock, VideoConv3d from comfy.ldm.hunyuan_video.vae_refiner import RMS_norm -import model_management -import model_patcher +import comfy.model_management +import comfy.model_patcher class SRResidualCausalBlock3D(nn.Module): def __init__(self, channels: int): @@ -103,13 +103,13 @@ def forward(self, z): class HunyuanVideo15SRModel(): def __init__(self, model_type, config): - self.load_device = model_management.vae_device() - offload_device = model_management.vae_offload_device() - self.dtype = model_management.vae_dtype(self.load_device) + self.load_device = comfy.model_management.vae_device() + offload_device = comfy.model_management.vae_offload_device() + self.dtype = comfy.model_management.vae_dtype(self.load_device) self.model_class = UPSAMPLERS.get(model_type) self.model = self.model_class(**config).eval() - self.patcher = model_patcher.ModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device) + self.patcher = comfy.model_patcher.ModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device) def load_sd(self, sd): return self.model.load_state_dict(sd, strict=True) @@ -118,5 +118,5 @@ def get_sd(self): return self.model.state_dict() def resample_latent(self, latent): - model_management.load_model_gpu(self.patcher) + comfy.model_management.load_model_gpu(self.patcher) return self.model(latent.to(self.load_device)) diff --git a/comfy/model_management.py b/comfy/model_management.py index 928282092573..e5de4a5b5238 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -22,7 +22,6 @@ from comfy.cli_args import args, PerformanceFeature import torch import sys -import importlib import platform import weakref import gc @@ -349,10 +348,22 @@ def amd_min_version(device=None, min_rdna_version=0): except: rocm_version = (6, -1) + def aotriton_supported(gpu_arch): + path = torch.__path__[0] + path = os.path.join(os.path.join(path, "lib"), "aotriton.images") + gfx = set(map(lambda a: a[4:], filter(lambda a: a.startswith("amd-gfx"), os.listdir(path)))) + if gpu_arch in gfx: + return True + if "{}x".format(gpu_arch[:-1]) in gfx: + return True + if "{}xx".format(gpu_arch[:-2]) in gfx: + return True + return False + logging.info("AMD arch: {}".format(arch)) logging.info("ROCm version: {}".format(rocm_version)) if args.use_split_cross_attention == False and args.use_quad_cross_attention == False: - if importlib.util.find_spec('triton') is not None: # AMD efficient attention implementation depends on triton. TODO: better way of detecting if it's compiled in or not. + if aotriton_supported(arch): # AMD efficient attention implementation depends on aotriton. if torch_version_numeric >= (2, 7): # works on 2.6 but doesn't actually seem to improve much if any((a in arch) for a in ["gfx90a", "gfx942", "gfx1100", "gfx1101", "gfx1151"]): # TODO: more arches, TODO: gfx950 ENABLE_PYTORCH_ATTENTION = True