Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,11 @@

import org.springframework.ai.anthropic.api.AnthropicApi;
import org.springframework.ai.anthropic.api.AnthropicApi.ChatCompletionRequest;
import org.springframework.ai.anthropic.api.AnthropicApi.ChatCompletionRequest.OutputFormat;
import org.springframework.ai.anthropic.api.AnthropicCacheOptions;
import org.springframework.ai.anthropic.api.CitationDocument;
import org.springframework.ai.model.ModelOptionsUtils;
import org.springframework.ai.model.tool.StructuredOutputChatOptions;
import org.springframework.ai.model.tool.ToolCallingChatOptions;
import org.springframework.ai.tool.ToolCallback;
import org.springframework.lang.Nullable;
Expand All @@ -51,7 +54,7 @@
* @since 1.0.0
*/
@JsonInclude(Include.NON_NULL)
public class AnthropicChatOptions implements ToolCallingChatOptions {
public class AnthropicChatOptions implements ToolCallingChatOptions, StructuredOutputChatOptions {

// @formatter:off
private @JsonProperty("model") String model;
Expand Down Expand Up @@ -115,6 +118,11 @@ public void setCacheOptions(AnthropicCacheOptions cacheOptions) {
@JsonIgnore
private Map<String, String> httpHeaders = new HashMap<>();

/**
* The desired response format for structured output.
*/
private @JsonProperty("output_format") OutputFormat outputFormat;

// @formatter:on

public static Builder builder() {
Expand All @@ -141,6 +149,7 @@ public static AnthropicChatOptions fromOptions(AnthropicChatOptions fromOptions)
.cacheOptions(fromOptions.getCacheOptions())
.citationDocuments(fromOptions.getCitationDocuments() != null
? new ArrayList<>(fromOptions.getCitationDocuments()) : null)
.outputFormat(fromOptions.getOutputFormat())
.build();
}

Expand Down Expand Up @@ -325,6 +334,27 @@ public void validateCitationConsistency() {
}
}

public OutputFormat getOutputFormat() {
return this.outputFormat;
}

public void setOutputFormat(OutputFormat outputFormat) {
Assert.notNull(outputFormat, "outputFormat cannot be null");
this.outputFormat = outputFormat;
}

@Override
@JsonIgnore
public String getOutputSchema() {
return this.getOutputFormat() != null ? ModelOptionsUtils.toJsonString(this.getOutputFormat().schema()) : null;
}

@Override
@JsonIgnore
public void setOutputSchema(String outputSchema) {
this.setOutputFormat(new OutputFormat(outputSchema));
}

@Override
@SuppressWarnings("unchecked")
public AnthropicChatOptions copy() {
Expand All @@ -351,6 +381,7 @@ public boolean equals(Object o) {
&& Objects.equals(this.toolContext, that.toolContext)
&& Objects.equals(this.httpHeaders, that.httpHeaders)
&& Objects.equals(this.cacheOptions, that.cacheOptions)
&& Objects.equals(this.outputFormat, that.outputFormat)
&& Objects.equals(this.citationDocuments, that.citationDocuments);
}

Expand All @@ -359,7 +390,7 @@ public int hashCode() {
return Objects.hash(this.model, this.maxTokens, this.metadata, this.stopSequences, this.temperature, this.topP,
this.topK, this.toolChoice, this.thinking, this.toolCallbacks, this.toolNames,
this.internalToolExecutionEnabled, this.toolContext, this.httpHeaders, this.cacheOptions,
this.citationDocuments);
this.outputFormat, this.citationDocuments);
}

public static final class Builder {
Expand Down Expand Up @@ -501,6 +532,16 @@ public Builder addCitationDocument(CitationDocument document) {
return this;
}

public Builder outputFormat(OutputFormat outputFormat) {
this.options.outputFormat = outputFormat;
return this;
}

public Builder outputSchema(String outputSchema) {
this.options.setOutputSchema(outputSchema);
return this;
}

public AnthropicChatOptions build() {
this.options.validateCitationConsistency();
return this.options;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ public static Builder builder() {

public static final String DEFAULT_ANTHROPIC_VERSION = "2023-06-01";

public static final String DEFAULT_ANTHROPIC_BETA_VERSION = "tools-2024-04-04,pdfs-2024-09-25";
public static final String DEFAULT_ANTHROPIC_BETA_VERSION = "tools-2024-04-04,pdfs-2024-09-25,structured-outputs-2025-11-13";

public static final String BETA_EXTENDED_CACHE_TTL = "extended-cache-ttl-2025-04-11";

Expand Down Expand Up @@ -542,18 +542,20 @@ public record ChatCompletionRequest(
@JsonProperty("top_k") Integer topK,
@JsonProperty("tools") List<Tool> tools,
@JsonProperty("tool_choice") ToolChoice toolChoice,
@JsonProperty("thinking") ThinkingConfig thinking) {
@JsonProperty("thinking") ThinkingConfig thinking,
@JsonProperty("output_format") OutputFormat outputFormat) {
// @formatter:on

public ChatCompletionRequest(String model, List<AnthropicMessage> messages, Object system, Integer maxTokens,
Double temperature, Boolean stream) {
this(model, messages, system, maxTokens, null, null, stream, temperature, null, null, null, null, null);
this(model, messages, system, maxTokens, null, null, stream, temperature, null, null, null, null, null,
null);
}

public ChatCompletionRequest(String model, List<AnthropicMessage> messages, Object system, Integer maxTokens,
List<String> stopSequences, Double temperature, Boolean stream) {
this(model, messages, system, maxTokens, null, stopSequences, stream, temperature, null, null, null, null,
null);
null, null);
}

public static ChatCompletionRequestBuilder builder() {
Expand All @@ -564,6 +566,15 @@ public static ChatCompletionRequestBuilder from(ChatCompletionRequest request) {
return new ChatCompletionRequestBuilder(request);
}

@JsonInclude(Include.NON_NULL)
public record OutputFormat(@JsonProperty("type") String type,
@JsonProperty("schema") Map<String, Object> schema) {

public OutputFormat(String jsonSchema) {
this("json_schema", ModelOptionsUtils.jsonToMap(jsonSchema));
}
}

/**
* Metadata about the request.
*
Expand Down Expand Up @@ -631,6 +642,8 @@ public static final class ChatCompletionRequestBuilder {

private ChatCompletionRequest.ThinkingConfig thinking;

private ChatCompletionRequest.OutputFormat outputFormat;

private ChatCompletionRequestBuilder() {
}

Expand All @@ -648,6 +661,7 @@ private ChatCompletionRequestBuilder(ChatCompletionRequest request) {
this.tools = request.tools;
this.toolChoice = request.toolChoice;
this.thinking = request.thinking;
this.outputFormat = request.outputFormat;
}

public ChatCompletionRequestBuilder model(ChatModel model) {
Expand Down Expand Up @@ -725,10 +739,15 @@ public ChatCompletionRequestBuilder thinking(ThinkingType type, Integer budgetTo
return this;
}

public ChatCompletionRequestBuilder outputFormat(ChatCompletionRequest.OutputFormat outputFormat) {
this.outputFormat = outputFormat;
return this;
}

public ChatCompletionRequest build() {
return new ChatCompletionRequest(this.model, this.messages, this.system, this.maxTokens, this.metadata,
this.stopSequences, this.stream, this.temperature, this.topP, this.topK, this.tools,
this.toolChoice, this.thinking);
this.toolChoice, this.thinking, this.outputFormat);
}

}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
import org.springframework.ai.anthropic.AnthropicTestConfiguration;
import org.springframework.ai.anthropic.api.AnthropicApi;
import org.springframework.ai.anthropic.api.tool.MockWeatherService;
import org.springframework.ai.chat.client.AdvisorParams;
import org.springframework.ai.chat.client.ChatClient;
import org.springframework.ai.chat.client.advisor.SimpleLoggerAdvisor;
import org.springframework.ai.chat.model.ChatModel;
Expand Down Expand Up @@ -118,6 +119,25 @@ void listOutputConverterBean() {
assertThat(actorsFilms).hasSize(2);
}

@Test
void listOutputConverterBean2() {

// @formatter:off
List<ActorsFilms> actorsFilms = ChatClient.create(this.chatModel).prompt()
.advisors(AdvisorParams.ENABLE_NATIVE_STRUCTURED_OUTPUT)
.options(AnthropicChatOptions.builder()
.model(AnthropicApi.ChatModel.CLAUDE_SONNET_4_5)
.build())
.user("Generate the filmography of 5 movies for Tom Hanks and Bill Murray.")
.call()
.entity(new ParameterizedTypeReference<>() {
});
// @formatter:on

logger.info("" + actorsFilms);
assertThat(actorsFilms).hasSize(2);
}

@Test
void customOutputConverter() {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -707,6 +707,9 @@ GeminiRequest createGeminiRequest(Prompt prompt) {
if (requestOptions.getResponseMimeType() != null) {
configBuilder.responseMimeType(requestOptions.getResponseMimeType());
}
if (requestOptions.getResponseSchema() != null) {
configBuilder.responseJsonSchema(jsonToSchema(requestOptions.getResponseSchema()));
}
if (requestOptions.getFrequencyPenalty() != null) {
configBuilder.frequencyPenalty(requestOptions.getFrequencyPenalty().floatValue());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@

import org.springframework.ai.google.genai.GoogleGenAiChatModel.ChatModel;
import org.springframework.ai.google.genai.common.GoogleGenAiSafetySetting;
import org.springframework.ai.model.tool.StructuredOutputChatOptions;
import org.springframework.ai.model.tool.ToolCallingChatOptions;
import org.springframework.ai.tool.ToolCallback;
import org.springframework.lang.Nullable;
Expand All @@ -49,7 +50,7 @@
* @since 1.0.0
*/
@JsonInclude(Include.NON_NULL)
public class GoogleGenAiChatOptions implements ToolCallingChatOptions {
public class GoogleGenAiChatOptions implements ToolCallingChatOptions, StructuredOutputChatOptions {

// https://cloud.google.com/vertex-ai/docs/reference/rest/v1/GenerationConfig

Expand Down Expand Up @@ -97,6 +98,11 @@ public class GoogleGenAiChatOptions implements ToolCallingChatOptions {
*/
private @JsonProperty("responseMimeType") String responseMimeType;

/**
* Optional. Geminie response schema.
*/
private @JsonProperty("responseSchema") String responseSchema;

/**
* Optional. Frequency penalties.
*/
Expand Down Expand Up @@ -199,8 +205,8 @@ public static GoogleGenAiChatOptions fromOptions(GoogleGenAiChatOptions fromOpti
options.setModel(fromOptions.getModel());
options.setToolCallbacks(fromOptions.getToolCallbacks());
options.setResponseMimeType(fromOptions.getResponseMimeType());
options.setResponseSchema(fromOptions.getResponseSchema());
options.setToolNames(fromOptions.getToolNames());
options.setResponseMimeType(fromOptions.getResponseMimeType());
options.setGoogleSearchRetrieval(fromOptions.getGoogleSearchRetrieval());
options.setSafetySettings(fromOptions.getSafetySettings());
options.setInternalToolExecutionEnabled(fromOptions.getInternalToolExecutionEnabled());
Expand Down Expand Up @@ -295,6 +301,14 @@ public void setResponseMimeType(String mimeType) {
this.responseMimeType = mimeType;
}

public String getResponseSchema() {
return this.responseSchema;
}

public void setResponseSchema(String responseSchema) {
this.responseSchema = responseSchema;
}

@Override
public List<ToolCallback> getToolCallbacks() {
return this.toolCallbacks;
Expand Down Expand Up @@ -433,6 +447,18 @@ public void setToolContext(Map<String, Object> toolContext) {
this.toolContext = toolContext;
}

@Override
public String getOutputSchema() {
return this.getResponseSchema();
}

@Override
@JsonIgnore
public void setOutputSchema(String jsonSchemaText) {
this.setResponseSchema(jsonSchemaText);
this.setResponseMimeType("application/json");
}

@Override
public boolean equals(Object o) {
if (this == o) {
Expand All @@ -450,6 +476,7 @@ public boolean equals(Object o) {
&& Objects.equals(this.thinkingBudget, that.thinkingBudget)
&& Objects.equals(this.maxOutputTokens, that.maxOutputTokens) && Objects.equals(this.model, that.model)
&& Objects.equals(this.responseMimeType, that.responseMimeType)
&& Objects.equals(this.responseSchema, that.responseSchema)
&& Objects.equals(this.toolCallbacks, that.toolCallbacks)
&& Objects.equals(this.toolNames, that.toolNames)
&& Objects.equals(this.safetySettings, that.safetySettings)
Expand All @@ -461,8 +488,9 @@ public boolean equals(Object o) {
public int hashCode() {
return Objects.hash(this.stopSequences, this.temperature, this.topP, this.topK, this.candidateCount,
this.frequencyPenalty, this.presencePenalty, this.thinkingBudget, this.maxOutputTokens, this.model,
this.responseMimeType, this.toolCallbacks, this.toolNames, this.googleSearchRetrieval,
this.safetySettings, this.internalToolExecutionEnabled, this.toolContext, this.labels);
this.responseMimeType, this.responseSchema, this.toolCallbacks, this.toolNames,
this.googleSearchRetrieval, this.safetySettings, this.internalToolExecutionEnabled, this.toolContext,
this.labels);
}

@Override
Expand Down Expand Up @@ -548,6 +576,16 @@ public Builder responseMimeType(String mimeType) {
return this;
}

public Builder responseSchema(String responseSchema) {
this.options.setResponseSchema(responseSchema);
return this;
}

public Builder outputSchema(String jsonSchema) {
this.options.setOutputSchema(jsonSchema);
return this;
}

public Builder toolCallbacks(List<ToolCallback> toolCallbacks) {
this.options.toolCallbacks = toolCallbacks;
return this;
Expand Down
Loading