Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions core/src/main/java/com/google/adk/flows/llmflows/Contents.java
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,9 @@ private ImmutableList<Content> getContents(
if (!isEventBelongsToBranch(currentBranch, event)) {
continue;
}
if (isRequestConfirmationEvent(event)) {
continue;
}

// TODO: Skip auth events.

Expand Down Expand Up @@ -511,4 +514,19 @@ private static boolean hasContentWithNonEmptyParts(Event event) {
.map(list -> !list.isEmpty()) // Optional<Boolean>
.orElse(false);
}

/** Checks if the event is a request confirmation event. */
private static boolean isRequestConfirmationEvent(Event event) {
return event.content().flatMap(Content::parts).orElse(ImmutableList.of()).stream()
.anyMatch(
part ->
part.functionCall()
.flatMap(FunctionCall::name)
.map(Functions.REQUEST_CONFIRMATION_FUNCTION_CALL_NAME::equals)
.orElse(false)
|| part.functionResponse()
.flatMap(FunctionResponse::name)
.map(Functions.REQUEST_CONFIRMATION_FUNCTION_CALL_NAME::equals)
.orElse(false));
}
}
110 changes: 93 additions & 17 deletions core/src/main/java/com/google/adk/flows/llmflows/Functions.java
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

package com.google.adk.flows.llmflows;

import static com.google.common.collect.ImmutableMap.toImmutableMap;

import com.google.adk.Telemetry;
import com.google.adk.agents.ActiveStreamingTool;
import com.google.adk.agents.Callbacks.AfterToolCallback;
Expand All @@ -27,6 +29,7 @@
import com.google.adk.events.EventActions;
import com.google.adk.tools.BaseTool;
import com.google.adk.tools.FunctionTool;
import com.google.adk.tools.ToolConfirmation;
import com.google.adk.tools.ToolContext;
import com.google.common.base.VerifyException;
import com.google.common.collect.ImmutableList;
Expand All @@ -52,14 +55,14 @@
import java.util.Optional;
import java.util.Set;
import java.util.UUID;
import org.jspecify.annotations.Nullable;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/** Utility class for handling function calls. */
public final class Functions {

private static final String AF_FUNCTION_CALL_ID_PREFIX = "adk-";
public static final String REQUEST_CONFIRMATION_FUNCTION_CALL_NAME = "adk_request_confirmation";
private static final Logger logger = LoggerFactory.getLogger(Functions.class);

/** Generates a unique ID for a function call. */
Expand Down Expand Up @@ -122,6 +125,15 @@ public static void populateClientFunctionCallId(Event modelResponseEvent) {
/** Handles standard, non-streaming function calls. */
public static Maybe<Event> handleFunctionCalls(
InvocationContext invocationContext, Event functionCallEvent, Map<String, BaseTool> tools) {
return handleFunctionCalls(invocationContext, functionCallEvent, tools, ImmutableMap.of());
}

/** Handles standard, non-streaming function calls with tool confirmations. */
public static Maybe<Event> handleFunctionCalls(
InvocationContext invocationContext,
Event functionCallEvent,
Map<String, BaseTool> tools,
Map<String, ToolConfirmation> toolConfirmations) {
ImmutableList<FunctionCall> functionCalls = functionCallEvent.functionCalls();

List<Maybe<Event>> functionResponseEvents = new ArrayList<>();
Expand All @@ -134,9 +146,10 @@ public static Maybe<Event> handleFunctionCalls(
ToolContext toolContext =
ToolContext.builder(invocationContext)
.functionCallId(functionCall.id().orElse(""))
.toolConfirmation(toolConfirmations.get(functionCall.id().orElse(null)))
.build();

Map<String, Object> functionArgs = functionCall.args().orElse(new HashMap<>());
Map<String, Object> functionArgs = functionCall.args().orElse(ImmutableMap.of());

Maybe<Map<String, Object>> maybeFunctionResult =
maybeInvokeBeforeToolCall(invocationContext, tool, functionArgs, toolContext)
Expand Down Expand Up @@ -192,10 +205,12 @@ public static Maybe<Event> handleFunctionCalls(
if (events.isEmpty()) {
return Maybe.empty();
}
Event mergedEvent = Functions.mergeParallelFunctionResponseEvents(events);
if (mergedEvent == null) {
Optional<Event> maybeMergedEvent =
Functions.mergeParallelFunctionResponseEvents(events);
if (maybeMergedEvent.isEmpty()) {
return Maybe.empty();
}
var mergedEvent = maybeMergedEvent.get();

if (events.size() > 1) {
Tracer tracer = Telemetry.getTracer();
Expand Down Expand Up @@ -288,7 +303,7 @@ public static Maybe<Event> handleFunctionCallsLive(
if (events.isEmpty()) {
return Maybe.empty();
}
return Maybe.just(Functions.mergeParallelFunctionResponseEvents(events));
return Maybe.just(Functions.mergeParallelFunctionResponseEvents(events).orElse(null));
});
}

Expand Down Expand Up @@ -387,13 +402,13 @@ public static Set<String> getLongRunningFunctionCalls(
return longRunningFunctionCalls;
}

private static @Nullable Event mergeParallelFunctionResponseEvents(
private static Optional<Event> mergeParallelFunctionResponseEvents(
List<Event> functionResponseEvents) {
if (functionResponseEvents.isEmpty()) {
return null;
return Optional.empty();
}
if (functionResponseEvents.size() == 1) {
return functionResponseEvents.get(0);
return Optional.of(functionResponseEvents.get(0));
}
// Use the first event as the base for common attributes
Event baseEvent = functionResponseEvents.get(0);
Expand All @@ -410,15 +425,16 @@ public static Set<String> getLongRunningFunctionCalls(
mergedActionsBuilder.merge(event.actions());
}

return Event.builder()
.id(Event.generateEventId())
.invocationId(baseEvent.invocationId())
.author(baseEvent.author())
.branch(baseEvent.branch())
.content(Optional.of(Content.builder().role("user").parts(mergedParts).build()))
.actions(mergedActionsBuilder.build())
.timestamp(baseEvent.timestamp())
.build();
return Optional.of(
Event.builder()
.id(Event.generateEventId())
.invocationId(baseEvent.invocationId())
.author(baseEvent.author())
.branch(baseEvent.branch())
.content(Optional.of(Content.builder().role("user").parts(mergedParts).build()))
.actions(mergedActionsBuilder.build())
.timestamp(baseEvent.timestamp())
.build());
}

private static Maybe<Map<String, Object>> maybeInvokeBeforeToolCall(
Expand Down Expand Up @@ -563,5 +579,65 @@ private static Event buildResponseEvent(
}
}

/**
* Generates a request confirmation event from a function response event.
*
* @param invocationContext The invocation context.
* @param functionCallEvent The event containing the original function call.
* @param functionResponseEvent The event containing the function response.
* @return An optional event containing the request confirmation function call.
*/
public static Optional<Event> generateRequestConfirmationEvent(
InvocationContext invocationContext, Event functionCallEvent, Event functionResponseEvent) {
if (functionResponseEvent.actions().requestedToolConfirmations().isEmpty()) {
return Optional.empty();
}

List<Part> parts = new ArrayList<>();
Set<String> longRunningToolIds = new HashSet<>();
ImmutableMap<String, FunctionCall> functionCallsById =
functionCallEvent.functionCalls().stream()
.filter(fc -> fc.id().isPresent())
.collect(toImmutableMap(fc -> fc.id().get(), fc -> fc));

for (Map.Entry<String, ToolConfirmation> entry :
functionResponseEvent.actions().requestedToolConfirmations().entrySet().stream()
.filter(fc -> functionCallsById.containsKey(fc.getKey()))
.collect(toImmutableMap(Map.Entry::getKey, Map.Entry::getValue))
.entrySet()) {

FunctionCall requestConfirmationFunctionCall =
FunctionCall.builder()
.name(REQUEST_CONFIRMATION_FUNCTION_CALL_NAME)
.args(
ImmutableMap.of(
"originalFunctionCall",
functionCallsById.get(entry.getKey()),
"toolConfirmation",
entry.getValue()))
.id(generateClientFunctionCallId())
.build();

longRunningToolIds.add(requestConfirmationFunctionCall.id().get());
parts.add(Part.builder().functionCall(requestConfirmationFunctionCall).build());
}

if (parts.isEmpty()) {
return Optional.empty();
}

var contentBuilder = Content.builder().parts(parts);
functionResponseEvent.content().flatMap(Content::role).ifPresent(contentBuilder::role);

return Optional.of(
Event.builder()
.invocationId(invocationContext.invocationId())
.author(invocationContext.agent().name())
.branch(invocationContext.branch())
.content(contentBuilder.build())
.longRunningToolIds(longRunningToolIds)
.build());
}

private Functions() {}
}
Loading