Skip to content

Commit fa121af

Browse files
committed
feat: Implement cache management for Anthropic API with eligibility tracking
gh-4325: Enhance cache management for Anthropic API by introudicing per-message TTL and configurable content block usage optimization. Signed-off-by: Austin Dase <[email protected]>
1 parent 3e17e16 commit fa121af

File tree

14 files changed

+995
-241
lines changed

14 files changed

+995
-241
lines changed

models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/AnthropicChatModel.java

Lines changed: 66 additions & 114 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,9 @@
4242
import org.springframework.ai.anthropic.api.AnthropicApi.ContentBlock.Source;
4343
import org.springframework.ai.anthropic.api.AnthropicApi.ContentBlock.Type;
4444
import org.springframework.ai.anthropic.api.AnthropicApi.Role;
45-
import org.springframework.ai.anthropic.api.AnthropicCacheStrategy;
45+
import org.springframework.ai.anthropic.api.AnthropicCacheOptions;
46+
import org.springframework.ai.anthropic.api.AnthropicCacheTtl;
47+
import org.springframework.ai.anthropic.api.utils.CacheEligibilityResolver;
4648
import org.springframework.ai.chat.messages.AssistantMessage;
4749
import org.springframework.ai.chat.messages.Message;
4850
import org.springframework.ai.chat.messages.MessageType;
@@ -93,6 +95,7 @@
9395
* @author Alexandros Pappas
9496
* @author Jonghoon Park
9597
* @author Soby Chacko
98+
* @author Austin Dase
9699
* @since 1.0.0
97100
*/
98101
public class AnthropicChatModel implements ChatModel {
@@ -463,11 +466,9 @@ Prompt buildRequestPrompt(Prompt prompt) {
463466
requestOptions.setToolContext(ToolCallingChatOptions.mergeToolContext(runtimeOptions.getToolContext(),
464467
this.defaultOptions.getToolContext()));
465468

466-
// Merge cache strategy and TTL (also @JsonIgnore fields)
467-
requestOptions.setCacheStrategy(runtimeOptions.getCacheStrategy() != null
468-
? runtimeOptions.getCacheStrategy() : this.defaultOptions.getCacheStrategy());
469-
requestOptions.setCacheTtl(runtimeOptions.getCacheTtl() != null ? runtimeOptions.getCacheTtl()
470-
: this.defaultOptions.getCacheTtl());
469+
// Merge cache options that are Json-ignored
470+
requestOptions.setCacheOptions(runtimeOptions.getCacheOptions() != null ? runtimeOptions.getCacheOptions()
471+
: this.defaultOptions.getCacheOptions());
471472
}
472473
else {
473474
requestOptions.setHttpHeaders(this.defaultOptions.getHttpHeaders());
@@ -498,41 +499,23 @@ ChatCompletionRequest createRequest(Prompt prompt, boolean stream) {
498499
AnthropicChatOptions requestOptions = null;
499500
if (prompt.getOptions() instanceof AnthropicChatOptions) {
500501
requestOptions = (AnthropicChatOptions) prompt.getOptions();
501-
logger.debug("DEBUGINFO: Found AnthropicChatOptions - cacheStrategy: {}, cacheTtl: {}",
502-
requestOptions.getCacheStrategy(), requestOptions.getCacheTtl());
502+
logger.debug("DEBUGINFO: Found AnthropicChatOptions - cacheOptions {}", requestOptions.getCacheOptions());
503503
}
504504
else {
505505
logger.debug("DEBUGINFO: Options is NOT AnthropicChatOptions, it's: {}",
506506
prompt.getOptions() != null ? prompt.getOptions().getClass().getName() : "null");
507507
}
508508

509-
AnthropicCacheStrategy strategy = requestOptions != null ? requestOptions.getCacheStrategy()
510-
: AnthropicCacheStrategy.NONE;
511-
String cacheTtl = requestOptions != null ? requestOptions.getCacheTtl() : "5m";
509+
AnthropicCacheOptions cacheOptions = requestOptions != null ? requestOptions.getCacheOptions()
510+
: AnthropicCacheOptions.DISABLED;
512511

513-
logger.debug("Cache strategy: {}, TTL: {}", strategy, cacheTtl);
512+
CacheEligibilityResolver cacheEligibilityResolver = CacheEligibilityResolver.from(cacheOptions);
514513

515-
// Track how many breakpoints we've used (max 4)
516-
CacheBreakpointTracker breakpointsUsed = new CacheBreakpointTracker();
517-
ChatCompletionRequest.CacheControl cacheControl = null;
518-
519-
if (strategy != AnthropicCacheStrategy.NONE) {
520-
// Create cache control with TTL if specified, otherwise use default 5m
521-
if (cacheTtl != null && !cacheTtl.equals("5m")) {
522-
cacheControl = new ChatCompletionRequest.CacheControl("ephemeral", cacheTtl);
523-
logger.debug("Created cache control with TTL: type={}, ttl={}", "ephemeral", cacheTtl);
524-
}
525-
else {
526-
cacheControl = new ChatCompletionRequest.CacheControl("ephemeral");
527-
logger.debug("Created cache control with default TTL: type={}, ttl={}", "ephemeral", "5m");
528-
}
529-
}
514+
// Process system - as array if caching, string otherwise
515+
Object systemContent = buildSystemContent(prompt, cacheEligibilityResolver);
530516

531517
// Build messages WITHOUT blanket cache control - strategic placement only
532-
List<AnthropicMessage> userMessages = buildMessages(prompt, strategy, cacheControl, breakpointsUsed);
533-
534-
// Process system - as array if caching, string otherwise
535-
Object systemContent = buildSystemContent(prompt, strategy, cacheControl, breakpointsUsed);
518+
List<AnthropicMessage> userMessages = buildMessages(prompt, cacheEligibilityResolver);
536519

537520
// Build base request
538521
ChatCompletionRequest request = new ChatCompletionRequest(this.defaultOptions.getModel(), userMessages,
@@ -547,16 +530,13 @@ ChatCompletionRequest createRequest(Prompt prompt, boolean stream) {
547530
List<AnthropicApi.Tool> tools = getFunctionTools(toolDefinitions);
548531

549532
// Apply caching to tools if strategy includes them
550-
if ((strategy == AnthropicCacheStrategy.SYSTEM_AND_TOOLS
551-
|| strategy == AnthropicCacheStrategy.CONVERSATION_HISTORY) && breakpointsUsed.canUse()) {
552-
tools = addCacheToLastTool(tools, cacheControl, breakpointsUsed);
553-
}
533+
tools = addCacheToLastTool(tools, cacheEligibilityResolver);
554534

555535
request = ChatCompletionRequest.from(request).tools(tools).build();
556536
}
557537

558538
// Add beta header for 1-hour TTL if needed
559-
if ("1h".equals(cacheTtl) && requestOptions != null) {
539+
if (cacheOptions.getMessageTypeTtl().containsValue(AnthropicCacheTtl.ONE_HOUR)) {
560540
Map<String, String> headers = new HashMap<>(requestOptions.getHttpHeaders());
561541
headers.put("anthropic-beta", AnthropicApi.BETA_EXTENDED_CACHE_TTL);
562542
requestOptions.setHttpHeaders(headers);
@@ -565,6 +545,25 @@ ChatCompletionRequest createRequest(Prompt prompt, boolean stream) {
565545
return request;
566546
}
567547

548+
private static ContentBlock cacheAwareContentBlock(ContentBlock contentBlock, MessageType messageType,
549+
CacheEligibilityResolver cacheEligibilityResolver) {
550+
String basisForLength = switch (contentBlock.type()) {
551+
case TEXT, TEXT_DELTA -> contentBlock.text();
552+
case TOOL_RESULT -> contentBlock.content();
553+
case TOOL_USE -> JsonParser.toJson(contentBlock.input());
554+
case THINKING, THINKING_DELTA -> contentBlock.thinking();
555+
case REDACTED_THINKING -> contentBlock.data();
556+
default -> null;
557+
};
558+
559+
ChatCompletionRequest.CacheControl cacheControl = cacheEligibilityResolver.resolve(messageType, basisForLength);
560+
if (cacheControl == null) {
561+
return contentBlock;
562+
}
563+
cacheEligibilityResolver.useCacheBlock();
564+
return ContentBlock.from(contentBlock).cacheControl(cacheControl).build();
565+
}
566+
568567
private List<AnthropicApi.Tool> getFunctionTools(List<ToolDefinition> toolDefinitions) {
569568
return toolDefinitions.stream().map(toolDefinition -> {
570569
var name = toolDefinition.name();
@@ -579,8 +578,7 @@ private List<AnthropicApi.Tool> getFunctionTools(List<ToolDefinition> toolDefini
579578
* Build messages strategically, applying cache control only where specified by the
580579
* strategy.
581580
*/
582-
private List<AnthropicMessage> buildMessages(Prompt prompt, AnthropicCacheStrategy strategy,
583-
ChatCompletionRequest.CacheControl cacheControl, CacheBreakpointTracker breakpointsUsed) {
581+
private List<AnthropicMessage> buildMessages(Prompt prompt, CacheEligibilityResolver cacheEligibilityResolver) {
584582

585583
List<Message> allMessages = prompt.getInstructions()
586584
.stream()
@@ -589,7 +587,7 @@ private List<AnthropicMessage> buildMessages(Prompt prompt, AnthropicCacheStrate
589587

590588
// Find the last user message (current question) for CONVERSATION_HISTORY strategy
591589
int lastUserIndex = -1;
592-
if (strategy == AnthropicCacheStrategy.CONVERSATION_HISTORY) {
590+
if (cacheEligibilityResolver.isCachingEnabled()) {
593591
for (int i = allMessages.size() - 1; i >= 0; i--) {
594592
if (allMessages.get(i).getMessageType() == MessageType.USER) {
595593
lastUserIndex = i;
@@ -601,30 +599,18 @@ private List<AnthropicMessage> buildMessages(Prompt prompt, AnthropicCacheStrate
601599
List<AnthropicMessage> result = new ArrayList<>();
602600
for (int i = 0; i < allMessages.size(); i++) {
603601
Message message = allMessages.get(i);
604-
boolean shouldApplyCache = false;
605-
606-
// Apply cache to history tail (message before current question) for
607-
// CONVERSATION_HISTORY
608-
if (strategy == AnthropicCacheStrategy.CONVERSATION_HISTORY && breakpointsUsed.canUse()) {
609-
if (lastUserIndex > 0) {
610-
// Cache the message immediately before the last user message
611-
// (multi-turn conversation)
612-
shouldApplyCache = (i == lastUserIndex - 1);
613-
}
614-
if (shouldApplyCache) {
615-
breakpointsUsed.use();
616-
}
617-
}
618-
619-
if (message.getMessageType() == MessageType.USER) {
620-
List<ContentBlock> contents = new ArrayList<>();
621-
622-
// Apply cache control strategically, not to all user messages
623-
if (shouldApplyCache && cacheControl != null) {
624-
contents.add(new ContentBlock(message.getText(), cacheControl));
602+
MessageType messageType = message.getMessageType();
603+
if (messageType == MessageType.USER) {
604+
List<ContentBlock> contentBlocks = new ArrayList<>();
605+
String content = message.getText();
606+
boolean isLastUserMessage = lastUserIndex == i;
607+
ContentBlock contentBlock = new ContentBlock(content);
608+
if (isLastUserMessage) {
609+
// Never cache the latest user message
610+
contentBlocks.add(contentBlock);
625611
}
626612
else {
627-
contents.add(new ContentBlock(message.getText()));
613+
contentBlocks.add(cacheAwareContentBlock(contentBlock, messageType, cacheEligibilityResolver));
628614
}
629615

630616
if (message instanceof UserMessage userMessage) {
@@ -634,30 +620,33 @@ private List<AnthropicMessage> buildMessages(Prompt prompt, AnthropicCacheStrate
634620
var source = getSourceByMedia(media);
635621
return new ContentBlock(contentBlockType, source);
636622
}).toList();
637-
contents.addAll(mediaContent);
623+
contentBlocks.addAll(mediaContent);
638624
}
639625
}
640-
result.add(new AnthropicMessage(contents, Role.valueOf(message.getMessageType().name())));
626+
result.add(new AnthropicMessage(contentBlocks, Role.valueOf(message.getMessageType().name())));
641627
}
642-
else if (message.getMessageType() == MessageType.ASSISTANT) {
628+
else if (messageType == MessageType.ASSISTANT) {
643629
AssistantMessage assistantMessage = (AssistantMessage) message;
644630
List<ContentBlock> contentBlocks = new ArrayList<>();
645631
if (StringUtils.hasText(message.getText())) {
646-
contentBlocks.add(new ContentBlock(message.getText()));
632+
ContentBlock contentBlock = new ContentBlock(message.getText());
633+
contentBlocks.add(cacheAwareContentBlock(contentBlock, messageType, cacheEligibilityResolver));
647634
}
648635
if (!CollectionUtils.isEmpty(assistantMessage.getToolCalls())) {
649636
for (AssistantMessage.ToolCall toolCall : assistantMessage.getToolCalls()) {
650-
contentBlocks.add(new ContentBlock(Type.TOOL_USE, toolCall.id(), toolCall.name(),
651-
ModelOptionsUtils.jsonToMap(toolCall.arguments())));
637+
ContentBlock contentBlock = new ContentBlock(Type.TOOL_USE, toolCall.id(), toolCall.name(),
638+
ModelOptionsUtils.jsonToMap(toolCall.arguments()));
639+
contentBlocks.add(cacheAwareContentBlock(contentBlock, messageType, cacheEligibilityResolver));
652640
}
653641
}
654642
result.add(new AnthropicMessage(contentBlocks, Role.ASSISTANT));
655643
}
656-
else if (message.getMessageType() == MessageType.TOOL) {
644+
else if (messageType == MessageType.TOOL) {
657645
List<ContentBlock> toolResponses = ((ToolResponseMessage) message).getResponses()
658646
.stream()
659647
.map(toolResponse -> new ContentBlock(Type.TOOL_RESULT, toolResponse.id(),
660648
toolResponse.responseData()))
649+
.map(contentBlock -> cacheAwareContentBlock(contentBlock, messageType, cacheEligibilityResolver))
661650
.toList();
662651
result.add(new AnthropicMessage(toolResponses, Role.USER));
663652
}
@@ -671,8 +660,7 @@ else if (message.getMessageType() == MessageType.TOOL) {
671660
/**
672661
* Build system content - as array if caching, string otherwise.
673662
*/
674-
private Object buildSystemContent(Prompt prompt, AnthropicCacheStrategy strategy,
675-
ChatCompletionRequest.CacheControl cacheControl, CacheBreakpointTracker breakpointsUsed) {
663+
private Object buildSystemContent(Prompt prompt, CacheEligibilityResolver cacheEligibilityResolver) {
676664

677665
String systemText = prompt.getInstructions()
678666
.stream()
@@ -685,15 +673,9 @@ private Object buildSystemContent(Prompt prompt, AnthropicCacheStrategy strategy
685673
}
686674

687675
// Use array format when caching system
688-
if ((strategy == AnthropicCacheStrategy.SYSTEM_ONLY || strategy == AnthropicCacheStrategy.SYSTEM_AND_TOOLS
689-
|| strategy == AnthropicCacheStrategy.CONVERSATION_HISTORY) && breakpointsUsed.canUse()
690-
&& cacheControl != null) {
691-
692-
logger.debug("Applying cache control to system message - strategy: {}, cacheControl: {}", strategy,
693-
cacheControl);
694-
List<ContentBlock> systemBlocks = List.of(new ContentBlock(systemText, cacheControl));
695-
breakpointsUsed.use();
696-
return systemBlocks;
676+
if (cacheEligibilityResolver.isCachingEnabled()) {
677+
return List
678+
.of(cacheAwareContentBlock(new ContentBlock(systemText), MessageType.SYSTEM, cacheEligibilityResolver));
697679
}
698680

699681
// Use string format when not caching (backward compatible)
@@ -704,9 +686,11 @@ private Object buildSystemContent(Prompt prompt, AnthropicCacheStrategy strategy
704686
* Add cache control to the last tool for deterministic caching.
705687
*/
706688
private List<AnthropicApi.Tool> addCacheToLastTool(List<AnthropicApi.Tool> tools,
707-
ChatCompletionRequest.CacheControl cacheControl, CacheBreakpointTracker breakpointsUsed) {
689+
CacheEligibilityResolver cacheEligibilityResolver) {
690+
691+
ChatCompletionRequest.CacheControl cacheControl = cacheEligibilityResolver.resolveToolCacheControl();
708692

709-
if (tools == null || tools.isEmpty() || !breakpointsUsed.canUse() || cacheControl == null) {
693+
if (cacheControl == null || tools == null || tools.isEmpty()) {
710694
return tools;
711695
}
712696

@@ -716,7 +700,7 @@ private List<AnthropicApi.Tool> addCacheToLastTool(List<AnthropicApi.Tool> tools
716700
if (i == tools.size() - 1) {
717701
// Add cache control to last tool
718702
tool = new AnthropicApi.Tool(tool.name(), tool.description(), tool.inputSchema(), cacheControl);
719-
breakpointsUsed.use();
703+
cacheEligibilityResolver.useCacheBlock();
720704
}
721705
modifiedTools.add(tool);
722706
}
@@ -804,36 +788,4 @@ public AnthropicChatModel build() {
804788

805789
}
806790

807-
/**
808-
* Tracks cache breakpoints used (max 4 allowed by Anthropic). Non-static to ensure
809-
* each request has its own instance.
810-
*/
811-
private class CacheBreakpointTracker {
812-
813-
private int count = 0;
814-
815-
private boolean hasWarned = false;
816-
817-
public boolean canUse() {
818-
return this.count < 4;
819-
}
820-
821-
public void use() {
822-
if (this.count < 4) {
823-
this.count++;
824-
}
825-
else if (!this.hasWarned) {
826-
logger.warn(
827-
"Anthropic cache breakpoint limit (4) reached. Additional cache_control directives will be ignored. "
828-
+ "Consider using fewer cache strategies or simpler content structure.");
829-
this.hasWarned = true;
830-
}
831-
}
832-
833-
public int getCount() {
834-
return this.count;
835-
}
836-
837-
}
838-
839791
}

0 commit comments

Comments
 (0)