Skip to content

Commit f4c259b

Browse files
Chong LuoChong Luo
Chong Luo
authored and
Chong Luo
committed
Add Voyage AI embedding API for Anthropic.
Signed-off-by: Chong Luo <[email protected]>
1 parent 7492681 commit f4c259b

File tree

6 files changed

+240
-2
lines changed

6 files changed

+240
-2
lines changed

.gitignore

+2-1
Original file line numberDiff line numberDiff line change
@@ -138,4 +138,5 @@ dmypy.json
138138
**/example.db
139139
**/.chroma
140140
docs/references/*
141-
!docs/references/index.rst
141+
!docs/references/index.rst
142+
.vscode/

README.md

+1
Original file line numberDiff line numberDiff line change
@@ -331,6 +331,7 @@ This module is created to extract embeddings from requests for similarity search
331331
- [x] Support [fastText](https://fasttext.cc) embedding.
332332
- [x] Support [SentenceTransformers](https://www.sbert.net) embedding.
333333
- [x] Support [Timm](https://timm.fast.ai/) models for image embedding.
334+
- [x] Support [VoyageAI](https://www.voyageai.com/) embedding API for Anthropic.
334335
- [ ] Support other embedding APIs.
335336
- **Cache Storage**:
336337
**Cache Storage** is where the response from LLMs, such as ChatGPT, is stored. Cached responses are retrieved to assist in evaluating similarity and are returned to the requester if there is a good semantic match. At present, GPTCache supports SQLite and offers a universally accessible interface for extension of this module.

gptcache/embedding/__init__.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
"Rwkv",
1313
"PaddleNLP",
1414
"UForm",
15+
"VoyageAI",
1516
]
1617

1718

@@ -31,7 +32,7 @@
3132
paddlenlp = LazyImport("paddlenlp", globals(), "gptcache.embedding.paddlenlp")
3233
uform = LazyImport("uform", globals(), "gptcache.embedding.uform")
3334
nomic = LazyImport("nomic", globals(), "gptcache.embedding.nomic")
34-
35+
voyageai = LazyImport("voyageai", globals(), "gptcache.embedding.voyageai")
3536

3637
def Nomic(model: str = "nomic-embed-text-v1.5",
3738
api_key: str = None,
@@ -90,3 +91,6 @@ def PaddleNLP(model="ernie-3.0-medium-zh"):
9091

9192
def UForm(model="unum-cloud/uform-vl-multilingual", embedding_type="text"):
9293
return uform.UForm(model, embedding_type)
94+
95+
def VoyageAI(model: str="voyage-3", api_key: str=None, api_key_path:str=None, input_type:str=None, truncation:bool=True):
96+
return voyageai.VoyageAI(model=model, api_key=api_key, api_key_path=api_key_path, input_type=input_type, truncation=truncation)

gptcache/embedding/voyageai.py

+88
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
import numpy as np
2+
3+
from gptcache.utils import import_voyageai
4+
from gptcache.embedding.base import BaseEmbedding
5+
6+
import_voyageai()
7+
8+
import voyageai
9+
10+
11+
class VoyageAI(BaseEmbedding):
12+
"""Generate text embedding for given text using VoyageAI.
13+
14+
:param model: The model name to use for generating embeddings. Defaults to 'voyage-3'.
15+
:type model: str
16+
:param api_key_path: The path to the VoyageAI API key file.
17+
:type api_key_path: str
18+
:param api_key: The VoyageAI API key. If it is None, the client will search for the API key in the following order:
19+
1. api_key_path, path to the file containing the key;
20+
2. environment variable VOYAGE_API_KEY_PATH, which can be set to the path to the file containing the key;
21+
3. environment variable VOYAGE_API_KEY.
22+
This behavior is defined by the VoyageAI Python SDK.
23+
:type api_key: str
24+
:param input_type: The type of input data. Defaults to None. Default to None. Other options: query, document.
25+
More details can be found in the https://docs.voyageai.com/docs/embeddings
26+
:type input_type: str
27+
:param truncation: Whether to truncate the input data. Defaults to True.
28+
:type truncation: bool
29+
30+
Example:
31+
.. code-block:: python
32+
33+
from gptcache.embedding import VoyageAI
34+
35+
test_sentence = 'Hello, world.'
36+
encoder = VoyageAI(model='voyage-3', api_key='your_voyageai_key')
37+
embed = encoder.to_embeddings(test_sentence)
38+
"""
39+
40+
def __init__(self, model: str = "voyage-3", api_key_path: str = None, api_key: str = None, input_type: str = None, truncation: bool = True):
41+
voyageai.api_key_path = api_key_path
42+
voyageai.api_key = api_key
43+
44+
self._vo = voyageai.Client()
45+
self._model = model
46+
self._input_type = input_type
47+
self._truncation = truncation
48+
49+
if self._model in self.dim_dict():
50+
self.__dimension = self.dim_dict()[model]
51+
else:
52+
self.__dimension = None
53+
54+
def to_embeddings(self, data, **_):
55+
"""
56+
Generate embedding for the given text input.
57+
58+
:param data: The input text.
59+
:type data: str or list[str]
60+
61+
:return: The text embedding in the shape of (dim,).
62+
:rtype: numpy.ndarray
63+
"""
64+
if not isinstance(data, list):
65+
data = [data]
66+
result = self._vo.embed(texts=data, model=self._model, input_type=self._input_type, truncation=self._truncation)
67+
embeddings = result.embeddings
68+
return np.array(embeddings).astype("float32").squeeze(0)
69+
70+
@property
71+
def dimension(self):
72+
"""Embedding dimension.
73+
74+
:return: embedding dimension
75+
"""
76+
if not self.__dimension:
77+
foo_emb = self.to_embeddings("foo")
78+
self.__dimension = len(foo_emb)
79+
return self.__dimension
80+
81+
@staticmethod
82+
def dim_dict():
83+
return {"voyage-3": 1024,
84+
"voyage-3-lite": 512,
85+
"voyage-finance-2": 1024,
86+
"voyage-multilingual-2": 1024,
87+
"voyage-law-2": 1024,
88+
"voyage-code-2": 1536}

gptcache/utils/__init__.py

+4
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
"import_sbert",
55
"import_cohere",
66
"import_nomic",
7+
"import_voyageai",
78
"import_fasttext",
89
"import_huggingface",
910
"import_uform",
@@ -85,6 +86,9 @@ def import_cohere():
8586
def import_nomic():
8687
_check_library("nomic")
8788

89+
def import_voyageai():
90+
_check_library("voyageai")
91+
8892

8993
def import_fasttext():
9094
_check_library("fasttext", package="fasttext==0.9.2")
+140
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,140 @@
1+
import os
2+
import types
3+
import pytest
4+
import mock
5+
from gptcache.utils import import_voyageai
6+
from gptcache.embedding import VoyageAI
7+
8+
import_voyageai()
9+
10+
11+
12+
@mock.patch.dict(os.environ, {"VOYAGE_API_KEY": "API_KEY", "VOYAGE_API_KEY_PATH": "API_KEY_FILE_PATH_ENV"})
13+
@mock.patch("builtins.open", new_callable=mock.mock_open, read_data="API_KEY")
14+
@mock.patch("voyageai.Client.embed", return_value=types.SimpleNamespace(embeddings=[[0] * 1024]))
15+
def test_voageai_without_api_key(mock_created, mock_file):
16+
dimension = 1024
17+
vo = VoyageAI()
18+
19+
assert vo.dimension == dimension
20+
assert len(vo.to_embeddings("foo")) == dimension
21+
22+
mock_file.assert_called_once_with("API_KEY_FILE_PATH_ENV", "rt")
23+
mock_created.assert_called_once_with(texts=["foo"], model="voyage-3", input_type=None, truncation=True)
24+
25+
26+
@mock.patch.dict(os.environ, {"VOYAGE_API_KEY": "API_KEY", "VOYAGE_API_KEY_PATH": "API_KEY_FILE_PATH_ENV"})
27+
@mock.patch("builtins.open", new_callable=mock.mock_open, read_data="API_KEY")
28+
@mock.patch("voyageai.Client.embed", return_value=types.SimpleNamespace(embeddings=[[0] * 1024]))
29+
def test_voageai_with_api_key_path(mock_create, mock_file):
30+
dimension = 1024
31+
vo = VoyageAI(api_key_path="API_KEY_FILE_PATH")
32+
33+
assert vo.dimension == dimension
34+
assert len(vo.to_embeddings("foo")) == dimension
35+
36+
mock_file.assert_called_once_with("API_KEY_FILE_PATH", "rt")
37+
mock_create.assert_called_once_with(texts=["foo"], model="voyage-3", input_type=None, truncation=True)
38+
39+
40+
@mock.patch.dict(os.environ, {"VOYAGE_API_KEY": "API_KEY"})
41+
@mock.patch("builtins.open", new_callable=mock.mock_open, read_data="API_KEY")
42+
@mock.patch("voyageai.Client.embed", return_value=types.SimpleNamespace(embeddings=[[0] * 1024]))
43+
def test_voageai_with_api_key_in_envrion(mock_create, mock_file):
44+
dimension = 1024
45+
vo = VoyageAI()
46+
47+
assert vo.dimension == dimension
48+
assert len(vo.to_embeddings("foo")) == dimension
49+
mock_file.assert_not_called()
50+
mock_create.assert_called_once_with(texts=["foo"], model="voyage-3", input_type=None, truncation=True)
51+
52+
53+
@mock.patch("voyageai.Client.embed", return_value=types.SimpleNamespace(embeddings=[[0] * 1024]))
54+
def test_voageai_with_api_key(mock_create):
55+
dimension = 1024
56+
vo = VoyageAI(api_key="API_KEY")
57+
58+
assert vo.dimension == dimension
59+
assert len(vo.to_embeddings("foo")) == dimension
60+
mock_create.assert_called_once_with(texts=["foo"], model="voyage-3", input_type=None, truncation=True)
61+
62+
63+
@mock.patch.dict(os.environ)
64+
@mock.patch("builtins.open", new_callable=mock.mock_open, read_data="API_KEY")
65+
def test_voageai_without_api_key_or_api_key_file_path(mock_file):
66+
with pytest.raises(Exception):
67+
VoyageAI()
68+
mock_file.assert_not_called()
69+
70+
71+
@mock.patch("voyageai.Client.embed", return_value=types.SimpleNamespace(embeddings=[[0] * 512]))
72+
def test_voageai_with_model_voyage_3_lite(mock_create):
73+
dimension = 512
74+
model = "voyage-3-lite"
75+
vo = VoyageAI(api_key="API_KEY", model=model)
76+
77+
assert vo.dimension == dimension
78+
assert len(vo.to_embeddings("foo")) == dimension
79+
mock_create.assert_called_once_with(texts=["foo"], model=model, input_type=None, truncation=True)
80+
81+
82+
@mock.patch("voyageai.Client.embed", return_value=types.SimpleNamespace(embeddings=[[0] * 1024]))
83+
def test_voageai_with_model_voyage_finance_2(mock_create):
84+
dimension = 1024
85+
model = "voyage-finance-2"
86+
vo = VoyageAI(api_key="API_KEY", model=model)
87+
88+
assert vo.dimension == dimension
89+
assert len(vo.to_embeddings("foo")) == dimension
90+
mock_create.assert_called_once_with(texts=["foo"], model=model, input_type=None, truncation=True)
91+
92+
93+
@mock.patch("voyageai.Client.embed", return_value=types.SimpleNamespace(embeddings=[[0] * 1024]))
94+
def test_voageai_with_model_voyage_multilingual_2(mock_create):
95+
dimension = 1024
96+
model = "voyage-multilingual-2"
97+
vo = VoyageAI(api_key="API_KEY", model=model)
98+
99+
assert vo.dimension == dimension
100+
assert len(vo.to_embeddings("foo")) == dimension
101+
mock_create.assert_called_once_with(texts=["foo"], model=model, input_type=None, truncation=True)
102+
103+
104+
@mock.patch("voyageai.Client.embed", return_value=types.SimpleNamespace(embeddings=[[0] * 1024]))
105+
def test_voageai_with_model_voyage_law_2(mock_create):
106+
dimension = 1024
107+
model = "voyage-law-2"
108+
vo = VoyageAI(api_key="API_KEY", model=model)
109+
110+
assert vo.dimension == dimension
111+
assert len(vo.to_embeddings("foo")) == dimension
112+
mock_create.assert_called_once_with(texts=["foo"], model=model, input_type=None, truncation=True)
113+
114+
115+
@mock.patch("voyageai.Client.embed", return_value=types.SimpleNamespace(embeddings=[[0] * 1536]))
116+
def test_voageai_with_model_voyage_code_2(mock_create):
117+
dimension = 1536
118+
model = "voyage-code-2"
119+
vo = VoyageAI(api_key="API_KEY", model=model)
120+
121+
assert vo.dimension == dimension
122+
assert len(vo.to_embeddings("foo")) == dimension
123+
mock_create.assert_called_once_with(texts=["foo"], model=model, input_type=None, truncation=True)
124+
125+
126+
@mock.patch("voyageai.Client.embed", return_value=types.SimpleNamespace(embeddings=[[0] * 1536]))
127+
def test_voageai_with_general_parameters(mock_create):
128+
dimension = 1536
129+
model = "voyage-code-2"
130+
api_key = "API_KEY"
131+
input_type = "query"
132+
truncation = False
133+
134+
mock_create.return_value = types.SimpleNamespace(embeddings=[[0] * dimension])
135+
136+
vo = VoyageAI(model=model, api_key=api_key, input_type=input_type, truncation=truncation)
137+
assert vo.dimension == dimension
138+
assert len(vo.to_embeddings(["foo"])) == dimension
139+
140+
mock_create.assert_called_once_with(texts=["foo"], model=model, input_type=input_type, truncation=truncation)

0 commit comments

Comments
 (0)