Skip to content

Commit 0e739b4

Browse files
committed
fix: parse and detect sse error event data from dynamo
1 parent 5ee6f01 commit 0e739b4

File tree

6 files changed

+288
-16
lines changed

6 files changed

+288
-16
lines changed

aiperf/clients/http/aiohttp_client.py

Lines changed: 33 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,8 @@
99

1010
from aiperf.clients.http.defaults import AioHttpDefaults, SocketDefaults
1111
from aiperf.clients.model_endpoint_info import ModelEndpointInfo
12-
from aiperf.common.enums import SSEFieldType
12+
from aiperf.common.enums import SSEEventType, SSEFieldType
13+
from aiperf.common.exceptions import SSEResponseError
1314
from aiperf.common.mixins import AIPerfLoggerMixin
1415
from aiperf.common.models import (
1516
ErrorDetails,
@@ -129,10 +130,14 @@ async def _request(
129130
)
130131
record.end_perf_ns = time.perf_counter_ns()
131132

133+
except SSEResponseError as e:
134+
record.end_perf_ns = time.perf_counter_ns()
135+
self.error(f"Error in SSE response: {e!r}")
136+
record.error = ErrorDetails.from_exception(e)
132137
except Exception as e:
133138
record.end_perf_ns = time.perf_counter_ns()
134139
self.error(f"Error in aiohttp request: {e!r}")
135-
record.error = ErrorDetails(type=e.__class__.__name__, message=str(e))
140+
record.error = ErrorDetails.from_exception(e)
136141

137142
return record
138143

@@ -182,6 +187,27 @@ async def read_complete_stream(self) -> list[SSEMessage]:
182187
async for raw_message, first_byte_ns in self.__aiter__():
183188
# Parse the raw SSE message into a SSEMessage object
184189
message = parse_sse_message(raw_message, first_byte_ns)
190+
191+
# Check if the message contains an error in the format:
192+
# event: error
193+
# comment: <error message>
194+
# If so, raise an exception with the error message.
195+
if (
196+
message.packets
197+
and message.packets[0].name == SSEFieldType.EVENT
198+
and message.packets[0].value == SSEEventType.ERROR
199+
):
200+
if (
201+
len(message.packets) > 1
202+
and message.packets[1].name == SSEFieldType.COMMENT
203+
):
204+
error_message = message.packets[1].value
205+
else:
206+
error_message = f"Unknown error {message.model_dump_json()}"
207+
raise SSEResponseError(
208+
f"Error occurred in SSE response: {error_message}", error_code=502
209+
)
210+
185211
messages.append(message)
186212

187213
return messages
@@ -211,18 +237,11 @@ async def __aiter__(self) -> typing.AsyncIterator[tuple[str, int]]:
211237
break
212238
chunk = first_byte + chunk
213239

214-
try:
215-
decoded = chunk.decode("utf-8")
216-
for sub_chunk in decoded.split("\n\n"):
217-
if sub_chunk:
218-
yield (sub_chunk, chunk_ns_first_byte)
219-
# Use the fastest available decoder
220-
except UnicodeDecodeError:
221-
# Handle potential encoding issues gracefully
222-
yield (
223-
chunk.decode("utf-8", errors="replace").strip(),
224-
chunk_ns_first_byte,
225-
)
240+
# Replace invalid UTF-8 characters with the Unicode replacement character
241+
decoded = chunk.decode("utf-8", errors="replace")
242+
for sub_chunk in decoded.split("\n\n"):
243+
if sub_chunk:
244+
yield (sub_chunk, chunk_ns_first_byte)
226245

227246

228247
def parse_sse_message(raw_message: str, perf_ns: int) -> SSEMessage:

aiperf/common/enums/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,7 @@
9393
ServiceType,
9494
)
9595
from aiperf.common.enums.sse_enums import (
96+
SSEEventType,
9697
SSEFieldType,
9798
)
9899
from aiperf.common.enums.system_enums import (
@@ -162,6 +163,7 @@
162163
"RecordProcessorType",
163164
"RequestRateMode",
164165
"ResultsProcessorType",
166+
"SSEEventType",
165167
"SSEFieldType",
166168
"ServiceRegistrationStatus",
167169
"ServiceRunType",

aiperf/common/enums/sse_enums.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,3 +12,9 @@ class SSEFieldType(CaseInsensitiveStrEnum):
1212
ID = "id"
1313
RETRY = "retry"
1414
COMMENT = "comment"
15+
16+
17+
class SSEEventType(CaseInsensitiveStrEnum):
18+
"""Event types in an SSE message."""
19+
20+
ERROR = "error"

aiperf/common/exceptions.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,14 @@ class ShutdownError(AIPerfError):
155155
"""Exception raised when a service encounters an error while shutting down."""
156156

157157

158+
class SSEResponseError(AIPerfError):
159+
"""Exception raised when a SSE response contains an error."""
160+
161+
def __init__(self, message: str, error_code: int = 500) -> None:
162+
self.error_code = error_code
163+
super().__init__(message)
164+
165+
158166
class UnsupportedHookError(AIPerfError):
159167
"""Exception raised when a hook is defined on a class that does not have any base classes that provide that hook type."""
160168

aiperf/common/models/error_models.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,10 +41,13 @@ def __hash__(self) -> int:
4141
@classmethod
4242
def from_exception(cls, e: BaseException) -> "ErrorDetails":
4343
"""Create an error details object from an exception."""
44-
return cls(
44+
error_details = cls(
4545
type=e.__class__.__name__,
4646
message=str(e),
4747
)
48+
if hasattr(e, "error_code") and isinstance(e.error_code, int):
49+
error_details.code = e.error_code
50+
return error_details
4851

4952

5053
class ExitErrorInfo(AIPerfBaseModel):

tests/clients/http/test_aiohttp_sse.py

Lines changed: 235 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,17 @@
77
################################################################################
88

99
import time
10+
from itertools import count
1011
from unittest.mock import AsyncMock, Mock, patch
1112

1213
import pytest
14+
from pytest import param
1315

14-
from aiperf.clients.http.aiohttp_client import AioHttpSSEStreamReader
16+
from aiperf.clients.http.aiohttp_client import AioHttpSSEStreamReader, parse_sse_message
17+
from aiperf.common.enums import SSEFieldType
18+
from aiperf.common.exceptions import SSEResponseError
1519
from aiperf.common.models import SSEMessage
20+
from aiperf.common.models.error_models import ErrorDetails
1621
from tests.clients.http.conftest import (
1722
create_sse_chunk_list,
1823
setup_single_sse_chunk,
@@ -231,3 +236,232 @@ async def test_malformed_sse_stream(self, mock_sse_response: Mock) -> None:
231236
with pytest.raises(Exception, match="Stream corruption"):
232237
async for _ in reader:
233238
pass
239+
240+
241+
@pytest.fixture
242+
async def stream_data_builder():
243+
"""Async fixture for building SSE stream data with auto-incrementing timestamps."""
244+
timestamp_base = 123456789
245+
246+
def build(*messages: tuple[str, int | None]) -> list[tuple[str, int]]:
247+
ts_gen = count(timestamp_base)
248+
return [
249+
(msg, ts_override if ts_override is not None else next(ts_gen))
250+
for msg, ts_override in messages
251+
]
252+
253+
return build
254+
255+
256+
@pytest.fixture
257+
async def mock_sse_iterator(mock_sse_response: Mock):
258+
"""Async fixture that patches AioHttpSSEStreamReader.__aiter__ for testing."""
259+
260+
async def _read_stream(sse_data: list[tuple[str, int]]) -> list[SSEMessage]:
261+
async def mock_aiter():
262+
for data in sse_data:
263+
yield data
264+
265+
with patch.object(
266+
AioHttpSSEStreamReader, "__aiter__", return_value=mock_aiter()
267+
):
268+
reader = AioHttpSSEStreamReader(Mock())
269+
return await reader.read_complete_stream()
270+
271+
return _read_stream
272+
273+
274+
@pytest.mark.asyncio
275+
class TestSSEErrorHandling:
276+
"""Test suite for SSE error handling in AioHttpSSEStreamReader."""
277+
278+
TIMESTAMP = 123456789
279+
280+
@pytest.mark.parametrize(
281+
"sse_message,expected_error_match",
282+
[
283+
param(
284+
"event: error\ncomment: Invalid request",
285+
"Invalid request",
286+
id="error_with_invalid_request",
287+
),
288+
param(
289+
"event: error\ncomment: Server error",
290+
"Server error",
291+
id="error_with_server_error",
292+
),
293+
param(
294+
"event: error\ncomment: Auth failed",
295+
"Auth failed",
296+
id="error_with_auth_failed",
297+
),
298+
param(
299+
'event: error\ncomment: {"code": "E001", "msg": "API Error"}',
300+
"Error occurred in SSE response",
301+
id="error_with_json_message",
302+
),
303+
],
304+
)
305+
async def test_error_with_comment_message(
306+
self,
307+
sse_message: str,
308+
expected_error_match: str,
309+
mock_sse_iterator,
310+
) -> None:
311+
"""Test SSE error with event: error followed by comment raises error."""
312+
with pytest.raises(SSEResponseError, match=expected_error_match):
313+
await mock_sse_iterator([(sse_message, self.TIMESTAMP)])
314+
315+
@pytest.mark.parametrize(
316+
"sse_message",
317+
[
318+
param("event: error", id="error_only"),
319+
param("event: error\ndata: error details", id="error_with_data"),
320+
param("event: error\nid: err-123", id="error_with_id"),
321+
],
322+
)
323+
async def test_error_without_comment_field(
324+
self, sse_message: str, mock_sse_iterator
325+
) -> None:
326+
"""Test SSE error without comment field raises with unknown error message."""
327+
with pytest.raises(SSEResponseError, match="Unknown error"):
328+
await mock_sse_iterator([(sse_message, self.TIMESTAMP)])
329+
330+
@pytest.mark.parametrize(
331+
"error_variant",
332+
[
333+
param("error", id="lowercase"),
334+
param("ERROR", id="uppercase"),
335+
param("Error", id="mixed_case"),
336+
param("eRrOr", id="random_case"),
337+
],
338+
)
339+
async def test_error_case_insensitive(
340+
self, error_variant: str, mock_sse_iterator
341+
) -> None:
342+
"""Test that event: error is case-insensitive."""
343+
sse_message = f"event: error\ncomment: Test error from {error_variant}"
344+
with pytest.raises(SSEResponseError, match="Test error"):
345+
await mock_sse_iterator([(sse_message, self.TIMESTAMP)])
346+
347+
async def test_error_in_middle_of_stream(
348+
self, stream_data_builder, mock_sse_iterator
349+
) -> None:
350+
"""Test that error in the middle of stream raises immediately."""
351+
sse_data = stream_data_builder(
352+
("data: First message", None),
353+
("event: error\ncomment: Stream error", None),
354+
("data: This should not be processed", None),
355+
)
356+
357+
with pytest.raises(SSEResponseError, match="Stream error"):
358+
await mock_sse_iterator(sse_data)
359+
360+
async def test_error_code_is_502(self, mock_sse_iterator) -> None:
361+
"""Test that SSEResponseError has correct error code."""
362+
with pytest.raises(SSEResponseError) as exc_info:
363+
await mock_sse_iterator([("event: error\ncomment: Test", self.TIMESTAMP)])
364+
assert exc_info.value.error_code == 502
365+
366+
async def test_error_details_preserved_with_502_code(
367+
self, mock_sse_iterator
368+
) -> None:
369+
"""Test that error details are properly preserved when SSEResponseError with 502 is converted."""
370+
error_message = "SSE stream error occurred"
371+
with pytest.raises(SSEResponseError) as exc_info:
372+
await mock_sse_iterator(
373+
[("event: error\ncomment: " + error_message, self.TIMESTAMP)]
374+
)
375+
376+
exc = exc_info.value
377+
error_details = ErrorDetails.from_exception(exc)
378+
379+
assert error_details.code == 502
380+
assert error_details.type == "SSEResponseError"
381+
assert error_message in error_details.message
382+
383+
@pytest.mark.parametrize(
384+
"event_type",
385+
[
386+
param("message", id="event_message"),
387+
param("update", id="event_update"),
388+
param("completion", id="event_completion"),
389+
param("chunk", id="event_chunk"),
390+
param("ping", id="event_ping"),
391+
],
392+
)
393+
async def test_normal_events_not_treated_as_errors(self, event_type: str) -> None:
394+
"""Test that normal event types are not treated as errors."""
395+
raw_message = f"event: {event_type}\ndata: some data"
396+
result = parse_sse_message(raw_message, self.TIMESTAMP)
397+
398+
assert result.packets[0].name == SSEFieldType.EVENT
399+
assert result.packets[0].value == event_type
400+
401+
async def test_error_with_very_long_message(self, mock_sse_iterator) -> None:
402+
"""Test SSE error with very long error message (>1000 chars)."""
403+
long_message = "x" * 5000
404+
with pytest.raises(SSEResponseError, match=long_message[:100]):
405+
await mock_sse_iterator(
406+
[("event: error\ncomment: " + long_message, self.TIMESTAMP)]
407+
)
408+
409+
@pytest.mark.parametrize(
410+
"special_chars",
411+
[
412+
param("Error with 'quotes'", id="with_single_quotes"),
413+
param('Error with "double quotes"', id="with_double_quotes"),
414+
param("Error: with: multiple: colons", id="with_colons"),
415+
param("Error\twith\ttabs", id="with_tabs"),
416+
param("Error with émojis 🚀 💻", id="with_emojis"),
417+
param("Error with <html> tags", id="with_html_tags"),
418+
param("Error with | pipe | symbols", id="with_pipes"),
419+
],
420+
)
421+
async def test_error_with_special_characters(
422+
self, special_chars: str, mock_sse_iterator
423+
) -> None:
424+
"""Test SSE error with special characters in message."""
425+
with pytest.raises(SSEResponseError):
426+
await mock_sse_iterator(
427+
[("event: error\ncomment: " + special_chars, self.TIMESTAMP)]
428+
)
429+
430+
async def test_error_with_only_whitespace_comment(self, mock_sse_iterator) -> None:
431+
"""Test SSE error with only whitespace in comment field."""
432+
with pytest.raises(SSEResponseError):
433+
await mock_sse_iterator([("event: error\ncomment: ", self.TIMESTAMP)])
434+
435+
async def test_multiple_errors_in_sequence(
436+
self, stream_data_builder, mock_sse_iterator
437+
) -> None:
438+
"""Test that first error in sequence is raised immediately."""
439+
sse_data = stream_data_builder(
440+
("event: error\ncomment: First error", None),
441+
("event: error\ncomment: Second error", None),
442+
)
443+
444+
with pytest.raises(SSEResponseError, match="First error"):
445+
await mock_sse_iterator(sse_data)
446+
447+
async def test_error_with_mixed_additional_fields(self, mock_sse_iterator) -> None:
448+
"""Test SSE error with multiple additional fields mixed in."""
449+
sse_message = (
450+
"event: error\n"
451+
"id: err-001\n"
452+
"comment: Mixed fields error\n"
453+
"retry: 3000\n"
454+
"data: extra data"
455+
)
456+
with pytest.raises(SSEResponseError, match="Mixed fields error"):
457+
await mock_sse_iterator([(sse_message, self.TIMESTAMP)])
458+
459+
async def test_error_with_special_escape_sequences(self) -> None:
460+
"""Test parsing error message with escape sequences in raw format."""
461+
raw_message = "event: error\ncomment: Error\\nwith\\nnewlines"
462+
result = parse_sse_message(raw_message, self.TIMESTAMP)
463+
464+
assert result.packets[0].name == SSEFieldType.EVENT
465+
assert result.packets[0].value == "error"
466+
assert result.packets[1].name == SSEFieldType.COMMENT
467+
assert result.packets[1].value == "Error\\nwith\\nnewlines"

0 commit comments

Comments
 (0)