1616package org .springframework .ai .bedrock .anthropic3 ;
1717
1818import reactor .core .publisher .Flux ;
19- import software .amazon .awssdk .services .bedrockruntime .model .ConverseResponse ;
2019import software .amazon .awssdk .services .bedrockruntime .model .ConverseStreamOutput ;
20+ import software .amazon .awssdk .services .bedrockruntime .model .Tool ;
21+ import software .amazon .awssdk .services .bedrockruntime .model .ToolConfiguration ;
22+ import software .amazon .awssdk .services .bedrockruntime .model .ToolInputSchema ;
23+ import software .amazon .awssdk .services .bedrockruntime .model .ToolSpecification ;
2124
25+ import java .util .HashSet ;
26+ import java .util .List ;
27+ import java .util .Set ;
28+
29+ import org .springframework .ai .bedrock .BedrockConverseChatGenerationMetadata ;
2230import org .springframework .ai .bedrock .api .BedrockConverseApi ;
31+ import org .springframework .ai .bedrock .api .BedrockConverseApi .BedrockConverseRequest ;
2332import org .springframework .ai .bedrock .api .BedrockConverseApiUtils ;
33+ import org .springframework .ai .chat .messages .Message ;
2434import org .springframework .ai .chat .model .ChatModel ;
2535import org .springframework .ai .chat .model .ChatResponse ;
36+ import org .springframework .ai .chat .model .Generation ;
2637import org .springframework .ai .chat .model .StreamingChatModel ;
2738import org .springframework .ai .chat .prompt .ChatOptions ;
2839import org .springframework .ai .chat .prompt .Prompt ;
2940import org .springframework .ai .model .ModelDescription ;
41+ import org .springframework .ai .model .ModelOptionsUtils ;
42+ import org .springframework .ai .model .function .AbstractFunctionCallSupport ;
43+ import org .springframework .ai .model .function .FunctionCallbackContext ;
3044import org .springframework .util .Assert ;
45+ import org .springframework .util .CollectionUtils ;
3146
3247/**
3348 * Java {@link ChatModel} and {@link StreamingChatModel} for the Bedrock Anthropic3 chat
3853 * @author Wei Jiang
3954 * @since 1.0.0
4055 */
41- public class BedrockAnthropic3ChatModel implements ChatModel , StreamingChatModel {
56+ public class BedrockAnthropic3ChatModel
57+ extends AbstractFunctionCallSupport <Message , BedrockConverseRequest , ChatResponse >
58+ implements ChatModel , StreamingChatModel {
4259
4360 private final String modelId ;
4461
@@ -56,6 +73,13 @@ public BedrockAnthropic3ChatModel(BedrockConverseApi converseApi, Anthropic3Chat
5673 }
5774
5875 public BedrockAnthropic3ChatModel (String modelId , BedrockConverseApi converseApi , Anthropic3ChatOptions options ) {
76+ this (modelId , converseApi , options , null );
77+ }
78+
79+ public BedrockAnthropic3ChatModel (String modelId , BedrockConverseApi converseApi , Anthropic3ChatOptions options ,
80+ FunctionCallbackContext functionCallbackContext ) {
81+ super (functionCallbackContext );
82+
5983 Assert .notNull (modelId , "modelId must not be null." );
6084 Assert .notNull (converseApi , "BedrockConverseApi must not be null." );
6185 Assert .notNull (options , "Anthropic3ChatOptions must not be null." );
@@ -69,29 +93,125 @@ public BedrockAnthropic3ChatModel(String modelId, BedrockConverseApi converseApi
6993 public ChatResponse call (Prompt prompt ) {
7094 Assert .notNull (prompt , "Prompt must not be null." );
7195
72- var request = BedrockConverseApiUtils .createConverseRequest (modelId , prompt , defaultOptions );
73-
74- ConverseResponse response = this .converseApi .converse (request );
96+ var request = createBedrockConverseRequest (prompt );
7597
76- return BedrockConverseApiUtils . convertConverseResponse ( response );
98+ return this . callWithFunctionSupport ( request );
7799 }
78100
79101 @ Override
80102 public Flux <ChatResponse > stream (Prompt prompt ) {
81103 Assert .notNull (prompt , "Prompt must not be null." );
82104
105+ // TODO
83106 var request = BedrockConverseApiUtils .createConverseStreamRequest (modelId , prompt , defaultOptions );
84107
85108 Flux <ConverseStreamOutput > fluxResponse = this .converseApi .converseStream (request );
86109
87110 return fluxResponse .map (output -> BedrockConverseApiUtils .convertConverseStreamOutput (output ));
88111 }
89112
113+ private BedrockConverseRequest createBedrockConverseRequest (Prompt prompt ) {
114+ var request = BedrockConverseApiUtils .createBedrockConverseRequest (modelId , prompt , defaultOptions );
115+
116+ ToolConfiguration toolConfiguration = createToolConfiguration (prompt );
117+ request .setToolConfiguration (toolConfiguration );
118+
119+ return request ;
120+ }
121+
122+ private ToolConfiguration createToolConfiguration (Prompt prompt ) {
123+ Set <String > functionsForThisRequest = new HashSet <>();
124+
125+ if (this .defaultOptions != null ) {
126+ Set <String > promptEnabledFunctions = this .handleFunctionCallbackConfigurations (this .defaultOptions ,
127+ !IS_RUNTIME_CALL );
128+ functionsForThisRequest .addAll (promptEnabledFunctions );
129+ }
130+
131+ if (prompt .getOptions () != null ) {
132+ if (prompt .getOptions () instanceof ChatOptions runtimeOptions ) {
133+ Anthropic3ChatOptions updatedRuntimeOptions = ModelOptionsUtils .copyToTarget (runtimeOptions ,
134+ ChatOptions .class , Anthropic3ChatOptions .class );
135+
136+ Set <String > defaultEnabledFunctions = this .handleFunctionCallbackConfigurations (updatedRuntimeOptions ,
137+ IS_RUNTIME_CALL );
138+ functionsForThisRequest .addAll (defaultEnabledFunctions );
139+ }
140+ else {
141+ throw new IllegalArgumentException ("Prompt options are not of type ChatOptions: "
142+ + prompt .getOptions ().getClass ().getSimpleName ());
143+ }
144+ }
145+
146+ if (CollectionUtils .isEmpty (functionsForThisRequest )) {
147+ return null ;
148+ }
149+ else {
150+ return ToolConfiguration .builder ().tools (getFunctionTools (functionsForThisRequest )).build ();
151+ }
152+ }
153+
154+ private List <Tool > getFunctionTools (Set <String > functionNames ) {
155+ return this .resolveFunctionCallbacks (functionNames ).stream ().map (functionCallback -> {
156+ var description = functionCallback .getDescription ();
157+ var name = functionCallback .getName ();
158+ String inputSchema = functionCallback .getInputTypeSchema ();
159+
160+ return Tool .builder ()
161+ .toolSpec (ToolSpecification .builder ()
162+ .name (name )
163+ .description (description )
164+ .inputSchema (ToolInputSchema .builder ()
165+ .json (BedrockConverseApiUtils .convertObjectToDocument (ModelOptionsUtils .jsonToMap (inputSchema )))
166+ .build ())
167+ .build ())
168+ .build ();
169+ }).toList ();
170+ }
171+
90172 @ Override
91173 public ChatOptions getDefaultOptions () {
92174 return Anthropic3ChatOptions .fromOptions (this .defaultOptions );
93175 }
94176
177+ @ Override
178+ protected BedrockConverseRequest doCreateToolResponseRequest (BedrockConverseRequest previousRequest ,
179+ Message responseMessage , List <Message > conversationHistory ) {
180+ // TODO
181+ return null ;
182+ }
183+
184+ @ Override
185+ protected List <Message > doGetUserMessages (BedrockConverseRequest request ) {
186+ return BedrockConverseApiUtils .getMessagesInstructions (request .getMessages ());
187+ }
188+
189+ @ Override
190+ protected Message doGetToolResponseMessage (ChatResponse response ) {
191+ return response .getResult ().getOutput ();
192+ }
193+
194+ @ Override
195+ protected ChatResponse doChatCompletion (BedrockConverseRequest request ) {
196+ return converseApi .converse (request );
197+ }
198+
199+ @ Override
200+ protected Flux <ChatResponse > doChatCompletionStream (BedrockConverseRequest request ) {
201+ return converseApi .converseStream (request );
202+ }
203+
204+ @ Override
205+ protected boolean isToolFunctionCall (ChatResponse response ) {
206+ Generation result = response .getResult ();
207+
208+ if (result .getMetadata () instanceof BedrockConverseChatGenerationMetadata metadata ) {
209+ return metadata .isToolUse ();
210+ }
211+
212+ return false ;
213+ }
214+
95215 /**
96216 * Anthropic3 models version.
97217 */
0 commit comments