Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
111 changes: 73 additions & 38 deletions graphgen/models/llm/local/vllm_wrapper.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import math
import uuid
from typing import Any, List, Optional
import asyncio

from graphgen.bases.base_llm_wrapper import BaseLLMWrapper
from graphgen.bases.datatypes import Token
Expand All @@ -19,6 +20,7 @@ def __init__(
temperature: float = 0.6,
top_p: float = 1.0,
topk: int = 5,
timeout: float = 300.0,
**kwargs: Any,
):
super().__init__(temperature=temperature, top_p=top_p, **kwargs)
Expand All @@ -42,6 +44,7 @@ def __init__(
self.temperature = temperature
self.top_p = top_p
self.topk = topk
self.timeout = timeout

@staticmethod
def _build_inputs(prompt: str, history: Optional[List[str]] = None) -> str:
Expand All @@ -57,6 +60,12 @@ def _build_inputs(prompt: str, history: Optional[List[str]] = None) -> str:
lines.append(prompt)
return "\n".join(lines)

async def _consume_generator(self, generator):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

For better code clarity and maintainability, it's good practice to add type hints to method signatures. Since the specific vLLM types are not imported at the module level, using typing.Any is a reasonable approach here.

Suggested change
async def _consume_generator(self, generator):
async def _consume_generator(self, generator: Any) -> Any:

final_output = None
async for request_output in generator:
final_output = request_output
return final_output

async def generate_answer(
self, text: str, history: Optional[List[str]] = None, **extra: Any
) -> str:
Expand All @@ -71,14 +80,27 @@ async def generate_answer(

result_generator = self.engine.generate(full_prompt, sp, request_id=request_id)

final_output = None
async for request_output in result_generator:
final_output = request_output

if not final_output or not final_output.outputs:
return ""

return final_output.outputs[0].text
try:
final_output = await asyncio.wait_for(
self._consume_generator(result_generator),
timeout=self.timeout
)

if not final_output or not final_output.outputs:
return ""

result_text = final_output.outputs[0].text
return result_text

except asyncio.TimeoutError:
await self.engine.abort(request_id)
raise
except asyncio.CancelledError:
await self.engine.abort(request_id)
raise
except Exception as e:
await self.engine.abort(request_id)
raise
Comment on lines +95 to +103
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The exception handling for TimeoutError, CancelledError, and Exception contains duplicated code (await self.engine.abort(request_id)). You can combine these into a single except block to make the code more concise and maintainable. asyncio.TimeoutError inherits from Exception, while asyncio.CancelledError inherits from BaseException, so catching (Exception, asyncio.CancelledError) covers all intended cases.

Suggested change
except asyncio.TimeoutError:
await self.engine.abort(request_id)
raise
except asyncio.CancelledError:
await self.engine.abort(request_id)
raise
except Exception as e:
await self.engine.abort(request_id)
raise
except (Exception, asyncio.CancelledError):
await self.engine.abort(request_id)
raise


async def generate_topk_per_token(
self, text: str, history: Optional[List[str]] = None, **extra: Any
Expand All @@ -95,41 +117,54 @@ async def generate_topk_per_token(

result_generator = self.engine.generate(full_prompt, sp, request_id=request_id)

final_output = None
async for request_output in result_generator:
final_output = request_output

if (
not final_output
or not final_output.outputs
or not final_output.outputs[0].logprobs
):
return []

top_logprobs = final_output.outputs[0].logprobs[0]

candidate_tokens = []
for _, logprob_obj in top_logprobs.items():
tok_str = (
logprob_obj.decoded_token.strip() if logprob_obj.decoded_token else ""
)
prob = float(math.exp(logprob_obj.logprob))
candidate_tokens.append(Token(tok_str, prob))

candidate_tokens.sort(key=lambda x: -x.prob)

if candidate_tokens:
main_token = Token(
text=candidate_tokens[0].text,
prob=candidate_tokens[0].prob,
top_candidates=candidate_tokens,
try:
final_output = await asyncio.wait_for(
self._consume_generator(result_generator),
timeout=self.timeout
)
return [main_token]
return []

if (
not final_output
or not final_output.outputs
or not final_output.outputs[0].logprobs
):
return []

top_logprobs = final_output.outputs[0].logprobs[0]

candidate_tokens = []
for _, logprob_obj in top_logprobs.items():
tok_str = (
logprob_obj.decoded_token.strip() if logprob_obj.decoded_token else ""
)
prob = float(math.exp(logprob_obj.logprob))
candidate_tokens.append(Token(tok_str, prob))

candidate_tokens.sort(key=lambda x: -x.prob)

if candidate_tokens:
main_token = Token(
text=candidate_tokens[0].text,
prob=candidate_tokens[0].prob,
top_candidates=candidate_tokens,
)
return [main_token]
return []

except asyncio.TimeoutError:
await self.engine.abort(request_id)
raise
except asyncio.CancelledError:
await self.engine.abort(request_id)
raise
except Exception as e:
await self.engine.abort(request_id)
raise
Comment on lines +154 to +162
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Similar to the generate_answer method, this exception handling logic is duplicated. This can be simplified by catching (Exception, asyncio.CancelledError) in a single block.

More importantly, the entire try...except block for running the generator with a timeout is a duplicate of the one in generate_answer. This violates the DRY (Don't Repeat Yourself) principle. To improve maintainability, consider extracting this logic into a private helper method.

Suggested change
except asyncio.TimeoutError:
await self.engine.abort(request_id)
raise
except asyncio.CancelledError:
await self.engine.abort(request_id)
raise
except Exception as e:
await self.engine.abort(request_id)
raise
except (Exception, asyncio.CancelledError):
await self.engine.abort(request_id)
raise


async def generate_inputs_prob(
self, text: str, history: Optional[List[str]] = None, **extra: Any
) -> List[Token]:
raise NotImplementedError(
"VLLMWrapper does not support per-token logprobs yet."
)

Loading