Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
98 changes: 95 additions & 3 deletions libs/oci/langchain_oci/chat_models/oci_generative_ai.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@

CUSTOM_ENDPOINT_PREFIX = "ocid1.generativeaiendpoint"

# Mapping of JSON schema types to Python types
# Mapping of JSON schema types to Python types (for Cohere API)
JSON_TO_PYTHON_TYPES = {
"string": "str",
"number": "float",
Expand All @@ -69,6 +69,24 @@
"any": "any",
}

# Mapping of Python types to JSON schema types (for Generic API)
PYTHON_TO_JSON_TYPES = {
"str": "string",
"float": "number",
"bool": "boolean",
"int": "integer",
"List": "array",
"list": "array",
"Dict": "object",
"dict": "object",
"any": ["string", "number", "integer", "boolean", "array", "object", "null"],
}

# Valid JSON Schema types
VALID_JSON_SCHEMA_TYPES = {
"string", "number", "integer", "boolean", "array", "object", "null"
}


class OCIUtils:
"""Utility functions for OCI Generative AI integration."""
Expand Down Expand Up @@ -550,6 +568,77 @@ class GenericProvider(Provider):

stop_sequence_key: str = "stop"

@staticmethod
def _to_json_schema_type(type_value: Optional[str]) -> str | list[str]:
"""Convert a type to valid JSON Schema type(s).

Args:
type_value: The type value (could be Python type or JSON Schema type)

Returns:
A valid JSON Schema type or list of types. Defaults to "string" if invalid.
"""
if not type_value:
return "string"

# If it's already a valid JSON Schema type, return it
if type_value in VALID_JSON_SCHEMA_TYPES:
return type_value

# Try to convert from Python type to JSON Schema type
if type_value in PYTHON_TO_JSON_TYPES:
return PYTHON_TO_JSON_TYPES[type_value] # type: ignore[return-value]

# Default to string for unknown types
return "string"

def _normalize_properties(self, properties: Dict[str, Any]) -> Dict[str, Any]:
"""Recursively normalize property schemas to use valid JSON Schema types.

Args:
properties: Dictionary of property schemas

Returns:
Dictionary with normalized property schemas
"""
normalized = {}
for prop_name, prop_schema in properties.items():
if isinstance(prop_schema, dict):
normalized_prop = prop_schema.copy()

# Normalize the type if present
if "type" in normalized_prop:
normalized_prop["type"] = self._to_json_schema_type(
normalized_prop["type"]
)

# Recursively normalize nested properties
if "properties" in normalized_prop and isinstance(
normalized_prop["properties"], dict
):
normalized_prop["properties"] = self._normalize_properties(
normalized_prop["properties"]
)

# Handle array items
if "items" in normalized_prop and isinstance(
normalized_prop["items"], dict
):
if "type" in normalized_prop["items"]:
normalized_prop["items"]["type"] = self._to_json_schema_type(
normalized_prop["items"]["type"]
)
if "properties" in normalized_prop["items"]:
normalized_prop["items"]["properties"] = self._normalize_properties(
normalized_prop["items"]["properties"]
)

normalized[prop_name] = normalized_prop
else:
normalized[prop_name] = prop_schema

return normalized

def __init__(self) -> None:
from oci.generative_ai_inference import models

Expand Down Expand Up @@ -814,7 +903,7 @@ def convert_to_oci_tool(
"type": "object",
"properties": {
p_name: {
"type": p_def.get("type", "any"),
"type": self._to_json_schema_type(p_def.get("type")),
"description": p_def.get("description", ""),
}
for p_name, p_def in tool.args.items()
Expand All @@ -837,7 +926,9 @@ def convert_to_oci_tool(
),
parameters={
"type": "object",
"properties": parameters.get("properties", {}),
"properties": self._normalize_properties(
parameters.get("properties", {})
),
"required": parameters.get("required", []),
},
)
Expand Down Expand Up @@ -1261,6 +1352,7 @@ def _generate(
request = self._prepare_request(messages, stop=stop, stream=False, **kwargs)
response = self.client.chat(request)


content = self._provider.chat_response_to_text(response)

if stop is not None:
Expand Down
207 changes: 206 additions & 1 deletion libs/oci/tests/unit_tests/chat_models/test_oci_generative_ai.py
Original file line number Diff line number Diff line change
Expand Up @@ -734,7 +734,7 @@ def mocked_response(*args, **kwargs): # type: ignore[no-untyped-def]
assert response.content == "I'll help you."


def test_get_provider():
def test_get_provider() -> None:
"""Test determining the provider based on the model_id."""
model_provider_map = {
"cohere.command-latest": "CohereProvider",
Expand All @@ -746,3 +746,208 @@ def test_get_provider():
ChatOCIGenAI(model_id=model_id)._provider.__class__.__name__
== provider_name
)


@pytest.mark.requires("oci")
def test_generic_provider_tool_schema_validation() -> None:
"""Test that GenericProvider creates valid JSON schemas for tools.

This test validates that tool parameters don't have invalid JSON Schema types
like 'any', which causes OCI API to reject the function schema.
"""
from langchain_core.tools import BaseTool
from langchain_core.messages import HumanMessage
from langchain_oci.chat_models.oci_generative_ai import GenericProvider

# Mock a BaseTool
mock_tool = MagicMock(spec=BaseTool)
mock_tool.name = "tell_a_joke"
mock_tool.description = "Tell a joke about a topic."
mock_tool.args = {
"topic": {
"type": "string",
"description": "The topic of the joke"
}
}

provider = GenericProvider()
function_def = provider.convert_to_oci_tool(mock_tool)

# Valid JSON Schema types according to JSON Schema spec
valid_types = {"string", "number", "integer", "boolean", "array", "object", "null"}

# Check that the function definition has valid structure
assert hasattr(function_def, "name")
assert hasattr(function_def, "description")
assert hasattr(function_def, "parameters")

parameters = function_def.parameters
assert isinstance(parameters, dict)
assert parameters.get("type") == "object"
assert "properties" in parameters

# Validate each property has a valid JSON Schema type
for prop_name, prop_schema in parameters["properties"].items():
assert "type" in prop_schema, f"Property {prop_name} missing type"
prop_type = prop_schema["type"]
assert prop_type in valid_types, (
f"Property {prop_name} has invalid JSON Schema type '{prop_type}'. "
f"Valid types are: {valid_types}"
)

# Now test that tools actually get passed through to the request
oci_gen_ai_client = MagicMock()
llm = ChatOCIGenAI(model_id="meta.llama-3.3-70b-instruct", client=oci_gen_ai_client)

request_captured = None

def mocked_response(*args, **kwargs): # type: ignore[no-untyped-def]
nonlocal request_captured
request_captured = args[0]
return MockResponseDict(
{
"status": 200,
"data": MockResponseDict(
{
"chat_response": MockResponseDict(
{
"api_format": "GENERIC",
"choices": [
MockResponseDict(
{
"message": MockResponseDict(
{
"role": "ASSISTANT",
"content": [MockResponseDict({"text": "joke", "type": "TEXT"})],
"tool_calls": [],
}
),
"finish_reason": "completed",
}
)
],
"time_created": "2025-08-14T10:00:01.100000+00:00",
}
),
"model_id": "meta.llama-3.3-70b-instruct",
"model_version": "1.0.0",
}
),
"request_id": "1234567890",
"headers": MockResponseDict({"content-length": "123"}),
}
)

oci_gen_ai_client.chat = mocked_response

llm_with_tools = llm.bind_tools([mock_tool])
messages = [HumanMessage(content="tell me a joke")]
llm_with_tools.invoke(messages)

# Verify tools were passed to the request
assert request_captured is not None, "No request was captured"
assert hasattr(request_captured, 'chat_request'), "Request missing chat_request"
assert hasattr(request_captured.chat_request, 'tools'), "chat_request missing tools"
assert request_captured.chat_request.tools is not None, "tools is None"
assert len(request_captured.chat_request.tools) > 0, "tools list is empty"

# Validate the tool schema in the request
tool = request_captured.chat_request.tools[0]
assert tool.name == "tell_a_joke"
assert isinstance(tool.parameters, dict)
assert tool.parameters.get("type") == "object"

# Validate each property type
for prop_name, prop_schema in tool.parameters["properties"].items():
prop_type = prop_schema["type"]
assert prop_type in valid_types, (
f"Property {prop_name} has invalid JSON Schema type '{prop_type}' in request"
)


@pytest.mark.requires("oci")
def test_generic_provider_tool_schema_validation_streaming() -> None:
"""Test that tools work correctly with streaming enabled.

This reproduces the issue where is_stream=True causes tool calling to fail.
"""
from langchain_core.tools import BaseTool
from langchain_core.messages import HumanMessage

# Mock a BaseTool
mock_tool = MagicMock(spec=BaseTool)
mock_tool.name = "tell_a_joke"
mock_tool.description = "Tell a joke about a topic."
mock_tool.args = {
"topic": {
"type": "string",
"description": "The topic of the joke"
}
}

oci_gen_ai_client = MagicMock()
llm = ChatOCIGenAI(
model_id="meta.llama-3.3-70b-instruct",
is_stream=True, # This is the key difference
client=oci_gen_ai_client
)

request_captured = None

def mocked_stream_response(*args, **kwargs): # type: ignore[no-untyped-def]
nonlocal request_captured
request_captured = args[0]

# Create a mock streaming response
class MockEvents:
def events(self): # type: ignore[no-untyped-def]
import json
# First event with content
yield MockResponseDict({
"data": json.dumps({
"message": {
"content": [{"text": "Here's a joke"}]
}
})
})
# Final event with finish reason
yield MockResponseDict({
"data": json.dumps({
"finishReason": "completed"
})
})

return MockResponseDict({
"data": MockEvents()
})

oci_gen_ai_client.chat = mocked_stream_response

llm_with_tools = llm.bind_tools([mock_tool])
messages = [HumanMessage(content="tell me a joke about felipe")]

# This should not raise an error
response = llm_with_tools.invoke(messages)

# Verify tools were passed to the request
assert request_captured is not None, "No request was captured"
assert hasattr(request_captured, 'chat_request'), "Request missing chat_request"
assert hasattr(request_captured.chat_request, 'tools'), "chat_request missing tools"
assert request_captured.chat_request.tools is not None, "tools is None"
assert len(request_captured.chat_request.tools) > 0, "tools list is empty"

# Valid JSON Schema types
valid_types = {"string", "number", "integer", "boolean", "array", "object", "null"}

# Validate the tool schema in the request
tool = request_captured.chat_request.tools[0]
assert tool.name == "tell_a_joke"
assert isinstance(tool.parameters, dict)
assert tool.parameters.get("type") == "object"

# Validate each property type
for prop_name, prop_schema in tool.parameters["properties"].items():
prop_type = prop_schema["type"]
assert prop_type in valid_types, (
f"Property {prop_name} has invalid JSON Schema type '{prop_type}' in streaming request"
)