diff --git a/src/google/adk/plugins/global_instruction_plugin.py b/src/google/adk/plugins/global_instruction_plugin.py index 4251f691f6..ed2a6d4821 100644 --- a/src/google/adk/plugins/global_instruction_plugin.py +++ b/src/google/adk/plugins/global_instruction_plugin.py @@ -79,9 +79,8 @@ async def before_model_callback( return None # Resolve the global instruction (handle both string and InstructionProvider) - readonly_context = ReadonlyContext(callback_context.invocation_context) final_global_instruction = await self._resolve_global_instruction( - readonly_context + callback_context ) if not final_global_instruction: diff --git a/tests/unittests/plugins/test_global_instruction_plugin.py b/tests/unittests/plugins/test_global_instruction_plugin.py index 2253b1fb5a..851f3a9334 100644 --- a/tests/unittests/plugins/test_global_instruction_plugin.py +++ b/tests/unittests/plugins/test_global_instruction_plugin.py @@ -43,7 +43,7 @@ async def test_global_instruction_plugin_with_string(): mock_invocation_context.session = mock_session mock_callback_context = Mock(spec=CallbackContext) - mock_callback_context.invocation_context = mock_invocation_context + mock_callback_context._invocation_context = mock_invocation_context llm_request = LlmRequest( model="gemini-1.5-flash", @@ -80,10 +80,10 @@ async def build_global_instruction(readonly_context: ReadonlyContext) -> str: ) mock_invocation_context = Mock(spec=InvocationContext) - mock_invocation_context.session = mock_session mock_callback_context = Mock(spec=CallbackContext) - mock_callback_context.invocation_context = mock_invocation_context + mock_callback_context._invocation_context = mock_invocation_context + mock_callback_context.session = mock_session llm_request = LlmRequest( model="gemini-1.5-flash", @@ -119,7 +119,7 @@ async def test_global_instruction_plugin_empty_instruction(): mock_invocation_context.session = mock_session mock_callback_context = Mock(spec=CallbackContext) - mock_callback_context.invocation_context = mock_invocation_context + mock_callback_context._invocation_context = mock_invocation_context llm_request = LlmRequest( model="gemini-1.5-flash", @@ -156,7 +156,7 @@ async def test_global_instruction_plugin_leads_existing(): mock_invocation_context.session = mock_session mock_callback_context = Mock(spec=CallbackContext) - mock_callback_context.invocation_context = mock_invocation_context + mock_callback_context._invocation_context = mock_invocation_context llm_request = LlmRequest( model="gemini-1.5-flash", @@ -191,7 +191,7 @@ async def test_global_instruction_plugin_prepends_to_list(): mock_invocation_context.session = mock_session mock_callback_context = Mock(spec=CallbackContext) - mock_callback_context.invocation_context = mock_invocation_context + mock_callback_context._invocation_context = mock_invocation_context llm_request = LlmRequest( model="gemini-1.5-flash",