Skip to content

Commit 0c0c2ff

Browse files
committed
Call Gemini function even when text part is returned
* add tests Closes gh-2499 Signed-off-by: Vlad Stoian <[email protected]>
1 parent c4e434a commit 0c0c2ff

File tree

2 files changed

+44
-27
lines changed

2 files changed

+44
-27
lines changed

models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiChatModel.java

Lines changed: 20 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import java.util.Collection;
2121
import java.util.List;
2222
import java.util.Map;
23+
import java.util.stream.Collectors;
2324

2425
import com.fasterxml.jackson.annotation.JsonInclude;
2526
import com.fasterxml.jackson.annotation.JsonInclude.Include;
@@ -587,35 +588,27 @@ protected List<Generation> responseCandidateToGeneration(Candidate candidate) {
587588
.finishReason(candidateFinishReason.name())
588589
.build();
589590

590-
boolean isFunctionCall = candidate.getContent().getPartsList().stream().allMatch(Part::hasFunctionCall);
591-
592-
if (isFunctionCall) {
593-
List<AssistantMessage.ToolCall> assistantToolCalls = candidate.getContent()
594-
.getPartsList()
595-
.stream()
596-
.filter(part -> part.hasFunctionCall())
597-
.map(part -> {
598-
FunctionCall functionCall = part.getFunctionCall();
599-
var functionName = functionCall.getName();
600-
String functionArguments = structToJson(functionCall.getArgs());
601-
return new AssistantMessage.ToolCall("", "function", functionName, functionArguments);
602-
})
603-
.toList();
591+
var assistantToolCalls = candidate.getContent()
592+
.getPartsList()
593+
.stream()
594+
.filter(Part::hasFunctionCall)
595+
.map(part -> {
596+
FunctionCall functionCall = part.getFunctionCall();
597+
var functionName = functionCall.getName();
598+
String functionArguments = structToJson(functionCall.getArgs());
599+
return new AssistantMessage.ToolCall("", "function", functionName, functionArguments);
600+
})
601+
.toList();
604602

605-
AssistantMessage assistantMessage = new AssistantMessage("", messageMetadata, assistantToolCalls);
603+
var text = candidate.getContent()
604+
.getPartsList()
605+
.stream()
606+
.filter(Part::hasText)
607+
.map(Part::getText)
608+
.collect(Collectors.joining(System.lineSeparator()));
606609

607-
return List.of(new Generation(assistantMessage, chatGenerationMetadata));
608-
}
609-
else {
610-
List<Generation> generations = candidate.getContent()
611-
.getPartsList()
612-
.stream()
613-
.map(part -> new AssistantMessage(part.getText(), messageMetadata))
614-
.map(assistantMessage -> new Generation(assistantMessage, chatGenerationMetadata))
615-
.toList();
616-
617-
return generations;
618-
}
610+
return List.of(new Generation(new AssistantMessage(text, messageMetadata, assistantToolCalls),
611+
chatGenerationMetadata));
619612
}
620613

621614
private ChatResponseMetadata toChatResponseMetadata(Usage usage) {

models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/tool/VertexAiGeminiChatModelToolCallingIT.java

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,30 @@ public void functionCallExplicitOpenApiSchema() {
9696
assertThat(response.getResult().getOutput().getText()).contains("30", "10", "15");
9797
}
9898

99+
@Test
100+
public void functionCallModelReturnsMixedTextAndFunctionCallParts() {
101+
UserMessage userMessage = new UserMessage(
102+
"What can you tell me about the temperature in San Francisco, Paris and in Tokyo? Return the temperature in Celsius. Expose your thinking process.");
103+
104+
List<Message> messages = new ArrayList<>(List.of(userMessage));
105+
106+
var promptOptions = VertexAiGeminiChatOptions.builder()
107+
.model(VertexAiGeminiChatModel.ChatModel.GEMINI_2_0_FLASH)
108+
.toolCallbacks(List.of(FunctionToolCallback.builder("get_current_weather", new MockWeatherService())
109+
.description("Get the current weather in a given location.")
110+
.inputType(MockWeatherService.Request.class)
111+
.build()))
112+
.build();
113+
114+
ChatResponse chatResponse = this.chatModel.call(new Prompt(messages, promptOptions));
115+
116+
assertThat(chatResponse.getResult().getOutput().getText()).contains("30", "10", "15");
117+
118+
assertThat(chatResponse.getMetadata()).isNotNull();
119+
assertThat(chatResponse.getMetadata().getUsage()).isNotNull();
120+
assertThat(chatResponse.getMetadata().getUsage().getTotalTokens()).isGreaterThan(150).isLessThan(550);
121+
}
122+
99123
@Test
100124
public void functionCallTestInferredOpenApiSchema() {
101125

0 commit comments

Comments
 (0)