diff --git a/models/spring-ai-google-genai-embedding/src/main/java/org/springframework/ai/google/genai/text/GoogleGenAiTextEmbeddingModel.java b/models/spring-ai-google-genai-embedding/src/main/java/org/springframework/ai/google/genai/text/GoogleGenAiTextEmbeddingModel.java index 4b87b05066f..8327bf1cee0 100644 --- a/models/spring-ai-google-genai-embedding/src/main/java/org/springframework/ai/google/genai/text/GoogleGenAiTextEmbeddingModel.java +++ b/models/spring-ai-google-genai-embedding/src/main/java/org/springframework/ai/google/genai/text/GoogleGenAiTextEmbeddingModel.java @@ -19,8 +19,7 @@ import java.util.ArrayList; import java.util.List; import java.util.Map; -import java.util.stream.Collectors; -import java.util.stream.Stream; +import java.util.Objects; import com.google.genai.Client; import com.google.genai.types.ContentEmbedding; @@ -43,6 +42,7 @@ import org.springframework.ai.embedding.observation.EmbeddingModelObservationConvention; import org.springframework.ai.embedding.observation.EmbeddingModelObservationDocumentation; import org.springframework.ai.google.genai.GoogleGenAiEmbeddingConnectionDetails; +import org.springframework.ai.model.EmbeddingModelDescription; import org.springframework.ai.model.ModelOptionsUtils; import org.springframework.ai.observation.conventions.AiProvider; import org.springframework.ai.retry.RetryUtils; @@ -64,10 +64,8 @@ public class GoogleGenAiTextEmbeddingModel extends AbstractEmbeddingModel { private static final EmbeddingModelObservationConvention DEFAULT_OBSERVATION_CONVENTION = new DefaultEmbeddingModelObservationConvention(); - private static final Map KNOWN_EMBEDDING_DIMENSIONS = Stream - .of(GoogleGenAiTextEmbeddingModelName.values()) - .collect(Collectors.toMap(GoogleGenAiTextEmbeddingModelName::getName, - GoogleGenAiTextEmbeddingModelName::getDimensions)); + private static final Map KNOWN_EMBEDDING_DIMENSIONS = EmbeddingModelDescription + .calculateKnownEmbeddingDimensions(GoogleGenAiTextEmbeddingModelName.class); public final GoogleGenAiTextEmbeddingOptions defaultOptions; @@ -257,9 +255,14 @@ private DefaultUsage getDefaultUsage(Integer totalTokens) { return new DefaultUsage(0, 0, totalTokens); } + @Override + public Map knownEmbeddingDimensions() { + return KNOWN_EMBEDDING_DIMENSIONS; + } + @Override public int dimensions() { - return KNOWN_EMBEDDING_DIMENSIONS.computeIfAbsent(this.defaultOptions.getModel(), model -> super.dimensions()); + return dimensions(this, Objects.requireNonNull(this.defaultOptions.getModel())); } /** diff --git a/models/spring-ai-vertex-ai-embedding/src/main/java/org/springframework/ai/vertexai/embedding/multimodal/VertexAiMultimodalEmbeddingModel.java b/models/spring-ai-vertex-ai-embedding/src/main/java/org/springframework/ai/vertexai/embedding/multimodal/VertexAiMultimodalEmbeddingModel.java index 8d25c4e4d36..0bb4f3fd2f2 100644 --- a/models/spring-ai-vertex-ai-embedding/src/main/java/org/springframework/ai/vertexai/embedding/multimodal/VertexAiMultimodalEmbeddingModel.java +++ b/models/spring-ai-vertex-ai-embedding/src/main/java/org/springframework/ai/vertexai/embedding/multimodal/VertexAiMultimodalEmbeddingModel.java @@ -1,5 +1,5 @@ /* - * Copyright 2023-2024 the original author or authors. + * Copyright 2023-2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -20,8 +20,6 @@ import java.util.EnumMap; import java.util.List; import java.util.Map; -import java.util.stream.Collectors; -import java.util.stream.Stream; import com.google.cloud.aiplatform.v1.EndpointName; import com.google.cloud.aiplatform.v1.PredictRequest; @@ -43,6 +41,7 @@ import org.springframework.ai.embedding.EmbeddingResponseMetadata; import org.springframework.ai.embedding.EmbeddingResultMetadata; import org.springframework.ai.embedding.EmbeddingResultMetadata.ModalityType; +import org.springframework.ai.model.EmbeddingModelDescription; import org.springframework.ai.model.ModelOptionsUtils; import org.springframework.ai.vertexai.embedding.VertexAiEmbeddingConnectionDetails; import org.springframework.ai.vertexai.embedding.VertexAiEmbeddingUtils; @@ -75,10 +74,8 @@ public class VertexAiMultimodalEmbeddingModel implements DocumentEmbeddingModel private static final List SUPPORTED_IMAGE_MIME_SUB_TYPES = List.of(MimeTypeUtils.IMAGE_JPEG, MimeTypeUtils.IMAGE_GIF, MimeTypeUtils.IMAGE_PNG, MimeTypeUtils.parseMimeType("image/bmp")); - private static final Map KNOWN_EMBEDDING_DIMENSIONS = Stream - .of(VertexAiMultimodalEmbeddingModelName.values()) - .collect(Collectors.toMap(VertexAiMultimodalEmbeddingModelName::getName, - VertexAiMultimodalEmbeddingModelName::getDimensions)); + private static final Map KNOWN_EMBEDDING_DIMENSIONS = EmbeddingModelDescription + .calculateKnownEmbeddingDimensions(VertexAiMultimodalEmbeddingModelName.class); public final VertexAiMultimodalEmbeddingOptions defaultOptions; diff --git a/models/spring-ai-vertex-ai-embedding/src/main/java/org/springframework/ai/vertexai/embedding/text/VertexAiTextEmbeddingModel.java b/models/spring-ai-vertex-ai-embedding/src/main/java/org/springframework/ai/vertexai/embedding/text/VertexAiTextEmbeddingModel.java index 4bef9d1145b..51770c73453 100644 --- a/models/spring-ai-vertex-ai-embedding/src/main/java/org/springframework/ai/vertexai/embedding/text/VertexAiTextEmbeddingModel.java +++ b/models/spring-ai-vertex-ai-embedding/src/main/java/org/springframework/ai/vertexai/embedding/text/VertexAiTextEmbeddingModel.java @@ -20,8 +20,7 @@ import java.util.ArrayList; import java.util.List; import java.util.Map; -import java.util.stream.Collectors; -import java.util.stream.Stream; +import java.util.Objects; import com.google.cloud.aiplatform.v1.EndpointName; import com.google.cloud.aiplatform.v1.PredictRequest; @@ -43,6 +42,7 @@ import org.springframework.ai.embedding.observation.EmbeddingModelObservationContext; import org.springframework.ai.embedding.observation.EmbeddingModelObservationConvention; import org.springframework.ai.embedding.observation.EmbeddingModelObservationDocumentation; +import org.springframework.ai.model.EmbeddingModelDescription; import org.springframework.ai.model.ModelOptionsUtils; import org.springframework.ai.observation.conventions.AiProvider; import org.springframework.ai.retry.RetryUtils; @@ -67,10 +67,8 @@ public class VertexAiTextEmbeddingModel extends AbstractEmbeddingModel { private static final EmbeddingModelObservationConvention DEFAULT_OBSERVATION_CONVENTION = new DefaultEmbeddingModelObservationConvention(); - private static final Map KNOWN_EMBEDDING_DIMENSIONS = Stream - .of(VertexAiTextEmbeddingModelName.values()) - .collect(Collectors.toMap(VertexAiTextEmbeddingModelName::getName, - VertexAiTextEmbeddingModelName::getDimensions)); + private static final Map KNOWN_EMBEDDING_DIMENSIONS = EmbeddingModelDescription + .calculateKnownEmbeddingDimensions(VertexAiTextEmbeddingModelName.class); public final VertexAiTextEmbeddingOptions defaultOptions; @@ -242,9 +240,14 @@ private DefaultUsage getDefaultUsage(Integer totalTokens) { return new DefaultUsage(0, 0, totalTokens); } + @Override + public Map knownEmbeddingDimensions() { + return KNOWN_EMBEDDING_DIMENSIONS; + } + @Override public int dimensions() { - return KNOWN_EMBEDDING_DIMENSIONS.getOrDefault(this.defaultOptions.getModel(), super.dimensions()); + return dimensions(this, Objects.requireNonNull(this.defaultOptions.getModel())); } /** diff --git a/spring-ai-model/src/main/java/org/springframework/ai/embedding/AbstractEmbeddingModel.java b/spring-ai-model/src/main/java/org/springframework/ai/embedding/AbstractEmbeddingModel.java index 7cfcf32d382..90e7f678166 100644 --- a/spring-ai-model/src/main/java/org/springframework/ai/embedding/AbstractEmbeddingModel.java +++ b/spring-ai-model/src/main/java/org/springframework/ai/embedding/AbstractEmbeddingModel.java @@ -71,6 +71,24 @@ public static int dimensions(EmbeddingModel embeddingModel, String modelName, St } } + /** + * Return the dimension of the requested embedding generative name. Uses the embedding + * model to retrieve its default dimensions if the generative name is unknown. + * @param embeddingModel Embedding model client to determine its known embedding + * dimensions and default dimensions. + * @param modelName Embedding generative name to retrieve the dimensions for. + * @return Returns the embedding dimensions for the model name. + */ + public static int dimensions(AbstractEmbeddingModel embeddingModel, String modelName) { + var dimensions = embeddingModel.knownEmbeddingDimensions().get(modelName); + + if (dimensions == null) { + dimensions = embeddingModel.defaultDimensions(); + } + + return dimensions; + } + private static Map loadKnownModelDimensions() { try { var resource = EMBEDDING_MODEL_DIMENSIONS_PROPERTIES; @@ -82,21 +100,35 @@ private static Map loadKnownModelDimensions() { } return properties.entrySet() .stream() - .collect(Collectors.toMap(e -> e.getKey().toString(), e -> Integer.parseInt(e.getValue().toString()))); + .collect(Collectors.collectingAndThen( + Collectors.toMap(e -> e.getKey().toString(), e -> Integer.parseInt(e.getValue().toString())), + Map::copyOf)); } catch (IOException e) { throw new RuntimeException(e); } } - @Override - public int dimensions() { + private int defaultDimensions() { if (this.embeddingDimensions.get() < 0) { this.embeddingDimensions.set(dimensions(this, "Test", "Hello World")); } return this.embeddingDimensions.get(); } + /** + * Retrieve all the known embedding dimensions. + * @return The map containing the known embedding dimensions by model name + */ + public Map knownEmbeddingDimensions() { + return KNOWN_EMBEDDING_DIMENSIONS; + } + + @Override + public int dimensions() { + return defaultDimensions(); + } + static class Hints implements RuntimeHintsRegistrar { @Override diff --git a/spring-ai-model/src/main/java/org/springframework/ai/model/EmbeddingModelDescription.java b/spring-ai-model/src/main/java/org/springframework/ai/model/EmbeddingModelDescription.java index 73104e060aa..f67cfaff72e 100644 --- a/spring-ai-model/src/main/java/org/springframework/ai/model/EmbeddingModelDescription.java +++ b/spring-ai-model/src/main/java/org/springframework/ai/model/EmbeddingModelDescription.java @@ -1,5 +1,5 @@ /* - * Copyright 2023-2024 the original author or authors. + * Copyright 2023-2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,10 +16,15 @@ package org.springframework.ai.model; +import java.util.Map; +import java.util.stream.Collectors; +import java.util.stream.Stream; + /** * Description of an embedding model. * * @author Christian Tzolov + * @author Nicolas Krier */ public interface EmbeddingModelDescription extends ModelDescription { @@ -27,4 +32,12 @@ default int getDimensions() { return -1; } + static & EmbeddingModelDescription> Map calculateKnownEmbeddingDimensions( + Class embeddingModelClass) { + return Stream.of(embeddingModelClass.getEnumConstants()) + .collect(Collectors.collectingAndThen( + Collectors.toMap(ModelDescription::getName, EmbeddingModelDescription::getDimensions), + Map::copyOf)); + } + }