diff --git a/comfy/model_base.py b/comfy/model_base.py index 6b7978949f6d..4556ee138df5 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -106,10 +106,12 @@ class ModelSampling(s, c): return ModelSampling(model_config) -def convert_tensor(extra, dtype): +def convert_tensor(extra, dtype, device): if hasattr(extra, "dtype"): if extra.dtype != torch.int and extra.dtype != torch.long: - extra = extra.to(dtype) + extra = extra.to(dtype=dtype, device=device) + else: + extra = extra.to(device=device) return extra @@ -169,20 +171,21 @@ def _apply_model(self, x, t, c_concat=None, c_crossattn=None, control=None, tran dtype = self.manual_cast_dtype xc = xc.to(dtype) + device = xc.device t = self.model_sampling.timestep(t).float() if context is not None: - context = context.to(dtype) + context = context.to(dtype=dtype, device=device) extra_conds = {} for o in kwargs: extra = kwargs[o] if hasattr(extra, "dtype"): - extra = convert_tensor(extra, dtype) + extra = convert_tensor(extra, dtype, device) elif isinstance(extra, list): ex = [] for ext in extra: - ex.append(convert_tensor(ext, dtype)) + ex.append(convert_tensor(ext, dtype, device)) extra = ex extra_conds[o] = extra diff --git a/requirements.txt b/requirements.txt index 3828c5b91a4e..ffa7dce65028 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ comfyui-frontend-package==1.23.4 -comfyui-workflow-templates==0.1.45 +comfyui-workflow-templates==0.1.47 comfyui-embedded-docs==0.2.4 torch torchsde