Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,7 @@
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import java.util.Objects;

import com.google.genai.Client;
import com.google.genai.types.ContentEmbedding;
Expand All @@ -43,6 +42,7 @@
import org.springframework.ai.embedding.observation.EmbeddingModelObservationConvention;
import org.springframework.ai.embedding.observation.EmbeddingModelObservationDocumentation;
import org.springframework.ai.google.genai.GoogleGenAiEmbeddingConnectionDetails;
import org.springframework.ai.model.EmbeddingModelDescription;
import org.springframework.ai.model.ModelOptionsUtils;
import org.springframework.ai.observation.conventions.AiProvider;
import org.springframework.ai.retry.RetryUtils;
Expand All @@ -64,10 +64,8 @@ public class GoogleGenAiTextEmbeddingModel extends AbstractEmbeddingModel {

private static final EmbeddingModelObservationConvention DEFAULT_OBSERVATION_CONVENTION = new DefaultEmbeddingModelObservationConvention();

private static final Map<String, Integer> KNOWN_EMBEDDING_DIMENSIONS = Stream
.of(GoogleGenAiTextEmbeddingModelName.values())
.collect(Collectors.toMap(GoogleGenAiTextEmbeddingModelName::getName,
GoogleGenAiTextEmbeddingModelName::getDimensions));
private static final Map<String, Integer> KNOWN_EMBEDDING_DIMENSIONS = EmbeddingModelDescription
.calculateKnownEmbeddingDimensions(GoogleGenAiTextEmbeddingModelName.class);

public final GoogleGenAiTextEmbeddingOptions defaultOptions;

Expand Down Expand Up @@ -257,9 +255,14 @@ private DefaultUsage getDefaultUsage(Integer totalTokens) {
return new DefaultUsage(0, 0, totalTokens);
}

@Override
public Map<String, Integer> knownEmbeddingDimensions() {
return KNOWN_EMBEDDING_DIMENSIONS;
}

@Override
public int dimensions() {
return KNOWN_EMBEDDING_DIMENSIONS.computeIfAbsent(this.defaultOptions.getModel(), model -> super.dimensions());
return dimensions(this, Objects.requireNonNull(this.defaultOptions.getModel()));
}

/**
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2023-2024 the original author or authors.
* Copyright 2023-2025 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -20,8 +20,6 @@
import java.util.EnumMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import java.util.stream.Stream;

import com.google.cloud.aiplatform.v1.EndpointName;
import com.google.cloud.aiplatform.v1.PredictRequest;
Expand All @@ -43,6 +41,7 @@
import org.springframework.ai.embedding.EmbeddingResponseMetadata;
import org.springframework.ai.embedding.EmbeddingResultMetadata;
import org.springframework.ai.embedding.EmbeddingResultMetadata.ModalityType;
import org.springframework.ai.model.EmbeddingModelDescription;
import org.springframework.ai.model.ModelOptionsUtils;
import org.springframework.ai.vertexai.embedding.VertexAiEmbeddingConnectionDetails;
import org.springframework.ai.vertexai.embedding.VertexAiEmbeddingUtils;
Expand Down Expand Up @@ -75,10 +74,8 @@ public class VertexAiMultimodalEmbeddingModel implements DocumentEmbeddingModel
private static final List<MimeType> SUPPORTED_IMAGE_MIME_SUB_TYPES = List.of(MimeTypeUtils.IMAGE_JPEG,
MimeTypeUtils.IMAGE_GIF, MimeTypeUtils.IMAGE_PNG, MimeTypeUtils.parseMimeType("image/bmp"));

private static final Map<String, Integer> KNOWN_EMBEDDING_DIMENSIONS = Stream
.of(VertexAiMultimodalEmbeddingModelName.values())
.collect(Collectors.toMap(VertexAiMultimodalEmbeddingModelName::getName,
VertexAiMultimodalEmbeddingModelName::getDimensions));
private static final Map<String, Integer> KNOWN_EMBEDDING_DIMENSIONS = EmbeddingModelDescription
.calculateKnownEmbeddingDimensions(VertexAiMultimodalEmbeddingModelName.class);

public final VertexAiMultimodalEmbeddingOptions defaultOptions;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,7 @@
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import java.util.Objects;

import com.google.cloud.aiplatform.v1.EndpointName;
import com.google.cloud.aiplatform.v1.PredictRequest;
Expand All @@ -43,6 +42,7 @@
import org.springframework.ai.embedding.observation.EmbeddingModelObservationContext;
import org.springframework.ai.embedding.observation.EmbeddingModelObservationConvention;
import org.springframework.ai.embedding.observation.EmbeddingModelObservationDocumentation;
import org.springframework.ai.model.EmbeddingModelDescription;
import org.springframework.ai.model.ModelOptionsUtils;
import org.springframework.ai.observation.conventions.AiProvider;
import org.springframework.ai.retry.RetryUtils;
Expand All @@ -67,10 +67,8 @@ public class VertexAiTextEmbeddingModel extends AbstractEmbeddingModel {

private static final EmbeddingModelObservationConvention DEFAULT_OBSERVATION_CONVENTION = new DefaultEmbeddingModelObservationConvention();

private static final Map<String, Integer> KNOWN_EMBEDDING_DIMENSIONS = Stream
.of(VertexAiTextEmbeddingModelName.values())
.collect(Collectors.toMap(VertexAiTextEmbeddingModelName::getName,
VertexAiTextEmbeddingModelName::getDimensions));
private static final Map<String, Integer> KNOWN_EMBEDDING_DIMENSIONS = EmbeddingModelDescription
.calculateKnownEmbeddingDimensions(VertexAiTextEmbeddingModelName.class);

public final VertexAiTextEmbeddingOptions defaultOptions;

Expand Down Expand Up @@ -242,9 +240,14 @@ private DefaultUsage getDefaultUsage(Integer totalTokens) {
return new DefaultUsage(0, 0, totalTokens);
}

@Override
public Map<String, Integer> knownEmbeddingDimensions() {
return KNOWN_EMBEDDING_DIMENSIONS;
}

@Override
public int dimensions() {
return KNOWN_EMBEDDING_DIMENSIONS.getOrDefault(this.defaultOptions.getModel(), super.dimensions());
return dimensions(this, Objects.requireNonNull(this.defaultOptions.getModel()));
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,24 @@ public static int dimensions(EmbeddingModel embeddingModel, String modelName, St
}
}

/**
* Return the dimension of the requested embedding generative name. Uses the embedding
* model to retrieve its default dimensions if the generative name is unknown.
* @param embeddingModel Embedding model client to determine its known embedding
* dimensions and default dimensions.
* @param modelName Embedding generative name to retrieve the dimensions for.
* @return Returns the embedding dimensions for the model name.
*/
public static int dimensions(AbstractEmbeddingModel embeddingModel, String modelName) {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure it is the best solution to avoid source code duplication in embedding models overriding dimensions method.

var dimensions = embeddingModel.knownEmbeddingDimensions().get(modelName);

if (dimensions == null) {
dimensions = embeddingModel.defaultDimensions();
}

return dimensions;
}

private static Map<String, Integer> loadKnownModelDimensions() {
try {
var resource = EMBEDDING_MODEL_DIMENSIONS_PROPERTIES;
Expand All @@ -82,21 +100,35 @@ private static Map<String, Integer> loadKnownModelDimensions() {
}
return properties.entrySet()
.stream()
.collect(Collectors.toMap(e -> e.getKey().toString(), e -> Integer.parseInt(e.getValue().toString())));
.collect(Collectors.collectingAndThen(
Collectors.toMap(e -> e.getKey().toString(), e -> Integer.parseInt(e.getValue().toString())),
Map::copyOf));
}
catch (IOException e) {
throw new RuntimeException(e);
}
}

@Override
public int dimensions() {
private int defaultDimensions() {
if (this.embeddingDimensions.get() < 0) {
this.embeddingDimensions.set(dimensions(this, "Test", "Hello World"));
}
return this.embeddingDimensions.get();
}

/**
* Retrieve all the known embedding dimensions.
* @return The map containing the known embedding dimensions by model name
*/
public Map<String, Integer> knownEmbeddingDimensions() {
return KNOWN_EMBEDDING_DIMENSIONS;
}

@Override
public int dimensions() {
return defaultDimensions();
}

static class Hints implements RuntimeHintsRegistrar {

@Override
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2023-2024 the original author or authors.
* Copyright 2023-2025 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -16,15 +16,28 @@

package org.springframework.ai.model;

import java.util.Map;
import java.util.stream.Collectors;
import java.util.stream.Stream;

/**
* Description of an embedding model.
*
* @author Christian Tzolov
* @author Nicolas Krier
*/
public interface EmbeddingModelDescription extends ModelDescription {

default int getDimensions() {
return -1;
}

static <E extends Enum<E> & EmbeddingModelDescription> Map<String, Integer> calculateKnownEmbeddingDimensions(
Class<E> embeddingModelClass) {
return Stream.of(embeddingModelClass.getEnumConstants())
.collect(Collectors.collectingAndThen(
Collectors.toMap(ModelDescription::getName, EmbeddingModelDescription::getDimensions),
Map::copyOf));
}

}