-
Notifications
You must be signed in to change notification settings - Fork 120
Feat session cache #766
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Feat session cache #766
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -56,6 +56,39 @@ def __init__( | |
| if headers: | ||
| self.headers.update(headers) | ||
|
|
||
| # Shared aiohttp session to prevent memory leaks | ||
| # One session per handler instance, not per request | ||
| self._session: Optional[Any] = None | ||
|
|
||
| async def _get_session(self): | ||
| """Get or create the shared aiohttp session. | ||
|
|
||
| This method ensures we reuse a single session across all requests, | ||
| preventing memory leaks from creating/destroying sessions repeatedly. | ||
|
|
||
| Returns: | ||
| aiohttp.ClientSession: The shared session instance. | ||
| """ | ||
| import aiohttp | ||
| if self._session is None or self._session.closed: | ||
| # Create session with connection pooling | ||
| connector = aiohttp.TCPConnector( | ||
| limit=100, # Max connections | ||
| limit_per_host=30, # Max connections per host | ||
| ) | ||
| self._session = aiohttp.ClientSession(connector=connector) | ||
| return self._session | ||
|
|
||
| async def close(self): | ||
| """Close the shared aiohttp session. | ||
|
|
||
| Call this method when the handler is no longer needed to properly | ||
| clean up resources. | ||
| """ | ||
| if self._session and not self._session.closed: | ||
| await self._session.close() | ||
| self._session = None | ||
|
Comment on lines
+82
to
+90
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这个 class LLMHTTPHandler:
# ... existing code ...
async def __aenter__(self):
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
await self.close()
# 用法:
async with LLMHTTPHandler(...) as handler:
await handler.async_call(...) |
||
|
|
||
| def _parse_sse_line(self, line: bytes) -> Optional[Dict[str, Any]]: | ||
| """Parse a Server-Sent Events (SSE) line. | ||
|
|
||
|
|
@@ -178,8 +211,8 @@ async def _make_async_request_stream( | |
| if headers: | ||
| request_headers.update(headers) | ||
|
|
||
| # Create an independent session and keep it open | ||
| session = aiohttp.ClientSession() | ||
| # Use the shared session instead of creating a new one | ||
| session = await self._get_session() | ||
| try: | ||
| response = await session.post( | ||
| url, | ||
|
|
@@ -215,9 +248,7 @@ async def _make_async_request_stream( | |
| except Exception as e: | ||
| logger.error(f"Error in stream: {str(e)}") | ||
| raise | ||
| finally: | ||
| # Ensure the session is eventually closed | ||
| await session.close() | ||
| # Note: We don't close the session here as it's shared and reused | ||
|
|
||
| async def _make_async_request( | ||
| self, | ||
|
|
@@ -237,21 +268,21 @@ async def _make_async_request( | |
| Raises: | ||
| aiohttp.ClientError: If the request fails. | ||
| """ | ||
| import aiohttp | ||
| url = f"{self.base_url}/{endpoint.lstrip('/')}" | ||
| request_headers = self.headers.copy() | ||
| if headers: | ||
| request_headers.update(headers) | ||
|
|
||
| async with aiohttp.ClientSession() as session: | ||
| async with session.post( | ||
| url, | ||
| headers=request_headers, | ||
| json=data, | ||
| timeout=self.timeout, | ||
| ) as response: | ||
| response.raise_for_status() | ||
| return await response.json() | ||
| # Use the shared session instead of creating a new one | ||
| session = await self._get_session() | ||
| async with session.post( | ||
| url, | ||
| headers=request_headers, | ||
| json=data, | ||
| timeout=self.timeout, | ||
| ) as response: | ||
| response.raise_for_status() | ||
| return await response.json() | ||
|
|
||
| def sync_call( | ||
| self, | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -34,15 +34,29 @@ | |
| ENDOFTEXT: 100256, | ||
| } | ||
|
|
||
| # Global cache to prevent memory leaks from repeatedly loading BPE files | ||
| _BPE_CACHE = {} | ||
|
|
||
|
|
||
| def _load_tiktoken_bpe(tiktoken_bpe_file: str) -> Dict[bytes, int]: | ||
| """Load tiktoken BPE file similar to qwen_tokenizer.""" | ||
| """Load tiktoken BPE file with caching to prevent memory leaks.""" | ||
| # Check cache first | ||
| if tiktoken_bpe_file in _BPE_CACHE: | ||
| return _BPE_CACHE[tiktoken_bpe_file] | ||
|
|
||
| # Load and decode file | ||
| with open(tiktoken_bpe_file, 'rb') as f: | ||
| contents = f.read() | ||
| return { | ||
| base64.b64decode(token): int(rank) for token, rank in (line.split() for line in contents.splitlines() if line) | ||
|
|
||
| result = { | ||
| base64.b64decode(token): int(rank) | ||
| for token, rank in (line.split() for line in contents.splitlines() if line) | ||
| } | ||
|
|
||
| # Cache the result | ||
| _BPE_CACHE[tiktoken_bpe_file] = result | ||
| return result | ||
|
Comment on lines
+38
to
+58
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The from functools import lru_cache
@lru_cache(maxsize=16)
def _load_tiktoken_bpe(tiktoken_bpe_file: str) -> Dict[bytes, int]:
"""Load tiktoken BPE file with caching to prevent memory leaks."""
# Load and decode file
with open(tiktoken_bpe_file, 'rb') as f:
contents = f.read()
result = {
base64.b64decode(token): int(rank)
for token, rank in (line.split() for line in contents.splitlines() if line)
}
return result |
||
|
|
||
|
|
||
| class OpenAITokenizer: | ||
| """OpenAI tokenizer using local tiktoken file.""" | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -45,14 +45,29 @@ | |
| )) | ||
| SPECIAL_TOKENS_SET = set(t for i, t in SPECIAL_TOKENS) | ||
|
|
||
| # Global cache to prevent memory leaks from repeatedly loading BPE files | ||
| _BPE_CACHE = {} | ||
|
|
||
|
|
||
| def _load_tiktoken_bpe(tiktoken_bpe_file: str) -> Dict[bytes, int]: | ||
| """Load tiktoken BPE file with caching to prevent memory leaks.""" | ||
| # Check cache first | ||
| if tiktoken_bpe_file in _BPE_CACHE: | ||
| return _BPE_CACHE[tiktoken_bpe_file] | ||
|
|
||
| # Load and decode file | ||
| with open(tiktoken_bpe_file, 'rb') as f: | ||
| contents = f.read() | ||
| return { | ||
| base64.b64decode(token): int(rank) for token, rank in (line.split() for line in contents.splitlines() if line) | ||
|
|
||
| result = { | ||
| base64.b64decode(token): int(rank) | ||
| for token, rank in (line.split() for line in contents.splitlines() if line) | ||
| } | ||
|
|
||
| # Cache the result | ||
| _BPE_CACHE[tiktoken_bpe_file] = result | ||
| return result | ||
|
Comment on lines
+49
to
+69
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The from functools import lru_cache
@lru_cache(maxsize=16)
def _load_tiktoken_bpe(tiktoken_bpe_file: str) -> Dict[bytes, int]:
"""Load tiktoken BPE file with caching to prevent memory leaks."""
# Load and decode file
with open(tiktoken_bpe_file, 'rb') as f:
contents = f.read()
result = {
base64.b64decode(token): int(rank)
for token, rank in (line.split() for line in contents.splitlines() if line)
}
return result |
||
|
|
||
|
|
||
| class QWenTokenizer: | ||
| """QWen tokenizer.""" | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -11,6 +11,34 @@ | |
| from aworld.models.openai_tokenizer import openai_tokenizer | ||
| from aworld.utils import import_package | ||
|
|
||
| # Global cache for tiktoken encodings to prevent memory leaks | ||
| _TIKTOKEN_ENCODING_CACHE = {} | ||
|
|
||
|
|
||
| def _get_cached_tiktoken_encoding(model: str): | ||
| """ | ||
| Get cached tiktoken encoding to prevent memory leaks. | ||
|
|
||
| Args: | ||
| model: Model name (e.g., 'gpt-4o', 'claude-3-opus') | ||
|
|
||
| Returns: | ||
| Cached tiktoken encoding object | ||
| """ | ||
| if model not in _TIKTOKEN_ENCODING_CACHE: | ||
| import tiktoken | ||
| try: | ||
| _TIKTOKEN_ENCODING_CACHE[model] = tiktoken.encoding_for_model(model) | ||
| logger.debug(f"Created and cached tiktoken encoding for model: {model}") | ||
| except KeyError: | ||
| logger.debug(f"{model} model not found. Using cl100k_base encoding.") | ||
| # Cache cl100k_base if not already cached | ||
| if "cl100k_base" not in _TIKTOKEN_ENCODING_CACHE: | ||
| _TIKTOKEN_ENCODING_CACHE["cl100k_base"] = tiktoken.get_encoding("cl100k_base") | ||
| # Reuse cl100k_base for this model | ||
| _TIKTOKEN_ENCODING_CACHE[model] = _TIKTOKEN_ENCODING_CACHE["cl100k_base"] | ||
| return _TIKTOKEN_ENCODING_CACHE[model] | ||
|
Comment on lines
+15
to
+40
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The from functools import lru_cache
@lru_cache(maxsize=128)
def _get_cached_tiktoken_encoding(model: str):
"""
Get cached tiktoken encoding to prevent memory leaks.
Args:
model: Model name (e.g., 'gpt-4o', 'claude-3-opus')
Returns:
Cached tiktoken encoding object
"""
import tiktoken
try:
encoding = tiktoken.encoding_for_model(model)
logger.debug(f"Created and cached tiktoken encoding for model: {model}")
return encoding
except KeyError:
logger.debug(f"{model} model not found. Using cl100k_base encoding.")
return tiktoken.get_encoding("cl100k_base") |
||
|
|
||
|
|
||
| class ModelUtils: | ||
| """Utility class for model-related operations""" | ||
|
|
@@ -265,37 +293,26 @@ def usage_process(usage: Dict[str, Union[int, Dict[str, int]]] = {}, context: Co | |
|
|
||
| def num_tokens_from_string(string: str, model: str = "openai"): | ||
| """Return the number of tokens used by a list of messages.""" | ||
| import tiktoken | ||
|
|
||
| if model.lower() == "qwen": | ||
| encoding = qwen_tokenizer | ||
| elif model.lower() == "openai": | ||
| encoding = openai_tokenizer | ||
| else: | ||
| try: | ||
| encoding = tiktoken.encoding_for_model(model) | ||
| except KeyError: | ||
| logger.debug( | ||
| f"{model} model not found. Using cl100k_base encoding.") | ||
| encoding = tiktoken.get_encoding("cl100k_base") | ||
| # Use cached encoding to prevent memory leaks | ||
| encoding = _get_cached_tiktoken_encoding(model) | ||
| return len(encoding.encode(string)) | ||
|
|
||
| def num_tokens_from_messages(messages, model="openai"): | ||
| """Return the number of tokens used by a list of messages.""" | ||
| import_package("tiktoken") | ||
| import tiktoken | ||
|
|
||
| if model.lower() == "qwen": | ||
| encoding = qwen_tokenizer | ||
| elif model.lower() == "openai": | ||
| encoding = openai_tokenizer | ||
| else: | ||
| try: | ||
| encoding = tiktoken.encoding_for_model(model) | ||
| except KeyError: | ||
| logger.warning( | ||
| f"{model} model not found. Using cl100k_base encoding.") | ||
| encoding = tiktoken.get_encoding("cl100k_base") | ||
| # Use cached encoding to prevent memory leaks | ||
| encoding = _get_cached_tiktoken_encoding(model) | ||
|
|
||
| tokens_per_message = 3 | ||
| tokens_per_name = 1 | ||
|
|
@@ -316,19 +333,14 @@ def num_tokens_from_messages(messages, model="openai"): | |
|
|
||
| def truncate_tokens_from_messages(messages: List[Dict[str, Any]], max_tokens: int, keep_both_sides: bool = False, model: str = "gpt-4o"): | ||
| import_package("tiktoken") | ||
| import tiktoken | ||
|
|
||
| if model.lower() == "qwen": | ||
| return qwen_tokenizer.truncate(messages, max_tokens, keep_both_sides) | ||
| elif model.lower() == "openai": | ||
| return openai_tokenizer.truncate(messages, max_tokens, keep_both_sides) | ||
|
|
||
| try: | ||
| encoding = tiktoken.encoding_for_model(model) | ||
| except KeyError: | ||
| logger.warning(f"{model} model not found. Using cl100k_base encoding.") | ||
| encoding = tiktoken.get_encoding("cl100k_base") | ||
|
|
||
| # Use cached encoding to prevent memory leaks | ||
| encoding = _get_cached_tiktoken_encoding(model) | ||
| return encoding.truncate(messages, max_tokens, keep_both_sides) | ||
|
|
||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这部分代码存在并发安全问题。当多个协程同时执行此
if检查时,可能会发生竞态条件,导致创建多个aiohttp.ClientSession实例,从而引发资源泄漏。为了解决这个问题,您应该使用
asyncio.Lock,正如您在 PR 描述中提到的那样。在
__init__方法中初始化锁:在
_get_session方法中使用锁: