Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions comfy/ldm/chroma/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,10 @@ def forward_orig(
pe = self.pe_embedder(ids)

blocks_replace = patches_replace.get("dit", {})
transformer_options["total_blocks"] = len(self.double_blocks)
transformer_options["block_type"] = "double"
for i, block in enumerate(self.double_blocks):
transformer_options["block_index"] = i
if i not in self.skip_mmdit:
double_mod = (
self.get_modulations(mod_vectors, "double_img", idx=i),
Expand Down Expand Up @@ -222,7 +225,10 @@ def block_wrap(args):

img = torch.cat((txt, img), 1)

transformer_options["total_blocks"] = len(self.single_blocks)
transformer_options["block_type"] = "single"
for i, block in enumerate(self.single_blocks):
transformer_options["block_index"] = i
if i not in self.skip_dit:
single_mod = self.get_modulations(mod_vectors, "single", idx=i)
if ("single_block", i) in blocks_replace:
Expand Down
6 changes: 6 additions & 0 deletions comfy/ldm/hunyuan_video/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,7 +389,10 @@ def forward_orig(
attn_mask = None

blocks_replace = patches_replace.get("dit", {})
transformer_options["total_blocks"] = len(self.double_blocks)
transformer_options["block_type"] = "double"
for i, block in enumerate(self.double_blocks):
transformer_options["block_index"] = i
if ("double_block", i) in blocks_replace:
def block_wrap(args):
out = {}
Expand All @@ -411,7 +414,10 @@ def block_wrap(args):

img = torch.cat((img, txt), 1)

transformer_options["total_blocks"] = len(self.single_blocks)
transformer_options["block_type"] = "single"
for i, block in enumerate(self.single_blocks):
transformer_options["block_index"] = i
if ("single_block", i) in blocks_replace:
def block_wrap(args):
out = {}
Expand Down
3 changes: 3 additions & 0 deletions comfy/ldm/qwen_image/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -439,7 +439,10 @@ def _forward(
patches = transformer_options.get("patches", {})
blocks_replace = patches_replace.get("dit", {})

transformer_options["total_blocks"] = len(self.transformer_blocks)
transformer_options["block_type"] = "double"
for i, block in enumerate(self.transformer_blocks):
transformer_options["block_index"] = i
if ("double_block", i) in blocks_replace:
def block_wrap(args):
out = {}
Expand Down
28 changes: 28 additions & 0 deletions comfy_api/latest/_input/video_types.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations
from abc import ABC, abstractmethod
from fractions import Fraction
from typing import Optional, Union, IO
import io
import av
Expand Down Expand Up @@ -72,6 +73,33 @@ def get_duration(self) -> float:
frame_count = components.images.shape[0]
return float(frame_count / components.frame_rate)

def get_frame_count(self) -> int:
"""
Returns the number of frames in the video.

Default implementation uses :meth:`get_components`, which may require
loading all frames into memory. File-based implementations should
override this method and use container/stream metadata instead.

Returns:
Total number of frames as an integer.
"""
return int(self.get_components().images.shape[0])

def get_frame_rate(self) -> Fraction:
"""
Returns the frame rate of the video.

Default implementation materializes the video into memory via
`get_components()`. Subclasses that can inspect the underlying
container (e.g. `VideoFromFile`) should override this with a more
efficient implementation.

Returns:
Frame rate as a Fraction.
"""
return self.get_components().frame_rate

def get_container_format(self) -> str:
"""
Returns the container format of the video (e.g., 'mp4', 'mov', 'avi').
Expand Down
72 changes: 72 additions & 0 deletions comfy_api/latest/_input_impl/video_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,71 @@ def get_duration(self) -> float:

raise ValueError(f"Could not determine duration for file '{self.__file}'")

def get_frame_count(self) -> int:
"""
Returns the number of frames in the video without materializing them as
torch tensors.
"""
if isinstance(self.__file, io.BytesIO):
self.__file.seek(0)

with av.open(self.__file, mode="r") as container:
video_stream = self._get_first_video_stream(container)
# 1. Prefer the frames field if available
if video_stream.frames and video_stream.frames > 0:
return int(video_stream.frames)

# 2. Try to estimate from duration and average_rate using only metadata
if container.duration is not None and video_stream.average_rate:
duration_seconds = float(container.duration / av.time_base)
estimated_frames = int(round(duration_seconds * float(video_stream.average_rate)))
if estimated_frames > 0:
return estimated_frames

if (
getattr(video_stream, "duration", None) is not None
and getattr(video_stream, "time_base", None) is not None
and video_stream.average_rate
):
duration_seconds = float(video_stream.duration * video_stream.time_base)
estimated_frames = int(round(duration_seconds * float(video_stream.average_rate)))
if estimated_frames > 0:
return estimated_frames

# 3. Last resort: decode frames and count them (streaming)
frame_count = 0
container.seek(0)
for packet in container.demux(video_stream):
for _ in packet.decode():
frame_count += 1

if frame_count == 0:
raise ValueError(f"Could not determine frame count for file '{self.__file}'")
return frame_count

def get_frame_rate(self) -> Fraction:
"""
Returns the average frame rate of the video using container metadata
without decoding all frames.
"""
if isinstance(self.__file, io.BytesIO):
self.__file.seek(0)

with av.open(self.__file, mode="r") as container:
video_stream = self._get_first_video_stream(container)
# Preferred: use PyAV's average_rate (usually already a Fraction-like)
if video_stream.average_rate:
return Fraction(video_stream.average_rate)

# Fallback: estimate from frames + duration if available
if video_stream.frames and container.duration:
duration_seconds = float(container.duration / av.time_base)
if duration_seconds > 0:
return Fraction(video_stream.frames / duration_seconds).limit_denominator()

# Last resort: match get_components_internal default
return Fraction(1)

def get_container_format(self) -> str:
"""
Returns the container format of the video (e.g., 'mp4', 'mov', 'avi').
Expand Down Expand Up @@ -238,6 +303,13 @@ def save_to(
packet.stream = stream_map[packet.stream]
output_container.mux(packet)

def _get_first_video_stream(self, container: InputContainer):
video_stream = next((s for s in container.streams if s.type == "video"), None)
if video_stream is None:
raise ValueError(f"No video stream found in file '{self.__file}'")
return video_stream


class VideoFromComponents(VideoInput):
"""
Class representing video input from tensors.
Expand Down
6 changes: 3 additions & 3 deletions comfy_api_nodes/apis/gemini_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,9 +113,9 @@ class GeminiGenerationConfig(BaseModel):
maxOutputTokens: int | None = Field(None, ge=16, le=8192)
seed: int | None = Field(None)
stopSequences: list[str] | None = Field(None)
temperature: float | None = Field(1, ge=0.0, le=2.0)
topK: int | None = Field(40, ge=1)
topP: float | None = Field(0.95, ge=0.0, le=1.0)
temperature: float | None = Field(None, ge=0.0, le=2.0)
topK: int | None = Field(None, ge=1)
topP: float | None = Field(None, ge=0.0, le=1.0)


class GeminiImageConfig(BaseModel):
Expand Down
21 changes: 11 additions & 10 deletions comfy_api_nodes/nodes_gemini.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,14 +104,14 @@ def get_parts_by_type(response: GeminiGenerateContentResponse, part_type: Litera
List of response parts matching the requested type.
"""
if response.candidates is None:
if response.promptFeedback.blockReason:
if response.promptFeedback and response.promptFeedback.blockReason:
feedback = response.promptFeedback
raise ValueError(
f"Gemini API blocked the request. Reason: {feedback.blockReason} ({feedback.blockReasonMessage})"
)
raise NotImplementedError(
"Gemini returned no response candidates. "
"Please report to ComfyUI repository with the example of workflow to reproduce this."
raise ValueError(
"Gemini API returned no response candidates. If you are using the `IMAGE` modality, "
"try changing it to `IMAGE+TEXT` to view the model's reasoning and understand why image generation failed."
)
parts = []
for part in response.candidates[0].content.parts:
Expand Down Expand Up @@ -182,11 +182,12 @@ def calculate_tokens_price(response: GeminiGenerateContentResponse) -> float | N
else:
return None
final_price = response.usageMetadata.promptTokenCount * input_tokens_price
for i in response.usageMetadata.candidatesTokensDetails:
if i.modality == Modality.IMAGE:
final_price += output_image_tokens_price * i.tokenCount # for Nano Banana models
else:
final_price += output_text_tokens_price * i.tokenCount
if response.usageMetadata.candidatesTokensDetails:
for i in response.usageMetadata.candidatesTokensDetails:
if i.modality == Modality.IMAGE:
final_price += output_image_tokens_price * i.tokenCount # for Nano Banana models
else:
final_price += output_text_tokens_price * i.tokenCount
if response.usageMetadata.thoughtsTokenCount:
final_price += output_text_tokens_price * response.usageMetadata.thoughtsTokenCount
return final_price / 1_000_000.0
Expand Down Expand Up @@ -645,7 +646,7 @@ def define_schema(cls):
options=["auto", "1:1", "2:3", "3:2", "3:4", "4:3", "4:5", "5:4", "9:16", "16:9", "21:9"],
default="auto",
tooltip="If set to 'auto', matches your input image's aspect ratio; "
"if no image is provided, generates a 1:1 square.",
"if no image is provided, a 16:9 square is usually generated.",
),
IO.Combo.Input(
"resolution",
Expand Down
13 changes: 5 additions & 8 deletions comfy_api_nodes/nodes_topaz.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,7 @@
import torch
from typing_extensions import override

from comfy_api.input.video_types import VideoInput
from comfy_api.latest import IO, ComfyExtension
from comfy_api.latest import IO, ComfyExtension, Input
from comfy_api_nodes.apis import topaz_api
from comfy_api_nodes.util import (
ApiEndpoint,
Expand Down Expand Up @@ -282,7 +281,7 @@ def define_schema(cls):
@classmethod
async def execute(
cls,
video: VideoInput,
video: Input.Video,
upscaler_enabled: bool,
upscaler_model: str,
upscaler_resolution: str,
Expand All @@ -297,12 +296,10 @@ async def execute(
) -> IO.NodeOutput:
if upscaler_enabled is False and interpolation_enabled is False:
raise ValueError("There is nothing to do: both upscaling and interpolation are disabled.")
validate_container_format_is_mp4(video)
src_width, src_height = video.get_dimensions()
video_components = video.get_components()
src_frame_rate = int(video_components.frame_rate)
src_frame_rate = int(video.get_frame_rate())
duration_sec = video.get_duration()
estimated_frames = int(duration_sec * src_frame_rate)
validate_container_format_is_mp4(video)
src_video_stream = video.get_stream_source()
target_width = src_width
target_height = src_height
Expand Down Expand Up @@ -338,7 +335,7 @@ async def execute(
container="mp4",
size=get_fs_object_size(src_video_stream),
duration=int(duration_sec),
frameCount=estimated_frames,
frameCount=video.get_frame_count(),
frameRate=src_frame_rate,
resolution=topaz_api.Resolution(width=src_width, height=src_height),
),
Expand Down
Loading