diff --git a/graphgen/common/init_llm.py b/graphgen/common/init_llm.py index af53709a..52604432 100644 --- a/graphgen/common/init_llm.py +++ b/graphgen/common/init_llm.py @@ -4,7 +4,6 @@ import ray from graphgen.bases import BaseLLMWrapper -from graphgen.common.init_storage import get_actor_handle from graphgen.models import Tokenizer @@ -74,9 +73,9 @@ class LLMServiceProxy(BaseLLMWrapper): A proxy class to interact with the LLMServiceActor for distributed LLM operations. """ - def __init__(self, actor_name: str): + def __init__(self, actor_handle: ray.actor.ActorHandle): super().__init__() - self.actor_handle = get_actor_handle(actor_name) + self.actor_handle = actor_handle self._create_local_tokenizer() async def generate_answer( @@ -128,25 +127,25 @@ def create_llm( actor_name = f"Actor_LLM_{model_type}" try: - ray.get_actor(actor_name) + actor_handle = ray.get_actor(actor_name) + print(f"Using existing Ray actor: {actor_name}") except ValueError: print(f"Creating Ray actor for LLM {model_type} with backend {backend}.") num_gpus = float(config.pop("num_gpus", 0)) - actor = ( + actor_handle = ( ray.remote(LLMServiceActor) .options( name=actor_name, num_gpus=num_gpus, - lifetime="detached", get_if_exists=True, ) .remote(backend, config) ) # wait for actor to be ready - ray.get(actor.ready.remote()) + ray.get(actor_handle.ready.remote()) - return LLMServiceProxy(actor_name) + return LLMServiceProxy(actor_handle) def _load_env_group(prefix: str) -> Dict[str, Any]: diff --git a/graphgen/common/init_storage.py b/graphgen/common/init_storage.py index b9358485..56528e7a 100644 --- a/graphgen/common/init_storage.py +++ b/graphgen/common/init_storage.py @@ -48,6 +48,9 @@ def drop(self): def reload(self): return self.kv.reload() + def ready(self) -> bool: + return True + class GraphStorageActor: def __init__(self, backend: str, working_dir: str, namespace: str): @@ -114,22 +117,14 @@ def delete_node(self, node_id: str): def reload(self): return self.graph.reload() - -def get_actor_handle(name: str): - try: - return ray.get_actor(name) - except ValueError as exc: - raise RuntimeError( - f"Actor {name} not found. Make sure it is created before accessing." - ) from exc + def ready(self) -> bool: + return True class RemoteKVStorageProxy(BaseKVStorage): - def __init__(self, namespace: str): + def __init__(self, actor_handle: ray.actor.ActorHandle): super().__init__() - self.namespace = namespace - self.actor_name = f"Actor_KV_{namespace}" - self.actor = get_actor_handle(self.actor_name) + self.actor = actor_handle def data(self) -> Dict[str, Any]: return ray.get(self.actor.data.remote()) @@ -163,11 +158,9 @@ def reload(self): class RemoteGraphStorageProxy(BaseGraphStorage): - def __init__(self, namespace: str): + def __init__(self, actor_handle: ray.actor.ActorHandle): super().__init__() - self.namespace = namespace - self.actor_name = f"Actor_Graph_{namespace}" - self.actor = get_actor_handle(self.actor_name) + self.actor = actor_handle def index_done_callback(self): return ray.get(self.actor.index_done_callback.remote()) @@ -235,27 +228,23 @@ class StorageFactory: def create_storage(backend: str, working_dir: str, namespace: str): if backend in ["json_kv", "rocksdb"]: actor_name = f"Actor_KV_{namespace}" - try: - ray.get_actor(actor_name) - except ValueError: - ray.remote(KVStorageActor).options( - name=actor_name, - lifetime="detached", - get_if_exists=True, - ).remote(backend, working_dir, namespace) - return RemoteKVStorageProxy(namespace) - if backend in ["networkx", "kuzu"]: + actor_class = KVStorageActor + proxy_class = RemoteKVStorageProxy + elif backend in ["networkx", "kuzu"]: actor_name = f"Actor_Graph_{namespace}" - try: - ray.get_actor(actor_name) - except ValueError: - ray.remote(GraphStorageActor).options( - name=actor_name, - lifetime="detached", - get_if_exists=True, - ).remote(backend, working_dir, namespace) - return RemoteGraphStorageProxy(namespace) - raise ValueError(f"Unknown storage backend: {backend}") + actor_class = GraphStorageActor + proxy_class = RemoteGraphStorageProxy + else: + raise ValueError(f"Unknown storage backend: {backend}") + try: + actor_handle = ray.get_actor(actor_name) + except ValueError: + actor_handle = ray.remote(actor_class).options( + name=actor_name, + get_if_exists=True, + ).remote(backend, working_dir, namespace) + ray.get(actor_handle.ready.remote()) + return proxy_class(actor_handle) def init_storage(backend: str, working_dir: str, namespace: str): diff --git a/graphgen/engine.py b/graphgen/engine.py index 501aa854..47ed242a 100644 --- a/graphgen/engine.py +++ b/graphgen/engine.py @@ -1,8 +1,10 @@ +import os import inspect import logging from collections import defaultdict, deque from functools import wraps from typing import Any, Callable, Dict, List, Set +from dotenv import load_dotenv import ray import ray.data @@ -10,7 +12,9 @@ from graphgen.bases import Config, Node from graphgen.utils import logger +from graphgen.common import init_llm, init_storage +load_dotenv() class Engine: def __init__( @@ -20,6 +24,8 @@ def __init__( self.global_params = self.config.global_params self.functions = functions self.datasets: Dict[str, ray.data.Dataset] = {} + self.llm_actors = {} + self.storage_actors = {} ctx = DataContext.get_current() ctx.enable_rich_progress_bars = False @@ -29,6 +35,16 @@ def __init__( ctx.enable_tensor_extension_casting = False ctx._metrics_export_port = 0 # Disable metrics exporter to avoid RpcError + all_env_vars = os.environ.copy() + if "runtime_env" not in ray_init_kwargs: + ray_init_kwargs["runtime_env"] = {} + + existing_env_vars = ray_init_kwargs["runtime_env"].get("env_vars", {}) + ray_init_kwargs["runtime_env"]["env_vars"] = { + **all_env_vars, + **existing_env_vars + } + if not ray.is_initialized(): context = ray.init( ignore_reinit_error=True, @@ -38,6 +54,59 @@ def __init__( ) logger.info("Ray Dashboard URL: %s", context.dashboard_url) + self._init_llms() + self._init_storage() + + def _init_llms(self): + self.llm_actors["synthesizer"] = init_llm("synthesizer") + self.llm_actors["trainee"] = init_llm("trainee") + + def _init_storage(self): + kv_namespaces, graph_namespaces = self._scan_storage_requirements() + working_dir = self.global_params["working_dir"] + + for node_id in kv_namespaces: + proxy = init_storage(self.global_params["kv_backend"], working_dir, node_id) + self.storage_actors[f"kv_{node_id}"] = proxy + logger.info("Create KV Storage Actor: namespace=%s", node_id) + + for ns in graph_namespaces: + proxy = init_storage(self.global_params["graph_backend"], working_dir, ns) + self.storage_actors[f"graph_{ns}"] = proxy + logger.info("Create Graph Storage Actor: namespace=%s", ns) + + def _scan_storage_requirements(self) -> tuple[set[str], set[str]]: + kv_namespaces = set() + graph_namespaces = set() + + # TODO: Temporarily hard-coded; node storage will be centrally managed later. + for node in self.config.nodes: + op_name = node.op_name + if self._function_needs_param(op_name, "kv_backend"): + kv_namespaces.add(op_name) + if self._function_needs_param(op_name, "graph_backend"): + graph_namespaces.add("graph") + return kv_namespaces, graph_namespaces + + def _function_needs_param(self, op_name: str, param_name: str) -> bool: + if op_name not in self.functions: + return False + + func = self.functions[op_name] + + if inspect.isclass(func): + try: + sig = inspect.signature(func.__init__) + return param_name in sig.parameters + except (ValueError, TypeError): + return False + + try: + sig = inspect.signature(func) + return param_name in sig.parameters + except (ValueError, TypeError): + return False + @staticmethod def _topo_sort(nodes: List[Node]) -> List[Node]: id_to_node: Dict[str, Node] = {} diff --git a/graphgen/run.py b/graphgen/run.py index b0383867..a1b65364 100644 --- a/graphgen/run.py +++ b/graphgen/run.py @@ -6,7 +6,6 @@ import ray import yaml -from dotenv import load_dotenv from ray.data.block import Block from ray.data.datasource.filename_provider import FilenameProvider @@ -16,8 +15,6 @@ sys_path = os.path.abspath(os.path.dirname(__file__)) -load_dotenv() - def set_working_dir(folder): os.makedirs(folder, exist_ok=True)