diff --git a/src/llama_stack/core/task.py b/src/llama_stack/core/task.py new file mode 100644 index 0000000000..e95b767841 --- /dev/null +++ b/src/llama_stack/core/task.py @@ -0,0 +1,73 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import asyncio +from collections.abc import Coroutine +from contextlib import contextmanager +from dataclasses import dataclass +from typing import Any + +from opentelemetry import context as otel_context + +from llama_stack.core.request_headers import PROVIDER_DATA_VAR + + +@dataclass +class RequestContext: + """Snapshot of request-scoped state for propagation through background queues. + + Background workers are long-lived asyncio tasks whose contextvars are frozen + at creation time. Capturing both the OTel trace context and the provider / + auth data at *enqueue* time and re-activating them per work-item ensures: + + * Each DB write is attributed to the correct request trace (OTel). + * Each DB write is stamped with the correct user identity (PROVIDER_DATA_VAR). + """ + + otel_ctx: otel_context.Context + provider_data: Any + + +def capture_request_context() -> RequestContext: + """Snapshot the current request-scoped context for later use in a worker.""" + return RequestContext( + otel_ctx=otel_context.get_current(), + provider_data=PROVIDER_DATA_VAR.get(), + ) + + +@contextmanager +def activate_request_context(ctx: RequestContext): + """Temporarily restore a previously captured request context. + + Use this in worker loops that run with a detached (empty) context to + attribute work back to the originating request. + """ + otel_token = otel_context.attach(ctx.otel_ctx) + provider_token = PROVIDER_DATA_VAR.set(ctx.provider_data) + try: + yield + finally: + PROVIDER_DATA_VAR.reset(provider_token) + otel_context.detach(otel_token) + + +def create_detached_background_task(coro: Coroutine[Any, Any, Any]) -> asyncio.Task[Any]: + """Create an asyncio task that does not inherit request-scoped context. + + asyncio.create_task copies all contextvars at creation time, which causes + long-lived background workers to permanently inherit the spawning request's + OTel trace and auth identity. This helper temporarily clears both before + creating the task, then immediately restores them so the caller is unaffected. + """ + otel_token = otel_context.attach(otel_context.Context()) + provider_token = PROVIDER_DATA_VAR.set(None) + try: + task = asyncio.create_task(coro) + finally: + PROVIDER_DATA_VAR.reset(provider_token) + otel_context.detach(otel_token) + return task diff --git a/src/llama_stack/providers/inline/agents/meta_reference/responses/openai_responses.py b/src/llama_stack/providers/inline/agents/meta_reference/responses/openai_responses.py index 46aafccb6e..a9bc082bb8 100644 --- a/src/llama_stack/providers/inline/agents/meta_reference/responses/openai_responses.py +++ b/src/llama_stack/providers/inline/agents/meta_reference/responses/openai_responses.py @@ -12,6 +12,16 @@ from pydantic import BaseModel, TypeAdapter +<<<<<<< HEAD:src/llama_stack/providers/inline/agents/meta_reference/responses/openai_responses.py +======= +from llama_stack.core.conversations.validation import CONVERSATION_ID_PATTERN +from llama_stack.core.task import ( + RequestContext, + activate_request_context, + capture_request_context, + create_detached_background_task, +) +>>>>>>> 9b86ce80 (fix: provider_data_var context leak (#5227)):src/llama_stack/providers/inline/agents/builtin/responses/openai_responses.py from llama_stack.log import get_logger from llama_stack.providers.utils.responses.responses_store import ( ResponsesStore, @@ -80,6 +90,17 @@ BACKGROUND_NUM_WORKERS = 10 +<<<<<<< HEAD:src/llama_stack/providers/inline/agents/meta_reference/responses/openai_responses.py +======= +@dataclass +class _BackgroundWorkItem: + """Typed queue item that pairs business kwargs with the originating request context.""" + + request_context: RequestContext + kwargs: dict = field(default_factory=dict) + + +>>>>>>> 9b86ce80 (fix: provider_data_var context leak (#5227)):src/llama_stack/providers/inline/agents/builtin/responses/openai_responses.py class OpenAIResponsePreviousResponseWithInputItems(BaseModel): input_items: ListOpenAIResponseInputItem response: OpenAIResponseObject @@ -131,7 +152,11 @@ async def initialize(self) -> None: async def _ensure_workers_started(self) -> None: """Start background workers in the current event loop if not already running.""" for _ in range(BACKGROUND_NUM_WORKERS - len(self._background_worker_tasks)): +<<<<<<< HEAD:src/llama_stack/providers/inline/agents/meta_reference/responses/openai_responses.py task = asyncio.create_task(self._background_worker()) +======= + task = create_detached_background_task(self._background_worker()) +>>>>>>> 9b86ce80 (fix: provider_data_var context leak (#5227)):src/llama_stack/providers/inline/agents/builtin/responses/openai_responses.py self._background_worker_tasks.add(task) task.add_done_callback(self._background_worker_tasks.discard) @@ -144,6 +169,7 @@ async def shutdown(self) -> None: async def _background_worker(self) -> None: """Worker coroutine that pulls items from the queue and processes them.""" while True: +<<<<<<< HEAD:src/llama_stack/providers/inline/agents/meta_reference/responses/openai_responses.py kwargs = await self._background_queue.get() try: await asyncio.wait_for( @@ -155,6 +181,10 @@ async def _background_worker(self) -> None: logger.exception( f"Background response {response_id} timed out after {BACKGROUND_RESPONSE_TIMEOUT_SECONDS}s" ) +======= + item = await self._background_queue.get() + with activate_request_context(item.request_context): +>>>>>>> 9b86ce80 (fix: provider_data_var context leak (#5227)):src/llama_stack/providers/inline/agents/builtin/responses/openai_responses.py try: existing = await self.responses_store.get_response_object(response_id) existing.status = "failed" @@ -812,6 +842,7 @@ async def _create_background_response( # Enqueue work item for background workers. Raises QueueFull if at capacity. try: self._background_queue.put_nowait( +<<<<<<< HEAD:src/llama_stack/providers/inline/agents/meta_reference/responses/openai_responses.py dict( response_id=response_id, input=input, @@ -839,6 +870,38 @@ async def _create_background_response( truncation=truncation, presence_penalty=presence_penalty, extra_body=extra_body, +======= + _BackgroundWorkItem( + request_context=capture_request_context(), + kwargs=dict( + response_id=response_id, + input=input, + model=model, + prompt=prompt, + instructions=instructions, + previous_response_id=previous_response_id, + conversation=conversation, + store=store, + temperature=temperature, + frequency_penalty=frequency_penalty, + text=text, + tool_choice=tool_choice, + tools=tools, + include=include, + max_infer_iters=max_infer_iters, + guardrail_ids=guardrail_ids, + parallel_tool_calls=parallel_tool_calls, + max_tool_calls=max_tool_calls, + reasoning=reasoning, + max_output_tokens=max_output_tokens, + safety_identifier=safety_identifier, + service_tier=service_tier, + metadata=metadata, + truncation=truncation, + presence_penalty=presence_penalty, + extra_body=extra_body, + ), +>>>>>>> 9b86ce80 (fix: provider_data_var context leak (#5227)):src/llama_stack/providers/inline/agents/builtin/responses/openai_responses.py ) ) except asyncio.QueueFull: diff --git a/src/llama_stack/providers/utils/inference/inference_store.py b/src/llama_stack/providers/utils/inference/inference_store.py index 78327573b6..663ea8cdbd 100644 --- a/src/llama_stack/providers/utils/inference/inference_store.py +++ b/src/llama_stack/providers/utils/inference/inference_store.py @@ -12,6 +12,15 @@ from llama_stack.core.storage.datatypes import InferenceStoreReference, StorageBackendType from llama_stack.core.storage.sqlstore.authorized_sqlstore import AuthorizedSqlStore from llama_stack.core.storage.sqlstore.sqlstore import _SQLSTORE_BACKENDS, sqlstore_impl +<<<<<<< HEAD +======= +from llama_stack.core.task import ( + RequestContext, + activate_request_context, + capture_request_context, + create_detached_background_task, +) +>>>>>>> 9b86ce80 (fix: provider_data_var context leak (#5227)) from llama_stack.log import get_logger from llama_stack_api import ( ListOpenAIChatCompletionResponse, @@ -25,6 +34,15 @@ logger = get_logger(name=__name__, category="inference") +<<<<<<< HEAD +======= +class _WriteItem(NamedTuple): + completion: OpenAIChatCompletion + messages: list[OpenAIMessageParam] + request_context: RequestContext + + +>>>>>>> 9b86ce80 (fix: provider_data_var context leak (#5227)) class InferenceStore: def __init__( self, @@ -100,7 +118,11 @@ async def _ensure_workers_started(self) -> None: if not self._worker_tasks: loop = asyncio.get_running_loop() for _ in range(self._num_writers): +<<<<<<< HEAD task = loop.create_task(self._worker_loop()) +======= + task = create_detached_background_task(self._worker_loop()) +>>>>>>> 9b86ce80 (fix: provider_data_var context leak (#5227)) self._worker_tasks.append(task) async def store_chat_completion( @@ -110,6 +132,10 @@ async def store_chat_completion( await self._ensure_workers_started() if self._queue is None: raise ValueError("Inference store is not initialized") +<<<<<<< HEAD +======= + item = _WriteItem(chat_completion, input_messages, capture_request_context()) +>>>>>>> 9b86ce80 (fix: provider_data_var context leak (#5227)) try: self._queue.put_nowait((chat_completion, input_messages)) except asyncio.QueueFull: @@ -129,7 +155,12 @@ async def _worker_loop(self) -> None: break chat_completion, input_messages = item try: +<<<<<<< HEAD await self._write_chat_completion(chat_completion, input_messages) +======= + with activate_request_context(item.request_context): + await self._write_chat_completion(item.completion, item.messages) +>>>>>>> 9b86ce80 (fix: provider_data_var context leak (#5227)) except Exception as e: # noqa: BLE001 logger.error(f"Error writing chat completion: {e}") finally: diff --git a/tests/unit/core/test_task.py b/tests/unit/core/test_task.py new file mode 100644 index 0000000000..e175c59183 --- /dev/null +++ b/tests/unit/core/test_task.py @@ -0,0 +1,297 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import asyncio + +from opentelemetry import context as otel_context +from opentelemetry import trace +from opentelemetry.sdk.trace import TracerProvider +from opentelemetry.sdk.trace.export import SimpleSpanProcessor, SpanExporter, SpanExportResult + +from llama_stack.core.request_headers import PROVIDER_DATA_VAR +from llama_stack.core.task import ( + RequestContext, + activate_request_context, + capture_request_context, + create_detached_background_task, +) + + +class _CollectingExporter(SpanExporter): + """Collects finished spans in memory for test assertions.""" + + def __init__(self): + self.spans = [] + + def export(self, spans): + self.spans.extend(spans) + return SpanExportResult.SUCCESS + + +async def test_detached_task_runs_coroutine(): + """The helper creates a task that actually runs the coroutine to completion.""" + result = [] + + async def work(): + result.append("done") + + task = create_detached_background_task(work()) + await task + assert result == ["done"] + + +async def test_detached_task_clears_otel_context(): + """The task should run with an empty OTel context, not the parent's.""" + provider = TracerProvider() + tracer = provider.get_tracer("test") + + captured_span = {} + + async def capture_context(): + captured_span["inner"] = trace.get_current_span() + + with tracer.start_as_current_span("parent-span"): + parent_ctx = otel_context.get_current() + parent_span = trace.get_current_span() + + task = create_detached_background_task(capture_context()) + await task + + assert not captured_span["inner"].is_recording() + assert parent_span.is_recording() + assert otel_context.get_current() == parent_ctx + + +async def test_detached_task_clears_provider_data(): + """The task should run with PROVIDER_DATA_VAR cleared.""" + captured = {} + token = PROVIDER_DATA_VAR.set({"__authenticated_user": "alice"}) + + async def capture_provider(): + captured["value"] = PROVIDER_DATA_VAR.get() + + try: + task = create_detached_background_task(capture_provider()) + await task + + assert captured["value"] is None, "Background task should not inherit PROVIDER_DATA_VAR" + assert PROVIDER_DATA_VAR.get() == {"__authenticated_user": "alice"}, "Caller's context should be unaffected" + finally: + PROVIDER_DATA_VAR.reset(token) + + +async def test_detached_task_restores_caller_context(): + """The calling coroutine's context is not affected by creating a detached task.""" + provider = TracerProvider() + tracer = provider.get_tracer("test") + + token = PROVIDER_DATA_VAR.set({"__authenticated_user": "bob"}) + try: + with tracer.start_as_current_span("parent-span"): + otel_before = otel_context.get_current() + provider_before = PROVIDER_DATA_VAR.get() + + create_detached_background_task(asyncio.sleep(0)) + + assert otel_context.get_current() == otel_before + assert PROVIDER_DATA_VAR.get() == provider_before + finally: + PROVIDER_DATA_VAR.reset(token) + + +async def test_detached_task_produces_independent_trace(): + """Spans created inside a detached task belong to a separate trace.""" + exporter = _CollectingExporter() + provider = TracerProvider() + provider.add_span_processor(SimpleSpanProcessor(exporter)) + tracer = provider.get_tracer("test") + + async def background_work(): + with tracer.start_as_current_span("background-db-write"): + await asyncio.sleep(0) + + with tracer.start_as_current_span("http-request"): + task = create_detached_background_task(background_work()) + await task + + provider.force_flush() + span_by_name = {s.name: s for s in exporter.spans} + + request_span = span_by_name["http-request"] + bg_span = span_by_name["background-db-write"] + + assert request_span.context.trace_id != bg_span.context.trace_id + assert bg_span.parent is None + + +async def test_normal_child_task_shares_trace(): + """Contrast: a regular asyncio.create_task DOES inherit the parent trace.""" + exporter = _CollectingExporter() + provider = TracerProvider() + provider.add_span_processor(SimpleSpanProcessor(exporter)) + tracer = provider.get_tracer("test") + + async def child_work(): + with tracer.start_as_current_span("child-span"): + await asyncio.sleep(0) + + with tracer.start_as_current_span("parent-request"): + task = asyncio.create_task(child_work()) + await task + + provider.force_flush() + span_by_name = {s.name: s for s in exporter.spans} + + parent_span = span_by_name["parent-request"] + child_span = span_by_name["child-span"] + + assert parent_span.context.trace_id == child_span.context.trace_id, ( + "Regular create_task should share the parent's trace" + ) + + +async def test_context_through_queue_pattern(): + """End-to-end: context captured at enqueue time is correctly attached in a detached worker. + + This simulates the inference_store pattern: + 1. Request creates a span and enqueues work with captured context + 2. Worker runs in a detached (empty) context + 3. Worker attaches the captured context before processing + 4. The resulting span belongs to the original request's trace + """ + exporter = _CollectingExporter() + provider = TracerProvider() + provider.add_span_processor(SimpleSpanProcessor(exporter)) + tracer = provider.get_tracer("test") + + queue: asyncio.Queue[tuple[str, RequestContext]] = asyncio.Queue() + + async def worker(): + item, ctx = await queue.get() + with activate_request_context(ctx): + with tracer.start_as_current_span(f"db-write-{item}"): + await asyncio.sleep(0) + queue.task_done() + + token = PROVIDER_DATA_VAR.set({"user": "A"}) + try: + with tracer.start_as_current_span("http-request-A"): + ctx_a = capture_request_context() + await queue.put(("A", ctx_a)) + finally: + PROVIDER_DATA_VAR.reset(token) + + worker_task = create_detached_background_task(worker()) + await worker_task + await queue.join() + + provider.force_flush() + span_by_name = {s.name: s for s in exporter.spans} + + request_span = span_by_name["http-request-A"] + write_span = span_by_name["db-write-A"] + + assert request_span.context.trace_id == write_span.context.trace_id, ( + "DB write should belong to the same trace as the originating request" + ) + + +async def test_capture_and_activate_request_context(): + """capture_request_context snapshots both OTel and provider data; activate restores both.""" + exporter = _CollectingExporter() + provider = TracerProvider() + provider.add_span_processor(SimpleSpanProcessor(exporter)) + tracer = provider.get_tracer("test") + + token = PROVIDER_DATA_VAR.set({"__authenticated_user": "charlie"}) + try: + with tracer.start_as_current_span("request"): + ctx = capture_request_context() + request_trace_id = trace.get_current_span().get_span_context().trace_id + + assert isinstance(ctx, RequestContext) + assert ctx.provider_data == {"__authenticated_user": "charlie"} + + # After span ends, activate context and verify OTel trace is restored + with activate_request_context(ctx): + with tracer.start_as_current_span("reattached-work"): + reattached_trace_id = trace.get_current_span().get_span_context().trace_id + assert PROVIDER_DATA_VAR.get() == {"__authenticated_user": "charlie"} + + assert request_trace_id == reattached_trace_id + finally: + PROVIDER_DATA_VAR.reset(token) + + +async def test_activate_restores_on_exit(): + """activate_request_context restores the previous context when the block exits.""" + provider = TracerProvider() + tracer = provider.get_tracer("test") + + token = PROVIDER_DATA_VAR.set({"__authenticated_user": "outer_user"}) + try: + with tracer.start_as_current_span("outer"): + outer_otel = otel_context.get_current() + + inner_ctx = RequestContext( + otel_ctx=otel_context.Context(), + provider_data={"__authenticated_user": "inner_user"}, + ) + with activate_request_context(inner_ctx): + assert PROVIDER_DATA_VAR.get() == {"__authenticated_user": "inner_user"} + + assert PROVIDER_DATA_VAR.get() == {"__authenticated_user": "outer_user"} + assert otel_context.get_current() == outer_otel + finally: + PROVIDER_DATA_VAR.reset(token) + + +async def test_context_through_queue_no_cross_contamination(): + """Two requests enqueue work; each item's context is correctly propagated.""" + exporter = _CollectingExporter() + provider = TracerProvider() + provider.add_span_processor(SimpleSpanProcessor(exporter)) + tracer = provider.get_tracer("test") + + queue: asyncio.Queue[tuple[str, RequestContext]] = asyncio.Queue() + processed = asyncio.Event() + + async def worker(): + for _ in range(2): + label, ctx = await queue.get() + with activate_request_context(ctx): + assert PROVIDER_DATA_VAR.get() == {"user": label} + with tracer.start_as_current_span(f"db-write-{label}"): + await asyncio.sleep(0) + queue.task_done() + processed.set() + + worker_task = create_detached_background_task(worker()) + + token_a = PROVIDER_DATA_VAR.set({"user": "A"}) + with tracer.start_as_current_span("request-A"): + await queue.put(("A", capture_request_context())) + PROVIDER_DATA_VAR.reset(token_a) + + token_b = PROVIDER_DATA_VAR.set({"user": "B"}) + with tracer.start_as_current_span("request-B"): + await queue.put(("B", capture_request_context())) + PROVIDER_DATA_VAR.reset(token_b) + + await processed.wait() + await worker_task + + provider.force_flush() + span_by_name = {s.name: s for s in exporter.spans} + + request_a = span_by_name["request-A"] + request_b = span_by_name["request-B"] + write_a = span_by_name["db-write-A"] + write_b = span_by_name["db-write-B"] + + assert write_a.context.trace_id == request_a.context.trace_id + assert write_b.context.trace_id == request_b.context.trace_id + assert request_a.context.trace_id != request_b.context.trace_id diff --git a/tests/unit/providers/agents/builtin/test_responses_background.py b/tests/unit/providers/agents/builtin/test_responses_background.py new file mode 100644 index 0000000000..40d2a1c3a1 --- /dev/null +++ b/tests/unit/providers/agents/builtin/test_responses_background.py @@ -0,0 +1,480 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +"""Unit tests for background parameter support in Responses API.""" + +import asyncio +from unittest.mock import AsyncMock, patch + +import pytest +from opentelemetry import trace +from opentelemetry.sdk.trace import TracerProvider +from opentelemetry.sdk.trace.export import SimpleSpanProcessor, SpanExporter, SpanExportResult + +from llama_stack.core.datatypes import User +from llama_stack.core.request_headers import PROVIDER_DATA_VAR, get_authenticated_user +from llama_stack.core.task import capture_request_context, create_detached_background_task +from llama_stack.providers.inline.agents.builtin.responses.openai_responses import ( + OpenAIResponsesImpl, + _BackgroundWorkItem, +) +from llama_stack_api import OpenAIResponseError, OpenAIResponseObject + + +class TestBackgroundFieldInResponseObject: + """Test that the background field is properly defined in OpenAIResponseObject.""" + + def test_background_field_default_is_none(self): + """Verify background field defaults to None.""" + response = OpenAIResponseObject( + id="resp_123", + created_at=1234567890, + model="test-model", + status="completed", + output=[], + store=True, + ) + assert response.background is None + + def test_background_field_can_be_true(self): + """Verify background field can be set to True.""" + response = OpenAIResponseObject( + id="resp_123", + created_at=1234567890, + model="test-model", + status="queued", + output=[], + background=True, + store=True, + ) + assert response.background is True + + def test_background_field_can_be_false(self): + """Verify background field can be False.""" + response = OpenAIResponseObject( + id="resp_123", + created_at=1234567890, + model="test-model", + status="completed", + output=[], + background=False, + store=True, + ) + assert response.background is False + + +class TestResponseStatus: + """Test that all expected status values work correctly.""" + + @pytest.mark.parametrize( + "status", + ["queued", "in_progress", "completed", "failed", "incomplete"], + ) + def test_valid_status_values(self, status): + """Verify all OpenAI-compatible status values are accepted.""" + response = OpenAIResponseObject( + id="resp_123", + created_at=1234567890, + model="test-model", + status=status, + output=[], + background=True if status in ("queued", "in_progress") else False, + store=True, + ) + assert response.status == status + + def test_queued_status_with_background(self): + """Verify queued status is typically used with background=True.""" + response = OpenAIResponseObject( + id="resp_123", + created_at=1234567890, + model="test-model", + status="queued", + output=[], + background=True, + store=True, + ) + assert response.status == "queued" + assert response.background is True + + +class TestResponseObjectSerialization: + """Test that the response object serializes correctly with background field.""" + + def test_model_dump_includes_background(self): + """Verify model_dump includes the background field.""" + response = OpenAIResponseObject( + id="resp_123", + created_at=1234567890, + model="test-model", + status="queued", + output=[], + background=True, + store=True, + ) + data = response.model_dump() + assert "background" in data + assert data["background"] is True + + def test_model_dump_json_includes_background(self): + """Verify JSON serialization includes the background field.""" + response = OpenAIResponseObject( + id="resp_123", + created_at=1234567890, + model="test-model", + status="completed", + output=[], + background=False, + store=True, + ) + json_str = response.model_dump_json() + assert '"background":false' in json_str or '"background": false' in json_str + + +class TestResponseErrorForBackground: + """Test error responses for background processing failures.""" + + def test_error_response_with_background(self): + """Verify error responses can include background field.""" + error = OpenAIResponseError( + code="processing_error", + message="Background processing failed", + ) + response = OpenAIResponseObject( + id="resp_123", + created_at=1234567890, + model="test-model", + status="failed", + output=[], + background=True, + error=error, + store=True, + ) + assert response.status == "failed" + assert response.background is True + assert response.error is not None + assert response.error.code == "processing_error" + + +def _make_responses_impl(): + """Create an OpenAIResponsesImpl with all dependencies mocked.""" + return OpenAIResponsesImpl( + inference_api=AsyncMock(), + tool_groups_api=AsyncMock(), + tool_runtime_api=AsyncMock(), + responses_store=AsyncMock(), + vector_io_api=AsyncMock(), + safety_api=None, + conversations_api=AsyncMock(), + prompts_api=AsyncMock(), + files_api=AsyncMock(), + connectors_api=AsyncMock(), + ) + + +class _CollectingExporter(SpanExporter): + """Collects finished spans in memory for test assertions.""" + + def __init__(self): + self.spans = [] + + def export(self, spans): + self.spans.extend(spans) + return SpanExportResult.SUCCESS + + +class TestResponsesOtelContextPropagation: + """Verify that OTel trace context flows correctly through the background worker queue.""" + + async def test_worker_attributes_work_to_correct_request_trace(self): + """Each queued response is processed under its originating request's trace context.""" + exporter = _CollectingExporter() + provider = TracerProvider() + provider.add_span_processor(SimpleSpanProcessor(exporter)) + tracer = provider.get_tracer("test") + + impl = _make_responses_impl() + + async def mock_response_loop(**kwargs): + with tracer.start_as_current_span(f"process-{kwargs['response_id']}"): + await asyncio.sleep(0) + + with patch.object(impl, "_run_background_response_loop", side_effect=mock_response_loop): + worker_task = create_detached_background_task(impl._background_worker()) + + with tracer.start_as_current_span("request-A"): + impl._background_queue.put_nowait( + _BackgroundWorkItem(request_context=capture_request_context(), kwargs=dict(response_id="resp-A")) + ) + + with tracer.start_as_current_span("request-B"): + impl._background_queue.put_nowait( + _BackgroundWorkItem(request_context=capture_request_context(), kwargs=dict(response_id="resp-B")) + ) + + await impl._background_queue.join() + worker_task.cancel() + try: + await worker_task + except asyncio.CancelledError: + pass + + provider.force_flush() + spans_by_name = {s.name: s for s in exporter.spans} + + request_a_trace = spans_by_name["request-A"].context.trace_id + request_b_trace = spans_by_name["request-B"].context.trace_id + process_a_trace = spans_by_name["process-resp-A"].context.trace_id + process_b_trace = spans_by_name["process-resp-B"].context.trace_id + + assert request_a_trace != request_b_trace, "Requests should have distinct traces" + assert process_a_trace == request_a_trace, "Response processing for resp-A should be in request-A's trace" + assert process_b_trace == request_b_trace, "Response processing for resp-B should be in request-B's trace" + + async def test_worker_does_not_leak_context_between_items(self): + """After processing one item, the worker returns to a clean OTel context.""" + exporter = _CollectingExporter() + provider = TracerProvider() + provider.add_span_processor(SimpleSpanProcessor(exporter)) + tracer = provider.get_tracer("test") + + impl = _make_responses_impl() + trace_ids_during_processing = {} + + async def mock_response_loop(**kwargs): + rid = kwargs["response_id"] + span_ctx = trace.get_current_span().get_span_context() + trace_ids_during_processing[rid] = span_ctx.trace_id if span_ctx.trace_id != 0 else None + with tracer.start_as_current_span(f"work-{rid}"): + await asyncio.sleep(0) + + with patch.object(impl, "_run_background_response_loop", side_effect=mock_response_loop): + worker_task = create_detached_background_task(impl._background_worker()) + + with tracer.start_as_current_span("req-1"): + impl._background_queue.put_nowait( + _BackgroundWorkItem(request_context=capture_request_context(), kwargs=dict(response_id="r1")) + ) + + with tracer.start_as_current_span("req-2"): + impl._background_queue.put_nowait( + _BackgroundWorkItem(request_context=capture_request_context(), kwargs=dict(response_id="r2")) + ) + + await impl._background_queue.join() + worker_task.cancel() + try: + await worker_task + except asyncio.CancelledError: + pass + + provider.force_flush() + spans_by_name = {s.name: s for s in exporter.spans} + + req1_trace = spans_by_name["req-1"].context.trace_id + req2_trace = spans_by_name["req-2"].context.trace_id + + assert trace_ids_during_processing["r1"] is not None, "r1 should have a trace context" + assert trace_ids_during_processing["r2"] is not None, "r2 should have a trace context" + assert trace_ids_during_processing["r1"] == req1_trace + assert trace_ids_during_processing["r2"] == req2_trace + + async def test_error_handling_runs_under_request_context(self): + """When processing fails, the error handler's DB writes are also in the request's trace.""" + exporter = _CollectingExporter() + provider = TracerProvider() + provider.add_span_processor(SimpleSpanProcessor(exporter)) + tracer = provider.get_tracer("test") + + impl = _make_responses_impl() + + mock_response = OpenAIResponseObject( + id="resp-err", + created_at=1234567890, + model="test-model", + status="in_progress", + output=[], + store=True, + ) + impl.responses_store.get_response_object = AsyncMock(return_value=mock_response) + impl.responses_store.update_response_object = AsyncMock() + + error_update_trace_ids = [] + original_update = impl.responses_store.update_response_object + + async def tracking_update(obj): + span_ctx = trace.get_current_span().get_span_context() + if span_ctx.trace_id != 0: + error_update_trace_ids.append(span_ctx.trace_id) + return await original_update(obj) + + impl.responses_store.update_response_object = tracking_update + + async def failing_loop(**kwargs): + raise RuntimeError("simulated failure") + + with patch.object(impl, "_run_background_response_loop", side_effect=failing_loop): + worker_task = create_detached_background_task(impl._background_worker()) + + with tracer.start_as_current_span("failing-request"): + request_trace = trace.get_current_span().get_span_context().trace_id + impl._background_queue.put_nowait( + _BackgroundWorkItem(request_context=capture_request_context(), kwargs=dict(response_id="resp-err")) + ) + + await impl._background_queue.join() + worker_task.cancel() + try: + await worker_task + except asyncio.CancelledError: + pass + + assert len(error_update_trace_ids) > 0, "Error handler should have made DB updates" + for tid in error_update_trace_ids: + assert tid == request_trace, "Error handler DB writes should be in the failing request's trace" + + +def _set_authenticated_user(user: User | None): + """Simulate what ProviderDataMiddleware does for each request.""" + if user: + PROVIDER_DATA_VAR.set({"__authenticated_user": user}) + else: + PROVIDER_DATA_VAR.set(None) + + +class TestResponsesProviderDataPropagation: + """Verify that PROVIDER_DATA_VAR flows correctly through the background worker queue. + + The responses worker processes the full response loop (LLM calls, tool execution, + DB writes). All operations inside the worker must run with the originating + request's auth identity, not whichever request first spawned the worker. + """ + + async def test_worker_runs_under_correct_user_identity(self): + """Each queued response is processed under its originating user's identity.""" + impl = _make_responses_impl() + + alice = User(principal="alice", attributes={"roles": ["user"]}) + bob = User(principal="bob", attributes={"roles": ["user"]}) + + observed_users: dict[str, User | None] = {} + + async def mock_response_loop(**kwargs): + observed_users[kwargs["response_id"]] = get_authenticated_user() + + with patch.object(impl, "_run_background_response_loop", side_effect=mock_response_loop): + worker_task = create_detached_background_task(impl._background_worker()) + + _set_authenticated_user(alice) + impl._background_queue.put_nowait( + _BackgroundWorkItem(request_context=capture_request_context(), kwargs=dict(response_id="resp-alice")) + ) + + _set_authenticated_user(bob) + impl._background_queue.put_nowait( + _BackgroundWorkItem(request_context=capture_request_context(), kwargs=dict(response_id="resp-bob")) + ) + + await impl._background_queue.join() + worker_task.cancel() + try: + await worker_task + except asyncio.CancelledError: + pass + + _set_authenticated_user(None) + + assert observed_users["resp-alice"] is not None, "Alice's request should have a user" + assert observed_users["resp-bob"] is not None, "Bob's request should have a user" + assert observed_users["resp-alice"].principal == "alice", "Alice's response should run as alice" + assert observed_users["resp-bob"].principal == "bob", "Bob's response should run as bob" + + async def test_worker_does_not_leak_identity_between_items(self): + """After processing one item, the worker returns to a clean state.""" + impl = _make_responses_impl() + + alice = User(principal="alice", attributes={"roles": ["user"]}) + + user_after_processing: list[User | None] = [] + + async def mock_response_loop(**kwargs): + user_after_processing.append(get_authenticated_user()) + + with patch.object(impl, "_run_background_response_loop", side_effect=mock_response_loop): + worker_task = create_detached_background_task(impl._background_worker()) + + # First item: enqueued by Alice + _set_authenticated_user(alice) + impl._background_queue.put_nowait( + _BackgroundWorkItem(request_context=capture_request_context(), kwargs=dict(response_id="resp-1")) + ) + + # Second item: enqueued with no user (anonymous) + _set_authenticated_user(None) + impl._background_queue.put_nowait( + _BackgroundWorkItem(request_context=capture_request_context(), kwargs=dict(response_id="resp-2")) + ) + + await impl._background_queue.join() + worker_task.cancel() + try: + await worker_task + except asyncio.CancelledError: + pass + + assert user_after_processing[0] is not None, "First item should run as alice" + assert user_after_processing[0].principal == "alice" + assert user_after_processing[1] is None, "Second item should run as anonymous — alice's identity must not leak" + + async def test_error_handler_runs_under_correct_identity(self): + """When processing fails, error-handling DB writes use the correct user.""" + impl = _make_responses_impl() + + bob = User(principal="bob", attributes={"roles": ["user"]}) + + mock_response = OpenAIResponseObject( + id="resp-err", + created_at=1234567890, + model="test-model", + status="in_progress", + output=[], + store=True, + ) + impl.responses_store.get_response_object = AsyncMock(return_value=mock_response) + + error_handler_users: list[User | None] = [] + original_update = impl.responses_store.update_response_object + + async def tracking_update(obj): + error_handler_users.append(get_authenticated_user()) + return await original_update(obj) + + impl.responses_store.update_response_object = tracking_update + + async def failing_loop(**kwargs): + raise RuntimeError("simulated failure") + + with patch.object(impl, "_run_background_response_loop", side_effect=failing_loop): + worker_task = create_detached_background_task(impl._background_worker()) + + _set_authenticated_user(bob) + impl._background_queue.put_nowait( + _BackgroundWorkItem(request_context=capture_request_context(), kwargs=dict(response_id="resp-err")) + ) + + await impl._background_queue.join() + worker_task.cancel() + try: + await worker_task + except asyncio.CancelledError: + pass + + _set_authenticated_user(None) + + assert len(error_handler_users) > 0, "Error handler should have made DB updates" + for user in error_handler_users: + assert user is not None, "Error handler should have a user identity" + assert user.principal == "bob", "Error handler should run as bob, not the worker's inherited identity" diff --git a/tests/unit/utils/inference/test_provider_data_leak.py b/tests/unit/utils/inference/test_provider_data_leak.py new file mode 100644 index 0000000000..ac91ec03dc --- /dev/null +++ b/tests/unit/utils/inference/test_provider_data_leak.py @@ -0,0 +1,150 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +""" +Reproduces the PROVIDER_DATA_VAR contextvar leak through background worker tasks. + +The inference store uses a write queue with long-lived worker tasks (on Postgres). +asyncio.create_task copies all contextvars at creation time, so the worker +permanently inherits the first request's PROVIDER_DATA_VAR. This means every +DB write is stamped with the first user's identity, regardless of who actually +made the request. + +This test forces the write queue on (normally disabled for SQLite) to demonstrate +the leak without needing a Postgres instance. +""" + +import time + +import pytest + +from llama_stack.core.access_control.datatypes import AccessRule, Action, Scope +from llama_stack.core.datatypes import User +from llama_stack.core.request_headers import PROVIDER_DATA_VAR +from llama_stack.core.storage.datatypes import InferenceStoreReference, SqliteSqlStoreConfig +from llama_stack.core.storage.sqlstore.sqlstore import register_sqlstore_backends +from llama_stack.providers.utils.inference.inference_store import InferenceStore +from llama_stack_api import ( + OpenAIChatCompletion, + OpenAIChatCompletionResponseMessage, + OpenAIChoice, + OpenAIUserMessageParam, +) + + +@pytest.fixture(autouse=True) +def setup_backends(tmp_path): + db_path = str(tmp_path / "test_leak.db") + register_sqlstore_backends({"sql_default": SqliteSqlStoreConfig(db_path=db_path)}) + + +def _set_authenticated_user(user: User | None): + """Simulate what ProviderDataMiddleware does for each request.""" + if user: + PROVIDER_DATA_VAR.set({"__authenticated_user": user}) + else: + PROVIDER_DATA_VAR.set(None) + + +def _make_completion(completion_id: str, created: int) -> OpenAIChatCompletion: + return OpenAIChatCompletion( + id=completion_id, + created=created, + model="test-model", + object="chat.completion", + choices=[ + OpenAIChoice( + index=0, + message=OpenAIChatCompletionResponseMessage( + role="assistant", + content=f"Response for {completion_id}", + ), + finish_reason="stop", + ) + ], + ) + + +async def test_provider_data_leak_through_write_queue(): + """Demonstrates that PROVIDER_DATA_VAR leaks into background workers. + + Expected behavior: each completion should be owned by the user who created it. + Actual behavior: all completions are owned by whoever triggered worker creation. + """ + owner_policy = [ + AccessRule(permit=Scope(actions=[Action.READ]), when=["user is owner"]), + AccessRule(permit=Scope(actions=[Action.CREATE]), when=[]), + ] + + reference = InferenceStoreReference( + backend="sql_default", + table_name="leak_test", + num_writers=1, + ) + store = InferenceStore(reference, policy=owner_policy) + await store.initialize() + + # Force the write queue on (normally disabled for SQLite) + store.enable_write_queue = True + + alice = User(principal="alice", attributes={"roles": ["user"]}) + bob = User(principal="bob", attributes={"roles": ["user"]}) + + base_time = int(time.time()) + + # --- Request 1: Alice creates a completion --- + # This is the first request, so it spawns the background worker. + # The worker inherits Alice's PROVIDER_DATA_VAR permanently. + _set_authenticated_user(alice) + await store.store_chat_completion( + _make_completion("alice-completion", base_time + 1), + [OpenAIUserMessageParam(role="user", content="Hello from Alice")], + ) + await store.flush() + + # --- Request 2: Bob creates a completion --- + # The worker is already running with Alice's context. + # Bob's write goes through the queue but is processed under Alice's identity. + _set_authenticated_user(bob) + await store.store_chat_completion( + _make_completion("bob-completion", base_time + 2), + [OpenAIUserMessageParam(role="user", content="Hello from Bob")], + ) + await store.flush() + + # --- Now verify: can each user see only their own completions? --- + + # Alice should see 1 completion (her own) + _set_authenticated_user(alice) + alice_results = await store.list_chat_completions() + + # Bob should see 1 completion (his own) + _set_authenticated_user(bob) + bob_results = await store.list_chat_completions() + + await store.shutdown() + + # --- Assertions --- + alice_ids = [c.id for c in alice_results.data] + bob_ids = [c.id for c in bob_results.data] + + print(f"\nAlice sees: {alice_ids}") + print(f"Bob sees: {bob_ids}") + + # If the bug exists: + # Alice sees: ['alice-completion', 'bob-completion'] (both!) + # Bob sees: [] (nothing!) + # + # If fixed: + # Alice sees: ['alice-completion'] + # Bob sees: ['bob-completion'] + + assert "alice-completion" in alice_ids, "Alice should see her own completion" + assert "bob-completion" not in alice_ids, ( + "BUG: Alice can see Bob's completion — PROVIDER_DATA_VAR leaked from worker" + ) + assert "bob-completion" in bob_ids, "Bob should see his own completion" + assert "alice-completion" not in bob_ids, "BUG: Bob can see Alice's completion — unexpected cross-contamination"