diff --git a/src/llama_stack/core/task.py b/src/llama_stack/core/task.py new file mode 100644 index 0000000000..e95b767841 --- /dev/null +++ b/src/llama_stack/core/task.py @@ -0,0 +1,73 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import asyncio +from collections.abc import Coroutine +from contextlib import contextmanager +from dataclasses import dataclass +from typing import Any + +from opentelemetry import context as otel_context + +from llama_stack.core.request_headers import PROVIDER_DATA_VAR + + +@dataclass +class RequestContext: + """Snapshot of request-scoped state for propagation through background queues. + + Background workers are long-lived asyncio tasks whose contextvars are frozen + at creation time. Capturing both the OTel trace context and the provider / + auth data at *enqueue* time and re-activating them per work-item ensures: + + * Each DB write is attributed to the correct request trace (OTel). + * Each DB write is stamped with the correct user identity (PROVIDER_DATA_VAR). + """ + + otel_ctx: otel_context.Context + provider_data: Any + + +def capture_request_context() -> RequestContext: + """Snapshot the current request-scoped context for later use in a worker.""" + return RequestContext( + otel_ctx=otel_context.get_current(), + provider_data=PROVIDER_DATA_VAR.get(), + ) + + +@contextmanager +def activate_request_context(ctx: RequestContext): + """Temporarily restore a previously captured request context. + + Use this in worker loops that run with a detached (empty) context to + attribute work back to the originating request. + """ + otel_token = otel_context.attach(ctx.otel_ctx) + provider_token = PROVIDER_DATA_VAR.set(ctx.provider_data) + try: + yield + finally: + PROVIDER_DATA_VAR.reset(provider_token) + otel_context.detach(otel_token) + + +def create_detached_background_task(coro: Coroutine[Any, Any, Any]) -> asyncio.Task[Any]: + """Create an asyncio task that does not inherit request-scoped context. + + asyncio.create_task copies all contextvars at creation time, which causes + long-lived background workers to permanently inherit the spawning request's + OTel trace and auth identity. This helper temporarily clears both before + creating the task, then immediately restores them so the caller is unaffected. + """ + otel_token = otel_context.attach(otel_context.Context()) + provider_token = PROVIDER_DATA_VAR.set(None) + try: + task = asyncio.create_task(coro) + finally: + PROVIDER_DATA_VAR.reset(provider_token) + otel_context.detach(otel_token) + return task diff --git a/src/llama_stack/providers/inline/agents/meta_reference/responses/openai_responses.py b/src/llama_stack/providers/inline/agents/meta_reference/responses/openai_responses.py index 46aafccb6e..c58055da39 100644 --- a/src/llama_stack/providers/inline/agents/meta_reference/responses/openai_responses.py +++ b/src/llama_stack/providers/inline/agents/meta_reference/responses/openai_responses.py @@ -9,9 +9,16 @@ import time import uuid from collections.abc import AsyncIterator +from dataclasses import dataclass, field from pydantic import BaseModel, TypeAdapter +from llama_stack.core.task import ( + RequestContext, + activate_request_context, + capture_request_context, + create_detached_background_task, +) from llama_stack.log import get_logger from llama_stack.providers.utils.responses.responses_store import ( ResponsesStore, @@ -80,6 +87,14 @@ BACKGROUND_NUM_WORKERS = 10 +@dataclass +class _BackgroundWorkItem: + """Typed queue item that pairs business kwargs with the originating request context.""" + + request_context: RequestContext + kwargs: dict = field(default_factory=dict) + + class OpenAIResponsePreviousResponseWithInputItems(BaseModel): input_items: ListOpenAIResponseInputItem response: OpenAIResponseObject @@ -131,7 +146,7 @@ async def initialize(self) -> None: async def _ensure_workers_started(self) -> None: """Start background workers in the current event loop if not already running.""" for _ in range(BACKGROUND_NUM_WORKERS - len(self._background_worker_tasks)): - task = asyncio.create_task(self._background_worker()) + task = create_detached_background_task(self._background_worker()) self._background_worker_tasks.add(task) task.add_done_callback(self._background_worker_tasks.discard) @@ -144,48 +159,49 @@ async def shutdown(self) -> None: async def _background_worker(self) -> None: """Worker coroutine that pulls items from the queue and processes them.""" while True: - kwargs = await self._background_queue.get() - try: - await asyncio.wait_for( - self._run_background_response_loop(**kwargs), - timeout=BACKGROUND_RESPONSE_TIMEOUT_SECONDS, - ) - except TimeoutError: - response_id = kwargs["response_id"] - logger.exception( - f"Background response {response_id} timed out after {BACKGROUND_RESPONSE_TIMEOUT_SECONDS}s" - ) - try: - existing = await self.responses_store.get_response_object(response_id) - existing.status = "failed" - existing.error = OpenAIResponseError( - code="processing_error", - message=f"Background response timed out after {BACKGROUND_RESPONSE_TIMEOUT_SECONDS}s", - ) - await self.responses_store.update_response_object(existing) - except Exception: - logger.exception( - f"Failed to update response {response_id} with timeout status. " - "Client polling this response will not see the failure." - ) - except Exception as e: - response_id = kwargs["response_id"] - logger.exception(f"Error processing background response {response_id}") + item = await self._background_queue.get() + with activate_request_context(item.request_context): try: - existing = await self.responses_store.get_response_object(response_id) - existing.status = "failed" - existing.error = OpenAIResponseError( - code="processing_error", - message=str(e), + await asyncio.wait_for( + self._run_background_response_loop(**item.kwargs), + timeout=BACKGROUND_RESPONSE_TIMEOUT_SECONDS, ) - await self.responses_store.update_response_object(existing) - except Exception: + except TimeoutError: + response_id = item.kwargs["response_id"] logger.exception( - f"Failed to update response {response_id} with error status. " - "Client polling this response will not see the failure." + f"Background response {response_id} timed out after {BACKGROUND_RESPONSE_TIMEOUT_SECONDS}s" ) - finally: - self._background_queue.task_done() + try: + existing = await self.responses_store.get_response_object(response_id) + existing.status = "failed" + existing.error = OpenAIResponseError( + code="processing_error", + message=f"Background response timed out after {BACKGROUND_RESPONSE_TIMEOUT_SECONDS}s", + ) + await self.responses_store.update_response_object(existing) + except Exception: + logger.exception( + f"Failed to update response {response_id} with timeout status. " + "Client polling this response will not see the failure." + ) + except Exception as e: + response_id = item.kwargs["response_id"] + logger.exception(f"Error processing background response {response_id}") + try: + existing = await self.responses_store.get_response_object(response_id) + existing.status = "failed" + existing.error = OpenAIResponseError( + code="processing_error", + message=str(e), + ) + await self.responses_store.update_response_object(existing) + except Exception: + logger.exception( + f"Failed to update response {response_id} with error status. " + "Client polling this response will not see the failure." + ) + finally: + self._background_queue.task_done() async def _prepend_previous_response( self, @@ -812,33 +828,36 @@ async def _create_background_response( # Enqueue work item for background workers. Raises QueueFull if at capacity. try: self._background_queue.put_nowait( - dict( - response_id=response_id, - input=input, - model=model, - prompt=prompt, - instructions=instructions, - previous_response_id=previous_response_id, - conversation=conversation, - store=store, - temperature=temperature, - frequency_penalty=frequency_penalty, - text=text, - tool_choice=tool_choice, - tools=tools, - include=include, - max_infer_iters=max_infer_iters, - guardrail_ids=guardrail_ids, - parallel_tool_calls=parallel_tool_calls, - max_tool_calls=max_tool_calls, - reasoning=reasoning, - max_output_tokens=max_output_tokens, - safety_identifier=safety_identifier, - service_tier=service_tier, - metadata=metadata, - truncation=truncation, - presence_penalty=presence_penalty, - extra_body=extra_body, + _BackgroundWorkItem( + request_context=capture_request_context(), + kwargs=dict( + response_id=response_id, + input=input, + model=model, + prompt=prompt, + instructions=instructions, + previous_response_id=previous_response_id, + conversation=conversation, + store=store, + temperature=temperature, + frequency_penalty=frequency_penalty, + text=text, + tool_choice=tool_choice, + tools=tools, + include=include, + max_infer_iters=max_infer_iters, + guardrail_ids=guardrail_ids, + parallel_tool_calls=parallel_tool_calls, + max_tool_calls=max_tool_calls, + reasoning=reasoning, + max_output_tokens=max_output_tokens, + safety_identifier=safety_identifier, + service_tier=service_tier, + metadata=metadata, + truncation=truncation, + presence_penalty=presence_penalty, + extra_body=extra_body, + ), ) ) except asyncio.QueueFull: diff --git a/src/llama_stack/providers/utils/inference/inference_store.py b/src/llama_stack/providers/utils/inference/inference_store.py index 78327573b6..801a5b29f9 100644 --- a/src/llama_stack/providers/utils/inference/inference_store.py +++ b/src/llama_stack/providers/utils/inference/inference_store.py @@ -4,7 +4,7 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. import asyncio -from typing import Any +from typing import Any, NamedTuple from sqlalchemy.exc import IntegrityError @@ -12,6 +12,12 @@ from llama_stack.core.storage.datatypes import InferenceStoreReference, StorageBackendType from llama_stack.core.storage.sqlstore.authorized_sqlstore import AuthorizedSqlStore from llama_stack.core.storage.sqlstore.sqlstore import _SQLSTORE_BACKENDS, sqlstore_impl +from llama_stack.core.task import ( + RequestContext, + activate_request_context, + capture_request_context, + create_detached_background_task, +) from llama_stack.log import get_logger from llama_stack_api import ( ListOpenAIChatCompletionResponse, @@ -25,6 +31,12 @@ logger = get_logger(name=__name__, category="inference") +class _WriteItem(NamedTuple): + completion: OpenAIChatCompletion + messages: list[OpenAIMessageParam] + request_context: RequestContext + + class InferenceStore: def __init__( self, @@ -37,7 +49,7 @@ def __init__( self.enable_write_queue = True # Async write queue and worker control - self._queue: asyncio.Queue[tuple[OpenAIChatCompletion, list[OpenAIMessageParam]]] | None = None + self._queue: asyncio.Queue[_WriteItem] | None = None self._worker_tasks: list[asyncio.Task[Any]] = [] self._max_write_queue_size: int = reference.max_write_queue_size self._num_writers: int = max(1, reference.num_writers) @@ -98,9 +110,8 @@ async def _ensure_workers_started(self) -> None: ) if not self._worker_tasks: - loop = asyncio.get_running_loop() for _ in range(self._num_writers): - task = loop.create_task(self._worker_loop()) + task = create_detached_background_task(self._worker_loop()) self._worker_tasks.append(task) async def store_chat_completion( @@ -110,13 +121,14 @@ async def store_chat_completion( await self._ensure_workers_started() if self._queue is None: raise ValueError("Inference store is not initialized") + item = _WriteItem(chat_completion, input_messages, capture_request_context()) try: - self._queue.put_nowait((chat_completion, input_messages)) + self._queue.put_nowait(item) except asyncio.QueueFull: logger.warning( f"Write queue full; adding chat completion id={getattr(chat_completion, 'id', '')}" ) - await self._queue.put((chat_completion, input_messages)) + await self._queue.put(item) else: await self._write_chat_completion(chat_completion, input_messages) @@ -127,9 +139,9 @@ async def _worker_loop(self) -> None: item = await self._queue.get() except asyncio.CancelledError: break - chat_completion, input_messages = item try: - await self._write_chat_completion(chat_completion, input_messages) + with activate_request_context(item.request_context): + await self._write_chat_completion(item.completion, item.messages) except Exception as e: # noqa: BLE001 logger.error(f"Error writing chat completion: {e}") finally: diff --git a/tests/unit/core/test_task.py b/tests/unit/core/test_task.py new file mode 100644 index 0000000000..e175c59183 --- /dev/null +++ b/tests/unit/core/test_task.py @@ -0,0 +1,297 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import asyncio + +from opentelemetry import context as otel_context +from opentelemetry import trace +from opentelemetry.sdk.trace import TracerProvider +from opentelemetry.sdk.trace.export import SimpleSpanProcessor, SpanExporter, SpanExportResult + +from llama_stack.core.request_headers import PROVIDER_DATA_VAR +from llama_stack.core.task import ( + RequestContext, + activate_request_context, + capture_request_context, + create_detached_background_task, +) + + +class _CollectingExporter(SpanExporter): + """Collects finished spans in memory for test assertions.""" + + def __init__(self): + self.spans = [] + + def export(self, spans): + self.spans.extend(spans) + return SpanExportResult.SUCCESS + + +async def test_detached_task_runs_coroutine(): + """The helper creates a task that actually runs the coroutine to completion.""" + result = [] + + async def work(): + result.append("done") + + task = create_detached_background_task(work()) + await task + assert result == ["done"] + + +async def test_detached_task_clears_otel_context(): + """The task should run with an empty OTel context, not the parent's.""" + provider = TracerProvider() + tracer = provider.get_tracer("test") + + captured_span = {} + + async def capture_context(): + captured_span["inner"] = trace.get_current_span() + + with tracer.start_as_current_span("parent-span"): + parent_ctx = otel_context.get_current() + parent_span = trace.get_current_span() + + task = create_detached_background_task(capture_context()) + await task + + assert not captured_span["inner"].is_recording() + assert parent_span.is_recording() + assert otel_context.get_current() == parent_ctx + + +async def test_detached_task_clears_provider_data(): + """The task should run with PROVIDER_DATA_VAR cleared.""" + captured = {} + token = PROVIDER_DATA_VAR.set({"__authenticated_user": "alice"}) + + async def capture_provider(): + captured["value"] = PROVIDER_DATA_VAR.get() + + try: + task = create_detached_background_task(capture_provider()) + await task + + assert captured["value"] is None, "Background task should not inherit PROVIDER_DATA_VAR" + assert PROVIDER_DATA_VAR.get() == {"__authenticated_user": "alice"}, "Caller's context should be unaffected" + finally: + PROVIDER_DATA_VAR.reset(token) + + +async def test_detached_task_restores_caller_context(): + """The calling coroutine's context is not affected by creating a detached task.""" + provider = TracerProvider() + tracer = provider.get_tracer("test") + + token = PROVIDER_DATA_VAR.set({"__authenticated_user": "bob"}) + try: + with tracer.start_as_current_span("parent-span"): + otel_before = otel_context.get_current() + provider_before = PROVIDER_DATA_VAR.get() + + create_detached_background_task(asyncio.sleep(0)) + + assert otel_context.get_current() == otel_before + assert PROVIDER_DATA_VAR.get() == provider_before + finally: + PROVIDER_DATA_VAR.reset(token) + + +async def test_detached_task_produces_independent_trace(): + """Spans created inside a detached task belong to a separate trace.""" + exporter = _CollectingExporter() + provider = TracerProvider() + provider.add_span_processor(SimpleSpanProcessor(exporter)) + tracer = provider.get_tracer("test") + + async def background_work(): + with tracer.start_as_current_span("background-db-write"): + await asyncio.sleep(0) + + with tracer.start_as_current_span("http-request"): + task = create_detached_background_task(background_work()) + await task + + provider.force_flush() + span_by_name = {s.name: s for s in exporter.spans} + + request_span = span_by_name["http-request"] + bg_span = span_by_name["background-db-write"] + + assert request_span.context.trace_id != bg_span.context.trace_id + assert bg_span.parent is None + + +async def test_normal_child_task_shares_trace(): + """Contrast: a regular asyncio.create_task DOES inherit the parent trace.""" + exporter = _CollectingExporter() + provider = TracerProvider() + provider.add_span_processor(SimpleSpanProcessor(exporter)) + tracer = provider.get_tracer("test") + + async def child_work(): + with tracer.start_as_current_span("child-span"): + await asyncio.sleep(0) + + with tracer.start_as_current_span("parent-request"): + task = asyncio.create_task(child_work()) + await task + + provider.force_flush() + span_by_name = {s.name: s for s in exporter.spans} + + parent_span = span_by_name["parent-request"] + child_span = span_by_name["child-span"] + + assert parent_span.context.trace_id == child_span.context.trace_id, ( + "Regular create_task should share the parent's trace" + ) + + +async def test_context_through_queue_pattern(): + """End-to-end: context captured at enqueue time is correctly attached in a detached worker. + + This simulates the inference_store pattern: + 1. Request creates a span and enqueues work with captured context + 2. Worker runs in a detached (empty) context + 3. Worker attaches the captured context before processing + 4. The resulting span belongs to the original request's trace + """ + exporter = _CollectingExporter() + provider = TracerProvider() + provider.add_span_processor(SimpleSpanProcessor(exporter)) + tracer = provider.get_tracer("test") + + queue: asyncio.Queue[tuple[str, RequestContext]] = asyncio.Queue() + + async def worker(): + item, ctx = await queue.get() + with activate_request_context(ctx): + with tracer.start_as_current_span(f"db-write-{item}"): + await asyncio.sleep(0) + queue.task_done() + + token = PROVIDER_DATA_VAR.set({"user": "A"}) + try: + with tracer.start_as_current_span("http-request-A"): + ctx_a = capture_request_context() + await queue.put(("A", ctx_a)) + finally: + PROVIDER_DATA_VAR.reset(token) + + worker_task = create_detached_background_task(worker()) + await worker_task + await queue.join() + + provider.force_flush() + span_by_name = {s.name: s for s in exporter.spans} + + request_span = span_by_name["http-request-A"] + write_span = span_by_name["db-write-A"] + + assert request_span.context.trace_id == write_span.context.trace_id, ( + "DB write should belong to the same trace as the originating request" + ) + + +async def test_capture_and_activate_request_context(): + """capture_request_context snapshots both OTel and provider data; activate restores both.""" + exporter = _CollectingExporter() + provider = TracerProvider() + provider.add_span_processor(SimpleSpanProcessor(exporter)) + tracer = provider.get_tracer("test") + + token = PROVIDER_DATA_VAR.set({"__authenticated_user": "charlie"}) + try: + with tracer.start_as_current_span("request"): + ctx = capture_request_context() + request_trace_id = trace.get_current_span().get_span_context().trace_id + + assert isinstance(ctx, RequestContext) + assert ctx.provider_data == {"__authenticated_user": "charlie"} + + # After span ends, activate context and verify OTel trace is restored + with activate_request_context(ctx): + with tracer.start_as_current_span("reattached-work"): + reattached_trace_id = trace.get_current_span().get_span_context().trace_id + assert PROVIDER_DATA_VAR.get() == {"__authenticated_user": "charlie"} + + assert request_trace_id == reattached_trace_id + finally: + PROVIDER_DATA_VAR.reset(token) + + +async def test_activate_restores_on_exit(): + """activate_request_context restores the previous context when the block exits.""" + provider = TracerProvider() + tracer = provider.get_tracer("test") + + token = PROVIDER_DATA_VAR.set({"__authenticated_user": "outer_user"}) + try: + with tracer.start_as_current_span("outer"): + outer_otel = otel_context.get_current() + + inner_ctx = RequestContext( + otel_ctx=otel_context.Context(), + provider_data={"__authenticated_user": "inner_user"}, + ) + with activate_request_context(inner_ctx): + assert PROVIDER_DATA_VAR.get() == {"__authenticated_user": "inner_user"} + + assert PROVIDER_DATA_VAR.get() == {"__authenticated_user": "outer_user"} + assert otel_context.get_current() == outer_otel + finally: + PROVIDER_DATA_VAR.reset(token) + + +async def test_context_through_queue_no_cross_contamination(): + """Two requests enqueue work; each item's context is correctly propagated.""" + exporter = _CollectingExporter() + provider = TracerProvider() + provider.add_span_processor(SimpleSpanProcessor(exporter)) + tracer = provider.get_tracer("test") + + queue: asyncio.Queue[tuple[str, RequestContext]] = asyncio.Queue() + processed = asyncio.Event() + + async def worker(): + for _ in range(2): + label, ctx = await queue.get() + with activate_request_context(ctx): + assert PROVIDER_DATA_VAR.get() == {"user": label} + with tracer.start_as_current_span(f"db-write-{label}"): + await asyncio.sleep(0) + queue.task_done() + processed.set() + + worker_task = create_detached_background_task(worker()) + + token_a = PROVIDER_DATA_VAR.set({"user": "A"}) + with tracer.start_as_current_span("request-A"): + await queue.put(("A", capture_request_context())) + PROVIDER_DATA_VAR.reset(token_a) + + token_b = PROVIDER_DATA_VAR.set({"user": "B"}) + with tracer.start_as_current_span("request-B"): + await queue.put(("B", capture_request_context())) + PROVIDER_DATA_VAR.reset(token_b) + + await processed.wait() + await worker_task + + provider.force_flush() + span_by_name = {s.name: s for s in exporter.spans} + + request_a = span_by_name["request-A"] + request_b = span_by_name["request-B"] + write_a = span_by_name["db-write-A"] + write_b = span_by_name["db-write-B"] + + assert write_a.context.trace_id == request_a.context.trace_id + assert write_b.context.trace_id == request_b.context.trace_id + assert request_a.context.trace_id != request_b.context.trace_id diff --git a/tests/unit/providers/agents/meta_reference/test_responses_background.py b/tests/unit/providers/agents/meta_reference/test_responses_background.py index b95b87e3d9..0e8916cded 100644 --- a/tests/unit/providers/agents/meta_reference/test_responses_background.py +++ b/tests/unit/providers/agents/meta_reference/test_responses_background.py @@ -6,8 +6,21 @@ """Unit tests for background parameter support in Responses API.""" +import asyncio +from unittest.mock import AsyncMock, patch + import pytest +from opentelemetry import trace +from opentelemetry.sdk.trace import TracerProvider +from opentelemetry.sdk.trace.export import SimpleSpanProcessor, SpanExporter, SpanExportResult +from llama_stack.core.datatypes import User +from llama_stack.core.request_headers import PROVIDER_DATA_VAR, get_authenticated_user +from llama_stack.core.task import capture_request_context, create_detached_background_task +from llama_stack.providers.inline.agents.meta_reference.responses.openai_responses import ( + OpenAIResponsesImpl, + _BackgroundWorkItem, +) from llama_stack_api import OpenAIResponseError, OpenAIResponseObject @@ -144,3 +157,324 @@ def test_error_response_with_background(self): assert response.background is True assert response.error is not None assert response.error.code == "processing_error" + + +def _make_responses_impl(): + """Create an OpenAIResponsesImpl with all dependencies mocked.""" + return OpenAIResponsesImpl( + inference_api=AsyncMock(), + tool_groups_api=AsyncMock(), + tool_runtime_api=AsyncMock(), + responses_store=AsyncMock(), + vector_io_api=AsyncMock(), + safety_api=None, + conversations_api=AsyncMock(), + prompts_api=AsyncMock(), + files_api=AsyncMock(), + connectors_api=AsyncMock(), + ) + + +class _CollectingExporter(SpanExporter): + """Collects finished spans in memory for test assertions.""" + + def __init__(self): + self.spans = [] + + def export(self, spans): + self.spans.extend(spans) + return SpanExportResult.SUCCESS + + +class TestResponsesOtelContextPropagation: + """Verify that OTel trace context flows correctly through the background worker queue.""" + + async def test_worker_attributes_work_to_correct_request_trace(self): + """Each queued response is processed under its originating request's trace context.""" + exporter = _CollectingExporter() + provider = TracerProvider() + provider.add_span_processor(SimpleSpanProcessor(exporter)) + tracer = provider.get_tracer("test") + + impl = _make_responses_impl() + + async def mock_response_loop(**kwargs): + with tracer.start_as_current_span(f"process-{kwargs['response_id']}"): + await asyncio.sleep(0) + + with patch.object(impl, "_run_background_response_loop", side_effect=mock_response_loop): + worker_task = create_detached_background_task(impl._background_worker()) + + with tracer.start_as_current_span("request-A"): + impl._background_queue.put_nowait( + _BackgroundWorkItem(request_context=capture_request_context(), kwargs=dict(response_id="resp-A")) + ) + + with tracer.start_as_current_span("request-B"): + impl._background_queue.put_nowait( + _BackgroundWorkItem(request_context=capture_request_context(), kwargs=dict(response_id="resp-B")) + ) + + await impl._background_queue.join() + worker_task.cancel() + try: + await worker_task + except asyncio.CancelledError: + pass + + provider.force_flush() + spans_by_name = {s.name: s for s in exporter.spans} + + request_a_trace = spans_by_name["request-A"].context.trace_id + request_b_trace = spans_by_name["request-B"].context.trace_id + process_a_trace = spans_by_name["process-resp-A"].context.trace_id + process_b_trace = spans_by_name["process-resp-B"].context.trace_id + + assert request_a_trace != request_b_trace, "Requests should have distinct traces" + assert process_a_trace == request_a_trace, "Response processing for resp-A should be in request-A's trace" + assert process_b_trace == request_b_trace, "Response processing for resp-B should be in request-B's trace" + + async def test_worker_does_not_leak_context_between_items(self): + """After processing one item, the worker returns to a clean OTel context.""" + exporter = _CollectingExporter() + provider = TracerProvider() + provider.add_span_processor(SimpleSpanProcessor(exporter)) + tracer = provider.get_tracer("test") + + impl = _make_responses_impl() + trace_ids_during_processing = {} + + async def mock_response_loop(**kwargs): + rid = kwargs["response_id"] + span_ctx = trace.get_current_span().get_span_context() + trace_ids_during_processing[rid] = span_ctx.trace_id if span_ctx.trace_id != 0 else None + with tracer.start_as_current_span(f"work-{rid}"): + await asyncio.sleep(0) + + with patch.object(impl, "_run_background_response_loop", side_effect=mock_response_loop): + worker_task = create_detached_background_task(impl._background_worker()) + + with tracer.start_as_current_span("req-1"): + impl._background_queue.put_nowait( + _BackgroundWorkItem(request_context=capture_request_context(), kwargs=dict(response_id="r1")) + ) + + with tracer.start_as_current_span("req-2"): + impl._background_queue.put_nowait( + _BackgroundWorkItem(request_context=capture_request_context(), kwargs=dict(response_id="r2")) + ) + + await impl._background_queue.join() + worker_task.cancel() + try: + await worker_task + except asyncio.CancelledError: + pass + + provider.force_flush() + spans_by_name = {s.name: s for s in exporter.spans} + + req1_trace = spans_by_name["req-1"].context.trace_id + req2_trace = spans_by_name["req-2"].context.trace_id + + assert trace_ids_during_processing["r1"] is not None, "r1 should have a trace context" + assert trace_ids_during_processing["r2"] is not None, "r2 should have a trace context" + assert trace_ids_during_processing["r1"] == req1_trace + assert trace_ids_during_processing["r2"] == req2_trace + + async def test_error_handling_runs_under_request_context(self): + """When processing fails, the error handler's DB writes are also in the request's trace.""" + exporter = _CollectingExporter() + provider = TracerProvider() + provider.add_span_processor(SimpleSpanProcessor(exporter)) + tracer = provider.get_tracer("test") + + impl = _make_responses_impl() + + mock_response = OpenAIResponseObject( + id="resp-err", + created_at=1234567890, + model="test-model", + status="in_progress", + output=[], + store=True, + ) + impl.responses_store.get_response_object = AsyncMock(return_value=mock_response) + impl.responses_store.update_response_object = AsyncMock() + + error_update_trace_ids = [] + original_update = impl.responses_store.update_response_object + + async def tracking_update(obj): + span_ctx = trace.get_current_span().get_span_context() + if span_ctx.trace_id != 0: + error_update_trace_ids.append(span_ctx.trace_id) + return await original_update(obj) + + impl.responses_store.update_response_object = tracking_update + + async def failing_loop(**kwargs): + raise RuntimeError("simulated failure") + + with patch.object(impl, "_run_background_response_loop", side_effect=failing_loop): + worker_task = create_detached_background_task(impl._background_worker()) + + with tracer.start_as_current_span("failing-request"): + request_trace = trace.get_current_span().get_span_context().trace_id + impl._background_queue.put_nowait( + _BackgroundWorkItem(request_context=capture_request_context(), kwargs=dict(response_id="resp-err")) + ) + + await impl._background_queue.join() + worker_task.cancel() + try: + await worker_task + except asyncio.CancelledError: + pass + + assert len(error_update_trace_ids) > 0, "Error handler should have made DB updates" + for tid in error_update_trace_ids: + assert tid == request_trace, "Error handler DB writes should be in the failing request's trace" + + +def _set_authenticated_user(user: User | None): + """Simulate what ProviderDataMiddleware does for each request.""" + if user: + PROVIDER_DATA_VAR.set({"__authenticated_user": user}) + else: + PROVIDER_DATA_VAR.set(None) + + +class TestResponsesProviderDataPropagation: + """Verify that PROVIDER_DATA_VAR flows correctly through the background worker queue. + + The responses worker processes the full response loop (LLM calls, tool execution, + DB writes). All operations inside the worker must run with the originating + request's auth identity, not whichever request first spawned the worker. + """ + + async def test_worker_runs_under_correct_user_identity(self): + """Each queued response is processed under its originating user's identity.""" + impl = _make_responses_impl() + + alice = User(principal="alice", attributes={"roles": ["user"]}) + bob = User(principal="bob", attributes={"roles": ["user"]}) + + observed_users: dict[str, User | None] = {} + + async def mock_response_loop(**kwargs): + observed_users[kwargs["response_id"]] = get_authenticated_user() + + with patch.object(impl, "_run_background_response_loop", side_effect=mock_response_loop): + worker_task = create_detached_background_task(impl._background_worker()) + + _set_authenticated_user(alice) + impl._background_queue.put_nowait( + _BackgroundWorkItem(request_context=capture_request_context(), kwargs=dict(response_id="resp-alice")) + ) + + _set_authenticated_user(bob) + impl._background_queue.put_nowait( + _BackgroundWorkItem(request_context=capture_request_context(), kwargs=dict(response_id="resp-bob")) + ) + + await impl._background_queue.join() + worker_task.cancel() + try: + await worker_task + except asyncio.CancelledError: + pass + + _set_authenticated_user(None) + + assert observed_users["resp-alice"] is not None, "Alice's request should have a user" + assert observed_users["resp-bob"] is not None, "Bob's request should have a user" + assert observed_users["resp-alice"].principal == "alice", "Alice's response should run as alice" + assert observed_users["resp-bob"].principal == "bob", "Bob's response should run as bob" + + async def test_worker_does_not_leak_identity_between_items(self): + """After processing one item, the worker returns to a clean state.""" + impl = _make_responses_impl() + + alice = User(principal="alice", attributes={"roles": ["user"]}) + + user_after_processing: list[User | None] = [] + + async def mock_response_loop(**kwargs): + user_after_processing.append(get_authenticated_user()) + + with patch.object(impl, "_run_background_response_loop", side_effect=mock_response_loop): + worker_task = create_detached_background_task(impl._background_worker()) + + # First item: enqueued by Alice + _set_authenticated_user(alice) + impl._background_queue.put_nowait( + _BackgroundWorkItem(request_context=capture_request_context(), kwargs=dict(response_id="resp-1")) + ) + + # Second item: enqueued with no user (anonymous) + _set_authenticated_user(None) + impl._background_queue.put_nowait( + _BackgroundWorkItem(request_context=capture_request_context(), kwargs=dict(response_id="resp-2")) + ) + + await impl._background_queue.join() + worker_task.cancel() + try: + await worker_task + except asyncio.CancelledError: + pass + + assert user_after_processing[0] is not None, "First item should run as alice" + assert user_after_processing[0].principal == "alice" + assert user_after_processing[1] is None, "Second item should run as anonymous — alice's identity must not leak" + + async def test_error_handler_runs_under_correct_identity(self): + """When processing fails, error-handling DB writes use the correct user.""" + impl = _make_responses_impl() + + bob = User(principal="bob", attributes={"roles": ["user"]}) + + mock_response = OpenAIResponseObject( + id="resp-err", + created_at=1234567890, + model="test-model", + status="in_progress", + output=[], + store=True, + ) + impl.responses_store.get_response_object = AsyncMock(return_value=mock_response) + + error_handler_users: list[User | None] = [] + original_update = impl.responses_store.update_response_object + + async def tracking_update(obj): + error_handler_users.append(get_authenticated_user()) + return await original_update(obj) + + impl.responses_store.update_response_object = tracking_update + + async def failing_loop(**kwargs): + raise RuntimeError("simulated failure") + + with patch.object(impl, "_run_background_response_loop", side_effect=failing_loop): + worker_task = create_detached_background_task(impl._background_worker()) + + _set_authenticated_user(bob) + impl._background_queue.put_nowait( + _BackgroundWorkItem(request_context=capture_request_context(), kwargs=dict(response_id="resp-err")) + ) + + await impl._background_queue.join() + worker_task.cancel() + try: + await worker_task + except asyncio.CancelledError: + pass + + _set_authenticated_user(None) + + assert len(error_handler_users) > 0, "Error handler should have made DB updates" + for user in error_handler_users: + assert user is not None, "Error handler should have a user identity" + assert user.principal == "bob", "Error handler should run as bob, not the worker's inherited identity" diff --git a/tests/unit/utils/inference/test_provider_data_leak.py b/tests/unit/utils/inference/test_provider_data_leak.py new file mode 100644 index 0000000000..ac91ec03dc --- /dev/null +++ b/tests/unit/utils/inference/test_provider_data_leak.py @@ -0,0 +1,150 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +""" +Reproduces the PROVIDER_DATA_VAR contextvar leak through background worker tasks. + +The inference store uses a write queue with long-lived worker tasks (on Postgres). +asyncio.create_task copies all contextvars at creation time, so the worker +permanently inherits the first request's PROVIDER_DATA_VAR. This means every +DB write is stamped with the first user's identity, regardless of who actually +made the request. + +This test forces the write queue on (normally disabled for SQLite) to demonstrate +the leak without needing a Postgres instance. +""" + +import time + +import pytest + +from llama_stack.core.access_control.datatypes import AccessRule, Action, Scope +from llama_stack.core.datatypes import User +from llama_stack.core.request_headers import PROVIDER_DATA_VAR +from llama_stack.core.storage.datatypes import InferenceStoreReference, SqliteSqlStoreConfig +from llama_stack.core.storage.sqlstore.sqlstore import register_sqlstore_backends +from llama_stack.providers.utils.inference.inference_store import InferenceStore +from llama_stack_api import ( + OpenAIChatCompletion, + OpenAIChatCompletionResponseMessage, + OpenAIChoice, + OpenAIUserMessageParam, +) + + +@pytest.fixture(autouse=True) +def setup_backends(tmp_path): + db_path = str(tmp_path / "test_leak.db") + register_sqlstore_backends({"sql_default": SqliteSqlStoreConfig(db_path=db_path)}) + + +def _set_authenticated_user(user: User | None): + """Simulate what ProviderDataMiddleware does for each request.""" + if user: + PROVIDER_DATA_VAR.set({"__authenticated_user": user}) + else: + PROVIDER_DATA_VAR.set(None) + + +def _make_completion(completion_id: str, created: int) -> OpenAIChatCompletion: + return OpenAIChatCompletion( + id=completion_id, + created=created, + model="test-model", + object="chat.completion", + choices=[ + OpenAIChoice( + index=0, + message=OpenAIChatCompletionResponseMessage( + role="assistant", + content=f"Response for {completion_id}", + ), + finish_reason="stop", + ) + ], + ) + + +async def test_provider_data_leak_through_write_queue(): + """Demonstrates that PROVIDER_DATA_VAR leaks into background workers. + + Expected behavior: each completion should be owned by the user who created it. + Actual behavior: all completions are owned by whoever triggered worker creation. + """ + owner_policy = [ + AccessRule(permit=Scope(actions=[Action.READ]), when=["user is owner"]), + AccessRule(permit=Scope(actions=[Action.CREATE]), when=[]), + ] + + reference = InferenceStoreReference( + backend="sql_default", + table_name="leak_test", + num_writers=1, + ) + store = InferenceStore(reference, policy=owner_policy) + await store.initialize() + + # Force the write queue on (normally disabled for SQLite) + store.enable_write_queue = True + + alice = User(principal="alice", attributes={"roles": ["user"]}) + bob = User(principal="bob", attributes={"roles": ["user"]}) + + base_time = int(time.time()) + + # --- Request 1: Alice creates a completion --- + # This is the first request, so it spawns the background worker. + # The worker inherits Alice's PROVIDER_DATA_VAR permanently. + _set_authenticated_user(alice) + await store.store_chat_completion( + _make_completion("alice-completion", base_time + 1), + [OpenAIUserMessageParam(role="user", content="Hello from Alice")], + ) + await store.flush() + + # --- Request 2: Bob creates a completion --- + # The worker is already running with Alice's context. + # Bob's write goes through the queue but is processed under Alice's identity. + _set_authenticated_user(bob) + await store.store_chat_completion( + _make_completion("bob-completion", base_time + 2), + [OpenAIUserMessageParam(role="user", content="Hello from Bob")], + ) + await store.flush() + + # --- Now verify: can each user see only their own completions? --- + + # Alice should see 1 completion (her own) + _set_authenticated_user(alice) + alice_results = await store.list_chat_completions() + + # Bob should see 1 completion (his own) + _set_authenticated_user(bob) + bob_results = await store.list_chat_completions() + + await store.shutdown() + + # --- Assertions --- + alice_ids = [c.id for c in alice_results.data] + bob_ids = [c.id for c in bob_results.data] + + print(f"\nAlice sees: {alice_ids}") + print(f"Bob sees: {bob_ids}") + + # If the bug exists: + # Alice sees: ['alice-completion', 'bob-completion'] (both!) + # Bob sees: [] (nothing!) + # + # If fixed: + # Alice sees: ['alice-completion'] + # Bob sees: ['bob-completion'] + + assert "alice-completion" in alice_ids, "Alice should see her own completion" + assert "bob-completion" not in alice_ids, ( + "BUG: Alice can see Bob's completion — PROVIDER_DATA_VAR leaked from worker" + ) + assert "bob-completion" in bob_ids, "Bob should see his own completion" + assert "alice-completion" not in bob_ids, "BUG: Bob can see Alice's completion — unexpected cross-contamination"