Skip to content

Commit c4e434a

Browse files
garethjevansilayaperumalg
authored andcommitted
fix: bedrock titan embeddings should return usage
Auto-cherry-pick to 1.0.x Signed-off-by: Gareth Evans <[email protected]>
1 parent 95b3df7 commit c4e434a

File tree

1 file changed

+15
-1
lines changed

1 file changed

+15
-1
lines changed

models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/titan/BedrockTitanEmbeddingModel.java

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,12 +28,14 @@
2828
import org.springframework.ai.bedrock.titan.api.TitanEmbeddingBedrockApi;
2929
import org.springframework.ai.bedrock.titan.api.TitanEmbeddingBedrockApi.TitanEmbeddingRequest;
3030
import org.springframework.ai.bedrock.titan.api.TitanEmbeddingBedrockApi.TitanEmbeddingResponse;
31+
import org.springframework.ai.chat.metadata.DefaultUsage;
3132
import org.springframework.ai.document.Document;
3233
import org.springframework.ai.embedding.AbstractEmbeddingModel;
3334
import org.springframework.ai.embedding.Embedding;
3435
import org.springframework.ai.embedding.EmbeddingOptions;
3536
import org.springframework.ai.embedding.EmbeddingRequest;
3637
import org.springframework.ai.embedding.EmbeddingResponse;
38+
import org.springframework.ai.embedding.EmbeddingResponseMetadata;
3739
import org.springframework.util.Assert;
3840

3941
/**
@@ -89,6 +91,7 @@ public EmbeddingResponse call(EmbeddingRequest request) {
8991

9092
List<Embedding> embeddings = new ArrayList<>();
9193
var indexCounter = new AtomicInteger(0);
94+
int tokenUsage = 0;
9295

9396
for (String inputContent : request.getInstructions()) {
9497
var apiRequest = createTitanEmbeddingRequest(inputContent, request.getOptions());
@@ -111,6 +114,10 @@ public EmbeddingResponse call(EmbeddingRequest request) {
111114
}
112115

113116
embeddings.add(new Embedding(response.embedding(), indexCounter.getAndIncrement()));
117+
118+
if (response.inputTextTokenCount() != null) {
119+
tokenUsage += response.inputTextTokenCount();
120+
}
114121
}
115122
catch (Exception ex) {
116123
logger.error("Titan API embedding failed for input at index {}: {}", indexCounter.get(),
@@ -120,7 +127,10 @@ public EmbeddingResponse call(EmbeddingRequest request) {
120127
}
121128
}
122129

123-
return new EmbeddingResponse(embeddings);
130+
EmbeddingResponseMetadata embeddingResponseMetadata = new EmbeddingResponseMetadata("",
131+
getDefaultUsage(tokenUsage));
132+
133+
return new EmbeddingResponse(embeddings, embeddingResponseMetadata);
124134
}
125135

126136
private TitanEmbeddingRequest createTitanEmbeddingRequest(String inputContent, EmbeddingOptions requestOptions) {
@@ -155,6 +165,10 @@ private String summarizeInput(String input) {
155165
return input.length() > 100 ? input.substring(0, 100) + "..." : input;
156166
}
157167

168+
private DefaultUsage getDefaultUsage(int tokens) {
169+
return new DefaultUsage(tokens, 0);
170+
}
171+
158172
public enum InputType {
159173

160174
TEXT, IMAGE

0 commit comments

Comments
 (0)