Skip to content

Commit 62c4738

Browse files
committed
Optimize MistralAiEmbeddingModel dimensions method
- Calculate and cache values for unknown models only if necessary - Make known embedding dimensions a mutable map attribute - Fix warnings in MistralAiEmbeddingModelTests unit tests Signed-off-by: Nicolas Krier <[email protected]>
1 parent 4532f64 commit 62c4738

File tree

2 files changed

+22
-15
lines changed

2 files changed

+22
-15
lines changed

models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiEmbeddingModel.java

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
package org.springframework.ai.mistralai;
1818

19+
import java.util.HashMap;
1920
import java.util.List;
2021
import java.util.Map;
2122

@@ -56,16 +57,14 @@ public class MistralAiEmbeddingModel extends AbstractEmbeddingModel {
5657

5758
private static final Logger logger = LoggerFactory.getLogger(MistralAiEmbeddingModel.class);
5859

60+
private static final EmbeddingModelObservationConvention DEFAULT_OBSERVATION_CONVENTION = new DefaultEmbeddingModelObservationConvention();
61+
5962
/**
6063
* Known embedding dimensions for Mistral AI models. Maps model names to their
6164
* respective embedding vector dimensions. This allows the dimensions() method to
6265
* return the correct value without making an API call.
6366
*/
64-
private static final Map<String, Integer> KNOWN_EMBEDDING_DIMENSIONS = Map.of(
65-
MistralAiApi.EmbeddingModel.EMBED.getValue(), 1024, MistralAiApi.EmbeddingModel.CODESTRAL_EMBED.getValue(),
66-
1536);
67-
68-
private static final EmbeddingModelObservationConvention DEFAULT_OBSERVATION_CONVENTION = new DefaultEmbeddingModelObservationConvention();
67+
private final Map<String, Integer> knownEmbeddingDimensions = createKnownEmbeddingDimensions();
6968

7069
private final MistralAiEmbeddingOptions defaultOptions;
7170

@@ -85,6 +84,14 @@ public class MistralAiEmbeddingModel extends AbstractEmbeddingModel {
8584
*/
8685
private EmbeddingModelObservationConvention observationConvention = DEFAULT_OBSERVATION_CONVENTION;
8786

87+
private static Map<String, Integer> createKnownEmbeddingDimensions() {
88+
Map<String, Integer> knownEmbeddingDimensions = new HashMap<>();
89+
knownEmbeddingDimensions.put(MistralAiApi.EmbeddingModel.EMBED.getValue(), 1024);
90+
knownEmbeddingDimensions.put(MistralAiApi.EmbeddingModel.CODESTRAL_EMBED.getValue(), 1536);
91+
92+
return knownEmbeddingDimensions;
93+
}
94+
8895
@Deprecated
8996
public MistralAiEmbeddingModel(MistralAiApi mistralAiApi) {
9097
this(mistralAiApi, MetadataMode.EMBED);
@@ -197,7 +204,8 @@ public float[] embed(Document document) {
197204

198205
@Override
199206
public int dimensions() {
200-
return KNOWN_EMBEDDING_DIMENSIONS.getOrDefault(this.defaultOptions.getModel(), super.dimensions());
207+
return this.knownEmbeddingDimensions.computeIfAbsent(this.defaultOptions.getModel(),
208+
model -> super.dimensions());
201209
}
202210

203211
/**

models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/MistralAiEmbeddingModelTests.java

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,10 @@
1616

1717
package org.springframework.ai.mistralai;
1818

19+
import java.util.Arrays;
1920
import java.util.List;
2021

22+
import io.micrometer.observation.ObservationRegistry;
2123
import org.junit.jupiter.api.Test;
2224
import org.mockito.Mockito;
2325

@@ -46,7 +48,7 @@ void testDimensionsForMistralEmbedModel() {
4648
.build();
4749

4850
MistralAiEmbeddingModel model = new MistralAiEmbeddingModel(mockApi, MetadataMode.EMBED, options,
49-
RetryUtils.DEFAULT_RETRY_TEMPLATE);
51+
RetryUtils.DEFAULT_RETRY_TEMPLATE, ObservationRegistry.NOOP);
5052

5153
assertThat(model.dimensions()).isEqualTo(1024);
5254
}
@@ -60,7 +62,7 @@ void testDimensionsForCodestralEmbedModel() {
6062
.build();
6163

6264
MistralAiEmbeddingModel model = new MistralAiEmbeddingModel(mockApi, MetadataMode.EMBED, options,
63-
RetryUtils.DEFAULT_RETRY_TEMPLATE);
65+
RetryUtils.DEFAULT_RETRY_TEMPLATE, ObservationRegistry.NOOP);
6466

6567
assertThat(model.dimensions()).isEqualTo(1536);
6668
}
@@ -73,7 +75,7 @@ void testDimensionsFallbackForUnknownModel() {
7375
MistralAiEmbeddingOptions options = MistralAiEmbeddingOptions.builder().withModel("unknown-model").build();
7476

7577
MistralAiEmbeddingModel model = new MistralAiEmbeddingModel(mockApi, MetadataMode.EMBED, options,
76-
RetryUtils.DEFAULT_RETRY_TEMPLATE);
78+
RetryUtils.DEFAULT_RETRY_TEMPLATE, ObservationRegistry.NOOP);
7779

7880
// Should fall back to super.dimensions() which detects dimensions from the API
7981
// response
@@ -94,7 +96,7 @@ void testAllEmbeddingModelsHaveDimensionMapping() {
9496
.build();
9597

9698
MistralAiEmbeddingModel model = new MistralAiEmbeddingModel(mockApi, MetadataMode.EMBED, options,
97-
RetryUtils.DEFAULT_RETRY_TEMPLATE);
99+
RetryUtils.DEFAULT_RETRY_TEMPLATE, ObservationRegistry.NOOP);
98100

99101
// Each model should have a valid dimension (not the fallback -1)
100102
assertThat(model.dimensions()).as("Model %s should have a dimension mapping", embeddingModel.getValue())
@@ -122,16 +124,13 @@ private MistralAiApi createMockApiWithEmbeddingResponse(int dimensions) {
122124

123125
// Create a mock embedding response with the specified dimensions
124126
float[] embedding = new float[dimensions];
125-
for (int i = 0; i < dimensions; i++) {
126-
embedding[i] = 0.1f;
127-
}
127+
Arrays.fill(embedding, 0.1f);
128128

129129
MistralAiApi.Embedding embeddingData = new MistralAiApi.Embedding(0, embedding, "embedding");
130130

131131
MistralAiApi.Usage usage = new MistralAiApi.Usage(10, 0, 10);
132132

133-
MistralAiApi.EmbeddingList embeddingList = new MistralAiApi.EmbeddingList("object", List.of(embeddingData),
134-
"model", usage);
133+
var embeddingList = new MistralAiApi.EmbeddingList<>("object", List.of(embeddingData), "model", usage);
135134

136135
when(mockApi.embeddings(any())).thenReturn(ResponseEntity.ok(embeddingList));
137136

0 commit comments

Comments
 (0)