Skip to content

Commit bc361de

Browse files
committed
feat: Implement cache management for Anthropic API with eligibility tracking
fix: Update default minimum content length for caching and improve backward compatibility in cache options refactor: Replace deprecated cache strategy and TTL fields with unified cache options in AnthropicChatOptions feat: Enhance cache eligibility logic for tool definitions and add comprehensive tests fix: Improve cache eligibility checks and enhance test coverage for caching strategies refactor: Update caching documentation to reflect changes in cache options and strategies gh-4325: Enhance cache management for Anthropic API by introudicing per-message TTL and configurable content block usage optimization.
1 parent f5e8349 commit bc361de

File tree

14 files changed

+983
-241
lines changed

14 files changed

+983
-241
lines changed

models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/AnthropicChatModel.java

Lines changed: 65 additions & 114 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,9 @@
4242
import org.springframework.ai.anthropic.api.AnthropicApi.ContentBlock.Source;
4343
import org.springframework.ai.anthropic.api.AnthropicApi.ContentBlock.Type;
4444
import org.springframework.ai.anthropic.api.AnthropicApi.Role;
45-
import org.springframework.ai.anthropic.api.AnthropicCacheStrategy;
45+
import org.springframework.ai.anthropic.api.AnthropicCacheOptions;
46+
import org.springframework.ai.anthropic.api.AnthropicCacheTtl;
47+
import org.springframework.ai.anthropic.api.utils.CacheEligibilityResolver;
4648
import org.springframework.ai.chat.messages.AssistantMessage;
4749
import org.springframework.ai.chat.messages.Message;
4850
import org.springframework.ai.chat.messages.MessageType;
@@ -463,11 +465,9 @@ Prompt buildRequestPrompt(Prompt prompt) {
463465
requestOptions.setToolContext(ToolCallingChatOptions.mergeToolContext(runtimeOptions.getToolContext(),
464466
this.defaultOptions.getToolContext()));
465467

466-
// Merge cache strategy and TTL (also @JsonIgnore fields)
467-
requestOptions.setCacheStrategy(runtimeOptions.getCacheStrategy() != null
468-
? runtimeOptions.getCacheStrategy() : this.defaultOptions.getCacheStrategy());
469-
requestOptions.setCacheTtl(runtimeOptions.getCacheTtl() != null ? runtimeOptions.getCacheTtl()
470-
: this.defaultOptions.getCacheTtl());
468+
// Merge cache options that are Json-ignored
469+
requestOptions.setCacheOptions(runtimeOptions.getCacheOptions() != null ? runtimeOptions.getCacheOptions()
470+
: this.defaultOptions.getCacheOptions());
471471
}
472472
else {
473473
requestOptions.setHttpHeaders(this.defaultOptions.getHttpHeaders());
@@ -498,41 +498,23 @@ ChatCompletionRequest createRequest(Prompt prompt, boolean stream) {
498498
AnthropicChatOptions requestOptions = null;
499499
if (prompt.getOptions() instanceof AnthropicChatOptions) {
500500
requestOptions = (AnthropicChatOptions) prompt.getOptions();
501-
logger.debug("DEBUGINFO: Found AnthropicChatOptions - cacheStrategy: {}, cacheTtl: {}",
502-
requestOptions.getCacheStrategy(), requestOptions.getCacheTtl());
501+
logger.debug("DEBUGINFO: Found AnthropicChatOptions - cacheOptions {}", requestOptions.getCacheOptions());
503502
}
504503
else {
505504
logger.debug("DEBUGINFO: Options is NOT AnthropicChatOptions, it's: {}",
506505
prompt.getOptions() != null ? prompt.getOptions().getClass().getName() : "null");
507506
}
508507

509-
AnthropicCacheStrategy strategy = requestOptions != null ? requestOptions.getCacheStrategy()
510-
: AnthropicCacheStrategy.NONE;
511-
String cacheTtl = requestOptions != null ? requestOptions.getCacheTtl() : "5m";
508+
AnthropicCacheOptions cacheOptions = requestOptions != null ? requestOptions.getCacheOptions()
509+
: AnthropicCacheOptions.DISABLED;
512510

513-
logger.debug("Cache strategy: {}, TTL: {}", strategy, cacheTtl);
511+
CacheEligibilityResolver cacheEligibilityResolver = CacheEligibilityResolver.from(cacheOptions);
514512

515-
// Track how many breakpoints we've used (max 4)
516-
CacheBreakpointTracker breakpointsUsed = new CacheBreakpointTracker();
517-
ChatCompletionRequest.CacheControl cacheControl = null;
518-
519-
if (strategy != AnthropicCacheStrategy.NONE) {
520-
// Create cache control with TTL if specified, otherwise use default 5m
521-
if (cacheTtl != null && !cacheTtl.equals("5m")) {
522-
cacheControl = new ChatCompletionRequest.CacheControl("ephemeral", cacheTtl);
523-
logger.debug("Created cache control with TTL: type={}, ttl={}", "ephemeral", cacheTtl);
524-
}
525-
else {
526-
cacheControl = new ChatCompletionRequest.CacheControl("ephemeral");
527-
logger.debug("Created cache control with default TTL: type={}, ttl={}", "ephemeral", "5m");
528-
}
529-
}
513+
// Process system - as array if caching, string otherwise
514+
Object systemContent = buildSystemContent(prompt, cacheEligibilityResolver);
530515

531516
// Build messages WITHOUT blanket cache control - strategic placement only
532-
List<AnthropicMessage> userMessages = buildMessages(prompt, strategy, cacheControl, breakpointsUsed);
533-
534-
// Process system - as array if caching, string otherwise
535-
Object systemContent = buildSystemContent(prompt, strategy, cacheControl, breakpointsUsed);
517+
List<AnthropicMessage> userMessages = buildMessages(prompt, cacheEligibilityResolver);
536518

537519
// Build base request
538520
ChatCompletionRequest request = new ChatCompletionRequest(this.defaultOptions.getModel(), userMessages,
@@ -547,16 +529,13 @@ ChatCompletionRequest createRequest(Prompt prompt, boolean stream) {
547529
List<AnthropicApi.Tool> tools = getFunctionTools(toolDefinitions);
548530

549531
// Apply caching to tools if strategy includes them
550-
if ((strategy == AnthropicCacheStrategy.SYSTEM_AND_TOOLS
551-
|| strategy == AnthropicCacheStrategy.CONVERSATION_HISTORY) && breakpointsUsed.canUse()) {
552-
tools = addCacheToLastTool(tools, cacheControl, breakpointsUsed);
553-
}
532+
tools = addCacheToLastTool(tools, cacheEligibilityResolver);
554533

555534
request = ChatCompletionRequest.from(request).tools(tools).build();
556535
}
557536

558537
// Add beta header for 1-hour TTL if needed
559-
if ("1h".equals(cacheTtl) && requestOptions != null) {
538+
if (cacheOptions.getMessageTypeTtls().containsValue(AnthropicCacheTtl.ONE_HOUR)) {
560539
Map<String, String> headers = new HashMap<>(requestOptions.getHttpHeaders());
561540
headers.put("anthropic-beta", AnthropicApi.BETA_EXTENDED_CACHE_TTL);
562541
requestOptions.setHttpHeaders(headers);
@@ -565,6 +544,25 @@ ChatCompletionRequest createRequest(Prompt prompt, boolean stream) {
565544
return request;
566545
}
567546

547+
private static ContentBlock cacheAwareContentBlock(ContentBlock contentBlock, MessageType messageType,
548+
CacheEligibilityResolver cacheEligibilityResolver) {
549+
String basisForLength = switch (contentBlock.type()) {
550+
case TEXT, TEXT_DELTA -> contentBlock.text();
551+
case TOOL_RESULT -> contentBlock.content();
552+
case TOOL_USE -> JsonParser.toJson(contentBlock.input());
553+
case THINKING, THINKING_DELTA -> contentBlock.thinking();
554+
case REDACTED_THINKING -> contentBlock.data();
555+
default -> null;
556+
};
557+
558+
ChatCompletionRequest.CacheControl cacheControl = cacheEligibilityResolver.resolve(messageType, basisForLength);
559+
if (cacheControl == null) {
560+
return contentBlock;
561+
}
562+
cacheEligibilityResolver.useCacheBlock();
563+
return ContentBlock.from(contentBlock).cacheControl(cacheControl).build();
564+
}
565+
568566
private List<AnthropicApi.Tool> getFunctionTools(List<ToolDefinition> toolDefinitions) {
569567
return toolDefinitions.stream().map(toolDefinition -> {
570568
var name = toolDefinition.name();
@@ -579,8 +577,7 @@ private List<AnthropicApi.Tool> getFunctionTools(List<ToolDefinition> toolDefini
579577
* Build messages strategically, applying cache control only where specified by the
580578
* strategy.
581579
*/
582-
private List<AnthropicMessage> buildMessages(Prompt prompt, AnthropicCacheStrategy strategy,
583-
ChatCompletionRequest.CacheControl cacheControl, CacheBreakpointTracker breakpointsUsed) {
580+
private List<AnthropicMessage> buildMessages(Prompt prompt, CacheEligibilityResolver cacheEligibilityResolver) {
584581

585582
List<Message> allMessages = prompt.getInstructions()
586583
.stream()
@@ -589,7 +586,7 @@ private List<AnthropicMessage> buildMessages(Prompt prompt, AnthropicCacheStrate
589586

590587
// Find the last user message (current question) for CONVERSATION_HISTORY strategy
591588
int lastUserIndex = -1;
592-
if (strategy == AnthropicCacheStrategy.CONVERSATION_HISTORY) {
589+
if (cacheEligibilityResolver.isCachingEnabled()) {
593590
for (int i = allMessages.size() - 1; i >= 0; i--) {
594591
if (allMessages.get(i).getMessageType() == MessageType.USER) {
595592
lastUserIndex = i;
@@ -601,30 +598,18 @@ private List<AnthropicMessage> buildMessages(Prompt prompt, AnthropicCacheStrate
601598
List<AnthropicMessage> result = new ArrayList<>();
602599
for (int i = 0; i < allMessages.size(); i++) {
603600
Message message = allMessages.get(i);
604-
boolean shouldApplyCache = false;
605-
606-
// Apply cache to history tail (message before current question) for
607-
// CONVERSATION_HISTORY
608-
if (strategy == AnthropicCacheStrategy.CONVERSATION_HISTORY && breakpointsUsed.canUse()) {
609-
if (lastUserIndex > 0) {
610-
// Cache the message immediately before the last user message
611-
// (multi-turn conversation)
612-
shouldApplyCache = (i == lastUserIndex - 1);
613-
}
614-
if (shouldApplyCache) {
615-
breakpointsUsed.use();
616-
}
617-
}
618-
619-
if (message.getMessageType() == MessageType.USER) {
620-
List<ContentBlock> contents = new ArrayList<>();
621-
622-
// Apply cache control strategically, not to all user messages
623-
if (shouldApplyCache && cacheControl != null) {
624-
contents.add(new ContentBlock(message.getText(), cacheControl));
601+
MessageType messageType = message.getMessageType();
602+
if (messageType == MessageType.USER) {
603+
List<ContentBlock> contentBlocks = new ArrayList<>();
604+
String content = message.getText();
605+
boolean isLastUserMessage = lastUserIndex == i;
606+
ContentBlock contentBlock = new ContentBlock(content);
607+
if (isLastUserMessage) {
608+
// Never cache the latest user message
609+
contentBlocks.add(contentBlock);
625610
}
626611
else {
627-
contents.add(new ContentBlock(message.getText()));
612+
contentBlocks.add(cacheAwareContentBlock(contentBlock, messageType, cacheEligibilityResolver));
628613
}
629614

630615
if (message instanceof UserMessage userMessage) {
@@ -634,30 +619,33 @@ private List<AnthropicMessage> buildMessages(Prompt prompt, AnthropicCacheStrate
634619
var source = getSourceByMedia(media);
635620
return new ContentBlock(contentBlockType, source);
636621
}).toList();
637-
contents.addAll(mediaContent);
622+
contentBlocks.addAll(mediaContent);
638623
}
639624
}
640-
result.add(new AnthropicMessage(contents, Role.valueOf(message.getMessageType().name())));
625+
result.add(new AnthropicMessage(contentBlocks, Role.valueOf(message.getMessageType().name())));
641626
}
642-
else if (message.getMessageType() == MessageType.ASSISTANT) {
627+
else if (messageType == MessageType.ASSISTANT) {
643628
AssistantMessage assistantMessage = (AssistantMessage) message;
644629
List<ContentBlock> contentBlocks = new ArrayList<>();
645630
if (StringUtils.hasText(message.getText())) {
646-
contentBlocks.add(new ContentBlock(message.getText()));
631+
ContentBlock contentBlock = new ContentBlock(message.getText());
632+
contentBlocks.add(cacheAwareContentBlock(contentBlock, messageType, cacheEligibilityResolver));
647633
}
648634
if (!CollectionUtils.isEmpty(assistantMessage.getToolCalls())) {
649635
for (AssistantMessage.ToolCall toolCall : assistantMessage.getToolCalls()) {
650-
contentBlocks.add(new ContentBlock(Type.TOOL_USE, toolCall.id(), toolCall.name(),
651-
ModelOptionsUtils.jsonToMap(toolCall.arguments())));
636+
ContentBlock contentBlock = new ContentBlock(Type.TOOL_USE, toolCall.id(), toolCall.name(),
637+
ModelOptionsUtils.jsonToMap(toolCall.arguments()));
638+
contentBlocks.add(cacheAwareContentBlock(contentBlock, messageType, cacheEligibilityResolver));
652639
}
653640
}
654641
result.add(new AnthropicMessage(contentBlocks, Role.ASSISTANT));
655642
}
656-
else if (message.getMessageType() == MessageType.TOOL) {
643+
else if (messageType == MessageType.TOOL) {
657644
List<ContentBlock> toolResponses = ((ToolResponseMessage) message).getResponses()
658645
.stream()
659646
.map(toolResponse -> new ContentBlock(Type.TOOL_RESULT, toolResponse.id(),
660647
toolResponse.responseData()))
648+
.map(contentBlock -> cacheAwareContentBlock(contentBlock, messageType, cacheEligibilityResolver))
661649
.toList();
662650
result.add(new AnthropicMessage(toolResponses, Role.USER));
663651
}
@@ -671,8 +659,7 @@ else if (message.getMessageType() == MessageType.TOOL) {
671659
/**
672660
* Build system content - as array if caching, string otherwise.
673661
*/
674-
private Object buildSystemContent(Prompt prompt, AnthropicCacheStrategy strategy,
675-
ChatCompletionRequest.CacheControl cacheControl, CacheBreakpointTracker breakpointsUsed) {
662+
private Object buildSystemContent(Prompt prompt, CacheEligibilityResolver cacheEligibilityResolver) {
676663

677664
String systemText = prompt.getInstructions()
678665
.stream()
@@ -685,15 +672,9 @@ private Object buildSystemContent(Prompt prompt, AnthropicCacheStrategy strategy
685672
}
686673

687674
// Use array format when caching system
688-
if ((strategy == AnthropicCacheStrategy.SYSTEM_ONLY || strategy == AnthropicCacheStrategy.SYSTEM_AND_TOOLS
689-
|| strategy == AnthropicCacheStrategy.CONVERSATION_HISTORY) && breakpointsUsed.canUse()
690-
&& cacheControl != null) {
691-
692-
logger.debug("Applying cache control to system message - strategy: {}, cacheControl: {}", strategy,
693-
cacheControl);
694-
List<ContentBlock> systemBlocks = List.of(new ContentBlock(systemText, cacheControl));
695-
breakpointsUsed.use();
696-
return systemBlocks;
675+
if (cacheEligibilityResolver.isCachingEnabled()) {
676+
return List
677+
.of(cacheAwareContentBlock(new ContentBlock(systemText), MessageType.SYSTEM, cacheEligibilityResolver));
697678
}
698679

699680
// Use string format when not caching (backward compatible)
@@ -704,9 +685,11 @@ private Object buildSystemContent(Prompt prompt, AnthropicCacheStrategy strategy
704685
* Add cache control to the last tool for deterministic caching.
705686
*/
706687
private List<AnthropicApi.Tool> addCacheToLastTool(List<AnthropicApi.Tool> tools,
707-
ChatCompletionRequest.CacheControl cacheControl, CacheBreakpointTracker breakpointsUsed) {
688+
CacheEligibilityResolver cacheEligibilityResolver) {
689+
690+
ChatCompletionRequest.CacheControl cacheControl = cacheEligibilityResolver.resolveToolCacheControl();
708691

709-
if (tools == null || tools.isEmpty() || !breakpointsUsed.canUse() || cacheControl == null) {
692+
if (cacheControl == null || tools == null || tools.isEmpty()) {
710693
return tools;
711694
}
712695

@@ -716,7 +699,7 @@ private List<AnthropicApi.Tool> addCacheToLastTool(List<AnthropicApi.Tool> tools
716699
if (i == tools.size() - 1) {
717700
// Add cache control to last tool
718701
tool = new AnthropicApi.Tool(tool.name(), tool.description(), tool.inputSchema(), cacheControl);
719-
breakpointsUsed.use();
702+
cacheEligibilityResolver.useCacheBlock();
720703
}
721704
modifiedTools.add(tool);
722705
}
@@ -804,36 +787,4 @@ public AnthropicChatModel build() {
804787

805788
}
806789

807-
/**
808-
* Tracks cache breakpoints used (max 4 allowed by Anthropic). Non-static to ensure
809-
* each request has its own instance.
810-
*/
811-
private class CacheBreakpointTracker {
812-
813-
private int count = 0;
814-
815-
private boolean hasWarned = false;
816-
817-
public boolean canUse() {
818-
return this.count < 4;
819-
}
820-
821-
public void use() {
822-
if (this.count < 4) {
823-
this.count++;
824-
}
825-
else if (!this.hasWarned) {
826-
logger.warn(
827-
"Anthropic cache breakpoint limit (4) reached. Additional cache_control directives will be ignored. "
828-
+ "Consider using fewer cache strategies or simpler content structure.");
829-
this.hasWarned = true;
830-
}
831-
}
832-
833-
public int getCount() {
834-
return this.count;
835-
}
836-
837-
}
838-
839790
}

0 commit comments

Comments
 (0)