Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions libs/core/kiln_ai/adapters/chat/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from .chat_formatter import (
BasicChatMessage,
ChatCompletionMessageIncludingLiteLLM,
ChatFormatter,
ChatMessage,
ChatStrategy,
MultiturnFormatter,
ToolCallMessage,
ToolResponseMessage,
get_chat_formatter,
Expand All @@ -11,9 +13,11 @@

__all__ = [
"BasicChatMessage",
"ChatCompletionMessageIncludingLiteLLM",
"ChatFormatter",
"ChatMessage",
"ChatStrategy",
"MultiturnFormatter",
"ToolCallMessage",
"ToolResponseMessage",
"build_tool_call_messages",
Expand Down
61 changes: 59 additions & 2 deletions libs/core/kiln_ai/adapters/chat/chat_formatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,25 @@
import json
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Dict, List, Literal, Optional, Sequence, Union
from typing import Dict, List, Literal, Optional, Sequence, TypeAlias, Union

from litellm.types.utils import Message as LiteLLMMessage

from kiln_ai.datamodel.datamodel_enums import ChatStrategy, InputType
from kiln_ai.utils.exhaustive_error import raise_exhaustive_enum_error
from kiln_ai.utils.open_ai_types import ChatCompletionMessageToolCallParam
from kiln_ai.utils.open_ai_types import (
ChatCompletionMessageParam,
ChatCompletionMessageToolCallParam,
)

COT_FINAL_ANSWER_PROMPT = "Considering the above, return a final result."


ChatCompletionMessageIncludingLiteLLM: TypeAlias = Union[
ChatCompletionMessageParam, LiteLLMMessage
]


@dataclass
class BasicChatMessage:
role: Literal["system", "assistant", "user"]
Expand Down Expand Up @@ -90,6 +100,10 @@ def intermediate_outputs(self) -> Dict[str, str]:
"""Get the intermediate outputs from the chat formatter."""
return self._intermediate_outputs

def initial_messages(self) -> list[ChatCompletionMessageIncludingLiteLLM]:
"""Messages to seed the conversation. Empty for fresh runs; prior trace for continuation."""
return []

@abstractmethod
def next_turn(self, previous_output: str | None = None) -> Optional[ChatTurn]:
"""Advance the conversation and return the next messages if any."""
Expand Down Expand Up @@ -236,6 +250,49 @@ def next_turn(self, previous_output: str | None = None) -> Optional[ChatTurn]:
return None


class MultiturnFormatter(ChatFormatter):
"""
Formatter for continuing a multi-turn conversation with prior trace.
Takes prior_trace (existing conversation) and appends the new user message.
Produces a single turn: the new user message. Tool calls and multi-turn
model responses are handled by _run_model_turn's internal loop.
"""

def __init__(
self,
prior_trace: list[ChatCompletionMessageParam],
user_input: InputType,
) -> None:
super().__init__(
system_message="",
user_input=user_input,
thinking_instructions=None,
)
self._prior_trace = prior_trace

def initial_messages(self) -> list[ChatCompletionMessageIncludingLiteLLM]:
"""Messages to seed the conversation (prior trace)."""
return list(self._prior_trace)

def next_turn(self, previous_output: str | None = None) -> Optional[ChatTurn]:
if self._state == "start":
# prior trace is already in the messages list and contains system and so on, we only need
# to append the latest new user message
user_msg = BasicChatMessage("user", format_user_message(self.user_input))
self._state = "awaiting_final"
self._messages.append(user_msg)
return ChatTurn(messages=[user_msg], final_call=True)

if self._state == "awaiting_final":
if previous_output is None:
raise ValueError("previous_output required for final step")
self._messages.append(BasicChatMessage("assistant", previous_output))
self._state = "done"
return None

return None


def get_chat_formatter(
strategy: ChatStrategy,
system_message: str,
Expand Down
27 changes: 27 additions & 0 deletions libs/core/kiln_ai/adapters/chat/test_chat_formatter.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from kiln_ai.adapters.chat import ChatStrategy, get_chat_formatter
from kiln_ai.adapters.chat.chat_formatter import (
COT_FINAL_ANSWER_PROMPT,
MultiturnFormatter,
format_user_message,
)

Expand Down Expand Up @@ -119,6 +120,32 @@ def test_chat_formatter_r1_style():
assert formatter.intermediate_outputs() == {}


def test_multiturn_formatter_initial_messages():
prior_trace = [
{"role": "user", "content": "hi"},
{"role": "assistant", "content": "hello"},
]
formatter = MultiturnFormatter(prior_trace=prior_trace, user_input="new input")
assert formatter.initial_messages() == prior_trace


def test_multiturn_formatter_next_turn():
prior_trace = [
{"role": "user", "content": "hi"},
{"role": "assistant", "content": "hello"},
]
formatter = MultiturnFormatter(prior_trace=prior_trace, user_input="follow-up")

first = formatter.next_turn()
assert first is not None
assert len(first.messages) == 1
assert first.messages[0].role == "user"
assert first.messages[0].content == "follow-up"
assert first.final_call

assert formatter.next_turn("assistant response") is None


def test_format_user_message():
# String
assert format_user_message("test input") == "test input"
Expand Down
124 changes: 89 additions & 35 deletions libs/core/kiln_ai/adapters/model_adapters/base_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,11 @@
from dataclasses import dataclass
from typing import Dict, Tuple

from kiln_ai.adapters.chat.chat_formatter import ChatFormatter, get_chat_formatter
from kiln_ai.adapters.chat.chat_formatter import (
ChatFormatter,
MultiturnFormatter,
get_chat_formatter,
)
from kiln_ai.adapters.ml_model_list import (
KilnModelProvider,
StructuredOutputMode,
Expand Down Expand Up @@ -123,14 +127,18 @@ async def invoke(
self,
input: InputType,
input_source: DataSource | None = None,
existing_run: TaskRun | None = None,
) -> TaskRun:
run_output, _ = await self.invoke_returning_run_output(input, input_source)
run_output, _ = await self.invoke_returning_run_output(
input, input_source, existing_run
)
return run_output

async def _run_returning_run_output(
self,
input: InputType,
input_source: DataSource | None = None,
existing_run: TaskRun | None = None,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

naming maybe one of: continue_from? prior_task_run? parent_task_run?

) -> Tuple[TaskRun, RunOutput]:
# validate input, allowing arrays
if self.input_schema is not None:
Expand All @@ -141,6 +149,15 @@ async def _run_returning_run_output(
require_object=False,
)

if existing_run is not None and (
not existing_run.trace or len(existing_run.trace) == 0
):
raise ValueError(
"Run has no trace. Cannot continue session without conversation history."
)

prior_trace = existing_run.trace if existing_run else None

# Format model input for model call (we save the original input in the task without formatting)
formatted_input = input
formatter_id = self.model_provider().formatter
Expand All @@ -149,7 +166,7 @@ async def _run_returning_run_output(
formatted_input = formatter.format_input(input)

# Run
run_output, usage = await self._run(formatted_input)
run_output, usage = await self._run(formatted_input, prior_trace=prior_trace)

# Parse
provider = self.model_provider()
Expand Down Expand Up @@ -198,10 +215,28 @@ async def _run_returning_run_output(
"Reasoning is required for this model, but no reasoning was returned."
)

# Generate the run and output
run = self.generate_run(
input, input_source, parsed_output, usage, run_output.trace
)
# Create the run and output - merge if there is an existing run
if existing_run is not None:
Copy link
Collaborator

@scosman scosman Mar 5, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

see comment here: #1088 (comment)

I'm leaning more towards always saving a new task_run, and just setting a new parent_id field on the child? Keep it immutable.

Will add a bit of work in UI, but much more robust for collisions.

merged_output = RunOutput(
output=parsed_output.output,
intermediate_outputs=parsed_output.intermediate_outputs
or run_output.intermediate_outputs,
output_logprobs=parsed_output.output_logprobs
or run_output.output_logprobs,
trace=run_output.trace,
)
run = self.generate_run(
input,
input_source,
merged_output,
usage,
run_output.trace,
existing_run=existing_run,
)
else:
run = self.generate_run(
input, input_source, parsed_output, usage, run_output.trace
)

# Save the run if configured to do so, and we have a path to save to
if (
Expand All @@ -210,7 +245,7 @@ async def _run_returning_run_output(
and self.task.path is not None
):
run.save_to_file()
else:
elif existing_run is None:
# Clear the ID to indicate it's not persisted
run.id = None

Expand All @@ -220,6 +255,7 @@ async def invoke_returning_run_output(
self,
input: InputType,
input_source: DataSource | None = None,
existing_run: TaskRun | None = None,
) -> Tuple[TaskRun, RunOutput]:
# Determine if this is the root agent (no existing run context)
is_root_agent = get_agent_run_id() is None
Expand All @@ -229,7 +265,9 @@ async def invoke_returning_run_output(
set_agent_run_id(run_id)

try:
return await self._run_returning_run_output(input, input_source)
return await self._run_returning_run_output(
input, input_source, existing_run
)
finally:
if is_root_agent:
try:
Expand All @@ -247,7 +285,11 @@ def adapter_name(self) -> str:
pass

@abstractmethod
async def _run(self, input: InputType) -> Tuple[RunOutput, Usage | None]:
async def _run(
self,
input: InputType,
prior_trace: list[ChatCompletionMessageParam] | None = None,
) -> Tuple[RunOutput, Usage | None]:
pass

def build_prompt(self) -> str:
Expand All @@ -267,7 +309,14 @@ def build_prompt(self) -> str:
include_json_instructions=add_json_instructions
)

def build_chat_formatter(self, input: InputType) -> ChatFormatter:
def build_chat_formatter(
self,
input: InputType,
prior_trace: list[ChatCompletionMessageParam] | None = None,
) -> ChatFormatter:
if prior_trace is not None:
return MultiturnFormatter(prior_trace, input)

if self.prompt_builder is None:
raise ValueError("Prompt builder is not available for MCP run config")
# Determine the chat strategy to use based on the prompt the user selected, the model's capabilities, and if the model was finetuned with a specific chat strategy.
Expand Down Expand Up @@ -323,24 +372,14 @@ def generate_run(
run_output: RunOutput,
usage: Usage | None = None,
trace: list[ChatCompletionMessageParam] | None = None,
existing_run: TaskRun | None = None,
) -> TaskRun:
# Convert input and output to JSON strings if they aren't strings
input_str = (
input if isinstance(input, str) else json.dumps(input, ensure_ascii=False)
)
output_str = (
json.dumps(run_output.output, ensure_ascii=False)
if isinstance(run_output.output, dict)
else run_output.output
)

# If no input source is provided, use the human data source
if input_source is None:
input_source = DataSource(
type=DataSourceType.human,
properties={"created_by": Config.shared().user_id},
)

# Synthetic since an adapter, not a human, is creating this
# Special case for MCP run configs which calls a mcp tool
output_source_type = (
Expand All @@ -349,26 +388,41 @@ def generate_run(
else DataSourceType.synthetic
)

new_task_run = TaskRun(
new_output = TaskOutput(
output=output_str,
source=DataSource(
type=output_source_type,
properties=self._properties_for_task_output(),
run_config=self.run_config,
),
)

final_usage = usage
final_intermediate = run_output.intermediate_outputs
if existing_run is not None:
final_usage = (existing_run.usage or Usage()) + (usage or Usage())
final_intermediate = run_output.intermediate_outputs

input_str = (
input if isinstance(input, str) else json.dumps(input, ensure_ascii=False)
)
if input_source is None:
input_source = DataSource(
type=DataSourceType.human,
properties={"created_by": Config.shared().user_id},
)

return TaskRun(
parent=self.task,
input=input_str,
input_source=input_source,
output=TaskOutput(
output=output_str,
source=DataSource(
type=output_source_type,
properties=self._properties_for_task_output(),
run_config=self.run_config,
),
),
intermediate_outputs=run_output.intermediate_outputs,
output=new_output,
intermediate_outputs=final_intermediate,
tags=self.base_adapter_config.default_tags or [],
usage=usage,
usage=final_usage,
trace=trace,
)

return new_task_run

def _properties_for_task_output(self) -> Dict[str, str | int | float]:
match self.run_config.type:
case "mcp":
Expand Down
Loading
Loading