@@ -59,6 +59,9 @@ def __init__(self, name, base_dir=None):
59
59
super (Workflow , self ).__init__ (name , base_dir )
60
60
self ._graph = nx .DiGraph ()
61
61
62
+ self ._nodes_cache = set ()
63
+ self ._nested_workflows_cache = set ()
64
+
62
65
# PUBLIC API
63
66
def clone (self , name ):
64
67
"""Clone a workflow
@@ -141,7 +144,7 @@ def connect(self, *args, **kwargs):
141
144
self .disconnect (connection_list )
142
145
return
143
146
144
- newnodes = []
147
+ newnodes = set ()
145
148
for srcnode , destnode , _ in connection_list :
146
149
if self in [srcnode , destnode ]:
147
150
msg = (
@@ -151,9 +154,9 @@ def connect(self, *args, **kwargs):
151
154
152
155
raise IOError (msg )
153
156
if (srcnode not in newnodes ) and not self ._has_node (srcnode ):
154
- newnodes .append (srcnode )
157
+ newnodes .add (srcnode )
155
158
if (destnode not in newnodes ) and not self ._has_node (destnode ):
156
- newnodes .append (destnode )
159
+ newnodes .add (destnode )
157
160
if newnodes :
158
161
self ._check_nodes (newnodes )
159
162
for node in newnodes :
@@ -163,15 +166,16 @@ def connect(self, *args, **kwargs):
163
166
connected_ports = {}
164
167
for srcnode , destnode , connects in connection_list :
165
168
if destnode not in connected_ports :
166
- connected_ports [destnode ] = []
169
+ connected_ports [destnode ] = set ()
167
170
# check to see which ports of destnode are already
168
171
# connected.
169
172
if not disconnect and (destnode in self ._graph .nodes ()):
170
173
for edge in self ._graph .in_edges (destnode ):
171
174
data = self ._graph .get_edge_data (* edge )
172
- for sourceinfo , destname in data ["connect" ]:
173
- if destname not in connected_ports [destnode ]:
174
- connected_ports [destnode ] += [destname ]
175
+ connected_ports [destnode ].update (
176
+ destname
177
+ for _ , destname in data ["connect" ]
178
+ )
175
179
for source , dest in connects :
176
180
# Currently datasource/sink/grabber.io modules
177
181
# determine their inputs/outputs depending on
@@ -226,7 +230,7 @@ def connect(self, *args, **kwargs):
226
230
)
227
231
if sourcename and not srcnode ._check_outputs (sourcename ):
228
232
not_found .append (["out" , srcnode .name , sourcename ])
229
- connected_ports [destnode ] += [ dest ]
233
+ connected_ports [destnode ]. add ( dest )
230
234
infostr = []
231
235
for info in not_found :
232
236
infostr += [
@@ -269,6 +273,9 @@ def connect(self, *args, **kwargs):
269
273
"(%s, %s): new edge data: %s" , srcnode , destnode , str (edge_data )
270
274
)
271
275
276
+ if newnodes :
277
+ self ._update_node_cache ()
278
+
272
279
def disconnect (self , * args ):
273
280
"""Disconnect nodes
274
281
See the docstring for connect for format.
@@ -325,7 +332,7 @@ def add_nodes(self, nodes):
325
332
newnodes = []
326
333
all_nodes = self ._get_all_nodes ()
327
334
for node in nodes :
328
- if self . _has_node ( node ) :
335
+ if node in all_nodes :
329
336
raise IOError ("Node %s already exists in the workflow" % node )
330
337
if isinstance (node , Workflow ):
331
338
for subnode in node ._get_all_nodes ():
@@ -346,6 +353,7 @@ def add_nodes(self, nodes):
346
353
if node ._hierarchy is None :
347
354
node ._hierarchy = self .name
348
355
self ._graph .add_nodes_from (newnodes )
356
+ self ._update_node_cache ()
349
357
350
358
def remove_nodes (self , nodes ):
351
359
"""Remove nodes from a workflow
@@ -356,6 +364,7 @@ def remove_nodes(self, nodes):
356
364
A list of EngineBase-based objects
357
365
"""
358
366
self ._graph .remove_nodes_from (nodes )
367
+ self ._update_node_cache ()
359
368
360
369
# Input-Output access
361
370
@property
@@ -895,22 +904,32 @@ def _set_node_input(self, node, param, source, sourceinfo):
895
904
node .set_input (param , deepcopy (newval ))
896
905
897
906
def _get_all_nodes (self ):
898
- allnodes = []
899
- for node in self ._graph .nodes ():
900
- if isinstance (node , Workflow ):
901
- allnodes .extend (node ._get_all_nodes ())
902
- else :
903
- allnodes .append (node )
907
+ allnodes = self ._nodes_cache - self ._nested_workflows_cache
908
+ for node in self ._nested_workflows_cache :
909
+ allnodes |= node ._get_all_nodes ()
904
910
return allnodes
905
911
906
- def _has_node (self , wanted_node ):
907
- for node in self ._graph .nodes ():
908
- if wanted_node == node :
909
- return True
912
+ def _update_node_cache (self ):
913
+ nodes = set (self ._graph )
914
+
915
+ added_nodes = nodes .difference (self ._nodes_cache )
916
+ removed_nodes = self ._nodes_cache .difference (nodes )
917
+
918
+ self ._nodes_cache = nodes
919
+ self ._nested_workflows_cache .difference_update (removed_nodes )
920
+
921
+ for node in added_nodes :
910
922
if isinstance (node , Workflow ):
911
- if node ._has_node (wanted_node ):
912
- return True
913
- return False
923
+ self ._nested_workflows_cache .add (node )
924
+
925
+ def _has_node (self , wanted_node ):
926
+ return (
927
+ wanted_node in self ._nodes_cache or
928
+ any (
929
+ wf ._has_node (wanted_node )
930
+ for wf in self ._nested_workflows_cache
931
+ )
932
+ )
914
933
915
934
def _create_flat_graph (self ):
916
935
"""Make a simple DAG where no node is a workflow."""
@@ -939,7 +958,7 @@ def _generate_flatgraph(self):
939
958
raise Exception (
940
959
("Workflow: %s is not a directed acyclic graph " "(DAG)" ) % self .name
941
960
)
942
- nodes = list (nx . topological_sort ( self ._graph ) )
961
+ nodes = list (self ._graph . nodes )
943
962
for node in nodes :
944
963
logger .debug ("processing node: %s" , node )
945
964
if isinstance (node , Workflow ):
0 commit comments