|
1 | 1 | import unittest.mock |
| 2 | +from typing import cast |
2 | 3 |
|
3 | 4 | import pytest |
4 | 5 |
|
5 | 6 | import strands |
6 | 7 | import strands.event_loop |
7 | | -from strands.types._events import TypedEvent |
| 8 | +from strands.types._events import ModelStopReason, TypedEvent |
| 9 | +from strands.types.content import Message |
8 | 10 | from strands.types.streaming import ( |
9 | 11 | ContentBlockDeltaEvent, |
10 | 12 | ContentBlockStartEvent, |
@@ -565,6 +567,88 @@ async def test_process_stream(response, exp_events, agenerator, alist): |
565 | 567 | assert non_typed_events == [] |
566 | 568 |
|
567 | 569 |
|
| 570 | +def _get_message_from_event(event: ModelStopReason) -> Message: |
| 571 | + return cast(Message, event["stop"][1]) |
| 572 | + |
| 573 | + |
| 574 | +@pytest.mark.asyncio |
| 575 | +async def test_process_stream_with_no_signature(agenerator, alist): |
| 576 | + response = [ |
| 577 | + {"messageStart": {"role": "assistant"}}, |
| 578 | + { |
| 579 | + "contentBlockDelta": { |
| 580 | + "delta": {"reasoningContent": {"text": 'User asks: "Reason about 2+2" so I will do that'}}, |
| 581 | + "contentBlockIndex": 0, |
| 582 | + } |
| 583 | + }, |
| 584 | + {"contentBlockDelta": {"delta": {"reasoningContent": {"text": "."}}, "contentBlockIndex": 0}}, |
| 585 | + {"contentBlockStop": {"contentBlockIndex": 0}}, |
| 586 | + { |
| 587 | + "contentBlockDelta": { |
| 588 | + "delta": {"text": "Sure! Let’s do it"}, |
| 589 | + "contentBlockIndex": 1, |
| 590 | + } |
| 591 | + }, |
| 592 | + {"contentBlockStop": {"contentBlockIndex": 1}}, |
| 593 | + {"messageStop": {"stopReason": "end_turn"}}, |
| 594 | + { |
| 595 | + "metadata": { |
| 596 | + "usage": {"inputTokens": 112, "outputTokens": 764, "totalTokens": 876}, |
| 597 | + "metrics": {"latencyMs": 2970}, |
| 598 | + } |
| 599 | + }, |
| 600 | + ] |
| 601 | + |
| 602 | + stream = strands.event_loop.streaming.process_stream(agenerator(response)) |
| 603 | + |
| 604 | + last_event = cast(ModelStopReason, (await alist(stream))[-1]) |
| 605 | + |
| 606 | + message = _get_message_from_event(last_event) |
| 607 | + |
| 608 | + assert "signature" not in message["content"][0]["reasoningContent"]["reasoningText"] |
| 609 | + assert message["content"][1]["text"] == "Sure! Let’s do it" |
| 610 | + |
| 611 | + |
| 612 | +@pytest.mark.asyncio |
| 613 | +async def test_process_stream_with_signature(agenerator, alist): |
| 614 | + response = [ |
| 615 | + {"messageStart": {"role": "assistant"}}, |
| 616 | + { |
| 617 | + "contentBlockDelta": { |
| 618 | + "delta": {"reasoningContent": {"text": 'User asks: "Reason about 2+2" so I will do that'}}, |
| 619 | + "contentBlockIndex": 0, |
| 620 | + } |
| 621 | + }, |
| 622 | + {"contentBlockDelta": {"delta": {"reasoningContent": {"text": "."}}, "contentBlockIndex": 0}}, |
| 623 | + {"contentBlockDelta": {"delta": {"reasoningContent": {"signature": "test-"}}, "contentBlockIndex": 0}}, |
| 624 | + {"contentBlockDelta": {"delta": {"reasoningContent": {"signature": "signature"}}, "contentBlockIndex": 0}}, |
| 625 | + {"contentBlockStop": {"contentBlockIndex": 0}}, |
| 626 | + { |
| 627 | + "contentBlockDelta": { |
| 628 | + "delta": {"text": "Sure! Let’s do it"}, |
| 629 | + "contentBlockIndex": 1, |
| 630 | + } |
| 631 | + }, |
| 632 | + {"contentBlockStop": {"contentBlockIndex": 1}}, |
| 633 | + {"messageStop": {"stopReason": "end_turn"}}, |
| 634 | + { |
| 635 | + "metadata": { |
| 636 | + "usage": {"inputTokens": 112, "outputTokens": 764, "totalTokens": 876}, |
| 637 | + "metrics": {"latencyMs": 2970}, |
| 638 | + } |
| 639 | + }, |
| 640 | + ] |
| 641 | + |
| 642 | + stream = strands.event_loop.streaming.process_stream(agenerator(response)) |
| 643 | + |
| 644 | + last_event = cast(ModelStopReason, (await alist(stream))[-1]) |
| 645 | + |
| 646 | + message = _get_message_from_event(last_event) |
| 647 | + |
| 648 | + assert message["content"][0]["reasoningContent"]["reasoningText"]["signature"] == "test-signature" |
| 649 | + assert message["content"][1]["text"] == "Sure! Let’s do it" |
| 650 | + |
| 651 | + |
568 | 652 | @pytest.mark.asyncio |
569 | 653 | async def test_stream_messages(agenerator, alist): |
570 | 654 | mock_model = unittest.mock.MagicMock() |
|
0 commit comments