Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/build_and_test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ jobs:
run: uv run python3 -m pytest --runslow .

- name: Check Python Types
run: uvx ty check
run: uv run ty check

- name: Build Core
run: uv build
Expand Down
6 changes: 3 additions & 3 deletions .github/workflows/format_and_lint.yml
Original file line number Diff line number Diff line change
Expand Up @@ -40,12 +40,12 @@ jobs:

- name: Lint with ruff
run: |
uvx ruff check
uv run ruff check

- name: Format with ruff
run: |
uvx ruff format --check .
uv run ruff format --check .

- name: Typecheck with ty
run: |
uvx ty check
uv run ty check
12 changes: 6 additions & 6 deletions checks.sh
Original file line number Diff line number Diff line change
Expand Up @@ -31,14 +31,14 @@ echo $PWD
headerStart="\n\033[4;34m=== "
headerEnd=" ===\033[0m\n"

echo "${headerStart}Checking Python: uvx ruff check ${headerEnd}"
uvx ruff check
echo "${headerStart}Checking Python: uv run ruff check ${headerEnd}"
uv run ruff check

echo "${headerStart}Checking Python: uvx ruff format --check ${headerEnd}"
uvx ruff format --check .
echo "${headerStart}Checking Python: uv run ruff format --check ${headerEnd}"
uv run ruff format --check .

echo "${headerStart}Checking Python Types: uvx ty check${headerEnd}"
uvx ty check
echo "${headerStart}Checking Python Types: uv run ty check${headerEnd}"
uv run ty check

echo "${headerStart}Checking for Misspellings${headerEnd}"
if command -v misspell >/dev/null 2>&1; then
Expand Down
10 changes: 5 additions & 5 deletions hooks_mcp.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,23 +12,23 @@ actions:

- name: "lint_python"
description: "Lint the python source code, checking for errors and warnings"
command: "uvx ruff check"
command: "uv run ruff check"

- name: "lint_fix_python"
description: "Lint the pythong source code, fixing errors and warnings which it can fix. Not all errors can be fixed automatically."
command: "uvx ruff check --fix"
command: "uv run ruff check --fix"

- name: "check_format_python"
description: "Check if the python source code is formatted correctly"
command: "uvx ruff format --check ."
command: "uv run ruff format --check ."

- name: "format_python"
description: "Format the python source code"
command: "uvx ruff format ."
command: "uv run ruff format ."

- name: "typecheck_python"
description: "Typecheck the source code"
command: "uvx ty check"
command: "uv run ty check"

- name: "test_file_python"
description: "Run tests in a specific python file or directory"
Expand Down
Empty file.
60 changes: 60 additions & 0 deletions libs/core/kiln_ai/adapters/litellm_utils/litellm_streaming.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
from __future__ import annotations

from typing import Any, AsyncIterator, Optional, Union

import litellm
from litellm.types.utils import (
ModelResponse,
ModelResponseStream,
TextCompletionResponse,
)


class StreamingCompletion:
"""
Async iterable wrapper around ``litellm.acompletion`` with streaming.

Yields ``ModelResponseStream`` chunks as they arrive. After iteration
completes, the assembled ``ModelResponse`` is available via the
``.response`` property.

Usage::

stream = StreamingCompletion(model=..., messages=...)
async for chunk in stream:
# handle chunk however you like (print, log, send over WS, …)
pass
final = stream.response # fully assembled ModelResponse
"""

def __init__(self, *args: Any, **kwargs: Any) -> None:
kwargs = dict(kwargs)
kwargs.pop("stream", None)
self._args = args
self._kwargs = kwargs
self._response: Optional[Union[ModelResponse, TextCompletionResponse]] = None
self._iterated: bool = False

@property
def response(self) -> Optional[Union[ModelResponse, TextCompletionResponse]]:
"""The final assembled response. Only available after iteration."""
if not self._iterated:
raise RuntimeError(
"StreamingCompletion has not been iterated yet. "
"Use 'async for chunk in stream:' before accessing .response"
)
return self._response

async def __aiter__(self) -> AsyncIterator[ModelResponseStream]:
self._response = None
self._iterated = False

chunks: list[ModelResponseStream] = []
stream = await litellm.acompletion(*self._args, stream=True, **self._kwargs)

async for chunk in stream:
chunks.append(chunk)
yield chunk

self._response = litellm.stream_chunk_builder(chunks)
self._iterated = True
Comment on lines +48 to +60
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Missing try/finally leaves response inaccessible after any interrupted iteration.

If the consumer's async for body raises (e.g., an on_chunk callback throws) or if the streaming call itself fails, Python sends GeneratorExit into the generator at the yield point. The two lines after the loop never execute, so _iterated stays False and stream.response will always raise RuntimeError — callers cannot distinguish "not yet started" from "stream failed".

Additionally, the litellm stream object (a CustomStreamWrapper) won't have .aclose() called implicitly when the generator is abandoned without exhaustion.

🐛 Proposed fix — `try/finally` for guaranteed state finalization
     async def __aiter__(self) -> AsyncIterator[ModelResponseStream]:
         self._response = None
         self._iterated = False

         chunks: list[ModelResponseStream] = []
-        stream = await litellm.acompletion(*self._args, stream=True, **self._kwargs)
-
-        async for chunk in stream:
-            chunks.append(chunk)
-            yield chunk
-
-        self._response = litellm.stream_chunk_builder(chunks)
-        self._iterated = True
+        try:
+            stream = await litellm.acompletion(*self._args, stream=True, **self._kwargs)
+            async for chunk in stream:
+                chunks.append(chunk)
+                yield chunk
+        finally:
+            self._response = litellm.stream_chunk_builder(chunks) if chunks else None
+            self._iterated = True
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@libs/core/kiln_ai/adapters/litellm_utils/litellm_streaming.py` around lines
48 - 60, The async iterator __aiter__ in litellm_streaming.py can be aborted
leaving _iterated False and _response unset and never closing the underlying
litellm stream; wrap the streaming logic in a try/finally: create the stream as
before, iterate and yield chunks inside try, and in finally always call
stream.aclose() if stream exists, set self._response =
litellm.stream_chunk_builder(chunks) (even if empty) and self._iterated = True
so stream.response works after interruption; ensure any exceptions are re-raised
after finalization so behavior is preserved.

139 changes: 139 additions & 0 deletions libs/core/kiln_ai/adapters/litellm_utils/test_litellm_streaming.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
from __future__ import annotations

from types import SimpleNamespace
from typing import Any, List
from unittest.mock import MagicMock, patch

import pytest

from kiln_ai.adapters.litellm_utils.litellm_streaming import StreamingCompletion


def _make_chunk(content: str | None = None, finish_reason: str | None = None) -> Any:
"""Build a minimal chunk object matching litellm's streaming shape."""
delta = SimpleNamespace(content=content, role="assistant")
choice = SimpleNamespace(delta=delta, finish_reason=finish_reason, index=0)
return SimpleNamespace(choices=[choice], id="chatcmpl-test", model="test-model")


async def _async_iter(items: List[Any]):
"""Turn a plain list into an async iterator."""
for item in items:
yield item


@pytest.fixture
def mock_acompletion():
with patch("litellm.acompletion") as mock:
yield mock


@pytest.fixture
def mock_chunk_builder():
with patch("litellm.stream_chunk_builder") as mock:
yield mock


class TestStreamingCompletion:
async def test_yields_all_chunks(self, mock_acompletion, mock_chunk_builder):
chunks = [_make_chunk("Hello"), _make_chunk(" world"), _make_chunk("!")]
mock_acompletion.return_value = _async_iter(chunks)
mock_chunk_builder.return_value = MagicMock(name="final_response")

stream = StreamingCompletion(model="test", messages=[])
received = [chunk async for chunk in stream]

assert received == chunks

async def test_response_available_after_iteration(
self, mock_acompletion, mock_chunk_builder
):
chunks = [_make_chunk("hi")]
mock_acompletion.return_value = _async_iter(chunks)
sentinel = MagicMock(name="final_response")
mock_chunk_builder.return_value = sentinel

stream = StreamingCompletion(model="test", messages=[])
async for _ in stream:
pass

assert stream.response is sentinel

async def test_response_raises_before_iteration(self):
stream = StreamingCompletion(model="test", messages=[])
with pytest.raises(RuntimeError, match="not been iterated"):
_ = stream.response

async def test_stream_kwarg_is_stripped(self, mock_acompletion, mock_chunk_builder):
mock_acompletion.return_value = _async_iter([])
mock_chunk_builder.return_value = None

stream = StreamingCompletion(model="test", messages=[], stream=False)
async for _ in stream:
pass

_, call_kwargs = mock_acompletion.call_args
assert call_kwargs["stream"] is True

async def test_passes_args_and_kwargs_through(
self, mock_acompletion, mock_chunk_builder
):
mock_acompletion.return_value = _async_iter([])
mock_chunk_builder.return_value = None

stream = StreamingCompletion(
model="gpt-4", messages=[{"role": "user", "content": "hi"}], temperature=0.5
)
async for _ in stream:
pass

_, call_kwargs = mock_acompletion.call_args
assert call_kwargs["model"] == "gpt-4"
assert call_kwargs["messages"] == [{"role": "user", "content": "hi"}]
assert call_kwargs["temperature"] == 0.5
assert call_kwargs["stream"] is True

async def test_chunks_passed_to_builder(self, mock_acompletion, mock_chunk_builder):
chunks = [_make_chunk("a"), _make_chunk("b")]
mock_acompletion.return_value = _async_iter(chunks)
mock_chunk_builder.return_value = MagicMock()

stream = StreamingCompletion(model="test", messages=[])
async for _ in stream:
pass

mock_chunk_builder.assert_called_once_with(chunks)

async def test_re_iteration_resets_state(
self, mock_acompletion, mock_chunk_builder
):
first_chunks = [_make_chunk("first")]
second_chunks = [_make_chunk("second")]
first_response = MagicMock(name="first_response")
second_response = MagicMock(name="second_response")

mock_acompletion.side_effect = [
_async_iter(first_chunks),
_async_iter(second_chunks),
]
mock_chunk_builder.side_effect = [first_response, second_response]

stream = StreamingCompletion(model="test", messages=[])

async for _ in stream:
pass
assert stream.response is first_response

async for _ in stream:
pass
assert stream.response is second_response

async def test_empty_stream(self, mock_acompletion, mock_chunk_builder):
mock_acompletion.return_value = _async_iter([])
mock_chunk_builder.return_value = None

stream = StreamingCompletion(model="test", messages=[])
received = [chunk async for chunk in stream]

assert received == []
assert stream.response is None
Loading
Loading