diff --git a/graphgen/models/llm/local/vllm_wrapper.py b/graphgen/models/llm/local/vllm_wrapper.py index fc412b51..74dc3c4e 100644 --- a/graphgen/models/llm/local/vllm_wrapper.py +++ b/graphgen/models/llm/local/vllm_wrapper.py @@ -1,6 +1,7 @@ import math import uuid from typing import Any, List, Optional +import asyncio from graphgen.bases.base_llm_wrapper import BaseLLMWrapper from graphgen.bases.datatypes import Token @@ -19,6 +20,7 @@ def __init__( temperature: float = 0.6, top_p: float = 1.0, topk: int = 5, + timeout: float = 300.0, **kwargs: Any, ): super().__init__(temperature=temperature, top_p=top_p, **kwargs) @@ -42,6 +44,7 @@ def __init__( self.temperature = temperature self.top_p = top_p self.topk = topk + self.timeout = timeout @staticmethod def _build_inputs(prompt: str, history: Optional[List[str]] = None) -> str: @@ -57,6 +60,12 @@ def _build_inputs(prompt: str, history: Optional[List[str]] = None) -> str: lines.append(prompt) return "\n".join(lines) + async def _consume_generator(self, generator): + final_output = None + async for request_output in generator: + final_output = request_output + return final_output + async def generate_answer( self, text: str, history: Optional[List[str]] = None, **extra: Any ) -> str: @@ -71,14 +80,27 @@ async def generate_answer( result_generator = self.engine.generate(full_prompt, sp, request_id=request_id) - final_output = None - async for request_output in result_generator: - final_output = request_output - - if not final_output or not final_output.outputs: - return "" - - return final_output.outputs[0].text + try: + final_output = await asyncio.wait_for( + self._consume_generator(result_generator), + timeout=self.timeout + ) + + if not final_output or not final_output.outputs: + return "" + + result_text = final_output.outputs[0].text + return result_text + + except asyncio.TimeoutError: + await self.engine.abort(request_id) + raise + except asyncio.CancelledError: + await self.engine.abort(request_id) + raise + except Exception as e: + await self.engine.abort(request_id) + raise async def generate_topk_per_token( self, text: str, history: Optional[List[str]] = None, **extra: Any @@ -95,37 +117,49 @@ async def generate_topk_per_token( result_generator = self.engine.generate(full_prompt, sp, request_id=request_id) - final_output = None - async for request_output in result_generator: - final_output = request_output - - if ( - not final_output - or not final_output.outputs - or not final_output.outputs[0].logprobs - ): - return [] - - top_logprobs = final_output.outputs[0].logprobs[0] - - candidate_tokens = [] - for _, logprob_obj in top_logprobs.items(): - tok_str = ( - logprob_obj.decoded_token.strip() if logprob_obj.decoded_token else "" - ) - prob = float(math.exp(logprob_obj.logprob)) - candidate_tokens.append(Token(tok_str, prob)) - - candidate_tokens.sort(key=lambda x: -x.prob) - - if candidate_tokens: - main_token = Token( - text=candidate_tokens[0].text, - prob=candidate_tokens[0].prob, - top_candidates=candidate_tokens, + try: + final_output = await asyncio.wait_for( + self._consume_generator(result_generator), + timeout=self.timeout ) - return [main_token] - return [] + + if ( + not final_output + or not final_output.outputs + or not final_output.outputs[0].logprobs + ): + return [] + + top_logprobs = final_output.outputs[0].logprobs[0] + + candidate_tokens = [] + for _, logprob_obj in top_logprobs.items(): + tok_str = ( + logprob_obj.decoded_token.strip() if logprob_obj.decoded_token else "" + ) + prob = float(math.exp(logprob_obj.logprob)) + candidate_tokens.append(Token(tok_str, prob)) + + candidate_tokens.sort(key=lambda x: -x.prob) + + if candidate_tokens: + main_token = Token( + text=candidate_tokens[0].text, + prob=candidate_tokens[0].prob, + top_candidates=candidate_tokens, + ) + return [main_token] + return [] + + except asyncio.TimeoutError: + await self.engine.abort(request_id) + raise + except asyncio.CancelledError: + await self.engine.abort(request_id) + raise + except Exception as e: + await self.engine.abort(request_id) + raise async def generate_inputs_prob( self, text: str, history: Optional[List[str]] = None, **extra: Any @@ -133,3 +167,4 @@ async def generate_inputs_prob( raise NotImplementedError( "VLLMWrapper does not support per-token logprobs yet." ) +