Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions app/desktop/desktop_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from kiln_ai.utils.logging import setup_litellm_logging

from app.desktop.log_config import log_config
from app.desktop.studio_server.chat_api import connect_chat_api
from app.desktop.studio_server.copilot_api import connect_copilot_api
from app.desktop.studio_server.data_gen_api import connect_data_gen_api
from app.desktop.studio_server.dev_tools import connect_dev_tools
Expand Down Expand Up @@ -71,6 +72,7 @@ def make_app(tk_root: tk.Tk | None = None):
connect_import_api(app, tk_root=tk_root)
connect_tool_servers_api(app)
connect_prompt_optimization_job_api(app)
connect_chat_api(app)
connect_copilot_api(app)
connect_dev_tools(app)

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""Contains endpoint functions for accessing the API"""
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
from http import HTTPStatus
from typing import Any

import httpx

from ... import errors
from ...client import AuthenticatedClient, Client
from ...models.chat_request import ChatRequest
from ...models.http_validation_error import HTTPValidationError
from ...types import Response


def _get_kwargs(
*,
body: ChatRequest,
) -> dict[str, Any]:
headers: dict[str, Any] = {}

_kwargs: dict[str, Any] = {
"method": "post",
"url": "/v1/chat/",
}

_kwargs["json"] = body.to_dict()

headers["Content-Type"] = "application/json"

_kwargs["headers"] = headers
return _kwargs


def _parse_response(
*, client: AuthenticatedClient | Client, response: httpx.Response
) -> Any | HTTPValidationError | None:
if response.status_code == 200:
response_200 = response.json()
return response_200

if response.status_code == 422:
response_422 = HTTPValidationError.from_dict(response.json())

return response_422

if client.raise_on_unexpected_status:
raise errors.UnexpectedStatus(response.status_code, response.content)
else:
return None


def _build_response(
*, client: AuthenticatedClient | Client, response: httpx.Response
) -> Response[Any | HTTPValidationError]:
return Response(
status_code=HTTPStatus(response.status_code),
content=response.content,
headers=response.headers,
parsed=_parse_response(client=client, response=response),
)


def sync_detailed(
*,
client: AuthenticatedClient,
body: ChatRequest,
) -> Response[Any | HTTPValidationError]:
"""Handle Chat

Args:
body (ChatRequest):

Raises:
errors.UnexpectedStatus: If the server returns an undocumented status code and Client.raise_on_unexpected_status is True.
httpx.TimeoutException: If the request takes longer than Client.timeout.

Returns:
Response[Any | HTTPValidationError]
"""

kwargs = _get_kwargs(
body=body,
)

response = client.get_httpx_client().request(
**kwargs,
)

return _build_response(client=client, response=response)


def sync(
*,
client: AuthenticatedClient,
body: ChatRequest,
) -> Any | HTTPValidationError | None:
"""Handle Chat

Args:
body (ChatRequest):

Raises:
errors.UnexpectedStatus: If the server returns an undocumented status code and Client.raise_on_unexpected_status is True.
httpx.TimeoutException: If the request takes longer than Client.timeout.

Returns:
Any | HTTPValidationError
"""

return sync_detailed(
client=client,
body=body,
).parsed


async def asyncio_detailed(
*,
client: AuthenticatedClient,
body: ChatRequest,
) -> Response[Any | HTTPValidationError]:
"""Handle Chat

Args:
body (ChatRequest):

Raises:
errors.UnexpectedStatus: If the server returns an undocumented status code and Client.raise_on_unexpected_status is True.
httpx.TimeoutException: If the request takes longer than Client.timeout.

Returns:
Response[Any | HTTPValidationError]
"""

kwargs = _get_kwargs(
body=body,
)

response = await client.get_async_httpx_client().request(**kwargs)

return _build_response(client=client, response=response)


async def asyncio(
*,
client: AuthenticatedClient,
body: ChatRequest,
) -> Any | HTTPValidationError | None:
"""Handle Chat

Args:
body (ChatRequest):

Raises:
errors.UnexpectedStatus: If the server returns an undocumented status code and Client.raise_on_unexpected_status is True.
httpx.TimeoutException: If the request takes longer than Client.timeout.

Returns:
Any | HTTPValidationError
"""

return (
await asyncio_detailed(
client=client,
body=body,
)
).parsed
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,15 @@
BodyStartPromptOptimizationJobV1JobsPromptOptimizationJobStartPost,
)
from .body_start_sample_job_v1_jobs_sample_job_start_post import BodyStartSampleJobV1JobsSampleJobStartPost
from .chat_request import ChatRequest
from .check_entitlements_v1_check_entitlements_get_response_check_entitlements_v1_check_entitlements_get import (
CheckEntitlementsV1CheckEntitlementsGetResponseCheckEntitlementsV1CheckEntitlementsGet,
)
from .check_model_supported_response import CheckModelSupportedResponse
from .clarify_spec_input import ClarifySpecInput
from .clarify_spec_output import ClarifySpecOutput
from .client_message import ClientMessage
from .client_message_part import ClientMessagePart
from .examples_for_feedback_item import ExamplesForFeedbackItem
from .examples_with_feedback_item import ExamplesWithFeedbackItem
from .generate_batch_input import GenerateBatchInput
Expand Down Expand Up @@ -51,6 +54,8 @@
from .synthetic_data_generation_step_config_input import SyntheticDataGenerationStepConfigInput
from .task_info import TaskInfo
from .task_metadata import TaskMetadata
from .tool_invocation import ToolInvocation
from .tool_invocation_state import ToolInvocationState
from .validation_error import ValidationError

__all__ = (
Expand All @@ -59,10 +64,13 @@
"ApiKeyVerificationResult",
"BodyStartPromptOptimizationJobV1JobsPromptOptimizationJobStartPost",
"BodyStartSampleJobV1JobsSampleJobStartPost",
"ChatRequest",
"CheckEntitlementsV1CheckEntitlementsGetResponseCheckEntitlementsV1CheckEntitlementsGet",
"CheckModelSupportedResponse",
"ClarifySpecInput",
"ClarifySpecOutput",
"ClientMessage",
"ClientMessagePart",
"ExamplesForFeedbackItem",
"ExamplesWithFeedbackItem",
"GenerateBatchInput",
Expand Down Expand Up @@ -101,5 +109,7 @@
"SyntheticDataGenerationStepConfigInput",
"TaskInfo",
"TaskMetadata",
"ToolInvocation",
"ToolInvocationState",
"ValidationError",
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
from __future__ import annotations

from collections.abc import Mapping
from typing import TYPE_CHECKING, Any, TypeVar, cast

from attrs import define as _attrs_define
from attrs import field as _attrs_field

from ..types import UNSET, Unset

if TYPE_CHECKING:
from ..models.client_message import ClientMessage


T = TypeVar("T", bound="ChatRequest")


@_attrs_define
class ChatRequest:
"""
Attributes:
messages (list[ClientMessage]):
task_id (None | str | Unset):
"""

messages: list[ClientMessage]
task_id: None | str | Unset = UNSET
additional_properties: dict[str, Any] = _attrs_field(init=False, factory=dict)

def to_dict(self) -> dict[str, Any]:
messages = []
for messages_item_data in self.messages:
messages_item = messages_item_data.to_dict()
messages.append(messages_item)

task_id: None | str | Unset
if isinstance(self.task_id, Unset):
task_id = UNSET
else:
task_id = self.task_id

field_dict: dict[str, Any] = {}
field_dict.update(self.additional_properties)
field_dict.update(
{
"messages": messages,
}
)
if task_id is not UNSET:
field_dict["task_id"] = task_id

return field_dict

@classmethod
def from_dict(cls: type[T], src_dict: Mapping[str, Any]) -> T:
from ..models.client_message import ClientMessage

d = dict(src_dict)
messages = []
_messages = d.pop("messages")
for messages_item_data in _messages:
messages_item = ClientMessage.from_dict(messages_item_data)

messages.append(messages_item)

def _parse_task_id(data: object) -> None | str | Unset:
if data is None:
return data
if isinstance(data, Unset):
return data
return cast(None | str | Unset, data)

task_id = _parse_task_id(d.pop("task_id", UNSET))

chat_request = cls(
messages=messages,
task_id=task_id,
)

chat_request.additional_properties = d
return chat_request

@property
def additional_keys(self) -> list[str]:
return list(self.additional_properties.keys())

def __getitem__(self, key: str) -> Any:
return self.additional_properties[key]

def __setitem__(self, key: str, value: Any) -> None:
self.additional_properties[key] = value

def __delitem__(self, key: str) -> None:
del self.additional_properties[key]

def __contains__(self, key: str) -> bool:
return key in self.additional_properties
Loading
Loading