Skip to content

Commit 65f9230

Browse files
committed
wip
1 parent 50cfff2 commit 65f9230

4 files changed

Lines changed: 159 additions & 123 deletions

File tree

sentry_sdk/ai/span_config.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
import sentry_sdk
2+
from sentry_sdk.consts import SPANDATA
3+
from sentry_sdk.ai.utils import (
4+
set_data_normalized,
5+
normalize_message_roles,
6+
truncate_and_annotate_messages,
7+
)
8+
from sentry_sdk.scope import should_send_default_pii
9+
10+
from typing import TYPE_CHECKING
11+
12+
if TYPE_CHECKING:
13+
from typing import Any, Dict
14+
from sentry_sdk.tracing import Span
15+
16+
17+
def set_input_span_data(span, kwargs, integration, config):
18+
# type: (Span, Dict[str, Any], Any, Dict[str, Any]) -> None
19+
"""
20+
Set input span data from a declarative config.
21+
22+
Config keys:
23+
system: str - gen_ai.system value
24+
operation: str - gen_ai.operation.name value
25+
params: dict - kwargs key -> span attr (always set if present)
26+
pii_params: dict - kwargs key -> span attr (only when PII allowed)
27+
extract_messages: callable(kwargs) -> list or None
28+
message_target: str - span attr for messages (default: GEN_AI_REQUEST_MESSAGES)
29+
truncation_fn: callable or None - truncation function (default: truncate_and_annotate_messages, None to skip)
30+
is_given: callable(value) -> bool - for NotGiven sentinels
31+
extra_static: dict - additional key/value pairs to set
32+
"""
33+
set_data_normalized(span, SPANDATA.GEN_AI_SYSTEM, config["system"])
34+
set_data_normalized(span, SPANDATA.GEN_AI_OPERATION_NAME, config["operation"])
35+
36+
is_given = config.get("is_given")
37+
for kwarg_key, span_attr in config.get("params", {}).items():
38+
if kwarg_key in kwargs:
39+
value = kwargs[kwarg_key]
40+
if is_given is None or is_given(value):
41+
set_data_normalized(span, span_attr, value)
42+
43+
if should_send_default_pii() and integration.include_prompts:
44+
extract = config.get("extract_messages")
45+
if extract is not None:
46+
messages = extract(kwargs)
47+
if messages:
48+
messages = normalize_message_roles(messages)
49+
truncation_fn = config.get(
50+
"truncation_fn", truncate_and_annotate_messages
51+
)
52+
if truncation_fn is not None:
53+
scope = sentry_sdk.get_current_scope()
54+
messages = truncation_fn(messages, span, scope)
55+
if messages is not None:
56+
target = config.get(
57+
"message_target", SPANDATA.GEN_AI_REQUEST_MESSAGES
58+
)
59+
set_data_normalized(span, target, messages, unpack=False)
60+
61+
for kwarg_key, span_attr in config.get("pii_params", {}).items():
62+
if kwarg_key in kwargs:
63+
value = kwargs[kwarg_key]
64+
if is_given is None or is_given(value):
65+
set_data_normalized(span, span_attr, value)
66+
67+
for key, value in config.get("extra_static", {}).items():
68+
set_data_normalized(span, key, value)

sentry_sdk/integrations/cohere/__init__.py

Lines changed: 12 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
from sentry_sdk.ai.monitoring import record_token_usage
55
from sentry_sdk.consts import OP, SPANDATA
6-
from sentry_sdk.ai.utils import set_data_normalized
6+
from sentry_sdk.ai.span_config import set_input_span_data
77

88
from typing import TYPE_CHECKING
99

@@ -13,7 +13,6 @@
1313
from typing import Any, Callable
1414

1515
import sentry_sdk
16-
from sentry_sdk.scope import should_send_default_pii
1716
from sentry_sdk.integrations import DidNotEnable, Integration
1817
from sentry_sdk.utils import capture_internal_exceptions, event_from_exception, reraise
1918

@@ -43,6 +42,16 @@ def _normalize_embedding_input(texts):
4342
return [texts]
4443

4544

45+
COHERE_EMBED_CONFIG = {
46+
"system": "cohere",
47+
"operation": "embeddings",
48+
"params": {"model": SPANDATA.GEN_AI_REQUEST_MODEL},
49+
"extract_messages": lambda kw: _normalize_embedding_input(kw["texts"]) if "texts" in kw else None,
50+
"message_target": SPANDATA.GEN_AI_EMBEDDINGS_INPUT,
51+
"truncation_fn": None,
52+
}
53+
54+
4655
class CohereIntegration(Integration):
4756
identifier = "cohere"
4857
origin = f"auto.ai.{identifier}"
@@ -91,22 +100,7 @@ def new_embed(*args, **kwargs):
91100
name=f"embeddings {model}".strip(),
92101
origin=CohereIntegration.origin,
93102
) as span:
94-
set_data_normalized(span, SPANDATA.GEN_AI_SYSTEM, "cohere")
95-
set_data_normalized(span, SPANDATA.GEN_AI_OPERATION_NAME, "embeddings")
96-
97-
if "texts" in kwargs and (
98-
should_send_default_pii() and integration.include_prompts
99-
):
100-
set_data_normalized(
101-
span,
102-
SPANDATA.GEN_AI_EMBEDDINGS_INPUT,
103-
_normalize_embedding_input(kwargs["texts"]),
104-
)
105-
106-
if "model" in kwargs:
107-
set_data_normalized(
108-
span, SPANDATA.GEN_AI_REQUEST_MODEL, kwargs["model"]
109-
)
103+
set_input_span_data(span, kwargs, integration, COHERE_EMBED_CONFIG)
110104

111105
try:
112106
res = f(*args, **kwargs)

sentry_sdk/integrations/cohere/v1.py

Lines changed: 62 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,8 @@
33

44
from sentry_sdk.ai.monitoring import record_token_usage
55
from sentry_sdk.consts import OP, SPANDATA
6-
from sentry_sdk.ai.utils import (
7-
set_data_normalized,
8-
normalize_message_roles,
9-
truncate_and_annotate_messages,
10-
)
6+
from sentry_sdk.ai.utils import set_data_normalized
7+
from sentry_sdk.ai.span_config import set_input_span_data
118

129
from typing import TYPE_CHECKING
1310

@@ -48,6 +45,32 @@
4845
"preamble": SPANDATA.GEN_AI_SYSTEM_INSTRUCTIONS,
4946
}
5047

48+
49+
def _extract_messages_v1(kwargs):
50+
# type: (dict[str, Any]) -> list[dict[str, str]]
51+
"""Extract role/content dicts from V1-style chat_history + message."""
52+
messages = []
53+
for x in kwargs.get("chat_history", []):
54+
messages.append(
55+
{
56+
"role": getattr(x, "role", "").lower(),
57+
"content": getattr(x, "message", ""),
58+
}
59+
)
60+
message = kwargs.get("message")
61+
if message:
62+
messages.append({"role": "user", "content": message})
63+
return messages
64+
65+
66+
COHERE_V1_CHAT_CONFIG = {
67+
"system": "cohere",
68+
"operation": "chat",
69+
"params": COLLECTED_CHAT_PARAMS,
70+
"pii_params": COLLECTED_PII_CHAT_PARAMS,
71+
"extract_messages": _extract_messages_v1,
72+
}
73+
5174
COLLECTED_CHAT_RESP_ATTRS = {
5275
"generation_id": SPANDATA.GEN_AI_RESPONSE_ID,
5376
"finish_reason": SPANDATA.GEN_AI_RESPONSE_FINISH_REASONS,
@@ -77,36 +100,6 @@ def _wrap_chat(f, streaming):
77100
if not _has_chat_types:
78101
return f
79102

80-
def collect_chat_response_fields(span, res, include_pii):
81-
# type: (Span, NonStreamedChatResponse, bool) -> None
82-
if include_pii:
83-
if hasattr(res, "text"):
84-
set_data_normalized(
85-
span,
86-
SPANDATA.GEN_AI_RESPONSE_TEXT,
87-
[res.text],
88-
)
89-
for attr, spandata_key in COLLECTED_PII_CHAT_RESP_ATTRS.items():
90-
if hasattr(res, attr):
91-
set_data_normalized(span, spandata_key, getattr(res, attr))
92-
93-
for attr, spandata_key in COLLECTED_CHAT_RESP_ATTRS.items():
94-
if hasattr(res, attr):
95-
set_data_normalized(span, spandata_key, getattr(res, attr))
96-
97-
if hasattr(res, "meta"):
98-
if hasattr(res.meta, "billed_units"):
99-
record_token_usage(
100-
span,
101-
input_tokens=res.meta.billed_units.input_tokens,
102-
output_tokens=res.meta.billed_units.output_tokens,
103-
)
104-
elif hasattr(res.meta, "tokens"):
105-
record_token_usage(
106-
span,
107-
input_tokens=res.meta.tokens.input_tokens,
108-
output_tokens=res.meta.tokens.output_tokens,
109-
)
110103

111104
@wraps(f)
112105
def new_chat(*args, **kwargs):
@@ -120,7 +113,6 @@ def new_chat(*args, **kwargs):
120113
):
121114
return f(*args, **kwargs)
122115

123-
message = kwargs.get("message")
124116
model = kwargs.get("model", "")
125117

126118
with sentry_sdk.start_span(
@@ -137,41 +129,10 @@ def new_chat(*args, **kwargs):
137129
reraise(*exc_info)
138130

139131
with capture_internal_exceptions():
140-
set_data_normalized(span, SPANDATA.GEN_AI_SYSTEM, "cohere")
141-
set_data_normalized(span, SPANDATA.GEN_AI_OPERATION_NAME, "chat")
142-
if model:
143-
set_data_normalized(span, SPANDATA.GEN_AI_REQUEST_MODEL, model)
144-
145-
if should_send_default_pii() and integration.include_prompts:
146-
messages = []
147-
for x in kwargs.get("chat_history", []):
148-
messages.append(
149-
{
150-
"role": getattr(x, "role", "").lower(),
151-
"content": getattr(x, "message", ""),
152-
}
153-
)
154-
messages.append({"role": "user", "content": message})
155-
messages = normalize_message_roles(messages)
156-
scope = sentry_sdk.get_current_scope()
157-
messages_data = truncate_and_annotate_messages(
158-
messages, span, scope
159-
)
160-
if messages_data is not None:
161-
set_data_normalized(
162-
span,
163-
SPANDATA.GEN_AI_REQUEST_MESSAGES,
164-
messages_data,
165-
unpack=False,
166-
)
167-
for k, v in COLLECTED_PII_CHAT_PARAMS.items():
168-
if k in kwargs:
169-
set_data_normalized(span, v, kwargs[k])
170-
171-
for k, v in COLLECTED_CHAT_PARAMS.items():
172-
if k in kwargs:
173-
set_data_normalized(span, v, kwargs[k])
174-
set_data_normalized(span, SPANDATA.GEN_AI_RESPONSE_STREAMING, streaming)
132+
set_input_span_data(span, kwargs, integration, {
133+
**COHERE_V1_CHAT_CONFIG,
134+
"extra_static": {SPANDATA.GEN_AI_RESPONSE_STREAMING: streaming},
135+
})
175136

176137
if streaming:
177138
old_iterator = res
@@ -203,4 +164,34 @@ def new_iterator():
203164
set_data_normalized(span, "unknown_response", True)
204165
return res
205166

167+
def collect_chat_response_fields(span, res, include_pii):
168+
# type: (Span, NonStreamedChatResponse, bool) -> None
169+
if include_pii:
170+
if hasattr(res, "text"):
171+
set_data_normalized(
172+
span,
173+
SPANDATA.GEN_AI_RESPONSE_TEXT,
174+
[res.text],
175+
)
176+
for attr, spandata_key in COLLECTED_PII_CHAT_RESP_ATTRS.items():
177+
if hasattr(res, attr):
178+
set_data_normalized(span, spandata_key, getattr(res, attr))
179+
180+
for attr, spandata_key in COLLECTED_CHAT_RESP_ATTRS.items():
181+
if hasattr(res, attr):
182+
set_data_normalized(span, spandata_key, getattr(res, attr))
183+
184+
if hasattr(res, "meta"):
185+
if hasattr(res.meta, "billed_units"):
186+
record_token_usage(
187+
span,
188+
input_tokens=res.meta.billed_units.input_tokens,
189+
output_tokens=res.meta.billed_units.output_tokens,
190+
)
191+
elif hasattr(res.meta, "tokens"):
192+
record_token_usage(
193+
span,
194+
input_tokens=res.meta.tokens.input_tokens,
195+
output_tokens=res.meta.tokens.output_tokens,
196+
)
206197
return new_chat

sentry_sdk/integrations/cohere/v2.py

Lines changed: 17 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,8 @@
33

44
from sentry_sdk.ai.monitoring import record_token_usage
55
from sentry_sdk.consts import OP, SPANDATA
6-
from sentry_sdk.ai.utils import (
7-
set_data_normalized,
8-
normalize_message_roles,
9-
truncate_and_annotate_messages,
10-
)
6+
from sentry_sdk.ai.utils import set_data_normalized
7+
from sentry_sdk.ai.span_config import set_input_span_data
118

129
from typing import TYPE_CHECKING
1310

@@ -101,6 +98,15 @@ def _extract_messages_v2(messages):
10198
return result
10299

103100

101+
COHERE_V2_CHAT_CONFIG = {
102+
"system": "cohere",
103+
"operation": "chat",
104+
"params": COLLECTED_CHAT_PARAMS,
105+
"pii_params": {"tools": SPANDATA.GEN_AI_REQUEST_AVAILABLE_TOOLS},
106+
"extract_messages": lambda kw: _extract_messages_v2(kw.get("messages", [])),
107+
}
108+
109+
104110
def _record_token_usage_v2(span, usage):
105111
# type: (Span, Any) -> None
106112
"""Extract and record token usage from a V2 Usage object."""
@@ -180,36 +186,13 @@ def new_chat(*args, **kwargs):
180186
reraise(*exc_info)
181187

182188
with capture_internal_exceptions():
183-
set_data_normalized(span, SPANDATA.GEN_AI_SYSTEM, "cohere")
184-
set_data_normalized(span, SPANDATA.GEN_AI_OPERATION_NAME, "chat")
189+
extra = {SPANDATA.GEN_AI_RESPONSE_STREAMING: streaming}
185190
if model:
186-
set_data_normalized(span, SPANDATA.GEN_AI_REQUEST_MODEL, model)
187-
188-
if should_send_default_pii() and integration.include_prompts:
189-
messages = _extract_messages_v2(kwargs.get("messages", []))
190-
messages = normalize_message_roles(messages)
191-
scope = sentry_sdk.get_current_scope()
192-
messages_data = truncate_and_annotate_messages(
193-
messages, span, scope
194-
)
195-
if messages_data is not None:
196-
set_data_normalized(
197-
span,
198-
SPANDATA.GEN_AI_REQUEST_MESSAGES,
199-
messages_data,
200-
unpack=False,
201-
)
202-
if "tools" in kwargs:
203-
set_data_normalized(
204-
span,
205-
SPANDATA.GEN_AI_REQUEST_AVAILABLE_TOOLS,
206-
kwargs["tools"],
207-
)
208-
209-
for k, v in COLLECTED_CHAT_PARAMS.items():
210-
if k in kwargs:
211-
set_data_normalized(span, v, kwargs[k])
212-
set_data_normalized(span, SPANDATA.GEN_AI_RESPONSE_STREAMING, streaming)
191+
extra[SPANDATA.GEN_AI_RESPONSE_MODEL] = model
192+
set_input_span_data(span, kwargs, integration, {
193+
**COHERE_V2_CHAT_CONFIG,
194+
"extra_static": extra,
195+
})
213196

214197
if streaming:
215198
old_iterator = res

0 commit comments

Comments
 (0)