Skip to content

Commit 0fbfcf2

Browse files
feat: add llm as actors
1 parent 68e5191 commit 0fbfcf2

File tree

5 files changed

+126
-52
lines changed

5 files changed

+126
-52
lines changed

graphgen/bases/base_llm_wrapper.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -72,9 +72,3 @@ def filter_think_tags(text: str, think_tag: str = "think") -> str:
7272

7373
filtered = filtered.strip()
7474
return filtered if filtered else text.strip()
75-
76-
def shutdown(self) -> None:
77-
"""Shutdown the LLM engine if applicable."""
78-
79-
def restart(self) -> None:
80-
"""Reinitialize the LLM engine if applicable."""

graphgen/common/init_llm.py

Lines changed: 124 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,57 +1,152 @@
11
import os
22
from typing import Any, Dict, Optional
33

4+
import ray
5+
46
from graphgen.bases import BaseLLMWrapper
7+
from graphgen.common.init_storage import get_actor_handle
58
from graphgen.models import Tokenizer
69

710

8-
class LLMFactory:
11+
class LLMServiceActor:
912
"""
10-
A factory class to create LLM wrapper instances based on the specified backend.
11-
Supported backends include:
12-
- http_api: HTTPClient
13-
- openai_api: OpenAIClient
14-
- ollama_api: OllamaClient
15-
- huggingface: HuggingFaceWrapper
16-
- sglang: SGLangWrapper
13+
A Ray actor class to wrap LLM wrapper instances for distributed usage.
1714
"""
1815

19-
@staticmethod
20-
def create_llm_wrapper(backend: str, config: Dict[str, Any]) -> BaseLLMWrapper:
21-
# add tokenizer
22-
tokenizer: Tokenizer = Tokenizer(
23-
os.environ.get("TOKENIZER_MODEL", "cl100k_base"),
24-
)
16+
def __init__(self, backend: str, config: Dict[str, Any]):
17+
self.backend = backend
18+
tokenizer_model = os.environ.get("TOKENIZER_MODEL", "cl100k_base")
19+
tokenizer = Tokenizer(model_name=tokenizer_model)
2520
config["tokenizer"] = tokenizer
21+
2622
if backend == "http_api":
2723
from graphgen.models.llm.api.http_client import HTTPClient
2824

29-
return HTTPClient(**config)
30-
if backend in ("openai_api", "azure_openai_api"):
25+
self.llm_instance = HTTPClient(**config)
26+
elif backend in ("openai_api", "azure_openai_api"):
3127
from graphgen.models.llm.api.openai_client import OpenAIClient
3228

3329
# pass in concrete backend to the OpenAIClient so that internally we can distinguish
3430
# between OpenAI and Azure OpenAI
35-
return OpenAIClient(**config, backend=backend)
36-
if backend == "ollama_api":
31+
self.llm_instance = OpenAIClient(**config, backend=backend)
32+
elif backend == "ollama_api":
3733
from graphgen.models.llm.api.ollama_client import OllamaClient
3834

39-
return OllamaClient(**config)
40-
if backend == "huggingface":
35+
self.llm_instance = OllamaClient(**config)
36+
elif backend == "huggingface":
4137
from graphgen.models.llm.local.hf_wrapper import HuggingFaceWrapper
4238

43-
return HuggingFaceWrapper(**config)
44-
if backend == "sglang":
39+
self.llm_instance = HuggingFaceWrapper(**config)
40+
elif backend == "sglang":
4541
from graphgen.models.llm.local.sglang_wrapper import SGLangWrapper
4642

47-
return SGLangWrapper(**config)
43+
self.llm_instance = SGLangWrapper(**config)
44+
45+
elif backend == "vllm":
46+
from graphgen.models.llm.local.vllm_wrapper import VLLMWrapper
47+
48+
self.llm_instance = VLLMWrapper(**config)
49+
else:
50+
raise NotImplementedError(f"Backend {backend} is not implemented yet.")
51+
52+
async def generate_answer(
53+
self, text: str, history: Optional[list[str]] = None, **extra: Any
54+
) -> str:
55+
return await self.llm_instance.generate_answer(text, history, **extra)
56+
57+
async def generate_topk_per_token(
58+
self, text: str, history: Optional[list[str]] = None, **extra: Any
59+
) -> list:
60+
return await self.llm_instance.generate_topk_per_token(text, history, **extra)
4861

49-
# if backend == "vllm":
50-
# from graphgen.models.llm.local.vllm_wrapper import VLLMWrapper
51-
#
52-
# return VLLMWrapper(**config)
62+
async def generate_inputs_prob(
63+
self, text: str, history: Optional[list[str]] = None, **extra: Any
64+
) -> list:
65+
return await self.llm_instance.generate_inputs_prob(text, history, **extra)
5366

54-
raise NotImplementedError(f"Backend {backend} is not implemented yet.")
67+
def ready(self) -> bool:
68+
"""A simple method to check if the actor is ready."""
69+
return True
70+
71+
72+
class LLMServiceProxy(BaseLLMWrapper):
73+
"""
74+
A proxy class to interact with the LLMServiceActor for distributed LLM operations.
75+
"""
76+
77+
def __init__(self, actor_name: str):
78+
super().__init__()
79+
self.actor_handle = get_actor_handle(actor_name)
80+
self._create_local_tokenizer()
81+
82+
async def generate_answer(
83+
self, text: str, history: Optional[list[str]] = None, **extra: Any
84+
) -> str:
85+
object_ref = self.actor_handle.generate_answer.remote(text, history, **extra)
86+
return await object_ref
87+
88+
async def generate_topk_per_token(
89+
self, text: str, history: Optional[list[str]] = None, **extra: Any
90+
) -> list:
91+
object_ref = self.actor_handle.generate_topk_per_token.remote(
92+
text, history, **extra
93+
)
94+
return await object_ref
95+
96+
async def generate_inputs_prob(
97+
self, text: str, history: Optional[list[str]] = None, **extra: Any
98+
) -> list:
99+
object_ref = self.actor_handle.generate_inputs_prob.remote(
100+
text, history, **extra
101+
)
102+
return await object_ref
103+
104+
def _create_local_tokenizer(self):
105+
tokenizer_model = os.environ.get("TOKENIZER_MODEL", "cl100k_base")
106+
self.tokenizer = Tokenizer(model_name=tokenizer_model)
107+
108+
109+
class LLMFactory:
110+
"""
111+
A factory class to create LLM wrapper instances based on the specified backend.
112+
Supported backends include:
113+
- http_api: HTTPClient
114+
- openai_api: OpenAIClient
115+
- ollama_api: OllamaClient
116+
- huggingface: HuggingFaceWrapper
117+
- sglang: SGLangWrapper
118+
"""
119+
120+
@staticmethod
121+
def create_llm(
122+
model_type: str, backend: str, config: Dict[str, Any]
123+
) -> BaseLLMWrapper:
124+
if not config:
125+
raise ValueError(
126+
f"No configuration provided for LLM {model_type} with backend {backend}."
127+
)
128+
129+
actor_name = f"Actor_LLM_{model_type}"
130+
try:
131+
ray.get_actor(actor_name)
132+
except ValueError:
133+
print(f"Creating Ray actor for LLM {model_type} with backend {backend}.")
134+
num_gpus = config.pop("num_gpus", 0)
135+
actor = (
136+
ray.remote(LLMServiceActor)
137+
.options(
138+
name=actor_name,
139+
num_gpus=num_gpus,
140+
lifetime="detached",
141+
get_if_exists=True,
142+
)
143+
.remote(backend, config)
144+
)
145+
146+
# wait for actor to be ready
147+
ray.get(actor.ready.remote())
148+
149+
return LLMServiceProxy(actor_name)
55150

56151

57152
def _load_env_group(prefix: str) -> Dict[str, Any]:
@@ -78,8 +173,5 @@ def init_llm(model_type: str) -> Optional[BaseLLMWrapper]:
78173
if not config:
79174
return None
80175
backend = config.pop("backend")
81-
llm_wrapper = LLMFactory.create_llm_wrapper(backend, config)
176+
llm_wrapper = LLMFactory.create_llm(model_type, backend, config)
82177
return llm_wrapper
83-
84-
85-
# TODO: use ray serve when loading large models to avoid re-loading in each actor

graphgen/common/init_storage.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ def __init__(self, backend: str, working_dir: str, namespace: str):
5555
from graphgen.models import NetworkXStorage
5656

5757
self.graph = NetworkXStorage(working_dir, namespace)
58-
if backend == "kuzu":
58+
elif backend == "kuzu":
5959
from graphgen.models import KuzuStorage
6060

6161
self.graph = KuzuStorage(working_dir, namespace)

graphgen/models/llm/local/sglang_wrapper.py

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -138,15 +138,3 @@ async def generate_inputs_prob(
138138
raise NotImplementedError(
139139
"SGLangWrapper does not support per-token logprobs yet."
140140
)
141-
142-
def shutdown(self) -> None:
143-
"""Gracefully shutdown the SGLang engine."""
144-
if hasattr(self, "engine"):
145-
self.engine.shutdown()
146-
147-
def restart(self) -> None:
148-
"""Restart the SGLang engine."""
149-
self.shutdown()
150-
self.engine = self.engine.__class__(
151-
model_path=self.model_path, tp_size=self.tp_size
152-
)

graphgen/operators/read/parallel_file_scanner.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ class ParallelFileScanner:
1111
def __init__(
1212
self, cache_dir: str, allowed_suffix, rescan: bool = False, max_workers: int = 4
1313
):
14-
self.cache = RocksDBCache(os.path.join(cache_dir, "file_paths_cache"))
14+
self.cache = RocksDBCache(os.path.join(cache_dir, "input_paths.db"))
1515
self.allowed_suffix = set(allowed_suffix) if allowed_suffix else None
1616
self.rescan = rescan
1717
self.max_workers = max_workers

0 commit comments

Comments
 (0)