1+ import uuid
2+ import math
13from typing import Any , List , Optional
2-
34from graphgen .bases .base_llm_wrapper import BaseLLMWrapper
45from graphgen .bases .datatypes import Token
56
6-
77class VLLMWrapper (BaseLLMWrapper ):
88 """
9- Async inference backend based on vLLM (https://github.com/vllm-project/vllm)
9+ Async inference backend based on vLLM.
1010 """
11-
1211 def __init__ (
1312 self ,
1413 model : str ,
@@ -20,12 +19,11 @@ def __init__(
2019 ** kwargs : Any ,
2120 ):
2221 super ().__init__ (temperature = temperature , top_p = top_p , ** kwargs )
23-
2422 try :
2523 from vllm import AsyncEngineArgs , AsyncLLMEngine , SamplingParams
2624 except ImportError as exc :
2725 raise ImportError (
28- "VLLMWrapper requires vllm. Install it with: uv pip install vllm --torch-backend=auto "
26+ "VLLMWrapper requires vllm. Install it with: uv pip install vllm"
2927 ) from exc
3028
3129 self .SamplingParams = SamplingParams
@@ -35,9 +33,9 @@ def __init__(
3533 tensor_parallel_size = tensor_parallel_size ,
3634 gpu_memory_utilization = gpu_memory_utilization ,
3735 trust_remote_code = kwargs .get ("trust_remote_code" , True ),
36+ disable_log_stats = False ,
3837 )
3938 self .engine = AsyncLLMEngine .from_engine_args (engine_args )
40-
4139 self .temperature = temperature
4240 self .top_p = top_p
4341 self .topk = topk
@@ -60,78 +58,61 @@ async def generate_answer(
6058 self , text : str , history : Optional [List [str ]] = None , ** extra : Any
6159 ) -> str :
6260 full_prompt = self ._build_inputs (text , history )
61+ request_id = f"graphgen_req_{ uuid .uuid4 ()} "
6362
6463 sp = self .SamplingParams (
6564 temperature = self .temperature if self .temperature > 0 else 1.0 ,
6665 top_p = self .top_p if self .temperature > 0 else 1.0 ,
6766 max_tokens = extra .get ("max_new_tokens" , 512 ),
6867 )
6968
70- results = []
71- async for req_output in self .engine .generate (
72- full_prompt , sp , request_id = "graphgen_req"
73- ):
74- results = req_output .outputs
75- return results [- 1 ].text
69+ result_generator = self .engine .generate (full_prompt , sp , request_id = request_id )
70+
71+ final_output = None
72+ async for request_output in result_generator :
73+ final_output = request_output
74+
75+ if not final_output or not final_output .outputs :
76+ return ""
77+
78+ return final_output .outputs [0 ].text
7679
7780 async def generate_topk_per_token (
7881 self , text : str , history : Optional [List [str ]] = None , ** extra : Any
7982 ) -> List [Token ]:
8083 full_prompt = self ._build_inputs (text , history )
8184
85+ request_id = f"graphgen_topk_{ uuid .uuid4 ()} "
86+
8287 sp = self .SamplingParams (
8388 temperature = 0 ,
8489 max_tokens = 1 ,
8590 logprobs = self .topk ,
8691 )
8792
88- results = []
89- async for req_output in self .engine .generate (
90- full_prompt , sp , request_id = "graphgen_topk"
91- ):
92- results = req_output .outputs
93- top_logprobs = results [- 1 ].logprobs [0 ]
93+ result_generator = self .engine .generate (full_prompt , sp , request_id = request_id )
94+
95+ final_output = None
96+ async for request_output in result_generator :
97+ final_output = request_output
98+
99+ if not final_output or not final_output .outputs or not final_output .outputs [0 ].logprobs :
100+ return []
101+
102+ top_logprobs = final_output .outputs [0 ].logprobs [0 ]
94103
95104 tokens = []
96105 for _ , logprob_obj in top_logprobs .items ():
97106 tok_str = logprob_obj .decoded_token
98- prob = float (logprob_obj . logprob . exp ())
107+ prob = float (math . exp (logprob_obj . logprob ))
99108 tokens .append (Token (tok_str , prob ))
109+
100110 tokens .sort (key = lambda x : - x .prob )
101111 return tokens
102112
103113 async def generate_inputs_prob (
104114 self , text : str , history : Optional [List [str ]] = None , ** extra : Any
105115 ) -> List [Token ]:
106- full_prompt = self ._build_inputs (text , history )
107-
108- # vLLM 没有现成的“mask 一个 token 再算 prob”接口,
109- # 我们采用最直观的方式:把 prompt 一次性送进去,打开
110- # prompt_logprobs=True,让 vLLM 返回 *输入部分* 每个位置的
111- # logprob,然后挑出对应 token 的概率即可。
112- sp = self .SamplingParams (
113- temperature = 0 ,
114- max_tokens = 0 , # 不生成新 token
115- prompt_logprobs = 1 , # 只要 top-1 就够了
116+ raise NotImplementedError (
117+ "VLLMWrapper does not support per-token logprobs yet."
116118 )
117-
118- results = []
119- async for req_output in self .engine .generate (
120- full_prompt , sp , request_id = "graphgen_prob"
121- ):
122- results = req_output .outputs
123-
124- # prompt_logprobs 是一个 list,长度 = prompt token 数,
125- # 每个元素是 dict{token_id: logprob_obj} 或 None(首个位置为 None)
126- prompt_logprobs = results [- 1 ].prompt_logprobs
127-
128- tokens = []
129- for _ , logprob_dict in enumerate (prompt_logprobs ):
130- if logprob_dict is None :
131- continue
132- # 这里每个 dict 只有 1 个 kv,因为 top-1
133- _ , logprob_obj = next (iter (logprob_dict .items ()))
134- tok_str = logprob_obj .decoded_token
135- prob = float (logprob_obj .logprob .exp ())
136- tokens .append (Token (tok_str , prob ))
137- return tokens
0 commit comments