|
1 | 1 | # Copyright (c) Microsoft. All rights reserved. |
2 | 2 |
|
| 3 | +from __future__ import annotations |
| 4 | + |
3 | 5 | import multiprocessing |
4 | 6 | import sys |
5 | | -from typing import Any, Optional, Union |
| 7 | +import threading |
| 8 | +import time |
| 9 | +from typing import Any, Callable, Coroutine, Iterator, List, Optional, Union |
6 | 10 |
|
7 | 11 | import agentops |
8 | 12 | import pytest |
| 13 | +import uvicorn |
9 | 14 | from agentops.sdk.core import TraceContext |
| 15 | +from fastapi import FastAPI, Request |
| 16 | +from opentelemetry.proto.collector.trace.v1.trace_service_pb2 import ( |
| 17 | + ExportTraceServiceRequest, |
| 18 | + ExportTraceServiceResponse, |
| 19 | +) |
| 20 | +from opentelemetry.sdk.trace import ReadableSpan |
10 | 21 | from opentelemetry.trace.status import StatusCode |
| 22 | +from portpicker import pick_unused_port |
11 | 23 |
|
| 24 | +from agentlightning.store.base import LightningStore, LightningStoreCapabilities |
12 | 25 | from agentlightning.tracer.agentops import AgentOpsTracer |
| 26 | +from agentlightning.types import Span |
| 27 | +from agentlightning.utils import otlp |
| 28 | + |
| 29 | + |
| 30 | +class MockOTLPService: |
| 31 | + """A mock OTLP server to capture trace export requests for testing purposes.""" |
| 32 | + |
| 33 | + def __init__(self) -> None: |
| 34 | + self.received: List[ExportTraceServiceRequest] = [] |
| 35 | + |
| 36 | + def start_service(self) -> int: |
| 37 | + app = FastAPI() |
| 38 | + |
| 39 | + @app.post("/v1/traces") |
| 40 | + async def _export_traces(request: Request): # type: ignore |
| 41 | + async def capture(message: ExportTraceServiceRequest) -> None: |
| 42 | + self.received.append(message) |
| 43 | + |
| 44 | + return await otlp.handle_otlp_export( |
| 45 | + request, |
| 46 | + ExportTraceServiceRequest, |
| 47 | + ExportTraceServiceResponse, |
| 48 | + capture, |
| 49 | + signal_name="traces", |
| 50 | + ) |
| 51 | + |
| 52 | + port = pick_unused_port() |
| 53 | + config = uvicorn.Config(app, host="127.0.0.1", port=port, log_level="error") |
| 54 | + self.server = uvicorn.Server(config) |
| 55 | + self.thread = threading.Thread(target=self.server.run, daemon=True) |
| 56 | + self.thread.start() |
| 57 | + timeout = time.time() + 5 |
| 58 | + while not getattr(self.server, "started", False): |
| 59 | + if time.time() > timeout: |
| 60 | + raise RuntimeError("OTLP test server failed to start") |
| 61 | + if not self.thread.is_alive(): |
| 62 | + raise RuntimeError("OTLP test server thread exited before startup") |
| 63 | + time.sleep(0.01) |
| 64 | + |
| 65 | + return port |
| 66 | + |
| 67 | + def stop_service(self) -> None: |
| 68 | + self.server.should_exit = True |
| 69 | + self.thread.join(timeout=5) |
| 70 | + |
| 71 | + def get_traces(self) -> List[ExportTraceServiceRequest]: |
| 72 | + return self.received |
| 73 | + |
| 74 | + |
| 75 | +class MockLightningStore(LightningStore): |
| 76 | + """A minimal stub-only LightningStore, only implements methods likely used in tests.""" |
| 77 | + |
| 78 | + def __init__(self, server_port: int = 80) -> None: |
| 79 | + super().__init__() |
| 80 | + self.otlp_traces = False |
| 81 | + self.server_port = server_port |
| 82 | + |
| 83 | + def enable_otlp_traces(self) -> None: |
| 84 | + self.otlp_traces = True |
| 85 | + |
| 86 | + async def add_otel_span( |
| 87 | + self, |
| 88 | + rollout_id: str, |
| 89 | + attempt_id: str, |
| 90 | + readable_span: ReadableSpan, |
| 91 | + sequence_id: int | None = None, |
| 92 | + ) -> Span: |
| 93 | + if sequence_id is None: |
| 94 | + sequence_id = 0 |
| 95 | + |
| 96 | + span = Span.from_opentelemetry( |
| 97 | + readable_span, rollout_id=rollout_id, attempt_id=attempt_id, sequence_id=sequence_id |
| 98 | + ) |
| 99 | + return span |
| 100 | + |
| 101 | + @property |
| 102 | + def capabilities(self) -> LightningStoreCapabilities: |
| 103 | + """Return the capabilities of the store.""" |
| 104 | + return LightningStoreCapabilities( |
| 105 | + async_safe=False, |
| 106 | + thread_safe=False, |
| 107 | + zero_copy=False, |
| 108 | + otlp_traces=self.otlp_traces, |
| 109 | + ) |
| 110 | + |
| 111 | + def otlp_traces_endpoint(self) -> str: |
| 112 | + return f"http://127.0.0.1:{self.server_port}/v1/traces" |
13 | 113 |
|
14 | 114 |
|
15 | 115 | def _func_with_exception(): |
@@ -86,3 +186,107 @@ def custom_end_trace( |
86 | 186 | agentops.end_trace = old_end_trace |
87 | 187 | tracer.teardown_worker(0) |
88 | 188 | tracer.teardown() |
| 189 | + |
| 190 | + |
| 191 | +async def _test_agentops_trace_without_store_imp(): |
| 192 | + tracer = AgentOpsTracer() |
| 193 | + tracer.init() |
| 194 | + tracer.init_worker(0) |
| 195 | + |
| 196 | + try: |
| 197 | + # Using AgentOpsTracer to trace a function without providing a store, rollout_id, or attempt_id. |
| 198 | + tracer.trace_run(_func_without_exception) |
| 199 | + spans = tracer.get_last_trace() |
| 200 | + assert len(spans) > 0 |
| 201 | + finally: |
| 202 | + tracer.teardown_worker(0) |
| 203 | + tracer.teardown() |
| 204 | + |
| 205 | + |
| 206 | +async def _test_agentops_trace_with_store_disable_imp(): |
| 207 | + tracer = AgentOpsTracer() |
| 208 | + tracer.init() |
| 209 | + tracer.init_worker(0) |
| 210 | + |
| 211 | + try: |
| 212 | + # Using AgentOpsTracer to trace a function with providing a store which disabled native otlp exporter, rollout_id, and attempt_id. |
| 213 | + store = MockLightningStore() |
| 214 | + async with tracer.trace_context( |
| 215 | + name="agentops_test", store=store, rollout_id="test_rollout_id", attempt_id="test_attempt_id" |
| 216 | + ): |
| 217 | + _func_without_exception() |
| 218 | + spans = tracer.get_last_trace() |
| 219 | + assert len(spans) > 0 |
| 220 | + finally: |
| 221 | + tracer.teardown_worker(0) |
| 222 | + tracer.teardown() |
| 223 | + |
| 224 | + |
| 225 | +async def _test_agentops_trace_with_store_enable_imp(): |
| 226 | + mock_service = MockOTLPService() |
| 227 | + port = mock_service.start_service() |
| 228 | + |
| 229 | + tracer = AgentOpsTracer() |
| 230 | + tracer.init() |
| 231 | + tracer.init_worker(0) |
| 232 | + |
| 233 | + try: |
| 234 | + # Using AgentOpsTracer to trace a function with providing a store which disabled native otlp exporter, rollout_id, and attempt_id. |
| 235 | + store = MockLightningStore(port) |
| 236 | + async with tracer.trace_context( |
| 237 | + name="agentops_test", store=store, rollout_id="test_rollout_id", attempt_id="test_attempt_id" |
| 238 | + ): |
| 239 | + _func_without_exception() |
| 240 | + spans = tracer.get_last_trace() |
| 241 | + assert len(spans) > 0 |
| 242 | + finally: |
| 243 | + tracer.teardown_worker(0) |
| 244 | + tracer.teardown() |
| 245 | + |
| 246 | + mock_service.stop_service() |
| 247 | + |
| 248 | + |
| 249 | +def agentops_trace_paths() -> Iterator[Callable[[], Any]]: |
| 250 | + yield from [ |
| 251 | + _test_agentops_trace_without_store_imp, |
| 252 | + _test_agentops_trace_with_store_disable_imp, |
| 253 | + _test_agentops_trace_with_store_enable_imp, |
| 254 | + ] |
| 255 | + |
| 256 | + |
| 257 | +@pytest.mark.parametrize("func_name", [f.__name__ for f in agentops_trace_paths()], ids=str) |
| 258 | +def test_agentops_trace_with_store_or_not(func_name: str): |
| 259 | + """ |
| 260 | + The purpose of this test is to verify whether the following two scenarios both work correctly: |
| 261 | +
|
| 262 | + 1. Using AgentOpsTracer to trace a function without providing a store, rollout_id, or attempt_id. |
| 263 | + 2. Using AgentOpsTracer to trace a function with providing a store which disabled native otlp exporter, rollout_id, and attempt_id. |
| 264 | + 3. Using AgentOpsTracer to trace a function with providing a store which enabled native otlp exporter, rollout_id, and attempt_id. |
| 265 | + """ |
| 266 | + |
| 267 | + func = {f.__name__: f for f in agentops_trace_paths()}[func_name] |
| 268 | + |
| 269 | + ctx = multiprocessing.get_context("spawn") |
| 270 | + proc = ctx.Process(target=_run_async, args=(func,)) |
| 271 | + proc.start() |
| 272 | + proc.join(30.0) # On GPU server, the time is around 10 seconds. |
| 273 | + |
| 274 | + if proc.is_alive(): |
| 275 | + proc.terminate() |
| 276 | + proc.join(5) |
| 277 | + if proc.is_alive(): |
| 278 | + proc.kill() |
| 279 | + |
| 280 | + assert False, "Child process hung. Check test output for details." |
| 281 | + |
| 282 | + assert proc.exitcode == 0, ( |
| 283 | + f"Child process for test_trace_error_status_from_instance failed with exit code {proc.exitcode}. " |
| 284 | + "Check child traceback in test output." |
| 285 | + ) |
| 286 | + |
| 287 | + |
| 288 | +def _run_async(coro: Callable[[], Coroutine[Any, Any, Any]]) -> None: |
| 289 | + """Small wrapper: run async function inside multiprocessing target.""" |
| 290 | + import asyncio |
| 291 | + |
| 292 | + asyncio.run(coro()) |
0 commit comments