Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions chatkit/server.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import asyncio
import base64
from abc import ABC, abstractmethod
from collections.abc import AsyncIterator
from contextlib import contextmanager
Expand Down Expand Up @@ -41,6 +42,7 @@
ErrorEvent,
FeedbackKind,
HiddenContextItem,
InputTranscribeReq,
ItemsFeedbackReq,
ItemsListReq,
NonStreamingReq,
Expand Down Expand Up @@ -69,6 +71,7 @@
ThreadStreamEvent,
ThreadsUpdateReq,
ThreadUpdatedEvent,
TranscriptionResult,
UserMessageInput,
UserMessageItem,
WidgetComponentUpdated,
Expand Down Expand Up @@ -319,6 +322,14 @@ async def add_feedback( # noqa: B027
"""Persist user feedback for one or more thread items."""
pass

async def transcribe( # noqa: B027
self, audio_bytes: bytes, mime_type: str, context: TContext
) -> TranscriptionResult:
"""Transcribe speech audio to text. Override this method to support dictation."""
raise NotImplementedError(
"transcribe() must be overridden to support the input.transcribe request."
)

def action(
self,
thread: ThreadMetadata,
Expand Down Expand Up @@ -446,6 +457,12 @@ async def _process_non_streaming(
request.params.attachment_id, context=context
)
return b"{}"
case InputTranscribeReq():
audio_bytes = base64.b64decode(request.params.audio_base64)
transcription_result = await self.transcribe(
audio_bytes, request.params.mime_type, context=context
)
return self._serialize(transcription_result)
case ItemsListReq():
items_list_params = request.params
items = await self.store.load_thread_items(
Expand Down
24 changes: 24 additions & 0 deletions chatkit/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,29 @@ class AttachmentCreateParams(BaseModel):
mime_type: str


class InputTranscribeReq(BaseReq):
"""Request to transcribe an audio payload into text."""

type: Literal["input.transcribe"] = "input.transcribe"
params: InputTranscribeParams


class InputTranscribeParams(BaseModel):
"""Parameters for speech transcription."""

audio_base64: str
"""Base64-encoded audio bytes."""

mime_type: str
"""MIME type for the audio payload (e.g. 'audio/webm', 'audio/wav')."""


class TranscriptionResult(BaseModel):
"""Input speech transcription result."""

text: str


class ItemsListReq(BaseReq):
"""Request to list items inside a thread."""

Expand Down Expand Up @@ -236,6 +259,7 @@ class ThreadDeleteParams(BaseModel):
| AttachmentsDeleteReq
| ThreadsUpdateReq
| ThreadsDeleteReq
| InputTranscribeReq
)
"""Union of request types that yield immediate responses."""

Expand Down
42 changes: 42 additions & 0 deletions tests/test_chatkit_server.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import asyncio
import base64
import sqlite3
from contextlib import contextmanager
from datetime import datetime
Expand Down Expand Up @@ -38,6 +39,8 @@
FileAttachment,
ImageAttachment,
InferenceOptions,
InputTranscribeParams,
InputTranscribeReq,
ItemFeedbackParams,
ItemsFeedbackReq,
ItemsListParams,
Expand Down Expand Up @@ -75,6 +78,7 @@
ThreadUpdatedEvent,
ThreadUpdateParams,
ToolChoice,
TranscriptionResult,
UserMessageInput,
UserMessageItem,
UserMessageTextContent,
Expand Down Expand Up @@ -156,6 +160,7 @@ def make_server(
]
| None = None,
file_store: AttachmentStore | None = None,
transcribe_callback: Callable[[bytes, str, Any], TranscriptionResult] | None = None,
):
global server_id
db_path = f"file:{server_id}?mode=memory&cache=shared"
Expand Down Expand Up @@ -203,6 +208,13 @@ async def add_feedback(
return
handle_feedback(thread_id, item_ids, feedback, context)

async def transcribe(
self, audio_bytes: bytes, mime_type: str, context: Any
) -> TranscriptionResult:
if transcribe_callback is None:
return await super().transcribe(audio_bytes, mime_type, context)
return transcribe_callback(audio_bytes, mime_type, context)

async def process_streaming(
self, request_obj, context: Any | None = None
) -> list[ThreadStreamEvent]:
Expand Down Expand Up @@ -1843,6 +1855,36 @@ async def responder(
assert any(e.type == "thread.item.done" for e in events)


async def test_input_transcribe_decodes_base64_and_passes_mime_type():
audio_bytes = b"hello audio"
audio_b64 = base64.b64encode(audio_bytes).decode("ascii")
seen: dict[str, Any] = {}

def transcribe_callback(
audio: bytes, mime: str, context: Any
) -> TranscriptionResult:
seen["audio"] = audio
seen["mime"] = mime
seen["context"] = context
return TranscriptionResult(text="ok")

with make_server(transcribe_callback=transcribe_callback) as server:
result = await server.process_non_streaming(
InputTranscribeReq(
params=InputTranscribeParams(
audio_base64=audio_b64,
mime_type="audio/wav",
)
)
)
parsed = TypeAdapter(TranscriptionResult).validate_json(result.json)
assert parsed.text == "ok"

assert seen["audio"] == audio_bytes
assert seen["mime"] == "audio/wav"
assert seen["context"] == DEFAULT_CONTEXT


async def test_retry_after_item_passes_tools_to_responder():
pass

Expand Down