Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 45 additions & 22 deletions src/llama_stack/core/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,44 +7,67 @@
import asyncio
from collections.abc import Coroutine
from contextlib import contextmanager
from dataclasses import dataclass
from typing import Any

from opentelemetry import context as otel_context

from llama_stack.core.request_headers import PROVIDER_DATA_VAR

def create_task_with_detached_otel_context(coro: Coroutine[Any, Any, Any]) -> asyncio.Task[Any]:
"""Create an asyncio task that does not inherit the current OpenTelemetry trace context.

asyncio.create_task copies all contextvars at creation time, which causes
fire-and-forget or long-lived background tasks to be attributed to whatever
request happened to spawn them. This inflates trace durations and bundles
unrelated DB operations under the wrong trace.
@dataclass
class RequestContext:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmmmm looking at this, I wonder if this would've been useful to be in the API pkg if used by providers... not something to change in this PR though.

"""Snapshot of request-scoped state for propagation through background queues.

Background workers are long-lived asyncio tasks whose contextvars are frozen
at creation time. Capturing both the OTel trace context and the provider /
auth data at *enqueue* time and re-activating them per work-item ensures:

This helper temporarily clears the OTel context before creating the task,
then immediately restores it so the calling coroutine is unaffected.
* Each DB write is attributed to the correct request trace (OTel).
* Each DB write is stamped with the correct user identity (PROVIDER_DATA_VAR).
"""
token = otel_context.attach(otel_context.Context())
try:
task = asyncio.create_task(coro)
finally:
otel_context.detach(token)
return task

otel_ctx: otel_context.Context
provider_data: Any

def capture_otel_context() -> otel_context.Context:
"""Snapshot the current OTel context for later use in a different task."""
return otel_context.get_current()

def capture_request_context() -> RequestContext:
"""Snapshot the current request-scoped context for later use in a worker."""
return RequestContext(
otel_ctx=otel_context.get_current(),
provider_data=PROVIDER_DATA_VAR.get(),
)


@contextmanager
def activate_otel_context(ctx: otel_context.Context):
"""Temporarily activate a previously captured OTel context.
def activate_request_context(ctx: RequestContext):
"""Temporarily restore a previously captured request context.

Use this in worker loops that run with a detached (empty) context to
attribute work back to the originating request's trace.
attribute work back to the originating request.
"""
token = otel_context.attach(ctx)
otel_token = otel_context.attach(ctx.otel_ctx)
provider_token = PROVIDER_DATA_VAR.set(ctx.provider_data)
try:
yield
finally:
otel_context.detach(token)
PROVIDER_DATA_VAR.reset(provider_token)
otel_context.detach(otel_token)


def create_detached_background_task(coro: Coroutine[Any, Any, Any]) -> asyncio.Task[Any]:
"""Create an asyncio task that does not inherit request-scoped context.

asyncio.create_task copies all contextvars at creation time, which causes
long-lived background workers to permanently inherit the spawning request's
OTel trace and auth identity. This helper temporarily clears both before
creating the task, then immediately restores them so the caller is unaffected.
"""
otel_token = otel_context.attach(otel_context.Context())
provider_token = PROVIDER_DATA_VAR.set(None)
try:
task = asyncio.create_task(coro)
finally:
PROVIDER_DATA_VAR.reset(provider_token)
otel_context.detach(otel_token)
return task
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,15 @@
from collections.abc import AsyncIterator
from dataclasses import dataclass, field

from opentelemetry import context as otel_context
from pydantic import BaseModel, TypeAdapter

from llama_stack.core.conversations.validation import CONVERSATION_ID_PATTERN
from llama_stack.core.task import activate_otel_context, capture_otel_context, create_task_with_detached_otel_context
from llama_stack.core.task import (
RequestContext,
activate_request_context,
capture_request_context,
create_detached_background_task,
)
from llama_stack.log import get_logger
from llama_stack.providers.utils.responses.responses_store import (
ResponsesStore,
Expand Down Expand Up @@ -87,9 +91,9 @@

@dataclass
class _BackgroundWorkItem:
"""Typed queue item for background response processing."""
"""Typed queue item that pairs business kwargs with the originating request context."""

otel_context: otel_context.Context
request_context: RequestContext
kwargs: dict = field(default_factory=dict)


Expand Down Expand Up @@ -144,7 +148,7 @@ async def initialize(self) -> None:
async def _ensure_workers_started(self) -> None:
"""Start background workers in the current event loop if not already running."""
for _ in range(BACKGROUND_NUM_WORKERS - len(self._background_worker_tasks)):
task = create_task_with_detached_otel_context(self._background_worker())
task = create_detached_background_task(self._background_worker())
self._background_worker_tasks.add(task)
task.add_done_callback(self._background_worker_tasks.discard)

Expand All @@ -158,7 +162,7 @@ async def _background_worker(self) -> None:
"""Worker coroutine that pulls items from the queue and processes them."""
while True:
item = await self._background_queue.get()
with activate_otel_context(item.otel_context):
with activate_request_context(item.request_context):
try:
await asyncio.wait_for(
self._run_background_response_loop(**item.kwargs),
Expand Down Expand Up @@ -833,7 +837,7 @@ async def _create_background_response(
try:
self._background_queue.put_nowait(
_BackgroundWorkItem(
otel_context=capture_otel_context(),
request_context=capture_request_context(),
kwargs=dict(
response_id=response_id,
input=input,
Expand Down
16 changes: 10 additions & 6 deletions src/llama_stack/providers/utils/inference/inference_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,18 @@
import asyncio
from typing import Any, NamedTuple

from opentelemetry import context as otel_context
from sqlalchemy.exc import IntegrityError

from llama_stack.core.datatypes import AccessRule
from llama_stack.core.storage.datatypes import InferenceStoreReference, StorageBackendType
from llama_stack.core.storage.sqlstore.authorized_sqlstore import AuthorizedSqlStore
from llama_stack.core.storage.sqlstore.sqlstore import _SQLSTORE_BACKENDS, sqlstore_impl
from llama_stack.core.task import activate_otel_context, capture_otel_context, create_task_with_detached_otel_context
from llama_stack.core.task import (
RequestContext,
activate_request_context,
capture_request_context,
create_detached_background_task,
)
from llama_stack.log import get_logger
from llama_stack_api import (
ListOpenAIChatCompletionResponse,
Expand All @@ -30,7 +34,7 @@
class _WriteItem(NamedTuple):
completion: OpenAIChatCompletion
messages: list[OpenAIMessageParam]
otel_context: otel_context.Context
request_context: RequestContext


class InferenceStore:
Expand Down Expand Up @@ -107,7 +111,7 @@ async def _ensure_workers_started(self) -> None:

if not self._worker_tasks:
for _ in range(self._num_writers):
task = create_task_with_detached_otel_context(self._worker_loop())
task = create_detached_background_task(self._worker_loop())
self._worker_tasks.append(task)

async def store_chat_completion(
Expand All @@ -117,7 +121,7 @@ async def store_chat_completion(
await self._ensure_workers_started()
if self._queue is None:
raise ValueError("Inference store is not initialized")
item = _WriteItem(chat_completion, input_messages, capture_otel_context())
item = _WriteItem(chat_completion, input_messages, capture_request_context())
try:
self._queue.put_nowait(item)
except asyncio.QueueFull:
Expand All @@ -136,7 +140,7 @@ async def _worker_loop(self) -> None:
except asyncio.CancelledError:
break
try:
with activate_otel_context(item.otel_context):
with activate_request_context(item.request_context):
await self._write_chat_completion(item.completion, item.messages)
except Exception as e: # noqa: BLE001
logger.error(f"Error writing chat completion: {e}")
Expand Down
Loading
Loading