Skip to content

Commit 1f67dc4

Browse files
google-genai-botcopybara-github
authored andcommitted
feat: Add HITL tool confirmation support to ADK/Java
This is a port of the python implementation and part of the "human in the loop" workflow. PiperOrigin-RevId: 820215719
1 parent 7f12064 commit 1f67dc4

File tree

14 files changed

+755
-9
lines changed

14 files changed

+755
-9
lines changed

core/src/main/java/com/google/adk/events/EventActions.java

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
import com.fasterxml.jackson.annotation.JsonProperty;
1919
import com.fasterxml.jackson.databind.annotation.JsonDeserialize;
20+
import com.google.adk.tools.ToolConfirmation;
2021
import com.google.errorprone.annotations.CanIgnoreReturnValue;
2122
import com.google.genai.types.Part;
2223
import java.util.Objects;
@@ -37,6 +38,8 @@ public class EventActions {
3738
private Optional<Boolean> escalate = Optional.empty();
3839
private ConcurrentMap<String, ConcurrentMap<String, Object>> requestedAuthConfigs =
3940
new ConcurrentHashMap<>();
41+
private ConcurrentMap<String, ToolConfirmation> requestedToolConfirmations =
42+
new ConcurrentHashMap<>();
4043
private Optional<Boolean> endInvocation = Optional.empty();
4144

4245
/** Default constructor for Jackson. */
@@ -113,6 +116,16 @@ public void setRequestedAuthConfigs(
113116
this.requestedAuthConfigs = requestedAuthConfigs;
114117
}
115118

119+
@JsonProperty("requestedToolConfirmations")
120+
public ConcurrentMap<String, ToolConfirmation> requestedToolConfirmations() {
121+
return requestedToolConfirmations;
122+
}
123+
124+
public void setRequestedToolConfirmations(
125+
ConcurrentMap<String, ToolConfirmation> requestedToolConfirmations) {
126+
this.requestedToolConfirmations = requestedToolConfirmations;
127+
}
128+
116129
@JsonProperty("endInvocation")
117130
public Optional<Boolean> endInvocation() {
118131
return endInvocation;
@@ -148,6 +161,7 @@ public boolean equals(Object o) {
148161
&& Objects.equals(transferToAgent, that.transferToAgent)
149162
&& Objects.equals(escalate, that.escalate)
150163
&& Objects.equals(requestedAuthConfigs, that.requestedAuthConfigs)
164+
&& Objects.equals(requestedToolConfirmations, that.requestedToolConfirmations)
151165
&& Objects.equals(endInvocation, that.endInvocation);
152166
}
153167

@@ -160,6 +174,7 @@ public int hashCode() {
160174
transferToAgent,
161175
escalate,
162176
requestedAuthConfigs,
177+
requestedToolConfirmations,
163178
endInvocation);
164179
}
165180

@@ -172,6 +187,8 @@ public static class Builder {
172187
private Optional<Boolean> escalate = Optional.empty();
173188
private ConcurrentMap<String, ConcurrentMap<String, Object>> requestedAuthConfigs =
174189
new ConcurrentHashMap<>();
190+
private ConcurrentMap<String, ToolConfirmation> requestedToolConfirmations =
191+
new ConcurrentHashMap<>();
175192
private Optional<Boolean> endInvocation = Optional.empty();
176193

177194
public Builder() {}
@@ -183,6 +200,8 @@ private Builder(EventActions eventActions) {
183200
this.transferToAgent = eventActions.transferToAgent();
184201
this.escalate = eventActions.escalate();
185202
this.requestedAuthConfigs = new ConcurrentHashMap<>(eventActions.requestedAuthConfigs());
203+
this.requestedToolConfirmations =
204+
new ConcurrentHashMap<>(eventActions.requestedToolConfirmations());
186205
this.endInvocation = eventActions.endInvocation();
187206
}
188207

@@ -229,6 +248,13 @@ public Builder requestedAuthConfigs(
229248
return this;
230249
}
231250

251+
@CanIgnoreReturnValue
252+
@JsonProperty("requestedToolConfirmations")
253+
public Builder requestedToolConfirmations(ConcurrentMap<String, ToolConfirmation> value) {
254+
this.requestedToolConfirmations = value;
255+
return this;
256+
}
257+
232258
@CanIgnoreReturnValue
233259
@JsonProperty("endInvocation")
234260
public Builder endInvocation(boolean endInvocation) {
@@ -256,6 +282,9 @@ public Builder merge(EventActions other) {
256282
if (other.requestedAuthConfigs() != null) {
257283
this.requestedAuthConfigs.putAll(other.requestedAuthConfigs());
258284
}
285+
if (other.requestedToolConfirmations() != null) {
286+
this.requestedToolConfirmations.putAll(other.requestedToolConfirmations());
287+
}
259288
if (other.endInvocation().isPresent()) {
260289
this.endInvocation = other.endInvocation();
261290
}
@@ -270,6 +299,7 @@ public EventActions build() {
270299
eventActions.setTransferToAgent(this.transferToAgent);
271300
eventActions.setEscalate(this.escalate);
272301
eventActions.setRequestedAuthConfigs(this.requestedAuthConfigs);
302+
eventActions.setRequestedToolConfirmations(this.requestedToolConfirmations);
273303
eventActions.setEndInvocation(this.endInvocation);
274304
return eventActions;
275305
}

core/src/main/java/com/google/adk/flows/llmflows/Contents.java

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,9 @@ private ImmutableList<Content> getContents(
109109
if (!isEventBelongsToBranch(currentBranch, event)) {
110110
continue;
111111
}
112+
if (isRequestConfirmationEvent(event)) {
113+
continue;
114+
}
112115

113116
// TODO: Skip auth events.
114117

@@ -511,4 +514,19 @@ private static boolean hasContentWithNonEmptyParts(Event event) {
511514
.map(list -> !list.isEmpty()) // Optional<Boolean>
512515
.orElse(false);
513516
}
517+
518+
/** Checks if the event is a request confirmation event. */
519+
private static boolean isRequestConfirmationEvent(Event event) {
520+
return event.content().flatMap(Content::parts).orElse(ImmutableList.of()).stream()
521+
.anyMatch(
522+
part ->
523+
part.functionCall()
524+
.flatMap(FunctionCall::name)
525+
.map(name -> name.equals(Functions.REQUEST_CONFIRMATION_FUNCTION_CALL_NAME))
526+
.orElse(false)
527+
|| part.functionResponse()
528+
.flatMap(FunctionResponse::name)
529+
.map(name -> name.equals(Functions.REQUEST_CONFIRMATION_FUNCTION_CALL_NAME))
530+
.orElse(false));
531+
}
514532
}

core/src/main/java/com/google/adk/flows/llmflows/Functions.java

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717

1818
package com.google.adk.flows.llmflows;
1919

20+
import static com.google.common.collect.ImmutableMap.toImmutableMap;
21+
2022
import com.google.adk.Telemetry;
2123
import com.google.adk.agents.ActiveStreamingTool;
2224
import com.google.adk.agents.Callbacks.AfterToolCallback;
@@ -27,6 +29,7 @@
2729
import com.google.adk.events.EventActions;
2830
import com.google.adk.tools.BaseTool;
2931
import com.google.adk.tools.FunctionTool;
32+
import com.google.adk.tools.ToolConfirmation;
3033
import com.google.adk.tools.ToolContext;
3134
import com.google.common.base.VerifyException;
3235
import com.google.common.collect.ImmutableList;
@@ -59,6 +62,7 @@
5962
public final class Functions {
6063

6164
private static final String AF_FUNCTION_CALL_ID_PREFIX = "adk-";
65+
public static final String REQUEST_CONFIRMATION_FUNCTION_CALL_NAME = "adk_request_confirmation";
6266
private static final Logger logger = LoggerFactory.getLogger(Functions.class);
6367

6468
/** Generates a unique ID for a function call. */
@@ -121,6 +125,15 @@ public static void populateClientFunctionCallId(Event modelResponseEvent) {
121125
/** Handles standard, non-streaming function calls. */
122126
public static Maybe<Event> handleFunctionCalls(
123127
InvocationContext invocationContext, Event functionCallEvent, Map<String, BaseTool> tools) {
128+
return handleFunctionCalls(invocationContext, functionCallEvent, tools, null);
129+
}
130+
131+
/** Handles standard, non-streaming function calls with tool confirmations. */
132+
public static Maybe<Event> handleFunctionCalls(
133+
InvocationContext invocationContext,
134+
Event functionCallEvent,
135+
Map<String, BaseTool> tools,
136+
@Nullable Map<String, ToolConfirmation> toolConfirmations) {
124137
ImmutableList<FunctionCall> functionCalls = functionCallEvent.functionCalls();
125138

126139
List<Maybe<Event>> functionResponseEvents = new ArrayList<>();
@@ -133,6 +146,10 @@ public static Maybe<Event> handleFunctionCalls(
133146
ToolContext toolContext =
134147
ToolContext.builder(invocationContext)
135148
.functionCallId(functionCall.id().orElse(""))
149+
.toolConfirmation(
150+
toolConfirmations != null
151+
? toolConfirmations.get(functionCall.id().orElse(null))
152+
: null)
136153
.build();
137154

138155
Map<String, Object> functionArgs = functionCall.args().orElse(new HashMap<>());
@@ -553,5 +570,65 @@ private static Event buildResponseEvent(
553570
}
554571
}
555572

573+
/**
574+
* Generates a request confirmation event from a function response event.
575+
*
576+
* @param invocationContext The invocation context.
577+
* @param functionCallEvent The event containing the original function call.
578+
* @param functionResponseEvent The event containing the function response.
579+
* @return An optional event containing the request confirmation function call.
580+
*/
581+
public static Optional<Event> generateRequestConfirmationEvent(
582+
InvocationContext invocationContext, Event functionCallEvent, Event functionResponseEvent) {
583+
if (functionResponseEvent.actions().requestedToolConfirmations().isEmpty()) {
584+
return Optional.empty();
585+
}
586+
587+
List<Part> parts = new ArrayList<>();
588+
Set<String> longRunningToolIds = new HashSet<>();
589+
ImmutableMap<String, FunctionCall> functionCallsById =
590+
functionCallEvent.functionCalls().stream()
591+
.filter(fc -> fc.id().isPresent())
592+
.collect(toImmutableMap(fc -> fc.id().get(), fc -> fc));
593+
594+
for (Map.Entry<String, ToolConfirmation> entry :
595+
functionResponseEvent.actions().requestedToolConfirmations().entrySet().stream()
596+
.filter(fc -> functionCallsById.containsKey(fc.getKey()))
597+
.collect(toImmutableMap(Map.Entry::getKey, Map.Entry::getValue))
598+
.entrySet()) {
599+
600+
FunctionCall requestConfirmationFunctionCall =
601+
FunctionCall.builder()
602+
.name(REQUEST_CONFIRMATION_FUNCTION_CALL_NAME)
603+
.args(
604+
ImmutableMap.of(
605+
"originalFunctionCall",
606+
functionCallsById.get(entry.getKey()),
607+
"toolConfirmation",
608+
entry.getValue()))
609+
.id(generateClientFunctionCallId())
610+
.build();
611+
612+
longRunningToolIds.add(requestConfirmationFunctionCall.id().get());
613+
parts.add(Part.builder().functionCall(requestConfirmationFunctionCall).build());
614+
}
615+
616+
if (parts.isEmpty()) {
617+
return Optional.empty();
618+
}
619+
620+
var contentBuilder = Content.builder().parts(parts);
621+
functionResponseEvent.content().flatMap(Content::role).ifPresent(contentBuilder::role);
622+
623+
return Optional.of(
624+
Event.builder()
625+
.invocationId(invocationContext.invocationId())
626+
.author(invocationContext.agent().name())
627+
.branch(invocationContext.branch())
628+
.content(contentBuilder.build())
629+
.longRunningToolIds(longRunningToolIds)
630+
.build());
631+
}
632+
556633
private Functions() {}
557634
}

0 commit comments

Comments
 (0)