Skip to content

Commit e5f1435

Browse files
authored
MariaDB Vector Store Integration: Add tuning parameters and utility functions (#17791)
1 parent b4a4716 commit e5f1435

File tree

5 files changed

+89
-24
lines changed

5 files changed

+89
-24
lines changed

llama-index-integrations/vector_stores/llama-index-vector-stores-mariadb/README.md

+2
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@ vector_store = MariaDBVectorStore.from_params(
2727
database="vectordb",
2828
table_name="llama_index_vectorstore",
2929
embed_dim=1536, # OpenAI embedding dimension
30+
default_m=6, # MariaDB Vector system parameter
31+
ef_search=20, # MariaDB Vector system parameter
3032
)
3133
```
3234

llama-index-integrations/vector_stores/llama-index-vector-stores-mariadb/llama_index/vector_stores/mariadb/base.py

+42-3
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
from urllib.parse import quote_plus
77

88
import sqlalchemy
9-
109
from llama_index.core.bridge.pydantic import PrivateAttr
1110
from llama_index.core.schema import BaseNode, MetadataMode
1211
from llama_index.core.vector_stores.types import (
@@ -52,6 +51,8 @@ class MariaDBVectorStore(BasePydanticVectorStore):
5251
password="password",
5352
database="vectordb",
5453
table_name="llama_index_vectorstore",
54+
default_m=6,
55+
ef_search=20,
5556
embed_dim=1536 # OpenAI embedding dimension
5657
)
5758
```
@@ -65,6 +66,8 @@ class MariaDBVectorStore(BasePydanticVectorStore):
6566
table_name: str
6667
schema_name: str
6768
embed_dim: int
69+
default_m: int
70+
ef_search: int
6871
perform_setup: bool
6972
debug: bool
7073

@@ -78,6 +81,8 @@ def __init__(
7881
table_name: str,
7982
schema_name: str,
8083
embed_dim: int = 1536,
84+
default_m: int = 6,
85+
ef_search: int = 20,
8186
perform_setup: bool = True,
8287
debug: bool = False,
8388
) -> None:
@@ -89,6 +94,8 @@ def __init__(
8994
table_name (str): Table name.
9095
schema_name (str): Schema name.
9196
embed_dim (int, optional): Embedding dimensions. Defaults to 1536.
97+
default_m (int, optional): Default M value for the vector index. Defaults to 6.
98+
ef_search (int, optional): EF search value for the vector index. Defaults to 20.
9299
perform_setup (bool, optional): If DB should be set up. Defaults to True.
93100
debug (bool, optional): Debug mode. Defaults to False.
94101
"""
@@ -98,15 +105,20 @@ def __init__(
98105
table_name=table_name,
99106
schema_name=schema_name,
100107
embed_dim=embed_dim,
108+
default_m=default_m,
109+
ef_search=ef_search,
101110
perform_setup=perform_setup,
102111
debug=debug,
103112
)
104113

114+
self._initialize()
115+
105116
def close(self) -> None:
106117
if not self._is_initialized:
107118
return
108119

109120
self._engine.dispose()
121+
self._is_initialized = False
110122

111123
@classmethod
112124
def class_name(cls) -> str:
@@ -125,6 +137,8 @@ def from_params(
125137
connection_string: Optional[Union[str, sqlalchemy.engine.URL]] = None,
126138
connection_args: Optional[Dict[str, Any]] = None,
127139
embed_dim: int = 1536,
140+
default_m: int = 6,
141+
ef_search: int = 20,
128142
perform_setup: bool = True,
129143
debug: bool = False,
130144
) -> "MariaDBVectorStore":
@@ -141,6 +155,8 @@ def from_params(
141155
connection_string (Union[str, sqlalchemy.engine.URL]): Connection string to MariaDB DB.
142156
connection_args (Dict[str, Any], optional): A dictionary of connection options.
143157
embed_dim (int, optional): Embedding dimensions. Defaults to 1536.
158+
default_m (int, optional): Default M value for the vector index. Defaults to 6.
159+
ef_search (int, optional): EF search value for the vector index. Defaults to 20.
144160
perform_setup (bool, optional): If DB should be set up. Defaults to True.
145161
debug (bool, optional): Debug mode. Defaults to False.
146162
@@ -162,6 +178,8 @@ def from_params(
162178
table_name=table_name,
163179
schema_name=schema_name,
164180
embed_dim=embed_dim,
181+
default_m=default_m,
182+
ef_search=ef_search,
165183
perform_setup=perform_setup,
166184
debug=debug,
167185
)
@@ -200,8 +218,8 @@ def _create_table_if_not_exists(self) -> None:
200218
text TEXT,
201219
metadata JSON,
202220
embedding VECTOR({self.embed_dim}) NOT NULL,
203-
INDEX `{self.table_name}_node_id_idx` (`node_id`),
204-
VECTOR INDEX (embedding) DISTANCE=cosine
221+
INDEX (`node_id`),
222+
VECTOR INDEX (embedding) M={self.default_m} DISTANCE=cosine
205223
)
206224
"""
207225
connection.execute(sqlalchemy.text(stmt))
@@ -378,6 +396,7 @@ def query(self, query: VectorStoreQuery, **kwargs: Any) -> VectorStoreQueryResul
378396
self._initialize()
379397

380398
stmt = f"""
399+
SET STATEMENT mhnsw_ef_search={self.ef_search} FOR
381400
SELECT
382401
node_id,
383402
text,
@@ -435,6 +454,26 @@ def delete_nodes(
435454

436455
connection.commit()
437456

457+
def count(self) -> int:
458+
self._initialize()
459+
460+
with self._engine.connect() as connection:
461+
stmt = f"""SELECT COUNT(*) FROM `{self.table_name}`"""
462+
result = connection.execute(sqlalchemy.text(stmt))
463+
464+
return result.scalar() or 0
465+
466+
def drop(self) -> None:
467+
self._initialize()
468+
469+
with self._engine.connect() as connection:
470+
stmt = f"""DROP TABLE IF EXISTS `{self.table_name}`"""
471+
connection.execute(sqlalchemy.text(stmt))
472+
473+
connection.commit()
474+
475+
self.close()
476+
438477
def clear(self) -> None:
439478
self._initialize()
440479

llama-index-integrations/vector_stores/llama-index-vector-stores-mariadb/pyproject.toml

+2-2
Original file line numberDiff line numberDiff line change
@@ -21,13 +21,13 @@ ignore_missing_imports = true
2121
python_version = "3.8"
2222

2323
[tool.poetry]
24-
authors = ["Your Name <you@example.com>"]
24+
authors = ["Kalin Arsov <[email protected]>", "Vishal Rao <vishal@skysql.com>"]
2525
description = "llama-index vector_stores mariadb integration"
2626
exclude = ["**/BUILD"]
2727
license = "MIT"
2828
name = "llama-index-vector-stores-mariadb"
2929
readme = "README.md"
30-
version = "0.3.0"
30+
version = "0.3.1"
3131

3232
[tool.poetry.dependencies]
3333
python = ">=3.9,<4.0"
Original file line numberDiff line numberDiff line change
@@ -1 +1,3 @@
1-
python_tests()
1+
python_tests(
2+
dependencies=["llama-index-integrations/vector_stores/llama-index-vector-stores-mariadb:poetry#pymysql"]
3+
)

llama-index-integrations/vector_stores/llama-index-vector-stores-mariadb/tests/test_mariadb.py

+40-18
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44

55
import pytest
66
import sqlalchemy
7-
87
from llama_index.core.schema import NodeRelationship, RelatedNodeInfo, TextNode
98
from llama_index.core.vector_stores.types import (
109
FilterCondition,
@@ -13,6 +12,7 @@
1312
MetadataFilters,
1413
VectorStoreQuery,
1514
)
15+
1616
from llama_index.vector_stores.mariadb import MariaDBVectorStore
1717
from llama_index.vector_stores.mariadb.base import _meets_min_server_version
1818

@@ -49,18 +49,19 @@
4949
),
5050
]
5151

52-
vector_store = MariaDBVectorStore.from_params(
53-
database="test",
54-
table_name="vector_store_test",
55-
embed_dim=3,
56-
host="127.0.0.1",
57-
user="root",
58-
password="test",
59-
port="3306",
60-
)
61-
6252

53+
vector_store = None
6354
try:
55+
vector_store = MariaDBVectorStore.from_params(
56+
database="test",
57+
table_name="vector_store_test",
58+
embed_dim=3,
59+
host="127.0.0.1",
60+
user="root",
61+
password="test",
62+
port="3306",
63+
)
64+
6465
# If you want to run the integration tests you need to do:
6566
# docker-compose up
6667

@@ -84,7 +85,8 @@ def teardown(request: pytest.FixtureRequest) -> Generator:
8485
if "noautousefixtures" in request.keywords:
8586
return
8687

87-
vector_store.clear()
88+
if vector_store is not None:
89+
vector_store.clear()
8890

8991

9092
@pytest.fixture(scope="session", autouse=True)
@@ -95,7 +97,8 @@ def close_db_connection(request: pytest.FixtureRequest) -> Generator:
9597
if "noautousefixtures" in request.keywords:
9698
return
9799

98-
vector_store.close()
100+
if vector_store is not None:
101+
vector_store.close()
99102

100103

101104
@pytest.mark.parametrize(
@@ -117,7 +120,7 @@ def test_meets_min_server_version(version: str, supported: bool) -> None:
117120

118121

119122
@pytest.mark.skipif(
120-
run_integration_tests is False,
123+
not run_integration_tests,
121124
reason="MariaDB instance required for integration tests",
122125
)
123126
def test_query() -> None:
@@ -131,7 +134,7 @@ def test_query() -> None:
131134

132135

133136
@pytest.mark.skipif(
134-
run_integration_tests is False,
137+
not run_integration_tests,
135138
reason="MariaDB instance required for integration tests",
136139
)
137140
def test_query_with_metadatafilters() -> None:
@@ -168,7 +171,7 @@ def test_query_with_metadatafilters() -> None:
168171

169172

170173
@pytest.mark.skipif(
171-
run_integration_tests is False,
174+
not run_integration_tests,
172175
reason="MariaDB instance required for integration tests",
173176
)
174177
def test_delete() -> None:
@@ -188,7 +191,7 @@ def test_delete() -> None:
188191

189192

190193
@pytest.mark.skipif(
191-
run_integration_tests is False,
194+
not run_integration_tests,
192195
reason="MariaDB instance required for integration tests",
193196
)
194197
def test_delete_nodes() -> None:
@@ -212,7 +215,26 @@ def test_delete_nodes() -> None:
212215

213216

214217
@pytest.mark.skipif(
215-
run_integration_tests is False,
218+
not run_integration_tests,
219+
reason="MariaDB instance required for integration tests",
220+
)
221+
def test_count() -> None:
222+
vector_store.add(TEST_NODES)
223+
assert vector_store.count() == 3
224+
225+
226+
@pytest.mark.skipif(
227+
not run_integration_tests,
228+
reason="MariaDB instance required for integration tests",
229+
)
230+
def test_drop() -> None:
231+
vector_store.add(TEST_NODES)
232+
vector_store.drop()
233+
assert vector_store.count() == 0
234+
235+
236+
@pytest.mark.skipif(
237+
not run_integration_tests,
216238
reason="MariaDB instance required for integration tests",
217239
)
218240
def test_clear() -> None:

0 commit comments

Comments
 (0)