Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 7 additions & 8 deletions graphgen/common/init_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import ray

from graphgen.bases import BaseLLMWrapper
from graphgen.common.init_storage import get_actor_handle
from graphgen.models import Tokenizer


Expand Down Expand Up @@ -74,9 +73,9 @@ class LLMServiceProxy(BaseLLMWrapper):
A proxy class to interact with the LLMServiceActor for distributed LLM operations.
"""

def __init__(self, actor_name: str):
def __init__(self, actor_handle: ray.actor.ActorHandle):
super().__init__()
self.actor_handle = get_actor_handle(actor_name)
self.actor_handle = actor_handle
self._create_local_tokenizer()

async def generate_answer(
Expand Down Expand Up @@ -128,25 +127,25 @@ def create_llm(

actor_name = f"Actor_LLM_{model_type}"
try:
ray.get_actor(actor_name)
actor_handle = ray.get_actor(actor_name)
print(f"Using existing Ray actor: {actor_name}")
except ValueError:
print(f"Creating Ray actor for LLM {model_type} with backend {backend}.")
Comment on lines +131 to 133
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

For consistency with logging across the project, it's better to use the configured logger instead of print(). This provides benefits like log levels, timestamps, and consistent formatting. You could import logger from graphgen.utils and replace the print calls with logger.info().

num_gpus = float(config.pop("num_gpus", 0))
actor = (
actor_handle = (
ray.remote(LLMServiceActor)
.options(
name=actor_name,
num_gpus=num_gpus,
lifetime="detached",
get_if_exists=True,
)
.remote(backend, config)
)

# wait for actor to be ready
ray.get(actor.ready.remote())
ray.get(actor_handle.ready.remote())

return LLMServiceProxy(actor_name)
return LLMServiceProxy(actor_handle)


def _load_env_group(prefix: str) -> Dict[str, Any]:
Expand Down
61 changes: 25 additions & 36 deletions graphgen/common/init_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,9 @@ def drop(self):
def reload(self):
return self.kv.reload()

def ready(self) -> bool:
return True


class GraphStorageActor:
def __init__(self, backend: str, working_dir: str, namespace: str):
Expand Down Expand Up @@ -114,22 +117,14 @@ def delete_node(self, node_id: str):
def reload(self):
return self.graph.reload()


def get_actor_handle(name: str):
try:
return ray.get_actor(name)
except ValueError as exc:
raise RuntimeError(
f"Actor {name} not found. Make sure it is created before accessing."
) from exc
def ready(self) -> bool:
return True


class RemoteKVStorageProxy(BaseKVStorage):
def __init__(self, namespace: str):
def __init__(self, actor_handle: ray.actor.ActorHandle):
super().__init__()
self.namespace = namespace
self.actor_name = f"Actor_KV_{namespace}"
self.actor = get_actor_handle(self.actor_name)
self.actor = actor_handle

def data(self) -> Dict[str, Any]:
return ray.get(self.actor.data.remote())
Expand Down Expand Up @@ -163,11 +158,9 @@ def reload(self):


class RemoteGraphStorageProxy(BaseGraphStorage):
def __init__(self, namespace: str):
def __init__(self, actor_handle: ray.actor.ActorHandle):
super().__init__()
self.namespace = namespace
self.actor_name = f"Actor_Graph_{namespace}"
self.actor = get_actor_handle(self.actor_name)
self.actor = actor_handle

def index_done_callback(self):
return ray.get(self.actor.index_done_callback.remote())
Expand Down Expand Up @@ -235,27 +228,23 @@ class StorageFactory:
def create_storage(backend: str, working_dir: str, namespace: str):
if backend in ["json_kv", "rocksdb"]:
actor_name = f"Actor_KV_{namespace}"
try:
ray.get_actor(actor_name)
except ValueError:
ray.remote(KVStorageActor).options(
name=actor_name,
lifetime="detached",
get_if_exists=True,
).remote(backend, working_dir, namespace)
return RemoteKVStorageProxy(namespace)
if backend in ["networkx", "kuzu"]:
actor_class = KVStorageActor
proxy_class = RemoteKVStorageProxy
elif backend in ["networkx", "kuzu"]:
actor_name = f"Actor_Graph_{namespace}"
try:
ray.get_actor(actor_name)
except ValueError:
ray.remote(GraphStorageActor).options(
name=actor_name,
lifetime="detached",
get_if_exists=True,
).remote(backend, working_dir, namespace)
return RemoteGraphStorageProxy(namespace)
raise ValueError(f"Unknown storage backend: {backend}")
actor_class = GraphStorageActor
proxy_class = RemoteGraphStorageProxy
else:
raise ValueError(f"Unknown storage backend: {backend}")
try:
actor_handle = ray.get_actor(actor_name)
except ValueError:
actor_handle = ray.remote(actor_class).options(
name=actor_name,
get_if_exists=True,
).remote(backend, working_dir, namespace)
ray.get(actor_handle.ready.remote())
return proxy_class(actor_handle)


def init_storage(backend: str, working_dir: str, namespace: str):
Expand Down
69 changes: 69 additions & 0 deletions graphgen/engine.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,20 @@
import os
import inspect
import logging
from collections import defaultdict, deque
from functools import wraps
from typing import Any, Callable, Dict, List, Set
from dotenv import load_dotenv

import ray
import ray.data
from ray.data import DataContext

from graphgen.bases import Config, Node
from graphgen.utils import logger
from graphgen.common import init_llm, init_storage

load_dotenv()

class Engine:
def __init__(
Expand All @@ -20,6 +24,8 @@ def __init__(
self.global_params = self.config.global_params
self.functions = functions
self.datasets: Dict[str, ray.data.Dataset] = {}
self.llm_actors = {}
self.storage_actors = {}

ctx = DataContext.get_current()
ctx.enable_rich_progress_bars = False
Expand All @@ -29,6 +35,16 @@ def __init__(
ctx.enable_tensor_extension_casting = False
ctx._metrics_export_port = 0 # Disable metrics exporter to avoid RpcError

all_env_vars = os.environ.copy()
if "runtime_env" not in ray_init_kwargs:
ray_init_kwargs["runtime_env"] = {}

existing_env_vars = ray_init_kwargs["runtime_env"].get("env_vars", {})
ray_init_kwargs["runtime_env"]["env_vars"] = {
**all_env_vars,
**existing_env_vars
}

if not ray.is_initialized():
context = ray.init(
ignore_reinit_error=True,
Expand All @@ -38,6 +54,59 @@ def __init__(
)
logger.info("Ray Dashboard URL: %s", context.dashboard_url)

self._init_llms()
self._init_storage()

def _init_llms(self):
self.llm_actors["synthesizer"] = init_llm("synthesizer")
self.llm_actors["trainee"] = init_llm("trainee")
Comment on lines +60 to +62
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The _init_llms method initializes self.llm_actors, but this attribute is not used anywhere else in the Engine class. This appears to be either dead code or part of an incomplete feature. If it's not used, it should be removed to avoid confusion and unnecessary initialization of LLM actors, which can be resource-intensive. If it is intended for use by operators, it should be passed to them, for example via a shared context.

Additionally, the LLM types synthesizer and trainee are hardcoded. This is inflexible. A better approach would be to derive the required LLM types from the configuration, for instance by inspecting the nodes in the graph for LLM requirements.


def _init_storage(self):
kv_namespaces, graph_namespaces = self._scan_storage_requirements()
working_dir = self.global_params["working_dir"]

for node_id in kv_namespaces:
proxy = init_storage(self.global_params["kv_backend"], working_dir, node_id)
self.storage_actors[f"kv_{node_id}"] = proxy
logger.info("Create KV Storage Actor: namespace=%s", node_id)

for ns in graph_namespaces:
proxy = init_storage(self.global_params["graph_backend"], working_dir, ns)
self.storage_actors[f"graph_{ns}"] = proxy
logger.info("Create Graph Storage Actor: namespace=%s", ns)

def _scan_storage_requirements(self) -> tuple[set[str], set[str]]:
kv_namespaces = set()
graph_namespaces = set()

# TODO: Temporarily hard-coded; node storage will be centrally managed later.
for node in self.config.nodes:
op_name = node.op_name
if self._function_needs_param(op_name, "kv_backend"):
kv_namespaces.add(op_name)
if self._function_needs_param(op_name, "graph_backend"):
graph_namespaces.add("graph")
return kv_namespaces, graph_namespaces

def _function_needs_param(self, op_name: str, param_name: str) -> bool:
if op_name not in self.functions:
return False

func = self.functions[op_name]

if inspect.isclass(func):
try:
sig = inspect.signature(func.__init__)
return param_name in sig.parameters
except (ValueError, TypeError):
return False

try:
sig = inspect.signature(func)
return param_name in sig.parameters
except (ValueError, TypeError):
return False

@staticmethod
def _topo_sort(nodes: List[Node]) -> List[Node]:
id_to_node: Dict[str, Node] = {}
Expand Down
3 changes: 0 additions & 3 deletions graphgen/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@

import ray
import yaml
from dotenv import load_dotenv
from ray.data.block import Block
from ray.data.datasource.filename_provider import FilenameProvider

Expand All @@ -16,8 +15,6 @@

sys_path = os.path.abspath(os.path.dirname(__file__))

load_dotenv()


def set_working_dir(folder):
os.makedirs(folder, exist_ok=True)
Expand Down