diff --git a/chatkit/server.py b/chatkit/server.py index 9d82b64..01372cb 100644 --- a/chatkit/server.py +++ b/chatkit/server.py @@ -664,7 +664,6 @@ async def _process_new_thread_item_respond( context: TContext, ) -> AsyncIterator[ThreadStreamEvent]: await self.store.add_thread_item(thread.id, item, context=context) - await self._cleanup_pending_client_tool_call(thread, context) yield ThreadItemDoneEvent(item=item) async for event in self._process_events( diff --git a/tests/test_chatkit_server.py b/tests/test_chatkit_server.py index d7ac167..41be193 100644 --- a/tests/test_chatkit_server.py +++ b/tests/test_chatkit_server.py @@ -679,73 +679,6 @@ async def responder( assert events[1].type == "thread.item.done" assert events[1].item.type == "assistant_message" - -async def test_removes_tool_call_if_no_output_provided(): - async def responder( - thread: ThreadMetadata, input: UserMessageItem | None, context: Any - ) -> AsyncIterator[ThreadStreamEvent]: - assert isinstance(input, UserMessageItem) - assert input.content[0].type == "input_text" - if input.content[0].text == "Message 1": - yield ThreadItemDoneEvent( - item=ClientToolCallItem( - id="msg_1", - created_at=datetime.now(), - name="tool_call_1", - arguments={"arg1": "val1", "arg2": False}, - call_id="tool_call_1", - thread_id=thread.id, - ), - ) - else: - yield ThreadItemDoneEvent( - item=AssistantMessageItem( - id="msg_2", - content=[AssistantMessageContent(text="All done!")], - created_at=datetime.now(), - thread_id=thread.id, - ), - ) - - with make_server(responder) as server: - events = await server.process_streaming( - ThreadsCreateReq( - params=ThreadCreateParams( - input=UserMessageInput( - content=[UserMessageTextContent(text="Message 1")], - attachments=[], - inference_options=InferenceOptions(), - ) - ) - ) - ) - thread = next( - event.thread for event in events if event.type == "thread.created" - ) - - await server.process_streaming( - ThreadsAddUserMessageReq( - params=ThreadAddUserMessageParams( - thread_id=thread.id, - input=UserMessageInput( - content=[UserMessageTextContent(text="Message 2")], - attachments=[], - inference_options=InferenceOptions(), - ), - ) - ) - ) - - items_result = await server.process_non_streaming( - ItemsListReq(params=ItemsListParams(thread_id=thread.id)) - ) - items = TypeAdapter(Page[ThreadItem]).validate_json(items_result.json) - assert len(items.data) == 3 - assert items.data[0].type == "assistant_message" - assert items.data[1].type == "user_message" - assert items.data[2].type == "user_message" - - async def test_respond_with_tool_status(): async def responder( thread: ThreadMetadata, input: UserMessageItem | None, context: Any