42
42
import org .springframework .ai .anthropic .api .AnthropicApi .ContentBlock .Source ;
43
43
import org .springframework .ai .anthropic .api .AnthropicApi .ContentBlock .Type ;
44
44
import org .springframework .ai .anthropic .api .AnthropicApi .Role ;
45
- import org .springframework .ai .anthropic .api .AnthropicCacheStrategy ;
45
+ import org .springframework .ai .anthropic .api .AnthropicCacheOptions ;
46
+ import org .springframework .ai .anthropic .api .AnthropicCacheTtl ;
47
+ import org .springframework .ai .anthropic .api .utils .CacheEligibilityResolver ;
46
48
import org .springframework .ai .chat .messages .AssistantMessage ;
47
49
import org .springframework .ai .chat .messages .Message ;
48
50
import org .springframework .ai .chat .messages .MessageType ;
93
95
* @author Alexandros Pappas
94
96
* @author Jonghoon Park
95
97
* @author Soby Chacko
98
+ * @author Austin Dase
96
99
* @since 1.0.0
97
100
*/
98
101
public class AnthropicChatModel implements ChatModel {
@@ -463,11 +466,9 @@ Prompt buildRequestPrompt(Prompt prompt) {
463
466
requestOptions .setToolContext (ToolCallingChatOptions .mergeToolContext (runtimeOptions .getToolContext (),
464
467
this .defaultOptions .getToolContext ()));
465
468
466
- // Merge cache strategy and TTL (also @JsonIgnore fields)
467
- requestOptions .setCacheStrategy (runtimeOptions .getCacheStrategy () != null
468
- ? runtimeOptions .getCacheStrategy () : this .defaultOptions .getCacheStrategy ());
469
- requestOptions .setCacheTtl (runtimeOptions .getCacheTtl () != null ? runtimeOptions .getCacheTtl ()
470
- : this .defaultOptions .getCacheTtl ());
469
+ // Merge cache options that are Json-ignored
470
+ requestOptions .setCacheOptions (runtimeOptions .getCacheOptions () != null ? runtimeOptions .getCacheOptions ()
471
+ : this .defaultOptions .getCacheOptions ());
471
472
}
472
473
else {
473
474
requestOptions .setHttpHeaders (this .defaultOptions .getHttpHeaders ());
@@ -498,41 +499,23 @@ ChatCompletionRequest createRequest(Prompt prompt, boolean stream) {
498
499
AnthropicChatOptions requestOptions = null ;
499
500
if (prompt .getOptions () instanceof AnthropicChatOptions ) {
500
501
requestOptions = (AnthropicChatOptions ) prompt .getOptions ();
501
- logger .debug ("DEBUGINFO: Found AnthropicChatOptions - cacheStrategy: {}, cacheTtl: {}" ,
502
- requestOptions .getCacheStrategy (), requestOptions .getCacheTtl ());
502
+ logger .debug ("DEBUGINFO: Found AnthropicChatOptions - cacheOptions {}" , requestOptions .getCacheOptions ());
503
503
}
504
504
else {
505
505
logger .debug ("DEBUGINFO: Options is NOT AnthropicChatOptions, it's: {}" ,
506
506
prompt .getOptions () != null ? prompt .getOptions ().getClass ().getName () : "null" );
507
507
}
508
508
509
- AnthropicCacheStrategy strategy = requestOptions != null ? requestOptions .getCacheStrategy ()
510
- : AnthropicCacheStrategy .NONE ;
511
- String cacheTtl = requestOptions != null ? requestOptions .getCacheTtl () : "5m" ;
509
+ AnthropicCacheOptions cacheOptions = requestOptions != null ? requestOptions .getCacheOptions ()
510
+ : AnthropicCacheOptions .DISABLED ;
512
511
513
- logger . debug ( "Cache strategy: {}, TTL: {}" , strategy , cacheTtl );
512
+ CacheEligibilityResolver cacheEligibilityResolver = CacheEligibilityResolver . from ( cacheOptions );
514
513
515
- // Track how many breakpoints we've used (max 4)
516
- CacheBreakpointTracker breakpointsUsed = new CacheBreakpointTracker ();
517
- ChatCompletionRequest .CacheControl cacheControl = null ;
518
-
519
- if (strategy != AnthropicCacheStrategy .NONE ) {
520
- // Create cache control with TTL if specified, otherwise use default 5m
521
- if (cacheTtl != null && !cacheTtl .equals ("5m" )) {
522
- cacheControl = new ChatCompletionRequest .CacheControl ("ephemeral" , cacheTtl );
523
- logger .debug ("Created cache control with TTL: type={}, ttl={}" , "ephemeral" , cacheTtl );
524
- }
525
- else {
526
- cacheControl = new ChatCompletionRequest .CacheControl ("ephemeral" );
527
- logger .debug ("Created cache control with default TTL: type={}, ttl={}" , "ephemeral" , "5m" );
528
- }
529
- }
514
+ // Process system - as array if caching, string otherwise
515
+ Object systemContent = buildSystemContent (prompt , cacheEligibilityResolver );
530
516
531
517
// Build messages WITHOUT blanket cache control - strategic placement only
532
- List <AnthropicMessage > userMessages = buildMessages (prompt , strategy , cacheControl , breakpointsUsed );
533
-
534
- // Process system - as array if caching, string otherwise
535
- Object systemContent = buildSystemContent (prompt , strategy , cacheControl , breakpointsUsed );
518
+ List <AnthropicMessage > userMessages = buildMessages (prompt , cacheEligibilityResolver );
536
519
537
520
// Build base request
538
521
ChatCompletionRequest request = new ChatCompletionRequest (this .defaultOptions .getModel (), userMessages ,
@@ -547,16 +530,13 @@ ChatCompletionRequest createRequest(Prompt prompt, boolean stream) {
547
530
List <AnthropicApi .Tool > tools = getFunctionTools (toolDefinitions );
548
531
549
532
// Apply caching to tools if strategy includes them
550
- if ((strategy == AnthropicCacheStrategy .SYSTEM_AND_TOOLS
551
- || strategy == AnthropicCacheStrategy .CONVERSATION_HISTORY ) && breakpointsUsed .canUse ()) {
552
- tools = addCacheToLastTool (tools , cacheControl , breakpointsUsed );
553
- }
533
+ tools = addCacheToLastTool (tools , cacheEligibilityResolver );
554
534
555
535
request = ChatCompletionRequest .from (request ).tools (tools ).build ();
556
536
}
557
537
558
538
// Add beta header for 1-hour TTL if needed
559
- if ("1h" . equals ( cacheTtl ) && requestOptions != null ) {
539
+ if (cacheOptions . getMessageTypeTtl (). containsValue ( AnthropicCacheTtl . ONE_HOUR ) ) {
560
540
Map <String , String > headers = new HashMap <>(requestOptions .getHttpHeaders ());
561
541
headers .put ("anthropic-beta" , AnthropicApi .BETA_EXTENDED_CACHE_TTL );
562
542
requestOptions .setHttpHeaders (headers );
@@ -565,6 +545,25 @@ ChatCompletionRequest createRequest(Prompt prompt, boolean stream) {
565
545
return request ;
566
546
}
567
547
548
+ private static ContentBlock cacheAwareContentBlock (ContentBlock contentBlock , MessageType messageType ,
549
+ CacheEligibilityResolver cacheEligibilityResolver ) {
550
+ String basisForLength = switch (contentBlock .type ()) {
551
+ case TEXT , TEXT_DELTA -> contentBlock .text ();
552
+ case TOOL_RESULT -> contentBlock .content ();
553
+ case TOOL_USE -> JsonParser .toJson (contentBlock .input ());
554
+ case THINKING , THINKING_DELTA -> contentBlock .thinking ();
555
+ case REDACTED_THINKING -> contentBlock .data ();
556
+ default -> null ;
557
+ };
558
+
559
+ ChatCompletionRequest .CacheControl cacheControl = cacheEligibilityResolver .resolve (messageType , basisForLength );
560
+ if (cacheControl == null ) {
561
+ return contentBlock ;
562
+ }
563
+ cacheEligibilityResolver .useCacheBlock ();
564
+ return ContentBlock .from (contentBlock ).cacheControl (cacheControl ).build ();
565
+ }
566
+
568
567
private List <AnthropicApi .Tool > getFunctionTools (List <ToolDefinition > toolDefinitions ) {
569
568
return toolDefinitions .stream ().map (toolDefinition -> {
570
569
var name = toolDefinition .name ();
@@ -579,8 +578,7 @@ private List<AnthropicApi.Tool> getFunctionTools(List<ToolDefinition> toolDefini
579
578
* Build messages strategically, applying cache control only where specified by the
580
579
* strategy.
581
580
*/
582
- private List <AnthropicMessage > buildMessages (Prompt prompt , AnthropicCacheStrategy strategy ,
583
- ChatCompletionRequest .CacheControl cacheControl , CacheBreakpointTracker breakpointsUsed ) {
581
+ private List <AnthropicMessage > buildMessages (Prompt prompt , CacheEligibilityResolver cacheEligibilityResolver ) {
584
582
585
583
List <Message > allMessages = prompt .getInstructions ()
586
584
.stream ()
@@ -589,7 +587,7 @@ private List<AnthropicMessage> buildMessages(Prompt prompt, AnthropicCacheStrate
589
587
590
588
// Find the last user message (current question) for CONVERSATION_HISTORY strategy
591
589
int lastUserIndex = -1 ;
592
- if (strategy == AnthropicCacheStrategy . CONVERSATION_HISTORY ) {
590
+ if (cacheEligibilityResolver . isCachingEnabled () ) {
593
591
for (int i = allMessages .size () - 1 ; i >= 0 ; i --) {
594
592
if (allMessages .get (i ).getMessageType () == MessageType .USER ) {
595
593
lastUserIndex = i ;
@@ -601,30 +599,18 @@ private List<AnthropicMessage> buildMessages(Prompt prompt, AnthropicCacheStrate
601
599
List <AnthropicMessage > result = new ArrayList <>();
602
600
for (int i = 0 ; i < allMessages .size (); i ++) {
603
601
Message message = allMessages .get (i );
604
- boolean shouldApplyCache = false ;
605
-
606
- // Apply cache to history tail (message before current question) for
607
- // CONVERSATION_HISTORY
608
- if (strategy == AnthropicCacheStrategy .CONVERSATION_HISTORY && breakpointsUsed .canUse ()) {
609
- if (lastUserIndex > 0 ) {
610
- // Cache the message immediately before the last user message
611
- // (multi-turn conversation)
612
- shouldApplyCache = (i == lastUserIndex - 1 );
613
- }
614
- if (shouldApplyCache ) {
615
- breakpointsUsed .use ();
616
- }
617
- }
618
-
619
- if (message .getMessageType () == MessageType .USER ) {
620
- List <ContentBlock > contents = new ArrayList <>();
621
-
622
- // Apply cache control strategically, not to all user messages
623
- if (shouldApplyCache && cacheControl != null ) {
624
- contents .add (new ContentBlock (message .getText (), cacheControl ));
602
+ MessageType messageType = message .getMessageType ();
603
+ if (messageType == MessageType .USER ) {
604
+ List <ContentBlock > contentBlocks = new ArrayList <>();
605
+ String content = message .getText ();
606
+ boolean isLastUserMessage = lastUserIndex == i ;
607
+ ContentBlock contentBlock = new ContentBlock (content );
608
+ if (isLastUserMessage ) {
609
+ // Never cache the latest user message
610
+ contentBlocks .add (contentBlock );
625
611
}
626
612
else {
627
- contents .add (new ContentBlock ( message . getText () ));
613
+ contentBlocks .add (cacheAwareContentBlock ( contentBlock , messageType , cacheEligibilityResolver ));
628
614
}
629
615
630
616
if (message instanceof UserMessage userMessage ) {
@@ -634,30 +620,33 @@ private List<AnthropicMessage> buildMessages(Prompt prompt, AnthropicCacheStrate
634
620
var source = getSourceByMedia (media );
635
621
return new ContentBlock (contentBlockType , source );
636
622
}).toList ();
637
- contents .addAll (mediaContent );
623
+ contentBlocks .addAll (mediaContent );
638
624
}
639
625
}
640
- result .add (new AnthropicMessage (contents , Role .valueOf (message .getMessageType ().name ())));
626
+ result .add (new AnthropicMessage (contentBlocks , Role .valueOf (message .getMessageType ().name ())));
641
627
}
642
- else if (message . getMessageType () == MessageType .ASSISTANT ) {
628
+ else if (messageType == MessageType .ASSISTANT ) {
643
629
AssistantMessage assistantMessage = (AssistantMessage ) message ;
644
630
List <ContentBlock > contentBlocks = new ArrayList <>();
645
631
if (StringUtils .hasText (message .getText ())) {
646
- contentBlocks .add (new ContentBlock (message .getText ()));
632
+ ContentBlock contentBlock = new ContentBlock (message .getText ());
633
+ contentBlocks .add (cacheAwareContentBlock (contentBlock , messageType , cacheEligibilityResolver ));
647
634
}
648
635
if (!CollectionUtils .isEmpty (assistantMessage .getToolCalls ())) {
649
636
for (AssistantMessage .ToolCall toolCall : assistantMessage .getToolCalls ()) {
650
- contentBlocks .add (new ContentBlock (Type .TOOL_USE , toolCall .id (), toolCall .name (),
651
- ModelOptionsUtils .jsonToMap (toolCall .arguments ())));
637
+ ContentBlock contentBlock = new ContentBlock (Type .TOOL_USE , toolCall .id (), toolCall .name (),
638
+ ModelOptionsUtils .jsonToMap (toolCall .arguments ()));
639
+ contentBlocks .add (cacheAwareContentBlock (contentBlock , messageType , cacheEligibilityResolver ));
652
640
}
653
641
}
654
642
result .add (new AnthropicMessage (contentBlocks , Role .ASSISTANT ));
655
643
}
656
- else if (message . getMessageType () == MessageType .TOOL ) {
644
+ else if (messageType == MessageType .TOOL ) {
657
645
List <ContentBlock > toolResponses = ((ToolResponseMessage ) message ).getResponses ()
658
646
.stream ()
659
647
.map (toolResponse -> new ContentBlock (Type .TOOL_RESULT , toolResponse .id (),
660
648
toolResponse .responseData ()))
649
+ .map (contentBlock -> cacheAwareContentBlock (contentBlock , messageType , cacheEligibilityResolver ))
661
650
.toList ();
662
651
result .add (new AnthropicMessage (toolResponses , Role .USER ));
663
652
}
@@ -671,8 +660,7 @@ else if (message.getMessageType() == MessageType.TOOL) {
671
660
/**
672
661
* Build system content - as array if caching, string otherwise.
673
662
*/
674
- private Object buildSystemContent (Prompt prompt , AnthropicCacheStrategy strategy ,
675
- ChatCompletionRequest .CacheControl cacheControl , CacheBreakpointTracker breakpointsUsed ) {
663
+ private Object buildSystemContent (Prompt prompt , CacheEligibilityResolver cacheEligibilityResolver ) {
676
664
677
665
String systemText = prompt .getInstructions ()
678
666
.stream ()
@@ -685,15 +673,9 @@ private Object buildSystemContent(Prompt prompt, AnthropicCacheStrategy strategy
685
673
}
686
674
687
675
// Use array format when caching system
688
- if ((strategy == AnthropicCacheStrategy .SYSTEM_ONLY || strategy == AnthropicCacheStrategy .SYSTEM_AND_TOOLS
689
- || strategy == AnthropicCacheStrategy .CONVERSATION_HISTORY ) && breakpointsUsed .canUse ()
690
- && cacheControl != null ) {
691
-
692
- logger .debug ("Applying cache control to system message - strategy: {}, cacheControl: {}" , strategy ,
693
- cacheControl );
694
- List <ContentBlock > systemBlocks = List .of (new ContentBlock (systemText , cacheControl ));
695
- breakpointsUsed .use ();
696
- return systemBlocks ;
676
+ if (cacheEligibilityResolver .isCachingEnabled ()) {
677
+ return List
678
+ .of (cacheAwareContentBlock (new ContentBlock (systemText ), MessageType .SYSTEM , cacheEligibilityResolver ));
697
679
}
698
680
699
681
// Use string format when not caching (backward compatible)
@@ -704,9 +686,11 @@ private Object buildSystemContent(Prompt prompt, AnthropicCacheStrategy strategy
704
686
* Add cache control to the last tool for deterministic caching.
705
687
*/
706
688
private List <AnthropicApi .Tool > addCacheToLastTool (List <AnthropicApi .Tool > tools ,
707
- ChatCompletionRequest .CacheControl cacheControl , CacheBreakpointTracker breakpointsUsed ) {
689
+ CacheEligibilityResolver cacheEligibilityResolver ) {
690
+
691
+ ChatCompletionRequest .CacheControl cacheControl = cacheEligibilityResolver .resolveToolCacheControl ();
708
692
709
- if (tools == null || tools . isEmpty () || ! breakpointsUsed . canUse () || cacheControl == null ) {
693
+ if (cacheControl == null || tools == null || tools . isEmpty () ) {
710
694
return tools ;
711
695
}
712
696
@@ -716,7 +700,7 @@ private List<AnthropicApi.Tool> addCacheToLastTool(List<AnthropicApi.Tool> tools
716
700
if (i == tools .size () - 1 ) {
717
701
// Add cache control to last tool
718
702
tool = new AnthropicApi .Tool (tool .name (), tool .description (), tool .inputSchema (), cacheControl );
719
- breakpointsUsed . use ();
703
+ cacheEligibilityResolver . useCacheBlock ();
720
704
}
721
705
modifiedTools .add (tool );
722
706
}
@@ -804,36 +788,4 @@ public AnthropicChatModel build() {
804
788
805
789
}
806
790
807
- /**
808
- * Tracks cache breakpoints used (max 4 allowed by Anthropic). Non-static to ensure
809
- * each request has its own instance.
810
- */
811
- private class CacheBreakpointTracker {
812
-
813
- private int count = 0 ;
814
-
815
- private boolean hasWarned = false ;
816
-
817
- public boolean canUse () {
818
- return this .count < 4 ;
819
- }
820
-
821
- public void use () {
822
- if (this .count < 4 ) {
823
- this .count ++;
824
- }
825
- else if (!this .hasWarned ) {
826
- logger .warn (
827
- "Anthropic cache breakpoint limit (4) reached. Additional cache_control directives will be ignored. "
828
- + "Consider using fewer cache strategies or simpler content structure." );
829
- this .hasWarned = true ;
830
- }
831
- }
832
-
833
- public int getCount () {
834
- return this .count ;
835
- }
836
-
837
- }
838
-
839
791
}
0 commit comments