Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/aiperf/common/enums/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@
ServiceType,
)
from aiperf.common.enums.sse_enums import (
SSEEventType,
SSEFieldType,
)
from aiperf.common.enums.system_enums import (
Expand Down Expand Up @@ -155,6 +156,7 @@
"RecordProcessorType",
"RequestRateMode",
"ResultsProcessorType",
"SSEEventType",
"SSEFieldType",
"ServiceRegistrationStatus",
"ServiceRunType",
Expand Down
6 changes: 6 additions & 0 deletions src/aiperf/common/enums/sse_enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,9 @@ class SSEFieldType(CaseInsensitiveStrEnum):
ID = "id"
RETRY = "retry"
COMMENT = "comment"


class SSEEventType(CaseInsensitiveStrEnum):
"""Event types in an SSE message."""

ERROR = "error"
8 changes: 8 additions & 0 deletions src/aiperf/common/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,14 @@ class ShutdownError(AIPerfError):
"""Exception raised when a service encounters an error while shutting down."""


class SSEResponseError(AIPerfError):
"""Exception raised when a SSE response contains an error."""

def __init__(self, message: str, error_code: int = 500) -> None:
self.error_code = error_code
super().__init__(message)


class UnsupportedHookError(AIPerfError):
"""Exception raised when a hook is defined on a class that does not have any base classes that provide that hook type."""

Expand Down
5 changes: 4 additions & 1 deletion src/aiperf/common/models/error_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,12 +65,15 @@ def __hash__(self) -> int:
@classmethod
def from_exception(cls, e: BaseException) -> "ErrorDetails":
"""Create an error details object from an exception."""
return cls(
error_details = cls(
type=e.__class__.__name__,
message=cls._safe_repr(e),
cause=cls._safe_repr(e.__cause__) if e.__cause__ else None,
details=[cls._safe_repr(arg) for arg in e.args] if e.args else None,
)
if hasattr(e, "error_code") and isinstance(e.error_code, int):
error_details.code = e.error_code
return error_details


class ExitErrorInfo(AIPerfBaseModel):
Expand Down
7 changes: 6 additions & 1 deletion src/aiperf/transports/aiohttp_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import aiohttp

from aiperf.common.exceptions import SSEResponseError
from aiperf.common.mixins import AIPerfLoggerMixin
from aiperf.common.models import (
ErrorDetails,
Expand Down Expand Up @@ -102,6 +103,7 @@ async def _request(
):
# Parse SSE stream with optimal performance
async for message in AsyncSSEStreamReader(response.content):
AsyncSSEStreamReader.inspect_message_for_error(message)
record.responses.append(message)
else:
raw_response = await response.text()
Expand All @@ -114,7 +116,10 @@ async def _request(
)
)
record.end_perf_ns = time.perf_counter_ns()

except SSEResponseError as e:
record.end_perf_ns = time.perf_counter_ns()
self.error(f"Error in SSE response: {e!r}")
record.error = ErrorDetails.from_exception(e)
except Exception as e:
record.end_perf_ns = time.perf_counter_ns()
self.error(f"Error in aiohttp request: {e!r}")
Expand Down
31 changes: 31 additions & 0 deletions src/aiperf/transports/sse_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
from collections.abc import AsyncIterator

from aiperf.common.aiperf_logger import AIPerfLogger
from aiperf.common.enums.sse_enums import SSEEventType, SSEFieldType
from aiperf.common.exceptions import SSEResponseError
from aiperf.common.models import SSEMessage

_logger = AIPerfLogger(__name__)
Expand Down Expand Up @@ -71,9 +73,38 @@ async def read_complete_stream(self) -> list[SSEMessage]:
"""Read the complete SSE stream and return a list of SSE messages."""
messages: list[SSEMessage] = []
async for message in self:
AsyncSSEStreamReader.inspect_message_for_error(message)
messages.append(message)
return messages

@staticmethod
def inspect_message_for_error(message: SSEMessage):
"""Check if the message contains an error event packet and raise an SSEResponseError if so.

If so, look for any comment field and raise an SSEResponseError
with that comment as the error message, otherwise use the full message.
"""
has_error_event = any(
packet.name == SSEFieldType.EVENT and packet.value == SSEEventType.ERROR
for packet in message.packets
)

if has_error_event:
error_message = None
for packet in message.packets:
if packet.name == SSEFieldType.COMMENT:
error_message = packet.value
break

if error_message is None:
error_message = (
f"Unknown error in SSE response: {message.model_dump_json()}"
)

raise SSEResponseError(
f"Error occurred in SSE response: {error_message}", error_code=502
)

async def __aiter__(self) -> AsyncIterator[SSEMessage]:
"""Iterate over the SSE stream in a performant manner and yield parsed SSE messages as they arrive."""

Expand Down
76 changes: 70 additions & 6 deletions tests/transports/test_aiohttp_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,74 @@ async def mock_aiter():
record, expected_response_count=2, expected_response_type=SSEMessage
)

@pytest.mark.asyncio
@pytest.mark.parametrize(
"comment_value,expected_error_text",
[
("Rate limit exceeded", "Rate limit exceeded"),
(None, "Unknown error in SSE response"),
],
)
async def test_sse_stream_error_event_handling(
self,
aiohttp_client: AioHttpClient,
mock_sse_response: Mock,
comment_value: str | None,
expected_error_text: str,
) -> None:
"""Test that SSE error events are properly caught and handled in the client."""
from aiperf.common.enums import SSEEventType, SSEFieldType
from aiperf.common.models import SSEField

packets = [
SSEField(name=SSEFieldType.EVENT, value=SSEEventType.ERROR),
]
if comment_value:
packets.append(SSEField(name=SSEFieldType.COMMENT, value=comment_value))
packets.append(SSEField(name=SSEFieldType.DATA, value="{}"))

mock_error_message = SSEMessage(perf_ns=123456789, packets=packets)

with (
patch("aiohttp.ClientSession") as mock_session_class,
patch(
"aiperf.transports.aiohttp_client.AsyncSSEStreamReader"
) as mock_reader_class,
):

async def mock_content_iter():
yield b"event: error\n"
if comment_value:
yield f": {comment_value}\n".encode()
yield b"data: {}\n\n"

mock_sse_response.content = mock_content_iter()

setup_mock_session(mock_session_class, mock_sse_response, ["request"])

async def mock_aiter():
yield mock_error_message
from aiperf.transports.sse_utils import AsyncSSEStreamReader

AsyncSSEStreamReader.inspect_message_for_error(mock_error_message)

mock_reader = Mock()
mock_reader.__aiter__ = Mock(return_value=mock_aiter())
mock_reader_class.return_value = mock_reader

record = await aiohttp_client.post_request(
"http://test.com/stream",
'{"stream": true}',
{"Accept": "text/event-stream"},
)

assert record.error is not None
assert record.error.code == 502
assert record.error.type == "SSEResponseError"
assert expected_error_text in record.error.message
assert len(record.responses) == 1
assert isinstance(record.responses[0], SSEMessage)

@pytest.mark.asyncio
@pytest.mark.parametrize(
"status_code,reason,error_text",
Expand Down Expand Up @@ -184,13 +252,9 @@ async def test_exception_handling(
"exception_class,message,expected_type",
[
(aiohttp.ClientConnectorError, "Connection failed", "ClientConnectorError"),
(
aiohttp.ClientResponseError,
"Internal Server Error",
"ClientResponseError",
),
(aiohttp.ClientResponseError, "Internal Server Error", "ClientResponseError"),
],
)
) # fmt: skip
async def test_aiohttp_specific_exceptions(
self,
aiohttp_client: AioHttpClient,
Expand Down
47 changes: 46 additions & 1 deletion tests/transports/test_aiohttp_sse.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

import pytest

from aiperf.common.exceptions import SSEResponseError
from aiperf.common.models import SSEMessage
from aiperf.transports.sse_utils import AsyncSSEStreamReader

Expand Down Expand Up @@ -448,7 +449,7 @@ async def test_aiter_crlf_all_field_types(self) -> None:
"""Test __aiter__ with CRLF and all SSE field types."""
chunks = [
b"data: test\r\nevent: custom\r\nid: msg-123\r\nretry: 5000\r\n: comment\r\n\r\n"
]
] # fmt: skip

reader = AsyncSSEStreamReader(self._create_byte_iterator(chunks))
messages = await self._collect_messages(reader)
Expand Down Expand Up @@ -529,3 +530,47 @@ async def test_aiter_crlf_performance(self) -> None:
assert processing_time < 3.0, (
f"CRLF processing took {processing_time:.3f}s, expected < 3s"
)

@pytest.mark.asyncio
@pytest.mark.parametrize(
"chunks,expected_error",
[
([b"data: Normal message\n\n", b"event: error\n: Rate limit\ndata: {}\n\n"], "Rate limit"),
([b"event: error\ndata: Something went wrong\n\n"], "Unknown error in SSE response"),
([b"event: error\r\n: Server error\r\ndata: {}\r\n\r\n"], "Server error"),
([b"event: error\n: Connection timeout\n\n"], "Connection timeout"),
([b"data: Message 1\n\n", b"data: Message 2\n\n", b"event: error\n: Fatal error\n\n"], "Fatal error"),
([b'event: error\n: Internal error\ndata: {"error_code": 500}\n\n'], "Internal error"),
],
) # fmt: skip
async def test_error_events_raise_in_read_complete_stream(
self, chunks: list[bytes], expected_error: str
) -> None:
"""Test that various error events raise SSEResponseError."""
reader = AsyncSSEStreamReader(self._create_byte_iterator(chunks))

with pytest.raises(SSEResponseError) as exc_info:
await reader.read_complete_stream()

assert expected_error in str(exc_info.value)
assert exc_info.value.error_code == 502

@pytest.mark.asyncio
async def test_error_in_manual_iteration_with_inspect(self) -> None:
"""Test that manual iteration with inspect raises on error event."""
chunks = [
b"data: First message\n\n",
b"event: error\n: Authentication failed\n\n",
b"data: Should not reach\n\n",
]

reader = AsyncSSEStreamReader(self._create_byte_iterator(chunks))
messages = []

with pytest.raises(SSEResponseError) as exc_info:
async for message in reader:
AsyncSSEStreamReader.inspect_message_for_error(message)
messages.append(message)

assert len(messages) == 1
assert "Authentication failed" in str(exc_info.value)
Loading