Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions comfy/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -998,11 +998,12 @@ def set_progress_bar_global_hook(function):
PROGRESS_BAR_HOOK = function

class ProgressBar:
def __init__(self, total):
def __init__(self, total, node_id=None):
global PROGRESS_BAR_HOOK
self.total = total
self.current = 0
self.hook = PROGRESS_BAR_HOOK
self.node_id = node_id

def update_absolute(self, value, total=None, preview=None):
if total is not None:
Expand All @@ -1011,7 +1012,7 @@ def update_absolute(self, value, total=None, preview=None):
value = self.total
self.current = value
if self.hook is not None:
self.hook(self.current, self.total, preview)
self.hook(self.current, self.total, preview, node_id=self.node_id)

def update(self, value):
self.update_absolute(self.current + value)
Expand Down
69 changes: 69 additions & 0 deletions comfy_api/feature_flags.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
"""
Feature flags module for ComfyUI WebSocket protocol negotiation.

This module handles capability negotiation between frontend and backend,
allowing graceful protocol evolution while maintaining backward compatibility.
"""

from typing import Any, Dict

from comfy.cli_args import args

# Default server capabilities
SERVER_FEATURE_FLAGS: Dict[str, Any] = {
"supports_preview_metadata": True,
"max_upload_size": args.max_upload_size * 1024 * 1024, # Convert MB to bytes
}


def get_connection_feature(
sockets_metadata: Dict[str, Dict[str, Any]],
sid: str,
feature_name: str,
default: Any = False
) -> Any:
"""
Get a feature flag value for a specific connection.

Args:
sockets_metadata: Dictionary of socket metadata
sid: Session ID of the connection
feature_name: Name of the feature to check
default: Default value if feature not found

Returns:
Feature value or default if not found
"""
if sid not in sockets_metadata:
return default

return sockets_metadata[sid].get("feature_flags", {}).get(feature_name, default)


def supports_feature(
sockets_metadata: Dict[str, Dict[str, Any]],
sid: str,
feature_name: str
) -> bool:
"""
Check if a connection supports a specific feature.

Args:
sockets_metadata: Dictionary of socket metadata
sid: Session ID of the connection
feature_name: Name of the feature to check

Returns:
Boolean indicating if feature is supported
"""
return get_connection_feature(sockets_metadata, sid, feature_name, False) is True


def get_server_features() -> Dict[str, Any]:
"""
Get the server's feature flags.

Returns:
Dictionary of server feature flags
"""
return SERVER_FEATURE_FLAGS.copy()
2 changes: 2 additions & 0 deletions comfy_api_nodes/nodes_kling.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,8 @@ def poll_until_finished(
result_url_extractor=result_url_extractor,
estimated_duration=estimated_duration,
node_id=node_id,
poll_interval=16.0,
max_poll_attempts=256,
).execute()


Expand Down
53 changes: 27 additions & 26 deletions comfy_execution/caching.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import itertools
from typing import Sequence, Mapping, Dict
from comfy_execution.graph import DynamicPrompt
from abc import ABC, abstractmethod

import nodes

Expand All @@ -16,12 +17,13 @@ def include_unique_id_in_input(class_type: str) -> bool:
NODE_CLASS_CONTAINS_UNIQUE_ID[class_type] = "UNIQUE_ID" in class_def.INPUT_TYPES().get("hidden", {}).values()
return NODE_CLASS_CONTAINS_UNIQUE_ID[class_type]

class CacheKeySet:
class CacheKeySet(ABC):
def __init__(self, dynprompt, node_ids, is_changed_cache):
self.keys = {}
self.subcache_keys = {}

def add_keys(self, node_ids):
@abstractmethod
async def add_keys(self, node_ids):
raise NotImplementedError()

def all_node_ids(self):
Expand Down Expand Up @@ -60,9 +62,8 @@ class CacheKeySetID(CacheKeySet):
def __init__(self, dynprompt, node_ids, is_changed_cache):
super().__init__(dynprompt, node_ids, is_changed_cache)
self.dynprompt = dynprompt
self.add_keys(node_ids)

def add_keys(self, node_ids):
async def add_keys(self, node_ids):
for node_id in node_ids:
if node_id in self.keys:
continue
Expand All @@ -77,37 +78,36 @@ def __init__(self, dynprompt, node_ids, is_changed_cache):
super().__init__(dynprompt, node_ids, is_changed_cache)
self.dynprompt = dynprompt
self.is_changed_cache = is_changed_cache
self.add_keys(node_ids)

def include_node_id_in_input(self) -> bool:
return False

def add_keys(self, node_ids):
async def add_keys(self, node_ids):
for node_id in node_ids:
if node_id in self.keys:
continue
if not self.dynprompt.has_node(node_id):
continue
node = self.dynprompt.get_node(node_id)
self.keys[node_id] = self.get_node_signature(self.dynprompt, node_id)
self.keys[node_id] = await self.get_node_signature(self.dynprompt, node_id)
self.subcache_keys[node_id] = (node_id, node["class_type"])

def get_node_signature(self, dynprompt, node_id):
async def get_node_signature(self, dynprompt, node_id):
signature = []
ancestors, order_mapping = self.get_ordered_ancestry(dynprompt, node_id)
signature.append(self.get_immediate_node_signature(dynprompt, node_id, order_mapping))
signature.append(await self.get_immediate_node_signature(dynprompt, node_id, order_mapping))
for ancestor_id in ancestors:
signature.append(self.get_immediate_node_signature(dynprompt, ancestor_id, order_mapping))
signature.append(await self.get_immediate_node_signature(dynprompt, ancestor_id, order_mapping))
return to_hashable(signature)

def get_immediate_node_signature(self, dynprompt, node_id, ancestor_order_mapping):
async def get_immediate_node_signature(self, dynprompt, node_id, ancestor_order_mapping):
if not dynprompt.has_node(node_id):
# This node doesn't exist -- we can't cache it.
return [float("NaN")]
node = dynprompt.get_node(node_id)
class_type = node["class_type"]
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
signature = [class_type, self.is_changed_cache.get(node_id)]
signature = [class_type, await self.is_changed_cache.get(node_id)]
if self.include_node_id_in_input() or (hasattr(class_def, "NOT_IDEMPOTENT") and class_def.NOT_IDEMPOTENT) or include_unique_id_in_input(class_type):
signature.append(node_id)
inputs = node["inputs"]
Expand Down Expand Up @@ -150,9 +150,10 @@ def __init__(self, key_class):
self.cache = {}
self.subcaches = {}

def set_prompt(self, dynprompt, node_ids, is_changed_cache):
async def set_prompt(self, dynprompt, node_ids, is_changed_cache):
self.dynprompt = dynprompt
self.cache_key_set = self.key_class(dynprompt, node_ids, is_changed_cache)
await self.cache_key_set.add_keys(node_ids)
self.is_changed_cache = is_changed_cache
self.initialized = True

Expand Down Expand Up @@ -201,13 +202,13 @@ def _get_immediate(self, node_id):
else:
return None

def _ensure_subcache(self, node_id, children_ids):
async def _ensure_subcache(self, node_id, children_ids):
subcache_key = self.cache_key_set.get_subcache_key(node_id)
subcache = self.subcaches.get(subcache_key, None)
if subcache is None:
subcache = BasicCache(self.key_class)
self.subcaches[subcache_key] = subcache
subcache.set_prompt(self.dynprompt, children_ids, self.is_changed_cache)
await subcache.set_prompt(self.dynprompt, children_ids, self.is_changed_cache)
return subcache

def _get_subcache(self, node_id):
Expand Down Expand Up @@ -259,10 +260,10 @@ def set(self, node_id, value):
assert cache is not None
cache._set_immediate(node_id, value)

def ensure_subcache_for(self, node_id, children_ids):
async def ensure_subcache_for(self, node_id, children_ids):
cache = self._get_cache_for(node_id)
assert cache is not None
return cache._ensure_subcache(node_id, children_ids)
return await cache._ensure_subcache(node_id, children_ids)

class LRUCache(BasicCache):
def __init__(self, key_class, max_size=100):
Expand All @@ -273,8 +274,8 @@ def __init__(self, key_class, max_size=100):
self.used_generation = {}
self.children = {}

def set_prompt(self, dynprompt, node_ids, is_changed_cache):
super().set_prompt(dynprompt, node_ids, is_changed_cache)
async def set_prompt(self, dynprompt, node_ids, is_changed_cache):
await super().set_prompt(dynprompt, node_ids, is_changed_cache)
self.generation += 1
for node_id in node_ids:
self._mark_used(node_id)
Expand Down Expand Up @@ -303,11 +304,11 @@ def set(self, node_id, value):
self._mark_used(node_id)
return self._set_immediate(node_id, value)

def ensure_subcache_for(self, node_id, children_ids):
async def ensure_subcache_for(self, node_id, children_ids):
# Just uses subcaches for tracking 'live' nodes
super()._ensure_subcache(node_id, children_ids)
await super()._ensure_subcache(node_id, children_ids)

self.cache_key_set.add_keys(children_ids)
await self.cache_key_set.add_keys(children_ids)
self._mark_used(node_id)
cache_key = self.cache_key_set.get_data_key(node_id)
self.children[cache_key] = []
Expand Down Expand Up @@ -337,7 +338,7 @@ def __init__(self, key_class):
self.ancestors = {} # Maps node_id -> set of ancestor node_ids
self.executed_nodes = set() # Tracks nodes that have been executed

def set_prompt(self, dynprompt, node_ids, is_changed_cache):
async def set_prompt(self, dynprompt, node_ids, is_changed_cache):
"""
Clear the entire cache and rebuild the dependency graph.

Expand All @@ -354,7 +355,7 @@ def set_prompt(self, dynprompt, node_ids, is_changed_cache):
self.executed_nodes.clear()

# Call the parent method to initialize the cache with the new prompt
super().set_prompt(dynprompt, node_ids, is_changed_cache)
await super().set_prompt(dynprompt, node_ids, is_changed_cache)

# Rebuild the dependency graph
self._build_dependency_graph(dynprompt, node_ids)
Expand Down Expand Up @@ -405,7 +406,7 @@ def get(self, node_id):
"""
return self._get_immediate(node_id)

def ensure_subcache_for(self, node_id, children_ids):
async def ensure_subcache_for(self, node_id, children_ids):
"""
Ensure a subcache exists for a node and update dependencies.

Expand All @@ -416,7 +417,7 @@ def ensure_subcache_for(self, node_id, children_ids):
Returns:
The subcache object for the node.
"""
subcache = super()._ensure_subcache(node_id, children_ids)
subcache = await super()._ensure_subcache(node_id, children_ids)
for child_id in children_ids:
self.descendants[node_id].add(child_id)
self.ancestors[child_id].add(node_id)
Expand Down
20 changes: 19 additions & 1 deletion comfy_execution/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import Type, Literal

import nodes
import asyncio
from comfy_execution.graph_utils import is_link
from comfy.comfy_types.node_typing import ComfyNodeABC, InputTypeDict, InputTypeOptions

Expand Down Expand Up @@ -100,6 +101,8 @@ def __init__(self, dynprompt):
self.pendingNodes = {}
self.blockCount = {} # Number of nodes this node is directly blocked by
self.blocking = {} # Which nodes are blocked by this node
self.externalBlocks = 0
self.unblockedEvent = asyncio.Event()

def get_input_info(self, unique_id, input_name):
class_type = self.dynprompt.get_node(unique_id)["class_type"]
Expand Down Expand Up @@ -153,6 +156,16 @@ def add_node(self, node_unique_id, include_lazy=False, subgraph_nodes=None):
for link in links:
self.add_strong_link(*link)

def add_external_block(self, node_id):
assert node_id in self.blockCount, "Can't add external block to a node that isn't pending"
self.externalBlocks += 1
self.blockCount[node_id] += 1
def unblock():
self.externalBlocks -= 1
self.blockCount[node_id] -= 1
self.unblockedEvent.set()
return unblock

def is_cached(self, node_id):
return False

Expand Down Expand Up @@ -181,11 +194,16 @@ def __init__(self, dynprompt, output_cache):
def is_cached(self, node_id):
return self.output_cache.get(node_id) is not None

def stage_node_execution(self):
async def stage_node_execution(self):
assert self.staged_node_id is None
if self.is_empty():
return None, None, None
available = self.get_ready_nodes()
while len(available) == 0 and self.externalBlocks > 0:
# Wait for an external block to be released
await self.unblockedEvent.wait()
self.unblockedEvent.clear()
available = self.get_ready_nodes()
if len(available) == 0:
cycled_nodes = self.get_nodes_in_cycle()
# Because cycles composed entirely of static nodes are caught during initial validation,
Expand Down
Loading
Loading