Skip to content

Commit 470d00e

Browse files
author
wmz7year
committed
Amazon Bedrock converse API module supports custom BedrockRuntimeClient and BedrockRuntimeAsyncClient.
1 parent 678389f commit 470d00e

File tree

10 files changed

+190
-79
lines changed

10 files changed

+190
-79
lines changed

models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/api/AbstractBedrockApi.java

Lines changed: 24 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,6 @@ public abstract class AbstractBedrockApi<I, O, SO> {
7070

7171
private final String modelId;
7272
private final ObjectMapper objectMapper;
73-
private final Region region;
7473
private final BedrockRuntimeClient client;
7574
private final BedrockRuntimeAsyncClient clientStreaming;
7675

@@ -136,29 +135,36 @@ public AbstractBedrockApi(String modelId, AwsCredentialsProvider credentialsProv
136135
*/
137136
public AbstractBedrockApi(String modelId, AwsCredentialsProvider credentialsProvider, Region region,
138137
ObjectMapper objectMapper, Duration timeout) {
138+
this(modelId, BedrockRuntimeClient.builder()
139+
.region(region)
140+
.credentialsProvider(credentialsProvider)
141+
.overrideConfiguration(c -> c.apiCallTimeout(timeout))
142+
.build(), BedrockRuntimeAsyncClient.builder()
143+
.region(region)
144+
.credentialsProvider(credentialsProvider)
145+
.overrideConfiguration(c -> c.apiCallTimeout(timeout))
146+
.build(), objectMapper);
147+
}
139148

149+
/**
150+
* Create a new AbstractBedrockApi instance using the provided AWS Bedrock clients, region and object mapper.
151+
*
152+
* @param modelId The model id to use.
153+
* @param bedrockRuntimeClient The AWS BedrockRuntimeClient instance.
154+
* @param bedrockRuntimeAsyncClient The AWS BedrockRuntimeAsyncClient instance.
155+
* @param objectMapper The object mapper to use for JSON serialization and deserialization.
156+
*/
157+
public AbstractBedrockApi(String modelId, BedrockRuntimeClient bedrockRuntimeClient, BedrockRuntimeAsyncClient bedrockRuntimeAsyncClient,
158+
ObjectMapper objectMapper) {
140159
Assert.hasText(modelId, "Model id must not be empty");
141-
Assert.notNull(credentialsProvider, "Credentials provider must not be null");
142-
Assert.notNull(region, "Region must not be empty");
160+
Assert.notNull(bedrockRuntimeClient, "bedrockRuntimeClient must not be null");
161+
Assert.notNull(bedrockRuntimeAsyncClient, "bedrockRuntimeAsyncClient must not be null");
143162
Assert.notNull(objectMapper, "Object mapper must not be null");
144-
Assert.notNull(timeout, "Timeout must not be null");
145163

146164
this.modelId = modelId;
165+
this.client = bedrockRuntimeClient;
166+
this.clientStreaming = bedrockRuntimeAsyncClient;
147167
this.objectMapper = objectMapper;
148-
this.region = region;
149-
150-
151-
this.client = BedrockRuntimeClient.builder()
152-
.region(this.region)
153-
.credentialsProvider(credentialsProvider)
154-
.overrideConfiguration(c -> c.apiCallTimeout(timeout))
155-
.build();
156-
157-
this.clientStreaming = BedrockRuntimeAsyncClient.builder()
158-
.region(this.region)
159-
.credentialsProvider(credentialsProvider)
160-
.overrideConfiguration(c -> c.apiCallTimeout(timeout))
161-
.build();
162168
}
163169

164170
/**
@@ -168,13 +174,6 @@ public String getModelId() {
168174
return this.modelId;
169175
}
170176

171-
/**
172-
* @return The AWS region.
173-
*/
174-
public Region getRegion() {
175-
return this.region;
176-
}
177-
178177
/**
179178
* Encapsulates the metrics about the model invocation.
180179
* https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-claude.html

models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/api/BedrockConverseApi.java

Lines changed: 28 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -69,11 +69,9 @@ public class BedrockConverseApi {
6969

7070
private static final Logger logger = LoggerFactory.getLogger(BedrockConverseApi.class);
7171

72-
private final Region region;
72+
private final BedrockRuntimeClient bedrockRuntimeClient;
7373

74-
private final BedrockRuntimeClient client;
75-
76-
private final BedrockRuntimeAsyncClient clientStreaming;
74+
private final BedrockRuntimeAsyncClient bedrockRuntimeAsyncClient;
7775

7876
private final RetryTemplate retryTemplate;
7977

@@ -155,32 +153,34 @@ public BedrockConverseApi(AwsCredentialsProvider credentialsProvider, Region reg
155153
*/
156154
public BedrockConverseApi(AwsCredentialsProvider credentialsProvider, Region region, Duration timeout,
157155
RetryTemplate retryTemplate) {
158-
Assert.notNull(credentialsProvider, "Credentials provider must not be null");
159-
Assert.notNull(region, "Region must not be empty");
160-
Assert.notNull(timeout, "Timeout must not be null");
161-
Assert.notNull(retryTemplate, "RetryTemplate must not be null");
162-
163-
this.region = region;
164-
this.retryTemplate = retryTemplate;
165-
166-
this.client = BedrockRuntimeClient.builder()
167-
.region(this.region)
168-
.credentialsProvider(credentialsProvider)
169-
.overrideConfiguration(c -> c.apiCallTimeout(timeout))
170-
.build();
171-
172-
this.clientStreaming = BedrockRuntimeAsyncClient.builder()
173-
.region(this.region)
174-
.credentialsProvider(credentialsProvider)
175-
.overrideConfiguration(c -> c.apiCallTimeout(timeout))
176-
.build();
156+
this(BedrockRuntimeClient.builder()
157+
.region(region)
158+
.credentialsProvider(credentialsProvider)
159+
.overrideConfiguration(c -> c.apiCallTimeout(timeout))
160+
.build(), BedrockRuntimeAsyncClient.builder()
161+
.region(region)
162+
.credentialsProvider(credentialsProvider)
163+
.overrideConfiguration(c -> c.apiCallTimeout(timeout))
164+
.build(), retryTemplate);
177165
}
178166

179167
/**
180-
* @return The AWS region.
168+
* Create a new BedrockConverseApi instance using the provided AWS Bedrock clients and the RetryTemplate.
169+
*
170+
* @param bedrockRuntimeClient The AWS BedrockRuntimeClient instance.
171+
* @param bedrockRuntimeAsyncClient The AWS BedrockRuntimeAsyncClient instance.
172+
* @param retryTemplate The retry template used to retry the Amazon Bedrock Converse
173+
* API calls.
181174
*/
182-
public Region getRegion() {
183-
return this.region;
175+
public BedrockConverseApi(BedrockRuntimeClient bedrockRuntimeClient, BedrockRuntimeAsyncClient bedrockRuntimeAsyncClient,
176+
RetryTemplate retryTemplate) {
177+
Assert.notNull(bedrockRuntimeClient, "bedrockRuntimeClient must not be null");
178+
Assert.notNull(bedrockRuntimeAsyncClient, "bedrockRuntimeAsyncClient must not be null");
179+
Assert.notNull(retryTemplate, "RetryTemplate must not be null");
180+
181+
this.bedrockRuntimeClient = bedrockRuntimeClient;
182+
this.bedrockRuntimeAsyncClient = bedrockRuntimeAsyncClient;
183+
this.retryTemplate = retryTemplate;
184184
}
185185

186186
/**
@@ -215,7 +215,7 @@ public ConverseResponse converse(ConverseRequest converseRequest) {
215215
Assert.notNull(converseRequest, "'converseRequest' must not be null");
216216

217217
return this.retryTemplate.execute(ctx -> {
218-
return client.converse(converseRequest);
218+
return bedrockRuntimeClient.converse(converseRequest);
219219
});
220220
}
221221

@@ -280,7 +280,7 @@ public Flux<ConverseStreamOutput> converseStream(ConverseStreamRequest converseS
280280
})
281281
.build();
282282

283-
clientStreaming.converseStream(converseStreamRequest, responseHandler);
283+
bedrockRuntimeAsyncClient.converseStream(converseStreamRequest, responseHandler);
284284

285285
return eventSink.asFlux();
286286
});

models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/cohere/api/CohereEmbeddingBedrockApi.java

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@
2525
import com.fasterxml.jackson.databind.ObjectMapper;
2626
import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider;
2727
import software.amazon.awssdk.regions.Region;
28+
import software.amazon.awssdk.services.bedrockruntime.BedrockRuntimeAsyncClient;
29+
import software.amazon.awssdk.services.bedrockruntime.BedrockRuntimeClient;
2830

2931
import org.springframework.ai.bedrock.api.AbstractBedrockApi;
3032
import org.springframework.ai.bedrock.cohere.api.CohereEmbeddingBedrockApi.CohereEmbeddingRequest;
@@ -109,6 +111,19 @@ public CohereEmbeddingBedrockApi(String modelId, AwsCredentialsProvider credenti
109111
super(modelId, credentialsProvider, region, objectMapper, timeout);
110112
}
111113

114+
/**
115+
* Create a new CohereEmbeddingBedrockApi instance using the provided AWS Bedrock clients, region and object mapper.
116+
*
117+
* @param modelId The model id to use.
118+
* @param bedrockRuntimeClient The AWS BedrockRuntimeClient instance.
119+
* @param bedrockRuntimeAsyncClient The AWS BedrockRuntimeAsyncClient instance.
120+
* @param objectMapper The object mapper to use for JSON serialization and deserialization.
121+
*/
122+
public CohereEmbeddingBedrockApi(String model, BedrockRuntimeClient bedrockRuntimeClient,
123+
BedrockRuntimeAsyncClient bedrockRuntimeAsyncClient, ObjectMapper objectMapper) {
124+
super(model, bedrockRuntimeClient, bedrockRuntimeAsyncClient, objectMapper);
125+
}
126+
112127
/**
113128
* The Cohere Embed model request.
114129
*

models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/titan/api/TitanEmbeddingBedrockApi.java

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@
2424
import com.fasterxml.jackson.databind.ObjectMapper;
2525
import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider;
2626
import software.amazon.awssdk.regions.Region;
27+
import software.amazon.awssdk.services.bedrockruntime.BedrockRuntimeAsyncClient;
28+
import software.amazon.awssdk.services.bedrockruntime.BedrockRuntimeClient;
2729

2830
import org.springframework.ai.bedrock.api.AbstractBedrockApi;
2931
import org.springframework.ai.bedrock.titan.api.TitanEmbeddingBedrockApi.TitanEmbeddingRequest;
@@ -81,6 +83,20 @@ public TitanEmbeddingBedrockApi(String modelId, AwsCredentialsProvider credentia
8183
super(modelId, credentialsProvider, region, objectMapper, timeout);
8284
}
8385

86+
/**
87+
* Create a new TitanEmbeddingBedrockApi instance.
88+
*
89+
* @param modelId The model id to use.
90+
* @param bedrockRuntimeClient The AWS BedrockRuntimeClient instance.
91+
* @param bedrockRuntimeAsyncClient The AWS BedrockRuntimeAsyncClient instance.
92+
* @param objectMapper The object mapper to use for JSON serialization and deserialization.
93+
*/
94+
public TitanEmbeddingBedrockApi(String model, BedrockRuntimeClient bedrockRuntimeClient,
95+
BedrockRuntimeAsyncClient bedrockRuntimeAsyncClient, ObjectMapper objectMapper) {
96+
super(model, bedrockRuntimeClient, bedrockRuntimeAsyncClient, objectMapper);
97+
}
98+
99+
84100
/**
85101
* Titan Embedding request parameters.
86102
*

spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/BedrockAwsConnectionConfiguration.java

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,10 @@
2222
import software.amazon.awssdk.regions.Region;
2323
import software.amazon.awssdk.regions.providers.AwsRegionProvider;
2424
import software.amazon.awssdk.regions.providers.DefaultAwsRegionProviderChain;
25+
import software.amazon.awssdk.services.bedrockruntime.BedrockRuntimeAsyncClient;
26+
import software.amazon.awssdk.services.bedrockruntime.BedrockRuntimeClient;
2527

28+
import org.springframework.boot.autoconfigure.condition.ConditionalOnBean;
2629
import org.springframework.boot.autoconfigure.condition.ConditionalOnMissingBean;
2730
import org.springframework.boot.context.properties.EnableConfigurationProperties;
2831
import org.springframework.context.annotation.Bean;
@@ -60,6 +63,32 @@ public AwsRegionProvider regionProvider(BedrockAwsConnectionProperties propertie
6063
return DefaultAwsRegionProviderChain.builder().build();
6164
}
6265

66+
@Bean
67+
@ConditionalOnMissingBean
68+
@ConditionalOnBean({ AwsCredentialsProvider.class, AwsRegionProvider.class })
69+
public BedrockRuntimeClient bedrockRuntimeClient(AwsCredentialsProvider credentialsProvider,
70+
AwsRegionProvider regionProvider, BedrockAwsConnectionProperties properties) {
71+
72+
return BedrockRuntimeClient.builder()
73+
.region(regionProvider.getRegion())
74+
.credentialsProvider(credentialsProvider)
75+
.overrideConfiguration(c -> c.apiCallTimeout(properties.getTimeout()))
76+
.build();
77+
}
78+
79+
@Bean
80+
@ConditionalOnMissingBean
81+
@ConditionalOnBean({ AwsCredentialsProvider.class, AwsRegionProvider.class })
82+
public BedrockRuntimeAsyncClient bedrockRuntimeAsyncClient(AwsCredentialsProvider credentialsProvider,
83+
AwsRegionProvider regionProvider, BedrockAwsConnectionProperties properties) {
84+
85+
return BedrockRuntimeAsyncClient.builder()
86+
.region(regionProvider.getRegion())
87+
.credentialsProvider(credentialsProvider)
88+
.overrideConfiguration(c -> c.apiCallTimeout(properties.getTimeout()))
89+
.build();
90+
}
91+
6392
/**
6493
* @author Wei Jiang
6594
*/

spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/api/BedrockConverseApiAutoConfiguration.java

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,6 @@
2828
import org.springframework.context.annotation.Import;
2929
import org.springframework.retry.support.RetryTemplate;
3030

31-
import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider;
32-
import software.amazon.awssdk.regions.providers.AwsRegionProvider;
3331
import software.amazon.awssdk.services.bedrockruntime.BedrockRuntimeAsyncClient;
3432
import software.amazon.awssdk.services.bedrockruntime.BedrockRuntimeClient;
3533

@@ -47,12 +45,10 @@ public class BedrockConverseApiAutoConfiguration {
4745

4846
@Bean
4947
@ConditionalOnMissingBean
50-
@ConditionalOnBean({ AwsCredentialsProvider.class, AwsRegionProvider.class })
51-
public BedrockConverseApi bedrockConverseApi(AwsCredentialsProvider credentialsProvider,
52-
AwsRegionProvider regionProvider, BedrockAwsConnectionProperties awsProperties,
53-
RetryTemplate retryTemplate) {
54-
return new BedrockConverseApi(credentialsProvider, regionProvider.getRegion(), awsProperties.getTimeout(),
55-
retryTemplate);
48+
@ConditionalOnBean({ BedrockRuntimeClient.class, BedrockRuntimeAsyncClient.class })
49+
public BedrockConverseApi bedrockConverseApi(BedrockRuntimeClient bedrockRuntimeClient,
50+
BedrockRuntimeAsyncClient bedrockRuntimeAsyncClient, RetryTemplate retryTemplate) {
51+
return new BedrockConverseApi(bedrockRuntimeClient, bedrockRuntimeAsyncClient, retryTemplate);
5652
}
5753

5854
}

spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/cohere/BedrockCohereEmbeddingAutoConfiguration.java

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,8 @@
1616
package org.springframework.ai.autoconfigure.bedrock.cohere;
1717

1818
import com.fasterxml.jackson.databind.ObjectMapper;
19-
import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider;
20-
import software.amazon.awssdk.regions.providers.AwsRegionProvider;
19+
import software.amazon.awssdk.services.bedrockruntime.BedrockRuntimeAsyncClient;
20+
import software.amazon.awssdk.services.bedrockruntime.BedrockRuntimeClient;
2121

2222
import org.springframework.ai.autoconfigure.bedrock.BedrockAwsConnectionConfiguration;
2323
import org.springframework.ai.autoconfigure.bedrock.BedrockAwsConnectionProperties;
@@ -48,12 +48,11 @@ public class BedrockCohereEmbeddingAutoConfiguration {
4848

4949
@Bean
5050
@ConditionalOnMissingBean
51-
@ConditionalOnBean({ AwsCredentialsProvider.class, AwsRegionProvider.class })
52-
public CohereEmbeddingBedrockApi cohereEmbeddingApi(AwsCredentialsProvider credentialsProvider,
53-
AwsRegionProvider regionProvider, BedrockCohereEmbeddingProperties properties,
54-
BedrockAwsConnectionProperties awsProperties) {
55-
return new CohereEmbeddingBedrockApi(properties.getModel(), credentialsProvider, regionProvider.getRegion(),
56-
new ObjectMapper(), awsProperties.getTimeout());
51+
@ConditionalOnBean({ BedrockRuntimeClient.class, BedrockRuntimeAsyncClient.class })
52+
public CohereEmbeddingBedrockApi cohereEmbeddingApi(BedrockCohereEmbeddingProperties properties,
53+
BedrockRuntimeClient bedrockRuntimeClient, BedrockRuntimeAsyncClient bedrockRuntimeAsyncClient) {
54+
return new CohereEmbeddingBedrockApi(properties.getModel(), bedrockRuntimeClient, bedrockRuntimeAsyncClient,
55+
new ObjectMapper());
5756
}
5857

5958
@Bean

spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/titan/BedrockTitanEmbeddingAutoConfiguration.java

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,8 @@
1616
package org.springframework.ai.autoconfigure.bedrock.titan;
1717

1818
import com.fasterxml.jackson.databind.ObjectMapper;
19-
import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider;
20-
import software.amazon.awssdk.regions.providers.AwsRegionProvider;
19+
import software.amazon.awssdk.services.bedrockruntime.BedrockRuntimeAsyncClient;
20+
import software.amazon.awssdk.services.bedrockruntime.BedrockRuntimeClient;
2121

2222
import org.springframework.ai.autoconfigure.bedrock.BedrockAwsConnectionConfiguration;
2323
import org.springframework.ai.autoconfigure.bedrock.BedrockAwsConnectionProperties;
@@ -48,12 +48,11 @@ public class BedrockTitanEmbeddingAutoConfiguration {
4848

4949
@Bean
5050
@ConditionalOnMissingBean
51-
@ConditionalOnBean({ AwsCredentialsProvider.class, AwsRegionProvider.class })
52-
public TitanEmbeddingBedrockApi titanEmbeddingBedrockApi(AwsCredentialsProvider credentialsProvider,
53-
AwsRegionProvider regionProvider, BedrockTitanEmbeddingProperties properties,
54-
BedrockAwsConnectionProperties awsProperties) {
55-
return new TitanEmbeddingBedrockApi(properties.getModel(), credentialsProvider, regionProvider.getRegion(),
56-
new ObjectMapper(), awsProperties.getTimeout());
51+
@ConditionalOnBean({ BedrockRuntimeClient.class, BedrockRuntimeAsyncClient.class })
52+
public TitanEmbeddingBedrockApi titanEmbeddingBedrockApi(BedrockTitanEmbeddingProperties properties,
53+
BedrockRuntimeClient bedrockRuntimeClient, BedrockRuntimeAsyncClient bedrockRuntimeAsyncClient) {
54+
return new TitanEmbeddingBedrockApi(properties.getModel(), bedrockRuntimeClient, bedrockRuntimeAsyncClient,
55+
new ObjectMapper());
5756
}
5857

5958
@Bean

0 commit comments

Comments
 (0)