|  | 
| 7 | 7 | ################################################################################ | 
| 8 | 8 | 
 | 
| 9 | 9 | import time | 
|  | 10 | +from itertools import count | 
| 10 | 11 | from unittest.mock import AsyncMock, Mock, patch | 
| 11 | 12 | 
 | 
| 12 | 13 | import pytest | 
|  | 14 | +from pytest import param | 
| 13 | 15 | 
 | 
| 14 |  | -from aiperf.clients.http.aiohttp_client import AioHttpSSEStreamReader | 
|  | 16 | +from aiperf.clients.http.aiohttp_client import AioHttpSSEStreamReader, parse_sse_message | 
|  | 17 | +from aiperf.common.enums import SSEFieldType | 
|  | 18 | +from aiperf.common.exceptions import SSEResponseError | 
| 15 | 19 | from aiperf.common.models import SSEMessage | 
|  | 20 | +from aiperf.common.models.error_models import ErrorDetails | 
| 16 | 21 | from tests.clients.http.conftest import ( | 
| 17 | 22 |     create_sse_chunk_list, | 
| 18 | 23 |     setup_single_sse_chunk, | 
| @@ -231,3 +236,232 @@ async def test_malformed_sse_stream(self, mock_sse_response: Mock) -> None: | 
| 231 | 236 |         with pytest.raises(Exception, match="Stream corruption"): | 
| 232 | 237 |             async for _ in reader: | 
| 233 | 238 |                 pass | 
|  | 239 | + | 
|  | 240 | + | 
|  | 241 | +@pytest.fixture | 
|  | 242 | +async def stream_data_builder(): | 
|  | 243 | +    """Async fixture for building SSE stream data with auto-incrementing timestamps.""" | 
|  | 244 | +    timestamp_base = 123456789 | 
|  | 245 | + | 
|  | 246 | +    def build(*messages: tuple[str, int | None]) -> list[tuple[str, int]]: | 
|  | 247 | +        ts_gen = count(timestamp_base) | 
|  | 248 | +        return [ | 
|  | 249 | +            (msg, ts_override if ts_override is not None else next(ts_gen)) | 
|  | 250 | +            for msg, ts_override in messages | 
|  | 251 | +        ] | 
|  | 252 | + | 
|  | 253 | +    return build | 
|  | 254 | + | 
|  | 255 | + | 
|  | 256 | +@pytest.fixture | 
|  | 257 | +async def mock_sse_iterator(mock_sse_response: Mock): | 
|  | 258 | +    """Async fixture that patches AioHttpSSEStreamReader.__aiter__ for testing.""" | 
|  | 259 | + | 
|  | 260 | +    async def _read_stream(sse_data: list[tuple[str, int]]) -> list[SSEMessage]: | 
|  | 261 | +        async def mock_aiter(): | 
|  | 262 | +            for data in sse_data: | 
|  | 263 | +                yield data | 
|  | 264 | + | 
|  | 265 | +        with patch.object( | 
|  | 266 | +            AioHttpSSEStreamReader, "__aiter__", return_value=mock_aiter() | 
|  | 267 | +        ): | 
|  | 268 | +            reader = AioHttpSSEStreamReader(Mock()) | 
|  | 269 | +            return await reader.read_complete_stream() | 
|  | 270 | + | 
|  | 271 | +    return _read_stream | 
|  | 272 | + | 
|  | 273 | + | 
|  | 274 | +@pytest.mark.asyncio | 
|  | 275 | +class TestSSEErrorHandling: | 
|  | 276 | +    """Test suite for SSE error handling in AioHttpSSEStreamReader.""" | 
|  | 277 | + | 
|  | 278 | +    TIMESTAMP = 123456789 | 
|  | 279 | + | 
|  | 280 | +    @pytest.mark.parametrize( | 
|  | 281 | +        "sse_message,expected_error_match", | 
|  | 282 | +        [ | 
|  | 283 | +            param( | 
|  | 284 | +                "event: error\ncomment: Invalid request", | 
|  | 285 | +                "Invalid request", | 
|  | 286 | +                id="error_with_invalid_request", | 
|  | 287 | +            ), | 
|  | 288 | +            param( | 
|  | 289 | +                "event: error\ncomment: Server error", | 
|  | 290 | +                "Server error", | 
|  | 291 | +                id="error_with_server_error", | 
|  | 292 | +            ), | 
|  | 293 | +            param( | 
|  | 294 | +                "event: error\ncomment: Auth failed", | 
|  | 295 | +                "Auth failed", | 
|  | 296 | +                id="error_with_auth_failed", | 
|  | 297 | +            ), | 
|  | 298 | +            param( | 
|  | 299 | +                'event: error\ncomment: {"code": "E001", "msg": "API Error"}', | 
|  | 300 | +                "Error occurred in SSE response", | 
|  | 301 | +                id="error_with_json_message", | 
|  | 302 | +            ), | 
|  | 303 | +        ], | 
|  | 304 | +    ) | 
|  | 305 | +    async def test_error_with_comment_message( | 
|  | 306 | +        self, | 
|  | 307 | +        sse_message: str, | 
|  | 308 | +        expected_error_match: str, | 
|  | 309 | +        mock_sse_iterator, | 
|  | 310 | +    ) -> None: | 
|  | 311 | +        """Test SSE error with event: error followed by comment raises error.""" | 
|  | 312 | +        with pytest.raises(SSEResponseError, match=expected_error_match): | 
|  | 313 | +            await mock_sse_iterator([(sse_message, self.TIMESTAMP)]) | 
|  | 314 | + | 
|  | 315 | +    @pytest.mark.parametrize( | 
|  | 316 | +        "sse_message", | 
|  | 317 | +        [ | 
|  | 318 | +            param("event: error", id="error_only"), | 
|  | 319 | +            param("event: error\ndata: error details", id="error_with_data"), | 
|  | 320 | +            param("event: error\nid: err-123", id="error_with_id"), | 
|  | 321 | +        ], | 
|  | 322 | +    ) | 
|  | 323 | +    async def test_error_without_comment_field( | 
|  | 324 | +        self, sse_message: str, mock_sse_iterator | 
|  | 325 | +    ) -> None: | 
|  | 326 | +        """Test SSE error without comment field raises with unknown error message.""" | 
|  | 327 | +        with pytest.raises(SSEResponseError, match="Unknown error"): | 
|  | 328 | +            await mock_sse_iterator([(sse_message, self.TIMESTAMP)]) | 
|  | 329 | + | 
|  | 330 | +    @pytest.mark.parametrize( | 
|  | 331 | +        "error_variant", | 
|  | 332 | +        [ | 
|  | 333 | +            param("error", id="lowercase"), | 
|  | 334 | +            param("ERROR", id="uppercase"), | 
|  | 335 | +            param("Error", id="mixed_case"), | 
|  | 336 | +            param("eRrOr", id="random_case"), | 
|  | 337 | +        ], | 
|  | 338 | +    ) | 
|  | 339 | +    async def test_error_case_insensitive( | 
|  | 340 | +        self, error_variant: str, mock_sse_iterator | 
|  | 341 | +    ) -> None: | 
|  | 342 | +        """Test that event: error is case-insensitive.""" | 
|  | 343 | +        sse_message = f"event: error\ncomment: Test error from {error_variant}" | 
|  | 344 | +        with pytest.raises(SSEResponseError, match="Test error"): | 
|  | 345 | +            await mock_sse_iterator([(sse_message, self.TIMESTAMP)]) | 
|  | 346 | + | 
|  | 347 | +    async def test_error_in_middle_of_stream( | 
|  | 348 | +        self, stream_data_builder, mock_sse_iterator | 
|  | 349 | +    ) -> None: | 
|  | 350 | +        """Test that error in the middle of stream raises immediately.""" | 
|  | 351 | +        sse_data = stream_data_builder( | 
|  | 352 | +            ("data: First message", None), | 
|  | 353 | +            ("event: error\ncomment: Stream error", None), | 
|  | 354 | +            ("data: This should not be processed", None), | 
|  | 355 | +        ) | 
|  | 356 | + | 
|  | 357 | +        with pytest.raises(SSEResponseError, match="Stream error"): | 
|  | 358 | +            await mock_sse_iterator(sse_data) | 
|  | 359 | + | 
|  | 360 | +    async def test_error_code_is_502(self, mock_sse_iterator) -> None: | 
|  | 361 | +        """Test that SSEResponseError has correct error code.""" | 
|  | 362 | +        with pytest.raises(SSEResponseError) as exc_info: | 
|  | 363 | +            await mock_sse_iterator([("event: error\ncomment: Test", self.TIMESTAMP)]) | 
|  | 364 | +        assert exc_info.value.error_code == 502 | 
|  | 365 | + | 
|  | 366 | +    async def test_error_details_preserved_with_502_code( | 
|  | 367 | +        self, mock_sse_iterator | 
|  | 368 | +    ) -> None: | 
|  | 369 | +        """Test that error details are properly preserved when SSEResponseError with 502 is converted.""" | 
|  | 370 | +        error_message = "SSE stream error occurred" | 
|  | 371 | +        with pytest.raises(SSEResponseError) as exc_info: | 
|  | 372 | +            await mock_sse_iterator( | 
|  | 373 | +                [("event: error\ncomment: " + error_message, self.TIMESTAMP)] | 
|  | 374 | +            ) | 
|  | 375 | + | 
|  | 376 | +        exc = exc_info.value | 
|  | 377 | +        error_details = ErrorDetails.from_exception(exc) | 
|  | 378 | + | 
|  | 379 | +        assert error_details.code == 502 | 
|  | 380 | +        assert error_details.type == "SSEResponseError" | 
|  | 381 | +        assert error_message in error_details.message | 
|  | 382 | + | 
|  | 383 | +    @pytest.mark.parametrize( | 
|  | 384 | +        "event_type", | 
|  | 385 | +        [ | 
|  | 386 | +            param("message", id="event_message"), | 
|  | 387 | +            param("update", id="event_update"), | 
|  | 388 | +            param("completion", id="event_completion"), | 
|  | 389 | +            param("chunk", id="event_chunk"), | 
|  | 390 | +            param("ping", id="event_ping"), | 
|  | 391 | +        ], | 
|  | 392 | +    ) | 
|  | 393 | +    async def test_normal_events_not_treated_as_errors(self, event_type: str) -> None: | 
|  | 394 | +        """Test that normal event types are not treated as errors.""" | 
|  | 395 | +        raw_message = f"event: {event_type}\ndata: some data" | 
|  | 396 | +        result = parse_sse_message(raw_message, self.TIMESTAMP) | 
|  | 397 | + | 
|  | 398 | +        assert result.packets[0].name == SSEFieldType.EVENT | 
|  | 399 | +        assert result.packets[0].value == event_type | 
|  | 400 | + | 
|  | 401 | +    async def test_error_with_very_long_message(self, mock_sse_iterator) -> None: | 
|  | 402 | +        """Test SSE error with very long error message (>1000 chars).""" | 
|  | 403 | +        long_message = "x" * 5000 | 
|  | 404 | +        with pytest.raises(SSEResponseError, match=long_message[:100]): | 
|  | 405 | +            await mock_sse_iterator( | 
|  | 406 | +                [("event: error\ncomment: " + long_message, self.TIMESTAMP)] | 
|  | 407 | +            ) | 
|  | 408 | + | 
|  | 409 | +    @pytest.mark.parametrize( | 
|  | 410 | +        "special_chars", | 
|  | 411 | +        [ | 
|  | 412 | +            param("Error with 'quotes'", id="with_single_quotes"), | 
|  | 413 | +            param('Error with "double quotes"', id="with_double_quotes"), | 
|  | 414 | +            param("Error: with: multiple: colons", id="with_colons"), | 
|  | 415 | +            param("Error\twith\ttabs", id="with_tabs"), | 
|  | 416 | +            param("Error with émojis 🚀 💻", id="with_emojis"), | 
|  | 417 | +            param("Error with <html> tags", id="with_html_tags"), | 
|  | 418 | +            param("Error with | pipe | symbols", id="with_pipes"), | 
|  | 419 | +        ], | 
|  | 420 | +    ) | 
|  | 421 | +    async def test_error_with_special_characters( | 
|  | 422 | +        self, special_chars: str, mock_sse_iterator | 
|  | 423 | +    ) -> None: | 
|  | 424 | +        """Test SSE error with special characters in message.""" | 
|  | 425 | +        with pytest.raises(SSEResponseError): | 
|  | 426 | +            await mock_sse_iterator( | 
|  | 427 | +                [("event: error\ncomment: " + special_chars, self.TIMESTAMP)] | 
|  | 428 | +            ) | 
|  | 429 | + | 
|  | 430 | +    async def test_error_with_only_whitespace_comment(self, mock_sse_iterator) -> None: | 
|  | 431 | +        """Test SSE error with only whitespace in comment field.""" | 
|  | 432 | +        with pytest.raises(SSEResponseError): | 
|  | 433 | +            await mock_sse_iterator([("event: error\ncomment:   ", self.TIMESTAMP)]) | 
|  | 434 | + | 
|  | 435 | +    async def test_multiple_errors_in_sequence( | 
|  | 436 | +        self, stream_data_builder, mock_sse_iterator | 
|  | 437 | +    ) -> None: | 
|  | 438 | +        """Test that first error in sequence is raised immediately.""" | 
|  | 439 | +        sse_data = stream_data_builder( | 
|  | 440 | +            ("event: error\ncomment: First error", None), | 
|  | 441 | +            ("event: error\ncomment: Second error", None), | 
|  | 442 | +        ) | 
|  | 443 | + | 
|  | 444 | +        with pytest.raises(SSEResponseError, match="First error"): | 
|  | 445 | +            await mock_sse_iterator(sse_data) | 
|  | 446 | + | 
|  | 447 | +    async def test_error_with_mixed_additional_fields(self, mock_sse_iterator) -> None: | 
|  | 448 | +        """Test SSE error with multiple additional fields mixed in.""" | 
|  | 449 | +        sse_message = ( | 
|  | 450 | +            "event: error\n" | 
|  | 451 | +            "id: err-001\n" | 
|  | 452 | +            "comment: Mixed fields error\n" | 
|  | 453 | +            "retry: 3000\n" | 
|  | 454 | +            "data: extra data" | 
|  | 455 | +        ) | 
|  | 456 | +        with pytest.raises(SSEResponseError, match="Mixed fields error"): | 
|  | 457 | +            await mock_sse_iterator([(sse_message, self.TIMESTAMP)]) | 
|  | 458 | + | 
|  | 459 | +    async def test_error_with_special_escape_sequences(self) -> None: | 
|  | 460 | +        """Test parsing error message with escape sequences in raw format.""" | 
|  | 461 | +        raw_message = "event: error\ncomment: Error\\nwith\\nnewlines" | 
|  | 462 | +        result = parse_sse_message(raw_message, self.TIMESTAMP) | 
|  | 463 | + | 
|  | 464 | +        assert result.packets[0].name == SSEFieldType.EVENT | 
|  | 465 | +        assert result.packets[0].value == "error" | 
|  | 466 | +        assert result.packets[1].name == SSEFieldType.COMMENT | 
|  | 467 | +        assert result.packets[1].value == "Error\\nwith\\nnewlines" | 
0 commit comments