Skip to content

Commit 3a3bfee

Browse files
authored
add test code to agentops's tracer (#324)
1 parent a733950 commit 3a3bfee

File tree

1 file changed

+205
-1
lines changed

1 file changed

+205
-1
lines changed

tests/tracer/test_agentops.py

Lines changed: 205 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,115 @@
11
# Copyright (c) Microsoft. All rights reserved.
22

3+
from __future__ import annotations
4+
35
import multiprocessing
46
import sys
5-
from typing import Any, Optional, Union
7+
import threading
8+
import time
9+
from typing import Any, Callable, Coroutine, Iterator, List, Optional, Union
610

711
import agentops
812
import pytest
13+
import uvicorn
914
from agentops.sdk.core import TraceContext
15+
from fastapi import FastAPI, Request
16+
from opentelemetry.proto.collector.trace.v1.trace_service_pb2 import (
17+
ExportTraceServiceRequest,
18+
ExportTraceServiceResponse,
19+
)
20+
from opentelemetry.sdk.trace import ReadableSpan
1021
from opentelemetry.trace.status import StatusCode
22+
from portpicker import pick_unused_port
1123

24+
from agentlightning.store.base import LightningStore, LightningStoreCapabilities
1225
from agentlightning.tracer.agentops import AgentOpsTracer
26+
from agentlightning.types import Span
27+
from agentlightning.utils import otlp
28+
29+
30+
class MockOTLPService:
31+
"""A mock OTLP server to capture trace export requests for testing purposes."""
32+
33+
def __init__(self) -> None:
34+
self.received: List[ExportTraceServiceRequest] = []
35+
36+
def start_service(self) -> int:
37+
app = FastAPI()
38+
39+
@app.post("/v1/traces")
40+
async def _export_traces(request: Request): # type: ignore
41+
async def capture(message: ExportTraceServiceRequest) -> None:
42+
self.received.append(message)
43+
44+
return await otlp.handle_otlp_export(
45+
request,
46+
ExportTraceServiceRequest,
47+
ExportTraceServiceResponse,
48+
capture,
49+
signal_name="traces",
50+
)
51+
52+
port = pick_unused_port()
53+
config = uvicorn.Config(app, host="127.0.0.1", port=port, log_level="error")
54+
self.server = uvicorn.Server(config)
55+
self.thread = threading.Thread(target=self.server.run, daemon=True)
56+
self.thread.start()
57+
timeout = time.time() + 5
58+
while not getattr(self.server, "started", False):
59+
if time.time() > timeout:
60+
raise RuntimeError("OTLP test server failed to start")
61+
if not self.thread.is_alive():
62+
raise RuntimeError("OTLP test server thread exited before startup")
63+
time.sleep(0.01)
64+
65+
return port
66+
67+
def stop_service(self) -> None:
68+
self.server.should_exit = True
69+
self.thread.join(timeout=5)
70+
71+
def get_traces(self) -> List[ExportTraceServiceRequest]:
72+
return self.received
73+
74+
75+
class MockLightningStore(LightningStore):
76+
"""A minimal stub-only LightningStore, only implements methods likely used in tests."""
77+
78+
def __init__(self, server_port: int = 80) -> None:
79+
super().__init__()
80+
self.otlp_traces = False
81+
self.server_port = server_port
82+
83+
def enable_otlp_traces(self) -> None:
84+
self.otlp_traces = True
85+
86+
async def add_otel_span(
87+
self,
88+
rollout_id: str,
89+
attempt_id: str,
90+
readable_span: ReadableSpan,
91+
sequence_id: int | None = None,
92+
) -> Span:
93+
if sequence_id is None:
94+
sequence_id = 0
95+
96+
span = Span.from_opentelemetry(
97+
readable_span, rollout_id=rollout_id, attempt_id=attempt_id, sequence_id=sequence_id
98+
)
99+
return span
100+
101+
@property
102+
def capabilities(self) -> LightningStoreCapabilities:
103+
"""Return the capabilities of the store."""
104+
return LightningStoreCapabilities(
105+
async_safe=False,
106+
thread_safe=False,
107+
zero_copy=False,
108+
otlp_traces=self.otlp_traces,
109+
)
110+
111+
def otlp_traces_endpoint(self) -> str:
112+
return f"http://127.0.0.1:{self.server_port}/v1/traces"
13113

14114

15115
def _func_with_exception():
@@ -86,3 +186,107 @@ def custom_end_trace(
86186
agentops.end_trace = old_end_trace
87187
tracer.teardown_worker(0)
88188
tracer.teardown()
189+
190+
191+
async def _test_agentops_trace_without_store_imp():
192+
tracer = AgentOpsTracer()
193+
tracer.init()
194+
tracer.init_worker(0)
195+
196+
try:
197+
# Using AgentOpsTracer to trace a function without providing a store, rollout_id, or attempt_id.
198+
tracer.trace_run(_func_without_exception)
199+
spans = tracer.get_last_trace()
200+
assert len(spans) > 0
201+
finally:
202+
tracer.teardown_worker(0)
203+
tracer.teardown()
204+
205+
206+
async def _test_agentops_trace_with_store_disable_imp():
207+
tracer = AgentOpsTracer()
208+
tracer.init()
209+
tracer.init_worker(0)
210+
211+
try:
212+
# Using AgentOpsTracer to trace a function with providing a store which disabled native otlp exporter, rollout_id, and attempt_id.
213+
store = MockLightningStore()
214+
async with tracer.trace_context(
215+
name="agentops_test", store=store, rollout_id="test_rollout_id", attempt_id="test_attempt_id"
216+
):
217+
_func_without_exception()
218+
spans = tracer.get_last_trace()
219+
assert len(spans) > 0
220+
finally:
221+
tracer.teardown_worker(0)
222+
tracer.teardown()
223+
224+
225+
async def _test_agentops_trace_with_store_enable_imp():
226+
mock_service = MockOTLPService()
227+
port = mock_service.start_service()
228+
229+
tracer = AgentOpsTracer()
230+
tracer.init()
231+
tracer.init_worker(0)
232+
233+
try:
234+
# Using AgentOpsTracer to trace a function with providing a store which disabled native otlp exporter, rollout_id, and attempt_id.
235+
store = MockLightningStore(port)
236+
async with tracer.trace_context(
237+
name="agentops_test", store=store, rollout_id="test_rollout_id", attempt_id="test_attempt_id"
238+
):
239+
_func_without_exception()
240+
spans = tracer.get_last_trace()
241+
assert len(spans) > 0
242+
finally:
243+
tracer.teardown_worker(0)
244+
tracer.teardown()
245+
246+
mock_service.stop_service()
247+
248+
249+
def agentops_trace_paths() -> Iterator[Callable[[], Any]]:
250+
yield from [
251+
_test_agentops_trace_without_store_imp,
252+
_test_agentops_trace_with_store_disable_imp,
253+
_test_agentops_trace_with_store_enable_imp,
254+
]
255+
256+
257+
@pytest.mark.parametrize("func_name", [f.__name__ for f in agentops_trace_paths()], ids=str)
258+
def test_agentops_trace_with_store_or_not(func_name: str):
259+
"""
260+
The purpose of this test is to verify whether the following two scenarios both work correctly:
261+
262+
1. Using AgentOpsTracer to trace a function without providing a store, rollout_id, or attempt_id.
263+
2. Using AgentOpsTracer to trace a function with providing a store which disabled native otlp exporter, rollout_id, and attempt_id.
264+
3. Using AgentOpsTracer to trace a function with providing a store which enabled native otlp exporter, rollout_id, and attempt_id.
265+
"""
266+
267+
func = {f.__name__: f for f in agentops_trace_paths()}[func_name]
268+
269+
ctx = multiprocessing.get_context("spawn")
270+
proc = ctx.Process(target=_run_async, args=(func,))
271+
proc.start()
272+
proc.join(30.0) # On GPU server, the time is around 10 seconds.
273+
274+
if proc.is_alive():
275+
proc.terminate()
276+
proc.join(5)
277+
if proc.is_alive():
278+
proc.kill()
279+
280+
assert False, "Child process hung. Check test output for details."
281+
282+
assert proc.exitcode == 0, (
283+
f"Child process for test_trace_error_status_from_instance failed with exit code {proc.exitcode}. "
284+
"Check child traceback in test output."
285+
)
286+
287+
288+
def _run_async(coro: Callable[[], Coroutine[Any, Any, Any]]) -> None:
289+
"""Small wrapper: run async function inside multiprocessing target."""
290+
import asyncio
291+
292+
asyncio.run(coro())

0 commit comments

Comments
 (0)