Skip to content

Commit b1d4fbd

Browse files
committed
refacotr
1 parent 9532175 commit b1d4fbd

4 files changed

Lines changed: 244 additions & 291 deletions

File tree

sentry_sdk/integrations/cohere/__init__.py

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -21,17 +21,6 @@
2121
except ImportError:
2222
raise DidNotEnable("Cohere not installed")
2323

24-
COLLECTED_CHAT_PARAMS = {
25-
"model": SPANDATA.GEN_AI_REQUEST_MODEL,
26-
"temperature": SPANDATA.GEN_AI_REQUEST_TEMPERATURE,
27-
"max_tokens": SPANDATA.GEN_AI_REQUEST_MAX_TOKENS,
28-
"k": SPANDATA.GEN_AI_REQUEST_TOP_K,
29-
"p": SPANDATA.GEN_AI_REQUEST_TOP_P,
30-
"seed": SPANDATA.GEN_AI_REQUEST_SEED,
31-
"frequency_penalty": SPANDATA.GEN_AI_REQUEST_FREQUENCY_PENALTY,
32-
"presence_penalty": SPANDATA.GEN_AI_REQUEST_PRESENCE_PENALTY,
33-
}
34-
3524

3625
def _normalize_embedding_input(texts):
3726
# type: (Any) -> Any
@@ -66,7 +55,6 @@ def __init__(self, include_prompts=True):
6655
def setup_once():
6756
# type: () -> None
6857
# Lazy imports to avoid circular dependencies:
69-
# v1/v2 import COLLECTED_CHAT_PARAMS and _capture_exception from this module.
7058
from sentry_sdk.integrations.cohere.v1 import setup_v1
7159
from sentry_sdk.integrations.cohere.v2 import setup_v2
7260

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
from sentry_sdk.ai.utils import set_data_normalized
2+
3+
from typing import TYPE_CHECKING
4+
5+
if TYPE_CHECKING:
6+
from typing import Any
7+
8+
_MISSING = object()
9+
10+
11+
def transitive_getattr(obj, *attrs):
12+
# type: (Any, str) -> Any
13+
current = obj
14+
for attr in attrs:
15+
current = getattr(current, attr, _MISSING)
16+
if current is _MISSING:
17+
return None
18+
return current
19+
20+
21+
def get_first_from_sources(obj, source_paths, require_truthy=False):
22+
# type: (Any, list[tuple[str, ...]], bool) -> Any
23+
for source_path in source_paths:
24+
value = transitive_getattr(obj, *source_path)
25+
if value if require_truthy else value is not None:
26+
return value
27+
return None
28+
29+
30+
def set_span_data_from_sources(span, obj, target_sources, require_truthy):
31+
# type: (Any, Any, dict[str, list[tuple[str, ...]]], bool) -> None
32+
for spandata_key, source_paths in target_sources.items():
33+
value = get_first_from_sources(obj, source_paths, require_truthy=require_truthy)
34+
if value is not None:
35+
set_data_normalized(span, spandata_key, value)

sentry_sdk/integrations/cohere/v1.py

Lines changed: 91 additions & 105 deletions
Original file line numberDiff line numberDiff line change
@@ -2,32 +2,30 @@
22
from functools import wraps
33

44
from sentry_sdk.ai.monitoring import record_token_usage
5-
from sentry_sdk.consts import OP, SPANDATA
6-
from sentry_sdk.ai.utils import set_data_normalized
75
from sentry_sdk.ai.span_config import set_input_span_data
6+
from sentry_sdk.ai.utils import set_data_normalized, transform_message_content
7+
from sentry_sdk.consts import OP, SPANDATA
88

99
from typing import TYPE_CHECKING
1010

1111
if TYPE_CHECKING:
1212
from typing import Any, Callable, Iterator
1313
from cohere import StreamedChatResponse
14-
from sentry_sdk.tracing import Span
1514

1615
import sentry_sdk
17-
from sentry_sdk.scope import should_send_default_pii
18-
from sentry_sdk.utils import capture_internal_exceptions, reraise
19-
2016
from sentry_sdk.integrations.cohere import (
2117
CohereIntegration,
22-
COLLECTED_CHAT_PARAMS,
2318
_capture_exception,
2419
)
20+
from sentry_sdk.integrations.cohere.utils import (
21+
get_first_from_sources,
22+
set_span_data_from_sources,
23+
)
24+
from sentry_sdk.scope import should_send_default_pii
25+
from sentry_sdk.utils import capture_internal_exceptions, reraise
2526

2627
try:
27-
from cohere import (
28-
ChatStreamEndEvent,
29-
NonStreamedChatResponse,
30-
)
28+
from cohere import ChatStreamEndEvent, NonStreamedChatResponse
3129

3230
try:
3331
from cohere import StreamEndStreamedChatResponse
@@ -40,53 +38,32 @@
4038
except ImportError:
4139
_has_chat_types = False
4240

43-
COLLECTED_PII_CHAT_PARAMS = {
44-
"tools": SPANDATA.GEN_AI_REQUEST_AVAILABLE_TOOLS,
45-
"preamble": SPANDATA.GEN_AI_SYSTEM_INSTRUCTIONS,
46-
}
47-
48-
49-
def _extract_messages_v1(kwargs):
50-
# type: (dict[str, Any]) -> list[dict[str, str]]
51-
"""Extract role/content dicts from V1-style chat_history + message."""
52-
messages = []
53-
for x in kwargs.get("chat_history", []):
54-
messages.append(
55-
{
56-
"role": getattr(x, "role", "").lower(),
57-
"content": getattr(x, "message", ""),
58-
}
59-
)
60-
message = kwargs.get("message")
61-
if message:
62-
messages.append({"role": "user", "content": message})
63-
return messages
64-
65-
66-
COHERE_V1_CHAT_CONFIG = {
67-
"system": "cohere",
68-
"operation": "chat",
69-
"params": COLLECTED_CHAT_PARAMS,
70-
"pii_params": COLLECTED_PII_CHAT_PARAMS,
71-
"extract_messages": _extract_messages_v1,
41+
CHAT_RESPONSE_SOURCES = {
42+
SPANDATA.GEN_AI_RESPONSE_ID: [("generation_id",)],
43+
SPANDATA.GEN_AI_RESPONSE_FINISH_REASONS: [("finish_reason",)],
7244
}
73-
74-
COLLECTED_CHAT_RESP_ATTRS = {
75-
"generation_id": SPANDATA.GEN_AI_RESPONSE_ID,
76-
"finish_reason": SPANDATA.GEN_AI_RESPONSE_FINISH_REASONS,
45+
PII_CHAT_RESPONSE_SOURCES = {
46+
SPANDATA.GEN_AI_RESPONSE_TOOL_CALLS: [("tool_calls",)],
7747
}
78-
79-
COLLECTED_PII_CHAT_RESP_ATTRS = {
80-
"tool_calls": SPANDATA.GEN_AI_RESPONSE_TOOL_CALLS,
48+
CHAT_RESPONSE_TEXT_SOURCES = [("text",)]
49+
CHAT_USAGE_TOKEN_SOURCES = {
50+
SPANDATA.GEN_AI_USAGE_INPUT_TOKENS: [
51+
("meta", "billed_units", "input_tokens"),
52+
("meta", "tokens", "input_tokens"),
53+
],
54+
SPANDATA.GEN_AI_USAGE_OUTPUT_TOKENS: [
55+
("meta", "billed_units", "output_tokens"),
56+
("meta", "tokens", "output_tokens"),
57+
],
8158
}
59+
STREAM_RESPONSE_SOURCES = [("response",)]
8260

8361

8462
def setup_v1(wrap_embed_fn):
8563
# type: (Callable[..., Any]) -> None
86-
"""Called from CohereIntegration.setup_once() to patch V1 Client methods."""
8764
try:
88-
from cohere.client import Client
8965
from cohere.base_client import BaseCohere
66+
from cohere.client import Client
9067
except ImportError:
9168
return
9269

@@ -104,7 +81,6 @@ def _wrap_chat(f, streaming):
10481
def new_chat(*args, **kwargs):
10582
# type: (*Any, **Any) -> Any
10683
integration = sentry_sdk.get_client().get_integration(CohereIntegration)
107-
10884
if (
10985
integration is None
11086
or "message" not in kwargs
@@ -113,6 +89,7 @@ def new_chat(*args, **kwargs):
11389
return f(*args, **kwargs)
11490

11591
model = kwargs.get("model", "")
92+
include_pii = should_send_default_pii() and integration.include_prompts
11693

11794
with sentry_sdk.start_span(
11895
op=OP.GEN_AI_CHAT,
@@ -128,75 +105,84 @@ def new_chat(*args, **kwargs):
128105
reraise(*exc_info)
129106

130107
with capture_internal_exceptions():
108+
if model:
109+
set_data_normalized(span, SPANDATA.GEN_AI_REQUEST_MODEL, model)
131110
set_input_span_data(
132111
span,
133112
kwargs,
134113
integration,
135114
{
136-
**COHERE_V1_CHAT_CONFIG,
115+
"system": "cohere",
116+
"operation": "chat",
117+
"extract_messages": _extract_messages_v1,
137118
"extra_static": {SPANDATA.GEN_AI_RESPONSE_STREAMING: streaming},
138119
},
139120
)
140121

141122
if streaming:
142-
old_iterator = res
143-
144-
def new_iterator():
145-
# type: () -> Iterator[StreamedChatResponse]
146-
with capture_internal_exceptions():
147-
for x in old_iterator:
148-
if isinstance(x, ChatStreamEndEvent) or isinstance(
149-
x, StreamEndStreamedChatResponse
150-
):
151-
collect_chat_response_fields(
152-
span,
153-
x.response,
154-
include_pii=should_send_default_pii()
155-
and integration.include_prompts,
156-
)
157-
yield x
158-
159-
return new_iterator()
160-
elif isinstance(res, NonStreamedChatResponse):
161-
collect_chat_response_fields(
162-
span,
163-
res,
164-
include_pii=should_send_default_pii()
165-
and integration.include_prompts,
166-
)
123+
return _iter_v1_stream_events(res, span, include_pii)
124+
if isinstance(res, NonStreamedChatResponse):
125+
_collect_v1_response_fields(span, res, include_pii=include_pii)
167126
else:
168127
set_data_normalized(span, "unknown_response", True)
169128
return res
170129

171-
def collect_chat_response_fields(span, res, include_pii):
172-
# type: (Span, NonStreamedChatResponse, bool) -> None
173-
if include_pii:
174-
if hasattr(res, "text"):
175-
set_data_normalized(
176-
span,
177-
SPANDATA.GEN_AI_RESPONSE_TEXT,
178-
[res.text],
179-
)
180-
for attr, spandata_key in COLLECTED_PII_CHAT_RESP_ATTRS.items():
181-
if hasattr(res, attr):
182-
set_data_normalized(span, spandata_key, getattr(res, attr))
130+
return new_chat
183131

184-
for attr, spandata_key in COLLECTED_CHAT_RESP_ATTRS.items():
185-
if hasattr(res, attr):
186-
set_data_normalized(span, spandata_key, getattr(res, attr))
187132

188-
if hasattr(res, "meta"):
189-
if hasattr(res.meta, "billed_units"):
190-
record_token_usage(
191-
span,
192-
input_tokens=res.meta.billed_units.input_tokens,
193-
output_tokens=res.meta.billed_units.output_tokens,
194-
)
195-
elif hasattr(res.meta, "tokens"):
196-
record_token_usage(
197-
span,
198-
input_tokens=res.meta.tokens.input_tokens,
199-
output_tokens=res.meta.tokens.output_tokens,
200-
)
133+
def _extract_messages_v1(kwargs):
134+
# type: (dict[str, Any]) -> list[dict[str, str]]
135+
messages = []
136+
for x in kwargs.get("chat_history", []):
137+
messages.append(
138+
{
139+
"role": getattr(x, "role", "").lower(),
140+
"content": transform_message_content(getattr(x, "message", "")),
141+
}
142+
)
143+
message = kwargs.get("message")
144+
if message:
145+
messages.append({"role": "user", "content": transform_message_content(message)})
146+
return messages
201147

202-
return new_chat
148+
149+
def _iter_v1_stream_events(old_iterator, span, include_pii):
150+
# type: (Any, Any, bool) -> Iterator[StreamedChatResponse]
151+
with capture_internal_exceptions():
152+
for x in old_iterator:
153+
if isinstance(x, ChatStreamEndEvent) or isinstance(
154+
x, StreamEndStreamedChatResponse
155+
):
156+
_collect_v1_stream_end_fields(span, x, include_pii)
157+
yield x
158+
159+
160+
def _collect_v1_stream_end_fields(span, event, include_pii):
161+
# type: (Any, Any, bool) -> None
162+
response = get_first_from_sources(event, STREAM_RESPONSE_SOURCES)
163+
if response is not None:
164+
_collect_v1_response_fields(span, response, include_pii)
165+
166+
167+
def _collect_v1_response_fields(span, response, include_pii):
168+
# type: (Any, Any, bool) -> None
169+
if include_pii:
170+
text = get_first_from_sources(response, CHAT_RESPONSE_TEXT_SOURCES)
171+
if text is not None:
172+
set_data_normalized(span, SPANDATA.GEN_AI_RESPONSE_TEXT, [text])
173+
set_span_data_from_sources(
174+
span, response, PII_CHAT_RESPONSE_SOURCES, require_truthy=False
175+
)
176+
177+
set_span_data_from_sources(
178+
span, response, CHAT_RESPONSE_SOURCES, require_truthy=False
179+
)
180+
record_token_usage(
181+
span,
182+
input_tokens=get_first_from_sources(
183+
response, CHAT_USAGE_TOKEN_SOURCES[SPANDATA.GEN_AI_USAGE_INPUT_TOKENS]
184+
),
185+
output_tokens=get_first_from_sources(
186+
response, CHAT_USAGE_TOKEN_SOURCES[SPANDATA.GEN_AI_USAGE_OUTPUT_TOKENS]
187+
),
188+
)

0 commit comments

Comments
 (0)