From f11fb87bd1777c4cb8a222c9f10a70503e05b3c5 Mon Sep 17 00:00:00 2001 From: YuJie Wan <31400063+eeaters@users.noreply.github.com> Date: Thu, 4 Sep 2025 15:02:45 +0800 Subject: [PATCH 1/2] optimize ZhiPu Embedding to support batch embedding. - support batch embedding - Make test adjustments based on the official demo Signed-off-by: YuJie Wan <31400063+eeaters@users.noreply.github.com> --- .../ai/zhipuai/ZhiPuAiEmbeddingModel.java | 69 ++++++++----------- .../ai/zhipuai/api/ZhiPuAiRetryTests.java | 7 +- .../ai/zhipuai/embedding/EmbeddingIT.java | 8 ++- 3 files changed, 38 insertions(+), 46 deletions(-) diff --git a/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/ZhiPuAiEmbeddingModel.java b/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/ZhiPuAiEmbeddingModel.java index f8c3f620529..6a22ee4a475 100644 --- a/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/ZhiPuAiEmbeddingModel.java +++ b/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/ZhiPuAiEmbeddingModel.java @@ -16,15 +16,15 @@ package org.springframework.ai.zhipuai; -import java.util.ArrayList; import java.util.List; -import java.util.concurrent.atomic.AtomicInteger; import io.micrometer.observation.ObservationRegistry; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.springframework.ai.chat.metadata.DefaultUsage; +import org.springframework.ai.chat.metadata.EmptyUsage; +import org.springframework.ai.chat.metadata.Usage; import org.springframework.ai.document.Document; import org.springframework.ai.document.MetadataMode; import org.springframework.ai.embedding.AbstractEmbeddingModel; @@ -43,6 +43,7 @@ import org.springframework.ai.zhipuai.api.ZhiPuApiConstants; import org.springframework.retry.support.RetryTemplate; import org.springframework.util.Assert; +import org.springframework.util.CollectionUtils; import org.springframework.util.StringUtils; /** @@ -50,6 +51,7 @@ * * @author Geng Rong * @author Soby Chacko + * @author YuJie Wan * @since 1.0.0 */ public class ZhiPuAiEmbeddingModel extends AbstractEmbeddingModel { @@ -150,12 +152,9 @@ public float[] embed(Document document) { @Override public EmbeddingResponse call(EmbeddingRequest request) { Assert.notEmpty(request.getInstructions(), "At least one text is required!"); - if (request.getInstructions().size() != 1) { - logger.warn( - "ZhiPu Embedding does not support batch embedding. Will make multiple API calls to embed(Document)"); - } EmbeddingRequest embeddingRequest = buildEmbeddingRequest(request); + var zhipuEmbeddingRequest = zhipuEmbeddingRequest(embeddingRequest); var observationContext = EmbeddingModelObservationContext.builder() .embeddingRequest(embeddingRequest) @@ -166,47 +165,37 @@ public EmbeddingResponse call(EmbeddingRequest request) { .observation(this.observationConvention, DEFAULT_OBSERVATION_CONVENTION, () -> observationContext, this.observationRegistry) .observe(() -> { - List embeddingList = new ArrayList<>(); - - var totalUsage = new ZhiPuAiApi.Usage(0, 0, 0); - - for (String inputContent : request.getInstructions()) { - var apiRequest = createEmbeddingRequest(inputContent, embeddingRequest.getOptions()); - - ZhiPuAiApi.EmbeddingList response = this.retryTemplate - .execute(ctx -> this.zhiPuAiApi.embeddings(apiRequest).getBody()); - if (response == null || response.data() == null || response.data().isEmpty()) { - logger.warn("No embeddings returned for input: {}", inputContent); - embeddingList.add(new float[0]); - } - else { - int completionTokens = totalUsage.completionTokens() + response.usage().completionTokens(); - int promptTokens = totalUsage.promptTokens() + response.usage().promptTokens(); - int totalTokens = totalUsage.totalTokens() + response.usage().totalTokens(); - totalUsage = new ZhiPuAiApi.Usage(completionTokens, promptTokens, totalTokens); - embeddingList.add(response.data().get(0).embedding()); - } - } + var embeddingResponse = this.retryTemplate + .execute(ctx -> this.zhiPuAiApi.embeddings(zhipuEmbeddingRequest)); - String model = (request.getOptions() != null && request.getOptions().getModel() != null) - ? request.getOptions().getModel() : "unknown"; + if (embeddingResponse == null || embeddingResponse.getBody() == null + || CollectionUtils.isEmpty(embeddingResponse.getBody().data())) { + logger.warn("No embeddings returned for request: {}", request); + return new EmbeddingResponse(List.of()); + } - var metadata = new EmbeddingResponseMetadata(model, getDefaultUsage(totalUsage)); + ZhiPuAiApi.Usage usage = embeddingResponse.getBody().usage(); + Usage usageResponse = usage != null ? getDefaultUsage(usage) : new EmptyUsage(); - var indexCounter = new AtomicInteger(0); + var metadata = new EmbeddingResponseMetadata(embeddingResponse.getBody().model(), usageResponse); - List embeddings = embeddingList.stream() - .map(e -> new Embedding(e, indexCounter.getAndIncrement())) + List embeddings = embeddingResponse.getBody() + .data() + .stream() + .map(e -> new Embedding(e.embedding(), e.index())) .toList(); - EmbeddingResponse embeddingResponse = new EmbeddingResponse(embeddings, metadata); - - observationContext.setResponse(embeddingResponse); - - return embeddingResponse; + EmbeddingResponse response = new EmbeddingResponse(embeddings, metadata); + observationContext.setResponse(response); + return response; }); } + private ZhiPuAiApi.EmbeddingRequest> zhipuEmbeddingRequest(EmbeddingRequest embeddingRequest) { + return new ZhiPuAiApi.EmbeddingRequest<>(embeddingRequest.getInstructions(), + embeddingRequest.getOptions().getModel(), embeddingRequest.getOptions().getDimensions()); + } + private DefaultUsage getDefaultUsage(ZhiPuAiApi.Usage usage) { return new DefaultUsage(usage.promptTokens(), usage.completionTokens(), usage.totalTokens(), usage); } @@ -231,10 +220,6 @@ EmbeddingRequest buildEmbeddingRequest(EmbeddingRequest embeddingRequest) { return new EmbeddingRequest(embeddingRequest.getInstructions(), requestOptions); } - private ZhiPuAiApi.EmbeddingRequest createEmbeddingRequest(String text, EmbeddingOptions requestOptions) { - return new ZhiPuAiApi.EmbeddingRequest<>(text, requestOptions.getModel(), requestOptions.getDimensions()); - } - public void setObservationConvention(EmbeddingModelObservationConvention observationConvention) { this.observationConvention = observationConvention; } diff --git a/models/spring-ai-zhipuai/src/test/java/org/springframework/ai/zhipuai/api/ZhiPuAiRetryTests.java b/models/spring-ai-zhipuai/src/test/java/org/springframework/ai/zhipuai/api/ZhiPuAiRetryTests.java index b78db162096..b1b37bd37ba 100644 --- a/models/spring-ai-zhipuai/src/test/java/org/springframework/ai/zhipuai/api/ZhiPuAiRetryTests.java +++ b/models/spring-ai-zhipuai/src/test/java/org/springframework/ai/zhipuai/api/ZhiPuAiRetryTests.java @@ -159,7 +159,8 @@ public void zhiPuAiChatStreamNonTransientError() { public void zhiPuAiEmbeddingTransientError() { EmbeddingList expectedEmbeddings = new EmbeddingList<>("list", - List.of(new Embedding(0, new float[] { 9.9f, 8.8f })), "model", new ZhiPuAiApi.Usage(10, 10, 10)); + List.of(new Embedding(0, new float[] { 9.9f, 8.8f }), new Embedding(0, new float[] { 9.9f, 8.8f })), + "model", new ZhiPuAiApi.Usage(10, 10, 10)); given(this.zhiPuAiApi.embeddings(isA(EmbeddingRequest.class))) .willThrow(new TransientAiException("Transient Error 1")) @@ -169,9 +170,11 @@ public void zhiPuAiEmbeddingTransientError() { var result = this.embeddingModel .call(new org.springframework.ai.embedding.EmbeddingRequest(List.of("text1", "text2"), options)); + assertThat(result.getResults().size()).isEqualTo(2); assertThat(result).isNotNull(); + // choose the first result assertThat(result.getResult().getOutput()).isEqualTo(new float[] { 9.9f, 8.8f }); - assertThat(this.retryListener.onSuccessRetryCount).isEqualTo(0); + assertThat(this.retryListener.onSuccessRetryCount).isEqualTo(2); assertThat(this.retryListener.onErrorRetryCount).isEqualTo(2); } diff --git a/models/spring-ai-zhipuai/src/test/java/org/springframework/ai/zhipuai/embedding/EmbeddingIT.java b/models/spring-ai-zhipuai/src/test/java/org/springframework/ai/zhipuai/embedding/EmbeddingIT.java index 15d2474bf6b..ec1c69a8029 100644 --- a/models/spring-ai-zhipuai/src/test/java/org/springframework/ai/zhipuai/embedding/EmbeddingIT.java +++ b/models/spring-ai-zhipuai/src/test/java/org/springframework/ai/zhipuai/embedding/EmbeddingIT.java @@ -85,9 +85,10 @@ void embeddingV3WithCustomDimension() { void batchEmbedding() { assertThat(this.embeddingModel).isNotNull(); - EmbeddingResponse embeddingResponse = this.embeddingModel.embedForResponse(List.of("Hello World", "HI")); + EmbeddingResponse embeddingResponse = this.embeddingModel + .embedForResponse(List.of("Hello world", "How are you?", "How is the weather today?")); - assertThat(embeddingResponse.getResults()).hasSize(2); + assertThat(embeddingResponse.getResults()).hasSize(3); assertThat(embeddingResponse.getResults().get(0)).isNotNull(); assertThat(embeddingResponse.getResults().get(0).getOutput()).hasSize(1024); @@ -95,6 +96,9 @@ void batchEmbedding() { assertThat(embeddingResponse.getResults().get(1)).isNotNull(); assertThat(embeddingResponse.getResults().get(1).getOutput()).hasSize(1024); + assertThat(embeddingResponse.getResults().get(2)).isNotNull(); + assertThat(embeddingResponse.getResults().get(2).getOutput()).hasSize(1024); + assertThat(this.embeddingModel.dimensions()).isEqualTo(1024); } From 9dcbe594f0d4e4ac08b6039c7e76ba36c36bfbcc Mon Sep 17 00:00:00 2001 From: YuJie Wan <31400063+eeaters@users.noreply.github.com> Date: Thu, 11 Sep 2025 17:51:24 +0800 Subject: [PATCH 2/2] refactor: Adjust the order of the code to improve its readability. Signed-off-by: YuJie Wan <31400063+eeaters@users.noreply.github.com> --- .../org/springframework/ai/zhipuai/ZhiPuAiEmbeddingModel.java | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/ZhiPuAiEmbeddingModel.java b/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/ZhiPuAiEmbeddingModel.java index 6a22ee4a475..c1ec94262e1 100644 --- a/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/ZhiPuAiEmbeddingModel.java +++ b/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/ZhiPuAiEmbeddingModel.java @@ -153,14 +153,14 @@ public float[] embed(Document document) { public EmbeddingResponse call(EmbeddingRequest request) { Assert.notEmpty(request.getInstructions(), "At least one text is required!"); - EmbeddingRequest embeddingRequest = buildEmbeddingRequest(request); - var zhipuEmbeddingRequest = zhipuEmbeddingRequest(embeddingRequest); + var embeddingRequest = buildEmbeddingRequest(request); var observationContext = EmbeddingModelObservationContext.builder() .embeddingRequest(embeddingRequest) .provider(ZhiPuApiConstants.PROVIDER_NAME) .build(); + var zhipuEmbeddingRequest = zhipuEmbeddingRequest(embeddingRequest); return EmbeddingModelObservationDocumentation.EMBEDDING_MODEL_OPERATION .observation(this.observationConvention, DEFAULT_OBSERVATION_CONVENTION, () -> observationContext, this.observationRegistry)