Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 27 additions & 16 deletions comfy/ldm/lightricks/vae/causal_conv3d.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
from typing import Tuple, Union

import threading
import torch
import torch.nn as nn
import comfy.ops
ops = comfy.ops.disable_weight_init


class CausalConv3d(nn.Module):
def __init__(
self,
Expand Down Expand Up @@ -42,23 +42,34 @@ def __init__(
padding_mode=spatial_padding_mode,
groups=groups,
)
self.temporal_cache_state={}

def forward(self, x, causal: bool = True):
if causal:
first_frame_pad = x[:, :, :1, :, :].repeat(
(1, 1, self.time_kernel_size - 1, 1, 1)
)
x = torch.concatenate((first_frame_pad, x), dim=2)
else:
first_frame_pad = x[:, :, :1, :, :].repeat(
(1, 1, (self.time_kernel_size - 1) // 2, 1, 1)
)
last_frame_pad = x[:, :, -1:, :, :].repeat(
(1, 1, (self.time_kernel_size - 1) // 2, 1, 1)
)
x = torch.concatenate((first_frame_pad, x, last_frame_pad), dim=2)
x = self.conv(x)
return x
tid = threading.get_ident()

cached, is_end = self.temporal_cache_state.get(tid, (None, False))
if cached is None:
padding_length = self.time_kernel_size - 1
if not causal:
padding_length = padding_length // 2
if x.shape[2] == 0:
return x
cached = x[:, :, :1, :, :].repeat((1, 1, padding_length, 1, 1))
pieces = [ cached, x ]
if is_end and not causal:
pieces.append(x[:, :, -1:, :, :].repeat((1, 1, (self.time_kernel_size - 1) // 2, 1, 1)))

needs_caching = not is_end
if needs_caching and x.shape[2] >= self.time_kernel_size - 1:
needs_caching = False
self.temporal_cache_state[tid] = (x[:, :, -(self.time_kernel_size - 1):, :, :], False)

x = torch.cat(pieces, dim=2)

if needs_caching:
self.temporal_cache_state[tid] = (x[:, :, -(self.time_kernel_size - 1):, :, :], False)

return self.conv(x) if x.shape[2] >= self.time_kernel_size else x[:, :, :0, :, :]

@property
def weight(self):
Expand Down
176 changes: 129 additions & 47 deletions comfy/ldm/lightricks/vae/causal_video_autoencoder.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,41 @@
from __future__ import annotations
import threading
import torch
from torch import nn
from functools import partial
import math
from einops import rearrange
from typing import List, Optional, Tuple, Union
from .conv_nd_factory import make_conv_nd, make_linear_nd
from .causal_conv3d import CausalConv3d
from .pixel_norm import PixelNorm
from ..model import PixArtAlphaCombinedTimestepSizeEmbeddings
import comfy.ops
from comfy.ldm.modules.diffusionmodules.model import torch_cat_if_needed

ops = comfy.ops.disable_weight_init

def mark_conv3d_ended(module):
tid = threading.get_ident()
for _, m in module.named_modules():
if isinstance(m, CausalConv3d):
current = m.temporal_cache_state.get(tid, (None, False))
m.temporal_cache_state[tid] = (current[0], True)

def split2(tensor, split_point, dim=2):
return torch.split(tensor, [split_point, tensor.shape[dim] - split_point], dim=dim)

def add_exchange_cache(dest, cache_in, new_input, dim=2):
if dest is not None:
if cache_in is not None:
cache_to_dest = min(dest.shape[dim], cache_in.shape[dim])
lead_in_dest, dest = split2(dest, cache_to_dest, dim=dim)
lead_in_source, cache_in = split2(cache_in, cache_to_dest, dim=dim)
lead_in_dest.add_(lead_in_source)
body, new_input = split2(new_input, dest.shape[dim], dim)
dest.add_(body)
return torch_cat_if_needed([cache_in, new_input], dim=dim)

class Encoder(nn.Module):
r"""
The `Encoder` layer of a variational autoencoder that encodes its input into a latent representation.
Expand Down Expand Up @@ -205,7 +229,7 @@ def __init__(

self.gradient_checkpointing = False

def forward(self, sample: torch.FloatTensor) -> torch.FloatTensor:
def forward_orig(self, sample: torch.FloatTensor) -> torch.FloatTensor:
r"""The forward method of the `Encoder` class."""

sample = patchify(sample, patch_size_hw=self.patch_size, patch_size_t=1)
Expand Down Expand Up @@ -254,6 +278,22 @@ def forward(self, sample: torch.FloatTensor) -> torch.FloatTensor:

return sample

def forward(self, *args, **kwargs):
#No encoder support so just flag the end so it doesnt use the cache.
mark_conv3d_ended(self)
try:
return self.forward_orig(*args, **kwargs)
finally:
tid = threading.get_ident()
for _, module in self.named_modules():
# ComfyUI doesn't thread this kind of stuff today, but just in case
# we key on the thread to make it thread safe.
tid = threading.get_ident()
if hasattr(module, "temporal_cache_state"):
module.temporal_cache_state.pop(tid, None)


MAX_CHUNK_SIZE=(128 * 1024 ** 2)

class Decoder(nn.Module):
r"""
Expand Down Expand Up @@ -341,18 +381,6 @@ def __init__(
timestep_conditioning=timestep_conditioning,
spatial_padding_mode=spatial_padding_mode,
)
elif block_name == "attn_res_x":
block = UNetMidBlock3D(
dims=dims,
in_channels=input_channel,
num_layers=block_params["num_layers"],
resnet_groups=norm_num_groups,
norm_layer=norm_layer,
inject_noise=block_params.get("inject_noise", False),
timestep_conditioning=timestep_conditioning,
attention_head_dim=block_params["attention_head_dim"],
spatial_padding_mode=spatial_padding_mode,
)
elif block_name == "res_x_y":
output_channel = output_channel // block_params.get("multiplier", 2)
block = ResnetBlock3D(
Expand Down Expand Up @@ -428,15 +456,17 @@ def __init__(
)
self.last_scale_shift_table = nn.Parameter(torch.empty(2, output_channel))


# def forward(self, sample: torch.FloatTensor, target_shape) -> torch.FloatTensor:
def forward(
def forward_orig(
self,
sample: torch.FloatTensor,
timestep: Optional[torch.Tensor] = None,
) -> torch.FloatTensor:
r"""The forward method of the `Decoder` class."""
batch_size = sample.shape[0]

mark_conv3d_ended(self.conv_in)
sample = self.conv_in(sample, causal=self.causal)

checkpoint_fn = (
Expand All @@ -445,24 +475,12 @@ def forward(
else lambda x: x
)

scaled_timestep = None
timestep_shift_scale = None
if self.timestep_conditioning:
assert (
timestep is not None
), "should pass timestep with timestep_conditioning=True"
scaled_timestep = timestep * self.timestep_scale_multiplier.to(dtype=sample.dtype, device=sample.device)

for up_block in self.up_blocks:
if self.timestep_conditioning and isinstance(up_block, UNetMidBlock3D):
sample = checkpoint_fn(up_block)(
sample, causal=self.causal, timestep=scaled_timestep
)
else:
sample = checkpoint_fn(up_block)(sample, causal=self.causal)

sample = self.conv_norm_out(sample)

if self.timestep_conditioning:
embedded_timestep = self.last_time_embedder(
timestep=scaled_timestep.flatten(),
resolution=None,
Expand All @@ -483,16 +501,62 @@ def forward(
embedded_timestep.shape[-2],
embedded_timestep.shape[-1],
)
shift, scale = ada_values.unbind(dim=1)
sample = sample * (1 + scale) + shift
timestep_shift_scale = ada_values.unbind(dim=1)

output = []

def run_up(idx, sample, ended):
if idx >= len(self.up_blocks):
sample = self.conv_norm_out(sample)
if timestep_shift_scale is not None:
shift, scale = timestep_shift_scale
sample = sample * (1 + scale) + shift
sample = self.conv_act(sample)
if ended:
mark_conv3d_ended(self.conv_out)
sample = self.conv_out(sample, causal=self.causal)
if sample is not None and sample.shape[2] > 0:
output.append(sample)
return

up_block = self.up_blocks[idx]
if (ended):
mark_conv3d_ended(up_block)
if self.timestep_conditioning and isinstance(up_block, UNetMidBlock3D):
sample = checkpoint_fn(up_block)(
sample, causal=self.causal, timestep=scaled_timestep
)
else:
sample = checkpoint_fn(up_block)(sample, causal=self.causal)

sample = self.conv_act(sample)
sample = self.conv_out(sample, causal=self.causal)
if sample is None or sample.shape[2] == 0:
return

total_bytes = sample.numel() * sample.element_size()
num_chunks = (total_bytes + MAX_CHUNK_SIZE - 1) // MAX_CHUNK_SIZE
samples = torch.chunk(sample, chunks=num_chunks, dim=2)

for chunk_idx, sample1 in enumerate(samples):
run_up(idx + 1, sample1, ended and chunk_idx == len(samples) - 1)

run_up(0, sample, True)
sample = torch.cat(output, dim=2)

sample = unpatchify(sample, patch_size_hw=self.patch_size, patch_size_t=1)

return sample

def forward(self, *args, **kwargs):
try:
return self.forward_orig(*args, **kwargs)
finally:
for _, module in self.named_modules():
#ComfyUI doesn't thread this kind of stuff today, but just incase
#we key on the thread to make it thread safe.
tid = threading.get_ident()
if hasattr(module, "temporal_cache_state"):
module.temporal_cache_state.pop(tid, None)


class UNetMidBlock3D(nn.Module):
"""
Expand Down Expand Up @@ -663,8 +727,22 @@ def __init__(
)
self.residual = residual
self.out_channels_reduction_factor = out_channels_reduction_factor
self.temporal_cache_state = {}

def forward(self, x, causal: bool = True, timestep: Optional[torch.Tensor] = None):
tid = threading.get_ident()
cached, drop_first_conv, drop_first_res = self.temporal_cache_state.get(tid, (None, True, True))
y = self.conv(x, causal=causal)
y = rearrange(
y,
"b (c p1 p2 p3) d h w -> b c (d p1) (h p2) (w p3)",
p1=self.stride[0],
p2=self.stride[1],
p3=self.stride[2],
)
if self.stride[0] == 2 and y.shape[2] > 0 and drop_first_conv:
y = y[:, :, 1:, :, :]
drop_first_conv = False
if self.residual:
# Reshape and duplicate the input to match the output shape
x_in = rearrange(
Expand All @@ -676,21 +754,20 @@ def forward(self, x, causal: bool = True, timestep: Optional[torch.Tensor] = Non
)
num_repeat = math.prod(self.stride) // self.out_channels_reduction_factor
x_in = x_in.repeat(1, num_repeat, 1, 1, 1)
if self.stride[0] == 2:
if self.stride[0] == 2 and x_in.shape[2] > 0 and drop_first_res:
x_in = x_in[:, :, 1:, :, :]
x = self.conv(x, causal=causal)
x = rearrange(
x,
"b (c p1 p2 p3) d h w -> b c (d p1) (h p2) (w p3)",
p1=self.stride[0],
p2=self.stride[1],
p3=self.stride[2],
)
if self.stride[0] == 2:
x = x[:, :, 1:, :, :]
if self.residual:
x = x + x_in
return x
drop_first_res = False

if y.shape[2] == 0:
y = None

cached = add_exchange_cache(y, cached, x_in, dim=2)
self.temporal_cache_state[tid] = (cached, drop_first_conv, drop_first_res)

else:
self.temporal_cache_state[tid] = (None, drop_first_conv, False)

return y

class LayerNorm(nn.Module):
def __init__(self, dim, eps, elementwise_affine=True) -> None:
Expand Down Expand Up @@ -807,6 +884,8 @@ def __init__(
torch.randn(4, in_channels) / in_channels**0.5
)

self.temporal_cache_state={}

def _feed_spatial_noise(
self, hidden_states: torch.FloatTensor, per_channel_scale: torch.FloatTensor
) -> torch.FloatTensor:
Expand Down Expand Up @@ -880,9 +959,12 @@ def forward(

input_tensor = self.conv_shortcut(input_tensor)

output_tensor = input_tensor + hidden_states
tid = threading.get_ident()
cached = self.temporal_cache_state.get(tid, None)
cached = add_exchange_cache(hidden_states, cached, input_tensor, dim=2)
self.temporal_cache_state[tid] = cached

return output_tensor
return hidden_states


def patchify(x, patch_size_hw, patch_size_t=1):
Expand Down
5 changes: 4 additions & 1 deletion comfy/ldm/modules/diffusionmodules/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,13 @@
import xformers.ops

def torch_cat_if_needed(xl, dim):
xl = [x for x in xl if x is not None and x.shape[dim] > 0]
if len(xl) > 1:
return torch.cat(xl, dim)
else:
elif len(xl) == 1:
return xl[0]
else:
return None

def get_timestep_embedding(timesteps, embedding_dim):
"""
Expand Down
2 changes: 1 addition & 1 deletion comfy_extras/nodes_load_3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def define_schema(cls):
files = [
normalize_path(str(file_path.relative_to(base_path)))
for file_path in input_path.rglob("*")
if file_path.suffix.lower() in {'.gltf', '.glb', '.obj', '.fbx', '.stl'}
if file_path.suffix.lower() in {'.gltf', '.glb', '.obj', '.fbx', '.stl', '.spz', '.splat', '.ply', '.ksplat'}
]
return IO.Schema(
node_id="Load3D",
Expand Down
Loading