Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,15 @@

package org.springframework.ai.zhipuai;

import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.atomic.AtomicInteger;

import io.micrometer.observation.ObservationRegistry;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import org.springframework.ai.chat.metadata.DefaultUsage;
import org.springframework.ai.chat.metadata.EmptyUsage;
import org.springframework.ai.chat.metadata.Usage;
import org.springframework.ai.document.Document;
import org.springframework.ai.document.MetadataMode;
import org.springframework.ai.embedding.AbstractEmbeddingModel;
Expand All @@ -43,13 +43,15 @@
import org.springframework.ai.zhipuai.api.ZhiPuApiConstants;
import org.springframework.retry.support.RetryTemplate;
import org.springframework.util.Assert;
import org.springframework.util.CollectionUtils;
import org.springframework.util.StringUtils;

/**
* ZhiPuAI Embedding Model implementation.
*
* @author Geng Rong
* @author Soby Chacko
* @author YuJie Wan
* @since 1.0.0
*/
public class ZhiPuAiEmbeddingModel extends AbstractEmbeddingModel {
Expand Down Expand Up @@ -150,63 +152,50 @@ public float[] embed(Document document) {
@Override
public EmbeddingResponse call(EmbeddingRequest request) {
Assert.notEmpty(request.getInstructions(), "At least one text is required!");
if (request.getInstructions().size() != 1) {
logger.warn(
"ZhiPu Embedding does not support batch embedding. Will make multiple API calls to embed(Document)");
}

EmbeddingRequest embeddingRequest = buildEmbeddingRequest(request);
var embeddingRequest = buildEmbeddingRequest(request);

var observationContext = EmbeddingModelObservationContext.builder()
.embeddingRequest(embeddingRequest)
.provider(ZhiPuApiConstants.PROVIDER_NAME)
.build();

var zhipuEmbeddingRequest = zhipuEmbeddingRequest(embeddingRequest);
return EmbeddingModelObservationDocumentation.EMBEDDING_MODEL_OPERATION
.observation(this.observationConvention, DEFAULT_OBSERVATION_CONVENTION, () -> observationContext,
this.observationRegistry)
.observe(() -> {
List<float[]> embeddingList = new ArrayList<>();

var totalUsage = new ZhiPuAiApi.Usage(0, 0, 0);

for (String inputContent : request.getInstructions()) {
var apiRequest = createEmbeddingRequest(inputContent, embeddingRequest.getOptions());

ZhiPuAiApi.EmbeddingList<ZhiPuAiApi.Embedding> response = this.retryTemplate
.execute(ctx -> this.zhiPuAiApi.embeddings(apiRequest).getBody());
if (response == null || response.data() == null || response.data().isEmpty()) {
logger.warn("No embeddings returned for input: {}", inputContent);
embeddingList.add(new float[0]);
}
else {
int completionTokens = totalUsage.completionTokens() + response.usage().completionTokens();
int promptTokens = totalUsage.promptTokens() + response.usage().promptTokens();
int totalTokens = totalUsage.totalTokens() + response.usage().totalTokens();
totalUsage = new ZhiPuAiApi.Usage(completionTokens, promptTokens, totalTokens);
embeddingList.add(response.data().get(0).embedding());
}
}
var embeddingResponse = this.retryTemplate
.execute(ctx -> this.zhiPuAiApi.embeddings(zhipuEmbeddingRequest));

String model = (request.getOptions() != null && request.getOptions().getModel() != null)
? request.getOptions().getModel() : "unknown";
if (embeddingResponse == null || embeddingResponse.getBody() == null
|| CollectionUtils.isEmpty(embeddingResponse.getBody().data())) {
logger.warn("No embeddings returned for request: {}", request);
return new EmbeddingResponse(List.of());
}

var metadata = new EmbeddingResponseMetadata(model, getDefaultUsage(totalUsage));
ZhiPuAiApi.Usage usage = embeddingResponse.getBody().usage();
Usage usageResponse = usage != null ? getDefaultUsage(usage) : new EmptyUsage();

var indexCounter = new AtomicInteger(0);
var metadata = new EmbeddingResponseMetadata(embeddingResponse.getBody().model(), usageResponse);

List<Embedding> embeddings = embeddingList.stream()
.map(e -> new Embedding(e, indexCounter.getAndIncrement()))
List<Embedding> embeddings = embeddingResponse.getBody()
.data()
.stream()
.map(e -> new Embedding(e.embedding(), e.index()))
.toList();

EmbeddingResponse embeddingResponse = new EmbeddingResponse(embeddings, metadata);

observationContext.setResponse(embeddingResponse);

return embeddingResponse;
EmbeddingResponse response = new EmbeddingResponse(embeddings, metadata);
observationContext.setResponse(response);
return response;
});
}

private ZhiPuAiApi.EmbeddingRequest<List<String>> zhipuEmbeddingRequest(EmbeddingRequest embeddingRequest) {
return new ZhiPuAiApi.EmbeddingRequest<>(embeddingRequest.getInstructions(),
embeddingRequest.getOptions().getModel(), embeddingRequest.getOptions().getDimensions());
}

private DefaultUsage getDefaultUsage(ZhiPuAiApi.Usage usage) {
return new DefaultUsage(usage.promptTokens(), usage.completionTokens(), usage.totalTokens(), usage);
}
Expand All @@ -231,10 +220,6 @@ EmbeddingRequest buildEmbeddingRequest(EmbeddingRequest embeddingRequest) {
return new EmbeddingRequest(embeddingRequest.getInstructions(), requestOptions);
}

private ZhiPuAiApi.EmbeddingRequest<String> createEmbeddingRequest(String text, EmbeddingOptions requestOptions) {
return new ZhiPuAiApi.EmbeddingRequest<>(text, requestOptions.getModel(), requestOptions.getDimensions());
}

public void setObservationConvention(EmbeddingModelObservationConvention observationConvention) {
this.observationConvention = observationConvention;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,8 @@ public void zhiPuAiChatStreamNonTransientError() {
public void zhiPuAiEmbeddingTransientError() {

EmbeddingList<Embedding> expectedEmbeddings = new EmbeddingList<>("list",
List.of(new Embedding(0, new float[] { 9.9f, 8.8f })), "model", new ZhiPuAiApi.Usage(10, 10, 10));
List.of(new Embedding(0, new float[] { 9.9f, 8.8f }), new Embedding(0, new float[] { 9.9f, 8.8f })),
"model", new ZhiPuAiApi.Usage(10, 10, 10));

given(this.zhiPuAiApi.embeddings(isA(EmbeddingRequest.class)))
.willThrow(new TransientAiException("Transient Error 1"))
Expand All @@ -169,9 +170,11 @@ public void zhiPuAiEmbeddingTransientError() {
var result = this.embeddingModel
.call(new org.springframework.ai.embedding.EmbeddingRequest(List.of("text1", "text2"), options));

assertThat(result.getResults().size()).isEqualTo(2);
assertThat(result).isNotNull();
// choose the first result
assertThat(result.getResult().getOutput()).isEqualTo(new float[] { 9.9f, 8.8f });
assertThat(this.retryListener.onSuccessRetryCount).isEqualTo(0);
assertThat(this.retryListener.onSuccessRetryCount).isEqualTo(2);
assertThat(this.retryListener.onErrorRetryCount).isEqualTo(2);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,16 +85,20 @@ void embeddingV3WithCustomDimension() {
void batchEmbedding() {
assertThat(this.embeddingModel).isNotNull();

EmbeddingResponse embeddingResponse = this.embeddingModel.embedForResponse(List.of("Hello World", "HI"));
EmbeddingResponse embeddingResponse = this.embeddingModel
.embedForResponse(List.of("Hello world", "How are you?", "How is the weather today?"));

assertThat(embeddingResponse.getResults()).hasSize(2);
assertThat(embeddingResponse.getResults()).hasSize(3);

assertThat(embeddingResponse.getResults().get(0)).isNotNull();
assertThat(embeddingResponse.getResults().get(0).getOutput()).hasSize(1024);

assertThat(embeddingResponse.getResults().get(1)).isNotNull();
assertThat(embeddingResponse.getResults().get(1).getOutput()).hasSize(1024);

assertThat(embeddingResponse.getResults().get(2)).isNotNull();
assertThat(embeddingResponse.getResults().get(2).getOutput()).hasSize(1024);

assertThat(this.embeddingModel.dimensions()).isEqualTo(1024);
}

Expand Down