Skip to content

Commit 6030cda

Browse files
committed
Add local, Transformers EmbeddingClient
- EmbeddingClient implementation that computes, locally, sentence embeddings with SBERT transformers. - Uses pre-trained transformer models, serialized into Open Neural Network Exchange (ONNX) format. - Deep Java Library and the Microsoft ONNX Java Runtime are used to run the ONNX models and compute the embeddings efficiently. - Add default tokenizer.json and model.onnx for sentence-transformers/all-MiniLM-L6-v2. - Add, configurable resource caching service to allow caching remote (http/https) resources to the local FS. - README.md provides information on how to serialize ONNX models. - add Git LFS configuration for large onnx model files.
1 parent e68bdeb commit 6030cda

File tree

13 files changed

+31652
-0
lines changed

13 files changed

+31652
-0
lines changed

.gitattributes

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
*.onnx filter=lfs diff=lfs merge=lfs -text
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
# Local Transformers Embedding Client
2+
3+
The `TransformersEmbeddingClient` is a `EmbeddingClient` implementation that computes, locally, [sentence embeddings](https://www.sbert.net/examples/applications/computing-embeddings/README.html#sentence-embeddings-with-transformers) using a selected [sentence transformer](https://www.sbert.net/).
4+
5+
It uses [pre-trained](https://www.sbert.net/docs/pretrained_models.html) transformer models, serialized into the [Open Neural Network Exchange (ONNX)](https://onnx.ai/) format.
6+
7+
The [Deep Java Library](https://djl.ai/) and the Microsoft [ONNX Java Runtime](https://onnxruntime.ai/docs/get-started/with-java.html) libraries are applied to run the ONNX models and compute the embeddings in Java.
8+
9+
## Serialize the Tokenizer and the Transformer Model
10+
11+
To run things in Java, we need to serialize the Tokenizer and the Transformer Model into ONNX format.
12+
13+
### Serialize with optimum-cli
14+
15+
One, quick, way to achieve this, is to use the [optimum-cli](https://huggingface.co/docs/optimum/exporters/onnx/usage_guides/export_a_model#exporting-a-model-to-onnx-using-the-cli) command line tool.
16+
17+
Following snippet creates an python virtual environment, installs the required packages and runs the optimum-cli to serialize (e.g. export) the models:
18+
19+
```bash
20+
python3 -m venv venv
21+
source ./venv/bin/activate
22+
(venv) pip install --upgrade pip
23+
(venv) pip install optimum onnx onnxruntime
24+
(venv) optimum-cli export onnx --model sentence-transformers/all-MiniLM-L6-v2 onnx-output-folder
25+
```
26+
27+
The `optimum-cli` command exports the [sentence-transformers/all-MiniLM-L6-v2](https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2) transformer into the `onnx-output-folder` folder. Later includes the `tokenizer.json` and `model.onnx` files used by the embedding client.
28+
29+
## Apply the ONNX model
30+
31+
Use the `setTokenizerResource(tokenizerJsonUri)` and `setModelResource(modelOnnxUri)` methods to set the URI locations of the exported `tokenizer.json` and `model.onnx` files.
32+
The `classpath:`, `file:` or `https:` URI schemas are supported.
33+
34+
If no other model is explicitly set, the `TransformersEmbeddingClient` defaults to [sentence-transformers/all-MiniLM-L6-v2](https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2) model:
35+
36+
| | |
37+
| -------- | ------- |
38+
| Dimensions |384 |
39+
| Avg. performance | 58.80 |
40+
| Speed | 14200 sentences/sec |
41+
| Size | 80MB |
42+
43+
44+
Following snippet illustrates how to use the `TransformersEmbeddingClient`:
45+
46+
```java
47+
TransformersEmbeddingClient embeddingClient = new TransformersEmbeddingClient();
48+
49+
// (optional) defaults to classpath:/onnx/all-MiniLM-L6-v2/tokenizer.json
50+
embeddingClient.setTokenizerResource("classpath:/onnx/all-MiniLM-L6-v2/tokenizer.json");
51+
// (optional) defaults to classpath:/onnx/all-MiniLM-L6-v2/model.onnx
52+
embeddingClient.setModelResource("classpath:/onnx/all-MiniLM-L6-v2/model.onnx");
53+
54+
// (optional) defaults to ${java.io.tmpdir}/spring-ai-onnx-model
55+
// Only the http/https resources are cached by default.
56+
embeddingClient.setResourceCacheDirectory("/tmp/onnx-zoo");
57+
58+
embeddingClient.afterPropertiesSet();
59+
60+
List<List<Double>> embeddings =
61+
embeddingClient.embed(List.of("Hello world", "World is big"));
62+
63+
```
64+
65+
66+
67+
68+
69+
70+
71+
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
<?xml version="1.0" encoding="UTF-8"?>
2+
<project xmlns="http://maven.apache.org/POM/4.0.0"
3+
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/maven-v4_0_0.xsd">
4+
<modelVersion>4.0.0</modelVersion>
5+
<parent>
6+
<groupId>org.springframework.experimental.ai</groupId>
7+
<artifactId>spring-ai</artifactId>
8+
<version>0.7.0-SNAPSHOT</version>
9+
<relativePath>../../pom.xml</relativePath>
10+
</parent>
11+
<artifactId>transformers-embedding</artifactId>
12+
<packaging>jar</packaging>
13+
<name>Spring AI Embedding Client - Sentence Transormers Embeddings </name>
14+
<description>Spring AI Sentence Transformers Embedding Client</description>
15+
<url>https://github.com/spring-projects-experimental/spring-ai</url>
16+
17+
<scm>
18+
<url>https://github.com/spring-projects-experimental/spring-ai</url>
19+
<connection>git://github.com/spring-projects-experimental/spring-ai.git</connection>
20+
<developerConnection>[email protected]:spring-projects-experimental/spring-ai.git</developerConnection>
21+
</scm>
22+
23+
<properties>
24+
<djl.version>0.24.0</djl.version>
25+
<onnxruntime.version>1.16.1</onnxruntime.version>
26+
</properties>
27+
<dependencies>
28+
<dependency>
29+
<groupId>org.springframework.experimental.ai</groupId>
30+
<artifactId>spring-ai-core</artifactId>
31+
<version>${parent.version}</version>
32+
</dependency>
33+
34+
<dependency>
35+
<groupId>com.microsoft.onnxruntime</groupId>
36+
<artifactId>onnxruntime</artifactId>
37+
<version>${onnxruntime.version}</version>
38+
</dependency>
39+
40+
<dependency>
41+
<groupId>ai.djl.pytorch</groupId>
42+
<artifactId>pytorch-engine</artifactId>
43+
<version>${djl.version}</version>
44+
</dependency>
45+
46+
<dependency>
47+
<groupId>ai.djl</groupId>
48+
<artifactId>api</artifactId>
49+
<version>${djl.version}</version>
50+
</dependency>
51+
52+
<dependency>
53+
<groupId>ai.djl</groupId>
54+
<artifactId>model-zoo</artifactId>
55+
<version>${djl.version}</version>
56+
</dependency>
57+
58+
<dependency>
59+
<groupId>ai.djl.huggingface</groupId>
60+
<artifactId>tokenizers</artifactId>
61+
<version>${djl.version}</version>
62+
</dependency>
63+
64+
65+
<!-- TESTING -->
66+
<dependency>
67+
<groupId>org.springframework.boot</groupId>
68+
<artifactId>spring-boot-starter-test</artifactId>
69+
<scope>test</scope>
70+
</dependency>
71+
72+
<dependency>
73+
<groupId>org.springframework.boot</groupId>
74+
<artifactId>spring-boot-testcontainers</artifactId>
75+
<scope>test</scope>
76+
</dependency>
77+
78+
<dependency>
79+
<groupId>org.testcontainers</groupId>
80+
<artifactId>junit-jupiter</artifactId>
81+
<scope>test</scope>
82+
</dependency>
83+
84+
</dependencies>
85+
86+
</project>
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,149 @@
1+
/*
2+
* Copyright 2023-2023 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package org.springframework.ai.embedding;
18+
19+
import java.io.File;
20+
import java.io.IOException;
21+
import java.net.URI;
22+
import java.util.ArrayList;
23+
import java.util.List;
24+
import java.util.UUID;
25+
26+
import org.apache.commons.logging.Log;
27+
import org.apache.commons.logging.LogFactory;
28+
29+
import org.springframework.core.io.DefaultResourceLoader;
30+
import org.springframework.core.io.FileUrlResource;
31+
import org.springframework.core.io.Resource;
32+
import org.springframework.util.Assert;
33+
import org.springframework.util.FileCopyUtils;
34+
import org.springframework.util.StreamUtils;
35+
import org.springframework.util.StringUtils;
36+
37+
/**
38+
* Service that helps caching remote {@link Resource}s on the local file system.
39+
*
40+
* @author Christian Tzolov
41+
*/
42+
public class ResourceCacheService {
43+
44+
private static final Log logger = LogFactory.getLog(ResourceCacheService.class);
45+
46+
/**
47+
* The parent folder that contains all cached resources.
48+
*/
49+
private final File cacheDirectory;
50+
51+
/**
52+
* Resources with URI schemas belonging to the excludedUriSchemas are not cached. By
53+
* default the file and classpath resources are not cached as they are already in the
54+
* local file system.
55+
*/
56+
private List<String> excludedUriSchemas = new ArrayList<>(List.of("file", "classpath"));
57+
58+
public ResourceCacheService() {
59+
this(new File(System.getProperty("java.io.tmpdir"), "spring-ai-onnx-model").getAbsolutePath());
60+
}
61+
62+
public ResourceCacheService(String rootCacheDirectory) {
63+
this(new File(rootCacheDirectory));
64+
}
65+
66+
public ResourceCacheService(File rootCacheDirectory) {
67+
Assert.notNull(rootCacheDirectory, "Cache directory can not be null.");
68+
this.cacheDirectory = rootCacheDirectory;
69+
if (!this.cacheDirectory.exists()) {
70+
logger.info("Create cache root directory: " + this.cacheDirectory.getAbsolutePath());
71+
this.cacheDirectory.mkdirs();
72+
}
73+
Assert.isTrue(this.cacheDirectory.isDirectory(), "The cache folder must be a directory");
74+
}
75+
76+
/**
77+
* Overrides the excluded URI schemas list.
78+
* @param excludedUriSchemas new list of URI schemas to be excluded from caching.
79+
*/
80+
public void setExcludedUriSchemas(List<String> excludedUriSchemas) {
81+
Assert.notNull(excludedUriSchemas, "The excluded URI schemas list can not be null");
82+
this.excludedUriSchemas = excludedUriSchemas;
83+
}
84+
85+
/**
86+
* Get {@link Resource} representing the cached copy of the original resource.
87+
* @param originalResourceUri Resource to be cached.
88+
* @return Returns a cached resource. If the original resource's URI schema is within
89+
* the excluded schema list the original resource is returned.
90+
*/
91+
public Resource getCachedResource(String originalResourceUri) {
92+
return this.getCachedResource(new DefaultResourceLoader().getResource(originalResourceUri));
93+
}
94+
95+
/**
96+
* Get {@link Resource} representing the cached copy of the original resource.
97+
* @param originalResource Resource to be cached.
98+
* @return Returns a cached resource. If the original resource's URI schema is within
99+
* the excluded schema list the original resource is returned.
100+
*/
101+
public Resource getCachedResource(Resource originalResource) {
102+
try {
103+
if (this.excludedUriSchemas.contains(originalResource.getURI().getScheme())) {
104+
logger.info("The " + originalResource.toString() + " resource with URI schema ["
105+
+ originalResource.getURI().getScheme() + "] is excluded from caching");
106+
return originalResource;
107+
}
108+
109+
File cachedFile = getCachedFile(originalResource);
110+
if (!cachedFile.exists()) {
111+
FileCopyUtils.copy(StreamUtils.copyToByteArray(originalResource.getInputStream()), cachedFile);
112+
logger.info("Caching the " + originalResource.toString() + " resource to: " + cachedFile);
113+
}
114+
return new FileUrlResource(cachedFile.getAbsolutePath());
115+
}
116+
catch (Exception e) {
117+
throw new IllegalStateException("Failed to cache the resource: " + originalResource.getDescription(), e);
118+
}
119+
}
120+
121+
private File getCachedFile(Resource originalResource) throws IOException {
122+
var resourceParentFolder = new File(this.cacheDirectory,
123+
UUID.nameUUIDFromBytes(pathWithoutLastSegment(originalResource.getURI())).toString());
124+
resourceParentFolder.mkdirs();
125+
String newFileName = getCacheName(originalResource);
126+
return new File(resourceParentFolder, newFileName);
127+
}
128+
129+
private byte[] pathWithoutLastSegment(URI uri) {
130+
String path = uri.toASCIIString();
131+
var pathBeforeLastSegment = path.substring(0, path.lastIndexOf('/') + 1);
132+
return pathBeforeLastSegment.getBytes();
133+
}
134+
135+
private String getCacheName(Resource originalResource) throws IOException {
136+
String fileName = originalResource.getFilename();
137+
String fragment = originalResource.getURI().getFragment();
138+
return !StringUtils.hasText(fragment) ? fileName : fileName + "_" + fragment;
139+
}
140+
141+
public void deleteCacheFolder() {
142+
if (this.cacheDirectory.exists()) {
143+
logger.info("Empty Model Cache at:" + this.cacheDirectory.getAbsolutePath());
144+
this.cacheDirectory.delete();
145+
this.cacheDirectory.mkdirs();
146+
}
147+
}
148+
149+
}

0 commit comments

Comments
 (0)