Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: cannot repair CSV-imported runs #254

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 13 additions & 7 deletions libs/core/kiln_ai/adapters/repair/repair_task.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
import json
from typing import Type

from pydantic import BaseModel, Field

from kiln_ai.adapters.prompt_builders import (
BasePromptBuilder,
SavedPromptBuilder,
PromptGenerators,
prompt_builder_from_id,
)
from kiln_ai.datamodel import Priority, Project, Task, TaskRequirement, TaskRun
Expand Down Expand Up @@ -44,16 +43,23 @@ def __init__(self, original_task: Task):
output_json_schema=original_task.output_json_schema,
)

@classmethod
def _get_prompt_id(cls, source_properties: dict) -> str:
"""Extract the prompt ID from source properties, falling back to simple if none provided."""
return (
source_properties.get("prompt_id")
or source_properties.get("prompt_builder_name")
or PromptGenerators.SIMPLE.value # some sources can have no prompt_id or prompt_builder_name
)

@classmethod
def _original_prompt(cls, run: TaskRun, task: Task) -> str:
if run.output.source is None or run.output.source.properties is None:
raise ValueError("No source properties found")

# Get the prompt builder id. Need the second check because we used to store this in a prompt_builder_name field, so loading legacy runs will need this.
prompt_id = run.output.source.properties.get(
"prompt_id"
) or run.output.source.properties.get("prompt_builder_name", None)
if prompt_id is not None and isinstance(prompt_id, str):
prompt_id = cls._get_prompt_id(run.output.source.properties)

if isinstance(prompt_id, str):
prompt_builder = prompt_builder_from_id(prompt_id, task)
if isinstance(prompt_builder, BasePromptBuilder):
return prompt_builder.build_prompt(include_json_instructions=False)
Expand Down
112 changes: 108 additions & 4 deletions libs/core/kiln_ai/adapters/repair/test_repair_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,8 @@
from kiln_ai.adapters.adapter_registry import adapter_for_task
from kiln_ai.adapters.model_adapters.base_adapter import RunOutput
from kiln_ai.adapters.model_adapters.litellm_adapter import LiteLlmAdapter
from kiln_ai.adapters.repair.repair_task import (
RepairTaskInput,
RepairTaskRun,
)
from kiln_ai.adapters.prompt_builders import prompt_builder_from_id
from kiln_ai.adapters.repair.repair_task import RepairTaskInput, RepairTaskRun
from kiln_ai.datamodel import (
DataSource,
DataSourceType,
Expand Down Expand Up @@ -104,6 +102,56 @@ def sample_task_run(sample_task):
return task_run


@pytest.fixture
def sample_task_run_no_prompt_id(sample_task):
task_run = TaskRun(
parent=sample_task,
input='{"topic": "chicken"}',
input_source=DataSource(
type=DataSourceType.file_import,
properties={
"file_name": "test_file.csv",
},
),
output=TaskOutput(
output='{"setup": "Why did the chicken cross the road?", "punchline": "To get to the other side", "rating": null}',
source=DataSource(
type=DataSourceType.file_import,
properties={
"file_name": "test_file.csv",
},
),
),
)
task_run.save_to_file()
return task_run


@pytest.fixture
def sample_task_run_invalid_prompt_id(sample_task):
task_run = TaskRun(
parent=sample_task,
input='{"topic": "chicken"}',
input_source=DataSource(
type=DataSourceType.human, properties={"created_by": "Jane Doe"}
),
output=TaskOutput(
output='{"setup": "Why did the chicken cross the road?", "punchline": "To get to the other side", "rating": null}',
source=DataSource(
type=DataSourceType.synthetic,
properties={
"model_name": "gpt_4o",
"model_provider": "openai",
"adapter_name": "langchain_adapter",
"prompt_id": "invalid_prompt_id",
},
),
),
)
task_run.save_to_file()
return task_run


@pytest.fixture
def sample_repair_data(sample_task, sample_task_run):
return {
Expand All @@ -130,6 +178,40 @@ def test_build_repair_task_input(sample_repair_data):
)


def test_build_repair_task_input_with_no_prompt_id(
sample_task, sample_task_run_no_prompt_id
):
with patch(
"kiln_ai.adapters.repair.repair_task.prompt_builder_from_id",
wraps=prompt_builder_from_id,
) as mock_prompt_builder_from_id:
result = RepairTaskRun.build_repair_task_input(
original_task=sample_task,
task_run=sample_task_run_no_prompt_id,
evaluator_feedback="The joke is too cliché. Please come up with a more original chicken-related joke.",
)

# verify we fallback to simple prompt builder
mock_prompt_builder_from_id.assert_called_once_with(
"simple_prompt_builder", sample_task
)
# verify we got a valid result
assert isinstance(result, RepairTaskInput)


def test_build_repair_task_input_with_invalid_prompt_id(
sample_task, sample_task_run_invalid_prompt_id
):
with pytest.raises(
ValueError,
):
RepairTaskRun.build_repair_task_input(
original_task=sample_task,
task_run=sample_task_run_invalid_prompt_id,
evaluator_feedback="The joke is too cliché. Please come up with a more original chicken-related joke.",
)


def test_repair_input_schema():
schema = RepairTaskInput.model_json_schema()
assert schema["type"] == "object"
Expand Down Expand Up @@ -245,3 +327,25 @@ async def test_mocked_repair_task_run(sample_task, sample_task_run, sample_repai

# Verify that the mock was called
mock_run.assert_called_once()


@pytest.mark.parametrize(
"properties,expected",
[
# Should use the prompt id if it is provided
({"prompt_id": "multi_shot_prompt_builder"}, "multi_shot_prompt_builder"),
({"prompt_builder_name": "legacy_prompt"}, "legacy_prompt"),
# Should fall back to SIMPLE if no prompt_id or prompt_builder_name is provided
({}, "simple_prompt_builder"),
({"some_other_field": "value"}, "simple_prompt_builder"),
],
)
def test_get_prompt_id_fallbacks(properties, expected):
result = RepairTaskRun._get_prompt_id(properties)
assert result == expected


def test_get_prompt_id_precedence():
properties = {"prompt_id": "new_prompt", "prompt_builder_name": "old_prompt"}
result = RepairTaskRun._get_prompt_id(properties)
assert result == "new_prompt"
Loading