diff --git a/app/web_ui/src/lib/api_schema.d.ts b/app/web_ui/src/lib/api_schema.d.ts index ee09f4958..e9d47a12b 100644 --- a/app/web_ui/src/lib/api_schema.d.ts +++ b/app/web_ui/src/lib/api_schema.d.ts @@ -6939,6 +6939,11 @@ export interface components { * * Contains the input used, its source, the output produced, and optional * repair information if the output needed correction. + * + * Can be nested under another TaskRun; nested runs are stored as child runs + * in a "runs" subfolder (same relationship name as Task's runs). + * + * Accepts both Task and TaskRun as parents (polymorphic). */ "TaskRun-Input": { /** @@ -7001,6 +7006,11 @@ export interface components { * * Contains the input used, its source, the output produced, and optional * repair information if the output needed correction. + * + * Can be nested under another TaskRun; nested runs are stored as child runs + * in a "runs" subfolder (same relationship name as Task's runs). + * + * Accepts both Task and TaskRun as parents (polymorphic). */ "TaskRun-Output": { /** diff --git a/libs/core/kiln_ai/adapters/model_adapters/base_adapter.py b/libs/core/kiln_ai/adapters/model_adapters/base_adapter.py index 60cd014b1..b8c6c0bfa 100644 --- a/libs/core/kiln_ai/adapters/model_adapters/base_adapter.py +++ b/libs/core/kiln_ai/adapters/model_adapters/base_adapter.py @@ -143,9 +143,10 @@ async def invoke( input: InputType, input_source: DataSource | None = None, prior_trace: list[ChatCompletionMessageParam] | None = None, + parent_task_run: TaskRun | None = None, ) -> TaskRun: run_output, _ = await self.invoke_returning_run_output( - input, input_source, prior_trace + input, input_source, prior_trace, parent_task_run ) return run_output @@ -154,6 +155,7 @@ async def _run_returning_run_output( input: InputType, input_source: DataSource | None = None, prior_trace: list[ChatCompletionMessageParam] | None = None, + parent_task_run: TaskRun | None = None, ) -> Tuple[TaskRun, RunOutput]: # validate input, allowing arrays if self.input_schema is not None: @@ -224,7 +226,7 @@ async def _run_returning_run_output( ) run = self.generate_run( - input, input_source, parsed_output, usage, run_output.trace + input, input_source, parsed_output, usage, run_output.trace, parent_task_run ) # Save the run if configured to do so, and we have a path to save to @@ -245,6 +247,7 @@ async def invoke_returning_run_output( input: InputType, input_source: DataSource | None = None, prior_trace: list[ChatCompletionMessageParam] | None = None, + parent_task_run: TaskRun | None = None, ) -> Tuple[TaskRun, RunOutput]: # Determine if this is the root agent (no existing run context) is_root_agent = get_agent_run_id() is None @@ -255,7 +258,7 @@ async def invoke_returning_run_output( try: return await self._run_returning_run_output( - input, input_source, prior_trace + input, input_source, prior_trace, parent_task_run ) finally: if is_root_agent: @@ -271,6 +274,7 @@ def invoke_openai_stream( input: InputType, input_source: DataSource | None = None, prior_trace: list[ChatCompletionMessageParam] | None = None, + parent_task_run: TaskRun | None = None, ) -> OpenAIStreamResult: """Stream raw OpenAI-protocol chunks for the task execution. @@ -282,13 +286,16 @@ def invoke_openai_stream( Tool-call rounds happen internally and are not surfaced; use ``invoke_ai_sdk_stream`` if you need tool-call events. """ - return OpenAIStreamResult(self, input, input_source, prior_trace) + return OpenAIStreamResult( + self, input, input_source, prior_trace, parent_task_run + ) def invoke_ai_sdk_stream( self, input: InputType, input_source: DataSource | None = None, prior_trace: list[ChatCompletionMessageParam] | None = None, + parent_task_run: TaskRun | None = None, ) -> AiSdkStreamResult: """Stream AI SDK protocol events for the task execution. @@ -297,7 +304,9 @@ def invoke_ai_sdk_stream( control events. After the iterator is exhausted the resulting ``TaskRun`` is available via the ``.task_run`` property. """ - return AiSdkStreamResult(self, input, input_source, prior_trace) + return AiSdkStreamResult( + self, input, input_source, prior_trace, parent_task_run + ) def _prepare_stream( self, @@ -327,6 +336,7 @@ def _finalize_stream( adapter_stream: AdapterStream, input: InputType, input_source: DataSource | None, + parent_task_run: TaskRun | None = None, ) -> TaskRun: """Streaming invocations are only concerned with passing through events as they come in. At the end of the stream, we still need to validate the output, create a run and everything @@ -379,7 +389,7 @@ def _finalize_stream( ) run = self.generate_run( - input, input_source, parsed_output, usage, run_output.trace + input, input_source, parsed_output, usage, run_output.trace, parent_task_run ) if ( @@ -496,6 +506,7 @@ def generate_run( run_output: RunOutput, usage: Usage | None = None, trace: list[ChatCompletionMessageParam] | None = None, + parent_task_run: TaskRun | None = None, ) -> TaskRun: output_str = ( json.dumps(run_output.output, ensure_ascii=False) @@ -530,7 +541,7 @@ def generate_run( ) return TaskRun( - parent=self.task, + parent=parent_task_run if parent_task_run is not None else self.task, input=input_str, input_source=input_source, output=new_output, @@ -621,11 +632,13 @@ def __init__( input: InputType, input_source: DataSource | None, prior_trace: list[ChatCompletionMessageParam] | None, + parent_task_run: TaskRun | None = None, ) -> None: self._adapter = adapter self._input = input self._input_source = input_source self._prior_trace = prior_trace + self._parent_task_run = parent_task_run self._task_run: TaskRun | None = None @property @@ -653,7 +666,7 @@ async def __aiter__(self) -> AsyncIterator[ModelResponseStream]: yield event self._task_run = self._adapter._finalize_stream( - adapter_stream, self._input, self._input_source + adapter_stream, self._input, self._input_source, self._parent_task_run ) finally: if is_root_agent: @@ -678,11 +691,13 @@ def __init__( input: InputType, input_source: DataSource | None, prior_trace: list[ChatCompletionMessageParam] | None, + parent_task_run: TaskRun | None = None, ) -> None: self._adapter = adapter self._input = input self._input_source = input_source self._prior_trace = prior_trace + self._parent_task_run = parent_task_run self._task_run: TaskRun | None = None @property @@ -730,7 +745,7 @@ async def __aiter__(self) -> AsyncIterator[AiSdkStreamEvent]: yield AiSdkStreamEvent(AiSdkEventType.FINISH_STEP) self._task_run = self._adapter._finalize_stream( - adapter_stream, self._input, self._input_source + adapter_stream, self._input, self._input_source, self._parent_task_run ) for ai_event in converter.finalize(): diff --git a/libs/core/kiln_ai/adapters/model_adapters/mcp_adapter.py b/libs/core/kiln_ai/adapters/model_adapters/mcp_adapter.py index c488e7fc0..9b38ff9bb 100644 --- a/libs/core/kiln_ai/adapters/model_adapters/mcp_adapter.py +++ b/libs/core/kiln_ai/adapters/model_adapters/mcp_adapter.py @@ -1,10 +1,7 @@ import json from typing import Tuple -from kiln_ai.adapters.model_adapters.base_adapter import ( - AdapterConfig, - BaseAdapter, -) +from kiln_ai.adapters.model_adapters.base_adapter import AdapterConfig, BaseAdapter from kiln_ai.adapters.parsers.json_parser import parse_json_string from kiln_ai.adapters.run_output import RunOutput from kiln_ai.datamodel import DataSource, Task, TaskRun, Usage @@ -89,15 +86,16 @@ async def invoke( input: InputType, input_source: DataSource | None = None, prior_trace: list[ChatCompletionMessageParam] | None = None, + parent_task_run: TaskRun | None = None, ) -> TaskRun: - if prior_trace: + if prior_trace or parent_task_run is not None: raise NotImplementedError( "Session continuation is not supported for MCP adapter. " "MCP tools are single-turn and do not maintain conversation state." ) run_output, _ = await self.invoke_returning_run_output( - input, input_source, prior_trace + input, input_source, prior_trace, parent_task_run ) return run_output @@ -106,12 +104,13 @@ async def invoke_returning_run_output( input: InputType, input_source: DataSource | None = None, prior_trace: list[ChatCompletionMessageParam] | None = None, + parent_task_run: TaskRun | None = None, ) -> Tuple[TaskRun, RunOutput]: """ Runs the task and returns both the persisted TaskRun and raw RunOutput. If this call is the root of a run, it creates an agent run context, ensures MCP tool calls have a valid session scope, and cleans up the session/context on completion. """ - if prior_trace: + if prior_trace or parent_task_run is not None: raise NotImplementedError( "Session continuation is not supported for MCP adapter. " "MCP tools are single-turn and do not maintain conversation state." @@ -124,7 +123,9 @@ async def invoke_returning_run_output( set_agent_run_id(run_id) try: - return await self._run_and_validate_output(input, input_source) + return await self._run_and_validate_output( + input, input_source, parent_task_run + ) finally: if is_root_agent: try: @@ -138,6 +139,7 @@ async def _run_and_validate_output( self, input: InputType, input_source: DataSource | None, + parent_task_run: TaskRun | None = None, ) -> Tuple[TaskRun, RunOutput]: """ Run the MCP task and validate the output. @@ -176,7 +178,9 @@ async def _run_and_validate_output( # Build single turn trace trace = self._build_single_turn_trace(input, run_output.output) - run = self.generate_run(input, input_source, run_output, usage, trace) + run = self.generate_run( + input, input_source, run_output, usage, trace, parent_task_run + ) if ( self.base_adapter_config.allow_saving diff --git a/libs/core/kiln_ai/adapters/model_adapters/test_mcp_adapter.py b/libs/core/kiln_ai/adapters/model_adapters/test_mcp_adapter.py index cb0a3e94b..689801f86 100644 --- a/libs/core/kiln_ai/adapters/model_adapters/test_mcp_adapter.py +++ b/libs/core/kiln_ai/adapters/model_adapters/test_mcp_adapter.py @@ -415,3 +415,51 @@ async def test_mcp_adapter_rejects_prior_trace_in_run( assert "Session continuation is not supported" in str(exc_info.value) assert "MCP tools are single-turn" in str(exc_info.value) + + +@pytest.mark.asyncio +async def test_mcp_adapter_rejects_parent_task_run_in_invoke( + project_with_local_mcp_server, local_mcp_tool_id +): + """invoke with parent_task_run raises NotImplementedError for MCP adapter.""" + project, _ = project_with_local_mcp_server + task = Task( + name="Test Task", + parent=project, + instruction="Echo input", + ) + + run_config = McpRunConfigProperties( + tool_reference=MCPToolReference(tool_id=local_mcp_tool_id) + ) + + adapter = MCPAdapter(task=task, run_config=run_config) + + with pytest.raises(NotImplementedError) as exc_info: + await adapter.invoke("input", parent_task_run=MagicMock()) + + assert "Session continuation is not supported" in str(exc_info.value) + + +@pytest.mark.asyncio +async def test_mcp_adapter_rejects_parent_task_run_in_invoke_returning_run_output( + project_with_local_mcp_server, local_mcp_tool_id +): + """invoke_returning_run_output with parent_task_run raises NotImplementedError for MCP adapter.""" + project, _ = project_with_local_mcp_server + task = Task( + name="Test Task", + parent=project, + instruction="Echo input", + ) + + run_config = McpRunConfigProperties( + tool_reference=MCPToolReference(tool_id=local_mcp_tool_id) + ) + + adapter = MCPAdapter(task=task, run_config=run_config) + + with pytest.raises(NotImplementedError) as exc_info: + await adapter.invoke_returning_run_output("input", parent_task_run=MagicMock()) + + assert "Session continuation is not supported" in str(exc_info.value) diff --git a/libs/core/kiln_ai/adapters/model_adapters/test_saving_adapter_results.py b/libs/core/kiln_ai/adapters/model_adapters/test_saving_adapter_results.py index db20fe5ea..c2ebf17a5 100644 --- a/libs/core/kiln_ai/adapters/model_adapters/test_saving_adapter_results.py +++ b/libs/core/kiln_ai/adapters/model_adapters/test_saving_adapter_results.py @@ -3,7 +3,7 @@ import pytest from kiln_ai.adapters.model_adapters.base_adapter import BaseAdapter, RunOutput -from kiln_ai.datamodel import DataSource, DataSourceType, Project, Task, Usage +from kiln_ai.datamodel import DataSource, DataSourceType, Project, Task, TaskRun, Usage from kiln_ai.datamodel.datamodel_enums import InputType from kiln_ai.datamodel.run_config import KilnAgentRunConfigProperties from kiln_ai.utils.config import Config @@ -425,3 +425,122 @@ def test_properties_for_task_output_custom_values(test_task): assert output.source.properties["structured_output_mode"] == "json_schema" assert output.source.properties["temperature"] == 0.7 assert output.source.properties["top_p"] == 0.9 + + +def test_generate_run_with_parent_task_run_sets_parent(test_task, adapter): + """Test that generate_run with parent_task_run uses it as parent instead of the task.""" + prior_run = adapter.generate_run( + input="prior input", + input_source=None, + run_output=RunOutput(output="prior output", intermediate_outputs=None), + ) + prior_run.save_to_file() + assert prior_run.id is not None + + new_run = adapter.generate_run( + input="new input", + input_source=None, + run_output=RunOutput(output="new output", intermediate_outputs=None), + parent_task_run=prior_run, + ) + + assert new_run.parent == prior_run + + new_run.save_to_file() + + reloaded_prior_run = TaskRun.load_from_file(prior_run.path) + child_runs = reloaded_prior_run.runs() + assert len(child_runs) == 1 + assert child_runs[0].output.output == "new output" + + # The task should only have the prior run as a direct child + reloaded_task = Task.load_from_file(test_task.path) + task_runs = reloaded_task.runs() + assert len(task_runs) == 1 + assert task_runs[0].id == prior_run.id + + +def test_generate_run_without_parent_task_run_defaults_to_task(test_task, adapter): + """Test that generate_run without parent_task_run defaults to using the task as parent.""" + run = adapter.generate_run( + input="input", + input_source=None, + run_output=RunOutput(output="output", intermediate_outputs=None), + ) + assert run.parent == test_task + + +@pytest.mark.asyncio +async def test_invoke_with_parent_task_run_saves_as_child(test_task, adapter): + """Test that invoke with parent_task_run saves the new run as a child of that run.""" + trace = [ + {"role": "user", "content": "Hello"}, + {"role": "assistant", "content": "Hi there!"}, + ] + + # Create and save a prior run to act as parent + prior_run = adapter.generate_run( + input="Hello", + input_source=None, + run_output=RunOutput( + output="Hi there!", intermediate_outputs=None, trace=trace + ), + trace=trace, + ) + prior_run.save_to_file() + assert prior_run.id is not None + + continuation_trace = [ + *trace, + {"role": "user", "content": "Tell me more"}, + {"role": "assistant", "content": "More details!"}, + ] + continuation_output = RunOutput( + output="More details!", + intermediate_outputs=None, + trace=continuation_trace, + ) + + adapter._run = AsyncMock(return_value=(continuation_output, None)) + + with ( + patch("kiln_ai.utils.config.Config.shared") as mock_shared, + patch.object( + adapter, + "model_provider", + return_value=MagicMock( + parser="default", + formatter=None, + reasoning_capable=False, + ), + ), + patch( + "kiln_ai.adapters.model_adapters.base_adapter.model_parser_from_id", + return_value=MagicMock( + parse_output=MagicMock(return_value=continuation_output) + ), + ), + ): + mock_shared.return_value.autosave_runs = True + mock_shared.return_value.user_id = "test_user" + + new_run = await adapter.invoke( + "Tell me more", + prior_trace=trace, + parent_task_run=prior_run, + ) + + assert new_run.id is not None + assert new_run.parent == prior_run + + # The prior run should have the new run as a child + reloaded_prior_run = TaskRun.load_from_file(prior_run.path) + child_runs = reloaded_prior_run.runs() + assert len(child_runs) == 1 + assert child_runs[0].output.output == "More details!" + + # The task should only have the prior run as a direct child + reloaded_task = Task.load_from_file(test_task.path) + task_runs = reloaded_task.runs() + assert len(task_runs) == 1 + assert task_runs[0].id == prior_run.id diff --git a/libs/core/kiln_ai/datamodel/basemodel.py b/libs/core/kiln_ai/datamodel/basemodel.py index c29f11f7e..90e981b78 100644 --- a/libs/core/kiln_ai/datamodel/basemodel.py +++ b/libs/core/kiln_ai/datamodel/basemodel.py @@ -545,17 +545,49 @@ def relationship_name(cls) -> str: def parent_type(cls) -> Type[KilnBaseModel]: raise NotImplementedError("Parent type must be implemented") - @model_validator(mode="after") - def check_parent_type(self) -> Self: + @classmethod + def _parent_types(cls) -> List[Type["KilnBaseModel"]] | None: + """Return accepted parent types. This must be implemented by the subclass if + the model can have multiple parent types. + + Return None (default) to use the single parent_type() check. + Override and return a list of parent types for models that can be nested + under more than one parent type (e.g. TaskRun can be nested under Task or TaskRun). + """ + return None + + def _check_parent_type( + self, + expected_parent_types: List[Type[KilnBaseModel]] | None = None, + ) -> Self: cached_parent = self.cached_parent() - if cached_parent is not None: + if cached_parent is None: + return self + + # some models support having multiple parent types, so we allow overriding the expected parent + if expected_parent_types is not None: + if not any( + isinstance(cached_parent, expected_parent_type) + for expected_parent_type in expected_parent_types + ): + raise ValueError( + f"Parent must be one of {expected_parent_types}, but was {type(cached_parent)}" + ) + else: + # default case where we expect a single parent type to be valid expected_parent_type = self.__class__.parent_type() if not isinstance(cached_parent, expected_parent_type): raise ValueError( f"Parent must be of type {expected_parent_type}, but was {type(cached_parent)}" ) + return self + @model_validator(mode="after") + def check_parent_type(self) -> Self: + """Default validation for parent type. Can be overridden by subclasses - for example if the parent is polymorphic.""" + return self._check_parent_type() + def build_child_dirname(self) -> Path: # Default implementation for readable folder names. # {id} - {name}/{type}.kiln @@ -602,10 +634,39 @@ def iterate_children_paths_of_parent_path(cls: Type[PT], parent_path: Path | Non else: parent_folder = parent_path - parent = cls.parent_type().load_from_file(parent_path) - if parent is None: + if not parent_path.exists(): raise ValueError("Parent must be set to load children") + # Validate the parent file's declared type so we fail fast when the caller + # passes a wrong path. For polymorphic children (e.g. TaskRun) the + # subclass overrides _accepted_parent_types() to broaden the check to all + # accepted parent types + parent_types_override = cls._parent_types() + if parent_types_override is None: + # Default: single expected parent type — original behaviour + parent = cls.parent_type().load_from_file(parent_path) + if parent is None: + raise ValueError("Parent must be set to load children") + else: + # Polymorphic parent: read only the model_type field to avoid a full load. + with open(parent_path, "r", encoding="utf-8") as fh: + actual_parent_type_name = json.loads(fh.read()).get("model_type", "") + parent_type_names = {t.type_name() for t in parent_types_override} + if actual_parent_type_name not in parent_type_names: + raise ValueError( + f"Parent model_type '{actual_parent_type_name}' is not one of " + f"{parent_type_names}" + ) + + parent_type = next( + t + for t in parent_types_override + if t.type_name() == actual_parent_type_name + ) + parent = parent_type.load_from_file(parent_path) + if parent is None: + raise ValueError("Parent must be set to load children") + # Ignore type error: this is abstract base class, but children must implement relationship_name relationship_folder = parent_folder / Path(cls.relationship_name()) # type: ignore diff --git a/libs/core/kiln_ai/datamodel/task.py b/libs/core/kiln_ai/datamodel/task.py index c223164cb..d5ace6449 100644 --- a/libs/core/kiln_ai/datamodel/task.py +++ b/libs/core/kiln_ai/datamodel/task.py @@ -203,3 +203,21 @@ def parent_project(self) -> Union["Project", None]: if self.parent is None or self.parent.__class__.__name__ != "Project": return None return self.parent # type: ignore + + def find_task_run_by_id_dfs( + self, task_run_id: str, readonly: bool = False + ) -> TaskRun | None: + """ + Find a task run by id in the entire task run tree. This is an expensive DFS + traversal of the file system so do not use too willy nilly. + + If you already know the root task run, you can use the same method on + the root TaskRun instead - that will save a bunch of subtree traversals. + """ + stack: List[TaskRun] = list(self.runs(readonly=readonly)) + while stack: + run = stack.pop() + if run.id == task_run_id: + return run + stack.extend(run.runs(readonly=readonly)) + return None diff --git a/libs/core/kiln_ai/datamodel/task_run.py b/libs/core/kiln_ai/datamodel/task_run.py index bef7ae700..40dac39e3 100644 --- a/libs/core/kiln_ai/datamodel/task_run.py +++ b/libs/core/kiln_ai/datamodel/task_run.py @@ -1,10 +1,14 @@ import json -from typing import TYPE_CHECKING, Dict, List, Union +from typing import TYPE_CHECKING, Dict, List, Optional, Type, Union from pydantic import BaseModel, Field, ValidationInfo, model_validator from typing_extensions import Self -from kiln_ai.datamodel.basemodel import KilnParentedModel +from kiln_ai.datamodel.basemodel import ( + KilnBaseModel, + KilnParentedModel, + KilnParentModel, +) from kiln_ai.datamodel.json_schema import validate_schema_with_value_error from kiln_ai.datamodel.strict_mode import strict_mode from kiln_ai.datamodel.task_output import DataSource, TaskOutput @@ -73,12 +77,17 @@ def _add_optional_float(a: float | None, b: float | None) -> float | None: ) -class TaskRun(KilnParentedModel): +class TaskRun(KilnParentedModel, KilnParentModel, parent_of={}): """ Represents a single execution of a Task. Contains the input used, its source, the output produced, and optional repair information if the output needed correction. + + Can be nested under another TaskRun; nested runs are stored as child runs + in a "runs" subfolder (same relationship name as Task's runs). + + Accepts both Task and TaskRun as parents (polymorphic). """ input: str = Field( @@ -132,9 +141,108 @@ def has_thinking_training_data(self) -> bool: # Workaround to return typed parent without importing Task def parent_task(self) -> Union["Task", None]: - if self.parent is None or self.parent.__class__.__name__ != "Task": + """The Task that this Run is in. Note the TaskRun may be nested in which case we walk back up the tree all the way to the root.""" + # lazy import to avoid circular dependency + from kiln_ai.datamodel.task import Task + + current: TaskRun = self + while True: + # should never really happen, except maybe in tests + parent = current.parent + if parent is None: + return None + + # this task run is the root task run + # so we just return its parent (a Task) + if isinstance(parent, Task): + return parent + + if not isinstance(parent, TaskRun): + # the parent is not a TaskRun, but also not a Task, so it is not + # a real parent + return None + + # the parent is a TaskRun, so we just walk up the tree until we find a Task + current = parent + + def parent_run(self) -> "TaskRun | None": + """The TaskRun that contains this run, if this run is nested; otherwise None.""" + parent = self.parent + if parent is None or not isinstance(parent, TaskRun): return None - return self.parent # type: ignore + return parent + + @classmethod + def _parent_types(cls) -> List[Type["KilnBaseModel"]]: + # lazy import to avoid circular dependency + from kiln_ai.datamodel.task import Task + + return [Task, TaskRun] + + def runs(self, readonly: bool = False) -> list["TaskRun"]: + """The list of child task runs.""" + return super().runs(readonly=readonly) # type: ignore + + def is_root_task_run(self) -> bool: + """Is this the root task run? (not nested under another task run)""" + # lazy import to avoid circular dependency + from kiln_ai.datamodel.task import Task + + return self.parent is None or isinstance(self.parent, Task) + + def find_task_run_by_id_dfs( + self, task_run_id: str, readonly: bool = False + ) -> "TaskRun | None": + """ + Find a task run by id in the entire task run tree. This is an expensive DFS + traversal of the file system so do not use too willy nilly. + """ + stack: List[TaskRun] = list(self.runs(readonly=readonly)) + while stack: + run = stack.pop() + if run.id == task_run_id: + return run + stack.extend(run.runs(readonly=readonly)) + return None + + def load_parent(self) -> Optional[KilnBaseModel]: + """Load the parent of this task run - this is an override of the default parent loading logic to support nested task runs.""" + cached = self.cached_parent() + if cached is not None: + return cached + if self.path is None: + return None + parent_dir = self.path.parent.parent.parent + task_run_path = parent_dir / TaskRun.base_filename() + if task_run_path.exists() and task_run_path != self.path: + try: + loaded_parent_run = TaskRun.load_from_file(task_run_path) + super().__setattr__("parent", loaded_parent_run) + return loaded_parent_run + except ValueError as e: + raise ValueError( + f"Failed to load parent TaskRun from {task_run_path}. " + f"This indicates a malformed nested task run. Error: {e}" + ) from e + + from kiln_ai.datamodel.task import Task + + task_path = parent_dir / Task.base_filename() + if task_path.exists(): + loaded_parent_task = Task.load_from_file(task_path) + super().__setattr__("parent", loaded_parent_task) + return loaded_parent_task + + return None + + @model_validator(mode="after") + def check_parent_type(self) -> Self: + """Check that the parent is a Task or TaskRun. This overrides the default parent type check + that only supports a single parent type.""" + # need to import here to avoid circular imports + from kiln_ai.datamodel.task import Task + + return self._check_parent_type([Task, TaskRun]) @model_validator(mode="after") def validate_input_format(self, info: ValidationInfo) -> Self: @@ -258,3 +366,10 @@ def validate_tags(self) -> Self: raise ValueError("Tags cannot contain spaces. Try underscores.") return self + + +# cannot do this in the class definition due to circular reference between TaskRun and itself: +# wire up TaskRun as its own child type so .runs() returns TaskRun instances +# this makes TaskRun polymorphic - can be parented under Task or another TaskRun +TaskRun._parent_of["runs"] = TaskRun +TaskRun._create_child_method("runs", TaskRun) diff --git a/libs/core/kiln_ai/datamodel/test_basemodel.py b/libs/core/kiln_ai/datamodel/test_basemodel.py index 897a3749f..be4bc8c45 100644 --- a/libs/core/kiln_ai/datamodel/test_basemodel.py +++ b/libs/core/kiln_ai/datamodel/test_basemodel.py @@ -11,7 +11,7 @@ from kiln_ai.adapters.model_adapters.base_adapter import BaseAdapter from kiln_ai.adapters.run_output import RunOutput -from kiln_ai.datamodel import Task, TaskRun +from kiln_ai.datamodel import Project, Task, TaskOutput, TaskRun from kiln_ai.datamodel.basemodel import ( MAX_FILENAME_LENGTH, KilnBaseModel, @@ -1166,3 +1166,200 @@ def test_readonly_cache_integration(tmp_model_cache, tmp_path): cached_readonly = ReadonlyTestModel.load_from_file(test_file, readonly=True) assert cached_readonly._readonly is True assert cached_readonly.name == "cached_model" + + +# ============================================================================ +# Tests for _parent_types() and parent type validation in _load_parent_and_validate_children +# ============================================================================ + + +def test_default_parent_types_returns_none(): + """Default _parent_types() returns None for models with single parent type.""" + assert DefaultParentedModel._parent_types() is None + assert NamedParentedModel._parent_types() is None + + +def test_taskrun_parent_types_returns_task_and_taskrun(): + """TaskRun._parent_types() returns [Task, TaskRun] for polymorphic parent support.""" + + parent_types = TaskRun._parent_types() + assert parent_types is not None + assert len(parent_types) == 2 + + parent_type_names = {t.type_name() for t in parent_types} + assert "task" in parent_type_names + assert "task_run" in parent_type_names + + +def test_invalid_parent_with_single_parent_type(tmp_path): + """Loading children fails when parent path points to wrong model type (single parent case).""" + # Create a project (wrong parent type) + project_path = tmp_path / "project.kiln" + project = Project(name="Test Project", path=project_path) + project.save_to_file() + + # Try to load DefaultParentedModel children from a Project path + # DefaultParentedModel expects BaseParentExample as parent, not Project + # The error occurs when trying to load the parent as the wrong type + with pytest.raises( + ValueError, match="Cannot load from file because the model type is incorrect" + ): + list(DefaultParentedModel.iterate_children_paths_of_parent_path(project_path)) + + +def test_invalid_parent_with_multiple_parent_types(tmp_path): + """Loading children fails when parent's model_type is not in accepted polymorphic types.""" + # Create a project (not an accepted parent type for TaskRun) + project_path = tmp_path / "project.kiln" + project = Project(name="Test Project", path=project_path) + project.save_to_file() + + # Try to load TaskRun children from a Project path + # TaskRun accepts Task and TaskRun, not Project + with pytest.raises(ValueError, match="Parent model_type 'project' is not one of"): + list(TaskRun.iterate_children_paths_of_parent_path(project_path)) + + +def test_valid_parent_type_single_parent(tmp_path): + """Successfully loads children when parent type matches single expected type.""" + parent = BaseParentExample(path=tmp_path / BaseParentExample.base_filename()) + parent.save_to_file() + + child = DefaultParentedModel(parent=parent, name="Test Child") + child.save_to_file() + + # Load children - should succeed since parent is correct type + children = list( + DefaultParentedModel.iterate_children_paths_of_parent_path(parent.path) + ) + assert len(children) == 1 + assert children[0] == child.path + + +def test_valid_parent_type_polymorphic_taskrun_as_parent(tmp_path): + """TaskRun can be loaded with TaskRun as parent (polymorphic case).""" + output = TaskOutput(output="test output") + + # Create a task as the ultimate parent + task = Task( + name="Test Task", + instruction="Test instruction", + path=tmp_path / "task.kiln", + ) + task.save_to_file() + + # Create a parent TaskRun + parent_run = TaskRun(input="parent input", output=output, parent=task) + parent_run.save_to_file() + + # Create a nested TaskRun + nested_run = TaskRun(input="nested input", output=output, parent=parent_run) + nested_run.save_to_file() + + # Load children of parent_run - should succeed + children = list(TaskRun.iterate_children_paths_of_parent_path(parent_run.path)) + assert len(children) == 1 + assert children[0] == nested_run.path + + +def test_valid_parent_type_polymorphic_task_as_parent(tmp_path): + """TaskRun can be loaded with Task as parent (polymorphic case).""" + output = TaskOutput(output="test output") + + # Create a task + task = Task( + name="Test Task", + instruction="Test instruction", + path=tmp_path / "task.kiln", + ) + task.save_to_file() + + # Create a TaskRun under the task + task_run = TaskRun(input="test input", output=output, parent=task) + task_run.save_to_file() + + # Load children of task - should succeed + children = list(TaskRun.iterate_children_paths_of_parent_path(task.path)) + assert len(children) == 1 + assert children[0] == task_run.path + + +def test_invalid_parent_type_name_mismatch_polymorphic(tmp_path): + """Polymorphic validation fails when parent file has wrong model_type.""" + # Create a file with wrong model_type + wrong_parent_path = tmp_path / "wrong_parent.kiln" + wrong_data = { + "v": 1, + "name": "Wrong Parent", + "model_type": "project", # Wrong type - not in accepted types + } + with open(wrong_parent_path, "w") as f: + json.dump(wrong_data, f) + + # Try to load TaskRun children from wrong parent + # TaskRun accepts Task and TaskRun, not Project + with pytest.raises(ValueError, match="Parent model_type 'project' is not one of"): + list(TaskRun.iterate_children_paths_of_parent_path(wrong_parent_path)) + + +def test_parent_loading_single_parent_type_fails_on_corrupt_file(tmp_path): + """Single parent type loading fails when parent file is corrupt/invalid.""" + # Create a corrupt parent file + corrupt_parent_path = tmp_path / "corrupt.kiln" + with open(corrupt_parent_path, "w") as f: + f.write("not valid json {{{") + + # The load_from_file call within iterate_children_paths_of_parent_path + # will fail when trying to parse the corrupt JSON + with pytest.raises(ValueError, match="Expecting value"): + list( + DefaultParentedModel.iterate_children_paths_of_parent_path( + corrupt_parent_path + ) + ) + + +def test_parent_loading_polymorphic_fails_on_corrupt_file(tmp_path): + """Polymorphic parent loading fails when parent file is corrupt/invalid.""" + # Create a corrupt parent file + corrupt_parent_path = tmp_path / "corrupt.kiln" + with open(corrupt_parent_path, "w") as f: + f.write("not valid json {{{") + + # The polymorphic path first reads model_type, which will fail on corrupt JSON + with pytest.raises(json.JSONDecodeError): + list(TaskRun.iterate_children_paths_of_parent_path(corrupt_parent_path)) + + +def test_parent_loading_single_parent_nonexistent_file(tmp_path): + """Single parent type loading fails when parent file doesn't exist.""" + nonexistent_path = tmp_path / "nonexistent.kiln" + + with pytest.raises(ValueError, match="Parent must be set to load children"): + list( + DefaultParentedModel.iterate_children_paths_of_parent_path(nonexistent_path) + ) + + +def test_parent_loading_polymorphic_nonexistent_file(tmp_path): + """Polymorphic parent loading fails when parent file doesn't exist.""" + nonexistent_path = tmp_path / "nonexistent.kiln" + + with pytest.raises(ValueError, match="Parent must be set to load children"): + list(TaskRun.iterate_children_paths_of_parent_path(nonexistent_path)) + + +def test_all_children_of_parent_path_single_parent_type(tmp_path): + """all_children_of_parent_path works correctly for single parent type models.""" + parent = BaseParentExample(path=tmp_path / "parent.kiln") + parent.save_to_file() + + child1 = DefaultParentedModel(parent=parent, name="Child1") + child2 = DefaultParentedModel(parent=parent, name="Child2") + child1.save_to_file() + child2.save_to_file() + + children = DefaultParentedModel.all_children_of_parent_path(parent.path) + assert len(children) == 2 + names = {child.name for child in children} + assert names == {"Child1", "Child2"} diff --git a/libs/core/kiln_ai/datamodel/test_models.py b/libs/core/kiln_ai/datamodel/test_models.py index 68d0f611e..cc11ba9a6 100644 --- a/libs/core/kiln_ai/datamodel/test_models.py +++ b/libs/core/kiln_ai/datamodel/test_models.py @@ -743,3 +743,464 @@ def test_generate_model_id(): # check it is a valid name - as we typically use model ids in filenames on FS validator = name_validator(min_length=1, max_length=12) validator(model_id) + + +# project and task fixture +@pytest.fixture +def task(tmp_path): + project_path = tmp_path / "project.kiln" + project = Project(name="P", path=project_path) + project.save_to_file() + task = Task(name="T", instruction="Do it", parent=project) + task.save_to_file() + return task + + +def test_nested_task_run_folder_structure(task: Task): + output = TaskOutput(output="out") + + parent_run = TaskRun(input="in", output=output, parent=task) + parent_run.save_to_file() + + nested_run = TaskRun(input="nested in", output=output, parent=parent_run) + nested_run.save_to_file() + + assert task.path is not None + assert parent_run.path is not None + assert nested_run.path is not None + task_dir = task.path.parent + runs_dir = task_dir / "runs" + parent_run_dir = runs_dir / parent_run.build_child_dirname() + nested_runs_dir = parent_run_dir / "runs" + + assert runs_dir.is_dir() + assert (parent_run_dir / "task_run.kiln").is_file() + assert nested_runs_dir.is_dir() + assert nested_run.path.is_file() + assert nested_run.path.parent.parent.parent == parent_run_dir + assert nested_run.path.name == TaskRun.base_filename() + + assert parent_run.parent_task() == task + assert parent_run.parent_run() is None + assert nested_run.parent_task() == task + assert nested_run.parent_run() == parent_run + assert len(parent_run.runs()) == 1 + assert parent_run.runs()[0].id == nested_run.id + + +def test_nested_task_runs_multiple_levels(task: Task): + output = TaskOutput(output="out") + + run1 = TaskRun(input="in1", output=output, parent=task) + run1.save_to_file() + run2 = TaskRun(input="in2", output=output, parent=run1) + run2.save_to_file() + run3 = TaskRun(input="in3", output=output, parent=run2) + run3.save_to_file() + + assert run1.parent_run() is None + assert run1.parent_task() == task + assert len(run1.runs()) == 1 + assert run1.runs()[0].id == run2.id + + assert run2.parent_run() == run1 + assert run2.parent_task() == task + assert len(run2.runs()) == 1 + assert run2.runs()[0].id == run3.id + + assert run3.parent_run() == run2 + assert run3.parent_task() == task + assert len(run3.runs()) == 0 + + +def test_parent_task_deeply_nested_task_run(task: Task): + output = TaskOutput(output="out") + run1 = TaskRun(input="in1", output=output, parent=task) + run1.save_to_file() + run2 = TaskRun(input="in2", output=output, parent=run1) + run2.save_to_file() + run3 = TaskRun(input="in3", output=output, parent=run2) + run3.save_to_file() + + assert run3.parent_task() == task + + +def test_find_nested_task_run_by_id_given_parent_run(task: Task): + assert task.path is not None + + output = TaskOutput(output="out") + parent_run = TaskRun(input="in", output=output, parent=task) + parent_run.save_to_file() + nested_run = TaskRun(input="nested in", output=output, parent=parent_run) + nested_run.save_to_file() + target_id = nested_run.id + + loaded_task = Task.load_from_file(task.path) + loaded_parent = next(r for r in loaded_task.runs() if r.id == parent_run.id) + found = next(r for r in loaded_parent.runs() if r.id == target_id) + assert found is not None + assert found.id == target_id + assert found.input == "nested in" + + +def test_find_root_task_run_by_id_given_task(task: Task): + output = TaskOutput(output="out") + root_run = TaskRun(input="in", output=output, parent=task) + root_run.save_to_file() + target_id = root_run.id + + assert task.path is not None + loaded_task = Task.load_from_file(task.path) + found = next(r for r in loaded_task.runs() if r.id == target_id) + assert found is not None + assert found.id == target_id + assert found.input == "in" + + +def test_find_task_run_by_id_dfs_finds_deeply_nested_run(task: Task): + """Task.find_task_run_by_id_dfs finds a run nested several levels (iterative stack-based DFS).""" + assert task.path is not None + output = TaskOutput(output="out") + + run1 = TaskRun(input="in1", output=output, parent=task) + run1.save_to_file() + + run2 = TaskRun(input="in2", output=output, parent=run1) + run2.save_to_file() + + run3 = TaskRun(input="in3", output=output, parent=run2) + run3.save_to_file() + + assert run3.id is not None + + loaded_task = Task.load_from_file(task.path) + found = loaded_task.find_task_run_by_id_dfs(run3.id) + assert found is not None + assert found.id == run3.id + assert found.input == "in3" + + +def test_find_task_run_by_id_dfs_finds_root_run(task: Task): + """Task.find_task_run_by_id_dfs finds the root run when it is the only run.""" + assert task.path is not None + output = TaskOutput(output="out") + root_run = TaskRun(input="root in", output=output, parent=task) + root_run.save_to_file() + assert root_run.id is not None + + loaded_task = Task.load_from_file(task.path) + found = loaded_task.find_task_run_by_id_dfs(root_run.id) + assert found is not None + assert found.id == root_run.id + assert found.input == "root in" + + +def test_find_task_run_by_id_dfs_returns_none_when_not_found(task: Task): + """Task.find_task_run_by_id_dfs returns None when no run has the given id.""" + assert task.path is not None + output = TaskOutput(output="out") + TaskRun(input="in", output=output, parent=task).save_to_file() + + loaded_task = Task.load_from_file(task.path) + found = loaded_task.find_task_run_by_id_dfs("nonexistent-id") + assert found is None + + +def test_task_run_find_task_run_by_id_dfs_finds_deeply_nested_run(task: Task): + """TaskRun.find_task_run_by_id_dfs finds a run nested several levels in its subtree.""" + assert task.path is not None + output = TaskOutput(output="out") + + run1 = TaskRun(input="in1", output=output, parent=task) + run1.save_to_file() + + run2 = TaskRun(input="in2", output=output, parent=run1) + run2.save_to_file() + + run3 = TaskRun(input="in3", output=output, parent=run2) + run3.save_to_file() + + assert run3.id is not None + + assert run1.path is not None + loaded_run1 = TaskRun.load_from_file(run1.path) + found = loaded_run1.find_task_run_by_id_dfs(run3.id) + assert found is not None + assert found.id == run3.id + assert found.input == "in3" + + +def test_task_run_find_task_run_by_id_dfs_finds_direct_child(task: Task): + """TaskRun.find_task_run_by_id_dfs finds a direct child run.""" + assert task.path is not None + output = TaskOutput(output="out") + + run1 = TaskRun(input="in1", output=output, parent=task) + run1.save_to_file() + + run2 = TaskRun(input="in2", output=output, parent=run1) + run2.save_to_file() + + assert run2.id is not None + + loaded_run1 = TaskRun.load_from_file(run1.path) + found = loaded_run1.find_task_run_by_id_dfs(run2.id) + assert found is not None + assert found.id == run2.id + assert found.input == "in2" + + +def test_task_run_find_task_run_by_id_dfs_returns_none_when_not_found(task: Task): + """TaskRun.find_task_run_by_id_dfs returns None when no run in its subtree has the given id.""" + assert task.path is not None + output = TaskOutput(output="out") + + run1 = TaskRun(input="in1", output=output, parent=task) + run1.save_to_file() + + TaskRun(input="in2", output=output, parent=run1).save_to_file() + + loaded_run1 = TaskRun.load_from_file(run1.path) + found = loaded_run1.find_task_run_by_id_dfs("nonexistent-id") + assert found is None + + +def test_is_root_task_run(task: Task): + output = TaskOutput(output="out") + root_run = TaskRun(input="in", output=output, parent=task) + root_run.save_to_file() + assert root_run.is_root_task_run() + nested_run = TaskRun(input="nested in", output=output, parent=root_run) + nested_run.save_to_file() + assert not nested_run.is_root_task_run() + + +def test_comprehensive_nested_task_run_hierarchy(tmp_path): + """Comprehensive integration test for polymorphic parent support on TaskRun. + + Tests: + - Project -> Task -> TaskRun hierarchy + - Multiple levels of nested TaskRuns (Task -> TaskRun -> TaskRun -> TaskRun -> TaskRun) + - Sibling TaskRuns at various levels + - Retrieval of all runs from the hierarchy + - is_root_task_run() correctness at all levels + - parent_task() and parent_run() correctness at all levels + """ + project_path = tmp_path / "project.kiln" + project = Project(name="Test Project", path=project_path) + project.save_to_file() + task = Task(name="Test Task", instruction="Test instruction", parent=project) + task.save_to_file() + + output = TaskOutput(output="test output") + + # Level 1: Two sibling TaskRuns under Task + run1_l1 = TaskRun(input="level1_run1", output=output, parent=task) + run1_l1.save_to_file() + + run2_l1 = TaskRun(input="level1_run2", output=output, parent=task) + run2_l1.save_to_file() + + # Level 2: Nested TaskRuns under run1_l1 (with sibling) + run1_l2 = TaskRun(input="level2_run1", output=output, parent=run1_l1) + run1_l2.save_to_file() + + run2_l2 = TaskRun(input="level2_run2", output=output, parent=run1_l1) + run2_l2.save_to_file() + + # Level 3: Nested TaskRuns under run1_l2 (with sibling) + run1_l3 = TaskRun(input="level3_run1", output=output, parent=run1_l2) + run1_l3.save_to_file() + + run2_l3 = TaskRun(input="level3_run2", output=output, parent=run1_l2) + run2_l3.save_to_file() + + # Level 4: Deepest nested TaskRun under run1_l3 + run1_l4 = TaskRun(input="level4_run1", output=output, parent=run1_l3) + run1_l4.save_to_file() + + # Level 4 sibling of run1_l4 under run2_l3 + run2_l4 = TaskRun(input="level4_sibling", output=output, parent=run2_l3) + run2_l4.save_to_file() + + # Reload everything from disk to test persistence and retrieval + loaded_project = Project.load_from_file(project_path) + loaded_task = loaded_project.tasks()[0] + + # Verify Task has 2 root-level runs + root_runs = loaded_task.runs() + assert len(root_runs) == 2 + root_run_ids = {r.id for r in root_runs} + assert run1_l1.id in root_run_ids + assert run2_l1.id in root_run_ids + + # Verify run1_l1 hierarchy + loaded_run1_l1 = next(r for r in root_runs if r.id == run1_l1.id) + assert loaded_run1_l1.parent_task() == loaded_task + assert loaded_run1_l1.parent_run() is None + assert loaded_run1_l1.is_root_task_run() is True + + level2_runs = loaded_run1_l1.runs() + assert len(level2_runs) == 2 + level2_run_ids = {r.id for r in level2_runs} + assert run1_l2.id in level2_run_ids + assert run2_l2.id in level2_run_ids + + # Verify run1_l2 hierarchy + loaded_run1_l2 = next(r for r in level2_runs if r.id == run1_l2.id) + assert loaded_run1_l2.parent_task() == loaded_task + assert loaded_run1_l2.parent_run() == loaded_run1_l1 + assert loaded_run1_l2.is_root_task_run() is False + + level3_runs = loaded_run1_l2.runs() + assert len(level3_runs) == 2 + level3_run_ids = {r.id for r in level3_runs} + assert run1_l3.id in level3_run_ids + assert run2_l3.id in level3_run_ids + + # Verify run1_l3 hierarchy (level 3) + loaded_run1_l3 = next(r for r in level3_runs if r.id == run1_l3.id) + assert loaded_run1_l3.parent_task() == loaded_task + assert loaded_run1_l3.parent_run().id == loaded_run1_l2.id + assert loaded_run1_l3.is_root_task_run() is False + + level4_runs = loaded_run1_l3.runs() + assert len(level4_runs) == 1 + assert level4_runs[0].id == run1_l4.id + + # Verify run1_l4 (deepest run) + loaded_run1_l4 = level4_runs[0] + assert loaded_run1_l4.parent_task() == loaded_task + assert loaded_run1_l4.parent_run().id == loaded_run1_l3.id + assert loaded_run1_l4.is_root_task_run() is False + assert len(loaded_run1_l4.runs()) == 0 + + # Verify run2_l3's nested run (sibling at level 3 has child at level 4) + loaded_run2_l3 = next(r for r in level3_runs if r.id == run2_l3.id) + assert loaded_run2_l3.parent_task() == loaded_task + assert loaded_run2_l3.parent_run().id == loaded_run1_l2.id + assert loaded_run2_l3.is_root_task_run() is False + + level4_sibling_runs = loaded_run2_l3.runs() + assert len(level4_sibling_runs) == 1 + assert level4_sibling_runs[0].id == run2_l4.id + + # Verify run2_l4 (sibling at level 4) + loaded_run2_l4 = level4_sibling_runs[0] + assert loaded_run2_l4.parent_task() == loaded_task + assert loaded_run2_l4.parent_run().id == loaded_run2_l3.id + assert loaded_run2_l4.is_root_task_run() is False + + # Verify run2_l1 (sibling at level 1 has no children) + loaded_run2_l1 = next(r for r in root_runs if r.id == run2_l1.id) + assert loaded_run2_l1.parent_task() == loaded_task + assert loaded_run2_l1.parent_run() is None + assert loaded_run2_l1.is_root_task_run() is True + assert len(loaded_run2_l1.runs()) == 0 + + # Verify run2_l2 (sibling at level 2 has no children) + loaded_run2_l2 = next(r for r in level2_runs if r.id == run2_l2.id) + assert loaded_run2_l2.parent_task() == loaded_task + assert loaded_run2_l2.parent_run() == loaded_run1_l1 + assert loaded_run2_l2.is_root_task_run() is False + assert len(loaded_run2_l2.runs()) == 0 + + # Test finding runs by ID from Task (DFS) + found_run1_l4 = loaded_task.find_task_run_by_id_dfs(run1_l4.id) + assert found_run1_l4 is not None + assert found_run1_l4.id == run1_l4.id + assert found_run1_l4.input == "level4_run1" + + found_run2_l4 = loaded_task.find_task_run_by_id_dfs(run2_l4.id) + assert found_run2_l4 is not None + assert found_run2_l4.id == run2_l4.id + assert found_run2_l4.input == "level4_sibling" + + # Test finding runs by ID from a TaskRun (DFS) + found_from_run1_l1 = loaded_run1_l1.find_task_run_by_id_dfs(run1_l4.id) + assert found_from_run1_l1 is not None + assert found_from_run1_l1.id == run1_l4.id + + # Count total runs in the hierarchy (should be 8) + all_runs = collect_all_task_runs(loaded_task) + assert len(all_runs) == 8 + + # Verify all levels are represented correctly + is_root_runs = [r for r in all_runs if r.is_root_task_run()] + assert len(is_root_runs) == 2 # run1_l1 and run2_l1 + + # Verify all non-root runs have parent_run set correctly + non_root_runs = [r for r in all_runs if not r.is_root_task_run()] + assert len(non_root_runs) == 6 + for run in non_root_runs: + assert run.parent_run() is not None + + +def collect_all_task_runs(root): + """Helper to recursively collect all TaskRuns in a hierarchy.""" + runs = [] + + def collect(node): + if isinstance(node, TaskRun): + runs.append(node) + for child in node.runs(): + collect(child) + elif hasattr(node, "runs"): + for child in node.runs(): + collect(child) + + collect(root) + return runs + + +def test_task_run_wrong_parent_type_raises(tmp_path): + project = Project(name="proj", path=tmp_path / "project.kiln") + project.save_to_file() + + with pytest.raises(ValidationError, match="Parent must be one of"): + TaskRun( + input="bad parent", + output=TaskOutput( + output="x", + source=DataSource( + type=DataSourceType.human, properties={"created_by": "test"} + ), + ), + parent=project, + ) + + +def test_task_run_runs_on_disk(tmp_path): + project = Project(name="proj", path=tmp_path / "project.kiln") + project.save_to_file() + task = Task(name="t", instruction="i", parent=project) + task.save_to_file() + + parent_run = TaskRun( + input="parent", + output=TaskOutput( + output="parent out", + source=DataSource( + type=DataSourceType.human, properties={"created_by": "test"} + ), + ), + parent=task, + ) + parent_run.save_to_file() + + child_run = TaskRun( + input="child", + output=TaskOutput( + output="child out", + source=DataSource( + type=DataSourceType.human, properties={"created_by": "test"} + ), + ), + parent=parent_run, + ) + child_run.save_to_file() + + loaded = TaskRun.load_from_file(parent_run.path) + children = loaded.runs() + assert len(children) == 1 + assert children[0].id == child_run.id diff --git a/libs/core/kiln_ai/datamodel/test_task.py b/libs/core/kiln_ai/datamodel/test_task.py index 1621db603..f3921f9b8 100644 --- a/libs/core/kiln_ai/datamodel/test_task.py +++ b/libs/core/kiln_ai/datamodel/test_task.py @@ -1,3 +1,5 @@ +import json + import pytest from pydantic import ValidationError @@ -6,6 +8,7 @@ StructuredOutputMode, TaskOutputRatingType, ) +from kiln_ai.datamodel.project import Project from kiln_ai.datamodel.prompt_id import PromptGenerators from kiln_ai.datamodel.run_config import KilnAgentRunConfigProperties from kiln_ai.datamodel.spec import Spec @@ -15,7 +18,8 @@ ToxicityProperties, ) from kiln_ai.datamodel.task import Task, TaskRunConfig -from kiln_ai.datamodel.task_output import normalize_rating +from kiln_ai.datamodel.task_output import TaskOutput, normalize_rating +from kiln_ai.datamodel.task_run import TaskRun def test_runconfig_valid_creation(): @@ -457,3 +461,307 @@ def test_task_prompt_optimization_jobs_readonly(tmp_path): assert ( prompt_optimization_jobs_default[0].name == "Readonly Prompt Optimization Job" ) + + +def test_all_children_of_parent_path_polymorphic(tmp_path): + """all_children_of_parent_path works correctly for polymorphic parent models.""" + # Test with TaskRun and Task + task = Task( + name="Test Task", + instruction="Test instruction", + path=tmp_path / "task.kiln", + ) + task.save_to_file() + + output = TaskOutput(output="test output") + + # Create direct children of Task + run1 = TaskRun(input="input1", output=output, parent=task) + run2 = TaskRun(input="input2", output=output, parent=task) + run1.save_to_file() + run2.save_to_file() + + children = TaskRun.all_children_of_parent_path(task.path) + assert len(children) == 2 + inputs = {child.input for child in children} + assert inputs == {"input1", "input2"} + + +def test_taskrun_nested_validates_parent_type_on_load(tmp_path): + """Loading a TaskRun validates its parent type is Task or TaskRun.""" + output = TaskOutput(output="test output") + + # Create a task + task = Task( + name="Test Task", + instruction="Test instruction", + path=tmp_path / "task.kiln", + ) + task.save_to_file() + + # Create parent TaskRun + parent_run = TaskRun(input="parent input", output=output, parent=task) + parent_run.save_to_file() + + # Create nested TaskRun + nested_run = TaskRun(input="nested input", output=output, parent=parent_run) + nested_run.save_to_file() + + # Reload from disk - parent type should be validated + # When loading from disk, the parent attribute points to the ultimate parent (Task) + # Use load_parent() to get the direct parent (TaskRun) + loaded_run = TaskRun.load_from_file(nested_run.path) + assert loaded_run is not None + assert loaded_run.input == "nested input" + # parent_task() returns the ultimate parent task (different instance but same data) + loaded_parent_task = loaded_run.parent_task() + assert loaded_parent_task is not None + assert loaded_parent_task.name == "Test Task" + assert loaded_parent_task.instruction == "Test instruction" + # Use load_parent() to get the direct parent TaskRun + direct_parent = loaded_run.load_parent() + assert direct_parent is not None + assert direct_parent.id == parent_run.id + assert direct_parent.input == "parent input" + + +def test_taskrun_loads_from_task_path(tmp_path): + """TaskRun children can be loaded from Task path (valid polymorphic parent).""" + output = TaskOutput(output="test output") + + task = Task( + name="Test Task", + instruction="Test instruction", + path=tmp_path / "task.kiln", + ) + task.save_to_file() + + run = TaskRun(input="test input", output=output, parent=task) + run.save_to_file() + + # Load children from task path - should succeed + children = list(TaskRun.iterate_children_paths_of_parent_path(task.path)) + assert len(children) == 1 + assert children[0] == run.path + + +def test_taskrun_loads_from_taskrun_path(tmp_path): + """TaskRun children can be loaded from TaskRun path (valid polymorphic parent).""" + output = TaskOutput(output="test output") + + task = Task( + name="Test Task", + instruction="Test instruction", + path=tmp_path / "task.kiln", + ) + task.save_to_file() + + parent_run = TaskRun(input="parent input", output=output, parent=task) + parent_run.save_to_file() + + nested_run = TaskRun(input="nested input", output=output, parent=parent_run) + nested_run.save_to_file() + + # Load children from TaskRun path - should succeed + children = list(TaskRun.iterate_children_paths_of_parent_path(parent_run.path)) + assert len(children) == 1 + assert children[0] == nested_run.path + + +def test_taskrun_fails_to_load_from_project_path(tmp_path): + """TaskRun children cannot be loaded from Project path (invalid polymorphic parent).""" + project_path = tmp_path / "project.kiln" + project = Project(name="Test Project", path=project_path) + project.save_to_file() + + # Try to load TaskRun children from a Project path - should fail + with pytest.raises(ValueError, match="Parent model_type 'project' is not one of"): + list(TaskRun.iterate_children_paths_of_parent_path(project_path)) + + +def test_multiple_nested_levels_validates_each_level(tmp_path): + """Multi-level nesting validates parent type at each level.""" + output = TaskOutput(output="test output") + + task = Task( + name="Test Task", + instruction="Test instruction", + path=tmp_path / "task.kiln", + ) + task.save_to_file() + + run1 = TaskRun(input="input1", output=output, parent=task) + run1.save_to_file() + + run2 = TaskRun(input="input2", output=output, parent=run1) + run2.save_to_file() + + run3 = TaskRun(input="input3", output=output, parent=run2) + run3.save_to_file() + + # Load children at each level - all should succeed + task_children = TaskRun.all_children_of_parent_path(task.path) + assert len(task_children) == 1 + assert task_children[0].id == run1.id + + run1_children = TaskRun.all_children_of_parent_path(run1.path) + assert len(run1_children) == 1 + assert run1_children[0].id == run2.id + + run2_children = TaskRun.all_children_of_parent_path(run2.path) + assert len(run2_children) == 1 + assert run2_children[0].id == run3.id + + +def test_polymorphic_parent_type_validation_fast_fail(tmp_path): + """Polymorphic validation fails fast without loading entire parent model.""" + # Create a file that's syntactically valid JSON but semantically invalid + # for the parent type - the polymorphic path should only read model_type + invalid_parent_path = tmp_path / "invalid.kiln" + + # Write a file that would fail if we tried to fully load as Task/TaskRun + # but should be caught by the model_type check first + invalid_data = { + "model_type": "project", # Wrong type - not Task or TaskRun + "extra_field": "this would cause issues", + } + with open(invalid_parent_path, "w") as f: + json.dump(invalid_data, f) + + # Should fail on model_type check, not on full load + with pytest.raises(ValueError, match="Parent model_type 'project' is not one of"): + list(TaskRun.iterate_children_paths_of_parent_path(invalid_parent_path)) + + +def test_load_parent_raises_on_malformed_taskrun(tmp_path): + """load_parent raises ValueError with context when parent TaskRun is malformed.""" + output = TaskOutput(output="test output") + + # Create a task + task = Task( + name="Test Task", + instruction="Test instruction", + path=tmp_path / "task.kiln", + ) + task.save_to_file() + + # Create parent TaskRun directory and valid run + parent_run_dir = tmp_path / "runs" / "parent_run" + parent_run_dir.mkdir(parents=True) + parent_run = TaskRun( + input="parent input", + output=output, + path=parent_run_dir / TaskRun.base_filename(), + ) + parent_run.save_to_file() + + # Create nested TaskRun directory and run (before corrupting parent) + nested_run_dir = parent_run_dir / "runs" / "nested_run" + nested_run_dir.mkdir(parents=True) + nested_run = TaskRun( + input="nested input", + output=output, + path=nested_run_dir / TaskRun.base_filename(), + ) + nested_run.save_to_file() + + # Verify it loads correctly with valid parent + loaded_nested = TaskRun.load_from_file(nested_run.path) + loaded_parent = loaded_nested.load_parent() + assert loaded_parent is not None + assert loaded_parent.input == "parent input" + + # Now corrupt the parent TaskRun file + with open(parent_run.path, "w") as f: + json.dump({"model_type": "task_run", "input": 123}, f) # Invalid input type + + # Reload nested run and try to load parent - should raise with context + loaded_nested = TaskRun.load_from_file(nested_run.path) + with pytest.raises(ValueError) as exc_info: + loaded_nested.load_parent() + + error_msg = str(exc_info.value) + assert "Failed to load parent TaskRun" in error_msg + assert str(parent_run.path) in error_msg + assert "malformed nested task run" in error_msg + + +def test_load_parent_succeeds_for_valid_taskrun_parent(tmp_path): + """load_parent successfully loads a valid TaskRun parent.""" + output = TaskOutput(output="test output") + + # Create a task + task = Task( + name="Test Task", + instruction="Test instruction", + path=tmp_path / "task.kiln", + ) + task.save_to_file() + + # Create parent TaskRun + parent_run = TaskRun(input="parent input", output=output, parent=task) + parent_run.save_to_file() + + # Create nested TaskRun + runs_dir = parent_run.path.parent / "runs" + runs_dir.mkdir(exist_ok=True) + nested_run_dir = runs_dir / "nested_run" + nested_run_dir.mkdir(exist_ok=True) + nested_run = TaskRun( + input="nested input", output=output, path=nested_run_dir / "task_run.kiln" + ) + nested_run.save_to_file() + + # Reload and load parent - should succeed + loaded_nested = TaskRun.load_from_file(nested_run.path) + loaded_parent = loaded_nested.load_parent() + + assert loaded_parent is not None + assert loaded_parent.id == parent_run.id + assert loaded_parent.input == "parent input" + assert isinstance(loaded_parent, TaskRun) + + +def test_is_root_task_run_true_when_parent_is_task(tmp_path): + """is_root_task_run returns True when parent is a Task.""" + output = TaskOutput(output="test output") + + task = Task( + name="Test Task", + instruction="Test instruction", + path=tmp_path / "task.kiln", + ) + task.save_to_file() + + run = TaskRun(input="test input", output=output, parent=task) + run.save_to_file() + + loaded_run = TaskRun.load_from_file(run.path) + assert loaded_run.is_root_task_run() is True + + +def test_is_root_task_run_false_when_parent_is_taskrun(tmp_path): + """is_root_task_run returns False when parent is another TaskRun.""" + output = TaskOutput(output="test output") + + task = Task( + name="Test Task", + instruction="Test instruction", + path=tmp_path / "task.kiln", + ) + task.save_to_file() + + parent_run = TaskRun(input="parent input", output=output, parent=task) + parent_run.save_to_file() + + runs_dir = parent_run.path.parent / "runs" + runs_dir.mkdir(exist_ok=True) + nested_run_dir = runs_dir / "nested_run" + nested_run_dir.mkdir(exist_ok=True) + nested_run = TaskRun( + input="nested input", output=output, path=nested_run_dir / "task_run.kiln" + ) + nested_run.save_to_file() + + loaded_nested = TaskRun.load_from_file(nested_run.path) + assert loaded_nested.is_root_task_run() is False