Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 16 additions & 5 deletions comfy/ldm/wan/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,15 @@ def forward(self, x, context, context_img_len):
}


def repeat_e(e, x):
repeats = 1
if e.shape[1] > 1:
repeats = x.shape[1] // e.shape[1]
if repeats == 1:
return e
return torch.repeat_interleave(e, repeats, dim=1)


class WanAttentionBlock(nn.Module):

def __init__(self,
Expand Down Expand Up @@ -201,6 +210,7 @@ def forward(
freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2]
"""
# assert e.dtype == torch.float32

if e.ndim < 4:
e = (comfy.model_management.cast_to(self.modulation, dtype=x.dtype, device=x.device) + e).chunk(6, dim=1)
else:
Expand All @@ -209,15 +219,15 @@ def forward(

# self-attention
y = self.self_attn(
self.norm1(x) * (1 + e[1]) + e[0],
self.norm1(x) * (1 + repeat_e(e[1], x)) + repeat_e(e[0], x),
freqs)

x = x + y * e[2]
x = x + y * repeat_e(e[2], x)

# cross-attention & ffn
x = x + self.cross_attn(self.norm3(x), context, context_img_len=context_img_len)
y = self.ffn(self.norm2(x) * (1 + e[4]) + e[3])
x = x + y * e[5]
y = self.ffn(self.norm2(x) * (1 + repeat_e(e[4], x)) + repeat_e(e[3], x))
x = x + y * repeat_e(e[5], x)
return x


Expand Down Expand Up @@ -331,7 +341,8 @@ def forward(self, x, e):
e = (comfy.model_management.cast_to(self.modulation, dtype=x.dtype, device=x.device) + e.unsqueeze(1)).chunk(2, dim=1)
else:
e = (comfy.model_management.cast_to(self.modulation, dtype=x.dtype, device=x.device).unsqueeze(0) + e.unsqueeze(2)).unbind(2)
x = (self.head(self.norm(x) * (1 + e[1]) + e[0]))

x = (self.head(self.norm(x) * (1 + repeat_e(e[1], x)) + repeat_e(e[0], x)))
return x


Expand Down
2 changes: 1 addition & 1 deletion comfy/model_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1202,7 +1202,7 @@ def extra_conds(self, **kwargs):
def process_timestep(self, timestep, x, denoise_mask=None, **kwargs):
if denoise_mask is None:
return timestep
temp_ts = (torch.mean(denoise_mask[:, :, :, ::2, ::2], dim=1, keepdim=True) * timestep.view([timestep.shape[0]] + [1] * (denoise_mask.ndim - 1))).reshape(timestep.shape[0], -1)
temp_ts = (torch.mean(denoise_mask[:, :, :, :, :], dim=(1, 3, 4), keepdim=True) * timestep.view([timestep.shape[0]] + [1] * (denoise_mask.ndim - 1))).reshape(timestep.shape[0], -1)
return temp_ts

def scale_latent_inpaint(self, sigma, noise, latent_image, **kwargs):
Expand Down
86 changes: 86 additions & 0 deletions comfy_api/generate_api_stubs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
#!/usr/bin/env python3
"""
Script to generate .pyi stub files for the synchronous API wrappers.
This allows generating stubs without running the full ComfyUI application.
"""

import os
import sys
import logging
import importlib

# Add ComfyUI to path so we can import modules
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

from comfy_api.internal.async_to_sync import AsyncToSyncConverter
from comfy_api.version_list import supported_versions


def generate_stubs_for_module(module_name: str) -> None:
"""Generate stub files for a specific module that exports ComfyAPI and ComfyAPISync."""
try:
# Import the module
module = importlib.import_module(module_name)

# Check if module has ComfyAPISync (the sync wrapper)
if hasattr(module, "ComfyAPISync"):
# Module already has a sync class
api_class = getattr(module, "ComfyAPI", None)
sync_class = getattr(module, "ComfyAPISync")

if api_class:
# Generate the stub file
AsyncToSyncConverter.generate_stub_file(api_class, sync_class)
logging.info(f"Generated stub file for {module_name}")
else:
logging.warning(
f"Module {module_name} has ComfyAPISync but no ComfyAPI"
)

elif hasattr(module, "ComfyAPI"):
# Module only has async API, need to create sync wrapper first
from comfy_api.internal.async_to_sync import create_sync_class

api_class = getattr(module, "ComfyAPI")
sync_class = create_sync_class(api_class)

# Generate the stub file
AsyncToSyncConverter.generate_stub_file(api_class, sync_class)
logging.info(f"Generated stub file for {module_name}")
else:
logging.warning(
f"Module {module_name} does not export ComfyAPI or ComfyAPISync"
)

except Exception as e:
logging.error(f"Failed to generate stub for {module_name}: {e}")
import traceback

traceback.print_exc()


def main():
"""Main function to generate all API stub files."""
logging.basicConfig(level=logging.INFO)

logging.info("Starting stub generation...")

# Dynamically get module names from supported_versions
api_modules = []
for api_class in supported_versions:
# Extract module name from the class
module_name = api_class.__module__
if module_name not in api_modules:
api_modules.append(module_name)

logging.info(f"Found {len(api_modules)} API modules: {api_modules}")

# Generate stubs for each module
for module_name in api_modules:
generate_stubs_for_module(module_name)

logging.info("Stub generation complete!")


if __name__ == "__main__":
main()
12 changes: 10 additions & 2 deletions comfy_api/input/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,16 @@
from .basic_types import ImageInput, AudioInput
from .video_types import VideoInput
# This file only exists for backwards compatibility.
from comfy_api.latest._input import (
ImageInput,
AudioInput,
MaskInput,
LatentInput,
VideoInput,
)

__all__ = [
"ImageInput",
"AudioInput",
"MaskInput",
"LatentInput",
"VideoInput",
]
34 changes: 14 additions & 20 deletions comfy_api/input/basic_types.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,14 @@
import torch
from typing import TypedDict

ImageInput = torch.Tensor
"""
An image in format [B, H, W, C] where B is the batch size, C is the number of channels,
"""

class AudioInput(TypedDict):
"""
TypedDict representing audio input.
"""

waveform: torch.Tensor
"""
Tensor in the format [B, C, T] where B is the batch size, C is the number of channels,
"""

sample_rate: int

# This file only exists for backwards compatibility.
from comfy_api.latest._input.basic_types import (
ImageInput,
AudioInput,
MaskInput,
LatentInput,
)

__all__ = [
"ImageInput",
"AudioInput",
"MaskInput",
"LatentInput",
]
89 changes: 5 additions & 84 deletions comfy_api/input/video_types.py
Original file line number Diff line number Diff line change
@@ -1,85 +1,6 @@
from __future__ import annotations
from abc import ABC, abstractmethod
from typing import Optional, Union
import io
import av
from comfy_api.util import VideoContainer, VideoCodec, VideoComponents
# This file only exists for backwards compatibility.
from comfy_api.latest._input.video_types import VideoInput

class VideoInput(ABC):
"""
Abstract base class for video input types.
"""

@abstractmethod
def get_components(self) -> VideoComponents:
"""
Abstract method to get the video components (images, audio, and frame rate).

Returns:
VideoComponents containing images, audio, and frame rate
"""
pass

@abstractmethod
def save_to(
self,
path: str,
format: VideoContainer = VideoContainer.AUTO,
codec: VideoCodec = VideoCodec.AUTO,
metadata: Optional[dict] = None
):
"""
Abstract method to save the video input to a file.
"""
pass

def get_stream_source(self) -> Union[str, io.BytesIO]:
"""
Get a streamable source for the video. This allows processing without
loading the entire video into memory.

Returns:
Either a file path (str) or a BytesIO object that can be opened with av.

Default implementation creates a BytesIO buffer, but subclasses should
override this for better performance when possible.
"""
buffer = io.BytesIO()
self.save_to(buffer)
buffer.seek(0)
return buffer

# Provide a default implementation, but subclasses can provide optimized versions
# if possible.
def get_dimensions(self) -> tuple[int, int]:
"""
Returns the dimensions of the video input.

Returns:
Tuple of (width, height)
"""
components = self.get_components()
return components.images.shape[2], components.images.shape[1]

def get_duration(self) -> float:
"""
Returns the duration of the video in seconds.

Returns:
Duration in seconds
"""
components = self.get_components()
frame_count = components.images.shape[0]
return float(frame_count / components.frame_rate)

def get_container_format(self) -> str:
"""
Returns the container format of the video (e.g., 'mp4', 'mov', 'avi').

Returns:
Container format as string
"""
# Default implementation - subclasses should override for better performance
source = self.get_stream_source()
with av.open(source, mode="r") as container:
return container.format.name
__all__ = [
"VideoInput",
]
4 changes: 2 additions & 2 deletions comfy_api/input_impl/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from .video_types import VideoFromFile, VideoFromComponents
# This file only exists for backwards compatibility.
from comfy_api.latest._input_impl import VideoFromFile, VideoFromComponents

__all__ = [
# Implementations
"VideoFromFile",
"VideoFromComponents",
]
Loading
Loading