diff --git a/libs/core/kiln_ai/adapters/chat/__init__.py b/libs/core/kiln_ai/adapters/chat/__init__.py index 7ab6328f0..11b5eda12 100644 --- a/libs/core/kiln_ai/adapters/chat/__init__.py +++ b/libs/core/kiln_ai/adapters/chat/__init__.py @@ -1,8 +1,10 @@ from .chat_formatter import ( BasicChatMessage, + ChatCompletionMessageIncludingLiteLLM, ChatFormatter, ChatMessage, ChatStrategy, + MultiturnFormatter, ToolCallMessage, ToolResponseMessage, get_chat_formatter, @@ -11,9 +13,11 @@ __all__ = [ "BasicChatMessage", + "ChatCompletionMessageIncludingLiteLLM", "ChatFormatter", "ChatMessage", "ChatStrategy", + "MultiturnFormatter", "ToolCallMessage", "ToolResponseMessage", "build_tool_call_messages", diff --git a/libs/core/kiln_ai/adapters/chat/chat_formatter.py b/libs/core/kiln_ai/adapters/chat/chat_formatter.py index d1a73514f..22ba371a3 100644 --- a/libs/core/kiln_ai/adapters/chat/chat_formatter.py +++ b/libs/core/kiln_ai/adapters/chat/chat_formatter.py @@ -3,15 +3,25 @@ import json from abc import ABC, abstractmethod from dataclasses import dataclass -from typing import Dict, List, Literal, Optional, Sequence, Union +from typing import Dict, List, Literal, Optional, Sequence, TypeAlias, Union + +from litellm.types.utils import Message as LiteLLMMessage from kiln_ai.datamodel.datamodel_enums import ChatStrategy, InputType from kiln_ai.utils.exhaustive_error import raise_exhaustive_enum_error -from kiln_ai.utils.open_ai_types import ChatCompletionMessageToolCallParam +from kiln_ai.utils.open_ai_types import ( + ChatCompletionMessageParam, + ChatCompletionMessageToolCallParam, +) COT_FINAL_ANSWER_PROMPT = "Considering the above, return a final result." +ChatCompletionMessageIncludingLiteLLM: TypeAlias = Union[ + ChatCompletionMessageParam, LiteLLMMessage +] + + @dataclass class BasicChatMessage: role: Literal["system", "assistant", "user"] @@ -90,6 +100,10 @@ def intermediate_outputs(self) -> Dict[str, str]: """Get the intermediate outputs from the chat formatter.""" return self._intermediate_outputs + def initial_messages(self) -> list[ChatCompletionMessageIncludingLiteLLM]: + """Messages to seed the conversation. Empty for fresh runs; prior trace for continuation.""" + return [] + @abstractmethod def next_turn(self, previous_output: str | None = None) -> Optional[ChatTurn]: """Advance the conversation and return the next messages if any.""" @@ -236,6 +250,49 @@ def next_turn(self, previous_output: str | None = None) -> Optional[ChatTurn]: return None +class MultiturnFormatter(ChatFormatter): + """ + Formatter for continuing a multi-turn conversation with prior trace. + Takes prior_trace (existing conversation) and appends the new user message. + Produces a single turn: the new user message. Tool calls and multi-turn + model responses are handled by _run_model_turn's internal loop. + """ + + def __init__( + self, + prior_trace: list[ChatCompletionMessageParam], + user_input: InputType, + ) -> None: + super().__init__( + system_message="", + user_input=user_input, + thinking_instructions=None, + ) + self._prior_trace = prior_trace + + def initial_messages(self) -> list[ChatCompletionMessageIncludingLiteLLM]: + """Messages to seed the conversation (prior trace).""" + return list(self._prior_trace) + + def next_turn(self, previous_output: str | None = None) -> Optional[ChatTurn]: + if self._state == "start": + # prior trace is already in the messages list and contains system and so on, we only need + # to append the latest new user message + user_msg = BasicChatMessage("user", format_user_message(self.user_input)) + self._state = "awaiting_final" + self._messages.append(user_msg) + return ChatTurn(messages=[user_msg], final_call=True) + + if self._state == "awaiting_final": + if previous_output is None: + raise ValueError("previous_output required for final step") + self._messages.append(BasicChatMessage("assistant", previous_output)) + self._state = "done" + return None + + return None + + def get_chat_formatter( strategy: ChatStrategy, system_message: str, diff --git a/libs/core/kiln_ai/adapters/chat/test_chat_formatter.py b/libs/core/kiln_ai/adapters/chat/test_chat_formatter.py index 642d49236..e0138a954 100644 --- a/libs/core/kiln_ai/adapters/chat/test_chat_formatter.py +++ b/libs/core/kiln_ai/adapters/chat/test_chat_formatter.py @@ -1,6 +1,7 @@ from kiln_ai.adapters.chat import ChatStrategy, get_chat_formatter from kiln_ai.adapters.chat.chat_formatter import ( COT_FINAL_ANSWER_PROMPT, + MultiturnFormatter, format_user_message, ) @@ -119,6 +120,32 @@ def test_chat_formatter_r1_style(): assert formatter.intermediate_outputs() == {} +def test_multiturn_formatter_initial_messages(): + prior_trace = [ + {"role": "user", "content": "hi"}, + {"role": "assistant", "content": "hello"}, + ] + formatter = MultiturnFormatter(prior_trace=prior_trace, user_input="new input") + assert formatter.initial_messages() == prior_trace + + +def test_multiturn_formatter_next_turn(): + prior_trace = [ + {"role": "user", "content": "hi"}, + {"role": "assistant", "content": "hello"}, + ] + formatter = MultiturnFormatter(prior_trace=prior_trace, user_input="follow-up") + + first = formatter.next_turn() + assert first is not None + assert len(first.messages) == 1 + assert first.messages[0].role == "user" + assert first.messages[0].content == "follow-up" + assert first.final_call + + assert formatter.next_turn("assistant response") is None + + def test_format_user_message(): # String assert format_user_message("test input") == "test input" diff --git a/libs/core/kiln_ai/adapters/model_adapters/base_adapter.py b/libs/core/kiln_ai/adapters/model_adapters/base_adapter.py index 134ebc7d3..56cfe1e54 100644 --- a/libs/core/kiln_ai/adapters/model_adapters/base_adapter.py +++ b/libs/core/kiln_ai/adapters/model_adapters/base_adapter.py @@ -3,7 +3,11 @@ from dataclasses import dataclass from typing import Dict, Tuple -from kiln_ai.adapters.chat.chat_formatter import ChatFormatter, get_chat_formatter +from kiln_ai.adapters.chat.chat_formatter import ( + ChatFormatter, + MultiturnFormatter, + get_chat_formatter, +) from kiln_ai.adapters.ml_model_list import ( KilnModelProvider, StructuredOutputMode, @@ -123,14 +127,18 @@ async def invoke( self, input: InputType, input_source: DataSource | None = None, + existing_run: TaskRun | None = None, ) -> TaskRun: - run_output, _ = await self.invoke_returning_run_output(input, input_source) + run_output, _ = await self.invoke_returning_run_output( + input, input_source, existing_run + ) return run_output async def _run_returning_run_output( self, input: InputType, input_source: DataSource | None = None, + existing_run: TaskRun | None = None, ) -> Tuple[TaskRun, RunOutput]: # validate input, allowing arrays if self.input_schema is not None: @@ -141,6 +149,15 @@ async def _run_returning_run_output( require_object=False, ) + if existing_run is not None and ( + not existing_run.trace or len(existing_run.trace) == 0 + ): + raise ValueError( + "Run has no trace. Cannot continue session without conversation history." + ) + + prior_trace = existing_run.trace if existing_run else None + # Format model input for model call (we save the original input in the task without formatting) formatted_input = input formatter_id = self.model_provider().formatter @@ -149,7 +166,7 @@ async def _run_returning_run_output( formatted_input = formatter.format_input(input) # Run - run_output, usage = await self._run(formatted_input) + run_output, usage = await self._run(formatted_input, prior_trace=prior_trace) # Parse provider = self.model_provider() @@ -198,10 +215,28 @@ async def _run_returning_run_output( "Reasoning is required for this model, but no reasoning was returned." ) - # Generate the run and output - run = self.generate_run( - input, input_source, parsed_output, usage, run_output.trace - ) + # Create the run and output - merge if there is an existing run + if existing_run is not None: + merged_output = RunOutput( + output=parsed_output.output, + intermediate_outputs=parsed_output.intermediate_outputs + or run_output.intermediate_outputs, + output_logprobs=parsed_output.output_logprobs + or run_output.output_logprobs, + trace=run_output.trace, + ) + run = self.generate_run( + input, + input_source, + merged_output, + usage, + run_output.trace, + existing_run=existing_run, + ) + else: + run = self.generate_run( + input, input_source, parsed_output, usage, run_output.trace + ) # Save the run if configured to do so, and we have a path to save to if ( @@ -210,7 +245,7 @@ async def _run_returning_run_output( and self.task.path is not None ): run.save_to_file() - else: + elif existing_run is None: # Clear the ID to indicate it's not persisted run.id = None @@ -220,6 +255,7 @@ async def invoke_returning_run_output( self, input: InputType, input_source: DataSource | None = None, + existing_run: TaskRun | None = None, ) -> Tuple[TaskRun, RunOutput]: # Determine if this is the root agent (no existing run context) is_root_agent = get_agent_run_id() is None @@ -229,7 +265,9 @@ async def invoke_returning_run_output( set_agent_run_id(run_id) try: - return await self._run_returning_run_output(input, input_source) + return await self._run_returning_run_output( + input, input_source, existing_run + ) finally: if is_root_agent: try: @@ -247,7 +285,11 @@ def adapter_name(self) -> str: pass @abstractmethod - async def _run(self, input: InputType) -> Tuple[RunOutput, Usage | None]: + async def _run( + self, + input: InputType, + prior_trace: list[ChatCompletionMessageParam] | None = None, + ) -> Tuple[RunOutput, Usage | None]: pass def build_prompt(self) -> str: @@ -267,7 +309,14 @@ def build_prompt(self) -> str: include_json_instructions=add_json_instructions ) - def build_chat_formatter(self, input: InputType) -> ChatFormatter: + def build_chat_formatter( + self, + input: InputType, + prior_trace: list[ChatCompletionMessageParam] | None = None, + ) -> ChatFormatter: + if prior_trace is not None: + return MultiturnFormatter(prior_trace, input) + if self.prompt_builder is None: raise ValueError("Prompt builder is not available for MCP run config") # Determine the chat strategy to use based on the prompt the user selected, the model's capabilities, and if the model was finetuned with a specific chat strategy. @@ -323,24 +372,14 @@ def generate_run( run_output: RunOutput, usage: Usage | None = None, trace: list[ChatCompletionMessageParam] | None = None, + existing_run: TaskRun | None = None, ) -> TaskRun: - # Convert input and output to JSON strings if they aren't strings - input_str = ( - input if isinstance(input, str) else json.dumps(input, ensure_ascii=False) - ) output_str = ( json.dumps(run_output.output, ensure_ascii=False) if isinstance(run_output.output, dict) else run_output.output ) - # If no input source is provided, use the human data source - if input_source is None: - input_source = DataSource( - type=DataSourceType.human, - properties={"created_by": Config.shared().user_id}, - ) - # Synthetic since an adapter, not a human, is creating this # Special case for MCP run configs which calls a mcp tool output_source_type = ( @@ -349,26 +388,41 @@ def generate_run( else DataSourceType.synthetic ) - new_task_run = TaskRun( + new_output = TaskOutput( + output=output_str, + source=DataSource( + type=output_source_type, + properties=self._properties_for_task_output(), + run_config=self.run_config, + ), + ) + + final_usage = usage + final_intermediate = run_output.intermediate_outputs + if existing_run is not None: + final_usage = (existing_run.usage or Usage()) + (usage or Usage()) + final_intermediate = run_output.intermediate_outputs + + input_str = ( + input if isinstance(input, str) else json.dumps(input, ensure_ascii=False) + ) + if input_source is None: + input_source = DataSource( + type=DataSourceType.human, + properties={"created_by": Config.shared().user_id}, + ) + + return TaskRun( parent=self.task, input=input_str, input_source=input_source, - output=TaskOutput( - output=output_str, - source=DataSource( - type=output_source_type, - properties=self._properties_for_task_output(), - run_config=self.run_config, - ), - ), - intermediate_outputs=run_output.intermediate_outputs, + output=new_output, + intermediate_outputs=final_intermediate, tags=self.base_adapter_config.default_tags or [], - usage=usage, + usage=final_usage, trace=trace, ) - return new_task_run - def _properties_for_task_output(self) -> Dict[str, str | int | float]: match self.run_config.type: case "mcp": diff --git a/libs/core/kiln_ai/adapters/model_adapters/litellm_adapter.py b/libs/core/kiln_ai/adapters/model_adapters/litellm_adapter.py index 41a8e64c0..131be097c 100644 --- a/libs/core/kiln_ai/adapters/model_adapters/litellm_adapter.py +++ b/libs/core/kiln_ai/adapters/model_adapters/litellm_adapter.py @@ -3,7 +3,7 @@ import json import logging from dataclasses import dataclass -from typing import Any, Dict, List, Tuple, TypeAlias, Union +from typing import Any, Dict, List, Tuple import litellm from litellm.types.utils import ( @@ -19,6 +19,7 @@ ) import kiln_ai.datamodel as datamodel +from kiln_ai.adapters.chat import ChatCompletionMessageIncludingLiteLLM from kiln_ai.adapters.ml_model_list import ( KilnModelProvider, ModelProviderName, @@ -56,10 +57,6 @@ logger = logging.getLogger(__name__) -ChatCompletionMessageIncludingLiteLLM: TypeAlias = Union[ - ChatCompletionMessageParam, LiteLLMMessage -] - @dataclass class ModelTurnResult: @@ -184,20 +181,29 @@ async def _run_model_turn( f"Too many tool calls ({tool_calls_count}). Stopping iteration to avoid using too many tokens." ) - async def _run(self, input: InputType) -> tuple[RunOutput, Usage | None]: + async def _run( + self, + input: InputType, + prior_trace: list[ChatCompletionMessageParam] | None = None, + ) -> tuple[RunOutput, Usage | None]: usage = Usage() provider = self.model_provider() if not provider.model_id: raise ValueError("Model ID is required for OpenAI compatible models") - chat_formatter = self.build_chat_formatter(input) - messages: list[ChatCompletionMessageIncludingLiteLLM] = [] + # build_chat_formatter returns MultiturnFormatter when prior_trace is set, else prompt-based formatter + chat_formatter = self.build_chat_formatter(input, prior_trace) + messages: list[ChatCompletionMessageIncludingLiteLLM] = copy.deepcopy( + chat_formatter.initial_messages() + ) prior_output: str | None = None final_choice: Choices | None = None turns = 0 + # Same loop for both fresh runs and prior_trace continuation. + # _run_model_turn has its own internal loop for tool calls (model calls tool -> we run it -> model continues). while True: turns += 1 if turns > MAX_CALLS_PER_TURN: diff --git a/libs/core/kiln_ai/adapters/model_adapters/mcp_adapter.py b/libs/core/kiln_ai/adapters/model_adapters/mcp_adapter.py index a7f1a8b6c..45aabc53e 100644 --- a/libs/core/kiln_ai/adapters/model_adapters/mcp_adapter.py +++ b/libs/core/kiln_ai/adapters/model_adapters/mcp_adapter.py @@ -42,7 +42,17 @@ def __init__( def adapter_name(self) -> str: return "mcp_adapter" - async def _run(self, input: InputType) -> Tuple[RunOutput, Usage | None]: + async def _run( + self, + input: InputType, + prior_trace: list[ChatCompletionMessageParam] | None = None, + ) -> Tuple[RunOutput, Usage | None]: + if prior_trace is not None: + raise NotImplementedError( + "Session continuation is not supported for MCP adapter. " + "MCP tools are single-turn and do not maintain conversation state." + ) + run_config = self.run_config if not isinstance(run_config, McpRunConfigProperties): raise ValueError("MCPAdapter requires McpRunConfigProperties") @@ -75,19 +85,35 @@ async def invoke( self, input: InputType, input_source: DataSource | None = None, + existing_run: TaskRun | None = None, ) -> TaskRun: - run_output, _ = await self.invoke_returning_run_output(input, input_source) + if existing_run is not None: + raise NotImplementedError( + "Session continuation is not supported for MCP adapter. " + "MCP tools are single-turn and do not maintain conversation state." + ) + + run_output, _ = await self.invoke_returning_run_output( + input, input_source, existing_run + ) return run_output async def invoke_returning_run_output( self, input: InputType, input_source: DataSource | None = None, + existing_run: TaskRun | None = None, ) -> Tuple[TaskRun, RunOutput]: """ Runs the task and returns both the persisted TaskRun and raw RunOutput. If this call is the root of a run, it creates an agent run context, ensures MCP tool calls have a valid session scope, and cleans up the session/context on completion. """ + if existing_run is not None: + raise NotImplementedError( + "Session continuation is not supported for MCP adapter. " + "MCP tools are single-turn and do not maintain conversation state." + ) + is_root_agent = get_agent_run_id() is None if is_root_agent: diff --git a/libs/core/kiln_ai/adapters/model_adapters/test_base_adapter.py b/libs/core/kiln_ai/adapters/model_adapters/test_base_adapter.py index d2a59a70f..8b150c68e 100644 --- a/libs/core/kiln_ai/adapters/model_adapters/test_base_adapter.py +++ b/libs/core/kiln_ai/adapters/model_adapters/test_base_adapter.py @@ -9,8 +9,14 @@ RunOutput, ) from kiln_ai.adapters.prompt_builders import BasePromptBuilder -from kiln_ai.datamodel import Task -from kiln_ai.datamodel.datamodel_enums import ChatStrategy +from kiln_ai.datamodel import ( + DataSource, + DataSourceType, + Task, + TaskOutput, + TaskRun, +) +from kiln_ai.datamodel.datamodel_enums import ChatStrategy, ModelProviderName from kiln_ai.datamodel.project import Project from kiln_ai.datamodel.run_config import KilnAgentRunConfigProperties, ToolsRunConfig from kiln_ai.datamodel.tool_id import KilnBuiltInToolId @@ -20,7 +26,7 @@ class MockAdapter(BaseAdapter): """Concrete implementation of BaseAdapter for testing""" - async def _run(self, input): + async def _run(self, input, prior_trace=None): return None, None def adapter_name(self) -> str: @@ -233,7 +239,7 @@ async def test_input_formatting( # Mock the _run method to capture the input captured_input = None - async def mock_run(input): + async def mock_run(input, prior_trace=None): nonlocal captured_input captured_input = input return RunOutput(output="test output", intermediate_outputs={}), None @@ -420,6 +426,117 @@ def test_build_chat_formatter( mock_prompt_builder.chain_of_thought_prompt.assert_called_once() +def test_build_chat_formatter_with_prior_trace_returns_multiturn_formatter(adapter): + prior_trace = [ + {"role": "user", "content": "hi"}, + {"role": "assistant", "content": "hello"}, + ] + formatter = adapter.build_chat_formatter("new input", prior_trace=prior_trace) + assert formatter.__class__.__name__ == "MultiturnFormatter" + assert formatter.initial_messages() == prior_trace + + +@pytest.mark.asyncio +async def test_existing_run_without_trace_raises(base_project): + task = Task( + name="test_task", + instruction="test_instruction", + parent=base_project, + ) + adapter = MockAdapter( + task=task, + run_config=KilnAgentRunConfigProperties( + model_name="gpt_4o", + model_provider_name=ModelProviderName.openai, + prompt_id="simple_prompt_builder", + structured_output_mode=StructuredOutputMode.json_schema, + ), + ) + run_without_trace = TaskRun( + parent=task, + input="hi", + input_source=None, + output=TaskOutput( + output="hello", + source=DataSource( + type=DataSourceType.synthetic, + properties={ + "model_name": "gpt_4o", + "model_provider": "openai", + "adapter_name": "test", + }, + ), + ), + trace=None, + ) + with pytest.raises(ValueError, match="no trace"): + await adapter.invoke("input", existing_run=run_without_trace) + + +@pytest.mark.asyncio +async def test_invoke_returning_run_output_passes_existing_run_to_run( + adapter, mock_parser, tmp_path +): + project = Project(name="proj", path=tmp_path / "proj.kiln") + project.save_to_file() + task = Task( + name="t", + instruction="i", + parent=project, + ) + task.save_to_file() + adapter.task = task + + trace = [ + {"role": "user", "content": "hi"}, + {"role": "assistant", "content": "hello"}, + ] + initial_run = adapter.generate_run( + input="hi", + input_source=None, + run_output=RunOutput( + output="hello", + intermediate_outputs=None, + trace=trace, + ), + trace=trace, + ) + initial_run.save_to_file() + run_id = initial_run.id + assert run_id is not None + + captured_prior_trace = None + + async def mock_run(input, prior_trace=None): + nonlocal captured_prior_trace + captured_prior_trace = prior_trace + return RunOutput(output="ok", intermediate_outputs=None, trace=trace), None + + adapter._run = mock_run + + provider = MagicMock() + provider.parser = "test_parser" + provider.formatter = None + provider.reasoning_capable = False + adapter.model_provider = MagicMock(return_value=provider) + mock_parser.parse_output.return_value = RunOutput( + output="ok", intermediate_outputs=None, trace=trace + ) + + with ( + patch( + "kiln_ai.adapters.model_adapters.base_adapter.model_parser_from_id", + return_value=mock_parser, + ), + patch( + "kiln_ai.adapters.model_adapters.base_adapter.request_formatter_from_id", + ), + ): + await adapter.invoke_returning_run_output("follow-up", existing_run=initial_run) + + assert captured_prior_trace == trace + + @pytest.mark.parametrize( "initial_mode,expected_mode", [ @@ -681,7 +798,7 @@ async def test_invoke_sets_run_context(self, adapter, clear_context): from kiln_ai.run_context import get_agent_run_id # Mock the _run method - async def mock_run(input): + async def mock_run(input, prior_trace=None): # Check that run ID is set during _run run_id = get_agent_run_id() assert run_id is not None @@ -721,7 +838,7 @@ async def test_invoke_clears_run_context_after(self, adapter, clear_context): from kiln_ai.run_context import get_agent_run_id # Mock the _run method - async def mock_run(input): + async def mock_run(input, prior_trace=None): return RunOutput(output="test output", intermediate_outputs={}), None adapter._run = mock_run @@ -759,7 +876,7 @@ async def test_invoke_clears_run_context_on_error(self, adapter, clear_context): from kiln_ai.run_context import get_agent_run_id # Mock the _run method to raise an error - async def mock_run(input): + async def mock_run(input, prior_trace=None): # Run ID should be set even when error occurs run_id = get_agent_run_id() assert run_id is not None @@ -796,7 +913,7 @@ async def test_sub_agent_inherits_run(self, adapter, clear_context): set_agent_run_id(parent_run_id) # Mock the _run method to check inherited run ID - async def mock_run(input): + async def mock_run(input, prior_trace=None): # Sub-agent should see parent's run ID run_id = get_agent_run_id() assert run_id == parent_run_id @@ -845,7 +962,7 @@ async def test_sub_agent_does_not_create_new_run(self, adapter, clear_context): run_id_during_run = None # Mock the _run method to capture run ID - async def mock_run(input): + async def mock_run(input, prior_trace=None): nonlocal run_id_during_run run_id_during_run = get_agent_run_id() return RunOutput(output="test output", intermediate_outputs={}), None @@ -885,7 +1002,7 @@ async def test_cleanup_session_called_on_completion(self, adapter, clear_context from kiln_ai.adapters.run_output import RunOutput # Mock the _run method - async def mock_run(input): + async def mock_run(input, prior_trace=None): return RunOutput(output="test output", intermediate_outputs={}), None adapter._run = mock_run diff --git a/libs/core/kiln_ai/adapters/model_adapters/test_litellm_adapter.py b/libs/core/kiln_ai/adapters/model_adapters/test_litellm_adapter.py index 7d0414b99..3b44812a6 100644 --- a/libs/core/kiln_ai/adapters/model_adapters/test_litellm_adapter.py +++ b/libs/core/kiln_ai/adapters/model_adapters/test_litellm_adapter.py @@ -7,7 +7,10 @@ from kiln_ai.adapters.ml_model_list import ModelProviderName, StructuredOutputMode from kiln_ai.adapters.model_adapters.base_adapter import AdapterConfig -from kiln_ai.adapters.model_adapters.litellm_adapter import LiteLlmAdapter +from kiln_ai.adapters.model_adapters.litellm_adapter import ( + LiteLlmAdapter, + ModelTurnResult, +) from kiln_ai.adapters.model_adapters.litellm_config import LiteLlmConfig from kiln_ai.datamodel import Project, Task, Usage from kiln_ai.datamodel.run_config import ( @@ -1301,3 +1304,61 @@ async def test_dict_input_converted_to_json(tmp_path, config): assert isinstance(content, str) parsed_content = json.loads(content) assert parsed_content == {"x": 10, "y": 20} + + +@pytest.mark.asyncio +async def test_run_with_prior_trace_uses_multiturn_formatter(mock_task): + config = LiteLlmConfig( + base_url="https://api.test.com", + run_config_properties=KilnAgentRunConfigProperties( + model_name="test-model", + model_provider_name="openai_compatible", + prompt_id="simple_prompt_builder", + structured_output_mode="json_schema", + ), + default_headers={"X-Test": "test"}, + additional_body_options={"api_key": "test_key"}, + ) + prior_trace = [ + {"role": "user", "content": "hi"}, + {"role": "assistant", "content": "hello"}, + ] + adapter = LiteLlmAdapter(config=config, kiln_task=mock_task) + + build_chat_formatter_calls = [] + + original_build = adapter.build_chat_formatter + + def capturing_build(input, prior_trace_arg=None): + build_chat_formatter_calls.append((input, prior_trace_arg)) + return original_build(input, prior_trace_arg) + + adapter.build_chat_formatter = capturing_build + + async def mock_run_model_turn( + provider, prior_messages, top_logprobs, skip_response_format + ): + extended = list(prior_messages) + extended.append({"role": "assistant", "content": "How can I help?"}) + return ModelTurnResult( + assistant_message="How can I help?", + all_messages=extended, + model_response=None, + model_choice=None, + usage=Usage(), + ) + + adapter._run_model_turn = mock_run_model_turn + + run_output, _ = await adapter._run("follow-up", prior_trace=prior_trace) + + assert len(build_chat_formatter_calls) == 1 + assert build_chat_formatter_calls[0][0] == "follow-up" + assert build_chat_formatter_calls[0][1] == prior_trace + + assert run_output.trace is not None + assert len(run_output.trace) == 4 + assert run_output.trace[0]["content"] == "hi" + assert run_output.trace[1]["content"] == "hello" + assert run_output.trace[2]["content"] == "follow-up" + assert run_output.trace[3]["content"] == "How can I help?" diff --git a/libs/core/kiln_ai/adapters/model_adapters/test_mcp_adapter.py b/libs/core/kiln_ai/adapters/model_adapters/test_mcp_adapter.py index 393a3859f..ad5826d8c 100644 --- a/libs/core/kiln_ai/adapters/model_adapters/test_mcp_adapter.py +++ b/libs/core/kiln_ai/adapters/model_adapters/test_mcp_adapter.py @@ -1,6 +1,6 @@ import json from pathlib import Path -from unittest.mock import AsyncMock, patch +from unittest.mock import AsyncMock, MagicMock, patch import pytest from mcp.types import CallToolResult, TextContent @@ -328,3 +328,88 @@ async def test_mcp_adapter_sets_and_clears_run_context( await adapter.invoke_returning_run_output("input") assert get_agent_run_id() is None + + +@pytest.mark.asyncio +async def test_mcp_adapter_rejects_multiturn_invoke_returning_run_output( + project_with_local_mcp_server, local_mcp_tool_id +): + """Session continuation (existing_run) is not supported for MCP adapter.""" + project, _ = project_with_local_mcp_server + task = Task( + name="Test Task", + parent=project, + instruction="Echo input", + ) + + run_config = McpRunConfigProperties( + tool_reference=MCPToolReference(tool_id=local_mcp_tool_id) + ) + + adapter = MCPAdapter(task=task, run_config=run_config) + + existing_run = MagicMock() + existing_run.trace = [{"role": "user", "content": "hi"}] + + with pytest.raises(NotImplementedError) as exc_info: + await adapter.invoke_returning_run_output("input", existing_run=existing_run) + + assert "Session continuation is not supported" in str(exc_info.value) + assert "MCP adapter" in str(exc_info.value) + + +@pytest.mark.asyncio +async def test_mcp_adapter_rejects_multiturn_invoke( + project_with_local_mcp_server, local_mcp_tool_id +): + """invoke with existing_run raises NotImplementedError for MCP adapter.""" + project, _ = project_with_local_mcp_server + task = Task( + name="Test Task", + parent=project, + instruction="Echo input", + ) + + run_config = McpRunConfigProperties( + tool_reference=MCPToolReference(tool_id=local_mcp_tool_id) + ) + + adapter = MCPAdapter(task=task, run_config=run_config) + + existing_run = MagicMock() + existing_run.trace = [{"role": "user", "content": "hi"}] + + with pytest.raises(NotImplementedError) as exc_info: + await adapter.invoke("input", existing_run=existing_run) + + assert "Session continuation is not supported" in str(exc_info.value) + + +@pytest.mark.asyncio +async def test_mcp_adapter_rejects_prior_trace_in_run( + project_with_local_mcp_server, local_mcp_tool_id +): + """_run with prior_trace raises NotImplementedError for MCP adapter.""" + project, _ = project_with_local_mcp_server + task = Task( + name="Test Task", + parent=project, + instruction="Echo input", + ) + + run_config = McpRunConfigProperties( + tool_reference=MCPToolReference(tool_id=local_mcp_tool_id) + ) + + adapter = MCPAdapter(task=task, run_config=run_config) + + prior_trace = [ + {"role": "user", "content": "first message"}, + {"role": "assistant", "content": "first response"}, + ] + + with pytest.raises(NotImplementedError) as exc_info: + await adapter._run("follow-up message", prior_trace=prior_trace) + + assert "Session continuation is not supported" in str(exc_info.value) + assert "MCP tools are single-turn" in str(exc_info.value) diff --git a/libs/core/kiln_ai/adapters/model_adapters/test_saving_adapter_results.py b/libs/core/kiln_ai/adapters/model_adapters/test_saving_adapter_results.py index 9cc5fa5d9..00d4622a5 100644 --- a/libs/core/kiln_ai/adapters/model_adapters/test_saving_adapter_results.py +++ b/libs/core/kiln_ai/adapters/model_adapters/test_saving_adapter_results.py @@ -1,4 +1,4 @@ -from unittest.mock import patch +from unittest.mock import MagicMock, patch import pytest @@ -10,7 +10,11 @@ class MockAdapter(BaseAdapter): - async def _run(self, input: InputType) -> tuple[RunOutput, Usage | None]: + async def _run( + self, + input: InputType, + prior_trace=None, + ) -> tuple[RunOutput, Usage | None]: return RunOutput(output="Test output", intermediate_outputs=None), None def adapter_name(self) -> str: @@ -233,6 +237,163 @@ async def test_autosave_true(test_task, adapter): assert output.source.properties["top_p"] == 1.0 +@pytest.mark.asyncio +async def test_invoke_continue_session(test_task, adapter): + """Test that invoke with existing_run continues a session and creates a new run.""" + with patch("kiln_ai.utils.config.Config.shared") as mock_shared: + mock_config = mock_shared.return_value + mock_config.autosave_runs = True + mock_config.user_id = "test_user" + + trace = [ + {"role": "user", "content": "Hello"}, + {"role": "assistant", "content": "Hi there!"}, + ] + initial_run = adapter.generate_run( + input="Hello", + input_source=None, + run_output=RunOutput( + output="Hi there!", + intermediate_outputs=None, + trace=trace, + ), + trace=trace, + ) + initial_run.save_to_file() + run_id = initial_run.id + assert run_id is not None + + async def mock_run(input, prior_trace=None): + if prior_trace is not None: + extended_trace = [ + *prior_trace, + {"role": "user", "content": input}, + {"role": "assistant", "content": "How can I help?"}, + ] + return ( + RunOutput( + output="How can I help?", + intermediate_outputs=None, + trace=extended_trace, + ), + None, + ) + return RunOutput(output="Test output", intermediate_outputs=None), None + + adapter._run = mock_run + + with ( + patch.object( + adapter, + "model_provider", + return_value=MagicMock( + parser="default", + formatter=None, + reasoning_capable=False, + ), + ), + patch( + "kiln_ai.adapters.model_adapters.base_adapter.model_parser_from_id" + ) as mock_parser_from_id, + ): + mock_parser = MagicMock() + mock_parser.parse_output.return_value = RunOutput( + output="How can I help?", + intermediate_outputs=None, + trace=[ + {"role": "user", "content": "Hello"}, + {"role": "assistant", "content": "Hi there!"}, + {"role": "user", "content": "Tell me more"}, + {"role": "assistant", "content": "How can I help?"}, + ], + ) + mock_parser_from_id.return_value = mock_parser + + updated_run = await adapter.invoke("Tell me more", existing_run=initial_run) + + assert updated_run.id != run_id + assert updated_run.input == "Tell me more" + assert updated_run.output.output == "How can I help?" + assert len(updated_run.trace) == 4 + assert updated_run.trace[-2]["content"] == "Tell me more" + assert updated_run.trace[-1]["content"] == "How can I help?" + + reloaded = Task.load_from_file(test_task.path) + runs = reloaded.runs() + assert len(runs) == 2 + initial_run_reloaded = next(r for r in runs if r.id == run_id) + continued_run = next(r for r in runs if r.id == updated_run.id) + assert initial_run_reloaded.output.output == "Hi there!" + assert continued_run.output.output == "How can I help?" + + +@pytest.mark.asyncio +async def test_invoke_continue_run_without_trace(test_task, adapter): + """Test that invoke with existing_run that has no trace raises ValueError.""" + with patch("kiln_ai.utils.config.Config.shared") as mock_shared: + mock_config = mock_shared.return_value + mock_config.autosave_runs = True + mock_config.user_id = "test_user" + + run_without_trace = adapter.generate_run( + input="Hello", + input_source=None, + run_output=RunOutput( + output="Hi", + intermediate_outputs=None, + trace=None, + ), + ) + run_without_trace.save_to_file() + + with pytest.raises(ValueError, match="no trace"): + await adapter.invoke("Follow up", existing_run=run_without_trace) + + +def test_generate_run_with_existing_run_merges_usage_and_intermediate_outputs( + test_task, adapter +): + trace = [ + {"role": "user", "content": "hi"}, + {"role": "assistant", "content": "hello"}, + ] + initial_run = adapter.generate_run( + input="hi", + input_source=None, + run_output=RunOutput( + output="hello", + intermediate_outputs={"chain_of_thought": "old"}, + trace=trace, + ), + usage=Usage(input_tokens=10, output_tokens=20), + trace=trace, + ) + extended_trace = [ + *trace, + {"role": "user", "content": "follow-up"}, + {"role": "assistant", "content": "ok"}, + ] + result = adapter.generate_run( + input="follow-up", + input_source=None, + run_output=RunOutput( + output="ok", + intermediate_outputs={"new_key": "new_val"}, + trace=extended_trace, + ), + usage=Usage(input_tokens=5, output_tokens=10), + trace=extended_trace, + existing_run=initial_run, + ) + assert result is not initial_run + assert result.id != initial_run.id + assert result.input == "follow-up" + assert result.usage.input_tokens == 15 + assert result.usage.output_tokens == 30 + assert result.intermediate_outputs == {"new_key": "new_val"} + assert result.output.output == "ok" + + def test_properties_for_task_output_custom_values(test_task): """Test that _properties_for_task_output includes custom temperature, top_p, and structured_output_mode""" adapter = MockAdapter( diff --git a/libs/core/kiln_ai/adapters/model_adapters/test_structured_output.py b/libs/core/kiln_ai/adapters/model_adapters/test_structured_output.py index fd5451705..92d9a1f4d 100644 --- a/libs/core/kiln_ai/adapters/model_adapters/test_structured_output.py +++ b/libs/core/kiln_ai/adapters/model_adapters/test_structured_output.py @@ -53,7 +53,11 @@ def __init__(self, kiln_task: datamodel.Task, response: InputType | None): ) self.response = response - async def _run(self, input: str) -> tuple[RunOutput, Usage | None]: + async def _run( + self, + input: str, + prior_trace=None, + ) -> tuple[RunOutput, Usage | None]: return RunOutput(output=self.response, intermediate_outputs=None), None def adapter_name(self) -> str: diff --git a/libs/core/kiln_ai/adapters/test_prompt_builders.py b/libs/core/kiln_ai/adapters/test_prompt_builders.py index 0f6485e73..25cadd99d 100644 --- a/libs/core/kiln_ai/adapters/test_prompt_builders.py +++ b/libs/core/kiln_ai/adapters/test_prompt_builders.py @@ -58,7 +58,11 @@ def test_simple_prompt_builder(tmp_path): class MockAdapter(BaseAdapter): - async def _run(self, input: InputType) -> tuple[RunOutput, Usage | None]: + async def _run( + self, + input: InputType, + prior_trace=None, + ) -> tuple[RunOutput, Usage | None]: return RunOutput(output="mock response", intermediate_outputs=None), None def adapter_name(self) -> str: diff --git a/libs/core/kiln_ai/datamodel/test_basemodel.py b/libs/core/kiln_ai/datamodel/test_basemodel.py index 92fadb7fc..03c6dfeaf 100644 --- a/libs/core/kiln_ai/datamodel/test_basemodel.py +++ b/libs/core/kiln_ai/datamodel/test_basemodel.py @@ -862,7 +862,7 @@ def individual_lookups(): class MockAdapter(BaseAdapter): """Implementation of BaseAdapter for testing""" - async def _run(self, input): + async def _run(self, input, prior_trace=None): return RunOutput(output="test output", intermediate_outputs=None), None def adapter_name(self) -> str: diff --git a/libs/server/kiln_server/test_run_api.py b/libs/server/kiln_server/test_run_api.py index 14d37371c..d7bbd6701 100644 --- a/libs/server/kiln_server/test_run_api.py +++ b/libs/server/kiln_server/test_run_api.py @@ -1,5 +1,7 @@ import logging +import os import time +from pathlib import Path from unittest.mock import AsyncMock, MagicMock, patch import pytest @@ -17,6 +19,8 @@ TaskOutputRatingType, TaskRun, ) +from kiln_ai.datamodel.tool_id import KilnBuiltInToolId +from kiln_ai.utils.config import Config from kiln_server.custom_errors import connect_custom_errors from kiln_server.run_api import ( @@ -95,6 +99,10 @@ def task_run_setup(tmp_path): }, ), ), + trace=[ + {"role": "user", "content": "Test input"}, + {"role": "assistant", "content": "Test output"}, + ], ) task_run.save_to_file() @@ -1663,3 +1671,153 @@ async def test_benchmark_tag_runs(client, task_run_setup, run_count): logger.info( f"Performance: {runs_per_second:.1f} runs/second, {avg_time_per_run * 1000:.2f}ms per run" ) + + +def _adapter_sanity_check_output_path() -> Path: + return Path(__file__).resolve().parent / "adapter_sanity_check.txt" + + +def _append_to_sanity_check(content: str, output_path: Path) -> None: + with open(output_path, "a", encoding="utf-8") as f: + f.write(content) + f.write("\n") + + +@pytest.fixture +def adapter_sanity_check_setup(tmp_path): + """Setup for paid adapter sanity check tests - real project/task, no adapter mocking.""" + # if project at the path does not exist, create it, otherwise reuse + project_path = ( + Path("/Users/leonardmarcq/Downloads/") + / "adapter_sanity_project" + / "project.kiln" + ) + if not project_path.exists(): + project_path.parent.mkdir() + + project = Project(name="Adapter Sanity Project", path=str(project_path)) + project.save_to_file() + + task = Task( + name="Adapter Sanity Task", + instruction="You are a helpful assistant. Respond concisely.", + description="Task for adapter sanity checking", + parent=project, + ) + task.save_to_file() + + else: + project = Project.load_from_file(project_path) + task = next( + ( + t + for t in project.tasks(readonly=True) + if t.name == "Adapter Sanity Task" + ), + None, + ) + if task is None: + raise ValueError("Task not found") + + config = Config.shared() + original_projects = list(config.projects) if config.projects else [] + config._settings["projects"] = [*original_projects, str(project.path)] + + yield {"project": project, "task": task} + + config._settings["projects"] = original_projects + + +@pytest.fixture +def adapter_sanity_check_math_tools_setup(tmp_path): + """Setup for paid math tools test - task with instructions to use add, multiply, etc.""" + project_path = tmp_path / "adapter_sanity_math_project" / "project.kiln" + project_path.parent.mkdir() + + project = Project(name="Adapter Sanity Math Project", path=str(project_path)) + project.save_to_file() + + task = Task( + name="Math Tools Task", + instruction="You are an assistant that performs math using the provided tools. You MUST use the add, subtract, multiply, and divide tools for any arithmetic. For example, for 2+2 you must call the add tool with a=2 and b=2. End your response with the final answer in square brackets, e.g. [4].", + description="Task for testing Kiln built-in math tools", + parent=project, + ) + task.save_to_file() + + config = Config.shared() + original_projects = list(config.projects) if config.projects else [] + config._settings["projects"] = [*original_projects, str(project.path)] + + yield {"project": project, "task": task} + + config._settings["projects"] = original_projects + + +def _assert_math_tools_response(res: dict, expected_in_output: str) -> None: + """Assert response has correct output, trace with tool calls, and output matches latest message.""" + assert res["id"] is not None + + output = res.get("output", {}).get("output", "") + assert output is not None + assert expected_in_output in output + + trace = res.get("trace") or [] + assistant_with_tool_calls = [ + m for m in trace if m.get("role") == "assistant" and m.get("tool_calls") + ] + assert len(assistant_with_tool_calls) >= 1 + + tool_messages = [m for m in trace if m.get("role") == "tool"] + assert len(tool_messages) >= 1 + + last_assistant = next( + (m for m in reversed(trace) if m.get("role") == "assistant"), None + ) + assert last_assistant is not None + last_content = last_assistant.get("content") or "" + assert expected_in_output in last_content + + intermediate_outputs = res.get("intermediate_outputs") or {} + if intermediate_outputs: + for key, value in intermediate_outputs.items(): + assert isinstance(value, str) + + +@pytest.mark.paid +@pytest.mark.asyncio +async def test_run_task_adapter_sanity_math_tools( + client, adapter_sanity_check_math_tools_setup +): + """Single-turn run with built-in Kiln math tools. Test that tools work as expected.""" + if not os.environ.get("OPENROUTER_API_KEY"): + pytest.skip("OPENROUTER_API_KEY required for this test") + + project = adapter_sanity_check_math_tools_setup["project"] + task = adapter_sanity_check_math_tools_setup["task"] + + run_config = { + "model_name": "gpt_5_nano", + "model_provider_name": "openrouter", + "prompt_id": "simple_prompt_builder", + "structured_output_mode": "json_schema", + "tools_config": { + "tools": [ + KilnBuiltInToolId.ADD_NUMBERS.value, + KilnBuiltInToolId.SUBTRACT_NUMBERS.value, + KilnBuiltInToolId.MULTIPLY_NUMBERS.value, + KilnBuiltInToolId.DIVIDE_NUMBERS.value, + ] + }, + } + + response = client.post( + f"/api/projects/{project.id}/tasks/{task.id}/run", + json={ + "run_config_properties": run_config, + "plaintext_input": "What is 2 + 2? Use the tools to calculate.", + }, + ) + assert response.status_code == 200 + res = response.json() + _assert_math_tools_response(res, "4")