Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
104 changes: 96 additions & 8 deletions core/src/test/java/com/google/adk/models/GeminiLlmConnectionTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,14 @@

import com.google.common.collect.ImmutableList;
import com.google.genai.types.Content;
import com.google.genai.types.FunctionCall;
import com.google.genai.types.LiveServerContent;
import com.google.genai.types.LiveServerMessage;
import com.google.genai.types.LiveServerSetupComplete;
import com.google.genai.types.LiveServerToolCall;
import com.google.genai.types.LiveServerToolCallCancellation;
import com.google.genai.types.Part;
import com.google.genai.types.UsageMetadata;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
Expand All @@ -34,8 +39,7 @@ public final class GeminiLlmConnectionTest {
public void convertToServerResponse_withInterruptedTrue_mapsInterruptedField() {
LiveServerContent serverContent =
LiveServerContent.builder()
.modelTurn(
Content.builder().parts(ImmutableList.of(Part.fromText("Model response"))).build())
.modelTurn(Content.fromParts(Part.fromText("Model response")))
.turnComplete(false)
.interrupted(true)
.build();
Expand All @@ -55,10 +59,7 @@ public void convertToServerResponse_withInterruptedTrue_mapsInterruptedField() {
public void convertToServerResponse_withInterruptedFalse_mapsInterruptedField() {
LiveServerContent serverContent =
LiveServerContent.builder()
.modelTurn(
Content.builder()
.parts(ImmutableList.of(Part.fromText("Continuing response")))
.build())
.modelTurn(Content.fromParts(Part.fromText("Continuing response")))
.turnComplete(false)
.interrupted(false)
.build();
Expand All @@ -75,8 +76,7 @@ public void convertToServerResponse_withInterruptedFalse_mapsInterruptedField()
public void convertToServerResponse_withoutInterruptedField_mapsEmptyOptional() {
LiveServerContent serverContent =
LiveServerContent.builder()
.modelTurn(
Content.builder().parts(ImmutableList.of(Part.fromText("Normal response"))).build())
.modelTurn(Content.fromParts(Part.fromText("Normal response")))
.turnComplete(true)
.build();

Expand All @@ -87,4 +87,92 @@ public void convertToServerResponse_withoutInterruptedField_mapsEmptyOptional()
assertThat(response.interrupted()).isEmpty();
assertThat(response.turnComplete()).hasValue(true);
}

@Test
public void convertToServerResponse_withTurnCompleteTrue_mapsPartialFalse() {
LiveServerContent serverContent =
LiveServerContent.builder()
.modelTurn(Content.fromParts(Part.fromText("Final response")))
.turnComplete(true)
.build();

LiveServerMessage message = LiveServerMessage.builder().serverContent(serverContent).build();

LlmResponse response = GeminiLlmConnection.convertToServerResponse(message).get();

assertThat(response.partial()).hasValue(false);
assertThat(response.turnComplete()).hasValue(true);
}

@Test
public void convertToServerResponse_withTurnCompleteFalse_mapsPartialTrue() {
LiveServerContent serverContent =
LiveServerContent.builder()
.modelTurn(Content.fromParts(Part.fromText("Partial response")))
.turnComplete(false)
.build();

LiveServerMessage message = LiveServerMessage.builder().serverContent(serverContent).build();

LlmResponse response = GeminiLlmConnection.convertToServerResponse(message).get();

assertThat(response.partial()).hasValue(true);
assertThat(response.turnComplete()).hasValue(false);
}

@Test
public void convertToServerResponse_withToolCall_mapsContentWithFunctionCall() {
FunctionCall functionCall = FunctionCall.builder().name("tool").build();
LiveServerToolCall toolCall =
LiveServerToolCall.builder().functionCalls(ImmutableList.of(functionCall)).build();

LiveServerMessage message = LiveServerMessage.builder().toolCall(toolCall).build();

LlmResponse response = GeminiLlmConnection.convertToServerResponse(message).get();

assertThat(response.content()).isPresent();
assertThat(response.content().get().parts()).isPresent();
assertThat(response.content().get().parts().get()).hasSize(1);
assertThat(response.content().get().parts().get().get(0).functionCall()).hasValue(functionCall);
assertThat(response.partial()).hasValue(false);
assertThat(response.turnComplete()).hasValue(false);
}

@Test
public void convertToServerResponse_withUsageMetadata_returnsEmptyOptional() {
LiveServerMessage message =
LiveServerMessage.builder().usageMetadata(UsageMetadata.builder().build()).build();

assertThat(GeminiLlmConnection.convertToServerResponse(message)).isEmpty();
}

@Test
public void convertToServerResponse_withToolCallCancellation_returnsEmptyOptional() {
LiveServerMessage message =
LiveServerMessage.builder()
.toolCallCancellation(LiveServerToolCallCancellation.builder().build())
.build();

assertThat(GeminiLlmConnection.convertToServerResponse(message)).isEmpty();
}

@Test
public void convertToServerResponse_withSetupComplete_returnsEmptyOptional() {
LiveServerMessage message =
LiveServerMessage.builder()
.setupComplete(LiveServerSetupComplete.builder().build())
.build();

assertThat(GeminiLlmConnection.convertToServerResponse(message)).isEmpty();
}

@Test
public void convertToServerResponse_withUnknownMessage_returnsErrorResponse() {
LiveServerMessage message = LiveServerMessage.builder().build();

LlmResponse response = GeminiLlmConnection.convertToServerResponse(message).get();

assertThat(response.errorCode()).isPresent();
assertThat(response.errorMessage()).hasValue("Received unknown server message.");
}
}