diff --git a/src/aiperf/common/enums/__init__.py b/src/aiperf/common/enums/__init__.py index c562d5094..77b07db6c 100644 --- a/src/aiperf/common/enums/__init__.py +++ b/src/aiperf/common/enums/__init__.py @@ -91,6 +91,7 @@ ServiceType, ) from aiperf.common.enums.sse_enums import ( + SSEEventType, SSEFieldType, ) from aiperf.common.enums.system_enums import ( @@ -155,6 +156,7 @@ "RecordProcessorType", "RequestRateMode", "ResultsProcessorType", + "SSEEventType", "SSEFieldType", "ServiceRegistrationStatus", "ServiceRunType", diff --git a/src/aiperf/common/enums/sse_enums.py b/src/aiperf/common/enums/sse_enums.py index 8599c40b0..dc6b7032f 100644 --- a/src/aiperf/common/enums/sse_enums.py +++ b/src/aiperf/common/enums/sse_enums.py @@ -12,3 +12,9 @@ class SSEFieldType(CaseInsensitiveStrEnum): ID = "id" RETRY = "retry" COMMENT = "comment" + + +class SSEEventType(CaseInsensitiveStrEnum): + """Event types in an SSE message.""" + + ERROR = "error" diff --git a/src/aiperf/common/exceptions.py b/src/aiperf/common/exceptions.py index d84161e4a..18363b513 100644 --- a/src/aiperf/common/exceptions.py +++ b/src/aiperf/common/exceptions.py @@ -167,6 +167,14 @@ class ShutdownError(AIPerfError): """Exception raised when a service encounters an error while shutting down.""" +class SSEResponseError(AIPerfError): + """Exception raised when a SSE response contains an error.""" + + def __init__(self, message: str, error_code: int = 500) -> None: + self.error_code = error_code + super().__init__(message) + + class UnsupportedHookError(AIPerfError): """Exception raised when a hook is defined on a class that does not have any base classes that provide that hook type.""" diff --git a/src/aiperf/common/models/error_models.py b/src/aiperf/common/models/error_models.py index 000ef5e34..fe3809f14 100644 --- a/src/aiperf/common/models/error_models.py +++ b/src/aiperf/common/models/error_models.py @@ -65,12 +65,15 @@ def __hash__(self) -> int: @classmethod def from_exception(cls, e: BaseException) -> "ErrorDetails": """Create an error details object from an exception.""" - return cls( + error_details = cls( type=e.__class__.__name__, message=cls._safe_repr(e), cause=cls._safe_repr(e.__cause__) if e.__cause__ else None, details=[cls._safe_repr(arg) for arg in e.args] if e.args else None, ) + if hasattr(e, "error_code") and isinstance(e.error_code, int): + error_details.code = e.error_code + return error_details class ExitErrorInfo(AIPerfBaseModel): diff --git a/src/aiperf/transports/aiohttp_client.py b/src/aiperf/transports/aiohttp_client.py index 9d18b692f..a01dd1ecc 100644 --- a/src/aiperf/transports/aiohttp_client.py +++ b/src/aiperf/transports/aiohttp_client.py @@ -6,6 +6,7 @@ import aiohttp +from aiperf.common.exceptions import SSEResponseError from aiperf.common.mixins import AIPerfLoggerMixin from aiperf.common.models import ( ErrorDetails, @@ -102,6 +103,7 @@ async def _request( ): # Parse SSE stream with optimal performance async for message in AsyncSSEStreamReader(response.content): + AsyncSSEStreamReader.inspect_message_for_error(message) record.responses.append(message) else: raw_response = await response.text() @@ -114,7 +116,10 @@ async def _request( ) ) record.end_perf_ns = time.perf_counter_ns() - + except SSEResponseError as e: + record.end_perf_ns = time.perf_counter_ns() + self.error(f"Error in SSE response: {e!r}") + record.error = ErrorDetails.from_exception(e) except Exception as e: record.end_perf_ns = time.perf_counter_ns() self.error(f"Error in aiohttp request: {e!r}") diff --git a/src/aiperf/transports/sse_utils.py b/src/aiperf/transports/sse_utils.py index 1bc51d987..b207e728a 100644 --- a/src/aiperf/transports/sse_utils.py +++ b/src/aiperf/transports/sse_utils.py @@ -6,6 +6,8 @@ from collections.abc import AsyncIterator from aiperf.common.aiperf_logger import AIPerfLogger +from aiperf.common.enums.sse_enums import SSEEventType, SSEFieldType +from aiperf.common.exceptions import SSEResponseError from aiperf.common.models import SSEMessage _logger = AIPerfLogger(__name__) @@ -71,9 +73,38 @@ async def read_complete_stream(self) -> list[SSEMessage]: """Read the complete SSE stream and return a list of SSE messages.""" messages: list[SSEMessage] = [] async for message in self: + AsyncSSEStreamReader.inspect_message_for_error(message) messages.append(message) return messages + @staticmethod + def inspect_message_for_error(message: SSEMessage): + """Check if the message contains an error event packet and raise an SSEResponseError if so. + + If so, look for any comment field and raise an SSEResponseError + with that comment as the error message, otherwise use the full message. + """ + has_error_event = any( + packet.name == SSEFieldType.EVENT and packet.value == SSEEventType.ERROR + for packet in message.packets + ) + + if has_error_event: + error_message = None + for packet in message.packets: + if packet.name == SSEFieldType.COMMENT: + error_message = packet.value + break + + if error_message is None: + error_message = ( + f"Unknown error in SSE response: {message.model_dump_json()}" + ) + + raise SSEResponseError( + f"Error occurred in SSE response: {error_message}", error_code=502 + ) + async def __aiter__(self) -> AsyncIterator[SSEMessage]: """Iterate over the SSE stream in a performant manner and yield parsed SSE messages as they arrive.""" diff --git a/tests/transports/test_aiohttp_client.py b/tests/transports/test_aiohttp_client.py index 5b5a85251..bddf49cd3 100644 --- a/tests/transports/test_aiohttp_client.py +++ b/tests/transports/test_aiohttp_client.py @@ -120,6 +120,74 @@ async def mock_aiter(): record, expected_response_count=2, expected_response_type=SSEMessage ) + @pytest.mark.asyncio + @pytest.mark.parametrize( + "comment_value,expected_error_text", + [ + ("Rate limit exceeded", "Rate limit exceeded"), + (None, "Unknown error in SSE response"), + ], + ) + async def test_sse_stream_error_event_handling( + self, + aiohttp_client: AioHttpClient, + mock_sse_response: Mock, + comment_value: str | None, + expected_error_text: str, + ) -> None: + """Test that SSE error events are properly caught and handled in the client.""" + from aiperf.common.enums import SSEEventType, SSEFieldType + from aiperf.common.models import SSEField + + packets = [ + SSEField(name=SSEFieldType.EVENT, value=SSEEventType.ERROR), + ] + if comment_value: + packets.append(SSEField(name=SSEFieldType.COMMENT, value=comment_value)) + packets.append(SSEField(name=SSEFieldType.DATA, value="{}")) + + mock_error_message = SSEMessage(perf_ns=123456789, packets=packets) + + with ( + patch("aiohttp.ClientSession") as mock_session_class, + patch( + "aiperf.transports.aiohttp_client.AsyncSSEStreamReader" + ) as mock_reader_class, + ): + + async def mock_content_iter(): + yield b"event: error\n" + if comment_value: + yield f": {comment_value}\n".encode() + yield b"data: {}\n\n" + + mock_sse_response.content = mock_content_iter() + + setup_mock_session(mock_session_class, mock_sse_response, ["request"]) + + async def mock_aiter(): + yield mock_error_message + from aiperf.transports.sse_utils import AsyncSSEStreamReader + + AsyncSSEStreamReader.inspect_message_for_error(mock_error_message) + + mock_reader = Mock() + mock_reader.__aiter__ = Mock(return_value=mock_aiter()) + mock_reader_class.return_value = mock_reader + + record = await aiohttp_client.post_request( + "http://test.com/stream", + '{"stream": true}', + {"Accept": "text/event-stream"}, + ) + + assert record.error is not None + assert record.error.code == 502 + assert record.error.type == "SSEResponseError" + assert expected_error_text in record.error.message + assert len(record.responses) == 1 + assert isinstance(record.responses[0], SSEMessage) + @pytest.mark.asyncio @pytest.mark.parametrize( "status_code,reason,error_text", @@ -184,13 +252,9 @@ async def test_exception_handling( "exception_class,message,expected_type", [ (aiohttp.ClientConnectorError, "Connection failed", "ClientConnectorError"), - ( - aiohttp.ClientResponseError, - "Internal Server Error", - "ClientResponseError", - ), + (aiohttp.ClientResponseError, "Internal Server Error", "ClientResponseError"), ], - ) + ) # fmt: skip async def test_aiohttp_specific_exceptions( self, aiohttp_client: AioHttpClient, diff --git a/tests/transports/test_aiohttp_sse.py b/tests/transports/test_aiohttp_sse.py index 5ce036754..da425dd5e 100644 --- a/tests/transports/test_aiohttp_sse.py +++ b/tests/transports/test_aiohttp_sse.py @@ -8,6 +8,7 @@ import pytest +from aiperf.common.exceptions import SSEResponseError from aiperf.common.models import SSEMessage from aiperf.transports.sse_utils import AsyncSSEStreamReader @@ -448,7 +449,7 @@ async def test_aiter_crlf_all_field_types(self) -> None: """Test __aiter__ with CRLF and all SSE field types.""" chunks = [ b"data: test\r\nevent: custom\r\nid: msg-123\r\nretry: 5000\r\n: comment\r\n\r\n" - ] + ] # fmt: skip reader = AsyncSSEStreamReader(self._create_byte_iterator(chunks)) messages = await self._collect_messages(reader) @@ -529,3 +530,47 @@ async def test_aiter_crlf_performance(self) -> None: assert processing_time < 3.0, ( f"CRLF processing took {processing_time:.3f}s, expected < 3s" ) + + @pytest.mark.asyncio + @pytest.mark.parametrize( + "chunks,expected_error", + [ + ([b"data: Normal message\n\n", b"event: error\n: Rate limit\ndata: {}\n\n"], "Rate limit"), + ([b"event: error\ndata: Something went wrong\n\n"], "Unknown error in SSE response"), + ([b"event: error\r\n: Server error\r\ndata: {}\r\n\r\n"], "Server error"), + ([b"event: error\n: Connection timeout\n\n"], "Connection timeout"), + ([b"data: Message 1\n\n", b"data: Message 2\n\n", b"event: error\n: Fatal error\n\n"], "Fatal error"), + ([b'event: error\n: Internal error\ndata: {"error_code": 500}\n\n'], "Internal error"), + ], + ) # fmt: skip + async def test_error_events_raise_in_read_complete_stream( + self, chunks: list[bytes], expected_error: str + ) -> None: + """Test that various error events raise SSEResponseError.""" + reader = AsyncSSEStreamReader(self._create_byte_iterator(chunks)) + + with pytest.raises(SSEResponseError) as exc_info: + await reader.read_complete_stream() + + assert expected_error in str(exc_info.value) + assert exc_info.value.error_code == 502 + + @pytest.mark.asyncio + async def test_error_in_manual_iteration_with_inspect(self) -> None: + """Test that manual iteration with inspect raises on error event.""" + chunks = [ + b"data: First message\n\n", + b"event: error\n: Authentication failed\n\n", + b"data: Should not reach\n\n", + ] + + reader = AsyncSSEStreamReader(self._create_byte_iterator(chunks)) + messages = [] + + with pytest.raises(SSEResponseError) as exc_info: + async for message in reader: + AsyncSSEStreamReader.inspect_message_for_error(message) + messages.append(message) + + assert len(messages) == 1 + assert "Authentication failed" in str(exc_info.value) diff --git a/tests/transports/test_sse_utils.py b/tests/transports/test_sse_utils.py index 08a309510..9716ebe45 100644 --- a/tests/transports/test_sse_utils.py +++ b/tests/transports/test_sse_utils.py @@ -5,7 +5,9 @@ import pytest from aiperf.common.enums import SSEFieldType +from aiperf.common.exceptions import SSEResponseError from aiperf.common.models import SSEField, SSEMessage +from aiperf.transports.sse_utils import AsyncSSEStreamReader @pytest.fixture @@ -482,3 +484,190 @@ def test_parse_many_packets(self, base_perf_ns: int) -> None: assert (end_time - start_time) < 0.250 assert result.perf_ns == base_perf_ns assert len(result.packets) == 1000 + + +class TestInspectMessageForError: + """Test suite for SSE error message inspection functionality.""" + + @pytest.mark.parametrize( + "raw_message", + [ + 'data: {"content": "Hello World"}\nevent: message\nid: msg_123', + "event: message\n: This is a comment\ndata: content", + ": This is just a comment", + "", + ], + ) + def test_non_error_messages_pass_through( + self, raw_message: str, base_perf_ns: int + ) -> None: + """Test that messages without error events pass through without raising.""" + message = SSEMessage.parse(raw_message, base_perf_ns) + AsyncSSEStreamReader.inspect_message_for_error(message) + + @pytest.mark.parametrize( + "raw_message,expected_error_text", + [ + ( + "event: error\n: Authentication failed\ndata: {}", + "Authentication failed", + ), + ( + "event: error\n: Rate limit exceeded\ndata: {}", + "Rate limit exceeded", + ), + ( + "event: message\nevent: error\n: Multiple events error\ndata: {}", + "Multiple events error", + ), + ( + "event: error\n: 你好世界 🚀 !@#$%^&*()\ndata: {}", + "你好世界 🚀 !@#$%^&*()", + ), + ], + ) # fmt: skip + def test_error_event_with_comment( + self, raw_message: str, expected_error_text: str, base_perf_ns: int + ) -> None: + """Test that error events with comments raise SSEResponseError with comment message.""" + message = SSEMessage.parse(raw_message, base_perf_ns) + + with pytest.raises(SSEResponseError) as exc_info: + AsyncSSEStreamReader.inspect_message_for_error(message) + + assert expected_error_text in str(exc_info.value) + assert exc_info.value.error_code == 502 + + def test_error_event_without_comment(self, base_perf_ns: int) -> None: + """Test that error event without comment raises with full message.""" + raw_message = 'event: error\ndata: {"error": "Something went wrong"}' + message = SSEMessage.parse(raw_message, base_perf_ns) + + with pytest.raises(SSEResponseError) as exc_info: + AsyncSSEStreamReader.inspect_message_for_error(message) + + assert "Unknown error in SSE response" in str(exc_info.value) + assert exc_info.value.error_code == 502 + + def test_error_event_with_multiple_comments_uses_first( + self, base_perf_ns: int + ) -> None: + """Test that when multiple comment fields exist, first one is used.""" + raw_message = "event: error\n: First error\n: Second error\ndata: {}" + message = SSEMessage.parse(raw_message, base_perf_ns) + + with pytest.raises(SSEResponseError) as exc_info: + AsyncSSEStreamReader.inspect_message_for_error(message) + + assert "First error" in str(exc_info.value) + assert "Second error" not in str(exc_info.value) + + @pytest.mark.parametrize("event_case", ["error", "ERROR", "Error", "eRrOr"]) + def test_error_event_case_insensitive( + self, event_case: str, base_perf_ns: int + ) -> None: + """Test that error event detection is case-insensitive.""" + raw_message = f"event: {event_case}\n: Error message" + message = SSEMessage.parse(raw_message, base_perf_ns) + + with pytest.raises(SSEResponseError) as exc_info: + AsyncSSEStreamReader.inspect_message_for_error(message) + + assert "Error message" in str(exc_info.value) + assert exc_info.value.error_code == 502 + + +@pytest.fixture +def create_mock_sse_iterator(): + """Factory fixture for creating mock SSE async iterators.""" + + def _factory(*chunks: bytes): + async def mock_async_iter(): + for chunk in chunks: + yield chunk + + return mock_async_iter() + + return _factory + + +class TestAsyncSSEStreamReaderErrorHandling: + """Test suite for AsyncSSEStreamReader integration with error handling.""" + + async def test_read_complete_stream_success_no_errors( + self, create_mock_sse_iterator + ) -> None: + """Test that read_complete_stream completes successfully when no errors.""" + reader = AsyncSSEStreamReader( + create_mock_sse_iterator( + b"data: Hello\n\n", + b"event: message\ndata: World\n\n", + b"data: [DONE]\n\n", + ) + ) + messages = await reader.read_complete_stream() + + assert len(messages) == 3 + assert all(isinstance(msg, SSEMessage) for msg in messages) + + @pytest.mark.parametrize( + "chunks,expected_error,expected_msg_count", + [ + ( + (b"event: error\n: Rate limit exceeded\n\n",), + "Rate limit exceeded", + 0, + ), + ( + ( + b"data: Message 1\n\n", + b"event: error\n: Server overloaded\n\n", + ), + "Server overloaded", + 1, + ), + ( + ( + b"data: Hello\n\n", + b"event: error\n: Connection timeout\ndata: {}\n\n", + b"data: Should not reach\n\n", + ), + "Connection timeout", + 1, + ), + ( + ( + b'data: {"content": "Hello"}\n\n', + b'event: message\ndata: {"content": "World"}\n\n', + b"event: error\n: Authentication expired\ndata: {}\n\n", + ), + "Authentication expired", + 2, + ), + ], + ) + async def test_stream_errors_at_various_positions( + self, + create_mock_sse_iterator, + chunks: tuple[bytes, ...], + expected_error: str, + expected_msg_count: int, + ) -> None: + """Test error handling at different positions in the stream.""" + reader = AsyncSSEStreamReader(create_mock_sse_iterator(*chunks)) + + if expected_msg_count == 0: + # For read_complete_stream, error is raised immediately + with pytest.raises(SSEResponseError) as exc_info: + await reader.read_complete_stream() + assert expected_error in str(exc_info.value) + else: + # For manual iteration, we can count messages before error + messages = [] + with pytest.raises(SSEResponseError) as exc_info: + async for message in reader: + AsyncSSEStreamReader.inspect_message_for_error(message) + messages.append(message) + + assert len(messages) == expected_msg_count + assert expected_error in str(exc_info.value)