11import os
22from typing import Any , Dict , Optional
33
4+ import ray
5+
46from graphgen .bases import BaseLLMWrapper
7+ from graphgen .common .init_storage import get_actor_handle
58from graphgen .models import Tokenizer
69
710
8- class LLMFactory :
11+ class LLMServiceActor :
912 """
10- A factory class to create LLM wrapper instances based on the specified backend.
11- Supported backends include:
12- - http_api: HTTPClient
13- - openai_api: OpenAIClient
14- - ollama_api: OllamaClient
15- - huggingface: HuggingFaceWrapper
16- - sglang: SGLangWrapper
13+ A Ray actor class to wrap LLM wrapper instances for distributed usage.
1714 """
1815
19- @staticmethod
20- def create_llm_wrapper (backend : str , config : Dict [str , Any ]) -> BaseLLMWrapper :
21- # add tokenizer
22- tokenizer : Tokenizer = Tokenizer (
23- os .environ .get ("TOKENIZER_MODEL" , "cl100k_base" ),
24- )
16+ def __init__ (self , backend : str , config : Dict [str , Any ]):
17+ self .backend = backend
18+ tokenizer_model = os .environ .get ("TOKENIZER_MODEL" , "cl100k_base" )
19+ tokenizer = Tokenizer (model_name = tokenizer_model )
2520 config ["tokenizer" ] = tokenizer
21+
2622 if backend == "http_api" :
2723 from graphgen .models .llm .api .http_client import HTTPClient
2824
29- return HTTPClient (** config )
30- if backend in ("openai_api" , "azure_openai_api" ):
25+ self . llm_instance = HTTPClient (** config )
26+ elif backend in ("openai_api" , "azure_openai_api" ):
3127 from graphgen .models .llm .api .openai_client import OpenAIClient
3228
3329 # pass in concrete backend to the OpenAIClient so that internally we can distinguish
3430 # between OpenAI and Azure OpenAI
35- return OpenAIClient (** config , backend = backend )
36- if backend == "ollama_api" :
31+ self . llm_instance = OpenAIClient (** config , backend = backend )
32+ elif backend == "ollama_api" :
3733 from graphgen .models .llm .api .ollama_client import OllamaClient
3834
39- return OllamaClient (** config )
40- if backend == "huggingface" :
35+ self . llm_instance = OllamaClient (** config )
36+ elif backend == "huggingface" :
4137 from graphgen .models .llm .local .hf_wrapper import HuggingFaceWrapper
4238
43- return HuggingFaceWrapper (** config )
44- if backend == "sglang" :
39+ self . llm_instance = HuggingFaceWrapper (** config )
40+ elif backend == "sglang" :
4541 from graphgen .models .llm .local .sglang_wrapper import SGLangWrapper
4642
47- return SGLangWrapper (** config )
43+ self .llm_instance = SGLangWrapper (** config )
44+
45+ elif backend == "vllm" :
46+ from graphgen .models .llm .local .vllm_wrapper import VLLMWrapper
47+
48+ self .llm_instance = VLLMWrapper (** config )
49+ else :
50+ raise NotImplementedError (f"Backend { backend } is not implemented yet." )
51+
52+ async def generate_answer (
53+ self , text : str , history : Optional [list [str ]] = None , ** extra : Any
54+ ) -> str :
55+ return await self .llm_instance .generate_answer (text , history , ** extra )
56+
57+ async def generate_topk_per_token (
58+ self , text : str , history : Optional [list [str ]] = None , ** extra : Any
59+ ) -> list :
60+ return await self .llm_instance .generate_topk_per_token (text , history , ** extra )
4861
49- # if backend == "vllm":
50- # from graphgen.models.llm.local.vllm_wrapper import VLLMWrapper
51- #
52- # return VLLMWrapper(**config )
62+ async def generate_inputs_prob (
63+ self , text : str , history : Optional [ list [ str ]] = None , ** extra : Any
64+ ) -> list :
65+ return await self . llm_instance . generate_inputs_prob ( text , history , ** extra )
5366
54- raise NotImplementedError (f"Backend { backend } is not implemented yet." )
67+ def ready (self ) -> bool :
68+ """A simple method to check if the actor is ready."""
69+ return True
70+
71+
72+ class LLMServiceProxy (BaseLLMWrapper ):
73+ """
74+ A proxy class to interact with the LLMServiceActor for distributed LLM operations.
75+ """
76+
77+ def __init__ (self , actor_name : str ):
78+ super ().__init__ ()
79+ self .actor_handle = get_actor_handle (actor_name )
80+ self ._create_local_tokenizer ()
81+
82+ async def generate_answer (
83+ self , text : str , history : Optional [list [str ]] = None , ** extra : Any
84+ ) -> str :
85+ object_ref = self .actor_handle .generate_answer .remote (text , history , ** extra )
86+ return await object_ref
87+
88+ async def generate_topk_per_token (
89+ self , text : str , history : Optional [list [str ]] = None , ** extra : Any
90+ ) -> list :
91+ object_ref = self .actor_handle .generate_topk_per_token .remote (
92+ text , history , ** extra
93+ )
94+ return await object_ref
95+
96+ async def generate_inputs_prob (
97+ self , text : str , history : Optional [list [str ]] = None , ** extra : Any
98+ ) -> list :
99+ object_ref = self .actor_handle .generate_inputs_prob .remote (
100+ text , history , ** extra
101+ )
102+ return await object_ref
103+
104+ def _create_local_tokenizer (self ):
105+ tokenizer_model = os .environ .get ("TOKENIZER_MODEL" , "cl100k_base" )
106+ self .tokenizer = Tokenizer (model_name = tokenizer_model )
107+
108+
109+ class LLMFactory :
110+ """
111+ A factory class to create LLM wrapper instances based on the specified backend.
112+ Supported backends include:
113+ - http_api: HTTPClient
114+ - openai_api: OpenAIClient
115+ - ollama_api: OllamaClient
116+ - huggingface: HuggingFaceWrapper
117+ - sglang: SGLangWrapper
118+ """
119+
120+ @staticmethod
121+ def create_llm (
122+ model_type : str , backend : str , config : Dict [str , Any ]
123+ ) -> BaseLLMWrapper :
124+ if not config :
125+ raise ValueError (
126+ f"No configuration provided for LLM { model_type } with backend { backend } ."
127+ )
128+
129+ actor_name = f"Actor_LLM_{ model_type } "
130+ try :
131+ ray .get_actor (actor_name )
132+ except ValueError :
133+ print (f"Creating Ray actor for LLM { model_type } with backend { backend } ." )
134+ num_gpus = config .pop ("num_gpus" , 0 )
135+ actor = (
136+ ray .remote (LLMServiceActor )
137+ .options (
138+ name = actor_name ,
139+ num_gpus = num_gpus ,
140+ lifetime = "detached" ,
141+ get_if_exists = True ,
142+ )
143+ .remote (backend , config )
144+ )
145+
146+ # wait for actor to be ready
147+ ray .get (actor .ready .remote ())
148+
149+ return LLMServiceProxy (actor_name )
55150
56151
57152def _load_env_group (prefix : str ) -> Dict [str , Any ]:
@@ -78,8 +173,5 @@ def init_llm(model_type: str) -> Optional[BaseLLMWrapper]:
78173 if not config :
79174 return None
80175 backend = config .pop ("backend" )
81- llm_wrapper = LLMFactory .create_llm_wrapper ( backend , config )
176+ llm_wrapper = LLMFactory .create_llm ( model_type , backend , config )
82177 return llm_wrapper
83-
84-
85- # TODO: use ray serve when loading large models to avoid re-loading in each actor
0 commit comments