26
26
import java .nio .file .FileAlreadyExistsException ;
27
27
import java .nio .file .Files ;
28
28
import java .util .Comparator ;
29
- import java .util .HashMap ;
30
29
import java .util .List ;
31
30
import java .util .Map ;
32
- import java .util .Objects ;
33
31
import java .util .Optional ;
34
32
import java .util .concurrent .ConcurrentHashMap ;
35
33
36
34
import com .fasterxml .jackson .core .JsonProcessingException ;
37
35
import com .fasterxml .jackson .core .type .TypeReference ;
38
36
import com .fasterxml .jackson .databind .ObjectMapper ;
39
- import com .fasterxml .jackson .databind .ObjectWriter ;
40
37
import com .fasterxml .jackson .databind .json .JsonMapper ;
41
38
import io .micrometer .observation .ObservationRegistry ;
42
39
import org .slf4j .Logger ;
50
47
import org .springframework .ai .vectorstore .observation .AbstractObservationVectorStore ;
51
48
import org .springframework .ai .vectorstore .observation .VectorStoreObservationContext ;
52
49
import org .springframework .ai .vectorstore .observation .VectorStoreObservationConvention ;
50
+ import org .springframework .core .io .FileSystemResource ;
53
51
import org .springframework .core .io .Resource ;
52
+ import org .springframework .util .Assert ;
54
53
55
54
/**
56
- * SimpleVectorStore is a simple implementation of the VectorStore interface.
57
- *
55
+ * Simple, in-memory implementation of the {@link VectorStore} interface.
56
+ * <p/>
58
57
* It also provides methods to save the current state of the vectors to a file, and to
59
58
* load vectors from a file.
60
- *
59
+ * <p/>
61
60
* For a deeper understanding of the mathematical concepts and computations involved in
62
61
* calculating similarity scores among vectors, refer to this
63
62
* [resource](https://docs.spring.io/spring-ai/reference/api/vectordbs.html#_understanding_vectors).
67
66
* @author Mark Pollack
68
67
* @author Christian Tzolov
69
68
* @author Sebastien Deleuze
69
+ * @author John Blum
70
+ * @see VectorStore
70
71
*/
71
72
public class SimpleVectorStore extends AbstractObservationVectorStore {
72
73
@@ -87,54 +88,72 @@ public SimpleVectorStore(EmbeddingModel embeddingModel, ObservationRegistry obse
87
88
88
89
super (observationRegistry , customObservationConvention );
89
90
90
- Objects .requireNonNull (embeddingModel , "EmbeddingModel must not be null" );
91
+ Assert .notNull (embeddingModel , "EmbeddingModel must not be null" );
92
+
91
93
this .embeddingModel = embeddingModel ;
92
94
this .objectMapper = JsonMapper .builder ().addModules (JacksonUtils .instantiateAvailableModules ()).build ();
93
95
}
94
96
95
97
@ Override
96
98
public void doAdd (List <Document > documents ) {
97
99
for (Document document : documents ) {
98
- logger .info ("Calling EmbeddingModel for document id = {}" , document .getId ());
99
- float [] embedding = this .embeddingModel .embed (document );
100
- document .setEmbedding (embedding );
100
+ logger .info ("Calling EmbeddingModel for Document id = {}" , document .getId ());
101
+ document = embed (document );
101
102
this .store .put (document .getId (), document );
102
103
}
103
104
}
104
105
106
+ protected Document embed (Document document ) {
107
+ float [] documentEmbedding = this .embeddingModel .embed (document );
108
+ document .setEmbedding (documentEmbedding );
109
+ return document ;
110
+ }
111
+
105
112
@ Override
106
113
public Optional <Boolean > doDelete (List <String > idList ) {
107
- for (String id : idList ) {
108
- this .store .remove (id );
109
- }
114
+ idList .forEach (this .store ::remove );
110
115
return Optional .of (true );
111
116
}
112
117
113
118
@ Override
114
119
public List <Document > doSimilaritySearch (SearchRequest request ) {
120
+
115
121
if (request .getFilterExpression () != null ) {
116
122
throw new UnsupportedOperationException (
117
- "The [" + this . getClass () + " ] doesn't support metadata filtering!" );
123
+ "[%s ] doesn't support metadata filtering" . formatted ( getClass (). getName ()) );
118
124
}
119
125
120
- float [] userQueryEmbedding = getUserQueryEmbedding (request .getQuery ());
121
- return this .store .values ()
122
- .stream ()
123
- .map (entry -> new Similarity (entry .getId (),
124
- EmbeddingMath .cosineSimilarity (userQueryEmbedding , entry .getEmbedding ())))
125
- .filter (s -> s .score >= request .getSimilarityThreshold ())
126
- .sorted (Comparator .<Similarity >comparingDouble (s -> s .score ).reversed ())
126
+ // @formatter:off
127
+ return this .store .values ().stream ()
128
+ .map (document -> computeSimilarity (request , document ))
129
+ .filter (similarity -> similarity .score >= request .getSimilarityThreshold ())
130
+ .sorted (Comparator .<Similarity >comparingDouble (similarity -> similarity .score ).reversed ())
127
131
.limit (request .getTopK ())
128
- .map (s -> this .store .get (s .key ))
132
+ .map (similarity -> this .store .get (similarity .key ))
129
133
.toList ();
134
+ // @formatter:on
135
+ }
136
+
137
+ protected Similarity computeSimilarity (SearchRequest request , Document document ) {
138
+
139
+ float [] userQueryEmbedding = getUserQueryEmbedding (request );
140
+ float [] documentEmbedding = document .getEmbedding ();
141
+
142
+ double score = computeCosineSimilarity (userQueryEmbedding , documentEmbedding );
143
+
144
+ return new Similarity (document .getId (), score );
145
+ }
146
+
147
+ protected double computeCosineSimilarity (float [] userQueryEmbedding , float [] storedDocumentEmbedding ) {
148
+ return EmbeddingMath .cosineSimilarity (userQueryEmbedding , storedDocumentEmbedding );
130
149
}
131
150
132
151
/**
133
152
* Serialize the vector store content into a file in JSON format.
134
153
* @param file the file to save the vector store content
135
154
*/
136
155
public void save (File file ) {
137
- String json = getVectorDbAsJson ();
156
+
138
157
try {
139
158
if (!file .exists ()) {
140
159
logger .info ("Creating new vector store file: {}" , file );
@@ -145,28 +164,22 @@ public void save(File file) {
145
164
throw new RuntimeException ("File already exists: " + file , e );
146
165
}
147
166
catch (IOException e ) {
148
- throw new RuntimeException ("Failed to create new file: " + file + ". Reason: " + e .getMessage (), e );
167
+ throw new RuntimeException ("Failed to create new file: " + file + "; Reason: " + e .getMessage (), e );
149
168
}
150
169
}
151
170
else {
152
171
logger .info ("Overwriting existing vector store file: {}" , file );
153
172
}
173
+
154
174
try (OutputStream stream = new FileOutputStream (file );
155
175
Writer writer = new OutputStreamWriter (stream , StandardCharsets .UTF_8 )) {
176
+ String json = getVectorDbAsJson ();
156
177
writer .write (json );
157
178
writer .flush ();
158
179
}
159
180
}
160
- catch (IOException ex ) {
161
- logger .error ("IOException occurred while saving vector store file." , ex );
162
- throw new RuntimeException (ex );
163
- }
164
- catch (SecurityException ex ) {
165
- logger .error ("SecurityException occurred while saving vector store file." , ex );
166
- throw new RuntimeException (ex );
167
- }
168
- catch (NullPointerException ex ) {
169
- logger .error ("NullPointerException occurred while saving vector store file." , ex );
181
+ catch (IOException | NullPointerException | SecurityException ex ) {
182
+ logger .error ("%s occurred while saving vector store file" .formatted (ex .getClass ().getSimpleName ()), ex );
170
183
throw new RuntimeException (ex );
171
184
}
172
185
}
@@ -176,45 +189,40 @@ public void save(File file) {
176
189
* @param file the file to load the vector store content
177
190
*/
178
191
public void load (File file ) {
179
- TypeReference <HashMap <String , Document >> typeRef = new TypeReference <>() {
180
-
181
- };
182
- try {
183
- Map <String , Document > deserializedMap = this .objectMapper .readValue (file , typeRef );
184
- this .store = deserializedMap ;
185
- }
186
- catch (IOException ex ) {
187
- throw new RuntimeException (ex );
188
- }
192
+ load (new FileSystemResource (file ));
189
193
}
190
194
191
195
/**
192
196
* Deserialize the vector store content from a resource in JSON format into memory.
193
197
* @param resource the resource to load the vector store content
194
198
*/
195
199
public void load (Resource resource ) {
196
- TypeReference <HashMap <String , Document >> typeRef = new TypeReference <>() {
197
200
198
- };
199
201
try {
200
- Map <String , Document > deserializedMap = this .objectMapper .readValue (resource .getInputStream (), typeRef );
201
- this .store = deserializedMap ;
202
+ this .store = this .objectMapper .readValue (resource .getInputStream (), documentMapTypeRef ());
202
203
}
203
204
catch (IOException ex ) {
204
205
throw new RuntimeException (ex );
205
206
}
206
207
}
207
208
209
+ private TypeReference <Map <String , Document >> documentMapTypeRef () {
210
+ return new TypeReference <>() {
211
+ };
212
+ }
213
+
208
214
private String getVectorDbAsJson () {
209
- ObjectWriter objectWriter = this .objectMapper .writerWithDefaultPrettyPrinter ();
210
- String json ;
215
+
211
216
try {
212
- json = objectWriter .writeValueAsString (this .store );
217
+ return this . objectMapper . writerWithDefaultPrettyPrinter () .writeValueAsString (this .store );
213
218
}
214
219
catch (JsonProcessingException e ) {
215
- throw new RuntimeException ("Error serializing documentMap to JSON. " , e );
220
+ throw new RuntimeException ("Error serializing Map of Documents to JSON" , e );
216
221
}
217
- return json ;
222
+ }
223
+
224
+ private float [] getUserQueryEmbedding (SearchRequest request ) {
225
+ return getUserQueryEmbedding (request .getQuery ());
218
226
}
219
227
220
228
private float [] getUserQueryEmbedding (String query ) {
@@ -232,9 +240,9 @@ public VectorStoreObservationContext.Builder createObservationContextBuilder(Str
232
240
233
241
public static class Similarity {
234
242
235
- private String key ;
243
+ private final String key ;
236
244
237
- private double score ;
245
+ private final double score ;
238
246
239
247
public Similarity (String key , double score ) {
240
248
this .key = key ;
@@ -243,16 +251,18 @@ public Similarity(String key, double score) {
243
251
244
252
}
245
253
246
- public final class EmbeddingMath {
254
+ public static final class EmbeddingMath {
247
255
248
256
private EmbeddingMath () {
249
257
throw new UnsupportedOperationException ("This is a utility class and cannot be instantiated" );
250
258
}
251
259
252
260
public static double cosineSimilarity (float [] vectorX , float [] vectorY ) {
261
+
253
262
if (vectorX == null || vectorY == null ) {
254
- throw new RuntimeException ("Vectors must not be null" );
263
+ throw new IllegalArgumentException ("Vectors must not be null" );
255
264
}
265
+
256
266
if (vectorX .length != vectorY .length ) {
257
267
throw new IllegalArgumentException ("Vectors lengths must be equal" );
258
268
}
@@ -268,20 +278,22 @@ public static double cosineSimilarity(float[] vectorX, float[] vectorY) {
268
278
return dotProduct / (Math .sqrt (normX ) * Math .sqrt (normY ));
269
279
}
270
280
271
- public static float dotProduct (float [] vectorX , float [] vectorY ) {
281
+ private static float dotProduct (float [] vectorX , float [] vectorY ) {
282
+
272
283
if (vectorX .length != vectorY .length ) {
273
284
throw new IllegalArgumentException ("Vectors lengths must be equal" );
274
285
}
275
286
276
287
float result = 0 ;
277
- for (int i = 0 ; i < vectorX .length ; ++i ) {
278
- result += vectorX [i ] * vectorY [i ];
288
+
289
+ for (int index = 0 ; index < vectorX .length ; ++index ) {
290
+ result += vectorX [index ] * vectorY [index ];
279
291
}
280
292
281
293
return result ;
282
294
}
283
295
284
- public static float norm (float [] vector ) {
296
+ private static float norm (float [] vector ) {
285
297
return dotProduct (vector , vector );
286
298
}
287
299
0 commit comments