|
| 1 | +""" |
| 2 | +orchestration engine for GraphGen |
| 3 | +""" |
| 4 | + |
| 5 | +import threading |
| 6 | +import traceback |
| 7 | +from functools import wraps |
| 8 | +from typing import Any, Callable, List |
| 9 | + |
| 10 | + |
| 11 | +class Context(dict): |
| 12 | + _lock = threading.Lock() |
| 13 | + |
| 14 | + def set(self, k, v): |
| 15 | + with self._lock: |
| 16 | + self[k] = v |
| 17 | + |
| 18 | + def get(self, k, default=None): |
| 19 | + with self._lock: |
| 20 | + return super().get(k, default) |
| 21 | + |
| 22 | + |
| 23 | +class OpNode: |
| 24 | + def __init__( |
| 25 | + self, name: str, deps: List[str], func: Callable[["OpNode", Context], Any] |
| 26 | + ): |
| 27 | + self.name, self.deps, self.func = name, deps, func |
| 28 | + |
| 29 | + |
| 30 | +def op(name: str, deps=None): |
| 31 | + deps = deps or [] |
| 32 | + |
| 33 | + def decorator(func): |
| 34 | + @wraps(func) |
| 35 | + def _wrapper(*args, **kwargs): |
| 36 | + return func(*args, **kwargs) |
| 37 | + |
| 38 | + _wrapper.op_node = OpNode(name, deps, lambda self, ctx: func(self, **ctx)) |
| 39 | + return _wrapper |
| 40 | + |
| 41 | + return decorator |
| 42 | + |
| 43 | + |
| 44 | +class Engine: |
| 45 | + def __init__(self, max_workers: int = 4): |
| 46 | + self.max_workers = max_workers |
| 47 | + |
| 48 | + def run(self, ops: List[OpNode], ctx: Context): |
| 49 | + name2op = {operation.name: operation for operation in ops} |
| 50 | + |
| 51 | + # topological sort |
| 52 | + graph = {n: set(name2op[n].deps) for n in name2op} |
| 53 | + topo = [] |
| 54 | + q = [n for n, d in graph.items() if not d] |
| 55 | + while q: |
| 56 | + cur = q.pop(0) |
| 57 | + topo.append(cur) |
| 58 | + for child in [c for c, d in graph.items() if cur in d]: |
| 59 | + graph[child].remove(cur) |
| 60 | + if not graph[child]: |
| 61 | + q.append(child) |
| 62 | + |
| 63 | + if len(topo) != len(ops): |
| 64 | + raise ValueError( |
| 65 | + "Cyclic dependencies detected among operations." |
| 66 | + "Please check your configuration." |
| 67 | + ) |
| 68 | + |
| 69 | + # semaphore for max_workers |
| 70 | + sem = threading.Semaphore(self.max_workers) |
| 71 | + done = {n: threading.Event() for n in name2op} |
| 72 | + exc = {} |
| 73 | + |
| 74 | + def _exec(n: str): |
| 75 | + with sem: |
| 76 | + for d in name2op[n].deps: |
| 77 | + done[d].wait() |
| 78 | + if any(d in exc for d in name2op[n].deps): |
| 79 | + exc[n] = Exception("Skipped due to failed dependencies") |
| 80 | + done[n].set() |
| 81 | + return |
| 82 | + try: |
| 83 | + name2op[n].func(name2op[n], ctx) |
| 84 | + except Exception: # pylint: disable=broad-except |
| 85 | + exc[n] = traceback.format_exc() |
| 86 | + done[n].set() |
| 87 | + |
| 88 | + ts = [threading.Thread(target=_exec, args=(n,), daemon=True) for n in topo] |
| 89 | + for t in ts: |
| 90 | + t.start() |
| 91 | + for t in ts: |
| 92 | + t.join() |
| 93 | + if exc: |
| 94 | + raise RuntimeError( |
| 95 | + "Some operations failed:\n" |
| 96 | + + "\n".join(f"---- {op} ----\n{tb}" for op, tb in exc.items()) |
| 97 | + ) |
| 98 | + |
| 99 | + |
| 100 | +def collect_ops(config: dict, graph_gen) -> List[OpNode]: |
| 101 | + """ |
| 102 | + build operation nodes from yaml config |
| 103 | + :param config |
| 104 | + :param graph_gen |
| 105 | + """ |
| 106 | + ops: List[OpNode] = [] |
| 107 | + for stage in config["pipeline"]: |
| 108 | + name = stage["name"] |
| 109 | + method = getattr(graph_gen, name) |
| 110 | + op_node = method.op_node |
| 111 | + |
| 112 | + # if there are runtime dependencies, override them |
| 113 | + runtime_deps = stage.get("deps", op_node.deps) |
| 114 | + op_node.deps = runtime_deps |
| 115 | + |
| 116 | + op_node.func = lambda self, ctx, m=method, sc=stage: m(sc.get("params")) |
| 117 | + ops.append(op_node) |
| 118 | + return ops |
0 commit comments