1717
1818package com .google .adk .flows .llmflows ;
1919
20+ import static com .google .common .collect .ImmutableMap .toImmutableMap ;
21+
2022import com .google .adk .Telemetry ;
2123import com .google .adk .agents .ActiveStreamingTool ;
2224import com .google .adk .agents .Callbacks .AfterToolCallback ;
2729import com .google .adk .events .EventActions ;
2830import com .google .adk .tools .BaseTool ;
2931import com .google .adk .tools .FunctionTool ;
32+ import com .google .adk .tools .ToolConfirmation ;
3033import com .google .adk .tools .ToolContext ;
3134import com .google .common .base .VerifyException ;
3235import com .google .common .collect .ImmutableList ;
5962public final class Functions {
6063
6164 private static final String AF_FUNCTION_CALL_ID_PREFIX = "adk-" ;
65+ public static final String REQUEST_CONFIRMATION_FUNCTION_CALL_NAME = "adk_request_confirmation" ;
6266 private static final Logger logger = LoggerFactory .getLogger (Functions .class );
6367
6468 /** Generates a unique ID for a function call. */
@@ -121,6 +125,15 @@ public static void populateClientFunctionCallId(Event modelResponseEvent) {
121125 /** Handles standard, non-streaming function calls. */
122126 public static Maybe <Event > handleFunctionCalls (
123127 InvocationContext invocationContext , Event functionCallEvent , Map <String , BaseTool > tools ) {
128+ return handleFunctionCalls (invocationContext , functionCallEvent , tools , null );
129+ }
130+
131+ /** Handles standard, non-streaming function calls with tool confirmations. */
132+ public static Maybe <Event > handleFunctionCalls (
133+ InvocationContext invocationContext ,
134+ Event functionCallEvent ,
135+ Map <String , BaseTool > tools ,
136+ @ Nullable Map <String , ToolConfirmation > toolConfirmations ) {
124137 ImmutableList <FunctionCall > functionCalls = functionCallEvent .functionCalls ();
125138
126139 List <Maybe <Event >> functionResponseEvents = new ArrayList <>();
@@ -133,6 +146,10 @@ public static Maybe<Event> handleFunctionCalls(
133146 ToolContext toolContext =
134147 ToolContext .builder (invocationContext )
135148 .functionCallId (functionCall .id ().orElse ("" ))
149+ .toolConfirmation (
150+ toolConfirmations != null
151+ ? toolConfirmations .get (functionCall .id ().orElse (null ))
152+ : null )
136153 .build ();
137154
138155 Map <String , Object > functionArgs = functionCall .args ().orElse (new HashMap <>());
@@ -553,5 +570,65 @@ private static Event buildResponseEvent(
553570 }
554571 }
555572
573+ /**
574+ * Generates a request confirmation event from a function response event.
575+ *
576+ * @param invocationContext The invocation context.
577+ * @param functionCallEvent The event containing the original function call.
578+ * @param functionResponseEvent The event containing the function response.
579+ * @return An optional event containing the request confirmation function call.
580+ */
581+ public static Optional <Event > generateRequestConfirmationEvent (
582+ InvocationContext invocationContext , Event functionCallEvent , Event functionResponseEvent ) {
583+ if (functionResponseEvent .actions ().requestedToolConfirmations ().isEmpty ()) {
584+ return Optional .empty ();
585+ }
586+
587+ List <Part > parts = new ArrayList <>();
588+ Set <String > longRunningToolIds = new HashSet <>();
589+ ImmutableMap <String , FunctionCall > functionCallsById =
590+ functionCallEvent .functionCalls ().stream ()
591+ .filter (fc -> fc .id ().isPresent ())
592+ .collect (toImmutableMap (fc -> fc .id ().get (), fc -> fc ));
593+
594+ for (Map .Entry <String , ToolConfirmation > entry :
595+ functionResponseEvent .actions ().requestedToolConfirmations ().entrySet ().stream ()
596+ .filter (fc -> functionCallsById .containsKey (fc .getKey ()))
597+ .collect (toImmutableMap (Map .Entry ::getKey , Map .Entry ::getValue ))
598+ .entrySet ()) {
599+
600+ FunctionCall requestConfirmationFunctionCall =
601+ FunctionCall .builder ()
602+ .name (REQUEST_CONFIRMATION_FUNCTION_CALL_NAME )
603+ .args (
604+ ImmutableMap .of (
605+ "originalFunctionCall" ,
606+ functionCallsById .get (entry .getKey ()),
607+ "toolConfirmation" ,
608+ entry .getValue ()))
609+ .id (generateClientFunctionCallId ())
610+ .build ();
611+
612+ longRunningToolIds .add (requestConfirmationFunctionCall .id ().get ());
613+ parts .add (Part .builder ().functionCall (requestConfirmationFunctionCall ).build ());
614+ }
615+
616+ if (parts .isEmpty ()) {
617+ return Optional .empty ();
618+ }
619+
620+ var contentBuilder = Content .builder ().parts (parts );
621+ functionResponseEvent .content ().flatMap (Content ::role ).ifPresent (contentBuilder ::role );
622+
623+ return Optional .of (
624+ Event .builder ()
625+ .invocationId (invocationContext .invocationId ())
626+ .author (invocationContext .agent ().name ())
627+ .branch (invocationContext .branch ())
628+ .content (contentBuilder .build ())
629+ .longRunningToolIds (longRunningToolIds )
630+ .build ());
631+ }
632+
556633 private Functions () {}
557634}
0 commit comments