diff --git a/nemoguardrails/actions/llm/utils.py b/nemoguardrails/actions/llm/utils.py index b71adb76c..1cf6fc388 100644 --- a/nemoguardrails/actions/llm/utils.py +++ b/nemoguardrails/actions/llm/utils.py @@ -110,6 +110,7 @@ async def llm_call( generation_llm, prompt, all_callbacks ) + _store_reasoning_traces(response) _store_tool_calls(response) _store_response_metadata(response) return _extract_content(response) @@ -172,6 +173,18 @@ def _convert_messages_to_langchain_format(prompt: List[dict]) -> List: return dicts_to_messages(prompt) +def _store_reasoning_traces(response) -> None: + if hasattr(response, "additional_kwargs"): + additional_kwargs = response.additional_kwargs + if ( + isinstance(additional_kwargs, dict) + and "reasoning_content" in additional_kwargs + ): + reasoning_content = additional_kwargs["reasoning_content"] + if reasoning_content: + reasoning_trace_var.set(reasoning_content) + + def _store_tool_calls(response) -> None: """Extract and store tool calls from response in context.""" tool_calls = getattr(response, "tool_calls", None) @@ -192,15 +205,6 @@ def _store_response_metadata(response) -> None: metadata[field_name] = getattr(response, field_name) llm_response_metadata_var.set(metadata) - if hasattr(response, "additional_kwargs"): - additional_kwargs = response.additional_kwargs - if ( - isinstance(additional_kwargs, dict) - and "reasoning_content" in additional_kwargs - ): - reasoning_content = additional_kwargs["reasoning_content"] - if reasoning_content: - reasoning_trace_var.set(reasoning_content) else: llm_response_metadata_var.set(None) @@ -704,6 +708,12 @@ def extract_tool_calls_from_events(events: list) -> Optional[list]: return None +def extract_bot_thinking_from_events(events: list): + for event in events: + if event.get("type") == "BotThinking": + return event.get("content") + + def get_and_clear_response_metadata_contextvar() -> Optional[dict]: """Get the current response metadata and clear it from the context. diff --git a/nemoguardrails/rails/llm/llmrails.py b/nemoguardrails/rails/llm/llmrails.py index fe56bcf08..e736a32df 100644 --- a/nemoguardrails/rails/llm/llmrails.py +++ b/nemoguardrails/rails/llm/llmrails.py @@ -43,8 +43,8 @@ from nemoguardrails.actions.llm.generation import LLMGenerationActions from nemoguardrails.actions.llm.utils import ( + extract_bot_thinking_from_events, extract_tool_calls_from_events, - get_and_clear_reasoning_trace_contextvar, get_and_clear_response_metadata_contextvar, get_colang_history, ) @@ -1037,7 +1037,7 @@ async def generate_async( else: res = GenerationResponse(response=[new_message]) - if reasoning_trace := get_and_clear_reasoning_trace_contextvar(): + if reasoning_trace := extract_bot_thinking_from_events(events): if prompt: # For prompt mode, response should be a string if isinstance(res.response, str): @@ -1182,7 +1182,7 @@ async def generate_async( else: # If a prompt is used, we only return the content of the message. - if reasoning_trace := get_and_clear_reasoning_trace_contextvar(): + if reasoning_trace := extract_bot_thinking_from_events(events): new_message["content"] = reasoning_trace + new_message["content"] if prompt: diff --git a/tests/test_reasoning_trace_extraction.py b/tests/test_reasoning_trace_extraction.py new file mode 100644 index 000000000..5c3be2436 --- /dev/null +++ b/tests/test_reasoning_trace_extraction.py @@ -0,0 +1,306 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from unittest.mock import AsyncMock + +import pytest +from langchain_core.messages import AIMessage + +from nemoguardrails.actions.llm.utils import _store_reasoning_traces +from nemoguardrails.context import reasoning_trace_var + + +class TestStoreReasoningTracesUnit: + def test_store_reasoning_traces_with_valid_reasoning_content(self): + test_reasoning = "Step 1: Analyze the question\nStep 2: Formulate response" + + response = AIMessage( + content="The answer is 42", + additional_kwargs={"reasoning_content": test_reasoning}, + ) + + _store_reasoning_traces(response) + + stored_trace = reasoning_trace_var.get() + assert stored_trace == test_reasoning + + reasoning_trace_var.set(None) + + def test_store_reasoning_traces_with_empty_reasoning_content(self): + response = AIMessage( + content="Response", additional_kwargs={"reasoning_content": ""} + ) + + reasoning_trace_var.set(None) + _store_reasoning_traces(response) + + stored_trace = reasoning_trace_var.get() + assert stored_trace is None + + reasoning_trace_var.set(None) + + def test_store_reasoning_traces_with_none_reasoning_content(self): + response = AIMessage( + content="Response", additional_kwargs={"reasoning_content": None} + ) + + reasoning_trace_var.set(None) + _store_reasoning_traces(response) + + stored_trace = reasoning_trace_var.get() + assert stored_trace is None + + reasoning_trace_var.set(None) + + def test_store_reasoning_traces_without_reasoning_content_key(self): + response = AIMessage( + content="Response", additional_kwargs={"other_key": "other_value"} + ) + + reasoning_trace_var.set(None) + _store_reasoning_traces(response) + + stored_trace = reasoning_trace_var.get() + assert stored_trace is None + + reasoning_trace_var.set(None) + + def test_store_reasoning_traces_with_empty_additional_kwargs(self): + response = AIMessage(content="Response", additional_kwargs={}) + + reasoning_trace_var.set(None) + _store_reasoning_traces(response) + + stored_trace = reasoning_trace_var.get() + assert stored_trace is None + + reasoning_trace_var.set(None) + + def test_store_reasoning_traces_without_additional_kwargs_attribute(self): + class SimpleResponse: + def __init__(self, content): + self.content = content + + response = SimpleResponse("Response") + + reasoning_trace_var.set(None) + _store_reasoning_traces(response) + + stored_trace = reasoning_trace_var.get() + assert stored_trace is None + + reasoning_trace_var.set(None) + + def test_store_reasoning_traces_with_non_dict_additional_kwargs(self): + class ResponseWithInvalidKwargs: + def __init__(self): + self.content = "Response" + self.additional_kwargs = "not_a_dict" + + response = ResponseWithInvalidKwargs() + + reasoning_trace_var.set(None) + _store_reasoning_traces(response) + + stored_trace = reasoning_trace_var.get() + assert stored_trace is None + + reasoning_trace_var.set(None) + + def test_store_reasoning_traces_overwrites_previous_trace(self): + initial_trace = "Initial reasoning" + new_trace = "New reasoning" + + reasoning_trace_var.set(initial_trace) + + response = AIMessage( + content="Response", additional_kwargs={"reasoning_content": new_trace} + ) + + _store_reasoning_traces(response) + + stored_trace = reasoning_trace_var.get() + assert stored_trace == new_trace + assert stored_trace != initial_trace + + reasoning_trace_var.set(None) + + def test_store_reasoning_traces_with_multiline_content(self): + multiline_reasoning = """Thought process: +1. First, understand the user's intent +2. Second, check available data +3. Third, formulate a response +4. Finally, validate the response""" + + response = AIMessage( + content="Response", + additional_kwargs={"reasoning_content": multiline_reasoning}, + ) + + _store_reasoning_traces(response) + + stored_trace = reasoning_trace_var.get() + assert stored_trace == multiline_reasoning + + reasoning_trace_var.set(None) + + def test_store_reasoning_traces_with_special_characters(self): + special_reasoning = "Thinking: Let's analyze this with \"quotes\" and 'apostrophes' & symbols!" + + response = AIMessage( + content="Response", + additional_kwargs={"reasoning_content": special_reasoning}, + ) + + _store_reasoning_traces(response) + + stored_trace = reasoning_trace_var.get() + assert stored_trace == special_reasoning + + reasoning_trace_var.set(None) + + +class TestReasoningTraceIntegration: + @pytest.mark.asyncio + async def test_llm_call_extracts_reasoning_from_additional_kwargs(self): + test_reasoning = "Let me think about this carefully..." + + mock_llm = AsyncMock() + mock_response = AIMessage( + content="The answer is 42", + additional_kwargs={"reasoning_content": test_reasoning}, + ) + mock_llm.ainvoke = AsyncMock(return_value=mock_response) + + from nemoguardrails.actions.llm.utils import llm_call + + reasoning_trace_var.set(None) + result = await llm_call(mock_llm, "What is the answer?") + + assert result == "The answer is 42" + stored_trace = reasoning_trace_var.get() + assert stored_trace == test_reasoning + + reasoning_trace_var.set(None) + + @pytest.mark.asyncio + async def test_llm_call_handles_missing_reasoning_content(self): + mock_llm = AsyncMock() + mock_response = AIMessage(content="Regular response", additional_kwargs={}) + mock_llm.ainvoke = AsyncMock(return_value=mock_response) + + from nemoguardrails.actions.llm.utils import llm_call + + reasoning_trace_var.set(None) + result = await llm_call(mock_llm, "Hello") + + assert result == "Regular response" + stored_trace = reasoning_trace_var.get() + assert stored_trace is None + + reasoning_trace_var.set(None) + + @pytest.mark.asyncio + async def test_llm_call_with_message_list_extracts_reasoning(self): + test_reasoning = "Analyzing the conversation context..." + + mock_llm = AsyncMock() + mock_response = AIMessage( + content="Here's my response", + additional_kwargs={"reasoning_content": test_reasoning}, + ) + mock_llm.ainvoke = AsyncMock(return_value=mock_response) + + from nemoguardrails.actions.llm.utils import llm_call + + messages = [ + {"role": "user", "content": "Hello"}, + {"role": "assistant", "content": "Hi there"}, + ] + + reasoning_trace_var.set(None) + result = await llm_call(mock_llm, messages) + + assert result == "Here's my response" + stored_trace = reasoning_trace_var.get() + assert stored_trace == test_reasoning + + reasoning_trace_var.set(None) + + @pytest.mark.asyncio + async def test_multiple_llm_calls_preserve_separate_reasoning_traces(self): + first_reasoning = "First analysis" + second_reasoning = "Second analysis" + + mock_llm = AsyncMock() + call_count = 0 + + async def mock_ainvoke(*args, **kwargs): + nonlocal call_count + call_count += 1 + if call_count == 1: + return AIMessage( + content="First response", + additional_kwargs={"reasoning_content": first_reasoning}, + ) + else: + return AIMessage( + content="Second response", + additional_kwargs={"reasoning_content": second_reasoning}, + ) + + mock_llm.ainvoke = mock_ainvoke + + from nemoguardrails.actions.llm.utils import llm_call + + reasoning_trace_var.set(None) + result1 = await llm_call(mock_llm, "First query") + trace1 = reasoning_trace_var.get() + + reasoning_trace_var.set(None) + result2 = await llm_call(mock_llm, "Second query") + trace2 = reasoning_trace_var.get() + + assert trace1 == first_reasoning + assert trace2 == second_reasoning + + reasoning_trace_var.set(None) + + @pytest.mark.asyncio + async def test_reasoning_content_with_other_additional_kwargs(self): + test_reasoning = "Complex reasoning process" + + mock_llm = AsyncMock() + mock_response = AIMessage( + content="Response", + additional_kwargs={ + "reasoning_content": test_reasoning, + "model": "test-model", + "finish_reason": "stop", + "other_metadata": {"key": "value"}, + }, + ) + mock_llm.ainvoke = AsyncMock(return_value=mock_response) + + from nemoguardrails.actions.llm.utils import llm_call + + reasoning_trace_var.set(None) + result = await llm_call(mock_llm, "Query") + + assert result == "Response" + stored_trace = reasoning_trace_var.get() + assert stored_trace == test_reasoning + + reasoning_trace_var.set(None)