44
55import threading
66import traceback
7+ from enum import Enum , auto
78from functools import wraps
8- from typing import Any , Callable , List
9+ from typing import Any , Callable , List , Optional
10+
11+
12+ class AggMode (Enum ):
13+ MAP = auto () # whenever upstream produces a result, run
14+ ALL_REDUCE = auto () # wait for all upstream results, then run
915
1016
1117class Context (dict ):
@@ -22,9 +28,18 @@ def get(self, k, default=None):
2228
2329class OpNode :
2430 def __init__ (
25- self , name : str , deps : List [str ], func : Callable [["OpNode" , Context ], Any ]
31+ self ,
32+ name : str ,
33+ deps : List [str ],
34+ compute_func : Callable [["OpNode" , Context ], Any ],
35+ callback_func : Optional [Callable [["OpNode" , Context , List [Any ]], None ]] = None ,
36+ agg_mode : AggMode = AggMode .ALL_REDUCE ,
2637 ):
27- self .name , self .deps , self .func = name , deps , func
38+ self .name = name
39+ self .deps = deps
40+ self .compute_func = compute_func
41+ self .callback_func = callback_func or (lambda self , ctx , results : None )
42+ self .agg_mode = agg_mode
2843
2944
3045def op (name : str , deps = None ):
@@ -97,6 +112,67 @@ def _exec(n: str):
97112 )
98113
99114
115+ class Bucket :
116+ """
117+ Bucket for a single operation, collecting computation results and triggering downstream ops
118+ """
119+
120+ def __init__ (
121+ self , name : str , size : int , mode : AggMode , callback : Callable [[List [Any ]], None ]
122+ ):
123+ self .name = name
124+ self .size = size
125+ self .mode = mode
126+ self .callback = callback
127+ self ._lock = threading .Lock ()
128+ self ._results : List [Any ] = []
129+ self ._done = False
130+
131+ def put (self , result : Any ):
132+ with self ._lock :
133+ if self ._done :
134+ return
135+ self ._results .append (result )
136+
137+ if self .mode == AggMode .MAP or len (self ._results ) == self .size :
138+ self ._fire ()
139+
140+ def _fire (self ):
141+ self ._done = True
142+ threading .Thread (target = self ._callback_wrapper , daemon = True ).start ()
143+
144+ def _callback_wrapper (self ):
145+ try :
146+ self .callback (self ._results )
147+ except Exception : # pylint: disable=broad-except
148+ traceback .print_exc ()
149+
150+
151+ class BucketManager :
152+ def __init__ (self ):
153+ self ._buckets : dict [str , Bucket ] = {}
154+ self ._lock = threading .Lock ()
155+
156+ def register (
157+ self ,
158+ node_name : str ,
159+ bucket_size : int ,
160+ mode : AggMode ,
161+ callback : Callable [[List [Any ]], None ],
162+ ):
163+ with self ._lock :
164+ if node_name in self ._buckets :
165+ raise RuntimeError (f"Bucket { node_name } already registered" )
166+ self ._buckets [node_name ] = Bucket (
167+ name = node_name , size = bucket_size , mode = mode , callback = callback
168+ )
169+ return self ._buckets [node_name ]
170+
171+ def get (self , node_name : str ) -> Optional [Bucket ]:
172+ with self ._lock :
173+ return self ._buckets .get (node_name )
174+
175+
100176def collect_ops (config : dict , graph_gen ) -> List [OpNode ]:
101177 """
102178 build operation nodes from yaml config
0 commit comments