6
6
from urllib .parse import quote_plus
7
7
8
8
import sqlalchemy
9
-
10
9
from llama_index .core .bridge .pydantic import PrivateAttr
11
10
from llama_index .core .schema import BaseNode , MetadataMode
12
11
from llama_index .core .vector_stores .types import (
@@ -52,6 +51,8 @@ class MariaDBVectorStore(BasePydanticVectorStore):
52
51
password="password",
53
52
database="vectordb",
54
53
table_name="llama_index_vectorstore",
54
+ default_m=6,
55
+ ef_search=20,
55
56
embed_dim=1536 # OpenAI embedding dimension
56
57
)
57
58
```
@@ -65,6 +66,8 @@ class MariaDBVectorStore(BasePydanticVectorStore):
65
66
table_name : str
66
67
schema_name : str
67
68
embed_dim : int
69
+ default_m : int
70
+ ef_search : int
68
71
perform_setup : bool
69
72
debug : bool
70
73
@@ -78,6 +81,8 @@ def __init__(
78
81
table_name : str ,
79
82
schema_name : str ,
80
83
embed_dim : int = 1536 ,
84
+ default_m : int = 6 ,
85
+ ef_search : int = 20 ,
81
86
perform_setup : bool = True ,
82
87
debug : bool = False ,
83
88
) -> None :
@@ -89,6 +94,8 @@ def __init__(
89
94
table_name (str): Table name.
90
95
schema_name (str): Schema name.
91
96
embed_dim (int, optional): Embedding dimensions. Defaults to 1536.
97
+ default_m (int, optional): Default M value for the vector index. Defaults to 6.
98
+ ef_search (int, optional): EF search value for the vector index. Defaults to 20.
92
99
perform_setup (bool, optional): If DB should be set up. Defaults to True.
93
100
debug (bool, optional): Debug mode. Defaults to False.
94
101
"""
@@ -98,15 +105,20 @@ def __init__(
98
105
table_name = table_name ,
99
106
schema_name = schema_name ,
100
107
embed_dim = embed_dim ,
108
+ default_m = default_m ,
109
+ ef_search = ef_search ,
101
110
perform_setup = perform_setup ,
102
111
debug = debug ,
103
112
)
104
113
114
+ self ._initialize ()
115
+
105
116
def close (self ) -> None :
106
117
if not self ._is_initialized :
107
118
return
108
119
109
120
self ._engine .dispose ()
121
+ self ._is_initialized = False
110
122
111
123
@classmethod
112
124
def class_name (cls ) -> str :
@@ -125,6 +137,8 @@ def from_params(
125
137
connection_string : Optional [Union [str , sqlalchemy .engine .URL ]] = None ,
126
138
connection_args : Optional [Dict [str , Any ]] = None ,
127
139
embed_dim : int = 1536 ,
140
+ default_m : int = 6 ,
141
+ ef_search : int = 20 ,
128
142
perform_setup : bool = True ,
129
143
debug : bool = False ,
130
144
) -> "MariaDBVectorStore" :
@@ -141,6 +155,8 @@ def from_params(
141
155
connection_string (Union[str, sqlalchemy.engine.URL]): Connection string to MariaDB DB.
142
156
connection_args (Dict[str, Any], optional): A dictionary of connection options.
143
157
embed_dim (int, optional): Embedding dimensions. Defaults to 1536.
158
+ default_m (int, optional): Default M value for the vector index. Defaults to 6.
159
+ ef_search (int, optional): EF search value for the vector index. Defaults to 20.
144
160
perform_setup (bool, optional): If DB should be set up. Defaults to True.
145
161
debug (bool, optional): Debug mode. Defaults to False.
146
162
@@ -162,6 +178,8 @@ def from_params(
162
178
table_name = table_name ,
163
179
schema_name = schema_name ,
164
180
embed_dim = embed_dim ,
181
+ default_m = default_m ,
182
+ ef_search = ef_search ,
165
183
perform_setup = perform_setup ,
166
184
debug = debug ,
167
185
)
@@ -200,8 +218,8 @@ def _create_table_if_not_exists(self) -> None:
200
218
text TEXT,
201
219
metadata JSON,
202
220
embedding VECTOR({ self .embed_dim } ) NOT NULL,
203
- INDEX ` { self . table_name } _node_id_idx` (`node_id`),
204
- VECTOR INDEX (embedding) DISTANCE=cosine
221
+ INDEX (`node_id`),
222
+ VECTOR INDEX (embedding) M= { self . default_m } DISTANCE=cosine
205
223
)
206
224
"""
207
225
connection .execute (sqlalchemy .text (stmt ))
@@ -378,6 +396,7 @@ def query(self, query: VectorStoreQuery, **kwargs: Any) -> VectorStoreQueryResul
378
396
self ._initialize ()
379
397
380
398
stmt = f"""
399
+ SET STATEMENT mhnsw_ef_search={ self .ef_search } FOR
381
400
SELECT
382
401
node_id,
383
402
text,
@@ -435,6 +454,26 @@ def delete_nodes(
435
454
436
455
connection .commit ()
437
456
457
+ def count (self ) -> int :
458
+ self ._initialize ()
459
+
460
+ with self ._engine .connect () as connection :
461
+ stmt = f"""SELECT COUNT(*) FROM `{ self .table_name } `"""
462
+ result = connection .execute (sqlalchemy .text (stmt ))
463
+
464
+ return result .scalar () or 0
465
+
466
+ def drop (self ) -> None :
467
+ self ._initialize ()
468
+
469
+ with self ._engine .connect () as connection :
470
+ stmt = f"""DROP TABLE IF EXISTS `{ self .table_name } `"""
471
+ connection .execute (sqlalchemy .text (stmt ))
472
+
473
+ connection .commit ()
474
+
475
+ self .close ()
476
+
438
477
def clear (self ) -> None :
439
478
self ._initialize ()
440
479
0 commit comments