diff --git a/libs/oci/langchain_oci/chat_models/oci_generative_ai.py b/libs/oci/langchain_oci/chat_models/oci_generative_ai.py index c4d8dc3..c8c840b 100644 --- a/libs/oci/langchain_oci/chat_models/oci_generative_ai.py +++ b/libs/oci/langchain_oci/chat_models/oci_generative_ai.py @@ -58,7 +58,7 @@ CUSTOM_ENDPOINT_PREFIX = "ocid1.generativeaiendpoint" -# Mapping of JSON schema types to Python types +# Mapping of JSON schema types to Python types (for Cohere API) JSON_TO_PYTHON_TYPES = { "string": "str", "number": "float", @@ -69,6 +69,24 @@ "any": "any", } +# Mapping of Python types to JSON schema types (for Generic API) +PYTHON_TO_JSON_TYPES = { + "str": "string", + "float": "number", + "bool": "boolean", + "int": "integer", + "List": "array", + "list": "array", + "Dict": "object", + "dict": "object", + "any": ["string", "number", "integer", "boolean", "array", "object", "null"], +} + +# Valid JSON Schema types +VALID_JSON_SCHEMA_TYPES = { + "string", "number", "integer", "boolean", "array", "object", "null" +} + class OCIUtils: """Utility functions for OCI Generative AI integration.""" @@ -550,6 +568,77 @@ class GenericProvider(Provider): stop_sequence_key: str = "stop" + @staticmethod + def _to_json_schema_type(type_value: Optional[str]) -> str | list[str]: + """Convert a type to valid JSON Schema type(s). + + Args: + type_value: The type value (could be Python type or JSON Schema type) + + Returns: + A valid JSON Schema type or list of types. Defaults to "string" if invalid. + """ + if not type_value: + return "string" + + # If it's already a valid JSON Schema type, return it + if type_value in VALID_JSON_SCHEMA_TYPES: + return type_value + + # Try to convert from Python type to JSON Schema type + if type_value in PYTHON_TO_JSON_TYPES: + return PYTHON_TO_JSON_TYPES[type_value] # type: ignore[return-value] + + # Default to string for unknown types + return "string" + + def _normalize_properties(self, properties: Dict[str, Any]) -> Dict[str, Any]: + """Recursively normalize property schemas to use valid JSON Schema types. + + Args: + properties: Dictionary of property schemas + + Returns: + Dictionary with normalized property schemas + """ + normalized = {} + for prop_name, prop_schema in properties.items(): + if isinstance(prop_schema, dict): + normalized_prop = prop_schema.copy() + + # Normalize the type if present + if "type" in normalized_prop: + normalized_prop["type"] = self._to_json_schema_type( + normalized_prop["type"] + ) + + # Recursively normalize nested properties + if "properties" in normalized_prop and isinstance( + normalized_prop["properties"], dict + ): + normalized_prop["properties"] = self._normalize_properties( + normalized_prop["properties"] + ) + + # Handle array items + if "items" in normalized_prop and isinstance( + normalized_prop["items"], dict + ): + if "type" in normalized_prop["items"]: + normalized_prop["items"]["type"] = self._to_json_schema_type( + normalized_prop["items"]["type"] + ) + if "properties" in normalized_prop["items"]: + normalized_prop["items"]["properties"] = self._normalize_properties( + normalized_prop["items"]["properties"] + ) + + normalized[prop_name] = normalized_prop + else: + normalized[prop_name] = prop_schema + + return normalized + def __init__(self) -> None: from oci.generative_ai_inference import models @@ -814,7 +903,7 @@ def convert_to_oci_tool( "type": "object", "properties": { p_name: { - "type": p_def.get("type", "any"), + "type": self._to_json_schema_type(p_def.get("type")), "description": p_def.get("description", ""), } for p_name, p_def in tool.args.items() @@ -837,7 +926,9 @@ def convert_to_oci_tool( ), parameters={ "type": "object", - "properties": parameters.get("properties", {}), + "properties": self._normalize_properties( + parameters.get("properties", {}) + ), "required": parameters.get("required", []), }, ) @@ -1261,6 +1352,7 @@ def _generate( request = self._prepare_request(messages, stop=stop, stream=False, **kwargs) response = self.client.chat(request) + content = self._provider.chat_response_to_text(response) if stop is not None: diff --git a/libs/oci/tests/unit_tests/chat_models/test_oci_generative_ai.py b/libs/oci/tests/unit_tests/chat_models/test_oci_generative_ai.py index 1a6649a..ad76c1d 100644 --- a/libs/oci/tests/unit_tests/chat_models/test_oci_generative_ai.py +++ b/libs/oci/tests/unit_tests/chat_models/test_oci_generative_ai.py @@ -734,7 +734,7 @@ def mocked_response(*args, **kwargs): # type: ignore[no-untyped-def] assert response.content == "I'll help you." -def test_get_provider(): +def test_get_provider() -> None: """Test determining the provider based on the model_id.""" model_provider_map = { "cohere.command-latest": "CohereProvider", @@ -746,3 +746,208 @@ def test_get_provider(): ChatOCIGenAI(model_id=model_id)._provider.__class__.__name__ == provider_name ) + + +@pytest.mark.requires("oci") +def test_generic_provider_tool_schema_validation() -> None: + """Test that GenericProvider creates valid JSON schemas for tools. + + This test validates that tool parameters don't have invalid JSON Schema types + like 'any', which causes OCI API to reject the function schema. + """ + from langchain_core.tools import BaseTool + from langchain_core.messages import HumanMessage + from langchain_oci.chat_models.oci_generative_ai import GenericProvider + + # Mock a BaseTool + mock_tool = MagicMock(spec=BaseTool) + mock_tool.name = "tell_a_joke" + mock_tool.description = "Tell a joke about a topic." + mock_tool.args = { + "topic": { + "type": "string", + "description": "The topic of the joke" + } + } + + provider = GenericProvider() + function_def = provider.convert_to_oci_tool(mock_tool) + + # Valid JSON Schema types according to JSON Schema spec + valid_types = {"string", "number", "integer", "boolean", "array", "object", "null"} + + # Check that the function definition has valid structure + assert hasattr(function_def, "name") + assert hasattr(function_def, "description") + assert hasattr(function_def, "parameters") + + parameters = function_def.parameters + assert isinstance(parameters, dict) + assert parameters.get("type") == "object" + assert "properties" in parameters + + # Validate each property has a valid JSON Schema type + for prop_name, prop_schema in parameters["properties"].items(): + assert "type" in prop_schema, f"Property {prop_name} missing type" + prop_type = prop_schema["type"] + assert prop_type in valid_types, ( + f"Property {prop_name} has invalid JSON Schema type '{prop_type}'. " + f"Valid types are: {valid_types}" + ) + + # Now test that tools actually get passed through to the request + oci_gen_ai_client = MagicMock() + llm = ChatOCIGenAI(model_id="meta.llama-3.3-70b-instruct", client=oci_gen_ai_client) + + request_captured = None + + def mocked_response(*args, **kwargs): # type: ignore[no-untyped-def] + nonlocal request_captured + request_captured = args[0] + return MockResponseDict( + { + "status": 200, + "data": MockResponseDict( + { + "chat_response": MockResponseDict( + { + "api_format": "GENERIC", + "choices": [ + MockResponseDict( + { + "message": MockResponseDict( + { + "role": "ASSISTANT", + "content": [MockResponseDict({"text": "joke", "type": "TEXT"})], + "tool_calls": [], + } + ), + "finish_reason": "completed", + } + ) + ], + "time_created": "2025-08-14T10:00:01.100000+00:00", + } + ), + "model_id": "meta.llama-3.3-70b-instruct", + "model_version": "1.0.0", + } + ), + "request_id": "1234567890", + "headers": MockResponseDict({"content-length": "123"}), + } + ) + + oci_gen_ai_client.chat = mocked_response + + llm_with_tools = llm.bind_tools([mock_tool]) + messages = [HumanMessage(content="tell me a joke")] + llm_with_tools.invoke(messages) + + # Verify tools were passed to the request + assert request_captured is not None, "No request was captured" + assert hasattr(request_captured, 'chat_request'), "Request missing chat_request" + assert hasattr(request_captured.chat_request, 'tools'), "chat_request missing tools" + assert request_captured.chat_request.tools is not None, "tools is None" + assert len(request_captured.chat_request.tools) > 0, "tools list is empty" + + # Validate the tool schema in the request + tool = request_captured.chat_request.tools[0] + assert tool.name == "tell_a_joke" + assert isinstance(tool.parameters, dict) + assert tool.parameters.get("type") == "object" + + # Validate each property type + for prop_name, prop_schema in tool.parameters["properties"].items(): + prop_type = prop_schema["type"] + assert prop_type in valid_types, ( + f"Property {prop_name} has invalid JSON Schema type '{prop_type}' in request" + ) + + +@pytest.mark.requires("oci") +def test_generic_provider_tool_schema_validation_streaming() -> None: + """Test that tools work correctly with streaming enabled. + + This reproduces the issue where is_stream=True causes tool calling to fail. + """ + from langchain_core.tools import BaseTool + from langchain_core.messages import HumanMessage + + # Mock a BaseTool + mock_tool = MagicMock(spec=BaseTool) + mock_tool.name = "tell_a_joke" + mock_tool.description = "Tell a joke about a topic." + mock_tool.args = { + "topic": { + "type": "string", + "description": "The topic of the joke" + } + } + + oci_gen_ai_client = MagicMock() + llm = ChatOCIGenAI( + model_id="meta.llama-3.3-70b-instruct", + is_stream=True, # This is the key difference + client=oci_gen_ai_client + ) + + request_captured = None + + def mocked_stream_response(*args, **kwargs): # type: ignore[no-untyped-def] + nonlocal request_captured + request_captured = args[0] + + # Create a mock streaming response + class MockEvents: + def events(self): # type: ignore[no-untyped-def] + import json + # First event with content + yield MockResponseDict({ + "data": json.dumps({ + "message": { + "content": [{"text": "Here's a joke"}] + } + }) + }) + # Final event with finish reason + yield MockResponseDict({ + "data": json.dumps({ + "finishReason": "completed" + }) + }) + + return MockResponseDict({ + "data": MockEvents() + }) + + oci_gen_ai_client.chat = mocked_stream_response + + llm_with_tools = llm.bind_tools([mock_tool]) + messages = [HumanMessage(content="tell me a joke about felipe")] + + # This should not raise an error + response = llm_with_tools.invoke(messages) + + # Verify tools were passed to the request + assert request_captured is not None, "No request was captured" + assert hasattr(request_captured, 'chat_request'), "Request missing chat_request" + assert hasattr(request_captured.chat_request, 'tools'), "chat_request missing tools" + assert request_captured.chat_request.tools is not None, "tools is None" + assert len(request_captured.chat_request.tools) > 0, "tools list is empty" + + # Valid JSON Schema types + valid_types = {"string", "number", "integer", "boolean", "array", "object", "null"} + + # Validate the tool schema in the request + tool = request_captured.chat_request.tools[0] + assert tool.name == "tell_a_joke" + assert isinstance(tool.parameters, dict) + assert tool.parameters.get("type") == "object" + + # Validate each property type + for prop_name, prop_schema in tool.parameters["properties"].items(): + prop_type = prop_schema["type"] + assert prop_type in valid_types, ( + f"Property {prop_name} has invalid JSON Schema type '{prop_type}' in streaming request" + )