From 133a5cfac327bb323df55cf2b725badff238b237 Mon Sep 17 00:00:00 2001 From: Eitan Yarmush Date: Wed, 10 Sep 2025 18:15:38 +0000 Subject: [PATCH 1/2] feat: allow for using default anthropic client natively Signed-off-by: Eitan Yarmush --- contributing/samples/token_usage/agent.py | 48 ++++++++++---------- src/google/adk/models/anthropic_llm.py | 30 +++++++++--- tests/unittests/models/test_anthropic_llm.py | 16 ++++--- tests/unittests/models/test_models.py | 42 ++++++++--------- 4 files changed, 78 insertions(+), 58 deletions(-) diff --git a/contributing/samples/token_usage/agent.py b/contributing/samples/token_usage/agent.py index a73f9e7638..35b9775706 100755 --- a/contributing/samples/token_usage/agent.py +++ b/contributing/samples/token_usage/agent.py @@ -26,26 +26,26 @@ def roll_die(sides: int, tool_context: ToolContext) -> int: - """Roll a die and return the rolled result. + """Roll a die and return the rolled result. - Args: - sides: The integer number of sides the die has. + Args: + sides: The integer number of sides the die has. - Returns: - An integer of the result of rolling the die. - """ - result = random.randint(1, sides) - if 'rolls' not in tool_context.state: - tool_context.state['rolls'] = [] + Returns: + An integer of the result of rolling the die. + """ + result = random.randint(1, sides) + if "rolls" not in tool_context.state: + tool_context.state["rolls"] = [] - tool_context.state['rolls'] = tool_context.state['rolls'] + [result] - return result + tool_context.state["rolls"] = tool_context.state["rolls"] + [result] + return result roll_agent_with_openai = LlmAgent( - model=LiteLlm(model='openai/gpt-4o'), - description='Handles rolling dice of different sizes.', - name='roll_agent_with_openai', + model=LiteLlm(model="openai/gpt-4o"), + description="Handles rolling dice of different sizes.", + name="roll_agent_with_openai", instruction=""" You are responsible for rolling dice based on the user's request. When asked to roll a die, you must call the roll_die tool with the number of sides as an integer. @@ -54,9 +54,9 @@ def roll_die(sides: int, tool_context: ToolContext) -> int: ) roll_agent_with_claude = LlmAgent( - model=Claude(model='claude-3-7-sonnet@20250219'), - description='Handles rolling dice of different sizes.', - name='roll_agent_with_claude', + model=Claude(model="claude-3-7-sonnet@20250219"), + description="Handles rolling dice of different sizes.", + name="roll_agent_with_claude", instruction=""" You are responsible for rolling dice based on the user's request. When asked to roll a die, you must call the roll_die tool with the number of sides as an integer. @@ -65,9 +65,9 @@ def roll_die(sides: int, tool_context: ToolContext) -> int: ) roll_agent_with_litellm_claude = LlmAgent( - model=LiteLlm(model='vertex_ai/claude-3-7-sonnet'), - description='Handles rolling dice of different sizes.', - name='roll_agent_with_litellm_claude', + model=LiteLlm(model="vertex_ai/claude-3-7-sonnet"), + description="Handles rolling dice of different sizes.", + name="roll_agent_with_litellm_claude", instruction=""" You are responsible for rolling dice based on the user's request. When asked to roll a die, you must call the roll_die tool with the number of sides as an integer. @@ -76,9 +76,9 @@ def roll_die(sides: int, tool_context: ToolContext) -> int: ) roll_agent_with_gemini = LlmAgent( - model='gemini-2.0-flash', - description='Handles rolling dice of different sizes.', - name='roll_agent_with_gemini', + model="gemini-2.0-flash", + description="Handles rolling dice of different sizes.", + name="roll_agent_with_gemini", instruction=""" You are responsible for rolling dice based on the user's request. When asked to roll a die, you must call the roll_die tool with the number of sides as an integer. @@ -87,7 +87,7 @@ def roll_die(sides: int, tool_context: ToolContext) -> int: ) root_agent = SequentialAgent( - name='code_pipeline_agent', + name="code_pipeline_agent", sub_agents=[ roll_agent_with_openai, roll_agent_with_claude, diff --git a/src/google/adk/models/anthropic_llm.py b/src/google/adk/models/anthropic_llm.py index 6c20b1b9a5..730689efa5 100644 --- a/src/google/adk/models/anthropic_llm.py +++ b/src/google/adk/models/anthropic_llm.py @@ -29,9 +29,11 @@ from typing import TYPE_CHECKING from typing import Union -from anthropic import AnthropicVertex +from anthropic import AsyncAnthropic +from anthropic import AsyncAnthropicVertex from anthropic import NOT_GIVEN from anthropic import types as anthropic_types +from anthropic.resources.messages import AsyncMessages from google.genai import types from pydantic import BaseModel from typing_extensions import override @@ -244,8 +246,8 @@ def function_declaration_to_tool_param( ) -class Claude(BaseLlm): - """Integration with Claude models served from Vertex AI. +class AnthropicClaude(BaseLlm): + """Integration with Claude models served from Anthropic. Attributes: model: The name of the Claude model. @@ -284,7 +286,7 @@ async def generate_content_async( else NOT_GIVEN ) # TODO(b/421255973): Enable streaming for anthropic models. - message = self._anthropic_client.messages.create( + message = await self._anthropic_client.create( model=llm_request.model, system=llm_request.config.system_instruction, messages=messages, @@ -295,7 +297,21 @@ async def generate_content_async( yield message_to_generate_content_response(message) @cached_property - def _anthropic_client(self) -> AnthropicVertex: + def _anthropic_client(self) -> AsyncMessages: + return AsyncAnthropic().messages + + +class Claude(AnthropicClaude): + """Integration with Claude models served from Vertex AI. + + Attributes: + model: The name of the Claude model. + max_tokens: The maximum number of tokens to generate. + """ + + @cached_property + @override + def _anthropic_client(self) -> AsyncMessages: if ( "GOOGLE_CLOUD_PROJECT" not in os.environ or "GOOGLE_CLOUD_LOCATION" not in os.environ @@ -305,7 +321,7 @@ def _anthropic_client(self) -> AnthropicVertex: " Anthropic on Vertex." ) - return AnthropicVertex( + return AsyncAnthropicVertex( project_id=os.environ["GOOGLE_CLOUD_PROJECT"], region=os.environ["GOOGLE_CLOUD_LOCATION"], - ) + ).messages diff --git a/tests/unittests/models/test_anthropic_llm.py b/tests/unittests/models/test_anthropic_llm.py index a81fbc7252..fe4c00718f 100644 --- a/tests/unittests/models/test_anthropic_llm.py +++ b/tests/unittests/models/test_anthropic_llm.py @@ -295,7 +295,9 @@ async def test_function_declaration_to_tool_param( async def test_generate_content_async( claude_llm, llm_request, generate_content_response, generate_llm_response ): - with mock.patch.object(claude_llm, "_anthropic_client") as mock_client: + with mock.patch.object( + claude_llm, "_anthropic_client" + ) as mock_messages_client: with mock.patch.object( anthropic_llm, "message_to_generate_content_response", @@ -306,7 +308,7 @@ async def mock_coro(): return generate_content_response # Assign the coroutine to the mocked method - mock_client.messages.create.return_value = mock_coro() + mock_messages_client.create.return_value = mock_coro() responses = [ resp @@ -324,7 +326,9 @@ async def test_generate_content_async_with_max_tokens( llm_request, generate_content_response, generate_llm_response ): claude_llm = Claude(model="claude-3-5-sonnet-v2@20241022", max_tokens=4096) - with mock.patch.object(claude_llm, "_anthropic_client") as mock_client: + with mock.patch.object( + claude_llm, "_anthropic_client" + ) as mock_messages_client: with mock.patch.object( anthropic_llm, "message_to_generate_content_response", @@ -335,7 +339,7 @@ async def mock_coro(): return generate_content_response # Assign the coroutine to the mocked method - mock_client.messages.create.return_value = mock_coro() + mock_messages_client.create.return_value = mock_coro() _ = [ resp @@ -343,6 +347,6 @@ async def mock_coro(): llm_request, stream=False ) ] - mock_client.messages.create.assert_called_once() - _, kwargs = mock_client.messages.create.call_args + mock_messages_client.create.assert_called_once() + _, kwargs = mock_messages_client.create.call_args assert kwargs["max_tokens"] == 4096 diff --git a/tests/unittests/models/test_models.py b/tests/unittests/models/test_models.py index 70246c7bc1..ea952af804 100644 --- a/tests/unittests/models/test_models.py +++ b/tests/unittests/models/test_models.py @@ -20,17 +20,17 @@ @pytest.mark.parametrize( - 'model_name', + "model_name", [ - 'gemini-1.5-flash', - 'gemini-1.5-flash-001', - 'gemini-1.5-flash-002', - 'gemini-1.5-pro', - 'gemini-1.5-pro-001', - 'gemini-1.5-pro-002', - 'gemini-2.0-flash-exp', - 'projects/123456/locations/us-central1/endpoints/123456', # finetuned vertex gemini endpoint - 'projects/123456/locations/us-central1/publishers/google/models/gemini-2.0-flash-exp', # vertex gemini long name + "gemini-1.5-flash", + "gemini-1.5-flash-001", + "gemini-1.5-flash-002", + "gemini-1.5-pro", + "gemini-1.5-pro-001", + "gemini-1.5-pro-002", + "gemini-2.0-flash-exp", + "projects/123456/locations/us-central1/endpoints/123456", # finetuned vertex gemini endpoint + "projects/123456/locations/us-central1/publishers/google/models/gemini-2.0-flash-exp", # vertex gemini long name ], ) def test_match_gemini_family(model_name): @@ -38,16 +38,16 @@ def test_match_gemini_family(model_name): @pytest.mark.parametrize( - 'model_name', + "model_name", [ - 'claude-3-5-haiku@20241022', - 'claude-3-5-sonnet-v2@20241022', - 'claude-3-5-sonnet@20240620', - 'claude-3-haiku@20240307', - 'claude-3-opus@20240229', - 'claude-3-sonnet@20240229', - 'claude-sonnet-4@20250514', - 'claude-opus-4@20250514', + "claude-3-5-haiku@20241022", + "claude-3-5-sonnet-v2@20241022", + "claude-3-5-sonnet@20240620", + "claude-3-haiku@20240307", + "claude-3-opus@20240229", + "claude-3-sonnet@20240229", + "claude-sonnet-4@20250514", + "claude-opus-4@20250514", ], ) def test_match_claude_family(model_name): @@ -58,5 +58,5 @@ def test_match_claude_family(model_name): def test_non_exist_model(): with pytest.raises(ValueError) as e_info: - models.LLMRegistry.resolve('non-exist-model') - assert 'Model non-exist-model not found.' in str(e_info.value) + models.LLMRegistry.resolve("non-exist-model") + assert "Model non-exist-model not found." in str(e_info.value) From 5e071fbd81ea6b4d766acb7772eca490f7b6c90b Mon Sep 17 00:00:00 2001 From: Eitan Yarmush Date: Wed, 10 Sep 2025 18:16:44 +0000 Subject: [PATCH 2/2] revert Signed-off-by: Eitan Yarmush --- tests/unittests/models/test_models.py | 42 +++++++++++++-------------- 1 file changed, 21 insertions(+), 21 deletions(-) diff --git a/tests/unittests/models/test_models.py b/tests/unittests/models/test_models.py index ea952af804..70246c7bc1 100644 --- a/tests/unittests/models/test_models.py +++ b/tests/unittests/models/test_models.py @@ -20,17 +20,17 @@ @pytest.mark.parametrize( - "model_name", + 'model_name', [ - "gemini-1.5-flash", - "gemini-1.5-flash-001", - "gemini-1.5-flash-002", - "gemini-1.5-pro", - "gemini-1.5-pro-001", - "gemini-1.5-pro-002", - "gemini-2.0-flash-exp", - "projects/123456/locations/us-central1/endpoints/123456", # finetuned vertex gemini endpoint - "projects/123456/locations/us-central1/publishers/google/models/gemini-2.0-flash-exp", # vertex gemini long name + 'gemini-1.5-flash', + 'gemini-1.5-flash-001', + 'gemini-1.5-flash-002', + 'gemini-1.5-pro', + 'gemini-1.5-pro-001', + 'gemini-1.5-pro-002', + 'gemini-2.0-flash-exp', + 'projects/123456/locations/us-central1/endpoints/123456', # finetuned vertex gemini endpoint + 'projects/123456/locations/us-central1/publishers/google/models/gemini-2.0-flash-exp', # vertex gemini long name ], ) def test_match_gemini_family(model_name): @@ -38,16 +38,16 @@ def test_match_gemini_family(model_name): @pytest.mark.parametrize( - "model_name", + 'model_name', [ - "claude-3-5-haiku@20241022", - "claude-3-5-sonnet-v2@20241022", - "claude-3-5-sonnet@20240620", - "claude-3-haiku@20240307", - "claude-3-opus@20240229", - "claude-3-sonnet@20240229", - "claude-sonnet-4@20250514", - "claude-opus-4@20250514", + 'claude-3-5-haiku@20241022', + 'claude-3-5-sonnet-v2@20241022', + 'claude-3-5-sonnet@20240620', + 'claude-3-haiku@20240307', + 'claude-3-opus@20240229', + 'claude-3-sonnet@20240229', + 'claude-sonnet-4@20250514', + 'claude-opus-4@20250514', ], ) def test_match_claude_family(model_name): @@ -58,5 +58,5 @@ def test_match_claude_family(model_name): def test_non_exist_model(): with pytest.raises(ValueError) as e_info: - models.LLMRegistry.resolve("non-exist-model") - assert "Model non-exist-model not found." in str(e_info.value) + models.LLMRegistry.resolve('non-exist-model') + assert 'Model non-exist-model not found.' in str(e_info.value)