Skip to content

Commit ec20ba9

Browse files
authored
Merge pull request #648 from UiPath/fix/response_format
fix(response_format): changed default serialization
2 parents 3fc695c + 80035ec commit ec20ba9

File tree

4 files changed

+112
-7
lines changed

4 files changed

+112
-7
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "uipath"
3-
version = "2.1.106"
3+
version = "2.1.107"
44
description = "Python SDK and CLI for UiPath Platform, enabling programmatic interaction with automation services, process management, and deployment tools."
55
readme = { file = "README.md", content-type = "text/markdown" }
66
requires-python = ">=3.10"

src/uipath/_services/llm_gateway_service.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -344,7 +344,7 @@ def __init__(self, config: Config, execution_context: ExecutionContext) -> None:
344344
@traced(name="llm_chat_completions", run_type="uipath")
345345
async def chat_completions(
346346
self,
347-
messages: List[Dict[str, str]],
347+
messages: Union[List[Dict[str, str]], List[tuple[str, str]]],
348348
model: str = ChatModels.gpt_4o_mini_2024_07_18,
349349
max_tokens: int = 4096,
350350
temperature: float = 0,
@@ -475,13 +475,26 @@ class Country(BaseModel):
475475
This service uses UiPath's normalized API format which provides consistent
476476
behavior across different underlying model providers and enhanced enterprise features.
477477
"""
478+
converted_messages = []
479+
480+
for message in messages:
481+
if isinstance(message, tuple) and len(message) == 2:
482+
role, content = message
483+
converted_messages.append({"role": role, "content": content})
484+
elif isinstance(message, dict):
485+
converted_messages.append(message)
486+
else:
487+
raise ValueError(
488+
f"Invalid message format: {message}. Expected tuple (role, content) or dict with 'role' and 'content' keys."
489+
)
490+
478491
endpoint = EndpointManager.get_normalized_endpoint().format(
479492
model=model, api_version=api_version
480493
)
481494
endpoint = Endpoint("/" + endpoint)
482495

483496
request_body = {
484-
"messages": messages,
497+
"messages": converted_messages,
485498
"max_tokens": max_tokens,
486499
"temperature": temperature,
487500
"n": n,

src/uipath/tracing/_utils.py

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,16 +13,26 @@
1313

1414
from opentelemetry.sdk.trace import ReadableSpan
1515
from opentelemetry.trace import StatusCode
16+
from pydantic import BaseModel
1617

1718
logger = logging.getLogger(__name__)
1819

1920

2021
def _simple_serialize_defaults(obj):
21-
if hasattr(obj, "model_dump"):
22+
# Handle Pydantic BaseModel instances
23+
if hasattr(obj, "model_dump") and not isinstance(obj, type):
2224
return obj.model_dump(exclude_none=True, mode="json")
23-
if hasattr(obj, "dict"):
25+
26+
# Handle classes - convert to schema representation
27+
if isinstance(obj, type) and issubclass(obj, BaseModel):
28+
return {
29+
"__class__": obj.__name__,
30+
"__module__": obj.__module__,
31+
"schema": obj.model_json_schema(),
32+
}
33+
if hasattr(obj, "dict") and not isinstance(obj, type):
2434
return obj.dict()
25-
if hasattr(obj, "to_dict"):
35+
if hasattr(obj, "to_dict") and not isinstance(obj, type):
2636
return obj.to_dict()
2737

2838
# Handle dataclasses
@@ -31,7 +41,7 @@ def _simple_serialize_defaults(obj):
3141

3242
# Handle enums
3343
if isinstance(obj, Enum):
34-
return obj.value
44+
return _simple_serialize_defaults(obj.value)
3545

3646
if isinstance(obj, (set, tuple)):
3747
if hasattr(obj, "_asdict") and callable(obj._asdict):
@@ -44,6 +54,10 @@ def _simple_serialize_defaults(obj):
4454
if isinstance(obj, (timezone, ZoneInfo)):
4555
return obj.tzname(None)
4656

57+
# Allow JSON-serializable primitives to pass through unchanged
58+
if obj is None or isinstance(obj, (bool, int, float, str)):
59+
return obj
60+
4761
return str(obj)
4862

4963

tests/tracing/test_traced.py

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -650,3 +650,81 @@ def test_complex_input(input: CalculatorInput) -> CalculatorOutput:
650650
assert output["result"] == 54.6 # 10.5 * 5.2 = 54.6
651651
# Verify the enum is serialized as its value
652652
assert output["operator"] == "*"
653+
654+
655+
@pytest.mark.asyncio
656+
async def test_traced_with_pydantic_basemodel_class(setup_tracer):
657+
"""Test that Pydantic BaseModel classes can be serialized in tracing.
658+
659+
This tests the fix for the issue where passing a Pydantic BaseModel class
660+
as a parameter (like response_format=OutputFormat) would cause JSON
661+
serialization errors in tracing.
662+
"""
663+
from pydantic import BaseModel
664+
665+
exporter, provider = setup_tracer
666+
667+
class OutputFormat(BaseModel):
668+
result: str
669+
confidence: float = 0.95
670+
671+
@traced()
672+
async def llm_chat_completions(messages: List[Any], response_format=None):
673+
"""Simulate LLM function with BaseModel class as response_format."""
674+
if response_format:
675+
mock_content = '{"result": "hi!", "confidence": 0.95}'
676+
return {"choices": [{"message": {"content": mock_content}}]}
677+
return {"choices": [{"message": {"content": "hi!"}}]}
678+
679+
# Test with tuple message format and BaseModel class as parameter
680+
messages = [("human", "repeat this: hi!")]
681+
result = await llm_chat_completions(messages, response_format=OutputFormat)
682+
683+
assert result is not None
684+
assert "choices" in result
685+
686+
provider.shutdown() # Ensure spans are flushed
687+
spans = exporter.get_exported_spans()
688+
689+
assert len(spans) == 1
690+
span = spans[0]
691+
assert span.name == "llm_chat_completions"
692+
assert span.attributes["span_type"] == "function_call_async"
693+
694+
# Verify inputs are properly serialized as JSON, including BaseModel class
695+
assert "input.value" in span.attributes
696+
inputs_json = span.attributes["input.value"]
697+
inputs = json.loads(inputs_json)
698+
699+
# Check BaseModel class is properly serialized with schema representation
700+
assert "response_format" in inputs
701+
response_format_data = inputs["response_format"]
702+
703+
# Verify the BaseModel class is serialized as a schema representation
704+
assert "__class__" in response_format_data
705+
assert "__module__" in response_format_data
706+
assert "schema" in response_format_data
707+
assert response_format_data["__class__"] == "OutputFormat"
708+
709+
# Verify the schema contains expected structure
710+
schema = response_format_data["schema"]
711+
assert "properties" in schema
712+
assert "result" in schema["properties"]
713+
assert "confidence" in schema["properties"]
714+
assert schema["properties"]["result"]["type"] == "string"
715+
assert schema["properties"]["confidence"]["type"] == "number"
716+
717+
# Verify that tuple messages are also properly serialized
718+
assert "messages" in inputs
719+
messages_data = inputs["messages"]
720+
assert isinstance(messages_data, list)
721+
assert len(messages_data) == 1
722+
assert messages_data[0] == ["human", "repeat this: hi!"]
723+
724+
# Verify that outputs are properly serialized as JSON
725+
assert "output.value" in span.attributes
726+
output_json = span.attributes["output.value"]
727+
output = json.loads(output_json)
728+
729+
assert "choices" in output
730+
assert len(output["choices"]) == 1

0 commit comments

Comments
 (0)