Skip to content
42 changes: 31 additions & 11 deletions app/desktop/studio_server/finetune_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,10 @@
import httpx
from fastapi import FastAPI, HTTPException, Query
from fastapi.responses import StreamingResponse
from kiln_ai.adapters.adapter_registry import (
load_skills_for_task,
load_skills_from_tool_ids,
)
from kiln_ai.adapters.fine_tune.base_finetune import FineTuneParameter, FineTuneStatus
from kiln_ai.adapters.fine_tune.dataset_formatter import (
DatasetFormat,
Expand All @@ -18,10 +22,6 @@
ModelProviderName,
built_in_models,
)
from kiln_ai.adapters.adapter_registry import (
load_skills_for_task,
load_skills_from_tool_ids,
)
from kiln_ai.adapters.prompt_builders import (
chain_of_thought_prompt,
prompt_builder_from_id,
Expand Down Expand Up @@ -200,7 +200,8 @@ def compute_finetune_tag_info(
high_quality_count: Dict[str, int] = {}
reasoning_and_high_quality_count: Dict[str, int] = {}

required_tools_set = set(tool_filter) if tool_filter else None
# None means no filter; [] means explicitly match runs with no tools/skills.
required_tools_set = None if tool_filter is None else set(tool_filter)

for sample in task.runs(readonly=True):
# filter by tools if provided
Expand Down Expand Up @@ -376,7 +377,14 @@ async def finetune_dataset_info(
project_id: str,
task_id: str,
tool_ids: Annotated[list[str] | None, Query()] = None,
empty_tool_filter: bool = False,
) -> FinetuneDatasetInfo:
# In the fine-tune UI, "no tools/skills selected" should mean `tool_ids=[]`,
# but `openapi-fetch` omits empty arrays, so we recover that state from
# `empty_tool_filter=true`.
if empty_tool_filter and tool_ids is None:
tool_ids = []

task = task_from_id(project_id, task_id)
# Only include datasets that is part of a finetune.
# Orphan datasets are created when user creates a dataset but didn't create a finetune.
Expand All @@ -394,13 +402,20 @@ async def finetune_dataset_info(
eligible_finetune_tags = compute_finetune_tag_info(task, tool_filter=tool_ids)

eligible_datasets = existing_datasets
if tool_ids:
# Only filter datasets when the caller provided a tool/skill selection.
# `tool_ids=[]` is a real filter meaning "match datasets with no tools/skills".
if tool_ids is not None:
required_tools_set = set(tool_ids)
eligible_datasets = [
dataset
for dataset in existing_datasets
if set(dataset.tool_info().tools) == required_tools_set
]
eligible_datasets = []
for dataset in existing_datasets:
tool_info = dataset.tool_info()
# Reusable datasets must have a uniform tool/skill set.
# `tool_info.tools=None` means the dataset mixes different tool/skill selections.
if (
tool_info.tools is not None
and set(tool_info.tools) == required_tools_set
):
eligible_datasets.append(dataset)

return FinetuneDatasetInfo(
existing_datasets=existing_datasets,
Expand Down Expand Up @@ -528,6 +543,11 @@ async def download_dataset_jsonl(
)

tool_info = dataset.tool_info()
if tool_info.tools is None:
raise HTTPException(
status_code=400,
detail="Dataset contains mixed tool/skill selections and cannot be exported",
)
skills_dict = load_skills_from_tool_ids(task, tool_info.tools)
skills = list(skills_dict.values())

Expand Down
Loading
Loading