Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions comfy_api/input/video_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from abc import ABC, abstractmethod
from typing import Optional, Union
import io
import av
from comfy_api.util import VideoContainer, VideoCodec, VideoComponents

class VideoInput(ABC):
Expand Down Expand Up @@ -70,3 +71,15 @@ def get_duration(self) -> float:
components = self.get_components()
frame_count = components.images.shape[0]
return float(frame_count / components.frame_rate)

def get_container_format(self) -> str:
"""
Returns the container format of the video (e.g., 'mp4', 'mov', 'avi').

Returns:
Container format as string
"""
# Default implementation - subclasses should override for better performance
source = self.get_stream_source()
with av.open(source, mode="r") as container:
return container.format.name
12 changes: 12 additions & 0 deletions comfy_api/input_impl/video_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,18 @@ def get_duration(self) -> float:

raise ValueError(f"Could not determine duration for file '{self.__file}'")

def get_container_format(self) -> str:
"""
Returns the container format of the video (e.g., 'mp4', 'mov', 'avi').

Returns:
Container format as string
"""
if isinstance(self.__file, io.BytesIO):
self.__file.seek(0)
with av.open(self.__file, mode='r') as container:
return container.format.name

def get_components_internal(self, container: InputContainer) -> VideoComponents:
# Get video frames
frames = []
Expand Down
225 changes: 120 additions & 105 deletions comfy_api_nodes/nodes_moonvalley.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from comfy_api_nodes.util.validation_utils import (
get_image_dimensions,
validate_image_dimensions,
validate_video_dimensions,
)


Expand Down Expand Up @@ -176,54 +175,76 @@ def validate_input_image(
)


def validate_input_video(
video: VideoInput, num_frames_out: int, with_frame_conditioning: bool = False
):
def validate_video_to_video_input(video: VideoInput) -> VideoInput:
"""
Validates and processes video input for Moonvalley Video-to-Video generation.

Args:
video: Input video to validate

Returns:
Validated and potentially trimmed video

Raises:
ValueError: If video doesn't meet requirements
MoonvalleyApiError: If video duration is too short
"""
width, height = _get_video_dimensions(video)
_validate_video_dimensions(width, height)
_validate_container_format(video)

return _validate_and_trim_duration(video)


def _get_video_dimensions(video: VideoInput) -> tuple[int, int]:
"""Extracts video dimensions with error handling."""
try:
width, height = video.get_dimensions()
return video.get_dimensions()
except Exception as e:
logging.error("Error getting dimensions of video: %s", e)
raise ValueError(f"Cannot get video dimensions: {e}") from e

validate_input_media(width, height, with_frame_conditioning)
validate_video_dimensions(
video,
min_width=MIN_VID_WIDTH,
min_height=MIN_VID_HEIGHT,
max_width=MAX_VID_WIDTH,
max_height=MAX_VID_HEIGHT,
)

trimmed_video = validate_input_video_length(video, num_frames_out)
return trimmed_video
def _validate_video_dimensions(width: int, height: int) -> None:
"""Validates video dimensions meet Moonvalley V2V requirements."""
supported_resolutions = {
(1920, 1080), (1080, 1920), (1152, 1152),
(1536, 1152), (1152, 1536)
}

if (width, height) not in supported_resolutions:
supported_list = ', '.join([f'{w}x{h}' for w, h in sorted(supported_resolutions)])
raise ValueError(f"Resolution {width}x{height} not supported. Supported: {supported_list}")

def validate_input_video_length(video: VideoInput, num_frames: int):

if video.get_duration() > 60:
raise MoonvalleyApiError(
"Input Video lenth should be less than 1min. Please trim."
)
def _validate_container_format(video: VideoInput) -> None:
"""Validates video container format is MP4."""
container_format = video.get_container_format()
if container_format not in ['mp4', 'mov,mp4,m4a,3gp,3g2,mj2']:
raise ValueError(f"Only MP4 container format supported. Got: {container_format}")

if num_frames == 128:
if video.get_duration() < 5:
raise MoonvalleyApiError(
"Input Video length is less than 5s. Please use a video longer than or equal to 5s."
)
if video.get_duration() > 5:
# trim video to 5s
video = trim_video(video, 5)
if num_frames == 256:
if video.get_duration() < 10:
raise MoonvalleyApiError(
"Input Video length is less than 10s. Please use a video longer than or equal to 10s."
)
if video.get_duration() > 10:
# trim video to 10s
video = trim_video(video, 10)

def _validate_and_trim_duration(video: VideoInput) -> VideoInput:
"""Validates video duration and trims to 5 seconds if needed."""
duration = video.get_duration()
_validate_minimum_duration(duration)
return _trim_if_too_long(video, duration)


def _validate_minimum_duration(duration: float) -> None:
"""Ensures video is at least 5 seconds long."""
if duration < 5:
raise MoonvalleyApiError("Input video must be at least 5 seconds long.")


def _trim_if_too_long(video: VideoInput, duration: float) -> VideoInput:
"""Trims video to 5 seconds if longer."""
if duration > 5:
return trim_video(video, 5)
return video



def trim_video(video: VideoInput, duration_sec: float) -> VideoInput:
"""
Returns a new VideoInput object trimmed from the beginning to the specified duration,
Expand Down Expand Up @@ -278,15 +299,13 @@ def trim_video(video: VideoInput, duration_sec: float) -> VideoInput:
f"Added audio stream: {stream.sample_rate}Hz, {stream.channels} channels"
)

# Calculate target frame count that's divisible by 32
# Calculate target frame count that's divisible by 16
fps = input_container.streams.video[0].average_rate
estimated_frames = int(duration_sec * fps)
target_frames = (
estimated_frames // 32
) * 32 # Round down to nearest multiple of 32
target_frames = (estimated_frames // 16) * 16 # Round down to nearest multiple of 16

if target_frames == 0:
raise ValueError("Video too short: need at least 32 frames for Moonvalley")
raise ValueError("Video too short: need at least 16 frames for Moonvalley")

frame_count = 0
audio_frame_count = 0
Expand Down Expand Up @@ -353,8 +372,8 @@ def parseWidthHeightFromRes(self, resolution: str):
"16:9 (1920 x 1080)": {"width": 1920, "height": 1080},
"9:16 (1080 x 1920)": {"width": 1080, "height": 1920},
"1:1 (1152 x 1152)": {"width": 1152, "height": 1152},
"4:3 (1440 x 1080)": {"width": 1440, "height": 1080},
"3:4 (1080 x 1440)": {"width": 1080, "height": 1440},
"4:3 (1536 x 1152)": {"width": 1536, "height": 1152},
"3:4 (1152 x 1536)": {"width": 1152, "height": 1536},
"21:9 (2560 x 1080)": {"width": 2560, "height": 1080},
}
if resolution in res_map:
Expand Down Expand Up @@ -494,7 +513,6 @@ def generate(
image = kwargs.get("image", None)
if image is None:
raise MoonvalleyApiError("image is required")
total_frames = get_total_frames_from_length()

validate_input_image(image, True)
validate_prompts(prompt, negative_prompt, MOONVALLEY_MAREY_MAX_PROMPT_LENGTH)
Expand All @@ -505,7 +523,7 @@ def generate(
steps=kwargs.get("steps"),
seed=kwargs.get("seed"),
guidance_scale=kwargs.get("prompt_adherence"),
num_frames=total_frames,
num_frames=128,
width=width_height.get("width"),
height=width_height.get("height"),
use_negative_prompts=True,
Expand Down Expand Up @@ -549,68 +567,76 @@ def __init__(self):

@classmethod
def INPUT_TYPES(cls):
input_types = super().INPUT_TYPES()
for param in ["resolution", "image"]:
if param in input_types["required"]:
del input_types["required"][param]
if param in input_types["optional"]:
del input_types["optional"][param]
input_types["optional"] = {
"video": (
IO.VIDEO,
{
"default": "",
"multiline": False,
"tooltip": "The reference video used to generate the output video. Input a 5s video for 128 frames and a 10s video for 256 frames. Longer videos will be trimmed automatically.",
},
),
"control_type": (
["Motion Transfer", "Pose Transfer"],
{"default": "Motion Transfer"},
),
"motion_intensity": (
"INT",
{
"default": 100,
"step": 1,
"min": 0,
"max": 100,
"tooltip": "Only used if control_type is 'Motion Transfer'",
},
),
return {
"required": {
"prompt": model_field_to_node_input(
IO.STRING, MoonvalleyVideoToVideoRequest, "prompt_text",
multiline=True
),
"negative_prompt": model_field_to_node_input(
IO.STRING,
MoonvalleyVideoToVideoInferenceParams,
"negative_prompt",
multiline=True,
default="low-poly, flat shader, bad rigging, stiff animation, uncanny eyes, low-quality textures, looping glitch, cheap effect, overbloom, bloom spam, default lighting, game asset, stiff face, ugly specular, AI artifacts"
),
"seed": model_field_to_node_input(IO.INT,MoonvalleyVideoToVideoInferenceParams, "seed", default=random.randint(0, 2**32 - 1), min=0, max=4294967295, step=1, display="number", tooltip="Random seed value", control_after_generate=True),
},
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
"unique_id": "UNIQUE_ID",
},
"optional": {
"video": (IO.VIDEO, {"default": "", "multiline": False, "tooltip": "The reference video used to generate the output video. Must be at least 5 seconds long. Videos longer than 5s will be automatically trimmed. Only MP4 format supported."}),
"control_type": (
["Motion Transfer", "Pose Transfer"],
{"default": "Motion Transfer"},
),
"motion_intensity": (
"INT",
{
"default": 100,
"step": 1,
"min": 0,
"max": 100,
"tooltip": "Only used if control_type is 'Motion Transfer'",
},
)
}
}

return input_types

RETURN_TYPES = ("VIDEO",)
RETURN_NAMES = ("video",)

def generate(
self, prompt, negative_prompt, unique_id: Optional[str] = None, **kwargs
):
video = kwargs.get("video")
num_frames = get_total_frames_from_length()

if not video:
raise MoonvalleyApiError("video is required")

"""Validate video input"""
video_url = ""
if video:
validated_video = validate_input_video(video, num_frames, False)
validated_video = validate_video_to_video_input(video)
video_url = upload_video_to_comfyapi(validated_video, auth_kwargs=kwargs)

control_type = kwargs.get("control_type")
motion_intensity = kwargs.get("motion_intensity")

"""Validate prompts and inference input"""
validate_prompts(prompt, negative_prompt)
inference_params = MoonvalleyVideoToVideoInferenceParams(

# Only include motion_intensity for Motion Transfer
control_params = {}
if control_type == "Motion Transfer" and motion_intensity is not None:
control_params['motion_intensity'] = motion_intensity

inference_params=MoonvalleyVideoToVideoInferenceParams(
negative_prompt=negative_prompt,
steps=kwargs.get("steps"),
seed=kwargs.get("seed"),
guidance_scale=kwargs.get("prompt_adherence"),
control_params={"motion_intensity": motion_intensity},
control_params=control_params
)

control = self.parseControlParameter(control_type)
Expand Down Expand Up @@ -667,17 +693,16 @@ def generate(
):
validate_prompts(prompt, negative_prompt, MOONVALLEY_MAREY_MAX_PROMPT_LENGTH)
width_height = self.parseWidthHeightFromRes(kwargs.get("resolution"))
num_frames = get_total_frames_from_length()

inference_params = MoonvalleyTextToVideoInferenceParams(
negative_prompt=negative_prompt,
steps=kwargs.get("steps"),
seed=kwargs.get("seed"),
guidance_scale=kwargs.get("prompt_adherence"),
num_frames=num_frames,
width=width_height.get("width"),
height=width_height.get("height"),
)
inference_params=MoonvalleyTextToVideoInferenceParams(
negative_prompt=negative_prompt,
steps=kwargs.get("steps"),
seed=kwargs.get("seed"),
guidance_scale=kwargs.get("prompt_adherence"),
num_frames=128,
width=width_height.get("width"),
height=width_height.get("height"),
)
request = MoonvalleyTextToVideoRequest(
prompt_text=prompt, inference_params=inference_params
)
Expand Down Expand Up @@ -707,22 +732,12 @@ def generate(
NODE_CLASS_MAPPINGS = {
"MoonvalleyImg2VideoNode": MoonvalleyImg2VideoNode,
"MoonvalleyTxt2VideoNode": MoonvalleyTxt2VideoNode,
# "MoonvalleyVideo2VideoNode": MoonvalleyVideo2VideoNode,
"MoonvalleyVideo2VideoNode": MoonvalleyVideo2VideoNode,
}


NODE_DISPLAY_NAME_MAPPINGS = {
"MoonvalleyImg2VideoNode": "Moonvalley Marey Image to Video",
"MoonvalleyTxt2VideoNode": "Moonvalley Marey Text to Video",
# "MoonvalleyVideo2VideoNode": "Moonvalley Marey Video to Video",
"MoonvalleyVideo2VideoNode": "Moonvalley Marey Video to Video",
}


def get_total_frames_from_length(length="5s"):
# if length == '5s':
# return 128
# elif length == '10s':
# return 256
return 128
# else:
# raise MoonvalleyApiError("length is required")
Loading