Skip to content

Commit 2b653e8

Browse files
authored
Support for async node functions (Comfy-Org#8830)
* Support for async execution functions This commit adds support for node execution functions defined as async. When a node's execution function is defined as async, we can continue executing other nodes while it is processing. Standard uses of `await` should "just work", but people will still have to be careful if they spawn actual threads. Because torch doesn't really have async/await versions of functions, this won't particularly help with most locally-executing nodes, but it does work for e.g. web requests to other machines. In addition to the execute function, the `VALIDATE_INPUTS` and `check_lazy_status` functions can also be defined as async, though we'll only resolve one node at a time right now for those. * Add the execution model tests to CI * Add a missing file It looks like this got caught by .gitignore? There's probably a better place to put it, but I'm not sure what that is. * Add the websocket library for automated tests * Add additional tests for async error cases Also fixes one bug that was found when an async function throws an error after being scheduled on a task. * Add a feature flags message to reduce bandwidth We now only send 1 preview message of the latest type the client can support. We'll add a console warning when the client fails to send a feature flags message at some point in the future. * Add async tests to CI * Don't actually add new tests in this PR Will do it in a separate PR * Resolve unit test in GPU-less runner * Just remove the tests that GHA can't handle * Change line endings to UNIX-style * Avoid loading model_management.py so early Because model_management.py has a top-level `logging.info`, we have to be careful not to import that file before we call `setup_logging`. If we do, we end up having the default logging handler registered in addition to our custom one.
1 parent 1fd3068 commit 2b653e8

File tree

19 files changed

+1898
-95
lines changed

19 files changed

+1898
-95
lines changed

comfy/utils.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -998,11 +998,12 @@ def set_progress_bar_global_hook(function):
998998
PROGRESS_BAR_HOOK = function
999999

10001000
class ProgressBar:
1001-
def __init__(self, total):
1001+
def __init__(self, total, node_id=None):
10021002
global PROGRESS_BAR_HOOK
10031003
self.total = total
10041004
self.current = 0
10051005
self.hook = PROGRESS_BAR_HOOK
1006+
self.node_id = node_id
10061007

10071008
def update_absolute(self, value, total=None, preview=None):
10081009
if total is not None:
@@ -1011,7 +1012,7 @@ def update_absolute(self, value, total=None, preview=None):
10111012
value = self.total
10121013
self.current = value
10131014
if self.hook is not None:
1014-
self.hook(self.current, self.total, preview)
1015+
self.hook(self.current, self.total, preview, node_id=self.node_id)
10151016

10161017
def update(self, value):
10171018
self.update_absolute(self.current + value)

comfy_api/feature_flags.py

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
"""
2+
Feature flags module for ComfyUI WebSocket protocol negotiation.
3+
4+
This module handles capability negotiation between frontend and backend,
5+
allowing graceful protocol evolution while maintaining backward compatibility.
6+
"""
7+
8+
from typing import Any, Dict
9+
10+
from comfy.cli_args import args
11+
12+
# Default server capabilities
13+
SERVER_FEATURE_FLAGS: Dict[str, Any] = {
14+
"supports_preview_metadata": True,
15+
"max_upload_size": args.max_upload_size * 1024 * 1024, # Convert MB to bytes
16+
}
17+
18+
19+
def get_connection_feature(
20+
sockets_metadata: Dict[str, Dict[str, Any]],
21+
sid: str,
22+
feature_name: str,
23+
default: Any = False
24+
) -> Any:
25+
"""
26+
Get a feature flag value for a specific connection.
27+
28+
Args:
29+
sockets_metadata: Dictionary of socket metadata
30+
sid: Session ID of the connection
31+
feature_name: Name of the feature to check
32+
default: Default value if feature not found
33+
34+
Returns:
35+
Feature value or default if not found
36+
"""
37+
if sid not in sockets_metadata:
38+
return default
39+
40+
return sockets_metadata[sid].get("feature_flags", {}).get(feature_name, default)
41+
42+
43+
def supports_feature(
44+
sockets_metadata: Dict[str, Dict[str, Any]],
45+
sid: str,
46+
feature_name: str
47+
) -> bool:
48+
"""
49+
Check if a connection supports a specific feature.
50+
51+
Args:
52+
sockets_metadata: Dictionary of socket metadata
53+
sid: Session ID of the connection
54+
feature_name: Name of the feature to check
55+
56+
Returns:
57+
Boolean indicating if feature is supported
58+
"""
59+
return get_connection_feature(sockets_metadata, sid, feature_name, False) is True
60+
61+
62+
def get_server_features() -> Dict[str, Any]:
63+
"""
64+
Get the server's feature flags.
65+
66+
Returns:
67+
Dictionary of server feature flags
68+
"""
69+
return SERVER_FEATURE_FLAGS.copy()

comfy_execution/caching.py

Lines changed: 27 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import itertools
22
from typing import Sequence, Mapping, Dict
33
from comfy_execution.graph import DynamicPrompt
4+
from abc import ABC, abstractmethod
45

56
import nodes
67

@@ -16,12 +17,13 @@ def include_unique_id_in_input(class_type: str) -> bool:
1617
NODE_CLASS_CONTAINS_UNIQUE_ID[class_type] = "UNIQUE_ID" in class_def.INPUT_TYPES().get("hidden", {}).values()
1718
return NODE_CLASS_CONTAINS_UNIQUE_ID[class_type]
1819

19-
class CacheKeySet:
20+
class CacheKeySet(ABC):
2021
def __init__(self, dynprompt, node_ids, is_changed_cache):
2122
self.keys = {}
2223
self.subcache_keys = {}
2324

24-
def add_keys(self, node_ids):
25+
@abstractmethod
26+
async def add_keys(self, node_ids):
2527
raise NotImplementedError()
2628

2729
def all_node_ids(self):
@@ -60,9 +62,8 @@ class CacheKeySetID(CacheKeySet):
6062
def __init__(self, dynprompt, node_ids, is_changed_cache):
6163
super().__init__(dynprompt, node_ids, is_changed_cache)
6264
self.dynprompt = dynprompt
63-
self.add_keys(node_ids)
6465

65-
def add_keys(self, node_ids):
66+
async def add_keys(self, node_ids):
6667
for node_id in node_ids:
6768
if node_id in self.keys:
6869
continue
@@ -77,37 +78,36 @@ def __init__(self, dynprompt, node_ids, is_changed_cache):
7778
super().__init__(dynprompt, node_ids, is_changed_cache)
7879
self.dynprompt = dynprompt
7980
self.is_changed_cache = is_changed_cache
80-
self.add_keys(node_ids)
8181

8282
def include_node_id_in_input(self) -> bool:
8383
return False
8484

85-
def add_keys(self, node_ids):
85+
async def add_keys(self, node_ids):
8686
for node_id in node_ids:
8787
if node_id in self.keys:
8888
continue
8989
if not self.dynprompt.has_node(node_id):
9090
continue
9191
node = self.dynprompt.get_node(node_id)
92-
self.keys[node_id] = self.get_node_signature(self.dynprompt, node_id)
92+
self.keys[node_id] = await self.get_node_signature(self.dynprompt, node_id)
9393
self.subcache_keys[node_id] = (node_id, node["class_type"])
9494

95-
def get_node_signature(self, dynprompt, node_id):
95+
async def get_node_signature(self, dynprompt, node_id):
9696
signature = []
9797
ancestors, order_mapping = self.get_ordered_ancestry(dynprompt, node_id)
98-
signature.append(self.get_immediate_node_signature(dynprompt, node_id, order_mapping))
98+
signature.append(await self.get_immediate_node_signature(dynprompt, node_id, order_mapping))
9999
for ancestor_id in ancestors:
100-
signature.append(self.get_immediate_node_signature(dynprompt, ancestor_id, order_mapping))
100+
signature.append(await self.get_immediate_node_signature(dynprompt, ancestor_id, order_mapping))
101101
return to_hashable(signature)
102102

103-
def get_immediate_node_signature(self, dynprompt, node_id, ancestor_order_mapping):
103+
async def get_immediate_node_signature(self, dynprompt, node_id, ancestor_order_mapping):
104104
if not dynprompt.has_node(node_id):
105105
# This node doesn't exist -- we can't cache it.
106106
return [float("NaN")]
107107
node = dynprompt.get_node(node_id)
108108
class_type = node["class_type"]
109109
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
110-
signature = [class_type, self.is_changed_cache.get(node_id)]
110+
signature = [class_type, await self.is_changed_cache.get(node_id)]
111111
if self.include_node_id_in_input() or (hasattr(class_def, "NOT_IDEMPOTENT") and class_def.NOT_IDEMPOTENT) or include_unique_id_in_input(class_type):
112112
signature.append(node_id)
113113
inputs = node["inputs"]
@@ -150,9 +150,10 @@ def __init__(self, key_class):
150150
self.cache = {}
151151
self.subcaches = {}
152152

153-
def set_prompt(self, dynprompt, node_ids, is_changed_cache):
153+
async def set_prompt(self, dynprompt, node_ids, is_changed_cache):
154154
self.dynprompt = dynprompt
155155
self.cache_key_set = self.key_class(dynprompt, node_ids, is_changed_cache)
156+
await self.cache_key_set.add_keys(node_ids)
156157
self.is_changed_cache = is_changed_cache
157158
self.initialized = True
158159

@@ -201,13 +202,13 @@ def _get_immediate(self, node_id):
201202
else:
202203
return None
203204

204-
def _ensure_subcache(self, node_id, children_ids):
205+
async def _ensure_subcache(self, node_id, children_ids):
205206
subcache_key = self.cache_key_set.get_subcache_key(node_id)
206207
subcache = self.subcaches.get(subcache_key, None)
207208
if subcache is None:
208209
subcache = BasicCache(self.key_class)
209210
self.subcaches[subcache_key] = subcache
210-
subcache.set_prompt(self.dynprompt, children_ids, self.is_changed_cache)
211+
await subcache.set_prompt(self.dynprompt, children_ids, self.is_changed_cache)
211212
return subcache
212213

213214
def _get_subcache(self, node_id):
@@ -259,10 +260,10 @@ def set(self, node_id, value):
259260
assert cache is not None
260261
cache._set_immediate(node_id, value)
261262

262-
def ensure_subcache_for(self, node_id, children_ids):
263+
async def ensure_subcache_for(self, node_id, children_ids):
263264
cache = self._get_cache_for(node_id)
264265
assert cache is not None
265-
return cache._ensure_subcache(node_id, children_ids)
266+
return await cache._ensure_subcache(node_id, children_ids)
266267

267268
class LRUCache(BasicCache):
268269
def __init__(self, key_class, max_size=100):
@@ -273,8 +274,8 @@ def __init__(self, key_class, max_size=100):
273274
self.used_generation = {}
274275
self.children = {}
275276

276-
def set_prompt(self, dynprompt, node_ids, is_changed_cache):
277-
super().set_prompt(dynprompt, node_ids, is_changed_cache)
277+
async def set_prompt(self, dynprompt, node_ids, is_changed_cache):
278+
await super().set_prompt(dynprompt, node_ids, is_changed_cache)
278279
self.generation += 1
279280
for node_id in node_ids:
280281
self._mark_used(node_id)
@@ -303,11 +304,11 @@ def set(self, node_id, value):
303304
self._mark_used(node_id)
304305
return self._set_immediate(node_id, value)
305306

306-
def ensure_subcache_for(self, node_id, children_ids):
307+
async def ensure_subcache_for(self, node_id, children_ids):
307308
# Just uses subcaches for tracking 'live' nodes
308-
super()._ensure_subcache(node_id, children_ids)
309+
await super()._ensure_subcache(node_id, children_ids)
309310

310-
self.cache_key_set.add_keys(children_ids)
311+
await self.cache_key_set.add_keys(children_ids)
311312
self._mark_used(node_id)
312313
cache_key = self.cache_key_set.get_data_key(node_id)
313314
self.children[cache_key] = []
@@ -337,7 +338,7 @@ def __init__(self, key_class):
337338
self.ancestors = {} # Maps node_id -> set of ancestor node_ids
338339
self.executed_nodes = set() # Tracks nodes that have been executed
339340

340-
def set_prompt(self, dynprompt, node_ids, is_changed_cache):
341+
async def set_prompt(self, dynprompt, node_ids, is_changed_cache):
341342
"""
342343
Clear the entire cache and rebuild the dependency graph.
343344
@@ -354,7 +355,7 @@ def set_prompt(self, dynprompt, node_ids, is_changed_cache):
354355
self.executed_nodes.clear()
355356

356357
# Call the parent method to initialize the cache with the new prompt
357-
super().set_prompt(dynprompt, node_ids, is_changed_cache)
358+
await super().set_prompt(dynprompt, node_ids, is_changed_cache)
358359

359360
# Rebuild the dependency graph
360361
self._build_dependency_graph(dynprompt, node_ids)
@@ -405,7 +406,7 @@ def get(self, node_id):
405406
"""
406407
return self._get_immediate(node_id)
407408

408-
def ensure_subcache_for(self, node_id, children_ids):
409+
async def ensure_subcache_for(self, node_id, children_ids):
409410
"""
410411
Ensure a subcache exists for a node and update dependencies.
411412
@@ -416,7 +417,7 @@ def ensure_subcache_for(self, node_id, children_ids):
416417
Returns:
417418
The subcache object for the node.
418419
"""
419-
subcache = super()._ensure_subcache(node_id, children_ids)
420+
subcache = await super()._ensure_subcache(node_id, children_ids)
420421
for child_id in children_ids:
421422
self.descendants[node_id].add(child_id)
422423
self.ancestors[child_id].add(node_id)

comfy_execution/graph.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from typing import Type, Literal
33

44
import nodes
5+
import asyncio
56
from comfy_execution.graph_utils import is_link
67
from comfy.comfy_types.node_typing import ComfyNodeABC, InputTypeDict, InputTypeOptions
78

@@ -100,6 +101,8 @@ def __init__(self, dynprompt):
100101
self.pendingNodes = {}
101102
self.blockCount = {} # Number of nodes this node is directly blocked by
102103
self.blocking = {} # Which nodes are blocked by this node
104+
self.externalBlocks = 0
105+
self.unblockedEvent = asyncio.Event()
103106

104107
def get_input_info(self, unique_id, input_name):
105108
class_type = self.dynprompt.get_node(unique_id)["class_type"]
@@ -153,6 +156,16 @@ def add_node(self, node_unique_id, include_lazy=False, subgraph_nodes=None):
153156
for link in links:
154157
self.add_strong_link(*link)
155158

159+
def add_external_block(self, node_id):
160+
assert node_id in self.blockCount, "Can't add external block to a node that isn't pending"
161+
self.externalBlocks += 1
162+
self.blockCount[node_id] += 1
163+
def unblock():
164+
self.externalBlocks -= 1
165+
self.blockCount[node_id] -= 1
166+
self.unblockedEvent.set()
167+
return unblock
168+
156169
def is_cached(self, node_id):
157170
return False
158171

@@ -181,11 +194,16 @@ def __init__(self, dynprompt, output_cache):
181194
def is_cached(self, node_id):
182195
return self.output_cache.get(node_id) is not None
183196

184-
def stage_node_execution(self):
197+
async def stage_node_execution(self):
185198
assert self.staged_node_id is None
186199
if self.is_empty():
187200
return None, None, None
188201
available = self.get_ready_nodes()
202+
while len(available) == 0 and self.externalBlocks > 0:
203+
# Wait for an external block to be released
204+
await self.unblockedEvent.wait()
205+
self.unblockedEvent.clear()
206+
available = self.get_ready_nodes()
189207
if len(available) == 0:
190208
cycled_nodes = self.get_nodes_in_cycle()
191209
# Because cycles composed entirely of static nodes are caught during initial validation,

0 commit comments

Comments
 (0)