Skip to content

Commit f13cff0

Browse files
authored
Add custom node published subgraphs endpoint (Comfy-Org#10438)
* Add get_subgraphs_dir to ComfyExtension and PUBLISHED_SUBGRAPH_DIRS to nodes.py * Created initial endpoints, although the returned paths are a bit off currently * Fix path and actually return real data * Sanitize returned /api/global_subgraphs entries * Remove leftover function from early prototyping * Remove added whitespace * Add None check for sanitize_entry
1 parent 9cdc649 commit f13cff0

File tree

2 files changed

+115
-0
lines changed

2 files changed

+115
-0
lines changed

app/subgraph_manager.py

Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
from __future__ import annotations
2+
3+
from typing import TypedDict
4+
import os
5+
import folder_paths
6+
import glob
7+
from aiohttp import web
8+
import hashlib
9+
10+
11+
class Source:
12+
custom_node = "custom_node"
13+
14+
class SubgraphEntry(TypedDict):
15+
source: str
16+
"""
17+
Source of subgraph - custom_nodes vs templates.
18+
"""
19+
path: str
20+
"""
21+
Relative path of the subgraph file.
22+
For custom nodes, will be the relative directory like <custom_node_dir>/subgraphs/<name>.json
23+
"""
24+
name: str
25+
"""
26+
Name of subgraph file.
27+
"""
28+
info: CustomNodeSubgraphEntryInfo
29+
"""
30+
Additional info about subgraph; in the case of custom_nodes, will contain nodepack name
31+
"""
32+
data: str
33+
34+
class CustomNodeSubgraphEntryInfo(TypedDict):
35+
node_pack: str
36+
"""Node pack name."""
37+
38+
class SubgraphManager:
39+
def __init__(self):
40+
self.cached_custom_node_subgraphs: dict[SubgraphEntry] | None = None
41+
42+
async def load_entry_data(self, entry: SubgraphEntry):
43+
with open(entry['path'], 'r') as f:
44+
entry['data'] = f.read()
45+
return entry
46+
47+
async def sanitize_entry(self, entry: SubgraphEntry | None, remove_data=False) -> SubgraphEntry | None:
48+
if entry is None:
49+
return None
50+
entry = entry.copy()
51+
entry.pop('path', None)
52+
if remove_data:
53+
entry.pop('data', None)
54+
return entry
55+
56+
async def sanitize_entries(self, entries: dict[str, SubgraphEntry], remove_data=False) -> dict[str, SubgraphEntry]:
57+
entries = entries.copy()
58+
for key in list(entries.keys()):
59+
entries[key] = await self.sanitize_entry(entries[key], remove_data)
60+
return entries
61+
62+
async def get_custom_node_subgraphs(self, loadedModules, force_reload=False):
63+
# if not forced to reload and cached, return cache
64+
if not force_reload and self.cached_custom_node_subgraphs is not None:
65+
return self.cached_custom_node_subgraphs
66+
# Load subgraphs from custom nodes
67+
subfolder = "subgraphs"
68+
subgraphs_dict: dict[SubgraphEntry] = {}
69+
70+
for folder in folder_paths.get_folder_paths("custom_nodes"):
71+
pattern = os.path.join(folder, f"*/{subfolder}/*.json")
72+
matched_files = glob.glob(pattern)
73+
for file in matched_files:
74+
# replace backslashes with forward slashes
75+
file = file.replace('\\', '/')
76+
info: CustomNodeSubgraphEntryInfo = {
77+
"node_pack": "custom_nodes." + file.split('/')[-3]
78+
}
79+
source = Source.custom_node
80+
# hash source + path to make sure id will be as unique as possible, but
81+
# reproducible across backend reloads
82+
id = hashlib.sha256(f"{source}{file}".encode()).hexdigest()
83+
entry: SubgraphEntry = {
84+
"source": Source.custom_node,
85+
"name": os.path.splitext(os.path.basename(file))[0],
86+
"path": file,
87+
"info": info,
88+
}
89+
subgraphs_dict[id] = entry
90+
self.cached_custom_node_subgraphs = subgraphs_dict
91+
return subgraphs_dict
92+
93+
async def get_custom_node_subgraph(self, id: str, loadedModules):
94+
subgraphs = await self.get_custom_node_subgraphs(loadedModules)
95+
entry: SubgraphEntry = subgraphs.get(id, None)
96+
if entry is not None and entry.get('data', None) is None:
97+
await self.load_entry_data(entry)
98+
return entry
99+
100+
def add_routes(self, routes, loadedModules):
101+
@routes.get("/global_subgraphs")
102+
async def get_global_subgraphs(request):
103+
subgraphs_dict = await self.get_custom_node_subgraphs(loadedModules)
104+
# NOTE: we may want to include other sources of global subgraphs such as templates in the future;
105+
# that's the reasoning for the current implementation
106+
return web.json_response(await self.sanitize_entries(subgraphs_dict, remove_data=True))
107+
108+
@routes.get("/global_subgraphs/{id}")
109+
async def get_global_subgraph(request):
110+
id = request.match_info.get("id", None)
111+
subgraph = await self.get_custom_node_subgraph(id, loadedModules)
112+
return web.json_response(await self.sanitize_entry(subgraph))

server.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
from app.user_manager import UserManager
3636
from app.model_manager import ModelFileManager
3737
from app.custom_node_manager import CustomNodeManager
38+
from app.subgraph_manager import SubgraphManager
3839
from typing import Optional, Union
3940
from api_server.routes.internal.internal_routes import InternalRoutes
4041
from protocol import BinaryEventTypes
@@ -173,6 +174,7 @@ def __init__(self, loop):
173174
self.user_manager = UserManager()
174175
self.model_file_manager = ModelFileManager()
175176
self.custom_node_manager = CustomNodeManager()
177+
self.subgraph_manager = SubgraphManager()
176178
self.internal_routes = InternalRoutes(self)
177179
self.supports = ["custom_nodes_from_web"]
178180
self.prompt_queue = execution.PromptQueue(self)
@@ -819,6 +821,7 @@ def add_routes(self):
819821
self.user_manager.add_routes(self.routes)
820822
self.model_file_manager.add_routes(self.routes)
821823
self.custom_node_manager.add_routes(self.routes, self.app, nodes.LOADED_MODULE_DIRS.items())
824+
self.subgraph_manager.add_routes(self.routes, nodes.LOADED_MODULE_DIRS.items())
822825
self.app.add_subapp('/internal', self.internal_routes.get_app())
823826

824827
# Prefix every route with /api for easier matching for delegation.

0 commit comments

Comments
 (0)