From 2af78f982988b8d528e52fa30f74e8160f5fc934 Mon Sep 17 00:00:00 2001 From: chenzihong_gavin <522023320011@smail.nju.edu.cn> Date: Tue, 16 Dec 2025 19:54:25 +0800 Subject: [PATCH 1/2] fix: fix data type of num_gpus & generate_topk_per_token of vllmwrapper --- graphgen/common/init_llm.py | 2 +- graphgen/models/llm/local/vllm_wrapper.py | 29 +++++++++++++++-------- 2 files changed, 20 insertions(+), 11 deletions(-) diff --git a/graphgen/common/init_llm.py b/graphgen/common/init_llm.py index 1e4f8cc7..af53709a 100644 --- a/graphgen/common/init_llm.py +++ b/graphgen/common/init_llm.py @@ -131,7 +131,7 @@ def create_llm( ray.get_actor(actor_name) except ValueError: print(f"Creating Ray actor for LLM {model_type} with backend {backend}.") - num_gpus = int(config.pop("num_gpus", 0)) + num_gpus = float(config.pop("num_gpus", 0)) actor = ( ray.remote(LLMServiceActor) .options( diff --git a/graphgen/models/llm/local/vllm_wrapper.py b/graphgen/models/llm/local/vllm_wrapper.py index b8d8a6de..88d3183e 100644 --- a/graphgen/models/llm/local/vllm_wrapper.py +++ b/graphgen/models/llm/local/vllm_wrapper.py @@ -33,8 +33,8 @@ def __init__( engine_args = AsyncEngineArgs( model=model, - tensor_parallel_size=tensor_parallel_size, - gpu_memory_utilization=gpu_memory_utilization, + tensor_parallel_size=int(tensor_parallel_size), + gpu_memory_utilization=float(gpu_memory_utilization), trust_remote_code=kwargs.get("trust_remote_code", True), disable_log_stats=False, ) @@ -82,15 +82,15 @@ async def generate_answer( async def generate_topk_per_token( self, text: str, history: Optional[List[str]] = None, **extra: Any - ) -> List[Token]: + ) -> List[Token]: full_prompt = self._build_inputs(text, history) - request_id = f"graphgen_topk_{uuid.uuid4()}" sp = self.SamplingParams( temperature=0, max_tokens=1, logprobs=self.topk, + prompt_logprobs=1, ) result_generator = self.engine.generate(full_prompt, sp, request_id=request_id) @@ -108,14 +108,23 @@ async def generate_topk_per_token( top_logprobs = final_output.outputs[0].logprobs[0] - tokens = [] + candidate_tokens = [] for _, logprob_obj in top_logprobs.items(): - tok_str = logprob_obj.decoded_token + tok_str = logprob_obj.decoded_token.strip() if logprob_obj.decoded_token else "" prob = float(math.exp(logprob_obj.logprob)) - tokens.append(Token(tok_str, prob)) - - tokens.sort(key=lambda x: -x.prob) - return tokens + candidate_tokens.append(Token(tok_str, prob)) + + candidate_tokens.sort(key=lambda x: -x.prob) + + if candidate_tokens: + main_token = Token( + text=candidate_tokens[0].text, + prob=candidate_tokens[0].prob, + top_candidates=candidate_tokens + ) + return [main_token] + else: + return [] async def generate_inputs_prob( self, text: str, history: Optional[List[str]] = None, **extra: Any From 52251be1654d15cf6697d85f94fb7541eba4c49e Mon Sep 17 00:00:00 2001 From: chenzihong_gavin <522023320011@smail.nju.edu.cn> Date: Tue, 16 Dec 2025 20:04:03 +0800 Subject: [PATCH 2/2] fix: fix lint error --- graphgen/models/llm/local/vllm_wrapper.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/graphgen/models/llm/local/vllm_wrapper.py b/graphgen/models/llm/local/vllm_wrapper.py index 88d3183e..c6e5feac 100644 --- a/graphgen/models/llm/local/vllm_wrapper.py +++ b/graphgen/models/llm/local/vllm_wrapper.py @@ -113,9 +113,9 @@ async def generate_topk_per_token( tok_str = logprob_obj.decoded_token.strip() if logprob_obj.decoded_token else "" prob = float(math.exp(logprob_obj.logprob)) candidate_tokens.append(Token(tok_str, prob)) - + candidate_tokens.sort(key=lambda x: -x.prob) - + if candidate_tokens: main_token = Token( text=candidate_tokens[0].text, @@ -123,8 +123,7 @@ async def generate_topk_per_token( top_candidates=candidate_tokens ) return [main_token] - else: - return [] + return [] async def generate_inputs_prob( self, text: str, history: Optional[List[str]] = None, **extra: Any