Skip to content

Commit c9e7097

Browse files
authored
Merge pull request #3331 from HippocampusGirl/has-node-performance
[REF] Cache nodes in workflow to speed up construction, other optimizations
2 parents 4e8f54d + 0c485d3 commit c9e7097

File tree

2 files changed

+50
-30
lines changed

2 files changed

+50
-30
lines changed

nipype/pipeline/engine/utils.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -753,7 +753,7 @@ def _merge_graphs(
753753
# nodes of the supergraph.
754754
supernodes = supergraph.nodes()
755755
ids = [n._hierarchy + n._id for n in supernodes]
756-
if len(np.unique(ids)) != len(ids):
756+
if len(set(ids)) != len(ids):
757757
# This should trap the problem of miswiring when multiple iterables are
758758
# used at the same level. The use of the template below for naming
759759
# updates to nodes is the general solution.
@@ -1098,11 +1098,12 @@ def make_field_func(*pair):
10981098
old_edge_dict = jedge_dict[jnode]
10991099
# the edge source node replicates
11001100
expansions = defaultdict(list)
1101-
for node in graph_in.nodes():
1101+
for node in graph_in:
11021102
for src_id in list(old_edge_dict.keys()):
11031103
# Drop the original JoinNodes; only concerned with
11041104
# generated Nodes
1105-
if hasattr(node, "joinfield") and node.itername == src_id:
1105+
itername = node.itername
1106+
if hasattr(node, "joinfield") and itername == src_id:
11061107
continue
11071108
# Patterns:
11081109
# - src_id : Non-iterable node
@@ -1111,10 +1112,10 @@ def make_field_func(*pair):
11111112
# - src_id.[a-z]I.[a-z]\d+ :
11121113
# Non-IdentityInterface w/ iterables
11131114
# - src_idJ\d+ : JoinNode(IdentityInterface)
1114-
if re.match(
1115-
src_id + r"((\.[a-z](I\.[a-z])?|J)\d+)?$", node.itername
1116-
):
1117-
expansions[src_id].append(node)
1115+
if itername.startswith(src_id):
1116+
suffix = itername[len(src_id):]
1117+
if re.fullmatch(r"((\.[a-z](I\.[a-z])?|J)\d+)?", suffix):
1118+
expansions[src_id].append(node)
11181119
for in_id, in_nodes in list(expansions.items()):
11191120
logger.debug(
11201121
"The join node %s input %s was expanded" " to %d nodes.",

nipype/pipeline/engine/workflows.py

Lines changed: 42 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,9 @@ def __init__(self, name, base_dir=None):
5959
super(Workflow, self).__init__(name, base_dir)
6060
self._graph = nx.DiGraph()
6161

62+
self._nodes_cache = set()
63+
self._nested_workflows_cache = set()
64+
6265
# PUBLIC API
6366
def clone(self, name):
6467
"""Clone a workflow
@@ -141,7 +144,7 @@ def connect(self, *args, **kwargs):
141144
self.disconnect(connection_list)
142145
return
143146

144-
newnodes = []
147+
newnodes = set()
145148
for srcnode, destnode, _ in connection_list:
146149
if self in [srcnode, destnode]:
147150
msg = (
@@ -151,9 +154,9 @@ def connect(self, *args, **kwargs):
151154

152155
raise IOError(msg)
153156
if (srcnode not in newnodes) and not self._has_node(srcnode):
154-
newnodes.append(srcnode)
157+
newnodes.add(srcnode)
155158
if (destnode not in newnodes) and not self._has_node(destnode):
156-
newnodes.append(destnode)
159+
newnodes.add(destnode)
157160
if newnodes:
158161
self._check_nodes(newnodes)
159162
for node in newnodes:
@@ -163,15 +166,16 @@ def connect(self, *args, **kwargs):
163166
connected_ports = {}
164167
for srcnode, destnode, connects in connection_list:
165168
if destnode not in connected_ports:
166-
connected_ports[destnode] = []
169+
connected_ports[destnode] = set()
167170
# check to see which ports of destnode are already
168171
# connected.
169172
if not disconnect and (destnode in self._graph.nodes()):
170173
for edge in self._graph.in_edges(destnode):
171174
data = self._graph.get_edge_data(*edge)
172-
for sourceinfo, destname in data["connect"]:
173-
if destname not in connected_ports[destnode]:
174-
connected_ports[destnode] += [destname]
175+
connected_ports[destnode].update(
176+
destname
177+
for _, destname in data["connect"]
178+
)
175179
for source, dest in connects:
176180
# Currently datasource/sink/grabber.io modules
177181
# determine their inputs/outputs depending on
@@ -226,7 +230,7 @@ def connect(self, *args, **kwargs):
226230
)
227231
if sourcename and not srcnode._check_outputs(sourcename):
228232
not_found.append(["out", srcnode.name, sourcename])
229-
connected_ports[destnode] += [dest]
233+
connected_ports[destnode].add(dest)
230234
infostr = []
231235
for info in not_found:
232236
infostr += [
@@ -269,6 +273,9 @@ def connect(self, *args, **kwargs):
269273
"(%s, %s): new edge data: %s", srcnode, destnode, str(edge_data)
270274
)
271275

276+
if newnodes:
277+
self._update_node_cache()
278+
272279
def disconnect(self, *args):
273280
"""Disconnect nodes
274281
See the docstring for connect for format.
@@ -325,7 +332,7 @@ def add_nodes(self, nodes):
325332
newnodes = []
326333
all_nodes = self._get_all_nodes()
327334
for node in nodes:
328-
if self._has_node(node):
335+
if node in all_nodes:
329336
raise IOError("Node %s already exists in the workflow" % node)
330337
if isinstance(node, Workflow):
331338
for subnode in node._get_all_nodes():
@@ -346,6 +353,7 @@ def add_nodes(self, nodes):
346353
if node._hierarchy is None:
347354
node._hierarchy = self.name
348355
self._graph.add_nodes_from(newnodes)
356+
self._update_node_cache()
349357

350358
def remove_nodes(self, nodes):
351359
"""Remove nodes from a workflow
@@ -356,6 +364,7 @@ def remove_nodes(self, nodes):
356364
A list of EngineBase-based objects
357365
"""
358366
self._graph.remove_nodes_from(nodes)
367+
self._update_node_cache()
359368

360369
# Input-Output access
361370
@property
@@ -895,22 +904,32 @@ def _set_node_input(self, node, param, source, sourceinfo):
895904
node.set_input(param, deepcopy(newval))
896905

897906
def _get_all_nodes(self):
898-
allnodes = []
899-
for node in self._graph.nodes():
900-
if isinstance(node, Workflow):
901-
allnodes.extend(node._get_all_nodes())
902-
else:
903-
allnodes.append(node)
907+
allnodes = self._nodes_cache - self._nested_workflows_cache
908+
for node in self._nested_workflows_cache:
909+
allnodes |= node._get_all_nodes()
904910
return allnodes
905911

906-
def _has_node(self, wanted_node):
907-
for node in self._graph.nodes():
908-
if wanted_node == node:
909-
return True
912+
def _update_node_cache(self):
913+
nodes = set(self._graph)
914+
915+
added_nodes = nodes.difference(self._nodes_cache)
916+
removed_nodes = self._nodes_cache.difference(nodes)
917+
918+
self._nodes_cache = nodes
919+
self._nested_workflows_cache.difference_update(removed_nodes)
920+
921+
for node in added_nodes:
910922
if isinstance(node, Workflow):
911-
if node._has_node(wanted_node):
912-
return True
913-
return False
923+
self._nested_workflows_cache.add(node)
924+
925+
def _has_node(self, wanted_node):
926+
return (
927+
wanted_node in self._nodes_cache or
928+
any(
929+
wf._has_node(wanted_node)
930+
for wf in self._nested_workflows_cache
931+
)
932+
)
914933

915934
def _create_flat_graph(self):
916935
"""Make a simple DAG where no node is a workflow."""
@@ -939,7 +958,7 @@ def _generate_flatgraph(self):
939958
raise Exception(
940959
("Workflow: %s is not a directed acyclic graph " "(DAG)") % self.name
941960
)
942-
nodes = list(nx.topological_sort(self._graph))
961+
nodes = list(self._graph.nodes)
943962
for node in nodes:
944963
logger.debug("processing node: %s", node)
945964
if isinstance(node, Workflow):

0 commit comments

Comments
 (0)