Skip to content

Commit 2299f88

Browse files
rfsalievgithub-actions[bot]
authored andcommitted
[SVS] Add TieredSVSIndex flow tests. (#707)
* Draft SVS Tiered flow test * Make tests pass with the new tiered flow * Code review s1e1 * Code review s1e2 (cherry picked from commit 4135eff)
1 parent 48ac2ce commit 2299f88

File tree

3 files changed

+703
-3
lines changed

3 files changed

+703
-3
lines changed

src/python_bindings/bindings.cpp

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -594,10 +594,15 @@ class PyTiered_SVSIndex : public PyTieredIndex {
594594
explicit PyTiered_SVSIndex(const SVSParams &svs_params,
595595
const TieredSVSParams &tiered_svs_params, size_t buffer_limit) {
596596

597-
// Create primaryIndexParams and specific params for hnsw tiered index.
597+
// Create primaryIndexParams and specific params for svs tiered index.
598598
VecSimParams primary_index_params = {.algo = VecSimAlgo_SVS,
599599
.algoParams = {.svsParams = svs_params}};
600600

601+
if (primary_index_params.algoParams.svsParams.num_threads == 0) {
602+
primary_index_params.algoParams.svsParams.num_threads =
603+
this->mock_thread_pool.thread_pool_size; // Use the mock thread pool size as default
604+
}
605+
601606
auto tiered_params = this->getTieredIndexParams(buffer_limit);
602607
tiered_params.primaryIndexParams = &primary_index_params;
603608
tiered_params.specificParams.tieredSVSParams = tiered_svs_params;
@@ -611,6 +616,10 @@ class PyTiered_SVSIndex : public PyTieredIndex {
611616
// Set the created tiered index in the index external context.
612617
this->mock_thread_pool.ctx->index_strong_ref = this->index;
613618
}
619+
620+
size_t SVSLabelCount() {
621+
return this->index->debugInfo().tieredInfo.backendCommonInfo.indexLabelCount;
622+
}
614623
};
615624
#endif
616625

@@ -672,10 +681,13 @@ PYBIND11_MODULE(VecSim, m) {
672681

673682
py::enum_<VecSimSvsQuantBits>(m, "VecSimSvsQuantBits")
674683
.value("VecSimSvsQuant_NONE", VecSimSvsQuant_NONE)
675-
.value("VecSimSvsQuant_8", VecSimSvsQuant_8)
684+
.value("VecSimSvsQuant_Scalar", VecSimSvsQuant_Scalar)
676685
.value("VecSimSvsQuant_4", VecSimSvsQuant_4)
686+
.value("VecSimSvsQuant_8", VecSimSvsQuant_8)
677687
.value("VecSimSvsQuant_4x4", VecSimSvsQuant_4x4)
678688
.value("VecSimSvsQuant_4x8", VecSimSvsQuant_4x8)
689+
.value("VecSimSvsQuant_4x8_LeanVec", VecSimSvsQuant_4x8_LeanVec)
690+
.value("VecSimSvsQuant_8x8_LeanVec", VecSimSvsQuant_8x8_LeanVec)
679691
.export_values();
680692

681693
py::class_<SVSParams>(m, "SVSParams")
@@ -705,6 +717,7 @@ PYBIND11_MODULE(VecSim, m) {
705717
py::class_<TieredSVSParams>(m, "TieredSVSParams")
706718
.def(py::init())
707719
.def_readwrite("trainingTriggerThreshold", &TieredSVSParams::trainingTriggerThreshold)
720+
.def_readwrite("updateTriggerThreshold", &TieredSVSParams::updateTriggerThreshold)
708721
.def_readwrite("updateJobWaitTime", &TieredSVSParams::updateJobWaitTime);
709722

710723
py::class_<AlgoParams>(m, "AlgoParams")
@@ -799,7 +812,9 @@ PYBIND11_MODULE(VecSim, m) {
799812
size_t flat_buffer_size = DEFAULT_BLOCK_SIZE) {
800813
return new PyTiered_SVSIndex(svs_params, tiered_svs_params, flat_buffer_size);
801814
}),
802-
py::arg("svs_params"), py::arg("tiered_svs_params"), py::arg("flat_buffer_size"));
815+
py::arg("svs_params"), py::arg("tiered_svs_params"),
816+
py::arg("flat_buffer_size") = DEFAULT_BLOCK_SIZE)
817+
.def("svs_label_count", &PyTiered_SVSIndex::SVSLabelCount);
803818
#endif
804819

805820
py::class_<PyBatchIterator>(m, "BatchIterator")

tests/flow/common.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,28 @@
1313
import math
1414
from ml_dtypes import bfloat16
1515

16+
# alpha = 0 means the default value, which is 1.2 for L2 and 0.9 for other metrics.
17+
def create_svs_params (dim, num_elements, data_type, metric, quantBits = VecSimSvsQuant_NONE, alpha = 0,
18+
graph_max_degree = 32, construction_window_size = 200, search_window_size = 10,
19+
max_candidate_pool_size = 0, prune_to = 0, epsilon = 0.01, num_threads = 0):
20+
svs_params = SVSParams()
21+
22+
svs_params.dim = dim
23+
svs_params.type = data_type
24+
svs_params.metric = metric
25+
svs_params.quantBits = quantBits
26+
svs_params.alpha = alpha
27+
svs_params.graph_max_degree = graph_max_degree
28+
svs_params.construction_window_size = construction_window_size
29+
svs_params.max_candidate_pool_size = max_candidate_pool_size
30+
svs_params.prune_to = prune_to
31+
svs_params.use_search_history = VecSimOption_AUTO # VecSimOption_AUTO means use the default value
32+
svs_params.search_window_size = search_window_size
33+
svs_params.epsilon = epsilon
34+
svs_params.num_threads = num_threads
35+
36+
return svs_params
37+
1638
def create_hnsw_params(dim, num_elements, metric, data_type, ef_construction=200, m=16, ef_runtime=10, epsilon=0.01,
1739
is_multi=False):
1840
hnsw_params = HNSWParams()

0 commit comments

Comments
 (0)