Skip to content

Commit 899cf01

Browse files
feat: add bucket for map and all-reduce
1 parent 18a616d commit 899cf01

File tree

1 file changed

+79
-3
lines changed

1 file changed

+79
-3
lines changed

graphgen/engine.py

Lines changed: 79 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,14 @@
44

55
import threading
66
import traceback
7+
from enum import Enum, auto
78
from functools import wraps
8-
from typing import Any, Callable, List
9+
from typing import Any, Callable, List, Optional
10+
11+
12+
class AggMode(Enum):
13+
MAP = auto() # whenever upstream produces a result, run
14+
ALL_REDUCE = auto() # wait for all upstream results, then run
915

1016

1117
class Context(dict):
@@ -22,9 +28,18 @@ def get(self, k, default=None):
2228

2329
class OpNode:
2430
def __init__(
25-
self, name: str, deps: List[str], func: Callable[["OpNode", Context], Any]
31+
self,
32+
name: str,
33+
deps: List[str],
34+
compute_func: Callable[["OpNode", Context], Any],
35+
callback_func: Optional[Callable[["OpNode", Context, List[Any]], None]] = None,
36+
agg_mode: AggMode = AggMode.ALL_REDUCE,
2637
):
27-
self.name, self.deps, self.func = name, deps, func
38+
self.name = name
39+
self.deps = deps
40+
self.compute_func = compute_func
41+
self.callback_func = callback_func or (lambda self, ctx, results: None)
42+
self.agg_mode = agg_mode
2843

2944

3045
def op(name: str, deps=None):
@@ -97,6 +112,67 @@ def _exec(n: str):
97112
)
98113

99114

115+
class Bucket:
116+
"""
117+
Bucket for a single operation, collecting computation results and triggering downstream ops
118+
"""
119+
120+
def __init__(
121+
self, name: str, size: int, mode: AggMode, callback: Callable[[List[Any]], None]
122+
):
123+
self.name = name
124+
self.size = size
125+
self.mode = mode
126+
self.callback = callback
127+
self._lock = threading.Lock()
128+
self._results: List[Any] = []
129+
self._done = False
130+
131+
def put(self, result: Any):
132+
with self._lock:
133+
if self._done:
134+
return
135+
self._results.append(result)
136+
137+
if self.mode == AggMode.MAP or len(self._results) == self.size:
138+
self._fire()
139+
140+
def _fire(self):
141+
self._done = True
142+
threading.Thread(target=self._callback_wrapper, daemon=True).start()
143+
144+
def _callback_wrapper(self):
145+
try:
146+
self.callback(self._results)
147+
except Exception: # pylint: disable=broad-except
148+
traceback.print_exc()
149+
150+
151+
class BucketManager:
152+
def __init__(self):
153+
self._buckets: dict[str, Bucket] = {}
154+
self._lock = threading.Lock()
155+
156+
def register(
157+
self,
158+
node_name: str,
159+
bucket_size: int,
160+
mode: AggMode,
161+
callback: Callable[[List[Any]], None],
162+
):
163+
with self._lock:
164+
if node_name in self._buckets:
165+
raise RuntimeError(f"Bucket {node_name} already registered")
166+
self._buckets[node_name] = Bucket(
167+
name=node_name, size=bucket_size, mode=mode, callback=callback
168+
)
169+
return self._buckets[node_name]
170+
171+
def get(self, node_name: str) -> Optional[Bucket]:
172+
with self._lock:
173+
return self._buckets.get(node_name)
174+
175+
100176
def collect_ops(config: dict, graph_gen) -> List[OpNode]:
101177
"""
102178
build operation nodes from yaml config

0 commit comments

Comments
 (0)