26
26
import java .nio .file .FileAlreadyExistsException ;
27
27
import java .nio .file .Files ;
28
28
import java .util .Comparator ;
29
- import java .util .HashMap ;
30
29
import java .util .List ;
31
30
import java .util .Map ;
32
- import java .util .Objects ;
33
31
import java .util .Optional ;
34
32
import java .util .concurrent .ConcurrentHashMap ;
35
33
36
34
import com .fasterxml .jackson .core .JsonProcessingException ;
37
35
import com .fasterxml .jackson .core .type .TypeReference ;
38
36
import com .fasterxml .jackson .databind .ObjectMapper ;
39
- import com .fasterxml .jackson .databind .ObjectWriter ;
40
37
import com .fasterxml .jackson .databind .json .JsonMapper ;
41
38
import io .micrometer .observation .ObservationRegistry ;
42
39
import org .slf4j .Logger ;
50
47
import org .springframework .ai .vectorstore .observation .AbstractObservationVectorStore ;
51
48
import org .springframework .ai .vectorstore .observation .VectorStoreObservationContext ;
52
49
import org .springframework .ai .vectorstore .observation .VectorStoreObservationConvention ;
50
+ import org .springframework .core .io .FileSystemResource ;
53
51
import org .springframework .core .io .Resource ;
52
+ import org .springframework .util .Assert ;
54
53
55
54
/**
56
- * SimpleVectorStore is a simple implementation of the VectorStore interface.
57
- *
55
+ * Simple, in-memory implementation of the {@link VectorStore} interface.
56
+ * <p/>
58
57
* It also provides methods to save the current state of the vectors to a file, and to
59
58
* load vectors from a file.
60
- *
59
+ * <p/>
61
60
* For a deeper understanding of the mathematical concepts and computations involved in
62
61
* calculating similarity scores among vectors, refer to this
63
62
* [resource](https://docs.spring.io/spring-ai/reference/api/vectordbs.html#_understanding_vectors).
67
66
* @author Mark Pollack
68
67
* @author Christian Tzolov
69
68
* @author Sebastien Deleuze
69
+ * @author John Blum
70
+ * @see VectorStore
70
71
*/
71
72
public class SimpleVectorStore extends AbstractObservationVectorStore {
72
73
@@ -87,54 +88,72 @@ public SimpleVectorStore(EmbeddingModel embeddingModel, ObservationRegistry obse
87
88
88
89
super (observationRegistry , customObservationConvention );
89
90
90
- Objects .requireNonNull (embeddingModel , "EmbeddingModel must not be null" );
91
+ Assert .notNull (embeddingModel , "EmbeddingModel must not be null" );
92
+
91
93
this .embeddingModel = embeddingModel ;
92
94
this .objectMapper = JsonMapper .builder ().addModules (JacksonUtils .instantiateAvailableModules ()).build ();
93
95
}
94
96
95
97
@ Override
96
98
public void doAdd (List <Document > documents ) {
97
99
for (Document document : documents ) {
98
- logger .info ("Calling EmbeddingModel for document id = {}" , document .getId ());
99
- float [] embedding = this .embeddingModel .embed (document );
100
- document .setEmbedding (embedding );
100
+ logger .info ("Calling EmbeddingModel for Document id = {}" , document .getId ());
101
+ document = embed (document );
101
102
this .store .put (document .getId (), document );
102
103
}
103
104
}
104
105
106
+ protected Document embed (Document document ) {
107
+ float [] documentEmbedding = this .embeddingModel .embed (document );
108
+ document .setEmbedding (documentEmbedding );
109
+ return document ;
110
+ }
111
+
105
112
@ Override
106
113
public Optional <Boolean > doDelete (List <String > idList ) {
107
- for (String id : idList ) {
108
- this .store .remove (id );
109
- }
114
+ idList .forEach (this .store ::remove );
110
115
return Optional .of (true );
111
116
}
112
117
113
118
@ Override
114
119
public List <Document > doSimilaritySearch (SearchRequest request ) {
120
+
115
121
if (request .getFilterExpression () != null ) {
116
122
throw new UnsupportedOperationException (
117
- "The [" + this . getClass () + " ] doesn't support metadata filtering!" );
123
+ "[%s ] doesn't support metadata filtering" . formatted ( getClass (). getName ()) );
118
124
}
119
125
120
- float [] userQueryEmbedding = getUserQueryEmbedding (request .getQuery ());
121
- return this .store .values ()
122
- .stream ()
123
- .map (entry -> new Similarity (entry .getId (),
124
- EmbeddingMath .cosineSimilarity (userQueryEmbedding , entry .getEmbedding ())))
125
- .filter (s -> s .score >= request .getSimilarityThreshold ())
126
- .sorted (Comparator .<Similarity >comparingDouble (s -> s .score ).reversed ())
126
+ // @formatter:off
127
+ return this .store .values ().stream ()
128
+ .map (document -> computeSimilarity (request , document ))
129
+ .filter (similarity -> similarity .score >= request .getSimilarityThreshold ())
130
+ .sorted (Comparator .<Similarity >comparingDouble (similarity -> similarity .score ).reversed ())
127
131
.limit (request .getTopK ())
128
- .map (s -> this .store .get (s .key ))
132
+ .map (similarity -> this .store .get (similarity .key ))
129
133
.toList ();
134
+ // @formatter:on
135
+ }
136
+
137
+ protected Similarity computeSimilarity (SearchRequest request , Document document ) {
138
+
139
+ float [] userQueryEmbedding = getUserQueryEmbedding (request );
140
+ float [] documentEmbedding = document .getEmbedding ();
141
+
142
+ double score = computeCosineSimilarity (userQueryEmbedding , documentEmbedding );
143
+
144
+ return new Similarity (document .getId (), score );
145
+ }
146
+
147
+ protected double computeCosineSimilarity (float [] userQueryEmbedding , float [] storedDocumentEmbedding ) {
148
+ return EmbeddingMath .cosineSimilarity (userQueryEmbedding , storedDocumentEmbedding );
130
149
}
131
150
132
151
/**
133
152
* Serialize the vector store content into a file in JSON format.
134
153
* @param file the file to save the vector store content
135
154
*/
136
155
public void save (File file ) {
137
- String json = getVectorDbAsJson ();
156
+
138
157
try {
139
158
if (!file .exists ()) {
140
159
logger .info ("Creating new vector store file: {}" , file );
@@ -145,28 +164,30 @@ public void save(File file) {
145
164
throw new RuntimeException ("File already exists: " + file , e );
146
165
}
147
166
catch (IOException e ) {
148
- throw new RuntimeException ("Failed to create new file: " + file + ". Reason: " + e .getMessage (), e );
167
+ throw new RuntimeException ("Failed to create new file: " + file + "; Reason: " + e .getMessage (), e );
149
168
}
150
169
}
151
170
else {
152
171
logger .info ("Overwriting existing vector store file: {}" , file );
153
172
}
173
+
154
174
try (OutputStream stream = new FileOutputStream (file );
155
175
Writer writer = new OutputStreamWriter (stream , StandardCharsets .UTF_8 )) {
176
+ String json = getVectorDbAsJson ();
156
177
writer .write (json );
157
178
writer .flush ();
158
179
}
159
180
}
160
181
catch (IOException ex ) {
161
- logger .error ("IOException occurred while saving vector store file. " , ex );
182
+ logger .error ("IOException occurred while saving vector store file" , ex );
162
183
throw new RuntimeException (ex );
163
184
}
164
185
catch (SecurityException ex ) {
165
- logger .error ("SecurityException occurred while saving vector store file. " , ex );
186
+ logger .error ("SecurityException occurred while saving vector store file" , ex );
166
187
throw new RuntimeException (ex );
167
188
}
168
189
catch (NullPointerException ex ) {
169
- logger .error ("NullPointerException occurred while saving vector store file. " , ex );
190
+ logger .error ("NullPointerException occurred while saving vector store file" , ex );
170
191
throw new RuntimeException (ex );
171
192
}
172
193
}
@@ -176,45 +197,40 @@ public void save(File file) {
176
197
* @param file the file to load the vector store content
177
198
*/
178
199
public void load (File file ) {
179
- TypeReference <HashMap <String , Document >> typeRef = new TypeReference <>() {
180
-
181
- };
182
- try {
183
- Map <String , Document > deserializedMap = this .objectMapper .readValue (file , typeRef );
184
- this .store = deserializedMap ;
185
- }
186
- catch (IOException ex ) {
187
- throw new RuntimeException (ex );
188
- }
200
+ load (new FileSystemResource (file ));
189
201
}
190
202
191
203
/**
192
204
* Deserialize the vector store content from a resource in JSON format into memory.
193
205
* @param resource the resource to load the vector store content
194
206
*/
195
207
public void load (Resource resource ) {
196
- TypeReference <HashMap <String , Document >> typeRef = new TypeReference <>() {
197
208
198
- };
199
209
try {
200
- Map <String , Document > deserializedMap = this .objectMapper .readValue (resource .getInputStream (), typeRef );
201
- this .store = deserializedMap ;
210
+ this .store = this .objectMapper .readValue (resource .getInputStream (), documentMapTypeRef ());
202
211
}
203
212
catch (IOException ex ) {
204
213
throw new RuntimeException (ex );
205
214
}
206
215
}
207
216
217
+ private TypeReference <Map <String , Document >> documentMapTypeRef () {
218
+ return new TypeReference <>() {
219
+ };
220
+ }
221
+
208
222
private String getVectorDbAsJson () {
209
- ObjectWriter objectWriter = this .objectMapper .writerWithDefaultPrettyPrinter ();
210
- String json ;
223
+
211
224
try {
212
- json = objectWriter .writeValueAsString (this .store );
225
+ return this . objectMapper . writerWithDefaultPrettyPrinter () .writeValueAsString (this .store );
213
226
}
214
227
catch (JsonProcessingException e ) {
215
- throw new RuntimeException ("Error serializing documentMap to JSON. " , e );
228
+ throw new RuntimeException ("Error serializing Map of Documents to JSON" , e );
216
229
}
217
- return json ;
230
+ }
231
+
232
+ private float [] getUserQueryEmbedding (SearchRequest request ) {
233
+ return getUserQueryEmbedding (request .getQuery ());
218
234
}
219
235
220
236
private float [] getUserQueryEmbedding (String query ) {
@@ -232,9 +248,9 @@ public VectorStoreObservationContext.Builder createObservationContextBuilder(Str
232
248
233
249
public static class Similarity {
234
250
235
- private String key ;
251
+ private final String key ;
236
252
237
- private double score ;
253
+ private final double score ;
238
254
239
255
public Similarity (String key , double score ) {
240
256
this .key = key ;
@@ -243,16 +259,18 @@ public Similarity(String key, double score) {
243
259
244
260
}
245
261
246
- public final class EmbeddingMath {
262
+ public static final class EmbeddingMath {
247
263
248
264
private EmbeddingMath () {
249
265
throw new UnsupportedOperationException ("This is a utility class and cannot be instantiated" );
250
266
}
251
267
252
268
public static double cosineSimilarity (float [] vectorX , float [] vectorY ) {
269
+
253
270
if (vectorX == null || vectorY == null ) {
254
- throw new RuntimeException ("Vectors must not be null" );
271
+ throw new IllegalArgumentException ("Vectors must not be null" );
255
272
}
273
+
256
274
if (vectorX .length != vectorY .length ) {
257
275
throw new IllegalArgumentException ("Vectors lengths must be equal" );
258
276
}
@@ -268,20 +286,22 @@ public static double cosineSimilarity(float[] vectorX, float[] vectorY) {
268
286
return dotProduct / (Math .sqrt (normX ) * Math .sqrt (normY ));
269
287
}
270
288
271
- public static float dotProduct (float [] vectorX , float [] vectorY ) {
289
+ private static float dotProduct (float [] vectorX , float [] vectorY ) {
290
+
272
291
if (vectorX .length != vectorY .length ) {
273
292
throw new IllegalArgumentException ("Vectors lengths must be equal" );
274
293
}
275
294
276
295
float result = 0 ;
277
- for (int i = 0 ; i < vectorX .length ; ++i ) {
278
- result += vectorX [i ] * vectorY [i ];
296
+
297
+ for (int index = 0 ; index < vectorX .length ; ++index ) {
298
+ result += vectorX [index ] * vectorY [index ];
279
299
}
280
300
281
301
return result ;
282
302
}
283
303
284
- public static float norm (float [] vector ) {
304
+ private static float norm (float [] vector ) {
285
305
return dotProduct (vector , vector );
286
306
}
287
307
0 commit comments