Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 14 additions & 13 deletions CODEOWNERS
Original file line number Diff line number Diff line change
Expand Up @@ -5,20 +5,21 @@
# Inlined the team members for now.

# Maintainers
*.md @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
/tests/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
/tests-unit/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
/notebooks/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
/script_examples/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
/.github/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
/requirements.txt @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
/pyproject.toml @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
*.md @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne @guill
/tests/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne @guill
/tests-unit/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne @guill
/notebooks/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne @guill
/script_examples/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne @guill
/.github/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne @guill
/requirements.txt @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne @guill
/pyproject.toml @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne @guill

# Python web server
/api_server/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @christian-byrne
/app/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @christian-byrne
/utils/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @christian-byrne
/api_server/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @christian-byrne @guill
/app/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @christian-byrne @guill
/utils/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @christian-byrne @guill

# Node developers
/comfy_extras/ @yoland68 @robinjhuang @pythongosssss @ltdrdata @Kosinkadink @webfiltered @christian-byrne
/comfy/comfy_types/ @yoland68 @robinjhuang @pythongosssss @ltdrdata @Kosinkadink @webfiltered @christian-byrne
/comfy_extras/ @yoland68 @robinjhuang @pythongosssss @ltdrdata @Kosinkadink @webfiltered @christian-byrne @guill
/comfy/comfy_types/ @yoland68 @robinjhuang @pythongosssss @ltdrdata @Kosinkadink @webfiltered @christian-byrne @guill
/comfy_api_nodes/ @yoland68 @robinjhuang @pythongosssss @ltdrdata @Kosinkadink @webfiltered @christian-byrne @guill
29 changes: 21 additions & 8 deletions comfy/ldm/qwen_image/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,6 +356,7 @@ def forward(
context,
attention_mask=None,
guidance: torch.Tensor = None,
transformer_options={},
**kwargs
):
timestep = timesteps
Expand Down Expand Up @@ -383,14 +384,26 @@ def forward(
else self.time_text_embed(timestep, guidance, hidden_states)
)

for block in self.transformer_blocks:
encoder_hidden_states, hidden_states = block(
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
encoder_hidden_states_mask=encoder_hidden_states_mask,
temb=temb,
image_rotary_emb=image_rotary_emb,
)
patches_replace = transformer_options.get("patches_replace", {})
blocks_replace = patches_replace.get("dit", {})

for i, block in enumerate(self.transformer_blocks):
if ("double_block", i) in blocks_replace:
def block_wrap(args):
out = {}
out["txt"], out["img"] = block(hidden_states=args["img"], encoder_hidden_states=args["txt"], encoder_hidden_states_mask=encoder_hidden_states_mask, temb=args["vec"], image_rotary_emb=args["pe"])
return out
out = blocks_replace[("double_block", i)]({"img": hidden_states, "txt": encoder_hidden_states, "vec": temb, "pe": image_rotary_emb}, {"original_block": block_wrap})
hidden_states = out["img"]
encoder_hidden_states = out["txt"]
else:
encoder_hidden_states, hidden_states = block(
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
encoder_hidden_states_mask=encoder_hidden_states_mask,
temb=temb,
image_rotary_emb=image_rotary_emb,
)

hidden_states = self.norm_out(hidden_states, temb)
hidden_states = self.proj_out(hidden_states)
Expand Down
27 changes: 15 additions & 12 deletions comfy/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,18 +32,21 @@ def scaled_dot_product_attention(q, k, v, *args, **kwargs):
try:
if torch.cuda.is_available():
from torch.nn.attention import SDPBackend, sdpa_kernel

SDPA_BACKEND_PRIORITY = [
SDPBackend.FLASH_ATTENTION,
SDPBackend.EFFICIENT_ATTENTION,
SDPBackend.MATH,
]

SDPA_BACKEND_PRIORITY.insert(0, SDPBackend.CUDNN_ATTENTION)

@sdpa_kernel(backends=SDPA_BACKEND_PRIORITY, set_priority=True)
def scaled_dot_product_attention(q, k, v, *args, **kwargs):
return torch.nn.functional.scaled_dot_product_attention(q, k, v, *args, **kwargs)
import inspect
if "set_priority" in inspect.signature(sdpa_kernel).parameters:
SDPA_BACKEND_PRIORITY = [
SDPBackend.FLASH_ATTENTION,
SDPBackend.EFFICIENT_ATTENTION,
SDPBackend.MATH,
]

SDPA_BACKEND_PRIORITY.insert(0, SDPBackend.CUDNN_ATTENTION)

def scaled_dot_product_attention(q, k, v, *args, **kwargs):
with sdpa_kernel(SDPA_BACKEND_PRIORITY, set_priority=True):
return torch.nn.functional.scaled_dot_product_attention(q, k, v, *args, **kwargs)
else:
logging.warning("Torch version too old to set sdpa backend priority.")
except (ModuleNotFoundError, TypeError):
logging.warning("Could not set sdpa backend priority.")

Expand Down
2 changes: 2 additions & 0 deletions comfy/rmsnorm.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import torch
import comfy.model_management
import numbers
import logging

RMSNorm = None

Expand All @@ -9,6 +10,7 @@
RMSNorm = torch.nn.RMSNorm
except:
rms_norm_torch = None
logging.warning("Please update pytorch to use native RMSNorm")


def rms_norm(x, weight=None, eps=1e-6):
Expand Down
128 changes: 91 additions & 37 deletions comfy_api_nodes/nodes_moonvalley.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import logging
from typing import Any, Callable, Optional, TypeVar
import random
import torch
from comfy_api_nodes.util.validation_utils import (
get_image_dimensions,
Expand Down Expand Up @@ -208,20 +207,29 @@ def _get_video_dimensions(video: VideoInput) -> tuple[int, int]:
def _validate_video_dimensions(width: int, height: int) -> None:
"""Validates video dimensions meet Moonvalley V2V requirements."""
supported_resolutions = {
(1920, 1080), (1080, 1920), (1152, 1152),
(1536, 1152), (1152, 1536)
(1920, 1080),
(1080, 1920),
(1152, 1152),
(1536, 1152),
(1152, 1536),
}

if (width, height) not in supported_resolutions:
supported_list = ', '.join([f'{w}x{h}' for w, h in sorted(supported_resolutions)])
raise ValueError(f"Resolution {width}x{height} not supported. Supported: {supported_list}")
supported_list = ", ".join(
[f"{w}x{h}" for w, h in sorted(supported_resolutions)]
)
raise ValueError(
f"Resolution {width}x{height} not supported. Supported: {supported_list}"
)


def _validate_container_format(video: VideoInput) -> None:
"""Validates video container format is MP4."""
container_format = video.get_container_format()
if container_format not in ['mp4', 'mov,mp4,m4a,3gp,3g2,mj2']:
raise ValueError(f"Only MP4 container format supported. Got: {container_format}")
if container_format not in ["mp4", "mov,mp4,m4a,3gp,3g2,mj2"]:
raise ValueError(
f"Only MP4 container format supported. Got: {container_format}"
)


def _validate_and_trim_duration(video: VideoInput) -> VideoInput:
Expand All @@ -244,7 +252,6 @@ def _trim_if_too_long(video: VideoInput, duration: float) -> VideoInput:
return video



def trim_video(video: VideoInput, duration_sec: float) -> VideoInput:
"""
Returns a new VideoInput object trimmed from the beginning to the specified duration,
Expand Down Expand Up @@ -302,7 +309,9 @@ def trim_video(video: VideoInput, duration_sec: float) -> VideoInput:
# Calculate target frame count that's divisible by 16
fps = input_container.streams.video[0].average_rate
estimated_frames = int(duration_sec * fps)
target_frames = (estimated_frames // 16) * 16 # Round down to nearest multiple of 16
target_frames = (
estimated_frames // 16
) * 16 # Round down to nearest multiple of 16

if target_frames == 0:
raise ValueError("Video too short: need at least 16 frames for Moonvalley")
Expand Down Expand Up @@ -424,7 +433,7 @@ def INPUT_TYPES(cls):
MoonvalleyTextToVideoInferenceParams,
"negative_prompt",
multiline=True,
default="low-poly, flat shader, bad rigging, stiff animation, uncanny eyes, low-quality textures, looping glitch, cheap effect, overbloom, bloom spam, default lighting, game asset, stiff face, ugly specular, AI artifacts",
default="<synthetic> <scene cut> gopro, bright, contrast, static, overexposed, vignette, artifacts, still, noise, texture, scanlines, videogame, 360 camera, VR, transition, flare, saturation, distorted, warped, wide angle, saturated, vibrant, glowing, cross dissolve, cheesy, ugly hands, mutated hands, mutant, disfigured, extra fingers, blown out, horrible, blurry, worst quality, bad, dissolve, melt, fade in, fade out, wobbly, weird, low quality, plastic, stock footage, video camera, boring",
),
"resolution": (
IO.COMBO,
Expand All @@ -441,12 +450,11 @@ def INPUT_TYPES(cls):
"tooltip": "Resolution of the output video",
},
),
# "length": (IO.COMBO,{"options":['5s','10s'], "default": '5s'}),
"prompt_adherence": model_field_to_node_input(
IO.FLOAT,
MoonvalleyTextToVideoInferenceParams,
"guidance_scale",
default=7.0,
default=10.0,
step=1,
min=1,
max=20,
Expand All @@ -455,13 +463,12 @@ def INPUT_TYPES(cls):
IO.INT,
MoonvalleyTextToVideoInferenceParams,
"seed",
default=random.randint(0, 2**32 - 1),
default=9,
min=0,
max=4294967295,
step=1,
display="number",
tooltip="Random seed value",
control_after_generate=True,
),
"steps": model_field_to_node_input(
IO.INT,
Expand Down Expand Up @@ -532,9 +539,11 @@ async def generate(
# Get MIME type from tensor - assuming PNG format for image tensors
mime_type = "image/png"

image_url = (await upload_images_to_comfyapi(
image, max_images=1, auth_kwargs=kwargs, mime_type=mime_type
))[0]
image_url = (
await upload_images_to_comfyapi(
image, max_images=1, auth_kwargs=kwargs, mime_type=mime_type
)
)[0]

request = MoonvalleyTextToVideoRequest(
image_url=image_url, prompt_text=prompt, inference_params=inference_params
Expand Down Expand Up @@ -570,25 +579,54 @@ def INPUT_TYPES(cls):
return {
"required": {
"prompt": model_field_to_node_input(
IO.STRING, MoonvalleyVideoToVideoRequest, "prompt_text",
multiline=True
IO.STRING,
MoonvalleyVideoToVideoRequest,
"prompt_text",
multiline=True,
),
"negative_prompt": model_field_to_node_input(
IO.STRING,
MoonvalleyVideoToVideoInferenceParams,
"negative_prompt",
multiline=True,
default="low-poly, flat shader, bad rigging, stiff animation, uncanny eyes, low-quality textures, looping glitch, cheap effect, overbloom, bloom spam, default lighting, game asset, stiff face, ugly specular, AI artifacts"
default="<synthetic> <scene cut> gopro, bright, contrast, static, overexposed, vignette, artifacts, still, noise, texture, scanlines, videogame, 360 camera, VR, transition, flare, saturation, distorted, warped, wide angle, saturated, vibrant, glowing, cross dissolve, cheesy, ugly hands, mutated hands, mutant, disfigured, extra fingers, blown out, horrible, blurry, worst quality, bad, dissolve, melt, fade in, fade out, wobbly, weird, low quality, plastic, stock footage, video camera, boring",
),
"seed": model_field_to_node_input(
IO.INT,
MoonvalleyVideoToVideoInferenceParams,
"seed",
default=9,
min=0,
max=4294967295,
step=1,
display="number",
tooltip="Random seed value",
control_after_generate=False,
),
"prompt_adherence": model_field_to_node_input(
IO.FLOAT,
MoonvalleyVideoToVideoInferenceParams,
"guidance_scale",
default=10.0,
step=1,
min=1,
max=20,
),
"seed": model_field_to_node_input(IO.INT,MoonvalleyVideoToVideoInferenceParams, "seed", default=random.randint(0, 2**32 - 1), min=0, max=4294967295, step=1, display="number", tooltip="Random seed value", control_after_generate=True),
},
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
"unique_id": "UNIQUE_ID",
},
"optional": {
"video": (IO.VIDEO, {"default": "", "multiline": False, "tooltip": "The reference video used to generate the output video. Must be at least 5 seconds long. Videos longer than 5s will be automatically trimmed. Only MP4 format supported."}),
"video": (
IO.VIDEO,
{
"default": "",
"multiline": False,
"tooltip": "The reference video used to generate the output video. Must be at least 5 seconds long. Videos longer than 5s will be automatically trimmed. Only MP4 format supported.",
},
),
"control_type": (
["Motion Transfer", "Pose Transfer"],
{"default": "Motion Transfer"},
Expand All @@ -602,8 +640,14 @@ def INPUT_TYPES(cls):
"max": 100,
"tooltip": "Only used if control_type is 'Motion Transfer'",
},
)
}
),
"image": model_field_to_node_input(
IO.IMAGE,
MoonvalleyTextToVideoRequest,
"image_url",
tooltip="The reference image used to generate the video",
),
},
}

RETURN_TYPES = ("VIDEO",)
Expand All @@ -613,15 +657,24 @@ async def generate(
self, prompt, negative_prompt, unique_id: Optional[str] = None, **kwargs
):
video = kwargs.get("video")
image = kwargs.get("image", None)

if not video:
raise MoonvalleyApiError("video is required")

video_url = ""
if video:
validated_video = validate_video_to_video_input(video)
video_url = await upload_video_to_comfyapi(validated_video, auth_kwargs=kwargs)
video_url = await upload_video_to_comfyapi(
validated_video, auth_kwargs=kwargs
)
mime_type = "image/png"

if not image is None:
validate_input_image(image, with_frame_conditioning=True)
image_url = await upload_images_to_comfyapi(
image=image, auth_kwargs=kwargs, max_images=1, mime_type=mime_type
)
control_type = kwargs.get("control_type")
motion_intensity = kwargs.get("motion_intensity")

Expand All @@ -631,12 +684,12 @@ async def generate(
# Only include motion_intensity for Motion Transfer
control_params = {}
if control_type == "Motion Transfer" and motion_intensity is not None:
control_params['motion_intensity'] = motion_intensity
control_params["motion_intensity"] = motion_intensity

inference_params=MoonvalleyVideoToVideoInferenceParams(
inference_params = MoonvalleyVideoToVideoInferenceParams(
negative_prompt=negative_prompt,
seed=kwargs.get("seed"),
control_params=control_params
control_params=control_params,
)

control = self.parseControlParameter(control_type)
Expand All @@ -647,6 +700,7 @@ async def generate(
prompt_text=prompt,
inference_params=inference_params,
)
request.image_url = image_url if not image is None else None

initial_operation = SynchronousOperation(
endpoint=ApiEndpoint(
Expand Down Expand Up @@ -694,15 +748,15 @@ async def generate(
validate_prompts(prompt, negative_prompt, MOONVALLEY_MAREY_MAX_PROMPT_LENGTH)
width_height = self.parseWidthHeightFromRes(kwargs.get("resolution"))

inference_params=MoonvalleyTextToVideoInferenceParams(
negative_prompt=negative_prompt,
steps=kwargs.get("steps"),
seed=kwargs.get("seed"),
guidance_scale=kwargs.get("prompt_adherence"),
num_frames=128,
width=width_height.get("width"),
height=width_height.get("height"),
)
inference_params = MoonvalleyTextToVideoInferenceParams(
negative_prompt=negative_prompt,
steps=kwargs.get("steps"),
seed=kwargs.get("seed"),
guidance_scale=kwargs.get("prompt_adherence"),
num_frames=128,
width=width_height.get("width"),
height=width_height.get("height"),
)
request = MoonvalleyTextToVideoRequest(
prompt_text=prompt, inference_params=inference_params
)
Expand Down
Loading
Loading