diff --git a/app/desktop/desktop_server.py b/app/desktop/desktop_server.py index bbb6b3910..dd52ce35b 100644 --- a/app/desktop/desktop_server.py +++ b/app/desktop/desktop_server.py @@ -14,6 +14,7 @@ from kiln_ai.utils.logging import setup_litellm_logging from app.desktop.log_config import log_config +from app.desktop.studio_server.chat_api import connect_chat_api from app.desktop.studio_server.copilot_api import connect_copilot_api from app.desktop.studio_server.data_gen_api import connect_data_gen_api from app.desktop.studio_server.dev_tools import connect_dev_tools @@ -71,6 +72,7 @@ def make_app(tk_root: tk.Tk | None = None): connect_import_api(app, tk_root=tk_root) connect_tool_servers_api(app) connect_prompt_optimization_job_api(app) + connect_chat_api(app) connect_copilot_api(app) connect_dev_tools(app) diff --git a/app/desktop/studio_server/api_client/kiln_ai_server_client/api/chat/__init__.py b/app/desktop/studio_server/api_client/kiln_ai_server_client/api/chat/__init__.py new file mode 100644 index 000000000..2d7c0b23d --- /dev/null +++ b/app/desktop/studio_server/api_client/kiln_ai_server_client/api/chat/__init__.py @@ -0,0 +1 @@ +"""Contains endpoint functions for accessing the API""" diff --git a/app/desktop/studio_server/api_client/kiln_ai_server_client/api/chat/handle_chat_v1_chat_post.py b/app/desktop/studio_server/api_client/kiln_ai_server_client/api/chat/handle_chat_v1_chat_post.py new file mode 100644 index 000000000..e621ed42c --- /dev/null +++ b/app/desktop/studio_server/api_client/kiln_ai_server_client/api/chat/handle_chat_v1_chat_post.py @@ -0,0 +1,164 @@ +from http import HTTPStatus +from typing import Any + +import httpx + +from ... import errors +from ...client import AuthenticatedClient, Client +from ...models.chat_request import ChatRequest +from ...models.http_validation_error import HTTPValidationError +from ...types import Response + + +def _get_kwargs( + *, + body: ChatRequest, +) -> dict[str, Any]: + headers: dict[str, Any] = {} + + _kwargs: dict[str, Any] = { + "method": "post", + "url": "/v1/chat/", + } + + _kwargs["json"] = body.to_dict() + + headers["Content-Type"] = "application/json" + + _kwargs["headers"] = headers + return _kwargs + + +def _parse_response( + *, client: AuthenticatedClient | Client, response: httpx.Response +) -> Any | HTTPValidationError | None: + if response.status_code == 200: + response_200 = response.json() + return response_200 + + if response.status_code == 422: + response_422 = HTTPValidationError.from_dict(response.json()) + + return response_422 + + if client.raise_on_unexpected_status: + raise errors.UnexpectedStatus(response.status_code, response.content) + else: + return None + + +def _build_response( + *, client: AuthenticatedClient | Client, response: httpx.Response +) -> Response[Any | HTTPValidationError]: + return Response( + status_code=HTTPStatus(response.status_code), + content=response.content, + headers=response.headers, + parsed=_parse_response(client=client, response=response), + ) + + +def sync_detailed( + *, + client: AuthenticatedClient, + body: ChatRequest, +) -> Response[Any | HTTPValidationError]: + """Handle Chat + + Args: + body (ChatRequest): + + Raises: + errors.UnexpectedStatus: If the server returns an undocumented status code and Client.raise_on_unexpected_status is True. + httpx.TimeoutException: If the request takes longer than Client.timeout. + + Returns: + Response[Any | HTTPValidationError] + """ + + kwargs = _get_kwargs( + body=body, + ) + + response = client.get_httpx_client().request( + **kwargs, + ) + + return _build_response(client=client, response=response) + + +def sync( + *, + client: AuthenticatedClient, + body: ChatRequest, +) -> Any | HTTPValidationError | None: + """Handle Chat + + Args: + body (ChatRequest): + + Raises: + errors.UnexpectedStatus: If the server returns an undocumented status code and Client.raise_on_unexpected_status is True. + httpx.TimeoutException: If the request takes longer than Client.timeout. + + Returns: + Any | HTTPValidationError + """ + + return sync_detailed( + client=client, + body=body, + ).parsed + + +async def asyncio_detailed( + *, + client: AuthenticatedClient, + body: ChatRequest, +) -> Response[Any | HTTPValidationError]: + """Handle Chat + + Args: + body (ChatRequest): + + Raises: + errors.UnexpectedStatus: If the server returns an undocumented status code and Client.raise_on_unexpected_status is True. + httpx.TimeoutException: If the request takes longer than Client.timeout. + + Returns: + Response[Any | HTTPValidationError] + """ + + kwargs = _get_kwargs( + body=body, + ) + + response = await client.get_async_httpx_client().request(**kwargs) + + return _build_response(client=client, response=response) + + +async def asyncio( + *, + client: AuthenticatedClient, + body: ChatRequest, +) -> Any | HTTPValidationError | None: + """Handle Chat + + Args: + body (ChatRequest): + + Raises: + errors.UnexpectedStatus: If the server returns an undocumented status code and Client.raise_on_unexpected_status is True. + httpx.TimeoutException: If the request takes longer than Client.timeout. + + Returns: + Any | HTTPValidationError + """ + + return ( + await asyncio_detailed( + client=client, + body=body, + ) + ).parsed diff --git a/app/desktop/studio_server/api_client/kiln_ai_server_client/models/__init__.py b/app/desktop/studio_server/api_client/kiln_ai_server_client/models/__init__.py index d793ae480..5bd7ed2e2 100644 --- a/app/desktop/studio_server/api_client/kiln_ai_server_client/models/__init__.py +++ b/app/desktop/studio_server/api_client/kiln_ai_server_client/models/__init__.py @@ -7,12 +7,15 @@ BodyStartPromptOptimizationJobV1JobsPromptOptimizationJobStartPost, ) from .body_start_sample_job_v1_jobs_sample_job_start_post import BodyStartSampleJobV1JobsSampleJobStartPost +from .chat_request import ChatRequest from .check_entitlements_v1_check_entitlements_get_response_check_entitlements_v1_check_entitlements_get import ( CheckEntitlementsV1CheckEntitlementsGetResponseCheckEntitlementsV1CheckEntitlementsGet, ) from .check_model_supported_response import CheckModelSupportedResponse from .clarify_spec_input import ClarifySpecInput from .clarify_spec_output import ClarifySpecOutput +from .client_message import ClientMessage +from .client_message_part import ClientMessagePart from .examples_for_feedback_item import ExamplesForFeedbackItem from .examples_with_feedback_item import ExamplesWithFeedbackItem from .generate_batch_input import GenerateBatchInput @@ -51,6 +54,8 @@ from .synthetic_data_generation_step_config_input import SyntheticDataGenerationStepConfigInput from .task_info import TaskInfo from .task_metadata import TaskMetadata +from .tool_invocation import ToolInvocation +from .tool_invocation_state import ToolInvocationState from .validation_error import ValidationError __all__ = ( @@ -59,10 +64,13 @@ "ApiKeyVerificationResult", "BodyStartPromptOptimizationJobV1JobsPromptOptimizationJobStartPost", "BodyStartSampleJobV1JobsSampleJobStartPost", + "ChatRequest", "CheckEntitlementsV1CheckEntitlementsGetResponseCheckEntitlementsV1CheckEntitlementsGet", "CheckModelSupportedResponse", "ClarifySpecInput", "ClarifySpecOutput", + "ClientMessage", + "ClientMessagePart", "ExamplesForFeedbackItem", "ExamplesWithFeedbackItem", "GenerateBatchInput", @@ -101,5 +109,7 @@ "SyntheticDataGenerationStepConfigInput", "TaskInfo", "TaskMetadata", + "ToolInvocation", + "ToolInvocationState", "ValidationError", ) diff --git a/app/desktop/studio_server/api_client/kiln_ai_server_client/models/chat_request.py b/app/desktop/studio_server/api_client/kiln_ai_server_client/models/chat_request.py new file mode 100644 index 000000000..aaab8fe61 --- /dev/null +++ b/app/desktop/studio_server/api_client/kiln_ai_server_client/models/chat_request.py @@ -0,0 +1,97 @@ +from __future__ import annotations + +from collections.abc import Mapping +from typing import TYPE_CHECKING, Any, TypeVar, cast + +from attrs import define as _attrs_define +from attrs import field as _attrs_field + +from ..types import UNSET, Unset + +if TYPE_CHECKING: + from ..models.client_message import ClientMessage + + +T = TypeVar("T", bound="ChatRequest") + + +@_attrs_define +class ChatRequest: + """ + Attributes: + messages (list[ClientMessage]): + task_id (None | str | Unset): + """ + + messages: list[ClientMessage] + task_id: None | str | Unset = UNSET + additional_properties: dict[str, Any] = _attrs_field(init=False, factory=dict) + + def to_dict(self) -> dict[str, Any]: + messages = [] + for messages_item_data in self.messages: + messages_item = messages_item_data.to_dict() + messages.append(messages_item) + + task_id: None | str | Unset + if isinstance(self.task_id, Unset): + task_id = UNSET + else: + task_id = self.task_id + + field_dict: dict[str, Any] = {} + field_dict.update(self.additional_properties) + field_dict.update( + { + "messages": messages, + } + ) + if task_id is not UNSET: + field_dict["task_id"] = task_id + + return field_dict + + @classmethod + def from_dict(cls: type[T], src_dict: Mapping[str, Any]) -> T: + from ..models.client_message import ClientMessage + + d = dict(src_dict) + messages = [] + _messages = d.pop("messages") + for messages_item_data in _messages: + messages_item = ClientMessage.from_dict(messages_item_data) + + messages.append(messages_item) + + def _parse_task_id(data: object) -> None | str | Unset: + if data is None: + return data + if isinstance(data, Unset): + return data + return cast(None | str | Unset, data) + + task_id = _parse_task_id(d.pop("task_id", UNSET)) + + chat_request = cls( + messages=messages, + task_id=task_id, + ) + + chat_request.additional_properties = d + return chat_request + + @property + def additional_keys(self) -> list[str]: + return list(self.additional_properties.keys()) + + def __getitem__(self, key: str) -> Any: + return self.additional_properties[key] + + def __setitem__(self, key: str, value: Any) -> None: + self.additional_properties[key] = value + + def __delitem__(self, key: str) -> None: + del self.additional_properties[key] + + def __contains__(self, key: str) -> bool: + return key in self.additional_properties diff --git a/app/desktop/studio_server/api_client/kiln_ai_server_client/models/client_message.py b/app/desktop/studio_server/api_client/kiln_ai_server_client/models/client_message.py new file mode 100644 index 000000000..ba9847cf1 --- /dev/null +++ b/app/desktop/studio_server/api_client/kiln_ai_server_client/models/client_message.py @@ -0,0 +1,169 @@ +from __future__ import annotations + +from collections.abc import Mapping +from typing import TYPE_CHECKING, Any, TypeVar, cast + +from attrs import define as _attrs_define +from attrs import field as _attrs_field + +from ..types import UNSET, Unset + +if TYPE_CHECKING: + from ..models.client_message_part import ClientMessagePart + from ..models.tool_invocation import ToolInvocation + + +T = TypeVar("T", bound="ClientMessage") + + +@_attrs_define +class ClientMessage: + """ + Attributes: + role (str): + content (None | str | Unset): + parts (list[ClientMessagePart] | None | Unset): + tool_invocations (list[ToolInvocation] | None | Unset): + """ + + role: str + content: None | str | Unset = UNSET + parts: list[ClientMessagePart] | None | Unset = UNSET + tool_invocations: list[ToolInvocation] | None | Unset = UNSET + additional_properties: dict[str, Any] = _attrs_field(init=False, factory=dict) + + def to_dict(self) -> dict[str, Any]: + role = self.role + + content: None | str | Unset + if isinstance(self.content, Unset): + content = UNSET + else: + content = self.content + + parts: list[dict[str, Any]] | None | Unset + if isinstance(self.parts, Unset): + parts = UNSET + elif isinstance(self.parts, list): + parts = [] + for parts_type_0_item_data in self.parts: + parts_type_0_item = parts_type_0_item_data.to_dict() + parts.append(parts_type_0_item) + + else: + parts = self.parts + + tool_invocations: list[dict[str, Any]] | None | Unset + if isinstance(self.tool_invocations, Unset): + tool_invocations = UNSET + elif isinstance(self.tool_invocations, list): + tool_invocations = [] + for tool_invocations_type_0_item_data in self.tool_invocations: + tool_invocations_type_0_item = tool_invocations_type_0_item_data.to_dict() + tool_invocations.append(tool_invocations_type_0_item) + + else: + tool_invocations = self.tool_invocations + + field_dict: dict[str, Any] = {} + field_dict.update(self.additional_properties) + field_dict.update( + { + "role": role, + } + ) + if content is not UNSET: + field_dict["content"] = content + if parts is not UNSET: + field_dict["parts"] = parts + if tool_invocations is not UNSET: + field_dict["toolInvocations"] = tool_invocations + + return field_dict + + @classmethod + def from_dict(cls: type[T], src_dict: Mapping[str, Any]) -> T: + from ..models.client_message_part import ClientMessagePart + from ..models.tool_invocation import ToolInvocation + + d = dict(src_dict) + role = d.pop("role") + + def _parse_content(data: object) -> None | str | Unset: + if data is None: + return data + if isinstance(data, Unset): + return data + return cast(None | str | Unset, data) + + content = _parse_content(d.pop("content", UNSET)) + + def _parse_parts(data: object) -> list[ClientMessagePart] | None | Unset: + if data is None: + return data + if isinstance(data, Unset): + return data + try: + if not isinstance(data, list): + raise TypeError() + parts_type_0 = [] + _parts_type_0 = data + for parts_type_0_item_data in _parts_type_0: + parts_type_0_item = ClientMessagePart.from_dict(parts_type_0_item_data) + + parts_type_0.append(parts_type_0_item) + + return parts_type_0 + except (TypeError, ValueError, AttributeError, KeyError): + pass + return cast(list[ClientMessagePart] | None | Unset, data) + + parts = _parse_parts(d.pop("parts", UNSET)) + + def _parse_tool_invocations(data: object) -> list[ToolInvocation] | None | Unset: + if data is None: + return data + if isinstance(data, Unset): + return data + try: + if not isinstance(data, list): + raise TypeError() + tool_invocations_type_0 = [] + _tool_invocations_type_0 = data + for tool_invocations_type_0_item_data in _tool_invocations_type_0: + tool_invocations_type_0_item = ToolInvocation.from_dict(tool_invocations_type_0_item_data) + + tool_invocations_type_0.append(tool_invocations_type_0_item) + + return tool_invocations_type_0 + except (TypeError, ValueError, AttributeError, KeyError): + pass + return cast(list[ToolInvocation] | None | Unset, data) + + tool_invocations = _parse_tool_invocations(d.pop("toolInvocations", UNSET)) + + client_message = cls( + role=role, + content=content, + parts=parts, + tool_invocations=tool_invocations, + ) + + client_message.additional_properties = d + return client_message + + @property + def additional_keys(self) -> list[str]: + return list(self.additional_properties.keys()) + + def __getitem__(self, key: str) -> Any: + return self.additional_properties[key] + + def __setitem__(self, key: str, value: Any) -> None: + self.additional_properties[key] = value + + def __delitem__(self, key: str) -> None: + del self.additional_properties[key] + + def __contains__(self, key: str) -> bool: + return key in self.additional_properties diff --git a/app/desktop/studio_server/api_client/kiln_ai_server_client/models/client_message_part.py b/app/desktop/studio_server/api_client/kiln_ai_server_client/models/client_message_part.py new file mode 100644 index 000000000..67880d9cd --- /dev/null +++ b/app/desktop/studio_server/api_client/kiln_ai_server_client/models/client_message_part.py @@ -0,0 +1,219 @@ +from __future__ import annotations + +from collections.abc import Mapping +from typing import Any, TypeVar, cast + +from attrs import define as _attrs_define +from attrs import field as _attrs_field + +from ..types import UNSET, Unset + +T = TypeVar("T", bound="ClientMessagePart") + + +@_attrs_define +class ClientMessagePart: + """ + Attributes: + type_ (str): + text (None | str | Unset): + content_type (None | str | Unset): + url (None | str | Unset): + data (Any | Unset): + tool_call_id (None | str | Unset): + tool_name (None | str | Unset): + state (None | str | Unset): + input_ (Any | Unset): + output (Any | Unset): + args (Any | Unset): + """ + + type_: str + text: None | str | Unset = UNSET + content_type: None | str | Unset = UNSET + url: None | str | Unset = UNSET + data: Any | Unset = UNSET + tool_call_id: None | str | Unset = UNSET + tool_name: None | str | Unset = UNSET + state: None | str | Unset = UNSET + input_: Any | Unset = UNSET + output: Any | Unset = UNSET + args: Any | Unset = UNSET + additional_properties: dict[str, Any] = _attrs_field(init=False, factory=dict) + + def to_dict(self) -> dict[str, Any]: + type_ = self.type_ + + text: None | str | Unset + if isinstance(self.text, Unset): + text = UNSET + else: + text = self.text + + content_type: None | str | Unset + if isinstance(self.content_type, Unset): + content_type = UNSET + else: + content_type = self.content_type + + url: None | str | Unset + if isinstance(self.url, Unset): + url = UNSET + else: + url = self.url + + data = self.data + + tool_call_id: None | str | Unset + if isinstance(self.tool_call_id, Unset): + tool_call_id = UNSET + else: + tool_call_id = self.tool_call_id + + tool_name: None | str | Unset + if isinstance(self.tool_name, Unset): + tool_name = UNSET + else: + tool_name = self.tool_name + + state: None | str | Unset + if isinstance(self.state, Unset): + state = UNSET + else: + state = self.state + + input_ = self.input_ + + output = self.output + + args = self.args + + field_dict: dict[str, Any] = {} + field_dict.update(self.additional_properties) + field_dict.update( + { + "type": type_, + } + ) + if text is not UNSET: + field_dict["text"] = text + if content_type is not UNSET: + field_dict["contentType"] = content_type + if url is not UNSET: + field_dict["url"] = url + if data is not UNSET: + field_dict["data"] = data + if tool_call_id is not UNSET: + field_dict["toolCallId"] = tool_call_id + if tool_name is not UNSET: + field_dict["toolName"] = tool_name + if state is not UNSET: + field_dict["state"] = state + if input_ is not UNSET: + field_dict["input"] = input_ + if output is not UNSET: + field_dict["output"] = output + if args is not UNSET: + field_dict["args"] = args + + return field_dict + + @classmethod + def from_dict(cls: type[T], src_dict: Mapping[str, Any]) -> T: + d = dict(src_dict) + type_ = d.pop("type") + + def _parse_text(data: object) -> None | str | Unset: + if data is None: + return data + if isinstance(data, Unset): + return data + return cast(None | str | Unset, data) + + text = _parse_text(d.pop("text", UNSET)) + + def _parse_content_type(data: object) -> None | str | Unset: + if data is None: + return data + if isinstance(data, Unset): + return data + return cast(None | str | Unset, data) + + content_type = _parse_content_type(d.pop("contentType", UNSET)) + + def _parse_url(data: object) -> None | str | Unset: + if data is None: + return data + if isinstance(data, Unset): + return data + return cast(None | str | Unset, data) + + url = _parse_url(d.pop("url", UNSET)) + + data = d.pop("data", UNSET) + + def _parse_tool_call_id(data: object) -> None | str | Unset: + if data is None: + return data + if isinstance(data, Unset): + return data + return cast(None | str | Unset, data) + + tool_call_id = _parse_tool_call_id(d.pop("toolCallId", UNSET)) + + def _parse_tool_name(data: object) -> None | str | Unset: + if data is None: + return data + if isinstance(data, Unset): + return data + return cast(None | str | Unset, data) + + tool_name = _parse_tool_name(d.pop("toolName", UNSET)) + + def _parse_state(data: object) -> None | str | Unset: + if data is None: + return data + if isinstance(data, Unset): + return data + return cast(None | str | Unset, data) + + state = _parse_state(d.pop("state", UNSET)) + + input_ = d.pop("input", UNSET) + + output = d.pop("output", UNSET) + + args = d.pop("args", UNSET) + + client_message_part = cls( + type_=type_, + text=text, + content_type=content_type, + url=url, + data=data, + tool_call_id=tool_call_id, + tool_name=tool_name, + state=state, + input_=input_, + output=output, + args=args, + ) + + client_message_part.additional_properties = d + return client_message_part + + @property + def additional_keys(self) -> list[str]: + return list(self.additional_properties.keys()) + + def __getitem__(self, key: str) -> Any: + return self.additional_properties[key] + + def __setitem__(self, key: str, value: Any) -> None: + self.additional_properties[key] = value + + def __delitem__(self, key: str) -> None: + del self.additional_properties[key] + + def __contains__(self, key: str) -> bool: + return key in self.additional_properties diff --git a/app/desktop/studio_server/api_client/kiln_ai_server_client/models/tool_invocation.py b/app/desktop/studio_server/api_client/kiln_ai_server_client/models/tool_invocation.py new file mode 100644 index 000000000..775865148 --- /dev/null +++ b/app/desktop/studio_server/api_client/kiln_ai_server_client/models/tool_invocation.py @@ -0,0 +1,95 @@ +from __future__ import annotations + +from collections.abc import Mapping +from typing import Any, TypeVar + +from attrs import define as _attrs_define +from attrs import field as _attrs_field + +from ..models.tool_invocation_state import ToolInvocationState + +T = TypeVar("T", bound="ToolInvocation") + + +@_attrs_define +class ToolInvocation: + """ + Attributes: + state (ToolInvocationState): + tool_call_id (str): + tool_name (str): + args (Any): + result (Any): + """ + + state: ToolInvocationState + tool_call_id: str + tool_name: str + args: Any + result: Any + additional_properties: dict[str, Any] = _attrs_field(init=False, factory=dict) + + def to_dict(self) -> dict[str, Any]: + state = self.state.value + + tool_call_id = self.tool_call_id + + tool_name = self.tool_name + + args = self.args + + result = self.result + + field_dict: dict[str, Any] = {} + field_dict.update(self.additional_properties) + field_dict.update( + { + "state": state, + "toolCallId": tool_call_id, + "toolName": tool_name, + "args": args, + "result": result, + } + ) + + return field_dict + + @classmethod + def from_dict(cls: type[T], src_dict: Mapping[str, Any]) -> T: + d = dict(src_dict) + state = ToolInvocationState(d.pop("state")) + + tool_call_id = d.pop("toolCallId") + + tool_name = d.pop("toolName") + + args = d.pop("args") + + result = d.pop("result") + + tool_invocation = cls( + state=state, + tool_call_id=tool_call_id, + tool_name=tool_name, + args=args, + result=result, + ) + + tool_invocation.additional_properties = d + return tool_invocation + + @property + def additional_keys(self) -> list[str]: + return list(self.additional_properties.keys()) + + def __getitem__(self, key: str) -> Any: + return self.additional_properties[key] + + def __setitem__(self, key: str, value: Any) -> None: + self.additional_properties[key] = value + + def __delitem__(self, key: str) -> None: + del self.additional_properties[key] + + def __contains__(self, key: str) -> bool: + return key in self.additional_properties diff --git a/app/desktop/studio_server/api_client/kiln_ai_server_client/models/tool_invocation_state.py b/app/desktop/studio_server/api_client/kiln_ai_server_client/models/tool_invocation_state.py new file mode 100644 index 000000000..eed1ef25c --- /dev/null +++ b/app/desktop/studio_server/api_client/kiln_ai_server_client/models/tool_invocation_state.py @@ -0,0 +1,10 @@ +from enum import Enum + + +class ToolInvocationState(str, Enum): + CALL = "call" + PARTIAL_CALL = "partial-call" + RESULT = "result" + + def __str__(self) -> str: + return str(self.value) diff --git a/app/desktop/studio_server/chat_api.py b/app/desktop/studio_server/chat_api.py new file mode 100644 index 000000000..1457171b5 --- /dev/null +++ b/app/desktop/studio_server/chat_api.py @@ -0,0 +1,209 @@ +import json +import logging +from typing import Any + +import httpx +from app.desktop.studio_server.api_client.kiln_server_client import ( + _get_base_url, + _get_common_headers, +) +from app.desktop.studio_server.utils.copilot_utils import get_copilot_api_key +from fastapi import FastAPI, Request +from fastapi.responses import StreamingResponse +from kiln_ai.datamodel import Project, TaskRun +from kiln_ai.utils.config import Config + +logger = logging.getLogger(__name__) + +_CHAT_TIMEOUT = httpx.Timeout(timeout=300.0, connect=30.0) +_MAX_CLIENT_TOOL_ROUNDS = 5 + + +def _build_upstream_headers(api_key: str) -> dict[str, str]: + return { + **_get_common_headers(), + "Authorization": f"Bearer {api_key}", + "Content-Type": "application/json", + } + + +def _find_task_run_by_id(task_run_id: str) -> TaskRun | None: + """Search all projects and tasks for a task run with the given ID.""" + project_paths = Config.shared().projects or [] + for project_path in project_paths: + try: + project = Project.load_from_file(project_path) + except Exception: + continue + for task in project.tasks(): + run = TaskRun.from_id_and_parent_path(task_run_id, task.path) + if run is not None: + return run + return None + + +def _execute_client_tool(tool_name: str, arguments: dict[str, Any]) -> str: + """Execute a client-side tool and return the result as a string.""" + if tool_name == "read_task_run": + task_run_id = arguments.get("task_run_id", "") + if not task_run_id: + return json.dumps({"error": "task_run_id is required"}) + try: + run = _find_task_run_by_id(task_run_id) + if run is None: + return json.dumps({"error": f"Task run not found: {task_run_id}"}) + return run.model_dump_json(indent=2) + except Exception as e: + return json.dumps({"error": f"Failed to read task run: {e}"}) + return json.dumps({"error": f"Unknown client tool: {tool_name}"}) + + +def _parse_sse_events( + raw: bytes, +) -> tuple[list[bytes], dict[str, Any] | None]: + """Parse raw SSE bytes into passthrough lines and an optional client-tool-call event. + + Returns (lines_to_forward, client_tool_event_or_none). + """ + lines_to_forward: list[bytes] = [] + client_tool_event: dict[str, Any] | None = None + + for line in raw.split(b"\n"): + if line.startswith(b"data: "): + payload = line[6:].strip() + if payload and payload != b"[DONE]": + try: + event = json.loads(payload) + if ( + isinstance(event, dict) + and event.get("type") == "client-tool-call" + ): + client_tool_event = event + continue + except (json.JSONDecodeError, TypeError): + pass + lines_to_forward.append(line) + + return lines_to_forward, client_tool_event + + +def connect_chat_api(app: FastAPI) -> None: + @app.post("/api/chat") + async def chat(request: Request) -> StreamingResponse: + api_key = get_copilot_api_key() + body_bytes = await request.body() + body_json = json.loads(body_bytes) + + upstream_url = f"{_get_base_url()}/v1/chat/" + headers = _build_upstream_headers(api_key) + + async def stream_with_client_tools(): + current_body = body_json + rounds = 0 + + while rounds < _MAX_CLIENT_TOOL_ROUNDS: + rounds += 1 + client_tool_event = None + + async with httpx.AsyncClient(timeout=_CHAT_TIMEOUT) as client: + async with client.stream( + "POST", + upstream_url, + content=json.dumps(current_body).encode(), + headers=headers, + ) as upstream: + if upstream.status_code != 200: + error_body = await upstream.aread() + detail = "Chat request failed." + if error_body.startswith(b"{"): + try: + detail = ( + json.loads(error_body).get("message", detail) + or detail + ) + except json.JSONDecodeError: + pass + yield f"data: {json.dumps({'type': 'error', 'message': detail})}\n\n".encode() + return + + try: + async for chunk in upstream.aiter_bytes(): + lines, tool_event = _parse_sse_events(chunk) + if tool_event: + client_tool_event = tool_event + forward_bytes = b"\n".join(lines) + if forward_bytes.strip(): + yield forward_bytes + b"\n" + except httpx.RemoteProtocolError: + if client_tool_event is not None: + logger.debug( + "Connection closed after client tool call event (expected)" + ) + else: + raise + + if client_tool_event is None: + return + + tool_name = client_tool_event.get("toolName", "") + tool_call_id = client_tool_event.get("toolCallId", "") + tool_input = client_tool_event.get("input", {}) + + logger.info( + f"Executing client tool: {tool_name} (call_id={tool_call_id})" + ) + + yield f"data: {json.dumps({'type': 'tool-output-available', 'toolCallId': tool_call_id, 'output': '(executing locally...)'})}\n\n".encode() + + tool_result = _execute_client_tool(tool_name, tool_input) + + yield f"data: {json.dumps({'type': 'tool-output-available', 'toolCallId': tool_call_id, 'output': tool_result})}\n\n".encode() + + current_body = _build_continuation_body( + current_body, tool_call_id, tool_name, tool_input, tool_result + ) + + return StreamingResponse( + content=stream_with_client_tools(), + media_type="text/event-stream", + ) + + +def _build_continuation_body( + original_body: dict[str, Any], + tool_call_id: str, + tool_name: str, + tool_input: Any, + tool_result: str, +) -> dict[str, Any]: + """Build the request body for continuing after a client tool call. + + Appends a single assistant message containing both the tool call and its + result so the backend's convert_to_openai_messages produces the correct + assistant(tool_calls) + tool(result) sequence. + """ + messages = list(original_body.get("messages", [])) + + messages.append( + { + "role": "assistant", + "parts": [ + { + "type": f"tool-{tool_name}", + "toolCallId": tool_call_id, + "toolName": tool_name, + "input": tool_input, + "state": "call", + }, + { + "type": f"tool-{tool_name}", + "toolCallId": tool_call_id, + "toolName": tool_name, + "output": tool_result, + "state": "output-available", + }, + ], + } + ) + + return {**original_body, "messages": messages} diff --git a/app/desktop/studio_server/test_chat_api.py b/app/desktop/studio_server/test_chat_api.py new file mode 100644 index 000000000..3be97003c --- /dev/null +++ b/app/desktop/studio_server/test_chat_api.py @@ -0,0 +1,283 @@ +import json +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest +from app.desktop.studio_server.chat_api import ( + _build_continuation_body, + _execute_client_tool, + _parse_sse_events, + connect_chat_api, +) +from fastapi import FastAPI +from fastapi.testclient import TestClient +from kiln_server.custom_errors import connect_custom_errors + + +@pytest.fixture +def app(): + app = FastAPI() + connect_custom_errors(app) + connect_chat_api(app) + return app + + +@pytest.fixture +def client(app): + return TestClient(app) + + +@pytest.fixture +def mock_api_key(): + with patch( + "app.desktop.studio_server.utils.copilot_utils.Config.shared" + ) as mock_config_shared: + mock_config = mock_config_shared.return_value + mock_config.kiln_copilot_api_key = "test_api_key" + yield mock_config + + +def _make_httpx_mock(status_code: int = 200, chunks: list[bytes] | None = None): + if chunks is None: + chunks = [b'data: {"type":"text-delta","delta":"hello"}\n\n'] + + async def mock_aiter_bytes(): + for chunk in chunks: + yield chunk + + mock_upstream = MagicMock() + mock_upstream.status_code = status_code + mock_upstream.aiter_bytes.return_value = mock_aiter_bytes() + mock_upstream.aread = AsyncMock( + return_value=b'{"message":"upstream error"}' if status_code != 200 else b"" + ) + mock_upstream.__aenter__ = AsyncMock(return_value=mock_upstream) + mock_upstream.__aexit__ = AsyncMock(return_value=None) + + mock_client = MagicMock() + mock_client.stream.return_value = mock_upstream + mock_client.__aenter__ = AsyncMock(return_value=mock_client) + mock_client.__aexit__ = AsyncMock(return_value=None) + + mock_async_client_class = MagicMock(return_value=mock_client) + return mock_async_client_class, mock_client, mock_upstream + + +# --- SSE passthrough tests --- + + +class TestChatStreaming: + def test_streams_chunks(self, client, mock_api_key): + chunks = [ + b'data: {"type":"text-delta","delta":"hello"}\n\n', + b'data: {"type":"finish"}\n\n', + ] + mock_class, _, _ = _make_httpx_mock(chunks=chunks) + + with patch("app.desktop.studio_server.chat_api.httpx.AsyncClient", mock_class): + response = client.post( + "/api/chat", + json={"messages": [{"role": "user", "content": "hi"}]}, + ) + + assert response.status_code == 200 + assert response.headers["content-type"].startswith("text/event-stream") + assert b"text-delta" in response.content + + def test_forwards_auth_header(self, client, mock_api_key): + mock_class, mock_client, _ = _make_httpx_mock() + + with patch("app.desktop.studio_server.chat_api.httpx.AsyncClient", mock_class): + client.post( + "/api/chat", + json={"messages": [{"role": "user", "content": "hi"}]}, + ) + + call_kwargs = mock_client.stream.call_args + headers = call_kwargs.kwargs.get("headers", {}) + assert headers.get("Authorization") == "Bearer test_api_key" + + def test_returns_401_when_no_api_key(self, client): + with patch( + "app.desktop.studio_server.utils.copilot_utils.Config.shared" + ) as mock_config_shared: + mock_config = mock_config_shared.return_value + mock_config.kiln_copilot_api_key = None + + response = client.post( + "/api/chat", + json={"messages": [{"role": "user", "content": "hi"}]}, + ) + + assert response.status_code == 401 + + def test_handles_upstream_error(self, client, mock_api_key): + mock_class, _, _ = _make_httpx_mock(status_code=500) + + with patch("app.desktop.studio_server.chat_api.httpx.AsyncClient", mock_class): + response = client.post( + "/api/chat", + json={"messages": [{"role": "user", "content": "hi"}]}, + ) + + assert response.status_code == 200 + assert b"error" in response.content + + +# --- SSE parsing tests --- + + +class TestParseSSEEvents: + def test_passthrough_normal_events(self): + raw = b'data: {"type":"text-delta","delta":"hi"}\n\n' + lines, tool_event = _parse_sse_events(raw) + assert tool_event is None + assert any(b"text-delta" in line for line in lines) + + def test_detects_client_tool_call(self): + raw = ( + b'data: {"type":"text-delta","delta":"hi"}\n' + b'data: {"type":"client-tool-call","toolCallId":"tc1","toolName":"read_task_run","input":{"path":"/x"}}\n\n' + ) + lines, tool_event = _parse_sse_events(raw) + assert tool_event is not None + assert tool_event["toolName"] == "read_task_run" + assert tool_event["toolCallId"] == "tc1" + assert not any(b"client-tool-call" in line for line in lines) + + def test_handles_empty_input(self): + lines, tool_event = _parse_sse_events(b"") + assert tool_event is None + + +# --- Client tool execution tests --- + + +class TestExecuteClientTool: + def test_read_task_run_success(self): + mock_run = MagicMock() + mock_run.model_dump_json.return_value = '{"id": "42", "input": "hello"}' + + with patch( + "app.desktop.studio_server.chat_api._find_task_run_by_id", + return_value=mock_run, + ): + result = _execute_client_tool("read_task_run", {"task_run_id": "42"}) + assert '"id": "42"' in result + + def test_read_task_run_not_found(self): + with patch( + "app.desktop.studio_server.chat_api._find_task_run_by_id", + return_value=None, + ): + result = _execute_client_tool("read_task_run", {"task_run_id": "999"}) + parsed = json.loads(result) + assert "error" in parsed + assert "999" in parsed["error"] + + def test_read_task_run_missing_id(self): + result = _execute_client_tool("read_task_run", {}) + parsed = json.loads(result) + assert "error" in parsed + + def test_unknown_tool(self): + result = _execute_client_tool("unknown_tool", {}) + assert "Unknown client tool" in result + + +# --- Continuation body tests --- + + +class TestBuildContinuationBody: + def test_appends_tool_messages(self): + original = {"messages": [{"role": "user", "content": "hi"}]} + result = _build_continuation_body( + original, "tc1", "read_task_run", {"path": "/x"}, '{"data": "result"}' + ) + + assert len(result["messages"]) == 2 + assert result["messages"][0]["role"] == "user" + + parts = result["messages"][1]["parts"] + assert result["messages"][1]["role"] == "assistant" + assert len(parts) == 2 + assert parts[0]["toolCallId"] == "tc1" + assert parts[0]["state"] == "call" + assert parts[0]["input"] == {"path": "/x"} + assert parts[1]["state"] == "output-available" + assert parts[1]["output"] == '{"data": "result"}' + assert "input" not in parts[1] + + def test_preserves_original_body_fields(self): + original = { + "messages": [{"role": "user", "content": "hi"}], + "task_id": "test_task", + } + result = _build_continuation_body(original, "tc1", "tool", {}, "result") + assert result["task_id"] == "test_task" + + +# --- Client tool round-trip test --- + + +class TestClientToolRoundTrip: + def test_detects_and_continues_after_client_tool(self, client, mock_api_key): + """First request returns client-tool-call, proxy executes locally and sends continuation.""" + first_response_chunks = [ + b'data: {"type":"text-delta","delta":"Let me read that"}\n\n', + b'data: {"type":"client-tool-call","toolCallId":"tc1","toolName":"read_task_run","input":{"path":"/fake"}}\n\n', + b'data: {"type":"finish"}\n\n', + ] + second_response_chunks = [ + b'data: {"type":"text-delta","delta":"Here is the result"}\n\n', + b'data: {"type":"finish"}\n\n', + ] + + call_count = 0 + + def make_stream_mock(chunks): + async def mock_aiter_bytes(): + for chunk in chunks: + yield chunk + + mock_upstream = MagicMock() + mock_upstream.status_code = 200 + mock_upstream.aiter_bytes.return_value = mock_aiter_bytes() + mock_upstream.__aenter__ = AsyncMock(return_value=mock_upstream) + mock_upstream.__aexit__ = AsyncMock(return_value=None) + return mock_upstream + + def side_effect_stream(*args, **kwargs): + nonlocal call_count + call_count += 1 + if call_count == 1: + return make_stream_mock(first_response_chunks) + return make_stream_mock(second_response_chunks) + + mock_client = MagicMock() + mock_client.stream.side_effect = side_effect_stream + mock_client.__aenter__ = AsyncMock(return_value=mock_client) + mock_client.__aexit__ = AsyncMock(return_value=None) + + mock_class = MagicMock(return_value=mock_client) + + with ( + patch("app.desktop.studio_server.chat_api.httpx.AsyncClient", mock_class), + patch( + "app.desktop.studio_server.chat_api._execute_client_tool", + return_value='{"data": "mock result"}', + ), + ): + response = client.post( + "/api/chat", + json={"messages": [{"role": "user", "content": "read my task run"}]}, + ) + + assert response.status_code == 200 + content = response.content + assert b"Let me read that" in content + assert b"Here is the result" in content + assert call_count == 2 + + continuation_call = mock_client.stream.call_args_list[1] + continuation_body = json.loads(continuation_call.kwargs["content"]) + assert len(continuation_body["messages"]) == 2 diff --git a/app/web_ui/package-lock.json b/app/web_ui/package-lock.json index 4a092c346..434db91ae 100644 --- a/app/web_ui/package-lock.json +++ b/app/web_ui/package-lock.json @@ -10,8 +10,10 @@ "dependencies": { "@floating-ui/dom": "^1.7.2", "@kinde-oss/kinde-auth-pkce-js": "^4.3.0", + "dompurify": "^3.3.2", "echarts": "^6.0.0", "highlight.js": "^11.10.0", + "marked": "^17.0.4", "openapi-fetch": "^0.12.2", "posthog-js": "^1.184.2" }, @@ -21,6 +23,7 @@ "@sveltejs/kit": "^2.20.6", "@sveltejs/vite-plugin-svelte": "^3.1.1", "@tailwindcss/typography": "^0.5.13", + "@types/dompurify": "^3.0.5", "@typescript-eslint/eslint-plugin": "^8.46.1", "@typescript-eslint/parser": "^8.46.1", "autoprefixer": "^10.4.15", @@ -2096,6 +2099,16 @@ "integrity": "sha512-4Kh9a6B2bQciAhf7FSuMRRkUWecJgJu9nPnx3yzpsfXX/c50REIqpHY4C82bXP90qrLtXtkDxTZosYO3UpOwlA==", "dev": true }, + "node_modules/@types/dompurify": { + "version": "3.0.5", + "resolved": "https://registry.npmjs.org/@types/dompurify/-/dompurify-3.0.5.tgz", + "integrity": "sha512-1Wg0g3BtQF7sSb27fJQAKck1HECM6zV1EB66j8JH9i3LCjYabJa0FSdiSgsD5K/RbrsR0SiraKacLB+T8ZVYAg==", + "dev": true, + "license": "MIT", + "dependencies": { + "@types/trusted-types": "*" + } + }, "node_modules/@types/estree": { "version": "1.0.6", "resolved": "https://registry.npmjs.org/@types/estree/-/estree-1.0.6.tgz", @@ -2134,9 +2147,8 @@ "version": "2.0.7", "resolved": "https://registry.npmjs.org/@types/trusted-types/-/trusted-types-2.0.7.tgz", "integrity": "sha512-ScaPdn1dQczgbl0QFTeTOmVHFULt394XJgOQNoyVhZ6r2vLnMLJfBPd53SB52T/3G36VI1/g2MZaX0cwDuXsfw==", - "dev": true, - "license": "MIT", - "optional": true + "devOptional": true, + "license": "MIT" }, "node_modules/@typescript-eslint/eslint-plugin": { "version": "8.46.1", @@ -3372,11 +3384,13 @@ "dev": true }, "node_modules/dompurify": { - "version": "3.3.0", - "resolved": "https://registry.npmjs.org/dompurify/-/dompurify-3.3.0.tgz", - "integrity": "sha512-r+f6MYR1gGN1eJv0TVQbhA7if/U7P87cdPl3HN5rikqaBSBxLiCb/b9O+2eG0cxz0ghyU+mU1QkbsOwERMYlWQ==", - "dev": true, + "version": "3.3.2", + "resolved": "https://registry.npmjs.org/dompurify/-/dompurify-3.3.2.tgz", + "integrity": "sha512-6obghkliLdmKa56xdbLOpUZ43pAR6xFy1uOrxBaIDjT+yaRuuybLjGS9eVBoSR/UPU5fq3OXClEHLJNGvbxKpQ==", "license": "(MPL-2.0 OR Apache-2.0)", + "engines": { + "node": ">=20" + }, "optionalDependencies": { "@types/trusted-types": "^2.0.7" } @@ -4878,16 +4892,15 @@ "license": "MIT" }, "node_modules/marked": { - "version": "4.3.0", - "resolved": "https://registry.npmjs.org/marked/-/marked-4.3.0.tgz", - "integrity": "sha512-PRsaiG84bK+AMvxziE/lCFss8juXjNaWzVbN5tXAm4XjeaS9NAHhop+PjQxz2A9h8Q4M/xGmzP8vqNwy6JeK0A==", - "dev": true, + "version": "17.0.4", + "resolved": "https://registry.npmjs.org/marked/-/marked-17.0.4.tgz", + "integrity": "sha512-NOmVMM+KAokHMvjWmC5N/ZOvgmSWuqJB8FoYI019j4ogb/PeRMKoKIjReZ2w3376kkA8dSJIP8uD993Kxc0iRQ==", "license": "MIT", "bin": { "marked": "bin/marked.js" }, "engines": { - "node": ">= 12" + "node": ">= 20" } }, "node_modules/math-intrinsics": { @@ -6200,6 +6213,19 @@ "styled-components": "^4.1.1 || ^5.1.1 || ^6.0.5" } }, + "node_modules/redoc/node_modules/marked": { + "version": "4.3.0", + "resolved": "https://registry.npmjs.org/marked/-/marked-4.3.0.tgz", + "integrity": "sha512-PRsaiG84bK+AMvxziE/lCFss8juXjNaWzVbN5tXAm4XjeaS9NAHhop+PjQxz2A9h8Q4M/xGmzP8vqNwy6JeK0A==", + "dev": true, + "license": "MIT", + "bin": { + "marked": "bin/marked.js" + }, + "engines": { + "node": ">= 12" + } + }, "node_modules/reftools": { "version": "1.1.9", "resolved": "https://registry.npmjs.org/reftools/-/reftools-1.1.9.tgz", diff --git a/app/web_ui/package.json b/app/web_ui/package.json index 28038aa15..11ccd4f57 100644 --- a/app/web_ui/package.json +++ b/app/web_ui/package.json @@ -21,6 +21,7 @@ "@sveltejs/kit": "^2.20.6", "@sveltejs/vite-plugin-svelte": "^3.1.1", "@tailwindcss/typography": "^0.5.13", + "@types/dompurify": "^3.0.5", "@typescript-eslint/eslint-plugin": "^8.46.1", "@typescript-eslint/parser": "^8.46.1", "autoprefixer": "^10.4.15", @@ -43,9 +44,11 @@ "type": "module", "dependencies": { "@floating-ui/dom": "^1.7.2", - "echarts": "^6.0.0", "@kinde-oss/kinde-auth-pkce-js": "^4.3.0", + "dompurify": "^3.3.2", + "echarts": "^6.0.0", "highlight.js": "^11.10.0", + "marked": "^17.0.4", "openapi-fetch": "^0.12.2", "posthog-js": "^1.184.2" }, diff --git a/app/web_ui/src/app.css b/app/web_ui/src/app.css index 073088829..46c1c575d 100644 --- a/app/web_ui/src/app.css +++ b/app/web_ui/src/app.css @@ -28,3 +28,17 @@ input::-webkit-inner-spin-button { input[type="number"] { -moz-appearance: textfield; } + +@keyframes thinking-dot { + 0%, + 100% { + opacity: 0.35; + } + 50% { + opacity: 1; + } +} + +.thinking-dot { + animation: thinking-dot 0.8s ease-in-out infinite; +} diff --git a/app/web_ui/src/lib/chat/ChatMarkdown.svelte b/app/web_ui/src/lib/chat/ChatMarkdown.svelte new file mode 100644 index 000000000..342a63c90 --- /dev/null +++ b/app/web_ui/src/lib/chat/ChatMarkdown.svelte @@ -0,0 +1,108 @@ + + + + + + +{#if sanitized} + +{:else} + {text || ""} +{/if} diff --git a/app/web_ui/src/lib/chat/streaming_chat.ts b/app/web_ui/src/lib/chat/streaming_chat.ts new file mode 100644 index 000000000..b28f1718e --- /dev/null +++ b/app/web_ui/src/lib/chat/streaming_chat.ts @@ -0,0 +1,338 @@ +/** + * Custom streaming chat: parses SSE from the backend (AI SDK protocol JSON events). + * Does not use @ai-sdk/svelte because we use Svelte 4 and @ai-sdk/svelte uses Svelte 5. + * + * There is an ancient version of the lib that works with Svelte 4, but then that forces us + * to use an old version of the protocol on the backend too, which is not a good idea. + */ + +export type ChatMessagePart = + | { type: "text"; text: string } + | { type: "reasoning"; reasoning: string } + | { + type: `tool-${string}` + toolCallId: string + toolName?: string + input?: unknown + output?: unknown + } + +export interface ChatMessage { + id: string + role: "user" | "assistant" | "system" + content?: string + parts?: ChatMessagePart[] +} + +/** Body the backend expects: POST /api/chat */ +export interface BackendChatRequest { + messages: Array<{ + role: string + content?: string + parts?: Array> + }> +} + +function toBackendMessage(m: ChatMessage): BackendChatRequest["messages"][0] { + if (m.role === "user") { + return { role: "user", content: m.content ?? "" } + } + if (m.role === "assistant" && m.parts?.length) { + return { + role: "assistant", + parts: m.parts.map((p) => { + if (p.type === "text") return { type: "text", text: p.text } + if (p.type === "reasoning") + return { type: "reasoning", reasoning: p.reasoning } + return { + type: p.type, + toolCallId: p.toolCallId, + toolName: p.toolName, + input: p.input, + output: p.output, + } + }), + } + } + return { role: m.role, content: m.content ?? "" } +} + +/** SSE event from backend (AI SDK stream event shape) */ +interface StreamEvent { + type: string + delta?: string + id?: string + messageId?: string + toolCallId?: string + toolName?: string + input?: unknown + inputTextDelta?: string + output?: unknown + errorText?: string + messageMetadata?: { finishReason?: string; usage?: unknown } +} + +export interface StreamChatOptions { + apiUrl: string + messages: ChatMessage[] + onAssistantMessage: (update: (draft: ChatMessage) => void) => void + onFinish: () => void + onError: (error: Error) => void + signal?: AbortSignal +} + +function generateId(): string { + return `msg-${Date.now()}-${Math.random().toString(36).slice(2, 9)}` +} + +/** + * POST to apiUrl with messages, then parse SSE stream and call onAssistantMessage + * for each event that updates the assistant reply. Calls onFinish when stream ends + * or onError on failure. Respects signal for abort. + */ +export async function streamChat(options: StreamChatOptions): Promise { + const { apiUrl, messages, onAssistantMessage, onFinish, onError, signal } = + options + + const body: BackendChatRequest = { + messages: messages.map(toBackendMessage), + } + + let response: Response + try { + response = await fetch(apiUrl, { + method: "POST", + headers: { "Content-Type": "application/json" }, + body: JSON.stringify(body), + signal, + }) + } catch (err) { + if ((err as Error).name === "AbortError") { + onFinish() + return + } + onError(err instanceof Error ? err : new Error(String(err))) + return + } + + if (!response.ok) { + const text = await response.text() + onError( + new Error( + `Chat API error ${response.status}: ${text || response.statusText}`, + ), + ) + return + } + + const reader = response.body?.getReader() + if (!reader) { + onError(new Error("No response body")) + return + } + + const decoder = new TextDecoder() + let buffer = "" + + type PartSlot = + | { kind: "text"; id: string } + | { kind: "reasoning"; id: string } + | { kind: "tool"; id: string } + const partOrder: PartSlot[] = [] + const textBlocks = new Map() + const reasoningBlocks = new Map() + const toolMap = new Map< + string, + { + type: `tool-${string}` + toolCallId: string + toolName?: string + input?: unknown + output?: unknown + } + >() + const toolInputBuffer = new Map() + let currentTextId: string | null = null + let currentReasoningId: string | null = null + let slotIdCounter = 0 + function nextSlotId(): string { + slotIdCounter += 1 + return `slot-${slotIdCounter}` + } + + function flushAssistant() { + onAssistantMessage((draft) => { + const next: ChatMessagePart[] = [] + for (const slot of partOrder) { + if (slot.kind === "text") { + const text = textBlocks.get(slot.id) + if (text) next.push({ type: "text", text }) + } else if (slot.kind === "reasoning") { + const reasoning = reasoningBlocks.get(slot.id) + if (reasoning) next.push({ type: "reasoning", reasoning }) + } else { + const tool = toolMap.get(slot.id) + if (tool) next.push(tool) + } + } + draft.parts = next + }) + } + + try { + while (true) { + const { done, value } = await reader.read() + if (done) break + buffer += decoder.decode(value, { stream: true }) + const lines = buffer.split("\n") + buffer = lines.pop() ?? "" + for (const line of lines) { + if (line.startsWith("data: ")) { + const payload = line.slice(6).trim() + if (payload === "[DONE]" || payload === "") continue + let event: StreamEvent + try { + event = JSON.parse(payload) as StreamEvent + } catch { + continue + } + const typ = event.type + if ( + typ === "text-start" || + (typ === "text-delta" && currentTextId === null) + ) { + if (typ === "text-start" && currentTextId !== null) { + currentTextId = null + } + if (currentTextId === null) { + const id = nextSlotId() + partOrder.push({ kind: "text", id }) + currentTextId = id + textBlocks.set(id, "") + } + } + if (typ === "text-delta" && event.delta != null) { + if (currentTextId === null) { + const id = nextSlotId() + partOrder.push({ kind: "text", id }) + currentTextId = id + textBlocks.set(id, "") + } + textBlocks.set( + currentTextId, + (textBlocks.get(currentTextId) ?? "") + event.delta, + ) + flushAssistant() + } else if (typ === "text-end") { + currentTextId = null + } else if ( + typ === "reasoning-start" || + (typ === "reasoning-delta" && currentReasoningId === null) + ) { + if (typ === "reasoning-start" && currentReasoningId !== null) { + currentReasoningId = null + } + if (currentReasoningId === null) { + const id = nextSlotId() + partOrder.push({ kind: "reasoning", id }) + currentReasoningId = id + reasoningBlocks.set(id, "") + } + } + if (typ === "reasoning-delta" && event.delta != null) { + if (currentReasoningId === null) { + const id = nextSlotId() + partOrder.push({ kind: "reasoning", id }) + currentReasoningId = id + reasoningBlocks.set(id, "") + } + reasoningBlocks.set( + currentReasoningId, + (reasoningBlocks.get(currentReasoningId) ?? "") + event.delta, + ) + flushAssistant() + } else if (typ === "reasoning-end") { + currentReasoningId = null + } else if (typ === "tool-input-start" && event.toolCallId) { + const key = event.toolCallId + if (!toolMap.has(key)) { + partOrder.push({ kind: "tool", id: key }) + toolMap.set(key, { + type: `tool-${event.toolName ?? "unknown"}`, + toolCallId: event.toolCallId, + toolName: event.toolName, + }) + } + flushAssistant() + } else if ( + typ === "tool-input-delta" && + event.toolCallId && + event.inputTextDelta != null + ) { + const key = event.toolCallId + const prev = toolInputBuffer.get(key) ?? "" + toolInputBuffer.set(key, prev + event.inputTextDelta) + let entry = toolMap.get(key) + if (!entry) { + partOrder.push({ kind: "tool", id: key }) + entry = { + type: `tool-${event.toolName ?? "unknown"}`, + toolCallId: event.toolCallId, + toolName: event.toolName, + } + toolMap.set(key, entry) + } + try { + entry.input = JSON.parse( + toolInputBuffer.get(key) ?? "{}", + ) as unknown + } catch { + entry.input = toolInputBuffer.get(key) + } + flushAssistant() + } else if (typ === "tool-input-available" && event.toolCallId) { + const key = event.toolCallId + let entry = toolMap.get(key) + if (!entry) { + partOrder.push({ kind: "tool", id: key }) + entry = { + type: `tool-${event.toolName ?? "unknown"}`, + toolCallId: event.toolCallId, + toolName: event.toolName, + input: event.input, + } + toolMap.set(key, entry) + } else { + entry.input = event.input + } + toolInputBuffer.delete(key) + flushAssistant() + } else if (typ === "tool-output-available" && event.toolCallId) { + const entry = toolMap.get(event.toolCallId) + if (entry) { + entry.output = event.output + flushAssistant() + } + } else if (typ === "tool-output-error" && event.toolCallId) { + const entry = toolMap.get(event.toolCallId) + if (entry) { + entry.output = { error: event.errorText } + flushAssistant() + } + } else if (typ === "finish" || typ === "finish-step") { + break + } + } + } + } + onFinish() + } catch (err) { + if ((err as Error).name === "AbortError") { + onFinish() + return + } + onError(err instanceof Error ? err : new Error(String(err))) + } +} + +export { generateId as chatGenerateId } diff --git a/app/web_ui/src/lib/ui/icons/arrow_up_icon.svelte b/app/web_ui/src/lib/ui/icons/arrow_up_icon.svelte new file mode 100644 index 000000000..a8a9d1153 --- /dev/null +++ b/app/web_ui/src/lib/ui/icons/arrow_up_icon.svelte @@ -0,0 +1,15 @@ + diff --git a/app/web_ui/src/lib/ui/icons/stop_icon.svelte b/app/web_ui/src/lib/ui/icons/stop_icon.svelte new file mode 100644 index 000000000..77bab00a8 --- /dev/null +++ b/app/web_ui/src/lib/ui/icons/stop_icon.svelte @@ -0,0 +1,13 @@ + diff --git a/app/web_ui/src/routes/(app)/+layout.svelte b/app/web_ui/src/routes/(app)/+layout.svelte index 661516a6d..5bb004de7 100644 --- a/app/web_ui/src/routes/(app)/+layout.svelte +++ b/app/web_ui/src/routes/(app)/+layout.svelte @@ -30,6 +30,7 @@ Specs, Generate, Run, + Chat, FineTune, Models, Optimize, @@ -67,6 +68,8 @@ section = Section.Specs } else if (path_start("/optimize", $page.url.pathname)) { section = Section.Optimize + } else if (path_start("/chat", $page.url.pathname)) { + section = Section.Chat } else { section = Section.None } @@ -174,6 +177,25 @@ Run +