Skip to content

Commit 8e6f48a

Browse files
authored
fix: attached custom attributes to all spans (#1235)
1 parent aaf9715 commit 8e6f48a

File tree

11 files changed

+88
-14
lines changed

11 files changed

+88
-14
lines changed

src/strands/event_loop/event_loop.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,10 @@ async def event_loop_cycle(
133133
# Create tracer span for this event loop cycle
134134
tracer = get_tracer()
135135
cycle_span = tracer.start_event_loop_cycle_span(
136-
invocation_state=invocation_state, messages=agent.messages, parent_span=agent.trace_span
136+
invocation_state=invocation_state,
137+
messages=agent.messages,
138+
parent_span=agent.trace_span,
139+
custom_trace_attributes=agent.trace_attributes,
137140
)
138141
invocation_state["event_loop_cycle_span"] = cycle_span
139142

@@ -320,6 +323,7 @@ async def _handle_model_execution(
320323
messages=agent.messages,
321324
parent_span=cycle_span,
322325
model_id=model_id,
326+
custom_trace_attributes=agent.trace_attributes,
323327
)
324328
with trace_api.use_span(model_invoke_span):
325329
await agent.hooks.invoke_callbacks_async(

src/strands/multiagent/base.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,13 @@
88
from abc import ABC, abstractmethod
99
from dataclasses import dataclass, field
1010
from enum import Enum
11-
from typing import Any, AsyncIterator, Union
11+
from typing import Any, AsyncIterator, Mapping, Union
1212

1313
from .._async import run_async
1414
from ..agent import AgentResult
1515
from ..types.event_loop import Metrics, Usage
1616
from ..types.multiagent import MultiAgentInput
17+
from ..types.traces import AttributeValue
1718

1819
logger = logging.getLogger(__name__)
1920

@@ -238,6 +239,18 @@ def deserialize_state(self, payload: dict[str, Any]) -> None:
238239
"""Restore orchestrator state from a session dict."""
239240
raise NotImplementedError
240241

242+
def _parse_trace_attributes(
243+
self, attributes: Mapping[str, AttributeValue] | None = None
244+
) -> dict[str, AttributeValue]:
245+
trace_attributes: dict[str, AttributeValue] = {}
246+
if attributes:
247+
for k, v in attributes.items():
248+
if isinstance(v, (str, int, float, bool)) or (
249+
isinstance(v, list) and all(isinstance(x, (str, int, float, bool)) for x in v)
250+
):
251+
trace_attributes[k] = v
252+
return trace_attributes
253+
241254

242255
# Private helper function to avoid duplicate code
243256

src/strands/multiagent/graph.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
import logging
2020
import time
2121
from dataclasses import dataclass, field
22-
from typing import Any, AsyncIterator, Callable, Optional, Tuple, cast
22+
from typing import Any, AsyncIterator, Callable, Mapping, Optional, Tuple, cast
2323

2424
from opentelemetry import trace as trace_api
2525

@@ -46,6 +46,7 @@
4646
from ..types.content import ContentBlock, Messages
4747
from ..types.event_loop import Metrics, Usage
4848
from ..types.multiagent import MultiAgentInput
49+
from ..types.traces import AttributeValue
4950
from .base import MultiAgentBase, MultiAgentResult, NodeResult, Status
5051

5152
logger = logging.getLogger(__name__)
@@ -413,6 +414,7 @@ def __init__(
413414
session_manager: Optional[SessionManager] = None,
414415
hooks: Optional[list[HookProvider]] = None,
415416
id: str = _DEFAULT_GRAPH_ID,
417+
trace_attributes: Optional[Mapping[str, AttributeValue]] = None,
416418
) -> None:
417419
"""Initialize Graph with execution limits and reset behavior.
418420
@@ -427,6 +429,7 @@ def __init__(
427429
session_manager: Session manager for persisting graph state and execution history (default: None)
428430
hooks: List of hook providers for monitoring and extending graph execution behavior (default: None)
429431
id: Unique graph id (default: None)
432+
trace_attributes: Custom trace attributes to apply to the agent's trace span (default: None)
430433
"""
431434
super().__init__()
432435

@@ -442,6 +445,7 @@ def __init__(
442445
self.reset_on_revisit = reset_on_revisit
443446
self.state = GraphState()
444447
self.tracer = get_tracer()
448+
self.trace_attributes: dict[str, AttributeValue] = self._parse_trace_attributes(trace_attributes)
445449
self.session_manager = session_manager
446450
self.hooks = HookRegistry()
447451
if self.session_manager:
@@ -537,7 +541,7 @@ async def stream_async(
537541
self.state.status = Status.EXECUTING
538542
self.state.start_time = start_time
539543

540-
span = self.tracer.start_multiagent_span(task, "graph")
544+
span = self.tracer.start_multiagent_span(task, "graph", custom_trace_attributes=self.trace_attributes)
541545
with trace_api.use_span(span, end_on_exit=True):
542546
try:
543547
logger.debug(

src/strands/multiagent/swarm.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
import logging
1919
import time
2020
from dataclasses import dataclass, field
21-
from typing import Any, AsyncIterator, Callable, Optional, Tuple, cast
21+
from typing import Any, AsyncIterator, Callable, Mapping, Optional, Tuple, cast
2222

2323
from opentelemetry import trace as trace_api
2424

@@ -46,6 +46,7 @@
4646
from ..types.content import ContentBlock, Messages
4747
from ..types.event_loop import Metrics, Usage
4848
from ..types.multiagent import MultiAgentInput
49+
from ..types.traces import AttributeValue
4950
from .base import MultiAgentBase, MultiAgentResult, NodeResult, Status
5051

5152
logger = logging.getLogger(__name__)
@@ -226,6 +227,7 @@ def __init__(
226227
session_manager: Optional[SessionManager] = None,
227228
hooks: Optional[list[HookProvider]] = None,
228229
id: str = _DEFAULT_SWARM_ID,
230+
trace_attributes: Optional[Mapping[str, AttributeValue]] = None,
229231
) -> None:
230232
"""Initialize Swarm with agents and configuration.
231233
@@ -243,6 +245,7 @@ def __init__(
243245
Disabled by default (default: 0)
244246
session_manager: Session manager for persisting graph state and execution history (default: None)
245247
hooks: List of hook providers for monitoring and extending graph execution behavior (default: None)
248+
trace_attributes: Custom trace attributes to apply to the agent's trace span (default: None)
246249
"""
247250
super().__init__()
248251
self.id = id
@@ -262,6 +265,7 @@ def __init__(
262265
completion_status=Status.PENDING,
263266
)
264267
self.tracer = get_tracer()
268+
self.trace_attributes: dict[str, AttributeValue] = self._parse_trace_attributes(trace_attributes)
265269

266270
self.session_manager = session_manager
267271
self.hooks = HookRegistry()
@@ -356,7 +360,7 @@ async def stream_async(
356360
self.state.completion_status = Status.EXECUTING
357361
self.state.start_time = time.time()
358362

359-
span = self.tracer.start_multiagent_span(task, "swarm")
363+
span = self.tracer.start_multiagent_span(task, "swarm", custom_trace_attributes=self.trace_attributes)
360364
with trace_api.use_span(span, end_on_exit=True):
361365
try:
362366
current_node = cast(SwarmNode, self.state.current_node)

src/strands/telemetry/tracer.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -277,6 +277,7 @@ def start_model_invoke_span(
277277
messages: Messages,
278278
parent_span: Optional[Span] = None,
279279
model_id: Optional[str] = None,
280+
custom_trace_attributes: Optional[Mapping[str, AttributeValue]] = None,
280281
**kwargs: Any,
281282
) -> Span:
282283
"""Start a new span for a model invocation.
@@ -285,13 +286,17 @@ def start_model_invoke_span(
285286
messages: Messages being sent to the model.
286287
parent_span: Optional parent span to link this span to.
287288
model_id: Optional identifier for the model being invoked.
289+
custom_trace_attributes: Optional mapping of custom trace attributes to include in the span.
288290
**kwargs: Additional attributes to add to the span.
289291
290292
Returns:
291293
The created span, or None if tracing is not enabled.
292294
"""
293295
attributes: Dict[str, AttributeValue] = self._get_common_attributes(operation_name="chat")
294296

297+
if custom_trace_attributes:
298+
attributes.update(custom_trace_attributes)
299+
295300
if model_id:
296301
attributes["gen_ai.request.model"] = model_id
297302

@@ -358,12 +363,19 @@ def end_model_invoke_span(
358363

359364
self._end_span(span, attributes, error)
360365

361-
def start_tool_call_span(self, tool: ToolUse, parent_span: Optional[Span] = None, **kwargs: Any) -> Span:
366+
def start_tool_call_span(
367+
self,
368+
tool: ToolUse,
369+
parent_span: Optional[Span] = None,
370+
custom_trace_attributes: Optional[Mapping[str, AttributeValue]] = None,
371+
**kwargs: Any,
372+
) -> Span:
362373
"""Start a new span for a tool call.
363374
364375
Args:
365376
tool: The tool being used.
366377
parent_span: Optional parent span to link this span to.
378+
custom_trace_attributes: Optional mapping of custom trace attributes to include in the span.
367379
**kwargs: Additional attributes to add to the span.
368380
369381
Returns:
@@ -377,6 +389,8 @@ def start_tool_call_span(self, tool: ToolUse, parent_span: Optional[Span] = None
377389
}
378390
)
379391

392+
if custom_trace_attributes:
393+
attributes.update(custom_trace_attributes)
380394
# Add additional kwargs as attributes
381395
attributes.update(kwargs)
382396

@@ -477,6 +491,7 @@ def start_event_loop_cycle_span(
477491
invocation_state: Any,
478492
messages: Messages,
479493
parent_span: Optional[Span] = None,
494+
custom_trace_attributes: Optional[Mapping[str, AttributeValue]] = None,
480495
**kwargs: Any,
481496
) -> Optional[Span]:
482497
"""Start a new span for an event loop cycle.
@@ -485,6 +500,7 @@ def start_event_loop_cycle_span(
485500
invocation_state: Arguments for the event loop cycle.
486501
parent_span: Optional parent span to link this span to.
487502
messages: Messages being processed in this cycle.
503+
custom_trace_attributes: Optional mapping of custom trace attributes to include in the span.
488504
**kwargs: Additional attributes to add to the span.
489505
490506
Returns:
@@ -497,6 +513,9 @@ def start_event_loop_cycle_span(
497513
"event_loop.cycle_id": event_loop_cycle_id,
498514
}
499515

516+
if custom_trace_attributes:
517+
attributes.update(custom_trace_attributes)
518+
500519
if "event_loop_parent_cycle_id" in invocation_state:
501520
attributes["event_loop.parent_cycle_id"] = str(invocation_state["event_loop_parent_cycle_id"])
502521

@@ -679,6 +698,7 @@ def start_multiagent_span(
679698
self,
680699
task: MultiAgentInput,
681700
instance: str,
701+
custom_trace_attributes: Optional[Mapping[str, AttributeValue]] = None,
682702
) -> Span:
683703
"""Start a new span for swarm invocation."""
684704
operation = f"invoke_{instance}"
@@ -689,6 +709,9 @@ def start_multiagent_span(
689709
}
690710
)
691711

712+
if custom_trace_attributes:
713+
attributes.update(custom_trace_attributes)
714+
692715
span = self._start_span(operation, attributes=attributes, span_kind=trace_api.SpanKind.CLIENT)
693716

694717
if self.use_latest_genai_conventions:

src/strands/tools/executors/_executor.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -249,7 +249,9 @@ async def _stream_with_trace(
249249

250250
tracer = get_tracer()
251251

252-
tool_call_span = tracer.start_tool_call_span(tool_use, cycle_span)
252+
tool_call_span = tracer.start_tool_call_span(
253+
tool_use, cycle_span, custom_trace_attributes=agent.trace_attributes
254+
)
253255
tool_trace = Trace(f"Tool: {tool_name}", parent_id=cycle_trace.id, raw_name=tool_name)
254256
tool_start_time = time.time()
255257

tests/strands/event_loop/test_event_loop.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,7 @@ def agent(model, system_prompt, messages, tool_registry, thread_pool, hook_regis
143143
mock.hooks = hook_registry
144144
mock.tool_executor = tool_executor
145145
mock._interrupt_state = _InterruptState()
146+
mock.trace_attributes = {}
146147

147148
return mock
148149

@@ -738,7 +739,10 @@ async def test_event_loop_cycle_with_parent_span(
738739

739740
# Verify parent_span was used when creating cycle span
740741
mock_tracer.start_event_loop_cycle_span.assert_called_once_with(
741-
invocation_state=unittest.mock.ANY, parent_span=parent_span, messages=messages
742+
invocation_state=unittest.mock.ANY,
743+
parent_span=parent_span,
744+
messages=messages,
745+
custom_trace_attributes=unittest.mock.ANY,
742746
)
743747

744748

tests/strands/event_loop/test_event_loop_structured_output.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ def mock_agent():
4040
agent.hooks = Mock()
4141
agent.hooks.invoke_callbacks_async = AsyncMock()
4242
agent.trace_span = None
43+
agent.trace_attributes = {}
4344
agent.tool_executor = Mock()
4445
agent._append_message = AsyncMock()
4546

tests/strands/telemetry/test_tracer.py

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -149,15 +149,20 @@ def test_start_model_invoke_span(mock_tracer):
149149

150150
messages = [{"role": "user", "content": [{"text": "Hello"}]}]
151151
model_id = "test-model"
152+
custom_attrs = {"custom_key": "custom_value", "user_id": "12345"}
152153

153-
span = tracer.start_model_invoke_span(messages=messages, agent_name="TestAgent", model_id=model_id)
154+
span = tracer.start_model_invoke_span(
155+
messages=messages, agent_name="TestAgent", model_id=model_id, custom_trace_attributes=custom_attrs
156+
)
154157

155158
mock_tracer.start_span.assert_called_once()
156159
assert mock_tracer.start_span.call_args[1]["name"] == "chat"
157160
assert mock_tracer.start_span.call_args[1]["kind"] == SpanKind.INTERNAL
158161
mock_span.set_attribute.assert_any_call("gen_ai.system", "strands-agents")
159162
mock_span.set_attribute.assert_any_call("gen_ai.operation.name", "chat")
160163
mock_span.set_attribute.assert_any_call("gen_ai.request.model", model_id)
164+
mock_span.set_attribute.assert_any_call("custom_key", "custom_value")
165+
mock_span.set_attribute.assert_any_call("user_id", "12345")
161166
mock_span.add_event.assert_called_with(
162167
"gen_ai.user.message", attributes={"content": json.dumps(messages[0]["content"])}
163168
)
@@ -293,15 +298,18 @@ def test_start_tool_call_span(mock_tracer):
293298
mock_tracer.start_span.return_value = mock_span
294299

295300
tool = {"name": "test-tool", "toolUseId": "123", "input": {"param": "value"}}
301+
custom_attrs = {"session_id": "abc123", "environment": "production"}
296302

297-
span = tracer.start_tool_call_span(tool)
303+
span = tracer.start_tool_call_span(tool, custom_trace_attributes=custom_attrs)
298304

299305
mock_tracer.start_span.assert_called_once()
300306
assert mock_tracer.start_span.call_args[1]["name"] == "execute_tool test-tool"
301307
mock_span.set_attribute.assert_any_call("gen_ai.tool.name", "test-tool")
302308
mock_span.set_attribute.assert_any_call("gen_ai.system", "strands-agents")
303309
mock_span.set_attribute.assert_any_call("gen_ai.operation.name", "execute_tool")
304310
mock_span.set_attribute.assert_any_call("gen_ai.tool.call.id", "123")
311+
mock_span.set_attribute.assert_any_call("session_id", "abc123")
312+
mock_span.set_attribute.assert_any_call("environment", "production")
305313
mock_span.add_event.assert_any_call(
306314
"gen_ai.tool.message", attributes={"role": "tool", "content": json.dumps({"param": "value"}), "id": "123"}
307315
)
@@ -361,14 +369,17 @@ def test_start_swarm_call_span_with_string_task(mock_tracer):
361369
mock_tracer.start_span.return_value = mock_span
362370

363371
task = "Design foo bar"
372+
custom_attrs = {"workflow_id": "wf-789", "priority": "high"}
364373

365-
span = tracer.start_multiagent_span(task, "swarm")
374+
span = tracer.start_multiagent_span(task, "swarm", custom_trace_attributes=custom_attrs)
366375

367376
mock_tracer.start_span.assert_called_once()
368377
assert mock_tracer.start_span.call_args[1]["name"] == "invoke_swarm"
369378
mock_span.set_attribute.assert_any_call("gen_ai.system", "strands-agents")
370379
mock_span.set_attribute.assert_any_call("gen_ai.agent.name", "swarm")
371380
mock_span.set_attribute.assert_any_call("gen_ai.operation.name", "invoke_swarm")
381+
mock_span.set_attribute.assert_any_call("workflow_id", "wf-789")
382+
mock_span.set_attribute.assert_any_call("priority", "high")
372383
mock_span.add_event.assert_any_call("gen_ai.user.message", attributes={"content": "Design foo bar"})
373384
assert span is not None
374385

@@ -575,12 +586,17 @@ def test_start_event_loop_cycle_span(mock_tracer):
575586

576587
event_loop_kwargs = {"event_loop_cycle_id": "cycle-123"}
577588
messages = [{"role": "user", "content": [{"text": "Hello"}]}]
589+
custom_attrs = {"request_id": "req-456", "trace_level": "debug"}
578590

579-
span = tracer.start_event_loop_cycle_span(event_loop_kwargs, messages=messages)
591+
span = tracer.start_event_loop_cycle_span(
592+
event_loop_kwargs, messages=messages, custom_trace_attributes=custom_attrs
593+
)
580594

581595
mock_tracer.start_span.assert_called_once()
582596
assert mock_tracer.start_span.call_args[1]["name"] == "execute_event_loop_cycle"
583597
mock_span.set_attribute.assert_any_call("event_loop.cycle_id", "cycle-123")
598+
mock_span.set_attribute.assert_any_call("request_id", "req-456")
599+
mock_span.set_attribute.assert_any_call("trace_level", "debug")
584600
mock_span.add_event.assert_any_call(
585601
"gen_ai.user.message", attributes={"content": json.dumps([{"text": "Hello"}])}
586602
)

tests/strands/tools/executors/conftest.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,7 @@ def agent(tool_registry, hook_registry):
105105
mock_agent.tool_registry = tool_registry
106106
mock_agent.hooks = hook_registry
107107
mock_agent._interrupt_state = _InterruptState()
108+
mock_agent.trace_attributes = {}
108109
return mock_agent
109110

110111

0 commit comments

Comments
 (0)