Skip to content

Commit 166ecaf

Browse files
refactor: refactor engine to dataflow orchestration
1 parent 3927239 commit 166ecaf

File tree

1 file changed

+190
-147
lines changed

1 file changed

+190
-147
lines changed

graphgen/engine.py

Lines changed: 190 additions & 147 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,25 @@
11
"""
22
orchestration engine for GraphGen
33
"""
4-
4+
import queue
55
import threading
66
import traceback
77
from enum import Enum, auto
88
from functools import wraps
9-
from typing import Any, Callable, List, Optional
9+
from typing import Callable, Dict, List
10+
11+
12+
class OpType(Enum):
13+
STREAMING = auto() # once data from upstream arrives, process it immediately
14+
BARRIER = auto() # wait for all upstream data to arrive before processing
15+
16+
# TODO: implement batch processing
17+
# BATCH = auto() # process data in batches
1018

1119

12-
class AggMode(Enum):
13-
MAP = auto() # whenever upstream produces a result, run
14-
ALL_REDUCE = auto() # wait for all upstream results, then run
20+
# signals the end of a data stream
21+
class EndOfStream:
22+
pass
1523

1624

1725
class Context(dict):
@@ -31,18 +39,16 @@ def __init__(
3139
self,
3240
name: str,
3341
deps: List[str],
34-
compute_func: Callable[["OpNode", Context], Any],
35-
callback_func: Optional[Callable[["OpNode", Context, List[Any]], None]] = None,
36-
agg_mode: AggMode = AggMode.ALL_REDUCE,
42+
func: Callable,
43+
op_type: OpType = OpType.BARRIER, # use barrier by default
3744
):
3845
self.name = name
3946
self.deps = deps
40-
self.compute_func = compute_func
41-
self.callback_func = callback_func or (lambda self, ctx, results: None)
42-
self.agg_mode = agg_mode
47+
self.func = func
48+
self.op_type = op_type
4349

4450

45-
def op(name: str, deps=None, agg_mode: AggMode = AggMode.ALL_REDUCE):
51+
def op(name: str, deps=None, op_type: OpType = OpType.BARRIER):
4652
deps = deps or []
4753

4854
def decorator(func):
@@ -51,157 +57,168 @@ def _wrapper(*args, **kwargs):
5157
return func(*args, **kwargs)
5258

5359
_wrapper.op_node = OpNode(
54-
name=name,
55-
deps=deps,
56-
compute_func=lambda self, ctx: func(self),
57-
callback_func=lambda self, ctx, results: None,
58-
agg_mode=agg_mode,
60+
name,
61+
deps,
62+
func,
63+
op_type=op_type,
5964
)
6065
return _wrapper
6166

6267
return decorator
6368

6469

6570
class Engine:
66-
def __init__(self, max_workers: int = 4):
67-
self.max_workers = max_workers
68-
self.bucket_mgr = BucketManager()
71+
def __init__(self, queue_size: int = 100):
72+
self.queue_size = queue_size
6973

70-
def run(self, ops: List[OpNode], ctx: Context):
71-
name2op = {op.name: op for op in ops}
72-
topo_names = [op.name for op in self._topo_sort(ops)]
73-
74-
sem = threading.Semaphore(self.max_workers)
75-
done = {n: threading.Event() for n in name2op}
76-
exc = {}
77-
78-
for node in ops:
79-
bucket_size = ctx.get(f"_bucket_size_{node.name}", 1)
80-
self.bucket_mgr.register(
81-
node.name,
82-
bucket_size,
83-
node.agg_mode,
84-
lambda results, n=node: self._callback_wrapper(n, ctx, results),
74+
@staticmethod
75+
def _topo_sort(name2op: Dict[str, OpNode]) -> List[str]:
76+
adj = {n: [] for n in name2op}
77+
in_degree = {n: 0 for n in name2op}
78+
79+
for name, operation in name2op.items():
80+
for dep_name in operation.deps:
81+
if dep_name not in name2op:
82+
raise ValueError(f"Dependency {dep_name} of {name} not found")
83+
adj[dep_name].append(name)
84+
in_degree[name] += 1
85+
86+
# Kahn's algorithm for topological sorting
87+
queue_nodes = [n for n in name2op if in_degree[n] == 0]
88+
topo_order = []
89+
90+
while queue_nodes:
91+
u = queue_nodes.pop(0)
92+
topo_order.append(u)
93+
94+
for v in adj[u]:
95+
in_degree[v] -= 1
96+
if in_degree[v] == 0:
97+
queue_nodes.append(v)
98+
99+
# cycle detection
100+
if len(topo_order) != len(name2op):
101+
cycle_nodes = set(name2op.keys()) - set(topo_order)
102+
raise ValueError(f"Cyclic dependency detected among: {cycle_nodes}")
103+
return topo_order
104+
105+
def _build_channels(self, name2op):
106+
"""Return channels / consumers_of / producer_counts"""
107+
channels, consumers_of, producer_counts = {}, {}, {n: 0 for n in name2op}
108+
for name, operator in name2op.items():
109+
consumers_of[name] = []
110+
for dep in operator.deps:
111+
if dep not in name2op:
112+
raise ValueError(f"Dependency {dep} of {name} not found")
113+
channels[(dep, name)] = queue.Queue(maxsize=self.queue_size)
114+
consumers_of[dep].append(name)
115+
producer_counts[name] += 1
116+
return channels, consumers_of, producer_counts
117+
118+
def _run_workers(self, ordered_ops, channels, consumers_of, producer_counts, ctx):
119+
"""Run worker threads for each operation node."""
120+
exceptions, threads = {}, []
121+
for node in ordered_ops:
122+
t = threading.Thread(
123+
target=self._worker_loop,
124+
args=(node, channels, consumers_of, producer_counts, ctx, exceptions),
125+
daemon=True,
85126
)
86-
87-
def _exec(n: str):
88-
with sem:
89-
for d in name2op[n].deps:
90-
done[d].wait()
91-
if any(d in exc for d in name2op[n].deps):
92-
exc[n] = "Skipped due to failed dependencies"
93-
done[n].set()
94-
return
95-
96-
try:
97-
name2op[n].compute_func(name2op[n], ctx)
98-
except Exception: # pylint: disable=broad-except
99-
exc[n] = traceback.format_exc()
100-
finally:
101-
done[n].set()
102-
103-
ts = [
104-
threading.Thread(target=_exec, args=(name,), daemon=True)
105-
for name in topo_names
106-
]
107-
for t in ts:
108127
t.start()
109-
for t in ts:
128+
threads.append(t)
129+
for t in threads:
110130
t.join()
111-
if exc:
112-
raise RuntimeError(
113-
"Some operations failed:\n"
114-
+ "\n".join(f"---- {op} ----\n{tb}" for op, tb in exc.items())
115-
)
116-
117-
@staticmethod
118-
def _callback_wrapper(node: OpNode, ctx: Context, results: List[Any]):
119-
try:
120-
node.callback_func(node, ctx, results)
121-
except Exception: # pylint: disable=broad-except
122-
traceback.print_exc()
123-
124-
@staticmethod
125-
def _topo_sort(ops: List[OpNode]) -> List[OpNode]:
126-
name2op = {operation.name: operation for operation in ops}
127-
graph = {n: set(name2op[n].deps) for n in name2op}
128-
topo = []
129-
q = [n for n, d in graph.items() if not d]
130-
while q:
131-
cur = q.pop(0)
132-
topo.append(name2op[cur])
133-
for child in [c for c, d in graph.items() if cur in d]:
134-
graph[child].remove(cur)
135-
if not graph[child]:
136-
q.append(child)
137-
138-
if len(topo) != len(ops):
139-
raise ValueError(
140-
"Cyclic dependencies detected among operations."
141-
"Please check your configuration."
142-
)
143-
return topo
144-
131+
return exceptions
145132

146-
class Bucket:
147-
"""
148-
Bucket for a single operation, collecting computation results and triggering downstream ops
149-
"""
150-
151-
def __init__(
152-
self, name: str, size: int, mode: AggMode, callback: Callable[[List[Any]], None]
133+
def _worker_loop(
134+
self, node, channels, consumers_of, producer_counts, ctx, exceptions
153135
):
154-
self.name = name
155-
self.size = size
156-
self.mode = mode
157-
self.callback = callback
158-
self._lock = threading.Lock()
159-
self._results: List[Any] = []
160-
self._done = False
161-
162-
def put(self, result: Any):
163-
with self._lock:
164-
if self._done:
165-
return
166-
self._results.append(result)
136+
op_name = node.name
167137

168-
if self.mode == AggMode.MAP or len(self._results) == self.size:
169-
self._fire()
138+
def input_generator():
139+
# if no dependencies, yield None once
140+
if not node.deps:
141+
yield None
142+
return
170143

171-
def _fire(self):
172-
self._done = True
173-
threading.Thread(target=self._callback_wrapper, daemon=True).start()
144+
active_producers = producer_counts[op_name]
145+
# collect all queues
146+
input_queues = [channels[(dep_name, op_name)] for dep_name in node.deps]
147+
148+
# loop until all producers are done
149+
while active_producers > 0:
150+
got_data = False
151+
for q in input_queues:
152+
try:
153+
item = q.get(timeout=0.1)
154+
if isinstance(item, EndOfStream):
155+
active_producers -= 1
156+
else:
157+
yield item
158+
got_data = True
159+
except queue.Empty:
160+
continue
161+
162+
if not got_data and active_producers > 0:
163+
# barrier wait on the first active queue
164+
item = input_queues[0].get()
165+
if isinstance(item, EndOfStream):
166+
active_producers -= 1
167+
else:
168+
yield item
169+
170+
in_stream = input_generator()
174171

175-
def _callback_wrapper(self):
176172
try:
177-
self.callback(self._results)
173+
# execute the operation
174+
result_iter = []
175+
if node.op_type == OpType.BARRIER:
176+
# consume all input
177+
buffered_inputs = list(in_stream)
178+
res = node.func(self, ctx, inputs=buffered_inputs)
179+
if res is not None:
180+
result_iter = res if isinstance(res, (list, tuple)) else [res]
181+
182+
elif node.op_type == OpType.STREAMING:
183+
# process input one by one
184+
res = node.func(self, ctx, input_stream=in_stream)
185+
if res is not None:
186+
result_iter = res
187+
else:
188+
raise ValueError(f"Unknown OpType {node.op_type} for {op_name}")
189+
190+
# output dispatch, send results to downstream consumers
191+
if result_iter:
192+
for item in result_iter:
193+
for consumer_name in consumers_of[op_name]:
194+
channels[(op_name, consumer_name)].put(item)
195+
178196
except Exception: # pylint: disable=broad-except
179197
traceback.print_exc()
198+
exceptions[op_name] = traceback.format_exc()
180199

200+
finally:
201+
# signal end of stream to downstream consumers
202+
for consumer_name in consumers_of[op_name]:
203+
channels[(op_name, consumer_name)].put(EndOfStream())
181204

182-
class BucketManager:
183-
def __init__(self):
184-
self._buckets: dict[str, Bucket] = {}
185-
self._lock = threading.Lock()
205+
def run(self, ops: List[OpNode], ctx: Context):
206+
name2op = {op.name: op for op in ops}
186207

187-
def register(
188-
self,
189-
node_name: str,
190-
bucket_size: int,
191-
mode: AggMode,
192-
callback: Callable[[List[Any]], None],
193-
):
194-
with self._lock:
195-
if node_name in self._buckets:
196-
raise RuntimeError(f"Bucket {node_name} already registered")
197-
self._buckets[node_name] = Bucket(
198-
name=node_name, size=bucket_size, mode=mode, callback=callback
199-
)
200-
return self._buckets[node_name]
208+
# Step 1: topo sort and validate
209+
sorted_op_names = self._topo_sort(name2op)
201210

202-
def get(self, node_name: str) -> Optional[Bucket]:
203-
with self._lock:
204-
return self._buckets.get(node_name)
211+
# Step 2: build channels and tracking structures
212+
channels, consumers_of, producer_counts = self._build_channels(name2op)
213+
214+
# Step3: start worker threads using topo order
215+
ordered_ops = [name2op[name] for name in sorted_op_names]
216+
exceptions = self._run_workers(
217+
ordered_ops, channels, consumers_of, producer_counts, ctx
218+
)
219+
220+
if exceptions:
221+
raise RuntimeError(f"Engine encountered exceptions: {exceptions}")
205222

206223

207224
def collect_ops(config: dict, graph_gen) -> List[OpNode]:
@@ -217,13 +234,39 @@ def collect_ops(config: dict, graph_gen) -> List[OpNode]:
217234
op_node = method.op_node
218235

219236
# if there are runtime dependencies, override them
220-
runtime_deps = stage.get("deps", op_node.deps)
221-
op_node.deps = runtime_deps
237+
deps = stage.get("deps", op_node.deps)
238+
op_type = op_node.op_type
239+
240+
if op_type == OpType.BARRIER:
241+
if "params" in stage:
242+
243+
def func(self, ctx, inputs, m=method, sc=stage):
244+
return m(sc.get("params", {}), inputs=inputs)
245+
246+
else:
247+
248+
def func(self, ctx, inputs, m=method):
249+
return m(inputs=inputs)
250+
251+
elif op_type == OpType.STREAMING:
252+
if "params" in stage:
253+
254+
def func(self, ctx, input_stream, m=method, sc=stage):
255+
return m(sc.get("params", {}), input_stream=input_stream)
256+
257+
else:
258+
259+
def func(self, ctx, input_stream, m=method):
260+
return m(input_stream=input_stream)
222261

223-
if "params" in stage:
224-
params = stage["params"]
225-
op_node.compute_func = lambda self, ctx, m=method, p=params: m(p)
226262
else:
227-
op_node.compute_func = lambda self, ctx, m=method: m()
228-
ops.append(op_node)
263+
raise ValueError(f"Unknown OpType {op_type} for operation {name}")
264+
265+
new_node = OpNode(
266+
name=name,
267+
deps=deps,
268+
func=func,
269+
op_type=op_type,
270+
)
271+
ops.append(new_node)
229272
return ops

0 commit comments

Comments
 (0)