Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 73 additions & 0 deletions src/llama_stack/core/task.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

import asyncio
from collections.abc import Coroutine
from contextlib import contextmanager
from dataclasses import dataclass
from typing import Any

from opentelemetry import context as otel_context

from llama_stack.core.request_headers import PROVIDER_DATA_VAR


@dataclass
class RequestContext:
"""Snapshot of request-scoped state for propagation through background queues.

Background workers are long-lived asyncio tasks whose contextvars are frozen
at creation time. Capturing both the OTel trace context and the provider /
auth data at *enqueue* time and re-activating them per work-item ensures:

* Each DB write is attributed to the correct request trace (OTel).
* Each DB write is stamped with the correct user identity (PROVIDER_DATA_VAR).
"""

otel_ctx: otel_context.Context
provider_data: Any


def capture_request_context() -> RequestContext:
"""Snapshot the current request-scoped context for later use in a worker."""
return RequestContext(
otel_ctx=otel_context.get_current(),
provider_data=PROVIDER_DATA_VAR.get(),
)


@contextmanager
def activate_request_context(ctx: RequestContext):
"""Temporarily restore a previously captured request context.

Use this in worker loops that run with a detached (empty) context to
attribute work back to the originating request.
"""
otel_token = otel_context.attach(ctx.otel_ctx)
provider_token = PROVIDER_DATA_VAR.set(ctx.provider_data)
try:
yield
finally:
PROVIDER_DATA_VAR.reset(provider_token)
otel_context.detach(otel_token)


def create_detached_background_task(coro: Coroutine[Any, Any, Any]) -> asyncio.Task[Any]:
"""Create an asyncio task that does not inherit request-scoped context.

asyncio.create_task copies all contextvars at creation time, which causes
long-lived background workers to permanently inherit the spawning request's
OTel trace and auth identity. This helper temporarily clears both before
creating the task, then immediately restores them so the caller is unaffected.
"""
otel_token = otel_context.attach(otel_context.Context())
provider_token = PROVIDER_DATA_VAR.set(None)
try:
task = asyncio.create_task(coro)
finally:
PROVIDER_DATA_VAR.reset(provider_token)
otel_context.detach(otel_token)
return task
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,16 @@
import time
import uuid
from collections.abc import AsyncIterator
from dataclasses import dataclass, field

from pydantic import BaseModel, TypeAdapter

from llama_stack.core.task import (
RequestContext,
activate_request_context,
capture_request_context,
create_detached_background_task,
)
from llama_stack.log import get_logger
from llama_stack.providers.utils.responses.responses_store import (
ResponsesStore,
Expand Down Expand Up @@ -80,6 +87,14 @@
BACKGROUND_NUM_WORKERS = 10


@dataclass
class _BackgroundWorkItem:
"""Typed queue item that pairs business kwargs with the originating request context."""

request_context: RequestContext
kwargs: dict = field(default_factory=dict)


class OpenAIResponsePreviousResponseWithInputItems(BaseModel):
input_items: ListOpenAIResponseInputItem
response: OpenAIResponseObject
Expand Down Expand Up @@ -131,7 +146,7 @@ async def initialize(self) -> None:
async def _ensure_workers_started(self) -> None:
"""Start background workers in the current event loop if not already running."""
for _ in range(BACKGROUND_NUM_WORKERS - len(self._background_worker_tasks)):
task = asyncio.create_task(self._background_worker())
task = create_detached_background_task(self._background_worker())
self._background_worker_tasks.add(task)
task.add_done_callback(self._background_worker_tasks.discard)

Expand All @@ -144,48 +159,49 @@ async def shutdown(self) -> None:
async def _background_worker(self) -> None:
"""Worker coroutine that pulls items from the queue and processes them."""
while True:
kwargs = await self._background_queue.get()
try:
await asyncio.wait_for(
self._run_background_response_loop(**kwargs),
timeout=BACKGROUND_RESPONSE_TIMEOUT_SECONDS,
)
except TimeoutError:
response_id = kwargs["response_id"]
logger.exception(
f"Background response {response_id} timed out after {BACKGROUND_RESPONSE_TIMEOUT_SECONDS}s"
)
try:
existing = await self.responses_store.get_response_object(response_id)
existing.status = "failed"
existing.error = OpenAIResponseError(
code="processing_error",
message=f"Background response timed out after {BACKGROUND_RESPONSE_TIMEOUT_SECONDS}s",
)
await self.responses_store.update_response_object(existing)
except Exception:
logger.exception(
f"Failed to update response {response_id} with timeout status. "
"Client polling this response will not see the failure."
)
except Exception as e:
response_id = kwargs["response_id"]
logger.exception(f"Error processing background response {response_id}")
item = await self._background_queue.get()
with activate_request_context(item.request_context):
try:
existing = await self.responses_store.get_response_object(response_id)
existing.status = "failed"
existing.error = OpenAIResponseError(
code="processing_error",
message=str(e),
await asyncio.wait_for(
self._run_background_response_loop(**item.kwargs),
timeout=BACKGROUND_RESPONSE_TIMEOUT_SECONDS,
)
await self.responses_store.update_response_object(existing)
except Exception:
except TimeoutError:
response_id = item.kwargs["response_id"]
logger.exception(
f"Failed to update response {response_id} with error status. "
"Client polling this response will not see the failure."
f"Background response {response_id} timed out after {BACKGROUND_RESPONSE_TIMEOUT_SECONDS}s"
)
finally:
self._background_queue.task_done()
try:
existing = await self.responses_store.get_response_object(response_id)
existing.status = "failed"
existing.error = OpenAIResponseError(
code="processing_error",
message=f"Background response timed out after {BACKGROUND_RESPONSE_TIMEOUT_SECONDS}s",
)
await self.responses_store.update_response_object(existing)
except Exception:
logger.exception(
f"Failed to update response {response_id} with timeout status. "
"Client polling this response will not see the failure."
)
except Exception as e:
response_id = item.kwargs["response_id"]
logger.exception(f"Error processing background response {response_id}")
try:
existing = await self.responses_store.get_response_object(response_id)
existing.status = "failed"
existing.error = OpenAIResponseError(
code="processing_error",
message=str(e),
)
await self.responses_store.update_response_object(existing)
except Exception:
logger.exception(
f"Failed to update response {response_id} with error status. "
"Client polling this response will not see the failure."
)
finally:
self._background_queue.task_done()

async def _prepend_previous_response(
self,
Expand Down Expand Up @@ -812,33 +828,36 @@ async def _create_background_response(
# Enqueue work item for background workers. Raises QueueFull if at capacity.
try:
self._background_queue.put_nowait(
dict(
response_id=response_id,
input=input,
model=model,
prompt=prompt,
instructions=instructions,
previous_response_id=previous_response_id,
conversation=conversation,
store=store,
temperature=temperature,
frequency_penalty=frequency_penalty,
text=text,
tool_choice=tool_choice,
tools=tools,
include=include,
max_infer_iters=max_infer_iters,
guardrail_ids=guardrail_ids,
parallel_tool_calls=parallel_tool_calls,
max_tool_calls=max_tool_calls,
reasoning=reasoning,
max_output_tokens=max_output_tokens,
safety_identifier=safety_identifier,
service_tier=service_tier,
metadata=metadata,
truncation=truncation,
presence_penalty=presence_penalty,
extra_body=extra_body,
_BackgroundWorkItem(
request_context=capture_request_context(),
kwargs=dict(
response_id=response_id,
input=input,
model=model,
prompt=prompt,
instructions=instructions,
previous_response_id=previous_response_id,
conversation=conversation,
store=store,
temperature=temperature,
frequency_penalty=frequency_penalty,
text=text,
tool_choice=tool_choice,
tools=tools,
include=include,
max_infer_iters=max_infer_iters,
guardrail_ids=guardrail_ids,
parallel_tool_calls=parallel_tool_calls,
max_tool_calls=max_tool_calls,
reasoning=reasoning,
max_output_tokens=max_output_tokens,
safety_identifier=safety_identifier,
service_tier=service_tier,
metadata=metadata,
truncation=truncation,
presence_penalty=presence_penalty,
extra_body=extra_body,
),
)
)
except asyncio.QueueFull:
Expand Down
28 changes: 20 additions & 8 deletions src/llama_stack/providers/utils/inference/inference_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,20 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import asyncio
from typing import Any
from typing import Any, NamedTuple

from sqlalchemy.exc import IntegrityError

from llama_stack.core.datatypes import AccessRule
from llama_stack.core.storage.datatypes import InferenceStoreReference, StorageBackendType
from llama_stack.core.storage.sqlstore.authorized_sqlstore import AuthorizedSqlStore
from llama_stack.core.storage.sqlstore.sqlstore import _SQLSTORE_BACKENDS, sqlstore_impl
from llama_stack.core.task import (
RequestContext,
activate_request_context,
capture_request_context,
create_detached_background_task,
)
from llama_stack.log import get_logger
from llama_stack_api import (
ListOpenAIChatCompletionResponse,
Expand All @@ -25,6 +31,12 @@
logger = get_logger(name=__name__, category="inference")


class _WriteItem(NamedTuple):
completion: OpenAIChatCompletion
messages: list[OpenAIMessageParam]
request_context: RequestContext


class InferenceStore:
def __init__(
self,
Expand All @@ -37,7 +49,7 @@ def __init__(
self.enable_write_queue = True

# Async write queue and worker control
self._queue: asyncio.Queue[tuple[OpenAIChatCompletion, list[OpenAIMessageParam]]] | None = None
self._queue: asyncio.Queue[_WriteItem] | None = None
self._worker_tasks: list[asyncio.Task[Any]] = []
self._max_write_queue_size: int = reference.max_write_queue_size
self._num_writers: int = max(1, reference.num_writers)
Expand Down Expand Up @@ -98,9 +110,8 @@ async def _ensure_workers_started(self) -> None:
)

if not self._worker_tasks:
loop = asyncio.get_running_loop()
for _ in range(self._num_writers):
task = loop.create_task(self._worker_loop())
task = create_detached_background_task(self._worker_loop())
self._worker_tasks.append(task)

async def store_chat_completion(
Expand All @@ -110,13 +121,14 @@ async def store_chat_completion(
await self._ensure_workers_started()
if self._queue is None:
raise ValueError("Inference store is not initialized")
item = _WriteItem(chat_completion, input_messages, capture_request_context())
try:
self._queue.put_nowait((chat_completion, input_messages))
self._queue.put_nowait(item)
except asyncio.QueueFull:
logger.warning(
f"Write queue full; adding chat completion id={getattr(chat_completion, 'id', '<unknown>')}"
)
await self._queue.put((chat_completion, input_messages))
await self._queue.put(item)
else:
await self._write_chat_completion(chat_completion, input_messages)

Expand All @@ -127,9 +139,9 @@ async def _worker_loop(self) -> None:
item = await self._queue.get()
except asyncio.CancelledError:
break
chat_completion, input_messages = item
try:
await self._write_chat_completion(chat_completion, input_messages)
with activate_request_context(item.request_context):
await self._write_chat_completion(item.completion, item.messages)
except Exception as e: # noqa: BLE001
logger.error(f"Error writing chat completion: {e}")
finally:
Expand Down
Loading
Loading