diff --git a/graphgen/common/init_llm.py b/graphgen/common/init_llm.py index 1e4f8cc7..af53709a 100644 --- a/graphgen/common/init_llm.py +++ b/graphgen/common/init_llm.py @@ -131,7 +131,7 @@ def create_llm( ray.get_actor(actor_name) except ValueError: print(f"Creating Ray actor for LLM {model_type} with backend {backend}.") - num_gpus = int(config.pop("num_gpus", 0)) + num_gpus = float(config.pop("num_gpus", 0)) actor = ( ray.remote(LLMServiceActor) .options( diff --git a/graphgen/models/llm/local/vllm_wrapper.py b/graphgen/models/llm/local/vllm_wrapper.py index b8d8a6de..c6e5feac 100644 --- a/graphgen/models/llm/local/vllm_wrapper.py +++ b/graphgen/models/llm/local/vllm_wrapper.py @@ -33,8 +33,8 @@ def __init__( engine_args = AsyncEngineArgs( model=model, - tensor_parallel_size=tensor_parallel_size, - gpu_memory_utilization=gpu_memory_utilization, + tensor_parallel_size=int(tensor_parallel_size), + gpu_memory_utilization=float(gpu_memory_utilization), trust_remote_code=kwargs.get("trust_remote_code", True), disable_log_stats=False, ) @@ -82,15 +82,15 @@ async def generate_answer( async def generate_topk_per_token( self, text: str, history: Optional[List[str]] = None, **extra: Any - ) -> List[Token]: + ) -> List[Token]: full_prompt = self._build_inputs(text, history) - request_id = f"graphgen_topk_{uuid.uuid4()}" sp = self.SamplingParams( temperature=0, max_tokens=1, logprobs=self.topk, + prompt_logprobs=1, ) result_generator = self.engine.generate(full_prompt, sp, request_id=request_id) @@ -108,14 +108,22 @@ async def generate_topk_per_token( top_logprobs = final_output.outputs[0].logprobs[0] - tokens = [] + candidate_tokens = [] for _, logprob_obj in top_logprobs.items(): - tok_str = logprob_obj.decoded_token + tok_str = logprob_obj.decoded_token.strip() if logprob_obj.decoded_token else "" prob = float(math.exp(logprob_obj.logprob)) - tokens.append(Token(tok_str, prob)) - - tokens.sort(key=lambda x: -x.prob) - return tokens + candidate_tokens.append(Token(tok_str, prob)) + + candidate_tokens.sort(key=lambda x: -x.prob) + + if candidate_tokens: + main_token = Token( + text=candidate_tokens[0].text, + prob=candidate_tokens[0].prob, + top_candidates=candidate_tokens + ) + return [main_token] + return [] async def generate_inputs_prob( self, text: str, history: Optional[List[str]] = None, **extra: Any