Skip to content

Commit 18a67be

Browse files
fix: fix vllm wrapper
1 parent c7e32b0 commit 18a67be

File tree

2 files changed

+33
-52
lines changed

2 files changed

+33
-52
lines changed

graphgen/common/init_llm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,7 @@ def create_llm(
131131
ray.get_actor(actor_name)
132132
except ValueError:
133133
print(f"Creating Ray actor for LLM {model_type} with backend {backend}.")
134-
num_gpus = config.pop("num_gpus", 0)
134+
num_gpus = int(config.pop("num_gpus", 0))
135135
actor = (
136136
ray.remote(LLMServiceActor)
137137
.options(
Lines changed: 32 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,13 @@
1+
import uuid
2+
import math
13
from typing import Any, List, Optional
2-
34
from graphgen.bases.base_llm_wrapper import BaseLLMWrapper
45
from graphgen.bases.datatypes import Token
56

6-
77
class VLLMWrapper(BaseLLMWrapper):
88
"""
9-
Async inference backend based on vLLM (https://github.com/vllm-project/vllm)
9+
Async inference backend based on vLLM.
1010
"""
11-
1211
def __init__(
1312
self,
1413
model: str,
@@ -20,12 +19,11 @@ def __init__(
2019
**kwargs: Any,
2120
):
2221
super().__init__(temperature=temperature, top_p=top_p, **kwargs)
23-
2422
try:
2523
from vllm import AsyncEngineArgs, AsyncLLMEngine, SamplingParams
2624
except ImportError as exc:
2725
raise ImportError(
28-
"VLLMWrapper requires vllm. Install it with: uv pip install vllm --torch-backend=auto"
26+
"VLLMWrapper requires vllm. Install it with: uv pip install vllm"
2927
) from exc
3028

3129
self.SamplingParams = SamplingParams
@@ -35,9 +33,9 @@ def __init__(
3533
tensor_parallel_size=tensor_parallel_size,
3634
gpu_memory_utilization=gpu_memory_utilization,
3735
trust_remote_code=kwargs.get("trust_remote_code", True),
36+
disable_log_stats=False,
3837
)
3938
self.engine = AsyncLLMEngine.from_engine_args(engine_args)
40-
4139
self.temperature = temperature
4240
self.top_p = top_p
4341
self.topk = topk
@@ -60,78 +58,61 @@ async def generate_answer(
6058
self, text: str, history: Optional[List[str]] = None, **extra: Any
6159
) -> str:
6260
full_prompt = self._build_inputs(text, history)
61+
request_id = f"graphgen_req_{uuid.uuid4()}"
6362

6463
sp = self.SamplingParams(
6564
temperature=self.temperature if self.temperature > 0 else 1.0,
6665
top_p=self.top_p if self.temperature > 0 else 1.0,
6766
max_tokens=extra.get("max_new_tokens", 512),
6867
)
6968

70-
results = []
71-
async for req_output in self.engine.generate(
72-
full_prompt, sp, request_id="graphgen_req"
73-
):
74-
results = req_output.outputs
75-
return results[-1].text
69+
result_generator = self.engine.generate(full_prompt, sp, request_id=request_id)
70+
71+
final_output = None
72+
async for request_output in result_generator:
73+
final_output = request_output
74+
75+
if not final_output or not final_output.outputs:
76+
return ""
77+
78+
return final_output.outputs[0].text
7679

7780
async def generate_topk_per_token(
7881
self, text: str, history: Optional[List[str]] = None, **extra: Any
7982
) -> List[Token]:
8083
full_prompt = self._build_inputs(text, history)
8184

85+
request_id = f"graphgen_topk_{uuid.uuid4()}"
86+
8287
sp = self.SamplingParams(
8388
temperature=0,
8489
max_tokens=1,
8590
logprobs=self.topk,
8691
)
8792

88-
results = []
89-
async for req_output in self.engine.generate(
90-
full_prompt, sp, request_id="graphgen_topk"
91-
):
92-
results = req_output.outputs
93-
top_logprobs = results[-1].logprobs[0]
93+
result_generator = self.engine.generate(full_prompt, sp, request_id=request_id)
94+
95+
final_output = None
96+
async for request_output in result_generator:
97+
final_output = request_output
98+
99+
if not final_output or not final_output.outputs or not final_output.outputs[0].logprobs:
100+
return []
101+
102+
top_logprobs = final_output.outputs[0].logprobs[0]
94103

95104
tokens = []
96105
for _, logprob_obj in top_logprobs.items():
97106
tok_str = logprob_obj.decoded_token
98-
prob = float(logprob_obj.logprob.exp())
107+
prob = float(math.exp(logprob_obj.logprob))
99108
tokens.append(Token(tok_str, prob))
109+
100110
tokens.sort(key=lambda x: -x.prob)
101111
return tokens
102112

103113
async def generate_inputs_prob(
104114
self, text: str, history: Optional[List[str]] = None, **extra: Any
105115
) -> List[Token]:
106-
full_prompt = self._build_inputs(text, history)
107-
108-
# vLLM 没有现成的“mask 一个 token 再算 prob”接口,
109-
# 我们采用最直观的方式:把 prompt 一次性送进去,打开
110-
# prompt_logprobs=True,让 vLLM 返回 *输入部分* 每个位置的
111-
# logprob,然后挑出对应 token 的概率即可。
112-
sp = self.SamplingParams(
113-
temperature=0,
114-
max_tokens=0, # 不生成新 token
115-
prompt_logprobs=1, # 只要 top-1 就够了
116+
raise NotImplementedError(
117+
"VLLMWrapper does not support per-token logprobs yet."
116118
)
117-
118-
results = []
119-
async for req_output in self.engine.generate(
120-
full_prompt, sp, request_id="graphgen_prob"
121-
):
122-
results = req_output.outputs
123-
124-
# prompt_logprobs 是一个 list,长度 = prompt token 数,
125-
# 每个元素是 dict{token_id: logprob_obj} 或 None(首个位置为 None)
126-
prompt_logprobs = results[-1].prompt_logprobs
127-
128-
tokens = []
129-
for _, logprob_dict in enumerate(prompt_logprobs):
130-
if logprob_dict is None:
131-
continue
132-
# 这里每个 dict 只有 1 个 kv,因为 top-1
133-
_, logprob_obj = next(iter(logprob_dict.items()))
134-
tok_str = logprob_obj.decoded_token
135-
prob = float(logprob_obj.logprob.exp())
136-
tokens.append(Token(tok_str, prob))
137-
return tokens

0 commit comments

Comments
 (0)