diff --git a/models/spring-ai-google-genai/src/main/java/org/springframework/ai/google/genai/GoogleGenAiChatModel.java b/models/spring-ai-google-genai/src/main/java/org/springframework/ai/google/genai/GoogleGenAiChatModel.java index 668c1e5a0d7..be6bc4a6d3b 100644 --- a/models/spring-ai-google-genai/src/main/java/org/springframework/ai/google/genai/GoogleGenAiChatModel.java +++ b/models/spring-ai-google-genai/src/main/java/org/springframework/ai/google/genai/GoogleGenAiChatModel.java @@ -90,6 +90,7 @@ import org.springframework.retry.support.RetryTemplate; import org.springframework.util.Assert; import org.springframework.util.CollectionUtils; +import org.springframework.util.MimeType; import org.springframework.util.StringUtils; /** @@ -621,14 +622,18 @@ protected List responseCandidateToGeneration(Candidate candidate) { return List.of(new Generation(assistantMessage, chatGenerationMetadata)); } else { - return candidate.content() - .get() - .parts() - .orElse(List.of()) - .stream() - .map(part -> new AssistantMessage(part.text().orElse(""), messageMetadata)) - .map(assistantMessage -> new Generation(assistantMessage, chatGenerationMetadata)) - .toList(); + return candidate.content().flatMap(Content::parts).orElse(List.of()).stream().map(part -> { + // Multimodality Response Support + List media = part.inlineData() + .filter(blob -> blob.data().isPresent() && blob.mimeType().isPresent()) + .map(blob -> Media.builder() + .mimeType(MimeType.valueOf(blob.mimeType().get())) + .data(blob.data().get()) + .build()) + .map(List::of) + .orElse(List.of()); + return new AssistantMessage(part.text().orElse(""), messageMetadata, List.of(), media); + }).map(assistantMessage -> new Generation(assistantMessage, chatGenerationMetadata)).toList(); } } @@ -725,6 +730,14 @@ GeminiRequest createGeminiRequest(Prompt prompt) { configBuilder.systemInstruction(systemContents.get(0)); } + if (!CollectionUtils.isEmpty(requestOptions.getResponseModalities())) { + configBuilder.responseModalities(requestOptions.getResponseModalities()); + } + + if (requestOptions.getImageConfig() != null) { + configBuilder.imageConfig(requestOptions.getImageConfig().convert()); + } + GenerateContentConfig config = configBuilder.build(); // Create message contents @@ -850,7 +863,7 @@ public static final class Builder { private GoogleGenAiChatOptions defaultOptions = GoogleGenAiChatOptions.builder() .temperature(0.7) .topP(1.0) - .model(GoogleGenAiChatModel.ChatModel.GEMINI_2_0_FLASH) + .model(ChatModel.GEMINI_2_0_FLASH) .build(); private ToolCallingManager toolCallingManager; diff --git a/models/spring-ai-google-genai/src/main/java/org/springframework/ai/google/genai/GoogleGenAiChatOptions.java b/models/spring-ai-google-genai/src/main/java/org/springframework/ai/google/genai/GoogleGenAiChatOptions.java index 7e05e5fc921..9d6a20e2ab1 100644 --- a/models/spring-ai-google-genai/src/main/java/org/springframework/ai/google/genai/GoogleGenAiChatOptions.java +++ b/models/spring-ai-google-genai/src/main/java/org/springframework/ai/google/genai/GoogleGenAiChatOptions.java @@ -113,6 +113,17 @@ public class GoogleGenAiChatOptions implements ToolCallingChatOptions { */ private @JsonProperty("thinkingBudget") Integer thinkingBudget; + /** + * Optional. Response Modalities. + * @see com.google.genai.types.Modality.Known + */ + private @JsonProperty("responseModalities") List responseModalities = new ArrayList<>(); + + /** + * Optional. imageConfig + */ + private @JsonProperty("imageConfig") GoogleGenAiChatOptionsImageConfig imageConfig; + /** * Collection of {@link ToolCallback}s to be used for tool calling in the chat * completion requests. @@ -174,6 +185,8 @@ public static GoogleGenAiChatOptions fromOptions(GoogleGenAiChatOptions fromOpti options.setToolContext(fromOptions.getToolContext()); options.setThinkingBudget(fromOptions.getThinkingBudget()); options.setLabels(fromOptions.getLabels()); + options.setResponseModalities(fromOptions.getResponseModalities()); + options.setImageConfig(fromOptions.getImageConfig()); return options; } @@ -355,6 +368,23 @@ public void setToolContext(Map toolContext) { this.toolContext = toolContext; } + public List getResponseModalities() { + return this.responseModalities; + } + + public void setResponseModalities(List responseModalities) { + Assert.notNull(responseModalities, "responseModalities cannot be null"); + this.responseModalities = responseModalities; + } + + public GoogleGenAiChatOptionsImageConfig getImageConfig() { + return this.imageConfig; + } + + public void setImageConfig(GoogleGenAiChatOptionsImageConfig imageConfig) { + this.imageConfig = imageConfig; + } + @Override public boolean equals(Object o) { if (this == o) { @@ -376,7 +406,9 @@ public boolean equals(Object o) { && Objects.equals(this.toolNames, that.toolNames) && Objects.equals(this.safetySettings, that.safetySettings) && Objects.equals(this.internalToolExecutionEnabled, that.internalToolExecutionEnabled) - && Objects.equals(this.toolContext, that.toolContext) && Objects.equals(this.labels, that.labels); + && Objects.equals(this.toolContext, that.toolContext) && Objects.equals(this.labels, that.labels) + && Objects.equals(this.responseModalities, that.responseModalities) + && Objects.equals(this.imageConfig, that.imageConfig); } @Override @@ -384,7 +416,8 @@ public int hashCode() { return Objects.hash(this.stopSequences, this.temperature, this.topP, this.topK, this.candidateCount, this.frequencyPenalty, this.presencePenalty, this.thinkingBudget, this.maxOutputTokens, this.model, this.responseMimeType, this.toolCallbacks, this.toolNames, this.googleSearchRetrieval, - this.safetySettings, this.internalToolExecutionEnabled, this.toolContext, this.labels); + this.safetySettings, this.internalToolExecutionEnabled, this.toolContext, this.labels, + this.responseModalities, this.imageConfig); } @Override @@ -396,7 +429,7 @@ public String toString() { + this.model + '\'' + ", responseMimeType='" + this.responseMimeType + '\'' + ", toolCallbacks=" + this.toolCallbacks + ", toolNames=" + this.toolNames + ", googleSearchRetrieval=" + this.googleSearchRetrieval + ", safetySettings=" + this.safetySettings + ", labels=" + this.labels - + '}'; + + ", responseModalities=" + this.responseModalities + ", imageConfig=" + this.imageConfig + '}'; } @Override @@ -530,6 +563,23 @@ public Builder labels(Map labels) { return this; } + public Builder responseModalities(List responseModalities) { + Assert.notNull(responseModalities, "responseModalities must not be null"); + this.options.responseModalities = responseModalities; + return this; + } + + public Builder responseModalitie(String responseModalitie) { + Assert.hasText(responseModalitie, "responseModalitie must not be empty"); + this.options.responseModalities.add(responseModalitie); + return this; + } + + public Builder imageConfig(GoogleGenAiChatOptionsImageConfig imageConfig) { + this.options.setImageConfig(imageConfig); + return this; + } + public GoogleGenAiChatOptions build() { return this.options; } diff --git a/models/spring-ai-google-genai/src/main/java/org/springframework/ai/google/genai/GoogleGenAiChatOptionsImageConfig.java b/models/spring-ai-google-genai/src/main/java/org/springframework/ai/google/genai/GoogleGenAiChatOptionsImageConfig.java new file mode 100644 index 00000000000..1d40e3a1327 --- /dev/null +++ b/models/spring-ai-google-genai/src/main/java/org/springframework/ai/google/genai/GoogleGenAiChatOptionsImageConfig.java @@ -0,0 +1,89 @@ +/* + * Copyright 2023-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.google.genai; + +import java.util.Objects; + +import com.fasterxml.jackson.annotation.JsonProperty; +import com.google.genai.types.ImageConfig; + +/** + * Google GenAI ImageConfig + * + * @author 楚孔响 + * @version 1.0.0 + * @date 2025-10-10 15:33:20 + */ +public class GoogleGenAiChatOptionsImageConfig { + + @JsonProperty("aspectRatio") + private String aspectRatio; + + public String getAspectRatio() { + return this.aspectRatio; + } + + public void setAspectRatio(String aspectRatio) { + this.aspectRatio = aspectRatio; + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (!(o instanceof GoogleGenAiChatOptionsImageConfig that)) { + return false; + } + + return Objects.equals(this.aspectRatio, that.aspectRatio); + } + + @Override + public int hashCode() { + return Objects.hash(this.aspectRatio); + } + + @Override + public String toString() { + return "GoogleGenAiChatOptionsImageConfig{" + "aspectRatio='" + this.aspectRatio + '\'' + '}'; + } + + public ImageConfig convert() { + return ImageConfig.builder().aspectRatio(this.aspectRatio).build(); + } + + public static Builder builder() { + return new Builder(); + } + + public static class Builder { + + private GoogleGenAiChatOptionsImageConfig imageConfig = new GoogleGenAiChatOptionsImageConfig(); + + public Builder aspectRatio(String aspectRatio) { + this.imageConfig.setAspectRatio(aspectRatio); + return this; + } + + public GoogleGenAiChatOptionsImageConfig build() { + return this.imageConfig; + } + + } + +} diff --git a/models/spring-ai-google-genai/src/test/java/org/springframework/ai/google/genai/GoogleGenAiChatOptionsTest.java b/models/spring-ai-google-genai/src/test/java/org/springframework/ai/google/genai/GoogleGenAiChatOptionsTest.java index 3521213bfb5..1b71d991073 100644 --- a/models/spring-ai-google-genai/src/test/java/org/springframework/ai/google/genai/GoogleGenAiChatOptionsTest.java +++ b/models/spring-ai-google-genai/src/test/java/org/springframework/ai/google/genai/GoogleGenAiChatOptionsTest.java @@ -16,8 +16,10 @@ package org.springframework.ai.google.genai; +import java.util.List; import java.util.Map; +import com.google.genai.types.Modality; import org.junit.jupiter.api.Test; import static org.assertj.core.api.Assertions.assertThat; @@ -153,4 +155,22 @@ public void testToStringWithLabels() { assertThat(toString).contains("test-model"); } + @Test + public void testResponseMultimodality() { + GoogleGenAiChatOptions options = GoogleGenAiChatOptions.builder() + .responseModalities(List.of(Modality.Known.TEXT.name(), Modality.Known.IMAGE.name())) + .build(); + String toString = options.toString(); + assertThat(toString).contains("responseModalities=[TEXT, IMAGE]"); + } + + @Test + public void testImageConfig() { + GoogleGenAiChatOptions options = GoogleGenAiChatOptions.builder() + .imageConfig(GoogleGenAiChatOptionsImageConfig.builder().aspectRatio("1:1").build()) + .build(); + String toString = options.toString(); + assertThat(toString).contains("imageConfig=GoogleGenAiChatOptionsImageConfig{aspectRatio='1:1'}"); + } + } diff --git a/pom.xml b/pom.xml index 5031dc90eed..6d37b51d894 100644 --- a/pom.xml +++ b/pom.xml @@ -282,7 +282,7 @@ 1.19.2 3.63.1 26.60.0 - 1.10.0 + 1.21.0 9.20.0 4.37.0 2.2.30 diff --git a/spring-ai-model/src/main/java/org/springframework/ai/chat/model/MessageAggregator.java b/spring-ai-model/src/main/java/org/springframework/ai/chat/model/MessageAggregator.java index 498c35b8d17..6f9120a8a7c 100644 --- a/spring-ai-model/src/main/java/org/springframework/ai/chat/model/MessageAggregator.java +++ b/spring-ai-model/src/main/java/org/springframework/ai/chat/model/MessageAggregator.java @@ -34,6 +34,7 @@ import org.springframework.ai.chat.metadata.PromptMetadata; import org.springframework.ai.chat.metadata.RateLimit; import org.springframework.ai.chat.metadata.Usage; +import org.springframework.ai.content.Media; import org.springframework.util.CollectionUtils; import org.springframework.util.StringUtils; @@ -60,6 +61,7 @@ public Flux aggregate(Flux fluxChatResponse, AtomicReference messageTextContentRef = new AtomicReference<>(new StringBuilder()); AtomicReference> messageMetadataMapRef = new AtomicReference<>(); AtomicReference> toolCallsRef = new AtomicReference<>(new ArrayList<>()); + AtomicReference> mediasRef = new AtomicReference<>(new ArrayList<>()); // ChatGeneration Metadata AtomicReference generationMetadataRef = new AtomicReference<>( @@ -80,6 +82,7 @@ public Flux aggregate(Flux fluxChatResponse, messageTextContentRef.set(new StringBuilder()); messageMetadataMapRef.set(new HashMap<>()); toolCallsRef.set(new ArrayList<>()); + mediasRef.set(new ArrayList<>()); metadataIdRef.set(""); metadataModelRef.set(""); metadataUsagePromptTokensRef.set(0); @@ -105,7 +108,9 @@ public Flux aggregate(Flux fluxChatResponse, if (!CollectionUtils.isEmpty(outputMessage.getToolCalls())) { toolCallsRef.get().addAll(outputMessage.getToolCalls()); } - + if (!CollectionUtils.isEmpty(outputMessage.getMedia())) { + mediasRef.get().addAll(outputMessage.getMedia()); + } } if (chatResponse.getMetadata() != null) { if (chatResponse.getMetadata().getUsage() != null) { @@ -137,6 +142,12 @@ public Flux aggregate(Flux fluxChatResponse, List toolCallsList = (List) toolCallsFromMetadata; toolCallsRef.get().addAll(toolCallsList); } + Object mediasFromMetadata = chatResponse.getMetadata().get("medias"); + if (mediasFromMetadata instanceof List) { + @SuppressWarnings("unchecked") + List mediasList = (List) mediasFromMetadata; + mediasRef.get().addAll(mediasList); + } } }).doOnComplete(() -> { @@ -152,18 +163,12 @@ public Flux aggregate(Flux fluxChatResponse, .promptMetadata(metadataPromptMetadataRef.get()) .build(); - AssistantMessage finalAssistantMessage; List collectedToolCalls = toolCallsRef.get(); + List collectedMedias = mediasRef.get(); - if (!CollectionUtils.isEmpty(collectedToolCalls)) { + AssistantMessage finalAssistantMessage = new AssistantMessage(messageTextContentRef.get().toString(), + messageMetadataMapRef.get(), collectedToolCalls, collectedMedias); - finalAssistantMessage = new AssistantMessage(messageTextContentRef.get().toString(), - messageMetadataMapRef.get(), collectedToolCalls); - } - else { - finalAssistantMessage = new AssistantMessage(messageTextContentRef.get().toString(), - messageMetadataMapRef.get()); - } onAggregationComplete.accept(new ChatResponse(List.of(new Generation(finalAssistantMessage, generationMetadataRef.get())), chatResponseMetadata)); @@ -171,6 +176,7 @@ public Flux aggregate(Flux fluxChatResponse, messageTextContentRef.set(new StringBuilder()); messageMetadataMapRef.set(new HashMap<>()); toolCallsRef.set(new ArrayList<>()); + mediasRef.set(new ArrayList<>()); metadataIdRef.set(""); metadataModelRef.set(""); metadataUsagePromptTokensRef.set(0);