Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions comfy/k_diffusion/sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -853,6 +853,11 @@ def sample_dpmpp_2m_sde(model, x, sigmas, extra_args=None, callback=None, disabl
return x


@torch.no_grad()
def sample_dpmpp_2m_sde_heun(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, solver_type='heun'):
return sample_dpmpp_2m_sde(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=eta, s_noise=s_noise, noise_sampler=noise_sampler, solver_type=solver_type)


@torch.no_grad()
def sample_dpmpp_3m_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
"""DPM-Solver++(3M) SDE."""
Expand Down Expand Up @@ -925,6 +930,16 @@ def sample_dpmpp_3m_sde_gpu(model, x, sigmas, extra_args=None, callback=None, di
return sample_dpmpp_3m_sde(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=eta, s_noise=s_noise, noise_sampler=noise_sampler)


@torch.no_grad()
def sample_dpmpp_2m_sde_heun_gpu(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, solver_type='heun'):
if len(sigmas) <= 1:
return x
extra_args = {} if extra_args is None else extra_args
sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=extra_args.get("seed", None), cpu=False) if noise_sampler is None else noise_sampler
return sample_dpmpp_2m_sde_heun(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=eta, s_noise=s_noise, noise_sampler=noise_sampler, solver_type=solver_type)


@torch.no_grad()
def sample_dpmpp_2m_sde_gpu(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, solver_type='midpoint'):
if len(sigmas) <= 1:
Expand Down
2 changes: 2 additions & 0 deletions comfy/ldm/wan/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -1278,6 +1278,7 @@ def forward_orig(
x = torch.cat([x, ref], dim=1)
freqs = torch.cat([freqs, freqs_ref], dim=1)
t = torch.cat([t, torch.zeros((t.shape[0], reference_latent.shape[-3]), device=t.device, dtype=t.dtype)], dim=1)
del ref, freqs_ref

if reference_motion is not None:
motion_encoded, freqs_motion = self.frame_packer(reference_motion, self)
Expand All @@ -1287,6 +1288,7 @@ def forward_orig(

t = torch.repeat_interleave(t, 2, dim=1)
t = torch.cat([t, torch.zeros((t.shape[0], 3), device=t.device, dtype=t.dtype)], dim=1)
del motion_encoded, freqs_motion

# time embeddings
e = self.time_embedding(
Expand Down
25 changes: 23 additions & 2 deletions comfy/model_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,7 @@ def __init__(self, model_config, model_type=ModelType.EPS, device=None, unet_mod
logging.debug("adm {}".format(self.adm_channels))
self.memory_usage_factor = model_config.memory_usage_factor
self.memory_usage_factor_conds = ()
self.memory_usage_shape_process = {}

def apply_model(self, x, t, c_concat=None, c_crossattn=None, control=None, transformer_options={}, **kwargs):
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
Expand Down Expand Up @@ -350,8 +351,15 @@ def memory_required(self, input_shape, cond_shapes={}):
input_shapes = [input_shape]
for c in self.memory_usage_factor_conds:
shape = cond_shapes.get(c, None)
if shape is not None and len(shape) > 0:
input_shapes += shape
if shape is not None:
if c in self.memory_usage_shape_process:
out = []
for s in shape:
out.append(self.memory_usage_shape_process[c](s))
shape = out

if len(shape) > 0:
input_shapes += shape

if comfy.model_management.xformers_enabled() or comfy.model_management.pytorch_attention_flash_attention():
dtype = self.get_dtype()
Expand Down Expand Up @@ -1204,6 +1212,8 @@ def extra_conds(self, **kwargs):
class WAN22_S2V(WAN21):
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.WanModel_S2V)
self.memory_usage_factor_conds = ("reference_latent", "reference_motion")
self.memory_usage_shape_process = {"reference_motion": lambda shape: [shape[0], shape[1], 1.5, shape[-2], shape[-1]]}

def extra_conds(self, **kwargs):
out = super().extra_conds(**kwargs)
Expand All @@ -1224,6 +1234,17 @@ def extra_conds(self, **kwargs):
out['control_video'] = comfy.conds.CONDRegular(self.process_latent_in(control_video))
return out

def extra_conds_shapes(self, **kwargs):
out = {}
ref_latents = kwargs.get("reference_latents", None)
if ref_latents is not None:
out['reference_latent'] = list([1, 16, sum(map(lambda a: math.prod(a.size()), ref_latents)) // 16])

reference_motion = kwargs.get("reference_motion", None)
if reference_motion is not None:
out['reference_motion'] = reference_motion.shape
return out

class WAN22(BaseModel):
def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=False, device=None):
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.WanModel)
Expand Down
2 changes: 1 addition & 1 deletion comfy/samplers.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -729,7 +729,7 @@ def max_denoise(self, model_wrap, sigmas):

KSAMPLER_NAMES = ["euler", "euler_cfg_pp", "euler_ancestral", "euler_ancestral_cfg_pp", "heun", "heunpp2","dpm_2", "dpm_2_ancestral",
"lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_2s_ancestral_cfg_pp", "dpmpp_sde", "dpmpp_sde_gpu",
"dpmpp_2m", "dpmpp_2m_cfg_pp", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddpm", "lcm",
"dpmpp_2m", "dpmpp_2m_cfg_pp", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_2m_sde_heun", "dpmpp_2m_sde_heun_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddpm", "lcm",
"ipndm", "ipndm_v", "deis", "res_multistep", "res_multistep_cfg_pp", "res_multistep_ancestral", "res_multistep_ancestral_cfg_pp",
"gradient_estimation", "gradient_estimation_cfg_pp", "er_sde", "seeds_2", "seeds_3", "sa_solver", "sa_solver_pece"]

Expand Down
2 changes: 1 addition & 1 deletion comfy/supported_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -700,7 +700,7 @@ class Flux(supported_models_base.BASE):
unet_extra_config = {}
latent_format = latent_formats.Flux

memory_usage_factor = 2.8
memory_usage_factor = 3.1 # TODO: debug why flux mem usage is so weird on windows.

supported_inference_dtypes = [torch.bfloat16, torch.float16, torch.float32]

Expand Down
33 changes: 33 additions & 0 deletions comfy_extras/nodes_latent.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,38 @@ def op(self, samples1, samples2, ratio):
samples_out["samples"] = st * (m1 * ratio + m2 * (1.0 - ratio))
return (samples_out,)

class LatentConcat:
@classmethod
def INPUT_TYPES(s):
return {"required": { "samples1": ("LATENT",), "samples2": ("LATENT",), "dim": (["x", "-x", "y", "-y", "t", "-t"], )}}

RETURN_TYPES = ("LATENT",)
FUNCTION = "op"

CATEGORY = "latent/advanced"

def op(self, samples1, samples2, dim):
samples_out = samples1.copy()

s1 = samples1["samples"]
s2 = samples2["samples"]
s2 = comfy.utils.repeat_to_batch_size(s2, s1.shape[0])

if "-" in dim:
c = (s2, s1)
else:
c = (s1, s2)

if "x" in dim:
dim = -1
elif "y" in dim:
dim = -2
elif "t" in dim:
dim = -3

samples_out["samples"] = torch.cat(c, dim=dim)
return (samples_out,)

class LatentBatch:
@classmethod
def INPUT_TYPES(s):
Expand Down Expand Up @@ -279,6 +311,7 @@ def sharpen(latent, **kwargs):
"LatentSubtract": LatentSubtract,
"LatentMultiply": LatentMultiply,
"LatentInterpolate": LatentInterpolate,
"LatentConcat": LatentConcat,
"LatentBatch": LatentBatch,
"LatentBatchSeedBehavior": LatentBatchSeedBehavior,
"LatentApplyOperation": LatentApplyOperation,
Expand Down
2 changes: 1 addition & 1 deletion comfyui_version.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
# This file is automatically generated by the build process when version is
# updated in pyproject.toml.
__version__ = "0.3.52"
__version__ = "0.3.53"
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "ComfyUI"
version = "0.3.52"
version = "0.3.53"
readme = "README.md"
license = { file = "LICENSE" }
requires-python = ">=3.9"
Expand Down
Loading