11"""
22orchestration engine for GraphGen
33"""
4-
4+ import queue
55import threading
66import traceback
77from enum import Enum , auto
88from functools import wraps
9- from typing import Any , Callable , List , Optional
9+ from typing import Callable , Dict , List
10+
11+
12+ class OpType (Enum ):
13+ STREAMING = auto () # once data from upstream arrives, process it immediately
14+ BARRIER = auto () # wait for all upstream data to arrive before processing
15+
16+ # TODO: implement batch processing
17+ # BATCH = auto() # process data in batches
1018
1119
12- class AggMode ( Enum ):
13- MAP = auto () # whenever upstream produces a result, run
14- ALL_REDUCE = auto () # wait for all upstream results, then run
20+ # signals the end of a data stream
21+ class EndOfStream :
22+ pass
1523
1624
1725class Context (dict ):
@@ -31,18 +39,16 @@ def __init__(
3139 self ,
3240 name : str ,
3341 deps : List [str ],
34- compute_func : Callable [["OpNode" , Context ], Any ],
35- callback_func : Optional [Callable [["OpNode" , Context , List [Any ]], None ]] = None ,
36- agg_mode : AggMode = AggMode .ALL_REDUCE ,
42+ func : Callable ,
43+ op_type : OpType = OpType .BARRIER , # use barrier by default
3744 ):
3845 self .name = name
3946 self .deps = deps
40- self .compute_func = compute_func
41- self .callback_func = callback_func or (lambda self , ctx , results : None )
42- self .agg_mode = agg_mode
47+ self .func = func
48+ self .op_type = op_type
4349
4450
45- def op (name : str , deps = None , agg_mode : AggMode = AggMode . ALL_REDUCE ):
51+ def op (name : str , deps = None , op_type : OpType = OpType . BARRIER ):
4652 deps = deps or []
4753
4854 def decorator (func ):
@@ -51,157 +57,168 @@ def _wrapper(*args, **kwargs):
5157 return func (* args , ** kwargs )
5258
5359 _wrapper .op_node = OpNode (
54- name = name ,
55- deps = deps ,
56- compute_func = lambda self , ctx : func (self ),
57- callback_func = lambda self , ctx , results : None ,
58- agg_mode = agg_mode ,
60+ name ,
61+ deps ,
62+ func ,
63+ op_type = op_type ,
5964 )
6065 return _wrapper
6166
6267 return decorator
6368
6469
6570class Engine :
66- def __init__ (self , max_workers : int = 4 ):
67- self .max_workers = max_workers
68- self .bucket_mgr = BucketManager ()
71+ def __init__ (self , queue_size : int = 100 ):
72+ self .queue_size = queue_size
6973
70- def run (self , ops : List [OpNode ], ctx : Context ):
71- name2op = {op .name : op for op in ops }
72- topo_names = [op .name for op in self ._topo_sort (ops )]
73-
74- sem = threading .Semaphore (self .max_workers )
75- done = {n : threading .Event () for n in name2op }
76- exc = {}
77-
78- for node in ops :
79- bucket_size = ctx .get (f"_bucket_size_{ node .name } " , 1 )
80- self .bucket_mgr .register (
81- node .name ,
82- bucket_size ,
83- node .agg_mode ,
84- lambda results , n = node : self ._callback_wrapper (n , ctx , results ),
74+ @staticmethod
75+ def _topo_sort (name2op : Dict [str , OpNode ]) -> List [str ]:
76+ adj = {n : [] for n in name2op }
77+ in_degree = {n : 0 for n in name2op }
78+
79+ for name , operation in name2op .items ():
80+ for dep_name in operation .deps :
81+ if dep_name not in name2op :
82+ raise ValueError (f"Dependency { dep_name } of { name } not found" )
83+ adj [dep_name ].append (name )
84+ in_degree [name ] += 1
85+
86+ # Kahn's algorithm for topological sorting
87+ queue_nodes = [n for n in name2op if in_degree [n ] == 0 ]
88+ topo_order = []
89+
90+ while queue_nodes :
91+ u = queue_nodes .pop (0 )
92+ topo_order .append (u )
93+
94+ for v in adj [u ]:
95+ in_degree [v ] -= 1
96+ if in_degree [v ] == 0 :
97+ queue_nodes .append (v )
98+
99+ # cycle detection
100+ if len (topo_order ) != len (name2op ):
101+ cycle_nodes = set (name2op .keys ()) - set (topo_order )
102+ raise ValueError (f"Cyclic dependency detected among: { cycle_nodes } " )
103+ return topo_order
104+
105+ def _build_channels (self , name2op ):
106+ """Return channels / consumers_of / producer_counts"""
107+ channels , consumers_of , producer_counts = {}, {}, {n : 0 for n in name2op }
108+ for name , operator in name2op .items ():
109+ consumers_of [name ] = []
110+ for dep in operator .deps :
111+ if dep not in name2op :
112+ raise ValueError (f"Dependency { dep } of { name } not found" )
113+ channels [(dep , name )] = queue .Queue (maxsize = self .queue_size )
114+ consumers_of [dep ].append (name )
115+ producer_counts [name ] += 1
116+ return channels , consumers_of , producer_counts
117+
118+ def _run_workers (self , ordered_ops , channels , consumers_of , producer_counts , ctx ):
119+ """Run worker threads for each operation node."""
120+ exceptions , threads = {}, []
121+ for node in ordered_ops :
122+ t = threading .Thread (
123+ target = self ._worker_loop ,
124+ args = (node , channels , consumers_of , producer_counts , ctx , exceptions ),
125+ daemon = True ,
85126 )
86-
87- def _exec (n : str ):
88- with sem :
89- for d in name2op [n ].deps :
90- done [d ].wait ()
91- if any (d in exc for d in name2op [n ].deps ):
92- exc [n ] = "Skipped due to failed dependencies"
93- done [n ].set ()
94- return
95-
96- try :
97- name2op [n ].compute_func (name2op [n ], ctx )
98- except Exception : # pylint: disable=broad-except
99- exc [n ] = traceback .format_exc ()
100- finally :
101- done [n ].set ()
102-
103- ts = [
104- threading .Thread (target = _exec , args = (name ,), daemon = True )
105- for name in topo_names
106- ]
107- for t in ts :
108127 t .start ()
109- for t in ts :
128+ threads .append (t )
129+ for t in threads :
110130 t .join ()
111- if exc :
112- raise RuntimeError (
113- "Some operations failed:\n "
114- + "\n " .join (f"---- { op } ----\n { tb } " for op , tb in exc .items ())
115- )
116-
117- @staticmethod
118- def _callback_wrapper (node : OpNode , ctx : Context , results : List [Any ]):
119- try :
120- node .callback_func (node , ctx , results )
121- except Exception : # pylint: disable=broad-except
122- traceback .print_exc ()
123-
124- @staticmethod
125- def _topo_sort (ops : List [OpNode ]) -> List [OpNode ]:
126- name2op = {operation .name : operation for operation in ops }
127- graph = {n : set (name2op [n ].deps ) for n in name2op }
128- topo = []
129- q = [n for n , d in graph .items () if not d ]
130- while q :
131- cur = q .pop (0 )
132- topo .append (name2op [cur ])
133- for child in [c for c , d in graph .items () if cur in d ]:
134- graph [child ].remove (cur )
135- if not graph [child ]:
136- q .append (child )
137-
138- if len (topo ) != len (ops ):
139- raise ValueError (
140- "Cyclic dependencies detected among operations."
141- "Please check your configuration."
142- )
143- return topo
144-
131+ return exceptions
145132
146- class Bucket :
147- """
148- Bucket for a single operation, collecting computation results and triggering downstream ops
149- """
150-
151- def __init__ (
152- self , name : str , size : int , mode : AggMode , callback : Callable [[List [Any ]], None ]
133+ def _worker_loop (
134+ self , node , channels , consumers_of , producer_counts , ctx , exceptions
153135 ):
154- self .name = name
155- self .size = size
156- self .mode = mode
157- self .callback = callback
158- self ._lock = threading .Lock ()
159- self ._results : List [Any ] = []
160- self ._done = False
161-
162- def put (self , result : Any ):
163- with self ._lock :
164- if self ._done :
165- return
166- self ._results .append (result )
136+ op_name = node .name
167137
168- if self .mode == AggMode .MAP or len (self ._results ) == self .size :
169- self ._fire ()
138+ def input_generator ():
139+ # if no dependencies, yield None once
140+ if not node .deps :
141+ yield None
142+ return
170143
171- def _fire (self ):
172- self ._done = True
173- threading .Thread (target = self ._callback_wrapper , daemon = True ).start ()
144+ active_producers = producer_counts [op_name ]
145+ # collect all queues
146+ input_queues = [channels [(dep_name , op_name )] for dep_name in node .deps ]
147+
148+ # loop until all producers are done
149+ while active_producers > 0 :
150+ got_data = False
151+ for q in input_queues :
152+ try :
153+ item = q .get (timeout = 0.1 )
154+ if isinstance (item , EndOfStream ):
155+ active_producers -= 1
156+ else :
157+ yield item
158+ got_data = True
159+ except queue .Empty :
160+ continue
161+
162+ if not got_data and active_producers > 0 :
163+ # barrier wait on the first active queue
164+ item = input_queues [0 ].get ()
165+ if isinstance (item , EndOfStream ):
166+ active_producers -= 1
167+ else :
168+ yield item
169+
170+ in_stream = input_generator ()
174171
175- def _callback_wrapper (self ):
176172 try :
177- self .callback (self ._results )
173+ # execute the operation
174+ result_iter = []
175+ if node .op_type == OpType .BARRIER :
176+ # consume all input
177+ buffered_inputs = list (in_stream )
178+ res = node .func (self , ctx , inputs = buffered_inputs )
179+ if res is not None :
180+ result_iter = res if isinstance (res , (list , tuple )) else [res ]
181+
182+ elif node .op_type == OpType .STREAMING :
183+ # process input one by one
184+ res = node .func (self , ctx , input_stream = in_stream )
185+ if res is not None :
186+ result_iter = res
187+ else :
188+ raise ValueError (f"Unknown OpType { node .op_type } for { op_name } " )
189+
190+ # output dispatch, send results to downstream consumers
191+ if result_iter :
192+ for item in result_iter :
193+ for consumer_name in consumers_of [op_name ]:
194+ channels [(op_name , consumer_name )].put (item )
195+
178196 except Exception : # pylint: disable=broad-except
179197 traceback .print_exc ()
198+ exceptions [op_name ] = traceback .format_exc ()
180199
200+ finally :
201+ # signal end of stream to downstream consumers
202+ for consumer_name in consumers_of [op_name ]:
203+ channels [(op_name , consumer_name )].put (EndOfStream ())
181204
182- class BucketManager :
183- def __init__ (self ):
184- self ._buckets : dict [str , Bucket ] = {}
185- self ._lock = threading .Lock ()
205+ def run (self , ops : List [OpNode ], ctx : Context ):
206+ name2op = {op .name : op for op in ops }
186207
187- def register (
188- self ,
189- node_name : str ,
190- bucket_size : int ,
191- mode : AggMode ,
192- callback : Callable [[List [Any ]], None ],
193- ):
194- with self ._lock :
195- if node_name in self ._buckets :
196- raise RuntimeError (f"Bucket { node_name } already registered" )
197- self ._buckets [node_name ] = Bucket (
198- name = node_name , size = bucket_size , mode = mode , callback = callback
199- )
200- return self ._buckets [node_name ]
208+ # Step 1: topo sort and validate
209+ sorted_op_names = self ._topo_sort (name2op )
201210
202- def get (self , node_name : str ) -> Optional [Bucket ]:
203- with self ._lock :
204- return self ._buckets .get (node_name )
211+ # Step 2: build channels and tracking structures
212+ channels , consumers_of , producer_counts = self ._build_channels (name2op )
213+
214+ # Step3: start worker threads using topo order
215+ ordered_ops = [name2op [name ] for name in sorted_op_names ]
216+ exceptions = self ._run_workers (
217+ ordered_ops , channels , consumers_of , producer_counts , ctx
218+ )
219+
220+ if exceptions :
221+ raise RuntimeError (f"Engine encountered exceptions: { exceptions } " )
205222
206223
207224def collect_ops (config : dict , graph_gen ) -> List [OpNode ]:
@@ -217,13 +234,39 @@ def collect_ops(config: dict, graph_gen) -> List[OpNode]:
217234 op_node = method .op_node
218235
219236 # if there are runtime dependencies, override them
220- runtime_deps = stage .get ("deps" , op_node .deps )
221- op_node .deps = runtime_deps
237+ deps = stage .get ("deps" , op_node .deps )
238+ op_type = op_node .op_type
239+
240+ if op_type == OpType .BARRIER :
241+ if "params" in stage :
242+
243+ def func (self , ctx , inputs , m = method , sc = stage ):
244+ return m (sc .get ("params" , {}), inputs = inputs )
245+
246+ else :
247+
248+ def func (self , ctx , inputs , m = method ):
249+ return m (inputs = inputs )
250+
251+ elif op_type == OpType .STREAMING :
252+ if "params" in stage :
253+
254+ def func (self , ctx , input_stream , m = method , sc = stage ):
255+ return m (sc .get ("params" , {}), input_stream = input_stream )
256+
257+ else :
258+
259+ def func (self , ctx , input_stream , m = method ):
260+ return m (input_stream = input_stream )
222261
223- if "params" in stage :
224- params = stage ["params" ]
225- op_node .compute_func = lambda self , ctx , m = method , p = params : m (p )
226262 else :
227- op_node .compute_func = lambda self , ctx , m = method : m ()
228- ops .append (op_node )
263+ raise ValueError (f"Unknown OpType { op_type } for operation { name } " )
264+
265+ new_node = OpNode (
266+ name = name ,
267+ deps = deps ,
268+ func = func ,
269+ op_type = op_type ,
270+ )
271+ ops .append (new_node )
229272 return ops
0 commit comments