Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
151 changes: 91 additions & 60 deletions chatkit/agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,11 @@
from openai.types.responses.response_output_text import (
Annotation as ResponsesAnnotation,
)
from openai.types.responses.response_output_text import (
AnnotationContainerFileCitation,
AnnotationFileCitation,
AnnotationURLCitation,
)
from pydantic import BaseModel, ConfigDict, SkipValidation, TypeAdapter
from typing_extensions import assert_never

Expand Down Expand Up @@ -220,62 +225,6 @@ def _complete(self):
self._events.put_nowait(_QueueCompleteSentinel())


def _convert_content(content: Content) -> AssistantMessageContent:
if content.type == "output_text":
annotations = [
_convert_annotation(annotation) for annotation in content.annotations
]
annotations = [a for a in annotations if a is not None]
return AssistantMessageContent(
text=content.text,
annotations=annotations,
)
else:
return AssistantMessageContent(
text=content.refusal,
annotations=[],
)


def _convert_annotation(raw_annotation: object) -> Annotation | None:
# There is a bug in the OpenAPI client that sometimes parses the annotation delta event into the wrong class
# resulting into annotation being a dict or untyped object instead instead of a ResponsesAnnotation
annotation = TypeAdapter[ResponsesAnnotation](ResponsesAnnotation).validate_python(
raw_annotation
)

if annotation.type == "file_citation":
filename = annotation.filename
if not filename:
return None

return Annotation(
source=FileSource(filename=filename, title=filename),
index=annotation.index,
)

if annotation.type == "url_citation":
return Annotation(
source=URLSource(
url=annotation.url,
title=annotation.title,
),
index=annotation.end_index,
)

if annotation.type == "container_file_citation":
filename = annotation.filename
if not filename:
return None

return Annotation(
source=FileSource(filename=filename, title=filename),
index=annotation.end_index,
)

return None


T1 = TypeVar("T1")
T2 = TypeVar("T2")

Expand Down Expand Up @@ -425,10 +374,86 @@ def partial_image_index_to_progress(self, partial_image_index: int) -> float:

return min(1.0, partial_image_index / self.partial_images)

async def file_citation_to_annotation(
self, file_citation: AnnotationFileCitation
) -> Annotation | None:
"""Convert a Responses API file citation into an assistant message annotation."""
filename = file_citation.filename
if not filename:
return None

return Annotation(
source=FileSource(filename=filename, title=filename),
index=file_citation.index,
)

async def container_file_citation_to_annotation(
self, container_file_citation: AnnotationContainerFileCitation
) -> Annotation | None:
"""Convert a Responses API container file citation into an assistant message annotation."""
filename = container_file_citation.filename
if not filename:
return None

return Annotation(
source=FileSource(filename=filename, title=filename),
index=container_file_citation.end_index,
)

async def url_citation_to_annotation(
self, url_citation: AnnotationURLCitation
) -> Annotation | None:
"""Convert a Responses API URL citation into an assistant message annotation."""
return Annotation(
source=URLSource(url=url_citation.url, title=url_citation.title),
index=url_citation.end_index,
)


_DEFAULT_RESPONSE_STREAM_CONVERTER = ResponseStreamConverter()


async def _convert_content(
content: Content, converter: ResponseStreamConverter
) -> AssistantMessageContent:
if content.type == "output_text":
annotations = [
await _convert_annotation(annotation, converter)
for annotation in content.annotations
]
annotations = [a for a in annotations if a is not None]
return AssistantMessageContent(
text=content.text,
annotations=annotations,
)
else:
return AssistantMessageContent(
text=content.refusal,
annotations=[],
)


async def _convert_annotation(
raw_annotation: object, converter: ResponseStreamConverter
) -> Annotation | None:
# There is a bug in the OpenAPI client that sometimes parses the annotation delta event into the wrong class
# resulting into annotation being a dict or untyped object instead instead of a ResponsesAnnotation
annotation = TypeAdapter[ResponsesAnnotation](ResponsesAnnotation).validate_python(
raw_annotation
)

if annotation.type == "file_citation":
return await converter.file_citation_to_annotation(annotation)

if annotation.type == "url_citation":
return await converter.url_citation_to_annotation(annotation)

if annotation.type == "container_file_citation":
return await converter.container_file_citation_to_annotation(annotation)

return None


async def stream_agent_response(
context: AgentContext,
result: RunResultStreaming,
Expand Down Expand Up @@ -546,7 +571,7 @@ def end_workflow(item: WorkflowItem):
if event.type == "response.content_part.added":
if event.part.type == "reasoning_text":
continue
content = _convert_content(event.part)
content = await _convert_content(event.part, converter)
yield ThreadItemUpdatedEvent(
item_id=event.item_id,
update=AssistantMessageContentPartAdded(
Expand Down Expand Up @@ -574,7 +599,7 @@ def end_workflow(item: WorkflowItem):
),
)
elif event.type == "response.output_text.annotation.added":
annotation = _convert_annotation(event.annotation)
annotation = await _convert_annotation(event.annotation, converter)
if annotation:
# Manually track annotation indices per content part in case we drop an annotation that
# we can't convert to our internal representation (e.g. missing filename).
Expand Down Expand Up @@ -613,7 +638,10 @@ def end_workflow(item: WorkflowItem):
# Reusing the Responses message ID
id=item.id,
thread_id=thread.id,
content=[_convert_content(c) for c in item.content],
content=[
await _convert_content(c, converter)
for c in item.content
],
created_at=datetime.now(),
),
)
Expand Down Expand Up @@ -722,7 +750,10 @@ def end_workflow(item: WorkflowItem):
# Reusing the Responses message ID
id=item.id,
thread_id=thread.id,
content=[_convert_content(c) for c in item.content],
content=[
await _convert_content(c, converter)
for c in item.content
],
created_at=datetime.now(),
),
)
Expand Down
Loading