Skip to content

Commit 0bed02e

Browse files
committed
Improve extensibility of SimpleVectoreStore
Additionally: * Throw IllegalArgumentException instead of NullPointerException from constructor for null EmbeddingModel. * Delegate Documentation embedding generation to embed(:Document) method called by doAdd(:Document). * Break doSimilaritySearch(:SearchRequest) method into discrete, overridable operations. * Implement load(:File) in terms of load(:Resource) using a FileSystemResource. * Hide implementation details of EmbeddedMath dotProduct(..) and norm(..) methods. * Use more meaningful and descriptive variable names. * Add whitespace to improve readability. * Fix compiler warnings.
1 parent 08a007f commit 0bed02e

File tree

1 file changed

+74
-54
lines changed

1 file changed

+74
-54
lines changed

spring-ai-core/src/main/java/org/springframework/ai/vectorstore/SimpleVectorStore.java

Lines changed: 74 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -26,17 +26,14 @@
2626
import java.nio.file.FileAlreadyExistsException;
2727
import java.nio.file.Files;
2828
import java.util.Comparator;
29-
import java.util.HashMap;
3029
import java.util.List;
3130
import java.util.Map;
32-
import java.util.Objects;
3331
import java.util.Optional;
3432
import java.util.concurrent.ConcurrentHashMap;
3533

3634
import com.fasterxml.jackson.core.JsonProcessingException;
3735
import com.fasterxml.jackson.core.type.TypeReference;
3836
import com.fasterxml.jackson.databind.ObjectMapper;
39-
import com.fasterxml.jackson.databind.ObjectWriter;
4037
import com.fasterxml.jackson.databind.json.JsonMapper;
4138
import io.micrometer.observation.ObservationRegistry;
4239
import org.slf4j.Logger;
@@ -50,14 +47,16 @@
5047
import org.springframework.ai.vectorstore.observation.AbstractObservationVectorStore;
5148
import org.springframework.ai.vectorstore.observation.VectorStoreObservationContext;
5249
import org.springframework.ai.vectorstore.observation.VectorStoreObservationConvention;
50+
import org.springframework.core.io.FileSystemResource;
5351
import org.springframework.core.io.Resource;
52+
import org.springframework.util.Assert;
5453

5554
/**
56-
* SimpleVectorStore is a simple implementation of the VectorStore interface.
57-
*
55+
* Simple, in-memory implementation of the {@link VectorStore} interface.
56+
* <p/>
5857
* It also provides methods to save the current state of the vectors to a file, and to
5958
* load vectors from a file.
60-
*
59+
* <p/>
6160
* For a deeper understanding of the mathematical concepts and computations involved in
6261
* calculating similarity scores among vectors, refer to this
6362
* [resource](https://docs.spring.io/spring-ai/reference/api/vectordbs.html#_understanding_vectors).
@@ -67,6 +66,8 @@
6766
* @author Mark Pollack
6867
* @author Christian Tzolov
6968
* @author Sebastien Deleuze
69+
* @author John Blum
70+
* @see VectorStore
7071
*/
7172
public class SimpleVectorStore extends AbstractObservationVectorStore {
7273

@@ -87,54 +88,72 @@ public SimpleVectorStore(EmbeddingModel embeddingModel, ObservationRegistry obse
8788

8889
super(observationRegistry, customObservationConvention);
8990

90-
Objects.requireNonNull(embeddingModel, "EmbeddingModel must not be null");
91+
Assert.notNull(embeddingModel, "EmbeddingModel must not be null");
92+
9193
this.embeddingModel = embeddingModel;
9294
this.objectMapper = JsonMapper.builder().addModules(JacksonUtils.instantiateAvailableModules()).build();
9395
}
9496

9597
@Override
9698
public void doAdd(List<Document> documents) {
9799
for (Document document : documents) {
98-
logger.info("Calling EmbeddingModel for document id = {}", document.getId());
99-
float[] embedding = this.embeddingModel.embed(document);
100-
document.setEmbedding(embedding);
100+
logger.info("Calling EmbeddingModel for Document id = {}", document.getId());
101+
document = embed(document);
101102
this.store.put(document.getId(), document);
102103
}
103104
}
104105

106+
protected Document embed(Document document) {
107+
float[] documentEmbedding = this.embeddingModel.embed(document);
108+
document.setEmbedding(documentEmbedding);
109+
return document;
110+
}
111+
105112
@Override
106113
public Optional<Boolean> doDelete(List<String> idList) {
107-
for (String id : idList) {
108-
this.store.remove(id);
109-
}
114+
idList.forEach(this.store::remove);
110115
return Optional.of(true);
111116
}
112117

113118
@Override
114119
public List<Document> doSimilaritySearch(SearchRequest request) {
120+
115121
if (request.getFilterExpression() != null) {
116122
throw new UnsupportedOperationException(
117-
"The [" + this.getClass() + "] doesn't support metadata filtering!");
123+
"[%s] doesn't support metadata filtering".formatted(getClass().getName()));
118124
}
119125

120-
float[] userQueryEmbedding = getUserQueryEmbedding(request.getQuery());
121-
return this.store.values()
122-
.stream()
123-
.map(entry -> new Similarity(entry.getId(),
124-
EmbeddingMath.cosineSimilarity(userQueryEmbedding, entry.getEmbedding())))
125-
.filter(s -> s.score >= request.getSimilarityThreshold())
126-
.sorted(Comparator.<Similarity>comparingDouble(s -> s.score).reversed())
126+
// @formatter:off
127+
return this.store.values().stream()
128+
.map(document -> computeSimilarity(request, document))
129+
.filter(similarity -> similarity.score >= request.getSimilarityThreshold())
130+
.sorted(Comparator.<Similarity>comparingDouble(similarity -> similarity.score).reversed())
127131
.limit(request.getTopK())
128-
.map(s -> this.store.get(s.key))
132+
.map(similarity -> this.store.get(similarity.key))
129133
.toList();
134+
// @formatter:on
135+
}
136+
137+
protected Similarity computeSimilarity(SearchRequest request, Document document) {
138+
139+
float[] userQueryEmbedding = getUserQueryEmbedding(request);
140+
float[] documentEmbedding = document.getEmbedding();
141+
142+
double score = computeCosineSimilarity(userQueryEmbedding, documentEmbedding);
143+
144+
return new Similarity(document.getId(), score);
145+
}
146+
147+
protected double computeCosineSimilarity(float[] userQueryEmbedding, float[] storedDocumentEmbedding) {
148+
return EmbeddingMath.cosineSimilarity(userQueryEmbedding, storedDocumentEmbedding);
130149
}
131150

132151
/**
133152
* Serialize the vector store content into a file in JSON format.
134153
* @param file the file to save the vector store content
135154
*/
136155
public void save(File file) {
137-
String json = getVectorDbAsJson();
156+
138157
try {
139158
if (!file.exists()) {
140159
logger.info("Creating new vector store file: {}", file);
@@ -145,28 +164,30 @@ public void save(File file) {
145164
throw new RuntimeException("File already exists: " + file, e);
146165
}
147166
catch (IOException e) {
148-
throw new RuntimeException("Failed to create new file: " + file + ". Reason: " + e.getMessage(), e);
167+
throw new RuntimeException("Failed to create new file: " + file + "; Reason: " + e.getMessage(), e);
149168
}
150169
}
151170
else {
152171
logger.info("Overwriting existing vector store file: {}", file);
153172
}
173+
154174
try (OutputStream stream = new FileOutputStream(file);
155175
Writer writer = new OutputStreamWriter(stream, StandardCharsets.UTF_8)) {
176+
String json = getVectorDbAsJson();
156177
writer.write(json);
157178
writer.flush();
158179
}
159180
}
160181
catch (IOException ex) {
161-
logger.error("IOException occurred while saving vector store file.", ex);
182+
logger.error("IOException occurred while saving vector store file", ex);
162183
throw new RuntimeException(ex);
163184
}
164185
catch (SecurityException ex) {
165-
logger.error("SecurityException occurred while saving vector store file.", ex);
186+
logger.error("SecurityException occurred while saving vector store file", ex);
166187
throw new RuntimeException(ex);
167188
}
168189
catch (NullPointerException ex) {
169-
logger.error("NullPointerException occurred while saving vector store file.", ex);
190+
logger.error("NullPointerException occurred while saving vector store file", ex);
170191
throw new RuntimeException(ex);
171192
}
172193
}
@@ -176,45 +197,40 @@ public void save(File file) {
176197
* @param file the file to load the vector store content
177198
*/
178199
public void load(File file) {
179-
TypeReference<HashMap<String, Document>> typeRef = new TypeReference<>() {
180-
181-
};
182-
try {
183-
Map<String, Document> deserializedMap = this.objectMapper.readValue(file, typeRef);
184-
this.store = deserializedMap;
185-
}
186-
catch (IOException ex) {
187-
throw new RuntimeException(ex);
188-
}
200+
load(new FileSystemResource(file));
189201
}
190202

191203
/**
192204
* Deserialize the vector store content from a resource in JSON format into memory.
193205
* @param resource the resource to load the vector store content
194206
*/
195207
public void load(Resource resource) {
196-
TypeReference<HashMap<String, Document>> typeRef = new TypeReference<>() {
197208

198-
};
199209
try {
200-
Map<String, Document> deserializedMap = this.objectMapper.readValue(resource.getInputStream(), typeRef);
201-
this.store = deserializedMap;
210+
this.store = this.objectMapper.readValue(resource.getInputStream(), documentMapTypeRef());
202211
}
203212
catch (IOException ex) {
204213
throw new RuntimeException(ex);
205214
}
206215
}
207216

217+
private TypeReference<Map<String, Document>> documentMapTypeRef() {
218+
return new TypeReference<>() {
219+
};
220+
}
221+
208222
private String getVectorDbAsJson() {
209-
ObjectWriter objectWriter = this.objectMapper.writerWithDefaultPrettyPrinter();
210-
String json;
223+
211224
try {
212-
json = objectWriter.writeValueAsString(this.store);
225+
return this.objectMapper.writerWithDefaultPrettyPrinter().writeValueAsString(this.store);
213226
}
214227
catch (JsonProcessingException e) {
215-
throw new RuntimeException("Error serializing documentMap to JSON.", e);
228+
throw new RuntimeException("Error serializing Map of Documents to JSON", e);
216229
}
217-
return json;
230+
}
231+
232+
private float[] getUserQueryEmbedding(SearchRequest request) {
233+
return getUserQueryEmbedding(request.getQuery());
218234
}
219235

220236
private float[] getUserQueryEmbedding(String query) {
@@ -232,9 +248,9 @@ public VectorStoreObservationContext.Builder createObservationContextBuilder(Str
232248

233249
public static class Similarity {
234250

235-
private String key;
251+
private final String key;
236252

237-
private double score;
253+
private final double score;
238254

239255
public Similarity(String key, double score) {
240256
this.key = key;
@@ -243,16 +259,18 @@ public Similarity(String key, double score) {
243259

244260
}
245261

246-
public final class EmbeddingMath {
262+
public static final class EmbeddingMath {
247263

248264
private EmbeddingMath() {
249265
throw new UnsupportedOperationException("This is a utility class and cannot be instantiated");
250266
}
251267

252268
public static double cosineSimilarity(float[] vectorX, float[] vectorY) {
269+
253270
if (vectorX == null || vectorY == null) {
254-
throw new RuntimeException("Vectors must not be null");
271+
throw new IllegalArgumentException("Vectors must not be null");
255272
}
273+
256274
if (vectorX.length != vectorY.length) {
257275
throw new IllegalArgumentException("Vectors lengths must be equal");
258276
}
@@ -268,20 +286,22 @@ public static double cosineSimilarity(float[] vectorX, float[] vectorY) {
268286
return dotProduct / (Math.sqrt(normX) * Math.sqrt(normY));
269287
}
270288

271-
public static float dotProduct(float[] vectorX, float[] vectorY) {
289+
private static float dotProduct(float[] vectorX, float[] vectorY) {
290+
272291
if (vectorX.length != vectorY.length) {
273292
throw new IllegalArgumentException("Vectors lengths must be equal");
274293
}
275294

276295
float result = 0;
277-
for (int i = 0; i < vectorX.length; ++i) {
278-
result += vectorX[i] * vectorY[i];
296+
297+
for (int index = 0; index < vectorX.length; ++index) {
298+
result += vectorX[index] * vectorY[index];
279299
}
280300

281301
return result;
282302
}
283303

284-
public static float norm(float[] vector) {
304+
private static float norm(float[] vector) {
285305
return dotProduct(vector, vector);
286306
}
287307

0 commit comments

Comments
 (0)