diff --git a/safetytooling/apis/inference/anthropic.py b/safetytooling/apis/inference/anthropic.py index 73154f2..db33d62 100644 --- a/safetytooling/apis/inference/anthropic.py +++ b/safetytooling/apis/inference/anthropic.py @@ -305,7 +305,7 @@ async def __call__( # Safely extract text and thinking content text_content = None - reasoning_content = None # We can extract this even if not used by LLMResponse yet + reasoning_content = None if content: for block in content: if block.type == "text" and hasattr(block, "text"): diff --git a/safetytooling/apis/inference/openai/batch_api.py b/safetytooling/apis/inference/openai/batch_api.py index 658e191..4386663 100644 --- a/safetytooling/apis/inference/openai/batch_api.py +++ b/safetytooling/apis/inference/openai/batch_api.py @@ -123,6 +123,7 @@ async def __call__( api_duration=None, cost=0, batch_custom_id=result["custom_id"], + reasoning_content=choice["message"].get("reasoning_content", None), ) responses = [] diff --git a/safetytooling/apis/inference/openai/chat.py b/safetytooling/apis/inference/openai/chat.py index a87557c..cac3a87 100644 --- a/safetytooling/apis/inference/openai/chat.py +++ b/safetytooling/apis/inference/openai/chat.py @@ -117,11 +117,29 @@ async def _make_api_call(self, prompt: Prompt, model_id, start_time, **kwargs) - ) else: api_func = self.aclient.chat.completions.create - api_response: openai.types.chat.ChatCompletion = await api_func( - messages=prompt.openai_format(), - model=model_id, - **kwargs, - ) + + original_base_url = self.aclient.base_url + try: + if model_id in {"deepseek-chat", "deepseek-reasoner"}: + if prompt.is_last_message_assistant(): + # Use the beta endpoint for assistant prefilled prompts with DeepSeek + self.aclient.base_url = "https://api.deepseek.com/beta" + else: + # Use the standard v1 endpoint otherwise + self.aclient.base_url = "https://api.deepseek.com/v1" + messages = prompt.deepseek_format() + else: + messages = prompt.openai_format() + + api_response: openai.types.chat.ChatCompletion = await api_func( + messages=messages, + model=model_id, + **kwargs, + ) + finally: + # Always revert the base_url after the call + self.aclient.base_url = original_base_url + if hasattr(api_response, "error") and ( "Rate limit exceeded" in api_response.error["message"] or api_response.error["code"] == 429 ): # OpenRouter routes through the error messages from the different providers, so we catch them here @@ -160,6 +178,7 @@ async def _make_api_call(self, prompt: Prompt, model_id, start_time, **kwargs) - duration=duration, cost=context_cost + count_tokens(choice.message.content, model_id) * completion_token_cost, logprobs=(self.convert_top_logprobs(choice.logprobs) if choice.logprobs is not None else None), + reasoning_content=getattr(choice.message, "reasoning_content", None), ) ) self.add_response_to_prompt_file(prompt_file, responses) diff --git a/safetytooling/data_models/messages.py b/safetytooling/data_models/messages.py index 655f898..9569aa8 100644 --- a/safetytooling/data_models/messages.py +++ b/safetytooling/data_models/messages.py @@ -26,6 +26,7 @@ PRINT_COLORS = { "user": "cyan", "system": "magenta", + "developer": "magenta", "assistant": "light_green", "audio": "yellow", "image": "yellow", @@ -36,6 +37,7 @@ class MessageRole(str, Enum): user = "user" system = "system" + developer = "developer" # A new system message for OpenAI o1 models assistant = "assistant" audio = "audio" image = "image"