diff --git a/.gitignore b/.gitignore index cc44bed61..68f9660f2 100644 --- a/.gitignore +++ b/.gitignore @@ -19,6 +19,7 @@ libs/core/build libs/server/build dist/ +test_output/ .mcp.json diff --git a/app/web_ui/src/lib/api_schema.d.ts b/app/web_ui/src/lib/api_schema.d.ts index bbb643d4e..095e32310 100644 --- a/app/web_ui/src/lib/api_schema.d.ts +++ b/app/web_ui/src/lib/api_schema.d.ts @@ -7031,6 +7031,11 @@ export interface components { * * Contains the input used, its source, the output produced, and optional * repair information if the output needed correction. + * + * Can be nested under another TaskRun; nested runs are stored as child runs + * in a "runs" subfolder (same relationship name as Task's runs). + * + * Accepts both Task and TaskRun as parents (polymorphic). */ "TaskRun-Input": { /** @@ -7093,6 +7098,11 @@ export interface components { * * Contains the input used, its source, the output produced, and optional * repair information if the output needed correction. + * + * Can be nested under another TaskRun; nested runs are stored as child runs + * in a "runs" subfolder (same relationship name as Task's runs). + * + * Accepts both Task and TaskRun as parents (polymorphic). */ "TaskRun-Output": { /** diff --git a/libs/core/kiln_ai/adapters/chat/__init__.py b/libs/core/kiln_ai/adapters/chat/__init__.py index 7ab6328f0..11b5eda12 100644 --- a/libs/core/kiln_ai/adapters/chat/__init__.py +++ b/libs/core/kiln_ai/adapters/chat/__init__.py @@ -1,8 +1,10 @@ from .chat_formatter import ( BasicChatMessage, + ChatCompletionMessageIncludingLiteLLM, ChatFormatter, ChatMessage, ChatStrategy, + MultiturnFormatter, ToolCallMessage, ToolResponseMessage, get_chat_formatter, @@ -11,9 +13,11 @@ __all__ = [ "BasicChatMessage", + "ChatCompletionMessageIncludingLiteLLM", "ChatFormatter", "ChatMessage", "ChatStrategy", + "MultiturnFormatter", "ToolCallMessage", "ToolResponseMessage", "build_tool_call_messages", diff --git a/libs/core/kiln_ai/adapters/chat/chat_formatter.py b/libs/core/kiln_ai/adapters/chat/chat_formatter.py index d1a73514f..22ba371a3 100644 --- a/libs/core/kiln_ai/adapters/chat/chat_formatter.py +++ b/libs/core/kiln_ai/adapters/chat/chat_formatter.py @@ -3,15 +3,25 @@ import json from abc import ABC, abstractmethod from dataclasses import dataclass -from typing import Dict, List, Literal, Optional, Sequence, Union +from typing import Dict, List, Literal, Optional, Sequence, TypeAlias, Union + +from litellm.types.utils import Message as LiteLLMMessage from kiln_ai.datamodel.datamodel_enums import ChatStrategy, InputType from kiln_ai.utils.exhaustive_error import raise_exhaustive_enum_error -from kiln_ai.utils.open_ai_types import ChatCompletionMessageToolCallParam +from kiln_ai.utils.open_ai_types import ( + ChatCompletionMessageParam, + ChatCompletionMessageToolCallParam, +) COT_FINAL_ANSWER_PROMPT = "Considering the above, return a final result." +ChatCompletionMessageIncludingLiteLLM: TypeAlias = Union[ + ChatCompletionMessageParam, LiteLLMMessage +] + + @dataclass class BasicChatMessage: role: Literal["system", "assistant", "user"] @@ -90,6 +100,10 @@ def intermediate_outputs(self) -> Dict[str, str]: """Get the intermediate outputs from the chat formatter.""" return self._intermediate_outputs + def initial_messages(self) -> list[ChatCompletionMessageIncludingLiteLLM]: + """Messages to seed the conversation. Empty for fresh runs; prior trace for continuation.""" + return [] + @abstractmethod def next_turn(self, previous_output: str | None = None) -> Optional[ChatTurn]: """Advance the conversation and return the next messages if any.""" @@ -236,6 +250,49 @@ def next_turn(self, previous_output: str | None = None) -> Optional[ChatTurn]: return None +class MultiturnFormatter(ChatFormatter): + """ + Formatter for continuing a multi-turn conversation with prior trace. + Takes prior_trace (existing conversation) and appends the new user message. + Produces a single turn: the new user message. Tool calls and multi-turn + model responses are handled by _run_model_turn's internal loop. + """ + + def __init__( + self, + prior_trace: list[ChatCompletionMessageParam], + user_input: InputType, + ) -> None: + super().__init__( + system_message="", + user_input=user_input, + thinking_instructions=None, + ) + self._prior_trace = prior_trace + + def initial_messages(self) -> list[ChatCompletionMessageIncludingLiteLLM]: + """Messages to seed the conversation (prior trace).""" + return list(self._prior_trace) + + def next_turn(self, previous_output: str | None = None) -> Optional[ChatTurn]: + if self._state == "start": + # prior trace is already in the messages list and contains system and so on, we only need + # to append the latest new user message + user_msg = BasicChatMessage("user", format_user_message(self.user_input)) + self._state = "awaiting_final" + self._messages.append(user_msg) + return ChatTurn(messages=[user_msg], final_call=True) + + if self._state == "awaiting_final": + if previous_output is None: + raise ValueError("previous_output required for final step") + self._messages.append(BasicChatMessage("assistant", previous_output)) + self._state = "done" + return None + + return None + + def get_chat_formatter( strategy: ChatStrategy, system_message: str, diff --git a/libs/core/kiln_ai/adapters/chat/test_chat_formatter.py b/libs/core/kiln_ai/adapters/chat/test_chat_formatter.py index 642d49236..2903b6eee 100644 --- a/libs/core/kiln_ai/adapters/chat/test_chat_formatter.py +++ b/libs/core/kiln_ai/adapters/chat/test_chat_formatter.py @@ -1,6 +1,7 @@ from kiln_ai.adapters.chat import ChatStrategy, get_chat_formatter from kiln_ai.adapters.chat.chat_formatter import ( COT_FINAL_ANSWER_PROMPT, + MultiturnFormatter, format_user_message, ) @@ -119,6 +120,76 @@ def test_chat_formatter_r1_style(): assert formatter.intermediate_outputs() == {} +def test_multiturn_formatter_initial_messages(): + prior_trace = [ + {"role": "user", "content": "hi"}, + {"role": "assistant", "content": "hello"}, + ] + formatter = MultiturnFormatter(prior_trace=prior_trace, user_input="new input") + assert formatter.initial_messages() == prior_trace + + +def test_multiturn_formatter_next_turn(): + prior_trace = [ + {"role": "user", "content": "hi"}, + {"role": "assistant", "content": "hello"}, + ] + formatter = MultiturnFormatter(prior_trace=prior_trace, user_input="follow-up") + + first = formatter.next_turn() + assert first is not None + assert len(first.messages) == 1 + assert first.messages[0].role == "user" + assert first.messages[0].content == "follow-up" + assert first.final_call + + assert formatter.next_turn("assistant response") is None + + +def test_multiturn_formatter_preserves_tool_call_messages(): + prior_trace = [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "4"}, + { + "role": "assistant", + "content": "", + "reasoning_content": "Let me multiply 4 by 7.\n", + "tool_calls": [ + { + "id": "call_abc123", + "function": {"arguments": '{"a": 4, "b": 7}', "name": "multiply"}, + "type": "function", + } + ], + }, + { + "content": "28", + "role": "tool", + "tool_call_id": "call_abc123", + "kiln_task_tool_data": None, + }, + { + "role": "assistant", + "content": "4 multiplied by 7 is 28.", + "reasoning_content": "Done.\n", + }, + ] + formatter = MultiturnFormatter(prior_trace=prior_trace, user_input="now double it") + initial = formatter.initial_messages() + assert initial == prior_trace + assert initial[2]["tool_calls"][0]["id"] == "call_abc123" + assert initial[2]["tool_calls"][0]["function"]["name"] == "multiply" + assert initial[3]["role"] == "tool" + assert initial[3]["tool_call_id"] == "call_abc123" + + first = formatter.next_turn() + assert first is not None + assert len(first.messages) == 1 + assert first.messages[0].role == "user" + assert first.messages[0].content == "now double it" + assert first.final_call + + def test_format_user_message(): # String assert format_user_message("test input") == "test input" diff --git a/libs/core/kiln_ai/adapters/litellm_utils/__init__.py b/libs/core/kiln_ai/adapters/litellm_utils/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/libs/core/kiln_ai/adapters/litellm_utils/litellm_streaming.py b/libs/core/kiln_ai/adapters/litellm_utils/litellm_streaming.py new file mode 100644 index 000000000..68b29dd62 --- /dev/null +++ b/libs/core/kiln_ai/adapters/litellm_utils/litellm_streaming.py @@ -0,0 +1,60 @@ +from __future__ import annotations + +from typing import Any, AsyncIterator, Optional, Union + +import litellm +from litellm.types.utils import ( + ModelResponse, + ModelResponseStream, + TextCompletionResponse, +) + + +class StreamingCompletion: + """ + Async iterable wrapper around ``litellm.acompletion`` with streaming. + + Yields ``ModelResponseStream`` chunks as they arrive. After iteration + completes, the assembled ``ModelResponse`` is available via the + ``.response`` property. + + Usage:: + + stream = StreamingCompletion(model=..., messages=...) + async for chunk in stream: + # handle chunk however you like (print, log, send over WS, …) + pass + final = stream.response # fully assembled ModelResponse + """ + + def __init__(self, *args: Any, **kwargs: Any) -> None: + kwargs = dict(kwargs) + kwargs.pop("stream", None) + self._args = args + self._kwargs = kwargs + self._response: Optional[Union[ModelResponse, TextCompletionResponse]] = None + self._iterated: bool = False + + @property + def response(self) -> Optional[Union[ModelResponse, TextCompletionResponse]]: + """The final assembled response. Only available after iteration.""" + if not self._iterated: + raise RuntimeError( + "StreamingCompletion has not been iterated yet. " + "Use 'async for chunk in stream:' before accessing .response" + ) + return self._response + + async def __aiter__(self) -> AsyncIterator[ModelResponseStream]: + self._response = None + self._iterated = False + + chunks: list[ModelResponseStream] = [] + stream = await litellm.acompletion(*self._args, stream=True, **self._kwargs) + + async for chunk in stream: + chunks.append(chunk) + yield chunk + + self._response = litellm.stream_chunk_builder(chunks) + self._iterated = True diff --git a/libs/core/kiln_ai/adapters/litellm_utils/test_litellm_streaming.py b/libs/core/kiln_ai/adapters/litellm_utils/test_litellm_streaming.py new file mode 100644 index 000000000..e35a51982 --- /dev/null +++ b/libs/core/kiln_ai/adapters/litellm_utils/test_litellm_streaming.py @@ -0,0 +1,139 @@ +from __future__ import annotations + +from types import SimpleNamespace +from typing import Any, List +from unittest.mock import MagicMock, patch + +import pytest + +from kiln_ai.adapters.litellm_utils.litellm_streaming import StreamingCompletion + + +def _make_chunk(content: str | None = None, finish_reason: str | None = None) -> Any: + """Build a minimal chunk object matching litellm's streaming shape.""" + delta = SimpleNamespace(content=content, role="assistant") + choice = SimpleNamespace(delta=delta, finish_reason=finish_reason, index=0) + return SimpleNamespace(choices=[choice], id="chatcmpl-test", model="test-model") + + +async def _async_iter(items: List[Any]): + """Turn a plain list into an async iterator.""" + for item in items: + yield item + + +@pytest.fixture +def mock_acompletion(): + with patch("litellm.acompletion") as mock: + yield mock + + +@pytest.fixture +def mock_chunk_builder(): + with patch("litellm.stream_chunk_builder") as mock: + yield mock + + +class TestStreamingCompletion: + async def test_yields_all_chunks(self, mock_acompletion, mock_chunk_builder): + chunks = [_make_chunk("Hello"), _make_chunk(" world"), _make_chunk("!")] + mock_acompletion.return_value = _async_iter(chunks) + mock_chunk_builder.return_value = MagicMock(name="final_response") + + stream = StreamingCompletion(model="test", messages=[]) + received = [chunk async for chunk in stream] + + assert received == chunks + + async def test_response_available_after_iteration( + self, mock_acompletion, mock_chunk_builder + ): + chunks = [_make_chunk("hi")] + mock_acompletion.return_value = _async_iter(chunks) + sentinel = MagicMock(name="final_response") + mock_chunk_builder.return_value = sentinel + + stream = StreamingCompletion(model="test", messages=[]) + async for _ in stream: + pass + + assert stream.response is sentinel + + async def test_response_raises_before_iteration(self): + stream = StreamingCompletion(model="test", messages=[]) + with pytest.raises(RuntimeError, match="not been iterated"): + _ = stream.response + + async def test_stream_kwarg_is_stripped(self, mock_acompletion, mock_chunk_builder): + mock_acompletion.return_value = _async_iter([]) + mock_chunk_builder.return_value = None + + stream = StreamingCompletion(model="test", messages=[], stream=False) + async for _ in stream: + pass + + _, call_kwargs = mock_acompletion.call_args + assert call_kwargs["stream"] is True + + async def test_passes_args_and_kwargs_through( + self, mock_acompletion, mock_chunk_builder + ): + mock_acompletion.return_value = _async_iter([]) + mock_chunk_builder.return_value = None + + stream = StreamingCompletion( + model="gpt-4", messages=[{"role": "user", "content": "hi"}], temperature=0.5 + ) + async for _ in stream: + pass + + _, call_kwargs = mock_acompletion.call_args + assert call_kwargs["model"] == "gpt-4" + assert call_kwargs["messages"] == [{"role": "user", "content": "hi"}] + assert call_kwargs["temperature"] == 0.5 + assert call_kwargs["stream"] is True + + async def test_chunks_passed_to_builder(self, mock_acompletion, mock_chunk_builder): + chunks = [_make_chunk("a"), _make_chunk("b")] + mock_acompletion.return_value = _async_iter(chunks) + mock_chunk_builder.return_value = MagicMock() + + stream = StreamingCompletion(model="test", messages=[]) + async for _ in stream: + pass + + mock_chunk_builder.assert_called_once_with(chunks) + + async def test_re_iteration_resets_state( + self, mock_acompletion, mock_chunk_builder + ): + first_chunks = [_make_chunk("first")] + second_chunks = [_make_chunk("second")] + first_response = MagicMock(name="first_response") + second_response = MagicMock(name="second_response") + + mock_acompletion.side_effect = [ + _async_iter(first_chunks), + _async_iter(second_chunks), + ] + mock_chunk_builder.side_effect = [first_response, second_response] + + stream = StreamingCompletion(model="test", messages=[]) + + async for _ in stream: + pass + assert stream.response is first_response + + async for _ in stream: + pass + assert stream.response is second_response + + async def test_empty_stream(self, mock_acompletion, mock_chunk_builder): + mock_acompletion.return_value = _async_iter([]) + mock_chunk_builder.return_value = None + + stream = StreamingCompletion(model="test", messages=[]) + received = [chunk async for chunk in stream] + + assert received == [] + assert stream.response is None diff --git a/libs/core/kiln_ai/adapters/model_adapters/adapter_stream.py b/libs/core/kiln_ai/adapters/model_adapters/adapter_stream.py new file mode 100644 index 000000000..db4c4642a --- /dev/null +++ b/libs/core/kiln_ai/adapters/model_adapters/adapter_stream.py @@ -0,0 +1,291 @@ +from __future__ import annotations + +import copy +import json +import logging +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any, AsyncIterator + +from litellm.types.utils import ( + ChatCompletionMessageToolCall, + Choices, + ModelResponse, +) + +from kiln_ai.adapters.chat import ChatCompletionMessageIncludingLiteLLM +from kiln_ai.adapters.chat.chat_formatter import ChatFormatter +from kiln_ai.adapters.litellm_utils.litellm_streaming import StreamingCompletion +from kiln_ai.adapters.ml_model_list import KilnModelProvider +from kiln_ai.adapters.model_adapters.stream_events import ( + AdapterStreamEvent, + ToolCallEvent, + ToolCallEventType, +) +from kiln_ai.adapters.run_output import RunOutput +from kiln_ai.datamodel import Usage + +if TYPE_CHECKING: + from kiln_ai.adapters.model_adapters.litellm_adapter import LiteLlmAdapter + +MAX_CALLS_PER_TURN = 10 +MAX_TOOL_CALLS_PER_TURN = 30 + +logger = logging.getLogger(__name__) + + +@dataclass +class AdapterStreamResult: + run_output: RunOutput + usage: Usage + + +class AdapterStream: + """ + Orchestrates a full task execution as an async iterator, + composing StreamingCompletion instances across chat turns and tool-call rounds. + + Yields ``ModelResponseStream`` chunks from each LLM call and + ``ToolCallEvent`` instances between tool-call rounds. + + After iteration completes the ``result`` property provides the + ``AdapterStreamResult`` with the final ``RunOutput`` and ``Usage``. + """ + + def __init__( + self, + adapter: LiteLlmAdapter, + provider: KilnModelProvider, + chat_formatter: ChatFormatter, + initial_messages: list[ChatCompletionMessageIncludingLiteLLM], + top_logprobs: int | None, + ) -> None: + self._adapter = adapter + self._provider = provider + self._chat_formatter = chat_formatter + self._messages = initial_messages + self._top_logprobs = top_logprobs + self._result: AdapterStreamResult | None = None + self._iterated = False + + @property + def result(self) -> AdapterStreamResult: + if not self._iterated: + raise RuntimeError( + "AdapterStream has not been iterated yet. " + "Use 'async for event in stream:' before accessing .result" + ) + if self._result is None: + raise RuntimeError("AdapterStream completed without producing a result") + return self._result + + async def __aiter__(self) -> AsyncIterator[AdapterStreamEvent]: + self._result = None + self._iterated = False + + usage = Usage() + prior_output: str | None = None + final_choice: Choices | None = None + turns = 0 + + while True: + turns += 1 + if turns > MAX_CALLS_PER_TURN: + raise RuntimeError( + f"Too many turns ({turns}). Stopping iteration to avoid using too many tokens." + ) + + turn = self._chat_formatter.next_turn(prior_output) + if turn is None: + break + + for message in turn.messages: + if message.content is None: + raise ValueError("Empty message content isn't allowed") + self._messages.append( + {"role": message.role, "content": message.content} # type: ignore[arg-type] + ) + + skip_response_format = not turn.final_call + turn_top_logprobs = self._top_logprobs if turn.final_call else None + + async for event in self._stream_model_turn( + skip_response_format, turn_top_logprobs + ): + if isinstance(event, _ModelTurnComplete): + usage += event.usage + prior_output = event.assistant_message + final_choice = event.model_choice + else: + yield event + + if not prior_output: + raise RuntimeError("No assistant message/output returned from model") + + logprobs = self._adapter._extract_and_validate_logprobs(final_choice) + + intermediate_outputs = self._chat_formatter.intermediate_outputs() + self._adapter._extract_reasoning_to_intermediate_outputs( + final_choice, intermediate_outputs + ) + + if not isinstance(prior_output, str): + raise RuntimeError(f"assistant message is not a string: {prior_output}") + + trace = self._adapter.all_messages_to_trace(self._messages) + self._result = AdapterStreamResult( + run_output=RunOutput( + output=prior_output, + intermediate_outputs=intermediate_outputs, + output_logprobs=logprobs, + trace=trace, + ), + usage=usage, + ) + self._iterated = True + + async def _stream_model_turn( + self, + skip_response_format: bool, + top_logprobs: int | None, + ) -> AsyncIterator[AdapterStreamEvent | _ModelTurnComplete]: + usage = Usage() + tool_calls_count = 0 + + while tool_calls_count < MAX_TOOL_CALLS_PER_TURN: + completion_kwargs = await self._adapter.build_completion_kwargs( + self._provider, + copy.deepcopy(self._messages), + top_logprobs, + skip_response_format, + ) + + stream = StreamingCompletion(**completion_kwargs) + async for chunk in stream: + yield chunk + + response, response_choice = _validate_response(stream.response) + usage += self._adapter.usage_from_response(response) + + content = response_choice.message.content + tool_calls = response_choice.message.tool_calls + if not content and not tool_calls: + raise ValueError( + "Model returned an assistant message, but no content or tool calls. This is not supported." + ) + + self._messages.append(response_choice.message) + + if tool_calls and len(tool_calls) > 0: + async for event in self._handle_tool_calls(tool_calls): + yield event + + assistant_msg = self._extract_task_response(tool_calls) + if assistant_msg is not None: + yield _ModelTurnComplete( + assistant_message=assistant_msg, + model_choice=response_choice, + usage=usage, + ) + return + + tool_calls_count += 1 + continue + + if content: + yield _ModelTurnComplete( + assistant_message=content, + model_choice=response_choice, + usage=usage, + ) + return + + raise RuntimeError( + "Model returned neither content nor tool calls. It must return at least one of these." + ) + + raise RuntimeError( + f"Too many tool calls ({tool_calls_count}). Stopping iteration to avoid using too many tokens." + ) + + async def _handle_tool_calls( + self, + tool_calls: list[ChatCompletionMessageToolCall], + ) -> AsyncIterator[AdapterStreamEvent]: + real_tool_calls = [ + tc for tc in tool_calls if tc.function.name != "task_response" + ] + + for tc in real_tool_calls: + try: + parsed_args = json.loads(tc.function.arguments) + except (json.JSONDecodeError, TypeError): + parsed_args = None + + yield ToolCallEvent( + event_type=ToolCallEventType.INPUT_AVAILABLE, + tool_call_id=tc.id, + tool_name=tc.function.name or "unknown", + arguments=parsed_args, + error=( + f"Failed to parse arguments: {tc.function.arguments}" + if parsed_args is None + else None + ), + ) + + _, tool_msgs = await self._adapter.process_tool_calls(tool_calls) + + for tool_msg in tool_msgs: + tc_id = tool_msg["tool_call_id"] + tc_name = _find_tool_name(tool_calls, tc_id) + content = tool_msg["content"] + yield ToolCallEvent( + event_type=ToolCallEventType.OUTPUT_AVAILABLE, + tool_call_id=tc_id, + tool_name=tc_name, + result=str(content) if content is not None else None, + ) + + self._messages.extend(tool_msgs) + + @staticmethod + def _extract_task_response( + tool_calls: list[ChatCompletionMessageToolCall], + ) -> str | None: + for tc in tool_calls: + if tc.function.name == "task_response": + return tc.function.arguments + return None + + +@dataclass +class _ModelTurnComplete: + """Internal sentinel yielded when a model turn finishes.""" + + assistant_message: str + model_choice: Choices | None + usage: Usage + + +def _validate_response( + response: Any, +) -> tuple[ModelResponse, Choices]: + if ( + not isinstance(response, ModelResponse) + or not response.choices + or len(response.choices) == 0 + or not isinstance(response.choices[0], Choices) + ): + raise RuntimeError( + f"Expected ModelResponse with Choices, got {type(response)}." + ) + return response, response.choices[0] + + +def _find_tool_name( + tool_calls: list[ChatCompletionMessageToolCall], tool_call_id: str +) -> str: + for tc in tool_calls: + if tc.id == tool_call_id: + return tc.function.name or "unknown" + return "unknown" diff --git a/libs/core/kiln_ai/adapters/model_adapters/base_adapter.py b/libs/core/kiln_ai/adapters/model_adapters/base_adapter.py index 807aabb8e..698cd7158 100644 --- a/libs/core/kiln_ai/adapters/model_adapters/base_adapter.py +++ b/libs/core/kiln_ai/adapters/model_adapters/base_adapter.py @@ -1,14 +1,30 @@ +from __future__ import annotations + import json +import uuid from abc import ABCMeta, abstractmethod from dataclasses import dataclass -from typing import Dict, Tuple +from typing import TYPE_CHECKING, AsyncIterator, Dict, Tuple + +from litellm.types.utils import ModelResponseStream -from kiln_ai.adapters.chat.chat_formatter import ChatFormatter, get_chat_formatter +from kiln_ai.adapters.chat.chat_formatter import ( + ChatFormatter, + MultiturnFormatter, + get_chat_formatter, +) from kiln_ai.adapters.ml_model_list import ( KilnModelProvider, StructuredOutputMode, default_structured_output_mode_for_model_provider, ) +from kiln_ai.adapters.model_adapters.adapter_stream import AdapterStreamResult +from kiln_ai.adapters.model_adapters.stream_events import ( + AiSdkEventType, + AiSdkStreamConverter, + AiSdkStreamEvent, + ToolCallEvent, +) from kiln_ai.adapters.parsers.json_parser import parse_json_string from kiln_ai.adapters.parsers.parser_registry import model_parser_from_id from kiln_ai.adapters.parsers.request_formatters import request_formatter_from_id @@ -48,6 +64,9 @@ from kiln_ai.utils.exhaustive_error import raise_exhaustive_enum_error from kiln_ai.utils.open_ai_types import ChatCompletionMessageParam +if TYPE_CHECKING: + from kiln_ai.adapters.model_adapters.adapter_stream import AdapterStream + SkillsDict = Dict[str, Skill] @@ -136,14 +155,20 @@ async def invoke( self, input: InputType, input_source: DataSource | None = None, + prior_trace: list[ChatCompletionMessageParam] | None = None, + parent_task_run: TaskRun | None = None, ) -> TaskRun: - run_output, _ = await self.invoke_returning_run_output(input, input_source) + run_output, _ = await self.invoke_returning_run_output( + input, input_source, prior_trace, parent_task_run + ) return run_output async def _run_returning_run_output( self, input: InputType, input_source: DataSource | None = None, + prior_trace: list[ChatCompletionMessageParam] | None = None, + parent_task_run: TaskRun | None = None, ) -> Tuple[TaskRun, RunOutput]: # validate input, allowing arrays if self.input_schema is not None: @@ -154,6 +179,8 @@ async def _run_returning_run_output( require_object=False, ) + prior_trace = prior_trace if prior_trace else None + # Format model input for model call (we save the original input in the task without formatting) formatted_input = input formatter_id = self.model_provider().formatter @@ -162,7 +189,7 @@ async def _run_returning_run_output( formatted_input = formatter.format_input(input) # Run - run_output, usage = await self._run(formatted_input) + run_output, usage = await self._run(formatted_input, prior_trace=prior_trace) # Parse provider = self.model_provider() @@ -211,9 +238,8 @@ async def _run_returning_run_output( "Reasoning is required for this model, but no reasoning was returned." ) - # Generate the run and output run = self.generate_run( - input, input_source, parsed_output, usage, run_output.trace + input, input_source, parsed_output, usage, run_output.trace, parent_task_run ) # Save the run if configured to do so, and we have a path to save to @@ -233,6 +259,8 @@ async def invoke_returning_run_output( self, input: InputType, input_source: DataSource | None = None, + prior_trace: list[ChatCompletionMessageParam] | None = None, + parent_task_run: TaskRun | None = None, ) -> Tuple[TaskRun, RunOutput]: # Determine if this is the root agent (no existing run context) is_root_agent = get_agent_run_id() is None @@ -242,7 +270,9 @@ async def invoke_returning_run_output( set_agent_run_id(run_id) try: - return await self._run_returning_run_output(input, input_source) + return await self._run_returning_run_output( + input, input_source, prior_trace, parent_task_run + ) finally: if is_root_agent: try: @@ -252,6 +282,140 @@ async def invoke_returning_run_output( finally: clear_agent_run_id() + def invoke_openai_stream( + self, + input: InputType, + input_source: DataSource | None = None, + prior_trace: list[ChatCompletionMessageParam] | None = None, + parent_task_run: TaskRun | None = None, + ) -> OpenAIStreamResult: + """Stream raw OpenAI-protocol chunks for the task execution. + + Returns an async-iterable that yields ``ModelResponseStream`` chunks + as they arrive from the model. After the iterator is exhausted the + run has been validated and saved (when configured). The resulting + ``TaskRun`` is available via the ``.task_run`` property. + + Tool-call rounds happen internally and are not surfaced; use + ``invoke_ai_sdk_stream`` if you need tool-call events. + """ + return OpenAIStreamResult( + self, input, input_source, prior_trace, parent_task_run + ) + + def invoke_ai_sdk_stream( + self, + input: InputType, + input_source: DataSource | None = None, + prior_trace: list[ChatCompletionMessageParam] | None = None, + parent_task_run: TaskRun | None = None, + ) -> AiSdkStreamResult: + """Stream AI SDK protocol events for the task execution. + + Returns an async-iterable that yields ``AiSdkStreamEvent`` instances + covering text, reasoning, tool-call lifecycle, step boundaries, and + control events. After the iterator is exhausted the resulting + ``TaskRun`` is available via the ``.task_run`` property. + """ + return AiSdkStreamResult( + self, input, input_source, prior_trace, parent_task_run + ) + + def _prepare_stream( + self, + input: InputType, + prior_trace: list[ChatCompletionMessageParam] | None, + ) -> AdapterStream: + if self.input_schema is not None: + validate_schema_with_value_error( + input, + self.input_schema, + "This task requires a specific input schema. While the model produced JSON, that JSON didn't meet the schema. Search 'Troubleshooting Structured Data Issues' in our docs for more information.", + require_object=False, + ) + + prior_trace = prior_trace if prior_trace else None + + formatted_input = input + formatter_id = self.model_provider().formatter + if formatter_id is not None: + formatter = request_formatter_from_id(formatter_id) + formatted_input = formatter.format_input(input) + + return self._create_run_stream(formatted_input, prior_trace) + + def _finalize_stream( + self, + adapter_stream: AdapterStream, + input: InputType, + input_source: DataSource | None, + parent_task_run: TaskRun | None = None, + ) -> TaskRun: + """Streaming invocations are only concerned with passing through events as they come in. + At the end of the stream, we still need to validate the output, create a run and everything + else that a non-streaming invocation would do. + """ + + result: AdapterStreamResult = adapter_stream.result + run_output = result.run_output + usage = result.usage + + provider = self.model_provider() + parser = model_parser_from_id(provider.parser) + parsed_output = parser.parse_output(original_output=run_output) + + if self.output_schema is not None: + if isinstance(parsed_output.output, str): + parsed_output.output = parse_json_string(parsed_output.output) + if not isinstance(parsed_output.output, dict): + raise RuntimeError( + f"structured response is not a dict: {parsed_output.output}" + ) + validate_schema_with_value_error( + parsed_output.output, + self.output_schema, + "This task requires a specific output schema. While the model produced JSON, that JSON didn't meet the schema. Search 'Troubleshooting Structured Data Issues' in our docs for more information.", + ) + else: + if not isinstance(parsed_output.output, str): + raise RuntimeError( + f"response is not a string for non-structured task: {parsed_output.output}" + ) + + trace_has_toolcalls = parsed_output.trace is not None and any( + message.get("role", None) == "tool" for message in parsed_output.trace + ) + if ( + provider.reasoning_capable + and ( + not parsed_output.intermediate_outputs + or "reasoning" not in parsed_output.intermediate_outputs + ) + and not ( + provider.reasoning_optional_for_structured_output + and self.has_structured_output() + ) + and not trace_has_toolcalls + ): + raise RuntimeError( + "Reasoning is required for this model, but no reasoning was returned." + ) + + run = self.generate_run( + input, input_source, parsed_output, usage, run_output.trace, parent_task_run + ) + + if ( + self.base_adapter_config.allow_saving + and Config.shared().autosave_runs + and self.task.path is not None + ): + run.save_to_file() + else: + run.id = None + + return run + def has_structured_output(self) -> bool: return self.output_schema is not None @@ -260,9 +424,21 @@ def adapter_name(self) -> str: pass @abstractmethod - async def _run(self, input: InputType) -> Tuple[RunOutput, Usage | None]: + async def _run( + self, + input: InputType, + prior_trace: list[ChatCompletionMessageParam] | None = None, + ) -> Tuple[RunOutput, Usage | None]: pass + def _create_run_stream( + self, + input: InputType, + prior_trace: list[ChatCompletionMessageParam] | None = None, + ) -> AdapterStream: + """Create a stream for the adapter. Implementations must override this method to support streaming.""" + raise NotImplementedError("Streaming is not supported for this adapter type") + def build_prompt(self) -> str: if self.prompt_builder is None: raise ValueError("Prompt builder is not available for MCP run config") @@ -333,7 +509,13 @@ def _resolve_skills(self) -> list[Skill]: self._resolved_skills = skills return self._resolved_skills - def build_chat_formatter(self, input: InputType) -> ChatFormatter: + def build_chat_formatter( + self, + input: InputType, + prior_trace: list[ChatCompletionMessageParam] | None = None, + ) -> ChatFormatter: + if prior_trace is not None: + return MultiturnFormatter(prior_trace, input) if self.prompt_builder is None: raise ValueError("Prompt builder is not available for MCP run config") # Determine the chat strategy to use based on the prompt the user selected, the model's capabilities, and if the model was finetuned with a specific chat strategy. @@ -389,52 +571,51 @@ def generate_run( run_output: RunOutput, usage: Usage | None = None, trace: list[ChatCompletionMessageParam] | None = None, + parent_task_run: TaskRun | None = None, ) -> TaskRun: - # Convert input and output to JSON strings if they aren't strings - input_str = ( - input if isinstance(input, str) else json.dumps(input, ensure_ascii=False) - ) output_str = ( json.dumps(run_output.output, ensure_ascii=False) if isinstance(run_output.output, dict) else run_output.output ) - # If no input source is provided, use the human data source - if input_source is None: - input_source = DataSource( - type=DataSourceType.human, - properties={"created_by": Config.shared().user_id}, - ) - - # Synthetic since an adapter, not a human, is creating this - # Special case for MCP run configs which calls a mcp tool output_source_type = ( DataSourceType.tool_call if self.run_config.type == "mcp" else DataSourceType.synthetic ) - new_task_run = TaskRun( - parent=self.task, + new_output = TaskOutput( + output=output_str, + source=DataSource( + type=output_source_type, + properties=self._properties_for_task_output(), + run_config=self.run_config, + ), + ) + + # Convert input and output to JSON strings if they aren't strings + input_str = ( + input if isinstance(input, str) else json.dumps(input, ensure_ascii=False) + ) + + if input_source is None: + input_source = DataSource( + type=DataSourceType.human, + properties={"created_by": Config.shared().user_id}, + ) + + return TaskRun( + parent=parent_task_run if parent_task_run is not None else self.task, input=input_str, input_source=input_source, - output=TaskOutput( - output=output_str, - source=DataSource( - type=output_source_type, - properties=self._properties_for_task_output(), - run_config=self.run_config, - ), - ), + output=new_output, intermediate_outputs=run_output.intermediate_outputs, tags=self.base_adapter_config.default_tags or [], usage=usage, trace=trace, ) - return new_task_run - def _properties_for_task_output(self) -> Dict[str, str | int | float]: match self.run_config.type: case "mcp": @@ -509,3 +690,144 @@ async def available_tools(self) -> list[KilnToolInterface]: ) return tools + + +class OpenAIStreamResult: + """Async-iterable wrapper around the OpenAI streaming flow. + + Yields ``ModelResponseStream`` chunks. After iteration the resulting + ``TaskRun`` is available via the ``.task_run`` property. + """ + + def __init__( + self, + adapter: BaseAdapter, + input: InputType, + input_source: DataSource | None, + prior_trace: list[ChatCompletionMessageParam] | None, + parent_task_run: TaskRun | None = None, + ) -> None: + self._adapter = adapter + self._input = input + self._input_source = input_source + self._prior_trace = prior_trace + self._parent_task_run = parent_task_run + self._task_run: TaskRun | None = None + + @property + def task_run(self) -> TaskRun: + if self._task_run is None: + raise RuntimeError( + "Stream has not been fully consumed yet. " + "Iterate over the stream before accessing .task_run" + ) + return self._task_run + + async def __aiter__(self) -> AsyncIterator[ModelResponseStream]: + self._task_run = None + is_root_agent = get_agent_run_id() is None + if is_root_agent: + set_agent_run_id(generate_agent_run_id()) + + try: + adapter_stream = self._adapter._prepare_stream( + self._input, self._prior_trace + ) + + async for event in adapter_stream: + if isinstance(event, ModelResponseStream): + yield event + + self._task_run = self._adapter._finalize_stream( + adapter_stream, self._input, self._input_source, self._parent_task_run + ) + finally: + if is_root_agent: + try: + run_id = get_agent_run_id() + if run_id: + await MCPSessionManager.shared().cleanup_session(run_id) + finally: + clear_agent_run_id() + + +class AiSdkStreamResult: + """Async-iterable wrapper around the AI SDK streaming flow. + + Yields ``AiSdkStreamEvent`` instances. After iteration the resulting + ``TaskRun`` is available via the ``.task_run`` property. + """ + + def __init__( + self, + adapter: BaseAdapter, + input: InputType, + input_source: DataSource | None, + prior_trace: list[ChatCompletionMessageParam] | None, + parent_task_run: TaskRun | None = None, + ) -> None: + self._adapter = adapter + self._input = input + self._input_source = input_source + self._prior_trace = prior_trace + self._parent_task_run = parent_task_run + self._task_run: TaskRun | None = None + + @property + def task_run(self) -> TaskRun: + if self._task_run is None: + raise RuntimeError( + "Stream has not been fully consumed yet. " + "Iterate over the stream before accessing .task_run" + ) + return self._task_run + + async def __aiter__(self) -> AsyncIterator[AiSdkStreamEvent]: + self._task_run = None + is_root_agent = get_agent_run_id() is None + if is_root_agent: + set_agent_run_id(generate_agent_run_id()) + + try: + adapter_stream = self._adapter._prepare_stream( + self._input, self._prior_trace + ) + + message_id = f"msg-{uuid.uuid4().hex}" + converter = AiSdkStreamConverter() + + yield AiSdkStreamEvent(AiSdkEventType.START, {"messageId": message_id}) + yield AiSdkStreamEvent(AiSdkEventType.START_STEP) + + last_event_was_tool_call = False + async for event in adapter_stream: + if isinstance(event, ModelResponseStream): + if last_event_was_tool_call: + converter.reset_for_next_step() + last_event_was_tool_call = False + for ai_event in converter.convert_chunk(event): + yield ai_event + elif isinstance(event, ToolCallEvent): + last_event_was_tool_call = True + for ai_event in converter.convert_tool_event(event): + yield ai_event + + for ai_event in converter.close_open_blocks(): + yield ai_event + + yield AiSdkStreamEvent(AiSdkEventType.FINISH_STEP) + + self._task_run = self._adapter._finalize_stream( + adapter_stream, self._input, self._input_source, self._parent_task_run + ) + + for ai_event in converter.finalize(): + yield ai_event + finally: + if is_root_agent: + try: + run_id = get_agent_run_id() + if run_id: + await MCPSessionManager.shared().cleanup_session(run_id) + finally: + clear_agent_run_id() diff --git a/libs/core/kiln_ai/adapters/model_adapters/litellm_adapter.py b/libs/core/kiln_ai/adapters/model_adapters/litellm_adapter.py index bc0dc00c1..da3999c59 100644 --- a/libs/core/kiln_ai/adapters/model_adapters/litellm_adapter.py +++ b/libs/core/kiln_ai/adapters/model_adapters/litellm_adapter.py @@ -3,7 +3,7 @@ import json import logging from dataclasses import dataclass -from typing import Any, Dict, List, Tuple, TypeAlias, Union +from typing import Any, Dict, List, Tuple import litellm from litellm.types.utils import ( @@ -19,11 +19,13 @@ ) import kiln_ai.datamodel as datamodel +from kiln_ai.adapters.chat import ChatCompletionMessageIncludingLiteLLM from kiln_ai.adapters.ml_model_list import ( KilnModelProvider, ModelProviderName, StructuredOutputMode, ) +from kiln_ai.adapters.model_adapters.adapter_stream import AdapterStream from kiln_ai.adapters.model_adapters.base_adapter import ( AdapterConfig, BaseAdapter, @@ -56,10 +58,6 @@ logger = logging.getLogger(__name__) -ChatCompletionMessageIncludingLiteLLM: TypeAlias = Union[ - ChatCompletionMessageParam, LiteLLMMessage -] - @dataclass class ModelTurnResult: @@ -184,20 +182,29 @@ async def _run_model_turn( f"Too many tool calls ({tool_calls_count}). Stopping iteration to avoid using too many tokens." ) - async def _run(self, input: InputType) -> tuple[RunOutput, Usage | None]: + async def _run( + self, + input: InputType, + prior_trace: list[ChatCompletionMessageParam] | None = None, + ) -> tuple[RunOutput, Usage | None]: usage = Usage() provider = self.model_provider() if not provider.model_id: raise ValueError("Model ID is required for OpenAI compatible models") - chat_formatter = self.build_chat_formatter(input) - messages: list[ChatCompletionMessageIncludingLiteLLM] = [] + # build_chat_formatter returns MultiturnFormatter when prior_trace is set, else prompt-based formatter + chat_formatter = self.build_chat_formatter(input, prior_trace) + messages: list[ChatCompletionMessageIncludingLiteLLM] = copy.deepcopy( + chat_formatter.initial_messages() + ) prior_output: str | None = None final_choice: Choices | None = None turns = 0 + # Same loop for both fresh runs and prior_trace continuation. + # _run_model_turn has its own internal loop for tool calls (model calls tool -> we run it -> model continues). while True: turns += 1 if turns > MAX_CALLS_PER_TURN: @@ -255,6 +262,28 @@ async def _run(self, input: InputType) -> tuple[RunOutput, Usage | None]: return output, usage + def _create_run_stream( + self, + input: InputType, + prior_trace: list[ChatCompletionMessageParam] | None = None, + ) -> AdapterStream: + provider = self.model_provider() + if not provider.model_id: + raise ValueError("Model ID is required for OpenAI compatible models") + + chat_formatter = self.build_chat_formatter(input, prior_trace) + initial_messages: list[ChatCompletionMessageIncludingLiteLLM] = copy.deepcopy( + chat_formatter.initial_messages() + ) + + return AdapterStream( + adapter=self, + provider=provider, + chat_formatter=chat_formatter, + initial_messages=initial_messages, + top_logprobs=self.base_adapter_config.top_logprobs, + ) + def _extract_and_validate_logprobs( self, final_choice: Choices | None ) -> ChoiceLogprobs | None: @@ -291,9 +320,10 @@ def _extract_reasoning_to_intermediate_outputs( intermediate_outputs["reasoning"] = stripped_reasoning_content async def acompletion_checking_response( - self, **kwargs + self, **kwargs: Any ) -> Tuple[ModelResponse, Choices]: response = await litellm.acompletion(**kwargs) + if ( not isinstance(response, ModelResponse) or not response.choices diff --git a/libs/core/kiln_ai/adapters/model_adapters/mcp_adapter.py b/libs/core/kiln_ai/adapters/model_adapters/mcp_adapter.py index a7f1a8b6c..9b38ff9bb 100644 --- a/libs/core/kiln_ai/adapters/model_adapters/mcp_adapter.py +++ b/libs/core/kiln_ai/adapters/model_adapters/mcp_adapter.py @@ -42,7 +42,17 @@ def __init__( def adapter_name(self) -> str: return "mcp_adapter" - async def _run(self, input: InputType) -> Tuple[RunOutput, Usage | None]: + async def _run( + self, + input: InputType, + prior_trace: list[ChatCompletionMessageParam] | None = None, + ) -> Tuple[RunOutput, Usage | None]: + if prior_trace is not None: + raise NotImplementedError( + "Session continuation is not supported for MCP adapter. " + "MCP tools are single-turn and do not maintain conversation state." + ) + run_config = self.run_config if not isinstance(run_config, McpRunConfigProperties): raise ValueError("MCPAdapter requires McpRunConfigProperties") @@ -75,19 +85,37 @@ async def invoke( self, input: InputType, input_source: DataSource | None = None, + prior_trace: list[ChatCompletionMessageParam] | None = None, + parent_task_run: TaskRun | None = None, ) -> TaskRun: - run_output, _ = await self.invoke_returning_run_output(input, input_source) + if prior_trace or parent_task_run is not None: + raise NotImplementedError( + "Session continuation is not supported for MCP adapter. " + "MCP tools are single-turn and do not maintain conversation state." + ) + + run_output, _ = await self.invoke_returning_run_output( + input, input_source, prior_trace, parent_task_run + ) return run_output async def invoke_returning_run_output( self, input: InputType, input_source: DataSource | None = None, + prior_trace: list[ChatCompletionMessageParam] | None = None, + parent_task_run: TaskRun | None = None, ) -> Tuple[TaskRun, RunOutput]: """ Runs the task and returns both the persisted TaskRun and raw RunOutput. If this call is the root of a run, it creates an agent run context, ensures MCP tool calls have a valid session scope, and cleans up the session/context on completion. """ + if prior_trace or parent_task_run is not None: + raise NotImplementedError( + "Session continuation is not supported for MCP adapter. " + "MCP tools are single-turn and do not maintain conversation state." + ) + is_root_agent = get_agent_run_id() is None if is_root_agent: @@ -95,7 +123,9 @@ async def invoke_returning_run_output( set_agent_run_id(run_id) try: - return await self._run_and_validate_output(input, input_source) + return await self._run_and_validate_output( + input, input_source, parent_task_run + ) finally: if is_root_agent: try: @@ -109,6 +139,7 @@ async def _run_and_validate_output( self, input: InputType, input_source: DataSource | None, + parent_task_run: TaskRun | None = None, ) -> Tuple[TaskRun, RunOutput]: """ Run the MCP task and validate the output. @@ -147,7 +178,9 @@ async def _run_and_validate_output( # Build single turn trace trace = self._build_single_turn_trace(input, run_output.output) - run = self.generate_run(input, input_source, run_output, usage, trace) + run = self.generate_run( + input, input_source, run_output, usage, trace, parent_task_run + ) if ( self.base_adapter_config.allow_saving diff --git a/libs/core/kiln_ai/adapters/model_adapters/stream_events.py b/libs/core/kiln_ai/adapters/model_adapters/stream_events.py new file mode 100644 index 000000000..eb023f910 --- /dev/null +++ b/libs/core/kiln_ai/adapters/model_adapters/stream_events.py @@ -0,0 +1,311 @@ +from __future__ import annotations + +import uuid +from dataclasses import dataclass, field +from enum import Enum +from typing import Any + +from litellm.types.utils import ModelResponseStream + + +class AiSdkEventType(str, Enum): + START = "start" + FINISH = "finish" + ERROR = "error" + ABORT = "abort" + + TEXT_START = "text-start" + TEXT_DELTA = "text-delta" + TEXT_END = "text-end" + + REASONING_START = "reasoning-start" + REASONING_DELTA = "reasoning-delta" + REASONING_END = "reasoning-end" + + TOOL_INPUT_START = "tool-input-start" + TOOL_INPUT_DELTA = "tool-input-delta" + TOOL_INPUT_AVAILABLE = "tool-input-available" + TOOL_INPUT_ERROR = "tool-input-error" + + TOOL_OUTPUT_AVAILABLE = "tool-output-available" + TOOL_OUTPUT_ERROR = "tool-output-error" + + START_STEP = "start-step" + FINISH_STEP = "finish-step" + + METADATA = "metadata" + SOURCE_URL = "source-url" + SOURCE_DOCUMENT = "source-document" + FILE = "file" + + +@dataclass +class AiSdkStreamEvent: + type: AiSdkEventType + payload: dict[str, Any] = field(default_factory=dict) + + def model_dump(self) -> dict[str, Any]: + return { + "type": self.type.value, + **self.payload, + } + + +class ToolCallEventType(str, Enum): + INPUT_AVAILABLE = "input_available" + OUTPUT_AVAILABLE = "output_available" + OUTPUT_ERROR = "output_error" + + +@dataclass +class ToolCallEvent: + event_type: ToolCallEventType + tool_call_id: str + tool_name: str + arguments: dict[str, Any] | None = None + result: str | None = None + error: str | None = None + + +AdapterStreamEvent = ModelResponseStream | ToolCallEvent + + +class AiSdkStreamConverter: + """Stateful converter from OpenAI streaming chunks to AI SDK events.""" + + def __init__(self) -> None: + self._text_started = False + self._text_id = f"text-{uuid.uuid4().hex[:12]}" + self._reasoning_started = False + self._reasoning_id = f"reasoning-{uuid.uuid4().hex[:12]}" + self._reasoning_block_count = 0 + self._tool_calls_state: dict[int, dict[str, Any]] = {} + self._finish_reason: str | None = None + self._usage_data: Any = None + + def convert_chunk(self, chunk: ModelResponseStream) -> list[AiSdkStreamEvent]: + events: list[AiSdkStreamEvent] = [] + + for choice in chunk.choices: + if choice.finish_reason is not None: + self._finish_reason = choice.finish_reason + + delta = choice.delta + if delta is None: + continue + + reasoning_content = getattr(delta, "reasoning_content", None) + if reasoning_content: + if not self._reasoning_started: + self._reasoning_block_count += 1 + self._reasoning_id = f"reasoning-{uuid.uuid4().hex[:12]}" + events.append( + AiSdkStreamEvent( + AiSdkEventType.REASONING_START, + {"id": self._reasoning_id}, + ) + ) + self._reasoning_started = True + events.append( + AiSdkStreamEvent( + AiSdkEventType.REASONING_DELTA, + {"id": self._reasoning_id, "delta": reasoning_content}, + ) + ) + + if delta.content: + if self._reasoning_started: + events.append( + AiSdkStreamEvent( + AiSdkEventType.REASONING_END, + {"id": self._reasoning_id}, + ) + ) + self._reasoning_started = False + + if not self._text_started: + self._text_id = f"text-{uuid.uuid4().hex[:12]}" + events.append( + AiSdkStreamEvent( + AiSdkEventType.TEXT_START, + {"id": self._text_id}, + ) + ) + self._text_started = True + events.append( + AiSdkStreamEvent( + AiSdkEventType.TEXT_DELTA, + {"id": self._text_id, "delta": delta.content}, + ) + ) + + if delta.tool_calls: + if self._reasoning_started: + events.append( + AiSdkStreamEvent( + AiSdkEventType.REASONING_END, + {"id": self._reasoning_id}, + ) + ) + self._reasoning_started = False + + if self._text_started: + events.append( + AiSdkStreamEvent( + AiSdkEventType.TEXT_END, + {"id": self._text_id}, + ) + ) + self._text_started = False + + for tc_delta in delta.tool_calls: + idx = tc_delta.index + tc_state = self._tool_calls_state.setdefault( + idx, + { + "id": None, + "name": None, + "arguments": "", + "started": False, + }, + ) + + if tc_delta.id is not None: + tc_state["id"] = tc_delta.id + + func = getattr(tc_delta, "function", None) + if func is not None: + if func.name is not None: + tc_state["name"] = func.name + if func.arguments: + tc_state["arguments"] += func.arguments + + if tc_state["id"] and tc_state["name"] and not tc_state["started"]: + events.append( + AiSdkStreamEvent( + AiSdkEventType.TOOL_INPUT_START, + { + "toolCallId": tc_state["id"], + "toolName": tc_state["name"], + }, + ) + ) + tc_state["started"] = True + + if func and func.arguments and tc_state["id"]: + events.append( + AiSdkStreamEvent( + AiSdkEventType.TOOL_INPUT_DELTA, + { + "toolCallId": tc_state["id"], + "inputTextDelta": func.arguments, + }, + ) + ) + + if not chunk.choices: + usage = getattr(chunk, "usage", None) + if usage is not None: + self._usage_data = usage + + return events + + def convert_tool_event(self, event: ToolCallEvent) -> list[AiSdkStreamEvent]: + events: list[AiSdkStreamEvent] = [] + + if event.event_type == ToolCallEventType.INPUT_AVAILABLE: + events.append( + AiSdkStreamEvent( + AiSdkEventType.TOOL_INPUT_AVAILABLE, + { + "toolCallId": event.tool_call_id, + "toolName": event.tool_name, + "input": event.arguments or {}, + }, + ) + ) + elif event.event_type == ToolCallEventType.OUTPUT_AVAILABLE: + events.append( + AiSdkStreamEvent( + AiSdkEventType.TOOL_OUTPUT_AVAILABLE, + { + "toolCallId": event.tool_call_id, + "output": event.result, + }, + ) + ) + elif event.event_type == ToolCallEventType.OUTPUT_ERROR: + events.append( + AiSdkStreamEvent( + AiSdkEventType.TOOL_OUTPUT_ERROR, + { + "toolCallId": event.tool_call_id, + "errorText": event.error or "Unknown error", + }, + ) + ) + + return events + + def close_open_blocks(self) -> list[AiSdkStreamEvent]: + """Close any open text/reasoning blocks. Call before FINISH_STEP.""" + events: list[AiSdkStreamEvent] = [] + + if self._reasoning_started: + events.append( + AiSdkStreamEvent( + AiSdkEventType.REASONING_END, + {"id": self._reasoning_id}, + ) + ) + self._reasoning_started = False + + if self._text_started: + events.append( + AiSdkStreamEvent( + AiSdkEventType.TEXT_END, + {"id": self._text_id}, + ) + ) + self._text_started = False + + return events + + def finalize(self) -> list[AiSdkStreamEvent]: + """Emit the terminal FINISH event with usage/finish metadata. Call after FINISH_STEP.""" + events: list[AiSdkStreamEvent] = [] + + finish_payload: dict[str, Any] = {} + if self._finish_reason is not None: + finish_payload["finishReason"] = self._finish_reason.replace("_", "-") + + if self._usage_data is not None: + usage_payload: dict[str, Any] = { + "promptTokens": self._usage_data.prompt_tokens, + "completionTokens": self._usage_data.completion_tokens, + } + total = getattr(self._usage_data, "total_tokens", None) + if total is not None: + usage_payload["totalTokens"] = total + finish_payload["usage"] = usage_payload + + if finish_payload: + events.append( + AiSdkStreamEvent( + AiSdkEventType.FINISH, + {"messageMetadata": finish_payload}, + ) + ) + else: + events.append(AiSdkStreamEvent(AiSdkEventType.FINISH)) + + return events + + def reset_for_next_step(self) -> None: + """Reset per-step state between LLM calls in a multi-step flow.""" + self._tool_calls_state = {} + self._finish_reason = None + self._text_started = False + self._reasoning_started = False + self._text_id = f"text-{uuid.uuid4().hex[:12]}" + self._reasoning_id = f"reasoning-{uuid.uuid4().hex[:12]}" diff --git a/libs/core/kiln_ai/adapters/model_adapters/test_adapter_stream.py b/libs/core/kiln_ai/adapters/model_adapters/test_adapter_stream.py new file mode 100644 index 000000000..a104ea066 --- /dev/null +++ b/libs/core/kiln_ai/adapters/model_adapters/test_adapter_stream.py @@ -0,0 +1,498 @@ +import json +from typing import Any +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest +from litellm.types.utils import ( + ChatCompletionMessageToolCall, + Choices, + Delta, + Function, + ModelResponse, + ModelResponseStream, + StreamingChoices, +) +from litellm.types.utils import Message as LiteLLMMessage + +from kiln_ai.adapters.chat import ChatFormatter +from kiln_ai.adapters.model_adapters.adapter_stream import AdapterStream +from kiln_ai.adapters.model_adapters.stream_events import ( + ToolCallEvent, + ToolCallEventType, +) +from kiln_ai.datamodel import Usage + + +def _make_streaming_chunk( + content: str | None = None, + finish_reason: str | None = None, +) -> ModelResponseStream: + delta = Delta(content=content) + choice = StreamingChoices( + index=0, + delta=delta, + finish_reason=finish_reason, + ) + return ModelResponseStream(id="test-stream", choices=[choice]) + + +def _make_model_response( + content: str = "Hello", + tool_calls: list[ChatCompletionMessageToolCall] | None = None, +) -> ModelResponse: + message = LiteLLMMessage(content=content, role="assistant") + if tool_calls is not None: + message.tool_calls = tool_calls + choice = Choices( + index=0, + message=message, + finish_reason="stop" if tool_calls is None else "tool_calls", + ) + return ModelResponse(id="test-response", choices=[choice]) + + +def _make_tool_call( + call_id: str = "call_1", + name: str = "add", + arguments: dict[str, Any] | None = None, +) -> ChatCompletionMessageToolCall: + args = json.dumps(arguments or {"a": 1, "b": 2}) + return ChatCompletionMessageToolCall( + id=call_id, + type="function", + function=Function(name=name, arguments=args), + ) + + +class FakeChatFormatter(ChatFormatter): + """A simple chat formatter that returns a single turn then None.""" + + def __init__(self, num_turns: int = 1): + self._turn_count = 0 + self._num_turns = num_turns + + def next_turn(self, prior_output: str | None): + if self._turn_count >= self._num_turns: + return None + self._turn_count += 1 + turn = MagicMock() + turn.messages = [MagicMock(role="user", content="test input")] + turn.final_call = self._turn_count == self._num_turns + return turn + + def intermediate_outputs(self): + return {} + + +class FakeStreamingCompletion: + """Mocks StreamingCompletion: yields chunks, then exposes .response""" + + def __init__( + self, + model_response: ModelResponse, + chunks: list[ModelResponseStream] | None = None, + ): + self._chunks = chunks or [ + _make_streaming_chunk(content="Hel"), + _make_streaming_chunk(content="lo"), + _make_streaming_chunk(finish_reason="stop"), + ] + self._response = model_response + + @property + def response(self): + return self._response + + async def __aiter__(self): + for chunk in self._chunks: + yield chunk + + +@pytest.fixture +def mock_adapter(): + adapter = MagicMock() + adapter.build_completion_kwargs = AsyncMock(return_value={"model": "test"}) + adapter.usage_from_response = MagicMock(return_value=Usage()) + adapter.process_tool_calls = AsyncMock(return_value=(None, [])) + adapter._extract_and_validate_logprobs = MagicMock(return_value=None) + adapter._extract_reasoning_to_intermediate_outputs = MagicMock() + adapter.all_messages_to_trace = MagicMock(return_value=[]) + adapter.base_adapter_config = MagicMock() + adapter.base_adapter_config.top_logprobs = None + return adapter + + +@pytest.fixture +def mock_provider(): + provider = MagicMock() + provider.model_id = "test-model" + return provider + + +class TestAdapterStreamSimple: + @pytest.mark.asyncio + async def test_simple_content_response(self, mock_adapter, mock_provider): + response = _make_model_response(content="Hello world") + fake_stream = FakeStreamingCompletion(response) + formatter = FakeChatFormatter() + + with patch( + "kiln_ai.adapters.model_adapters.adapter_stream.StreamingCompletion", + return_value=fake_stream, + ): + stream = AdapterStream( + adapter=mock_adapter, + provider=mock_provider, + chat_formatter=formatter, + initial_messages=[], + top_logprobs=None, + ) + + events = [] + async for event in stream: + events.append(event) + + chunks = [e for e in events if isinstance(e, ModelResponseStream)] + assert len(chunks) == 3 + + result = stream.result + assert result.run_output.output == "Hello world" + + @pytest.mark.asyncio + async def test_result_not_available_before_iteration( + self, mock_adapter, mock_provider + ): + stream = AdapterStream( + adapter=mock_adapter, + provider=mock_provider, + chat_formatter=FakeChatFormatter(), + initial_messages=[], + top_logprobs=None, + ) + with pytest.raises(RuntimeError, match="not been iterated"): + _ = stream.result + + +class TestAdapterStreamToolCalls: + @pytest.mark.asyncio + async def test_tool_call_yields_events(self, mock_adapter, mock_provider): + tool_call = _make_tool_call( + call_id="call_1", name="add", arguments={"a": 1, "b": 2} + ) + tool_response = _make_model_response(content=None, tool_calls=[tool_call]) + final_response = _make_model_response(content="The answer is 3") + + tool_stream = FakeStreamingCompletion( + tool_response, + [_make_streaming_chunk(finish_reason="tool_calls")], + ) + final_stream = FakeStreamingCompletion( + final_response, + [ + _make_streaming_chunk(content="The answer is 3"), + _make_streaming_chunk(finish_reason="stop"), + ], + ) + + streams_iter = iter([tool_stream, final_stream]) + + mock_adapter.process_tool_calls = AsyncMock( + return_value=( + None, + [{"role": "tool", "tool_call_id": "call_1", "content": "3"}], + ) + ) + + with patch( + "kiln_ai.adapters.model_adapters.adapter_stream.StreamingCompletion", + side_effect=lambda **kw: next(streams_iter), + ): + stream = AdapterStream( + adapter=mock_adapter, + provider=mock_provider, + chat_formatter=FakeChatFormatter(), + initial_messages=[], + top_logprobs=None, + ) + + events = [] + async for event in stream: + events.append(event) + + tool_events = [e for e in events if isinstance(e, ToolCallEvent)] + assert len(tool_events) == 2 + + input_event = next( + e for e in tool_events if e.event_type == ToolCallEventType.INPUT_AVAILABLE + ) + assert input_event.tool_call_id == "call_1" + assert input_event.tool_name == "add" + assert input_event.arguments == {"a": 1, "b": 2} + + output_event = next( + e for e in tool_events if e.event_type == ToolCallEventType.OUTPUT_AVAILABLE + ) + assert output_event.tool_call_id == "call_1" + assert output_event.result == "3" + + assert stream.result.run_output.output == "The answer is 3" + + @pytest.mark.asyncio + async def test_task_response_tool_call(self, mock_adapter, mock_provider): + task_response_call = _make_tool_call( + call_id="call_tr", name="task_response", arguments={"result": "42"} + ) + response = _make_model_response(content=None, tool_calls=[task_response_call]) + + fake_stream = FakeStreamingCompletion( + response, + [_make_streaming_chunk(finish_reason="tool_calls")], + ) + + mock_adapter.process_tool_calls = AsyncMock( + return_value=('{"result": "42"}', []) + ) + + with patch( + "kiln_ai.adapters.model_adapters.adapter_stream.StreamingCompletion", + return_value=fake_stream, + ): + stream = AdapterStream( + adapter=mock_adapter, + provider=mock_provider, + chat_formatter=FakeChatFormatter(), + initial_messages=[], + top_logprobs=None, + ) + + events = [] + async for event in stream: + events.append(event) + + tool_events = [e for e in events if isinstance(e, ToolCallEvent)] + assert len(tool_events) == 0 + + assert stream.result.run_output.output == '{"result": "42"}' + + @pytest.mark.asyncio + async def test_too_many_tool_calls_raises(self, mock_adapter, mock_provider): + tool_call = _make_tool_call() + response = _make_model_response(content=None, tool_calls=[tool_call]) + + mock_adapter.process_tool_calls = AsyncMock( + return_value=( + None, + [{"role": "tool", "tool_call_id": "call_1", "content": "ok"}], + ) + ) + + def make_stream(**kw): + return FakeStreamingCompletion( + response, + [_make_streaming_chunk(finish_reason="tool_calls")], + ) + + with ( + patch( + "kiln_ai.adapters.model_adapters.adapter_stream.StreamingCompletion", + side_effect=make_stream, + ), + patch( + "kiln_ai.adapters.model_adapters.adapter_stream.MAX_TOOL_CALLS_PER_TURN", + 2, + ), + ): + stream = AdapterStream( + adapter=mock_adapter, + provider=mock_provider, + chat_formatter=FakeChatFormatter(), + initial_messages=[], + top_logprobs=None, + ) + + with pytest.raises(RuntimeError, match="Too many tool calls"): + async for _ in stream: + pass + + @pytest.mark.asyncio + async def test_unparseable_tool_call_arguments(self, mock_adapter, mock_provider): + bad_tool_call = ChatCompletionMessageToolCall( + id="call_bad", + type="function", + function=Function(name="add", arguments="not json"), + ) + response = _make_model_response(content=None, tool_calls=[bad_tool_call]) + final_response = _make_model_response(content="fallback") + + tool_stream = FakeStreamingCompletion( + response, + [_make_streaming_chunk(finish_reason="tool_calls")], + ) + final_stream = FakeStreamingCompletion( + final_response, + [ + _make_streaming_chunk(content="fallback"), + _make_streaming_chunk(finish_reason="stop"), + ], + ) + + streams_iter = iter([tool_stream, final_stream]) + + mock_adapter.process_tool_calls = AsyncMock( + return_value=( + None, + [{"role": "tool", "tool_call_id": "call_bad", "content": "error"}], + ) + ) + + with patch( + "kiln_ai.adapters.model_adapters.adapter_stream.StreamingCompletion", + side_effect=lambda **kw: next(streams_iter), + ): + stream = AdapterStream( + adapter=mock_adapter, + provider=mock_provider, + chat_formatter=FakeChatFormatter(), + initial_messages=[], + top_logprobs=None, + ) + + events = [] + async for event in stream: + events.append(event) + + input_events = [ + e + for e in events + if isinstance(e, ToolCallEvent) + and e.event_type == ToolCallEventType.INPUT_AVAILABLE + ] + assert len(input_events) == 1 + assert input_events[0].arguments is None + assert "Failed to parse" in (input_events[0].error or "") + + +class TestAdapterStreamEdgeCases: + @pytest.mark.asyncio + async def test_too_many_turns_raises(self, mock_adapter, mock_provider): + formatter = FakeChatFormatter(num_turns=15) + response = _make_model_response(content="ok") + fake_stream = FakeStreamingCompletion(response) + + with ( + patch( + "kiln_ai.adapters.model_adapters.adapter_stream.StreamingCompletion", + return_value=fake_stream, + ), + patch( + "kiln_ai.adapters.model_adapters.adapter_stream.MAX_CALLS_PER_TURN", + 2, + ), + ): + stream = AdapterStream( + adapter=mock_adapter, + provider=mock_provider, + chat_formatter=formatter, + initial_messages=[], + top_logprobs=None, + ) + with pytest.raises(RuntimeError, match="Too many turns"): + async for _ in stream: + pass + + @pytest.mark.asyncio + async def test_empty_message_content_raises(self, mock_adapter, mock_provider): + formatter = MagicMock() + turn = MagicMock() + turn.messages = [MagicMock(role="user", content=None)] + turn.final_call = True + formatter.next_turn = MagicMock(side_effect=[turn, None]) + + stream = AdapterStream( + adapter=mock_adapter, + provider=mock_provider, + chat_formatter=formatter, + initial_messages=[], + top_logprobs=None, + ) + with pytest.raises(ValueError, match="Empty message content"): + async for _ in stream: + pass + + @pytest.mark.asyncio + async def test_no_content_or_tool_calls_raises(self, mock_adapter, mock_provider): + response = _make_model_response(content=None, tool_calls=None) + response.choices[0].message.content = None + fake_stream = FakeStreamingCompletion( + response, [_make_streaming_chunk(finish_reason="stop")] + ) + + with patch( + "kiln_ai.adapters.model_adapters.adapter_stream.StreamingCompletion", + return_value=fake_stream, + ): + stream = AdapterStream( + adapter=mock_adapter, + provider=mock_provider, + chat_formatter=FakeChatFormatter(), + initial_messages=[], + top_logprobs=None, + ) + with pytest.raises(ValueError, match="no content or tool calls"): + async for _ in stream: + pass + + +class TestValidateResponse: + def test_valid_response(self): + from kiln_ai.adapters.model_adapters.adapter_stream import _validate_response + + response = _make_model_response(content="hello") + result, choice = _validate_response(response) + assert result is response + assert choice.message.content == "hello" + + def test_none_response_raises(self): + from kiln_ai.adapters.model_adapters.adapter_stream import _validate_response + + with pytest.raises(RuntimeError, match="Expected ModelResponse"): + _validate_response(None) + + def test_empty_choices_raises(self): + from kiln_ai.adapters.model_adapters.adapter_stream import _validate_response + + response = ModelResponse(id="test", choices=[]) + with pytest.raises(RuntimeError, match="Expected ModelResponse"): + _validate_response(response) + + +class TestFindToolName: + def test_found(self): + from kiln_ai.adapters.model_adapters.adapter_stream import _find_tool_name + + tc = ChatCompletionMessageToolCall( + id="call_1", + type="function", + function=Function(name="add", arguments="{}"), + ) + assert _find_tool_name([tc], "call_1") == "add" + + def test_not_found_returns_unknown(self): + from kiln_ai.adapters.model_adapters.adapter_stream import _find_tool_name + + tc = ChatCompletionMessageToolCall( + id="call_1", + type="function", + function=Function(name="add", arguments="{}"), + ) + assert _find_tool_name([tc], "call_999") == "unknown" + + def test_name_is_none_returns_unknown(self): + from kiln_ai.adapters.model_adapters.adapter_stream import _find_tool_name + + tc = ChatCompletionMessageToolCall( + id="call_1", + type="function", + function=Function(name=None, arguments="{}"), + ) + assert _find_tool_name([tc], "call_1") == "unknown" diff --git a/libs/core/kiln_ai/adapters/model_adapters/test_base_adapter.py b/libs/core/kiln_ai/adapters/model_adapters/test_base_adapter.py index 8f17bf8d4..2f6770fa5 100644 --- a/libs/core/kiln_ai/adapters/model_adapters/test_base_adapter.py +++ b/libs/core/kiln_ai/adapters/model_adapters/test_base_adapter.py @@ -1,6 +1,13 @@ from unittest.mock import AsyncMock, MagicMock, patch import pytest +from litellm.types.utils import ( + ChatCompletionDeltaToolCall, + Delta, + Function, + ModelResponseStream, + StreamingChoices, +) from kiln_ai.adapters.ml_model_list import KilnModelProvider, StructuredOutputMode from kiln_ai.adapters.model_adapters.base_adapter import ( @@ -8,9 +15,14 @@ BaseAdapter, RunOutput, ) +from kiln_ai.adapters.model_adapters.stream_events import ( + AiSdkEventType, + ToolCallEvent, + ToolCallEventType, +) from kiln_ai.adapters.prompt_builders import BasePromptBuilder -from kiln_ai.datamodel import Task -from kiln_ai.datamodel.datamodel_enums import ChatStrategy +from kiln_ai.datamodel import Task, TaskRun, Usage +from kiln_ai.datamodel.datamodel_enums import ChatStrategy, ModelProviderName from kiln_ai.datamodel.project import Project from kiln_ai.datamodel.run_config import KilnAgentRunConfigProperties, ToolsRunConfig from kiln_ai.datamodel.skill import Skill @@ -21,7 +33,7 @@ class MockAdapter(BaseAdapter): """Concrete implementation of BaseAdapter for testing""" - async def _run(self, input): + async def _run(self, input, **kwargs): return None, None def adapter_name(self) -> str: @@ -235,7 +247,7 @@ async def test_input_formatting( # Mock the _run method to capture the input captured_input = None - async def mock_run(input): + async def mock_run(input, **kwargs): nonlocal captured_input captured_input = input return RunOutput(output="test output", intermediate_outputs={}), None @@ -424,6 +436,119 @@ def test_build_chat_formatter( mock_prompt_builder.chain_of_thought_prompt.assert_called_once() +def test_build_chat_formatter_with_prior_trace_returns_multiturn_formatter(adapter): + prior_trace = [ + {"role": "user", "content": "hi"}, + {"role": "assistant", "content": "hello"}, + ] + formatter = adapter.build_chat_formatter("new input", prior_trace=prior_trace) + assert formatter.__class__.__name__ == "MultiturnFormatter" + assert formatter.initial_messages() == prior_trace + + +@pytest.mark.asyncio +async def test_invoke_with_prior_trace_none_starts_fresh(base_project): + task = Task( + name="test_task", + instruction="test_instruction", + parent=base_project, + ) + adapter = MockAdapter( + task=task, + run_config=KilnAgentRunConfigProperties( + model_name="gpt_4o", + model_provider_name=ModelProviderName.openai, + prompt_id="simple_prompt_builder", + structured_output_mode=StructuredOutputMode.json_schema, + ), + ) + adapter._run = AsyncMock( + return_value=( + RunOutput(output="ok", intermediate_outputs=None, trace=None), + None, + ) + ) + with ( + patch( + "kiln_ai.adapters.model_adapters.base_adapter.model_parser_from_id", + return_value=MagicMock( + parse_output=MagicMock( + return_value=RunOutput( + output="ok", intermediate_outputs=None, trace=None + ) + ) + ), + ), + patch( + "kiln_ai.adapters.model_adapters.base_adapter.request_formatter_from_id", + ), + patch.object( + adapter, + "model_provider", + return_value=MagicMock( + parser="default", + formatter=None, + reasoning_capable=False, + ), + ), + ): + run = await adapter.invoke("input", prior_trace=None) + assert run.output.output == "ok" + adapter._run.assert_called_once() + assert adapter._run.call_args[1].get("prior_trace") is None + + +@pytest.mark.asyncio +async def test_invoke_returning_run_output_passes_prior_trace_to_run( + adapter, mock_parser, tmp_path +): + project = Project(name="proj", path=tmp_path / "proj.kiln") + project.save_to_file() + task = Task( + name="t", + instruction="i", + parent=project, + ) + task.save_to_file() + adapter.task = task + + trace = [ + {"role": "user", "content": "hi"}, + {"role": "assistant", "content": "hello"}, + ] + + captured_prior_trace = None + + async def mock_run(input, **kwargs): + nonlocal captured_prior_trace + captured_prior_trace = kwargs.get("prior_trace") + return RunOutput(output="ok", intermediate_outputs=None, trace=trace), None + + adapter._run = mock_run + + provider = MagicMock() + provider.parser = "test_parser" + provider.formatter = None + provider.reasoning_capable = False + adapter.model_provider = MagicMock(return_value=provider) + mock_parser.parse_output.return_value = RunOutput( + output="ok", intermediate_outputs=None, trace=trace + ) + + with ( + patch( + "kiln_ai.adapters.model_adapters.base_adapter.model_parser_from_id", + return_value=mock_parser, + ), + patch( + "kiln_ai.adapters.model_adapters.base_adapter.request_formatter_from_id", + ), + ): + await adapter.invoke_returning_run_output("follow-up", prior_trace=trace) + + assert captured_prior_trace == trace + + @pytest.mark.parametrize( "initial_mode,expected_mode", [ @@ -685,7 +810,7 @@ async def test_invoke_sets_run_context(self, adapter, clear_context): from kiln_ai.run_context import get_agent_run_id # Mock the _run method - async def mock_run(input): + async def mock_run(input, **kwargs): # Check that run ID is set during _run run_id = get_agent_run_id() assert run_id is not None @@ -725,7 +850,7 @@ async def test_invoke_clears_run_context_after(self, adapter, clear_context): from kiln_ai.run_context import get_agent_run_id # Mock the _run method - async def mock_run(input): + async def mock_run(input, **kwargs): return RunOutput(output="test output", intermediate_outputs={}), None adapter._run = mock_run @@ -763,7 +888,7 @@ async def test_invoke_clears_run_context_on_error(self, adapter, clear_context): from kiln_ai.run_context import get_agent_run_id # Mock the _run method to raise an error - async def mock_run(input): + async def mock_run(input, **kwargs): # Run ID should be set even when error occurs run_id = get_agent_run_id() assert run_id is not None @@ -800,7 +925,7 @@ async def test_sub_agent_inherits_run(self, adapter, clear_context): set_agent_run_id(parent_run_id) # Mock the _run method to check inherited run ID - async def mock_run(input): + async def mock_run(input, **kwargs): # Sub-agent should see parent's run ID run_id = get_agent_run_id() assert run_id == parent_run_id @@ -849,7 +974,7 @@ async def test_sub_agent_does_not_create_new_run(self, adapter, clear_context): run_id_during_run = None # Mock the _run method to capture run ID - async def mock_run(input): + async def mock_run(input, **kwargs): nonlocal run_id_during_run run_id_during_run = get_agent_run_id() return RunOutput(output="test output", intermediate_outputs={}), None @@ -889,7 +1014,7 @@ async def test_cleanup_session_called_on_completion(self, adapter, clear_context from kiln_ai.adapters.run_output import RunOutput # Mock the _run method - async def mock_run(input): + async def mock_run(input, **kwargs): return RunOutput(output="test output", intermediate_outputs={}), None adapter._run = mock_run @@ -934,6 +1059,356 @@ async def mock_run(input): assert run_id.startswith("run_") +class TestStreamMethods: + """Tests for the streaming methods on BaseAdapter.""" + + @pytest.fixture + def stream_adapter(self, base_task): + return MockAdapter( + task=base_task, + run_config=KilnAgentRunConfigProperties( + model_name="test_model", + model_provider_name="openai", + prompt_id="simple_prompt_builder", + structured_output_mode="json_schema", + ), + ) + + @pytest.mark.asyncio + async def test_invoke_openai_stream_raises_for_unsupported_adapter( + self, stream_adapter + ): + """MockAdapter does not implement _create_run_stream.""" + provider = MagicMock() + provider.formatter = None + stream_adapter.model_provider = MagicMock(return_value=provider) + + with pytest.raises(NotImplementedError, match="Streaming is not supported"): + async for _chunk in stream_adapter.invoke_openai_stream("test input"): + pass + + @pytest.mark.asyncio + async def test_invoke_ai_sdk_stream_raises_for_unsupported_adapter( + self, stream_adapter + ): + """MockAdapter does not implement _create_run_stream.""" + provider = MagicMock() + provider.formatter = None + stream_adapter.model_provider = MagicMock(return_value=provider) + + with pytest.raises(NotImplementedError, match="Streaming is not supported"): + async for _event in stream_adapter.invoke_ai_sdk_stream("test input"): + pass + + @pytest.mark.asyncio + async def test_invoke_ai_sdk_stream_resets_converter_between_tool_rounds( + self, stream_adapter + ): + """tool-input-start must be emitted for a new tool call at index 0 after a tool round.""" + + def _make_tool_chunk(call_id: str, name: str) -> ModelResponseStream: + func = Function(name=name, arguments='{"x":1}') + tc = ChatCompletionDeltaToolCall(index=0, function=func) + tc.id = call_id + delta = Delta(tool_calls=[tc]) + choice = StreamingChoices(index=0, delta=delta, finish_reason=None) + return ModelResponseStream(id="test", choices=[choice]) + + round1_chunk = _make_tool_chunk("call_r1", "tool_a") + round2_chunk = _make_tool_chunk("call_r2", "tool_b") + + fake_events = [ + round1_chunk, + ToolCallEvent( + event_type=ToolCallEventType.INPUT_AVAILABLE, + tool_call_id="call_r1", + tool_name="tool_a", + arguments={"x": 1}, + ), + ToolCallEvent( + event_type=ToolCallEventType.OUTPUT_AVAILABLE, + tool_call_id="call_r1", + tool_name="tool_a", + result="done", + ), + round2_chunk, + ] + + class FakeAdapterStream: + result = MagicMock() + + async def __aiter__(self): + for event in fake_events: + yield event + + with ( + patch.object( + stream_adapter, + "_prepare_stream", + return_value=FakeAdapterStream(), + ), + patch.object(stream_adapter, "_finalize_stream"), + ): + events = [] + async for event in stream_adapter.invoke_ai_sdk_stream("test input"): + events.append(event) + + tool_input_starts = [ + e for e in events if e.type == AiSdkEventType.TOOL_INPUT_START + ] + assert len(tool_input_starts) == 2, ( + "tool-input-start must fire once per tool-call round" + ) + assert tool_input_starts[0].payload["toolCallId"] == "call_r1" + assert tool_input_starts[1].payload["toolCallId"] == "call_r2" + + @pytest.mark.asyncio + async def test_openai_stream_exposes_task_run_after_iteration(self, stream_adapter): + fake_chunk = ModelResponseStream( + id="test", + choices=[ + StreamingChoices( + index=0, + delta=Delta(content="hi"), + finish_reason=None, + ) + ], + ) + + class FakeAdapterStream: + result = MagicMock() + + async def __aiter__(self): + yield fake_chunk + + expected_run = MagicMock(spec=TaskRun) + + with ( + patch.object( + stream_adapter, + "_prepare_stream", + return_value=FakeAdapterStream(), + ), + patch.object(stream_adapter, "_finalize_stream", return_value=expected_run), + ): + stream = stream_adapter.invoke_openai_stream("test input") + + with pytest.raises(RuntimeError, match="not been fully consumed"): + _ = stream.task_run + + async for _chunk in stream: + pass + + assert stream.task_run is expected_run + + @pytest.mark.asyncio + async def test_ai_sdk_stream_exposes_task_run_after_iteration(self, stream_adapter): + fake_chunk = ModelResponseStream( + id="test", + choices=[ + StreamingChoices( + index=0, + delta=Delta(content="hi"), + finish_reason=None, + ) + ], + ) + + class FakeAdapterStream: + result = MagicMock() + + async def __aiter__(self): + yield fake_chunk + + expected_run = MagicMock(spec=TaskRun) + + with ( + patch.object( + stream_adapter, + "_prepare_stream", + return_value=FakeAdapterStream(), + ), + patch.object(stream_adapter, "_finalize_stream", return_value=expected_run), + ): + stream = stream_adapter.invoke_ai_sdk_stream("test input") + + with pytest.raises(RuntimeError, match="not been fully consumed"): + _ = stream.task_run + + async for _event in stream: + pass + + assert stream.task_run is expected_run + + +class TestFinalizeStream: + """Tests for _finalize_stream post-processing after streaming.""" + + @pytest.fixture + def finalize_adapter(self, base_task): + return MockAdapter( + task=base_task, + run_config=KilnAgentRunConfigProperties( + model_name="test_model", + model_provider_name="openai", + prompt_id="simple_prompt_builder", + structured_output_mode="json_schema", + ), + ) + + def _make_adapter_stream(self, output, usage=None, trace=None): + from kiln_ai.adapters.model_adapters.adapter_stream import AdapterStreamResult + + stream = MagicMock() + stream.result = AdapterStreamResult( + run_output=RunOutput( + output=output, + intermediate_outputs={}, + trace=trace, + ), + usage=usage or Usage(), + ) + return stream + + def test_finalize_stream_plain_text(self, finalize_adapter): + provider = MagicMock() + provider.parser = None + provider.reasoning_capable = False + finalize_adapter.model_provider = MagicMock(return_value=provider) + + adapter_stream = self._make_adapter_stream("Hello world") + run = finalize_adapter._finalize_stream(adapter_stream, "test input", None) + + assert isinstance(run, TaskRun) + assert run.output.output == "Hello world" + assert run.id is None + + def _make_structured_adapter(self, base_task, schema): + base_task.output_json_schema = schema + adapter = MockAdapter( + task=base_task, + run_config=KilnAgentRunConfigProperties( + model_name="test_model", + model_provider_name="openai", + prompt_id="simple_prompt_builder", + structured_output_mode="json_schema", + ), + ) + return adapter + + def test_finalize_stream_structured_output(self, base_task): + schema = '{"type": "object", "properties": {"name": {"type": "string"}}, "required": ["name"]}' + adapter = self._make_structured_adapter(base_task, schema) + + provider = MagicMock() + provider.parser = None + provider.reasoning_capable = False + adapter.model_provider = MagicMock(return_value=provider) + + adapter_stream = self._make_adapter_stream({"name": "test"}) + run = adapter._finalize_stream(adapter_stream, "test input", None) + + assert isinstance(run, TaskRun) + assert '"name"' in run.output.output + + def test_finalize_stream_structured_output_from_json_string(self, base_task): + schema = '{"type": "object", "properties": {"val": {"type": "integer"}}, "required": ["val"]}' + adapter = self._make_structured_adapter(base_task, schema) + + provider = MagicMock() + provider.parser = None + provider.reasoning_capable = False + adapter.model_provider = MagicMock(return_value=provider) + + adapter_stream = self._make_adapter_stream('{"val": 42}') + run = adapter._finalize_stream(adapter_stream, "test input", None) + assert isinstance(run, TaskRun) + + def test_finalize_stream_structured_output_not_dict_raises(self, base_task): + schema = '{"type": "object", "properties": {"x": {"type": "string"}}, "required": ["x"]}' + adapter = self._make_structured_adapter(base_task, schema) + + provider = MagicMock() + provider.parser = None + provider.reasoning_capable = False + adapter.model_provider = MagicMock(return_value=provider) + + adapter_stream = self._make_adapter_stream(42) + with pytest.raises(RuntimeError, match="structured response is not a dict"): + adapter._finalize_stream(adapter_stream, "test input", None) + + def test_finalize_stream_non_structured_non_string_raises(self, finalize_adapter): + provider = MagicMock() + provider.parser = None + provider.reasoning_capable = False + finalize_adapter.model_provider = MagicMock(return_value=provider) + + adapter_stream = self._make_adapter_stream({"unexpected": "dict"}) + with pytest.raises(RuntimeError, match="not a string for non-structured"): + finalize_adapter._finalize_stream(adapter_stream, "test input", None) + + def test_finalize_stream_reasoning_required_but_missing(self, finalize_adapter): + provider = MagicMock() + provider.parser = None + provider.reasoning_capable = True + provider.reasoning_optional_for_structured_output = False + finalize_adapter.model_provider = MagicMock(return_value=provider) + + adapter_stream = self._make_adapter_stream("output") + with pytest.raises(RuntimeError, match="Reasoning is required"): + finalize_adapter._finalize_stream(adapter_stream, "test input", None) + + def test_finalize_stream_reasoning_not_required_with_tool_calls( + self, finalize_adapter + ): + provider = MagicMock() + provider.parser = None + provider.reasoning_capable = True + provider.reasoning_optional_for_structured_output = False + finalize_adapter.model_provider = MagicMock(return_value=provider) + + trace = [ + {"role": "user", "content": "hi"}, + {"role": "tool", "content": "result", "tool_call_id": "call_1"}, + ] + adapter_stream = self._make_adapter_stream("output", trace=trace) + run = finalize_adapter._finalize_stream(adapter_stream, "test input", None) + assert isinstance(run, TaskRun) + + def test_finalize_stream_saves_when_allowed(self, tmp_path): + project_path = tmp_path / "proj" / "project.kiln" + project_path.parent.mkdir() + project = Project(name="test", path=project_path) + project.save_to_file() + task = Task(name="t", instruction="i", parent=project) + task.save_to_file() + + adapter = MockAdapter( + task=task, + run_config=KilnAgentRunConfigProperties( + model_name="test_model", + model_provider_name="openai", + prompt_id="simple_prompt_builder", + structured_output_mode="json_schema", + ), + config=AdapterConfig(allow_saving=True), + ) + + provider = MagicMock() + provider.parser = None + provider.reasoning_capable = False + adapter.model_provider = MagicMock(return_value=provider) + + adapter_stream = self._make_adapter_stream("result") + with patch( + "kiln_ai.adapters.model_adapters.base_adapter.Config" + ) as mock_config: + mock_config.shared.return_value.autosave_runs = True + mock_config.shared.return_value.user_id = "test_user" + run = adapter._finalize_stream(adapter_stream, "test input", None) + assert run.id is not None + + class TestResolveSkills: def test_returns_empty_for_non_kiln_agent(self, base_task): from kiln_ai.datamodel.run_config import ( diff --git a/libs/core/kiln_ai/adapters/model_adapters/test_litellm_adapter.py b/libs/core/kiln_ai/adapters/model_adapters/test_litellm_adapter.py index 959d8b680..c6d8741d9 100644 --- a/libs/core/kiln_ai/adapters/model_adapters/test_litellm_adapter.py +++ b/libs/core/kiln_ai/adapters/model_adapters/test_litellm_adapter.py @@ -7,7 +7,10 @@ from kiln_ai.adapters.ml_model_list import ModelProviderName, StructuredOutputMode from kiln_ai.adapters.model_adapters.base_adapter import AdapterConfig -from kiln_ai.adapters.model_adapters.litellm_adapter import LiteLlmAdapter +from kiln_ai.adapters.model_adapters.litellm_adapter import ( + LiteLlmAdapter, + ModelTurnResult, +) from kiln_ai.adapters.model_adapters.litellm_config import LiteLlmConfig from kiln_ai.datamodel import Project, Task, Usage from kiln_ai.datamodel.run_config import ( @@ -1318,7 +1321,11 @@ async def test_array_input_converted_to_json(tmp_path, config): mock_config_obj.user_id = "test_user" with ( - patch("litellm.acompletion", new=AsyncMock(return_value=mock_response)), + patch.object( + LiteLlmAdapter, + "acompletion_checking_response", + new=AsyncMock(return_value=(mock_response, mock_response.choices[0])), + ), patch("kiln_ai.utils.config.Config.shared", return_value=mock_config_obj), ): array_input = [1, 2, 3, 4, 5] @@ -1388,7 +1395,11 @@ async def test_dict_input_converted_to_json(tmp_path, config): mock_config_obj.user_id = "test_user" with ( - patch("litellm.acompletion", new=AsyncMock(return_value=mock_response)), + patch.object( + LiteLlmAdapter, + "acompletion_checking_response", + new=AsyncMock(return_value=(mock_response, mock_response.choices[0])), + ), patch("kiln_ai.utils.config.Config.shared", return_value=mock_config_obj), ): dict_input = {"x": 10, "y": 20} @@ -1410,3 +1421,179 @@ async def test_dict_input_converted_to_json(tmp_path, config): assert isinstance(content, str) parsed_content = json.loads(content) assert parsed_content == {"x": 10, "y": 20} + + +@pytest.mark.asyncio +async def test_run_with_prior_trace_uses_multiturn_formatter(mock_task): + config = LiteLlmConfig( + base_url="https://api.test.com", + run_config_properties=KilnAgentRunConfigProperties( + model_name="test-model", + model_provider_name="openai_compatible", + prompt_id="simple_prompt_builder", + structured_output_mode="json_schema", + ), + default_headers={"X-Test": "test"}, + additional_body_options={"api_key": "test_key"}, + ) + prior_trace = [ + {"role": "user", "content": "hi"}, + {"role": "assistant", "content": "hello"}, + ] + adapter = LiteLlmAdapter(config=config, kiln_task=mock_task) + + build_chat_formatter_calls = [] + + original_build = adapter.build_chat_formatter + + def capturing_build(input, prior_trace_arg=None): + build_chat_formatter_calls.append((input, prior_trace_arg)) + return original_build(input, prior_trace_arg) + + adapter.build_chat_formatter = capturing_build + + async def mock_run_model_turn( + provider, prior_messages, top_logprobs, skip_response_format + ): + extended = list(prior_messages) + extended.append({"role": "assistant", "content": "How can I help?"}) + return ModelTurnResult( + assistant_message="How can I help?", + all_messages=extended, + model_response=None, + model_choice=None, + usage=Usage(), + ) + + adapter._run_model_turn = mock_run_model_turn + + run_output, _ = await adapter._run("follow-up", prior_trace=prior_trace) + + assert len(build_chat_formatter_calls) == 1 + assert build_chat_formatter_calls[0][0] == "follow-up" + assert build_chat_formatter_calls[0][1] == prior_trace + + assert run_output.trace is not None + assert len(run_output.trace) == 4 + assert run_output.trace[0]["content"] == "hi" + assert run_output.trace[1]["content"] == "hello" + assert run_output.trace[2]["content"] == "follow-up" + assert run_output.trace[3]["content"] == "How can I help?" + + +@pytest.mark.asyncio +async def test_run_with_prior_trace_preserves_tool_calls(mock_task): + """Prior trace containing tool calls should be passed through to the model and preserved in the output trace.""" + config = LiteLlmConfig( + base_url="https://api.test.com", + run_config_properties=KilnAgentRunConfigProperties( + model_name="test-model", + model_provider_name="openai_compatible", + prompt_id="simple_prompt_builder", + structured_output_mode="json_schema", + ), + default_headers={"X-Test": "test"}, + additional_body_options={"api_key": "test_key"}, + ) + + prior_trace = [ + {"role": "system", "content": "Use the math tools."}, + {"role": "user", "content": "4"}, + { + "role": "assistant", + "content": "", + "reasoning_content": "Let me multiply 4 by 7.\n", + "tool_calls": [ + { + "id": "call_abc123", + "function": {"arguments": '{"a": 4, "b": 7}', "name": "multiply"}, + "type": "function", + } + ], + }, + { + "content": "28", + "role": "tool", + "tool_call_id": "call_abc123", + "kiln_task_tool_data": None, + }, + { + "role": "assistant", + "content": "", + "reasoning_content": "Now add 144.\n", + "tool_calls": [ + { + "id": "call_def456", + "function": {"arguments": '{"a": 28, "b": 144}', "name": "add"}, + "type": "function", + } + ], + }, + { + "content": "172", + "role": "tool", + "tool_call_id": "call_def456", + "kiln_task_tool_data": None, + }, + { + "role": "assistant", + "content": "There were 172 distinct species of giant tortoises.", + "reasoning_content": "Now I have 172.\n", + }, + ] + adapter = LiteLlmAdapter(config=config, kiln_task=mock_task) + + captured_messages = [] + + async def mock_run_model_turn( + provider, prior_messages, top_logprobs, skip_response_format + ): + captured_messages.extend(prior_messages) + extended = list(prior_messages) + extended.append({"role": "assistant", "content": '{"test": "response"}'}) + return ModelTurnResult( + assistant_message='{"test": "response"}', + all_messages=extended, + model_response=None, + model_choice=None, + usage=Usage(), + ) + + adapter._run_model_turn = mock_run_model_turn + + run_output, _ = await adapter._run("what else?", prior_trace=prior_trace) + + assert run_output.trace is not None + # 7 prior trace messages + 1 new user + 1 new assistant = 9 + assert len(run_output.trace) == 9 + + # Verify tool call messages are preserved in the trace + assistant_with_tools = run_output.trace[2] + assert assistant_with_tools["role"] == "assistant" + assert assistant_with_tools["tool_calls"][0]["id"] == "call_abc123" + assert assistant_with_tools["tool_calls"][0]["function"]["name"] == "multiply" + assert assistant_with_tools["reasoning_content"] == "Let me multiply 4 by 7.\n" + + tool_response = run_output.trace[3] + assert tool_response["role"] == "tool" + assert tool_response["tool_call_id"] == "call_abc123" + assert tool_response["content"] == "28" + + second_tool_call = run_output.trace[4] + assert second_tool_call["tool_calls"][0]["id"] == "call_def456" + assert second_tool_call["tool_calls"][0]["function"]["name"] == "add" + + second_tool_response = run_output.trace[5] + assert second_tool_response["role"] == "tool" + assert second_tool_response["tool_call_id"] == "call_def456" + assert second_tool_response["content"] == "172" + + # Verify the tool call messages were passed to _run_model_turn (i.e., sent to the model) + assert any( + m.get("tool_calls") is not None + for m in captured_messages + if isinstance(m, dict) + ) + assert any( + m.get("role") == "tool" for m in captured_messages if isinstance(m, dict) + ) diff --git a/libs/core/kiln_ai/adapters/model_adapters/test_litellm_adapter_streaming.py b/libs/core/kiln_ai/adapters/model_adapters/test_litellm_adapter_streaming.py new file mode 100644 index 000000000..cc2d51b98 --- /dev/null +++ b/libs/core/kiln_ai/adapters/model_adapters/test_litellm_adapter_streaming.py @@ -0,0 +1,380 @@ +import json +import logging +import re +from datetime import datetime, timezone +from pathlib import Path +from typing import Any, Callable + +import litellm +import pytest +from litellm.types.utils import ChatCompletionDeltaToolCall + +from kiln_ai.adapters.ml_model_list import ModelProviderName, StructuredOutputMode +from kiln_ai.adapters.model_adapters.litellm_adapter import LiteLlmAdapter +from kiln_ai.adapters.model_adapters.litellm_config import LiteLlmConfig +from kiln_ai.adapters.model_adapters.stream_events import ( + AiSdkEventType, + AiSdkStreamEvent, +) +from kiln_ai.datamodel import Project, PromptGenerators, Task +from kiln_ai.datamodel.run_config import KilnAgentRunConfigProperties, ToolsRunConfig +from kiln_ai.datamodel.tool_id import KilnBuiltInToolId + +logger = logging.getLogger(__name__) + +STREAMING_MODELS = [ + ("claude_sonnet_4_5", ModelProviderName.openrouter, "medium"), + ("claude_sonnet_4_5", ModelProviderName.anthropic, "medium"), + ("claude_sonnet_4_6", ModelProviderName.openrouter, "medium"), + ("claude_sonnet_4_6", ModelProviderName.anthropic, "medium"), + ("claude_opus_4_5", ModelProviderName.openrouter, "medium"), + ("claude_opus_4_5", ModelProviderName.anthropic, "medium"), + ("claude_opus_4_6", ModelProviderName.openrouter, "medium"), + ("claude_opus_4_6", ModelProviderName.anthropic, "medium"), + ("minimax_m2_5", ModelProviderName.openrouter, "medium"), + ("claude_4_5_haiku", ModelProviderName.openrouter, "medium"), + ("claude_4_5_haiku", ModelProviderName.anthropic, "medium"), +] + +PAID_TEST_OUTPUT_DIR = Path(__file__).resolve().parents[5] / "test_output" + + +def _serialize_for_dump(obj: Any) -> Any: + if hasattr(obj, "model_dump"): + return obj.model_dump(mode="json") + if isinstance(obj, list): + if not obj: + return [] + first = obj[0] + if hasattr(first, "type") and hasattr(first, "payload"): + return [{"type": e.type.value, "payload": e.payload} for e in obj] + if hasattr(first, "model_dump"): + return [item.model_dump(mode="json") for item in obj] + return [_serialize_for_dump(x) for x in obj] + return obj + + +def _dump_paid_test_output(request: pytest.FixtureRequest, **payloads: Any) -> Path: + test_name = re.sub(r"[^\w\-]", "_", request.node.name) + param_id = "default" + if hasattr(request.node, "callspec") and request.node.callspec is not None: + id_attr = getattr(request.node.callspec, "id", None) + if id_attr is not None: + param_id = re.sub(r"[^\w\-]", "_", str(id_attr)) + timestamp = datetime.now(timezone.utc).strftime("%Y-%m-%dT%H-%M-%S") + out_dir = PAID_TEST_OUTPUT_DIR / test_name / param_id / timestamp + out_dir.mkdir(parents=True, exist_ok=True) + for filename, data in payloads.items(): + if data is None: + continue + if not filename.endswith(".json"): + filename = f"{filename}.json" + serialized = _serialize_for_dump(data) + (out_dir / filename).write_text( + json.dumps(serialized, indent=2, default=str), encoding="utf-8" + ) + return out_dir + + +@pytest.fixture +def task(tmp_path): + project_path: Path = tmp_path / "test_project" / "project.kiln" + project_path.parent.mkdir() + + project = Project(name="Test Project", path=project_path) + project.save_to_file() + + task = Task( + name="Streaming Test Task", + instruction="Think about it hard! Solve the math problem provided by the user, in a step by step manner. Use the tools provided to solve the math problem. Then use the result in a short sentence about a cat going to the mall. Remember to use the tools for math even if the operation looks easy.", + parent=project, + ) + task.save_to_file() + return task + + +@pytest.fixture +def adapter_factory( + task: Task, +) -> Callable[[str, ModelProviderName, str | None], LiteLlmAdapter]: + def create_adapter( + model_id: str, provider_name: ModelProviderName, thinking_level: str | None + ) -> LiteLlmAdapter: + return LiteLlmAdapter( + kiln_task=task, + config=LiteLlmConfig( + run_config_properties=KilnAgentRunConfigProperties( + model_name=model_id, + model_provider_name=provider_name, + prompt_id=PromptGenerators.SIMPLE, + structured_output_mode=StructuredOutputMode.unknown, + tools_config=ToolsRunConfig( + tools=[ + KilnBuiltInToolId.ADD_NUMBERS, + KilnBuiltInToolId.SUBTRACT_NUMBERS, + KilnBuiltInToolId.MULTIPLY_NUMBERS, + KilnBuiltInToolId.DIVIDE_NUMBERS, + ], + ), + thinking_level=thinking_level, + ) + ), + ) + + return create_adapter + + +@pytest.mark.paid +@pytest.mark.parametrize("model_id,provider_name,thinking_level", STREAMING_MODELS) +async def test_invoke_openai_stream_chunks( + request: pytest.FixtureRequest, + model_id: str, + provider_name: ModelProviderName, + thinking_level: str | None, + adapter_factory: Callable[[str, ModelProviderName, str | None], LiteLlmAdapter], +): + """Collect all OpenAI-protocol chunks via invoke_openai_stream and verify we got reasoning, content, and tool call data.""" + adapter = adapter_factory(model_id, provider_name, thinking_level) + + chunks: list[litellm.ModelResponseStream] = [] + async for chunk in adapter.invoke_openai_stream(input="123 + 321 = ?"): + chunks.append(chunk) + + _dump_paid_test_output(request, chunks=chunks) + assert len(chunks) > 0, "No chunks collected" + + reasoning_contents: list[str] = [] + contents: list[str] = [] + tool_calls: list[ChatCompletionDeltaToolCall | Any] = [] + + for chunk in chunks: + if len(chunk.choices) == 0: + continue + if chunk.choices[0].finish_reason is not None: + continue + delta = chunk.choices[0].delta + if delta is None: + continue + if delta.tool_calls is not None: + tool_calls.extend(delta.tool_calls) + elif getattr(delta, "reasoning_content", None) is not None: + text = getattr(delta, "reasoning_content", None) + if text is not None: + reasoning_contents.append(text) + elif delta.content is not None: + contents.append(delta.content) + + assert len(reasoning_contents) > 0, "No reasoning content in chunks" + assert len(contents) > 0, "No content in chunks" + assert len(tool_calls) > 0, "No tool calls in chunks" + assert not all(r.strip() == "" for r in reasoning_contents), ( + "All reasoning content in chunks is empty" + ) + assert not all(c.strip() == "" for c in contents), "All content in chunks is empty" + + tool_call_function_names = [ + tc.function.name for tc in tool_calls if tc.function.name is not None + ] + assert len(tool_call_function_names) == 1, ( + "Expected exactly one tool call function name" + ) + assert tool_call_function_names[0] == "add", "Tool call function name is not 'add'" + + tool_call_args_chunks = "".join( + tc.function.arguments for tc in tool_calls if tc.function.arguments is not None + ) + tool_call_args = json.loads(tool_call_args_chunks) + assert tool_call_args == {"a": 123, "b": 321} or tool_call_args == { + "a": 321, + "b": 123, + }, f"Tool call arguments not as expected: {tool_call_args}" + + +@pytest.mark.paid +@pytest.mark.parametrize("model_id,provider_name,thinking_level", STREAMING_MODELS) +async def test_invoke_ai_sdk_stream( + request: pytest.FixtureRequest, + model_id: str, + provider_name: ModelProviderName, + thinking_level: str | None, + adapter_factory: Callable[[str, ModelProviderName, str | None], LiteLlmAdapter], +): + """Collect AI SDK events and verify the full protocol lifecycle including tool events.""" + adapter = adapter_factory(model_id, provider_name, thinking_level) + + events: list[AiSdkStreamEvent] = [] + async for event in adapter.invoke_ai_sdk_stream(input="123 + 321 = ?"): + events.append(event) + logger.info(f"AI SDK event: {event.type.value} {event.payload}") + + _dump_paid_test_output(request, events=events) + assert len(events) > 0, "No events collected" + + event_types = [e.type for e in events] + + assert event_types[0] == AiSdkEventType.START, "First event should be START" + assert event_types[1] == AiSdkEventType.START_STEP, ( + "Second event should be START_STEP" + ) + + assert AiSdkEventType.FINISH_STEP in event_types, "Should have FINISH_STEP" + assert AiSdkEventType.FINISH in event_types, "Should have FINISH" + finish_step_idx = event_types.index(AiSdkEventType.FINISH_STEP) + finish_idx = event_types.index(AiSdkEventType.FINISH) + assert finish_step_idx < finish_idx, ( + f"FINISH_STEP (idx {finish_step_idx}) must come before FINISH (idx {finish_idx})" + ) + + assert AiSdkEventType.REASONING_START in event_types, "Should have REASONING_START" + assert AiSdkEventType.REASONING_DELTA in event_types, "Should have REASONING_DELTA" + + assert AiSdkEventType.TEXT_START in event_types, "Should have TEXT_START" + assert AiSdkEventType.TEXT_DELTA in event_types, "Should have TEXT_DELTA" + assert AiSdkEventType.TEXT_END in event_types, "Should have TEXT_END" + + assert AiSdkEventType.TOOL_INPUT_START in event_types, ( + "Should have TOOL_INPUT_START" + ) + assert AiSdkEventType.TOOL_INPUT_AVAILABLE in event_types, ( + "Should have TOOL_INPUT_AVAILABLE" + ) + assert AiSdkEventType.TOOL_OUTPUT_AVAILABLE in event_types, ( + "Should have TOOL_OUTPUT_AVAILABLE" + ) + + text_deltas = [ + e.payload.get("delta", "") + for e in events + if e.type == AiSdkEventType.TEXT_DELTA + ] + full_text = "".join(text_deltas) + assert len(full_text) > 0, "Text content is empty" + + tool_input_available = [ + e for e in events if e.type == AiSdkEventType.TOOL_INPUT_AVAILABLE + ] + assert len(tool_input_available) >= 1, ( + "Should have at least one tool-input-available" + ) + tool_input = tool_input_available[0].payload.get("input", {}) + assert "a" in tool_input and "b" in tool_input, ( + f"Tool input should have a and b keys: {tool_input}" + ) + + tool_output_available = [ + e for e in events if e.type == AiSdkEventType.TOOL_OUTPUT_AVAILABLE + ] + assert len(tool_output_available) >= 1, ( + "Should have at least one tool-output-available" + ) + assert tool_output_available[0].payload.get("output") is not None, ( + "Tool output should not be None" + ) + + +@pytest.mark.paid +@pytest.mark.parametrize("model_id,provider_name,thinking_level", STREAMING_MODELS) +async def test_ai_sdk_stream_text_ends_before_tool_calls( + request: pytest.FixtureRequest, + model_id: str, + provider_name: ModelProviderName, + thinking_level: str | None, + adapter_factory: Callable[[str, ModelProviderName, str | None], LiteLlmAdapter], +): + """Verify text blocks are properly closed before tool-input-start and reopened with a new ID after tool execution.""" + adapter = adapter_factory(model_id, provider_name, thinking_level) + + events: list[AiSdkStreamEvent] = [] + async for event in adapter.invoke_ai_sdk_stream( + input="First tell me you're about to calculate, then compute 11 + 50 and 50 * 85, then add the results. Use the tools for all math." + ): + events.append(event) + + _dump_paid_test_output(request, events=events) + + event_types = [e.type for e in events] + assert AiSdkEventType.TEXT_START in event_types, "Should have TEXT_START" + assert AiSdkEventType.TOOL_INPUT_START in event_types, ( + "Should have TOOL_INPUT_START" + ) + + text_ids_seen: list[str] = [] + text_open = False + for event in events: + if event.type == AiSdkEventType.TEXT_START: + assert not text_open, ( + "text-start emitted while a text block was already open" + ) + text_open = True + text_ids_seen.append(event.payload["id"]) + + elif event.type == AiSdkEventType.TEXT_END: + assert text_open, "text-end emitted without a preceding text-start" + text_open = False + + elif event.type == AiSdkEventType.TEXT_DELTA: + assert text_open, ( + f"text-delta emitted outside an open text block: {event.payload}" + ) + + elif event.type == AiSdkEventType.TOOL_INPUT_START: + assert not text_open, ( + "tool-input-start emitted while text block was still open " + "(missing text-end before tool calls)" + ) + + assert len(text_ids_seen) >= 2, ( + f"Expected at least 2 distinct text blocks (before and after tool calls), " + f"got {len(text_ids_seen)}: {text_ids_seen}" + ) + assert len(set(text_ids_seen)) == len(text_ids_seen), ( + f"Each text block should have a unique ID, got duplicates: {text_ids_seen}" + ) + + +@pytest.mark.paid +@pytest.mark.parametrize("model_id,provider_name,thinking_level", STREAMING_MODELS) +async def test_invoke_openai_stream_non_streaming_still_works( + request: pytest.FixtureRequest, + model_id: str, + provider_name: ModelProviderName, + thinking_level: str | None, + adapter_factory: Callable[[str, ModelProviderName, str | None], LiteLlmAdapter], +): + """Verify the non-streaming invoke() still works after the refactor.""" + adapter = adapter_factory(model_id, provider_name, thinking_level) + task_run = await adapter.invoke(input="123 + 321 = ?") + + _dump_paid_test_output(request, task_run=task_run) + assert task_run.trace is not None, "Task run trace is None" + assert len(task_run.trace) > 0, "Task run trace is empty" + assert "444" in task_run.output.output, ( + f"Expected 444 in output: {task_run.output.output}" + ) + + +@pytest.mark.paid +@pytest.mark.parametrize("model_id,provider_name,thinking_level", STREAMING_MODELS) +async def test_invoke_openai_stream_with_prior_trace( + request: pytest.FixtureRequest, + model_id: str, + provider_name: ModelProviderName, + thinking_level: str | None, + adapter_factory: Callable[[str, ModelProviderName, str | None], LiteLlmAdapter], +): + """Test that streaming works when continuing an existing run (session continuation).""" + adapter = adapter_factory(model_id, provider_name, thinking_level) + + initial_run = await adapter.invoke(input="123 + 321 = ?") + assert initial_run.trace is not None + assert len(initial_run.trace) > 0 + + continuation_chunks: list[litellm.ModelResponseStream] = [] + async for chunk in adapter.invoke_openai_stream( + input="What was the result? Reply in one short sentence.", + prior_trace=initial_run.trace, + ): + continuation_chunks.append(chunk) + + _dump_paid_test_output(request, continuation_chunks=continuation_chunks) + assert len(continuation_chunks) > 0, "No continuation chunks collected" diff --git a/libs/core/kiln_ai/adapters/model_adapters/test_litellm_adapter_tools.py b/libs/core/kiln_ai/adapters/model_adapters/test_litellm_adapter_tools.py index fb7bc4c21..3674fe5b3 100644 --- a/libs/core/kiln_ai/adapters/model_adapters/test_litellm_adapter_tools.py +++ b/libs/core/kiln_ai/adapters/model_adapters/test_litellm_adapter_tools.py @@ -1,6 +1,6 @@ import json from pathlib import Path -from unittest.mock import Mock, patch +from unittest.mock import AsyncMock, Mock, patch import pytest from litellm.types.utils import ModelResponse @@ -287,10 +287,17 @@ async def test_tools_simplied_mocked(tmp_path): mock_config.open_ai_api_key = "mock_api_key" mock_config.user_id = "test_user" + responses = [mock_response_1, mock_response_2] + + async def mock_acompletion_checking_response(self, **kwargs): + response = responses.pop(0) + return response, response.choices[0] + with ( - patch( - "litellm.acompletion", - side_effect=[mock_response_1, mock_response_2], + patch.object( + LiteLlmAdapter, + "acompletion_checking_response", + new=mock_acompletion_checking_response, ), patch("kiln_ai.utils.config.Config.shared", return_value=mock_config), ): @@ -386,9 +393,16 @@ async def test_tools_mocked(tmp_path): mock_config.user_id = "test_user" with ( - patch( - "litellm.acompletion", - side_effect=[mock_response_1, mock_response_2, mock_response_3], + patch.object( + LiteLlmAdapter, + "acompletion_checking_response", + new=AsyncMock( + side_effect=[ + (mock_response_1, mock_response_1.choices[0]), + (mock_response_2, mock_response_2.choices[0]), + (mock_response_3, mock_response_3.choices[0]), + ] + ), ), patch("kiln_ai.utils.config.Config.shared", return_value=mock_config), ): diff --git a/libs/core/kiln_ai/adapters/model_adapters/test_mcp_adapter.py b/libs/core/kiln_ai/adapters/model_adapters/test_mcp_adapter.py index 393a3859f..689801f86 100644 --- a/libs/core/kiln_ai/adapters/model_adapters/test_mcp_adapter.py +++ b/libs/core/kiln_ai/adapters/model_adapters/test_mcp_adapter.py @@ -1,6 +1,6 @@ import json from pathlib import Path -from unittest.mock import AsyncMock, patch +from unittest.mock import AsyncMock, MagicMock, patch import pytest from mcp.types import CallToolResult, TextContent @@ -328,3 +328,138 @@ async def test_mcp_adapter_sets_and_clears_run_context( await adapter.invoke_returning_run_output("input") assert get_agent_run_id() is None + + +@pytest.mark.asyncio +async def test_mcp_adapter_rejects_multiturn_invoke_returning_run_output( + project_with_local_mcp_server, local_mcp_tool_id +): + """Session continuation (prior_trace) is not supported for MCP adapter.""" + project, _ = project_with_local_mcp_server + task = Task( + name="Test Task", + parent=project, + instruction="Echo input", + ) + + run_config = McpRunConfigProperties( + tool_reference=MCPToolReference(tool_id=local_mcp_tool_id) + ) + + adapter = MCPAdapter(task=task, run_config=run_config) + + existing_run = MagicMock() + existing_run.trace = [{"role": "user", "content": "hi"}] + + with pytest.raises(NotImplementedError) as exc_info: + await adapter.invoke_returning_run_output( + "input", prior_trace=existing_run.trace + ) + + assert "Session continuation is not supported" in str(exc_info.value) + assert "MCP adapter" in str(exc_info.value) + + +@pytest.mark.asyncio +async def test_mcp_adapter_rejects_multiturn_invoke( + project_with_local_mcp_server, local_mcp_tool_id +): + """invoke with prior_trace raises NotImplementedError for MCP adapter.""" + project, _ = project_with_local_mcp_server + task = Task( + name="Test Task", + parent=project, + instruction="Echo input", + ) + + run_config = McpRunConfigProperties( + tool_reference=MCPToolReference(tool_id=local_mcp_tool_id) + ) + + adapter = MCPAdapter(task=task, run_config=run_config) + + existing_run = MagicMock() + existing_run.trace = [{"role": "user", "content": "hi"}] + + with pytest.raises(NotImplementedError) as exc_info: + await adapter.invoke("input", prior_trace=existing_run.trace) + + assert "Session continuation is not supported" in str(exc_info.value) + + +@pytest.mark.asyncio +async def test_mcp_adapter_rejects_prior_trace_in_run( + project_with_local_mcp_server, local_mcp_tool_id +): + """_run with prior_trace raises NotImplementedError for MCP adapter.""" + project, _ = project_with_local_mcp_server + task = Task( + name="Test Task", + parent=project, + instruction="Echo input", + ) + + run_config = McpRunConfigProperties( + tool_reference=MCPToolReference(tool_id=local_mcp_tool_id) + ) + + adapter = MCPAdapter(task=task, run_config=run_config) + + prior_trace = [ + {"role": "user", "content": "first message"}, + {"role": "assistant", "content": "first response"}, + ] + + with pytest.raises(NotImplementedError) as exc_info: + await adapter._run("follow-up message", prior_trace=prior_trace) + + assert "Session continuation is not supported" in str(exc_info.value) + assert "MCP tools are single-turn" in str(exc_info.value) + + +@pytest.mark.asyncio +async def test_mcp_adapter_rejects_parent_task_run_in_invoke( + project_with_local_mcp_server, local_mcp_tool_id +): + """invoke with parent_task_run raises NotImplementedError for MCP adapter.""" + project, _ = project_with_local_mcp_server + task = Task( + name="Test Task", + parent=project, + instruction="Echo input", + ) + + run_config = McpRunConfigProperties( + tool_reference=MCPToolReference(tool_id=local_mcp_tool_id) + ) + + adapter = MCPAdapter(task=task, run_config=run_config) + + with pytest.raises(NotImplementedError) as exc_info: + await adapter.invoke("input", parent_task_run=MagicMock()) + + assert "Session continuation is not supported" in str(exc_info.value) + + +@pytest.mark.asyncio +async def test_mcp_adapter_rejects_parent_task_run_in_invoke_returning_run_output( + project_with_local_mcp_server, local_mcp_tool_id +): + """invoke_returning_run_output with parent_task_run raises NotImplementedError for MCP adapter.""" + project, _ = project_with_local_mcp_server + task = Task( + name="Test Task", + parent=project, + instruction="Echo input", + ) + + run_config = McpRunConfigProperties( + tool_reference=MCPToolReference(tool_id=local_mcp_tool_id) + ) + + adapter = MCPAdapter(task=task, run_config=run_config) + + with pytest.raises(NotImplementedError) as exc_info: + await adapter.invoke_returning_run_output("input", parent_task_run=MagicMock()) + + assert "Session continuation is not supported" in str(exc_info.value) diff --git a/libs/core/kiln_ai/adapters/model_adapters/test_saving_adapter_results.py b/libs/core/kiln_ai/adapters/model_adapters/test_saving_adapter_results.py index 9cc5fa5d9..c2ebf17a5 100644 --- a/libs/core/kiln_ai/adapters/model_adapters/test_saving_adapter_results.py +++ b/libs/core/kiln_ai/adapters/model_adapters/test_saving_adapter_results.py @@ -1,16 +1,16 @@ -from unittest.mock import patch +from unittest.mock import AsyncMock, MagicMock, patch import pytest from kiln_ai.adapters.model_adapters.base_adapter import BaseAdapter, RunOutput -from kiln_ai.datamodel import DataSource, DataSourceType, Project, Task, Usage +from kiln_ai.datamodel import DataSource, DataSourceType, Project, Task, TaskRun, Usage from kiln_ai.datamodel.datamodel_enums import InputType from kiln_ai.datamodel.run_config import KilnAgentRunConfigProperties from kiln_ai.utils.config import Config class MockAdapter(BaseAdapter): - async def _run(self, input: InputType) -> tuple[RunOutput, Usage | None]: + async def _run(self, input: InputType, **kwargs) -> tuple[RunOutput, Usage | None]: return RunOutput(output="Test output", intermediate_outputs=None), None def adapter_name(self) -> str: @@ -233,6 +233,164 @@ async def test_autosave_true(test_task, adapter): assert output.source.properties["top_p"] == 1.0 +@pytest.mark.asyncio +async def test_invoke_continue_session(test_task, adapter): + """Test that invoke with prior_trace continues a session and creates a new run.""" + with patch("kiln_ai.utils.config.Config.shared") as mock_shared: + mock_config = mock_shared.return_value + mock_config.autosave_runs = True + mock_config.user_id = "test_user" + + trace = [ + {"role": "user", "content": "Hello"}, + {"role": "assistant", "content": "Hi there!"}, + ] + + async def mock_run(input, **kwargs): + prior_trace = kwargs.get("prior_trace") + if prior_trace is not None: + extended_trace = [ + *prior_trace, + {"role": "user", "content": input}, + {"role": "assistant", "content": "How can I help?"}, + ] + return ( + RunOutput( + output="How can I help?", + intermediate_outputs=None, + trace=extended_trace, + ), + None, + ) + return RunOutput(output="Test output", intermediate_outputs=None), None + + adapter._run = mock_run + + with ( + patch.object( + adapter, + "model_provider", + return_value=MagicMock( + parser="default", + formatter=None, + reasoning_capable=False, + ), + ), + patch( + "kiln_ai.adapters.model_adapters.base_adapter.model_parser_from_id" + ) as mock_parser_from_id, + ): + mock_parser = MagicMock() + mock_parser.parse_output.return_value = RunOutput( + output="How can I help?", + intermediate_outputs=None, + trace=[ + {"role": "user", "content": "Hello"}, + {"role": "assistant", "content": "Hi there!"}, + {"role": "user", "content": "Tell me more"}, + {"role": "assistant", "content": "How can I help?"}, + ], + ) + mock_parser_from_id.return_value = mock_parser + + new_run = await adapter.invoke("Tell me more", prior_trace=trace) + + assert new_run.id is not None + assert new_run.input == "Tell me more" + assert new_run.output.output == "How can I help?" + assert len(new_run.trace) == 4 + assert new_run.trace[-2]["content"] == "Tell me more" + assert new_run.trace[-1]["content"] == "How can I help?" + + reloaded = Task.load_from_file(test_task.path) + runs = reloaded.runs() + assert len(runs) == 1 + assert runs[0].output.output == "How can I help?" + + +@pytest.mark.asyncio +async def test_invoke_with_empty_prior_trace_starts_fresh(test_task, adapter): + """Test that invoke with prior_trace=[] starts a fresh conversation (no error).""" + with patch("kiln_ai.utils.config.Config.shared") as mock_shared: + mock_config = mock_shared.return_value + mock_config.autosave_runs = True + mock_config.user_id = "test_user" + + adapter._run = AsyncMock( + return_value=( + RunOutput(output="Fresh reply", intermediate_outputs=None, trace=None), + None, + ) + ) + with ( + patch.object( + adapter, + "model_provider", + return_value=MagicMock( + parser="default", + formatter=None, + reasoning_capable=False, + ), + ), + patch( + "kiln_ai.adapters.model_adapters.base_adapter.model_parser_from_id", + return_value=MagicMock( + parse_output=MagicMock( + return_value=RunOutput( + output="Fresh reply", + intermediate_outputs=None, + trace=None, + ) + ) + ), + ), + patch( + "kiln_ai.adapters.model_adapters.base_adapter.request_formatter_from_id", + ), + ): + run = await adapter.invoke("Follow up", prior_trace=[]) + assert run.output.output == "Fresh reply" + + +def test_generate_run_always_creates_new_task_run(test_task, adapter): + trace = [ + {"role": "user", "content": "hi"}, + {"role": "assistant", "content": "hello"}, + ] + run1 = adapter.generate_run( + input="hi", + input_source=None, + run_output=RunOutput( + output="hello", + intermediate_outputs={"chain_of_thought": "old"}, + trace=trace, + ), + usage=Usage(input_tokens=10, output_tokens=20), + trace=trace, + ) + extended_trace = [ + *trace, + {"role": "user", "content": "follow-up"}, + {"role": "assistant", "content": "ok"}, + ] + run2 = adapter.generate_run( + input="follow-up", + input_source=None, + run_output=RunOutput( + output="ok", + intermediate_outputs={"new_key": "new_val"}, + trace=extended_trace, + ), + usage=Usage(input_tokens=5, output_tokens=10), + trace=extended_trace, + ) + assert run2 is not run1 + assert run2.usage is not None and run2.usage.input_tokens == 5 + assert run2.usage.output_tokens == 10 + assert run2.intermediate_outputs == {"new_key": "new_val"} + assert run2.output.output == "ok" + + def test_properties_for_task_output_custom_values(test_task): """Test that _properties_for_task_output includes custom temperature, top_p, and structured_output_mode""" adapter = MockAdapter( @@ -267,3 +425,122 @@ def test_properties_for_task_output_custom_values(test_task): assert output.source.properties["structured_output_mode"] == "json_schema" assert output.source.properties["temperature"] == 0.7 assert output.source.properties["top_p"] == 0.9 + + +def test_generate_run_with_parent_task_run_sets_parent(test_task, adapter): + """Test that generate_run with parent_task_run uses it as parent instead of the task.""" + prior_run = adapter.generate_run( + input="prior input", + input_source=None, + run_output=RunOutput(output="prior output", intermediate_outputs=None), + ) + prior_run.save_to_file() + assert prior_run.id is not None + + new_run = adapter.generate_run( + input="new input", + input_source=None, + run_output=RunOutput(output="new output", intermediate_outputs=None), + parent_task_run=prior_run, + ) + + assert new_run.parent == prior_run + + new_run.save_to_file() + + reloaded_prior_run = TaskRun.load_from_file(prior_run.path) + child_runs = reloaded_prior_run.runs() + assert len(child_runs) == 1 + assert child_runs[0].output.output == "new output" + + # The task should only have the prior run as a direct child + reloaded_task = Task.load_from_file(test_task.path) + task_runs = reloaded_task.runs() + assert len(task_runs) == 1 + assert task_runs[0].id == prior_run.id + + +def test_generate_run_without_parent_task_run_defaults_to_task(test_task, adapter): + """Test that generate_run without parent_task_run defaults to using the task as parent.""" + run = adapter.generate_run( + input="input", + input_source=None, + run_output=RunOutput(output="output", intermediate_outputs=None), + ) + assert run.parent == test_task + + +@pytest.mark.asyncio +async def test_invoke_with_parent_task_run_saves_as_child(test_task, adapter): + """Test that invoke with parent_task_run saves the new run as a child of that run.""" + trace = [ + {"role": "user", "content": "Hello"}, + {"role": "assistant", "content": "Hi there!"}, + ] + + # Create and save a prior run to act as parent + prior_run = adapter.generate_run( + input="Hello", + input_source=None, + run_output=RunOutput( + output="Hi there!", intermediate_outputs=None, trace=trace + ), + trace=trace, + ) + prior_run.save_to_file() + assert prior_run.id is not None + + continuation_trace = [ + *trace, + {"role": "user", "content": "Tell me more"}, + {"role": "assistant", "content": "More details!"}, + ] + continuation_output = RunOutput( + output="More details!", + intermediate_outputs=None, + trace=continuation_trace, + ) + + adapter._run = AsyncMock(return_value=(continuation_output, None)) + + with ( + patch("kiln_ai.utils.config.Config.shared") as mock_shared, + patch.object( + adapter, + "model_provider", + return_value=MagicMock( + parser="default", + formatter=None, + reasoning_capable=False, + ), + ), + patch( + "kiln_ai.adapters.model_adapters.base_adapter.model_parser_from_id", + return_value=MagicMock( + parse_output=MagicMock(return_value=continuation_output) + ), + ), + ): + mock_shared.return_value.autosave_runs = True + mock_shared.return_value.user_id = "test_user" + + new_run = await adapter.invoke( + "Tell me more", + prior_trace=trace, + parent_task_run=prior_run, + ) + + assert new_run.id is not None + assert new_run.parent == prior_run + + # The prior run should have the new run as a child + reloaded_prior_run = TaskRun.load_from_file(prior_run.path) + child_runs = reloaded_prior_run.runs() + assert len(child_runs) == 1 + assert child_runs[0].output.output == "More details!" + + # The task should only have the prior run as a direct child + reloaded_task = Task.load_from_file(test_task.path) + task_runs = reloaded_task.runs() + assert len(task_runs) == 1 + assert task_runs[0].id == prior_run.id diff --git a/libs/core/kiln_ai/adapters/model_adapters/test_stream_events.py b/libs/core/kiln_ai/adapters/model_adapters/test_stream_events.py new file mode 100644 index 000000000..e7abbb832 --- /dev/null +++ b/libs/core/kiln_ai/adapters/model_adapters/test_stream_events.py @@ -0,0 +1,387 @@ +from litellm.types.utils import ( + ChatCompletionDeltaToolCall, + Delta, + Function, + ModelResponseStream, + StreamingChoices, +) + +from kiln_ai.adapters.model_adapters.stream_events import ( + AiSdkEventType, + AiSdkStreamConverter, + AiSdkStreamEvent, + ToolCallEvent, + ToolCallEventType, +) + + +def _make_tool_call_delta( + index: int = 0, + call_id: str | None = None, + name: str | None = None, + arguments: str | None = None, +) -> ChatCompletionDeltaToolCall: + func = Function(name=name, arguments=arguments or "") + tc = ChatCompletionDeltaToolCall(index=index, function=func) + if call_id is not None: + tc.id = call_id + return tc + + +def _make_chunk( + content: str | None = None, + reasoning_content: str | None = None, + tool_calls: list[ChatCompletionDeltaToolCall] | None = None, + finish_reason: str | None = None, +) -> ModelResponseStream: + delta = Delta(content=content, tool_calls=tool_calls) + if reasoning_content is not None: + delta.reasoning_content = reasoning_content + choice = StreamingChoices( + index=0, + delta=delta, + finish_reason=finish_reason, + ) + return ModelResponseStream(id="test", choices=[choice]) + + +class TestAiSdkStreamEvent: + def test_model_dump(self): + event = AiSdkStreamEvent(AiSdkEventType.START, {"messageId": "msg-123"}) + dump = event.model_dump() + assert dump["type"] == "start" + assert dump["messageId"] == "msg-123" + + +class TestAiSdkStreamConverter: + def test_text_start_and_delta(self): + converter = AiSdkStreamConverter() + events = converter.convert_chunk(_make_chunk(content="Hello")) + types = [e.type for e in events] + assert AiSdkEventType.TEXT_START in types + assert AiSdkEventType.TEXT_DELTA in types + assert events[-1].payload["delta"] == "Hello" + + def test_text_delta_no_duplicate_start(self): + converter = AiSdkStreamConverter() + converter.convert_chunk(_make_chunk(content="Hello")) + events = converter.convert_chunk(_make_chunk(content=" world")) + types = [e.type for e in events] + assert AiSdkEventType.TEXT_START not in types + assert AiSdkEventType.TEXT_DELTA in types + + def test_reasoning_start_and_delta(self): + converter = AiSdkStreamConverter() + events = converter.convert_chunk(_make_chunk(reasoning_content="Thinking...")) + types = [e.type for e in events] + assert AiSdkEventType.REASONING_START in types + assert AiSdkEventType.REASONING_DELTA in types + + def test_reasoning_ends_when_content_starts(self): + converter = AiSdkStreamConverter() + converter.convert_chunk(_make_chunk(reasoning_content="Thinking...")) + events = converter.convert_chunk(_make_chunk(content="Answer")) + types = [e.type for e in events] + assert AiSdkEventType.REASONING_END in types + assert AiSdkEventType.TEXT_START in types + + def test_reasoning_ends_when_tool_calls_start(self): + converter = AiSdkStreamConverter() + converter.convert_chunk(_make_chunk(reasoning_content="Thinking...")) + + tc_delta = _make_tool_call_delta( + index=0, call_id="call_1", name="add", arguments='{"a":1}' + ) + events = converter.convert_chunk(_make_chunk(tool_calls=[tc_delta])) + types = [e.type for e in events] + assert AiSdkEventType.REASONING_END in types + + def test_tool_call_input_start_and_delta(self): + converter = AiSdkStreamConverter() + + tc_delta = _make_tool_call_delta( + index=0, call_id="call_1", name="add", arguments='{"a":' + ) + events = converter.convert_chunk(_make_chunk(tool_calls=[tc_delta])) + types = [e.type for e in events] + assert AiSdkEventType.TOOL_INPUT_START in types + assert AiSdkEventType.TOOL_INPUT_DELTA in types + + start_event = next( + e for e in events if e.type == AiSdkEventType.TOOL_INPUT_START + ) + assert start_event.payload["toolCallId"] == "call_1" + assert start_event.payload["toolName"] == "add" + + def test_close_open_blocks_closes_text(self): + converter = AiSdkStreamConverter() + converter.convert_chunk(_make_chunk(content="text")) + block_events = converter.close_open_blocks() + types = [e.type for e in block_events] + assert AiSdkEventType.TEXT_END in types + finish_events = converter.finalize() + assert any(e.type == AiSdkEventType.FINISH for e in finish_events) + + def test_close_open_blocks_closes_reasoning(self): + converter = AiSdkStreamConverter() + converter.convert_chunk(_make_chunk(reasoning_content="thinking")) + block_events = converter.close_open_blocks() + types = [e.type for e in block_events] + assert AiSdkEventType.REASONING_END in types + + def test_convert_tool_event_input_available(self): + converter = AiSdkStreamConverter() + event = ToolCallEvent( + event_type=ToolCallEventType.INPUT_AVAILABLE, + tool_call_id="call_1", + tool_name="add", + arguments={"a": 1, "b": 2}, + ) + events = converter.convert_tool_event(event) + assert len(events) == 1 + assert events[0].type == AiSdkEventType.TOOL_INPUT_AVAILABLE + assert events[0].payload["toolCallId"] == "call_1" + assert events[0].payload["input"] == {"a": 1, "b": 2} + + def test_convert_tool_event_output_available(self): + converter = AiSdkStreamConverter() + event = ToolCallEvent( + event_type=ToolCallEventType.OUTPUT_AVAILABLE, + tool_call_id="call_1", + tool_name="add", + result="3", + ) + events = converter.convert_tool_event(event) + assert len(events) == 1 + assert events[0].type == AiSdkEventType.TOOL_OUTPUT_AVAILABLE + assert events[0].payload["output"] == "3" + + def test_convert_tool_event_output_error(self): + converter = AiSdkStreamConverter() + event = ToolCallEvent( + event_type=ToolCallEventType.OUTPUT_ERROR, + tool_call_id="call_1", + tool_name="add", + error="Something went wrong", + ) + events = converter.convert_tool_event(event) + assert len(events) == 1 + assert events[0].type == AiSdkEventType.TOOL_OUTPUT_ERROR + assert events[0].payload["errorText"] == "Something went wrong" + + def test_reasoning_not_interrupted_by_empty_content(self): + # Minimax and similar models send chunks with both reasoning_content and + # delta.content="" simultaneously. Empty content must not close reasoning + # blocks or emit useless text-delta events. + converter = AiSdkStreamConverter() + + chunk1 = _make_chunk(reasoning_content="The", content="") + chunk2 = _make_chunk(reasoning_content=" user", content="") + chunk3 = _make_chunk(reasoning_content=" is", content="") + + events1 = converter.convert_chunk(chunk1) + events2 = converter.convert_chunk(chunk2) + events3 = converter.convert_chunk(chunk3) + + all_types1 = [e.type for e in events1] + all_types2 = [e.type for e in events2] + all_types3 = [e.type for e in events3] + + # First chunk opens the reasoning block + assert AiSdkEventType.REASONING_START in all_types1 + assert AiSdkEventType.REASONING_DELTA in all_types1 + # No text events from empty content + assert AiSdkEventType.TEXT_START not in all_types1 + assert AiSdkEventType.TEXT_DELTA not in all_types1 + + # Subsequent chunks must NOT re-open reasoning (no start) and must NOT + # close reasoning with reasoning-end + assert AiSdkEventType.REASONING_START not in all_types2 + assert AiSdkEventType.REASONING_END not in all_types2 + assert AiSdkEventType.REASONING_DELTA in all_types2 + assert AiSdkEventType.TEXT_DELTA not in all_types2 + + assert AiSdkEventType.REASONING_START not in all_types3 + assert AiSdkEventType.REASONING_END not in all_types3 + assert AiSdkEventType.REASONING_DELTA in all_types3 + assert AiSdkEventType.TEXT_DELTA not in all_types3 + + def test_reset_for_next_step(self): + converter = AiSdkStreamConverter() + converter._finish_reason = "tool_calls" + converter._tool_calls_state = { + 0: {"id": "x", "name": "y", "arguments": "", "started": True} + } + converter.reset_for_next_step() + assert converter._tool_calls_state == {} + assert converter._finish_reason is None + + def test_finish_reason_in_finalize(self): + converter = AiSdkStreamConverter() + converter.convert_chunk(_make_chunk(content="done", finish_reason="stop")) + converter.close_open_blocks() + events = converter.finalize() + finish_events = [e for e in events if e.type == AiSdkEventType.FINISH] + assert len(finish_events) == 1 + meta = finish_events[0].payload.get("messageMetadata", {}) + assert meta.get("finishReason") == "stop" + + def test_tool_input_start_reemitted_after_reset(self): + """After reset_for_next_step, tool-input-start must fire again for index 0.""" + converter = AiSdkStreamConverter() + + tc_round1 = _make_tool_call_delta( + index=0, call_id="call_r1", name="search", arguments='{"q":"hi"}' + ) + events_r1 = converter.convert_chunk(_make_chunk(tool_calls=[tc_round1])) + starts_r1 = [e for e in events_r1 if e.type == AiSdkEventType.TOOL_INPUT_START] + assert len(starts_r1) == 1 + assert starts_r1[0].payload["toolCallId"] == "call_r1" + + converter.reset_for_next_step() + + tc_round2 = _make_tool_call_delta( + index=0, call_id="call_r2", name="search", arguments='{"q":"world"}' + ) + events_r2 = converter.convert_chunk(_make_chunk(tool_calls=[tc_round2])) + starts_r2 = [e for e in events_r2 if e.type == AiSdkEventType.TOOL_INPUT_START] + assert len(starts_r2) == 1, ( + "tool-input-start must be re-emitted for index 0 after reset" + ) + assert starts_r2[0].payload["toolCallId"] == "call_r2" + + def test_text_ends_before_tool_calls(self): + converter = AiSdkStreamConverter() + events1 = converter.convert_chunk(_make_chunk(content="Hello")) + text_start = next(e for e in events1 if e.type == AiSdkEventType.TEXT_START) + text_id_1 = text_start.payload["id"] + + tc_delta = _make_tool_call_delta( + index=0, call_id="call_1", name="add", arguments='{"a":1}' + ) + events2 = converter.convert_chunk(_make_chunk(tool_calls=[tc_delta])) + types2 = [e.type for e in events2] + assert AiSdkEventType.TEXT_END in types2 + text_end_idx = types2.index(AiSdkEventType.TEXT_END) + tool_start_idx = types2.index(AiSdkEventType.TOOL_INPUT_START) + assert text_end_idx < tool_start_idx, ( + "text-end must come before tool-input-start" + ) + + text_end_event = next(e for e in events2 if e.type == AiSdkEventType.TEXT_END) + assert text_end_event.payload["id"] == text_id_1 + + def test_text_restarts_with_new_id_after_tool_calls(self): + converter = AiSdkStreamConverter() + events1 = converter.convert_chunk(_make_chunk(content="Hello")) + text_start_1 = next(e for e in events1 if e.type == AiSdkEventType.TEXT_START) + text_id_1 = text_start_1.payload["id"] + + tc_delta = _make_tool_call_delta( + index=0, call_id="call_1", name="add", arguments='{"a":1}' + ) + converter.convert_chunk(_make_chunk(tool_calls=[tc_delta])) + converter.reset_for_next_step() + + events3 = converter.convert_chunk(_make_chunk(content="Result")) + types3 = [e.type for e in events3] + assert AiSdkEventType.TEXT_START in types3 + text_start_2 = next(e for e in events3 if e.type == AiSdkEventType.TEXT_START) + text_id_2 = text_start_2.payload["id"] + assert text_id_1 != text_id_2, ( + "New text block after tool calls must have a different id" + ) + + def test_no_text_end_before_tool_calls_when_text_not_started(self): + converter = AiSdkStreamConverter() + tc_delta = _make_tool_call_delta( + index=0, call_id="call_1", name="add", arguments='{"a":1}' + ) + events = converter.convert_chunk(_make_chunk(tool_calls=[tc_delta])) + types = [e.type for e in events] + assert AiSdkEventType.TEXT_END not in types + + def test_tool_input_start_not_reemitted_without_reset(self): + """Without reset, a second tool call at index 0 must NOT re-emit tool-input-start.""" + converter = AiSdkStreamConverter() + + tc_round1 = _make_tool_call_delta( + index=0, call_id="call_r1", name="search", arguments='{"q":"hi"}' + ) + converter.convert_chunk(_make_chunk(tool_calls=[tc_round1])) + + tc_round2 = _make_tool_call_delta( + index=0, call_id="call_r2", name="search", arguments='{"q":"world"}' + ) + events_r2 = converter.convert_chunk(_make_chunk(tool_calls=[tc_round2])) + starts_r2 = [e for e in events_r2 if e.type == AiSdkEventType.TOOL_INPUT_START] + assert len(starts_r2) == 0, ( + "Without reset, started=True blocks duplicate tool-input-start" + ) + + def test_choice_with_none_delta_is_skipped(self): + chunk = ModelResponseStream( + id="test", + choices=[StreamingChoices(index=0, delta=None, finish_reason=None)], + ) + converter = AiSdkStreamConverter() + events = converter.convert_chunk(chunk) + assert events == [] + + def test_usage_data_extracted_from_empty_choices_chunk(self): + usage = type( + "Usage", + (), + { + "prompt_tokens": 10, + "completion_tokens": 20, + "total_tokens": 30, + }, + )() + chunk = ModelResponseStream(id="test", choices=[]) + chunk.usage = usage + + converter = AiSdkStreamConverter() + converter.convert_chunk(chunk) + assert converter._usage_data is usage + + def test_finalize_includes_usage_payload(self): + usage = type( + "Usage", + (), + { + "prompt_tokens": 15, + "completion_tokens": 25, + "total_tokens": 40, + }, + )() + converter = AiSdkStreamConverter() + converter._usage_data = usage + converter._finish_reason = "stop" + + events = converter.finalize() + finish = next(e for e in events if e.type == AiSdkEventType.FINISH) + meta = finish.payload["messageMetadata"] + assert meta["finishReason"] == "stop" + assert meta["usage"]["promptTokens"] == 15 + assert meta["usage"]["completionTokens"] == 25 + assert meta["usage"]["totalTokens"] == 40 + + def test_finalize_usage_without_total_tokens(self): + usage = type( + "Usage", + (), + { + "prompt_tokens": 5, + "completion_tokens": 10, + }, + )() + converter = AiSdkStreamConverter() + converter._usage_data = usage + + events = converter.finalize() + finish = next(e for e in events if e.type == AiSdkEventType.FINISH) + meta = finish.payload["messageMetadata"] + assert meta["usage"]["promptTokens"] == 5 + assert meta["usage"]["completionTokens"] == 10 + assert "totalTokens" not in meta["usage"] diff --git a/libs/core/kiln_ai/adapters/model_adapters/test_structured_output.py b/libs/core/kiln_ai/adapters/model_adapters/test_structured_output.py index fd5451705..46fa88aa7 100644 --- a/libs/core/kiln_ai/adapters/model_adapters/test_structured_output.py +++ b/libs/core/kiln_ai/adapters/model_adapters/test_structured_output.py @@ -1,7 +1,7 @@ import json from pathlib import Path from typing import Dict -from unittest.mock import Mock, patch +from unittest.mock import AsyncMock, Mock, patch import pytest from litellm.types.utils import ModelResponse @@ -10,6 +10,7 @@ from kiln_ai.adapters.adapter_registry import adapter_for_task from kiln_ai.adapters.ml_model_list import built_in_models from kiln_ai.adapters.model_adapters.base_adapter import BaseAdapter, RunOutput, Usage +from kiln_ai.adapters.model_adapters.litellm_adapter import LiteLlmAdapter from kiln_ai.adapters.ollama_tools import ollama_online from kiln_ai.adapters.test_prompt_adaptors import get_all_models_and_providers from kiln_ai.datamodel import PromptId @@ -53,7 +54,7 @@ def __init__(self, kiln_task: datamodel.Task, response: InputType | None): ) self.response = response - async def _run(self, input: str) -> tuple[RunOutput, Usage | None]: + async def _run(self, input: str, **kwargs) -> tuple[RunOutput, Usage | None]: return RunOutput(output=self.response, intermediate_outputs=None), None def adapter_name(self) -> str: @@ -347,9 +348,10 @@ async def test_all_built_in_models_structured_input_mocked(tmp_path): mock_config.groq_api_key = "mock_api_key" with ( - patch( - "litellm.acompletion", - side_effect=[mock_response], + patch.object( + LiteLlmAdapter, + "acompletion_checking_response", + new=AsyncMock(return_value=(mock_response, mock_response.choices[0])), ), patch("kiln_ai.utils.config.Config.shared", return_value=mock_config), ): @@ -402,9 +404,15 @@ async def test_structured_input_cot_prompt_builder_mocked(tmp_path): mock_config.groq_api_key = "mock_api_key" with ( - patch( - "litellm.acompletion", - side_effect=[mock_response_1, mock_response_2], + patch.object( + LiteLlmAdapter, + "acompletion_checking_response", + new=AsyncMock( + side_effect=[ + (mock_response_1, mock_response_1.choices[0]), + (mock_response_2, mock_response_2.choices[0]), + ] + ), ), patch("kiln_ai.utils.config.Config.shared", return_value=mock_config), ): diff --git a/libs/core/kiln_ai/adapters/test_prompt_adaptors.py b/libs/core/kiln_ai/adapters/test_prompt_adaptors.py index 23dd15b0d..187de1ff2 100644 --- a/libs/core/kiln_ai/adapters/test_prompt_adaptors.py +++ b/libs/core/kiln_ai/adapters/test_prompt_adaptors.py @@ -1,9 +1,9 @@ import os from pathlib import Path -from unittest.mock import patch +from unittest.mock import AsyncMock, patch import pytest -from litellm.utils import ModelResponse +from litellm.types.utils import ModelResponse import kiln_ai.datamodel as datamodel from kiln_ai.adapters.adapter_registry import adapter_for_task @@ -113,13 +113,15 @@ async def test_amazon_bedrock(tmp_path): async def test_mock_returning_run(tmp_path): task = build_test_task(tmp_path) - with patch("litellm.acompletion") as mock_acompletion: - # Configure the mock to return a properly structured response - mock_acompletion.return_value = ModelResponse( - model="custom_model", - choices=[{"message": {"content": "mock response"}}], - ) - + mock_response = ModelResponse( + model="custom_model", + choices=[{"message": {"content": "mock response"}}], + ) + with patch.object( + LiteLlmAdapter, + "acompletion_checking_response", + new=AsyncMock(return_value=(mock_response, mock_response.choices[0])), + ): run_config = KilnAgentRunConfigProperties( model_name="custom_model", model_provider_name=ModelProviderName.ollama, diff --git a/libs/core/kiln_ai/adapters/test_prompt_builders.py b/libs/core/kiln_ai/adapters/test_prompt_builders.py index b7dca4463..f488e68b0 100644 --- a/libs/core/kiln_ai/adapters/test_prompt_builders.py +++ b/libs/core/kiln_ai/adapters/test_prompt_builders.py @@ -60,7 +60,7 @@ def test_simple_prompt_builder(tmp_path): class MockAdapter(BaseAdapter): - async def _run(self, input: InputType) -> tuple[RunOutput, Usage | None]: + async def _run(self, input: InputType, **kwargs) -> tuple[RunOutput, Usage | None]: return RunOutput(output="mock response", intermediate_outputs=None), None def adapter_name(self) -> str: diff --git a/libs/core/kiln_ai/datamodel/basemodel.py b/libs/core/kiln_ai/datamodel/basemodel.py index c29f11f7e..90e981b78 100644 --- a/libs/core/kiln_ai/datamodel/basemodel.py +++ b/libs/core/kiln_ai/datamodel/basemodel.py @@ -545,17 +545,49 @@ def relationship_name(cls) -> str: def parent_type(cls) -> Type[KilnBaseModel]: raise NotImplementedError("Parent type must be implemented") - @model_validator(mode="after") - def check_parent_type(self) -> Self: + @classmethod + def _parent_types(cls) -> List[Type["KilnBaseModel"]] | None: + """Return accepted parent types. This must be implemented by the subclass if + the model can have multiple parent types. + + Return None (default) to use the single parent_type() check. + Override and return a list of parent types for models that can be nested + under more than one parent type (e.g. TaskRun can be nested under Task or TaskRun). + """ + return None + + def _check_parent_type( + self, + expected_parent_types: List[Type[KilnBaseModel]] | None = None, + ) -> Self: cached_parent = self.cached_parent() - if cached_parent is not None: + if cached_parent is None: + return self + + # some models support having multiple parent types, so we allow overriding the expected parent + if expected_parent_types is not None: + if not any( + isinstance(cached_parent, expected_parent_type) + for expected_parent_type in expected_parent_types + ): + raise ValueError( + f"Parent must be one of {expected_parent_types}, but was {type(cached_parent)}" + ) + else: + # default case where we expect a single parent type to be valid expected_parent_type = self.__class__.parent_type() if not isinstance(cached_parent, expected_parent_type): raise ValueError( f"Parent must be of type {expected_parent_type}, but was {type(cached_parent)}" ) + return self + @model_validator(mode="after") + def check_parent_type(self) -> Self: + """Default validation for parent type. Can be overridden by subclasses - for example if the parent is polymorphic.""" + return self._check_parent_type() + def build_child_dirname(self) -> Path: # Default implementation for readable folder names. # {id} - {name}/{type}.kiln @@ -602,10 +634,39 @@ def iterate_children_paths_of_parent_path(cls: Type[PT], parent_path: Path | Non else: parent_folder = parent_path - parent = cls.parent_type().load_from_file(parent_path) - if parent is None: + if not parent_path.exists(): raise ValueError("Parent must be set to load children") + # Validate the parent file's declared type so we fail fast when the caller + # passes a wrong path. For polymorphic children (e.g. TaskRun) the + # subclass overrides _accepted_parent_types() to broaden the check to all + # accepted parent types + parent_types_override = cls._parent_types() + if parent_types_override is None: + # Default: single expected parent type — original behaviour + parent = cls.parent_type().load_from_file(parent_path) + if parent is None: + raise ValueError("Parent must be set to load children") + else: + # Polymorphic parent: read only the model_type field to avoid a full load. + with open(parent_path, "r", encoding="utf-8") as fh: + actual_parent_type_name = json.loads(fh.read()).get("model_type", "") + parent_type_names = {t.type_name() for t in parent_types_override} + if actual_parent_type_name not in parent_type_names: + raise ValueError( + f"Parent model_type '{actual_parent_type_name}' is not one of " + f"{parent_type_names}" + ) + + parent_type = next( + t + for t in parent_types_override + if t.type_name() == actual_parent_type_name + ) + parent = parent_type.load_from_file(parent_path) + if parent is None: + raise ValueError("Parent must be set to load children") + # Ignore type error: this is abstract base class, but children must implement relationship_name relationship_folder = parent_folder / Path(cls.relationship_name()) # type: ignore diff --git a/libs/core/kiln_ai/datamodel/task.py b/libs/core/kiln_ai/datamodel/task.py index c223164cb..d5ace6449 100644 --- a/libs/core/kiln_ai/datamodel/task.py +++ b/libs/core/kiln_ai/datamodel/task.py @@ -203,3 +203,21 @@ def parent_project(self) -> Union["Project", None]: if self.parent is None or self.parent.__class__.__name__ != "Project": return None return self.parent # type: ignore + + def find_task_run_by_id_dfs( + self, task_run_id: str, readonly: bool = False + ) -> TaskRun | None: + """ + Find a task run by id in the entire task run tree. This is an expensive DFS + traversal of the file system so do not use too willy nilly. + + If you already know the root task run, you can use the same method on + the root TaskRun instead - that will save a bunch of subtree traversals. + """ + stack: List[TaskRun] = list(self.runs(readonly=readonly)) + while stack: + run = stack.pop() + if run.id == task_run_id: + return run + stack.extend(run.runs(readonly=readonly)) + return None diff --git a/libs/core/kiln_ai/datamodel/task_run.py b/libs/core/kiln_ai/datamodel/task_run.py index bef7ae700..40dac39e3 100644 --- a/libs/core/kiln_ai/datamodel/task_run.py +++ b/libs/core/kiln_ai/datamodel/task_run.py @@ -1,10 +1,14 @@ import json -from typing import TYPE_CHECKING, Dict, List, Union +from typing import TYPE_CHECKING, Dict, List, Optional, Type, Union from pydantic import BaseModel, Field, ValidationInfo, model_validator from typing_extensions import Self -from kiln_ai.datamodel.basemodel import KilnParentedModel +from kiln_ai.datamodel.basemodel import ( + KilnBaseModel, + KilnParentedModel, + KilnParentModel, +) from kiln_ai.datamodel.json_schema import validate_schema_with_value_error from kiln_ai.datamodel.strict_mode import strict_mode from kiln_ai.datamodel.task_output import DataSource, TaskOutput @@ -73,12 +77,17 @@ def _add_optional_float(a: float | None, b: float | None) -> float | None: ) -class TaskRun(KilnParentedModel): +class TaskRun(KilnParentedModel, KilnParentModel, parent_of={}): """ Represents a single execution of a Task. Contains the input used, its source, the output produced, and optional repair information if the output needed correction. + + Can be nested under another TaskRun; nested runs are stored as child runs + in a "runs" subfolder (same relationship name as Task's runs). + + Accepts both Task and TaskRun as parents (polymorphic). """ input: str = Field( @@ -132,9 +141,108 @@ def has_thinking_training_data(self) -> bool: # Workaround to return typed parent without importing Task def parent_task(self) -> Union["Task", None]: - if self.parent is None or self.parent.__class__.__name__ != "Task": + """The Task that this Run is in. Note the TaskRun may be nested in which case we walk back up the tree all the way to the root.""" + # lazy import to avoid circular dependency + from kiln_ai.datamodel.task import Task + + current: TaskRun = self + while True: + # should never really happen, except maybe in tests + parent = current.parent + if parent is None: + return None + + # this task run is the root task run + # so we just return its parent (a Task) + if isinstance(parent, Task): + return parent + + if not isinstance(parent, TaskRun): + # the parent is not a TaskRun, but also not a Task, so it is not + # a real parent + return None + + # the parent is a TaskRun, so we just walk up the tree until we find a Task + current = parent + + def parent_run(self) -> "TaskRun | None": + """The TaskRun that contains this run, if this run is nested; otherwise None.""" + parent = self.parent + if parent is None or not isinstance(parent, TaskRun): return None - return self.parent # type: ignore + return parent + + @classmethod + def _parent_types(cls) -> List[Type["KilnBaseModel"]]: + # lazy import to avoid circular dependency + from kiln_ai.datamodel.task import Task + + return [Task, TaskRun] + + def runs(self, readonly: bool = False) -> list["TaskRun"]: + """The list of child task runs.""" + return super().runs(readonly=readonly) # type: ignore + + def is_root_task_run(self) -> bool: + """Is this the root task run? (not nested under another task run)""" + # lazy import to avoid circular dependency + from kiln_ai.datamodel.task import Task + + return self.parent is None or isinstance(self.parent, Task) + + def find_task_run_by_id_dfs( + self, task_run_id: str, readonly: bool = False + ) -> "TaskRun | None": + """ + Find a task run by id in the entire task run tree. This is an expensive DFS + traversal of the file system so do not use too willy nilly. + """ + stack: List[TaskRun] = list(self.runs(readonly=readonly)) + while stack: + run = stack.pop() + if run.id == task_run_id: + return run + stack.extend(run.runs(readonly=readonly)) + return None + + def load_parent(self) -> Optional[KilnBaseModel]: + """Load the parent of this task run - this is an override of the default parent loading logic to support nested task runs.""" + cached = self.cached_parent() + if cached is not None: + return cached + if self.path is None: + return None + parent_dir = self.path.parent.parent.parent + task_run_path = parent_dir / TaskRun.base_filename() + if task_run_path.exists() and task_run_path != self.path: + try: + loaded_parent_run = TaskRun.load_from_file(task_run_path) + super().__setattr__("parent", loaded_parent_run) + return loaded_parent_run + except ValueError as e: + raise ValueError( + f"Failed to load parent TaskRun from {task_run_path}. " + f"This indicates a malformed nested task run. Error: {e}" + ) from e + + from kiln_ai.datamodel.task import Task + + task_path = parent_dir / Task.base_filename() + if task_path.exists(): + loaded_parent_task = Task.load_from_file(task_path) + super().__setattr__("parent", loaded_parent_task) + return loaded_parent_task + + return None + + @model_validator(mode="after") + def check_parent_type(self) -> Self: + """Check that the parent is a Task or TaskRun. This overrides the default parent type check + that only supports a single parent type.""" + # need to import here to avoid circular imports + from kiln_ai.datamodel.task import Task + + return self._check_parent_type([Task, TaskRun]) @model_validator(mode="after") def validate_input_format(self, info: ValidationInfo) -> Self: @@ -258,3 +366,10 @@ def validate_tags(self) -> Self: raise ValueError("Tags cannot contain spaces. Try underscores.") return self + + +# cannot do this in the class definition due to circular reference between TaskRun and itself: +# wire up TaskRun as its own child type so .runs() returns TaskRun instances +# this makes TaskRun polymorphic - can be parented under Task or another TaskRun +TaskRun._parent_of["runs"] = TaskRun +TaskRun._create_child_method("runs", TaskRun) diff --git a/libs/core/kiln_ai/datamodel/test_basemodel.py b/libs/core/kiln_ai/datamodel/test_basemodel.py index 92fadb7fc..be4bc8c45 100644 --- a/libs/core/kiln_ai/datamodel/test_basemodel.py +++ b/libs/core/kiln_ai/datamodel/test_basemodel.py @@ -11,7 +11,7 @@ from kiln_ai.adapters.model_adapters.base_adapter import BaseAdapter from kiln_ai.adapters.run_output import RunOutput -from kiln_ai.datamodel import Task, TaskRun +from kiln_ai.datamodel import Project, Task, TaskOutput, TaskRun from kiln_ai.datamodel.basemodel import ( MAX_FILENAME_LENGTH, KilnBaseModel, @@ -862,7 +862,7 @@ def individual_lookups(): class MockAdapter(BaseAdapter): """Implementation of BaseAdapter for testing""" - async def _run(self, input): + async def _run(self, input, **kwargs): return RunOutput(output="test output", intermediate_outputs=None), None def adapter_name(self) -> str: @@ -1166,3 +1166,200 @@ def test_readonly_cache_integration(tmp_model_cache, tmp_path): cached_readonly = ReadonlyTestModel.load_from_file(test_file, readonly=True) assert cached_readonly._readonly is True assert cached_readonly.name == "cached_model" + + +# ============================================================================ +# Tests for _parent_types() and parent type validation in _load_parent_and_validate_children +# ============================================================================ + + +def test_default_parent_types_returns_none(): + """Default _parent_types() returns None for models with single parent type.""" + assert DefaultParentedModel._parent_types() is None + assert NamedParentedModel._parent_types() is None + + +def test_taskrun_parent_types_returns_task_and_taskrun(): + """TaskRun._parent_types() returns [Task, TaskRun] for polymorphic parent support.""" + + parent_types = TaskRun._parent_types() + assert parent_types is not None + assert len(parent_types) == 2 + + parent_type_names = {t.type_name() for t in parent_types} + assert "task" in parent_type_names + assert "task_run" in parent_type_names + + +def test_invalid_parent_with_single_parent_type(tmp_path): + """Loading children fails when parent path points to wrong model type (single parent case).""" + # Create a project (wrong parent type) + project_path = tmp_path / "project.kiln" + project = Project(name="Test Project", path=project_path) + project.save_to_file() + + # Try to load DefaultParentedModel children from a Project path + # DefaultParentedModel expects BaseParentExample as parent, not Project + # The error occurs when trying to load the parent as the wrong type + with pytest.raises( + ValueError, match="Cannot load from file because the model type is incorrect" + ): + list(DefaultParentedModel.iterate_children_paths_of_parent_path(project_path)) + + +def test_invalid_parent_with_multiple_parent_types(tmp_path): + """Loading children fails when parent's model_type is not in accepted polymorphic types.""" + # Create a project (not an accepted parent type for TaskRun) + project_path = tmp_path / "project.kiln" + project = Project(name="Test Project", path=project_path) + project.save_to_file() + + # Try to load TaskRun children from a Project path + # TaskRun accepts Task and TaskRun, not Project + with pytest.raises(ValueError, match="Parent model_type 'project' is not one of"): + list(TaskRun.iterate_children_paths_of_parent_path(project_path)) + + +def test_valid_parent_type_single_parent(tmp_path): + """Successfully loads children when parent type matches single expected type.""" + parent = BaseParentExample(path=tmp_path / BaseParentExample.base_filename()) + parent.save_to_file() + + child = DefaultParentedModel(parent=parent, name="Test Child") + child.save_to_file() + + # Load children - should succeed since parent is correct type + children = list( + DefaultParentedModel.iterate_children_paths_of_parent_path(parent.path) + ) + assert len(children) == 1 + assert children[0] == child.path + + +def test_valid_parent_type_polymorphic_taskrun_as_parent(tmp_path): + """TaskRun can be loaded with TaskRun as parent (polymorphic case).""" + output = TaskOutput(output="test output") + + # Create a task as the ultimate parent + task = Task( + name="Test Task", + instruction="Test instruction", + path=tmp_path / "task.kiln", + ) + task.save_to_file() + + # Create a parent TaskRun + parent_run = TaskRun(input="parent input", output=output, parent=task) + parent_run.save_to_file() + + # Create a nested TaskRun + nested_run = TaskRun(input="nested input", output=output, parent=parent_run) + nested_run.save_to_file() + + # Load children of parent_run - should succeed + children = list(TaskRun.iterate_children_paths_of_parent_path(parent_run.path)) + assert len(children) == 1 + assert children[0] == nested_run.path + + +def test_valid_parent_type_polymorphic_task_as_parent(tmp_path): + """TaskRun can be loaded with Task as parent (polymorphic case).""" + output = TaskOutput(output="test output") + + # Create a task + task = Task( + name="Test Task", + instruction="Test instruction", + path=tmp_path / "task.kiln", + ) + task.save_to_file() + + # Create a TaskRun under the task + task_run = TaskRun(input="test input", output=output, parent=task) + task_run.save_to_file() + + # Load children of task - should succeed + children = list(TaskRun.iterate_children_paths_of_parent_path(task.path)) + assert len(children) == 1 + assert children[0] == task_run.path + + +def test_invalid_parent_type_name_mismatch_polymorphic(tmp_path): + """Polymorphic validation fails when parent file has wrong model_type.""" + # Create a file with wrong model_type + wrong_parent_path = tmp_path / "wrong_parent.kiln" + wrong_data = { + "v": 1, + "name": "Wrong Parent", + "model_type": "project", # Wrong type - not in accepted types + } + with open(wrong_parent_path, "w") as f: + json.dump(wrong_data, f) + + # Try to load TaskRun children from wrong parent + # TaskRun accepts Task and TaskRun, not Project + with pytest.raises(ValueError, match="Parent model_type 'project' is not one of"): + list(TaskRun.iterate_children_paths_of_parent_path(wrong_parent_path)) + + +def test_parent_loading_single_parent_type_fails_on_corrupt_file(tmp_path): + """Single parent type loading fails when parent file is corrupt/invalid.""" + # Create a corrupt parent file + corrupt_parent_path = tmp_path / "corrupt.kiln" + with open(corrupt_parent_path, "w") as f: + f.write("not valid json {{{") + + # The load_from_file call within iterate_children_paths_of_parent_path + # will fail when trying to parse the corrupt JSON + with pytest.raises(ValueError, match="Expecting value"): + list( + DefaultParentedModel.iterate_children_paths_of_parent_path( + corrupt_parent_path + ) + ) + + +def test_parent_loading_polymorphic_fails_on_corrupt_file(tmp_path): + """Polymorphic parent loading fails when parent file is corrupt/invalid.""" + # Create a corrupt parent file + corrupt_parent_path = tmp_path / "corrupt.kiln" + with open(corrupt_parent_path, "w") as f: + f.write("not valid json {{{") + + # The polymorphic path first reads model_type, which will fail on corrupt JSON + with pytest.raises(json.JSONDecodeError): + list(TaskRun.iterate_children_paths_of_parent_path(corrupt_parent_path)) + + +def test_parent_loading_single_parent_nonexistent_file(tmp_path): + """Single parent type loading fails when parent file doesn't exist.""" + nonexistent_path = tmp_path / "nonexistent.kiln" + + with pytest.raises(ValueError, match="Parent must be set to load children"): + list( + DefaultParentedModel.iterate_children_paths_of_parent_path(nonexistent_path) + ) + + +def test_parent_loading_polymorphic_nonexistent_file(tmp_path): + """Polymorphic parent loading fails when parent file doesn't exist.""" + nonexistent_path = tmp_path / "nonexistent.kiln" + + with pytest.raises(ValueError, match="Parent must be set to load children"): + list(TaskRun.iterate_children_paths_of_parent_path(nonexistent_path)) + + +def test_all_children_of_parent_path_single_parent_type(tmp_path): + """all_children_of_parent_path works correctly for single parent type models.""" + parent = BaseParentExample(path=tmp_path / "parent.kiln") + parent.save_to_file() + + child1 = DefaultParentedModel(parent=parent, name="Child1") + child2 = DefaultParentedModel(parent=parent, name="Child2") + child1.save_to_file() + child2.save_to_file() + + children = DefaultParentedModel.all_children_of_parent_path(parent.path) + assert len(children) == 2 + names = {child.name for child in children} + assert names == {"Child1", "Child2"} diff --git a/libs/core/kiln_ai/datamodel/test_models.py b/libs/core/kiln_ai/datamodel/test_models.py index 68d0f611e..cc11ba9a6 100644 --- a/libs/core/kiln_ai/datamodel/test_models.py +++ b/libs/core/kiln_ai/datamodel/test_models.py @@ -743,3 +743,464 @@ def test_generate_model_id(): # check it is a valid name - as we typically use model ids in filenames on FS validator = name_validator(min_length=1, max_length=12) validator(model_id) + + +# project and task fixture +@pytest.fixture +def task(tmp_path): + project_path = tmp_path / "project.kiln" + project = Project(name="P", path=project_path) + project.save_to_file() + task = Task(name="T", instruction="Do it", parent=project) + task.save_to_file() + return task + + +def test_nested_task_run_folder_structure(task: Task): + output = TaskOutput(output="out") + + parent_run = TaskRun(input="in", output=output, parent=task) + parent_run.save_to_file() + + nested_run = TaskRun(input="nested in", output=output, parent=parent_run) + nested_run.save_to_file() + + assert task.path is not None + assert parent_run.path is not None + assert nested_run.path is not None + task_dir = task.path.parent + runs_dir = task_dir / "runs" + parent_run_dir = runs_dir / parent_run.build_child_dirname() + nested_runs_dir = parent_run_dir / "runs" + + assert runs_dir.is_dir() + assert (parent_run_dir / "task_run.kiln").is_file() + assert nested_runs_dir.is_dir() + assert nested_run.path.is_file() + assert nested_run.path.parent.parent.parent == parent_run_dir + assert nested_run.path.name == TaskRun.base_filename() + + assert parent_run.parent_task() == task + assert parent_run.parent_run() is None + assert nested_run.parent_task() == task + assert nested_run.parent_run() == parent_run + assert len(parent_run.runs()) == 1 + assert parent_run.runs()[0].id == nested_run.id + + +def test_nested_task_runs_multiple_levels(task: Task): + output = TaskOutput(output="out") + + run1 = TaskRun(input="in1", output=output, parent=task) + run1.save_to_file() + run2 = TaskRun(input="in2", output=output, parent=run1) + run2.save_to_file() + run3 = TaskRun(input="in3", output=output, parent=run2) + run3.save_to_file() + + assert run1.parent_run() is None + assert run1.parent_task() == task + assert len(run1.runs()) == 1 + assert run1.runs()[0].id == run2.id + + assert run2.parent_run() == run1 + assert run2.parent_task() == task + assert len(run2.runs()) == 1 + assert run2.runs()[0].id == run3.id + + assert run3.parent_run() == run2 + assert run3.parent_task() == task + assert len(run3.runs()) == 0 + + +def test_parent_task_deeply_nested_task_run(task: Task): + output = TaskOutput(output="out") + run1 = TaskRun(input="in1", output=output, parent=task) + run1.save_to_file() + run2 = TaskRun(input="in2", output=output, parent=run1) + run2.save_to_file() + run3 = TaskRun(input="in3", output=output, parent=run2) + run3.save_to_file() + + assert run3.parent_task() == task + + +def test_find_nested_task_run_by_id_given_parent_run(task: Task): + assert task.path is not None + + output = TaskOutput(output="out") + parent_run = TaskRun(input="in", output=output, parent=task) + parent_run.save_to_file() + nested_run = TaskRun(input="nested in", output=output, parent=parent_run) + nested_run.save_to_file() + target_id = nested_run.id + + loaded_task = Task.load_from_file(task.path) + loaded_parent = next(r for r in loaded_task.runs() if r.id == parent_run.id) + found = next(r for r in loaded_parent.runs() if r.id == target_id) + assert found is not None + assert found.id == target_id + assert found.input == "nested in" + + +def test_find_root_task_run_by_id_given_task(task: Task): + output = TaskOutput(output="out") + root_run = TaskRun(input="in", output=output, parent=task) + root_run.save_to_file() + target_id = root_run.id + + assert task.path is not None + loaded_task = Task.load_from_file(task.path) + found = next(r for r in loaded_task.runs() if r.id == target_id) + assert found is not None + assert found.id == target_id + assert found.input == "in" + + +def test_find_task_run_by_id_dfs_finds_deeply_nested_run(task: Task): + """Task.find_task_run_by_id_dfs finds a run nested several levels (iterative stack-based DFS).""" + assert task.path is not None + output = TaskOutput(output="out") + + run1 = TaskRun(input="in1", output=output, parent=task) + run1.save_to_file() + + run2 = TaskRun(input="in2", output=output, parent=run1) + run2.save_to_file() + + run3 = TaskRun(input="in3", output=output, parent=run2) + run3.save_to_file() + + assert run3.id is not None + + loaded_task = Task.load_from_file(task.path) + found = loaded_task.find_task_run_by_id_dfs(run3.id) + assert found is not None + assert found.id == run3.id + assert found.input == "in3" + + +def test_find_task_run_by_id_dfs_finds_root_run(task: Task): + """Task.find_task_run_by_id_dfs finds the root run when it is the only run.""" + assert task.path is not None + output = TaskOutput(output="out") + root_run = TaskRun(input="root in", output=output, parent=task) + root_run.save_to_file() + assert root_run.id is not None + + loaded_task = Task.load_from_file(task.path) + found = loaded_task.find_task_run_by_id_dfs(root_run.id) + assert found is not None + assert found.id == root_run.id + assert found.input == "root in" + + +def test_find_task_run_by_id_dfs_returns_none_when_not_found(task: Task): + """Task.find_task_run_by_id_dfs returns None when no run has the given id.""" + assert task.path is not None + output = TaskOutput(output="out") + TaskRun(input="in", output=output, parent=task).save_to_file() + + loaded_task = Task.load_from_file(task.path) + found = loaded_task.find_task_run_by_id_dfs("nonexistent-id") + assert found is None + + +def test_task_run_find_task_run_by_id_dfs_finds_deeply_nested_run(task: Task): + """TaskRun.find_task_run_by_id_dfs finds a run nested several levels in its subtree.""" + assert task.path is not None + output = TaskOutput(output="out") + + run1 = TaskRun(input="in1", output=output, parent=task) + run1.save_to_file() + + run2 = TaskRun(input="in2", output=output, parent=run1) + run2.save_to_file() + + run3 = TaskRun(input="in3", output=output, parent=run2) + run3.save_to_file() + + assert run3.id is not None + + assert run1.path is not None + loaded_run1 = TaskRun.load_from_file(run1.path) + found = loaded_run1.find_task_run_by_id_dfs(run3.id) + assert found is not None + assert found.id == run3.id + assert found.input == "in3" + + +def test_task_run_find_task_run_by_id_dfs_finds_direct_child(task: Task): + """TaskRun.find_task_run_by_id_dfs finds a direct child run.""" + assert task.path is not None + output = TaskOutput(output="out") + + run1 = TaskRun(input="in1", output=output, parent=task) + run1.save_to_file() + + run2 = TaskRun(input="in2", output=output, parent=run1) + run2.save_to_file() + + assert run2.id is not None + + loaded_run1 = TaskRun.load_from_file(run1.path) + found = loaded_run1.find_task_run_by_id_dfs(run2.id) + assert found is not None + assert found.id == run2.id + assert found.input == "in2" + + +def test_task_run_find_task_run_by_id_dfs_returns_none_when_not_found(task: Task): + """TaskRun.find_task_run_by_id_dfs returns None when no run in its subtree has the given id.""" + assert task.path is not None + output = TaskOutput(output="out") + + run1 = TaskRun(input="in1", output=output, parent=task) + run1.save_to_file() + + TaskRun(input="in2", output=output, parent=run1).save_to_file() + + loaded_run1 = TaskRun.load_from_file(run1.path) + found = loaded_run1.find_task_run_by_id_dfs("nonexistent-id") + assert found is None + + +def test_is_root_task_run(task: Task): + output = TaskOutput(output="out") + root_run = TaskRun(input="in", output=output, parent=task) + root_run.save_to_file() + assert root_run.is_root_task_run() + nested_run = TaskRun(input="nested in", output=output, parent=root_run) + nested_run.save_to_file() + assert not nested_run.is_root_task_run() + + +def test_comprehensive_nested_task_run_hierarchy(tmp_path): + """Comprehensive integration test for polymorphic parent support on TaskRun. + + Tests: + - Project -> Task -> TaskRun hierarchy + - Multiple levels of nested TaskRuns (Task -> TaskRun -> TaskRun -> TaskRun -> TaskRun) + - Sibling TaskRuns at various levels + - Retrieval of all runs from the hierarchy + - is_root_task_run() correctness at all levels + - parent_task() and parent_run() correctness at all levels + """ + project_path = tmp_path / "project.kiln" + project = Project(name="Test Project", path=project_path) + project.save_to_file() + task = Task(name="Test Task", instruction="Test instruction", parent=project) + task.save_to_file() + + output = TaskOutput(output="test output") + + # Level 1: Two sibling TaskRuns under Task + run1_l1 = TaskRun(input="level1_run1", output=output, parent=task) + run1_l1.save_to_file() + + run2_l1 = TaskRun(input="level1_run2", output=output, parent=task) + run2_l1.save_to_file() + + # Level 2: Nested TaskRuns under run1_l1 (with sibling) + run1_l2 = TaskRun(input="level2_run1", output=output, parent=run1_l1) + run1_l2.save_to_file() + + run2_l2 = TaskRun(input="level2_run2", output=output, parent=run1_l1) + run2_l2.save_to_file() + + # Level 3: Nested TaskRuns under run1_l2 (with sibling) + run1_l3 = TaskRun(input="level3_run1", output=output, parent=run1_l2) + run1_l3.save_to_file() + + run2_l3 = TaskRun(input="level3_run2", output=output, parent=run1_l2) + run2_l3.save_to_file() + + # Level 4: Deepest nested TaskRun under run1_l3 + run1_l4 = TaskRun(input="level4_run1", output=output, parent=run1_l3) + run1_l4.save_to_file() + + # Level 4 sibling of run1_l4 under run2_l3 + run2_l4 = TaskRun(input="level4_sibling", output=output, parent=run2_l3) + run2_l4.save_to_file() + + # Reload everything from disk to test persistence and retrieval + loaded_project = Project.load_from_file(project_path) + loaded_task = loaded_project.tasks()[0] + + # Verify Task has 2 root-level runs + root_runs = loaded_task.runs() + assert len(root_runs) == 2 + root_run_ids = {r.id for r in root_runs} + assert run1_l1.id in root_run_ids + assert run2_l1.id in root_run_ids + + # Verify run1_l1 hierarchy + loaded_run1_l1 = next(r for r in root_runs if r.id == run1_l1.id) + assert loaded_run1_l1.parent_task() == loaded_task + assert loaded_run1_l1.parent_run() is None + assert loaded_run1_l1.is_root_task_run() is True + + level2_runs = loaded_run1_l1.runs() + assert len(level2_runs) == 2 + level2_run_ids = {r.id for r in level2_runs} + assert run1_l2.id in level2_run_ids + assert run2_l2.id in level2_run_ids + + # Verify run1_l2 hierarchy + loaded_run1_l2 = next(r for r in level2_runs if r.id == run1_l2.id) + assert loaded_run1_l2.parent_task() == loaded_task + assert loaded_run1_l2.parent_run() == loaded_run1_l1 + assert loaded_run1_l2.is_root_task_run() is False + + level3_runs = loaded_run1_l2.runs() + assert len(level3_runs) == 2 + level3_run_ids = {r.id for r in level3_runs} + assert run1_l3.id in level3_run_ids + assert run2_l3.id in level3_run_ids + + # Verify run1_l3 hierarchy (level 3) + loaded_run1_l3 = next(r for r in level3_runs if r.id == run1_l3.id) + assert loaded_run1_l3.parent_task() == loaded_task + assert loaded_run1_l3.parent_run().id == loaded_run1_l2.id + assert loaded_run1_l3.is_root_task_run() is False + + level4_runs = loaded_run1_l3.runs() + assert len(level4_runs) == 1 + assert level4_runs[0].id == run1_l4.id + + # Verify run1_l4 (deepest run) + loaded_run1_l4 = level4_runs[0] + assert loaded_run1_l4.parent_task() == loaded_task + assert loaded_run1_l4.parent_run().id == loaded_run1_l3.id + assert loaded_run1_l4.is_root_task_run() is False + assert len(loaded_run1_l4.runs()) == 0 + + # Verify run2_l3's nested run (sibling at level 3 has child at level 4) + loaded_run2_l3 = next(r for r in level3_runs if r.id == run2_l3.id) + assert loaded_run2_l3.parent_task() == loaded_task + assert loaded_run2_l3.parent_run().id == loaded_run1_l2.id + assert loaded_run2_l3.is_root_task_run() is False + + level4_sibling_runs = loaded_run2_l3.runs() + assert len(level4_sibling_runs) == 1 + assert level4_sibling_runs[0].id == run2_l4.id + + # Verify run2_l4 (sibling at level 4) + loaded_run2_l4 = level4_sibling_runs[0] + assert loaded_run2_l4.parent_task() == loaded_task + assert loaded_run2_l4.parent_run().id == loaded_run2_l3.id + assert loaded_run2_l4.is_root_task_run() is False + + # Verify run2_l1 (sibling at level 1 has no children) + loaded_run2_l1 = next(r for r in root_runs if r.id == run2_l1.id) + assert loaded_run2_l1.parent_task() == loaded_task + assert loaded_run2_l1.parent_run() is None + assert loaded_run2_l1.is_root_task_run() is True + assert len(loaded_run2_l1.runs()) == 0 + + # Verify run2_l2 (sibling at level 2 has no children) + loaded_run2_l2 = next(r for r in level2_runs if r.id == run2_l2.id) + assert loaded_run2_l2.parent_task() == loaded_task + assert loaded_run2_l2.parent_run() == loaded_run1_l1 + assert loaded_run2_l2.is_root_task_run() is False + assert len(loaded_run2_l2.runs()) == 0 + + # Test finding runs by ID from Task (DFS) + found_run1_l4 = loaded_task.find_task_run_by_id_dfs(run1_l4.id) + assert found_run1_l4 is not None + assert found_run1_l4.id == run1_l4.id + assert found_run1_l4.input == "level4_run1" + + found_run2_l4 = loaded_task.find_task_run_by_id_dfs(run2_l4.id) + assert found_run2_l4 is not None + assert found_run2_l4.id == run2_l4.id + assert found_run2_l4.input == "level4_sibling" + + # Test finding runs by ID from a TaskRun (DFS) + found_from_run1_l1 = loaded_run1_l1.find_task_run_by_id_dfs(run1_l4.id) + assert found_from_run1_l1 is not None + assert found_from_run1_l1.id == run1_l4.id + + # Count total runs in the hierarchy (should be 8) + all_runs = collect_all_task_runs(loaded_task) + assert len(all_runs) == 8 + + # Verify all levels are represented correctly + is_root_runs = [r for r in all_runs if r.is_root_task_run()] + assert len(is_root_runs) == 2 # run1_l1 and run2_l1 + + # Verify all non-root runs have parent_run set correctly + non_root_runs = [r for r in all_runs if not r.is_root_task_run()] + assert len(non_root_runs) == 6 + for run in non_root_runs: + assert run.parent_run() is not None + + +def collect_all_task_runs(root): + """Helper to recursively collect all TaskRuns in a hierarchy.""" + runs = [] + + def collect(node): + if isinstance(node, TaskRun): + runs.append(node) + for child in node.runs(): + collect(child) + elif hasattr(node, "runs"): + for child in node.runs(): + collect(child) + + collect(root) + return runs + + +def test_task_run_wrong_parent_type_raises(tmp_path): + project = Project(name="proj", path=tmp_path / "project.kiln") + project.save_to_file() + + with pytest.raises(ValidationError, match="Parent must be one of"): + TaskRun( + input="bad parent", + output=TaskOutput( + output="x", + source=DataSource( + type=DataSourceType.human, properties={"created_by": "test"} + ), + ), + parent=project, + ) + + +def test_task_run_runs_on_disk(tmp_path): + project = Project(name="proj", path=tmp_path / "project.kiln") + project.save_to_file() + task = Task(name="t", instruction="i", parent=project) + task.save_to_file() + + parent_run = TaskRun( + input="parent", + output=TaskOutput( + output="parent out", + source=DataSource( + type=DataSourceType.human, properties={"created_by": "test"} + ), + ), + parent=task, + ) + parent_run.save_to_file() + + child_run = TaskRun( + input="child", + output=TaskOutput( + output="child out", + source=DataSource( + type=DataSourceType.human, properties={"created_by": "test"} + ), + ), + parent=parent_run, + ) + child_run.save_to_file() + + loaded = TaskRun.load_from_file(parent_run.path) + children = loaded.runs() + assert len(children) == 1 + assert children[0].id == child_run.id diff --git a/libs/core/kiln_ai/datamodel/test_task.py b/libs/core/kiln_ai/datamodel/test_task.py index 1621db603..f3921f9b8 100644 --- a/libs/core/kiln_ai/datamodel/test_task.py +++ b/libs/core/kiln_ai/datamodel/test_task.py @@ -1,3 +1,5 @@ +import json + import pytest from pydantic import ValidationError @@ -6,6 +8,7 @@ StructuredOutputMode, TaskOutputRatingType, ) +from kiln_ai.datamodel.project import Project from kiln_ai.datamodel.prompt_id import PromptGenerators from kiln_ai.datamodel.run_config import KilnAgentRunConfigProperties from kiln_ai.datamodel.spec import Spec @@ -15,7 +18,8 @@ ToxicityProperties, ) from kiln_ai.datamodel.task import Task, TaskRunConfig -from kiln_ai.datamodel.task_output import normalize_rating +from kiln_ai.datamodel.task_output import TaskOutput, normalize_rating +from kiln_ai.datamodel.task_run import TaskRun def test_runconfig_valid_creation(): @@ -457,3 +461,307 @@ def test_task_prompt_optimization_jobs_readonly(tmp_path): assert ( prompt_optimization_jobs_default[0].name == "Readonly Prompt Optimization Job" ) + + +def test_all_children_of_parent_path_polymorphic(tmp_path): + """all_children_of_parent_path works correctly for polymorphic parent models.""" + # Test with TaskRun and Task + task = Task( + name="Test Task", + instruction="Test instruction", + path=tmp_path / "task.kiln", + ) + task.save_to_file() + + output = TaskOutput(output="test output") + + # Create direct children of Task + run1 = TaskRun(input="input1", output=output, parent=task) + run2 = TaskRun(input="input2", output=output, parent=task) + run1.save_to_file() + run2.save_to_file() + + children = TaskRun.all_children_of_parent_path(task.path) + assert len(children) == 2 + inputs = {child.input for child in children} + assert inputs == {"input1", "input2"} + + +def test_taskrun_nested_validates_parent_type_on_load(tmp_path): + """Loading a TaskRun validates its parent type is Task or TaskRun.""" + output = TaskOutput(output="test output") + + # Create a task + task = Task( + name="Test Task", + instruction="Test instruction", + path=tmp_path / "task.kiln", + ) + task.save_to_file() + + # Create parent TaskRun + parent_run = TaskRun(input="parent input", output=output, parent=task) + parent_run.save_to_file() + + # Create nested TaskRun + nested_run = TaskRun(input="nested input", output=output, parent=parent_run) + nested_run.save_to_file() + + # Reload from disk - parent type should be validated + # When loading from disk, the parent attribute points to the ultimate parent (Task) + # Use load_parent() to get the direct parent (TaskRun) + loaded_run = TaskRun.load_from_file(nested_run.path) + assert loaded_run is not None + assert loaded_run.input == "nested input" + # parent_task() returns the ultimate parent task (different instance but same data) + loaded_parent_task = loaded_run.parent_task() + assert loaded_parent_task is not None + assert loaded_parent_task.name == "Test Task" + assert loaded_parent_task.instruction == "Test instruction" + # Use load_parent() to get the direct parent TaskRun + direct_parent = loaded_run.load_parent() + assert direct_parent is not None + assert direct_parent.id == parent_run.id + assert direct_parent.input == "parent input" + + +def test_taskrun_loads_from_task_path(tmp_path): + """TaskRun children can be loaded from Task path (valid polymorphic parent).""" + output = TaskOutput(output="test output") + + task = Task( + name="Test Task", + instruction="Test instruction", + path=tmp_path / "task.kiln", + ) + task.save_to_file() + + run = TaskRun(input="test input", output=output, parent=task) + run.save_to_file() + + # Load children from task path - should succeed + children = list(TaskRun.iterate_children_paths_of_parent_path(task.path)) + assert len(children) == 1 + assert children[0] == run.path + + +def test_taskrun_loads_from_taskrun_path(tmp_path): + """TaskRun children can be loaded from TaskRun path (valid polymorphic parent).""" + output = TaskOutput(output="test output") + + task = Task( + name="Test Task", + instruction="Test instruction", + path=tmp_path / "task.kiln", + ) + task.save_to_file() + + parent_run = TaskRun(input="parent input", output=output, parent=task) + parent_run.save_to_file() + + nested_run = TaskRun(input="nested input", output=output, parent=parent_run) + nested_run.save_to_file() + + # Load children from TaskRun path - should succeed + children = list(TaskRun.iterate_children_paths_of_parent_path(parent_run.path)) + assert len(children) == 1 + assert children[0] == nested_run.path + + +def test_taskrun_fails_to_load_from_project_path(tmp_path): + """TaskRun children cannot be loaded from Project path (invalid polymorphic parent).""" + project_path = tmp_path / "project.kiln" + project = Project(name="Test Project", path=project_path) + project.save_to_file() + + # Try to load TaskRun children from a Project path - should fail + with pytest.raises(ValueError, match="Parent model_type 'project' is not one of"): + list(TaskRun.iterate_children_paths_of_parent_path(project_path)) + + +def test_multiple_nested_levels_validates_each_level(tmp_path): + """Multi-level nesting validates parent type at each level.""" + output = TaskOutput(output="test output") + + task = Task( + name="Test Task", + instruction="Test instruction", + path=tmp_path / "task.kiln", + ) + task.save_to_file() + + run1 = TaskRun(input="input1", output=output, parent=task) + run1.save_to_file() + + run2 = TaskRun(input="input2", output=output, parent=run1) + run2.save_to_file() + + run3 = TaskRun(input="input3", output=output, parent=run2) + run3.save_to_file() + + # Load children at each level - all should succeed + task_children = TaskRun.all_children_of_parent_path(task.path) + assert len(task_children) == 1 + assert task_children[0].id == run1.id + + run1_children = TaskRun.all_children_of_parent_path(run1.path) + assert len(run1_children) == 1 + assert run1_children[0].id == run2.id + + run2_children = TaskRun.all_children_of_parent_path(run2.path) + assert len(run2_children) == 1 + assert run2_children[0].id == run3.id + + +def test_polymorphic_parent_type_validation_fast_fail(tmp_path): + """Polymorphic validation fails fast without loading entire parent model.""" + # Create a file that's syntactically valid JSON but semantically invalid + # for the parent type - the polymorphic path should only read model_type + invalid_parent_path = tmp_path / "invalid.kiln" + + # Write a file that would fail if we tried to fully load as Task/TaskRun + # but should be caught by the model_type check first + invalid_data = { + "model_type": "project", # Wrong type - not Task or TaskRun + "extra_field": "this would cause issues", + } + with open(invalid_parent_path, "w") as f: + json.dump(invalid_data, f) + + # Should fail on model_type check, not on full load + with pytest.raises(ValueError, match="Parent model_type 'project' is not one of"): + list(TaskRun.iterate_children_paths_of_parent_path(invalid_parent_path)) + + +def test_load_parent_raises_on_malformed_taskrun(tmp_path): + """load_parent raises ValueError with context when parent TaskRun is malformed.""" + output = TaskOutput(output="test output") + + # Create a task + task = Task( + name="Test Task", + instruction="Test instruction", + path=tmp_path / "task.kiln", + ) + task.save_to_file() + + # Create parent TaskRun directory and valid run + parent_run_dir = tmp_path / "runs" / "parent_run" + parent_run_dir.mkdir(parents=True) + parent_run = TaskRun( + input="parent input", + output=output, + path=parent_run_dir / TaskRun.base_filename(), + ) + parent_run.save_to_file() + + # Create nested TaskRun directory and run (before corrupting parent) + nested_run_dir = parent_run_dir / "runs" / "nested_run" + nested_run_dir.mkdir(parents=True) + nested_run = TaskRun( + input="nested input", + output=output, + path=nested_run_dir / TaskRun.base_filename(), + ) + nested_run.save_to_file() + + # Verify it loads correctly with valid parent + loaded_nested = TaskRun.load_from_file(nested_run.path) + loaded_parent = loaded_nested.load_parent() + assert loaded_parent is not None + assert loaded_parent.input == "parent input" + + # Now corrupt the parent TaskRun file + with open(parent_run.path, "w") as f: + json.dump({"model_type": "task_run", "input": 123}, f) # Invalid input type + + # Reload nested run and try to load parent - should raise with context + loaded_nested = TaskRun.load_from_file(nested_run.path) + with pytest.raises(ValueError) as exc_info: + loaded_nested.load_parent() + + error_msg = str(exc_info.value) + assert "Failed to load parent TaskRun" in error_msg + assert str(parent_run.path) in error_msg + assert "malformed nested task run" in error_msg + + +def test_load_parent_succeeds_for_valid_taskrun_parent(tmp_path): + """load_parent successfully loads a valid TaskRun parent.""" + output = TaskOutput(output="test output") + + # Create a task + task = Task( + name="Test Task", + instruction="Test instruction", + path=tmp_path / "task.kiln", + ) + task.save_to_file() + + # Create parent TaskRun + parent_run = TaskRun(input="parent input", output=output, parent=task) + parent_run.save_to_file() + + # Create nested TaskRun + runs_dir = parent_run.path.parent / "runs" + runs_dir.mkdir(exist_ok=True) + nested_run_dir = runs_dir / "nested_run" + nested_run_dir.mkdir(exist_ok=True) + nested_run = TaskRun( + input="nested input", output=output, path=nested_run_dir / "task_run.kiln" + ) + nested_run.save_to_file() + + # Reload and load parent - should succeed + loaded_nested = TaskRun.load_from_file(nested_run.path) + loaded_parent = loaded_nested.load_parent() + + assert loaded_parent is not None + assert loaded_parent.id == parent_run.id + assert loaded_parent.input == "parent input" + assert isinstance(loaded_parent, TaskRun) + + +def test_is_root_task_run_true_when_parent_is_task(tmp_path): + """is_root_task_run returns True when parent is a Task.""" + output = TaskOutput(output="test output") + + task = Task( + name="Test Task", + instruction="Test instruction", + path=tmp_path / "task.kiln", + ) + task.save_to_file() + + run = TaskRun(input="test input", output=output, parent=task) + run.save_to_file() + + loaded_run = TaskRun.load_from_file(run.path) + assert loaded_run.is_root_task_run() is True + + +def test_is_root_task_run_false_when_parent_is_taskrun(tmp_path): + """is_root_task_run returns False when parent is another TaskRun.""" + output = TaskOutput(output="test output") + + task = Task( + name="Test Task", + instruction="Test instruction", + path=tmp_path / "task.kiln", + ) + task.save_to_file() + + parent_run = TaskRun(input="parent input", output=output, parent=task) + parent_run.save_to_file() + + runs_dir = parent_run.path.parent / "runs" + runs_dir.mkdir(exist_ok=True) + nested_run_dir = runs_dir / "nested_run" + nested_run_dir.mkdir(exist_ok=True) + nested_run = TaskRun( + input="nested input", output=output, path=nested_run_dir / "task_run.kiln" + ) + nested_run.save_to_file() + + loaded_nested = TaskRun.load_from_file(nested_run.path) + assert loaded_nested.is_root_task_run() is False diff --git a/libs/server/kiln_server/test_run_api.py b/libs/server/kiln_server/test_run_api.py index 14d37371c..974f1d433 100644 --- a/libs/server/kiln_server/test_run_api.py +++ b/libs/server/kiln_server/test_run_api.py @@ -1,5 +1,7 @@ import logging +import os import time +from pathlib import Path from unittest.mock import AsyncMock, MagicMock, patch import pytest @@ -17,6 +19,8 @@ TaskOutputRatingType, TaskRun, ) +from kiln_ai.datamodel.tool_id import KilnBuiltInToolId +from kiln_ai.utils.config import Config from kiln_server.custom_errors import connect_custom_errors from kiln_server.run_api import ( @@ -95,6 +99,10 @@ def task_run_setup(tmp_path): }, ), ), + trace=[ + {"role": "user", "content": "Test input"}, + {"role": "assistant", "content": "Test output"}, + ], ) task_run.save_to_file() @@ -1663,3 +1671,186 @@ async def test_benchmark_tag_runs(client, task_run_setup, run_count): logger.info( f"Performance: {runs_per_second:.1f} runs/second, {avg_time_per_run * 1000:.2f}ms per run" ) + + +def _adapter_sanity_check_output_path() -> Path: + return Path(__file__).resolve().parent / "adapter_sanity_check.txt" + + +def _append_to_sanity_check(content: str, output_path: Path) -> None: + with open(output_path, "a", encoding="utf-8") as f: + f.write(content) + f.write("\n") + + +@pytest.fixture +def adapter_sanity_check_setup(tmp_path): + """Setup for paid adapter sanity check tests - real project/task, no adapter mocking.""" + # if project at the path does not exist, create it, otherwise reuse + project_path = ( + Path("/Users/leonardmarcq/Downloads/") + / "adapter_sanity_project" + / "project.kiln" + ) + if not project_path.exists(): + project_path.parent.mkdir() + + project = Project(name="Adapter Sanity Project", path=str(project_path)) + project.save_to_file() + + task = Task( + name="Adapter Sanity Task", + instruction="You are a helpful assistant. Respond concisely.", + description="Task for adapter sanity checking", + parent=project, + ) + task.save_to_file() + + else: + project = Project.load_from_file(project_path) + task = next( + ( + t + for t in project.tasks(readonly=True) + if t.name == "Adapter Sanity Task" + ), + None, + ) + if task is None: + raise ValueError("Task not found") + + config = Config.shared() + original_projects = list(config.projects) if config.projects else [] + config._settings["projects"] = [*original_projects, str(project.path)] + + yield {"project": project, "task": task} + + config._settings["projects"] = original_projects + + +@pytest.fixture +def adapter_sanity_check_math_tools_setup(tmp_path): + """Setup for paid math tools test - task with instructions to use add, multiply, etc.""" + project_path = tmp_path / "adapter_sanity_math_project" / "project.kiln" + project_path.parent.mkdir() + + project = Project(name="Adapter Sanity Math Project", path=str(project_path)) + project.save_to_file() + + task = Task( + name="Math Tools Task", + instruction="You are an assistant that performs math using the provided tools. You MUST use the add, subtract, multiply, and divide tools for any arithmetic. For example, for 2+2 you must call the add tool with a=2 and b=2. End your response with the final answer in square brackets, e.g. [4].", + description="Task for testing Kiln built-in math tools", + parent=project, + ) + task.save_to_file() + + config = Config.shared() + original_projects = list(config.projects) if config.projects else [] + config._settings["projects"] = [*original_projects, str(project.path)] + + yield {"project": project, "task": task} + + config._settings["projects"] = original_projects + + +def _assert_math_tools_response(res: dict, expected_in_output: str) -> None: + """Assert response has correct output, trace with tool calls, and output matches latest message.""" + assert res["id"] is not None + + output = res.get("output", {}).get("output", "") + assert output is not None + assert expected_in_output in output + + trace = res.get("trace") or [] + assistant_with_tool_calls = [ + m for m in trace if m.get("role") == "assistant" and m.get("tool_calls") + ] + assert len(assistant_with_tool_calls) >= 1 + + tool_messages = [m for m in trace if m.get("role") == "tool"] + assert len(tool_messages) >= 1 + + last_assistant = next( + (m for m in reversed(trace) if m.get("role") == "assistant"), None + ) + assert last_assistant is not None + last_content = last_assistant.get("content") or "" + assert expected_in_output in last_content + + intermediate_outputs = res.get("intermediate_outputs") or {} + if intermediate_outputs: + for key, value in intermediate_outputs.items(): + assert isinstance(value, str) + + +@pytest.mark.paid +@pytest.mark.asyncio +async def test_run_task_adapter_sanity_math_tools( + client, adapter_sanity_check_math_tools_setup +): + """Multiple runs with built-in Kiln math tools. Test that tools work across independent runs.""" + if not os.environ.get("OPENROUTER_API_KEY"): + pytest.skip("OPENROUTER_API_KEY required for this test") + + project = adapter_sanity_check_math_tools_setup["project"] + task = adapter_sanity_check_math_tools_setup["task"] + + run_config = { + "model_name": "gpt_5_nano", + "model_provider_name": "openrouter", + "prompt_id": "simple_prompt_builder", + "structured_output_mode": "json_schema", + "tools_config": { + "tools": [ + KilnBuiltInToolId.ADD_NUMBERS.value, + KilnBuiltInToolId.SUBTRACT_NUMBERS.value, + KilnBuiltInToolId.MULTIPLY_NUMBERS.value, + KilnBuiltInToolId.DIVIDE_NUMBERS.value, + ] + }, + } + + response1 = client.post( + f"/api/projects/{project.id}/tasks/{task.id}/run", + json={ + "run_config_properties": run_config, + "plaintext_input": "What is 2 + 2? Use the tools to calculate.", + }, + ) + assert response1.status_code == 200 + res1 = response1.json() + _assert_math_tools_response(res1, "4") + + response2 = client.post( + f"/api/projects/{project.id}/tasks/{task.id}/run", + json={ + "run_config_properties": run_config, + "plaintext_input": "What is 3 times 4? Use the tools to calculate.", + }, + ) + assert response2.status_code == 200 + res2 = response2.json() + _assert_math_tools_response(res2, "12") + + response3 = client.post( + f"/api/projects/{project.id}/tasks/{task.id}/run", + json={ + "run_config_properties": run_config, + "plaintext_input": "What is 7 times 8 plus 3? Use the tools to calculate.", + }, + ) + assert response3.status_code == 200 + res3 = response3.json() + _assert_math_tools_response(res3, "59") + + response4 = client.post( + f"/api/projects/{project.id}/tasks/{task.id}/run", + json={ + "run_config_properties": run_config, + "plaintext_input": "What is 10 minus 3? Use the tools to calculate.", + }, + ) + assert response4.status_code == 200 + res4 = response4.json() + _assert_math_tools_response(res4, "7")