Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions app/web_ui/src/lib/api_schema.d.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6939,6 +6939,11 @@ export interface components {
*
* Contains the input used, its source, the output produced, and optional
* repair information if the output needed correction.
*
* Can be nested under another TaskRun; nested runs are stored as child runs
* in a "runs" subfolder (same relationship name as Task's runs).
*
* Accepts both Task and TaskRun as parents (polymorphic).
*/
"TaskRun-Input": {
/**
Expand Down Expand Up @@ -7001,6 +7006,11 @@ export interface components {
*
* Contains the input used, its source, the output produced, and optional
* repair information if the output needed correction.
*
* Can be nested under another TaskRun; nested runs are stored as child runs
* in a "runs" subfolder (same relationship name as Task's runs).
*
* Accepts both Task and TaskRun as parents (polymorphic).
*/
"TaskRun-Output": {
/**
Expand Down
33 changes: 24 additions & 9 deletions libs/core/kiln_ai/adapters/model_adapters/base_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,9 +143,10 @@ async def invoke(
input: InputType,
input_source: DataSource | None = None,
prior_trace: list[ChatCompletionMessageParam] | None = None,
parent_task_run: TaskRun | None = None,
) -> TaskRun:
run_output, _ = await self.invoke_returning_run_output(
input, input_source, prior_trace
input, input_source, prior_trace, parent_task_run
)
return run_output

Expand All @@ -154,6 +155,7 @@ async def _run_returning_run_output(
input: InputType,
input_source: DataSource | None = None,
prior_trace: list[ChatCompletionMessageParam] | None = None,
parent_task_run: TaskRun | None = None,
) -> Tuple[TaskRun, RunOutput]:
# validate input, allowing arrays
if self.input_schema is not None:
Expand Down Expand Up @@ -224,7 +226,7 @@ async def _run_returning_run_output(
)

run = self.generate_run(
input, input_source, parsed_output, usage, run_output.trace
input, input_source, parsed_output, usage, run_output.trace, parent_task_run
)

# Save the run if configured to do so, and we have a path to save to
Expand All @@ -245,6 +247,7 @@ async def invoke_returning_run_output(
input: InputType,
input_source: DataSource | None = None,
prior_trace: list[ChatCompletionMessageParam] | None = None,
parent_task_run: TaskRun | None = None,
) -> Tuple[TaskRun, RunOutput]:
# Determine if this is the root agent (no existing run context)
is_root_agent = get_agent_run_id() is None
Expand All @@ -255,7 +258,7 @@ async def invoke_returning_run_output(

try:
return await self._run_returning_run_output(
input, input_source, prior_trace
input, input_source, prior_trace, parent_task_run
)
finally:
if is_root_agent:
Expand All @@ -271,6 +274,7 @@ def invoke_openai_stream(
input: InputType,
input_source: DataSource | None = None,
prior_trace: list[ChatCompletionMessageParam] | None = None,
parent_task_run: TaskRun | None = None,
) -> OpenAIStreamResult:
"""Stream raw OpenAI-protocol chunks for the task execution.

Expand All @@ -282,13 +286,16 @@ def invoke_openai_stream(
Tool-call rounds happen internally and are not surfaced; use
``invoke_ai_sdk_stream`` if you need tool-call events.
"""
return OpenAIStreamResult(self, input, input_source, prior_trace)
return OpenAIStreamResult(
self, input, input_source, prior_trace, parent_task_run
)

def invoke_ai_sdk_stream(
self,
input: InputType,
input_source: DataSource | None = None,
prior_trace: list[ChatCompletionMessageParam] | None = None,
parent_task_run: TaskRun | None = None,
) -> AiSdkStreamResult:
"""Stream AI SDK protocol events for the task execution.

Expand All @@ -297,7 +304,9 @@ def invoke_ai_sdk_stream(
control events. After the iterator is exhausted the resulting
``TaskRun`` is available via the ``.task_run`` property.
"""
return AiSdkStreamResult(self, input, input_source, prior_trace)
return AiSdkStreamResult(
self, input, input_source, prior_trace, parent_task_run
)

def _prepare_stream(
self,
Expand Down Expand Up @@ -327,6 +336,7 @@ def _finalize_stream(
adapter_stream: AdapterStream,
input: InputType,
input_source: DataSource | None,
parent_task_run: TaskRun | None = None,
) -> TaskRun:
"""Streaming invocations are only concerned with passing through events as they come in.
At the end of the stream, we still need to validate the output, create a run and everything
Expand Down Expand Up @@ -379,7 +389,7 @@ def _finalize_stream(
)

run = self.generate_run(
input, input_source, parsed_output, usage, run_output.trace
input, input_source, parsed_output, usage, run_output.trace, parent_task_run
)

if (
Expand Down Expand Up @@ -496,6 +506,7 @@ def generate_run(
run_output: RunOutput,
usage: Usage | None = None,
trace: list[ChatCompletionMessageParam] | None = None,
parent_task_run: TaskRun | None = None,
) -> TaskRun:
output_str = (
json.dumps(run_output.output, ensure_ascii=False)
Expand Down Expand Up @@ -530,7 +541,7 @@ def generate_run(
)

return TaskRun(
parent=self.task,
parent=parent_task_run if parent_task_run is not None else self.task,
input=input_str,
input_source=input_source,
output=new_output,
Expand Down Expand Up @@ -621,11 +632,13 @@ def __init__(
input: InputType,
input_source: DataSource | None,
prior_trace: list[ChatCompletionMessageParam] | None,
parent_task_run: TaskRun | None = None,
) -> None:
self._adapter = adapter
self._input = input
self._input_source = input_source
self._prior_trace = prior_trace
self._parent_task_run = parent_task_run
self._task_run: TaskRun | None = None

@property
Expand Down Expand Up @@ -653,7 +666,7 @@ async def __aiter__(self) -> AsyncIterator[ModelResponseStream]:
yield event

self._task_run = self._adapter._finalize_stream(
adapter_stream, self._input, self._input_source
adapter_stream, self._input, self._input_source, self._parent_task_run
)
finally:
if is_root_agent:
Expand All @@ -678,11 +691,13 @@ def __init__(
input: InputType,
input_source: DataSource | None,
prior_trace: list[ChatCompletionMessageParam] | None,
parent_task_run: TaskRun | None = None,
) -> None:
self._adapter = adapter
self._input = input
self._input_source = input_source
self._prior_trace = prior_trace
self._parent_task_run = parent_task_run
self._task_run: TaskRun | None = None

@property
Expand Down Expand Up @@ -730,7 +745,7 @@ async def __aiter__(self) -> AsyncIterator[AiSdkStreamEvent]:
yield AiSdkStreamEvent(AiSdkEventType.FINISH_STEP)

self._task_run = self._adapter._finalize_stream(
adapter_stream, self._input, self._input_source
adapter_stream, self._input, self._input_source, self._parent_task_run
)

for ai_event in converter.finalize():
Expand Down
22 changes: 13 additions & 9 deletions libs/core/kiln_ai/adapters/model_adapters/mcp_adapter.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,7 @@
import json
from typing import Tuple

from kiln_ai.adapters.model_adapters.base_adapter import (
AdapterConfig,
BaseAdapter,
)
from kiln_ai.adapters.model_adapters.base_adapter import AdapterConfig, BaseAdapter
from kiln_ai.adapters.parsers.json_parser import parse_json_string
from kiln_ai.adapters.run_output import RunOutput
from kiln_ai.datamodel import DataSource, Task, TaskRun, Usage
Expand Down Expand Up @@ -89,15 +86,16 @@ async def invoke(
input: InputType,
input_source: DataSource | None = None,
prior_trace: list[ChatCompletionMessageParam] | None = None,
parent_task_run: TaskRun | None = None,
) -> TaskRun:
if prior_trace:
if prior_trace or parent_task_run is not None:
raise NotImplementedError(
"Session continuation is not supported for MCP adapter. "
"MCP tools are single-turn and do not maintain conversation state."
)

run_output, _ = await self.invoke_returning_run_output(
input, input_source, prior_trace
input, input_source, prior_trace, parent_task_run
)
return run_output

Expand All @@ -106,12 +104,13 @@ async def invoke_returning_run_output(
input: InputType,
input_source: DataSource | None = None,
prior_trace: list[ChatCompletionMessageParam] | None = None,
parent_task_run: TaskRun | None = None,
) -> Tuple[TaskRun, RunOutput]:
"""
Runs the task and returns both the persisted TaskRun and raw RunOutput.
If this call is the root of a run, it creates an agent run context, ensures MCP tool calls have a valid session scope, and cleans up the session/context on completion.
"""
if prior_trace:
if prior_trace or parent_task_run is not None:
raise NotImplementedError(
"Session continuation is not supported for MCP adapter. "
"MCP tools are single-turn and do not maintain conversation state."
Expand All @@ -124,7 +123,9 @@ async def invoke_returning_run_output(
set_agent_run_id(run_id)

try:
return await self._run_and_validate_output(input, input_source)
return await self._run_and_validate_output(
input, input_source, parent_task_run
)
finally:
if is_root_agent:
try:
Expand All @@ -138,6 +139,7 @@ async def _run_and_validate_output(
self,
input: InputType,
input_source: DataSource | None,
parent_task_run: TaskRun | None = None,
) -> Tuple[TaskRun, RunOutput]:
"""
Run the MCP task and validate the output.
Expand Down Expand Up @@ -176,7 +178,9 @@ async def _run_and_validate_output(
# Build single turn trace
trace = self._build_single_turn_trace(input, run_output.output)

run = self.generate_run(input, input_source, run_output, usage, trace)
run = self.generate_run(
input, input_source, run_output, usage, trace, parent_task_run
)

if (
self.base_adapter_config.allow_saving
Expand Down
48 changes: 48 additions & 0 deletions libs/core/kiln_ai/adapters/model_adapters/test_mcp_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -415,3 +415,51 @@ async def test_mcp_adapter_rejects_prior_trace_in_run(

assert "Session continuation is not supported" in str(exc_info.value)
assert "MCP tools are single-turn" in str(exc_info.value)


@pytest.mark.asyncio
async def test_mcp_adapter_rejects_parent_task_run_in_invoke(
project_with_local_mcp_server, local_mcp_tool_id
):
"""invoke with parent_task_run raises NotImplementedError for MCP adapter."""
project, _ = project_with_local_mcp_server
task = Task(
name="Test Task",
parent=project,
instruction="Echo input",
)

run_config = McpRunConfigProperties(
tool_reference=MCPToolReference(tool_id=local_mcp_tool_id)
)

adapter = MCPAdapter(task=task, run_config=run_config)

with pytest.raises(NotImplementedError) as exc_info:
await adapter.invoke("input", parent_task_run=MagicMock())

assert "Session continuation is not supported" in str(exc_info.value)


@pytest.mark.asyncio
async def test_mcp_adapter_rejects_parent_task_run_in_invoke_returning_run_output(
project_with_local_mcp_server, local_mcp_tool_id
):
"""invoke_returning_run_output with parent_task_run raises NotImplementedError for MCP adapter."""
project, _ = project_with_local_mcp_server
task = Task(
name="Test Task",
parent=project,
instruction="Echo input",
)

run_config = McpRunConfigProperties(
tool_reference=MCPToolReference(tool_id=local_mcp_tool_id)
)

adapter = MCPAdapter(task=task, run_config=run_config)

with pytest.raises(NotImplementedError) as exc_info:
await adapter.invoke_returning_run_output("input", parent_task_run=MagicMock())

assert "Session continuation is not supported" in str(exc_info.value)
Loading
Loading