Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 25 additions & 2 deletions src/google/adk/models/lite_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,19 @@
_NEW_LINE = "\n"
_EXCLUDED_PART_FIELD = {"inline_data": {"data"}}

# Mapping of LiteLLM finish_reason strings to FinishReason enum values
# Note: tool_calls/function_call map to STOP because:
# 1. FinishReason.TOOL_CALL enum does not exist (as of google-genai 0.8.0)
# 2. Tool calls represent normal completion (model stopped to invoke tools)
# 3. Gemini native responses use STOP for tool calls (see lite_llm.py:910)
_FINISH_REASON_MAPPING = {
"length": types.FinishReason.MAX_TOKENS,
"stop": types.FinishReason.STOP,
"tool_calls": types.FinishReason.STOP, # Normal completion with tool invocation
"function_call": types.FinishReason.STOP, # Legacy function call variant
"content_filter": types.FinishReason.SAFETY,
}


class ChatCompletionFileUrlObject(TypedDict, total=False):
file_data: str
Expand Down Expand Up @@ -494,13 +507,23 @@ def _model_response_to_generate_content_response(
"""

message = None
if response.get("choices", None):
message = response["choices"][0].get("message", None)
finish_reason = None
if (choices := response.get("choices")) and choices:
first_choice = choices[0]
message = first_choice.get("message", None)
finish_reason = first_choice.get("finish_reason", None)

if not message:
raise ValueError("No message in response")

llm_response = _message_to_generate_content_response(message)
if finish_reason:
# Map LiteLLM finish_reason strings to FinishReason enum
# This provides type consistency with Gemini native responses and avoids warnings
finish_reason_str = str(finish_reason).lower()
llm_response.finish_reason = _FINISH_REASON_MAPPING.get(
finish_reason_str, types.FinishReason.OTHER
)
if response.get("usage", None):
llm_response.usage_metadata = types.GenerateContentResponseUsageMetadata(
prompt_token_count=response["usage"].get("prompt_tokens", 0),
Expand Down
6 changes: 5 additions & 1 deletion src/google/adk/models/llm_response.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,11 @@ class LlmResponse(BaseModel):
"""

finish_reason: Optional[types.FinishReason] = None
"""The finish reason of the response."""
"""The finish reason of the response.

Always a types.FinishReason enum. String values from underlying model providers
are mapped to corresponding enum values (with fallback to OTHER for unknown values).
"""

error_code: Optional[str] = None
"""Error code if the response is an error. Code varies by model."""
Expand Down
4 changes: 3 additions & 1 deletion src/google/adk/telemetry/tracing.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,9 +281,11 @@ def trace_call_llm(
llm_response.usage_metadata.candidates_token_count,
)
if llm_response.finish_reason:
# finish_reason is always FinishReason enum
finish_reason_str = llm_response.finish_reason.name.lower()
span.set_attribute(
'gen_ai.response.finish_reasons',
[llm_response.finish_reason.value.lower()],
[finish_reason_str],
)


Expand Down
113 changes: 113 additions & 0 deletions tests/unittests/models/test_litellm.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import warnings

from google.adk.models.lite_llm import _content_to_message_param
from google.adk.models.lite_llm import _FINISH_REASON_MAPPING
from google.adk.models.lite_llm import _function_declaration_to_tool_param
from google.adk.models.lite_llm import _get_content
from google.adk.models.lite_llm import _message_to_generate_content_response
Expand Down Expand Up @@ -1903,3 +1904,115 @@ def test_non_gemini_litellm_no_warning():
# Test with non-Gemini model
LiteLlm(model="openai/gpt-4o")
assert len(w) == 0


@pytest.mark.parametrize(
"finish_reason,response_content,expected_content,has_tool_calls",
[
("length", "Test response", "Test response", False),
("stop", "Complete response", "Complete response", False),
(
"tool_calls",
"",
"",
True,
),
("content_filter", "", "", False),
],
ids=["length", "stop", "tool_calls", "content_filter"],
)
@pytest.mark.asyncio
async def test_finish_reason_propagation(
mock_acompletion,
lite_llm_instance,
finish_reason,
response_content,
expected_content,
has_tool_calls,
):
"""Test that finish_reason is properly propagated from LiteLLM response."""
tool_calls = None
if has_tool_calls:
tool_calls = [
ChatCompletionMessageToolCall(
type="function",
id="test_id",
function=Function(
name="test_function",
arguments='{"arg": "value"}',
),
)
]

mock_response = ModelResponse(
choices=[
Choices(
message=ChatCompletionAssistantMessage(
role="assistant",
content=response_content,
tool_calls=tool_calls,
),
finish_reason=finish_reason,
)
]
)
mock_acompletion.return_value = mock_response

llm_request = LlmRequest(
contents=[
types.Content(
role="user", parts=[types.Part.from_text(text="Test prompt")]
)
],
)

async for response in lite_llm_instance.generate_content_async(llm_request):
assert response.content.role == "model"
# Verify finish_reason is mapped to FinishReason enum
assert isinstance(response.finish_reason, types.FinishReason)
# Verify correct enum mapping using the actual mapping from lite_llm
assert response.finish_reason == _FINISH_REASON_MAPPING[finish_reason]
if expected_content:
assert response.content.parts[0].text == expected_content
if has_tool_calls:
assert len(response.content.parts) > 0
assert response.content.parts[-1].function_call.name == "test_function"

mock_acompletion.assert_called_once()



@pytest.mark.asyncio
async def test_finish_reason_unknown_maps_to_other(
mock_acompletion, lite_llm_instance
):
"""Test that unknown finish_reason values map to FinishReason.OTHER."""
mock_response = ModelResponse(
choices=[
Choices(
message=ChatCompletionAssistantMessage(
role="assistant",
content="Test response",
),
finish_reason="unknown_reason_type",
)
]
)
mock_acompletion.return_value = mock_response

llm_request = LlmRequest(
contents=[
types.Content(
role="user", parts=[types.Part.from_text(text="Test prompt")]
)
],
)

async for response in lite_llm_instance.generate_content_async(llm_request):
assert response.content.role == "model"
# Unknown finish_reason should map to OTHER
assert isinstance(response.finish_reason, types.FinishReason)
assert response.finish_reason == types.FinishReason.OTHER

mock_acompletion.assert_called_once()