-
Notifications
You must be signed in to change notification settings - Fork 63
Fix: fix detached actors #132
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
a6d32d7
44265e4
a16b8b7
48192fb
d155d20
fa230d2
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,16 +1,20 @@ | ||
| import os | ||
| import inspect | ||
| import logging | ||
| from collections import defaultdict, deque | ||
| from functools import wraps | ||
| from typing import Any, Callable, Dict, List, Set | ||
| from dotenv import load_dotenv | ||
|
|
||
| import ray | ||
| import ray.data | ||
| from ray.data import DataContext | ||
|
|
||
| from graphgen.bases import Config, Node | ||
| from graphgen.utils import logger | ||
| from graphgen.common import init_llm, init_storage | ||
|
|
||
| load_dotenv() | ||
|
|
||
| class Engine: | ||
| def __init__( | ||
|
|
@@ -20,6 +24,8 @@ def __init__( | |
| self.global_params = self.config.global_params | ||
| self.functions = functions | ||
| self.datasets: Dict[str, ray.data.Dataset] = {} | ||
| self.llm_actors = {} | ||
| self.storage_actors = {} | ||
|
|
||
| ctx = DataContext.get_current() | ||
| ctx.enable_rich_progress_bars = False | ||
|
|
@@ -29,6 +35,16 @@ def __init__( | |
| ctx.enable_tensor_extension_casting = False | ||
| ctx._metrics_export_port = 0 # Disable metrics exporter to avoid RpcError | ||
|
|
||
| all_env_vars = os.environ.copy() | ||
| if "runtime_env" not in ray_init_kwargs: | ||
| ray_init_kwargs["runtime_env"] = {} | ||
|
|
||
| existing_env_vars = ray_init_kwargs["runtime_env"].get("env_vars", {}) | ||
| ray_init_kwargs["runtime_env"]["env_vars"] = { | ||
| **all_env_vars, | ||
| **existing_env_vars | ||
| } | ||
|
|
||
| if not ray.is_initialized(): | ||
| context = ray.init( | ||
| ignore_reinit_error=True, | ||
|
|
@@ -38,6 +54,59 @@ def __init__( | |
| ) | ||
| logger.info("Ray Dashboard URL: %s", context.dashboard_url) | ||
|
|
||
| self._init_llms() | ||
| self._init_storage() | ||
|
|
||
| def _init_llms(self): | ||
| self.llm_actors["synthesizer"] = init_llm("synthesizer") | ||
| self.llm_actors["trainee"] = init_llm("trainee") | ||
|
Comment on lines
+60
to
+62
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The Additionally, the LLM types |
||
|
|
||
| def _init_storage(self): | ||
| kv_namespaces, graph_namespaces = self._scan_storage_requirements() | ||
| working_dir = self.global_params["working_dir"] | ||
|
|
||
| for node_id in kv_namespaces: | ||
| proxy = init_storage(self.global_params["kv_backend"], working_dir, node_id) | ||
| self.storage_actors[f"kv_{node_id}"] = proxy | ||
| logger.info("Create KV Storage Actor: namespace=%s", node_id) | ||
|
|
||
| for ns in graph_namespaces: | ||
| proxy = init_storage(self.global_params["graph_backend"], working_dir, ns) | ||
| self.storage_actors[f"graph_{ns}"] = proxy | ||
| logger.info("Create Graph Storage Actor: namespace=%s", ns) | ||
|
|
||
| def _scan_storage_requirements(self) -> tuple[set[str], set[str]]: | ||
| kv_namespaces = set() | ||
| graph_namespaces = set() | ||
|
|
||
| # TODO: Temporarily hard-coded; node storage will be centrally managed later. | ||
| for node in self.config.nodes: | ||
| op_name = node.op_name | ||
| if self._function_needs_param(op_name, "kv_backend"): | ||
| kv_namespaces.add(op_name) | ||
| if self._function_needs_param(op_name, "graph_backend"): | ||
| graph_namespaces.add("graph") | ||
| return kv_namespaces, graph_namespaces | ||
|
|
||
| def _function_needs_param(self, op_name: str, param_name: str) -> bool: | ||
| if op_name not in self.functions: | ||
| return False | ||
|
|
||
| func = self.functions[op_name] | ||
|
|
||
| if inspect.isclass(func): | ||
| try: | ||
| sig = inspect.signature(func.__init__) | ||
| return param_name in sig.parameters | ||
| except (ValueError, TypeError): | ||
| return False | ||
|
|
||
| try: | ||
| sig = inspect.signature(func) | ||
| return param_name in sig.parameters | ||
| except (ValueError, TypeError): | ||
| return False | ||
|
|
||
| @staticmethod | ||
| def _topo_sort(nodes: List[Node]) -> List[Node]: | ||
| id_to_node: Dict[str, Node] = {} | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For consistency with logging across the project, it's better to use the configured logger instead of
print(). This provides benefits like log levels, timestamps, and consistent formatting. You could importloggerfromgraphgen.utilsand replace theprintcalls withlogger.info().