diff --git a/backend/openedx_ai_extensions/processors/llm/litellm_base_processor.py b/backend/openedx_ai_extensions/processors/llm/litellm_base_processor.py index d8786c1e..97080ce0 100644 --- a/backend/openedx_ai_extensions/processors/llm/litellm_base_processor.py +++ b/backend/openedx_ai_extensions/processors/llm/litellm_base_processor.py @@ -60,10 +60,6 @@ def __init__(self, config=None, user_session=None): if functions_schema_filtered: self.extra_params["tools"] = functions_schema_filtered - if self.stream and "tools" in self.extra_params: - logger.warning("Streaming responses with tools is not supported; disabling streaming.") - self.stream = False - self.mcp_configs = {} allowed_mcp_configs = self.config.get("mcp_configs", []) if allowed_mcp_configs: diff --git a/backend/openedx_ai_extensions/processors/llm/llm_processor.py b/backend/openedx_ai_extensions/processors/llm/llm_processor.py index b5083b58..4852f148 100644 --- a/backend/openedx_ai_extensions/processors/llm/llm_processor.py +++ b/backend/openedx_ai_extensions/processors/llm/llm_processor.py @@ -127,33 +127,90 @@ def _build_response_api_params(self, system_role=None): return params - def _yield_threaded_stream(self, response): + def _yield_threaded_stream(self, response, params=None): """ Helper generator to handle streaming logic for threaded responses. + Handles tool calls execution recursively during streaming. """ total_tokens = None - try: + try: # pylint: disable=R1702 for chunk in response: - if hasattr(chunk, "usage") and chunk.usage: + # Track token usage from the response metadata + chunk_response = getattr(chunk, "response", None) + if chunk_response and hasattr(chunk_response, "usage") and chunk_response.usage: + total_tokens = chunk_response.usage.total_tokens + elif hasattr(chunk, "usage") and chunk.usage: total_tokens = chunk.usage.total_tokens - if getattr(chunk, "response", None): - resp = getattr(chunk, "response", None) - if resp is not None: - response_id = getattr(resp, "id", None) - self.user_session.remote_response_id = response_id - self.user_session.save() + # Persist thread ID for ongoing conversation context + if chunk_response: + if chunk_response is not None and self.user_session: + response_id = getattr(chunk_response, "id", None) + if response_id: + self.user_session.remote_response_id = response_id + self.user_session.save() + + # Stream text deltas, filtering out empty JSON artifacts if hasattr(chunk, "delta"): - yield chunk.delta + yield chunk.delta if chunk.delta != "{}" else "" + + # Check for completed tool call requests + chunk_type = getattr(chunk, "type", None) + if chunk_type == "response.output_item.done": + item = chunk.item + if getattr(item, "type", None) == "function_call": + logger.info(f"[LLM STREAM] Intercepted tool call: {item.name}") + + function_name = item.name + call_id = item.call_id + arguments_str = item.arguments + + # Add function call intent to history (required for API consistency) + if params is not None: + params["input"].append({ + "type": "function_call", + "call_id": call_id, + "name": function_name, + "arguments": arguments_str + }) + + # Parse arguments and execute the local function + try: + function_args = json.loads(arguments_str) + except json.JSONDecodeError: + function_args = {} + + if function_name in AVAILABLE_TOOLS: + func = AVAILABLE_TOOLS[function_name] + try: + tool_output = func(**function_args) + except Exception as e: # pylint: disable=broad-exception-caught + tool_output = f"Error: {str(e)}" + else: + tool_output = "Error: Tool not found." + + if params is None: + yield "\n[System Error: Missing params]".encode("utf-8") + return + + # Add tool output to history and re-trigger LLM response + params["input"].append({ + "type": "function_call_output", + "call_id": call_id, + "output": str(tool_output) + }) + + # Recursively stream the new response interpreting the tool output + new_response = responses(**params) + yield from self._yield_threaded_stream(new_response, params) + return if total_tokens is not None: logger.info(f"[LLM STREAM] Tokens used: {total_tokens}") - else: - logger.info("[LLM STREAM] Tokens used: unknown (model did not report)") except Exception as e: # pylint: disable=broad-exception-caught - logger.error(f"Error during threaded AI streaming: {e}", exc_info=True) - yield f"\n[AI Error: {e}]" + logger.error(f"Error: {e}", exc_info=True) + yield f"\n[AI Error: {e}]".encode("utf-8") def _call_responses_wrapper(self, params, initialize=False): """ @@ -163,7 +220,7 @@ def _call_responses_wrapper(self, params, initialize=False): response = self._responses_with_tools(tool_calls=[], params=params) if params["stream"]: - return self._yield_threaded_stream(response) + return self._yield_threaded_stream(response, params=params) response_id = getattr(response, "id", None) content = self._extract_response_content(response=response) @@ -245,12 +302,26 @@ def _completion_with_tools(self, tool_calls, params): """Handle tool calls recursively until no more tool calls are present.""" for tool_call in tool_calls: function_name = tool_call.function.name + + # Ensure tool exists + if function_name not in AVAILABLE_TOOLS: + logger.error(f"Tool '{function_name}' requested by LLM but not available locally.") + continue + function_to_call = AVAILABLE_TOOLS[function_name] - function_args = json.loads(tool_call.function.arguments) + logger.info(f"[LLM] Tool call: {function_to_call}") + + try: + function_args = json.loads(tool_call.function.arguments) + function_response = function_to_call(**function_args) + logger.info(f"[LLM] Response from tool call: {function_response}") + except json.JSONDecodeError: + function_response = "Error: Invalid JSON arguments provided." + logger.error(f"Failed to parse JSON arguments for {function_name}") + except Exception as e: # pylint: disable=broad-exception-caught + function_response = f"Error executing tool: {str(e)}" + logger.error(f"Error executing tool {function_name}: {e}") - function_response = function_to_call( - **function_args, - ) params["messages"].append( { "tool_call_id": tool_call.id, @@ -263,10 +334,9 @@ def _completion_with_tools(self, tool_calls, params): # Call completion again with updated messages response = completion(**params) - # For streaming, return the generator immediately - # Tool calls are not supported in streaming mode + # For streaming, we need to handle the stream to detect tool calls if params.get("stream"): - return response + return self._handle_streaming_tool_calls(response, params) # For non-streaming, check for tool calls and handle recursively new_tool_calls = response.choices[0].message.tool_calls @@ -276,6 +346,99 @@ def _completion_with_tools(self, tool_calls, params): return response + def _handle_streaming_tool_calls(self, response, params): + """ + Generator that handles streaming responses containing tool calls. + It accumulates tool call chunks, executes them, and recursively calls completion. + """ + tool_calls_buffer = {} # index -> {id, function: {name, arguments}} + accumulating_tools = False + logger.info("[LLM STREAM] Streaming tool calls") + + for chunk in response: + delta = chunk.choices[0].delta + + # If there is content, yield it immediately to the user + if delta.content: + yield chunk + + # If there are tool calls, buffer them + if delta.tool_calls: + if not accumulating_tools: + logger.info("[AI STREAM] Start: buffer function") + accumulating_tools = True + for tc_chunk in delta.tool_calls: + idx = tc_chunk.index + + if idx not in tool_calls_buffer: + tool_calls_buffer[idx] = { + "id": "", + "type": "function", + "function": {"name": "", "arguments": ""} + } + + if tc_chunk.id: + tool_calls_buffer[idx]["id"] += tc_chunk.id + + if tc_chunk.function: + if tc_chunk.function.name: + tool_calls_buffer[idx]["function"]["name"] += tc_chunk.function.name + if tc_chunk.function.arguments: + tool_calls_buffer[idx]["function"]["arguments"] += tc_chunk.function.arguments + + # If we accumulated tool calls, reconstruct them and recurse + if accumulating_tools and tool_calls_buffer: + + # Helper classes to mimic the object structure LiteLLM expects in _completion_with_tools + class FunctionMock: + def __init__(self, name, arguments): + self.name = name + self.arguments = arguments + + class ToolCallMock: + def __init__(self, t_id, name, arguments): + self.id = t_id + self.function = FunctionMock(name, arguments) + self.type = "function" + + # Prepare list for the recursive call + reconstructed_tool_calls = [] + + # Prepare message to append to history (as dict for JSON serialization) + assistant_message_tool_calls = [] + + for idx in sorted(tool_calls_buffer.keys()): + data = tool_calls_buffer[idx] + + # Create object for internal logic + tc_obj = ToolCallMock( + t_id=data['id'], + name=data['function']['name'], + arguments=data['function']['arguments'] + ) + reconstructed_tool_calls.append(tc_obj) + + # Create dict for history + assistant_message_tool_calls.append({ + "id": data['id'], + "type": "function", + "function": { + "name": data['function']['name'], + "arguments": data['function']['arguments'] + } + }) + + # Append the Assistant's intent to call tools to the history + params["messages"].append({ + "role": "assistant", + "content": None, + "tool_calls": assistant_message_tool_calls + }) + + # Recursively call completion with the reconstructed tools + # yield from delegates the generation of the next stream (result of tool) to this generator + yield from self._completion_with_tools(reconstructed_tool_calls, params) + def _responses_with_tools(self, tool_calls, params): """Handle tool calls recursively until no more tool calls are present.""" for tool_call in tool_calls: diff --git a/backend/tests/test_litellm_base_processor.py b/backend/tests/test_litellm_base_processor.py index e3ce0a75..5b03a85b 100644 --- a/backend/tests/test_litellm_base_processor.py +++ b/backend/tests/test_litellm_base_processor.py @@ -577,39 +577,6 @@ def test_non_string_provider_raises_error(mock_settings): # pylint: disable=unu LitellmProcessor(config=config, user_session=None) -# ============================================================================ -# Streaming with Tools Tests -# ============================================================================ - - -@patch.object(settings, "AI_EXTENSIONS", new_callable=lambda: { - "default": { - "MODEL": "openai/gpt-4", - } -}) -@pytest.mark.django_db -def test_streaming_with_tools_disables_streaming(mock_settings): # pylint: disable=unused-argument - """ - Test that streaming is disabled when tools are enabled. - """ - config = { - "LitellmProcessor": { - "stream": True, - "enabled_tools": ["roll_dice"], - } - } - with patch('openedx_ai_extensions.processors.llm.litellm_base_processor.logger') as mock_logger: - processor = LitellmProcessor(config=config, user_session=None) - - # Verify streaming was disabled - assert processor.stream is False - - # Verify warning was logged - mock_logger.warning.assert_called_once_with( - "Streaming responses with tools is not supported; disabling streaming." - ) - - # ============================================================================ # MCP Configs Tests # ============================================================================ diff --git a/backend/tests/test_llm_processor.py b/backend/tests/test_llm_processor.py index 2e2adb10..5c40e07b 100644 --- a/backend/tests/test_llm_processor.py +++ b/backend/tests/test_llm_processor.py @@ -9,6 +9,7 @@ from opaque_keys.edx.keys import CourseKey from opaque_keys.edx.locator import BlockUsageLocator +from openedx_ai_extensions.functions.decorators import AVAILABLE_TOOLS from openedx_ai_extensions.processors.llm.llm_processor import LLMProcessor from openedx_ai_extensions.workflows.models import AIWorkflowProfile, AIWorkflowScope, AIWorkflowSession @@ -89,8 +90,9 @@ def llm_processor(user_session, settings): # pylint: disable=redefined-outer-na class MockDelta: """Mock for the delta object in a streaming chunk.""" - def __init__(self, content): + def __init__(self, content, tool_calls=None): self.content = content + self.tool_calls = tool_calls class MockChoice: @@ -130,6 +132,26 @@ def __init__(self, content, is_stream=True): ] +class MockUsage: + """Mock for usage statistics.""" + def __init__(self, total_tokens=10): + self.total_tokens = total_tokens + + +class MockStreamChunk: + """Mock for a streaming chunk.""" + def __init__(self, content, is_delta=True): + self.usage = MockUsage(total_tokens=5) + self.delta = None + self.choices = [] + + if is_delta: + mock_delta = MockDelta(content) + self.choices = [MockChoice(delta=mock_delta)] + self.delta = mock_delta + self.response = Mock(id="stream-id-123") + + # ============================================================================ # Non-Streaming Tests (Standard) # ============================================================================ @@ -713,3 +735,225 @@ def test_call_with_custom_prompt_missing_prompt_raises_error( with pytest.raises(ValueError, match="Custom prompt not provided in configuration"): processor.process(input_data="Test input") + + +# ============================================================================ +# Streaming Tool Call Tests +# ============================================================================ + +class MockToolStreamChunk: + """ + Helper for simulating tool call chunks in a stream. + Structure follows: chunk.choices[0].delta.tool_calls[...] + """ + + def __init__(self, index, tool_id=None, name=None, arguments=None): + self.usage = MockUsage(total_tokens=5) + + # 1. Create the function mock + func_mock = Mock() + func_mock.name = name + func_mock.arguments = arguments + + # 2. Create the tool_call mock + tool_call_mock = Mock() + tool_call_mock.index = index + tool_call_mock.id = tool_id + tool_call_mock.function = func_mock + + # Construct the delta + delta = MockDelta(content=None, tool_calls=[tool_call_mock]) + + # Construct the choice + self.choices = [MockChoice(delta=delta)] + + +@pytest.mark.django_db +@patch("openedx_ai_extensions.processors.llm.llm_processor.completion") +@patch("openedx_ai_extensions.processors.llm.llm_processor.adapt_to_provider") +def test_streaming_tool_execution_recursion( + mock_adapt, mock_completion, llm_processor # pylint: disable=redefined-outer-name +): + """ + Test that streaming correctly handles tool calls: + 1. Buffers tool call chunks. + 2. Executes the tool. + 3. Recursively calls completion with tool output. + 4. Yields the final content chunks. + """ + # 1. Setup + mock_adapt.side_effect = lambda provider, params, **kwargs: params + + # Configure processor for streaming + custom function calling + llm_processor.config["function"] = "summarize_content" # Uses _call_completion_wrapper + llm_processor.config["stream"] = True + llm_processor.stream = True + llm_processor.extra_params["tools"] = ["mock_tool"] # Needs to pass check in init if strict, but mainly for logic + + # 2. Define a Mock Tool + mock_tool_func = Mock(return_value="tool_result_value") + + # Patch the global AVAILABLE_TOOLS to include our mock + with patch.dict(AVAILABLE_TOOLS, {"mock_tool": mock_tool_func}): + # 3. Define Stream Sequences + + # Sequence 1: The Model decides to call "mock_tool" with args {"arg": "val"} + # Split into multiple chunks to test buffering logic + tool_chunks = [ + # Chunk 1: ID and Name + MockToolStreamChunk(index=0, tool_id="call_123", name="mock_tool"), + # Chunk 2: Start of args + MockToolStreamChunk(index=0, arguments='{"arg":'), + # Chunk 3: End of args + MockToolStreamChunk(index=0, arguments=' "val"}'), + ] + + # Sequence 2: The Model sees the tool result and generates final text + content_chunks = [ + MockStreamChunk("Result "), + MockStreamChunk("is "), + MockStreamChunk("tool_result_value"), + ] + + # Configure completion to return the first sequence, then the second + mock_completion.side_effect = [iter(tool_chunks), iter(content_chunks)] + + # 4. Execute + generator = llm_processor.process(context="Ctx") + results = list(generator) + + # 5. Assertions + + # Check final output (byte encoded by _handle_streaming_completion) + assert b"Result " in results + assert b"is " in results + assert b"tool_result_value" in results + + # Check Tool Execution + mock_tool_func.assert_called_once_with(arg="val") + + # Check Recursion (completion called twice) + assert mock_completion.call_count == 2 + + # Verify second call included the tool output + second_call_kwargs = mock_completion.call_args_list[1][1] + messages = second_call_kwargs["messages"] + + # Should have: System, (Context/User), Assistant(ToolCall), Tool(Result) + # Finding the tool message + tool_msg = next((m for m in messages if m.get("role") == "tool"), None) + assert tool_msg is not None + assert tool_msg["tool_call_id"] == "call_123" + assert tool_msg["content"] == "tool_result_value" + assert tool_msg["name"] == "mock_tool" + + +# ============================================================================ +# Threaded Streaming & Tool Call Tests (Responses API) +# ============================================================================ + +class MockResponseUsage: + """Helper to mock usage in Responses API.""" + + def __init__(self, total): + self.total_tokens = total + + +class MockResponsesChunk: + """Helper to mock chunks specifically for the Responses API stream.""" + + def __init__(self, chunk_type, **kwargs): + self.type = chunk_type + self.delta = kwargs.get("delta") + self.item = kwargs.get("item") + self.response = None + usage_total = kwargs.get("usage_total") + response_id = kwargs.get("response_id") + if usage_total or response_id: + self.response = Mock( + usage=MockResponseUsage(usage_total) if usage_total else None, + id=response_id + ) + if usage_total: + self.usage = MockResponseUsage(usage_total) + + +@pytest.mark.django_db +@patch("openedx_ai_extensions.processors.llm.llm_processor.logger") +@patch("openedx_ai_extensions.processors.llm.llm_processor.responses") +def test_yield_threaded_stream_text_and_tokens( + mock_responses, mock_logger, llm_processor, user_session # pylint: disable=W0621,W0613 +): + """ + Test verifies text yielding and ensures token usage is LOGGED properly. + """ + chunks = [ + MockResponsesChunk("response.created", response_id="resp_123"), + MockResponsesChunk("response.delta", delta="Hello"), + MockResponsesChunk("response.delta", delta="{}"), + MockResponsesChunk("response.delta", delta=" world"), + MockResponsesChunk("response.completed", usage_total=42) + ] + # pylint: disable=protected-access + generator = llm_processor._yield_threaded_stream(iter(chunks)) + results = list(generator) + + assert "Hello" in results + assert " world" in results + assert "{}" not in results + + user_session.refresh_from_db() + assert user_session.remote_response_id == "resp_123" + + found_token_log = any( + "Tokens used: 42" in str(args) + for args, _ in mock_logger.info.call_args_list + ) + assert found_token_log, "Total tokens (42) were not logged as expected." + + +@pytest.mark.django_db +@patch("openedx_ai_extensions.processors.llm.llm_processor.responses") +def test_yield_threaded_stream_recursive_tool_call( + mock_responses, llm_processor # pylint: disable=W0621 +): + """ + Test recursive tool execution in threaded stream. + """ + mock_dice_roll = Mock(return_value="[6]") + + with patch.dict("openedx_ai_extensions.functions.decorators.AVAILABLE_TOOLS", + {"roll_dice": mock_dice_roll}): + item_mock = Mock( + type="function_call", + call_id="call_abc", + arguments='{"n_dice": 1}' + ) + item_mock.name = "roll_dice" + + stream_a = [ + MockResponsesChunk("response.output_item.done", item=item_mock) + ] + + stream_b = [ + MockResponsesChunk("response.delta", delta="You rolled a 6") + ] + + mock_responses.return_value = iter(stream_b) + + params = {"input": [{"role": "user", "content": "Roll dice"}], "stream": True} + # pylint: disable=protected-access + generator = llm_processor._yield_threaded_stream(iter(stream_a), params=params) + results = list(generator) + + assert "Error: Tool not found." not in str(results) + + assert "You rolled a 6" in results + + mock_dice_roll.assert_called_once_with(n_dice=1) + + history = params["input"] + assert history[1]["type"] == "function_call" + assert history[1]["name"] == "roll_dice" + + mock_responses.assert_called_once()