Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 46 additions & 15 deletions aworld/models/llm_http_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,39 @@ def __init__(
if headers:
self.headers.update(headers)

# Shared aiohttp session to prevent memory leaks
# One session per handler instance, not per request
self._session: Optional[Any] = None

async def _get_session(self):
"""Get or create the shared aiohttp session.

This method ensures we reuse a single session across all requests,
preventing memory leaks from creating/destroying sessions repeatedly.

Returns:
aiohttp.ClientSession: The shared session instance.
"""
import aiohttp
if self._session is None or self._session.closed:
# Create session with connection pooling
connector = aiohttp.TCPConnector(
limit=100, # Max connections
limit_per_host=30, # Max connections per host
)
self._session = aiohttp.ClientSession(connector=connector)
Comment on lines +73 to +79
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

这部分代码存在并发安全问题。当多个协程同时执行此 if 检查时,可能会发生竞态条件,导致创建多个 aiohttp.ClientSession 实例,从而引发资源泄漏。

为了解决这个问题,您应该使用 asyncio.Lock,正如您在 PR 描述中提到的那样。

  1. __init__ 方法中初始化锁:

    import asyncio
    # ...
    class LLMHTTPHandler:
        def __init__(self, ...):
            # ...
            self._session: Optional[Any] = None
            self._session_lock = asyncio.Lock()
  2. _get_session 方法中使用锁:

    async def _get_session(self):
        # ...
        import aiohttp
        async with self._session_lock:
            if self._session is None or self._session.closed:
                # ... create session ...
                self._session = aiohttp.ClientSession(connector=connector)
        return self._session

return self._session

async def close(self):
"""Close the shared aiohttp session.

Call this method when the handler is no longer needed to properly
clean up resources.
"""
if self._session and not self._session.closed:
await self._session.close()
self._session = None
Comment on lines +82 to +90
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

这个 close 方法对于手动清理资源非常好。为了使资源管理更加健壮和易于使用,可以考虑实现异步上下文管理器协议(__aenter____aexit__)。这允许用户将处理器包装在 async with 块中,确保 close() 被自动调用,正如您在 PR 描述中提到的那样。

class LLMHTTPHandler:
    # ... existing code ...

    async def __aenter__(self):
        return self

    async def __aexit__(self, exc_type, exc_val, exc_tb):
        await self.close()

# 用法:
async with LLMHTTPHandler(...) as handler:
    await handler.async_call(...)


def _parse_sse_line(self, line: bytes) -> Optional[Dict[str, Any]]:
"""Parse a Server-Sent Events (SSE) line.

Expand Down Expand Up @@ -178,8 +211,8 @@ async def _make_async_request_stream(
if headers:
request_headers.update(headers)

# Create an independent session and keep it open
session = aiohttp.ClientSession()
# Use the shared session instead of creating a new one
session = await self._get_session()
try:
response = await session.post(
url,
Expand Down Expand Up @@ -215,9 +248,7 @@ async def _make_async_request_stream(
except Exception as e:
logger.error(f"Error in stream: {str(e)}")
raise
finally:
# Ensure the session is eventually closed
await session.close()
# Note: We don't close the session here as it's shared and reused

async def _make_async_request(
self,
Expand All @@ -237,21 +268,21 @@ async def _make_async_request(
Raises:
aiohttp.ClientError: If the request fails.
"""
import aiohttp
url = f"{self.base_url}/{endpoint.lstrip('/')}"
request_headers = self.headers.copy()
if headers:
request_headers.update(headers)

async with aiohttp.ClientSession() as session:
async with session.post(
url,
headers=request_headers,
json=data,
timeout=self.timeout,
) as response:
response.raise_for_status()
return await response.json()
# Use the shared session instead of creating a new one
session = await self._get_session()
async with session.post(
url,
headers=request_headers,
json=data,
timeout=self.timeout,
) as response:
response.raise_for_status()
return await response.json()

def sync_call(
self,
Expand Down
20 changes: 17 additions & 3 deletions aworld/models/openai_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,15 +34,29 @@
ENDOFTEXT: 100256,
}

# Global cache to prevent memory leaks from repeatedly loading BPE files
_BPE_CACHE = {}


def _load_tiktoken_bpe(tiktoken_bpe_file: str) -> Dict[bytes, int]:
"""Load tiktoken BPE file similar to qwen_tokenizer."""
"""Load tiktoken BPE file with caching to prevent memory leaks."""
# Check cache first
if tiktoken_bpe_file in _BPE_CACHE:
return _BPE_CACHE[tiktoken_bpe_file]

# Load and decode file
with open(tiktoken_bpe_file, 'rb') as f:
contents = f.read()
return {
base64.b64decode(token): int(rank) for token, rank in (line.split() for line in contents.splitlines() if line)

result = {
base64.b64decode(token): int(rank)
for token, rank in (line.split() for line in contents.splitlines() if line)
}

# Cache the result
_BPE_CACHE[tiktoken_bpe_file] = result
return result
Comment on lines +38 to +58
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

security-medium medium

The _BPE_CACHE dictionary is an unbounded cache that can grow indefinitely with unique tiktoken_bpe_file paths. This poses a risk of memory exhaustion and Denial of Service (DoS) if file paths are attacker-controlled. Additionally, the current cache implementation has a race condition where multiple threads/coroutines might redundantly load and process the same file. Using functools.lru_cache as suggested addresses both the unbounded cache issue and provides thread-safety.

from functools import lru_cache

@lru_cache(maxsize=16)
def _load_tiktoken_bpe(tiktoken_bpe_file: str) -> Dict[bytes, int]:
    """Load tiktoken BPE file with caching to prevent memory leaks."""
    # Load and decode file
    with open(tiktoken_bpe_file, 'rb') as f:
        contents = f.read()

    result = {
        base64.b64decode(token): int(rank)
        for token, rank in (line.split() for line in contents.splitlines() if line)
    }

    return result



class OpenAITokenizer:
"""OpenAI tokenizer using local tiktoken file."""
Expand Down
19 changes: 17 additions & 2 deletions aworld/models/qwen_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,14 +45,29 @@
))
SPECIAL_TOKENS_SET = set(t for i, t in SPECIAL_TOKENS)

# Global cache to prevent memory leaks from repeatedly loading BPE files
_BPE_CACHE = {}


def _load_tiktoken_bpe(tiktoken_bpe_file: str) -> Dict[bytes, int]:
"""Load tiktoken BPE file with caching to prevent memory leaks."""
# Check cache first
if tiktoken_bpe_file in _BPE_CACHE:
return _BPE_CACHE[tiktoken_bpe_file]

# Load and decode file
with open(tiktoken_bpe_file, 'rb') as f:
contents = f.read()
return {
base64.b64decode(token): int(rank) for token, rank in (line.split() for line in contents.splitlines() if line)

result = {
base64.b64decode(token): int(rank)
for token, rank in (line.split() for line in contents.splitlines() if line)
}

# Cache the result
_BPE_CACHE[tiktoken_bpe_file] = result
return result
Comment on lines +49 to +69
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

security-medium medium

The _BPE_CACHE dictionary is an unbounded cache that can grow indefinitely with unique tiktoken_bpe_file paths. This poses a risk of memory exhaustion and Denial of Service (DoS) if file paths are attacker-controlled. Additionally, the current cache implementation has a race condition where multiple threads/coroutines might redundantly load and process the same file. Using functools.lru_cache as suggested addresses both the unbounded cache issue and provides thread-safety.

from functools import lru_cache

@lru_cache(maxsize=16)
def _load_tiktoken_bpe(tiktoken_bpe_file: str) -> Dict[bytes, int]:
    """Load tiktoken BPE file with caching to prevent memory leaks."""
    # Load and decode file
    with open(tiktoken_bpe_file, 'rb') as f:
        contents = f.read()

    result = {
        base64.b64decode(token): int(rank)
        for token, rank in (line.split() for line in contents.splitlines() if line)
    }

    return result



class QWenTokenizer:
"""QWen tokenizer."""
Expand Down
56 changes: 34 additions & 22 deletions aworld/models/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,34 @@
from aworld.models.openai_tokenizer import openai_tokenizer
from aworld.utils import import_package

# Global cache for tiktoken encodings to prevent memory leaks
_TIKTOKEN_ENCODING_CACHE = {}


def _get_cached_tiktoken_encoding(model: str):
"""
Get cached tiktoken encoding to prevent memory leaks.

Args:
model: Model name (e.g., 'gpt-4o', 'claude-3-opus')

Returns:
Cached tiktoken encoding object
"""
if model not in _TIKTOKEN_ENCODING_CACHE:
import tiktoken
try:
_TIKTOKEN_ENCODING_CACHE[model] = tiktoken.encoding_for_model(model)
logger.debug(f"Created and cached tiktoken encoding for model: {model}")
except KeyError:
logger.debug(f"{model} model not found. Using cl100k_base encoding.")
# Cache cl100k_base if not already cached
if "cl100k_base" not in _TIKTOKEN_ENCODING_CACHE:
_TIKTOKEN_ENCODING_CACHE["cl100k_base"] = tiktoken.get_encoding("cl100k_base")
# Reuse cl100k_base for this model
_TIKTOKEN_ENCODING_CACHE[model] = _TIKTOKEN_ENCODING_CACHE["cl100k_base"]
return _TIKTOKEN_ENCODING_CACHE[model]
Comment on lines +15 to +40
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

security-medium medium

The _TIKTOKEN_ENCODING_CACHE dictionary is an unbounded cache, which can lead to memory exhaustion and a Denial of Service (DoS) if user-controlled model parameters are used to populate it. Additionally, this caching function is not thread-safe and can lead to race conditions where tiktoken.encoding_for_model() is called unnecessarily multiple times under concurrent access. Using functools.lru_cache as suggested addresses both the unbounded cache issue and provides thread-safety.

from functools import lru_cache

@lru_cache(maxsize=128)
def _get_cached_tiktoken_encoding(model: str):
    """
    Get cached tiktoken encoding to prevent memory leaks.

    Args:
        model: Model name (e.g., 'gpt-4o', 'claude-3-opus')

    Returns:
        Cached tiktoken encoding object
    """
    import tiktoken
    try:
        encoding = tiktoken.encoding_for_model(model)
        logger.debug(f"Created and cached tiktoken encoding for model: {model}")
        return encoding
    except KeyError:
        logger.debug(f"{model} model not found. Using cl100k_base encoding.")
        return tiktoken.get_encoding("cl100k_base")



class ModelUtils:
"""Utility class for model-related operations"""
Expand Down Expand Up @@ -265,37 +293,26 @@ def usage_process(usage: Dict[str, Union[int, Dict[str, int]]] = {}, context: Co

def num_tokens_from_string(string: str, model: str = "openai"):
"""Return the number of tokens used by a list of messages."""
import tiktoken

if model.lower() == "qwen":
encoding = qwen_tokenizer
elif model.lower() == "openai":
encoding = openai_tokenizer
else:
try:
encoding = tiktoken.encoding_for_model(model)
except KeyError:
logger.debug(
f"{model} model not found. Using cl100k_base encoding.")
encoding = tiktoken.get_encoding("cl100k_base")
# Use cached encoding to prevent memory leaks
encoding = _get_cached_tiktoken_encoding(model)
return len(encoding.encode(string))

def num_tokens_from_messages(messages, model="openai"):
"""Return the number of tokens used by a list of messages."""
import_package("tiktoken")
import tiktoken

if model.lower() == "qwen":
encoding = qwen_tokenizer
elif model.lower() == "openai":
encoding = openai_tokenizer
else:
try:
encoding = tiktoken.encoding_for_model(model)
except KeyError:
logger.warning(
f"{model} model not found. Using cl100k_base encoding.")
encoding = tiktoken.get_encoding("cl100k_base")
# Use cached encoding to prevent memory leaks
encoding = _get_cached_tiktoken_encoding(model)

tokens_per_message = 3
tokens_per_name = 1
Expand All @@ -316,19 +333,14 @@ def num_tokens_from_messages(messages, model="openai"):

def truncate_tokens_from_messages(messages: List[Dict[str, Any]], max_tokens: int, keep_both_sides: bool = False, model: str = "gpt-4o"):
import_package("tiktoken")
import tiktoken

if model.lower() == "qwen":
return qwen_tokenizer.truncate(messages, max_tokens, keep_both_sides)
elif model.lower() == "openai":
return openai_tokenizer.truncate(messages, max_tokens, keep_both_sides)

try:
encoding = tiktoken.encoding_for_model(model)
except KeyError:
logger.warning(f"{model} model not found. Using cl100k_base encoding.")
encoding = tiktoken.get_encoding("cl100k_base")

# Use cached encoding to prevent memory leaks
encoding = _get_cached_tiktoken_encoding(model)
return encoding.truncate(messages, max_tokens, keep_both_sides)


Expand Down