22from functools import wraps
33
44from sentry_sdk .ai .monitoring import record_token_usage
5- from sentry_sdk .consts import OP , SPANDATA
6- from sentry_sdk .ai .utils import set_data_normalized
75from sentry_sdk .ai .span_config import set_input_span_data
6+ from sentry_sdk .ai .utils import set_data_normalized , transform_message_content
7+ from sentry_sdk .consts import OP , SPANDATA
88
99from typing import TYPE_CHECKING
1010
1111if TYPE_CHECKING :
1212 from typing import Any , Callable , Iterator
1313 from cohere import StreamedChatResponse
14- from sentry_sdk .tracing import Span
1514
1615import sentry_sdk
17- from sentry_sdk .scope import should_send_default_pii
18- from sentry_sdk .utils import capture_internal_exceptions , reraise
19-
2016from sentry_sdk .integrations .cohere import (
2117 CohereIntegration ,
22- COLLECTED_CHAT_PARAMS ,
2318 _capture_exception ,
2419)
20+ from sentry_sdk .integrations .cohere .utils import (
21+ get_first_from_sources ,
22+ set_span_data_from_sources ,
23+ )
24+ from sentry_sdk .scope import should_send_default_pii
25+ from sentry_sdk .utils import capture_internal_exceptions , reraise
2526
2627try :
27- from cohere import (
28- ChatStreamEndEvent ,
29- NonStreamedChatResponse ,
30- )
28+ from cohere import ChatStreamEndEvent , NonStreamedChatResponse
3129
3230 try :
3331 from cohere import StreamEndStreamedChatResponse
4038except ImportError :
4139 _has_chat_types = False
4240
43- COLLECTED_PII_CHAT_PARAMS = {
44- "tools" : SPANDATA .GEN_AI_REQUEST_AVAILABLE_TOOLS ,
45- "preamble" : SPANDATA .GEN_AI_SYSTEM_INSTRUCTIONS ,
46- }
47-
48-
49- def _extract_messages_v1 (kwargs ):
50- # type: (dict[str, Any]) -> list[dict[str, str]]
51- """Extract role/content dicts from V1-style chat_history + message."""
52- messages = []
53- for x in kwargs .get ("chat_history" , []):
54- messages .append (
55- {
56- "role" : getattr (x , "role" , "" ).lower (),
57- "content" : getattr (x , "message" , "" ),
58- }
59- )
60- message = kwargs .get ("message" )
61- if message :
62- messages .append ({"role" : "user" , "content" : message })
63- return messages
64-
65-
66- COHERE_V1_CHAT_CONFIG = {
67- "system" : "cohere" ,
68- "operation" : "chat" ,
69- "params" : COLLECTED_CHAT_PARAMS ,
70- "pii_params" : COLLECTED_PII_CHAT_PARAMS ,
71- "extract_messages" : _extract_messages_v1 ,
41+ CHAT_RESPONSE_SOURCES = {
42+ SPANDATA .GEN_AI_RESPONSE_ID : [("generation_id" ,)],
43+ SPANDATA .GEN_AI_RESPONSE_FINISH_REASONS : [("finish_reason" ,)],
7244}
73-
74- COLLECTED_CHAT_RESP_ATTRS = {
75- "generation_id" : SPANDATA .GEN_AI_RESPONSE_ID ,
76- "finish_reason" : SPANDATA .GEN_AI_RESPONSE_FINISH_REASONS ,
45+ PII_CHAT_RESPONSE_SOURCES = {
46+ SPANDATA .GEN_AI_RESPONSE_TOOL_CALLS : [("tool_calls" ,)],
7747}
78-
79- COLLECTED_PII_CHAT_RESP_ATTRS = {
80- "tool_calls" : SPANDATA .GEN_AI_RESPONSE_TOOL_CALLS ,
48+ CHAT_RESPONSE_TEXT_SOURCES = [("text" ,)]
49+ CHAT_USAGE_TOKEN_SOURCES = {
50+ SPANDATA .GEN_AI_USAGE_INPUT_TOKENS : [
51+ ("meta" , "billed_units" , "input_tokens" ),
52+ ("meta" , "tokens" , "input_tokens" ),
53+ ],
54+ SPANDATA .GEN_AI_USAGE_OUTPUT_TOKENS : [
55+ ("meta" , "billed_units" , "output_tokens" ),
56+ ("meta" , "tokens" , "output_tokens" ),
57+ ],
8158}
59+ STREAM_RESPONSE_SOURCES = [("response" ,)]
8260
8361
8462def setup_v1 (wrap_embed_fn ):
8563 # type: (Callable[..., Any]) -> None
86- """Called from CohereIntegration.setup_once() to patch V1 Client methods."""
8764 try :
88- from cohere .client import Client
8965 from cohere .base_client import BaseCohere
66+ from cohere .client import Client
9067 except ImportError :
9168 return
9269
@@ -104,7 +81,6 @@ def _wrap_chat(f, streaming):
10481 def new_chat (* args , ** kwargs ):
10582 # type: (*Any, **Any) -> Any
10683 integration = sentry_sdk .get_client ().get_integration (CohereIntegration )
107-
10884 if (
10985 integration is None
11086 or "message" not in kwargs
@@ -113,6 +89,7 @@ def new_chat(*args, **kwargs):
11389 return f (* args , ** kwargs )
11490
11591 model = kwargs .get ("model" , "" )
92+ include_pii = should_send_default_pii () and integration .include_prompts
11693
11794 with sentry_sdk .start_span (
11895 op = OP .GEN_AI_CHAT ,
@@ -128,75 +105,84 @@ def new_chat(*args, **kwargs):
128105 reraise (* exc_info )
129106
130107 with capture_internal_exceptions ():
108+ if model :
109+ set_data_normalized (span , SPANDATA .GEN_AI_REQUEST_MODEL , model )
131110 set_input_span_data (
132111 span ,
133112 kwargs ,
134113 integration ,
135114 {
136- ** COHERE_V1_CHAT_CONFIG ,
115+ "system" : "cohere" ,
116+ "operation" : "chat" ,
117+ "extract_messages" : _extract_messages_v1 ,
137118 "extra_static" : {SPANDATA .GEN_AI_RESPONSE_STREAMING : streaming },
138119 },
139120 )
140121
141122 if streaming :
142- old_iterator = res
143-
144- def new_iterator ():
145- # type: () -> Iterator[StreamedChatResponse]
146- with capture_internal_exceptions ():
147- for x in old_iterator :
148- if isinstance (x , ChatStreamEndEvent ) or isinstance (
149- x , StreamEndStreamedChatResponse
150- ):
151- collect_chat_response_fields (
152- span ,
153- x .response ,
154- include_pii = should_send_default_pii ()
155- and integration .include_prompts ,
156- )
157- yield x
158-
159- return new_iterator ()
160- elif isinstance (res , NonStreamedChatResponse ):
161- collect_chat_response_fields (
162- span ,
163- res ,
164- include_pii = should_send_default_pii ()
165- and integration .include_prompts ,
166- )
123+ return _iter_v1_stream_events (res , span , include_pii )
124+ if isinstance (res , NonStreamedChatResponse ):
125+ _collect_v1_response_fields (span , res , include_pii = include_pii )
167126 else :
168127 set_data_normalized (span , "unknown_response" , True )
169128 return res
170129
171- def collect_chat_response_fields (span , res , include_pii ):
172- # type: (Span, NonStreamedChatResponse, bool) -> None
173- if include_pii :
174- if hasattr (res , "text" ):
175- set_data_normalized (
176- span ,
177- SPANDATA .GEN_AI_RESPONSE_TEXT ,
178- [res .text ],
179- )
180- for attr , spandata_key in COLLECTED_PII_CHAT_RESP_ATTRS .items ():
181- if hasattr (res , attr ):
182- set_data_normalized (span , spandata_key , getattr (res , attr ))
130+ return new_chat
183131
184- for attr , spandata_key in COLLECTED_CHAT_RESP_ATTRS .items ():
185- if hasattr (res , attr ):
186- set_data_normalized (span , spandata_key , getattr (res , attr ))
187132
188- if hasattr (res , "meta" ):
189- if hasattr (res .meta , "billed_units" ):
190- record_token_usage (
191- span ,
192- input_tokens = res .meta .billed_units .input_tokens ,
193- output_tokens = res .meta .billed_units .output_tokens ,
194- )
195- elif hasattr (res .meta , "tokens" ):
196- record_token_usage (
197- span ,
198- input_tokens = res .meta .tokens .input_tokens ,
199- output_tokens = res .meta .tokens .output_tokens ,
200- )
133+ def _extract_messages_v1 (kwargs ):
134+ # type: (dict[str, Any]) -> list[dict[str, str]]
135+ messages = []
136+ for x in kwargs .get ("chat_history" , []):
137+ messages .append (
138+ {
139+ "role" : getattr (x , "role" , "" ).lower (),
140+ "content" : transform_message_content (getattr (x , "message" , "" )),
141+ }
142+ )
143+ message = kwargs .get ("message" )
144+ if message :
145+ messages .append ({"role" : "user" , "content" : transform_message_content (message )})
146+ return messages
201147
202- return new_chat
148+
149+ def _iter_v1_stream_events (old_iterator , span , include_pii ):
150+ # type: (Any, Any, bool) -> Iterator[StreamedChatResponse]
151+ with capture_internal_exceptions ():
152+ for x in old_iterator :
153+ if isinstance (x , ChatStreamEndEvent ) or isinstance (
154+ x , StreamEndStreamedChatResponse
155+ ):
156+ _collect_v1_stream_end_fields (span , x , include_pii )
157+ yield x
158+
159+
160+ def _collect_v1_stream_end_fields (span , event , include_pii ):
161+ # type: (Any, Any, bool) -> None
162+ response = get_first_from_sources (event , STREAM_RESPONSE_SOURCES )
163+ if response is not None :
164+ _collect_v1_response_fields (span , response , include_pii )
165+
166+
167+ def _collect_v1_response_fields (span , response , include_pii ):
168+ # type: (Any, Any, bool) -> None
169+ if include_pii :
170+ text = get_first_from_sources (response , CHAT_RESPONSE_TEXT_SOURCES )
171+ if text is not None :
172+ set_data_normalized (span , SPANDATA .GEN_AI_RESPONSE_TEXT , [text ])
173+ set_span_data_from_sources (
174+ span , response , PII_CHAT_RESPONSE_SOURCES , require_truthy = False
175+ )
176+
177+ set_span_data_from_sources (
178+ span , response , CHAT_RESPONSE_SOURCES , require_truthy = False
179+ )
180+ record_token_usage (
181+ span ,
182+ input_tokens = get_first_from_sources (
183+ response , CHAT_USAGE_TOKEN_SOURCES [SPANDATA .GEN_AI_USAGE_INPUT_TOKENS ]
184+ ),
185+ output_tokens = get_first_from_sources (
186+ response , CHAT_USAGE_TOKEN_SOURCES [SPANDATA .GEN_AI_USAGE_OUTPUT_TOKENS ]
187+ ),
188+ )
0 commit comments