Skip to content
50 changes: 50 additions & 0 deletions src/llama_stack/core/task.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

import asyncio
from collections.abc import Coroutine
from contextlib import contextmanager
from typing import Any

from opentelemetry import context as otel_context


def create_task_with_detached_otel_context(coro: Coroutine[Any, Any, Any]) -> asyncio.Task[Any]:
"""Create an asyncio task that does not inherit the current OpenTelemetry trace context.

asyncio.create_task copies all contextvars at creation time, which causes
fire-and-forget or long-lived background tasks to be attributed to whatever
request happened to spawn them. This inflates trace durations and bundles
unrelated DB operations under the wrong trace.

This helper temporarily clears the OTel context before creating the task,
then immediately restores it so the calling coroutine is unaffected.
"""
token = otel_context.attach(otel_context.Context())
try:
task = asyncio.create_task(coro)
finally:
otel_context.detach(token)
return task


def capture_otel_context() -> otel_context.Context:
"""Snapshot the current OTel context for later use in a different task."""
return otel_context.get_current()


@contextmanager
def activate_otel_context(ctx: otel_context.Context):
"""Temporarily activate a previously captured OTel context.

Use this in worker loops that run with a detached (empty) context to
attribute work back to the originating request's trace.
"""
token = otel_context.attach(ctx)
try:
yield
finally:
otel_context.detach(token)
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,13 @@
import time
import uuid
from collections.abc import AsyncIterator
from dataclasses import dataclass, field

from opentelemetry import context as otel_context
from pydantic import BaseModel, TypeAdapter

from llama_stack.core.conversations.validation import CONVERSATION_ID_PATTERN
from llama_stack.core.task import activate_otel_context, capture_otel_context, create_task_with_detached_otel_context
from llama_stack.log import get_logger
from llama_stack.providers.utils.responses.responses_store import (
ResponsesStore,
Expand Down Expand Up @@ -82,6 +85,14 @@
BACKGROUND_NUM_WORKERS = 10


@dataclass
class _BackgroundWorkItem:
"""Typed queue item for background response processing."""

otel_context: otel_context.Context
kwargs: dict = field(default_factory=dict)


class OpenAIResponsePreviousResponseWithInputItems(BaseModel):
input_items: ListOpenAIResponseInputItem
response: OpenAIResponseObject
Expand Down Expand Up @@ -118,7 +129,7 @@ def __init__(
self.prompts_api = prompts_api
self.files_api = files_api
self.connectors_api = connectors_api
self._background_queue: asyncio.Queue = asyncio.Queue(maxsize=BACKGROUND_QUEUE_MAX_SIZE)
self._background_queue: asyncio.Queue[_BackgroundWorkItem] = asyncio.Queue(maxsize=BACKGROUND_QUEUE_MAX_SIZE)
self._background_worker_tasks: set[asyncio.Task] = set()

async def initialize(self) -> None:
Expand All @@ -133,7 +144,7 @@ async def initialize(self) -> None:
async def _ensure_workers_started(self) -> None:
"""Start background workers in the current event loop if not already running."""
for _ in range(BACKGROUND_NUM_WORKERS - len(self._background_worker_tasks)):
task = asyncio.create_task(self._background_worker())
task = create_task_with_detached_otel_context(self._background_worker())
self._background_worker_tasks.add(task)
task.add_done_callback(self._background_worker_tasks.discard)

Expand All @@ -146,48 +157,49 @@ async def shutdown(self) -> None:
async def _background_worker(self) -> None:
"""Worker coroutine that pulls items from the queue and processes them."""
while True:
kwargs = await self._background_queue.get()
try:
await asyncio.wait_for(
self._run_background_response_loop(**kwargs),
timeout=BACKGROUND_RESPONSE_TIMEOUT_SECONDS,
)
except TimeoutError:
response_id = kwargs["response_id"]
logger.exception(
f"Background response {response_id} timed out after {BACKGROUND_RESPONSE_TIMEOUT_SECONDS}s"
)
item = await self._background_queue.get()
with activate_otel_context(item.otel_context):
try:
existing = await self.responses_store.get_response_object(response_id)
existing.status = "failed"
existing.error = OpenAIResponseError(
code="processing_error",
message=f"Background response timed out after {BACKGROUND_RESPONSE_TIMEOUT_SECONDS}s",
await asyncio.wait_for(
self._run_background_response_loop(**item.kwargs),
timeout=BACKGROUND_RESPONSE_TIMEOUT_SECONDS,
)
await self.responses_store.update_response_object(existing)
except Exception:
except TimeoutError:
response_id = item.kwargs["response_id"]
logger.exception(
f"Failed to update response {response_id} with timeout status. "
"Client polling this response will not see the failure."
)
except Exception as e:
response_id = kwargs["response_id"]
logger.exception(f"Error processing background response {response_id}")
try:
existing = await self.responses_store.get_response_object(response_id)
existing.status = "failed"
existing.error = OpenAIResponseError(
code="processing_error",
message=str(e),
f"Background response {response_id} timed out after {BACKGROUND_RESPONSE_TIMEOUT_SECONDS}s"
)
await self.responses_store.update_response_object(existing)
except Exception:
logger.exception(
f"Failed to update response {response_id} with error status. "
"Client polling this response will not see the failure."
)
finally:
self._background_queue.task_done()
try:
existing = await self.responses_store.get_response_object(response_id)
existing.status = "failed"
existing.error = OpenAIResponseError(
code="processing_error",
message=f"Background response timed out after {BACKGROUND_RESPONSE_TIMEOUT_SECONDS}s",
)
await self.responses_store.update_response_object(existing)
except Exception:
logger.exception(
f"Failed to update response {response_id} with timeout status. "
"Client polling this response will not see the failure."
)
except Exception as e:
response_id = item.kwargs["response_id"]
logger.exception(f"Error processing background response {response_id}")
try:
existing = await self.responses_store.get_response_object(response_id)
existing.status = "failed"
existing.error = OpenAIResponseError(
code="processing_error",
message=str(e),
)
await self.responses_store.update_response_object(existing)
except Exception:
logger.exception(
f"Failed to update response {response_id} with error status. "
"Client polling this response will not see the failure."
)
finally:
self._background_queue.task_done()

async def _prepend_previous_response(
self,
Expand Down Expand Up @@ -820,33 +832,36 @@ async def _create_background_response(
# Enqueue work item for background workers. Raises QueueFull if at capacity.
try:
self._background_queue.put_nowait(
dict(
response_id=response_id,
input=input,
model=model,
prompt=prompt,
instructions=instructions,
previous_response_id=previous_response_id,
conversation=conversation,
store=store,
temperature=temperature,
frequency_penalty=frequency_penalty,
text=text,
tool_choice=tool_choice,
tools=tools,
include=include,
max_infer_iters=max_infer_iters,
guardrail_ids=guardrail_ids,
parallel_tool_calls=parallel_tool_calls,
max_tool_calls=max_tool_calls,
reasoning=reasoning,
max_output_tokens=max_output_tokens,
safety_identifier=safety_identifier,
service_tier=service_tier,
metadata=metadata,
truncation=truncation,
presence_penalty=presence_penalty,
extra_body=extra_body,
_BackgroundWorkItem(
otel_context=capture_otel_context(),
kwargs=dict(
response_id=response_id,
input=input,
model=model,
prompt=prompt,
instructions=instructions,
previous_response_id=previous_response_id,
conversation=conversation,
store=store,
temperature=temperature,
frequency_penalty=frequency_penalty,
text=text,
tool_choice=tool_choice,
tools=tools,
include=include,
max_infer_iters=max_infer_iters,
guardrail_ids=guardrail_ids,
parallel_tool_calls=parallel_tool_calls,
max_tool_calls=max_tool_calls,
reasoning=reasoning,
max_output_tokens=max_output_tokens,
safety_identifier=safety_identifier,
service_tier=service_tier,
metadata=metadata,
truncation=truncation,
presence_penalty=presence_penalty,
extra_body=extra_body,
),
)
)
except asyncio.QueueFull:
Expand Down
24 changes: 16 additions & 8 deletions src/llama_stack/providers/utils/inference/inference_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,16 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import asyncio
from typing import Any
from typing import Any, NamedTuple

from opentelemetry import context as otel_context
from sqlalchemy.exc import IntegrityError

from llama_stack.core.datatypes import AccessRule
from llama_stack.core.storage.datatypes import InferenceStoreReference, StorageBackendType
from llama_stack.core.storage.sqlstore.authorized_sqlstore import AuthorizedSqlStore
from llama_stack.core.storage.sqlstore.sqlstore import _SQLSTORE_BACKENDS, sqlstore_impl
from llama_stack.core.task import activate_otel_context, capture_otel_context, create_task_with_detached_otel_context
from llama_stack.log import get_logger
from llama_stack_api import (
ListOpenAIChatCompletionResponse,
Expand All @@ -25,6 +27,12 @@
logger = get_logger(name=__name__, category="inference")


class _WriteItem(NamedTuple):
completion: OpenAIChatCompletion
messages: list[OpenAIMessageParam]
otel_context: otel_context.Context


class InferenceStore:
def __init__(
self,
Expand All @@ -37,7 +45,7 @@ def __init__(
self.enable_write_queue = True

# Async write queue and worker control
self._queue: asyncio.Queue[tuple[OpenAIChatCompletion, list[OpenAIMessageParam]]] | None = None
self._queue: asyncio.Queue[_WriteItem] | None = None
self._worker_tasks: list[asyncio.Task[Any]] = []
self._max_write_queue_size: int = reference.max_write_queue_size
self._num_writers: int = max(1, reference.num_writers)
Expand Down Expand Up @@ -98,9 +106,8 @@ async def _ensure_workers_started(self) -> None:
)

if not self._worker_tasks:
loop = asyncio.get_running_loop()
for _ in range(self._num_writers):
task = loop.create_task(self._worker_loop())
task = create_task_with_detached_otel_context(self._worker_loop())
self._worker_tasks.append(task)

async def store_chat_completion(
Expand All @@ -110,13 +117,14 @@ async def store_chat_completion(
await self._ensure_workers_started()
if self._queue is None:
raise ValueError("Inference store is not initialized")
item = _WriteItem(chat_completion, input_messages, capture_otel_context())
try:
self._queue.put_nowait((chat_completion, input_messages))
self._queue.put_nowait(item)
except asyncio.QueueFull:
logger.warning(
f"Write queue full; adding chat completion id={getattr(chat_completion, 'id', '<unknown>')}"
)
await self._queue.put((chat_completion, input_messages))
await self._queue.put(item)
else:
await self._write_chat_completion(chat_completion, input_messages)

Expand All @@ -127,9 +135,9 @@ async def _worker_loop(self) -> None:
item = await self._queue.get()
except asyncio.CancelledError:
break
chat_completion, input_messages = item
try:
await self._write_chat_completion(chat_completion, input_messages)
with activate_otel_context(item.otel_context):
await self._write_chat_completion(item.completion, item.messages)
except Exception as e: # noqa: BLE001
logger.error(f"Error writing chat completion: {e}")
finally:
Expand Down
Loading
Loading