From aa66265c3cd321b2c35f0b96ec40419386172e71 Mon Sep 17 00:00:00 2001 From: Joshua Brown Date: Fri, 10 Oct 2025 11:52:08 -0700 Subject: [PATCH 1/2] Fix OCI tool calling with GenericProvider --- .../chat_models/oci_generative_ai.py | 6 +- .../chat_models/test_oci_generative_ai.py | 150 +++++++++++++++++- 2 files changed, 152 insertions(+), 4 deletions(-) diff --git a/libs/oci/langchain_oci/chat_models/oci_generative_ai.py b/libs/oci/langchain_oci/chat_models/oci_generative_ai.py index 939313a..c4d8dc3 100644 --- a/libs/oci/langchain_oci/chat_models/oci_generative_ai.py +++ b/libs/oci/langchain_oci/chat_models/oci_generative_ai.py @@ -712,9 +712,9 @@ def messages_to_oci_params( ) else: oci_message = self.oci_chat_message[role](content=tool_content) - elif isinstance(message, AIMessage) and message.additional_kwargs.get( - "tool_calls" - ): + elif isinstance(message, AIMessage) and ( + message.tool_calls or + message.additional_kwargs.get("tool_calls")): # Process content and tool calls for assistant messages content = self._process_message_content(message.content) tool_calls = [] diff --git a/libs/oci/tests/unit_tests/chat_models/test_oci_generative_ai.py b/libs/oci/tests/unit_tests/chat_models/test_oci_generative_ai.py index d33adde..84daccc 100644 --- a/libs/oci/tests/unit_tests/chat_models/test_oci_generative_ai.py +++ b/libs/oci/tests/unit_tests/chat_models/test_oci_generative_ai.py @@ -6,9 +6,9 @@ from unittest.mock import MagicMock import pytest -from langchain_core.messages import HumanMessage from pytest import MonkeyPatch +from langchain_core.messages import HumanMessage, AIMessage from langchain_oci.chat_models.oci_generative_ai import ChatOCIGenAI @@ -575,6 +575,154 @@ def mocked_response(*args, **kwargs): # type: ignore[no-untyped-def] assert response["parsed"].conditions == "Sunny" +@pytest.mark.requires("oci") +def test_ai_message_tool_calls_direct_field(monkeypatch: MonkeyPatch) -> None: + """Test AIMessage with tool_calls in the direct tool_calls field.""" + + oci_gen_ai_client = MagicMock() + llm = ChatOCIGenAI(model_id="meta.llama-3.3-70b-instruct", client=oci_gen_ai_client) + + def mocked_response(*args, **kwargs): # type: ignore[no-untyped-def] + return MockResponseDict( + { + "status": 200, + "data": MockResponseDict( + { + "chat_response": MockResponseDict( + { + "api_format": "GENERIC", + "choices": [ + MockResponseDict( + { + "message": MockResponseDict( + { + "role": "ASSISTANT", + "name": None, + "content": [ + MockResponseDict( + { + "text": ( + "I'll help you." + ), + "type": "TEXT", + } + ) + ], + "tool_calls": [], + } + ), + "finish_reason": "completed", + } + ) + ], + "time_created": "2025-08-14T10:00:01.100000+00:00", + } + ), + "model_id": "meta.llama-3.3-70b-instruct", + "model_version": "1.0.0", + } + ), + "request_id": "1234567890", + "headers": MockResponseDict({"content-length": "123"}), + } + ) + + monkeypatch.setattr(llm.client, "chat", mocked_response) + + # Create AIMessage with tool_calls in the direct tool_calls field + ai_message = AIMessage( + content="I need to call a function", + tool_calls=[ + { + "id": "call_123", + "name": "get_weather", + "args": {"location": "San Francisco"}, + } + ] + ) + + messages = [ai_message] + + # This should not raise an error and should process the tool_calls correctly + response = llm.invoke(messages) + assert response.content == "I'll help you." + + +@pytest.mark.requires("oci") +def test_ai_message_tool_calls_additional_kwargs(monkeypatch: MonkeyPatch) -> None: + """Test AIMessage with tool_calls in additional_kwargs field.""" + + oci_gen_ai_client = MagicMock() + llm = ChatOCIGenAI(model_id="meta.llama-3.3-70b-instruct", client=oci_gen_ai_client) + + def mocked_response(*args, **kwargs): # type: ignore[no-untyped-def] + return MockResponseDict( + { + "status": 200, + "data": MockResponseDict( + { + "chat_response": MockResponseDict( + { + "api_format": "GENERIC", + "choices": [ + MockResponseDict( + { + "message": MockResponseDict( + { + "role": "ASSISTANT", + "name": None, + "content": [ + MockResponseDict( + { + "text": ( + "I'll help you." + ), + "type": "TEXT", + } + ) + ], + "tool_calls": [], + } + ), + "finish_reason": "completed", + } + ) + ], + "time_created": "2025-08-14T10:00:01.100000+00:00", + } + ), + "model_id": "meta.llama-3.3-70b-instruct", + "model_version": "1.0.0", + } + ), + "request_id": "1234567890", + "headers": MockResponseDict({"content-length": "123"}), + } + ) + + monkeypatch.setattr(llm.client, "chat", mocked_response) + + # Create AIMessage with tool_calls in additional_kwargs + ai_message = AIMessage( + content="I need to call a function", + additional_kwargs={ + "tool_calls": [ + { + "id": "call_456", + "name": "get_weather", + "args": {"location": "New York"}, + } + ] + } + ) + + messages = [ai_message] + + # This should not raise an error and should process the tool_calls correctly + response = llm.invoke(messages) + assert response.content == "I'll help you." + + def test_get_provider(): """Test determining the provider based on the model_id.""" model_provider_map = { From 3c62fd7d380cd0ff75bee1069330abd4d81a53ad Mon Sep 17 00:00:00 2001 From: Joshua Brown Date: Fri, 10 Oct 2025 12:03:06 -0700 Subject: [PATCH 2/2] Update unit test to verify tool call --- .../unit_tests/chat_models/test_oci_generative_ai.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/libs/oci/tests/unit_tests/chat_models/test_oci_generative_ai.py b/libs/oci/tests/unit_tests/chat_models/test_oci_generative_ai.py index 84daccc..1a6649a 100644 --- a/libs/oci/tests/unit_tests/chat_models/test_oci_generative_ai.py +++ b/libs/oci/tests/unit_tests/chat_models/test_oci_generative_ai.py @@ -582,7 +582,18 @@ def test_ai_message_tool_calls_direct_field(monkeypatch: MonkeyPatch) -> None: oci_gen_ai_client = MagicMock() llm = ChatOCIGenAI(model_id="meta.llama-3.3-70b-instruct", client=oci_gen_ai_client) + # Track if the tool_calls processing branch is executed + tool_calls_processed = False + def mocked_response(*args, **kwargs): # type: ignore[no-untyped-def] + nonlocal tool_calls_processed + # Check if the request contains tool_calls in the message + request = args[0] + if hasattr(request, 'chat_request') and hasattr(request.chat_request, 'messages'): + for msg in request.chat_request.messages: + if hasattr(msg, 'tool_calls') and msg.tool_calls: + tool_calls_processed = True + break return MockResponseDict( { "status": 200,