Skip to content
Merged
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

from dataclasses import dataclass


@dataclass
class Response:
"""Response details from agent execution."""

"""The list of response messages from the agent."""
messages: list[str]
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,14 @@
from typing import TYPE_CHECKING, Any

from opentelemetry import baggage, context, trace
from opentelemetry.trace import Span, SpanKind, Status, StatusCode, Tracer, set_span_in_context
from opentelemetry.trace import (
Span,
SpanKind,
Status,
StatusCode,
Tracer,
set_span_in_context,
)

from .constants import (
ENABLE_A365_OBSERVABILITY,
Expand All @@ -32,6 +39,7 @@
SOURCE_NAME,
TENANT_ID_KEY,
)
from .utils import parse_parent_id_to_context

if TYPE_CHECKING:
from .agent_details import AgentDetails
Expand Down Expand Up @@ -71,6 +79,7 @@ def __init__(
activity_name: str,
agent_details: "AgentDetails | None" = None,
tenant_details: "TenantDetails | None" = None,
parent_id: str | None = None,
):
"""Initialize the OpenTelemetry scope.

Expand All @@ -80,6 +89,8 @@ def __init__(
activity_name: The name of the activity for display purposes
agent_details: Optional agent details
tenant_details: Optional tenant details
parent_id: Optional parent Activity ID used to link this span to an upstream
operation
"""
self._span: Span | None = None
self._start_time = time.time()
Expand All @@ -102,12 +113,13 @@ def __init__(
elif kind.lower() == "consumer":
activity_kind = SpanKind.CONSUMER

# Get current context for parent relationship
current_context = context.get_current()
# Get context for parent relationship
# If parent_id is provided, parse it and use it as the parent context
# Otherwise, use the current context
parent_context = parse_parent_id_to_context(parent_id)
span_context = parent_context if parent_context else context.get_current()

self._span = tracer.start_span(
activity_name, kind=activity_kind, context=current_context
)
self._span = tracer.start_span(activity_name, kind=activity_kind, context=span_context)

# Log span creation
if self._span:
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

from ..agent_details import AgentDetails
from ..constants import GEN_AI_OUTPUT_MESSAGES_KEY
from ..models.response import Response
from ..opentelemetry_scope import OpenTelemetryScope
from ..tenant_details import TenantDetails
from ..utils import safe_json_dumps

OUTPUT_OPERATION_NAME = "output_messages"


class OutputScope(OpenTelemetryScope):
"""Provides OpenTelemetry tracing scope for output messages."""

@staticmethod
def start(
agent_details: AgentDetails,
tenant_details: TenantDetails,
response: Response,
parent_id: str | None = None,
) -> "OutputScope":
"""Creates and starts a new scope for output tracing.

Args:
agent_details: The details of the agent
tenant_details: The details of the tenant
response: The response details from the agent
parent_id: Optional parent Activity ID used to link this span to an upstream
operation

Returns:
A new OutputScope instance
"""
return OutputScope(agent_details, tenant_details, response, parent_id)

def __init__(
self,
agent_details: AgentDetails,
tenant_details: TenantDetails,
response: Response,
parent_id: str | None = None,
):
"""Initialize the output scope.

Args:
agent_details: The details of the agent
tenant_details: The details of the tenant
response: The response details from the agent
parent_id: Optional parent Activity ID used to link this span to an upstream
operation
"""
super().__init__(
kind="Client",
operation_name=OUTPUT_OPERATION_NAME,
activity_name=(f"{OUTPUT_OPERATION_NAME} {agent_details.agent_id}"),
agent_details=agent_details,
tenant_details=tenant_details,
parent_id=parent_id,
)

# Initialize accumulated messages list
self._output_messages: list[str] = list(response.messages)

# Set response messages
self.set_tag_maybe(GEN_AI_OUTPUT_MESSAGES_KEY, safe_json_dumps(self._output_messages))

def record_output_messages(self, messages: list[str]) -> None:
"""Records the output messages for telemetry tracking.

Appends the provided messages to the accumulated output messages list.

Args:
messages: List of output messages to append
"""
self._output_messages.extend(messages)
self.set_tag_maybe(GEN_AI_OUTPUT_MESSAGES_KEY, safe_json_dumps(self._output_messages))
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,12 @@
from threading import RLock
from typing import Any, Generic, TypeVar, cast

from opentelemetry import context
from opentelemetry.semconv.attributes.exception_attributes import (
EXCEPTION_MESSAGE,
EXCEPTION_STACKTRACE,
)
from opentelemetry.trace import Span
from opentelemetry.trace import NonRecordingSpan, Span, SpanContext, TraceFlags, set_span_in_context
from opentelemetry.util.types import AttributeValue
from wrapt import ObjectProxy

Expand All @@ -27,6 +28,128 @@
logger.addHandler(logging.NullHandler())


# W3C Trace Context constants
W3C_TRACE_CONTEXT_VERSION = "00"
W3C_TRACE_ID_LENGTH = 32 # 32 hex chars = 128 bits
W3C_SPAN_ID_LENGTH = 16 # 16 hex chars = 64 bits


def validate_w3c_trace_context_version(version: str) -> bool:
"""Validate W3C Trace Context version.

Args:
version: The version string to validate

Returns:
True if valid, False otherwise
"""
return version == W3C_TRACE_CONTEXT_VERSION


def _is_valid_hex(hex_string: str) -> bool:
"""Check if a string contains only valid hexadecimal characters.

Args:
hex_string: The string to validate

Returns:
True if all characters are valid hexadecimal (0-9, a-f, A-F), False otherwise
"""
return all(c in "0123456789abcdefABCDEF" for c in hex_string)


def validate_trace_id(trace_id_hex: str) -> bool:
"""Validate W3C Trace Context trace_id format.

Args:
trace_id_hex: The trace_id hex string to validate (should be 32 hex chars)

Returns:
True if valid (32 hex chars), False otherwise
"""
return len(trace_id_hex) == W3C_TRACE_ID_LENGTH and _is_valid_hex(trace_id_hex)


def validate_span_id(span_id_hex: str) -> bool:
"""Validate W3C Trace Context span_id format.

Args:
span_id_hex: The span_id hex string to validate (should be 16 hex chars)

Returns:
True if valid (16 hex chars), False otherwise
"""
return len(span_id_hex) == W3C_SPAN_ID_LENGTH and _is_valid_hex(span_id_hex)


def parse_parent_id_to_context(parent_id: str | None) -> context.Context | None:
"""Parse a W3C trace context parent ID and return a context with the parent span.

The parent_id format is expected to be W3C Trace Context format:
"00-{trace_id}-{span_id}-{trace_flags}"
Example: "00-1234567890abcdef1234567890abcdef-abcdefabcdef1234-01"

Args:
parent_id: The W3C Trace Context format parent ID string

Returns:
A context containing the parent span, or None if parent_id is invalid
"""
if not parent_id:
return None

try:
# W3C Trace Context format: "00-{trace_id}-{span_id}-{trace_flags}"
parts = parent_id.split("-")
if len(parts) != 4:
logger.warning(f"Invalid parent_id format (expected 4 parts): {parent_id}")
return None

version, trace_id_hex, span_id_hex, trace_flags_hex = parts

# Validate W3C Trace Context version
if not validate_w3c_trace_context_version(version):
logger.warning(f"Unsupported W3C Trace Context version: {version}")
return None

# Validate trace_id (must be 32 hex chars)
if not validate_trace_id(trace_id_hex):
logger.warning(
f"Invalid trace_id (expected {W3C_TRACE_ID_LENGTH} hex chars): '{trace_id_hex}'"
)
return None

# Validate span_id (must be 16 hex chars)
if not validate_span_id(span_id_hex):
logger.warning(
f"Invalid span_id (expected {W3C_SPAN_ID_LENGTH} hex chars): '{span_id_hex}'"
)
return None

# Parse the hex values
trace_id = int(trace_id_hex, 16)
span_id = int(span_id_hex, 16)
trace_flags = TraceFlags(int(trace_flags_hex, 16))

# Create a SpanContext from the parsed values
parent_span_context = SpanContext(
trace_id=trace_id,
span_id=span_id,
is_remote=True,
trace_flags=trace_flags,
)

# Create a NonRecordingSpan with the parent context
parent_span = NonRecordingSpan(parent_span_context)

# Create a context with the parent span
return set_span_in_context(parent_span)

except (ValueError, IndexError) as e:
logger.warning(f"Failed to parse parent_id '{parent_id}': {e}")
return None


def safe_json_dumps(obj: Any, **kwargs: Any) -> str:
return json.dumps(obj, default=str, ensure_ascii=False, **kwargs)

Expand Down
Loading