diff --git a/chatkit/server.py b/chatkit/server.py index 12f13ca..4ad02ad 100644 --- a/chatkit/server.py +++ b/chatkit/server.py @@ -1,4 +1,5 @@ import asyncio +import base64 from abc import ABC, abstractmethod from collections.abc import AsyncIterator from contextlib import contextmanager @@ -41,6 +42,7 @@ ErrorEvent, FeedbackKind, HiddenContextItem, + InputTranscribeReq, ItemsFeedbackReq, ItemsListReq, NonStreamingReq, @@ -69,6 +71,7 @@ ThreadStreamEvent, ThreadsUpdateReq, ThreadUpdatedEvent, + TranscriptionResult, UserMessageInput, UserMessageItem, WidgetComponentUpdated, @@ -319,6 +322,14 @@ async def add_feedback( # noqa: B027 """Persist user feedback for one or more thread items.""" pass + async def transcribe( # noqa: B027 + self, audio_bytes: bytes, mime_type: str, context: TContext + ) -> TranscriptionResult: + """Transcribe speech audio to text. Override this method to support dictation.""" + raise NotImplementedError( + "transcribe() must be overridden to support the input.transcribe request." + ) + def action( self, thread: ThreadMetadata, @@ -446,6 +457,12 @@ async def _process_non_streaming( request.params.attachment_id, context=context ) return b"{}" + case InputTranscribeReq(): + audio_bytes = base64.b64decode(request.params.audio_base64) + transcription_result = await self.transcribe( + audio_bytes, request.params.mime_type, context=context + ) + return self._serialize(transcription_result) case ItemsListReq(): items_list_params = request.params items = await self.store.load_thread_items( diff --git a/chatkit/types.py b/chatkit/types.py index cbb5900..217deec 100644 --- a/chatkit/types.py +++ b/chatkit/types.py @@ -174,6 +174,29 @@ class AttachmentCreateParams(BaseModel): mime_type: str +class InputTranscribeReq(BaseReq): + """Request to transcribe an audio payload into text.""" + + type: Literal["input.transcribe"] = "input.transcribe" + params: InputTranscribeParams + + +class InputTranscribeParams(BaseModel): + """Parameters for speech transcription.""" + + audio_base64: str + """Base64-encoded audio bytes.""" + + mime_type: str + """MIME type for the audio payload (e.g. 'audio/webm', 'audio/wav').""" + + +class TranscriptionResult(BaseModel): + """Input speech transcription result.""" + + text: str + + class ItemsListReq(BaseReq): """Request to list items inside a thread.""" @@ -236,6 +259,7 @@ class ThreadDeleteParams(BaseModel): | AttachmentsDeleteReq | ThreadsUpdateReq | ThreadsDeleteReq + | InputTranscribeReq ) """Union of request types that yield immediate responses.""" diff --git a/tests/test_chatkit_server.py b/tests/test_chatkit_server.py index f6797fa..e0d03eb 100644 --- a/tests/test_chatkit_server.py +++ b/tests/test_chatkit_server.py @@ -1,4 +1,5 @@ import asyncio +import base64 import sqlite3 from contextlib import contextmanager from datetime import datetime @@ -38,6 +39,8 @@ FileAttachment, ImageAttachment, InferenceOptions, + InputTranscribeParams, + InputTranscribeReq, ItemFeedbackParams, ItemsFeedbackReq, ItemsListParams, @@ -75,6 +78,7 @@ ThreadUpdatedEvent, ThreadUpdateParams, ToolChoice, + TranscriptionResult, UserMessageInput, UserMessageItem, UserMessageTextContent, @@ -156,6 +160,7 @@ def make_server( ] | None = None, file_store: AttachmentStore | None = None, + transcribe_callback: Callable[[bytes, str, Any], TranscriptionResult] | None = None, ): global server_id db_path = f"file:{server_id}?mode=memory&cache=shared" @@ -203,6 +208,13 @@ async def add_feedback( return handle_feedback(thread_id, item_ids, feedback, context) + async def transcribe( + self, audio_bytes: bytes, mime_type: str, context: Any + ) -> TranscriptionResult: + if transcribe_callback is None: + return await super().transcribe(audio_bytes, mime_type, context) + return transcribe_callback(audio_bytes, mime_type, context) + async def process_streaming( self, request_obj, context: Any | None = None ) -> list[ThreadStreamEvent]: @@ -1843,6 +1855,36 @@ async def responder( assert any(e.type == "thread.item.done" for e in events) +async def test_input_transcribe_decodes_base64_and_passes_mime_type(): + audio_bytes = b"hello audio" + audio_b64 = base64.b64encode(audio_bytes).decode("ascii") + seen: dict[str, Any] = {} + + def transcribe_callback( + audio: bytes, mime: str, context: Any + ) -> TranscriptionResult: + seen["audio"] = audio + seen["mime"] = mime + seen["context"] = context + return TranscriptionResult(text="ok") + + with make_server(transcribe_callback=transcribe_callback) as server: + result = await server.process_non_streaming( + InputTranscribeReq( + params=InputTranscribeParams( + audio_base64=audio_b64, + mime_type="audio/wav", + ) + ) + ) + parsed = TypeAdapter(TranscriptionResult).validate_json(result.json) + assert parsed.text == "ok" + + assert seen["audio"] == audio_bytes + assert seen["mime"] == "audio/wav" + assert seen["context"] == DEFAULT_CONTEXT + + async def test_retry_after_item_passes_tools_to_responder(): pass