diff --git a/src/llama_stack/core/task.py b/src/llama_stack/core/task.py new file mode 100644 index 0000000000..cdcd2c655e --- /dev/null +++ b/src/llama_stack/core/task.py @@ -0,0 +1,50 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import asyncio +from collections.abc import Coroutine +from contextlib import contextmanager +from typing import Any + +from opentelemetry import context as otel_context + + +def create_task_with_detached_otel_context(coro: Coroutine[Any, Any, Any]) -> asyncio.Task[Any]: + """Create an asyncio task that does not inherit the current OpenTelemetry trace context. + + asyncio.create_task copies all contextvars at creation time, which causes + fire-and-forget or long-lived background tasks to be attributed to whatever + request happened to spawn them. This inflates trace durations and bundles + unrelated DB operations under the wrong trace. + + This helper temporarily clears the OTel context before creating the task, + then immediately restores it so the calling coroutine is unaffected. + """ + token = otel_context.attach(otel_context.Context()) + try: + task = asyncio.create_task(coro) + finally: + otel_context.detach(token) + return task + + +def capture_otel_context() -> otel_context.Context: + """Snapshot the current OTel context for later use in a different task.""" + return otel_context.get_current() + + +@contextmanager +def activate_otel_context(ctx: otel_context.Context): + """Temporarily activate a previously captured OTel context. + + Use this in worker loops that run with a detached (empty) context to + attribute work back to the originating request's trace. + """ + token = otel_context.attach(ctx) + try: + yield + finally: + otel_context.detach(token) diff --git a/src/llama_stack/providers/inline/agents/builtin/responses/openai_responses.py b/src/llama_stack/providers/inline/agents/builtin/responses/openai_responses.py index 3e01d7a8ab..8ac9c4f5c5 100644 --- a/src/llama_stack/providers/inline/agents/builtin/responses/openai_responses.py +++ b/src/llama_stack/providers/inline/agents/builtin/responses/openai_responses.py @@ -9,10 +9,13 @@ import time import uuid from collections.abc import AsyncIterator +from dataclasses import dataclass, field +from opentelemetry import context as otel_context from pydantic import BaseModel, TypeAdapter from llama_stack.core.conversations.validation import CONVERSATION_ID_PATTERN +from llama_stack.core.task import activate_otel_context, capture_otel_context, create_task_with_detached_otel_context from llama_stack.log import get_logger from llama_stack.providers.utils.responses.responses_store import ( ResponsesStore, @@ -82,6 +85,14 @@ BACKGROUND_NUM_WORKERS = 10 +@dataclass +class _BackgroundWorkItem: + """Typed queue item for background response processing.""" + + otel_context: otel_context.Context + kwargs: dict = field(default_factory=dict) + + class OpenAIResponsePreviousResponseWithInputItems(BaseModel): input_items: ListOpenAIResponseInputItem response: OpenAIResponseObject @@ -118,7 +129,7 @@ def __init__( self.prompts_api = prompts_api self.files_api = files_api self.connectors_api = connectors_api - self._background_queue: asyncio.Queue = asyncio.Queue(maxsize=BACKGROUND_QUEUE_MAX_SIZE) + self._background_queue: asyncio.Queue[_BackgroundWorkItem] = asyncio.Queue(maxsize=BACKGROUND_QUEUE_MAX_SIZE) self._background_worker_tasks: set[asyncio.Task] = set() async def initialize(self) -> None: @@ -133,7 +144,7 @@ async def initialize(self) -> None: async def _ensure_workers_started(self) -> None: """Start background workers in the current event loop if not already running.""" for _ in range(BACKGROUND_NUM_WORKERS - len(self._background_worker_tasks)): - task = asyncio.create_task(self._background_worker()) + task = create_task_with_detached_otel_context(self._background_worker()) self._background_worker_tasks.add(task) task.add_done_callback(self._background_worker_tasks.discard) @@ -146,48 +157,49 @@ async def shutdown(self) -> None: async def _background_worker(self) -> None: """Worker coroutine that pulls items from the queue and processes them.""" while True: - kwargs = await self._background_queue.get() - try: - await asyncio.wait_for( - self._run_background_response_loop(**kwargs), - timeout=BACKGROUND_RESPONSE_TIMEOUT_SECONDS, - ) - except TimeoutError: - response_id = kwargs["response_id"] - logger.exception( - f"Background response {response_id} timed out after {BACKGROUND_RESPONSE_TIMEOUT_SECONDS}s" - ) + item = await self._background_queue.get() + with activate_otel_context(item.otel_context): try: - existing = await self.responses_store.get_response_object(response_id) - existing.status = "failed" - existing.error = OpenAIResponseError( - code="processing_error", - message=f"Background response timed out after {BACKGROUND_RESPONSE_TIMEOUT_SECONDS}s", + await asyncio.wait_for( + self._run_background_response_loop(**item.kwargs), + timeout=BACKGROUND_RESPONSE_TIMEOUT_SECONDS, ) - await self.responses_store.update_response_object(existing) - except Exception: + except TimeoutError: + response_id = item.kwargs["response_id"] logger.exception( - f"Failed to update response {response_id} with timeout status. " - "Client polling this response will not see the failure." - ) - except Exception as e: - response_id = kwargs["response_id"] - logger.exception(f"Error processing background response {response_id}") - try: - existing = await self.responses_store.get_response_object(response_id) - existing.status = "failed" - existing.error = OpenAIResponseError( - code="processing_error", - message=str(e), + f"Background response {response_id} timed out after {BACKGROUND_RESPONSE_TIMEOUT_SECONDS}s" ) - await self.responses_store.update_response_object(existing) - except Exception: - logger.exception( - f"Failed to update response {response_id} with error status. " - "Client polling this response will not see the failure." - ) - finally: - self._background_queue.task_done() + try: + existing = await self.responses_store.get_response_object(response_id) + existing.status = "failed" + existing.error = OpenAIResponseError( + code="processing_error", + message=f"Background response timed out after {BACKGROUND_RESPONSE_TIMEOUT_SECONDS}s", + ) + await self.responses_store.update_response_object(existing) + except Exception: + logger.exception( + f"Failed to update response {response_id} with timeout status. " + "Client polling this response will not see the failure." + ) + except Exception as e: + response_id = item.kwargs["response_id"] + logger.exception(f"Error processing background response {response_id}") + try: + existing = await self.responses_store.get_response_object(response_id) + existing.status = "failed" + existing.error = OpenAIResponseError( + code="processing_error", + message=str(e), + ) + await self.responses_store.update_response_object(existing) + except Exception: + logger.exception( + f"Failed to update response {response_id} with error status. " + "Client polling this response will not see the failure." + ) + finally: + self._background_queue.task_done() async def _prepend_previous_response( self, @@ -820,33 +832,36 @@ async def _create_background_response( # Enqueue work item for background workers. Raises QueueFull if at capacity. try: self._background_queue.put_nowait( - dict( - response_id=response_id, - input=input, - model=model, - prompt=prompt, - instructions=instructions, - previous_response_id=previous_response_id, - conversation=conversation, - store=store, - temperature=temperature, - frequency_penalty=frequency_penalty, - text=text, - tool_choice=tool_choice, - tools=tools, - include=include, - max_infer_iters=max_infer_iters, - guardrail_ids=guardrail_ids, - parallel_tool_calls=parallel_tool_calls, - max_tool_calls=max_tool_calls, - reasoning=reasoning, - max_output_tokens=max_output_tokens, - safety_identifier=safety_identifier, - service_tier=service_tier, - metadata=metadata, - truncation=truncation, - presence_penalty=presence_penalty, - extra_body=extra_body, + _BackgroundWorkItem( + otel_context=capture_otel_context(), + kwargs=dict( + response_id=response_id, + input=input, + model=model, + prompt=prompt, + instructions=instructions, + previous_response_id=previous_response_id, + conversation=conversation, + store=store, + temperature=temperature, + frequency_penalty=frequency_penalty, + text=text, + tool_choice=tool_choice, + tools=tools, + include=include, + max_infer_iters=max_infer_iters, + guardrail_ids=guardrail_ids, + parallel_tool_calls=parallel_tool_calls, + max_tool_calls=max_tool_calls, + reasoning=reasoning, + max_output_tokens=max_output_tokens, + safety_identifier=safety_identifier, + service_tier=service_tier, + metadata=metadata, + truncation=truncation, + presence_penalty=presence_penalty, + extra_body=extra_body, + ), ) ) except asyncio.QueueFull: diff --git a/src/llama_stack/providers/utils/inference/inference_store.py b/src/llama_stack/providers/utils/inference/inference_store.py index 78327573b6..f96fd4cd31 100644 --- a/src/llama_stack/providers/utils/inference/inference_store.py +++ b/src/llama_stack/providers/utils/inference/inference_store.py @@ -4,14 +4,16 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. import asyncio -from typing import Any +from typing import Any, NamedTuple +from opentelemetry import context as otel_context from sqlalchemy.exc import IntegrityError from llama_stack.core.datatypes import AccessRule from llama_stack.core.storage.datatypes import InferenceStoreReference, StorageBackendType from llama_stack.core.storage.sqlstore.authorized_sqlstore import AuthorizedSqlStore from llama_stack.core.storage.sqlstore.sqlstore import _SQLSTORE_BACKENDS, sqlstore_impl +from llama_stack.core.task import activate_otel_context, capture_otel_context, create_task_with_detached_otel_context from llama_stack.log import get_logger from llama_stack_api import ( ListOpenAIChatCompletionResponse, @@ -25,6 +27,12 @@ logger = get_logger(name=__name__, category="inference") +class _WriteItem(NamedTuple): + completion: OpenAIChatCompletion + messages: list[OpenAIMessageParam] + otel_context: otel_context.Context + + class InferenceStore: def __init__( self, @@ -37,7 +45,7 @@ def __init__( self.enable_write_queue = True # Async write queue and worker control - self._queue: asyncio.Queue[tuple[OpenAIChatCompletion, list[OpenAIMessageParam]]] | None = None + self._queue: asyncio.Queue[_WriteItem] | None = None self._worker_tasks: list[asyncio.Task[Any]] = [] self._max_write_queue_size: int = reference.max_write_queue_size self._num_writers: int = max(1, reference.num_writers) @@ -98,9 +106,8 @@ async def _ensure_workers_started(self) -> None: ) if not self._worker_tasks: - loop = asyncio.get_running_loop() for _ in range(self._num_writers): - task = loop.create_task(self._worker_loop()) + task = create_task_with_detached_otel_context(self._worker_loop()) self._worker_tasks.append(task) async def store_chat_completion( @@ -110,13 +117,14 @@ async def store_chat_completion( await self._ensure_workers_started() if self._queue is None: raise ValueError("Inference store is not initialized") + item = _WriteItem(chat_completion, input_messages, capture_otel_context()) try: - self._queue.put_nowait((chat_completion, input_messages)) + self._queue.put_nowait(item) except asyncio.QueueFull: logger.warning( f"Write queue full; adding chat completion id={getattr(chat_completion, 'id', '')}" ) - await self._queue.put((chat_completion, input_messages)) + await self._queue.put(item) else: await self._write_chat_completion(chat_completion, input_messages) @@ -127,9 +135,9 @@ async def _worker_loop(self) -> None: item = await self._queue.get() except asyncio.CancelledError: break - chat_completion, input_messages = item try: - await self._write_chat_completion(chat_completion, input_messages) + with activate_otel_context(item.otel_context): + await self._write_chat_completion(item.completion, item.messages) except Exception as e: # noqa: BLE001 logger.error(f"Error writing chat completion: {e}") finally: diff --git a/tests/unit/core/test_task.py b/tests/unit/core/test_task.py new file mode 100644 index 0000000000..afb497ec7b --- /dev/null +++ b/tests/unit/core/test_task.py @@ -0,0 +1,250 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import asyncio + +from opentelemetry import context as otel_context +from opentelemetry import trace +from opentelemetry.sdk.trace import TracerProvider +from opentelemetry.sdk.trace.export import SimpleSpanProcessor, SpanExporter, SpanExportResult + +from llama_stack.core.task import ( + activate_otel_context, + capture_otel_context, + create_task_with_detached_otel_context, +) + + +class _CollectingExporter(SpanExporter): + """Collects finished spans in memory for test assertions.""" + + def __init__(self): + self.spans = [] + + def export(self, spans): + self.spans.extend(spans) + return SpanExportResult.SUCCESS + + +async def test_detached_task_runs_coroutine(): + """The helper creates a task that actually runs the coroutine to completion.""" + result = [] + + async def work(): + result.append("done") + + task = create_task_with_detached_otel_context(work()) + await task + assert result == ["done"] + + +async def test_detached_task_clears_otel_context(): + """The task should run with an empty OTel context, not the parent's.""" + provider = TracerProvider() + tracer = provider.get_tracer("test") + + captured_span = {} + + async def capture_context(): + captured_span["inner"] = trace.get_current_span() + + with tracer.start_as_current_span("parent-span"): + parent_ctx = otel_context.get_current() + parent_span = trace.get_current_span() + + task = create_task_with_detached_otel_context(capture_context()) + await task + + assert not captured_span["inner"].is_recording() + assert parent_span.is_recording() + assert otel_context.get_current() == parent_ctx + + +async def test_detached_task_restores_caller_context(): + """The calling coroutine's OTel context is not affected by creating a detached task.""" + provider = TracerProvider() + tracer = provider.get_tracer("test") + + with tracer.start_as_current_span("parent-span"): + before = otel_context.get_current() + create_task_with_detached_otel_context(asyncio.sleep(0)) + after = otel_context.get_current() + assert before == after + + +async def test_detached_task_produces_independent_trace(): + """Spans created inside a detached task belong to a separate trace, not the parent's.""" + exporter = _CollectingExporter() + provider = TracerProvider() + provider.add_span_processor(SimpleSpanProcessor(exporter)) + tracer = provider.get_tracer("test") + + async def background_work(): + with tracer.start_as_current_span("background-db-write"): + await asyncio.sleep(0) + + with tracer.start_as_current_span("http-request"): + task = create_task_with_detached_otel_context(background_work()) + await task + + provider.force_flush() + span_by_name = {s.name: s for s in exporter.spans} + + request_span = span_by_name["http-request"] + bg_span = span_by_name["background-db-write"] + + assert request_span.context.trace_id != bg_span.context.trace_id, ( + "Background span should belong to a different trace than the request" + ) + assert bg_span.parent is None, "Background span should be a root span with no parent" + + +async def test_normal_child_task_shares_trace(): + """Contrast: a regular asyncio.create_task DOES inherit the parent trace.""" + exporter = _CollectingExporter() + provider = TracerProvider() + provider.add_span_processor(SimpleSpanProcessor(exporter)) + tracer = provider.get_tracer("test") + + async def child_work(): + with tracer.start_as_current_span("child-span"): + await asyncio.sleep(0) + + with tracer.start_as_current_span("parent-request"): + task = asyncio.create_task(child_work()) + await task + + provider.force_flush() + span_by_name = {s.name: s for s in exporter.spans} + + parent_span = span_by_name["parent-request"] + child_span = span_by_name["child-span"] + + assert parent_span.context.trace_id == child_span.context.trace_id, ( + "Regular create_task should share the parent's trace" + ) + + +async def test_capture_and_attach_otel_context(): + """capture_otel_context snapshots the current context; activate_otel_context re-activates it.""" + exporter = _CollectingExporter() + provider = TracerProvider() + provider.add_span_processor(SimpleSpanProcessor(exporter)) + tracer = provider.get_tracer("test") + + with tracer.start_as_current_span("request"): + ctx = capture_otel_context() + request_trace_id = trace.get_current_span().get_span_context().trace_id + + with activate_otel_context(ctx): + with tracer.start_as_current_span("reattached-work"): + reattached_trace_id = trace.get_current_span().get_span_context().trace_id + + assert request_trace_id == reattached_trace_id, "Work done under attached context should share the original trace" + + +async def test_attached_context_restores_on_exit(): + """activate_otel_context restores the previous context when the block exits.""" + provider = TracerProvider() + tracer = provider.get_tracer("test") + + with tracer.start_as_current_span("outer"): + outer_ctx = otel_context.get_current() + + inner_ctx = otel_context.Context() + with activate_otel_context(inner_ctx): + assert otel_context.get_current() == inner_ctx + + assert otel_context.get_current() == outer_ctx + + +async def test_context_through_queue_pattern(): + """End-to-end: context captured at enqueue time is correctly attached in a detached worker. + + This simulates the inference_store pattern: + 1. Request creates a span and enqueues work with captured context + 2. Worker runs in a detached (empty) context + 3. Worker attaches the captured context before processing + 4. The resulting span belongs to the original request's trace + """ + exporter = _CollectingExporter() + provider = TracerProvider() + provider.add_span_processor(SimpleSpanProcessor(exporter)) + tracer = provider.get_tracer("test") + + queue: asyncio.Queue[tuple[str, otel_context.Context]] = asyncio.Queue() + + async def worker(): + item, ctx = await queue.get() + with activate_otel_context(ctx): + with tracer.start_as_current_span(f"db-write-{item}"): + await asyncio.sleep(0) + queue.task_done() + + with tracer.start_as_current_span("http-request-A"): + ctx_a = capture_otel_context() + await queue.put(("A", ctx_a)) + + worker_task = create_task_with_detached_otel_context(worker()) + await worker_task + await queue.join() + + provider.force_flush() + span_by_name = {s.name: s for s in exporter.spans} + + request_span = span_by_name["http-request-A"] + write_span = span_by_name["db-write-A"] + + assert request_span.context.trace_id == write_span.context.trace_id, ( + "DB write should belong to the same trace as the originating request" + ) + + +async def test_context_through_queue_no_cross_contamination(): + """Two requests enqueue work; each DB write is attributed to its own request trace. + + This is the key property: workers don't permanently inherit any single + request's context, and each queued item carries the correct context. + """ + exporter = _CollectingExporter() + provider = TracerProvider() + provider.add_span_processor(SimpleSpanProcessor(exporter)) + tracer = provider.get_tracer("test") + + queue: asyncio.Queue[tuple[str, otel_context.Context]] = asyncio.Queue() + processed = asyncio.Event() + + async def worker(): + for _ in range(2): + item, ctx = await queue.get() + with activate_otel_context(ctx): + with tracer.start_as_current_span(f"db-write-{item}"): + await asyncio.sleep(0) + queue.task_done() + processed.set() + + worker_task = create_task_with_detached_otel_context(worker()) + + with tracer.start_as_current_span("request-A"): + await queue.put(("A", capture_otel_context())) + + with tracer.start_as_current_span("request-B"): + await queue.put(("B", capture_otel_context())) + + await processed.wait() + await worker_task + + provider.force_flush() + span_by_name = {s.name: s for s in exporter.spans} + + request_a = span_by_name["request-A"] + request_b = span_by_name["request-B"] + write_a = span_by_name["db-write-A"] + write_b = span_by_name["db-write-B"] + + assert write_a.context.trace_id == request_a.context.trace_id, "Write A should be in request A's trace" + assert write_b.context.trace_id == request_b.context.trace_id, "Write B should be in request B's trace" + assert request_a.context.trace_id != request_b.context.trace_id, "Request A and B should have different traces" diff --git a/tests/unit/providers/agents/builtin/test_responses_background.py b/tests/unit/providers/agents/builtin/test_responses_background.py index b95b87e3d9..81b4a66df1 100644 --- a/tests/unit/providers/agents/builtin/test_responses_background.py +++ b/tests/unit/providers/agents/builtin/test_responses_background.py @@ -6,11 +6,33 @@ """Unit tests for background parameter support in Responses API.""" +import asyncio +from unittest.mock import AsyncMock, patch + import pytest +from opentelemetry import trace +from opentelemetry.sdk.trace import TracerProvider +from opentelemetry.sdk.trace.export import SimpleSpanProcessor, SpanExporter, SpanExportResult +from llama_stack.core.task import capture_otel_context, create_task_with_detached_otel_context +from llama_stack.providers.inline.agents.builtin.responses.openai_responses import ( + OpenAIResponsesImpl, + _BackgroundWorkItem, +) from llama_stack_api import OpenAIResponseError, OpenAIResponseObject +class _CollectingExporter(SpanExporter): + """Collects finished spans in memory for test assertions.""" + + def __init__(self): + self.spans = [] + + def export(self, spans): + self.spans.extend(spans) + return SpanExportResult.SUCCESS + + class TestBackgroundFieldInResponseObject: """Test that the background field is properly defined in OpenAIResponseObject.""" @@ -144,3 +166,182 @@ def test_error_response_with_background(self): assert response.background is True assert response.error is not None assert response.error.code == "processing_error" + + +def _make_responses_impl(): + """Create an OpenAIResponsesImpl with all dependencies mocked.""" + return OpenAIResponsesImpl( + inference_api=AsyncMock(), + tool_groups_api=AsyncMock(), + tool_runtime_api=AsyncMock(), + responses_store=AsyncMock(), + vector_io_api=AsyncMock(), + safety_api=None, + conversations_api=AsyncMock(), + prompts_api=AsyncMock(), + files_api=AsyncMock(), + connectors_api=AsyncMock(), + ) + + +class TestResponsesOtelContextPropagation: + """Verify that OTel trace context flows correctly through the background worker queue. + + The responses worker runs a full multi-step loop (_run_background_response_loop) + containing status updates, LLM calls, tool execution, and DB writes. All of + these operations must be attributed to the originating request's trace, not + to whichever request first spawned the worker. + """ + + async def test_worker_attributes_work_to_correct_request_trace(self): + """Each queued response is processed under its originating request's trace context.""" + exporter = _CollectingExporter() + provider = TracerProvider() + provider.add_span_processor(SimpleSpanProcessor(exporter)) + tracer = provider.get_tracer("test") + + impl = _make_responses_impl() + + async def mock_response_loop(**kwargs): + with tracer.start_as_current_span(f"process-{kwargs['response_id']}"): + await asyncio.sleep(0) + + with patch.object(impl, "_run_background_response_loop", side_effect=mock_response_loop): + worker_task = create_task_with_detached_otel_context(impl._background_worker()) + + with tracer.start_as_current_span("request-A"): + impl._background_queue.put_nowait( + _BackgroundWorkItem(otel_context=capture_otel_context(), kwargs=dict(response_id="resp-A")) + ) + + with tracer.start_as_current_span("request-B"): + impl._background_queue.put_nowait( + _BackgroundWorkItem(otel_context=capture_otel_context(), kwargs=dict(response_id="resp-B")) + ) + + await impl._background_queue.join() + worker_task.cancel() + try: + await worker_task + except asyncio.CancelledError: + pass + + provider.force_flush() + spans_by_name = {s.name: s for s in exporter.spans} + + request_a_trace = spans_by_name["request-A"].context.trace_id + request_b_trace = spans_by_name["request-B"].context.trace_id + process_a_trace = spans_by_name["process-resp-A"].context.trace_id + process_b_trace = spans_by_name["process-resp-B"].context.trace_id + + assert request_a_trace != request_b_trace, "Requests should have distinct traces" + + assert process_a_trace == request_a_trace, "Response processing for resp-A should be in request-A's trace" + assert process_b_trace == request_b_trace, "Response processing for resp-B should be in request-B's trace" + + async def test_worker_does_not_leak_context_between_items(self): + """After processing one item, the worker returns to a clean context. + + This ensures that if item A's processing sets some OTel state, it + doesn't bleed into item B's processing. + """ + exporter = _CollectingExporter() + provider = TracerProvider() + provider.add_span_processor(SimpleSpanProcessor(exporter)) + tracer = provider.get_tracer("test") + + impl = _make_responses_impl() + trace_ids_during_processing = {} + + async def mock_response_loop(**kwargs): + rid = kwargs["response_id"] + # The parent span has ended by this point, but the context still + # carries its trace_id. Child spans inherit this trace_id. + span_ctx = trace.get_current_span().get_span_context() + trace_ids_during_processing[rid] = span_ctx.trace_id if span_ctx.trace_id != 0 else None + with tracer.start_as_current_span(f"work-{rid}"): + await asyncio.sleep(0) + + with patch.object(impl, "_run_background_response_loop", side_effect=mock_response_loop): + worker_task = create_task_with_detached_otel_context(impl._background_worker()) + + with tracer.start_as_current_span("req-1"): + impl._background_queue.put_nowait( + _BackgroundWorkItem(otel_context=capture_otel_context(), kwargs=dict(response_id="r1")) + ) + + with tracer.start_as_current_span("req-2"): + impl._background_queue.put_nowait( + _BackgroundWorkItem(otel_context=capture_otel_context(), kwargs=dict(response_id="r2")) + ) + + await impl._background_queue.join() + worker_task.cancel() + try: + await worker_task + except asyncio.CancelledError: + pass + + provider.force_flush() + spans_by_name = {s.name: s for s in exporter.spans} + + req1_trace = spans_by_name["req-1"].context.trace_id + req2_trace = spans_by_name["req-2"].context.trace_id + + assert trace_ids_during_processing["r1"] is not None, "r1 should have a trace context" + assert trace_ids_during_processing["r2"] is not None, "r2 should have a trace context" + assert trace_ids_during_processing["r1"] == req1_trace + assert trace_ids_during_processing["r2"] == req2_trace + + async def test_error_handling_runs_under_request_context(self): + """When processing fails, the error handler's DB writes are also in the request's trace.""" + exporter = _CollectingExporter() + provider = TracerProvider() + provider.add_span_processor(SimpleSpanProcessor(exporter)) + tracer = provider.get_tracer("test") + + impl = _make_responses_impl() + mock_response = OpenAIResponseObject( + id="resp-err", + created_at=1234567890, + model="test-model", + status="in_progress", + output=[], + store=True, + ) + impl.responses_store.get_response_object = AsyncMock(return_value=mock_response) + impl.responses_store.update_response_object = AsyncMock() + + error_update_trace_ids = [] + original_update = impl.responses_store.update_response_object + + async def tracking_update(obj): + span_ctx = trace.get_current_span().get_span_context() + if span_ctx.trace_id != 0: + error_update_trace_ids.append(span_ctx.trace_id) + return await original_update(obj) + + impl.responses_store.update_response_object = tracking_update + + async def failing_loop(**kwargs): + raise RuntimeError("simulated failure") + + with patch.object(impl, "_run_background_response_loop", side_effect=failing_loop): + worker_task = create_task_with_detached_otel_context(impl._background_worker()) + + with tracer.start_as_current_span("failing-request"): + request_trace = trace.get_current_span().get_span_context().trace_id + impl._background_queue.put_nowait( + _BackgroundWorkItem(otel_context=capture_otel_context(), kwargs=dict(response_id="resp-err")) + ) + + await impl._background_queue.join() + worker_task.cancel() + try: + await worker_task + except asyncio.CancelledError: + pass + + assert len(error_update_trace_ids) > 0, "Error handler should have made DB updates" + for tid in error_update_trace_ids: + assert tid == request_trace, "Error handler DB writes should be in the failing request's trace" diff --git a/tests/unit/utils/inference/test_inference_store.py b/tests/unit/utils/inference/test_inference_store.py index 1fd7655ed2..dbc474f89c 100644 --- a/tests/unit/utils/inference/test_inference_store.py +++ b/tests/unit/utils/inference/test_inference_store.py @@ -4,9 +4,13 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +import asyncio import time import pytest +from opentelemetry import trace +from opentelemetry.sdk.trace import TracerProvider +from opentelemetry.sdk.trace.export import SimpleSpanProcessor, SpanExporter, SpanExportResult from llama_stack.core.storage.datatypes import InferenceStoreReference, SqliteSqlStoreConfig from llama_stack.core.storage.sqlstore.sqlstore import register_sqlstore_backends @@ -20,6 +24,17 @@ ) +class _CollectingExporter(SpanExporter): + """Collects finished spans in memory for test assertions.""" + + def __init__(self): + self.spans = [] + + def export(self, spans): + self.spans.extend(spans) + return SpanExportResult.SUCCESS + + @pytest.fixture(autouse=True) def setup_backends(tmp_path): """Register SQL store backends for testing.""" @@ -239,3 +254,163 @@ async def test_inference_store_custom_table_name(): # Verify the error message uses the custom table name with pytest.raises(ValueError, match=f"Record with id='non-existent' not found in table '{custom_table_name}'"): await store.list_chat_completions(after="non-existent", limit=2) + + +async def test_otel_traces_not_leaked_across_requests(): + """Two concurrent requests produce clean, separate OTel traces. + + Reproduces the bug observed in Jaeger traces where background worker tasks + permanently inherited the first request's OTel context. This caused all + subsequent DB writes from other requests to appear under that trace, + inflating it from 5s to 62s with 334 unrelated INSERT operations. + + The fix captures OTel context at enqueue time and attaches it per-item + in the worker loop, so each DB write is attributed to its originating request. + """ + exporter = _CollectingExporter() + provider = TracerProvider() + provider.add_span_processor(SimpleSpanProcessor(exporter)) + tracer = provider.get_tracer("test") + + reference = InferenceStoreReference( + backend="sql_default", + table_name="otel_test_completions", + num_writers=1, + ) + store = InferenceStore(reference, policy=[]) + await store.initialize() + store.enable_write_queue = True + + original_write = store._write_chat_completion + + async def instrumented_write(completion, messages): + with tracer.start_as_current_span(f"db-write-{completion.id}"): + await original_write(completion, messages) + + store._write_chat_completion = instrumented_write + + base_time = int(time.time()) + completion_a = create_test_chat_completion("completion-A", base_time + 1) + completion_b = create_test_chat_completion("completion-B", base_time + 2) + messages_a = [OpenAIUserMessageParam(role="user", content="request A")] + messages_b = [OpenAIUserMessageParam(role="user", content="request B")] + + # Simulate two API requests arriving in sequence (as the InferenceRouter does: + # asyncio.create_task(store.store_chat_completion(...)) inside a request span). + with tracer.start_as_current_span("request-A"): + task_a = asyncio.create_task(store.store_chat_completion(completion_a, messages_a)) + await task_a + + with tracer.start_as_current_span("request-B"): + task_b = asyncio.create_task(store.store_chat_completion(completion_b, messages_b)) + await task_b + + await store.flush() + await store.shutdown() + + provider.force_flush() + spans_by_name = {} + for s in exporter.spans: + spans_by_name[s.name] = s + + request_a_trace = spans_by_name["request-A"].context.trace_id + request_b_trace = spans_by_name["request-B"].context.trace_id + write_a_trace = spans_by_name["db-write-completion-A"].context.trace_id + write_b_trace = spans_by_name["db-write-completion-B"].context.trace_id + + assert request_a_trace != request_b_trace, "Requests should have distinct trace IDs" + + assert write_a_trace == request_a_trace, ( + f"DB write for completion-A should be in request-A's trace, " + f"got trace {write_a_trace:#x} expected {request_a_trace:#x}" + ) + assert write_b_trace == request_b_trace, ( + f"DB write for completion-B should be in request-B's trace, " + f"got trace {write_b_trace:#x} expected {request_b_trace:#x}" + ) + + +async def test_otel_worker_does_not_inherit_first_request_trace(): + """Workers start with a detached context and don't permanently adopt any request's trace. + + Before the fix, the worker task was created via loop.create_task() inside + the first request's span context, permanently binding all future work to + that trace. This test verifies that worker-internal operations (like queue + polling) don't produce spans under any request's trace. + """ + exporter = _CollectingExporter() + provider = TracerProvider() + provider.add_span_processor(SimpleSpanProcessor(exporter)) + tracer = provider.get_tracer("test") + + reference = InferenceStoreReference( + backend="sql_default", + table_name="otel_worker_test", + num_writers=1, + ) + store = InferenceStore(reference, policy=[]) + await store.initialize() + store.enable_write_queue = True + + original_write = store._write_chat_completion + + async def instrumented_write(completion, messages): + with tracer.start_as_current_span(f"db-write-{completion.id}"): + await original_write(completion, messages) + + store._write_chat_completion = instrumented_write + + base_time = int(time.time()) + + # First request spawns the worker (this is where the old bug lived: + # the worker permanently inherited request-1's trace context) + with tracer.start_as_current_span("request-1-spawns-worker"): + first_request_trace = trace.get_current_span().get_span_context().trace_id + completion_1 = create_test_chat_completion("comp-1", base_time + 1) + task = asyncio.create_task( + store.store_chat_completion( + completion_1, + [OpenAIUserMessageParam(role="user", content="first")], + ) + ) + await task + await store.flush() + + # Second request enqueues work; worker is already running + with tracer.start_as_current_span("request-2"): + second_request_trace = trace.get_current_span().get_span_context().trace_id + completion_2 = create_test_chat_completion("comp-2", base_time + 2) + task = asyncio.create_task( + store.store_chat_completion( + completion_2, + [OpenAIUserMessageParam(role="user", content="second")], + ) + ) + await task + await store.flush() + + # Third request (no trace context at all) + completion_3 = create_test_chat_completion("comp-3", base_time + 3) + await store.store_chat_completion( + completion_3, + [OpenAIUserMessageParam(role="user", content="third")], + ) + await store.flush() + await store.shutdown() + + provider.force_flush() + spans_by_name = {s.name: s for s in exporter.spans} + + # Write 1 should be in request-1's trace + assert spans_by_name["db-write-comp-1"].context.trace_id == first_request_trace + + # Write 2 should be in request-2's trace, NOT request-1's + assert spans_by_name["db-write-comp-2"].context.trace_id == second_request_trace + assert spans_by_name["db-write-comp-2"].context.trace_id != first_request_trace, ( + "BUG REPRODUCED: write for request-2 leaked into request-1's trace" + ) + + # Write 3 (no request context) should be in its own independent trace + write_3_trace = spans_by_name["db-write-comp-3"].context.trace_id + assert write_3_trace != first_request_trace + assert write_3_trace != second_request_trace