diff --git a/chatkit/agents.py b/chatkit/agents.py index fbb4446..82c08e7 100644 --- a/chatkit/agents.py +++ b/chatkit/agents.py @@ -38,6 +38,11 @@ from openai.types.responses.response_output_text import ( Annotation as ResponsesAnnotation, ) +from openai.types.responses.response_output_text import ( + AnnotationContainerFileCitation, + AnnotationFileCitation, + AnnotationURLCitation, +) from pydantic import BaseModel, ConfigDict, SkipValidation, TypeAdapter from typing_extensions import assert_never @@ -220,62 +225,6 @@ def _complete(self): self._events.put_nowait(_QueueCompleteSentinel()) -def _convert_content(content: Content) -> AssistantMessageContent: - if content.type == "output_text": - annotations = [ - _convert_annotation(annotation) for annotation in content.annotations - ] - annotations = [a for a in annotations if a is not None] - return AssistantMessageContent( - text=content.text, - annotations=annotations, - ) - else: - return AssistantMessageContent( - text=content.refusal, - annotations=[], - ) - - -def _convert_annotation(raw_annotation: object) -> Annotation | None: - # There is a bug in the OpenAPI client that sometimes parses the annotation delta event into the wrong class - # resulting into annotation being a dict or untyped object instead instead of a ResponsesAnnotation - annotation = TypeAdapter[ResponsesAnnotation](ResponsesAnnotation).validate_python( - raw_annotation - ) - - if annotation.type == "file_citation": - filename = annotation.filename - if not filename: - return None - - return Annotation( - source=FileSource(filename=filename, title=filename), - index=annotation.index, - ) - - if annotation.type == "url_citation": - return Annotation( - source=URLSource( - url=annotation.url, - title=annotation.title, - ), - index=annotation.end_index, - ) - - if annotation.type == "container_file_citation": - filename = annotation.filename - if not filename: - return None - - return Annotation( - source=FileSource(filename=filename, title=filename), - index=annotation.end_index, - ) - - return None - - T1 = TypeVar("T1") T2 = TypeVar("T2") @@ -425,10 +374,86 @@ def partial_image_index_to_progress(self, partial_image_index: int) -> float: return min(1.0, partial_image_index / self.partial_images) + async def file_citation_to_annotation( + self, file_citation: AnnotationFileCitation + ) -> Annotation | None: + """Convert a Responses API file citation into an assistant message annotation.""" + filename = file_citation.filename + if not filename: + return None + + return Annotation( + source=FileSource(filename=filename, title=filename), + index=file_citation.index, + ) + + async def container_file_citation_to_annotation( + self, container_file_citation: AnnotationContainerFileCitation + ) -> Annotation | None: + """Convert a Responses API container file citation into an assistant message annotation.""" + filename = container_file_citation.filename + if not filename: + return None + + return Annotation( + source=FileSource(filename=filename, title=filename), + index=container_file_citation.end_index, + ) + + async def url_citation_to_annotation( + self, url_citation: AnnotationURLCitation + ) -> Annotation | None: + """Convert a Responses API URL citation into an assistant message annotation.""" + return Annotation( + source=URLSource(url=url_citation.url, title=url_citation.title), + index=url_citation.end_index, + ) + _DEFAULT_RESPONSE_STREAM_CONVERTER = ResponseStreamConverter() +async def _convert_content( + content: Content, converter: ResponseStreamConverter +) -> AssistantMessageContent: + if content.type == "output_text": + annotations = [ + await _convert_annotation(annotation, converter) + for annotation in content.annotations + ] + annotations = [a for a in annotations if a is not None] + return AssistantMessageContent( + text=content.text, + annotations=annotations, + ) + else: + return AssistantMessageContent( + text=content.refusal, + annotations=[], + ) + + +async def _convert_annotation( + raw_annotation: object, converter: ResponseStreamConverter +) -> Annotation | None: + # There is a bug in the OpenAPI client that sometimes parses the annotation delta event into the wrong class + # resulting into annotation being a dict or untyped object instead instead of a ResponsesAnnotation + annotation = TypeAdapter[ResponsesAnnotation](ResponsesAnnotation).validate_python( + raw_annotation + ) + + if annotation.type == "file_citation": + return await converter.file_citation_to_annotation(annotation) + + if annotation.type == "url_citation": + return await converter.url_citation_to_annotation(annotation) + + if annotation.type == "container_file_citation": + return await converter.container_file_citation_to_annotation(annotation) + + return None + + async def stream_agent_response( context: AgentContext, result: RunResultStreaming, @@ -546,7 +571,7 @@ def end_workflow(item: WorkflowItem): if event.type == "response.content_part.added": if event.part.type == "reasoning_text": continue - content = _convert_content(event.part) + content = await _convert_content(event.part, converter) yield ThreadItemUpdatedEvent( item_id=event.item_id, update=AssistantMessageContentPartAdded( @@ -574,7 +599,7 @@ def end_workflow(item: WorkflowItem): ), ) elif event.type == "response.output_text.annotation.added": - annotation = _convert_annotation(event.annotation) + annotation = await _convert_annotation(event.annotation, converter) if annotation: # Manually track annotation indices per content part in case we drop an annotation that # we can't convert to our internal representation (e.g. missing filename). @@ -613,7 +638,10 @@ def end_workflow(item: WorkflowItem): # Reusing the Responses message ID id=item.id, thread_id=thread.id, - content=[_convert_content(c) for c in item.content], + content=[ + await _convert_content(c, converter) + for c in item.content + ], created_at=datetime.now(), ), ) @@ -722,7 +750,10 @@ def end_workflow(item: WorkflowItem): # Reusing the Responses message ID id=item.id, thread_id=thread.id, - content=[_convert_content(c) for c in item.content], + content=[ + await _convert_content(c, converter) + for c in item.content + ], created_at=datetime.now(), ), ) diff --git a/tests/test_agents.py b/tests/test_agents.py index 7cdcfa9..1dfe81d 100644 --- a/tests/test_agents.py +++ b/tests/test_agents.py @@ -1020,6 +1020,213 @@ def add_annotation_event(annotation, sequence_number): ] +async def test_stream_agent_response_annotation_added_normalizes_annotations(): + context = AgentContext( + previous_response_id=None, thread=thread, store=mock_store, request_context=None + ) + result = make_result() + item_id = "item_123" + + def add_annotation_event(annotation, sequence_number): + result.add_event( + RawResponsesStreamEvent( + type="raw_response_event", + data=Mock( + type="response.output_text.annotation.added", + annotation=annotation, + content_index=0, + item_id=item_id, + annotation_index=sequence_number, + output_index=0, + sequence_number=sequence_number, + ), + ) + ) + + # Invalid file citation is dropped (empty filename) + add_annotation_event( + { + "type": "file_citation", + "file_id": "file_invalid", + "filename": "", + "index": 0, + }, + sequence_number=0, + ) + + # Container + URL citations are converted; indices are compacted (0, 1) despite the dropped first event. + add_annotation_event( + { + "type": "container_file_citation", + "container_id": "container_1", + "file_id": "file_123", + "filename": "container.txt", + "start_index": 0, + "end_index": 3, + }, + sequence_number=1, + ) + add_annotation_event( + { + "type": "url_citation", + "url": "https://example.com", + "title": "Example", + "start_index": 1, + "end_index": 5, + }, + sequence_number=2, + ) + + result.done() + + events = await all_events(stream_agent_response(context, result)) + assert events == [ + ThreadItemUpdatedEvent( + item_id=item_id, + update=AssistantMessageContentPartAnnotationAdded( + content_index=0, + annotation_index=0, + annotation=Annotation( + source=FileSource(filename="container.txt", title="container.txt"), + index=3, + ), + ), + ), + ThreadItemUpdatedEvent( + item_id=item_id, + update=AssistantMessageContentPartAnnotationAdded( + content_index=0, + annotation_index=1, + annotation=Annotation( + source=URLSource(url="https://example.com", title="Example"), + index=5, + ), + ), + ), + ] + + +async def test_custom_annotation_conversion_used_by_stream_agent_response(): + context = AgentContext( + previous_response_id=None, thread=thread, store=mock_store, request_context=None + ) + result = make_result() + + class CustomResponseStreamConverter(ResponseStreamConverter): + def __init__(self): + super().__init__() + self.calls: list[str] = [] + + async def file_citation_to_annotation(self, file_citation): + self.calls.append("file_citation") + return Annotation( + source=FileSource( + filename="report.pdf", + title="Usage Report", + description="Usage report for the month of January", + ), + index=111, + ) + + async def url_citation_to_annotation(self, url_citation): + self.calls.append("url_citation") + return Annotation( + source=URLSource(url="https://custom.example/url", title="Custom"), + index=222, + ) + + converter = CustomResponseStreamConverter() + + result.add_event( + RawResponsesStreamEvent( + type="raw_response_event", + data=Mock( + type="response.output_text.annotation.added", + annotation=ResponsesAnnotationFileCitation( + type="file_citation", + file_id="file_123", + filename="file.txt", + index=0, + ), + content_index=0, + item_id="m_1", + annotation_index=0, + output_index=0, + sequence_number=0, + ), + ) + ) + + result.add_event( + RawResponsesStreamEvent( + type="raw_response_event", + data=ResponseOutputItemDoneEvent( + type="response.output_item.done", + item=ResponseOutputMessage( + id="m_1", + content=[ + ResponseOutputText( + annotations=[ + ResponsesAnnotationURLCitation( + type="url_citation", + url="https://example.com", + title="Example", + start_index=0, + end_index=7, + ) + ], + text="Hello!", + type="output_text", + ) + ], + role="assistant", + status="completed", + type="message", + ), + output_index=0, + sequence_number=1, + ), + ) + ) + result.done() + + events = await all_events( + stream_agent_response(context, result, converter=converter) + ) + assert len(events) == 2 + + assert isinstance(events[0], ThreadItemUpdatedEvent) + assert events[0].item_id == "m_1" + assert events[0].update == AssistantMessageContentPartAnnotationAdded( + content_index=0, + annotation_index=0, + annotation=Annotation( + source=FileSource( + filename="report.pdf", + title="Usage Report", + description="Usage report for the month of January", + ), + index=111, + ), + ) + + assert isinstance(events[1], ThreadItemDoneEvent) + assert isinstance(events[1].item, AssistantMessageItem) + assert events[1].item.id == "m_1" + assert events[1].item.content == [ + AssistantMessageContent( + text="Hello!", + annotations=[ + Annotation( + source=URLSource(url="https://custom.example/url", title="Custom"), + index=222, + ) + ], + ) + ] + assert converter.calls == ["file_citation", "url_citation"] + + @pytest.mark.parametrize("throw_guardrail", ["input", "output"]) async def test_stream_agent_response_yields_item_removed_event(throw_guardrail): context = AgentContext(