From 3f83187b9d4568b434dfee505c7a4acb39b7e79d Mon Sep 17 00:00:00 2001 From: "Leonard Q. Marcq" Date: Thu, 26 Feb 2026 19:02:17 +0800 Subject: [PATCH 01/32] refactor: support multiturn (existing conversation history with task run as session) --- app/web_ui/src/lib/api_schema.d.ts | 5 + libs/core/kiln_ai/adapters/chat/__init__.py | 2 + .../kiln_ai/adapters/chat/chat_formatter.py | 57 ++++- .../adapters/chat/test_chat_formatter.py | 27 ++ .../adapters/model_adapters/base_adapter.py | 149 ++++++++--- .../model_adapters/litellm_adapter.py | 15 +- .../adapters/model_adapters/mcp_adapter.py | 30 ++- .../model_adapters/test_base_adapter.py | 113 ++++++++- .../model_adapters/test_litellm_adapter.py | 63 ++++- .../model_adapters/test_mcp_adapter.py | 79 ++++++ .../test_saving_adapter_results.py | 175 ++++++++++++- .../model_adapters/test_structured_output.py | 6 +- .../kiln_ai/adapters/test_prompt_builders.py | 6 +- libs/core/kiln_ai/datamodel/test_basemodel.py | 2 +- libs/server/kiln_server/run_api.py | 6 +- libs/server/kiln_server/test_run_api.py | 234 ++++++++++++++++++ 16 files changed, 911 insertions(+), 58 deletions(-) diff --git a/app/web_ui/src/lib/api_schema.d.ts b/app/web_ui/src/lib/api_schema.d.ts index 95edbf682..d7bb96796 100644 --- a/app/web_ui/src/lib/api_schema.d.ts +++ b/app/web_ui/src/lib/api_schema.d.ts @@ -6232,6 +6232,11 @@ export interface components { } | unknown[] | null; /** Tags */ tags?: string[] | null; + /** + * Task Run Id + * @description When set, continue an existing session. The new message is appended to the run's trace. + */ + task_run_id?: string | null; }; /** * SampleApi diff --git a/libs/core/kiln_ai/adapters/chat/__init__.py b/libs/core/kiln_ai/adapters/chat/__init__.py index 7ab6328f0..b8d7877fd 100644 --- a/libs/core/kiln_ai/adapters/chat/__init__.py +++ b/libs/core/kiln_ai/adapters/chat/__init__.py @@ -3,6 +3,7 @@ ChatFormatter, ChatMessage, ChatStrategy, + MultiturnFormatter, ToolCallMessage, ToolResponseMessage, get_chat_formatter, @@ -14,6 +15,7 @@ "ChatFormatter", "ChatMessage", "ChatStrategy", + "MultiturnFormatter", "ToolCallMessage", "ToolResponseMessage", "build_tool_call_messages", diff --git a/libs/core/kiln_ai/adapters/chat/chat_formatter.py b/libs/core/kiln_ai/adapters/chat/chat_formatter.py index d1a73514f..a6f71f4b4 100644 --- a/libs/core/kiln_ai/adapters/chat/chat_formatter.py +++ b/libs/core/kiln_ai/adapters/chat/chat_formatter.py @@ -3,11 +3,14 @@ import json from abc import ABC, abstractmethod from dataclasses import dataclass -from typing import Dict, List, Literal, Optional, Sequence, Union +from typing import Any, Dict, List, Literal, Optional, Sequence, Union from kiln_ai.datamodel.datamodel_enums import ChatStrategy, InputType from kiln_ai.utils.exhaustive_error import raise_exhaustive_enum_error -from kiln_ai.utils.open_ai_types import ChatCompletionMessageToolCallParam +from kiln_ai.utils.open_ai_types import ( + ChatCompletionMessageParam, + ChatCompletionMessageToolCallParam, +) COT_FINAL_ANSWER_PROMPT = "Considering the above, return a final result." @@ -90,6 +93,11 @@ def intermediate_outputs(self) -> Dict[str, str]: """Get the intermediate outputs from the chat formatter.""" return self._intermediate_outputs + def initial_messages(self) -> list[Any]: + """Messages to seed the conversation. Empty for fresh runs; prior trace for continuation.""" + # TODO: fix the type somehow + return [] + @abstractmethod def next_turn(self, previous_output: str | None = None) -> Optional[ChatTurn]: """Advance the conversation and return the next messages if any.""" @@ -236,6 +244,51 @@ def next_turn(self, previous_output: str | None = None) -> Optional[ChatTurn]: return None +class MultiturnFormatter(ChatFormatter): + """ + Formatter for continuing a multi-turn conversation with prior trace. + Takes prior_trace (existing conversation) and appends the new user message. + Produces a single turn: the new user message. Tool calls and multi-turn + model responses are handled by _run_model_turn's internal loop. + """ + + def __init__( + self, + prior_trace: list[ChatCompletionMessageParam], + user_input: InputType, + ) -> None: + super().__init__( + system_message="", + user_input=user_input, + thinking_instructions=None, + ) + self._prior_trace = prior_trace + + def initial_messages(self) -> list[Any]: + """Messages to seed the conversation (prior trace).""" + # TODO: use the type we need, but trace is untyped, and we cannot import from litellm adapter here + # or we get circular imports + return list(self._prior_trace) + + def next_turn(self, previous_output: str | None = None) -> Optional[ChatTurn]: + if self._state == "start": + # prior trace is already in the messages list and contains system and so on, we only need + # to append the latest new user message + user_msg = BasicChatMessage("user", format_user_message(self.user_input)) + self._state = "awaiting_final" + self._messages.append(user_msg) + return ChatTurn(messages=[user_msg], final_call=True) + + if self._state == "awaiting_final": + if previous_output is None: + raise ValueError("previous_output required for final step") + self._messages.append(BasicChatMessage("assistant", previous_output)) + self._state = "done" + return None + + return None + + def get_chat_formatter( strategy: ChatStrategy, system_message: str, diff --git a/libs/core/kiln_ai/adapters/chat/test_chat_formatter.py b/libs/core/kiln_ai/adapters/chat/test_chat_formatter.py index 642d49236..e0138a954 100644 --- a/libs/core/kiln_ai/adapters/chat/test_chat_formatter.py +++ b/libs/core/kiln_ai/adapters/chat/test_chat_formatter.py @@ -1,6 +1,7 @@ from kiln_ai.adapters.chat import ChatStrategy, get_chat_formatter from kiln_ai.adapters.chat.chat_formatter import ( COT_FINAL_ANSWER_PROMPT, + MultiturnFormatter, format_user_message, ) @@ -119,6 +120,32 @@ def test_chat_formatter_r1_style(): assert formatter.intermediate_outputs() == {} +def test_multiturn_formatter_initial_messages(): + prior_trace = [ + {"role": "user", "content": "hi"}, + {"role": "assistant", "content": "hello"}, + ] + formatter = MultiturnFormatter(prior_trace=prior_trace, user_input="new input") + assert formatter.initial_messages() == prior_trace + + +def test_multiturn_formatter_next_turn(): + prior_trace = [ + {"role": "user", "content": "hi"}, + {"role": "assistant", "content": "hello"}, + ] + formatter = MultiturnFormatter(prior_trace=prior_trace, user_input="follow-up") + + first = formatter.next_turn() + assert first is not None + assert len(first.messages) == 1 + assert first.messages[0].role == "user" + assert first.messages[0].content == "follow-up" + assert first.final_call + + assert formatter.next_turn("assistant response") is None + + def test_format_user_message(): # String assert format_user_message("test input") == "test input" diff --git a/libs/core/kiln_ai/adapters/model_adapters/base_adapter.py b/libs/core/kiln_ai/adapters/model_adapters/base_adapter.py index 134ebc7d3..a4d43b873 100644 --- a/libs/core/kiln_ai/adapters/model_adapters/base_adapter.py +++ b/libs/core/kiln_ai/adapters/model_adapters/base_adapter.py @@ -3,7 +3,11 @@ from dataclasses import dataclass from typing import Dict, Tuple -from kiln_ai.adapters.chat.chat_formatter import ChatFormatter, get_chat_formatter +from kiln_ai.adapters.chat.chat_formatter import ( + ChatFormatter, + MultiturnFormatter, + get_chat_formatter, +) from kiln_ai.adapters.ml_model_list import ( KilnModelProvider, StructuredOutputMode, @@ -23,6 +27,7 @@ TaskRun, Usage, ) +from kiln_ai.datamodel.basemodel import ID_TYPE from kiln_ai.datamodel.datamodel_enums import ChatStrategy, InputType from kiln_ai.datamodel.json_schema import validate_schema_with_value_error from kiln_ai.datamodel.run_config import ( @@ -123,14 +128,18 @@ async def invoke( self, input: InputType, input_source: DataSource | None = None, + task_run_id: ID_TYPE | None = None, ) -> TaskRun: - run_output, _ = await self.invoke_returning_run_output(input, input_source) + run_output, _ = await self.invoke_returning_run_output( + input, input_source, task_run_id + ) return run_output async def _run_returning_run_output( self, input: InputType, input_source: DataSource | None = None, + task_run_id: ID_TYPE | None = None, ) -> Tuple[TaskRun, RunOutput]: # validate input, allowing arrays if self.input_schema is not None: @@ -141,6 +150,25 @@ async def _run_returning_run_output( require_object=False, ) + prior_trace: list[ChatCompletionMessageParam] | None = None + existing_run: TaskRun | None = None + + if task_run_id is not None: + if self.task.path is None: + raise ValueError( + "Cannot continue session: task has no path. Save the task first." + ) + existing_run = TaskRun.from_id_and_parent_path(task_run_id, self.task.path) + if existing_run is None: + raise ValueError( + f"Run not found. Cannot continue session. ID: {task_run_id}" + ) + if not existing_run.trace or len(existing_run.trace) == 0: + raise ValueError( + f"Run has no trace. Cannot continue session without conversation history. ID: {task_run_id}" + ) + prior_trace = existing_run.trace + # Format model input for model call (we save the original input in the task without formatting) formatted_input = input formatter_id = self.model_provider().formatter @@ -149,7 +177,7 @@ async def _run_returning_run_output( formatted_input = formatter.format_input(input) # Run - run_output, usage = await self._run(formatted_input) + run_output, usage = await self._run(formatted_input, prior_trace=prior_trace) # Parse provider = self.model_provider() @@ -198,10 +226,28 @@ async def _run_returning_run_output( "Reasoning is required for this model, but no reasoning was returned." ) - # Generate the run and output - run = self.generate_run( - input, input_source, parsed_output, usage, run_output.trace - ) + # Create the run and output - merge if there is an existing run + if existing_run is not None: + merged_output = RunOutput( + output=parsed_output.output, + intermediate_outputs=parsed_output.intermediate_outputs + or run_output.intermediate_outputs, + output_logprobs=parsed_output.output_logprobs + or run_output.output_logprobs, + trace=run_output.trace, + ) + run = self.generate_run( + input, + input_source, + merged_output, + usage, + run_output.trace, + existing_run=existing_run, + ) + else: + run = self.generate_run( + input, input_source, parsed_output, usage, run_output.trace + ) # Save the run if configured to do so, and we have a path to save to if ( @@ -210,7 +256,7 @@ async def _run_returning_run_output( and self.task.path is not None ): run.save_to_file() - else: + elif existing_run is None: # Clear the ID to indicate it's not persisted run.id = None @@ -220,6 +266,7 @@ async def invoke_returning_run_output( self, input: InputType, input_source: DataSource | None = None, + task_run_id: ID_TYPE | None = None, ) -> Tuple[TaskRun, RunOutput]: # Determine if this is the root agent (no existing run context) is_root_agent = get_agent_run_id() is None @@ -229,7 +276,9 @@ async def invoke_returning_run_output( set_agent_run_id(run_id) try: - return await self._run_returning_run_output(input, input_source) + return await self._run_returning_run_output( + input, input_source, task_run_id + ) finally: if is_root_agent: try: @@ -247,7 +296,11 @@ def adapter_name(self) -> str: pass @abstractmethod - async def _run(self, input: InputType) -> Tuple[RunOutput, Usage | None]: + async def _run( + self, + input: InputType, + prior_trace: list[ChatCompletionMessageParam] | None = None, + ) -> Tuple[RunOutput, Usage | None]: pass def build_prompt(self) -> str: @@ -267,7 +320,14 @@ def build_prompt(self) -> str: include_json_instructions=add_json_instructions ) - def build_chat_formatter(self, input: InputType) -> ChatFormatter: + def build_chat_formatter( + self, + input: InputType, + prior_trace: list[ChatCompletionMessageParam] | None = None, + ) -> ChatFormatter: + if prior_trace is not None: + return MultiturnFormatter(prior_trace, input) + if self.prompt_builder is None: raise ValueError("Prompt builder is not available for MCP run config") # Determine the chat strategy to use based on the prompt the user selected, the model's capabilities, and if the model was finetuned with a specific chat strategy. @@ -323,52 +383,71 @@ def generate_run( run_output: RunOutput, usage: Usage | None = None, trace: list[ChatCompletionMessageParam] | None = None, + existing_run: TaskRun | None = None, ) -> TaskRun: - # Convert input and output to JSON strings if they aren't strings - input_str = ( - input if isinstance(input, str) else json.dumps(input, ensure_ascii=False) - ) output_str = ( json.dumps(run_output.output, ensure_ascii=False) if isinstance(run_output.output, dict) else run_output.output ) - # If no input source is provided, use the human data source - if input_source is None: - input_source = DataSource( - type=DataSourceType.human, - properties={"created_by": Config.shared().user_id}, - ) - - # Synthetic since an adapter, not a human, is creating this - # Special case for MCP run configs which calls a mcp tool output_source_type = ( DataSourceType.tool_call if self.run_config.type == "mcp" else DataSourceType.synthetic ) - new_task_run = TaskRun( + new_output = TaskOutput( + output=output_str, + source=DataSource( + type=output_source_type, + properties=self._properties_for_task_output(), + run_config=self.run_config, + ), + ) + + if existing_run is not None: + accumulated_usage = existing_run.usage + if usage is not None: + if accumulated_usage is not None: + accumulated_usage = accumulated_usage + usage + else: + accumulated_usage = usage + + merged_intermediate = dict(existing_run.intermediate_outputs or {}) + if run_output.intermediate_outputs: + for k, v in run_output.intermediate_outputs.items(): + merged_intermediate[k] = v + + existing_run.output = new_output + existing_run.trace = trace + existing_run.usage = accumulated_usage + existing_run.intermediate_outputs = merged_intermediate + + return existing_run + + # Convert input and output to JSON strings if they aren't strings + input_str = ( + input if isinstance(input, str) else json.dumps(input, ensure_ascii=False) + ) + + if input_source is None: + input_source = DataSource( + type=DataSourceType.human, + properties={"created_by": Config.shared().user_id}, + ) + + return TaskRun( parent=self.task, input=input_str, input_source=input_source, - output=TaskOutput( - output=output_str, - source=DataSource( - type=output_source_type, - properties=self._properties_for_task_output(), - run_config=self.run_config, - ), - ), + output=new_output, intermediate_outputs=run_output.intermediate_outputs, tags=self.base_adapter_config.default_tags or [], usage=usage, trace=trace, ) - return new_task_run - def _properties_for_task_output(self) -> Dict[str, str | int | float]: match self.run_config.type: case "mcp": diff --git a/libs/core/kiln_ai/adapters/model_adapters/litellm_adapter.py b/libs/core/kiln_ai/adapters/model_adapters/litellm_adapter.py index 41a8e64c0..d49a7dc82 100644 --- a/libs/core/kiln_ai/adapters/model_adapters/litellm_adapter.py +++ b/libs/core/kiln_ai/adapters/model_adapters/litellm_adapter.py @@ -184,20 +184,29 @@ async def _run_model_turn( f"Too many tool calls ({tool_calls_count}). Stopping iteration to avoid using too many tokens." ) - async def _run(self, input: InputType) -> tuple[RunOutput, Usage | None]: + async def _run( + self, + input: InputType, + prior_trace: list[ChatCompletionMessageParam] | None = None, + ) -> tuple[RunOutput, Usage | None]: usage = Usage() provider = self.model_provider() if not provider.model_id: raise ValueError("Model ID is required for OpenAI compatible models") - chat_formatter = self.build_chat_formatter(input) - messages: list[ChatCompletionMessageIncludingLiteLLM] = [] + # build_chat_formatter returns MultiturnFormatter when prior_trace is set, else prompt-based formatter + chat_formatter = self.build_chat_formatter(input, prior_trace) + messages: list[ChatCompletionMessageIncludingLiteLLM] = copy.deepcopy( + chat_formatter.initial_messages() + ) prior_output: str | None = None final_choice: Choices | None = None turns = 0 + # Same loop for both fresh runs and prior_trace continuation. + # _run_model_turn has its own internal loop for tool calls (model calls tool -> we run it -> model continues). while True: turns += 1 if turns > MAX_CALLS_PER_TURN: diff --git a/libs/core/kiln_ai/adapters/model_adapters/mcp_adapter.py b/libs/core/kiln_ai/adapters/model_adapters/mcp_adapter.py index a7f1a8b6c..cd0d34662 100644 --- a/libs/core/kiln_ai/adapters/model_adapters/mcp_adapter.py +++ b/libs/core/kiln_ai/adapters/model_adapters/mcp_adapter.py @@ -42,7 +42,17 @@ def __init__( def adapter_name(self) -> str: return "mcp_adapter" - async def _run(self, input: InputType) -> Tuple[RunOutput, Usage | None]: + async def _run( + self, + input: InputType, + prior_trace: list[ChatCompletionMessageParam] | None = None, + ) -> Tuple[RunOutput, Usage | None]: + if prior_trace is not None: + raise NotImplementedError( + "Session continuation is not supported for MCP adapter. " + "MCP tools are single-turn and do not maintain conversation state." + ) + run_config = self.run_config if not isinstance(run_config, McpRunConfigProperties): raise ValueError("MCPAdapter requires McpRunConfigProperties") @@ -75,19 +85,35 @@ async def invoke( self, input: InputType, input_source: DataSource | None = None, + task_run_id: str | None = None, ) -> TaskRun: - run_output, _ = await self.invoke_returning_run_output(input, input_source) + if task_run_id is not None: + raise NotImplementedError( + "Session continuation is not supported for MCP adapter. " + "MCP tools are single-turn and do not maintain conversation state." + ) + + run_output, _ = await self.invoke_returning_run_output( + input, input_source, task_run_id + ) return run_output async def invoke_returning_run_output( self, input: InputType, input_source: DataSource | None = None, + task_run_id: str | None = None, ) -> Tuple[TaskRun, RunOutput]: """ Runs the task and returns both the persisted TaskRun and raw RunOutput. If this call is the root of a run, it creates an agent run context, ensures MCP tool calls have a valid session scope, and cleans up the session/context on completion. """ + if task_run_id is not None: + raise NotImplementedError( + "Session continuation is not supported for MCP adapter. " + "MCP tools are single-turn and do not maintain conversation state." + ) + is_root_agent = get_agent_run_id() is None if is_root_agent: diff --git a/libs/core/kiln_ai/adapters/model_adapters/test_base_adapter.py b/libs/core/kiln_ai/adapters/model_adapters/test_base_adapter.py index d2a59a70f..8b5619d6b 100644 --- a/libs/core/kiln_ai/adapters/model_adapters/test_base_adapter.py +++ b/libs/core/kiln_ai/adapters/model_adapters/test_base_adapter.py @@ -10,7 +10,7 @@ ) from kiln_ai.adapters.prompt_builders import BasePromptBuilder from kiln_ai.datamodel import Task -from kiln_ai.datamodel.datamodel_enums import ChatStrategy +from kiln_ai.datamodel.datamodel_enums import ChatStrategy, ModelProviderName from kiln_ai.datamodel.project import Project from kiln_ai.datamodel.run_config import KilnAgentRunConfigProperties, ToolsRunConfig from kiln_ai.datamodel.tool_id import KilnBuiltInToolId @@ -20,7 +20,7 @@ class MockAdapter(BaseAdapter): """Concrete implementation of BaseAdapter for testing""" - async def _run(self, input): + async def _run(self, input, prior_trace=None): return None, None def adapter_name(self) -> str: @@ -233,7 +233,7 @@ async def test_input_formatting( # Mock the _run method to capture the input captured_input = None - async def mock_run(input): + async def mock_run(input, prior_trace=None): nonlocal captured_input captured_input = input return RunOutput(output="test output", intermediate_outputs={}), None @@ -420,6 +420,101 @@ def test_build_chat_formatter( mock_prompt_builder.chain_of_thought_prompt.assert_called_once() +def test_build_chat_formatter_with_prior_trace_returns_multiturn_formatter(adapter): + prior_trace = [ + {"role": "user", "content": "hi"}, + {"role": "assistant", "content": "hello"}, + ] + formatter = adapter.build_chat_formatter("new input", prior_trace=prior_trace) + assert formatter.__class__.__name__ == "MultiturnFormatter" + assert formatter.initial_messages() == prior_trace + + +@pytest.mark.asyncio +async def test_task_run_id_task_path_none_raises(base_project): + task = Task( + name="test_task", + instruction="test_instruction", + parent=base_project, + ) + assert task.path is None + adapter = MockAdapter( + task=task, + run_config=KilnAgentRunConfigProperties( + model_name="gpt_4o", + model_provider_name=ModelProviderName.openai, + prompt_id="simple_prompt_builder", + structured_output_mode=StructuredOutputMode.json_schema, + ), + ) + with pytest.raises(ValueError, match="task has no path"): + await adapter.invoke("input", task_run_id="some-id") + + +@pytest.mark.asyncio +async def test_invoke_returning_run_output_passes_task_run_id_to_run( + adapter, mock_parser, tmp_path +): + project = Project(name="proj", path=tmp_path / "proj.kiln") + project.save_to_file() + task = Task( + name="t", + instruction="i", + parent=project, + ) + task.save_to_file() + adapter.task = task + + trace = [ + {"role": "user", "content": "hi"}, + {"role": "assistant", "content": "hello"}, + ] + initial_run = adapter.generate_run( + input="hi", + input_source=None, + run_output=RunOutput( + output="hello", + intermediate_outputs=None, + trace=trace, + ), + trace=trace, + ) + initial_run.save_to_file() + run_id = initial_run.id + assert run_id is not None + + captured_prior_trace = None + + async def mock_run(input, prior_trace=None): + nonlocal captured_prior_trace + captured_prior_trace = prior_trace + return RunOutput(output="ok", intermediate_outputs=None, trace=trace), None + + adapter._run = mock_run + + provider = MagicMock() + provider.parser = "test_parser" + provider.formatter = None + provider.reasoning_capable = False + adapter.model_provider = MagicMock(return_value=provider) + mock_parser.parse_output.return_value = RunOutput( + output="ok", intermediate_outputs=None, trace=trace + ) + + with ( + patch( + "kiln_ai.adapters.model_adapters.base_adapter.model_parser_from_id", + return_value=mock_parser, + ), + patch( + "kiln_ai.adapters.model_adapters.base_adapter.request_formatter_from_id", + ), + ): + await adapter.invoke_returning_run_output("follow-up", task_run_id=run_id) + + assert captured_prior_trace == trace + + @pytest.mark.parametrize( "initial_mode,expected_mode", [ @@ -681,7 +776,7 @@ async def test_invoke_sets_run_context(self, adapter, clear_context): from kiln_ai.run_context import get_agent_run_id # Mock the _run method - async def mock_run(input): + async def mock_run(input, prior_trace=None): # Check that run ID is set during _run run_id = get_agent_run_id() assert run_id is not None @@ -721,7 +816,7 @@ async def test_invoke_clears_run_context_after(self, adapter, clear_context): from kiln_ai.run_context import get_agent_run_id # Mock the _run method - async def mock_run(input): + async def mock_run(input, prior_trace=None): return RunOutput(output="test output", intermediate_outputs={}), None adapter._run = mock_run @@ -759,7 +854,7 @@ async def test_invoke_clears_run_context_on_error(self, adapter, clear_context): from kiln_ai.run_context import get_agent_run_id # Mock the _run method to raise an error - async def mock_run(input): + async def mock_run(input, prior_trace=None): # Run ID should be set even when error occurs run_id = get_agent_run_id() assert run_id is not None @@ -796,7 +891,7 @@ async def test_sub_agent_inherits_run(self, adapter, clear_context): set_agent_run_id(parent_run_id) # Mock the _run method to check inherited run ID - async def mock_run(input): + async def mock_run(input, prior_trace=None): # Sub-agent should see parent's run ID run_id = get_agent_run_id() assert run_id == parent_run_id @@ -845,7 +940,7 @@ async def test_sub_agent_does_not_create_new_run(self, adapter, clear_context): run_id_during_run = None # Mock the _run method to capture run ID - async def mock_run(input): + async def mock_run(input, prior_trace=None): nonlocal run_id_during_run run_id_during_run = get_agent_run_id() return RunOutput(output="test output", intermediate_outputs={}), None @@ -885,7 +980,7 @@ async def test_cleanup_session_called_on_completion(self, adapter, clear_context from kiln_ai.adapters.run_output import RunOutput # Mock the _run method - async def mock_run(input): + async def mock_run(input, prior_trace=None): return RunOutput(output="test output", intermediate_outputs={}), None adapter._run = mock_run diff --git a/libs/core/kiln_ai/adapters/model_adapters/test_litellm_adapter.py b/libs/core/kiln_ai/adapters/model_adapters/test_litellm_adapter.py index 7d0414b99..3b44812a6 100644 --- a/libs/core/kiln_ai/adapters/model_adapters/test_litellm_adapter.py +++ b/libs/core/kiln_ai/adapters/model_adapters/test_litellm_adapter.py @@ -7,7 +7,10 @@ from kiln_ai.adapters.ml_model_list import ModelProviderName, StructuredOutputMode from kiln_ai.adapters.model_adapters.base_adapter import AdapterConfig -from kiln_ai.adapters.model_adapters.litellm_adapter import LiteLlmAdapter +from kiln_ai.adapters.model_adapters.litellm_adapter import ( + LiteLlmAdapter, + ModelTurnResult, +) from kiln_ai.adapters.model_adapters.litellm_config import LiteLlmConfig from kiln_ai.datamodel import Project, Task, Usage from kiln_ai.datamodel.run_config import ( @@ -1301,3 +1304,61 @@ async def test_dict_input_converted_to_json(tmp_path, config): assert isinstance(content, str) parsed_content = json.loads(content) assert parsed_content == {"x": 10, "y": 20} + + +@pytest.mark.asyncio +async def test_run_with_prior_trace_uses_multiturn_formatter(mock_task): + config = LiteLlmConfig( + base_url="https://api.test.com", + run_config_properties=KilnAgentRunConfigProperties( + model_name="test-model", + model_provider_name="openai_compatible", + prompt_id="simple_prompt_builder", + structured_output_mode="json_schema", + ), + default_headers={"X-Test": "test"}, + additional_body_options={"api_key": "test_key"}, + ) + prior_trace = [ + {"role": "user", "content": "hi"}, + {"role": "assistant", "content": "hello"}, + ] + adapter = LiteLlmAdapter(config=config, kiln_task=mock_task) + + build_chat_formatter_calls = [] + + original_build = adapter.build_chat_formatter + + def capturing_build(input, prior_trace_arg=None): + build_chat_formatter_calls.append((input, prior_trace_arg)) + return original_build(input, prior_trace_arg) + + adapter.build_chat_formatter = capturing_build + + async def mock_run_model_turn( + provider, prior_messages, top_logprobs, skip_response_format + ): + extended = list(prior_messages) + extended.append({"role": "assistant", "content": "How can I help?"}) + return ModelTurnResult( + assistant_message="How can I help?", + all_messages=extended, + model_response=None, + model_choice=None, + usage=Usage(), + ) + + adapter._run_model_turn = mock_run_model_turn + + run_output, _ = await adapter._run("follow-up", prior_trace=prior_trace) + + assert len(build_chat_formatter_calls) == 1 + assert build_chat_formatter_calls[0][0] == "follow-up" + assert build_chat_formatter_calls[0][1] == prior_trace + + assert run_output.trace is not None + assert len(run_output.trace) == 4 + assert run_output.trace[0]["content"] == "hi" + assert run_output.trace[1]["content"] == "hello" + assert run_output.trace[2]["content"] == "follow-up" + assert run_output.trace[3]["content"] == "How can I help?" diff --git a/libs/core/kiln_ai/adapters/model_adapters/test_mcp_adapter.py b/libs/core/kiln_ai/adapters/model_adapters/test_mcp_adapter.py index 393a3859f..ac00673d5 100644 --- a/libs/core/kiln_ai/adapters/model_adapters/test_mcp_adapter.py +++ b/libs/core/kiln_ai/adapters/model_adapters/test_mcp_adapter.py @@ -328,3 +328,82 @@ async def test_mcp_adapter_sets_and_clears_run_context( await adapter.invoke_returning_run_output("input") assert get_agent_run_id() is None + + +@pytest.mark.asyncio +async def test_mcp_adapter_rejects_multiturn_invoke_returning_run_output( + project_with_local_mcp_server, local_mcp_tool_id +): + """Session continuation (task_run_id) is not supported for MCP adapter.""" + project, _ = project_with_local_mcp_server + task = Task( + name="Test Task", + parent=project, + instruction="Echo input", + ) + + run_config = McpRunConfigProperties( + tool_reference=MCPToolReference(tool_id=local_mcp_tool_id) + ) + + adapter = MCPAdapter(task=task, run_config=run_config) + + with pytest.raises(NotImplementedError) as exc_info: + await adapter.invoke_returning_run_output("input", task_run_id="some-run-id") + + assert "Session continuation is not supported" in str(exc_info.value) + assert "MCP adapter" in str(exc_info.value) + + +@pytest.mark.asyncio +async def test_mcp_adapter_rejects_multiturn_invoke( + project_with_local_mcp_server, local_mcp_tool_id +): + """invoke with task_run_id raises NotImplementedError for MCP adapter.""" + project, _ = project_with_local_mcp_server + task = Task( + name="Test Task", + parent=project, + instruction="Echo input", + ) + + run_config = McpRunConfigProperties( + tool_reference=MCPToolReference(tool_id=local_mcp_tool_id) + ) + + adapter = MCPAdapter(task=task, run_config=run_config) + + with pytest.raises(NotImplementedError) as exc_info: + await adapter.invoke("input", task_run_id="some-run-id") + + assert "Session continuation is not supported" in str(exc_info.value) + + +@pytest.mark.asyncio +async def test_mcp_adapter_rejects_prior_trace_in_run( + project_with_local_mcp_server, local_mcp_tool_id +): + """_run with prior_trace raises NotImplementedError for MCP adapter.""" + project, _ = project_with_local_mcp_server + task = Task( + name="Test Task", + parent=project, + instruction="Echo input", + ) + + run_config = McpRunConfigProperties( + tool_reference=MCPToolReference(tool_id=local_mcp_tool_id) + ) + + adapter = MCPAdapter(task=task, run_config=run_config) + + prior_trace = [ + {"role": "user", "content": "first message"}, + {"role": "assistant", "content": "first response"}, + ] + + with pytest.raises(NotImplementedError) as exc_info: + await adapter._run("follow-up message", prior_trace=prior_trace) + + assert "Session continuation is not supported" in str(exc_info.value) + assert "MCP tools are single-turn" in str(exc_info.value) diff --git a/libs/core/kiln_ai/adapters/model_adapters/test_saving_adapter_results.py b/libs/core/kiln_ai/adapters/model_adapters/test_saving_adapter_results.py index 9cc5fa5d9..82ea05905 100644 --- a/libs/core/kiln_ai/adapters/model_adapters/test_saving_adapter_results.py +++ b/libs/core/kiln_ai/adapters/model_adapters/test_saving_adapter_results.py @@ -1,4 +1,4 @@ -from unittest.mock import patch +from unittest.mock import MagicMock, patch import pytest @@ -10,7 +10,11 @@ class MockAdapter(BaseAdapter): - async def _run(self, input: InputType) -> tuple[RunOutput, Usage | None]: + async def _run( + self, + input: InputType, + prior_trace=None, + ) -> tuple[RunOutput, Usage | None]: return RunOutput(output="Test output", intermediate_outputs=None), None def adapter_name(self) -> str: @@ -233,6 +237,173 @@ async def test_autosave_true(test_task, adapter): assert output.source.properties["top_p"] == 1.0 +@pytest.mark.asyncio +async def test_invoke_continue_session(test_task, adapter): + """Test that invoke with task_run_id continues a session and updates the run.""" + with patch("kiln_ai.utils.config.Config.shared") as mock_shared: + mock_config = mock_shared.return_value + mock_config.autosave_runs = True + mock_config.user_id = "test_user" + + trace = [ + {"role": "user", "content": "Hello"}, + {"role": "assistant", "content": "Hi there!"}, + ] + initial_run = adapter.generate_run( + input="Hello", + input_source=None, + run_output=RunOutput( + output="Hi there!", + intermediate_outputs=None, + trace=trace, + ), + trace=trace, + ) + initial_run.save_to_file() + run_id = initial_run.id + assert run_id is not None + + async def mock_run(input, prior_trace=None): + if prior_trace is not None: + extended_trace = [ + *prior_trace, + {"role": "user", "content": input}, + {"role": "assistant", "content": "How can I help?"}, + ] + return ( + RunOutput( + output="How can I help?", + intermediate_outputs=None, + trace=extended_trace, + ), + None, + ) + return RunOutput(output="Test output", intermediate_outputs=None), None + + adapter._run = mock_run + + with ( + patch.object( + adapter, + "model_provider", + return_value=MagicMock( + parser="default", + formatter=None, + reasoning_capable=False, + ), + ), + patch( + "kiln_ai.adapters.model_adapters.base_adapter.model_parser_from_id" + ) as mock_parser_from_id, + ): + mock_parser = MagicMock() + mock_parser.parse_output.return_value = RunOutput( + output="How can I help?", + intermediate_outputs=None, + trace=[ + {"role": "user", "content": "Hello"}, + {"role": "assistant", "content": "Hi there!"}, + {"role": "user", "content": "Tell me more"}, + {"role": "assistant", "content": "How can I help?"}, + ], + ) + mock_parser_from_id.return_value = mock_parser + + updated_run = await adapter.invoke("Tell me more", task_run_id=run_id) + + assert updated_run.id == run_id + assert updated_run.input == "Hello" + assert updated_run.output.output == "How can I help?" + assert len(updated_run.trace) == 4 + assert updated_run.trace[-2]["content"] == "Tell me more" + assert updated_run.trace[-1]["content"] == "How can I help?" + + reloaded = Task.load_from_file(test_task.path) + runs = reloaded.runs() + assert len(runs) == 1 + assert runs[0].output.output == "How can I help?" + + +@pytest.mark.asyncio +async def test_invoke_continue_invalid_task_run_id(test_task, adapter): + """Test that invoke with invalid task_run_id raises ValueError.""" + with patch("kiln_ai.utils.config.Config.shared") as mock_shared: + mock_config = mock_shared.return_value + mock_config.autosave_runs = True + mock_config.user_id = "test_user" + + with pytest.raises(ValueError, match="Run not found"): + await adapter.invoke("Hello", task_run_id="nonexistent-id") + + +@pytest.mark.asyncio +async def test_invoke_continue_run_without_trace(test_task, adapter): + """Test that invoke with task_run_id for a run without trace raises ValueError.""" + with patch("kiln_ai.utils.config.Config.shared") as mock_shared: + mock_config = mock_shared.return_value + mock_config.autosave_runs = True + mock_config.user_id = "test_user" + + run_without_trace = adapter.generate_run( + input="Hello", + input_source=None, + run_output=RunOutput( + output="Hi", + intermediate_outputs=None, + trace=None, + ), + ) + run_without_trace.save_to_file() + + with pytest.raises(ValueError, match="no trace"): + await adapter.invoke("Follow up", task_run_id=run_without_trace.id) + + +def test_generate_run_with_existing_run_merges_usage_and_intermediate_outputs( + test_task, adapter +): + trace = [ + {"role": "user", "content": "hi"}, + {"role": "assistant", "content": "hello"}, + ] + initial_run = adapter.generate_run( + input="hi", + input_source=None, + run_output=RunOutput( + output="hello", + intermediate_outputs={"chain_of_thought": "old"}, + trace=trace, + ), + usage=Usage(input_tokens=10, output_tokens=20), + trace=trace, + ) + extended_trace = [ + *trace, + {"role": "user", "content": "follow-up"}, + {"role": "assistant", "content": "ok"}, + ] + result = adapter.generate_run( + input="hi", + input_source=None, + run_output=RunOutput( + output="ok", + intermediate_outputs={"new_key": "new_val"}, + trace=extended_trace, + ), + usage=Usage(input_tokens=5, output_tokens=10), + trace=extended_trace, + existing_run=initial_run, + ) + assert result is initial_run + assert result.usage.input_tokens == 15 + assert result.usage.output_tokens == 30 + assert result.intermediate_outputs == { + "chain_of_thought": "old", + "new_key": "new_val", + } + assert result.output.output == "ok" + + def test_properties_for_task_output_custom_values(test_task): """Test that _properties_for_task_output includes custom temperature, top_p, and structured_output_mode""" adapter = MockAdapter( diff --git a/libs/core/kiln_ai/adapters/model_adapters/test_structured_output.py b/libs/core/kiln_ai/adapters/model_adapters/test_structured_output.py index fd5451705..92d9a1f4d 100644 --- a/libs/core/kiln_ai/adapters/model_adapters/test_structured_output.py +++ b/libs/core/kiln_ai/adapters/model_adapters/test_structured_output.py @@ -53,7 +53,11 @@ def __init__(self, kiln_task: datamodel.Task, response: InputType | None): ) self.response = response - async def _run(self, input: str) -> tuple[RunOutput, Usage | None]: + async def _run( + self, + input: str, + prior_trace=None, + ) -> tuple[RunOutput, Usage | None]: return RunOutput(output=self.response, intermediate_outputs=None), None def adapter_name(self) -> str: diff --git a/libs/core/kiln_ai/adapters/test_prompt_builders.py b/libs/core/kiln_ai/adapters/test_prompt_builders.py index 0f6485e73..25cadd99d 100644 --- a/libs/core/kiln_ai/adapters/test_prompt_builders.py +++ b/libs/core/kiln_ai/adapters/test_prompt_builders.py @@ -58,7 +58,11 @@ def test_simple_prompt_builder(tmp_path): class MockAdapter(BaseAdapter): - async def _run(self, input: InputType) -> tuple[RunOutput, Usage | None]: + async def _run( + self, + input: InputType, + prior_trace=None, + ) -> tuple[RunOutput, Usage | None]: return RunOutput(output="mock response", intermediate_outputs=None), None def adapter_name(self) -> str: diff --git a/libs/core/kiln_ai/datamodel/test_basemodel.py b/libs/core/kiln_ai/datamodel/test_basemodel.py index 92fadb7fc..03c6dfeaf 100644 --- a/libs/core/kiln_ai/datamodel/test_basemodel.py +++ b/libs/core/kiln_ai/datamodel/test_basemodel.py @@ -862,7 +862,7 @@ def individual_lookups(): class MockAdapter(BaseAdapter): """Implementation of BaseAdapter for testing""" - async def _run(self, input): + async def _run(self, input, prior_trace=None): return RunOutput(output="test output", intermediate_outputs=None), None def adapter_name(self) -> str: diff --git a/libs/server/kiln_server/run_api.py b/libs/server/kiln_server/run_api.py index f123d421e..f23aee0e8 100644 --- a/libs/server/kiln_server/run_api.py +++ b/libs/server/kiln_server/run_api.py @@ -56,6 +56,10 @@ class RunTaskRequest(BaseModel): plaintext_input: str | None = None structured_input: StructuredInputType | None = None tags: list[str] | None = None + task_run_id: str | None = Field( + default=None, + description="When set, continue an existing session. The new message is appended to the run's trace.", + ) # Allows use of the model_name field (usually pydantic will reserve model_*) model_config = ConfigDict(protected_namespaces=()) @@ -281,7 +285,7 @@ async def run_task( detail="No input provided. Ensure your provided the proper format (plaintext or structured).", ) - return await adapter.invoke(input) + return await adapter.invoke(input, task_run_id=request.task_run_id) @app.patch("/api/projects/{project_id}/tasks/{task_id}/runs/{run_id}") async def update_run( diff --git a/libs/server/kiln_server/test_run_api.py b/libs/server/kiln_server/test_run_api.py index 14d37371c..a645de32a 100644 --- a/libs/server/kiln_server/test_run_api.py +++ b/libs/server/kiln_server/test_run_api.py @@ -1,5 +1,7 @@ import logging +import os import time +from pathlib import Path from unittest.mock import AsyncMock, MagicMock, patch import pytest @@ -17,6 +19,8 @@ TaskOutputRatingType, TaskRun, ) +from kiln_ai.datamodel.tool_id import KilnBuiltInToolId +from kiln_ai.utils.config import Config from kiln_server.custom_errors import connect_custom_errors from kiln_server.run_api import ( @@ -135,6 +139,59 @@ async def test_run_task_success(client, task_run_setup): assert res["id"] is not None +@pytest.mark.asyncio +async def test_run_task_with_task_run_id_continues_session(client, task_run_setup): + """Test that run_task with task_run_id passes it to adapter.invoke for session continuation.""" + project = task_run_setup["project"] + task = task_run_setup["task"] + task_run = task_run_setup["task_run"] + + run_task_request = { + "run_config_properties": { + "model_name": "gpt_4o", + "model_provider_name": "ollama", + "prompt_id": "simple_prompt_builder", + "structured_output_mode": "json_schema", + }, + "plaintext_input": "Follow-up message", + "task_run_id": task_run.id, + } + + continued_run = TaskRun( + parent=task, + input=task_run.input, + input_source=task_run.input_source, + output=TaskOutput( + output="Continued response", + source=task_run.output.source, + ), + ) + continued_run.id = task_run.id + + with ( + patch("kiln_server.run_api.task_from_id") as mock_task_from_id, + patch.object(LiteLlmAdapter, "invoke", new_callable=AsyncMock) as mock_invoke, + patch("kiln_ai.utils.config.Config.shared") as MockConfig, + ): + mock_task_from_id.return_value = task + mock_invoke.return_value = continued_run + + mock_config_instance = MockConfig.return_value + mock_config_instance.ollama_base_url = "http://localhost:11434/v1" + + response = client.post( + f"/api/projects/{project.id}/tasks/{task.id}/run", json=run_task_request + ) + + assert response.status_code == 200 + mock_invoke.assert_called_once() + call_kwargs = mock_invoke.call_args[1] + assert call_kwargs["task_run_id"] == task_run.id + assert mock_invoke.call_args[0][0] == "Follow-up message" + res = response.json() + assert res["output"]["output"] == "Continued response" + + @pytest.mark.asyncio async def test_run_task_structured_output(client, task_run_setup): task = task_run_setup["task"] @@ -1663,3 +1720,180 @@ async def test_benchmark_tag_runs(client, task_run_setup, run_count): logger.info( f"Performance: {runs_per_second:.1f} runs/second, {avg_time_per_run * 1000:.2f}ms per run" ) + + +def _adapter_sanity_check_output_path() -> Path: + return Path(__file__).resolve().parent / "adapter_sanity_check.txt" + + +def _append_to_sanity_check(content: str, output_path: Path) -> None: + with open(output_path, "a", encoding="utf-8") as f: + f.write(content) + f.write("\n") + + +@pytest.fixture +def adapter_sanity_check_setup(tmp_path): + """Setup for paid adapter sanity check tests - real project/task, no adapter mocking.""" + # if project at the path does not exist, create it, otherwise reuse + project_path = ( + Path("/Users/leonardmarcq/Downloads/") + / "adapter_sanity_project" + / "project.kiln" + ) + if not project_path.exists(): + project_path.parent.mkdir() + + project = Project(name="Adapter Sanity Project", path=str(project_path)) + project.save_to_file() + + task = Task( + name="Adapter Sanity Task", + instruction="You are a helpful assistant. Respond concisely.", + description="Task for adapter sanity checking", + parent=project, + ) + task.save_to_file() + + else: + project = Project.load_from_file(project_path) + task = next( + ( + t + for t in project.tasks(readonly=True) + if t.name == "Adapter Sanity Task" + ), + None, + ) + if task is None: + raise ValueError("Task not found") + + config = Config.shared() + original_projects = list(config.projects) if config.projects else [] + config._settings["projects"] = [*original_projects, str(project.path)] + + yield {"project": project, "task": task} + + config._settings["projects"] = original_projects + + +@pytest.fixture +def adapter_sanity_check_math_tools_setup(tmp_path): + """Setup for paid math tools test - task with instructions to use add, multiply, etc.""" + project_path = tmp_path / "adapter_sanity_math_project" / "project.kiln" + project_path.parent.mkdir() + + project = Project(name="Adapter Sanity Math Project", path=str(project_path)) + project.save_to_file() + + task = Task( + name="Math Tools Task", + instruction="You are an assistant that performs math using the provided tools. You MUST use the add, subtract, multiply, and divide tools for any arithmetic. For example, for 2+2 you must call the add tool with a=2 and b=2. End your response with the final answer in square brackets, e.g. [4].", + description="Task for testing Kiln built-in math tools", + parent=project, + ) + task.save_to_file() + + config = Config.shared() + original_projects = list(config.projects) if config.projects else [] + config._settings["projects"] = [*original_projects, str(project.path)] + + yield {"project": project, "task": task} + + config._settings["projects"] = original_projects + + +def _assert_math_tools_response(res: dict, expected_in_output: str) -> None: + """Assert response has correct output, trace with tool calls, and output matches latest message.""" + assert res["id"] is not None + + output = res.get("output", {}).get("output", "") + assert output is not None + assert expected_in_output in output + + trace = res.get("trace") or [] + assistant_with_tool_calls = [ + m for m in trace if m.get("role") == "assistant" and m.get("tool_calls") + ] + assert len(assistant_with_tool_calls) >= 1 + + tool_messages = [m for m in trace if m.get("role") == "tool"] + assert len(tool_messages) >= 1 + + last_assistant = next( + (m for m in reversed(trace) if m.get("role") == "assistant"), None + ) + assert last_assistant is not None + last_content = last_assistant.get("content") or "" + assert expected_in_output in last_content + + intermediate_outputs = res.get("intermediate_outputs") or {} + if intermediate_outputs: + for key, value in intermediate_outputs.items(): + assert isinstance(value, str) + + +@pytest.mark.paid +@pytest.mark.asyncio +async def test_run_task_adapter_sanity_math_tools( + client, adapter_sanity_check_math_tools_setup +): + """Multi-turn run with built-in Kiln math tools. Uses gpt_4o_mini for function calling.""" + if not os.environ.get("OPENROUTER_API_KEY"): + pytest.skip("OPENROUTER_API_KEY required for this test") + + project = adapter_sanity_check_math_tools_setup["project"] + task = adapter_sanity_check_math_tools_setup["task"] + + run_config = { + "model_name": "gpt_4o_mini", + "model_provider_name": "openrouter", + "prompt_id": "simple_prompt_builder", + "structured_output_mode": "json_schema", + "tools_config": { + "tools": [ + KilnBuiltInToolId.ADD_NUMBERS.value, + KilnBuiltInToolId.SUBTRACT_NUMBERS.value, + KilnBuiltInToolId.MULTIPLY_NUMBERS.value, + KilnBuiltInToolId.DIVIDE_NUMBERS.value, + ] + }, + } + + response1 = client.post( + f"/api/projects/{project.id}/tasks/{task.id}/run", + json={ + "run_config_properties": run_config, + "plaintext_input": "What is 2 + 2? Use the tools to calculate.", + }, + ) + assert response1.status_code == 200 + res1 = response1.json() + _assert_math_tools_response(res1, "4") + task_run_id = res1["id"] + + response2 = client.post( + f"/api/projects/{project.id}/tasks/{task.id}/run", + json={ + "run_config_properties": run_config, + "plaintext_input": "What is 3 times 4? Use the tools to calculate.", + "task_run_id": task_run_id, + }, + ) + assert response2.status_code == 200 + res2 = response2.json() + assert res2["id"] == task_run_id + _assert_math_tools_response(res2, "12") + + response3 = client.post( + f"/api/projects/{project.id}/tasks/{task.id}/run", + json={ + "run_config_properties": run_config, + "plaintext_input": "What is 7 times 8 plus 3? Use the tools to calculate.", + "task_run_id": task_run_id, + }, + ) + assert response3.status_code == 200 + res3 = response3.json() + assert res3["id"] == task_run_id + _assert_math_tools_response(res3, "59") From 6c7cc58e5216ea728aa837c1c705842e16788cea Mon Sep 17 00:00:00 2001 From: "Leonard Q. Marcq" Date: Fri, 27 Feb 2026 15:21:28 +0800 Subject: [PATCH 02/32] refactor: use correct typing in chat formatter --- libs/core/kiln_ai/adapters/chat/__init__.py | 2 ++ .../core/kiln_ai/adapters/chat/chat_formatter.py | 16 ++++++++++------ .../adapters/model_adapters/litellm_adapter.py | 7 ++----- 3 files changed, 14 insertions(+), 11 deletions(-) diff --git a/libs/core/kiln_ai/adapters/chat/__init__.py b/libs/core/kiln_ai/adapters/chat/__init__.py index b8d7877fd..11b5eda12 100644 --- a/libs/core/kiln_ai/adapters/chat/__init__.py +++ b/libs/core/kiln_ai/adapters/chat/__init__.py @@ -1,5 +1,6 @@ from .chat_formatter import ( BasicChatMessage, + ChatCompletionMessageIncludingLiteLLM, ChatFormatter, ChatMessage, ChatStrategy, @@ -12,6 +13,7 @@ __all__ = [ "BasicChatMessage", + "ChatCompletionMessageIncludingLiteLLM", "ChatFormatter", "ChatMessage", "ChatStrategy", diff --git a/libs/core/kiln_ai/adapters/chat/chat_formatter.py b/libs/core/kiln_ai/adapters/chat/chat_formatter.py index a6f71f4b4..22ba371a3 100644 --- a/libs/core/kiln_ai/adapters/chat/chat_formatter.py +++ b/libs/core/kiln_ai/adapters/chat/chat_formatter.py @@ -3,7 +3,9 @@ import json from abc import ABC, abstractmethod from dataclasses import dataclass -from typing import Any, Dict, List, Literal, Optional, Sequence, Union +from typing import Dict, List, Literal, Optional, Sequence, TypeAlias, Union + +from litellm.types.utils import Message as LiteLLMMessage from kiln_ai.datamodel.datamodel_enums import ChatStrategy, InputType from kiln_ai.utils.exhaustive_error import raise_exhaustive_enum_error @@ -15,6 +17,11 @@ COT_FINAL_ANSWER_PROMPT = "Considering the above, return a final result." +ChatCompletionMessageIncludingLiteLLM: TypeAlias = Union[ + ChatCompletionMessageParam, LiteLLMMessage +] + + @dataclass class BasicChatMessage: role: Literal["system", "assistant", "user"] @@ -93,9 +100,8 @@ def intermediate_outputs(self) -> Dict[str, str]: """Get the intermediate outputs from the chat formatter.""" return self._intermediate_outputs - def initial_messages(self) -> list[Any]: + def initial_messages(self) -> list[ChatCompletionMessageIncludingLiteLLM]: """Messages to seed the conversation. Empty for fresh runs; prior trace for continuation.""" - # TODO: fix the type somehow return [] @abstractmethod @@ -264,10 +270,8 @@ def __init__( ) self._prior_trace = prior_trace - def initial_messages(self) -> list[Any]: + def initial_messages(self) -> list[ChatCompletionMessageIncludingLiteLLM]: """Messages to seed the conversation (prior trace).""" - # TODO: use the type we need, but trace is untyped, and we cannot import from litellm adapter here - # or we get circular imports return list(self._prior_trace) def next_turn(self, previous_output: str | None = None) -> Optional[ChatTurn]: diff --git a/libs/core/kiln_ai/adapters/model_adapters/litellm_adapter.py b/libs/core/kiln_ai/adapters/model_adapters/litellm_adapter.py index d49a7dc82..131be097c 100644 --- a/libs/core/kiln_ai/adapters/model_adapters/litellm_adapter.py +++ b/libs/core/kiln_ai/adapters/model_adapters/litellm_adapter.py @@ -3,7 +3,7 @@ import json import logging from dataclasses import dataclass -from typing import Any, Dict, List, Tuple, TypeAlias, Union +from typing import Any, Dict, List, Tuple import litellm from litellm.types.utils import ( @@ -19,6 +19,7 @@ ) import kiln_ai.datamodel as datamodel +from kiln_ai.adapters.chat import ChatCompletionMessageIncludingLiteLLM from kiln_ai.adapters.ml_model_list import ( KilnModelProvider, ModelProviderName, @@ -56,10 +57,6 @@ logger = logging.getLogger(__name__) -ChatCompletionMessageIncludingLiteLLM: TypeAlias = Union[ - ChatCompletionMessageParam, LiteLLMMessage -] - @dataclass class ModelTurnResult: From 04748e4aebfafb73e6bc2ce830b0d96d754a604e Mon Sep 17 00:00:00 2001 From: "Leonard Q. Marcq" Date: Fri, 27 Feb 2026 16:34:32 +0800 Subject: [PATCH 03/32] refactor: retrieve task_run one level up --- .../adapters/model_adapters/base_adapter.py | 35 +++--- .../adapters/model_adapters/mcp_adapter.py | 10 +- .../model_adapters/test_base_adapter.py | 36 +++++-- .../model_adapters/test_mcp_adapter.py | 16 ++- .../test_saving_adapter_results.py | 18 +--- libs/server/kiln_server/run_api.py | 23 +++- libs/server/kiln_server/test_run_api.py | 100 +++++++++++++++++- 7 files changed, 179 insertions(+), 59 deletions(-) diff --git a/libs/core/kiln_ai/adapters/model_adapters/base_adapter.py b/libs/core/kiln_ai/adapters/model_adapters/base_adapter.py index a4d43b873..7cda8c6dc 100644 --- a/libs/core/kiln_ai/adapters/model_adapters/base_adapter.py +++ b/libs/core/kiln_ai/adapters/model_adapters/base_adapter.py @@ -27,7 +27,6 @@ TaskRun, Usage, ) -from kiln_ai.datamodel.basemodel import ID_TYPE from kiln_ai.datamodel.datamodel_enums import ChatStrategy, InputType from kiln_ai.datamodel.json_schema import validate_schema_with_value_error from kiln_ai.datamodel.run_config import ( @@ -128,10 +127,10 @@ async def invoke( self, input: InputType, input_source: DataSource | None = None, - task_run_id: ID_TYPE | None = None, + existing_run: TaskRun | None = None, ) -> TaskRun: run_output, _ = await self.invoke_returning_run_output( - input, input_source, task_run_id + input, input_source, existing_run ) return run_output @@ -139,7 +138,7 @@ async def _run_returning_run_output( self, input: InputType, input_source: DataSource | None = None, - task_run_id: ID_TYPE | None = None, + existing_run: TaskRun | None = None, ) -> Tuple[TaskRun, RunOutput]: # validate input, allowing arrays if self.input_schema is not None: @@ -150,24 +149,14 @@ async def _run_returning_run_output( require_object=False, ) - prior_trace: list[ChatCompletionMessageParam] | None = None - existing_run: TaskRun | None = None + if existing_run is not None and ( + not existing_run.trace or len(existing_run.trace) == 0 + ): + raise ValueError( + "Run has no trace. Cannot continue session without conversation history." + ) - if task_run_id is not None: - if self.task.path is None: - raise ValueError( - "Cannot continue session: task has no path. Save the task first." - ) - existing_run = TaskRun.from_id_and_parent_path(task_run_id, self.task.path) - if existing_run is None: - raise ValueError( - f"Run not found. Cannot continue session. ID: {task_run_id}" - ) - if not existing_run.trace or len(existing_run.trace) == 0: - raise ValueError( - f"Run has no trace. Cannot continue session without conversation history. ID: {task_run_id}" - ) - prior_trace = existing_run.trace + prior_trace = existing_run.trace if existing_run else None # Format model input for model call (we save the original input in the task without formatting) formatted_input = input @@ -266,7 +255,7 @@ async def invoke_returning_run_output( self, input: InputType, input_source: DataSource | None = None, - task_run_id: ID_TYPE | None = None, + existing_run: TaskRun | None = None, ) -> Tuple[TaskRun, RunOutput]: # Determine if this is the root agent (no existing run context) is_root_agent = get_agent_run_id() is None @@ -277,7 +266,7 @@ async def invoke_returning_run_output( try: return await self._run_returning_run_output( - input, input_source, task_run_id + input, input_source, existing_run ) finally: if is_root_agent: diff --git a/libs/core/kiln_ai/adapters/model_adapters/mcp_adapter.py b/libs/core/kiln_ai/adapters/model_adapters/mcp_adapter.py index cd0d34662..45aabc53e 100644 --- a/libs/core/kiln_ai/adapters/model_adapters/mcp_adapter.py +++ b/libs/core/kiln_ai/adapters/model_adapters/mcp_adapter.py @@ -85,16 +85,16 @@ async def invoke( self, input: InputType, input_source: DataSource | None = None, - task_run_id: str | None = None, + existing_run: TaskRun | None = None, ) -> TaskRun: - if task_run_id is not None: + if existing_run is not None: raise NotImplementedError( "Session continuation is not supported for MCP adapter. " "MCP tools are single-turn and do not maintain conversation state." ) run_output, _ = await self.invoke_returning_run_output( - input, input_source, task_run_id + input, input_source, existing_run ) return run_output @@ -102,13 +102,13 @@ async def invoke_returning_run_output( self, input: InputType, input_source: DataSource | None = None, - task_run_id: str | None = None, + existing_run: TaskRun | None = None, ) -> Tuple[TaskRun, RunOutput]: """ Runs the task and returns both the persisted TaskRun and raw RunOutput. If this call is the root of a run, it creates an agent run context, ensures MCP tool calls have a valid session scope, and cleans up the session/context on completion. """ - if task_run_id is not None: + if existing_run is not None: raise NotImplementedError( "Session continuation is not supported for MCP adapter. " "MCP tools are single-turn and do not maintain conversation state." diff --git a/libs/core/kiln_ai/adapters/model_adapters/test_base_adapter.py b/libs/core/kiln_ai/adapters/model_adapters/test_base_adapter.py index 8b5619d6b..8b150c68e 100644 --- a/libs/core/kiln_ai/adapters/model_adapters/test_base_adapter.py +++ b/libs/core/kiln_ai/adapters/model_adapters/test_base_adapter.py @@ -9,7 +9,13 @@ RunOutput, ) from kiln_ai.adapters.prompt_builders import BasePromptBuilder -from kiln_ai.datamodel import Task +from kiln_ai.datamodel import ( + DataSource, + DataSourceType, + Task, + TaskOutput, + TaskRun, +) from kiln_ai.datamodel.datamodel_enums import ChatStrategy, ModelProviderName from kiln_ai.datamodel.project import Project from kiln_ai.datamodel.run_config import KilnAgentRunConfigProperties, ToolsRunConfig @@ -431,13 +437,12 @@ def test_build_chat_formatter_with_prior_trace_returns_multiturn_formatter(adapt @pytest.mark.asyncio -async def test_task_run_id_task_path_none_raises(base_project): +async def test_existing_run_without_trace_raises(base_project): task = Task( name="test_task", instruction="test_instruction", parent=base_project, ) - assert task.path is None adapter = MockAdapter( task=task, run_config=KilnAgentRunConfigProperties( @@ -447,12 +452,29 @@ async def test_task_run_id_task_path_none_raises(base_project): structured_output_mode=StructuredOutputMode.json_schema, ), ) - with pytest.raises(ValueError, match="task has no path"): - await adapter.invoke("input", task_run_id="some-id") + run_without_trace = TaskRun( + parent=task, + input="hi", + input_source=None, + output=TaskOutput( + output="hello", + source=DataSource( + type=DataSourceType.synthetic, + properties={ + "model_name": "gpt_4o", + "model_provider": "openai", + "adapter_name": "test", + }, + ), + ), + trace=None, + ) + with pytest.raises(ValueError, match="no trace"): + await adapter.invoke("input", existing_run=run_without_trace) @pytest.mark.asyncio -async def test_invoke_returning_run_output_passes_task_run_id_to_run( +async def test_invoke_returning_run_output_passes_existing_run_to_run( adapter, mock_parser, tmp_path ): project = Project(name="proj", path=tmp_path / "proj.kiln") @@ -510,7 +532,7 @@ async def mock_run(input, prior_trace=None): "kiln_ai.adapters.model_adapters.base_adapter.request_formatter_from_id", ), ): - await adapter.invoke_returning_run_output("follow-up", task_run_id=run_id) + await adapter.invoke_returning_run_output("follow-up", existing_run=initial_run) assert captured_prior_trace == trace diff --git a/libs/core/kiln_ai/adapters/model_adapters/test_mcp_adapter.py b/libs/core/kiln_ai/adapters/model_adapters/test_mcp_adapter.py index ac00673d5..ad5826d8c 100644 --- a/libs/core/kiln_ai/adapters/model_adapters/test_mcp_adapter.py +++ b/libs/core/kiln_ai/adapters/model_adapters/test_mcp_adapter.py @@ -1,6 +1,6 @@ import json from pathlib import Path -from unittest.mock import AsyncMock, patch +from unittest.mock import AsyncMock, MagicMock, patch import pytest from mcp.types import CallToolResult, TextContent @@ -334,7 +334,7 @@ async def test_mcp_adapter_sets_and_clears_run_context( async def test_mcp_adapter_rejects_multiturn_invoke_returning_run_output( project_with_local_mcp_server, local_mcp_tool_id ): - """Session continuation (task_run_id) is not supported for MCP adapter.""" + """Session continuation (existing_run) is not supported for MCP adapter.""" project, _ = project_with_local_mcp_server task = Task( name="Test Task", @@ -348,8 +348,11 @@ async def test_mcp_adapter_rejects_multiturn_invoke_returning_run_output( adapter = MCPAdapter(task=task, run_config=run_config) + existing_run = MagicMock() + existing_run.trace = [{"role": "user", "content": "hi"}] + with pytest.raises(NotImplementedError) as exc_info: - await adapter.invoke_returning_run_output("input", task_run_id="some-run-id") + await adapter.invoke_returning_run_output("input", existing_run=existing_run) assert "Session continuation is not supported" in str(exc_info.value) assert "MCP adapter" in str(exc_info.value) @@ -359,7 +362,7 @@ async def test_mcp_adapter_rejects_multiturn_invoke_returning_run_output( async def test_mcp_adapter_rejects_multiturn_invoke( project_with_local_mcp_server, local_mcp_tool_id ): - """invoke with task_run_id raises NotImplementedError for MCP adapter.""" + """invoke with existing_run raises NotImplementedError for MCP adapter.""" project, _ = project_with_local_mcp_server task = Task( name="Test Task", @@ -373,8 +376,11 @@ async def test_mcp_adapter_rejects_multiturn_invoke( adapter = MCPAdapter(task=task, run_config=run_config) + existing_run = MagicMock() + existing_run.trace = [{"role": "user", "content": "hi"}] + with pytest.raises(NotImplementedError) as exc_info: - await adapter.invoke("input", task_run_id="some-run-id") + await adapter.invoke("input", existing_run=existing_run) assert "Session continuation is not supported" in str(exc_info.value) diff --git a/libs/core/kiln_ai/adapters/model_adapters/test_saving_adapter_results.py b/libs/core/kiln_ai/adapters/model_adapters/test_saving_adapter_results.py index 82ea05905..f96c70a41 100644 --- a/libs/core/kiln_ai/adapters/model_adapters/test_saving_adapter_results.py +++ b/libs/core/kiln_ai/adapters/model_adapters/test_saving_adapter_results.py @@ -309,7 +309,7 @@ async def mock_run(input, prior_trace=None): ) mock_parser_from_id.return_value = mock_parser - updated_run = await adapter.invoke("Tell me more", task_run_id=run_id) + updated_run = await adapter.invoke("Tell me more", existing_run=initial_run) assert updated_run.id == run_id assert updated_run.input == "Hello" @@ -324,21 +324,9 @@ async def mock_run(input, prior_trace=None): assert runs[0].output.output == "How can I help?" -@pytest.mark.asyncio -async def test_invoke_continue_invalid_task_run_id(test_task, adapter): - """Test that invoke with invalid task_run_id raises ValueError.""" - with patch("kiln_ai.utils.config.Config.shared") as mock_shared: - mock_config = mock_shared.return_value - mock_config.autosave_runs = True - mock_config.user_id = "test_user" - - with pytest.raises(ValueError, match="Run not found"): - await adapter.invoke("Hello", task_run_id="nonexistent-id") - - @pytest.mark.asyncio async def test_invoke_continue_run_without_trace(test_task, adapter): - """Test that invoke with task_run_id for a run without trace raises ValueError.""" + """Test that invoke with existing_run that has no trace raises ValueError.""" with patch("kiln_ai.utils.config.Config.shared") as mock_shared: mock_config = mock_shared.return_value mock_config.autosave_runs = True @@ -356,7 +344,7 @@ async def test_invoke_continue_run_without_trace(test_task, adapter): run_without_trace.save_to_file() with pytest.raises(ValueError, match="no trace"): - await adapter.invoke("Follow up", task_run_id=run_without_trace.id) + await adapter.invoke("Follow up", existing_run=run_without_trace) def test_generate_run_with_existing_run_merges_usage_and_intermediate_outputs( diff --git a/libs/server/kiln_server/run_api.py b/libs/server/kiln_server/run_api.py index f23aee0e8..ea13104d2 100644 --- a/libs/server/kiln_server/run_api.py +++ b/libs/server/kiln_server/run_api.py @@ -285,7 +285,28 @@ async def run_task( detail="No input provided. Ensure your provided the proper format (plaintext or structured).", ) - return await adapter.invoke(input, task_run_id=request.task_run_id) + existing_run: TaskRun | None = None + if request.task_run_id is not None: + if task.path is None: + raise HTTPException( + status_code=400, + detail="Cannot continue session: task has no path. Save the task first.", + ) + existing_run = TaskRun.from_id_and_parent_path( + request.task_run_id, task.path + ) + if existing_run is None: + raise HTTPException( + status_code=404, + detail="Run not found. Cannot continue session.", + ) + if not existing_run.trace or len(existing_run.trace) == 0: + raise HTTPException( + status_code=400, + detail="Run has no trace. Cannot continue session without conversation history.", + ) + + return await adapter.invoke(input, existing_run=existing_run) @app.patch("/api/projects/{project_id}/tasks/{task_id}/runs/{run_id}") async def update_run( diff --git a/libs/server/kiln_server/test_run_api.py b/libs/server/kiln_server/test_run_api.py index a645de32a..608a03a94 100644 --- a/libs/server/kiln_server/test_run_api.py +++ b/libs/server/kiln_server/test_run_api.py @@ -99,6 +99,10 @@ def task_run_setup(tmp_path): }, ), ), + trace=[ + {"role": "user", "content": "Test input"}, + {"role": "assistant", "content": "Test output"}, + ], ) task_run.save_to_file() @@ -186,12 +190,88 @@ async def test_run_task_with_task_run_id_continues_session(client, task_run_setu assert response.status_code == 200 mock_invoke.assert_called_once() call_kwargs = mock_invoke.call_args[1] - assert call_kwargs["task_run_id"] == task_run.id + assert call_kwargs["existing_run"].id == task_run.id assert mock_invoke.call_args[0][0] == "Follow-up message" res = response.json() assert res["output"]["output"] == "Continued response" +@pytest.mark.asyncio +async def test_run_task_task_run_id_not_found_returns_404(client, task_run_setup): + """Test that run_task with nonexistent task_run_id returns 404.""" + project = task_run_setup["project"] + task = task_run_setup["task"] + + run_task_request = { + "run_config_properties": { + "model_name": "gpt_4o", + "model_provider_name": "ollama", + "prompt_id": "simple_prompt_builder", + "structured_output_mode": "json_schema", + }, + "plaintext_input": "Follow-up", + "task_run_id": "nonexistent-run-id", + } + + with patch("kiln_server.run_api.task_from_id") as mock_task_from_id: + mock_task_from_id.return_value = task + response = client.post( + f"/api/projects/{project.id}/tasks/{task.id}/run", json=run_task_request + ) + + assert response.status_code == 404 + assert "Run not found" in response.json()["message"] + + +@pytest.mark.asyncio +async def test_run_task_task_run_id_no_trace_returns_400(client, task_run_setup): + """Test that run_task with task_run_id for run without trace returns 400.""" + project = task_run_setup["project"] + task = task_run_setup["task"] + + task_run_no_trace = TaskRun( + parent=task, + input="Hello", + input_source=DataSource( + type=DataSourceType.human, properties={"created_by": "Test User"} + ), + output=TaskOutput( + output="Hi", + source=DataSource( + type=DataSourceType.synthetic, + properties={ + "model_name": "gpt_4o", + "model_provider": "ollama", + "adapter_name": "kiln_langchain_adapter", + "prompt_id": "simple_prompt_builder", + }, + ), + ), + trace=None, + ) + task_run_no_trace.save_to_file() + + run_task_request = { + "run_config_properties": { + "model_name": "gpt_4o", + "model_provider_name": "ollama", + "prompt_id": "simple_prompt_builder", + "structured_output_mode": "json_schema", + }, + "plaintext_input": "Follow-up", + "task_run_id": task_run_no_trace.id, + } + + with patch("kiln_server.run_api.task_from_id") as mock_task_from_id: + mock_task_from_id.return_value = task + response = client.post( + f"/api/projects/{project.id}/tasks/{task.id}/run", json=run_task_request + ) + + assert response.status_code == 400 + assert "no trace" in response.json()["message"].lower() + + @pytest.mark.asyncio async def test_run_task_structured_output(client, task_run_setup): task = task_run_setup["task"] @@ -1838,7 +1918,7 @@ def _assert_math_tools_response(res: dict, expected_in_output: str) -> None: async def test_run_task_adapter_sanity_math_tools( client, adapter_sanity_check_math_tools_setup ): - """Multi-turn run with built-in Kiln math tools. Uses gpt_4o_mini for function calling.""" + """Multi-turn run with built-in Kiln math tools. Test that tools + continue session work as expected.""" if not os.environ.get("OPENROUTER_API_KEY"): pytest.skip("OPENROUTER_API_KEY required for this test") @@ -1846,7 +1926,7 @@ async def test_run_task_adapter_sanity_math_tools( task = adapter_sanity_check_math_tools_setup["task"] run_config = { - "model_name": "gpt_4o_mini", + "model_name": "gpt_5_nano", "model_provider_name": "openrouter", "prompt_id": "simple_prompt_builder", "structured_output_mode": "json_schema", @@ -1897,3 +1977,17 @@ async def test_run_task_adapter_sanity_math_tools( res3 = response3.json() assert res3["id"] == task_run_id _assert_math_tools_response(res3, "59") + + # now ask it to list out all the previous results in an array + response4 = client.post( + f"/api/projects/{project.id}/tasks/{task.id}/run", + json={ + "run_config_properties": run_config, + "plaintext_input": "List all the previous results in an array - e.g. [55, 81, 7].", + "task_run_id": task_run_id, + }, + ) + assert response4.status_code == 200 + res4 = response4.json() + assert res4["id"] == task_run_id + assert res4["output"]["output"] == "[4, 12, 59]" From 57896a81dc50f9fb42b04bab083ac150b7c3e67d Mon Sep 17 00:00:00 2001 From: scosman Date: Wed, 18 Feb 2026 15:35:46 -0500 Subject: [PATCH 04/32] Proof of concept streaming API --- .../adapters/litellm_utils/__init__.py | 0 .../litellm_utils/litellm_streaming.py | 60 ++++++++ .../litellm_utils/test_litellm_streaming.py | 139 ++++++++++++++++++ .../adapters/model_adapters/base_adapter.py | 17 ++- .../model_adapters/litellm_adapter.py | 17 ++- .../adapters/model_adapters/mcp_adapter.py | 11 +- .../model_adapters/test_base_adapter.py | 139 ++++++++++++++++-- .../model_adapters/test_litellm_adapter.py | 14 +- .../test_litellm_adapter_tools.py | 40 ++++- .../test_saving_adapter_results.py | 9 +- .../model_adapters/test_structured_output.py | 28 ++-- .../kiln_ai/adapters/test_prompt_adaptors.py | 20 +-- .../kiln_ai/adapters/test_prompt_builders.py | 6 +- libs/core/kiln_ai/datamodel/test_basemodel.py | 2 +- 14 files changed, 439 insertions(+), 63 deletions(-) create mode 100644 libs/core/kiln_ai/adapters/litellm_utils/__init__.py create mode 100644 libs/core/kiln_ai/adapters/litellm_utils/litellm_streaming.py create mode 100644 libs/core/kiln_ai/adapters/litellm_utils/test_litellm_streaming.py diff --git a/libs/core/kiln_ai/adapters/litellm_utils/__init__.py b/libs/core/kiln_ai/adapters/litellm_utils/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/libs/core/kiln_ai/adapters/litellm_utils/litellm_streaming.py b/libs/core/kiln_ai/adapters/litellm_utils/litellm_streaming.py new file mode 100644 index 000000000..68b29dd62 --- /dev/null +++ b/libs/core/kiln_ai/adapters/litellm_utils/litellm_streaming.py @@ -0,0 +1,60 @@ +from __future__ import annotations + +from typing import Any, AsyncIterator, Optional, Union + +import litellm +from litellm.types.utils import ( + ModelResponse, + ModelResponseStream, + TextCompletionResponse, +) + + +class StreamingCompletion: + """ + Async iterable wrapper around ``litellm.acompletion`` with streaming. + + Yields ``ModelResponseStream`` chunks as they arrive. After iteration + completes, the assembled ``ModelResponse`` is available via the + ``.response`` property. + + Usage:: + + stream = StreamingCompletion(model=..., messages=...) + async for chunk in stream: + # handle chunk however you like (print, log, send over WS, …) + pass + final = stream.response # fully assembled ModelResponse + """ + + def __init__(self, *args: Any, **kwargs: Any) -> None: + kwargs = dict(kwargs) + kwargs.pop("stream", None) + self._args = args + self._kwargs = kwargs + self._response: Optional[Union[ModelResponse, TextCompletionResponse]] = None + self._iterated: bool = False + + @property + def response(self) -> Optional[Union[ModelResponse, TextCompletionResponse]]: + """The final assembled response. Only available after iteration.""" + if not self._iterated: + raise RuntimeError( + "StreamingCompletion has not been iterated yet. " + "Use 'async for chunk in stream:' before accessing .response" + ) + return self._response + + async def __aiter__(self) -> AsyncIterator[ModelResponseStream]: + self._response = None + self._iterated = False + + chunks: list[ModelResponseStream] = [] + stream = await litellm.acompletion(*self._args, stream=True, **self._kwargs) + + async for chunk in stream: + chunks.append(chunk) + yield chunk + + self._response = litellm.stream_chunk_builder(chunks) + self._iterated = True diff --git a/libs/core/kiln_ai/adapters/litellm_utils/test_litellm_streaming.py b/libs/core/kiln_ai/adapters/litellm_utils/test_litellm_streaming.py new file mode 100644 index 000000000..e35a51982 --- /dev/null +++ b/libs/core/kiln_ai/adapters/litellm_utils/test_litellm_streaming.py @@ -0,0 +1,139 @@ +from __future__ import annotations + +from types import SimpleNamespace +from typing import Any, List +from unittest.mock import MagicMock, patch + +import pytest + +from kiln_ai.adapters.litellm_utils.litellm_streaming import StreamingCompletion + + +def _make_chunk(content: str | None = None, finish_reason: str | None = None) -> Any: + """Build a minimal chunk object matching litellm's streaming shape.""" + delta = SimpleNamespace(content=content, role="assistant") + choice = SimpleNamespace(delta=delta, finish_reason=finish_reason, index=0) + return SimpleNamespace(choices=[choice], id="chatcmpl-test", model="test-model") + + +async def _async_iter(items: List[Any]): + """Turn a plain list into an async iterator.""" + for item in items: + yield item + + +@pytest.fixture +def mock_acompletion(): + with patch("litellm.acompletion") as mock: + yield mock + + +@pytest.fixture +def mock_chunk_builder(): + with patch("litellm.stream_chunk_builder") as mock: + yield mock + + +class TestStreamingCompletion: + async def test_yields_all_chunks(self, mock_acompletion, mock_chunk_builder): + chunks = [_make_chunk("Hello"), _make_chunk(" world"), _make_chunk("!")] + mock_acompletion.return_value = _async_iter(chunks) + mock_chunk_builder.return_value = MagicMock(name="final_response") + + stream = StreamingCompletion(model="test", messages=[]) + received = [chunk async for chunk in stream] + + assert received == chunks + + async def test_response_available_after_iteration( + self, mock_acompletion, mock_chunk_builder + ): + chunks = [_make_chunk("hi")] + mock_acompletion.return_value = _async_iter(chunks) + sentinel = MagicMock(name="final_response") + mock_chunk_builder.return_value = sentinel + + stream = StreamingCompletion(model="test", messages=[]) + async for _ in stream: + pass + + assert stream.response is sentinel + + async def test_response_raises_before_iteration(self): + stream = StreamingCompletion(model="test", messages=[]) + with pytest.raises(RuntimeError, match="not been iterated"): + _ = stream.response + + async def test_stream_kwarg_is_stripped(self, mock_acompletion, mock_chunk_builder): + mock_acompletion.return_value = _async_iter([]) + mock_chunk_builder.return_value = None + + stream = StreamingCompletion(model="test", messages=[], stream=False) + async for _ in stream: + pass + + _, call_kwargs = mock_acompletion.call_args + assert call_kwargs["stream"] is True + + async def test_passes_args_and_kwargs_through( + self, mock_acompletion, mock_chunk_builder + ): + mock_acompletion.return_value = _async_iter([]) + mock_chunk_builder.return_value = None + + stream = StreamingCompletion( + model="gpt-4", messages=[{"role": "user", "content": "hi"}], temperature=0.5 + ) + async for _ in stream: + pass + + _, call_kwargs = mock_acompletion.call_args + assert call_kwargs["model"] == "gpt-4" + assert call_kwargs["messages"] == [{"role": "user", "content": "hi"}] + assert call_kwargs["temperature"] == 0.5 + assert call_kwargs["stream"] is True + + async def test_chunks_passed_to_builder(self, mock_acompletion, mock_chunk_builder): + chunks = [_make_chunk("a"), _make_chunk("b")] + mock_acompletion.return_value = _async_iter(chunks) + mock_chunk_builder.return_value = MagicMock() + + stream = StreamingCompletion(model="test", messages=[]) + async for _ in stream: + pass + + mock_chunk_builder.assert_called_once_with(chunks) + + async def test_re_iteration_resets_state( + self, mock_acompletion, mock_chunk_builder + ): + first_chunks = [_make_chunk("first")] + second_chunks = [_make_chunk("second")] + first_response = MagicMock(name="first_response") + second_response = MagicMock(name="second_response") + + mock_acompletion.side_effect = [ + _async_iter(first_chunks), + _async_iter(second_chunks), + ] + mock_chunk_builder.side_effect = [first_response, second_response] + + stream = StreamingCompletion(model="test", messages=[]) + + async for _ in stream: + pass + assert stream.response is first_response + + async for _ in stream: + pass + assert stream.response is second_response + + async def test_empty_stream(self, mock_acompletion, mock_chunk_builder): + mock_acompletion.return_value = _async_iter([]) + mock_chunk_builder.return_value = None + + stream = StreamingCompletion(model="test", messages=[]) + received = [chunk async for chunk in stream] + + assert received == [] + assert stream.response is None diff --git a/libs/core/kiln_ai/adapters/model_adapters/base_adapter.py b/libs/core/kiln_ai/adapters/model_adapters/base_adapter.py index 7cda8c6dc..6c2863502 100644 --- a/libs/core/kiln_ai/adapters/model_adapters/base_adapter.py +++ b/libs/core/kiln_ai/adapters/model_adapters/base_adapter.py @@ -1,8 +1,11 @@ import json from abc import ABCMeta, abstractmethod +from collections.abc import Awaitable, Callable from dataclasses import dataclass from typing import Dict, Tuple +from litellm.types.utils import ModelResponseStream + from kiln_ai.adapters.chat.chat_formatter import ( ChatFormatter, MultiturnFormatter, @@ -49,6 +52,8 @@ from kiln_ai.utils.exhaustive_error import raise_exhaustive_enum_error from kiln_ai.utils.open_ai_types import ChatCompletionMessageParam +StreamCallback = Callable[[ModelResponseStream], Awaitable[None]] + @dataclass class AdapterConfig: @@ -128,9 +133,10 @@ async def invoke( input: InputType, input_source: DataSource | None = None, existing_run: TaskRun | None = None, + on_chunk: StreamCallback | None = None, ) -> TaskRun: run_output, _ = await self.invoke_returning_run_output( - input, input_source, existing_run + input, input_source, existing_run, on_chunk=on_chunk ) return run_output @@ -139,6 +145,7 @@ async def _run_returning_run_output( input: InputType, input_source: DataSource | None = None, existing_run: TaskRun | None = None, + on_chunk: StreamCallback | None = None, ) -> Tuple[TaskRun, RunOutput]: # validate input, allowing arrays if self.input_schema is not None: @@ -166,7 +173,9 @@ async def _run_returning_run_output( formatted_input = formatter.format_input(input) # Run - run_output, usage = await self._run(formatted_input, prior_trace=prior_trace) + run_output, usage = await self._run( + formatted_input, prior_trace=prior_trace, on_chunk=on_chunk + ) # Parse provider = self.model_provider() @@ -256,6 +265,7 @@ async def invoke_returning_run_output( input: InputType, input_source: DataSource | None = None, existing_run: TaskRun | None = None, + on_chunk: StreamCallback | None = None, ) -> Tuple[TaskRun, RunOutput]: # Determine if this is the root agent (no existing run context) is_root_agent = get_agent_run_id() is None @@ -266,7 +276,7 @@ async def invoke_returning_run_output( try: return await self._run_returning_run_output( - input, input_source, existing_run + input, input_source, existing_run, on_chunk=on_chunk ) finally: if is_root_agent: @@ -289,6 +299,7 @@ async def _run( self, input: InputType, prior_trace: list[ChatCompletionMessageParam] | None = None, + on_chunk: StreamCallback | None = None, ) -> Tuple[RunOutput, Usage | None]: pass diff --git a/libs/core/kiln_ai/adapters/model_adapters/litellm_adapter.py b/libs/core/kiln_ai/adapters/model_adapters/litellm_adapter.py index 131be097c..344bedec0 100644 --- a/libs/core/kiln_ai/adapters/model_adapters/litellm_adapter.py +++ b/libs/core/kiln_ai/adapters/model_adapters/litellm_adapter.py @@ -5,7 +5,6 @@ from dataclasses import dataclass from typing import Any, Dict, List, Tuple -import litellm from litellm.types.utils import ( ChatCompletionMessageToolCall, ChoiceLogprobs, @@ -20,6 +19,7 @@ import kiln_ai.datamodel as datamodel from kiln_ai.adapters.chat import ChatCompletionMessageIncludingLiteLLM +from kiln_ai.adapters.litellm_utils.litellm_streaming import StreamingCompletion from kiln_ai.adapters.ml_model_list import ( KilnModelProvider, ModelProviderName, @@ -29,6 +29,7 @@ AdapterConfig, BaseAdapter, RunOutput, + StreamCallback, Usage, ) from kiln_ai.adapters.model_adapters.litellm_config import LiteLlmConfig @@ -95,6 +96,7 @@ async def _run_model_turn( prior_messages: list[ChatCompletionMessageIncludingLiteLLM], top_logprobs: int | None, skip_response_format: bool, + on_chunk: StreamCallback | None = None, ) -> ModelTurnResult: """ Call the model for a single top level turn: from user message to agent message. @@ -118,7 +120,7 @@ async def _run_model_turn( # Make the completion call model_response, response_choice = await self.acompletion_checking_response( - **completion_kwargs + on_chunk=on_chunk, **completion_kwargs ) # count the usage @@ -185,6 +187,7 @@ async def _run( self, input: InputType, prior_trace: list[ChatCompletionMessageParam] | None = None, + on_chunk: StreamCallback | None = None, ) -> tuple[RunOutput, Usage | None]: usage = Usage() @@ -229,6 +232,7 @@ async def _run( messages, self.base_adapter_config.top_logprobs if turn.final_call else None, skip_response_format, + on_chunk=on_chunk, ) usage += turn_result.usage @@ -297,9 +301,14 @@ def _extract_reasoning_to_intermediate_outputs( intermediate_outputs["reasoning"] = stripped_reasoning_content async def acompletion_checking_response( - self, **kwargs + self, on_chunk: StreamCallback | None = None, **kwargs ) -> Tuple[ModelResponse, Choices]: - response = await litellm.acompletion(**kwargs) + stream = StreamingCompletion(**kwargs) + async for chunk in stream: + if on_chunk is not None: + await on_chunk(chunk) + response = stream.response + if ( not isinstance(response, ModelResponse) or not response.choices diff --git a/libs/core/kiln_ai/adapters/model_adapters/mcp_adapter.py b/libs/core/kiln_ai/adapters/model_adapters/mcp_adapter.py index 45aabc53e..236b63a9e 100644 --- a/libs/core/kiln_ai/adapters/model_adapters/mcp_adapter.py +++ b/libs/core/kiln_ai/adapters/model_adapters/mcp_adapter.py @@ -1,7 +1,11 @@ import json from typing import Tuple -from kiln_ai.adapters.model_adapters.base_adapter import AdapterConfig, BaseAdapter +from kiln_ai.adapters.model_adapters.base_adapter import ( + AdapterConfig, + BaseAdapter, + StreamCallback, +) from kiln_ai.adapters.parsers.json_parser import parse_json_string from kiln_ai.adapters.run_output import RunOutput from kiln_ai.datamodel import DataSource, Task, TaskRun, Usage @@ -46,6 +50,7 @@ async def _run( self, input: InputType, prior_trace: list[ChatCompletionMessageParam] | None = None, + on_chunk: StreamCallback | None = None, ) -> Tuple[RunOutput, Usage | None]: if prior_trace is not None: raise NotImplementedError( @@ -86,6 +91,7 @@ async def invoke( input: InputType, input_source: DataSource | None = None, existing_run: TaskRun | None = None, + on_chunk: StreamCallback | None = None, ) -> TaskRun: if existing_run is not None: raise NotImplementedError( @@ -94,7 +100,7 @@ async def invoke( ) run_output, _ = await self.invoke_returning_run_output( - input, input_source, existing_run + input, input_source, existing_run, on_chunk=on_chunk ) return run_output @@ -103,6 +109,7 @@ async def invoke_returning_run_output( input: InputType, input_source: DataSource | None = None, existing_run: TaskRun | None = None, + on_chunk: StreamCallback | None = None, ) -> Tuple[TaskRun, RunOutput]: """ Runs the task and returns both the persisted TaskRun and raw RunOutput. diff --git a/libs/core/kiln_ai/adapters/model_adapters/test_base_adapter.py b/libs/core/kiln_ai/adapters/model_adapters/test_base_adapter.py index 8b150c68e..e6f1cca82 100644 --- a/libs/core/kiln_ai/adapters/model_adapters/test_base_adapter.py +++ b/libs/core/kiln_ai/adapters/model_adapters/test_base_adapter.py @@ -26,7 +26,7 @@ class MockAdapter(BaseAdapter): """Concrete implementation of BaseAdapter for testing""" - async def _run(self, input, prior_trace=None): + async def _run(self, input, **kwargs): return None, None def adapter_name(self) -> str: @@ -239,7 +239,7 @@ async def test_input_formatting( # Mock the _run method to capture the input captured_input = None - async def mock_run(input, prior_trace=None): + async def mock_run(input, **kwargs): nonlocal captured_input captured_input = input return RunOutput(output="test output", intermediate_outputs={}), None @@ -507,9 +507,9 @@ async def test_invoke_returning_run_output_passes_existing_run_to_run( captured_prior_trace = None - async def mock_run(input, prior_trace=None): + async def mock_run(input, **kwargs): nonlocal captured_prior_trace - captured_prior_trace = prior_trace + captured_prior_trace = kwargs.get("prior_trace") return RunOutput(output="ok", intermediate_outputs=None, trace=trace), None adapter._run = mock_run @@ -798,7 +798,7 @@ async def test_invoke_sets_run_context(self, adapter, clear_context): from kiln_ai.run_context import get_agent_run_id # Mock the _run method - async def mock_run(input, prior_trace=None): + async def mock_run(input, **kwargs): # Check that run ID is set during _run run_id = get_agent_run_id() assert run_id is not None @@ -838,7 +838,7 @@ async def test_invoke_clears_run_context_after(self, adapter, clear_context): from kiln_ai.run_context import get_agent_run_id # Mock the _run method - async def mock_run(input, prior_trace=None): + async def mock_run(input, **kwargs): return RunOutput(output="test output", intermediate_outputs={}), None adapter._run = mock_run @@ -876,7 +876,7 @@ async def test_invoke_clears_run_context_on_error(self, adapter, clear_context): from kiln_ai.run_context import get_agent_run_id # Mock the _run method to raise an error - async def mock_run(input, prior_trace=None): + async def mock_run(input, **kwargs): # Run ID should be set even when error occurs run_id = get_agent_run_id() assert run_id is not None @@ -913,7 +913,7 @@ async def test_sub_agent_inherits_run(self, adapter, clear_context): set_agent_run_id(parent_run_id) # Mock the _run method to check inherited run ID - async def mock_run(input, prior_trace=None): + async def mock_run(input, **kwargs): # Sub-agent should see parent's run ID run_id = get_agent_run_id() assert run_id == parent_run_id @@ -962,7 +962,7 @@ async def test_sub_agent_does_not_create_new_run(self, adapter, clear_context): run_id_during_run = None # Mock the _run method to capture run ID - async def mock_run(input, prior_trace=None): + async def mock_run(input, **kwargs): nonlocal run_id_during_run run_id_during_run = get_agent_run_id() return RunOutput(output="test output", intermediate_outputs={}), None @@ -1002,7 +1002,7 @@ async def test_cleanup_session_called_on_completion(self, adapter, clear_context from kiln_ai.adapters.run_output import RunOutput # Mock the _run method - async def mock_run(input, prior_trace=None): + async def mock_run(input, **kwargs): return RunOutput(output="test output", intermediate_outputs={}), None adapter._run = mock_run @@ -1045,3 +1045,122 @@ async def mock_run(input, prior_trace=None): assert call_args is not None run_id = call_args[0][0] if call_args[0] else call_args[1]["run_id"] assert run_id.startswith("run_") + + +class TestStreamCallback: + """Tests for the on_chunk streaming callback parameter.""" + + @pytest.fixture + def stream_adapter(self, base_task): + return MockAdapter( + task=base_task, + run_config=KilnAgentRunConfigProperties( + model_name="test_model", + model_provider_name="openai", + prompt_id="simple_prompt_builder", + structured_output_mode="json_schema", + ), + ) + + def _setup_adapter_mocks(self, adapter): + provider = MagicMock() + provider.parser = "test_parser" + provider.formatter = None + provider.reasoning_capable = False + adapter.model_provider = MagicMock(return_value=provider) + + @pytest.mark.asyncio + async def test_on_chunk_forwarded_to_run(self, stream_adapter): + """Test that on_chunk is passed through to _run.""" + received_kwargs = {} + + async def mock_run(input, **kwargs): + received_kwargs.update(kwargs) + return RunOutput(output="test output", intermediate_outputs={}), None + + stream_adapter._run = mock_run + self._setup_adapter_mocks(stream_adapter) + + callback = AsyncMock() + + parser = MagicMock() + parser.parse_output.return_value = RunOutput( + output="test output", intermediate_outputs={} + ) + + with ( + patch( + "kiln_ai.adapters.model_adapters.base_adapter.model_parser_from_id" + ) as mock_parser_factory, + patch( + "kiln_ai.adapters.model_adapters.base_adapter.request_formatter_from_id" + ), + ): + mock_parser_factory.return_value = parser + await stream_adapter.invoke_returning_run_output( + {"test": "input"}, on_chunk=callback + ) + + assert received_kwargs.get("on_chunk") is callback + + @pytest.mark.asyncio + async def test_on_chunk_none_by_default(self, stream_adapter): + """Test that on_chunk defaults to None when not provided.""" + received_kwargs = {} + + async def mock_run(input, **kwargs): + received_kwargs.update(kwargs) + return RunOutput(output="test output", intermediate_outputs={}), None + + stream_adapter._run = mock_run + self._setup_adapter_mocks(stream_adapter) + + parser = MagicMock() + parser.parse_output.return_value = RunOutput( + output="test output", intermediate_outputs={} + ) + + with ( + patch( + "kiln_ai.adapters.model_adapters.base_adapter.model_parser_from_id" + ) as mock_parser_factory, + patch( + "kiln_ai.adapters.model_adapters.base_adapter.request_formatter_from_id" + ), + ): + mock_parser_factory.return_value = parser + await stream_adapter.invoke_returning_run_output({"test": "input"}) + + assert received_kwargs.get("on_chunk") is None + + @pytest.mark.asyncio + async def test_invoke_forwards_on_chunk(self, stream_adapter): + """Test that invoke() also forwards on_chunk.""" + received_kwargs = {} + + async def mock_run(input, **kwargs): + received_kwargs.update(kwargs) + return RunOutput(output="test output", intermediate_outputs={}), None + + stream_adapter._run = mock_run + self._setup_adapter_mocks(stream_adapter) + + callback = AsyncMock() + + parser = MagicMock() + parser.parse_output.return_value = RunOutput( + output="test output", intermediate_outputs={} + ) + + with ( + patch( + "kiln_ai.adapters.model_adapters.base_adapter.model_parser_from_id" + ) as mock_parser_factory, + patch( + "kiln_ai.adapters.model_adapters.base_adapter.request_formatter_from_id" + ), + ): + mock_parser_factory.return_value = parser + await stream_adapter.invoke({"test": "input"}, on_chunk=callback) + + assert received_kwargs.get("on_chunk") is callback diff --git a/libs/core/kiln_ai/adapters/model_adapters/test_litellm_adapter.py b/libs/core/kiln_ai/adapters/model_adapters/test_litellm_adapter.py index 3b44812a6..b47d2d46e 100644 --- a/libs/core/kiln_ai/adapters/model_adapters/test_litellm_adapter.py +++ b/libs/core/kiln_ai/adapters/model_adapters/test_litellm_adapter.py @@ -1212,7 +1212,11 @@ async def test_array_input_converted_to_json(tmp_path, config): mock_config_obj.user_id = "test_user" with ( - patch("litellm.acompletion", new=AsyncMock(return_value=mock_response)), + patch.object( + LiteLlmAdapter, + "acompletion_checking_response", + new=AsyncMock(return_value=(mock_response, mock_response.choices[0])), + ), patch("kiln_ai.utils.config.Config.shared", return_value=mock_config_obj), ): array_input = [1, 2, 3, 4, 5] @@ -1282,7 +1286,11 @@ async def test_dict_input_converted_to_json(tmp_path, config): mock_config_obj.user_id = "test_user" with ( - patch("litellm.acompletion", new=AsyncMock(return_value=mock_response)), + patch.object( + LiteLlmAdapter, + "acompletion_checking_response", + new=AsyncMock(return_value=(mock_response, mock_response.choices[0])), + ), patch("kiln_ai.utils.config.Config.shared", return_value=mock_config_obj), ): dict_input = {"x": 10, "y": 20} @@ -1336,7 +1344,7 @@ def capturing_build(input, prior_trace_arg=None): adapter.build_chat_formatter = capturing_build async def mock_run_model_turn( - provider, prior_messages, top_logprobs, skip_response_format + provider, prior_messages, top_logprobs, skip_response_format, on_chunk=None ): extended = list(prior_messages) extended.append({"role": "assistant", "content": "How can I help?"}) diff --git a/libs/core/kiln_ai/adapters/model_adapters/test_litellm_adapter_tools.py b/libs/core/kiln_ai/adapters/model_adapters/test_litellm_adapter_tools.py index fb7bc4c21..bc15622ed 100644 --- a/libs/core/kiln_ai/adapters/model_adapters/test_litellm_adapter_tools.py +++ b/libs/core/kiln_ai/adapters/model_adapters/test_litellm_adapter_tools.py @@ -1,6 +1,6 @@ import json from pathlib import Path -from unittest.mock import Mock, patch +from unittest.mock import AsyncMock, Mock, patch import pytest from litellm.types.utils import ModelResponse @@ -89,7 +89,15 @@ async def run_simple_task_with_tools( with patch.object(adapter, "available_tools", return_value=mock_math_tools): if simplified: - run = await adapter.invoke("what is 2+2") + # test our chunking handler also works e2e on real models + received_chunks = [] + + async def on_chunk_handler(chunk): + received_chunks.append(chunk) + + run = await adapter.invoke("what is 2+2", on_chunk=on_chunk_handler) + + assert len(received_chunks) > 0 # Verify that AddTool.run was called with correct parameters add_spy.run.assert_called() @@ -287,10 +295,19 @@ async def test_tools_simplied_mocked(tmp_path): mock_config.open_ai_api_key = "mock_api_key" mock_config.user_id = "test_user" + responses = [mock_response_1, mock_response_2] + + async def mock_acompletion_checking_response(self, on_chunk=None, **kwargs): + if on_chunk is not None: + await on_chunk(Mock()) + response = responses.pop(0) + return response, response.choices[0] + with ( - patch( - "litellm.acompletion", - side_effect=[mock_response_1, mock_response_2], + patch.object( + LiteLlmAdapter, + "acompletion_checking_response", + new=mock_acompletion_checking_response, ), patch("kiln_ai.utils.config.Config.shared", return_value=mock_config), ): @@ -386,9 +403,16 @@ async def test_tools_mocked(tmp_path): mock_config.user_id = "test_user" with ( - patch( - "litellm.acompletion", - side_effect=[mock_response_1, mock_response_2, mock_response_3], + patch.object( + LiteLlmAdapter, + "acompletion_checking_response", + new=AsyncMock( + side_effect=[ + (mock_response_1, mock_response_1.choices[0]), + (mock_response_2, mock_response_2.choices[0]), + (mock_response_3, mock_response_3.choices[0]), + ] + ), ), patch("kiln_ai.utils.config.Config.shared", return_value=mock_config), ): diff --git a/libs/core/kiln_ai/adapters/model_adapters/test_saving_adapter_results.py b/libs/core/kiln_ai/adapters/model_adapters/test_saving_adapter_results.py index f96c70a41..43df8c522 100644 --- a/libs/core/kiln_ai/adapters/model_adapters/test_saving_adapter_results.py +++ b/libs/core/kiln_ai/adapters/model_adapters/test_saving_adapter_results.py @@ -10,11 +10,7 @@ class MockAdapter(BaseAdapter): - async def _run( - self, - input: InputType, - prior_trace=None, - ) -> tuple[RunOutput, Usage | None]: + async def _run(self, input: InputType, **kwargs) -> tuple[RunOutput, Usage | None]: return RunOutput(output="Test output", intermediate_outputs=None), None def adapter_name(self) -> str: @@ -263,7 +259,8 @@ async def test_invoke_continue_session(test_task, adapter): run_id = initial_run.id assert run_id is not None - async def mock_run(input, prior_trace=None): + async def mock_run(input, **kwargs): + prior_trace = kwargs.get("prior_trace") if prior_trace is not None: extended_trace = [ *prior_trace, diff --git a/libs/core/kiln_ai/adapters/model_adapters/test_structured_output.py b/libs/core/kiln_ai/adapters/model_adapters/test_structured_output.py index 92d9a1f4d..46fa88aa7 100644 --- a/libs/core/kiln_ai/adapters/model_adapters/test_structured_output.py +++ b/libs/core/kiln_ai/adapters/model_adapters/test_structured_output.py @@ -1,7 +1,7 @@ import json from pathlib import Path from typing import Dict -from unittest.mock import Mock, patch +from unittest.mock import AsyncMock, Mock, patch import pytest from litellm.types.utils import ModelResponse @@ -10,6 +10,7 @@ from kiln_ai.adapters.adapter_registry import adapter_for_task from kiln_ai.adapters.ml_model_list import built_in_models from kiln_ai.adapters.model_adapters.base_adapter import BaseAdapter, RunOutput, Usage +from kiln_ai.adapters.model_adapters.litellm_adapter import LiteLlmAdapter from kiln_ai.adapters.ollama_tools import ollama_online from kiln_ai.adapters.test_prompt_adaptors import get_all_models_and_providers from kiln_ai.datamodel import PromptId @@ -53,11 +54,7 @@ def __init__(self, kiln_task: datamodel.Task, response: InputType | None): ) self.response = response - async def _run( - self, - input: str, - prior_trace=None, - ) -> tuple[RunOutput, Usage | None]: + async def _run(self, input: str, **kwargs) -> tuple[RunOutput, Usage | None]: return RunOutput(output=self.response, intermediate_outputs=None), None def adapter_name(self) -> str: @@ -351,9 +348,10 @@ async def test_all_built_in_models_structured_input_mocked(tmp_path): mock_config.groq_api_key = "mock_api_key" with ( - patch( - "litellm.acompletion", - side_effect=[mock_response], + patch.object( + LiteLlmAdapter, + "acompletion_checking_response", + new=AsyncMock(return_value=(mock_response, mock_response.choices[0])), ), patch("kiln_ai.utils.config.Config.shared", return_value=mock_config), ): @@ -406,9 +404,15 @@ async def test_structured_input_cot_prompt_builder_mocked(tmp_path): mock_config.groq_api_key = "mock_api_key" with ( - patch( - "litellm.acompletion", - side_effect=[mock_response_1, mock_response_2], + patch.object( + LiteLlmAdapter, + "acompletion_checking_response", + new=AsyncMock( + side_effect=[ + (mock_response_1, mock_response_1.choices[0]), + (mock_response_2, mock_response_2.choices[0]), + ] + ), ), patch("kiln_ai.utils.config.Config.shared", return_value=mock_config), ): diff --git a/libs/core/kiln_ai/adapters/test_prompt_adaptors.py b/libs/core/kiln_ai/adapters/test_prompt_adaptors.py index 23dd15b0d..187de1ff2 100644 --- a/libs/core/kiln_ai/adapters/test_prompt_adaptors.py +++ b/libs/core/kiln_ai/adapters/test_prompt_adaptors.py @@ -1,9 +1,9 @@ import os from pathlib import Path -from unittest.mock import patch +from unittest.mock import AsyncMock, patch import pytest -from litellm.utils import ModelResponse +from litellm.types.utils import ModelResponse import kiln_ai.datamodel as datamodel from kiln_ai.adapters.adapter_registry import adapter_for_task @@ -113,13 +113,15 @@ async def test_amazon_bedrock(tmp_path): async def test_mock_returning_run(tmp_path): task = build_test_task(tmp_path) - with patch("litellm.acompletion") as mock_acompletion: - # Configure the mock to return a properly structured response - mock_acompletion.return_value = ModelResponse( - model="custom_model", - choices=[{"message": {"content": "mock response"}}], - ) - + mock_response = ModelResponse( + model="custom_model", + choices=[{"message": {"content": "mock response"}}], + ) + with patch.object( + LiteLlmAdapter, + "acompletion_checking_response", + new=AsyncMock(return_value=(mock_response, mock_response.choices[0])), + ): run_config = KilnAgentRunConfigProperties( model_name="custom_model", model_provider_name=ModelProviderName.ollama, diff --git a/libs/core/kiln_ai/adapters/test_prompt_builders.py b/libs/core/kiln_ai/adapters/test_prompt_builders.py index 25cadd99d..7a67ff5c9 100644 --- a/libs/core/kiln_ai/adapters/test_prompt_builders.py +++ b/libs/core/kiln_ai/adapters/test_prompt_builders.py @@ -58,11 +58,7 @@ def test_simple_prompt_builder(tmp_path): class MockAdapter(BaseAdapter): - async def _run( - self, - input: InputType, - prior_trace=None, - ) -> tuple[RunOutput, Usage | None]: + async def _run(self, input: InputType, **kwargs) -> tuple[RunOutput, Usage | None]: return RunOutput(output="mock response", intermediate_outputs=None), None def adapter_name(self) -> str: diff --git a/libs/core/kiln_ai/datamodel/test_basemodel.py b/libs/core/kiln_ai/datamodel/test_basemodel.py index 03c6dfeaf..897a3749f 100644 --- a/libs/core/kiln_ai/datamodel/test_basemodel.py +++ b/libs/core/kiln_ai/datamodel/test_basemodel.py @@ -862,7 +862,7 @@ def individual_lookups(): class MockAdapter(BaseAdapter): """Implementation of BaseAdapter for testing""" - async def _run(self, input, prior_trace=None): + async def _run(self, input, **kwargs): return RunOutput(output="test output", intermediate_outputs=None), None def adapter_name(self) -> str: From 3d6ced250f657dd56850fc53e93d5bc6a2f7b296 Mon Sep 17 00:00:00 2001 From: "Leonard Q. Marcq" Date: Tue, 3 Mar 2026 18:23:45 +0800 Subject: [PATCH 05/32] test: paid integration test for streaming --- .../test_litellm_adapter_streaming.py | 413 ++++++++++++++++++ 1 file changed, 413 insertions(+) create mode 100644 libs/core/kiln_ai/adapters/model_adapters/test_litellm_adapter_streaming.py diff --git a/libs/core/kiln_ai/adapters/model_adapters/test_litellm_adapter_streaming.py b/libs/core/kiln_ai/adapters/model_adapters/test_litellm_adapter_streaming.py new file mode 100644 index 000000000..9c2b7e9ef --- /dev/null +++ b/libs/core/kiln_ai/adapters/model_adapters/test_litellm_adapter_streaming.py @@ -0,0 +1,413 @@ +import json +import logging +from abc import ABC, abstractmethod +from pathlib import Path +from typing import Any, Callable, Tuple +from unittest.mock import patch + +import litellm +import pytest +from litellm.types.utils import ChatCompletionDeltaToolCall + +from kiln_ai.adapters.ml_model_list import ModelProviderName, StructuredOutputMode +from kiln_ai.adapters.model_adapters.litellm_adapter import LiteLlmAdapter +from kiln_ai.adapters.model_adapters.litellm_config import LiteLlmConfig +from kiln_ai.datamodel import Project, PromptGenerators, Task +from kiln_ai.datamodel.run_config import KilnAgentRunConfigProperties, ToolsRunConfig +from kiln_ai.datamodel.tool_id import KilnBuiltInToolId + +logger = logging.getLogger(__name__) + + +class ChunkRendererAbstract(ABC): + @abstractmethod + async def render_chunk(self, chunk: litellm.ModelResponseStream): + pass + + @abstractmethod + def get_stream_text(self) -> str: + pass + + +class ChunkRenderer(ChunkRendererAbstract): + def __init__(self): + self.chunk_texts: list[str] = [] + self.current_block_type: str | None = None + + def print_and_append(self, text: str): + # replace with print if your logger is not outputting info logs + logger.info(text) + self.chunk_texts.append(text) + + def enter_block(self, block_type: str): + if self.current_block_type != block_type: + if self.current_block_type is not None: + self.print_and_append(f"\n") + + self.print_and_append(f"\n<{block_type}>\n") + self.current_block_type = block_type + + def render_reasoning(self, reasoning_content: str): + self.enter_block("reasoning") + self.print_and_append(reasoning_content) + + def render_content(self, content: str): + self.enter_block("content") + self.print_and_append(content) + + def render_tool_call(self, tool_calls: list[ChatCompletionDeltaToolCall | Any]): + self.enter_block("tool_call") + for tool_call in tool_calls: + # first it says the tool name, then the arguments + if tool_call.function.name is not None: + self.print_and_append(f'Calling tool: "{tool_call.function.name}" ') + self.print_and_append("with args: ") + elif tool_call.function.arguments is not None: + args = tool_call.function.arguments + self.print_and_append(args) + + def render_stop(self, stop_reason: str): + self.print_and_append("\n") + + def render_unknown(self, chunk: litellm.ModelResponseStream): + self.enter_block("unknown") + self.print_and_append(f"Unknown chunk: {chunk}") + + async def render_chunk(self, chunk: litellm.ModelResponseStream): + if chunk.choices[0].finish_reason is not None: + self.render_stop(chunk.choices[0].finish_reason) + return + elif chunk.choices[0].delta is not None: + # inconsistent behavior between providers, some have multiple fields at once, some don't + if chunk.choices[0].delta.tool_calls is not None: + self.render_tool_call(chunk.choices[0].delta.tool_calls) + elif getattr(chunk.choices[0].delta, "reasoning_content", None) is not None: + text = getattr(chunk.choices[0].delta, "reasoning_content", None) + if text is not None: + self.render_reasoning(text) + elif chunk.choices[0].delta.content is not None: + self.render_content(chunk.choices[0].delta.content) + else: + self.render_unknown(chunk) + + def get_stream_text(self) -> str: + return "".join(self.chunk_texts) + + +class ChunkRawRenderer(ChunkRendererAbstract): + def __init__(self): + self.chunks: list[litellm.ModelResponseStream] = [] + self.current_block_type: str | None = None + + async def render_chunk(self, chunk: litellm.ModelResponseStream): + logger.info(str(chunk)) + self.chunks.append(chunk) + + def get_stream_text(self) -> str: + return "\n".join([str(chunk) for chunk in self.chunks]) + + +@pytest.fixture +def task(tmp_path): + project_path: Path = tmp_path / "test_project" / "project.kiln" + project_path.parent.mkdir() + + project = Project(name="Test Project", path=project_path) + project.save_to_file() + + task = Task( + name="Streaming Test Task", + instruction="Think about it hard! Solve the math problem provided by the user, in a step by step manner. Use the tools provided to solve the math problem. Then use the result in a short sentence about a cat going to the mall. Remember to use the tools for math even if the operation looks easy.", + parent=project, + ) + task.save_to_file() + return task + + +@pytest.fixture +def adapter_factory(task: Task) -> Callable[[str, ModelProviderName], LiteLlmAdapter]: + def create_adapter( + model_id: str, provider_name: ModelProviderName + ) -> LiteLlmAdapter: + adapter = LiteLlmAdapter( + kiln_task=task, + config=LiteLlmConfig( + run_config_properties=KilnAgentRunConfigProperties( + model_name=model_id, + model_provider_name=provider_name, + prompt_id=PromptGenerators.SIMPLE, + structured_output_mode=StructuredOutputMode.unknown, + tools_config=ToolsRunConfig( + tools=[ + KilnBuiltInToolId.ADD_NUMBERS, + KilnBuiltInToolId.SUBTRACT_NUMBERS, + KilnBuiltInToolId.MULTIPLY_NUMBERS, + KilnBuiltInToolId.DIVIDE_NUMBERS, + ], + ), + ) + ), + ) + return adapter + + return create_adapter + + +@pytest.mark.paid +@pytest.mark.parametrize( + "model_id,provider_name", + [ + ("claude_sonnet_4_5", ModelProviderName.openrouter), + ("claude_sonnet_4_5", ModelProviderName.anthropic), + ("claude_sonnet_4_6", ModelProviderName.openrouter), + ("claude_sonnet_4_6", ModelProviderName.anthropic), + ("claude_opus_4_5", ModelProviderName.openrouter), + ("claude_opus_4_5", ModelProviderName.anthropic), + ("claude_opus_4_6", ModelProviderName.openrouter), + ("claude_opus_4_6", ModelProviderName.anthropic), + ("minimax_m2_5", ModelProviderName.openrouter), + ("claude_4_5_haiku", ModelProviderName.openrouter), + ("claude_4_5_haiku", ModelProviderName.anthropic), + ], +) +async def test_acompletion_streaming_response( + model_id: str, + provider_name: ModelProviderName, + adapter_factory: Callable[[str, ModelProviderName], LiteLlmAdapter], +): + """Check the accumulated response has all the expected parts""" + adapter = adapter_factory(model_id, provider_name) + + renderer = ChunkRenderer() + + # we proxy all the calls to the original function so we can spy on the return values + captured_responses: list[Tuple[litellm.ModelResponse, litellm.Choices]] = [] + origin_func = adapter.acompletion_checking_response + + async def spy( + *args: Any, **kwargs: Any + ) -> Tuple[litellm.ModelResponse, litellm.Choices]: + nonlocal captured_responses + + result = await origin_func(*args, **kwargs) + captured_responses.append(result) + return result + + with patch.object(adapter, "acompletion_checking_response", side_effect=spy): + task_run = await adapter.invoke( + input="123 + 321 = ?", + on_chunk=renderer.render_chunk, + ) + + # there is one call per thing going on (tool call, content, etc.) + # with our toy task, we expect ~2 or 3 calls (reasoning + tool call -> content) + if len(captured_responses) == 0: + raise RuntimeError( + "captured_responses is empty after invocation - test probably broken due to wrong spy" + ) + + # check we are getting the trace successfully + assert task_run.trace is not None, "Task run trace is None" + assert len(task_run.trace) > 0, "Task run trace is empty" + + assistant_messages: list[litellm.Message] = [] + for model_response, _ in captured_responses: + for choice in model_response.choices: + if isinstance(choice, litellm.Choices): + assistant_messages.append(choice.message) + assert len(assistant_messages) > 0, "No assistant messages found in the trace" + + # we do not know which message the reasoning / content / tool call is in, but we know each one + # should appear in at least one message so we accumulate them here + reasoning_contents: list[str] = [] + contents: list[str] = [] + tool_calls: list[ChatCompletionDeltaToolCall | Any] = [] + for assistant_message in assistant_messages: + reasoning_content = getattr(assistant_message, "reasoning_content", None) + if reasoning_content: + reasoning_contents.append(reasoning_content) + + content = getattr(assistant_message, "content", None) + if content: + contents.append(str(content)) + + _tool_calls = getattr(assistant_message, "tool_calls", None) + if _tool_calls: + tool_calls.extend(_tool_calls) + + # check we got all the expected parts somewhere + assert len(reasoning_contents) > 0, "No reasoning contents found in the trace" + assert len(contents) > 0, "No contents found in the trace" + assert len(tool_calls) > 0, "No tool calls found in the trace" + assert len(tool_calls) == 1, "Expected exactly one tool call (to do the math)" + + # check we got some non-empty reasoning - we should have gotten some reasoning at least somewhere + # usually the toolcall + assert not all( + reasoning_content.strip() == "" for reasoning_content in reasoning_contents + ), "All reasoning contents are empty" + + # check we got some non-empty content (we get empty strings when there is no content) + assert not all(content.strip() == "" for content in contents), ( + "All contents are empty" + ) + + for tool_call in tool_calls: + assert tool_call.function.name is not None, "Tool call name is None" + assert tool_call.function.arguments is not None, "Tool call arguments are None" + assert json.loads(tool_call.function.arguments) is not None, ( + "Tool call arguments are not JSON" + ) + tool_call_args = json.loads(tool_call.function.arguments) + assert tool_call_args == { + "a": 123, + "b": 321, + } or tool_call_args == { + "a": 321, + "b": 123, + }, f"Tool call arguments are not the expected values: {tool_call_args}" + + +@pytest.mark.paid +@pytest.mark.parametrize( + "model_id,provider_name", + [ + ("claude_sonnet_4_5", ModelProviderName.openrouter), + ("claude_sonnet_4_5", ModelProviderName.anthropic), + ("claude_sonnet_4_6", ModelProviderName.openrouter), + ("claude_sonnet_4_6", ModelProviderName.anthropic), + ("claude_opus_4_5", ModelProviderName.openrouter), + ("claude_opus_4_5", ModelProviderName.anthropic), + ("claude_opus_4_6", ModelProviderName.openrouter), + ("claude_opus_4_6", ModelProviderName.anthropic), + ("minimax_m2_5", ModelProviderName.openrouter), + ("claude_4_5_haiku", ModelProviderName.openrouter), + ("claude_4_5_haiku", ModelProviderName.anthropic), + ], +) +async def test_acompletion_streaming_chunks( + model_id: str, + provider_name: ModelProviderName, + adapter_factory: Callable[[str, ModelProviderName], LiteLlmAdapter], +): + """Collect all chunks from all completion calls, then one pass to check we got reasoning, content, and tool calls.""" + + adapter = adapter_factory(model_id, provider_name) + + chunks: list[litellm.ModelResponseStream] = [] + + renderer = ChunkRenderer() + + async def collect_chunks(chunk: litellm.ModelResponseStream) -> None: + chunks.append(chunk) + await renderer.render_chunk(chunk) + + await adapter.invoke(input="123 + 321 = ?", on_chunk=collect_chunks) + + assert len(chunks) > 0, "No chunks collected" + reasoning_contents: list[str] = [] + contents: list[str] = [] + tool_calls: list[ChatCompletionDeltaToolCall | Any] = [] + + for chunk in chunks: + if chunk.choices[0].finish_reason is not None: + continue + delta = chunk.choices[0].delta + if delta is None: + continue + if delta.tool_calls is not None: + tool_calls.extend(delta.tool_calls) + elif getattr(delta, "reasoning_content", None) is not None: + text = getattr(delta, "reasoning_content", None) + if text is not None: + reasoning_contents.append(text) + elif delta.content is not None: + contents.append(delta.content) + + assert len(reasoning_contents) > 0, "No reasoning content in chunks" + assert len(contents) > 0, "No content in chunks" + assert len(tool_calls) > 0, "No tool calls in chunks" + assert not all(r.strip() == "" for r in reasoning_contents), ( + "All reasoning content in chunks is empty" + ) + assert not all(c.strip() == "" for c in contents), "All content in chunks is empty" + + tool_call_function_names = [ + tool_call.function.name + for tool_call in tool_calls + if tool_call.function.name is not None + ] + assert len(tool_call_function_names) == 1, ( + "Expected exactly one tool call function name" + ) + assert tool_call_function_names[0] == "add", "Tool call function name is not 'add'" + + tool_call_args_chunks = "".join( + [ + tool_call.function.arguments + for tool_call in tool_calls + if tool_call.function.arguments is not None + ] + ) + + tool_call_args = json.loads(tool_call_args_chunks) + assert tool_call_args == {"a": 123, "b": 321} or tool_call_args == { + "a": 321, + "b": 123, + }, f"Tool call arguments not as expected: {tool_call_args}" + + +@pytest.mark.paid +@pytest.mark.parametrize( + "model_id,provider_name", + [ + ("claude_sonnet_4_5", ModelProviderName.openrouter), + ("claude_sonnet_4_5", ModelProviderName.anthropic), + ("claude_sonnet_4_6", ModelProviderName.openrouter), + ("claude_sonnet_4_6", ModelProviderName.anthropic), + ("claude_opus_4_5", ModelProviderName.openrouter), + ("claude_opus_4_5", ModelProviderName.anthropic), + ("claude_opus_4_6", ModelProviderName.openrouter), + ("claude_opus_4_6", ModelProviderName.anthropic), + ("minimax_m2_5", ModelProviderName.openrouter), + ("claude_4_5_haiku", ModelProviderName.openrouter), + ("claude_4_5_haiku", ModelProviderName.anthropic), + ], +) +async def test_acompletion_streaming_rendering( + model_id: str, + provider_name: ModelProviderName, + adapter_factory: Callable[[str, ModelProviderName], LiteLlmAdapter], +): + """Test that the streaming response with a renderer to see how it looks""" + adapter = adapter_factory(model_id, provider_name) + renderer = ChunkRenderer() + await adapter.invoke(input="123 + 321 = ?", on_chunk=renderer.render_chunk) + assert renderer.get_stream_text() is not None + + +@pytest.mark.paid +@pytest.mark.parametrize( + "model_id,provider_name", + [ + ("claude_sonnet_4_5", ModelProviderName.openrouter), + ("claude_sonnet_4_5", ModelProviderName.anthropic), + ("claude_sonnet_4_6", ModelProviderName.openrouter), + ("claude_sonnet_4_6", ModelProviderName.anthropic), + ("claude_opus_4_5", ModelProviderName.openrouter), + ("claude_opus_4_5", ModelProviderName.anthropic), + ("claude_opus_4_6", ModelProviderName.openrouter), + ("claude_opus_4_6", ModelProviderName.anthropic), + ("minimax_m2_5", ModelProviderName.openrouter), + ], +) +async def test_acompletion_streaming_rendering_raw_chunks( + model_id: str, + provider_name: ModelProviderName, + adapter_factory: Callable[[str, ModelProviderName], LiteLlmAdapter], +): + """Test that the streaming response with a renderer to see how it looks, but with raw chunks""" + adapter = adapter_factory(model_id, provider_name) + renderer = ChunkRawRenderer() + await adapter.invoke(input="123 + 321 = ?", on_chunk=renderer.render_chunk) + assert renderer.get_stream_text() is not None From 1464fb84c58383d7404e784af6afbfdbd9f24867 Mon Sep 17 00:00:00 2001 From: "Leonard Q. Marcq" Date: Tue, 3 Mar 2026 18:30:02 +0800 Subject: [PATCH 06/32] test: add test for session + streaming together --- .../test_litellm_adapter_streaming.py | 46 +++++++++++++++++++ 1 file changed, 46 insertions(+) diff --git a/libs/core/kiln_ai/adapters/model_adapters/test_litellm_adapter_streaming.py b/libs/core/kiln_ai/adapters/model_adapters/test_litellm_adapter_streaming.py index 9c2b7e9ef..8dbc2118d 100644 --- a/libs/core/kiln_ai/adapters/model_adapters/test_litellm_adapter_streaming.py +++ b/libs/core/kiln_ai/adapters/model_adapters/test_litellm_adapter_streaming.py @@ -411,3 +411,49 @@ async def test_acompletion_streaming_rendering_raw_chunks( renderer = ChunkRawRenderer() await adapter.invoke(input="123 + 321 = ?", on_chunk=renderer.render_chunk) assert renderer.get_stream_text() is not None + + +@pytest.mark.paid +@pytest.mark.parametrize( + "model_id,provider_name", + [ + ("claude_sonnet_4_5", ModelProviderName.openrouter), + ("claude_sonnet_4_5", ModelProviderName.anthropic), + ("claude_sonnet_4_6", ModelProviderName.openrouter), + ("claude_sonnet_4_6", ModelProviderName.anthropic), + ("claude_opus_4_5", ModelProviderName.openrouter), + ("claude_opus_4_5", ModelProviderName.anthropic), + ("claude_opus_4_6", ModelProviderName.openrouter), + ("claude_opus_4_6", ModelProviderName.anthropic), + ("minimax_m2_5", ModelProviderName.openrouter), + ], +) +async def test_acompletion_streaming_with_existing_run( + model_id: str, + provider_name: ModelProviderName, + adapter_factory: Callable[[str, ModelProviderName], LiteLlmAdapter], +): + """Test that streaming works when continuing an existing run (session continuation).""" + adapter = adapter_factory(model_id, provider_name) + renderer = ChunkRawRenderer() + + initial_run = await adapter.invoke( + input="123 + 321 = ?", + on_chunk=renderer.render_chunk, + ) + assert initial_run.trace is not None + assert len(initial_run.trace) > 0 + initial_trace_len = len(initial_run.trace) + + continuation_renderer = ChunkRawRenderer() + continued_run = await adapter.invoke( + input="What was the result? Reply in one short sentence.", + existing_run=initial_run, + on_chunk=continuation_renderer.render_chunk, + ) + + assert continued_run.id == initial_run.id + assert continued_run.trace is not None + assert len(continued_run.trace) > initial_trace_len + assert continuation_renderer.get_stream_text() is not None + assert len(continuation_renderer.chunks) > 0 From dee10b372c5d4140275081166dea78f356d0212b Mon Sep 17 00:00:00 2001 From: "Leonard Q. Marcq" Date: Tue, 3 Mar 2026 18:37:25 +0800 Subject: [PATCH 07/32] Update libs/core/kiln_ai/adapters/model_adapters/test_litellm_adapter_streaming.py Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --- .../adapters/model_adapters/test_litellm_adapter_streaming.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/libs/core/kiln_ai/adapters/model_adapters/test_litellm_adapter_streaming.py b/libs/core/kiln_ai/adapters/model_adapters/test_litellm_adapter_streaming.py index 8dbc2118d..62e16ce2a 100644 --- a/libs/core/kiln_ai/adapters/model_adapters/test_litellm_adapter_streaming.py +++ b/libs/core/kiln_ai/adapters/model_adapters/test_litellm_adapter_streaming.py @@ -62,7 +62,7 @@ def render_tool_call(self, tool_calls: list[ChatCompletionDeltaToolCall | Any]): if tool_call.function.name is not None: self.print_and_append(f'Calling tool: "{tool_call.function.name}" ') self.print_and_append("with args: ") - elif tool_call.function.arguments is not None: + if tool_call.function.arguments is not None: args = tool_call.function.arguments self.print_and_append(args) From 456088de5c1f4e8688936bf221ba79044b8f1694 Mon Sep 17 00:00:00 2001 From: "Leonard Q. Marcq" Date: Sun, 8 Mar 2026 15:07:32 +0800 Subject: [PATCH 08/32] refactor: stream with support for AI SDK (with tool events) and OpenAI protocols --- .gitignore | 2 + .../adapters/model_adapters/adapter_stream.py | 291 +++++++++++ .../adapters/model_adapters/base_adapter.py | 240 ++++++++- .../model_adapters/litellm_adapter.py | 39 +- .../adapters/model_adapters/mcp_adapter.py | 6 +- .../adapters/model_adapters/stream_events.py | 289 +++++++++++ .../model_adapters/test_adapter_stream.py | 372 ++++++++++++++ .../model_adapters/test_base_adapter.py | 119 +---- .../model_adapters/test_litellm_adapter.py | 2 +- .../test_litellm_adapter_streaming.py | 473 ++++++------------ .../test_litellm_adapter_tools.py | 14 +- .../model_adapters/test_stream_events.py | 227 +++++++++ 12 files changed, 1622 insertions(+), 452 deletions(-) create mode 100644 libs/core/kiln_ai/adapters/model_adapters/adapter_stream.py create mode 100644 libs/core/kiln_ai/adapters/model_adapters/stream_events.py create mode 100644 libs/core/kiln_ai/adapters/model_adapters/test_adapter_stream.py create mode 100644 libs/core/kiln_ai/adapters/model_adapters/test_stream_events.py diff --git a/.gitignore b/.gitignore index a4b8f5f71..ce76a0841 100644 --- a/.gitignore +++ b/.gitignore @@ -21,3 +21,5 @@ libs/server/build dist/ .mcp.json + +test_output/ diff --git a/libs/core/kiln_ai/adapters/model_adapters/adapter_stream.py b/libs/core/kiln_ai/adapters/model_adapters/adapter_stream.py new file mode 100644 index 000000000..db4c4642a --- /dev/null +++ b/libs/core/kiln_ai/adapters/model_adapters/adapter_stream.py @@ -0,0 +1,291 @@ +from __future__ import annotations + +import copy +import json +import logging +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any, AsyncIterator + +from litellm.types.utils import ( + ChatCompletionMessageToolCall, + Choices, + ModelResponse, +) + +from kiln_ai.adapters.chat import ChatCompletionMessageIncludingLiteLLM +from kiln_ai.adapters.chat.chat_formatter import ChatFormatter +from kiln_ai.adapters.litellm_utils.litellm_streaming import StreamingCompletion +from kiln_ai.adapters.ml_model_list import KilnModelProvider +from kiln_ai.adapters.model_adapters.stream_events import ( + AdapterStreamEvent, + ToolCallEvent, + ToolCallEventType, +) +from kiln_ai.adapters.run_output import RunOutput +from kiln_ai.datamodel import Usage + +if TYPE_CHECKING: + from kiln_ai.adapters.model_adapters.litellm_adapter import LiteLlmAdapter + +MAX_CALLS_PER_TURN = 10 +MAX_TOOL_CALLS_PER_TURN = 30 + +logger = logging.getLogger(__name__) + + +@dataclass +class AdapterStreamResult: + run_output: RunOutput + usage: Usage + + +class AdapterStream: + """ + Orchestrates a full task execution as an async iterator, + composing StreamingCompletion instances across chat turns and tool-call rounds. + + Yields ``ModelResponseStream`` chunks from each LLM call and + ``ToolCallEvent`` instances between tool-call rounds. + + After iteration completes the ``result`` property provides the + ``AdapterStreamResult`` with the final ``RunOutput`` and ``Usage``. + """ + + def __init__( + self, + adapter: LiteLlmAdapter, + provider: KilnModelProvider, + chat_formatter: ChatFormatter, + initial_messages: list[ChatCompletionMessageIncludingLiteLLM], + top_logprobs: int | None, + ) -> None: + self._adapter = adapter + self._provider = provider + self._chat_formatter = chat_formatter + self._messages = initial_messages + self._top_logprobs = top_logprobs + self._result: AdapterStreamResult | None = None + self._iterated = False + + @property + def result(self) -> AdapterStreamResult: + if not self._iterated: + raise RuntimeError( + "AdapterStream has not been iterated yet. " + "Use 'async for event in stream:' before accessing .result" + ) + if self._result is None: + raise RuntimeError("AdapterStream completed without producing a result") + return self._result + + async def __aiter__(self) -> AsyncIterator[AdapterStreamEvent]: + self._result = None + self._iterated = False + + usage = Usage() + prior_output: str | None = None + final_choice: Choices | None = None + turns = 0 + + while True: + turns += 1 + if turns > MAX_CALLS_PER_TURN: + raise RuntimeError( + f"Too many turns ({turns}). Stopping iteration to avoid using too many tokens." + ) + + turn = self._chat_formatter.next_turn(prior_output) + if turn is None: + break + + for message in turn.messages: + if message.content is None: + raise ValueError("Empty message content isn't allowed") + self._messages.append( + {"role": message.role, "content": message.content} # type: ignore[arg-type] + ) + + skip_response_format = not turn.final_call + turn_top_logprobs = self._top_logprobs if turn.final_call else None + + async for event in self._stream_model_turn( + skip_response_format, turn_top_logprobs + ): + if isinstance(event, _ModelTurnComplete): + usage += event.usage + prior_output = event.assistant_message + final_choice = event.model_choice + else: + yield event + + if not prior_output: + raise RuntimeError("No assistant message/output returned from model") + + logprobs = self._adapter._extract_and_validate_logprobs(final_choice) + + intermediate_outputs = self._chat_formatter.intermediate_outputs() + self._adapter._extract_reasoning_to_intermediate_outputs( + final_choice, intermediate_outputs + ) + + if not isinstance(prior_output, str): + raise RuntimeError(f"assistant message is not a string: {prior_output}") + + trace = self._adapter.all_messages_to_trace(self._messages) + self._result = AdapterStreamResult( + run_output=RunOutput( + output=prior_output, + intermediate_outputs=intermediate_outputs, + output_logprobs=logprobs, + trace=trace, + ), + usage=usage, + ) + self._iterated = True + + async def _stream_model_turn( + self, + skip_response_format: bool, + top_logprobs: int | None, + ) -> AsyncIterator[AdapterStreamEvent | _ModelTurnComplete]: + usage = Usage() + tool_calls_count = 0 + + while tool_calls_count < MAX_TOOL_CALLS_PER_TURN: + completion_kwargs = await self._adapter.build_completion_kwargs( + self._provider, + copy.deepcopy(self._messages), + top_logprobs, + skip_response_format, + ) + + stream = StreamingCompletion(**completion_kwargs) + async for chunk in stream: + yield chunk + + response, response_choice = _validate_response(stream.response) + usage += self._adapter.usage_from_response(response) + + content = response_choice.message.content + tool_calls = response_choice.message.tool_calls + if not content and not tool_calls: + raise ValueError( + "Model returned an assistant message, but no content or tool calls. This is not supported." + ) + + self._messages.append(response_choice.message) + + if tool_calls and len(tool_calls) > 0: + async for event in self._handle_tool_calls(tool_calls): + yield event + + assistant_msg = self._extract_task_response(tool_calls) + if assistant_msg is not None: + yield _ModelTurnComplete( + assistant_message=assistant_msg, + model_choice=response_choice, + usage=usage, + ) + return + + tool_calls_count += 1 + continue + + if content: + yield _ModelTurnComplete( + assistant_message=content, + model_choice=response_choice, + usage=usage, + ) + return + + raise RuntimeError( + "Model returned neither content nor tool calls. It must return at least one of these." + ) + + raise RuntimeError( + f"Too many tool calls ({tool_calls_count}). Stopping iteration to avoid using too many tokens." + ) + + async def _handle_tool_calls( + self, + tool_calls: list[ChatCompletionMessageToolCall], + ) -> AsyncIterator[AdapterStreamEvent]: + real_tool_calls = [ + tc for tc in tool_calls if tc.function.name != "task_response" + ] + + for tc in real_tool_calls: + try: + parsed_args = json.loads(tc.function.arguments) + except (json.JSONDecodeError, TypeError): + parsed_args = None + + yield ToolCallEvent( + event_type=ToolCallEventType.INPUT_AVAILABLE, + tool_call_id=tc.id, + tool_name=tc.function.name or "unknown", + arguments=parsed_args, + error=( + f"Failed to parse arguments: {tc.function.arguments}" + if parsed_args is None + else None + ), + ) + + _, tool_msgs = await self._adapter.process_tool_calls(tool_calls) + + for tool_msg in tool_msgs: + tc_id = tool_msg["tool_call_id"] + tc_name = _find_tool_name(tool_calls, tc_id) + content = tool_msg["content"] + yield ToolCallEvent( + event_type=ToolCallEventType.OUTPUT_AVAILABLE, + tool_call_id=tc_id, + tool_name=tc_name, + result=str(content) if content is not None else None, + ) + + self._messages.extend(tool_msgs) + + @staticmethod + def _extract_task_response( + tool_calls: list[ChatCompletionMessageToolCall], + ) -> str | None: + for tc in tool_calls: + if tc.function.name == "task_response": + return tc.function.arguments + return None + + +@dataclass +class _ModelTurnComplete: + """Internal sentinel yielded when a model turn finishes.""" + + assistant_message: str + model_choice: Choices | None + usage: Usage + + +def _validate_response( + response: Any, +) -> tuple[ModelResponse, Choices]: + if ( + not isinstance(response, ModelResponse) + or not response.choices + or len(response.choices) == 0 + or not isinstance(response.choices[0], Choices) + ): + raise RuntimeError( + f"Expected ModelResponse with Choices, got {type(response)}." + ) + return response, response.choices[0] + + +def _find_tool_name( + tool_calls: list[ChatCompletionMessageToolCall], tool_call_id: str +) -> str: + for tc in tool_calls: + if tc.id == tool_call_id: + return tc.function.name or "unknown" + return "unknown" diff --git a/libs/core/kiln_ai/adapters/model_adapters/base_adapter.py b/libs/core/kiln_ai/adapters/model_adapters/base_adapter.py index 6c2863502..e6a301401 100644 --- a/libs/core/kiln_ai/adapters/model_adapters/base_adapter.py +++ b/libs/core/kiln_ai/adapters/model_adapters/base_adapter.py @@ -1,8 +1,10 @@ +from __future__ import annotations + import json +import uuid from abc import ABCMeta, abstractmethod -from collections.abc import Awaitable, Callable from dataclasses import dataclass -from typing import Dict, Tuple +from typing import TYPE_CHECKING, AsyncIterator, Dict, Tuple from litellm.types.utils import ModelResponseStream @@ -16,6 +18,13 @@ StructuredOutputMode, default_structured_output_mode_for_model_provider, ) +from kiln_ai.adapters.model_adapters.adapter_stream import AdapterStreamResult +from kiln_ai.adapters.model_adapters.stream_events import ( + AiSdkEventType, + AiSdkStreamConverter, + AiSdkStreamEvent, + ToolCallEvent, +) from kiln_ai.adapters.parsers.json_parser import parse_json_string from kiln_ai.adapters.parsers.parser_registry import model_parser_from_id from kiln_ai.adapters.parsers.request_formatters import request_formatter_from_id @@ -52,7 +61,8 @@ from kiln_ai.utils.exhaustive_error import raise_exhaustive_enum_error from kiln_ai.utils.open_ai_types import ChatCompletionMessageParam -StreamCallback = Callable[[ModelResponseStream], Awaitable[None]] +if TYPE_CHECKING: + from kiln_ai.adapters.model_adapters.adapter_stream import AdapterStream @dataclass @@ -133,10 +143,9 @@ async def invoke( input: InputType, input_source: DataSource | None = None, existing_run: TaskRun | None = None, - on_chunk: StreamCallback | None = None, ) -> TaskRun: run_output, _ = await self.invoke_returning_run_output( - input, input_source, existing_run, on_chunk=on_chunk + input, input_source, existing_run ) return run_output @@ -145,7 +154,6 @@ async def _run_returning_run_output( input: InputType, input_source: DataSource | None = None, existing_run: TaskRun | None = None, - on_chunk: StreamCallback | None = None, ) -> Tuple[TaskRun, RunOutput]: # validate input, allowing arrays if self.input_schema is not None: @@ -173,9 +181,7 @@ async def _run_returning_run_output( formatted_input = formatter.format_input(input) # Run - run_output, usage = await self._run( - formatted_input, prior_trace=prior_trace, on_chunk=on_chunk - ) + run_output, usage = await self._run(formatted_input, prior_trace=prior_trace) # Parse provider = self.model_provider() @@ -265,7 +271,6 @@ async def invoke_returning_run_output( input: InputType, input_source: DataSource | None = None, existing_run: TaskRun | None = None, - on_chunk: StreamCallback | None = None, ) -> Tuple[TaskRun, RunOutput]: # Determine if this is the root agent (no existing run context) is_root_agent = get_agent_run_id() is None @@ -276,7 +281,7 @@ async def invoke_returning_run_output( try: return await self._run_returning_run_output( - input, input_source, existing_run, on_chunk=on_chunk + input, input_source, existing_run ) finally: if is_root_agent: @@ -287,6 +292,210 @@ async def invoke_returning_run_output( finally: clear_agent_run_id() + async def invoke_openai_stream( + self, + input: InputType, + input_source: DataSource | None = None, + existing_run: TaskRun | None = None, + ) -> AsyncIterator[ModelResponseStream]: + """Stream raw OpenAI-protocol chunks for the task execution. + + Yields ``ModelResponseStream`` chunks as they arrive from the model. + After the iterator is exhausted the run has been validated and saved + (when configured). Tool-call rounds happen internally and are not + surfaced; use ``invoke_ai_sdk_stream`` if you need tool-call events. + """ + is_root_agent = get_agent_run_id() is None + if is_root_agent: + set_agent_run_id(generate_agent_run_id()) + + try: + adapter_stream = self._prepare_stream(input, existing_run) + + async for event in adapter_stream: + if isinstance(event, ModelResponseStream): + yield event + + self._finalize_stream(adapter_stream, input, input_source, existing_run) + finally: + if is_root_agent: + try: + run_id = get_agent_run_id() + if run_id: + await MCPSessionManager.shared().cleanup_session(run_id) + finally: + clear_agent_run_id() + + async def invoke_ai_sdk_stream( + self, + input: InputType, + input_source: DataSource | None = None, + existing_run: TaskRun | None = None, + ) -> AsyncIterator[AiSdkStreamEvent]: + """Stream AI SDK protocol events for the task execution. + + Yields ``AiSdkStreamEvent`` instances covering text, reasoning, + tool-call lifecycle, step boundaries, and control events. + """ + is_root_agent = get_agent_run_id() is None + if is_root_agent: + set_agent_run_id(generate_agent_run_id()) + + try: + adapter_stream = self._prepare_stream(input, existing_run) + + message_id = f"msg-{uuid.uuid4().hex}" + converter = AiSdkStreamConverter() + + yield AiSdkStreamEvent(AiSdkEventType.START, {"messageId": message_id}) + + yield AiSdkStreamEvent(AiSdkEventType.START_STEP) + + async for event in adapter_stream: + # ModelResponseStream events come from LiteLLM's own OpenAI compatible streaming + if isinstance(event, ModelResponseStream): + for ai_event in converter.convert_chunk(event): + yield ai_event + # ToolCallEvent events come from ourselves and are emitted on rounds of toolcalls + elif isinstance(event, ToolCallEvent): + for ai_event in converter.convert_tool_event(event): + yield ai_event + + for ai_event in converter.finalize(): + yield ai_event + + yield AiSdkStreamEvent(AiSdkEventType.FINISH_STEP) + + self._finalize_stream(adapter_stream, input, input_source, existing_run) + finally: + if is_root_agent: + try: + run_id = get_agent_run_id() + if run_id: + await MCPSessionManager.shared().cleanup_session(run_id) + finally: + clear_agent_run_id() + + def _prepare_stream( + self, + input: InputType, + existing_run: TaskRun | None, + ) -> AdapterStream: + if self.input_schema is not None: + validate_schema_with_value_error( + input, + self.input_schema, + "This task requires a specific input schema. While the model produced JSON, that JSON didn't meet the schema. Search 'Troubleshooting Structured Data Issues' in our docs for more information.", + require_object=False, + ) + + if existing_run is not None and ( + not existing_run.trace or len(existing_run.trace) == 0 + ): + raise ValueError( + "Run has no trace. Cannot continue session without conversation history." + ) + + prior_trace = existing_run.trace if existing_run else None + + formatted_input = input + formatter_id = self.model_provider().formatter + if formatter_id is not None: + formatter = request_formatter_from_id(formatter_id) + formatted_input = formatter.format_input(input) + + return self._create_run_stream(formatted_input, prior_trace) + + def _finalize_stream( + self, + adapter_stream: AdapterStream, + input: InputType, + input_source: DataSource | None, + existing_run: TaskRun | None, + ) -> TaskRun: + """Streaming invocations are only concerned with passing through events as they come in. + At the end of the stream, we still need to validate the output, create a run and everything + else that a non-streaming invocation would do. + """ + + result: AdapterStreamResult = adapter_stream.result + run_output = result.run_output + usage = result.usage + + provider = self.model_provider() + parser = model_parser_from_id(provider.parser) + parsed_output = parser.parse_output(original_output=run_output) + + if self.output_schema is not None: + if isinstance(parsed_output.output, str): + parsed_output.output = parse_json_string(parsed_output.output) + if not isinstance(parsed_output.output, dict): + raise RuntimeError( + f"structured response is not a dict: {parsed_output.output}" + ) + validate_schema_with_value_error( + parsed_output.output, + self.output_schema, + "This task requires a specific output schema. While the model produced JSON, that JSON didn't meet the schema. Search 'Troubleshooting Structured Data Issues' in our docs for more information.", + ) + else: + if not isinstance(parsed_output.output, str): + raise RuntimeError( + f"response is not a string for non-structured task: {parsed_output.output}" + ) + + trace_has_toolcalls = parsed_output.trace is not None and any( + message.get("role", None) == "tool" for message in parsed_output.trace + ) + if ( + provider.reasoning_capable + and ( + not parsed_output.intermediate_outputs + or "reasoning" not in parsed_output.intermediate_outputs + ) + and not ( + provider.reasoning_optional_for_structured_output + and self.has_structured_output() + ) + and not trace_has_toolcalls + ): + raise RuntimeError( + "Reasoning is required for this model, but no reasoning was returned." + ) + + if existing_run is not None: + merged_output = RunOutput( + output=parsed_output.output, + intermediate_outputs=parsed_output.intermediate_outputs + or run_output.intermediate_outputs, + output_logprobs=parsed_output.output_logprobs + or run_output.output_logprobs, + trace=run_output.trace, + ) + run = self.generate_run( + input, + input_source, + merged_output, + usage, + run_output.trace, + existing_run=existing_run, + ) + else: + run = self.generate_run( + input, input_source, parsed_output, usage, run_output.trace + ) + + if ( + self.base_adapter_config.allow_saving + and Config.shared().autosave_runs + and self.task.path is not None + ): + run.save_to_file() + elif existing_run is None: + run.id = None + + return run + def has_structured_output(self) -> bool: return self.output_schema is not None @@ -299,10 +508,17 @@ async def _run( self, input: InputType, prior_trace: list[ChatCompletionMessageParam] | None = None, - on_chunk: StreamCallback | None = None, ) -> Tuple[RunOutput, Usage | None]: pass + def _create_run_stream( + self, + input: InputType, + prior_trace: list[ChatCompletionMessageParam] | None = None, + ) -> AdapterStream: + """Create a stream for the adapter. Implementations must override this method to support streaming.""" + raise NotImplementedError("Streaming is not supported for this adapter type") + def build_prompt(self) -> str: if self.prompt_builder is None: raise ValueError("Prompt builder is not available for MCP run config") diff --git a/libs/core/kiln_ai/adapters/model_adapters/litellm_adapter.py b/libs/core/kiln_ai/adapters/model_adapters/litellm_adapter.py index 344bedec0..cebccdf57 100644 --- a/libs/core/kiln_ai/adapters/model_adapters/litellm_adapter.py +++ b/libs/core/kiln_ai/adapters/model_adapters/litellm_adapter.py @@ -5,6 +5,7 @@ from dataclasses import dataclass from typing import Any, Dict, List, Tuple +import litellm from litellm.types.utils import ( ChatCompletionMessageToolCall, ChoiceLogprobs, @@ -19,17 +20,16 @@ import kiln_ai.datamodel as datamodel from kiln_ai.adapters.chat import ChatCompletionMessageIncludingLiteLLM -from kiln_ai.adapters.litellm_utils.litellm_streaming import StreamingCompletion from kiln_ai.adapters.ml_model_list import ( KilnModelProvider, ModelProviderName, StructuredOutputMode, ) +from kiln_ai.adapters.model_adapters.adapter_stream import AdapterStream from kiln_ai.adapters.model_adapters.base_adapter import ( AdapterConfig, BaseAdapter, RunOutput, - StreamCallback, Usage, ) from kiln_ai.adapters.model_adapters.litellm_config import LiteLlmConfig @@ -96,7 +96,6 @@ async def _run_model_turn( prior_messages: list[ChatCompletionMessageIncludingLiteLLM], top_logprobs: int | None, skip_response_format: bool, - on_chunk: StreamCallback | None = None, ) -> ModelTurnResult: """ Call the model for a single top level turn: from user message to agent message. @@ -120,7 +119,7 @@ async def _run_model_turn( # Make the completion call model_response, response_choice = await self.acompletion_checking_response( - on_chunk=on_chunk, **completion_kwargs + **completion_kwargs ) # count the usage @@ -187,7 +186,6 @@ async def _run( self, input: InputType, prior_trace: list[ChatCompletionMessageParam] | None = None, - on_chunk: StreamCallback | None = None, ) -> tuple[RunOutput, Usage | None]: usage = Usage() @@ -232,7 +230,6 @@ async def _run( messages, self.base_adapter_config.top_logprobs if turn.final_call else None, skip_response_format, - on_chunk=on_chunk, ) usage += turn_result.usage @@ -265,6 +262,28 @@ async def _run( return output, usage + def _create_run_stream( + self, + input: InputType, + prior_trace: list[ChatCompletionMessageParam] | None = None, + ) -> AdapterStream: + provider = self.model_provider() + if not provider.model_id: + raise ValueError("Model ID is required for OpenAI compatible models") + + chat_formatter = self.build_chat_formatter(input, prior_trace) + initial_messages: list[ChatCompletionMessageIncludingLiteLLM] = copy.deepcopy( + chat_formatter.initial_messages() + ) + + return AdapterStream( + adapter=self, + provider=provider, + chat_formatter=chat_formatter, + initial_messages=initial_messages, + top_logprobs=self.base_adapter_config.top_logprobs, + ) + def _extract_and_validate_logprobs( self, final_choice: Choices | None ) -> ChoiceLogprobs | None: @@ -301,13 +320,9 @@ def _extract_reasoning_to_intermediate_outputs( intermediate_outputs["reasoning"] = stripped_reasoning_content async def acompletion_checking_response( - self, on_chunk: StreamCallback | None = None, **kwargs + self, **kwargs: Any ) -> Tuple[ModelResponse, Choices]: - stream = StreamingCompletion(**kwargs) - async for chunk in stream: - if on_chunk is not None: - await on_chunk(chunk) - response = stream.response + response = await litellm.acompletion(**kwargs) if ( not isinstance(response, ModelResponse) diff --git a/libs/core/kiln_ai/adapters/model_adapters/mcp_adapter.py b/libs/core/kiln_ai/adapters/model_adapters/mcp_adapter.py index 236b63a9e..6055d96d9 100644 --- a/libs/core/kiln_ai/adapters/model_adapters/mcp_adapter.py +++ b/libs/core/kiln_ai/adapters/model_adapters/mcp_adapter.py @@ -4,7 +4,6 @@ from kiln_ai.adapters.model_adapters.base_adapter import ( AdapterConfig, BaseAdapter, - StreamCallback, ) from kiln_ai.adapters.parsers.json_parser import parse_json_string from kiln_ai.adapters.run_output import RunOutput @@ -50,7 +49,6 @@ async def _run( self, input: InputType, prior_trace: list[ChatCompletionMessageParam] | None = None, - on_chunk: StreamCallback | None = None, ) -> Tuple[RunOutput, Usage | None]: if prior_trace is not None: raise NotImplementedError( @@ -91,7 +89,6 @@ async def invoke( input: InputType, input_source: DataSource | None = None, existing_run: TaskRun | None = None, - on_chunk: StreamCallback | None = None, ) -> TaskRun: if existing_run is not None: raise NotImplementedError( @@ -100,7 +97,7 @@ async def invoke( ) run_output, _ = await self.invoke_returning_run_output( - input, input_source, existing_run, on_chunk=on_chunk + input, input_source, existing_run ) return run_output @@ -109,7 +106,6 @@ async def invoke_returning_run_output( input: InputType, input_source: DataSource | None = None, existing_run: TaskRun | None = None, - on_chunk: StreamCallback | None = None, ) -> Tuple[TaskRun, RunOutput]: """ Runs the task and returns both the persisted TaskRun and raw RunOutput. diff --git a/libs/core/kiln_ai/adapters/model_adapters/stream_events.py b/libs/core/kiln_ai/adapters/model_adapters/stream_events.py new file mode 100644 index 000000000..a60c03260 --- /dev/null +++ b/libs/core/kiln_ai/adapters/model_adapters/stream_events.py @@ -0,0 +1,289 @@ +from __future__ import annotations + +import json +import uuid +from dataclasses import dataclass, field +from enum import Enum +from typing import Any + +from litellm.types.utils import ModelResponseStream + + +class AiSdkEventType(str, Enum): + START = "start" + FINISH = "finish" + ERROR = "error" + ABORT = "abort" + + TEXT_START = "text-start" + TEXT_DELTA = "text-delta" + TEXT_END = "text-end" + + REASONING_START = "reasoning-start" + REASONING_DELTA = "reasoning-delta" + REASONING_END = "reasoning-end" + + TOOL_INPUT_START = "tool-input-start" + TOOL_INPUT_DELTA = "tool-input-delta" + TOOL_INPUT_AVAILABLE = "tool-input-available" + TOOL_INPUT_ERROR = "tool-input-error" + + TOOL_OUTPUT_AVAILABLE = "tool-output-available" + TOOL_OUTPUT_ERROR = "tool-output-error" + + START_STEP = "start-step" + FINISH_STEP = "finish-step" + + METADATA = "metadata" + SOURCE_URL = "source-url" + SOURCE_DOCUMENT = "source-document" + FILE = "file" + + +@dataclass +class AiSdkStreamEvent: + type: AiSdkEventType + payload: dict[str, Any] = field(default_factory=dict) + + def to_sse(self) -> str: + data = {"type": self.type.value, **self.payload} + return f"data: {json.dumps(data, separators=(',', ':'))}\n\n" + + +class ToolCallEventType(str, Enum): + INPUT_AVAILABLE = "input_available" + OUTPUT_AVAILABLE = "output_available" + OUTPUT_ERROR = "output_error" + + +@dataclass +class ToolCallEvent: + event_type: ToolCallEventType + tool_call_id: str + tool_name: str + arguments: dict[str, Any] | None = None + result: str | None = None + error: str | None = None + + +AdapterStreamEvent = ModelResponseStream | ToolCallEvent + + +class AiSdkStreamConverter: + """Stateful converter from OpenAI streaming chunks to AI SDK events.""" + + def __init__(self) -> None: + self._text_started = False + self._text_id = f"text-{uuid.uuid4().hex[:12]}" + self._reasoning_started = False + self._reasoning_id = f"reasoning-{uuid.uuid4().hex[:12]}" + self._reasoning_block_count = 0 + self._tool_calls_state: dict[int, dict[str, Any]] = {} + self._finish_reason: str | None = None + self._usage_data: Any = None + + def convert_chunk(self, chunk: ModelResponseStream) -> list[AiSdkStreamEvent]: + events: list[AiSdkStreamEvent] = [] + + for choice in chunk.choices: + if choice.finish_reason is not None: + self._finish_reason = choice.finish_reason + + delta = choice.delta + if delta is None: + continue + + reasoning_content = getattr(delta, "reasoning_content", None) + if reasoning_content is not None: + if not self._reasoning_started: + self._reasoning_block_count += 1 + self._reasoning_id = f"reasoning-{uuid.uuid4().hex[:12]}" + events.append( + AiSdkStreamEvent( + AiSdkEventType.REASONING_START, + {"id": self._reasoning_id}, + ) + ) + self._reasoning_started = True + events.append( + AiSdkStreamEvent( + AiSdkEventType.REASONING_DELTA, + {"id": self._reasoning_id, "delta": reasoning_content}, + ) + ) + + if delta.content: + if self._reasoning_started: + events.append( + AiSdkStreamEvent( + AiSdkEventType.REASONING_END, + {"id": self._reasoning_id}, + ) + ) + self._reasoning_started = False + + if not self._text_started: + events.append( + AiSdkStreamEvent( + AiSdkEventType.TEXT_START, + {"id": self._text_id}, + ) + ) + self._text_started = True + events.append( + AiSdkStreamEvent( + AiSdkEventType.TEXT_DELTA, + {"id": self._text_id, "delta": delta.content}, + ) + ) + + if delta.tool_calls: + if self._reasoning_started: + events.append( + AiSdkStreamEvent( + AiSdkEventType.REASONING_END, + {"id": self._reasoning_id}, + ) + ) + self._reasoning_started = False + + for tc_delta in delta.tool_calls: + idx = tc_delta.index + tc_state = self._tool_calls_state.setdefault( + idx, + { + "id": None, + "name": None, + "arguments": "", + "started": False, + }, + ) + + if tc_delta.id is not None: + tc_state["id"] = tc_delta.id + + func = getattr(tc_delta, "function", None) + if func is not None: + if func.name is not None: + tc_state["name"] = func.name + if func.arguments: + tc_state["arguments"] += func.arguments + + if tc_state["id"] and tc_state["name"] and not tc_state["started"]: + events.append( + AiSdkStreamEvent( + AiSdkEventType.TOOL_INPUT_START, + { + "toolCallId": tc_state["id"], + "toolName": tc_state["name"], + }, + ) + ) + tc_state["started"] = True + + if func and func.arguments and tc_state["id"]: + events.append( + AiSdkStreamEvent( + AiSdkEventType.TOOL_INPUT_DELTA, + { + "toolCallId": tc_state["id"], + "inputTextDelta": func.arguments, + }, + ) + ) + + if not chunk.choices: + usage = getattr(chunk, "usage", None) + if usage is not None: + self._usage_data = usage + + return events + + def convert_tool_event(self, event: ToolCallEvent) -> list[AiSdkStreamEvent]: + events: list[AiSdkStreamEvent] = [] + + if event.event_type == ToolCallEventType.INPUT_AVAILABLE: + events.append( + AiSdkStreamEvent( + AiSdkEventType.TOOL_INPUT_AVAILABLE, + { + "toolCallId": event.tool_call_id, + "toolName": event.tool_name, + "input": event.arguments or {}, + }, + ) + ) + elif event.event_type == ToolCallEventType.OUTPUT_AVAILABLE: + events.append( + AiSdkStreamEvent( + AiSdkEventType.TOOL_OUTPUT_AVAILABLE, + { + "toolCallId": event.tool_call_id, + "output": event.result, + }, + ) + ) + elif event.event_type == ToolCallEventType.OUTPUT_ERROR: + events.append( + AiSdkStreamEvent( + AiSdkEventType.TOOL_OUTPUT_ERROR, + { + "toolCallId": event.tool_call_id, + "errorText": event.error or "Unknown error", + }, + ) + ) + + return events + + def finalize(self) -> list[AiSdkStreamEvent]: + events: list[AiSdkStreamEvent] = [] + + if self._reasoning_started: + events.append( + AiSdkStreamEvent( + AiSdkEventType.REASONING_END, + {"id": self._reasoning_id}, + ) + ) + self._reasoning_started = False + + if self._text_started: + events.append( + AiSdkStreamEvent( + AiSdkEventType.TEXT_END, + {"id": self._text_id}, + ) + ) + self._text_started = False + + finish_payload: dict[str, Any] = {} + if self._finish_reason is not None: + finish_payload["finishReason"] = self._finish_reason.replace("_", "-") + + if self._usage_data is not None: + usage_payload: dict[str, Any] = { + "promptTokens": self._usage_data.prompt_tokens, + "completionTokens": self._usage_data.completion_tokens, + } + total = getattr(self._usage_data, "total_tokens", None) + if total is not None: + usage_payload["totalTokens"] = total + finish_payload["usage"] = usage_payload + + if finish_payload: + events.append( + AiSdkStreamEvent( + AiSdkEventType.FINISH, + {"messageMetadata": finish_payload}, + ) + ) + else: + events.append(AiSdkStreamEvent(AiSdkEventType.FINISH)) + + return events + + def reset_for_next_step(self) -> None: + """Reset per-step state between LLM calls in a multi-step flow.""" + self._tool_calls_state = {} + self._finish_reason = None diff --git a/libs/core/kiln_ai/adapters/model_adapters/test_adapter_stream.py b/libs/core/kiln_ai/adapters/model_adapters/test_adapter_stream.py new file mode 100644 index 000000000..25715645c --- /dev/null +++ b/libs/core/kiln_ai/adapters/model_adapters/test_adapter_stream.py @@ -0,0 +1,372 @@ +import json +from typing import Any +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest +from litellm.types.utils import ( + ChatCompletionMessageToolCall, + Choices, + Delta, + Function, + ModelResponse, + ModelResponseStream, + StreamingChoices, +) +from litellm.types.utils import Message as LiteLLMMessage + +from kiln_ai.adapters.chat import ChatFormatter +from kiln_ai.adapters.model_adapters.adapter_stream import AdapterStream +from kiln_ai.adapters.model_adapters.stream_events import ( + ToolCallEvent, + ToolCallEventType, +) +from kiln_ai.datamodel import Usage + + +def _make_streaming_chunk( + content: str | None = None, + finish_reason: str | None = None, +) -> ModelResponseStream: + delta = Delta(content=content) + choice = StreamingChoices( + index=0, + delta=delta, + finish_reason=finish_reason, + ) + return ModelResponseStream(id="test-stream", choices=[choice]) + + +def _make_model_response( + content: str = "Hello", + tool_calls: list[ChatCompletionMessageToolCall] | None = None, +) -> ModelResponse: + message = LiteLLMMessage(content=content, role="assistant") + if tool_calls is not None: + message.tool_calls = tool_calls + choice = Choices( + index=0, + message=message, + finish_reason="stop" if tool_calls is None else "tool_calls", + ) + return ModelResponse(id="test-response", choices=[choice]) + + +def _make_tool_call( + call_id: str = "call_1", + name: str = "add", + arguments: dict[str, Any] | None = None, +) -> ChatCompletionMessageToolCall: + args = json.dumps(arguments or {"a": 1, "b": 2}) + return ChatCompletionMessageToolCall( + id=call_id, + type="function", + function=Function(name=name, arguments=args), + ) + + +class FakeChatFormatter(ChatFormatter): + """A simple chat formatter that returns a single turn then None.""" + + def __init__(self, num_turns: int = 1): + self._turn_count = 0 + self._num_turns = num_turns + + def next_turn(self, prior_output: str | None): + if self._turn_count >= self._num_turns: + return None + self._turn_count += 1 + turn = MagicMock() + turn.messages = [MagicMock(role="user", content="test input")] + turn.final_call = self._turn_count == self._num_turns + return turn + + def intermediate_outputs(self): + return {} + + +class FakeStreamingCompletion: + """Mocks StreamingCompletion: yields chunks, then exposes .response""" + + def __init__( + self, + model_response: ModelResponse, + chunks: list[ModelResponseStream] | None = None, + ): + self._chunks = chunks or [ + _make_streaming_chunk(content="Hel"), + _make_streaming_chunk(content="lo"), + _make_streaming_chunk(finish_reason="stop"), + ] + self._response = model_response + + @property + def response(self): + return self._response + + async def __aiter__(self): + for chunk in self._chunks: + yield chunk + + +@pytest.fixture +def mock_adapter(): + adapter = MagicMock() + adapter.build_completion_kwargs = AsyncMock(return_value={"model": "test"}) + adapter.usage_from_response = MagicMock(return_value=Usage()) + adapter.process_tool_calls = AsyncMock(return_value=(None, [])) + adapter._extract_and_validate_logprobs = MagicMock(return_value=None) + adapter._extract_reasoning_to_intermediate_outputs = MagicMock() + adapter.all_messages_to_trace = MagicMock(return_value=[]) + adapter.base_adapter_config = MagicMock() + adapter.base_adapter_config.top_logprobs = None + return adapter + + +@pytest.fixture +def mock_provider(): + provider = MagicMock() + provider.model_id = "test-model" + return provider + + +class TestAdapterStreamSimple: + @pytest.mark.asyncio + async def test_simple_content_response(self, mock_adapter, mock_provider): + response = _make_model_response(content="Hello world") + fake_stream = FakeStreamingCompletion(response) + formatter = FakeChatFormatter() + + with patch( + "kiln_ai.adapters.model_adapters.adapter_stream.StreamingCompletion", + return_value=fake_stream, + ): + stream = AdapterStream( + adapter=mock_adapter, + provider=mock_provider, + chat_formatter=formatter, + initial_messages=[], + top_logprobs=None, + ) + + events = [] + async for event in stream: + events.append(event) + + chunks = [e for e in events if isinstance(e, ModelResponseStream)] + assert len(chunks) == 3 + + result = stream.result + assert result.run_output.output == "Hello world" + + @pytest.mark.asyncio + async def test_result_not_available_before_iteration( + self, mock_adapter, mock_provider + ): + stream = AdapterStream( + adapter=mock_adapter, + provider=mock_provider, + chat_formatter=FakeChatFormatter(), + initial_messages=[], + top_logprobs=None, + ) + with pytest.raises(RuntimeError, match="not been iterated"): + _ = stream.result + + +class TestAdapterStreamToolCalls: + @pytest.mark.asyncio + async def test_tool_call_yields_events(self, mock_adapter, mock_provider): + tool_call = _make_tool_call( + call_id="call_1", name="add", arguments={"a": 1, "b": 2} + ) + tool_response = _make_model_response(content=None, tool_calls=[tool_call]) + final_response = _make_model_response(content="The answer is 3") + + tool_stream = FakeStreamingCompletion( + tool_response, + [_make_streaming_chunk(finish_reason="tool_calls")], + ) + final_stream = FakeStreamingCompletion( + final_response, + [ + _make_streaming_chunk(content="The answer is 3"), + _make_streaming_chunk(finish_reason="stop"), + ], + ) + + streams_iter = iter([tool_stream, final_stream]) + + mock_adapter.process_tool_calls = AsyncMock( + return_value=( + None, + [{"role": "tool", "tool_call_id": "call_1", "content": "3"}], + ) + ) + + with patch( + "kiln_ai.adapters.model_adapters.adapter_stream.StreamingCompletion", + side_effect=lambda **kw: next(streams_iter), + ): + stream = AdapterStream( + adapter=mock_adapter, + provider=mock_provider, + chat_formatter=FakeChatFormatter(), + initial_messages=[], + top_logprobs=None, + ) + + events = [] + async for event in stream: + events.append(event) + + tool_events = [e for e in events if isinstance(e, ToolCallEvent)] + assert len(tool_events) == 2 + + input_event = next( + e for e in tool_events if e.event_type == ToolCallEventType.INPUT_AVAILABLE + ) + assert input_event.tool_call_id == "call_1" + assert input_event.tool_name == "add" + assert input_event.arguments == {"a": 1, "b": 2} + + output_event = next( + e for e in tool_events if e.event_type == ToolCallEventType.OUTPUT_AVAILABLE + ) + assert output_event.tool_call_id == "call_1" + assert output_event.result == "3" + + assert stream.result.run_output.output == "The answer is 3" + + @pytest.mark.asyncio + async def test_task_response_tool_call(self, mock_adapter, mock_provider): + task_response_call = _make_tool_call( + call_id="call_tr", name="task_response", arguments={"result": "42"} + ) + response = _make_model_response(content=None, tool_calls=[task_response_call]) + + fake_stream = FakeStreamingCompletion( + response, + [_make_streaming_chunk(finish_reason="tool_calls")], + ) + + mock_adapter.process_tool_calls = AsyncMock( + return_value=('{"result": "42"}', []) + ) + + with patch( + "kiln_ai.adapters.model_adapters.adapter_stream.StreamingCompletion", + return_value=fake_stream, + ): + stream = AdapterStream( + adapter=mock_adapter, + provider=mock_provider, + chat_formatter=FakeChatFormatter(), + initial_messages=[], + top_logprobs=None, + ) + + events = [] + async for event in stream: + events.append(event) + + tool_events = [e for e in events if isinstance(e, ToolCallEvent)] + assert len(tool_events) == 0 + + assert stream.result.run_output.output == '{"result": "42"}' + + @pytest.mark.asyncio + async def test_too_many_tool_calls_raises(self, mock_adapter, mock_provider): + tool_call = _make_tool_call() + response = _make_model_response(content=None, tool_calls=[tool_call]) + + mock_adapter.process_tool_calls = AsyncMock( + return_value=( + None, + [{"role": "tool", "tool_call_id": "call_1", "content": "ok"}], + ) + ) + + def make_stream(**kw): + return FakeStreamingCompletion( + response, + [_make_streaming_chunk(finish_reason="tool_calls")], + ) + + with ( + patch( + "kiln_ai.adapters.model_adapters.adapter_stream.StreamingCompletion", + side_effect=make_stream, + ), + patch( + "kiln_ai.adapters.model_adapters.adapter_stream.MAX_TOOL_CALLS_PER_TURN", + 2, + ), + ): + stream = AdapterStream( + adapter=mock_adapter, + provider=mock_provider, + chat_formatter=FakeChatFormatter(), + initial_messages=[], + top_logprobs=None, + ) + + with pytest.raises(RuntimeError, match="Too many tool calls"): + async for _ in stream: + pass + + @pytest.mark.asyncio + async def test_unparseable_tool_call_arguments(self, mock_adapter, mock_provider): + bad_tool_call = ChatCompletionMessageToolCall( + id="call_bad", + type="function", + function=Function(name="add", arguments="not json"), + ) + response = _make_model_response(content=None, tool_calls=[bad_tool_call]) + final_response = _make_model_response(content="fallback") + + tool_stream = FakeStreamingCompletion( + response, + [_make_streaming_chunk(finish_reason="tool_calls")], + ) + final_stream = FakeStreamingCompletion( + final_response, + [ + _make_streaming_chunk(content="fallback"), + _make_streaming_chunk(finish_reason="stop"), + ], + ) + + streams_iter = iter([tool_stream, final_stream]) + + mock_adapter.process_tool_calls = AsyncMock( + return_value=( + None, + [{"role": "tool", "tool_call_id": "call_bad", "content": "error"}], + ) + ) + + with patch( + "kiln_ai.adapters.model_adapters.adapter_stream.StreamingCompletion", + side_effect=lambda **kw: next(streams_iter), + ): + stream = AdapterStream( + adapter=mock_adapter, + provider=mock_provider, + chat_formatter=FakeChatFormatter(), + initial_messages=[], + top_logprobs=None, + ) + + events = [] + async for event in stream: + events.append(event) + + input_events = [ + e + for e in events + if isinstance(e, ToolCallEvent) + and e.event_type == ToolCallEventType.INPUT_AVAILABLE + ] + assert len(input_events) == 1 + assert input_events[0].arguments is None + assert "Failed to parse" in (input_events[0].error or "") diff --git a/libs/core/kiln_ai/adapters/model_adapters/test_base_adapter.py b/libs/core/kiln_ai/adapters/model_adapters/test_base_adapter.py index e6f1cca82..0e8e51d95 100644 --- a/libs/core/kiln_ai/adapters/model_adapters/test_base_adapter.py +++ b/libs/core/kiln_ai/adapters/model_adapters/test_base_adapter.py @@ -1047,8 +1047,8 @@ async def mock_run(input, **kwargs): assert run_id.startswith("run_") -class TestStreamCallback: - """Tests for the on_chunk streaming callback parameter.""" +class TestStreamMethods: + """Tests for the streaming methods on BaseAdapter.""" @pytest.fixture def stream_adapter(self, base_task): @@ -1062,105 +1062,28 @@ def stream_adapter(self, base_task): ), ) - def _setup_adapter_mocks(self, adapter): + @pytest.mark.asyncio + async def test_invoke_openai_stream_raises_for_unsupported_adapter( + self, stream_adapter + ): + """MockAdapter does not implement _create_run_stream.""" provider = MagicMock() - provider.parser = "test_parser" provider.formatter = None - provider.reasoning_capable = False - adapter.model_provider = MagicMock(return_value=provider) - - @pytest.mark.asyncio - async def test_on_chunk_forwarded_to_run(self, stream_adapter): - """Test that on_chunk is passed through to _run.""" - received_kwargs = {} - - async def mock_run(input, **kwargs): - received_kwargs.update(kwargs) - return RunOutput(output="test output", intermediate_outputs={}), None + stream_adapter.model_provider = MagicMock(return_value=provider) - stream_adapter._run = mock_run - self._setup_adapter_mocks(stream_adapter) - - callback = AsyncMock() - - parser = MagicMock() - parser.parse_output.return_value = RunOutput( - output="test output", intermediate_outputs={} - ) - - with ( - patch( - "kiln_ai.adapters.model_adapters.base_adapter.model_parser_from_id" - ) as mock_parser_factory, - patch( - "kiln_ai.adapters.model_adapters.base_adapter.request_formatter_from_id" - ), - ): - mock_parser_factory.return_value = parser - await stream_adapter.invoke_returning_run_output( - {"test": "input"}, on_chunk=callback - ) - - assert received_kwargs.get("on_chunk") is callback + with pytest.raises(NotImplementedError, match="Streaming is not supported"): + async for _chunk in stream_adapter.invoke_openai_stream("test input"): + pass @pytest.mark.asyncio - async def test_on_chunk_none_by_default(self, stream_adapter): - """Test that on_chunk defaults to None when not provided.""" - received_kwargs = {} - - async def mock_run(input, **kwargs): - received_kwargs.update(kwargs) - return RunOutput(output="test output", intermediate_outputs={}), None - - stream_adapter._run = mock_run - self._setup_adapter_mocks(stream_adapter) - - parser = MagicMock() - parser.parse_output.return_value = RunOutput( - output="test output", intermediate_outputs={} - ) - - with ( - patch( - "kiln_ai.adapters.model_adapters.base_adapter.model_parser_from_id" - ) as mock_parser_factory, - patch( - "kiln_ai.adapters.model_adapters.base_adapter.request_formatter_from_id" - ), - ): - mock_parser_factory.return_value = parser - await stream_adapter.invoke_returning_run_output({"test": "input"}) - - assert received_kwargs.get("on_chunk") is None - - @pytest.mark.asyncio - async def test_invoke_forwards_on_chunk(self, stream_adapter): - """Test that invoke() also forwards on_chunk.""" - received_kwargs = {} - - async def mock_run(input, **kwargs): - received_kwargs.update(kwargs) - return RunOutput(output="test output", intermediate_outputs={}), None - - stream_adapter._run = mock_run - self._setup_adapter_mocks(stream_adapter) - - callback = AsyncMock() - - parser = MagicMock() - parser.parse_output.return_value = RunOutput( - output="test output", intermediate_outputs={} - ) - - with ( - patch( - "kiln_ai.adapters.model_adapters.base_adapter.model_parser_from_id" - ) as mock_parser_factory, - patch( - "kiln_ai.adapters.model_adapters.base_adapter.request_formatter_from_id" - ), - ): - mock_parser_factory.return_value = parser - await stream_adapter.invoke({"test": "input"}, on_chunk=callback) + async def test_invoke_ai_sdk_stream_raises_for_unsupported_adapter( + self, stream_adapter + ): + """MockAdapter does not implement _create_run_stream.""" + provider = MagicMock() + provider.formatter = None + stream_adapter.model_provider = MagicMock(return_value=provider) - assert received_kwargs.get("on_chunk") is callback + with pytest.raises(NotImplementedError, match="Streaming is not supported"): + async for _event in stream_adapter.invoke_ai_sdk_stream("test input"): + pass diff --git a/libs/core/kiln_ai/adapters/model_adapters/test_litellm_adapter.py b/libs/core/kiln_ai/adapters/model_adapters/test_litellm_adapter.py index b47d2d46e..197b1f500 100644 --- a/libs/core/kiln_ai/adapters/model_adapters/test_litellm_adapter.py +++ b/libs/core/kiln_ai/adapters/model_adapters/test_litellm_adapter.py @@ -1344,7 +1344,7 @@ def capturing_build(input, prior_trace_arg=None): adapter.build_chat_formatter = capturing_build async def mock_run_model_turn( - provider, prior_messages, top_logprobs, skip_response_format, on_chunk=None + provider, prior_messages, top_logprobs, skip_response_format ): extended = list(prior_messages) extended.append({"role": "assistant", "content": "How can I help?"}) diff --git a/libs/core/kiln_ai/adapters/model_adapters/test_litellm_adapter_streaming.py b/libs/core/kiln_ai/adapters/model_adapters/test_litellm_adapter_streaming.py index 62e16ce2a..43b2aaff4 100644 --- a/libs/core/kiln_ai/adapters/model_adapters/test_litellm_adapter_streaming.py +++ b/libs/core/kiln_ai/adapters/model_adapters/test_litellm_adapter_streaming.py @@ -1,9 +1,9 @@ import json import logging -from abc import ABC, abstractmethod +import re +from datetime import datetime, timezone from pathlib import Path -from typing import Any, Callable, Tuple -from unittest.mock import patch +from typing import Any, Callable import litellm import pytest @@ -12,99 +12,70 @@ from kiln_ai.adapters.ml_model_list import ModelProviderName, StructuredOutputMode from kiln_ai.adapters.model_adapters.litellm_adapter import LiteLlmAdapter from kiln_ai.adapters.model_adapters.litellm_config import LiteLlmConfig +from kiln_ai.adapters.model_adapters.stream_events import ( + AiSdkEventType, + AiSdkStreamEvent, +) from kiln_ai.datamodel import Project, PromptGenerators, Task from kiln_ai.datamodel.run_config import KilnAgentRunConfigProperties, ToolsRunConfig from kiln_ai.datamodel.tool_id import KilnBuiltInToolId logger = logging.getLogger(__name__) - -class ChunkRendererAbstract(ABC): - @abstractmethod - async def render_chunk(self, chunk: litellm.ModelResponseStream): - pass - - @abstractmethod - def get_stream_text(self) -> str: - pass - - -class ChunkRenderer(ChunkRendererAbstract): - def __init__(self): - self.chunk_texts: list[str] = [] - self.current_block_type: str | None = None - - def print_and_append(self, text: str): - # replace with print if your logger is not outputting info logs - logger.info(text) - self.chunk_texts.append(text) - - def enter_block(self, block_type: str): - if self.current_block_type != block_type: - if self.current_block_type is not None: - self.print_and_append(f"\n") - - self.print_and_append(f"\n<{block_type}>\n") - self.current_block_type = block_type - - def render_reasoning(self, reasoning_content: str): - self.enter_block("reasoning") - self.print_and_append(reasoning_content) - - def render_content(self, content: str): - self.enter_block("content") - self.print_and_append(content) - - def render_tool_call(self, tool_calls: list[ChatCompletionDeltaToolCall | Any]): - self.enter_block("tool_call") - for tool_call in tool_calls: - # first it says the tool name, then the arguments - if tool_call.function.name is not None: - self.print_and_append(f'Calling tool: "{tool_call.function.name}" ') - self.print_and_append("with args: ") - if tool_call.function.arguments is not None: - args = tool_call.function.arguments - self.print_and_append(args) - - def render_stop(self, stop_reason: str): - self.print_and_append("\n") - - def render_unknown(self, chunk: litellm.ModelResponseStream): - self.enter_block("unknown") - self.print_and_append(f"Unknown chunk: {chunk}") - - async def render_chunk(self, chunk: litellm.ModelResponseStream): - if chunk.choices[0].finish_reason is not None: - self.render_stop(chunk.choices[0].finish_reason) - return - elif chunk.choices[0].delta is not None: - # inconsistent behavior between providers, some have multiple fields at once, some don't - if chunk.choices[0].delta.tool_calls is not None: - self.render_tool_call(chunk.choices[0].delta.tool_calls) - elif getattr(chunk.choices[0].delta, "reasoning_content", None) is not None: - text = getattr(chunk.choices[0].delta, "reasoning_content", None) - if text is not None: - self.render_reasoning(text) - elif chunk.choices[0].delta.content is not None: - self.render_content(chunk.choices[0].delta.content) - else: - self.render_unknown(chunk) - - def get_stream_text(self) -> str: - return "".join(self.chunk_texts) - - -class ChunkRawRenderer(ChunkRendererAbstract): - def __init__(self): - self.chunks: list[litellm.ModelResponseStream] = [] - self.current_block_type: str | None = None - - async def render_chunk(self, chunk: litellm.ModelResponseStream): - logger.info(str(chunk)) - self.chunks.append(chunk) - - def get_stream_text(self) -> str: - return "\n".join([str(chunk) for chunk in self.chunks]) +STREAMING_MODELS = [ + ("claude_sonnet_4_5", ModelProviderName.openrouter), + ("claude_sonnet_4_5", ModelProviderName.anthropic), + ("claude_sonnet_4_6", ModelProviderName.openrouter), + ("claude_sonnet_4_6", ModelProviderName.anthropic), + ("claude_opus_4_5", ModelProviderName.openrouter), + ("claude_opus_4_5", ModelProviderName.anthropic), + ("claude_opus_4_6", ModelProviderName.openrouter), + ("claude_opus_4_6", ModelProviderName.anthropic), + ("minimax_m2_5", ModelProviderName.openrouter), + ("claude_4_5_haiku", ModelProviderName.openrouter), + ("claude_4_5_haiku", ModelProviderName.anthropic), +] + +STREAMING_MODELS_NO_HAIKU = [m for m in STREAMING_MODELS if "haiku" not in m[0]] + +PAID_TEST_OUTPUT_DIR = Path(__file__).resolve().parents[5] / "test_output" + + +def _serialize_for_dump(obj: Any) -> Any: + if hasattr(obj, "model_dump"): + return obj.model_dump(mode="json") + if isinstance(obj, list): + if not obj: + return [] + first = obj[0] + if hasattr(first, "type") and hasattr(first, "payload"): + return [{"type": e.type.value, "payload": e.payload} for e in obj] + if hasattr(first, "model_dump"): + return [item.model_dump(mode="json") for item in obj] + return [_serialize_for_dump(x) for x in obj] + return obj + + +def _dump_paid_test_output(request: pytest.FixtureRequest, **payloads: Any) -> Path: + test_name = re.sub(r"[^\w\-]", "_", request.node.name) + param_id = "default" + if hasattr(request.node, "callspec") and request.node.callspec is not None: + id_attr = getattr(request.node.callspec, "id", None) + if id_attr is not None: + param_id = re.sub(r"[^\w\-]", "_", str(id_attr)) + timestamp = datetime.now(timezone.utc).strftime("%Y-%m-%dT%H-%M-%S") + out_dir = PAID_TEST_OUTPUT_DIR / test_name / param_id / timestamp + out_dir.mkdir(parents=True, exist_ok=True) + for filename, data in payloads.items(): + if data is None: + continue + if not filename.endswith(".json"): + filename = f"{filename}.json" + serialized = _serialize_for_dump(data) + (out_dir / filename).write_text( + json.dumps(serialized, indent=2, default=str), encoding="utf-8" + ) + return out_dir @pytest.fixture @@ -129,7 +100,7 @@ def adapter_factory(task: Task) -> Callable[[str, ModelProviderName], LiteLlmAda def create_adapter( model_id: str, provider_name: ModelProviderName ) -> LiteLlmAdapter: - adapter = LiteLlmAdapter( + return LiteLlmAdapter( kiln_task=task, config=LiteLlmConfig( run_config_properties=KilnAgentRunConfigProperties( @@ -148,163 +119,28 @@ def create_adapter( ) ), ) - return adapter return create_adapter @pytest.mark.paid -@pytest.mark.parametrize( - "model_id,provider_name", - [ - ("claude_sonnet_4_5", ModelProviderName.openrouter), - ("claude_sonnet_4_5", ModelProviderName.anthropic), - ("claude_sonnet_4_6", ModelProviderName.openrouter), - ("claude_sonnet_4_6", ModelProviderName.anthropic), - ("claude_opus_4_5", ModelProviderName.openrouter), - ("claude_opus_4_5", ModelProviderName.anthropic), - ("claude_opus_4_6", ModelProviderName.openrouter), - ("claude_opus_4_6", ModelProviderName.anthropic), - ("minimax_m2_5", ModelProviderName.openrouter), - ("claude_4_5_haiku", ModelProviderName.openrouter), - ("claude_4_5_haiku", ModelProviderName.anthropic), - ], -) -async def test_acompletion_streaming_response( +@pytest.mark.parametrize("model_id,provider_name", STREAMING_MODELS) +async def test_invoke_openai_stream_chunks( + request: pytest.FixtureRequest, model_id: str, provider_name: ModelProviderName, adapter_factory: Callable[[str, ModelProviderName], LiteLlmAdapter], ): - """Check the accumulated response has all the expected parts""" - adapter = adapter_factory(model_id, provider_name) - - renderer = ChunkRenderer() - - # we proxy all the calls to the original function so we can spy on the return values - captured_responses: list[Tuple[litellm.ModelResponse, litellm.Choices]] = [] - origin_func = adapter.acompletion_checking_response - - async def spy( - *args: Any, **kwargs: Any - ) -> Tuple[litellm.ModelResponse, litellm.Choices]: - nonlocal captured_responses - - result = await origin_func(*args, **kwargs) - captured_responses.append(result) - return result - - with patch.object(adapter, "acompletion_checking_response", side_effect=spy): - task_run = await adapter.invoke( - input="123 + 321 = ?", - on_chunk=renderer.render_chunk, - ) - - # there is one call per thing going on (tool call, content, etc.) - # with our toy task, we expect ~2 or 3 calls (reasoning + tool call -> content) - if len(captured_responses) == 0: - raise RuntimeError( - "captured_responses is empty after invocation - test probably broken due to wrong spy" - ) - - # check we are getting the trace successfully - assert task_run.trace is not None, "Task run trace is None" - assert len(task_run.trace) > 0, "Task run trace is empty" - - assistant_messages: list[litellm.Message] = [] - for model_response, _ in captured_responses: - for choice in model_response.choices: - if isinstance(choice, litellm.Choices): - assistant_messages.append(choice.message) - assert len(assistant_messages) > 0, "No assistant messages found in the trace" - - # we do not know which message the reasoning / content / tool call is in, but we know each one - # should appear in at least one message so we accumulate them here - reasoning_contents: list[str] = [] - contents: list[str] = [] - tool_calls: list[ChatCompletionDeltaToolCall | Any] = [] - for assistant_message in assistant_messages: - reasoning_content = getattr(assistant_message, "reasoning_content", None) - if reasoning_content: - reasoning_contents.append(reasoning_content) - - content = getattr(assistant_message, "content", None) - if content: - contents.append(str(content)) - - _tool_calls = getattr(assistant_message, "tool_calls", None) - if _tool_calls: - tool_calls.extend(_tool_calls) - - # check we got all the expected parts somewhere - assert len(reasoning_contents) > 0, "No reasoning contents found in the trace" - assert len(contents) > 0, "No contents found in the trace" - assert len(tool_calls) > 0, "No tool calls found in the trace" - assert len(tool_calls) == 1, "Expected exactly one tool call (to do the math)" - - # check we got some non-empty reasoning - we should have gotten some reasoning at least somewhere - # usually the toolcall - assert not all( - reasoning_content.strip() == "" for reasoning_content in reasoning_contents - ), "All reasoning contents are empty" - - # check we got some non-empty content (we get empty strings when there is no content) - assert not all(content.strip() == "" for content in contents), ( - "All contents are empty" - ) - - for tool_call in tool_calls: - assert tool_call.function.name is not None, "Tool call name is None" - assert tool_call.function.arguments is not None, "Tool call arguments are None" - assert json.loads(tool_call.function.arguments) is not None, ( - "Tool call arguments are not JSON" - ) - tool_call_args = json.loads(tool_call.function.arguments) - assert tool_call_args == { - "a": 123, - "b": 321, - } or tool_call_args == { - "a": 321, - "b": 123, - }, f"Tool call arguments are not the expected values: {tool_call_args}" - - -@pytest.mark.paid -@pytest.mark.parametrize( - "model_id,provider_name", - [ - ("claude_sonnet_4_5", ModelProviderName.openrouter), - ("claude_sonnet_4_5", ModelProviderName.anthropic), - ("claude_sonnet_4_6", ModelProviderName.openrouter), - ("claude_sonnet_4_6", ModelProviderName.anthropic), - ("claude_opus_4_5", ModelProviderName.openrouter), - ("claude_opus_4_5", ModelProviderName.anthropic), - ("claude_opus_4_6", ModelProviderName.openrouter), - ("claude_opus_4_6", ModelProviderName.anthropic), - ("minimax_m2_5", ModelProviderName.openrouter), - ("claude_4_5_haiku", ModelProviderName.openrouter), - ("claude_4_5_haiku", ModelProviderName.anthropic), - ], -) -async def test_acompletion_streaming_chunks( - model_id: str, - provider_name: ModelProviderName, - adapter_factory: Callable[[str, ModelProviderName], LiteLlmAdapter], -): - """Collect all chunks from all completion calls, then one pass to check we got reasoning, content, and tool calls.""" - + """Collect all OpenAI-protocol chunks via invoke_openai_stream and verify we got reasoning, content, and tool call data.""" adapter = adapter_factory(model_id, provider_name) chunks: list[litellm.ModelResponseStream] = [] - - renderer = ChunkRenderer() - - async def collect_chunks(chunk: litellm.ModelResponseStream) -> None: + async for chunk in adapter.invoke_openai_stream(input="123 + 321 = ?"): chunks.append(chunk) - await renderer.render_chunk(chunk) - - await adapter.invoke(input="123 + 321 = ?", on_chunk=collect_chunks) + _dump_paid_test_output(request, chunks=chunks) assert len(chunks) > 0, "No chunks collected" + reasoning_contents: list[str] = [] contents: list[str] = [] tool_calls: list[ChatCompletionDeltaToolCall | Any] = [] @@ -333,9 +169,7 @@ async def collect_chunks(chunk: litellm.ModelResponseStream) -> None: assert not all(c.strip() == "" for c in contents), "All content in chunks is empty" tool_call_function_names = [ - tool_call.function.name - for tool_call in tool_calls - if tool_call.function.name is not None + tc.function.name for tc in tool_calls if tc.function.name is not None ] assert len(tool_call_function_names) == 1, ( "Expected exactly one tool call function name" @@ -343,13 +177,8 @@ async def collect_chunks(chunk: litellm.ModelResponseStream) -> None: assert tool_call_function_names[0] == "add", "Tool call function name is not 'add'" tool_call_args_chunks = "".join( - [ - tool_call.function.arguments - for tool_call in tool_calls - if tool_call.function.arguments is not None - ] + tc.function.arguments for tc in tool_calls if tc.function.arguments is not None ) - tool_call_args = json.loads(tool_call_args_chunks) assert tool_call_args == {"a": 123, "b": 321} or tool_call_args == { "a": 321, @@ -358,102 +187,122 @@ async def collect_chunks(chunk: litellm.ModelResponseStream) -> None: @pytest.mark.paid -@pytest.mark.parametrize( - "model_id,provider_name", - [ - ("claude_sonnet_4_5", ModelProviderName.openrouter), - ("claude_sonnet_4_5", ModelProviderName.anthropic), - ("claude_sonnet_4_6", ModelProviderName.openrouter), - ("claude_sonnet_4_6", ModelProviderName.anthropic), - ("claude_opus_4_5", ModelProviderName.openrouter), - ("claude_opus_4_5", ModelProviderName.anthropic), - ("claude_opus_4_6", ModelProviderName.openrouter), - ("claude_opus_4_6", ModelProviderName.anthropic), - ("minimax_m2_5", ModelProviderName.openrouter), - ("claude_4_5_haiku", ModelProviderName.openrouter), - ("claude_4_5_haiku", ModelProviderName.anthropic), - ], -) -async def test_acompletion_streaming_rendering( +@pytest.mark.parametrize("model_id,provider_name", STREAMING_MODELS) +async def test_invoke_ai_sdk_stream( + request: pytest.FixtureRequest, model_id: str, provider_name: ModelProviderName, adapter_factory: Callable[[str, ModelProviderName], LiteLlmAdapter], ): - """Test that the streaming response with a renderer to see how it looks""" + """Collect AI SDK events and verify the full protocol lifecycle including tool events.""" adapter = adapter_factory(model_id, provider_name) - renderer = ChunkRenderer() - await adapter.invoke(input="123 + 321 = ?", on_chunk=renderer.render_chunk) - assert renderer.get_stream_text() is not None + + events: list[AiSdkStreamEvent] = [] + async for event in adapter.invoke_ai_sdk_stream(input="123 + 321 = ?"): + events.append(event) + logger.info(f"AI SDK event: {event.type.value} {event.payload}") + + _dump_paid_test_output(request, events=events) + assert len(events) > 0, "No events collected" + + event_types = [e.type for e in events] + + assert event_types[0] == AiSdkEventType.START, "First event should be START" + assert event_types[1] == AiSdkEventType.START_STEP, ( + "Second event should be START_STEP" + ) + + assert AiSdkEventType.FINISH_STEP in event_types, "Should have FINISH_STEP" + assert AiSdkEventType.FINISH in event_types, "Should have FINISH" + + assert AiSdkEventType.REASONING_START in event_types, "Should have REASONING_START" + assert AiSdkEventType.REASONING_DELTA in event_types, "Should have REASONING_DELTA" + + assert AiSdkEventType.TEXT_START in event_types, "Should have TEXT_START" + assert AiSdkEventType.TEXT_DELTA in event_types, "Should have TEXT_DELTA" + assert AiSdkEventType.TEXT_END in event_types, "Should have TEXT_END" + + assert AiSdkEventType.TOOL_INPUT_START in event_types, ( + "Should have TOOL_INPUT_START" + ) + assert AiSdkEventType.TOOL_INPUT_AVAILABLE in event_types, ( + "Should have TOOL_INPUT_AVAILABLE" + ) + assert AiSdkEventType.TOOL_OUTPUT_AVAILABLE in event_types, ( + "Should have TOOL_OUTPUT_AVAILABLE" + ) + + text_deltas = [ + e.payload.get("delta", "") + for e in events + if e.type == AiSdkEventType.TEXT_DELTA + ] + full_text = "".join(text_deltas) + assert len(full_text) > 0, "Text content is empty" + + tool_input_available = [ + e for e in events if e.type == AiSdkEventType.TOOL_INPUT_AVAILABLE + ] + assert len(tool_input_available) >= 1, ( + "Should have at least one tool-input-available" + ) + tool_input = tool_input_available[0].payload.get("input", {}) + assert "a" in tool_input and "b" in tool_input, ( + f"Tool input should have a and b keys: {tool_input}" + ) + + tool_output_available = [ + e for e in events if e.type == AiSdkEventType.TOOL_OUTPUT_AVAILABLE + ] + assert len(tool_output_available) >= 1, ( + "Should have at least one tool-output-available" + ) + assert tool_output_available[0].payload.get("output") is not None, ( + "Tool output should not be None" + ) @pytest.mark.paid -@pytest.mark.parametrize( - "model_id,provider_name", - [ - ("claude_sonnet_4_5", ModelProviderName.openrouter), - ("claude_sonnet_4_5", ModelProviderName.anthropic), - ("claude_sonnet_4_6", ModelProviderName.openrouter), - ("claude_sonnet_4_6", ModelProviderName.anthropic), - ("claude_opus_4_5", ModelProviderName.openrouter), - ("claude_opus_4_5", ModelProviderName.anthropic), - ("claude_opus_4_6", ModelProviderName.openrouter), - ("claude_opus_4_6", ModelProviderName.anthropic), - ("minimax_m2_5", ModelProviderName.openrouter), - ], -) -async def test_acompletion_streaming_rendering_raw_chunks( +@pytest.mark.parametrize("model_id,provider_name", STREAMING_MODELS_NO_HAIKU) +async def test_invoke_openai_stream_non_streaming_still_works( + request: pytest.FixtureRequest, model_id: str, provider_name: ModelProviderName, adapter_factory: Callable[[str, ModelProviderName], LiteLlmAdapter], ): - """Test that the streaming response with a renderer to see how it looks, but with raw chunks""" + """Verify the non-streaming invoke() still works after the refactor.""" adapter = adapter_factory(model_id, provider_name) - renderer = ChunkRawRenderer() - await adapter.invoke(input="123 + 321 = ?", on_chunk=renderer.render_chunk) - assert renderer.get_stream_text() is not None + task_run = await adapter.invoke(input="123 + 321 = ?") + + _dump_paid_test_output(request, task_run=task_run) + assert task_run.trace is not None, "Task run trace is None" + assert len(task_run.trace) > 0, "Task run trace is empty" + assert "444" in task_run.output.output, ( + f"Expected 444 in output: {task_run.output.output}" + ) @pytest.mark.paid -@pytest.mark.parametrize( - "model_id,provider_name", - [ - ("claude_sonnet_4_5", ModelProviderName.openrouter), - ("claude_sonnet_4_5", ModelProviderName.anthropic), - ("claude_sonnet_4_6", ModelProviderName.openrouter), - ("claude_sonnet_4_6", ModelProviderName.anthropic), - ("claude_opus_4_5", ModelProviderName.openrouter), - ("claude_opus_4_5", ModelProviderName.anthropic), - ("claude_opus_4_6", ModelProviderName.openrouter), - ("claude_opus_4_6", ModelProviderName.anthropic), - ("minimax_m2_5", ModelProviderName.openrouter), - ], -) -async def test_acompletion_streaming_with_existing_run( +@pytest.mark.parametrize("model_id,provider_name", STREAMING_MODELS_NO_HAIKU) +async def test_invoke_openai_stream_with_existing_run( + request: pytest.FixtureRequest, model_id: str, provider_name: ModelProviderName, adapter_factory: Callable[[str, ModelProviderName], LiteLlmAdapter], ): """Test that streaming works when continuing an existing run (session continuation).""" adapter = adapter_factory(model_id, provider_name) - renderer = ChunkRawRenderer() - initial_run = await adapter.invoke( - input="123 + 321 = ?", - on_chunk=renderer.render_chunk, - ) + initial_run = await adapter.invoke(input="123 + 321 = ?") assert initial_run.trace is not None assert len(initial_run.trace) > 0 - initial_trace_len = len(initial_run.trace) - continuation_renderer = ChunkRawRenderer() - continued_run = await adapter.invoke( + continuation_chunks: list[litellm.ModelResponseStream] = [] + async for chunk in adapter.invoke_openai_stream( input="What was the result? Reply in one short sentence.", existing_run=initial_run, - on_chunk=continuation_renderer.render_chunk, - ) + ): + continuation_chunks.append(chunk) - assert continued_run.id == initial_run.id - assert continued_run.trace is not None - assert len(continued_run.trace) > initial_trace_len - assert continuation_renderer.get_stream_text() is not None - assert len(continuation_renderer.chunks) > 0 + _dump_paid_test_output(request, continuation_chunks=continuation_chunks) + assert len(continuation_chunks) > 0, "No continuation chunks collected" diff --git a/libs/core/kiln_ai/adapters/model_adapters/test_litellm_adapter_tools.py b/libs/core/kiln_ai/adapters/model_adapters/test_litellm_adapter_tools.py index bc15622ed..3674fe5b3 100644 --- a/libs/core/kiln_ai/adapters/model_adapters/test_litellm_adapter_tools.py +++ b/libs/core/kiln_ai/adapters/model_adapters/test_litellm_adapter_tools.py @@ -89,15 +89,7 @@ async def run_simple_task_with_tools( with patch.object(adapter, "available_tools", return_value=mock_math_tools): if simplified: - # test our chunking handler also works e2e on real models - received_chunks = [] - - async def on_chunk_handler(chunk): - received_chunks.append(chunk) - - run = await adapter.invoke("what is 2+2", on_chunk=on_chunk_handler) - - assert len(received_chunks) > 0 + run = await adapter.invoke("what is 2+2") # Verify that AddTool.run was called with correct parameters add_spy.run.assert_called() @@ -297,9 +289,7 @@ async def test_tools_simplied_mocked(tmp_path): responses = [mock_response_1, mock_response_2] - async def mock_acompletion_checking_response(self, on_chunk=None, **kwargs): - if on_chunk is not None: - await on_chunk(Mock()) + async def mock_acompletion_checking_response(self, **kwargs): response = responses.pop(0) return response, response.choices[0] diff --git a/libs/core/kiln_ai/adapters/model_adapters/test_stream_events.py b/libs/core/kiln_ai/adapters/model_adapters/test_stream_events.py new file mode 100644 index 000000000..c12d83949 --- /dev/null +++ b/libs/core/kiln_ai/adapters/model_adapters/test_stream_events.py @@ -0,0 +1,227 @@ +from litellm.types.utils import ( + ChatCompletionDeltaToolCall, + Delta, + Function, + ModelResponseStream, + StreamingChoices, +) + +from kiln_ai.adapters.model_adapters.stream_events import ( + AiSdkEventType, + AiSdkStreamConverter, + AiSdkStreamEvent, + ToolCallEvent, + ToolCallEventType, +) + + +def _make_tool_call_delta( + index: int = 0, + call_id: str | None = None, + name: str | None = None, + arguments: str | None = None, +) -> ChatCompletionDeltaToolCall: + func = Function(name=name, arguments=arguments or "") + tc = ChatCompletionDeltaToolCall(index=index, function=func) + if call_id is not None: + tc.id = call_id + return tc + + +def _make_chunk( + content: str | None = None, + reasoning_content: str | None = None, + tool_calls: list[ChatCompletionDeltaToolCall] | None = None, + finish_reason: str | None = None, +) -> ModelResponseStream: + delta = Delta(content=content, tool_calls=tool_calls) + if reasoning_content is not None: + delta.reasoning_content = reasoning_content + choice = StreamingChoices( + index=0, + delta=delta, + finish_reason=finish_reason, + ) + return ModelResponseStream(id="test", choices=[choice]) + + +class TestAiSdkStreamEvent: + def test_to_sse(self): + event = AiSdkStreamEvent(AiSdkEventType.START, {"messageId": "msg-123"}) + sse = event.to_sse() + assert sse.startswith("data: ") + assert sse.endswith("\n\n") + assert '"type":"start"' in sse + assert '"messageId":"msg-123"' in sse + + +class TestAiSdkStreamConverter: + def test_text_start_and_delta(self): + converter = AiSdkStreamConverter() + events = converter.convert_chunk(_make_chunk(content="Hello")) + types = [e.type for e in events] + assert AiSdkEventType.TEXT_START in types + assert AiSdkEventType.TEXT_DELTA in types + assert events[-1].payload["delta"] == "Hello" + + def test_text_delta_no_duplicate_start(self): + converter = AiSdkStreamConverter() + converter.convert_chunk(_make_chunk(content="Hello")) + events = converter.convert_chunk(_make_chunk(content=" world")) + types = [e.type for e in events] + assert AiSdkEventType.TEXT_START not in types + assert AiSdkEventType.TEXT_DELTA in types + + def test_reasoning_start_and_delta(self): + converter = AiSdkStreamConverter() + events = converter.convert_chunk(_make_chunk(reasoning_content="Thinking...")) + types = [e.type for e in events] + assert AiSdkEventType.REASONING_START in types + assert AiSdkEventType.REASONING_DELTA in types + + def test_reasoning_ends_when_content_starts(self): + converter = AiSdkStreamConverter() + converter.convert_chunk(_make_chunk(reasoning_content="Thinking...")) + events = converter.convert_chunk(_make_chunk(content="Answer")) + types = [e.type for e in events] + assert AiSdkEventType.REASONING_END in types + assert AiSdkEventType.TEXT_START in types + + def test_reasoning_ends_when_tool_calls_start(self): + converter = AiSdkStreamConverter() + converter.convert_chunk(_make_chunk(reasoning_content="Thinking...")) + + tc_delta = _make_tool_call_delta( + index=0, call_id="call_1", name="add", arguments='{"a":1}' + ) + events = converter.convert_chunk(_make_chunk(tool_calls=[tc_delta])) + types = [e.type for e in events] + assert AiSdkEventType.REASONING_END in types + + def test_tool_call_input_start_and_delta(self): + converter = AiSdkStreamConverter() + + tc_delta = _make_tool_call_delta( + index=0, call_id="call_1", name="add", arguments='{"a":' + ) + events = converter.convert_chunk(_make_chunk(tool_calls=[tc_delta])) + types = [e.type for e in events] + assert AiSdkEventType.TOOL_INPUT_START in types + assert AiSdkEventType.TOOL_INPUT_DELTA in types + + start_event = next( + e for e in events if e.type == AiSdkEventType.TOOL_INPUT_START + ) + assert start_event.payload["toolCallId"] == "call_1" + assert start_event.payload["toolName"] == "add" + + def test_finalize_closes_open_blocks(self): + converter = AiSdkStreamConverter() + converter.convert_chunk(_make_chunk(content="text")) + events = converter.finalize() + types = [e.type for e in events] + assert AiSdkEventType.TEXT_END in types + assert AiSdkEventType.FINISH in types + + def test_finalize_closes_reasoning(self): + converter = AiSdkStreamConverter() + converter.convert_chunk(_make_chunk(reasoning_content="thinking")) + events = converter.finalize() + types = [e.type for e in events] + assert AiSdkEventType.REASONING_END in types + + def test_convert_tool_event_input_available(self): + converter = AiSdkStreamConverter() + event = ToolCallEvent( + event_type=ToolCallEventType.INPUT_AVAILABLE, + tool_call_id="call_1", + tool_name="add", + arguments={"a": 1, "b": 2}, + ) + events = converter.convert_tool_event(event) + assert len(events) == 1 + assert events[0].type == AiSdkEventType.TOOL_INPUT_AVAILABLE + assert events[0].payload["toolCallId"] == "call_1" + assert events[0].payload["input"] == {"a": 1, "b": 2} + + def test_convert_tool_event_output_available(self): + converter = AiSdkStreamConverter() + event = ToolCallEvent( + event_type=ToolCallEventType.OUTPUT_AVAILABLE, + tool_call_id="call_1", + tool_name="add", + result="3", + ) + events = converter.convert_tool_event(event) + assert len(events) == 1 + assert events[0].type == AiSdkEventType.TOOL_OUTPUT_AVAILABLE + assert events[0].payload["output"] == "3" + + def test_convert_tool_event_output_error(self): + converter = AiSdkStreamConverter() + event = ToolCallEvent( + event_type=ToolCallEventType.OUTPUT_ERROR, + tool_call_id="call_1", + tool_name="add", + error="Something went wrong", + ) + events = converter.convert_tool_event(event) + assert len(events) == 1 + assert events[0].type == AiSdkEventType.TOOL_OUTPUT_ERROR + assert events[0].payload["errorText"] == "Something went wrong" + + def test_reasoning_not_interrupted_by_empty_content(self): + # Minimax and similar models send chunks with both reasoning_content and + # delta.content="" simultaneously. Empty content must not close reasoning + # blocks or emit useless text-delta events. + converter = AiSdkStreamConverter() + + chunk1 = _make_chunk(reasoning_content="The", content="") + chunk2 = _make_chunk(reasoning_content=" user", content="") + chunk3 = _make_chunk(reasoning_content=" is", content="") + + events1 = converter.convert_chunk(chunk1) + events2 = converter.convert_chunk(chunk2) + events3 = converter.convert_chunk(chunk3) + + all_types1 = [e.type for e in events1] + all_types2 = [e.type for e in events2] + all_types3 = [e.type for e in events3] + + # First chunk opens the reasoning block + assert AiSdkEventType.REASONING_START in all_types1 + assert AiSdkEventType.REASONING_DELTA in all_types1 + # No text events from empty content + assert AiSdkEventType.TEXT_START not in all_types1 + assert AiSdkEventType.TEXT_DELTA not in all_types1 + + # Subsequent chunks must NOT re-open reasoning (no start) and must NOT + # close reasoning with reasoning-end + assert AiSdkEventType.REASONING_START not in all_types2 + assert AiSdkEventType.REASONING_END not in all_types2 + assert AiSdkEventType.REASONING_DELTA in all_types2 + assert AiSdkEventType.TEXT_DELTA not in all_types2 + + assert AiSdkEventType.REASONING_START not in all_types3 + assert AiSdkEventType.REASONING_END not in all_types3 + assert AiSdkEventType.REASONING_DELTA in all_types3 + assert AiSdkEventType.TEXT_DELTA not in all_types3 + + def test_reset_for_next_step(self): + converter = AiSdkStreamConverter() + converter._finish_reason = "tool_calls" + converter._tool_calls_state = { + 0: {"id": "x", "name": "y", "arguments": "", "started": True} + } + converter.reset_for_next_step() + assert converter._tool_calls_state == {} + assert converter._finish_reason is None + + def test_finish_reason_in_finalize(self): + converter = AiSdkStreamConverter() + converter.convert_chunk(_make_chunk(content="done", finish_reason="stop")) + events = converter.finalize() + finish_events = [e for e in events if e.type == AiSdkEventType.FINISH] + assert len(finish_events) == 1 + meta = finish_events[0].payload.get("messageMetadata", {}) + assert meta.get("finishReason") == "stop" From 0ea65b4fcfd69186b1ef7d5ac85c3d0cd6508a12 Mon Sep 17 00:00:00 2001 From: "Leonard Q. Marcq" Date: Sun, 8 Mar 2026 15:34:06 +0800 Subject: [PATCH 09/32] refactor: ai sdk events as pydantic models --- .../kiln_ai/adapters/model_adapters/stream_events.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/libs/core/kiln_ai/adapters/model_adapters/stream_events.py b/libs/core/kiln_ai/adapters/model_adapters/stream_events.py index a60c03260..f95e7b838 100644 --- a/libs/core/kiln_ai/adapters/model_adapters/stream_events.py +++ b/libs/core/kiln_ai/adapters/model_adapters/stream_events.py @@ -7,6 +7,7 @@ from typing import Any from litellm.types.utils import ModelResponseStream +from pydantic import BaseModel class AiSdkEventType(str, Enum): @@ -40,11 +41,16 @@ class AiSdkEventType(str, Enum): FILE = "file" -@dataclass -class AiSdkStreamEvent: +class AiSdkStreamEvent(BaseModel): type: AiSdkEventType payload: dict[str, Any] = field(default_factory=dict) + def __init__(self, type: AiSdkEventType, payload: dict[str, Any] = {}): + super().__init__( + type=type, + payload=payload, + ) + def to_sse(self) -> str: data = {"type": self.type.value, **self.payload} return f"data: {json.dumps(data, separators=(',', ':'))}\n\n" From a98d886146f917aa71b7edc04b784f09065177aa Mon Sep 17 00:00:00 2001 From: "Leonard Q. Marcq" Date: Sun, 8 Mar 2026 15:47:54 +0800 Subject: [PATCH 10/32] fix: model_dump implementation and remove to_see to leave transport specifics to caller --- .../adapters/model_adapters/stream_events.py | 19 +++++++------------ .../model_adapters/test_stream_events.py | 10 ++++------ 2 files changed, 11 insertions(+), 18 deletions(-) diff --git a/libs/core/kiln_ai/adapters/model_adapters/stream_events.py b/libs/core/kiln_ai/adapters/model_adapters/stream_events.py index f95e7b838..b2784484a 100644 --- a/libs/core/kiln_ai/adapters/model_adapters/stream_events.py +++ b/libs/core/kiln_ai/adapters/model_adapters/stream_events.py @@ -1,13 +1,11 @@ from __future__ import annotations -import json import uuid from dataclasses import dataclass, field from enum import Enum from typing import Any from litellm.types.utils import ModelResponseStream -from pydantic import BaseModel class AiSdkEventType(str, Enum): @@ -41,19 +39,16 @@ class AiSdkEventType(str, Enum): FILE = "file" -class AiSdkStreamEvent(BaseModel): +@dataclass +class AiSdkStreamEvent: type: AiSdkEventType payload: dict[str, Any] = field(default_factory=dict) - def __init__(self, type: AiSdkEventType, payload: dict[str, Any] = {}): - super().__init__( - type=type, - payload=payload, - ) - - def to_sse(self) -> str: - data = {"type": self.type.value, **self.payload} - return f"data: {json.dumps(data, separators=(',', ':'))}\n\n" + def model_dump(self) -> dict[str, Any]: + return { + "type": self.type.value, + **self.payload, + } class ToolCallEventType(str, Enum): diff --git a/libs/core/kiln_ai/adapters/model_adapters/test_stream_events.py b/libs/core/kiln_ai/adapters/model_adapters/test_stream_events.py index c12d83949..ccc79581d 100644 --- a/libs/core/kiln_ai/adapters/model_adapters/test_stream_events.py +++ b/libs/core/kiln_ai/adapters/model_adapters/test_stream_events.py @@ -46,13 +46,11 @@ def _make_chunk( class TestAiSdkStreamEvent: - def test_to_sse(self): + def test_model_dump(self): event = AiSdkStreamEvent(AiSdkEventType.START, {"messageId": "msg-123"}) - sse = event.to_sse() - assert sse.startswith("data: ") - assert sse.endswith("\n\n") - assert '"type":"start"' in sse - assert '"messageId":"msg-123"' in sse + dump = event.model_dump() + assert dump["type"] == "start" + assert dump["messageId"] == "msg-123" class TestAiSdkStreamConverter: From e989dca32bdd7fe116fc328e8d5c5ddd8a912253 Mon Sep 17 00:00:00 2001 From: "Leonard Q. Marcq" Date: Sun, 8 Mar 2026 16:06:12 +0800 Subject: [PATCH 11/32] fix: should reset before next round of toolcalls --- .../adapters/model_adapters/base_adapter.py | 5 ++ .../model_adapters/test_base_adapter.py | 82 +++++++++++++++++-- .../model_adapters/test_stream_events.py | 38 +++++++++ 3 files changed, 118 insertions(+), 7 deletions(-) diff --git a/libs/core/kiln_ai/adapters/model_adapters/base_adapter.py b/libs/core/kiln_ai/adapters/model_adapters/base_adapter.py index e6a301401..c599de650 100644 --- a/libs/core/kiln_ai/adapters/model_adapters/base_adapter.py +++ b/libs/core/kiln_ai/adapters/model_adapters/base_adapter.py @@ -351,13 +351,18 @@ async def invoke_ai_sdk_stream( yield AiSdkStreamEvent(AiSdkEventType.START_STEP) + last_event_was_tool_call = False async for event in adapter_stream: # ModelResponseStream events come from LiteLLM's own OpenAI compatible streaming if isinstance(event, ModelResponseStream): + if last_event_was_tool_call: + converter.reset_for_next_step() + last_event_was_tool_call = False for ai_event in converter.convert_chunk(event): yield ai_event # ToolCallEvent events come from ourselves and are emitted on rounds of toolcalls elif isinstance(event, ToolCallEvent): + last_event_was_tool_call = True for ai_event in converter.convert_tool_event(event): yield ai_event diff --git a/libs/core/kiln_ai/adapters/model_adapters/test_base_adapter.py b/libs/core/kiln_ai/adapters/model_adapters/test_base_adapter.py index 0e8e51d95..f842e4b98 100644 --- a/libs/core/kiln_ai/adapters/model_adapters/test_base_adapter.py +++ b/libs/core/kiln_ai/adapters/model_adapters/test_base_adapter.py @@ -1,6 +1,13 @@ from unittest.mock import AsyncMock, MagicMock, patch import pytest +from litellm.types.utils import ( + ChatCompletionDeltaToolCall, + Delta, + Function, + ModelResponseStream, + StreamingChoices, +) from kiln_ai.adapters.ml_model_list import KilnModelProvider, StructuredOutputMode from kiln_ai.adapters.model_adapters.base_adapter import ( @@ -8,14 +15,13 @@ BaseAdapter, RunOutput, ) -from kiln_ai.adapters.prompt_builders import BasePromptBuilder -from kiln_ai.datamodel import ( - DataSource, - DataSourceType, - Task, - TaskOutput, - TaskRun, +from kiln_ai.adapters.model_adapters.stream_events import ( + AiSdkEventType, + ToolCallEvent, + ToolCallEventType, ) +from kiln_ai.adapters.prompt_builders import BasePromptBuilder +from kiln_ai.datamodel import DataSource, DataSourceType, Task, TaskOutput, TaskRun from kiln_ai.datamodel.datamodel_enums import ChatStrategy, ModelProviderName from kiln_ai.datamodel.project import Project from kiln_ai.datamodel.run_config import KilnAgentRunConfigProperties, ToolsRunConfig @@ -1087,3 +1093,65 @@ async def test_invoke_ai_sdk_stream_raises_for_unsupported_adapter( with pytest.raises(NotImplementedError, match="Streaming is not supported"): async for _event in stream_adapter.invoke_ai_sdk_stream("test input"): pass + + @pytest.mark.asyncio + async def test_invoke_ai_sdk_stream_resets_converter_between_tool_rounds( + self, stream_adapter + ): + """tool-input-start must be emitted for a new tool call at index 0 after a tool round.""" + + def _make_tool_chunk(call_id: str, name: str) -> ModelResponseStream: + func = Function(name=name, arguments='{"x":1}') + tc = ChatCompletionDeltaToolCall(index=0, function=func) + tc.id = call_id + delta = Delta(tool_calls=[tc]) + choice = StreamingChoices(index=0, delta=delta, finish_reason=None) + return ModelResponseStream(id="test", choices=[choice]) + + round1_chunk = _make_tool_chunk("call_r1", "tool_a") + round2_chunk = _make_tool_chunk("call_r2", "tool_b") + + fake_events = [ + round1_chunk, + ToolCallEvent( + event_type=ToolCallEventType.INPUT_AVAILABLE, + tool_call_id="call_r1", + tool_name="tool_a", + arguments={"x": 1}, + ), + ToolCallEvent( + event_type=ToolCallEventType.OUTPUT_AVAILABLE, + tool_call_id="call_r1", + tool_name="tool_a", + result="done", + ), + round2_chunk, + ] + + class FakeAdapterStream: + result = MagicMock() + + async def __aiter__(self): + for event in fake_events: + yield event + + with ( + patch.object( + stream_adapter, + "_prepare_stream", + return_value=FakeAdapterStream(), + ), + patch.object(stream_adapter, "_finalize_stream"), + ): + events = [] + async for event in stream_adapter.invoke_ai_sdk_stream("test input"): + events.append(event) + + tool_input_starts = [ + e for e in events if e.type == AiSdkEventType.TOOL_INPUT_START + ] + assert len(tool_input_starts) == 2, ( + "tool-input-start must fire once per tool-call round" + ) + assert tool_input_starts[0].payload["toolCallId"] == "call_r1" + assert tool_input_starts[1].payload["toolCallId"] == "call_r2" diff --git a/libs/core/kiln_ai/adapters/model_adapters/test_stream_events.py b/libs/core/kiln_ai/adapters/model_adapters/test_stream_events.py index ccc79581d..220b7b13e 100644 --- a/libs/core/kiln_ai/adapters/model_adapters/test_stream_events.py +++ b/libs/core/kiln_ai/adapters/model_adapters/test_stream_events.py @@ -223,3 +223,41 @@ def test_finish_reason_in_finalize(self): assert len(finish_events) == 1 meta = finish_events[0].payload.get("messageMetadata", {}) assert meta.get("finishReason") == "stop" + + def test_tool_input_start_reemitted_after_reset(self): + """After reset_for_next_step, tool-input-start must fire again for index 0.""" + converter = AiSdkStreamConverter() + + tc_round1 = _make_tool_call_delta( + index=0, call_id="call_r1", name="search", arguments='{"q":"hi"}' + ) + events_r1 = converter.convert_chunk(_make_chunk(tool_calls=[tc_round1])) + starts_r1 = [e for e in events_r1 if e.type == AiSdkEventType.TOOL_INPUT_START] + assert len(starts_r1) == 1 + assert starts_r1[0].payload["toolCallId"] == "call_r1" + + converter.reset_for_next_step() + + tc_round2 = _make_tool_call_delta( + index=0, call_id="call_r2", name="search", arguments='{"q":"world"}' + ) + events_r2 = converter.convert_chunk(_make_chunk(tool_calls=[tc_round2])) + starts_r2 = [e for e in events_r2 if e.type == AiSdkEventType.TOOL_INPUT_START] + assert len(starts_r2) == 1, "tool-input-start must be re-emitted for index 0 after reset" + assert starts_r2[0].payload["toolCallId"] == "call_r2" + + def test_tool_input_start_not_reemitted_without_reset(self): + """Without reset, a second tool call at index 0 must NOT re-emit tool-input-start.""" + converter = AiSdkStreamConverter() + + tc_round1 = _make_tool_call_delta( + index=0, call_id="call_r1", name="search", arguments='{"q":"hi"}' + ) + converter.convert_chunk(_make_chunk(tool_calls=[tc_round1])) + + tc_round2 = _make_tool_call_delta( + index=0, call_id="call_r2", name="search", arguments='{"q":"world"}' + ) + events_r2 = converter.convert_chunk(_make_chunk(tool_calls=[tc_round2])) + starts_r2 = [e for e in events_r2 if e.type == AiSdkEventType.TOOL_INPUT_START] + assert len(starts_r2) == 0, "Without reset, started=True blocks duplicate tool-input-start" From 11710f20092436f19b0b65911be6ead6e0dccb5e Mon Sep 17 00:00:00 2001 From: "Leonard Q. Marcq" Date: Sun, 8 Mar 2026 16:47:05 +0800 Subject: [PATCH 12/32] refactor: take in a trace instead of a task_run for session continuation --- .../adapters/model_adapters/base_adapter.py | 118 ++++-------------- .../adapters/model_adapters/mcp_adapter.py | 10 +- .../model_adapters/test_base_adapter.py | 62 +++++---- .../test_litellm_adapter_streaming.py | 4 +- .../model_adapters/test_mcp_adapter.py | 8 +- .../test_saving_adapter_results.py | 104 +++++++-------- libs/server/kiln_server/run_api.py | 5 +- libs/server/kiln_server/test_run_api.py | 13 +- 8 files changed, 125 insertions(+), 199 deletions(-) diff --git a/libs/core/kiln_ai/adapters/model_adapters/base_adapter.py b/libs/core/kiln_ai/adapters/model_adapters/base_adapter.py index c599de650..d2f323467 100644 --- a/libs/core/kiln_ai/adapters/model_adapters/base_adapter.py +++ b/libs/core/kiln_ai/adapters/model_adapters/base_adapter.py @@ -142,10 +142,10 @@ async def invoke( self, input: InputType, input_source: DataSource | None = None, - existing_run: TaskRun | None = None, + prior_trace: list[ChatCompletionMessageParam] | None = None, ) -> TaskRun: run_output, _ = await self.invoke_returning_run_output( - input, input_source, existing_run + input, input_source, prior_trace ) return run_output @@ -153,7 +153,7 @@ async def _run_returning_run_output( self, input: InputType, input_source: DataSource | None = None, - existing_run: TaskRun | None = None, + prior_trace: list[ChatCompletionMessageParam] | None = None, ) -> Tuple[TaskRun, RunOutput]: # validate input, allowing arrays if self.input_schema is not None: @@ -164,14 +164,7 @@ async def _run_returning_run_output( require_object=False, ) - if existing_run is not None and ( - not existing_run.trace or len(existing_run.trace) == 0 - ): - raise ValueError( - "Run has no trace. Cannot continue session without conversation history." - ) - - prior_trace = existing_run.trace if existing_run else None + prior_trace = prior_trace if prior_trace else None # Format model input for model call (we save the original input in the task without formatting) formatted_input = input @@ -230,28 +223,9 @@ async def _run_returning_run_output( "Reasoning is required for this model, but no reasoning was returned." ) - # Create the run and output - merge if there is an existing run - if existing_run is not None: - merged_output = RunOutput( - output=parsed_output.output, - intermediate_outputs=parsed_output.intermediate_outputs - or run_output.intermediate_outputs, - output_logprobs=parsed_output.output_logprobs - or run_output.output_logprobs, - trace=run_output.trace, - ) - run = self.generate_run( - input, - input_source, - merged_output, - usage, - run_output.trace, - existing_run=existing_run, - ) - else: - run = self.generate_run( - input, input_source, parsed_output, usage, run_output.trace - ) + run = self.generate_run( + input, input_source, parsed_output, usage, run_output.trace + ) # Save the run if configured to do so, and we have a path to save to if ( @@ -260,7 +234,7 @@ async def _run_returning_run_output( and self.task.path is not None ): run.save_to_file() - elif existing_run is None: + else: # Clear the ID to indicate it's not persisted run.id = None @@ -270,7 +244,7 @@ async def invoke_returning_run_output( self, input: InputType, input_source: DataSource | None = None, - existing_run: TaskRun | None = None, + prior_trace: list[ChatCompletionMessageParam] | None = None, ) -> Tuple[TaskRun, RunOutput]: # Determine if this is the root agent (no existing run context) is_root_agent = get_agent_run_id() is None @@ -281,7 +255,7 @@ async def invoke_returning_run_output( try: return await self._run_returning_run_output( - input, input_source, existing_run + input, input_source, prior_trace ) finally: if is_root_agent: @@ -296,7 +270,7 @@ async def invoke_openai_stream( self, input: InputType, input_source: DataSource | None = None, - existing_run: TaskRun | None = None, + prior_trace: list[ChatCompletionMessageParam] | None = None, ) -> AsyncIterator[ModelResponseStream]: """Stream raw OpenAI-protocol chunks for the task execution. @@ -310,13 +284,13 @@ async def invoke_openai_stream( set_agent_run_id(generate_agent_run_id()) try: - adapter_stream = self._prepare_stream(input, existing_run) + adapter_stream = self._prepare_stream(input, prior_trace) async for event in adapter_stream: if isinstance(event, ModelResponseStream): yield event - self._finalize_stream(adapter_stream, input, input_source, existing_run) + self._finalize_stream(adapter_stream, input, input_source, prior_trace) finally: if is_root_agent: try: @@ -330,7 +304,7 @@ async def invoke_ai_sdk_stream( self, input: InputType, input_source: DataSource | None = None, - existing_run: TaskRun | None = None, + prior_trace: list[ChatCompletionMessageParam] | None = None, ) -> AsyncIterator[AiSdkStreamEvent]: """Stream AI SDK protocol events for the task execution. @@ -342,7 +316,7 @@ async def invoke_ai_sdk_stream( set_agent_run_id(generate_agent_run_id()) try: - adapter_stream = self._prepare_stream(input, existing_run) + adapter_stream = self._prepare_stream(input, prior_trace) message_id = f"msg-{uuid.uuid4().hex}" converter = AiSdkStreamConverter() @@ -371,7 +345,7 @@ async def invoke_ai_sdk_stream( yield AiSdkStreamEvent(AiSdkEventType.FINISH_STEP) - self._finalize_stream(adapter_stream, input, input_source, existing_run) + self._finalize_stream(adapter_stream, input, input_source, prior_trace) finally: if is_root_agent: try: @@ -384,7 +358,7 @@ async def invoke_ai_sdk_stream( def _prepare_stream( self, input: InputType, - existing_run: TaskRun | None, + prior_trace: list[ChatCompletionMessageParam] | None, ) -> AdapterStream: if self.input_schema is not None: validate_schema_with_value_error( @@ -394,14 +368,7 @@ def _prepare_stream( require_object=False, ) - if existing_run is not None and ( - not existing_run.trace or len(existing_run.trace) == 0 - ): - raise ValueError( - "Run has no trace. Cannot continue session without conversation history." - ) - - prior_trace = existing_run.trace if existing_run else None + prior_trace = prior_trace if prior_trace else None formatted_input = input formatter_id = self.model_provider().formatter @@ -416,7 +383,7 @@ def _finalize_stream( adapter_stream: AdapterStream, input: InputType, input_source: DataSource | None, - existing_run: TaskRun | None, + prior_trace: list[ChatCompletionMessageParam] | None, ) -> TaskRun: """Streaming invocations are only concerned with passing through events as they come in. At the end of the stream, we still need to validate the output, create a run and everything @@ -468,27 +435,9 @@ def _finalize_stream( "Reasoning is required for this model, but no reasoning was returned." ) - if existing_run is not None: - merged_output = RunOutput( - output=parsed_output.output, - intermediate_outputs=parsed_output.intermediate_outputs - or run_output.intermediate_outputs, - output_logprobs=parsed_output.output_logprobs - or run_output.output_logprobs, - trace=run_output.trace, - ) - run = self.generate_run( - input, - input_source, - merged_output, - usage, - run_output.trace, - existing_run=existing_run, - ) - else: - run = self.generate_run( - input, input_source, parsed_output, usage, run_output.trace - ) + run = self.generate_run( + input, input_source, parsed_output, usage, run_output.trace + ) if ( self.base_adapter_config.allow_saving @@ -496,7 +445,7 @@ def _finalize_stream( and self.task.path is not None ): run.save_to_file() - elif existing_run is None: + else: run.id = None return run @@ -604,7 +553,6 @@ def generate_run( run_output: RunOutput, usage: Usage | None = None, trace: list[ChatCompletionMessageParam] | None = None, - existing_run: TaskRun | None = None, ) -> TaskRun: output_str = ( json.dumps(run_output.output, ensure_ascii=False) @@ -627,26 +575,6 @@ def generate_run( ), ) - if existing_run is not None: - accumulated_usage = existing_run.usage - if usage is not None: - if accumulated_usage is not None: - accumulated_usage = accumulated_usage + usage - else: - accumulated_usage = usage - - merged_intermediate = dict(existing_run.intermediate_outputs or {}) - if run_output.intermediate_outputs: - for k, v in run_output.intermediate_outputs.items(): - merged_intermediate[k] = v - - existing_run.output = new_output - existing_run.trace = trace - existing_run.usage = accumulated_usage - existing_run.intermediate_outputs = merged_intermediate - - return existing_run - # Convert input and output to JSON strings if they aren't strings input_str = ( input if isinstance(input, str) else json.dumps(input, ensure_ascii=False) diff --git a/libs/core/kiln_ai/adapters/model_adapters/mcp_adapter.py b/libs/core/kiln_ai/adapters/model_adapters/mcp_adapter.py index 6055d96d9..c488e7fc0 100644 --- a/libs/core/kiln_ai/adapters/model_adapters/mcp_adapter.py +++ b/libs/core/kiln_ai/adapters/model_adapters/mcp_adapter.py @@ -88,16 +88,16 @@ async def invoke( self, input: InputType, input_source: DataSource | None = None, - existing_run: TaskRun | None = None, + prior_trace: list[ChatCompletionMessageParam] | None = None, ) -> TaskRun: - if existing_run is not None: + if prior_trace: raise NotImplementedError( "Session continuation is not supported for MCP adapter. " "MCP tools are single-turn and do not maintain conversation state." ) run_output, _ = await self.invoke_returning_run_output( - input, input_source, existing_run + input, input_source, prior_trace ) return run_output @@ -105,13 +105,13 @@ async def invoke_returning_run_output( self, input: InputType, input_source: DataSource | None = None, - existing_run: TaskRun | None = None, + prior_trace: list[ChatCompletionMessageParam] | None = None, ) -> Tuple[TaskRun, RunOutput]: """ Runs the task and returns both the persisted TaskRun and raw RunOutput. If this call is the root of a run, it creates an agent run context, ensures MCP tool calls have a valid session scope, and cleans up the session/context on completion. """ - if existing_run is not None: + if prior_trace: raise NotImplementedError( "Session continuation is not supported for MCP adapter. " "MCP tools are single-turn and do not maintain conversation state." diff --git a/libs/core/kiln_ai/adapters/model_adapters/test_base_adapter.py b/libs/core/kiln_ai/adapters/model_adapters/test_base_adapter.py index f842e4b98..88854666e 100644 --- a/libs/core/kiln_ai/adapters/model_adapters/test_base_adapter.py +++ b/libs/core/kiln_ai/adapters/model_adapters/test_base_adapter.py @@ -443,7 +443,7 @@ def test_build_chat_formatter_with_prior_trace_returns_multiturn_formatter(adapt @pytest.mark.asyncio -async def test_existing_run_without_trace_raises(base_project): +async def test_invoke_with_prior_trace_none_starts_fresh(base_project): task = Task( name="test_task", instruction="test_instruction", @@ -458,29 +458,38 @@ async def test_existing_run_without_trace_raises(base_project): structured_output_mode=StructuredOutputMode.json_schema, ), ) - run_without_trace = TaskRun( - parent=task, - input="hi", - input_source=None, - output=TaskOutput( - output="hello", - source=DataSource( - type=DataSourceType.synthetic, - properties={ - "model_name": "gpt_4o", - "model_provider": "openai", - "adapter_name": "test", - }, + adapter._run = AsyncMock( + return_value=( + RunOutput(output="ok", intermediate_outputs=None, trace=None), + None, + ) + ) + with ( + patch( + "kiln_ai.adapters.model_adapters.base_adapter.model_parser_from_id", + return_value=MagicMock(parse_output=MagicMock(return_value=RunOutput(output="ok", intermediate_outputs=None, trace=None))), + ), + patch( + "kiln_ai.adapters.model_adapters.base_adapter.request_formatter_from_id", + ), + patch.object( + adapter, + "model_provider", + return_value=MagicMock( + parser="default", + formatter=None, + reasoning_capable=False, ), ), - trace=None, - ) - with pytest.raises(ValueError, match="no trace"): - await adapter.invoke("input", existing_run=run_without_trace) + ): + run = await adapter.invoke("input", prior_trace=None) + assert run.output.output == "ok" + adapter._run.assert_called_once() + assert adapter._run.call_args[1].get("prior_trace") is None @pytest.mark.asyncio -async def test_invoke_returning_run_output_passes_existing_run_to_run( +async def test_invoke_returning_run_output_passes_prior_trace_to_run( adapter, mock_parser, tmp_path ): project = Project(name="proj", path=tmp_path / "proj.kiln") @@ -497,19 +506,6 @@ async def test_invoke_returning_run_output_passes_existing_run_to_run( {"role": "user", "content": "hi"}, {"role": "assistant", "content": "hello"}, ] - initial_run = adapter.generate_run( - input="hi", - input_source=None, - run_output=RunOutput( - output="hello", - intermediate_outputs=None, - trace=trace, - ), - trace=trace, - ) - initial_run.save_to_file() - run_id = initial_run.id - assert run_id is not None captured_prior_trace = None @@ -538,7 +534,7 @@ async def mock_run(input, **kwargs): "kiln_ai.adapters.model_adapters.base_adapter.request_formatter_from_id", ), ): - await adapter.invoke_returning_run_output("follow-up", existing_run=initial_run) + await adapter.invoke_returning_run_output("follow-up", prior_trace=trace) assert captured_prior_trace == trace diff --git a/libs/core/kiln_ai/adapters/model_adapters/test_litellm_adapter_streaming.py b/libs/core/kiln_ai/adapters/model_adapters/test_litellm_adapter_streaming.py index 43b2aaff4..65c540376 100644 --- a/libs/core/kiln_ai/adapters/model_adapters/test_litellm_adapter_streaming.py +++ b/libs/core/kiln_ai/adapters/model_adapters/test_litellm_adapter_streaming.py @@ -284,7 +284,7 @@ async def test_invoke_openai_stream_non_streaming_still_works( @pytest.mark.paid @pytest.mark.parametrize("model_id,provider_name", STREAMING_MODELS_NO_HAIKU) -async def test_invoke_openai_stream_with_existing_run( +async def test_invoke_openai_stream_with_prior_trace( request: pytest.FixtureRequest, model_id: str, provider_name: ModelProviderName, @@ -300,7 +300,7 @@ async def test_invoke_openai_stream_with_existing_run( continuation_chunks: list[litellm.ModelResponseStream] = [] async for chunk in adapter.invoke_openai_stream( input="What was the result? Reply in one short sentence.", - existing_run=initial_run, + prior_trace=initial_run.trace, ): continuation_chunks.append(chunk) diff --git a/libs/core/kiln_ai/adapters/model_adapters/test_mcp_adapter.py b/libs/core/kiln_ai/adapters/model_adapters/test_mcp_adapter.py index ad5826d8c..0e2142208 100644 --- a/libs/core/kiln_ai/adapters/model_adapters/test_mcp_adapter.py +++ b/libs/core/kiln_ai/adapters/model_adapters/test_mcp_adapter.py @@ -334,7 +334,7 @@ async def test_mcp_adapter_sets_and_clears_run_context( async def test_mcp_adapter_rejects_multiturn_invoke_returning_run_output( project_with_local_mcp_server, local_mcp_tool_id ): - """Session continuation (existing_run) is not supported for MCP adapter.""" + """Session continuation (prior_trace) is not supported for MCP adapter.""" project, _ = project_with_local_mcp_server task = Task( name="Test Task", @@ -352,7 +352,7 @@ async def test_mcp_adapter_rejects_multiturn_invoke_returning_run_output( existing_run.trace = [{"role": "user", "content": "hi"}] with pytest.raises(NotImplementedError) as exc_info: - await adapter.invoke_returning_run_output("input", existing_run=existing_run) + await adapter.invoke_returning_run_output("input", prior_trace=existing_run.trace) assert "Session continuation is not supported" in str(exc_info.value) assert "MCP adapter" in str(exc_info.value) @@ -362,7 +362,7 @@ async def test_mcp_adapter_rejects_multiturn_invoke_returning_run_output( async def test_mcp_adapter_rejects_multiturn_invoke( project_with_local_mcp_server, local_mcp_tool_id ): - """invoke with existing_run raises NotImplementedError for MCP adapter.""" + """invoke with prior_trace raises NotImplementedError for MCP adapter.""" project, _ = project_with_local_mcp_server task = Task( name="Test Task", @@ -380,7 +380,7 @@ async def test_mcp_adapter_rejects_multiturn_invoke( existing_run.trace = [{"role": "user", "content": "hi"}] with pytest.raises(NotImplementedError) as exc_info: - await adapter.invoke("input", existing_run=existing_run) + await adapter.invoke("input", prior_trace=existing_run.trace) assert "Session continuation is not supported" in str(exc_info.value) diff --git a/libs/core/kiln_ai/adapters/model_adapters/test_saving_adapter_results.py b/libs/core/kiln_ai/adapters/model_adapters/test_saving_adapter_results.py index 43df8c522..db20fe5ea 100644 --- a/libs/core/kiln_ai/adapters/model_adapters/test_saving_adapter_results.py +++ b/libs/core/kiln_ai/adapters/model_adapters/test_saving_adapter_results.py @@ -1,4 +1,4 @@ -from unittest.mock import MagicMock, patch +from unittest.mock import AsyncMock, MagicMock, patch import pytest @@ -235,7 +235,7 @@ async def test_autosave_true(test_task, adapter): @pytest.mark.asyncio async def test_invoke_continue_session(test_task, adapter): - """Test that invoke with task_run_id continues a session and updates the run.""" + """Test that invoke with prior_trace continues a session and creates a new run.""" with patch("kiln_ai.utils.config.Config.shared") as mock_shared: mock_config = mock_shared.return_value mock_config.autosave_runs = True @@ -245,19 +245,6 @@ async def test_invoke_continue_session(test_task, adapter): {"role": "user", "content": "Hello"}, {"role": "assistant", "content": "Hi there!"}, ] - initial_run = adapter.generate_run( - input="Hello", - input_source=None, - run_output=RunOutput( - output="Hi there!", - intermediate_outputs=None, - trace=trace, - ), - trace=trace, - ) - initial_run.save_to_file() - run_id = initial_run.id - assert run_id is not None async def mock_run(input, **kwargs): prior_trace = kwargs.get("prior_trace") @@ -306,14 +293,14 @@ async def mock_run(input, **kwargs): ) mock_parser_from_id.return_value = mock_parser - updated_run = await adapter.invoke("Tell me more", existing_run=initial_run) + new_run = await adapter.invoke("Tell me more", prior_trace=trace) - assert updated_run.id == run_id - assert updated_run.input == "Hello" - assert updated_run.output.output == "How can I help?" - assert len(updated_run.trace) == 4 - assert updated_run.trace[-2]["content"] == "Tell me more" - assert updated_run.trace[-1]["content"] == "How can I help?" + assert new_run.id is not None + assert new_run.input == "Tell me more" + assert new_run.output.output == "How can I help?" + assert len(new_run.trace) == 4 + assert new_run.trace[-2]["content"] == "Tell me more" + assert new_run.trace[-1]["content"] == "How can I help?" reloaded = Task.load_from_file(test_task.path) runs = reloaded.runs() @@ -322,36 +309,55 @@ async def mock_run(input, **kwargs): @pytest.mark.asyncio -async def test_invoke_continue_run_without_trace(test_task, adapter): - """Test that invoke with existing_run that has no trace raises ValueError.""" +async def test_invoke_with_empty_prior_trace_starts_fresh(test_task, adapter): + """Test that invoke with prior_trace=[] starts a fresh conversation (no error).""" with patch("kiln_ai.utils.config.Config.shared") as mock_shared: mock_config = mock_shared.return_value mock_config.autosave_runs = True mock_config.user_id = "test_user" - run_without_trace = adapter.generate_run( - input="Hello", - input_source=None, - run_output=RunOutput( - output="Hi", - intermediate_outputs=None, - trace=None, - ), + adapter._run = AsyncMock( + return_value=( + RunOutput(output="Fresh reply", intermediate_outputs=None, trace=None), + None, + ) ) - run_without_trace.save_to_file() - - with pytest.raises(ValueError, match="no trace"): - await adapter.invoke("Follow up", existing_run=run_without_trace) + with ( + patch.object( + adapter, + "model_provider", + return_value=MagicMock( + parser="default", + formatter=None, + reasoning_capable=False, + ), + ), + patch( + "kiln_ai.adapters.model_adapters.base_adapter.model_parser_from_id", + return_value=MagicMock( + parse_output=MagicMock( + return_value=RunOutput( + output="Fresh reply", + intermediate_outputs=None, + trace=None, + ) + ) + ), + ), + patch( + "kiln_ai.adapters.model_adapters.base_adapter.request_formatter_from_id", + ), + ): + run = await adapter.invoke("Follow up", prior_trace=[]) + assert run.output.output == "Fresh reply" -def test_generate_run_with_existing_run_merges_usage_and_intermediate_outputs( - test_task, adapter -): +def test_generate_run_always_creates_new_task_run(test_task, adapter): trace = [ {"role": "user", "content": "hi"}, {"role": "assistant", "content": "hello"}, ] - initial_run = adapter.generate_run( + run1 = adapter.generate_run( input="hi", input_source=None, run_output=RunOutput( @@ -367,8 +373,8 @@ def test_generate_run_with_existing_run_merges_usage_and_intermediate_outputs( {"role": "user", "content": "follow-up"}, {"role": "assistant", "content": "ok"}, ] - result = adapter.generate_run( - input="hi", + run2 = adapter.generate_run( + input="follow-up", input_source=None, run_output=RunOutput( output="ok", @@ -377,16 +383,12 @@ def test_generate_run_with_existing_run_merges_usage_and_intermediate_outputs( ), usage=Usage(input_tokens=5, output_tokens=10), trace=extended_trace, - existing_run=initial_run, ) - assert result is initial_run - assert result.usage.input_tokens == 15 - assert result.usage.output_tokens == 30 - assert result.intermediate_outputs == { - "chain_of_thought": "old", - "new_key": "new_val", - } - assert result.output.output == "ok" + assert run2 is not run1 + assert run2.usage is not None and run2.usage.input_tokens == 5 + assert run2.usage.output_tokens == 10 + assert run2.intermediate_outputs == {"new_key": "new_val"} + assert run2.output.output == "ok" def test_properties_for_task_output_custom_values(test_task): diff --git a/libs/server/kiln_server/run_api.py b/libs/server/kiln_server/run_api.py index ea13104d2..6f1764575 100644 --- a/libs/server/kiln_server/run_api.py +++ b/libs/server/kiln_server/run_api.py @@ -285,7 +285,7 @@ async def run_task( detail="No input provided. Ensure your provided the proper format (plaintext or structured).", ) - existing_run: TaskRun | None = None + prior_trace: list | None = None if request.task_run_id is not None: if task.path is None: raise HTTPException( @@ -305,8 +305,9 @@ async def run_task( status_code=400, detail="Run has no trace. Cannot continue session without conversation history.", ) + prior_trace = existing_run.trace - return await adapter.invoke(input, existing_run=existing_run) + return await adapter.invoke(input, prior_trace=prior_trace) @app.patch("/api/projects/{project_id}/tasks/{task_id}/runs/{run_id}") async def update_run( diff --git a/libs/server/kiln_server/test_run_api.py b/libs/server/kiln_server/test_run_api.py index 608a03a94..3085fd069 100644 --- a/libs/server/kiln_server/test_run_api.py +++ b/libs/server/kiln_server/test_run_api.py @@ -190,7 +190,7 @@ async def test_run_task_with_task_run_id_continues_session(client, task_run_setu assert response.status_code == 200 mock_invoke.assert_called_once() call_kwargs = mock_invoke.call_args[1] - assert call_kwargs["existing_run"].id == task_run.id + assert call_kwargs["prior_trace"] == task_run.trace assert mock_invoke.call_args[0][0] == "Follow-up message" res = response.json() assert res["output"]["output"] == "Continued response" @@ -1962,7 +1962,7 @@ async def test_run_task_adapter_sanity_math_tools( ) assert response2.status_code == 200 res2 = response2.json() - assert res2["id"] == task_run_id + assert res2["id"] != task_run_id _assert_math_tools_response(res2, "12") response3 = client.post( @@ -1970,24 +1970,23 @@ async def test_run_task_adapter_sanity_math_tools( json={ "run_config_properties": run_config, "plaintext_input": "What is 7 times 8 plus 3? Use the tools to calculate.", - "task_run_id": task_run_id, + "task_run_id": res2["id"], }, ) assert response3.status_code == 200 res3 = response3.json() - assert res3["id"] == task_run_id + assert res3["id"] != res2["id"] _assert_math_tools_response(res3, "59") - # now ask it to list out all the previous results in an array response4 = client.post( f"/api/projects/{project.id}/tasks/{task.id}/run", json={ "run_config_properties": run_config, "plaintext_input": "List all the previous results in an array - e.g. [55, 81, 7].", - "task_run_id": task_run_id, + "task_run_id": res3["id"], }, ) assert response4.status_code == 200 res4 = response4.json() - assert res4["id"] == task_run_id + assert res4["id"] != res3["id"] assert res4["output"]["output"] == "[4, 12, 59]" From 3ad6a27479f79ecd7764f2e43d014814b5b4d271 Mon Sep 17 00:00:00 2001 From: "Leonard Q. Marcq" Date: Sun, 8 Mar 2026 16:54:43 +0800 Subject: [PATCH 13/32] refactor: remove ability to continue task run at api level --- libs/server/kiln_server/run_api.py | 32 +----- libs/server/kiln_server/test_run_api.py | 142 +----------------------- 2 files changed, 8 insertions(+), 166 deletions(-) diff --git a/libs/server/kiln_server/run_api.py b/libs/server/kiln_server/run_api.py index 6f1764575..cd10ec52f 100644 --- a/libs/server/kiln_server/run_api.py +++ b/libs/server/kiln_server/run_api.py @@ -16,6 +16,7 @@ from kiln_ai.datamodel.datamodel_enums import StructuredInputType from kiln_ai.datamodel.task import RunConfigProperties from kiln_ai.datamodel.task_output import DataSource, DataSourceType, TaskOutput +from kiln_ai.utils.config import Config from kiln_ai.utils.dataset_import import ( DatasetFileImporter, DatasetImportFormat, @@ -32,6 +33,9 @@ update_run_lock = Lock() +Config.shared().autosave_runs = True + + def deep_update( source: Dict[str, Any] | None, update: Dict[str, Any | None] ) -> Dict[str, Any]: @@ -56,10 +60,6 @@ class RunTaskRequest(BaseModel): plaintext_input: str | None = None structured_input: StructuredInputType | None = None tags: list[str] | None = None - task_run_id: str | None = Field( - default=None, - description="When set, continue an existing session. The new message is appended to the run's trace.", - ) # Allows use of the model_name field (usually pydantic will reserve model_*) model_config = ConfigDict(protected_namespaces=()) @@ -285,29 +285,7 @@ async def run_task( detail="No input provided. Ensure your provided the proper format (plaintext or structured).", ) - prior_trace: list | None = None - if request.task_run_id is not None: - if task.path is None: - raise HTTPException( - status_code=400, - detail="Cannot continue session: task has no path. Save the task first.", - ) - existing_run = TaskRun.from_id_and_parent_path( - request.task_run_id, task.path - ) - if existing_run is None: - raise HTTPException( - status_code=404, - detail="Run not found. Cannot continue session.", - ) - if not existing_run.trace or len(existing_run.trace) == 0: - raise HTTPException( - status_code=400, - detail="Run has no trace. Cannot continue session without conversation history.", - ) - prior_trace = existing_run.trace - - return await adapter.invoke(input, prior_trace=prior_trace) + return await adapter.invoke(input) @app.patch("/api/projects/{project_id}/tasks/{task_id}/runs/{run_id}") async def update_run( diff --git a/libs/server/kiln_server/test_run_api.py b/libs/server/kiln_server/test_run_api.py index 3085fd069..974f1d433 100644 --- a/libs/server/kiln_server/test_run_api.py +++ b/libs/server/kiln_server/test_run_api.py @@ -143,135 +143,6 @@ async def test_run_task_success(client, task_run_setup): assert res["id"] is not None -@pytest.mark.asyncio -async def test_run_task_with_task_run_id_continues_session(client, task_run_setup): - """Test that run_task with task_run_id passes it to adapter.invoke for session continuation.""" - project = task_run_setup["project"] - task = task_run_setup["task"] - task_run = task_run_setup["task_run"] - - run_task_request = { - "run_config_properties": { - "model_name": "gpt_4o", - "model_provider_name": "ollama", - "prompt_id": "simple_prompt_builder", - "structured_output_mode": "json_schema", - }, - "plaintext_input": "Follow-up message", - "task_run_id": task_run.id, - } - - continued_run = TaskRun( - parent=task, - input=task_run.input, - input_source=task_run.input_source, - output=TaskOutput( - output="Continued response", - source=task_run.output.source, - ), - ) - continued_run.id = task_run.id - - with ( - patch("kiln_server.run_api.task_from_id") as mock_task_from_id, - patch.object(LiteLlmAdapter, "invoke", new_callable=AsyncMock) as mock_invoke, - patch("kiln_ai.utils.config.Config.shared") as MockConfig, - ): - mock_task_from_id.return_value = task - mock_invoke.return_value = continued_run - - mock_config_instance = MockConfig.return_value - mock_config_instance.ollama_base_url = "http://localhost:11434/v1" - - response = client.post( - f"/api/projects/{project.id}/tasks/{task.id}/run", json=run_task_request - ) - - assert response.status_code == 200 - mock_invoke.assert_called_once() - call_kwargs = mock_invoke.call_args[1] - assert call_kwargs["prior_trace"] == task_run.trace - assert mock_invoke.call_args[0][0] == "Follow-up message" - res = response.json() - assert res["output"]["output"] == "Continued response" - - -@pytest.mark.asyncio -async def test_run_task_task_run_id_not_found_returns_404(client, task_run_setup): - """Test that run_task with nonexistent task_run_id returns 404.""" - project = task_run_setup["project"] - task = task_run_setup["task"] - - run_task_request = { - "run_config_properties": { - "model_name": "gpt_4o", - "model_provider_name": "ollama", - "prompt_id": "simple_prompt_builder", - "structured_output_mode": "json_schema", - }, - "plaintext_input": "Follow-up", - "task_run_id": "nonexistent-run-id", - } - - with patch("kiln_server.run_api.task_from_id") as mock_task_from_id: - mock_task_from_id.return_value = task - response = client.post( - f"/api/projects/{project.id}/tasks/{task.id}/run", json=run_task_request - ) - - assert response.status_code == 404 - assert "Run not found" in response.json()["message"] - - -@pytest.mark.asyncio -async def test_run_task_task_run_id_no_trace_returns_400(client, task_run_setup): - """Test that run_task with task_run_id for run without trace returns 400.""" - project = task_run_setup["project"] - task = task_run_setup["task"] - - task_run_no_trace = TaskRun( - parent=task, - input="Hello", - input_source=DataSource( - type=DataSourceType.human, properties={"created_by": "Test User"} - ), - output=TaskOutput( - output="Hi", - source=DataSource( - type=DataSourceType.synthetic, - properties={ - "model_name": "gpt_4o", - "model_provider": "ollama", - "adapter_name": "kiln_langchain_adapter", - "prompt_id": "simple_prompt_builder", - }, - ), - ), - trace=None, - ) - task_run_no_trace.save_to_file() - - run_task_request = { - "run_config_properties": { - "model_name": "gpt_4o", - "model_provider_name": "ollama", - "prompt_id": "simple_prompt_builder", - "structured_output_mode": "json_schema", - }, - "plaintext_input": "Follow-up", - "task_run_id": task_run_no_trace.id, - } - - with patch("kiln_server.run_api.task_from_id") as mock_task_from_id: - mock_task_from_id.return_value = task - response = client.post( - f"/api/projects/{project.id}/tasks/{task.id}/run", json=run_task_request - ) - - assert response.status_code == 400 - assert "no trace" in response.json()["message"].lower() - - @pytest.mark.asyncio async def test_run_task_structured_output(client, task_run_setup): task = task_run_setup["task"] @@ -1918,7 +1789,7 @@ def _assert_math_tools_response(res: dict, expected_in_output: str) -> None: async def test_run_task_adapter_sanity_math_tools( client, adapter_sanity_check_math_tools_setup ): - """Multi-turn run with built-in Kiln math tools. Test that tools + continue session work as expected.""" + """Multiple runs with built-in Kiln math tools. Test that tools work across independent runs.""" if not os.environ.get("OPENROUTER_API_KEY"): pytest.skip("OPENROUTER_API_KEY required for this test") @@ -1950,19 +1821,16 @@ async def test_run_task_adapter_sanity_math_tools( assert response1.status_code == 200 res1 = response1.json() _assert_math_tools_response(res1, "4") - task_run_id = res1["id"] response2 = client.post( f"/api/projects/{project.id}/tasks/{task.id}/run", json={ "run_config_properties": run_config, "plaintext_input": "What is 3 times 4? Use the tools to calculate.", - "task_run_id": task_run_id, }, ) assert response2.status_code == 200 res2 = response2.json() - assert res2["id"] != task_run_id _assert_math_tools_response(res2, "12") response3 = client.post( @@ -1970,23 +1838,19 @@ async def test_run_task_adapter_sanity_math_tools( json={ "run_config_properties": run_config, "plaintext_input": "What is 7 times 8 plus 3? Use the tools to calculate.", - "task_run_id": res2["id"], }, ) assert response3.status_code == 200 res3 = response3.json() - assert res3["id"] != res2["id"] _assert_math_tools_response(res3, "59") response4 = client.post( f"/api/projects/{project.id}/tasks/{task.id}/run", json={ "run_config_properties": run_config, - "plaintext_input": "List all the previous results in an array - e.g. [55, 81, 7].", - "task_run_id": res3["id"], + "plaintext_input": "What is 10 minus 3? Use the tools to calculate.", }, ) assert response4.status_code == 200 res4 = response4.json() - assert res4["id"] != res3["id"] - assert res4["output"]["output"] == "[4, 12, 59]" + _assert_math_tools_response(res4, "7") From 4d8e99fc0021fe27ed8ea5b89c7f01865b7fe045 Mon Sep 17 00:00:00 2001 From: "Leonard Q. Marcq" Date: Sun, 8 Mar 2026 17:29:36 +0800 Subject: [PATCH 14/32] refactor: wrap stream iterators to allow exposing task run at the end --- app/web_ui/src/lib/api_schema.d.ts | 5 - .../adapters/chat/test_chat_formatter.py | 44 ++++ .../adapters/model_adapters/base_adapter.py | 224 ++++++++++++------ .../model_adapters/test_base_adapter.py | 88 ++++++- .../model_adapters/test_litellm_adapter.py | 118 +++++++++ .../model_adapters/test_mcp_adapter.py | 4 +- .../model_adapters/test_stream_events.py | 8 +- 7 files changed, 408 insertions(+), 83 deletions(-) diff --git a/app/web_ui/src/lib/api_schema.d.ts b/app/web_ui/src/lib/api_schema.d.ts index 535a9aa60..754759f8a 100644 --- a/app/web_ui/src/lib/api_schema.d.ts +++ b/app/web_ui/src/lib/api_schema.d.ts @@ -6249,11 +6249,6 @@ export interface components { } | unknown[] | null; /** Tags */ tags?: string[] | null; - /** - * Task Run Id - * @description When set, continue an existing session. The new message is appended to the run's trace. - */ - task_run_id?: string | null; }; /** * SampleApi diff --git a/libs/core/kiln_ai/adapters/chat/test_chat_formatter.py b/libs/core/kiln_ai/adapters/chat/test_chat_formatter.py index e0138a954..2903b6eee 100644 --- a/libs/core/kiln_ai/adapters/chat/test_chat_formatter.py +++ b/libs/core/kiln_ai/adapters/chat/test_chat_formatter.py @@ -146,6 +146,50 @@ def test_multiturn_formatter_next_turn(): assert formatter.next_turn("assistant response") is None +def test_multiturn_formatter_preserves_tool_call_messages(): + prior_trace = [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "4"}, + { + "role": "assistant", + "content": "", + "reasoning_content": "Let me multiply 4 by 7.\n", + "tool_calls": [ + { + "id": "call_abc123", + "function": {"arguments": '{"a": 4, "b": 7}', "name": "multiply"}, + "type": "function", + } + ], + }, + { + "content": "28", + "role": "tool", + "tool_call_id": "call_abc123", + "kiln_task_tool_data": None, + }, + { + "role": "assistant", + "content": "4 multiplied by 7 is 28.", + "reasoning_content": "Done.\n", + }, + ] + formatter = MultiturnFormatter(prior_trace=prior_trace, user_input="now double it") + initial = formatter.initial_messages() + assert initial == prior_trace + assert initial[2]["tool_calls"][0]["id"] == "call_abc123" + assert initial[2]["tool_calls"][0]["function"]["name"] == "multiply" + assert initial[3]["role"] == "tool" + assert initial[3]["tool_call_id"] == "call_abc123" + + first = formatter.next_turn() + assert first is not None + assert len(first.messages) == 1 + assert first.messages[0].role == "user" + assert first.messages[0].content == "now double it" + assert first.final_call + + def test_format_user_message(): # String assert format_user_message("test input") == "test input" diff --git a/libs/core/kiln_ai/adapters/model_adapters/base_adapter.py b/libs/core/kiln_ai/adapters/model_adapters/base_adapter.py index d2f323467..d9633ac81 100644 --- a/libs/core/kiln_ai/adapters/model_adapters/base_adapter.py +++ b/libs/core/kiln_ai/adapters/model_adapters/base_adapter.py @@ -266,94 +266,38 @@ async def invoke_returning_run_output( finally: clear_agent_run_id() - async def invoke_openai_stream( + def invoke_openai_stream( self, input: InputType, input_source: DataSource | None = None, prior_trace: list[ChatCompletionMessageParam] | None = None, - ) -> AsyncIterator[ModelResponseStream]: + ) -> OpenAIStreamResult: """Stream raw OpenAI-protocol chunks for the task execution. - Yields ``ModelResponseStream`` chunks as they arrive from the model. - After the iterator is exhausted the run has been validated and saved - (when configured). Tool-call rounds happen internally and are not - surfaced; use ``invoke_ai_sdk_stream`` if you need tool-call events. - """ - is_root_agent = get_agent_run_id() is None - if is_root_agent: - set_agent_run_id(generate_agent_run_id()) - - try: - adapter_stream = self._prepare_stream(input, prior_trace) - - async for event in adapter_stream: - if isinstance(event, ModelResponseStream): - yield event + Returns an async-iterable that yields ``ModelResponseStream`` chunks + as they arrive from the model. After the iterator is exhausted the + run has been validated and saved (when configured). The resulting + ``TaskRun`` is available via the ``.task_run`` property. - self._finalize_stream(adapter_stream, input, input_source, prior_trace) - finally: - if is_root_agent: - try: - run_id = get_agent_run_id() - if run_id: - await MCPSessionManager.shared().cleanup_session(run_id) - finally: - clear_agent_run_id() + Tool-call rounds happen internally and are not surfaced; use + ``invoke_ai_sdk_stream`` if you need tool-call events. + """ + return OpenAIStreamResult(self, input, input_source, prior_trace) - async def invoke_ai_sdk_stream( + def invoke_ai_sdk_stream( self, input: InputType, input_source: DataSource | None = None, prior_trace: list[ChatCompletionMessageParam] | None = None, - ) -> AsyncIterator[AiSdkStreamEvent]: + ) -> AiSdkStreamResult: """Stream AI SDK protocol events for the task execution. - Yields ``AiSdkStreamEvent`` instances covering text, reasoning, - tool-call lifecycle, step boundaries, and control events. + Returns an async-iterable that yields ``AiSdkStreamEvent`` instances + covering text, reasoning, tool-call lifecycle, step boundaries, and + control events. After the iterator is exhausted the resulting + ``TaskRun`` is available via the ``.task_run`` property. """ - is_root_agent = get_agent_run_id() is None - if is_root_agent: - set_agent_run_id(generate_agent_run_id()) - - try: - adapter_stream = self._prepare_stream(input, prior_trace) - - message_id = f"msg-{uuid.uuid4().hex}" - converter = AiSdkStreamConverter() - - yield AiSdkStreamEvent(AiSdkEventType.START, {"messageId": message_id}) - - yield AiSdkStreamEvent(AiSdkEventType.START_STEP) - - last_event_was_tool_call = False - async for event in adapter_stream: - # ModelResponseStream events come from LiteLLM's own OpenAI compatible streaming - if isinstance(event, ModelResponseStream): - if last_event_was_tool_call: - converter.reset_for_next_step() - last_event_was_tool_call = False - for ai_event in converter.convert_chunk(event): - yield ai_event - # ToolCallEvent events come from ourselves and are emitted on rounds of toolcalls - elif isinstance(event, ToolCallEvent): - last_event_was_tool_call = True - for ai_event in converter.convert_tool_event(event): - yield ai_event - - for ai_event in converter.finalize(): - yield ai_event - - yield AiSdkStreamEvent(AiSdkEventType.FINISH_STEP) - - self._finalize_stream(adapter_stream, input, input_source, prior_trace) - finally: - if is_root_agent: - try: - run_id = get_agent_run_id() - if run_id: - await MCPSessionManager.shared().cleanup_session(run_id) - finally: - clear_agent_run_id() + return AiSdkStreamResult(self, input, input_source, prior_trace) def _prepare_stream( self, @@ -663,3 +607,137 @@ async def available_tools(self) -> list[KilnToolInterface]: ) return tools + + +class OpenAIStreamResult: + """Async-iterable wrapper around the OpenAI streaming flow. + + Yields ``ModelResponseStream`` chunks. After iteration the resulting + ``TaskRun`` is available via the ``.task_run`` property. + """ + + def __init__( + self, + adapter: BaseAdapter, + input: InputType, + input_source: DataSource | None, + prior_trace: list[ChatCompletionMessageParam] | None, + ) -> None: + self._adapter = adapter + self._input = input + self._input_source = input_source + self._prior_trace = prior_trace + self._task_run: TaskRun | None = None + + @property + def task_run(self) -> TaskRun: + if self._task_run is None: + raise RuntimeError( + "Stream has not been fully consumed yet. " + "Iterate over the stream before accessing .task_run" + ) + return self._task_run + + async def __aiter__(self) -> AsyncIterator[ModelResponseStream]: + self._task_run = None + is_root_agent = get_agent_run_id() is None + if is_root_agent: + set_agent_run_id(generate_agent_run_id()) + + try: + adapter_stream = self._adapter._prepare_stream( + self._input, self._prior_trace + ) + + async for event in adapter_stream: + if isinstance(event, ModelResponseStream): + yield event + + self._task_run = self._adapter._finalize_stream( + adapter_stream, self._input, self._input_source, self._prior_trace + ) + finally: + if is_root_agent: + try: + run_id = get_agent_run_id() + if run_id: + await MCPSessionManager.shared().cleanup_session(run_id) + finally: + clear_agent_run_id() + + +class AiSdkStreamResult: + """Async-iterable wrapper around the AI SDK streaming flow. + + Yields ``AiSdkStreamEvent`` instances. After iteration the resulting + ``TaskRun`` is available via the ``.task_run`` property. + """ + + def __init__( + self, + adapter: BaseAdapter, + input: InputType, + input_source: DataSource | None, + prior_trace: list[ChatCompletionMessageParam] | None, + ) -> None: + self._adapter = adapter + self._input = input + self._input_source = input_source + self._prior_trace = prior_trace + self._task_run: TaskRun | None = None + + @property + def task_run(self) -> TaskRun: + if self._task_run is None: + raise RuntimeError( + "Stream has not been fully consumed yet. " + "Iterate over the stream before accessing .task_run" + ) + return self._task_run + + async def __aiter__(self) -> AsyncIterator[AiSdkStreamEvent]: + self._task_run = None + is_root_agent = get_agent_run_id() is None + if is_root_agent: + set_agent_run_id(generate_agent_run_id()) + + try: + adapter_stream = self._adapter._prepare_stream( + self._input, self._prior_trace + ) + + message_id = f"msg-{uuid.uuid4().hex}" + converter = AiSdkStreamConverter() + + yield AiSdkStreamEvent(AiSdkEventType.START, {"messageId": message_id}) + yield AiSdkStreamEvent(AiSdkEventType.START_STEP) + + last_event_was_tool_call = False + async for event in adapter_stream: + if isinstance(event, ModelResponseStream): + if last_event_was_tool_call: + converter.reset_for_next_step() + last_event_was_tool_call = False + for ai_event in converter.convert_chunk(event): + yield ai_event + elif isinstance(event, ToolCallEvent): + last_event_was_tool_call = True + for ai_event in converter.convert_tool_event(event): + yield ai_event + + for ai_event in converter.finalize(): + yield ai_event + + yield AiSdkStreamEvent(AiSdkEventType.FINISH_STEP) + + self._task_run = self._adapter._finalize_stream( + adapter_stream, self._input, self._input_source, self._prior_trace + ) + finally: + if is_root_agent: + try: + run_id = get_agent_run_id() + if run_id: + await MCPSessionManager.shared().cleanup_session(run_id) + finally: + clear_agent_run_id() diff --git a/libs/core/kiln_ai/adapters/model_adapters/test_base_adapter.py b/libs/core/kiln_ai/adapters/model_adapters/test_base_adapter.py index 88854666e..b000df81f 100644 --- a/libs/core/kiln_ai/adapters/model_adapters/test_base_adapter.py +++ b/libs/core/kiln_ai/adapters/model_adapters/test_base_adapter.py @@ -21,7 +21,7 @@ ToolCallEventType, ) from kiln_ai.adapters.prompt_builders import BasePromptBuilder -from kiln_ai.datamodel import DataSource, DataSourceType, Task, TaskOutput, TaskRun +from kiln_ai.datamodel import Task, TaskRun from kiln_ai.datamodel.datamodel_enums import ChatStrategy, ModelProviderName from kiln_ai.datamodel.project import Project from kiln_ai.datamodel.run_config import KilnAgentRunConfigProperties, ToolsRunConfig @@ -467,7 +467,13 @@ async def test_invoke_with_prior_trace_none_starts_fresh(base_project): with ( patch( "kiln_ai.adapters.model_adapters.base_adapter.model_parser_from_id", - return_value=MagicMock(parse_output=MagicMock(return_value=RunOutput(output="ok", intermediate_outputs=None, trace=None))), + return_value=MagicMock( + parse_output=MagicMock( + return_value=RunOutput( + output="ok", intermediate_outputs=None, trace=None + ) + ) + ), ), patch( "kiln_ai.adapters.model_adapters.base_adapter.request_formatter_from_id", @@ -1151,3 +1157,81 @@ async def __aiter__(self): ) assert tool_input_starts[0].payload["toolCallId"] == "call_r1" assert tool_input_starts[1].payload["toolCallId"] == "call_r2" + + @pytest.mark.asyncio + async def test_openai_stream_exposes_task_run_after_iteration(self, stream_adapter): + fake_chunk = ModelResponseStream( + id="test", + choices=[ + StreamingChoices( + index=0, + delta=Delta(content="hi"), + finish_reason=None, + ) + ], + ) + + class FakeAdapterStream: + result = MagicMock() + + async def __aiter__(self): + yield fake_chunk + + expected_run = MagicMock(spec=TaskRun) + + with ( + patch.object( + stream_adapter, + "_prepare_stream", + return_value=FakeAdapterStream(), + ), + patch.object(stream_adapter, "_finalize_stream", return_value=expected_run), + ): + stream = stream_adapter.invoke_openai_stream("test input") + + with pytest.raises(RuntimeError, match="not been fully consumed"): + _ = stream.task_run + + async for _chunk in stream: + pass + + assert stream.task_run is expected_run + + @pytest.mark.asyncio + async def test_ai_sdk_stream_exposes_task_run_after_iteration(self, stream_adapter): + fake_chunk = ModelResponseStream( + id="test", + choices=[ + StreamingChoices( + index=0, + delta=Delta(content="hi"), + finish_reason=None, + ) + ], + ) + + class FakeAdapterStream: + result = MagicMock() + + async def __aiter__(self): + yield fake_chunk + + expected_run = MagicMock(spec=TaskRun) + + with ( + patch.object( + stream_adapter, + "_prepare_stream", + return_value=FakeAdapterStream(), + ), + patch.object(stream_adapter, "_finalize_stream", return_value=expected_run), + ): + stream = stream_adapter.invoke_ai_sdk_stream("test input") + + with pytest.raises(RuntimeError, match="not been fully consumed"): + _ = stream.task_run + + async for _event in stream: + pass + + assert stream.task_run is expected_run diff --git a/libs/core/kiln_ai/adapters/model_adapters/test_litellm_adapter.py b/libs/core/kiln_ai/adapters/model_adapters/test_litellm_adapter.py index 197b1f500..6ec73e6b6 100644 --- a/libs/core/kiln_ai/adapters/model_adapters/test_litellm_adapter.py +++ b/libs/core/kiln_ai/adapters/model_adapters/test_litellm_adapter.py @@ -1370,3 +1370,121 @@ async def mock_run_model_turn( assert run_output.trace[1]["content"] == "hello" assert run_output.trace[2]["content"] == "follow-up" assert run_output.trace[3]["content"] == "How can I help?" + + +@pytest.mark.asyncio +async def test_run_with_prior_trace_preserves_tool_calls(mock_task): + """Prior trace containing tool calls should be passed through to the model and preserved in the output trace.""" + config = LiteLlmConfig( + base_url="https://api.test.com", + run_config_properties=KilnAgentRunConfigProperties( + model_name="test-model", + model_provider_name="openai_compatible", + prompt_id="simple_prompt_builder", + structured_output_mode="json_schema", + ), + default_headers={"X-Test": "test"}, + additional_body_options={"api_key": "test_key"}, + ) + + prior_trace = [ + {"role": "system", "content": "Use the math tools."}, + {"role": "user", "content": "4"}, + { + "role": "assistant", + "content": "", + "reasoning_content": "Let me multiply 4 by 7.\n", + "tool_calls": [ + { + "id": "call_abc123", + "function": {"arguments": '{"a": 4, "b": 7}', "name": "multiply"}, + "type": "function", + } + ], + }, + { + "content": "28", + "role": "tool", + "tool_call_id": "call_abc123", + "kiln_task_tool_data": None, + }, + { + "role": "assistant", + "content": "", + "reasoning_content": "Now add 144.\n", + "tool_calls": [ + { + "id": "call_def456", + "function": {"arguments": '{"a": 28, "b": 144}', "name": "add"}, + "type": "function", + } + ], + }, + { + "content": "172", + "role": "tool", + "tool_call_id": "call_def456", + "kiln_task_tool_data": None, + }, + { + "role": "assistant", + "content": "There were 172 distinct species of giant tortoises.", + "reasoning_content": "Now I have 172.\n", + }, + ] + adapter = LiteLlmAdapter(config=config, kiln_task=mock_task) + + captured_messages = [] + + async def mock_run_model_turn( + provider, prior_messages, top_logprobs, skip_response_format + ): + captured_messages.extend(prior_messages) + extended = list(prior_messages) + extended.append({"role": "assistant", "content": '{"test": "response"}'}) + return ModelTurnResult( + assistant_message='{"test": "response"}', + all_messages=extended, + model_response=None, + model_choice=None, + usage=Usage(), + ) + + adapter._run_model_turn = mock_run_model_turn + + run_output, _ = await adapter._run("what else?", prior_trace=prior_trace) + + assert run_output.trace is not None + # 7 prior trace messages + 1 new user + 1 new assistant = 9 + assert len(run_output.trace) == 9 + + # Verify tool call messages are preserved in the trace + assistant_with_tools = run_output.trace[2] + assert assistant_with_tools["role"] == "assistant" + assert assistant_with_tools["tool_calls"][0]["id"] == "call_abc123" + assert assistant_with_tools["tool_calls"][0]["function"]["name"] == "multiply" + assert assistant_with_tools["reasoning_content"] == "Let me multiply 4 by 7.\n" + + tool_response = run_output.trace[3] + assert tool_response["role"] == "tool" + assert tool_response["tool_call_id"] == "call_abc123" + assert tool_response["content"] == "28" + + second_tool_call = run_output.trace[4] + assert second_tool_call["tool_calls"][0]["id"] == "call_def456" + assert second_tool_call["tool_calls"][0]["function"]["name"] == "add" + + second_tool_response = run_output.trace[5] + assert second_tool_response["role"] == "tool" + assert second_tool_response["tool_call_id"] == "call_def456" + assert second_tool_response["content"] == "172" + + # Verify the tool call messages were passed to _run_model_turn (i.e., sent to the model) + assert any( + m.get("tool_calls") is not None + for m in captured_messages + if isinstance(m, dict) + ) + assert any( + m.get("role") == "tool" for m in captured_messages if isinstance(m, dict) + ) diff --git a/libs/core/kiln_ai/adapters/model_adapters/test_mcp_adapter.py b/libs/core/kiln_ai/adapters/model_adapters/test_mcp_adapter.py index 0e2142208..cb0a3e94b 100644 --- a/libs/core/kiln_ai/adapters/model_adapters/test_mcp_adapter.py +++ b/libs/core/kiln_ai/adapters/model_adapters/test_mcp_adapter.py @@ -352,7 +352,9 @@ async def test_mcp_adapter_rejects_multiturn_invoke_returning_run_output( existing_run.trace = [{"role": "user", "content": "hi"}] with pytest.raises(NotImplementedError) as exc_info: - await adapter.invoke_returning_run_output("input", prior_trace=existing_run.trace) + await adapter.invoke_returning_run_output( + "input", prior_trace=existing_run.trace + ) assert "Session continuation is not supported" in str(exc_info.value) assert "MCP adapter" in str(exc_info.value) diff --git a/libs/core/kiln_ai/adapters/model_adapters/test_stream_events.py b/libs/core/kiln_ai/adapters/model_adapters/test_stream_events.py index 220b7b13e..d4f3b3e43 100644 --- a/libs/core/kiln_ai/adapters/model_adapters/test_stream_events.py +++ b/libs/core/kiln_ai/adapters/model_adapters/test_stream_events.py @@ -243,7 +243,9 @@ def test_tool_input_start_reemitted_after_reset(self): ) events_r2 = converter.convert_chunk(_make_chunk(tool_calls=[tc_round2])) starts_r2 = [e for e in events_r2 if e.type == AiSdkEventType.TOOL_INPUT_START] - assert len(starts_r2) == 1, "tool-input-start must be re-emitted for index 0 after reset" + assert len(starts_r2) == 1, ( + "tool-input-start must be re-emitted for index 0 after reset" + ) assert starts_r2[0].payload["toolCallId"] == "call_r2" def test_tool_input_start_not_reemitted_without_reset(self): @@ -260,4 +262,6 @@ def test_tool_input_start_not_reemitted_without_reset(self): ) events_r2 = converter.convert_chunk(_make_chunk(tool_calls=[tc_round2])) starts_r2 = [e for e in events_r2 if e.type == AiSdkEventType.TOOL_INPUT_START] - assert len(starts_r2) == 0, "Without reset, started=True blocks duplicate tool-input-start" + assert len(starts_r2) == 0, ( + "Without reset, started=True blocks duplicate tool-input-start" + ) From eb537edf6ffeabb5ca7d3dd13dde9b1452e92715 Mon Sep 17 00:00:00 2001 From: "Leonard Q. Marcq" Date: Sun, 8 Mar 2026 18:08:18 +0800 Subject: [PATCH 15/32] fix: remove autosave_runs hardcoded --- libs/server/kiln_server/run_api.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/libs/server/kiln_server/run_api.py b/libs/server/kiln_server/run_api.py index cd10ec52f..f123d421e 100644 --- a/libs/server/kiln_server/run_api.py +++ b/libs/server/kiln_server/run_api.py @@ -16,7 +16,6 @@ from kiln_ai.datamodel.datamodel_enums import StructuredInputType from kiln_ai.datamodel.task import RunConfigProperties from kiln_ai.datamodel.task_output import DataSource, DataSourceType, TaskOutput -from kiln_ai.utils.config import Config from kiln_ai.utils.dataset_import import ( DatasetFileImporter, DatasetImportFormat, @@ -33,9 +32,6 @@ update_run_lock = Lock() -Config.shared().autosave_runs = True - - def deep_update( source: Dict[str, Any] | None, update: Dict[str, Any | None] ) -> Dict[str, Any]: From 66b31511f2457f26c65f29f243d366542b57e511 Mon Sep 17 00:00:00 2001 From: "Leonard Q. Marcq" Date: Sun, 8 Mar 2026 19:20:31 +0800 Subject: [PATCH 16/32] fix: close text when opening toolcall --- .../adapters/model_adapters/stream_events.py | 10 ++++ .../test_litellm_adapter_streaming.py | 59 +++++++++++++++++++ .../model_adapters/test_stream_events.py | 53 +++++++++++++++++ 3 files changed, 122 insertions(+) diff --git a/libs/core/kiln_ai/adapters/model_adapters/stream_events.py b/libs/core/kiln_ai/adapters/model_adapters/stream_events.py index b2784484a..88fc60675 100644 --- a/libs/core/kiln_ai/adapters/model_adapters/stream_events.py +++ b/libs/core/kiln_ai/adapters/model_adapters/stream_events.py @@ -124,6 +124,7 @@ def convert_chunk(self, chunk: ModelResponseStream) -> list[AiSdkStreamEvent]: self._reasoning_started = False if not self._text_started: + self._text_id = f"text-{uuid.uuid4().hex[:12]}" events.append( AiSdkStreamEvent( AiSdkEventType.TEXT_START, @@ -148,6 +149,15 @@ def convert_chunk(self, chunk: ModelResponseStream) -> list[AiSdkStreamEvent]: ) self._reasoning_started = False + if self._text_started: + events.append( + AiSdkStreamEvent( + AiSdkEventType.TEXT_END, + {"id": self._text_id}, + ) + ) + self._text_started = False + for tc_delta in delta.tool_calls: idx = tc_delta.index tc_state = self._tool_calls_state.setdefault( diff --git a/libs/core/kiln_ai/adapters/model_adapters/test_litellm_adapter_streaming.py b/libs/core/kiln_ai/adapters/model_adapters/test_litellm_adapter_streaming.py index 65c540376..8b786a673 100644 --- a/libs/core/kiln_ai/adapters/model_adapters/test_litellm_adapter_streaming.py +++ b/libs/core/kiln_ai/adapters/model_adapters/test_litellm_adapter_streaming.py @@ -262,6 +262,65 @@ async def test_invoke_ai_sdk_stream( ) +@pytest.mark.paid +@pytest.mark.parametrize("model_id,provider_name", STREAMING_MODELS) +async def test_ai_sdk_stream_text_ends_before_tool_calls( + request: pytest.FixtureRequest, + model_id: str, + provider_name: ModelProviderName, + adapter_factory: Callable[[str, ModelProviderName], LiteLlmAdapter], +): + """Verify text blocks are properly closed before tool-input-start and reopened with a new ID after tool execution.""" + adapter = adapter_factory(model_id, provider_name) + + events: list[AiSdkStreamEvent] = [] + async for event in adapter.invoke_ai_sdk_stream( + input="First tell me you're about to calculate, then compute 11 + 50 and 50 * 85, then add the results. Use the tools for all math." + ): + events.append(event) + + _dump_paid_test_output(request, events=events) + + event_types = [e.type for e in events] + assert AiSdkEventType.TEXT_START in event_types, "Should have TEXT_START" + assert AiSdkEventType.TOOL_INPUT_START in event_types, ( + "Should have TOOL_INPUT_START" + ) + + text_ids_seen: list[str] = [] + text_open = False + for event in events: + if event.type == AiSdkEventType.TEXT_START: + assert not text_open, ( + "text-start emitted while a text block was already open" + ) + text_open = True + text_ids_seen.append(event.payload["id"]) + + elif event.type == AiSdkEventType.TEXT_END: + assert text_open, "text-end emitted without a preceding text-start" + text_open = False + + elif event.type == AiSdkEventType.TEXT_DELTA: + assert text_open, ( + f"text-delta emitted outside an open text block: {event.payload}" + ) + + elif event.type == AiSdkEventType.TOOL_INPUT_START: + assert not text_open, ( + "tool-input-start emitted while text block was still open " + "(missing text-end before tool calls)" + ) + + assert len(text_ids_seen) >= 2, ( + f"Expected at least 2 distinct text blocks (before and after tool calls), " + f"got {len(text_ids_seen)}: {text_ids_seen}" + ) + assert len(set(text_ids_seen)) == len(text_ids_seen), ( + f"Each text block should have a unique ID, got duplicates: {text_ids_seen}" + ) + + @pytest.mark.paid @pytest.mark.parametrize("model_id,provider_name", STREAMING_MODELS_NO_HAIKU) async def test_invoke_openai_stream_non_streaming_still_works( diff --git a/libs/core/kiln_ai/adapters/model_adapters/test_stream_events.py b/libs/core/kiln_ai/adapters/model_adapters/test_stream_events.py index d4f3b3e43..3c945c7e3 100644 --- a/libs/core/kiln_ai/adapters/model_adapters/test_stream_events.py +++ b/libs/core/kiln_ai/adapters/model_adapters/test_stream_events.py @@ -248,6 +248,59 @@ def test_tool_input_start_reemitted_after_reset(self): ) assert starts_r2[0].payload["toolCallId"] == "call_r2" + def test_text_ends_before_tool_calls(self): + converter = AiSdkStreamConverter() + events1 = converter.convert_chunk(_make_chunk(content="Hello")) + text_start = next(e for e in events1 if e.type == AiSdkEventType.TEXT_START) + text_id_1 = text_start.payload["id"] + + tc_delta = _make_tool_call_delta( + index=0, call_id="call_1", name="add", arguments='{"a":1}' + ) + events2 = converter.convert_chunk(_make_chunk(tool_calls=[tc_delta])) + types2 = [e.type for e in events2] + assert AiSdkEventType.TEXT_END in types2 + text_end_idx = types2.index(AiSdkEventType.TEXT_END) + tool_start_idx = types2.index(AiSdkEventType.TOOL_INPUT_START) + assert text_end_idx < tool_start_idx, ( + "text-end must come before tool-input-start" + ) + + text_end_event = next( + e for e in events2 if e.type == AiSdkEventType.TEXT_END + ) + assert text_end_event.payload["id"] == text_id_1 + + def test_text_restarts_with_new_id_after_tool_calls(self): + converter = AiSdkStreamConverter() + events1 = converter.convert_chunk(_make_chunk(content="Hello")) + text_start_1 = next(e for e in events1 if e.type == AiSdkEventType.TEXT_START) + text_id_1 = text_start_1.payload["id"] + + tc_delta = _make_tool_call_delta( + index=0, call_id="call_1", name="add", arguments='{"a":1}' + ) + converter.convert_chunk(_make_chunk(tool_calls=[tc_delta])) + converter.reset_for_next_step() + + events3 = converter.convert_chunk(_make_chunk(content="Result")) + types3 = [e.type for e in events3] + assert AiSdkEventType.TEXT_START in types3 + text_start_2 = next(e for e in events3 if e.type == AiSdkEventType.TEXT_START) + text_id_2 = text_start_2.payload["id"] + assert text_id_1 != text_id_2, ( + "New text block after tool calls must have a different id" + ) + + def test_no_text_end_before_tool_calls_when_text_not_started(self): + converter = AiSdkStreamConverter() + tc_delta = _make_tool_call_delta( + index=0, call_id="call_1", name="add", arguments='{"a":1}' + ) + events = converter.convert_chunk(_make_chunk(tool_calls=[tc_delta])) + types = [e.type for e in events] + assert AiSdkEventType.TEXT_END not in types + def test_tool_input_start_not_reemitted_without_reset(self): """Without reset, a second tool call at index 0 must NOT re-emit tool-input-start.""" converter = AiSdkStreamConverter() From 446058085db1f809df67215c6576eb866d698de5 Mon Sep 17 00:00:00 2001 From: "Leonard Q. Marcq" Date: Sun, 8 Mar 2026 20:01:16 +0800 Subject: [PATCH 17/32] fix: reset fully --- libs/core/kiln_ai/adapters/model_adapters/stream_events.py | 6 +++++- .../kiln_ai/adapters/model_adapters/test_stream_events.py | 4 +--- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/libs/core/kiln_ai/adapters/model_adapters/stream_events.py b/libs/core/kiln_ai/adapters/model_adapters/stream_events.py index 88fc60675..fbe2ff230 100644 --- a/libs/core/kiln_ai/adapters/model_adapters/stream_events.py +++ b/libs/core/kiln_ai/adapters/model_adapters/stream_events.py @@ -95,7 +95,7 @@ def convert_chunk(self, chunk: ModelResponseStream) -> list[AiSdkStreamEvent]: continue reasoning_content = getattr(delta, "reasoning_content", None) - if reasoning_content is not None: + if reasoning_content: if not self._reasoning_started: self._reasoning_block_count += 1 self._reasoning_id = f"reasoning-{uuid.uuid4().hex[:12]}" @@ -298,3 +298,7 @@ def reset_for_next_step(self) -> None: """Reset per-step state between LLM calls in a multi-step flow.""" self._tool_calls_state = {} self._finish_reason = None + self._text_started = False + self._reasoning_started = False + self._text_id = f"text-{uuid.uuid4().hex[:12]}" + self._reasoning_id = f"reasoning-{uuid.uuid4().hex[:12]}" diff --git a/libs/core/kiln_ai/adapters/model_adapters/test_stream_events.py b/libs/core/kiln_ai/adapters/model_adapters/test_stream_events.py index 3c945c7e3..9df765ba9 100644 --- a/libs/core/kiln_ai/adapters/model_adapters/test_stream_events.py +++ b/libs/core/kiln_ai/adapters/model_adapters/test_stream_events.py @@ -266,9 +266,7 @@ def test_text_ends_before_tool_calls(self): "text-end must come before tool-input-start" ) - text_end_event = next( - e for e in events2 if e.type == AiSdkEventType.TEXT_END - ) + text_end_event = next(e for e in events2 if e.type == AiSdkEventType.TEXT_END) assert text_end_event.payload["id"] == text_id_1 def test_text_restarts_with_new_id_after_tool_calls(self): From 89133f6095b919d61bb182b85e8368d37d8d4c26 Mon Sep 17 00:00:00 2001 From: "Leonard Q. Marcq" Date: Sat, 14 Mar 2026 01:10:27 +0800 Subject: [PATCH 18/32] feat: support nesting task runs into each other --- app/web_ui/src/lib/api_schema.d.ts | 10 + libs/core/kiln_ai/datamodel/basemodel.py | 32 ++- libs/core/kiln_ai/datamodel/task.py | 18 ++ libs/core/kiln_ai/datamodel/task_run.py | 104 +++++++- libs/core/kiln_ai/datamodel/test_models.py | 272 +++++++++++++++++++++ 5 files changed, 426 insertions(+), 10 deletions(-) diff --git a/app/web_ui/src/lib/api_schema.d.ts b/app/web_ui/src/lib/api_schema.d.ts index de9e56143..000d7f5cd 100644 --- a/app/web_ui/src/lib/api_schema.d.ts +++ b/app/web_ui/src/lib/api_schema.d.ts @@ -6934,6 +6934,11 @@ export interface components { * * Contains the input used, its source, the output produced, and optional * repair information if the output needed correction. + * + * Can be nested under another TaskRun; nested runs are stored as child runs + * in a "runs" subfolder (same relationship name as Task's runs). + * + * Accepts both Task and TaskRun as parents (polymorphic). */ "TaskRun-Input": { /** @@ -6996,6 +7001,11 @@ export interface components { * * Contains the input used, its source, the output produced, and optional * repair information if the output needed correction. + * + * Can be nested under another TaskRun; nested runs are stored as child runs + * in a "runs" subfolder (same relationship name as Task's runs). + * + * Accepts both Task and TaskRun as parents (polymorphic). */ "TaskRun-Output": { /** diff --git a/libs/core/kiln_ai/datamodel/basemodel.py b/libs/core/kiln_ai/datamodel/basemodel.py index c29f11f7e..b325f3109 100644 --- a/libs/core/kiln_ai/datamodel/basemodel.py +++ b/libs/core/kiln_ai/datamodel/basemodel.py @@ -545,17 +545,38 @@ def relationship_name(cls) -> str: def parent_type(cls) -> Type[KilnBaseModel]: raise NotImplementedError("Parent type must be implemented") - @model_validator(mode="after") - def check_parent_type(self) -> Self: + def _check_parent_type( + self, + expected_parent_types: List[Type[KilnBaseModel]] | None = None, + ) -> Self: cached_parent = self.cached_parent() - if cached_parent is not None: + if cached_parent is None: + return self + + # some models support having multiple parent types, so we allow overriding the expected parent + if expected_parent_types is not None: + if not any( + isinstance(cached_parent, expected_parent_type) + for expected_parent_type in expected_parent_types + ): + raise ValueError( + f"Parent must be one of {expected_parent_types}, but was {type(cached_parent)}" + ) + else: + # default case where we expect a single parent type to be valid expected_parent_type = self.__class__.parent_type() if not isinstance(cached_parent, expected_parent_type): raise ValueError( f"Parent must be of type {expected_parent_type}, but was {type(cached_parent)}" ) + return self + @model_validator(mode="after") + def check_parent_type(self) -> Self: + """Default validation for parent type. Can be overridden by subclasses - for example if the parent is polymorphic.""" + return self._check_parent_type() + def build_child_dirname(self) -> Path: # Default implementation for readable folder names. # {id} - {name}/{type}.kiln @@ -602,8 +623,9 @@ def iterate_children_paths_of_parent_path(cls: Type[PT], parent_path: Path | Non else: parent_folder = parent_path - parent = cls.parent_type().load_from_file(parent_path) - if parent is None: + # cannot validate the parent type here because some parentable models are polymorphic + # and can be nested under different types of parent models + if not parent_path.exists(): raise ValueError("Parent must be set to load children") # Ignore type error: this is abstract base class, but children must implement relationship_name diff --git a/libs/core/kiln_ai/datamodel/task.py b/libs/core/kiln_ai/datamodel/task.py index c223164cb..d5ace6449 100644 --- a/libs/core/kiln_ai/datamodel/task.py +++ b/libs/core/kiln_ai/datamodel/task.py @@ -203,3 +203,21 @@ def parent_project(self) -> Union["Project", None]: if self.parent is None or self.parent.__class__.__name__ != "Project": return None return self.parent # type: ignore + + def find_task_run_by_id_dfs( + self, task_run_id: str, readonly: bool = False + ) -> TaskRun | None: + """ + Find a task run by id in the entire task run tree. This is an expensive DFS + traversal of the file system so do not use too willy nilly. + + If you already know the root task run, you can use the same method on + the root TaskRun instead - that will save a bunch of subtree traversals. + """ + stack: List[TaskRun] = list(self.runs(readonly=readonly)) + while stack: + run = stack.pop() + if run.id == task_run_id: + return run + stack.extend(run.runs(readonly=readonly)) + return None diff --git a/libs/core/kiln_ai/datamodel/task_run.py b/libs/core/kiln_ai/datamodel/task_run.py index bef7ae700..048898e6a 100644 --- a/libs/core/kiln_ai/datamodel/task_run.py +++ b/libs/core/kiln_ai/datamodel/task_run.py @@ -1,10 +1,14 @@ import json -from typing import TYPE_CHECKING, Dict, List, Union +from typing import TYPE_CHECKING, Dict, List, Optional, Union from pydantic import BaseModel, Field, ValidationInfo, model_validator from typing_extensions import Self -from kiln_ai.datamodel.basemodel import KilnParentedModel +from kiln_ai.datamodel.basemodel import ( + KilnBaseModel, + KilnParentedModel, + KilnParentModel, +) from kiln_ai.datamodel.json_schema import validate_schema_with_value_error from kiln_ai.datamodel.strict_mode import strict_mode from kiln_ai.datamodel.task_output import DataSource, TaskOutput @@ -73,12 +77,17 @@ def _add_optional_float(a: float | None, b: float | None) -> float | None: ) -class TaskRun(KilnParentedModel): +class TaskRun(KilnParentedModel, KilnParentModel, parent_of={}): """ Represents a single execution of a Task. Contains the input used, its source, the output produced, and optional repair information if the output needed correction. + + Can be nested under another TaskRun; nested runs are stored as child runs + in a "runs" subfolder (same relationship name as Task's runs). + + Accepts both Task and TaskRun as parents (polymorphic). """ input: str = Field( @@ -132,9 +141,87 @@ def has_thinking_training_data(self) -> bool: # Workaround to return typed parent without importing Task def parent_task(self) -> Union["Task", None]: - if self.parent is None or self.parent.__class__.__name__ != "Task": + """The Task that this Run is in. Note the TaskRun may be nested in which case we walk back up the tree all the way to the root.""" + if self.parent is None: + return None + + # this task run is already the root task run + if self.parent.__class__.__name__ == "Task": + return self.parent # type: ignore + + # this task run is nested under other ones, so we walk back + # up to the root task run + parent_run = self.cached_parent() + if isinstance(parent_run, TaskRun): + return parent_run.parent_task() + + return None + + def parent_run(self) -> "TaskRun | None": + """The TaskRun that contains this run, if this run is nested; otherwise None.""" + parent = self.cached_parent() + if parent is None or not isinstance(parent, TaskRun): + return None + return parent + + def runs(self, readonly: bool = False) -> list["TaskRun"]: + """The list of child task runs.""" + return super().runs(readonly=readonly) # type: ignore + + def is_root_task_run(self) -> bool: + """Is this the root task run? (not nested under another task run)""" + return self.parent is None or self.parent.__class__.__name__ == "Task" + + def find_task_run_by_id_dfs( + self, task_run_id: str, readonly: bool = False + ) -> "TaskRun | None": + """ + Find a task run by id in the entire task run tree. This is an expensive DFS + traversal of the file system so do not use too willy nilly. + """ + stack: List[TaskRun] = list(self.runs(readonly=readonly)) + while stack: + run = stack.pop() + if run.id == task_run_id: + return run + stack.extend(run.runs(readonly=readonly)) + return None + + def load_parent(self) -> Optional[KilnBaseModel]: + """Load the parent of this task run - this is an override of the default parent loading logic to support nested task runs.""" + cached = self.cached_parent() + if cached is not None: + return cached + if self.path is None: return None - return self.parent # type: ignore + parent_dir = self.path.parent.parent.parent + task_run_path = parent_dir / TaskRun.base_filename() + if task_run_path.exists() and task_run_path != self.path: + try: + loaded_parent_run = TaskRun.load_from_file(task_run_path) + super().__setattr__("parent", loaded_parent_run) + return loaded_parent_run + except ValueError: + pass + + from kiln_ai.datamodel.task import Task + + task_path = parent_dir / Task.base_filename() + if task_path.exists(): + loaded_parent_task = Task.load_from_file(task_path) + super().__setattr__("parent", loaded_parent_task) + return loaded_parent_task + + return None + + @model_validator(mode="after") + def check_parent_type(self) -> Self: + """Check that the parent is a Task or TaskRun. This overrides the default parent type check + that only supports a single parent type.""" + # need to import here to avoid circular imports + from kiln_ai.datamodel.task import Task + + return self._check_parent_type([Task, TaskRun]) @model_validator(mode="after") def validate_input_format(self, info: ValidationInfo) -> Self: @@ -258,3 +345,10 @@ def validate_tags(self) -> Self: raise ValueError("Tags cannot contain spaces. Try underscores.") return self + + +# cannot do this in the class definition due to circular reference between TaskRun and itself: +# wire up TaskRun as its own child type so .runs() returns TaskRun instances +# this makes TaskRun polymorphic - can be parented under Task or another TaskRun +TaskRun._parent_of["runs"] = TaskRun +TaskRun._create_child_method("runs", TaskRun) diff --git a/libs/core/kiln_ai/datamodel/test_models.py b/libs/core/kiln_ai/datamodel/test_models.py index 68d0f611e..8da052eb6 100644 --- a/libs/core/kiln_ai/datamodel/test_models.py +++ b/libs/core/kiln_ai/datamodel/test_models.py @@ -1,5 +1,6 @@ import json import os +from typing import Union from unittest.mock import patch import pytest @@ -743,3 +744,274 @@ def test_generate_model_id(): # check it is a valid name - as we typically use model ids in filenames on FS validator = name_validator(min_length=1, max_length=12) validator(model_id) + + +# project and task fixture +@pytest.fixture +def task(tmp_path): + project_path = tmp_path / "project.kiln" + project = Project(name="P", path=project_path) + project.save_to_file() + task = Task(name="T", instruction="Do it", parent=project) + task.save_to_file() + return task + + +def test_nested_task_run_folder_structure(task: Task): + output = TaskOutput(output="out") + + parent_run = TaskRun(input="in", output=output, parent=task) + parent_run.save_to_file() + + nested_run = TaskRun(input="nested in", output=output, parent=parent_run) + nested_run.save_to_file() + + assert task.path is not None + assert parent_run.path is not None + assert nested_run.path is not None + task_dir = task.path.parent + runs_dir = task_dir / "runs" + parent_run_dir = runs_dir / parent_run.build_child_dirname() + nested_runs_dir = parent_run_dir / "runs" + + assert runs_dir.is_dir() + assert (parent_run_dir / "task_run.kiln").is_file() + assert nested_runs_dir.is_dir() + assert nested_run.path.is_file() + assert nested_run.path.parent.parent.parent == parent_run_dir + assert nested_run.path.name == TaskRun.base_filename() + + assert parent_run.parent_task() == task + assert parent_run.parent_run() is None + assert nested_run.parent_task() == task + assert nested_run.parent_run() == parent_run + assert len(parent_run.runs()) == 1 + assert parent_run.runs()[0].id == nested_run.id + + +def test_nested_task_runs_multiple_levels(task: Task): + output = TaskOutput(output="out") + + run1 = TaskRun(input="in1", output=output, parent=task) + run1.save_to_file() + run2 = TaskRun(input="in2", output=output, parent=run1) + run2.save_to_file() + run3 = TaskRun(input="in3", output=output, parent=run2) + run3.save_to_file() + + assert run1.parent_run() is None + assert run1.parent_task() == task + assert len(run1.runs()) == 1 + assert run1.runs()[0].id == run2.id + + assert run2.parent_run() == run1 + assert run2.parent_task() == task + assert len(run2.runs()) == 1 + assert run2.runs()[0].id == run3.id + + assert run3.parent_run() == run2 + assert run3.parent_task() == task + assert len(run3.runs()) == 0 + + +def test_parent_task_deeply_nested_task_run(task: Task): + output = TaskOutput(output="out") + run1 = TaskRun(input="in1", output=output, parent=task) + run1.save_to_file() + run2 = TaskRun(input="in2", output=output, parent=run1) + run2.save_to_file() + run3 = TaskRun(input="in3", output=output, parent=run2) + run3.save_to_file() + + assert run3.parent_task() == task + + +def test_find_nested_task_run_by_id_given_parent_run(task: Task): + assert task.path is not None + + output = TaskOutput(output="out") + parent_run = TaskRun(input="in", output=output, parent=task) + parent_run.save_to_file() + nested_run = TaskRun(input="nested in", output=output, parent=parent_run) + nested_run.save_to_file() + target_id = nested_run.id + + loaded_task = Task.load_from_file(task.path) + loaded_parent = next(r for r in loaded_task.runs() if r.id == parent_run.id) + found = next(r for r in loaded_parent.runs() if r.id == target_id) + assert found is not None + assert found.id == target_id + assert found.input == "nested in" + + +def test_find_root_task_run_by_id_given_task(task: Task): + output = TaskOutput(output="out") + root_run = TaskRun(input="in", output=output, parent=task) + root_run.save_to_file() + target_id = root_run.id + + assert task.path is not None + loaded_task = Task.load_from_file(task.path) + found = next(r for r in loaded_task.runs() if r.id == target_id) + assert found is not None + assert found.id == target_id + assert found.input == "in" + + +def test_find_deeply_nested_find_task_run_by_id_dfs(task: Task): + """Find a run by id when we only have the task; run may be nested several levels. + No built-in helper exists for this; test uses a local recursive search. + A Task.run_by_id(id) or similar could be added to the datamodel if needed.""" + assert task.path is not None + output = TaskOutput(output="out") + + run1 = TaskRun(input="in1", output=output, parent=task) + run1.save_to_file() + + run2 = TaskRun(input="in2", output=output, parent=run1) + run2.save_to_file() + + run3 = TaskRun(input="in3", output=output, parent=run2) + run3.save_to_file() + + # this is the task run we will be searching for + assert run3.id is not None + + def find_run_by_id( + parent: Union[Task, TaskRun], task_run_id_to_find: str + ) -> Union[TaskRun, None]: + for r in parent.runs(): + if r.id == task_run_id_to_find: + return r + + # not ideal that we recurse over FS - might be costly, will need to use this + # lightly + found = find_run_by_id(r, task_run_id_to_find) + if found is not None: + return found + + return None + + loaded_task = Task.load_from_file(task.path) + found = find_run_by_id(loaded_task, run3.id) + assert found is not None + assert found.id == run3.id + assert found.input == "in3" + + +def test_find_task_run_by_id_dfs_finds_deeply_nested_run(task: Task): + """Task.find_task_run_by_id_dfs finds a run nested several levels (iterative stack-based DFS).""" + assert task.path is not None + output = TaskOutput(output="out") + + run1 = TaskRun(input="in1", output=output, parent=task) + run1.save_to_file() + + run2 = TaskRun(input="in2", output=output, parent=run1) + run2.save_to_file() + + run3 = TaskRun(input="in3", output=output, parent=run2) + run3.save_to_file() + + assert run3.id is not None + + loaded_task = Task.load_from_file(task.path) + found = loaded_task.find_task_run_by_id_dfs(run3.id) + assert found is not None + assert found.id == run3.id + assert found.input == "in3" + + +def test_find_task_run_by_id_dfs_finds_root_run(task: Task): + """Task.find_task_run_by_id_dfs finds the root run when it is the only run.""" + assert task.path is not None + output = TaskOutput(output="out") + root_run = TaskRun(input="root in", output=output, parent=task) + root_run.save_to_file() + assert root_run.id is not None + + loaded_task = Task.load_from_file(task.path) + found = loaded_task.find_task_run_by_id_dfs(root_run.id) + assert found is not None + assert found.id == root_run.id + assert found.input == "root in" + + +def test_find_task_run_by_id_dfs_returns_none_when_not_found(task: Task): + """Task.find_task_run_by_id_dfs returns None when no run has the given id.""" + assert task.path is not None + output = TaskOutput(output="out") + TaskRun(input="in", output=output, parent=task).save_to_file() + + loaded_task = Task.load_from_file(task.path) + found = loaded_task.find_task_run_by_id_dfs("nonexistent-id") + assert found is None + + +def test_task_run_find_task_run_by_id_dfs_finds_deeply_nested_run(task: Task): + """TaskRun.find_task_run_by_id_dfs finds a run nested several levels in its subtree.""" + assert task.path is not None + output = TaskOutput(output="out") + + run1 = TaskRun(input="in1", output=output, parent=task) + run1.save_to_file() + + run2 = TaskRun(input="in2", output=output, parent=run1) + run2.save_to_file() + + run3 = TaskRun(input="in3", output=output, parent=run2) + run3.save_to_file() + + assert run3.id is not None + + assert run1.path is not None + loaded_run1 = TaskRun.load_from_file(run1.path) + found = loaded_run1.find_task_run_by_id_dfs(run3.id) + assert found is not None + assert found.id == run3.id + assert found.input == "in3" + + +def test_task_run_find_task_run_by_id_dfs_finds_direct_child(task: Task): + """TaskRun.find_task_run_by_id_dfs finds a direct child run.""" + assert task.path is not None + output = TaskOutput(output="out") + + run1 = TaskRun(input="in1", output=output, parent=task) + run1.save_to_file() + + run2 = TaskRun(input="in2", output=output, parent=run1) + run2.save_to_file() + + assert run2.id is not None + + loaded_run1 = TaskRun.load_from_file(run1.path) + found = loaded_run1.find_task_run_by_id_dfs(run2.id) + assert found is not None + assert found.id == run2.id + assert found.input == "in2" + + +def test_task_run_find_task_run_by_id_dfs_returns_none_when_not_found(task: Task): + """TaskRun.find_task_run_by_id_dfs returns None when no run in its subtree has the given id.""" + assert task.path is not None + output = TaskOutput(output="out") + + run1 = TaskRun(input="in1", output=output, parent=task) + run1.save_to_file() + + TaskRun(input="in2", output=output, parent=run1).save_to_file() + + loaded_run1 = TaskRun.load_from_file(run1.path) + found = loaded_run1.find_task_run_by_id_dfs("nonexistent-id") + assert found is None + + +def test_is_root_task_run(task: Task): + output = TaskOutput(output="out") + root_run = TaskRun(input="in", output=output, parent=task) + root_run.save_to_file() + assert root_run.is_root_task_run() + nested_run = TaskRun(input="nested in", output=output, parent=root_run) + nested_run.save_to_file() + assert not nested_run.is_root_task_run() From dc74f47618ba4392fd0ed60cba27fe3268bc5f17 Mon Sep 17 00:00:00 2001 From: "Leonard Q. Marcq" Date: Sat, 14 Mar 2026 01:29:22 +0800 Subject: [PATCH 19/32] chore: remove redundant test --- libs/core/kiln_ai/datamodel/test_models.py | 41 ---------------------- 1 file changed, 41 deletions(-) diff --git a/libs/core/kiln_ai/datamodel/test_models.py b/libs/core/kiln_ai/datamodel/test_models.py index 8da052eb6..8ea0cd273 100644 --- a/libs/core/kiln_ai/datamodel/test_models.py +++ b/libs/core/kiln_ai/datamodel/test_models.py @@ -858,47 +858,6 @@ def test_find_root_task_run_by_id_given_task(task: Task): assert found.input == "in" -def test_find_deeply_nested_find_task_run_by_id_dfs(task: Task): - """Find a run by id when we only have the task; run may be nested several levels. - No built-in helper exists for this; test uses a local recursive search. - A Task.run_by_id(id) or similar could be added to the datamodel if needed.""" - assert task.path is not None - output = TaskOutput(output="out") - - run1 = TaskRun(input="in1", output=output, parent=task) - run1.save_to_file() - - run2 = TaskRun(input="in2", output=output, parent=run1) - run2.save_to_file() - - run3 = TaskRun(input="in3", output=output, parent=run2) - run3.save_to_file() - - # this is the task run we will be searching for - assert run3.id is not None - - def find_run_by_id( - parent: Union[Task, TaskRun], task_run_id_to_find: str - ) -> Union[TaskRun, None]: - for r in parent.runs(): - if r.id == task_run_id_to_find: - return r - - # not ideal that we recurse over FS - might be costly, will need to use this - # lightly - found = find_run_by_id(r, task_run_id_to_find) - if found is not None: - return found - - return None - - loaded_task = Task.load_from_file(task.path) - found = find_run_by_id(loaded_task, run3.id) - assert found is not None - assert found.id == run3.id - assert found.input == "in3" - - def test_find_task_run_by_id_dfs_finds_deeply_nested_run(task: Task): """Task.find_task_run_by_id_dfs finds a run nested several levels (iterative stack-based DFS).""" assert task.path is not None From 79d3806b0185337c926cb3e9e55fd9df5b2d114d Mon Sep 17 00:00:00 2001 From: "Leonard Q. Marcq" Date: Sat, 14 Mar 2026 01:35:13 +0800 Subject: [PATCH 20/32] chore: lint unused import --- libs/core/kiln_ai/datamodel/test_models.py | 1 - 1 file changed, 1 deletion(-) diff --git a/libs/core/kiln_ai/datamodel/test_models.py b/libs/core/kiln_ai/datamodel/test_models.py index 8ea0cd273..ea2b51be2 100644 --- a/libs/core/kiln_ai/datamodel/test_models.py +++ b/libs/core/kiln_ai/datamodel/test_models.py @@ -1,6 +1,5 @@ import json import os -from typing import Union from unittest.mock import patch import pytest From 8aa5da1dbc9d0b1891164fe724b6a60109d90425 Mon Sep 17 00:00:00 2001 From: "Leonard Q. Marcq" Date: Tue, 17 Mar 2026 00:17:58 +0800 Subject: [PATCH 21/32] refactor: allow for subclasses to declare parent types --- libs/core/kiln_ai/datamodel/basemodel.py | 53 ++++- libs/core/kiln_ai/datamodel/task_run.py | 9 +- libs/core/kiln_ai/datamodel/test_basemodel.py | 199 +++++++++++++++++- libs/core/kiln_ai/datamodel/test_task.py | 176 +++++++++++++++- 4 files changed, 432 insertions(+), 5 deletions(-) diff --git a/libs/core/kiln_ai/datamodel/basemodel.py b/libs/core/kiln_ai/datamodel/basemodel.py index b325f3109..14e808671 100644 --- a/libs/core/kiln_ai/datamodel/basemodel.py +++ b/libs/core/kiln_ai/datamodel/basemodel.py @@ -545,6 +545,17 @@ def relationship_name(cls) -> str: def parent_type(cls) -> Type[KilnBaseModel]: raise NotImplementedError("Parent type must be implemented") + @classmethod + def _parent_types(cls) -> List[Type["KilnBaseModel"]] | None: + """Return accepted parent types. This must be implemented by the subclass if + the model can have multiple parent types. + + Return None (default) to use the single parent_type() check. + Override and return a list of parent types for models that can be nested + under more than one parent type (e.g. TaskRun can be nested under Task or TaskRun). + """ + return None + def _check_parent_type( self, expected_parent_types: List[Type[KilnBaseModel]] | None = None, @@ -623,11 +634,49 @@ def iterate_children_paths_of_parent_path(cls: Type[PT], parent_path: Path | Non else: parent_folder = parent_path - # cannot validate the parent type here because some parentable models are polymorphic - # and can be nested under different types of parent models if not parent_path.exists(): raise ValueError("Parent must be set to load children") + # Validate the parent file's declared type so we fail fast when the caller + # passes a wrong path. For polymorphic children (e.g. TaskRun) the + # subclass overrides _accepted_parent_types() to broaden the check to all + # accepted parent types + parent_types_override = cls._parent_types() + if parent_types_override is None: + # Default: single expected parent type — original behaviour + parent = cls.parent_type().load_from_file(parent_path) + if parent is None: + raise ValueError("Parent must be set to load children") + else: + # Polymorphic parent: read only the model_type field to avoid a full load. + with open(parent_path, "r", encoding="utf-8") as fh: + actual_parent_type_name = json.loads(fh.read()).get("model_type", "") + parent_type_names = {t.type_name() for t in parent_types_override} + if actual_parent_type_name not in parent_type_names: + raise ValueError( + f"Parent model_type '{actual_parent_type_name}' is not one of " + f"{parent_type_names}" + ) + else: + # find the parent type that matches the actual parent type name we found on disk + # and validate it + parent_type = next( + ( + t + for t in parent_types_override + if t.type_name() == actual_parent_type_name + ), + None, + ) + if parent_type is None: + raise ValueError( + f"Could not find parent type '{actual_parent_type_name}' in " + f"{parent_types_override}" + ) + parent = parent_type.load_from_file(parent_path) + if parent is None: + raise ValueError("Parent must be set to load children") + # Ignore type error: this is abstract base class, but children must implement relationship_name relationship_folder = parent_folder / Path(cls.relationship_name()) # type: ignore diff --git a/libs/core/kiln_ai/datamodel/task_run.py b/libs/core/kiln_ai/datamodel/task_run.py index 048898e6a..c6594ed0b 100644 --- a/libs/core/kiln_ai/datamodel/task_run.py +++ b/libs/core/kiln_ai/datamodel/task_run.py @@ -1,5 +1,5 @@ import json -from typing import TYPE_CHECKING, Dict, List, Optional, Union +from typing import TYPE_CHECKING, Dict, List, Optional, Type, Union from pydantic import BaseModel, Field, ValidationInfo, model_validator from typing_extensions import Self @@ -164,6 +164,13 @@ def parent_run(self) -> "TaskRun | None": return None return parent + @classmethod + def _parent_types(cls) -> List[Type["KilnBaseModel"]]: + # lazy import to avoid circular dependency + from kiln_ai.datamodel.task import Task + + return [Task, TaskRun] + def runs(self, readonly: bool = False) -> list["TaskRun"]: """The list of child task runs.""" return super().runs(readonly=readonly) # type: ignore diff --git a/libs/core/kiln_ai/datamodel/test_basemodel.py b/libs/core/kiln_ai/datamodel/test_basemodel.py index 897a3749f..be4bc8c45 100644 --- a/libs/core/kiln_ai/datamodel/test_basemodel.py +++ b/libs/core/kiln_ai/datamodel/test_basemodel.py @@ -11,7 +11,7 @@ from kiln_ai.adapters.model_adapters.base_adapter import BaseAdapter from kiln_ai.adapters.run_output import RunOutput -from kiln_ai.datamodel import Task, TaskRun +from kiln_ai.datamodel import Project, Task, TaskOutput, TaskRun from kiln_ai.datamodel.basemodel import ( MAX_FILENAME_LENGTH, KilnBaseModel, @@ -1166,3 +1166,200 @@ def test_readonly_cache_integration(tmp_model_cache, tmp_path): cached_readonly = ReadonlyTestModel.load_from_file(test_file, readonly=True) assert cached_readonly._readonly is True assert cached_readonly.name == "cached_model" + + +# ============================================================================ +# Tests for _parent_types() and parent type validation in _load_parent_and_validate_children +# ============================================================================ + + +def test_default_parent_types_returns_none(): + """Default _parent_types() returns None for models with single parent type.""" + assert DefaultParentedModel._parent_types() is None + assert NamedParentedModel._parent_types() is None + + +def test_taskrun_parent_types_returns_task_and_taskrun(): + """TaskRun._parent_types() returns [Task, TaskRun] for polymorphic parent support.""" + + parent_types = TaskRun._parent_types() + assert parent_types is not None + assert len(parent_types) == 2 + + parent_type_names = {t.type_name() for t in parent_types} + assert "task" in parent_type_names + assert "task_run" in parent_type_names + + +def test_invalid_parent_with_single_parent_type(tmp_path): + """Loading children fails when parent path points to wrong model type (single parent case).""" + # Create a project (wrong parent type) + project_path = tmp_path / "project.kiln" + project = Project(name="Test Project", path=project_path) + project.save_to_file() + + # Try to load DefaultParentedModel children from a Project path + # DefaultParentedModel expects BaseParentExample as parent, not Project + # The error occurs when trying to load the parent as the wrong type + with pytest.raises( + ValueError, match="Cannot load from file because the model type is incorrect" + ): + list(DefaultParentedModel.iterate_children_paths_of_parent_path(project_path)) + + +def test_invalid_parent_with_multiple_parent_types(tmp_path): + """Loading children fails when parent's model_type is not in accepted polymorphic types.""" + # Create a project (not an accepted parent type for TaskRun) + project_path = tmp_path / "project.kiln" + project = Project(name="Test Project", path=project_path) + project.save_to_file() + + # Try to load TaskRun children from a Project path + # TaskRun accepts Task and TaskRun, not Project + with pytest.raises(ValueError, match="Parent model_type 'project' is not one of"): + list(TaskRun.iterate_children_paths_of_parent_path(project_path)) + + +def test_valid_parent_type_single_parent(tmp_path): + """Successfully loads children when parent type matches single expected type.""" + parent = BaseParentExample(path=tmp_path / BaseParentExample.base_filename()) + parent.save_to_file() + + child = DefaultParentedModel(parent=parent, name="Test Child") + child.save_to_file() + + # Load children - should succeed since parent is correct type + children = list( + DefaultParentedModel.iterate_children_paths_of_parent_path(parent.path) + ) + assert len(children) == 1 + assert children[0] == child.path + + +def test_valid_parent_type_polymorphic_taskrun_as_parent(tmp_path): + """TaskRun can be loaded with TaskRun as parent (polymorphic case).""" + output = TaskOutput(output="test output") + + # Create a task as the ultimate parent + task = Task( + name="Test Task", + instruction="Test instruction", + path=tmp_path / "task.kiln", + ) + task.save_to_file() + + # Create a parent TaskRun + parent_run = TaskRun(input="parent input", output=output, parent=task) + parent_run.save_to_file() + + # Create a nested TaskRun + nested_run = TaskRun(input="nested input", output=output, parent=parent_run) + nested_run.save_to_file() + + # Load children of parent_run - should succeed + children = list(TaskRun.iterate_children_paths_of_parent_path(parent_run.path)) + assert len(children) == 1 + assert children[0] == nested_run.path + + +def test_valid_parent_type_polymorphic_task_as_parent(tmp_path): + """TaskRun can be loaded with Task as parent (polymorphic case).""" + output = TaskOutput(output="test output") + + # Create a task + task = Task( + name="Test Task", + instruction="Test instruction", + path=tmp_path / "task.kiln", + ) + task.save_to_file() + + # Create a TaskRun under the task + task_run = TaskRun(input="test input", output=output, parent=task) + task_run.save_to_file() + + # Load children of task - should succeed + children = list(TaskRun.iterate_children_paths_of_parent_path(task.path)) + assert len(children) == 1 + assert children[0] == task_run.path + + +def test_invalid_parent_type_name_mismatch_polymorphic(tmp_path): + """Polymorphic validation fails when parent file has wrong model_type.""" + # Create a file with wrong model_type + wrong_parent_path = tmp_path / "wrong_parent.kiln" + wrong_data = { + "v": 1, + "name": "Wrong Parent", + "model_type": "project", # Wrong type - not in accepted types + } + with open(wrong_parent_path, "w") as f: + json.dump(wrong_data, f) + + # Try to load TaskRun children from wrong parent + # TaskRun accepts Task and TaskRun, not Project + with pytest.raises(ValueError, match="Parent model_type 'project' is not one of"): + list(TaskRun.iterate_children_paths_of_parent_path(wrong_parent_path)) + + +def test_parent_loading_single_parent_type_fails_on_corrupt_file(tmp_path): + """Single parent type loading fails when parent file is corrupt/invalid.""" + # Create a corrupt parent file + corrupt_parent_path = tmp_path / "corrupt.kiln" + with open(corrupt_parent_path, "w") as f: + f.write("not valid json {{{") + + # The load_from_file call within iterate_children_paths_of_parent_path + # will fail when trying to parse the corrupt JSON + with pytest.raises(ValueError, match="Expecting value"): + list( + DefaultParentedModel.iterate_children_paths_of_parent_path( + corrupt_parent_path + ) + ) + + +def test_parent_loading_polymorphic_fails_on_corrupt_file(tmp_path): + """Polymorphic parent loading fails when parent file is corrupt/invalid.""" + # Create a corrupt parent file + corrupt_parent_path = tmp_path / "corrupt.kiln" + with open(corrupt_parent_path, "w") as f: + f.write("not valid json {{{") + + # The polymorphic path first reads model_type, which will fail on corrupt JSON + with pytest.raises(json.JSONDecodeError): + list(TaskRun.iterate_children_paths_of_parent_path(corrupt_parent_path)) + + +def test_parent_loading_single_parent_nonexistent_file(tmp_path): + """Single parent type loading fails when parent file doesn't exist.""" + nonexistent_path = tmp_path / "nonexistent.kiln" + + with pytest.raises(ValueError, match="Parent must be set to load children"): + list( + DefaultParentedModel.iterate_children_paths_of_parent_path(nonexistent_path) + ) + + +def test_parent_loading_polymorphic_nonexistent_file(tmp_path): + """Polymorphic parent loading fails when parent file doesn't exist.""" + nonexistent_path = tmp_path / "nonexistent.kiln" + + with pytest.raises(ValueError, match="Parent must be set to load children"): + list(TaskRun.iterate_children_paths_of_parent_path(nonexistent_path)) + + +def test_all_children_of_parent_path_single_parent_type(tmp_path): + """all_children_of_parent_path works correctly for single parent type models.""" + parent = BaseParentExample(path=tmp_path / "parent.kiln") + parent.save_to_file() + + child1 = DefaultParentedModel(parent=parent, name="Child1") + child2 = DefaultParentedModel(parent=parent, name="Child2") + child1.save_to_file() + child2.save_to_file() + + children = DefaultParentedModel.all_children_of_parent_path(parent.path) + assert len(children) == 2 + names = {child.name for child in children} + assert names == {"Child1", "Child2"} diff --git a/libs/core/kiln_ai/datamodel/test_task.py b/libs/core/kiln_ai/datamodel/test_task.py index 1621db603..2d819323b 100644 --- a/libs/core/kiln_ai/datamodel/test_task.py +++ b/libs/core/kiln_ai/datamodel/test_task.py @@ -1,3 +1,5 @@ +import json + import pytest from pydantic import ValidationError @@ -6,6 +8,7 @@ StructuredOutputMode, TaskOutputRatingType, ) +from kiln_ai.datamodel.project import Project from kiln_ai.datamodel.prompt_id import PromptGenerators from kiln_ai.datamodel.run_config import KilnAgentRunConfigProperties from kiln_ai.datamodel.spec import Spec @@ -15,7 +18,8 @@ ToxicityProperties, ) from kiln_ai.datamodel.task import Task, TaskRunConfig -from kiln_ai.datamodel.task_output import normalize_rating +from kiln_ai.datamodel.task_output import TaskOutput, normalize_rating +from kiln_ai.datamodel.task_run import TaskRun def test_runconfig_valid_creation(): @@ -457,3 +461,173 @@ def test_task_prompt_optimization_jobs_readonly(tmp_path): assert ( prompt_optimization_jobs_default[0].name == "Readonly Prompt Optimization Job" ) + + +def test_all_children_of_parent_path_polymorphic(tmp_path): + """all_children_of_parent_path works correctly for polymorphic parent models.""" + # Test with TaskRun and Task + task = Task( + name="Test Task", + instruction="Test instruction", + path=tmp_path / "task.kiln", + ) + task.save_to_file() + + output = TaskOutput(output="test output") + + # Create direct children of Task + run1 = TaskRun(input="input1", output=output, parent=task) + run2 = TaskRun(input="input2", output=output, parent=task) + run1.save_to_file() + run2.save_to_file() + + children = TaskRun.all_children_of_parent_path(task.path) + assert len(children) == 2 + inputs = {child.input for child in children} + assert inputs == {"input1", "input2"} + + +def test_taskrun_nested_validates_parent_type_on_load(tmp_path): + """Loading a TaskRun validates its parent type is Task or TaskRun.""" + output = TaskOutput(output="test output") + + # Create a task + task = Task( + name="Test Task", + instruction="Test instruction", + path=tmp_path / "task.kiln", + ) + task.save_to_file() + + # Create parent TaskRun + parent_run = TaskRun(input="parent input", output=output, parent=task) + parent_run.save_to_file() + + # Create nested TaskRun + nested_run = TaskRun(input="nested input", output=output, parent=parent_run) + nested_run.save_to_file() + + # Reload from disk - parent type should be validated + # When loading from disk, the parent attribute points to the ultimate parent (Task) + # Use load_parent() to get the direct parent (TaskRun) + loaded_run = TaskRun.load_from_file(nested_run.path) + assert loaded_run is not None + assert loaded_run.input == "nested input" + # parent_task() returns the ultimate parent task (different instance but same data) + loaded_parent_task = loaded_run.parent_task() + assert loaded_parent_task is not None + assert loaded_parent_task.name == "Test Task" + assert loaded_parent_task.instruction == "Test instruction" + # Use load_parent() to get the direct parent TaskRun + direct_parent = loaded_run.load_parent() + assert direct_parent is not None + assert direct_parent.id == parent_run.id + assert direct_parent.input == "parent input" + + +def test_taskrun_loads_from_task_path(tmp_path): + """TaskRun children can be loaded from Task path (valid polymorphic parent).""" + output = TaskOutput(output="test output") + + task = Task( + name="Test Task", + instruction="Test instruction", + path=tmp_path / "task.kiln", + ) + task.save_to_file() + + run = TaskRun(input="test input", output=output, parent=task) + run.save_to_file() + + # Load children from task path - should succeed + children = list(TaskRun.iterate_children_paths_of_parent_path(task.path)) + assert len(children) == 1 + assert children[0] == run.path + + +def test_taskrun_loads_from_taskrun_path(tmp_path): + """TaskRun children can be loaded from TaskRun path (valid polymorphic parent).""" + output = TaskOutput(output="test output") + + task = Task( + name="Test Task", + instruction="Test instruction", + path=tmp_path / "task.kiln", + ) + task.save_to_file() + + parent_run = TaskRun(input="parent input", output=output, parent=task) + parent_run.save_to_file() + + nested_run = TaskRun(input="nested input", output=output, parent=parent_run) + nested_run.save_to_file() + + # Load children from TaskRun path - should succeed + children = list(TaskRun.iterate_children_paths_of_parent_path(parent_run.path)) + assert len(children) == 1 + assert children[0] == nested_run.path + + +def test_taskrun_fails_to_load_from_project_path(tmp_path): + """TaskRun children cannot be loaded from Project path (invalid polymorphic parent).""" + project_path = tmp_path / "project.kiln" + project = Project(name="Test Project", path=project_path) + project.save_to_file() + + # Try to load TaskRun children from a Project path - should fail + with pytest.raises(ValueError, match="Parent model_type 'project' is not one of"): + list(TaskRun.iterate_children_paths_of_parent_path(project_path)) + + +def test_multiple_nested_levels_validates_each_level(tmp_path): + """Multi-level nesting validates parent type at each level.""" + output = TaskOutput(output="test output") + + task = Task( + name="Test Task", + instruction="Test instruction", + path=tmp_path / "task.kiln", + ) + task.save_to_file() + + run1 = TaskRun(input="input1", output=output, parent=task) + run1.save_to_file() + + run2 = TaskRun(input="input2", output=output, parent=run1) + run2.save_to_file() + + run3 = TaskRun(input="input3", output=output, parent=run2) + run3.save_to_file() + + # Load children at each level - all should succeed + task_children = TaskRun.all_children_of_parent_path(task.path) + assert len(task_children) == 1 + assert task_children[0].id == run1.id + + run1_children = TaskRun.all_children_of_parent_path(run1.path) + assert len(run1_children) == 1 + assert run1_children[0].id == run2.id + + run2_children = TaskRun.all_children_of_parent_path(run2.path) + assert len(run2_children) == 1 + assert run2_children[0].id == run3.id + + +def test_polymorphic_parent_type_validation_fast_fail(tmp_path): + """Polymorphic validation fails fast without loading entire parent model.""" + # Create a file that's syntactically valid JSON but semantically invalid + # for the parent type - the polymorphic path should only read model_type + invalid_parent_path = tmp_path / "invalid.kiln" + + # Write a file that would fail if we tried to fully load as Task/TaskRun + # but should be caught by the model_type check first + invalid_data = { + "model_type": "project", # Wrong type - not Task or TaskRun + "extra_field": "this would cause issues", + } + with open(invalid_parent_path, "w") as f: + json.dump(invalid_data, f) + + # Should fail on model_type check, not on full load + with pytest.raises(ValueError, match="Parent model_type 'project' is not one of"): + list(TaskRun.iterate_children_paths_of_parent_path(invalid_parent_path)) From 27e039307da5b17aa17434445a93f79b1eed5262 Mon Sep 17 00:00:00 2001 From: "Leonard Q. Marcq" Date: Tue, 17 Mar 2026 00:35:01 +0800 Subject: [PATCH 22/32] fix: throw on task run parent error --- libs/core/kiln_ai/datamodel/task_run.py | 7 +- libs/core/kiln_ai/datamodel/test_task.py | 134 +++++++++++++++++++++++ 2 files changed, 139 insertions(+), 2 deletions(-) diff --git a/libs/core/kiln_ai/datamodel/task_run.py b/libs/core/kiln_ai/datamodel/task_run.py index c6594ed0b..f3f1361ab 100644 --- a/libs/core/kiln_ai/datamodel/task_run.py +++ b/libs/core/kiln_ai/datamodel/task_run.py @@ -208,8 +208,11 @@ def load_parent(self) -> Optional[KilnBaseModel]: loaded_parent_run = TaskRun.load_from_file(task_run_path) super().__setattr__("parent", loaded_parent_run) return loaded_parent_run - except ValueError: - pass + except ValueError as e: + raise ValueError( + f"Failed to load parent TaskRun from {task_run_path}. " + f"This indicates a malformed nested task run. Error: {e}" + ) from e from kiln_ai.datamodel.task import Task diff --git a/libs/core/kiln_ai/datamodel/test_task.py b/libs/core/kiln_ai/datamodel/test_task.py index 2d819323b..f3921f9b8 100644 --- a/libs/core/kiln_ai/datamodel/test_task.py +++ b/libs/core/kiln_ai/datamodel/test_task.py @@ -631,3 +631,137 @@ def test_polymorphic_parent_type_validation_fast_fail(tmp_path): # Should fail on model_type check, not on full load with pytest.raises(ValueError, match="Parent model_type 'project' is not one of"): list(TaskRun.iterate_children_paths_of_parent_path(invalid_parent_path)) + + +def test_load_parent_raises_on_malformed_taskrun(tmp_path): + """load_parent raises ValueError with context when parent TaskRun is malformed.""" + output = TaskOutput(output="test output") + + # Create a task + task = Task( + name="Test Task", + instruction="Test instruction", + path=tmp_path / "task.kiln", + ) + task.save_to_file() + + # Create parent TaskRun directory and valid run + parent_run_dir = tmp_path / "runs" / "parent_run" + parent_run_dir.mkdir(parents=True) + parent_run = TaskRun( + input="parent input", + output=output, + path=parent_run_dir / TaskRun.base_filename(), + ) + parent_run.save_to_file() + + # Create nested TaskRun directory and run (before corrupting parent) + nested_run_dir = parent_run_dir / "runs" / "nested_run" + nested_run_dir.mkdir(parents=True) + nested_run = TaskRun( + input="nested input", + output=output, + path=nested_run_dir / TaskRun.base_filename(), + ) + nested_run.save_to_file() + + # Verify it loads correctly with valid parent + loaded_nested = TaskRun.load_from_file(nested_run.path) + loaded_parent = loaded_nested.load_parent() + assert loaded_parent is not None + assert loaded_parent.input == "parent input" + + # Now corrupt the parent TaskRun file + with open(parent_run.path, "w") as f: + json.dump({"model_type": "task_run", "input": 123}, f) # Invalid input type + + # Reload nested run and try to load parent - should raise with context + loaded_nested = TaskRun.load_from_file(nested_run.path) + with pytest.raises(ValueError) as exc_info: + loaded_nested.load_parent() + + error_msg = str(exc_info.value) + assert "Failed to load parent TaskRun" in error_msg + assert str(parent_run.path) in error_msg + assert "malformed nested task run" in error_msg + + +def test_load_parent_succeeds_for_valid_taskrun_parent(tmp_path): + """load_parent successfully loads a valid TaskRun parent.""" + output = TaskOutput(output="test output") + + # Create a task + task = Task( + name="Test Task", + instruction="Test instruction", + path=tmp_path / "task.kiln", + ) + task.save_to_file() + + # Create parent TaskRun + parent_run = TaskRun(input="parent input", output=output, parent=task) + parent_run.save_to_file() + + # Create nested TaskRun + runs_dir = parent_run.path.parent / "runs" + runs_dir.mkdir(exist_ok=True) + nested_run_dir = runs_dir / "nested_run" + nested_run_dir.mkdir(exist_ok=True) + nested_run = TaskRun( + input="nested input", output=output, path=nested_run_dir / "task_run.kiln" + ) + nested_run.save_to_file() + + # Reload and load parent - should succeed + loaded_nested = TaskRun.load_from_file(nested_run.path) + loaded_parent = loaded_nested.load_parent() + + assert loaded_parent is not None + assert loaded_parent.id == parent_run.id + assert loaded_parent.input == "parent input" + assert isinstance(loaded_parent, TaskRun) + + +def test_is_root_task_run_true_when_parent_is_task(tmp_path): + """is_root_task_run returns True when parent is a Task.""" + output = TaskOutput(output="test output") + + task = Task( + name="Test Task", + instruction="Test instruction", + path=tmp_path / "task.kiln", + ) + task.save_to_file() + + run = TaskRun(input="test input", output=output, parent=task) + run.save_to_file() + + loaded_run = TaskRun.load_from_file(run.path) + assert loaded_run.is_root_task_run() is True + + +def test_is_root_task_run_false_when_parent_is_taskrun(tmp_path): + """is_root_task_run returns False when parent is another TaskRun.""" + output = TaskOutput(output="test output") + + task = Task( + name="Test Task", + instruction="Test instruction", + path=tmp_path / "task.kiln", + ) + task.save_to_file() + + parent_run = TaskRun(input="parent input", output=output, parent=task) + parent_run.save_to_file() + + runs_dir = parent_run.path.parent / "runs" + runs_dir.mkdir(exist_ok=True) + nested_run_dir = runs_dir / "nested_run" + nested_run_dir.mkdir(exist_ok=True) + nested_run = TaskRun( + input="nested input", output=output, path=nested_run_dir / "task_run.kiln" + ) + nested_run.save_to_file() + + loaded_nested = TaskRun.load_from_file(nested_run.path) + assert loaded_nested.is_root_task_run() is False From 243da49e5259cf3d0ba455f6acba1af995ba8789 Mon Sep 17 00:00:00 2001 From: "Leonard Q. Marcq" Date: Tue, 17 Mar 2026 01:12:59 +0800 Subject: [PATCH 23/32] test: add mega test --- libs/core/kiln_ai/datamodel/test_models.py | 178 +++++++++++++++++++++ 1 file changed, 178 insertions(+) diff --git a/libs/core/kiln_ai/datamodel/test_models.py b/libs/core/kiln_ai/datamodel/test_models.py index ea2b51be2..09973c4dd 100644 --- a/libs/core/kiln_ai/datamodel/test_models.py +++ b/libs/core/kiln_ai/datamodel/test_models.py @@ -973,3 +973,181 @@ def test_is_root_task_run(task: Task): nested_run = TaskRun(input="nested in", output=output, parent=root_run) nested_run.save_to_file() assert not nested_run.is_root_task_run() + + +def test_comprehensive_nested_task_run_hierarchy(tmp_path): + """Comprehensive integration test for polymorphic parent support on TaskRun. + + Tests: + - Project -> Task -> TaskRun hierarchy + - Multiple levels of nested TaskRuns (Task -> TaskRun -> TaskRun -> TaskRun -> TaskRun) + - Sibling TaskRuns at various levels + - Retrieval of all runs from the hierarchy + - is_root_task_run() correctness at all levels + - parent_task() and parent_run() correctness at all levels + """ + project_path = tmp_path / "project.kiln" + project = Project(name="Test Project", path=project_path) + project.save_to_file() + task = Task(name="Test Task", instruction="Test instruction", parent=project) + task.save_to_file() + + output = TaskOutput(output="test output") + + # Level 1: Two sibling TaskRuns under Task + run1_l1 = TaskRun(input="level1_run1", output=output, parent=task) + run1_l1.save_to_file() + + run2_l1 = TaskRun(input="level1_run2", output=output, parent=task) + run2_l1.save_to_file() + + # Level 2: Nested TaskRuns under run1_l1 (with sibling) + run1_l2 = TaskRun(input="level2_run1", output=output, parent=run1_l1) + run1_l2.save_to_file() + + run2_l2 = TaskRun(input="level2_run2", output=output, parent=run1_l1) + run2_l2.save_to_file() + + # Level 3: Nested TaskRuns under run1_l2 (with sibling) + run1_l3 = TaskRun(input="level3_run1", output=output, parent=run1_l2) + run1_l3.save_to_file() + + run2_l3 = TaskRun(input="level3_run2", output=output, parent=run1_l2) + run2_l3.save_to_file() + + # Level 4: Deepest nested TaskRun under run1_l3 + run1_l4 = TaskRun(input="level4_run1", output=output, parent=run1_l3) + run1_l4.save_to_file() + + # Level 4 sibling of run1_l4 under run2_l3 + run2_l4 = TaskRun(input="level4_sibling", output=output, parent=run2_l3) + run2_l4.save_to_file() + + # Reload everything from disk to test persistence and retrieval + loaded_project = Project.load_from_file(project_path) + loaded_task = loaded_project.tasks()[0] + + # Verify Task has 2 root-level runs + root_runs = loaded_task.runs() + assert len(root_runs) == 2 + root_run_ids = {r.id for r in root_runs} + assert run1_l1.id in root_run_ids + assert run2_l1.id in root_run_ids + + # Verify run1_l1 hierarchy + loaded_run1_l1 = next(r for r in root_runs if r.id == run1_l1.id) + assert loaded_run1_l1.parent_task() == loaded_task + assert loaded_run1_l1.parent_run() is None + assert loaded_run1_l1.is_root_task_run() is True + + level2_runs = loaded_run1_l1.runs() + assert len(level2_runs) == 2 + level2_run_ids = {r.id for r in level2_runs} + assert run1_l2.id in level2_run_ids + assert run2_l2.id in level2_run_ids + + # Verify run1_l2 hierarchy + loaded_run1_l2 = next(r for r in level2_runs if r.id == run1_l2.id) + assert loaded_run1_l2.parent_task() == loaded_task + assert loaded_run1_l2.parent_run() == loaded_run1_l1 + assert loaded_run1_l2.is_root_task_run() is False + + level3_runs = loaded_run1_l2.runs() + assert len(level3_runs) == 2 + level3_run_ids = {r.id for r in level3_runs} + assert run1_l3.id in level3_run_ids + assert run2_l3.id in level3_run_ids + + # Verify run1_l3 hierarchy (level 3) + loaded_run1_l3 = next(r for r in level3_runs if r.id == run1_l3.id) + assert loaded_run1_l3.parent_task() == loaded_task + assert loaded_run1_l3.parent_run().id == loaded_run1_l2.id + assert loaded_run1_l3.is_root_task_run() is False + + level4_runs = loaded_run1_l3.runs() + assert len(level4_runs) == 1 + assert level4_runs[0].id == run1_l4.id + + # Verify run1_l4 (deepest run) + loaded_run1_l4 = level4_runs[0] + assert loaded_run1_l4.parent_task() == loaded_task + assert loaded_run1_l4.parent_run().id == loaded_run1_l3.id + assert loaded_run1_l4.is_root_task_run() is False + assert len(loaded_run1_l4.runs()) == 0 + + # Verify run2_l3's nested run (sibling at level 3 has child at level 4) + loaded_run2_l3 = next(r for r in level3_runs if r.id == run2_l3.id) + assert loaded_run2_l3.parent_task() == loaded_task + assert loaded_run2_l3.parent_run().id == loaded_run1_l2.id + assert loaded_run2_l3.is_root_task_run() is False + + level4_sibling_runs = loaded_run2_l3.runs() + assert len(level4_sibling_runs) == 1 + assert level4_sibling_runs[0].id == run2_l4.id + + # Verify run2_l4 (sibling at level 4) + loaded_run2_l4 = level4_sibling_runs[0] + assert loaded_run2_l4.parent_task() == loaded_task + assert loaded_run2_l4.parent_run().id == loaded_run2_l3.id + assert loaded_run2_l4.is_root_task_run() is False + + # Verify run2_l1 (sibling at level 1 has no children) + loaded_run2_l1 = next(r for r in root_runs if r.id == run2_l1.id) + assert loaded_run2_l1.parent_task() == loaded_task + assert loaded_run2_l1.parent_run() is None + assert loaded_run2_l1.is_root_task_run() is True + assert len(loaded_run2_l1.runs()) == 0 + + # Verify run2_l2 (sibling at level 2 has no children) + loaded_run2_l2 = next(r for r in level2_runs if r.id == run2_l2.id) + assert loaded_run2_l2.parent_task() == loaded_task + assert loaded_run2_l2.parent_run() == loaded_run1_l1 + assert loaded_run2_l2.is_root_task_run() is False + assert len(loaded_run2_l2.runs()) == 0 + + # Test finding runs by ID from Task (DFS) + found_run1_l4 = loaded_task.find_task_run_by_id_dfs(run1_l4.id) + assert found_run1_l4 is not None + assert found_run1_l4.id == run1_l4.id + assert found_run1_l4.input == "level4_run1" + + found_run2_l4 = loaded_task.find_task_run_by_id_dfs(run2_l4.id) + assert found_run2_l4 is not None + assert found_run2_l4.id == run2_l4.id + assert found_run2_l4.input == "level4_sibling" + + # Test finding runs by ID from a TaskRun (DFS) + found_from_run1_l1 = loaded_run1_l1.find_task_run_by_id_dfs(run1_l4.id) + assert found_from_run1_l1 is not None + assert found_from_run1_l1.id == run1_l4.id + + # Count total runs in the hierarchy (should be 8) + all_runs = collect_all_task_runs(loaded_task) + assert len(all_runs) == 8 + + # Verify all levels are represented correctly + is_root_runs = [r for r in all_runs if r.is_root_task_run()] + assert len(is_root_runs) == 2 # run1_l1 and run2_l1 + + # Verify all non-root runs have parent_run set correctly + non_root_runs = [r for r in all_runs if not r.is_root_task_run()] + assert len(non_root_runs) == 6 + for run in non_root_runs: + assert run.parent_run() is not None + + +def collect_all_task_runs(root): + """Helper to recursively collect all TaskRuns in a hierarchy.""" + runs = [] + + def collect(node): + if isinstance(node, TaskRun): + runs.append(node) + for child in node.runs(): + collect(child) + elif hasattr(node, "runs"): + for child in node.runs(): + collect(child) + + collect(root) + return runs From 1accfe022014504bab733e608718647b6a985599 Mon Sep 17 00:00:00 2001 From: "Leonard Q. Marcq" Date: Tue, 17 Mar 2026 01:33:46 +0800 Subject: [PATCH 24/32] refactor: iterative walk back up the tree instead of recursion --- libs/core/kiln_ai/datamodel/task_run.py | 32 +++++++++++++++---------- 1 file changed, 19 insertions(+), 13 deletions(-) diff --git a/libs/core/kiln_ai/datamodel/task_run.py b/libs/core/kiln_ai/datamodel/task_run.py index f3f1361ab..c231c4f36 100644 --- a/libs/core/kiln_ai/datamodel/task_run.py +++ b/libs/core/kiln_ai/datamodel/task_run.py @@ -142,20 +142,26 @@ def has_thinking_training_data(self) -> bool: # Workaround to return typed parent without importing Task def parent_task(self) -> Union["Task", None]: """The Task that this Run is in. Note the TaskRun may be nested in which case we walk back up the tree all the way to the root.""" - if self.parent is None: - return None - - # this task run is already the root task run - if self.parent.__class__.__name__ == "Task": - return self.parent # type: ignore - - # this task run is nested under other ones, so we walk back - # up to the root task run - parent_run = self.cached_parent() - if isinstance(parent_run, TaskRun): - return parent_run.parent_task() + current = self + while True: + # should never really happen, except maybe in tests + if current.parent is None: + return None - return None + # this task run is the root task run + # so we just return its parent (a Task) + if current.parent.__class__.__name__ == "Task": + return current.parent # type: ignore + + # the parent of this task is not a Task, so it has to be a TaskRun + # and we just walk back up the tree of TaskRuns until we find a Task + parent_run = current.cached_parent() + if isinstance(parent_run, TaskRun): + current = parent_run + else: + # the parent is not a TaskRun, but also not a Task, so it is not + # a real parent + return None def parent_run(self) -> "TaskRun | None": """The TaskRun that contains this run, if this run is nested; otherwise None.""" From 7d536d306b359f3219ed982b8908b45bb5d13cfa Mon Sep 17 00:00:00 2001 From: "Leonard Q. Marcq" Date: Tue, 17 Mar 2026 03:19:40 +0800 Subject: [PATCH 25/32] refactor: remove unused arg --- libs/core/kiln_ai/adapters/model_adapters/base_adapter.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/libs/core/kiln_ai/adapters/model_adapters/base_adapter.py b/libs/core/kiln_ai/adapters/model_adapters/base_adapter.py index d9633ac81..40e32d5ef 100644 --- a/libs/core/kiln_ai/adapters/model_adapters/base_adapter.py +++ b/libs/core/kiln_ai/adapters/model_adapters/base_adapter.py @@ -327,7 +327,6 @@ def _finalize_stream( adapter_stream: AdapterStream, input: InputType, input_source: DataSource | None, - prior_trace: list[ChatCompletionMessageParam] | None, ) -> TaskRun: """Streaming invocations are only concerned with passing through events as they come in. At the end of the stream, we still need to validate the output, create a run and everything @@ -654,7 +653,7 @@ async def __aiter__(self) -> AsyncIterator[ModelResponseStream]: yield event self._task_run = self._adapter._finalize_stream( - adapter_stream, self._input, self._input_source, self._prior_trace + adapter_stream, self._input, self._input_source ) finally: if is_root_agent: @@ -731,7 +730,7 @@ async def __aiter__(self) -> AsyncIterator[AiSdkStreamEvent]: yield AiSdkStreamEvent(AiSdkEventType.FINISH_STEP) self._task_run = self._adapter._finalize_stream( - adapter_stream, self._input, self._input_source, self._prior_trace + adapter_stream, self._input, self._input_source ) finally: if is_root_agent: From 880ae77c1c017168a11a64ea48e10590eb5ac00b Mon Sep 17 00:00:00 2001 From: "Leonard Q. Marcq" Date: Tue, 17 Mar 2026 03:55:14 +0800 Subject: [PATCH 26/32] fix: paid test with thinking level --- .../test_litellm_adapter_streaming.py | 68 +++++++++++-------- 1 file changed, 38 insertions(+), 30 deletions(-) diff --git a/libs/core/kiln_ai/adapters/model_adapters/test_litellm_adapter_streaming.py b/libs/core/kiln_ai/adapters/model_adapters/test_litellm_adapter_streaming.py index 8b786a673..28b92faf1 100644 --- a/libs/core/kiln_ai/adapters/model_adapters/test_litellm_adapter_streaming.py +++ b/libs/core/kiln_ai/adapters/model_adapters/test_litellm_adapter_streaming.py @@ -23,21 +23,19 @@ logger = logging.getLogger(__name__) STREAMING_MODELS = [ - ("claude_sonnet_4_5", ModelProviderName.openrouter), - ("claude_sonnet_4_5", ModelProviderName.anthropic), - ("claude_sonnet_4_6", ModelProviderName.openrouter), - ("claude_sonnet_4_6", ModelProviderName.anthropic), - ("claude_opus_4_5", ModelProviderName.openrouter), - ("claude_opus_4_5", ModelProviderName.anthropic), - ("claude_opus_4_6", ModelProviderName.openrouter), - ("claude_opus_4_6", ModelProviderName.anthropic), - ("minimax_m2_5", ModelProviderName.openrouter), - ("claude_4_5_haiku", ModelProviderName.openrouter), - ("claude_4_5_haiku", ModelProviderName.anthropic), + ("claude_sonnet_4_5", ModelProviderName.openrouter, "medium"), + ("claude_sonnet_4_5", ModelProviderName.anthropic, "medium"), + ("claude_sonnet_4_6", ModelProviderName.openrouter, "medium"), + ("claude_sonnet_4_6", ModelProviderName.anthropic, "medium"), + ("claude_opus_4_5", ModelProviderName.openrouter, "medium"), + ("claude_opus_4_5", ModelProviderName.anthropic, "medium"), + ("claude_opus_4_6", ModelProviderName.openrouter, "medium"), + ("claude_opus_4_6", ModelProviderName.anthropic, "medium"), + ("minimax_m2_5", ModelProviderName.openrouter, "medium"), + ("claude_4_5_haiku", ModelProviderName.openrouter, "medium"), + ("claude_4_5_haiku", ModelProviderName.anthropic, "medium"), ] -STREAMING_MODELS_NO_HAIKU = [m for m in STREAMING_MODELS if "haiku" not in m[0]] - PAID_TEST_OUTPUT_DIR = Path(__file__).resolve().parents[5] / "test_output" @@ -96,9 +94,11 @@ def task(tmp_path): @pytest.fixture -def adapter_factory(task: Task) -> Callable[[str, ModelProviderName], LiteLlmAdapter]: +def adapter_factory( + task: Task, +) -> Callable[[str, ModelProviderName, str | None], LiteLlmAdapter]: def create_adapter( - model_id: str, provider_name: ModelProviderName + model_id: str, provider_name: ModelProviderName, thinking_level: str | None ) -> LiteLlmAdapter: return LiteLlmAdapter( kiln_task=task, @@ -116,6 +116,7 @@ def create_adapter( KilnBuiltInToolId.DIVIDE_NUMBERS, ], ), + thinking_level=thinking_level, ) ), ) @@ -124,15 +125,16 @@ def create_adapter( @pytest.mark.paid -@pytest.mark.parametrize("model_id,provider_name", STREAMING_MODELS) +@pytest.mark.parametrize("model_id,provider_name,thinking_level", STREAMING_MODELS) async def test_invoke_openai_stream_chunks( request: pytest.FixtureRequest, model_id: str, provider_name: ModelProviderName, - adapter_factory: Callable[[str, ModelProviderName], LiteLlmAdapter], + thinking_level: str | None, + adapter_factory: Callable[[str, ModelProviderName, str | None], LiteLlmAdapter], ): """Collect all OpenAI-protocol chunks via invoke_openai_stream and verify we got reasoning, content, and tool call data.""" - adapter = adapter_factory(model_id, provider_name) + adapter = adapter_factory(model_id, provider_name, thinking_level) chunks: list[litellm.ModelResponseStream] = [] async for chunk in adapter.invoke_openai_stream(input="123 + 321 = ?"): @@ -146,6 +148,8 @@ async def test_invoke_openai_stream_chunks( tool_calls: list[ChatCompletionDeltaToolCall | Any] = [] for chunk in chunks: + if len(chunk.choices) == 0: + continue if chunk.choices[0].finish_reason is not None: continue delta = chunk.choices[0].delta @@ -187,15 +191,16 @@ async def test_invoke_openai_stream_chunks( @pytest.mark.paid -@pytest.mark.parametrize("model_id,provider_name", STREAMING_MODELS) +@pytest.mark.parametrize("model_id,provider_name,thinking_level", STREAMING_MODELS) async def test_invoke_ai_sdk_stream( request: pytest.FixtureRequest, model_id: str, provider_name: ModelProviderName, - adapter_factory: Callable[[str, ModelProviderName], LiteLlmAdapter], + thinking_level: str | None, + adapter_factory: Callable[[str, ModelProviderName, str | None], LiteLlmAdapter], ): """Collect AI SDK events and verify the full protocol lifecycle including tool events.""" - adapter = adapter_factory(model_id, provider_name) + adapter = adapter_factory(model_id, provider_name, thinking_level) events: list[AiSdkStreamEvent] = [] async for event in adapter.invoke_ai_sdk_stream(input="123 + 321 = ?"): @@ -263,15 +268,16 @@ async def test_invoke_ai_sdk_stream( @pytest.mark.paid -@pytest.mark.parametrize("model_id,provider_name", STREAMING_MODELS) +@pytest.mark.parametrize("model_id,provider_name,thinking_level", STREAMING_MODELS) async def test_ai_sdk_stream_text_ends_before_tool_calls( request: pytest.FixtureRequest, model_id: str, provider_name: ModelProviderName, - adapter_factory: Callable[[str, ModelProviderName], LiteLlmAdapter], + thinking_level: str | None, + adapter_factory: Callable[[str, ModelProviderName, str | None], LiteLlmAdapter], ): """Verify text blocks are properly closed before tool-input-start and reopened with a new ID after tool execution.""" - adapter = adapter_factory(model_id, provider_name) + adapter = adapter_factory(model_id, provider_name, thinking_level) events: list[AiSdkStreamEvent] = [] async for event in adapter.invoke_ai_sdk_stream( @@ -322,15 +328,16 @@ async def test_ai_sdk_stream_text_ends_before_tool_calls( @pytest.mark.paid -@pytest.mark.parametrize("model_id,provider_name", STREAMING_MODELS_NO_HAIKU) +@pytest.mark.parametrize("model_id,provider_name,thinking_level", STREAMING_MODELS) async def test_invoke_openai_stream_non_streaming_still_works( request: pytest.FixtureRequest, model_id: str, provider_name: ModelProviderName, - adapter_factory: Callable[[str, ModelProviderName], LiteLlmAdapter], + thinking_level: str | None, + adapter_factory: Callable[[str, ModelProviderName, str | None], LiteLlmAdapter], ): """Verify the non-streaming invoke() still works after the refactor.""" - adapter = adapter_factory(model_id, provider_name) + adapter = adapter_factory(model_id, provider_name, thinking_level) task_run = await adapter.invoke(input="123 + 321 = ?") _dump_paid_test_output(request, task_run=task_run) @@ -342,15 +349,16 @@ async def test_invoke_openai_stream_non_streaming_still_works( @pytest.mark.paid -@pytest.mark.parametrize("model_id,provider_name", STREAMING_MODELS_NO_HAIKU) +@pytest.mark.parametrize("model_id,provider_name,thinking_level", STREAMING_MODELS) async def test_invoke_openai_stream_with_prior_trace( request: pytest.FixtureRequest, model_id: str, provider_name: ModelProviderName, - adapter_factory: Callable[[str, ModelProviderName], LiteLlmAdapter], + thinking_level: str | None, + adapter_factory: Callable[[str, ModelProviderName, str | None], LiteLlmAdapter], ): """Test that streaming works when continuing an existing run (session continuation).""" - adapter = adapter_factory(model_id, provider_name) + adapter = adapter_factory(model_id, provider_name, thinking_level) initial_run = await adapter.invoke(input="123 + 321 = ?") assert initial_run.trace is not None From f7070048ec53b3877d705408ccc9c928a8994d46 Mon Sep 17 00:00:00 2001 From: "Leonard Q. Marcq" Date: Tue, 17 Mar 2026 04:36:25 +0800 Subject: [PATCH 27/32] fix: finish_step must come before finish in AI SDK protocol --- .../adapters/model_adapters/base_adapter.py | 5 ++++- .../adapters/model_adapters/stream_events.py | 9 ++++++++- .../test_litellm_adapter_streaming.py | 5 +++++ .../model_adapters/test_stream_events.py | 16 +++++++++------- 4 files changed, 26 insertions(+), 9 deletions(-) diff --git a/libs/core/kiln_ai/adapters/model_adapters/base_adapter.py b/libs/core/kiln_ai/adapters/model_adapters/base_adapter.py index 40e32d5ef..60cd014b1 100644 --- a/libs/core/kiln_ai/adapters/model_adapters/base_adapter.py +++ b/libs/core/kiln_ai/adapters/model_adapters/base_adapter.py @@ -724,7 +724,7 @@ async def __aiter__(self) -> AsyncIterator[AiSdkStreamEvent]: for ai_event in converter.convert_tool_event(event): yield ai_event - for ai_event in converter.finalize(): + for ai_event in converter.close_open_blocks(): yield ai_event yield AiSdkStreamEvent(AiSdkEventType.FINISH_STEP) @@ -732,6 +732,9 @@ async def __aiter__(self) -> AsyncIterator[AiSdkStreamEvent]: self._task_run = self._adapter._finalize_stream( adapter_stream, self._input, self._input_source ) + + for ai_event in converter.finalize(): + yield ai_event finally: if is_root_agent: try: diff --git a/libs/core/kiln_ai/adapters/model_adapters/stream_events.py b/libs/core/kiln_ai/adapters/model_adapters/stream_events.py index fbe2ff230..eb023f910 100644 --- a/libs/core/kiln_ai/adapters/model_adapters/stream_events.py +++ b/libs/core/kiln_ai/adapters/model_adapters/stream_events.py @@ -247,7 +247,8 @@ def convert_tool_event(self, event: ToolCallEvent) -> list[AiSdkStreamEvent]: return events - def finalize(self) -> list[AiSdkStreamEvent]: + def close_open_blocks(self) -> list[AiSdkStreamEvent]: + """Close any open text/reasoning blocks. Call before FINISH_STEP.""" events: list[AiSdkStreamEvent] = [] if self._reasoning_started: @@ -268,6 +269,12 @@ def finalize(self) -> list[AiSdkStreamEvent]: ) self._text_started = False + return events + + def finalize(self) -> list[AiSdkStreamEvent]: + """Emit the terminal FINISH event with usage/finish metadata. Call after FINISH_STEP.""" + events: list[AiSdkStreamEvent] = [] + finish_payload: dict[str, Any] = {} if self._finish_reason is not None: finish_payload["finishReason"] = self._finish_reason.replace("_", "-") diff --git a/libs/core/kiln_ai/adapters/model_adapters/test_litellm_adapter_streaming.py b/libs/core/kiln_ai/adapters/model_adapters/test_litellm_adapter_streaming.py index 28b92faf1..cc2d51b98 100644 --- a/libs/core/kiln_ai/adapters/model_adapters/test_litellm_adapter_streaming.py +++ b/libs/core/kiln_ai/adapters/model_adapters/test_litellm_adapter_streaming.py @@ -219,6 +219,11 @@ async def test_invoke_ai_sdk_stream( assert AiSdkEventType.FINISH_STEP in event_types, "Should have FINISH_STEP" assert AiSdkEventType.FINISH in event_types, "Should have FINISH" + finish_step_idx = event_types.index(AiSdkEventType.FINISH_STEP) + finish_idx = event_types.index(AiSdkEventType.FINISH) + assert finish_step_idx < finish_idx, ( + f"FINISH_STEP (idx {finish_step_idx}) must come before FINISH (idx {finish_idx})" + ) assert AiSdkEventType.REASONING_START in event_types, "Should have REASONING_START" assert AiSdkEventType.REASONING_DELTA in event_types, "Should have REASONING_DELTA" diff --git a/libs/core/kiln_ai/adapters/model_adapters/test_stream_events.py b/libs/core/kiln_ai/adapters/model_adapters/test_stream_events.py index 9df765ba9..f5f9ca73a 100644 --- a/libs/core/kiln_ai/adapters/model_adapters/test_stream_events.py +++ b/libs/core/kiln_ai/adapters/model_adapters/test_stream_events.py @@ -113,19 +113,20 @@ def test_tool_call_input_start_and_delta(self): assert start_event.payload["toolCallId"] == "call_1" assert start_event.payload["toolName"] == "add" - def test_finalize_closes_open_blocks(self): + def test_close_open_blocks_closes_text(self): converter = AiSdkStreamConverter() converter.convert_chunk(_make_chunk(content="text")) - events = converter.finalize() - types = [e.type for e in events] + block_events = converter.close_open_blocks() + types = [e.type for e in block_events] assert AiSdkEventType.TEXT_END in types - assert AiSdkEventType.FINISH in types + finish_events = converter.finalize() + assert any(e.type == AiSdkEventType.FINISH for e in finish_events) - def test_finalize_closes_reasoning(self): + def test_close_open_blocks_closes_reasoning(self): converter = AiSdkStreamConverter() converter.convert_chunk(_make_chunk(reasoning_content="thinking")) - events = converter.finalize() - types = [e.type for e in events] + block_events = converter.close_open_blocks() + types = [e.type for e in block_events] assert AiSdkEventType.REASONING_END in types def test_convert_tool_event_input_available(self): @@ -218,6 +219,7 @@ def test_reset_for_next_step(self): def test_finish_reason_in_finalize(self): converter = AiSdkStreamConverter() converter.convert_chunk(_make_chunk(content="done", finish_reason="stop")) + converter.close_open_blocks() events = converter.finalize() finish_events = [e for e in events if e.type == AiSdkEventType.FINISH] assert len(finish_events) == 1 From 61b124962bd77a9ccbfe170c12ae6b7cb5b251ed Mon Sep 17 00:00:00 2001 From: "Leonard Q. Marcq" Date: Tue, 17 Mar 2026 04:49:32 +0800 Subject: [PATCH 28/32] test: more coverage for streaming --- .../model_adapters/test_adapter_stream.py | 126 +++++++++++++ .../model_adapters/test_base_adapter.py | 170 +++++++++++++++++- .../model_adapters/test_stream_events.py | 67 +++++++ 3 files changed, 362 insertions(+), 1 deletion(-) diff --git a/libs/core/kiln_ai/adapters/model_adapters/test_adapter_stream.py b/libs/core/kiln_ai/adapters/model_adapters/test_adapter_stream.py index 25715645c..a104ea066 100644 --- a/libs/core/kiln_ai/adapters/model_adapters/test_adapter_stream.py +++ b/libs/core/kiln_ai/adapters/model_adapters/test_adapter_stream.py @@ -370,3 +370,129 @@ async def test_unparseable_tool_call_arguments(self, mock_adapter, mock_provider assert len(input_events) == 1 assert input_events[0].arguments is None assert "Failed to parse" in (input_events[0].error or "") + + +class TestAdapterStreamEdgeCases: + @pytest.mark.asyncio + async def test_too_many_turns_raises(self, mock_adapter, mock_provider): + formatter = FakeChatFormatter(num_turns=15) + response = _make_model_response(content="ok") + fake_stream = FakeStreamingCompletion(response) + + with ( + patch( + "kiln_ai.adapters.model_adapters.adapter_stream.StreamingCompletion", + return_value=fake_stream, + ), + patch( + "kiln_ai.adapters.model_adapters.adapter_stream.MAX_CALLS_PER_TURN", + 2, + ), + ): + stream = AdapterStream( + adapter=mock_adapter, + provider=mock_provider, + chat_formatter=formatter, + initial_messages=[], + top_logprobs=None, + ) + with pytest.raises(RuntimeError, match="Too many turns"): + async for _ in stream: + pass + + @pytest.mark.asyncio + async def test_empty_message_content_raises(self, mock_adapter, mock_provider): + formatter = MagicMock() + turn = MagicMock() + turn.messages = [MagicMock(role="user", content=None)] + turn.final_call = True + formatter.next_turn = MagicMock(side_effect=[turn, None]) + + stream = AdapterStream( + adapter=mock_adapter, + provider=mock_provider, + chat_formatter=formatter, + initial_messages=[], + top_logprobs=None, + ) + with pytest.raises(ValueError, match="Empty message content"): + async for _ in stream: + pass + + @pytest.mark.asyncio + async def test_no_content_or_tool_calls_raises(self, mock_adapter, mock_provider): + response = _make_model_response(content=None, tool_calls=None) + response.choices[0].message.content = None + fake_stream = FakeStreamingCompletion( + response, [_make_streaming_chunk(finish_reason="stop")] + ) + + with patch( + "kiln_ai.adapters.model_adapters.adapter_stream.StreamingCompletion", + return_value=fake_stream, + ): + stream = AdapterStream( + adapter=mock_adapter, + provider=mock_provider, + chat_formatter=FakeChatFormatter(), + initial_messages=[], + top_logprobs=None, + ) + with pytest.raises(ValueError, match="no content or tool calls"): + async for _ in stream: + pass + + +class TestValidateResponse: + def test_valid_response(self): + from kiln_ai.adapters.model_adapters.adapter_stream import _validate_response + + response = _make_model_response(content="hello") + result, choice = _validate_response(response) + assert result is response + assert choice.message.content == "hello" + + def test_none_response_raises(self): + from kiln_ai.adapters.model_adapters.adapter_stream import _validate_response + + with pytest.raises(RuntimeError, match="Expected ModelResponse"): + _validate_response(None) + + def test_empty_choices_raises(self): + from kiln_ai.adapters.model_adapters.adapter_stream import _validate_response + + response = ModelResponse(id="test", choices=[]) + with pytest.raises(RuntimeError, match="Expected ModelResponse"): + _validate_response(response) + + +class TestFindToolName: + def test_found(self): + from kiln_ai.adapters.model_adapters.adapter_stream import _find_tool_name + + tc = ChatCompletionMessageToolCall( + id="call_1", + type="function", + function=Function(name="add", arguments="{}"), + ) + assert _find_tool_name([tc], "call_1") == "add" + + def test_not_found_returns_unknown(self): + from kiln_ai.adapters.model_adapters.adapter_stream import _find_tool_name + + tc = ChatCompletionMessageToolCall( + id="call_1", + type="function", + function=Function(name="add", arguments="{}"), + ) + assert _find_tool_name([tc], "call_999") == "unknown" + + def test_name_is_none_returns_unknown(self): + from kiln_ai.adapters.model_adapters.adapter_stream import _find_tool_name + + tc = ChatCompletionMessageToolCall( + id="call_1", + type="function", + function=Function(name=None, arguments="{}"), + ) + assert _find_tool_name([tc], "call_1") == "unknown" diff --git a/libs/core/kiln_ai/adapters/model_adapters/test_base_adapter.py b/libs/core/kiln_ai/adapters/model_adapters/test_base_adapter.py index 2bf10b3a9..ee802a7bc 100644 --- a/libs/core/kiln_ai/adapters/model_adapters/test_base_adapter.py +++ b/libs/core/kiln_ai/adapters/model_adapters/test_base_adapter.py @@ -21,7 +21,7 @@ ToolCallEventType, ) from kiln_ai.adapters.prompt_builders import BasePromptBuilder -from kiln_ai.datamodel import Task, TaskRun +from kiln_ai.datamodel import Task, TaskRun, Usage from kiln_ai.datamodel.datamodel_enums import ChatStrategy, ModelProviderName from kiln_ai.datamodel.project import Project from kiln_ai.datamodel.run_config import KilnAgentRunConfigProperties, ToolsRunConfig @@ -1237,3 +1237,171 @@ async def __aiter__(self): pass assert stream.task_run is expected_run + + +class TestFinalizeStream: + """Tests for _finalize_stream post-processing after streaming.""" + + @pytest.fixture + def finalize_adapter(self, base_task): + return MockAdapter( + task=base_task, + run_config=KilnAgentRunConfigProperties( + model_name="test_model", + model_provider_name="openai", + prompt_id="simple_prompt_builder", + structured_output_mode="json_schema", + ), + ) + + def _make_adapter_stream(self, output, usage=None, trace=None): + from kiln_ai.adapters.model_adapters.adapter_stream import AdapterStreamResult + + stream = MagicMock() + stream.result = AdapterStreamResult( + run_output=RunOutput( + output=output, + intermediate_outputs={}, + trace=trace, + ), + usage=usage or Usage(), + ) + return stream + + def test_finalize_stream_plain_text(self, finalize_adapter): + provider = MagicMock() + provider.parser = None + provider.reasoning_capable = False + finalize_adapter.model_provider = MagicMock(return_value=provider) + + adapter_stream = self._make_adapter_stream("Hello world") + run = finalize_adapter._finalize_stream(adapter_stream, "test input", None) + + assert isinstance(run, TaskRun) + assert run.output.output == "Hello world" + assert run.id is None + + def _make_structured_adapter(self, base_task, schema): + base_task.output_json_schema = schema + adapter = MockAdapter( + task=base_task, + run_config=KilnAgentRunConfigProperties( + model_name="test_model", + model_provider_name="openai", + prompt_id="simple_prompt_builder", + structured_output_mode="json_schema", + ), + ) + return adapter + + def test_finalize_stream_structured_output(self, base_task): + schema = '{"type": "object", "properties": {"name": {"type": "string"}}, "required": ["name"]}' + adapter = self._make_structured_adapter(base_task, schema) + + provider = MagicMock() + provider.parser = None + provider.reasoning_capable = False + adapter.model_provider = MagicMock(return_value=provider) + + adapter_stream = self._make_adapter_stream({"name": "test"}) + run = adapter._finalize_stream(adapter_stream, "test input", None) + + assert isinstance(run, TaskRun) + assert '"name"' in run.output.output + + def test_finalize_stream_structured_output_from_json_string(self, base_task): + schema = '{"type": "object", "properties": {"val": {"type": "integer"}}, "required": ["val"]}' + adapter = self._make_structured_adapter(base_task, schema) + + provider = MagicMock() + provider.parser = None + provider.reasoning_capable = False + adapter.model_provider = MagicMock(return_value=provider) + + adapter_stream = self._make_adapter_stream('{"val": 42}') + run = adapter._finalize_stream(adapter_stream, "test input", None) + assert isinstance(run, TaskRun) + + def test_finalize_stream_structured_output_not_dict_raises(self, base_task): + schema = '{"type": "object", "properties": {"x": {"type": "string"}}, "required": ["x"]}' + adapter = self._make_structured_adapter(base_task, schema) + + provider = MagicMock() + provider.parser = None + provider.reasoning_capable = False + adapter.model_provider = MagicMock(return_value=provider) + + adapter_stream = self._make_adapter_stream(42) + with pytest.raises(RuntimeError, match="structured response is not a dict"): + adapter._finalize_stream(adapter_stream, "test input", None) + + def test_finalize_stream_non_structured_non_string_raises(self, finalize_adapter): + provider = MagicMock() + provider.parser = None + provider.reasoning_capable = False + finalize_adapter.model_provider = MagicMock(return_value=provider) + + adapter_stream = self._make_adapter_stream({"unexpected": "dict"}) + with pytest.raises(RuntimeError, match="not a string for non-structured"): + finalize_adapter._finalize_stream(adapter_stream, "test input", None) + + def test_finalize_stream_reasoning_required_but_missing(self, finalize_adapter): + provider = MagicMock() + provider.parser = None + provider.reasoning_capable = True + provider.reasoning_optional_for_structured_output = False + finalize_adapter.model_provider = MagicMock(return_value=provider) + + adapter_stream = self._make_adapter_stream("output") + with pytest.raises(RuntimeError, match="Reasoning is required"): + finalize_adapter._finalize_stream(adapter_stream, "test input", None) + + def test_finalize_stream_reasoning_not_required_with_tool_calls( + self, finalize_adapter + ): + provider = MagicMock() + provider.parser = None + provider.reasoning_capable = True + provider.reasoning_optional_for_structured_output = False + finalize_adapter.model_provider = MagicMock(return_value=provider) + + trace = [ + {"role": "user", "content": "hi"}, + {"role": "tool", "content": "result", "tool_call_id": "call_1"}, + ] + adapter_stream = self._make_adapter_stream("output", trace=trace) + run = finalize_adapter._finalize_stream(adapter_stream, "test input", None) + assert isinstance(run, TaskRun) + + def test_finalize_stream_saves_when_allowed(self, tmp_path): + project_path = tmp_path / "proj" / "project.kiln" + project_path.parent.mkdir() + project = Project(name="test", path=project_path) + project.save_to_file() + task = Task(name="t", instruction="i", parent=project) + task.save_to_file() + + adapter = MockAdapter( + task=task, + run_config=KilnAgentRunConfigProperties( + model_name="test_model", + model_provider_name="openai", + prompt_id="simple_prompt_builder", + structured_output_mode="json_schema", + ), + config=AdapterConfig(allow_saving=True), + ) + + provider = MagicMock() + provider.parser = None + provider.reasoning_capable = False + adapter.model_provider = MagicMock(return_value=provider) + + adapter_stream = self._make_adapter_stream("result") + with patch( + "kiln_ai.adapters.model_adapters.base_adapter.Config" + ) as mock_config: + mock_config.shared.return_value.autosave_runs = True + mock_config.shared.return_value.user_id = "test_user" + run = adapter._finalize_stream(adapter_stream, "test input", None) + assert run.id is not None diff --git a/libs/core/kiln_ai/adapters/model_adapters/test_stream_events.py b/libs/core/kiln_ai/adapters/model_adapters/test_stream_events.py index f5f9ca73a..e7abbb832 100644 --- a/libs/core/kiln_ai/adapters/model_adapters/test_stream_events.py +++ b/libs/core/kiln_ai/adapters/model_adapters/test_stream_events.py @@ -318,3 +318,70 @@ def test_tool_input_start_not_reemitted_without_reset(self): assert len(starts_r2) == 0, ( "Without reset, started=True blocks duplicate tool-input-start" ) + + def test_choice_with_none_delta_is_skipped(self): + chunk = ModelResponseStream( + id="test", + choices=[StreamingChoices(index=0, delta=None, finish_reason=None)], + ) + converter = AiSdkStreamConverter() + events = converter.convert_chunk(chunk) + assert events == [] + + def test_usage_data_extracted_from_empty_choices_chunk(self): + usage = type( + "Usage", + (), + { + "prompt_tokens": 10, + "completion_tokens": 20, + "total_tokens": 30, + }, + )() + chunk = ModelResponseStream(id="test", choices=[]) + chunk.usage = usage + + converter = AiSdkStreamConverter() + converter.convert_chunk(chunk) + assert converter._usage_data is usage + + def test_finalize_includes_usage_payload(self): + usage = type( + "Usage", + (), + { + "prompt_tokens": 15, + "completion_tokens": 25, + "total_tokens": 40, + }, + )() + converter = AiSdkStreamConverter() + converter._usage_data = usage + converter._finish_reason = "stop" + + events = converter.finalize() + finish = next(e for e in events if e.type == AiSdkEventType.FINISH) + meta = finish.payload["messageMetadata"] + assert meta["finishReason"] == "stop" + assert meta["usage"]["promptTokens"] == 15 + assert meta["usage"]["completionTokens"] == 25 + assert meta["usage"]["totalTokens"] == 40 + + def test_finalize_usage_without_total_tokens(self): + usage = type( + "Usage", + (), + { + "prompt_tokens": 5, + "completion_tokens": 10, + }, + )() + converter = AiSdkStreamConverter() + converter._usage_data = usage + + events = converter.finalize() + finish = next(e for e in events if e.type == AiSdkEventType.FINISH) + meta = finish.payload["messageMetadata"] + assert meta["usage"]["promptTokens"] == 5 + assert meta["usage"]["completionTokens"] == 10 + assert "totalTokens" not in meta["usage"] From 5f81cfc62e529551f96584686f5c91e62aefd54f Mon Sep 17 00:00:00 2001 From: "Leonard Q. Marcq" Date: Tue, 17 Mar 2026 05:17:34 +0800 Subject: [PATCH 29/32] fix: do not use only the cached parent --- libs/core/kiln_ai/datamodel/task_run.py | 27 ++++++++++++------------- 1 file changed, 13 insertions(+), 14 deletions(-) diff --git a/libs/core/kiln_ai/datamodel/task_run.py b/libs/core/kiln_ai/datamodel/task_run.py index c231c4f36..a3ad22e4e 100644 --- a/libs/core/kiln_ai/datamodel/task_run.py +++ b/libs/core/kiln_ai/datamodel/task_run.py @@ -142,33 +142,32 @@ def has_thinking_training_data(self) -> bool: # Workaround to return typed parent without importing Task def parent_task(self) -> Union["Task", None]: """The Task that this Run is in. Note the TaskRun may be nested in which case we walk back up the tree all the way to the root.""" - current = self + current: TaskRun = self while True: # should never really happen, except maybe in tests - if current.parent is None: + parent = current.parent + if parent is None: return None # this task run is the root task run # so we just return its parent (a Task) - if current.parent.__class__.__name__ == "Task": - return current.parent # type: ignore - - # the parent of this task is not a Task, so it has to be a TaskRun - # and we just walk back up the tree of TaskRuns until we find a Task - parent_run = current.cached_parent() - if isinstance(parent_run, TaskRun): - current = parent_run - else: + if parent.__class__.__name__ == "Task": + return parent # type: ignore + + if parent.__class__.__name__ != "TaskRun": # the parent is not a TaskRun, but also not a Task, so it is not # a real parent return None + # the parent is a TaskRun, so we just walk up the tree until we find a Task + current = parent # type: ignore + def parent_run(self) -> "TaskRun | None": """The TaskRun that contains this run, if this run is nested; otherwise None.""" - parent = self.cached_parent() - if parent is None or not isinstance(parent, TaskRun): + parent = self.parent + if parent is None or parent.__class__.__name__ != "TaskRun": return None - return parent + return parent # type: ignore @classmethod def _parent_types(cls) -> List[Type["KilnBaseModel"]]: From 72dde7a8a8abbc66bdb7c821329d4cc1a2dc0485 Mon Sep 17 00:00:00 2001 From: "Leonard Q. Marcq" Date: Tue, 17 Mar 2026 05:29:26 +0800 Subject: [PATCH 30/32] refactor: remove dead code, add test --- libs/core/kiln_ai/datamodel/basemodel.py | 28 ++++-------- libs/core/kiln_ai/datamodel/test_models.py | 53 ++++++++++++++++++++++ 2 files changed, 62 insertions(+), 19 deletions(-) diff --git a/libs/core/kiln_ai/datamodel/basemodel.py b/libs/core/kiln_ai/datamodel/basemodel.py index 14e808671..90e981b78 100644 --- a/libs/core/kiln_ai/datamodel/basemodel.py +++ b/libs/core/kiln_ai/datamodel/basemodel.py @@ -657,25 +657,15 @@ def iterate_children_paths_of_parent_path(cls: Type[PT], parent_path: Path | Non f"Parent model_type '{actual_parent_type_name}' is not one of " f"{parent_type_names}" ) - else: - # find the parent type that matches the actual parent type name we found on disk - # and validate it - parent_type = next( - ( - t - for t in parent_types_override - if t.type_name() == actual_parent_type_name - ), - None, - ) - if parent_type is None: - raise ValueError( - f"Could not find parent type '{actual_parent_type_name}' in " - f"{parent_types_override}" - ) - parent = parent_type.load_from_file(parent_path) - if parent is None: - raise ValueError("Parent must be set to load children") + + parent_type = next( + t + for t in parent_types_override + if t.type_name() == actual_parent_type_name + ) + parent = parent_type.load_from_file(parent_path) + if parent is None: + raise ValueError("Parent must be set to load children") # Ignore type error: this is abstract base class, but children must implement relationship_name relationship_folder = parent_folder / Path(cls.relationship_name()) # type: ignore diff --git a/libs/core/kiln_ai/datamodel/test_models.py b/libs/core/kiln_ai/datamodel/test_models.py index 09973c4dd..cc11ba9a6 100644 --- a/libs/core/kiln_ai/datamodel/test_models.py +++ b/libs/core/kiln_ai/datamodel/test_models.py @@ -1151,3 +1151,56 @@ def collect(node): collect(root) return runs + + +def test_task_run_wrong_parent_type_raises(tmp_path): + project = Project(name="proj", path=tmp_path / "project.kiln") + project.save_to_file() + + with pytest.raises(ValidationError, match="Parent must be one of"): + TaskRun( + input="bad parent", + output=TaskOutput( + output="x", + source=DataSource( + type=DataSourceType.human, properties={"created_by": "test"} + ), + ), + parent=project, + ) + + +def test_task_run_runs_on_disk(tmp_path): + project = Project(name="proj", path=tmp_path / "project.kiln") + project.save_to_file() + task = Task(name="t", instruction="i", parent=project) + task.save_to_file() + + parent_run = TaskRun( + input="parent", + output=TaskOutput( + output="parent out", + source=DataSource( + type=DataSourceType.human, properties={"created_by": "test"} + ), + ), + parent=task, + ) + parent_run.save_to_file() + + child_run = TaskRun( + input="child", + output=TaskOutput( + output="child out", + source=DataSource( + type=DataSourceType.human, properties={"created_by": "test"} + ), + ), + parent=parent_run, + ) + child_run.save_to_file() + + loaded = TaskRun.load_from_file(parent_run.path) + children = loaded.runs() + assert len(children) == 1 + assert children[0].id == child_run.id From e26e52dea1fce6265faf2383f3198442c5c2da61 Mon Sep 17 00:00:00 2001 From: "Leonard Q. Marcq" Date: Tue, 17 Mar 2026 16:10:03 +0800 Subject: [PATCH 31/32] cr: use isinstance instead of class name --- libs/core/kiln_ai/datamodel/task_run.py | 20 +++++++++++++------- 1 file changed, 13 insertions(+), 7 deletions(-) diff --git a/libs/core/kiln_ai/datamodel/task_run.py b/libs/core/kiln_ai/datamodel/task_run.py index a3ad22e4e..40dac39e3 100644 --- a/libs/core/kiln_ai/datamodel/task_run.py +++ b/libs/core/kiln_ai/datamodel/task_run.py @@ -142,6 +142,9 @@ def has_thinking_training_data(self) -> bool: # Workaround to return typed parent without importing Task def parent_task(self) -> Union["Task", None]: """The Task that this Run is in. Note the TaskRun may be nested in which case we walk back up the tree all the way to the root.""" + # lazy import to avoid circular dependency + from kiln_ai.datamodel.task import Task + current: TaskRun = self while True: # should never really happen, except maybe in tests @@ -151,23 +154,23 @@ def parent_task(self) -> Union["Task", None]: # this task run is the root task run # so we just return its parent (a Task) - if parent.__class__.__name__ == "Task": - return parent # type: ignore + if isinstance(parent, Task): + return parent - if parent.__class__.__name__ != "TaskRun": + if not isinstance(parent, TaskRun): # the parent is not a TaskRun, but also not a Task, so it is not # a real parent return None # the parent is a TaskRun, so we just walk up the tree until we find a Task - current = parent # type: ignore + current = parent def parent_run(self) -> "TaskRun | None": """The TaskRun that contains this run, if this run is nested; otherwise None.""" parent = self.parent - if parent is None or parent.__class__.__name__ != "TaskRun": + if parent is None or not isinstance(parent, TaskRun): return None - return parent # type: ignore + return parent @classmethod def _parent_types(cls) -> List[Type["KilnBaseModel"]]: @@ -182,7 +185,10 @@ def runs(self, readonly: bool = False) -> list["TaskRun"]: def is_root_task_run(self) -> bool: """Is this the root task run? (not nested under another task run)""" - return self.parent is None or self.parent.__class__.__name__ == "Task" + # lazy import to avoid circular dependency + from kiln_ai.datamodel.task import Task + + return self.parent is None or isinstance(self.parent, Task) def find_task_run_by_id_dfs( self, task_run_id: str, readonly: bool = False From 2c120dea3869ba6f3da7ddd14c1a5c2b254880aa Mon Sep 17 00:00:00 2001 From: "Leonard Q. Marcq" Date: Tue, 17 Mar 2026 20:42:41 +0800 Subject: [PATCH 32/32] refactor: actually saving task_run under its parent during invoke --- .../adapters/model_adapters/base_adapter.py | 33 +++-- .../adapters/model_adapters/mcp_adapter.py | 22 ++-- .../model_adapters/test_mcp_adapter.py | 48 +++++++ .../test_saving_adapter_results.py | 121 +++++++++++++++++- 4 files changed, 205 insertions(+), 19 deletions(-) diff --git a/libs/core/kiln_ai/adapters/model_adapters/base_adapter.py b/libs/core/kiln_ai/adapters/model_adapters/base_adapter.py index 60cd014b1..b8c6c0bfa 100644 --- a/libs/core/kiln_ai/adapters/model_adapters/base_adapter.py +++ b/libs/core/kiln_ai/adapters/model_adapters/base_adapter.py @@ -143,9 +143,10 @@ async def invoke( input: InputType, input_source: DataSource | None = None, prior_trace: list[ChatCompletionMessageParam] | None = None, + parent_task_run: TaskRun | None = None, ) -> TaskRun: run_output, _ = await self.invoke_returning_run_output( - input, input_source, prior_trace + input, input_source, prior_trace, parent_task_run ) return run_output @@ -154,6 +155,7 @@ async def _run_returning_run_output( input: InputType, input_source: DataSource | None = None, prior_trace: list[ChatCompletionMessageParam] | None = None, + parent_task_run: TaskRun | None = None, ) -> Tuple[TaskRun, RunOutput]: # validate input, allowing arrays if self.input_schema is not None: @@ -224,7 +226,7 @@ async def _run_returning_run_output( ) run = self.generate_run( - input, input_source, parsed_output, usage, run_output.trace + input, input_source, parsed_output, usage, run_output.trace, parent_task_run ) # Save the run if configured to do so, and we have a path to save to @@ -245,6 +247,7 @@ async def invoke_returning_run_output( input: InputType, input_source: DataSource | None = None, prior_trace: list[ChatCompletionMessageParam] | None = None, + parent_task_run: TaskRun | None = None, ) -> Tuple[TaskRun, RunOutput]: # Determine if this is the root agent (no existing run context) is_root_agent = get_agent_run_id() is None @@ -255,7 +258,7 @@ async def invoke_returning_run_output( try: return await self._run_returning_run_output( - input, input_source, prior_trace + input, input_source, prior_trace, parent_task_run ) finally: if is_root_agent: @@ -271,6 +274,7 @@ def invoke_openai_stream( input: InputType, input_source: DataSource | None = None, prior_trace: list[ChatCompletionMessageParam] | None = None, + parent_task_run: TaskRun | None = None, ) -> OpenAIStreamResult: """Stream raw OpenAI-protocol chunks for the task execution. @@ -282,13 +286,16 @@ def invoke_openai_stream( Tool-call rounds happen internally and are not surfaced; use ``invoke_ai_sdk_stream`` if you need tool-call events. """ - return OpenAIStreamResult(self, input, input_source, prior_trace) + return OpenAIStreamResult( + self, input, input_source, prior_trace, parent_task_run + ) def invoke_ai_sdk_stream( self, input: InputType, input_source: DataSource | None = None, prior_trace: list[ChatCompletionMessageParam] | None = None, + parent_task_run: TaskRun | None = None, ) -> AiSdkStreamResult: """Stream AI SDK protocol events for the task execution. @@ -297,7 +304,9 @@ def invoke_ai_sdk_stream( control events. After the iterator is exhausted the resulting ``TaskRun`` is available via the ``.task_run`` property. """ - return AiSdkStreamResult(self, input, input_source, prior_trace) + return AiSdkStreamResult( + self, input, input_source, prior_trace, parent_task_run + ) def _prepare_stream( self, @@ -327,6 +336,7 @@ def _finalize_stream( adapter_stream: AdapterStream, input: InputType, input_source: DataSource | None, + parent_task_run: TaskRun | None = None, ) -> TaskRun: """Streaming invocations are only concerned with passing through events as they come in. At the end of the stream, we still need to validate the output, create a run and everything @@ -379,7 +389,7 @@ def _finalize_stream( ) run = self.generate_run( - input, input_source, parsed_output, usage, run_output.trace + input, input_source, parsed_output, usage, run_output.trace, parent_task_run ) if ( @@ -496,6 +506,7 @@ def generate_run( run_output: RunOutput, usage: Usage | None = None, trace: list[ChatCompletionMessageParam] | None = None, + parent_task_run: TaskRun | None = None, ) -> TaskRun: output_str = ( json.dumps(run_output.output, ensure_ascii=False) @@ -530,7 +541,7 @@ def generate_run( ) return TaskRun( - parent=self.task, + parent=parent_task_run if parent_task_run is not None else self.task, input=input_str, input_source=input_source, output=new_output, @@ -621,11 +632,13 @@ def __init__( input: InputType, input_source: DataSource | None, prior_trace: list[ChatCompletionMessageParam] | None, + parent_task_run: TaskRun | None = None, ) -> None: self._adapter = adapter self._input = input self._input_source = input_source self._prior_trace = prior_trace + self._parent_task_run = parent_task_run self._task_run: TaskRun | None = None @property @@ -653,7 +666,7 @@ async def __aiter__(self) -> AsyncIterator[ModelResponseStream]: yield event self._task_run = self._adapter._finalize_stream( - adapter_stream, self._input, self._input_source + adapter_stream, self._input, self._input_source, self._parent_task_run ) finally: if is_root_agent: @@ -678,11 +691,13 @@ def __init__( input: InputType, input_source: DataSource | None, prior_trace: list[ChatCompletionMessageParam] | None, + parent_task_run: TaskRun | None = None, ) -> None: self._adapter = adapter self._input = input self._input_source = input_source self._prior_trace = prior_trace + self._parent_task_run = parent_task_run self._task_run: TaskRun | None = None @property @@ -730,7 +745,7 @@ async def __aiter__(self) -> AsyncIterator[AiSdkStreamEvent]: yield AiSdkStreamEvent(AiSdkEventType.FINISH_STEP) self._task_run = self._adapter._finalize_stream( - adapter_stream, self._input, self._input_source + adapter_stream, self._input, self._input_source, self._parent_task_run ) for ai_event in converter.finalize(): diff --git a/libs/core/kiln_ai/adapters/model_adapters/mcp_adapter.py b/libs/core/kiln_ai/adapters/model_adapters/mcp_adapter.py index c488e7fc0..9b38ff9bb 100644 --- a/libs/core/kiln_ai/adapters/model_adapters/mcp_adapter.py +++ b/libs/core/kiln_ai/adapters/model_adapters/mcp_adapter.py @@ -1,10 +1,7 @@ import json from typing import Tuple -from kiln_ai.adapters.model_adapters.base_adapter import ( - AdapterConfig, - BaseAdapter, -) +from kiln_ai.adapters.model_adapters.base_adapter import AdapterConfig, BaseAdapter from kiln_ai.adapters.parsers.json_parser import parse_json_string from kiln_ai.adapters.run_output import RunOutput from kiln_ai.datamodel import DataSource, Task, TaskRun, Usage @@ -89,15 +86,16 @@ async def invoke( input: InputType, input_source: DataSource | None = None, prior_trace: list[ChatCompletionMessageParam] | None = None, + parent_task_run: TaskRun | None = None, ) -> TaskRun: - if prior_trace: + if prior_trace or parent_task_run is not None: raise NotImplementedError( "Session continuation is not supported for MCP adapter. " "MCP tools are single-turn and do not maintain conversation state." ) run_output, _ = await self.invoke_returning_run_output( - input, input_source, prior_trace + input, input_source, prior_trace, parent_task_run ) return run_output @@ -106,12 +104,13 @@ async def invoke_returning_run_output( input: InputType, input_source: DataSource | None = None, prior_trace: list[ChatCompletionMessageParam] | None = None, + parent_task_run: TaskRun | None = None, ) -> Tuple[TaskRun, RunOutput]: """ Runs the task and returns both the persisted TaskRun and raw RunOutput. If this call is the root of a run, it creates an agent run context, ensures MCP tool calls have a valid session scope, and cleans up the session/context on completion. """ - if prior_trace: + if prior_trace or parent_task_run is not None: raise NotImplementedError( "Session continuation is not supported for MCP adapter. " "MCP tools are single-turn and do not maintain conversation state." @@ -124,7 +123,9 @@ async def invoke_returning_run_output( set_agent_run_id(run_id) try: - return await self._run_and_validate_output(input, input_source) + return await self._run_and_validate_output( + input, input_source, parent_task_run + ) finally: if is_root_agent: try: @@ -138,6 +139,7 @@ async def _run_and_validate_output( self, input: InputType, input_source: DataSource | None, + parent_task_run: TaskRun | None = None, ) -> Tuple[TaskRun, RunOutput]: """ Run the MCP task and validate the output. @@ -176,7 +178,9 @@ async def _run_and_validate_output( # Build single turn trace trace = self._build_single_turn_trace(input, run_output.output) - run = self.generate_run(input, input_source, run_output, usage, trace) + run = self.generate_run( + input, input_source, run_output, usage, trace, parent_task_run + ) if ( self.base_adapter_config.allow_saving diff --git a/libs/core/kiln_ai/adapters/model_adapters/test_mcp_adapter.py b/libs/core/kiln_ai/adapters/model_adapters/test_mcp_adapter.py index cb0a3e94b..689801f86 100644 --- a/libs/core/kiln_ai/adapters/model_adapters/test_mcp_adapter.py +++ b/libs/core/kiln_ai/adapters/model_adapters/test_mcp_adapter.py @@ -415,3 +415,51 @@ async def test_mcp_adapter_rejects_prior_trace_in_run( assert "Session continuation is not supported" in str(exc_info.value) assert "MCP tools are single-turn" in str(exc_info.value) + + +@pytest.mark.asyncio +async def test_mcp_adapter_rejects_parent_task_run_in_invoke( + project_with_local_mcp_server, local_mcp_tool_id +): + """invoke with parent_task_run raises NotImplementedError for MCP adapter.""" + project, _ = project_with_local_mcp_server + task = Task( + name="Test Task", + parent=project, + instruction="Echo input", + ) + + run_config = McpRunConfigProperties( + tool_reference=MCPToolReference(tool_id=local_mcp_tool_id) + ) + + adapter = MCPAdapter(task=task, run_config=run_config) + + with pytest.raises(NotImplementedError) as exc_info: + await adapter.invoke("input", parent_task_run=MagicMock()) + + assert "Session continuation is not supported" in str(exc_info.value) + + +@pytest.mark.asyncio +async def test_mcp_adapter_rejects_parent_task_run_in_invoke_returning_run_output( + project_with_local_mcp_server, local_mcp_tool_id +): + """invoke_returning_run_output with parent_task_run raises NotImplementedError for MCP adapter.""" + project, _ = project_with_local_mcp_server + task = Task( + name="Test Task", + parent=project, + instruction="Echo input", + ) + + run_config = McpRunConfigProperties( + tool_reference=MCPToolReference(tool_id=local_mcp_tool_id) + ) + + adapter = MCPAdapter(task=task, run_config=run_config) + + with pytest.raises(NotImplementedError) as exc_info: + await adapter.invoke_returning_run_output("input", parent_task_run=MagicMock()) + + assert "Session continuation is not supported" in str(exc_info.value) diff --git a/libs/core/kiln_ai/adapters/model_adapters/test_saving_adapter_results.py b/libs/core/kiln_ai/adapters/model_adapters/test_saving_adapter_results.py index db20fe5ea..c2ebf17a5 100644 --- a/libs/core/kiln_ai/adapters/model_adapters/test_saving_adapter_results.py +++ b/libs/core/kiln_ai/adapters/model_adapters/test_saving_adapter_results.py @@ -3,7 +3,7 @@ import pytest from kiln_ai.adapters.model_adapters.base_adapter import BaseAdapter, RunOutput -from kiln_ai.datamodel import DataSource, DataSourceType, Project, Task, Usage +from kiln_ai.datamodel import DataSource, DataSourceType, Project, Task, TaskRun, Usage from kiln_ai.datamodel.datamodel_enums import InputType from kiln_ai.datamodel.run_config import KilnAgentRunConfigProperties from kiln_ai.utils.config import Config @@ -425,3 +425,122 @@ def test_properties_for_task_output_custom_values(test_task): assert output.source.properties["structured_output_mode"] == "json_schema" assert output.source.properties["temperature"] == 0.7 assert output.source.properties["top_p"] == 0.9 + + +def test_generate_run_with_parent_task_run_sets_parent(test_task, adapter): + """Test that generate_run with parent_task_run uses it as parent instead of the task.""" + prior_run = adapter.generate_run( + input="prior input", + input_source=None, + run_output=RunOutput(output="prior output", intermediate_outputs=None), + ) + prior_run.save_to_file() + assert prior_run.id is not None + + new_run = adapter.generate_run( + input="new input", + input_source=None, + run_output=RunOutput(output="new output", intermediate_outputs=None), + parent_task_run=prior_run, + ) + + assert new_run.parent == prior_run + + new_run.save_to_file() + + reloaded_prior_run = TaskRun.load_from_file(prior_run.path) + child_runs = reloaded_prior_run.runs() + assert len(child_runs) == 1 + assert child_runs[0].output.output == "new output" + + # The task should only have the prior run as a direct child + reloaded_task = Task.load_from_file(test_task.path) + task_runs = reloaded_task.runs() + assert len(task_runs) == 1 + assert task_runs[0].id == prior_run.id + + +def test_generate_run_without_parent_task_run_defaults_to_task(test_task, adapter): + """Test that generate_run without parent_task_run defaults to using the task as parent.""" + run = adapter.generate_run( + input="input", + input_source=None, + run_output=RunOutput(output="output", intermediate_outputs=None), + ) + assert run.parent == test_task + + +@pytest.mark.asyncio +async def test_invoke_with_parent_task_run_saves_as_child(test_task, adapter): + """Test that invoke with parent_task_run saves the new run as a child of that run.""" + trace = [ + {"role": "user", "content": "Hello"}, + {"role": "assistant", "content": "Hi there!"}, + ] + + # Create and save a prior run to act as parent + prior_run = adapter.generate_run( + input="Hello", + input_source=None, + run_output=RunOutput( + output="Hi there!", intermediate_outputs=None, trace=trace + ), + trace=trace, + ) + prior_run.save_to_file() + assert prior_run.id is not None + + continuation_trace = [ + *trace, + {"role": "user", "content": "Tell me more"}, + {"role": "assistant", "content": "More details!"}, + ] + continuation_output = RunOutput( + output="More details!", + intermediate_outputs=None, + trace=continuation_trace, + ) + + adapter._run = AsyncMock(return_value=(continuation_output, None)) + + with ( + patch("kiln_ai.utils.config.Config.shared") as mock_shared, + patch.object( + adapter, + "model_provider", + return_value=MagicMock( + parser="default", + formatter=None, + reasoning_capable=False, + ), + ), + patch( + "kiln_ai.adapters.model_adapters.base_adapter.model_parser_from_id", + return_value=MagicMock( + parse_output=MagicMock(return_value=continuation_output) + ), + ), + ): + mock_shared.return_value.autosave_runs = True + mock_shared.return_value.user_id = "test_user" + + new_run = await adapter.invoke( + "Tell me more", + prior_trace=trace, + parent_task_run=prior_run, + ) + + assert new_run.id is not None + assert new_run.parent == prior_run + + # The prior run should have the new run as a child + reloaded_prior_run = TaskRun.load_from_file(prior_run.path) + child_runs = reloaded_prior_run.runs() + assert len(child_runs) == 1 + assert child_runs[0].output.output == "More details!" + + # The task should only have the prior run as a direct child + reloaded_task = Task.load_from_file(test_task.path) + task_runs = reloaded_task.runs() + assert len(task_runs) == 1 + assert task_runs[0].id == prior_run.id