diff --git a/README.md b/README.md index b0731db33bb0..c9a0644e33b5 100644 --- a/README.md +++ b/README.md @@ -197,7 +197,7 @@ comfy install ## Manual Install (Windows, Linux) -Python 3.14 will work if you comment out the `kornia` dependency in the requirements.txt file (breaks the canny node) and install pytorch nightly but it is not recommended. +Python 3.14 will work if you comment out the `kornia` dependency in the requirements.txt file (breaks the canny node) but it is not recommended. Python 3.13 is very well supported. If you have trouble with some custom node dependencies on 3.13 you can try 3.12 diff --git a/comfy_execution/caching.py b/comfy_execution/caching.py index 41224ce3b82e..566bc3f9c74a 100644 --- a/comfy_execution/caching.py +++ b/comfy_execution/caching.py @@ -265,6 +265,26 @@ async def ensure_subcache_for(self, node_id, children_ids): assert cache is not None return await cache._ensure_subcache(node_id, children_ids) +class NullCache: + + async def set_prompt(self, dynprompt, node_ids, is_changed_cache): + pass + + def all_node_ids(self): + return [] + + def clean_unused(self): + pass + + def get(self, node_id): + return None + + def set(self, node_id, value): + pass + + async def ensure_subcache_for(self, node_id, children_ids): + return self + class LRUCache(BasicCache): def __init__(self, key_class, max_size=100): super().__init__(key_class) @@ -316,157 +336,3 @@ async def ensure_subcache_for(self, node_id, children_ids): self._mark_used(child_id) self.children[cache_key].append(self.cache_key_set.get_data_key(child_id)) return self - - -class DependencyAwareCache(BasicCache): - """ - A cache implementation that tracks dependencies between nodes and manages - their execution and caching accordingly. It extends the BasicCache class. - Nodes are removed from this cache once all of their descendants have been - executed. - """ - - def __init__(self, key_class): - """ - Initialize the DependencyAwareCache. - - Args: - key_class: The class used for generating cache keys. - """ - super().__init__(key_class) - self.descendants = {} # Maps node_id -> set of descendant node_ids - self.ancestors = {} # Maps node_id -> set of ancestor node_ids - self.executed_nodes = set() # Tracks nodes that have been executed - - async def set_prompt(self, dynprompt, node_ids, is_changed_cache): - """ - Clear the entire cache and rebuild the dependency graph. - - Args: - dynprompt: The dynamic prompt object containing node information. - node_ids: List of node IDs to initialize the cache for. - is_changed_cache: Flag indicating if the cache has changed. - """ - # Clear all existing cache data - self.cache.clear() - self.subcaches.clear() - self.descendants.clear() - self.ancestors.clear() - self.executed_nodes.clear() - - # Call the parent method to initialize the cache with the new prompt - await super().set_prompt(dynprompt, node_ids, is_changed_cache) - - # Rebuild the dependency graph - self._build_dependency_graph(dynprompt, node_ids) - - def _build_dependency_graph(self, dynprompt, node_ids): - """ - Build the dependency graph for all nodes. - - Args: - dynprompt: The dynamic prompt object containing node information. - node_ids: List of node IDs to build the graph for. - """ - self.descendants.clear() - self.ancestors.clear() - for node_id in node_ids: - self.descendants[node_id] = set() - self.ancestors[node_id] = set() - - for node_id in node_ids: - inputs = dynprompt.get_node(node_id)["inputs"] - for input_data in inputs.values(): - if is_link(input_data): # Check if the input is a link to another node - ancestor_id = input_data[0] - self.descendants[ancestor_id].add(node_id) - self.ancestors[node_id].add(ancestor_id) - - def set(self, node_id, value): - """ - Mark a node as executed and store its value in the cache. - - Args: - node_id: The ID of the node to store. - value: The value to store for the node. - """ - self._set_immediate(node_id, value) - self.executed_nodes.add(node_id) - self._cleanup_ancestors(node_id) - - def get(self, node_id): - """ - Retrieve the cached value for a node. - - Args: - node_id: The ID of the node to retrieve. - - Returns: - The cached value for the node. - """ - return self._get_immediate(node_id) - - async def ensure_subcache_for(self, node_id, children_ids): - """ - Ensure a subcache exists for a node and update dependencies. - - Args: - node_id: The ID of the parent node. - children_ids: List of child node IDs to associate with the parent node. - - Returns: - The subcache object for the node. - """ - subcache = await super()._ensure_subcache(node_id, children_ids) - for child_id in children_ids: - self.descendants[node_id].add(child_id) - self.ancestors[child_id].add(node_id) - return subcache - - def _cleanup_ancestors(self, node_id): - """ - Check if ancestors of a node can be removed from the cache. - - Args: - node_id: The ID of the node whose ancestors are to be checked. - """ - for ancestor_id in self.ancestors.get(node_id, []): - if ancestor_id in self.executed_nodes: - # Remove ancestor if all its descendants have been executed - if all(descendant in self.executed_nodes for descendant in self.descendants[ancestor_id]): - self._remove_node(ancestor_id) - - def _remove_node(self, node_id): - """ - Remove a node from the cache. - - Args: - node_id: The ID of the node to remove. - """ - cache_key = self.cache_key_set.get_data_key(node_id) - if cache_key in self.cache: - del self.cache[cache_key] - subcache_key = self.cache_key_set.get_subcache_key(node_id) - if subcache_key in self.subcaches: - del self.subcaches[subcache_key] - - def clean_unused(self): - """ - Clean up unused nodes. This is a no-op for this cache implementation. - """ - pass - - def recursive_debug_dump(self): - """ - Dump the cache and dependency graph for debugging. - - Returns: - A list containing the cache state and dependency graph. - """ - result = super().recursive_debug_dump() - result.append({ - "descendants": self.descendants, - "ancestors": self.ancestors, - "executed_nodes": list(self.executed_nodes), - }) - return result diff --git a/comfy_execution/graph.py b/comfy_execution/graph.py index f4b427265da7..d5bbacde3a02 100644 --- a/comfy_execution/graph.py +++ b/comfy_execution/graph.py @@ -153,8 +153,9 @@ def add_node(self, node_unique_id, include_lazy=False, subgraph_nodes=None): continue _, _, input_info = self.get_input_info(unique_id, input_name) is_lazy = input_info is not None and "lazy" in input_info and input_info["lazy"] - if (include_lazy or not is_lazy) and not self.is_cached(from_node_id): - node_ids.append(from_node_id) + if (include_lazy or not is_lazy): + if not self.is_cached(from_node_id): + node_ids.append(from_node_id) links.append((from_node_id, from_socket, unique_id)) for link in links: @@ -194,10 +195,34 @@ def __init__(self, dynprompt, output_cache): super().__init__(dynprompt) self.output_cache = output_cache self.staged_node_id = None + self.execution_cache = {} + self.execution_cache_listeners = {} def is_cached(self, node_id): return self.output_cache.get(node_id) is not None + def cache_link(self, from_node_id, to_node_id): + if not to_node_id in self.execution_cache: + self.execution_cache[to_node_id] = {} + self.execution_cache[to_node_id][from_node_id] = self.output_cache.get(from_node_id) + if not from_node_id in self.execution_cache_listeners: + self.execution_cache_listeners[from_node_id] = set() + self.execution_cache_listeners[from_node_id].add(to_node_id) + + def get_output_cache(self, from_node_id, to_node_id): + if not to_node_id in self.execution_cache: + return None + return self.execution_cache[to_node_id].get(from_node_id) + + def cache_update(self, node_id, value): + if node_id in self.execution_cache_listeners: + for to_node_id in self.execution_cache_listeners[node_id]: + self.execution_cache[to_node_id][node_id] = value + + def add_strong_link(self, from_node_id, from_socket, to_node_id): + super().add_strong_link(from_node_id, from_socket, to_node_id) + self.cache_link(from_node_id, to_node_id) + async def stage_node_execution(self): assert self.staged_node_id is None if self.is_empty(): @@ -277,6 +302,8 @@ def unstage_node_execution(self): def complete_node_execution(self): node_id = self.staged_node_id self.pop_node(node_id) + self.execution_cache.pop(node_id, None) + self.execution_cache_listeners.pop(node_id, None) self.staged_node_id = None def get_nodes_in_cycle(self): diff --git a/comfy_extras/nodes_controlnet.py b/comfy_extras/nodes_controlnet.py index 2d20e1fed7c2..e835feed71fc 100644 --- a/comfy_extras/nodes_controlnet.py +++ b/comfy_extras/nodes_controlnet.py @@ -1,20 +1,26 @@ from comfy.cldm.control_types import UNION_CONTROLNET_TYPES import nodes import comfy.utils +from typing_extensions import override +from comfy_api.latest import ComfyExtension, io -class SetUnionControlNetType: +class SetUnionControlNetType(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": {"control_net": ("CONTROL_NET", ), - "type": (["auto"] + list(UNION_CONTROLNET_TYPES.keys()),) - }} + def define_schema(cls): + return io.Schema( + node_id="SetUnionControlNetType", + category="conditioning/controlnet", + inputs=[ + io.ControlNet.Input("control_net"), + io.Combo.Input("type", options=["auto"] + list(UNION_CONTROLNET_TYPES.keys())), + ], + outputs=[ + io.ControlNet.Output(), + ], + ) - CATEGORY = "conditioning/controlnet" - RETURN_TYPES = ("CONTROL_NET",) - - FUNCTION = "set_controlnet_type" - - def set_controlnet_type(self, control_net, type): + @classmethod + def execute(cls, control_net, type) -> io.NodeOutput: control_net = control_net.copy() type_number = UNION_CONTROLNET_TYPES.get(type, -1) if type_number >= 0: @@ -22,27 +28,36 @@ def set_controlnet_type(self, control_net, type): else: control_net.set_extra_arg("control_type", []) - return (control_net,) + return io.NodeOutput(control_net) + + set_controlnet_type = execute # TODO: remove + -class ControlNetInpaintingAliMamaApply(nodes.ControlNetApplyAdvanced): +class ControlNetInpaintingAliMamaApply(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": {"positive": ("CONDITIONING", ), - "negative": ("CONDITIONING", ), - "control_net": ("CONTROL_NET", ), - "vae": ("VAE", ), - "image": ("IMAGE", ), - "mask": ("MASK", ), - "strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}), - "start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}), - "end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.001}) - }} - - FUNCTION = "apply_inpaint_controlnet" - - CATEGORY = "conditioning/controlnet" - - def apply_inpaint_controlnet(self, positive, negative, control_net, vae, image, mask, strength, start_percent, end_percent): + def define_schema(cls): + return io.Schema( + node_id="ControlNetInpaintingAliMamaApply", + category="conditioning/controlnet", + inputs=[ + io.Conditioning.Input("positive"), + io.Conditioning.Input("negative"), + io.ControlNet.Input("control_net"), + io.Vae.Input("vae"), + io.Image.Input("image"), + io.Mask.Input("mask"), + io.Float.Input("strength", default=1.0, min=0.0, max=10.0, step=0.01), + io.Float.Input("start_percent", default=0.0, min=0.0, max=1.0, step=0.001), + io.Float.Input("end_percent", default=1.0, min=0.0, max=1.0, step=0.001), + ], + outputs=[ + io.Conditioning.Output(display_name="positive"), + io.Conditioning.Output(display_name="negative"), + ], + ) + + @classmethod + def execute(cls, positive, negative, control_net, vae, image, mask, strength, start_percent, end_percent) -> io.NodeOutput: extra_concat = [] if control_net.concat_mask: mask = 1.0 - mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])) @@ -50,11 +65,20 @@ def apply_inpaint_controlnet(self, positive, negative, control_net, vae, image, image = image * mask_apply.movedim(1, -1).repeat(1, 1, 1, image.shape[3]) extra_concat = [mask] - return self.apply_controlnet(positive, negative, control_net, image, strength, start_percent, end_percent, vae=vae, extra_concat=extra_concat) + result = nodes.ControlNetApplyAdvanced().apply_controlnet(positive, negative, control_net, image, strength, start_percent, end_percent, vae=vae, extra_concat=extra_concat) + return io.NodeOutput(result[0], result[1]) + + apply_inpaint_controlnet = execute # TODO: remove + +class ControlNetExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + SetUnionControlNetType, + ControlNetInpaintingAliMamaApply, + ] -NODE_CLASS_MAPPINGS = { - "SetUnionControlNetType": SetUnionControlNetType, - "ControlNetInpaintingAliMamaApply": ControlNetInpaintingAliMamaApply, -} +async def comfy_entrypoint() -> ControlNetExtension: + return ControlNetExtension() diff --git a/execution.py b/execution.py index 1dc35738b823..78c36a4b0556 100644 --- a/execution.py +++ b/execution.py @@ -18,7 +18,7 @@ BasicCache, CacheKeySetID, CacheKeySetInputSignature, - DependencyAwareCache, + NullCache, HierarchicalCache, LRUCache, ) @@ -91,13 +91,13 @@ async def get(self, node_id): class CacheType(Enum): CLASSIC = 0 LRU = 1 - DEPENDENCY_AWARE = 2 + NONE = 2 class CacheSet: def __init__(self, cache_type=None, cache_size=None): - if cache_type == CacheType.DEPENDENCY_AWARE: - self.init_dependency_aware_cache() + if cache_type == CacheType.NONE: + self.init_null_cache() logging.info("Disabling intermediate node cache.") elif cache_type == CacheType.LRU: if cache_size is None: @@ -120,11 +120,12 @@ def init_lru_cache(self, cache_size): self.ui = LRUCache(CacheKeySetInputSignature, max_size=cache_size) self.objects = HierarchicalCache(CacheKeySetID) - # only hold cached items while the decendents have not executed - def init_dependency_aware_cache(self): - self.outputs = DependencyAwareCache(CacheKeySetInputSignature) - self.ui = DependencyAwareCache(CacheKeySetInputSignature) - self.objects = DependencyAwareCache(CacheKeySetID) + def init_null_cache(self): + self.outputs = NullCache() + #The UI cache is expected to be iterable at the end of each workflow + #so it must cache at least a full workflow. Use Heirachical + self.ui = HierarchicalCache(CacheKeySetInputSignature) + self.objects = NullCache() def recursive_debug_dump(self): result = { @@ -135,7 +136,7 @@ def recursive_debug_dump(self): SENSITIVE_EXTRA_DATA_KEYS = ("auth_token_comfy_org", "api_key_comfy_org") -def get_input_data(inputs, class_def, unique_id, outputs=None, dynprompt=None, extra_data={}): +def get_input_data(inputs, class_def, unique_id, execution_list=None, dynprompt=None, extra_data={}): is_v3 = issubclass(class_def, _ComfyNodeInternal) if is_v3: valid_inputs, schema = class_def.INPUT_TYPES(include_hidden=False, return_schema=True) @@ -153,10 +154,10 @@ def mark_missing(): if is_link(input_data) and (not input_info or not input_info.get("rawLink", False)): input_unique_id = input_data[0] output_index = input_data[1] - if outputs is None: + if execution_list is None: mark_missing() continue # This might be a lazily-evaluated input - cached_output = outputs.get(input_unique_id) + cached_output = execution_list.get_output_cache(input_unique_id, unique_id) if cached_output is None: mark_missing() continue @@ -405,6 +406,7 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed, cached_output = caches.ui.get(unique_id) or {} server.send_sync("executed", { "node": unique_id, "display_node": display_node_id, "output": cached_output.get("output",None), "prompt_id": prompt_id }, server.client_id) get_progress_state().finish_progress(unique_id) + execution_list.cache_update(unique_id, caches.outputs.get(unique_id)) return (ExecutionResult.SUCCESS, None, None) input_data_all = None @@ -434,7 +436,7 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed, for r in result: if is_link(r): source_node, source_output = r[0], r[1] - node_output = caches.outputs.get(source_node)[source_output] + node_output = execution_list.get_output_cache(source_node, unique_id)[source_output] for o in node_output: resolved_output.append(o) @@ -446,7 +448,7 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed, has_subgraph = False else: get_progress_state().start_progress(unique_id) - input_data_all, missing_keys, hidden_inputs = get_input_data(inputs, class_def, unique_id, caches.outputs, dynprompt, extra_data) + input_data_all, missing_keys, hidden_inputs = get_input_data(inputs, class_def, unique_id, execution_list, dynprompt, extra_data) if server.client_id is not None: server.last_node_id = display_node_id server.send_sync("executing", { "node": unique_id, "display_node": display_node_id, "prompt_id": prompt_id }, server.client_id) @@ -549,11 +551,15 @@ async def await_completion(): subcache.clean_unused() for node_id in new_output_ids: execution_list.add_node(node_id) + execution_list.cache_link(node_id, unique_id) for link in new_output_links: execution_list.add_strong_link(link[0], link[1], unique_id) pending_subgraph_results[unique_id] = cached_outputs return (ExecutionResult.PENDING, None, None) + caches.outputs.set(unique_id, output_data) + execution_list.cache_update(unique_id, output_data) + except comfy.model_management.InterruptProcessingException as iex: logging.info("Processing interrupted") diff --git a/main.py b/main.py index 35857dba8a6b..4b4c5dcc4294 100644 --- a/main.py +++ b/main.py @@ -173,7 +173,7 @@ def prompt_worker(q, server_instance): if args.cache_lru > 0: cache_type = execution.CacheType.LRU elif args.cache_none: - cache_type = execution.CacheType.DEPENDENCY_AWARE + cache_type = execution.CacheType.NONE e = execution.PromptExecutor(server_instance, cache_type=cache_type, cache_size=args.cache_lru) last_gc_collect = 0 diff --git a/tests/execution/test_execution.py b/tests/execution/test_execution.py index ef73ad9fdd83..ace0d2279093 100644 --- a/tests/execution/test_execution.py +++ b/tests/execution/test_execution.py @@ -152,12 +152,12 @@ class TestExecution: # Initialize server and client # @fixture(scope="class", autouse=True, params=[ - # (use_lru, lru_size) - (False, 0), - (True, 0), - (True, 100), + { "extra_args" : [], "should_cache_results" : True }, + { "extra_args" : ["--cache-lru", 0], "should_cache_results" : True }, + { "extra_args" : ["--cache-lru", 100], "should_cache_results" : True }, + { "extra_args" : ["--cache-none"], "should_cache_results" : False }, ]) - def _server(self, args_pytest, request): + def server(self, args_pytest, request): # Start server pargs = [ 'python','main.py', @@ -167,12 +167,10 @@ def _server(self, args_pytest, request): '--extra-model-paths-config', 'tests/execution/extra_model_paths.yaml', '--cpu', ] - use_lru, lru_size = request.param - if use_lru: - pargs += ['--cache-lru', str(lru_size)] + pargs += [ str(param) for param in request.param["extra_args"] ] print("Running server with args:", pargs) # noqa: T201 p = subprocess.Popen(pargs) - yield + yield request.param p.kill() torch.cuda.empty_cache() @@ -193,7 +191,7 @@ def start_client(self, listen:str, port:int): return comfy_client @fixture(scope="class", autouse=True) - def shared_client(self, args_pytest, _server): + def shared_client(self, args_pytest, server): client = self.start_client(args_pytest["listen"], args_pytest["port"]) yield client del client @@ -225,7 +223,7 @@ def test_lazy_input(self, client: ComfyClient, builder: GraphBuilder): assert result.did_run(mask) assert result.did_run(lazy_mix) - def test_full_cache(self, client: ComfyClient, builder: GraphBuilder): + def test_full_cache(self, client: ComfyClient, builder: GraphBuilder, server): g = builder input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1) input2 = g.node("StubImage", content="NOISE", height=512, width=512, batch_size=1) @@ -237,9 +235,12 @@ def test_full_cache(self, client: ComfyClient, builder: GraphBuilder): client.run(g) result2 = client.run(g) for node_id, node in g.nodes.items(): - assert not result2.did_run(node), f"Node {node_id} ran, but should have been cached" + if server["should_cache_results"]: + assert not result2.did_run(node), f"Node {node_id} ran, but should have been cached" + else: + assert result2.did_run(node), f"Node {node_id} was cached, but should have been run" - def test_partial_cache(self, client: ComfyClient, builder: GraphBuilder): + def test_partial_cache(self, client: ComfyClient, builder: GraphBuilder, server): g = builder input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1) input2 = g.node("StubImage", content="NOISE", height=512, width=512, batch_size=1) @@ -251,8 +252,12 @@ def test_partial_cache(self, client: ComfyClient, builder: GraphBuilder): client.run(g) mask.inputs['value'] = 0.4 result2 = client.run(g) - assert not result2.did_run(input1), "Input1 should have been cached" - assert not result2.did_run(input2), "Input2 should have been cached" + if server["should_cache_results"]: + assert not result2.did_run(input1), "Input1 should have been cached" + assert not result2.did_run(input2), "Input2 should have been cached" + else: + assert result2.did_run(input1), "Input1 should have been rerun" + assert result2.did_run(input2), "Input2 should have been rerun" def test_error(self, client: ComfyClient, builder: GraphBuilder): g = builder @@ -411,7 +416,7 @@ def test_missing_node_error(self, client: ComfyClient, builder: GraphBuilder): input2 = g.node("StubImage", id="removeme", content="WHITE", height=512, width=512, batch_size=1) client.run(g) - def test_custom_is_changed(self, client: ComfyClient, builder: GraphBuilder): + def test_custom_is_changed(self, client: ComfyClient, builder: GraphBuilder, server): g = builder # Creating the nodes in this specific order previously caused a bug save = g.node("SaveImage") @@ -427,7 +432,10 @@ def test_custom_is_changed(self, client: ComfyClient, builder: GraphBuilder): result3 = client.run(g) result4 = client.run(g) assert result1.did_run(is_changed), "is_changed should have been run" - assert not result2.did_run(is_changed), "is_changed should have been cached" + if server["should_cache_results"]: + assert not result2.did_run(is_changed), "is_changed should have been cached" + else: + assert result2.did_run(is_changed), "is_changed should have been re-run" assert result3.did_run(is_changed), "is_changed should have been re-run" assert result4.did_run(is_changed), "is_changed should not have been cached" @@ -514,7 +522,7 @@ def test_output_reuse(self, client: ComfyClient, builder: GraphBuilder): assert len(images2) == 1, "Should have 1 image" # This tests that only constant outputs are used in the call to `IS_CHANGED` - def test_is_changed_with_outputs(self, client: ComfyClient, builder: GraphBuilder): + def test_is_changed_with_outputs(self, client: ComfyClient, builder: GraphBuilder, server): g = builder input1 = g.node("StubConstantImage", value=0.5, height=512, width=512, batch_size=1) test_node = g.node("TestIsChangedWithConstants", image=input1.out(0), value=0.5) @@ -530,7 +538,11 @@ def test_is_changed_with_outputs(self, client: ComfyClient, builder: GraphBuilde images = result.get_images(output) assert len(images) == 1, "Should have 1 image" assert numpy.array(images[0]).min() == 63 and numpy.array(images[0]).max() == 63, "Image should have value 0.25" - assert not result.did_run(test_node), "The execution should have been cached" + if server["should_cache_results"]: + assert not result.did_run(test_node), "The execution should have been cached" + else: + assert result.did_run(test_node), "The execution should have been re-run" + def test_parallel_sleep_nodes(self, client: ComfyClient, builder: GraphBuilder, skip_timing_checks): # Warmup execution to ensure server is fully initialized