diff --git a/core/src/main/java/com/google/adk/events/EventActions.java b/core/src/main/java/com/google/adk/events/EventActions.java index 972870be..d183c6f7 100644 --- a/core/src/main/java/com/google/adk/events/EventActions.java +++ b/core/src/main/java/com/google/adk/events/EventActions.java @@ -17,6 +17,7 @@ import com.fasterxml.jackson.annotation.JsonProperty; import com.fasterxml.jackson.databind.annotation.JsonDeserialize; +import com.google.adk.tools.ToolConfirmation; import com.google.errorprone.annotations.CanIgnoreReturnValue; import com.google.genai.types.Part; import java.util.Objects; @@ -37,6 +38,8 @@ public class EventActions { private Optional escalate = Optional.empty(); private ConcurrentMap> requestedAuthConfigs = new ConcurrentHashMap<>(); + private ConcurrentMap requestedToolConfirmations = + new ConcurrentHashMap<>(); private Optional endInvocation = Optional.empty(); /** Default constructor for Jackson. */ @@ -113,6 +116,16 @@ public void setRequestedAuthConfigs( this.requestedAuthConfigs = requestedAuthConfigs; } + @JsonProperty("requestedToolConfirmations") + public ConcurrentMap requestedToolConfirmations() { + return requestedToolConfirmations; + } + + public void setRequestedToolConfirmations( + ConcurrentMap requestedToolConfirmations) { + this.requestedToolConfirmations = requestedToolConfirmations; + } + @JsonProperty("endInvocation") public Optional endInvocation() { return endInvocation; @@ -148,6 +161,7 @@ public boolean equals(Object o) { && Objects.equals(transferToAgent, that.transferToAgent) && Objects.equals(escalate, that.escalate) && Objects.equals(requestedAuthConfigs, that.requestedAuthConfigs) + && Objects.equals(requestedToolConfirmations, that.requestedToolConfirmations) && Objects.equals(endInvocation, that.endInvocation); } @@ -160,6 +174,7 @@ public int hashCode() { transferToAgent, escalate, requestedAuthConfigs, + requestedToolConfirmations, endInvocation); } @@ -172,6 +187,8 @@ public static class Builder { private Optional escalate = Optional.empty(); private ConcurrentMap> requestedAuthConfigs = new ConcurrentHashMap<>(); + private ConcurrentMap requestedToolConfirmations = + new ConcurrentHashMap<>(); private Optional endInvocation = Optional.empty(); public Builder() {} @@ -183,6 +200,8 @@ private Builder(EventActions eventActions) { this.transferToAgent = eventActions.transferToAgent(); this.escalate = eventActions.escalate(); this.requestedAuthConfigs = new ConcurrentHashMap<>(eventActions.requestedAuthConfigs()); + this.requestedToolConfirmations = + new ConcurrentHashMap<>(eventActions.requestedToolConfirmations()); this.endInvocation = eventActions.endInvocation(); } @@ -229,6 +248,13 @@ public Builder requestedAuthConfigs( return this; } + @CanIgnoreReturnValue + @JsonProperty("requestedToolConfirmations") + public Builder requestedToolConfirmations(ConcurrentMap value) { + this.requestedToolConfirmations = value; + return this; + } + @CanIgnoreReturnValue @JsonProperty("endInvocation") public Builder endInvocation(boolean endInvocation) { @@ -256,6 +282,9 @@ public Builder merge(EventActions other) { if (other.requestedAuthConfigs() != null) { this.requestedAuthConfigs.putAll(other.requestedAuthConfigs()); } + if (other.requestedToolConfirmations() != null) { + this.requestedToolConfirmations.putAll(other.requestedToolConfirmations()); + } if (other.endInvocation().isPresent()) { this.endInvocation = other.endInvocation(); } @@ -270,6 +299,7 @@ public EventActions build() { eventActions.setTransferToAgent(this.transferToAgent); eventActions.setEscalate(this.escalate); eventActions.setRequestedAuthConfigs(this.requestedAuthConfigs); + eventActions.setRequestedToolConfirmations(this.requestedToolConfirmations); eventActions.setEndInvocation(this.endInvocation); return eventActions; } diff --git a/core/src/main/java/com/google/adk/tools/ToolConfirmation.java b/core/src/main/java/com/google/adk/tools/ToolConfirmation.java new file mode 100644 index 00000000..8a588b52 --- /dev/null +++ b/core/src/main/java/com/google/adk/tools/ToolConfirmation.java @@ -0,0 +1,71 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.adk.tools; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; +import com.fasterxml.jackson.databind.annotation.JsonDeserialize; +import com.google.auto.value.AutoValue; +import com.google.errorprone.annotations.CanIgnoreReturnValue; +import javax.annotation.Nullable; + +/** Represents a tool confirmation configuration. */ +@AutoValue +@JsonDeserialize(builder = ToolConfirmation.Builder.class) +public abstract class ToolConfirmation { + + @Nullable + @JsonProperty("hint") + public abstract String hint(); + + @JsonProperty("confirmed") + public abstract boolean confirmed(); + + @Nullable + @JsonProperty("payload") + public abstract Object payload(); + + public static Builder builder() { + return new AutoValue_ToolConfirmation.Builder().hint("").confirmed(false); + } + + public abstract Builder toBuilder(); + + /** Builder for {@link ToolConfirmation}. */ + @AutoValue.Builder + public abstract static class Builder { + @CanIgnoreReturnValue + @JsonProperty("hint") + public abstract Builder hint(@Nullable String hint); + + @CanIgnoreReturnValue + @JsonProperty("confirmed") + public abstract Builder confirmed(boolean confirmed); + + @CanIgnoreReturnValue + @JsonProperty("payload") + public abstract Builder payload(@Nullable Object payload); + + /** For internal usage. Please use `ToolConfirmation.builder()` for instantiation. */ + @JsonCreator + private static Builder create() { + return new AutoValue_ToolConfirmation.Builder(); + } + + public abstract ToolConfirmation build(); + } +} diff --git a/core/src/main/java/com/google/adk/tools/ToolContext.java b/core/src/main/java/com/google/adk/tools/ToolContext.java index f6a55431..76fe47a7 100644 --- a/core/src/main/java/com/google/adk/tools/ToolContext.java +++ b/core/src/main/java/com/google/adk/tools/ToolContext.java @@ -23,17 +23,21 @@ import com.google.errorprone.annotations.CanIgnoreReturnValue; import io.reactivex.rxjava3.core.Single; import java.util.Optional; +import javax.annotation.Nullable; /** ToolContext object provides a structured context for executing tools or functions. */ public class ToolContext extends CallbackContext { private Optional functionCallId = Optional.empty(); + private Optional toolConfirmation = Optional.empty(); private ToolContext( InvocationContext invocationContext, EventActions eventActions, - Optional functionCallId) { + Optional functionCallId, + Optional toolConfirmation) { super(invocationContext, eventActions); this.functionCallId = functionCallId; + this.toolConfirmation = toolConfirmation; } public EventActions actions() { @@ -52,6 +56,14 @@ public void functionCallId(String functionCallId) { this.functionCallId = Optional.ofNullable(functionCallId); } + public Optional toolConfirmation() { + return toolConfirmation; + } + + public void toolConfirmation(ToolConfirmation toolConfirmation) { + this.toolConfirmation = Optional.ofNullable(toolConfirmation); + } + @SuppressWarnings("unused") private void requestCredential() { // TODO: b/414678311 - Implement credential request logic. Make this public. @@ -64,6 +76,35 @@ private void getAuthResponse() { throw new UnsupportedOperationException("Auth response retrieval not implemented yet."); } + /** + * Requests confirmation for the given function call. + * + * @param hint A hint to the user on how to confirm the tool call. + * @param payload The payload used to confirm the tool call. + */ + public void requestConfirmation(@Nullable String hint, @Nullable Object payload) { + if (functionCallId.isEmpty()) { + throw new IllegalStateException("function_call_id is not set."); + } + this.eventActions + .requestedToolConfirmations() + .put(functionCallId.get(), ToolConfirmation.builder().hint(hint).payload(payload).build()); + } + + /** + * Requests confirmation for the given function call. + * + * @param hint A hint to the user on how to confirm the tool call. + */ + public void requestConfirmation(@Nullable String hint) { + requestConfirmation(hint, null); + } + + /** Requests confirmation for the given function call. */ + public void requestConfirmation() { + requestConfirmation(null, null); + } + /** Searches the memory of the current user. */ public Single searchMemory(String query) { if (invocationContext.memoryService() == null) { @@ -82,7 +123,8 @@ public static Builder builder(InvocationContext invocationContext) { public Builder toBuilder() { return new Builder(invocationContext) .actions(eventActions) - .functionCallId(functionCallId.orElse(null)); + .functionCallId(functionCallId.orElse(null)) + .toolConfirmation(toolConfirmation.orElse(null)); } /** Builder for {@link ToolContext}. */ @@ -90,6 +132,7 @@ public static final class Builder { private final InvocationContext invocationContext; private EventActions eventActions = EventActions.builder().build(); // Default empty actions private Optional functionCallId = Optional.empty(); + private Optional toolConfirmation = Optional.empty(); private Builder(InvocationContext invocationContext) { this.invocationContext = invocationContext; @@ -107,8 +150,14 @@ public Builder functionCallId(String functionCallId) { return this; } + @CanIgnoreReturnValue + public Builder toolConfirmation(ToolConfirmation toolConfirmation) { + this.toolConfirmation = Optional.ofNullable(toolConfirmation); + return this; + } + public ToolContext build() { - return new ToolContext(invocationContext, eventActions, functionCallId); + return new ToolContext(invocationContext, eventActions, functionCallId, toolConfirmation); } } } diff --git a/core/src/test/java/com/google/adk/tools/ToolContextTest.java b/core/src/test/java/com/google/adk/tools/ToolContextTest.java index d8900352..d6dd3bcc 100644 --- a/core/src/test/java/com/google/adk/tools/ToolContextTest.java +++ b/core/src/test/java/com/google/adk/tools/ToolContextTest.java @@ -80,4 +80,51 @@ public void listArtifacts_noArtifacts_returnsEmptyList() { assertThat(filenames).isEmpty(); } + + @Test + public void requestConfirmation_noFunctionCallId_throwsException() { + ToolContext toolContext = ToolContext.builder(mockInvocationContext).build(); + IllegalStateException exception = + assertThrows( + IllegalStateException.class, () -> toolContext.requestConfirmation(null, null)); + assertThat(exception).hasMessageThat().isEqualTo("function_call_id is not set."); + } + + @Test + public void requestConfirmation_withHintAndPayload_setsToolConfirmation() { + ToolContext toolContext = + ToolContext.builder(mockInvocationContext).functionCallId("testId").build(); + toolContext.requestConfirmation("testHint", "testPayload"); + assertThat(toolContext.actions().requestedToolConfirmations()) + .containsExactly( + "testId", ToolConfirmation.builder().hint("testHint").payload("testPayload").build()); + } + + @Test + public void requestConfirmation_withHint_setsToolConfirmation() { + ToolContext toolContext = + ToolContext.builder(mockInvocationContext).functionCallId("testId").build(); + toolContext.requestConfirmation("testHint"); + assertThat(toolContext.actions().requestedToolConfirmations()) + .containsExactly( + "testId", ToolConfirmation.builder().hint("testHint").payload(null).build()); + } + + @Test + public void requestConfirmation_noHintOrPayload_setsToolConfirmation() { + ToolContext toolContext = + ToolContext.builder(mockInvocationContext).functionCallId("testId").build(); + toolContext.requestConfirmation(); + assertThat(toolContext.actions().requestedToolConfirmations()) + .containsExactly("testId", ToolConfirmation.builder().hint(null).payload(null).build()); + } + + @Test + public void requestConfirmation_nullHint_setsToolConfirmation() { + ToolContext toolContext = + ToolContext.builder(mockInvocationContext).functionCallId("testId").build(); + toolContext.requestConfirmation(null); + assertThat(toolContext.actions().requestedToolConfirmations()) + .containsExactly("testId", ToolConfirmation.builder().hint(null).payload(null).build()); + } }