Skip to content

Commit

Permalink
Add filtering to python for ivf_flat (#664)
Browse files Browse the repository at this point in the history
Authors:
  - Ben Frederickson (https://github.com/benfred)

Approvers:
  - Corey J. Nolet (https://github.com/cjnolet)

URL: #664
  • Loading branch information
benfred authored Feb 7, 2025
1 parent 904051e commit ab597ec
Show file tree
Hide file tree
Showing 11 changed files with 182 additions and 109 deletions.
7 changes: 6 additions & 1 deletion cpp/include/cuvs/neighbors/ivf_flat.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

#include <cuvs/core/c_api.h>
#include <cuvs/distance/distance.h>
#include <cuvs/neighbors/common.h>
#include <dlpack/dlpack.h>
#include <stdbool.h>
#include <stdint.h>
Expand Down Expand Up @@ -267,13 +268,17 @@ cuvsError_t cuvsIvfFlatBuild(cuvsResources_t res,
* @param[in] queries DLManagedTensor* queries dataset to search
* @param[out] neighbors DLManagedTensor* output `k` neighbors for queries
* @param[out] distances DLManagedTensor* output `k` distances for queries
* @param[in] filter cuvsFilter input filter that can be used
to filter queries and neighbors based on the given bitset.
*/
cuvsError_t cuvsIvfFlatSearch(cuvsResources_t res,
cuvsIvfFlatSearchParams_t search_params,
cuvsIvfFlatIndex_t index,
DLManagedTensor* queries,
DLManagedTensor* neighbors,
DLManagedTensor* distances);
DLManagedTensor* distances,
cuvsFilter filter);

/**
* @}
*/
Expand Down
36 changes: 29 additions & 7 deletions cpp/src/neighbors/ivf_flat_c.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,8 @@ void _search(cuvsResources_t res,
cuvsIvfFlatIndex index,
DLManagedTensor* queries_tensor,
DLManagedTensor* neighbors_tensor,
DLManagedTensor* distances_tensor)
DLManagedTensor* distances_tensor,
cuvsFilter filter)
{
auto res_ptr = reinterpret_cast<raft::resources*>(res);
auto index_ptr = reinterpret_cast<cuvs::neighbors::ivf_flat::index<T, IdxT>*>(index.addr);
Expand All @@ -82,8 +83,27 @@ void _search(cuvsResources_t res,
auto neighbors_mds = cuvs::core::from_dlpack<neighbors_mdspan_type>(neighbors_tensor);
auto distances_mds = cuvs::core::from_dlpack<distances_mdspan_type>(distances_tensor);

cuvs::neighbors::ivf_flat::search(
*res_ptr, search_params, *index_ptr, queries_mds, neighbors_mds, distances_mds);
if (filter.type == NO_FILTER) {
cuvs::neighbors::ivf_flat::search(
*res_ptr, search_params, *index_ptr, queries_mds, neighbors_mds, distances_mds);
} else if (filter.type == BITSET) {
using filter_mdspan_type = raft::device_vector_view<std::uint32_t, int64_t, raft::row_major>;
auto removed_indices_tensor = reinterpret_cast<DLManagedTensor*>(filter.addr);
auto removed_indices = cuvs::core::from_dlpack<filter_mdspan_type>(removed_indices_tensor);
cuvs::core::bitset_view<std::uint32_t, int64_t> removed_indices_bitset(removed_indices,
index_ptr->size());
auto bitset_filter_obj = cuvs::neighbors::filtering::bitset_filter(removed_indices_bitset);
cuvs::neighbors::ivf_flat::search(*res_ptr,
search_params,
*index_ptr,
queries_mds,
neighbors_mds,
distances_mds,
bitset_filter_obj);

} else {
RAFT_FAIL("Unsupported filter type: BITMAP");
}
}

template <typename T, typename IdxT>
Expand Down Expand Up @@ -179,7 +199,9 @@ extern "C" cuvsError_t cuvsIvfFlatSearch(cuvsResources_t res,
cuvsIvfFlatIndex_t index_c_ptr,
DLManagedTensor* queries_tensor,
DLManagedTensor* neighbors_tensor,
DLManagedTensor* distances_tensor)
DLManagedTensor* distances_tensor,
cuvsFilter filter)

{
return cuvs::core::translate_exceptions([=] {
auto queries = queries_tensor->dl_tensor;
Expand All @@ -203,13 +225,13 @@ extern "C" cuvsError_t cuvsIvfFlatSearch(cuvsResources_t res,

if (queries.dtype.code == kDLFloat && queries.dtype.bits == 32) {
_search<float, int64_t>(
res, *params, index, queries_tensor, neighbors_tensor, distances_tensor);
res, *params, index, queries_tensor, neighbors_tensor, distances_tensor, filter);
} else if (queries.dtype.code == kDLInt && queries.dtype.bits == 8) {
_search<int8_t, int64_t>(
res, *params, index, queries_tensor, neighbors_tensor, distances_tensor);
res, *params, index, queries_tensor, neighbors_tensor, distances_tensor, filter);
} else if (queries.dtype.code == kDLUInt && queries.dtype.bits == 8) {
_search<uint8_t, int64_t>(
res, *params, index, queries_tensor, neighbors_tensor, distances_tensor);
res, *params, index, queries_tensor, neighbors_tensor, distances_tensor, filter);
} else {
RAFT_FAIL("Unsupported queries DLtensor dtype: %d and bits: %d",
queries.dtype.code,
Expand Down
6 changes: 5 additions & 1 deletion cpp/tests/neighbors/run_ivf_flat_c.c
Original file line number Diff line number Diff line change
Expand Up @@ -91,12 +91,16 @@ void run_ivf_flat(int64_t n_rows,
distances_tensor.dl_tensor.shape = distances_shape;
distances_tensor.dl_tensor.strides = NULL;

cuvsFilter filter;
filter.type = NO_FILTER;
filter.addr = (uintptr_t)NULL;

// search index
cuvsIvfFlatSearchParams_t search_params;
cuvsIvfFlatSearchParamsCreate(&search_params);
search_params->n_probes = n_probes;
cuvsIvfFlatSearch(
res, search_params, index, &queries_tensor, &neighbors_tensor, &distances_tensor);
res, search_params, index, &queries_tensor, &neighbors_tensor, &distances_tensor, filter);

// de-allocate index and res
cuvsIvfFlatSearchParamsDestroy(search_params);
Expand Down
11 changes: 9 additions & 2 deletions examples/c/src/ivf_flat_c_example.c
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,12 @@ void ivf_flat_build_search_simple(cuvsResources_t *res, DLManagedTensor * datase
search_params->n_probes = 50;

// Search the `index` built using `ivfFlatBuild`
cuvsFilter filter;
filter.type = NO_FILTER;
filter.addr = (uintptr_t)NULL;

cuvsError_t search_status = cuvsIvfFlatSearch(*res, search_params, index,
queries_tensor, &neighbors_tensor, &distances_tensor);
queries_tensor, &neighbors_tensor, &distances_tensor, filter);
if (build_status != CUVS_SUCCESS) {
printf("%s.\n", cuvsGetLastErrorText());
}
Expand Down Expand Up @@ -165,8 +169,11 @@ void ivf_flat_build_extend_search(cuvsResources_t *res, DLManagedTensor * trains
search_params->n_probes = 10;

// Search the `index` built using `ivfFlatBuild`
cuvsFilter filter;
filter.type = NO_FILTER;
filter.addr = (uintptr_t)NULL;
cuvsError_t search_status = cuvsIvfFlatSearch(*res, search_params, index,
queries_tensor, &neighbors_tensor, &distances_tensor);
queries_tensor, &neighbors_tensor, &distances_tensor, filter);
if (search_status != CUVS_SUCCESS) {
printf("%s.\n", cuvsGetLastErrorText());
exit(-1);
Expand Down
6 changes: 5 additions & 1 deletion go/ivf_flat/ivf_flat.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,10 @@ func SearchIndex[T any](Resources cuvs.Resource, params *SearchParams, index *Iv
if !index.trained {
return errors.New("index needs to be built before calling search")
}
prefilter := C.cuvsFilter{
addr: 0,
_type: C.NO_FILTER,
}

return cuvs.CheckCuvs(cuvs.CuvsError(C.cuvsIvfFlatSearch(C.cuvsResources_t(Resources.Resource), params.params, index.index, (*C.DLManagedTensor)(unsafe.Pointer(queries.C_tensor)), (*C.DLManagedTensor)(unsafe.Pointer(neighbors.C_tensor)), (*C.DLManagedTensor)(unsafe.Pointer(distances.C_tensor)))))
return cuvs.CheckCuvs(cuvs.CuvsError(C.cuvsIvfFlatSearch(C.cuvsResources_t(Resources.Resource), params.params, index.index, (*C.DLManagedTensor)(unsafe.Pointer(queries.C_tensor)), (*C.DLManagedTensor)(unsafe.Pointer(neighbors.C_tensor)), (*C.DLManagedTensor)(unsafe.Pointer(distances.C_tensor)), prefilter)))
}
4 changes: 3 additions & 1 deletion python/cuvs/cuvs/neighbors/ivf_flat/ivf_flat.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ from libcpp cimport bool
from cuvs.common.c_api cimport cuvsError_t, cuvsResources_t
from cuvs.common.cydlpack cimport DLDataType, DLManagedTensor
from cuvs.distance_type cimport cuvsDistanceType
from cuvs.neighbors.filters.filters cimport cuvsFilter


cdef extern from "cuvs/neighbors/ivf_flat.h" nogil:
Expand Down Expand Up @@ -71,7 +72,8 @@ cdef extern from "cuvs/neighbors/ivf_flat.h" nogil:
cuvsIvfFlatIndex_t index,
DLManagedTensor* queries,
DLManagedTensor* neighbors,
DLManagedTensor* distances) except +
DLManagedTensor* distances,
cuvsFilter filter) except +

cuvsError_t cuvsIvfFlatSerialize(cuvsResources_t res,
const char * filename,
Expand Down
12 changes: 10 additions & 2 deletions python/cuvs/cuvs/neighbors/ivf_flat/ivf_flat.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ from pylibraft.common.interruptible import cuda_interruptible

from cuvs.distance import DISTANCE_TYPES
from cuvs.neighbors.common import _check_input_array
from cuvs.neighbors.filters import no_filter

from libc.stdint cimport (
int8_t,
Expand Down Expand Up @@ -274,7 +275,8 @@ def search(SearchParams search_params,
k,
neighbors=None,
distances=None,
resources=None):
resources=None,
filter=None):
"""
Find the k nearest neighbors for each query.
Expand All @@ -293,6 +295,8 @@ def search(SearchParams search_params,
distances : Optional CUDA array interface compliant matrix shape
(n_queries, k) If supplied, the distances to the
neighbors will be written here in-place. (default None)
filter: Optional cuvs.neighbors.cuvsFilter can be used to filter
neighbors based on a given bitset. (default None)
{resources_docstring}
Examples
Expand Down Expand Up @@ -339,6 +343,9 @@ def search(SearchParams search_params,
_check_input_array(distances_cai, [np.dtype('float32')],
exp_rows=n_queries, exp_cols=k)

if filter is None:
filter = no_filter()

cdef cuvsIvfFlatSearchParams* params = search_params.params
cdef cuvsError_t search_status
cdef cydlpack.DLManagedTensor* queries_dlpack = \
Expand All @@ -356,7 +363,8 @@ def search(SearchParams search_params,
index.index,
queries_dlpack,
neighbors_dlpack,
distances_dlpack
distances_dlpack,
filter.prefilter
))

return (distances, neighbors)
Expand Down
92 changes: 91 additions & 1 deletion python/cuvs/cuvs/tests/ann_utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2023-2024, NVIDIA CORPORATION.
# Copyright (c) 2023-2025, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -13,6 +13,10 @@
# limitations under the License.

import numpy as np
from pylibraft.common import device_ndarray
from sklearn.neighbors import NearestNeighbors

from cuvs.neighbors import filters


def generate_data(shape, dtype):
Expand All @@ -33,3 +37,89 @@ def calc_recall(ann_idx, true_nn_idx):
n += np.intersect1d(ann_idx[i, :], true_nn_idx[i, :]).size
recall = n / ann_idx.size
return recall


def create_sparse_bitset(n_size, sparsity):
bits_per_uint32 = 32
num_bits = n_size
num_uint32s = (num_bits + bits_per_uint32 - 1) // bits_per_uint32
num_ones = int(num_bits * sparsity)

array = np.zeros(num_uint32s, dtype=np.uint32)
indices = np.random.choice(num_bits, num_ones, replace=False)

for index in indices:
i = index // bits_per_uint32
bit_position = index % bits_per_uint32
array[i] |= 1 << bit_position

return array


def run_filtered_search_test(
search_module,
sparsity,
n_rows=10000,
n_cols=10,
n_queries=10,
k=10,
):
dataset = generate_data((n_rows, n_cols), np.float32)
queries = generate_data((n_queries, n_cols), np.float32)

bitset = create_sparse_bitset(n_rows, sparsity)

dataset_device = device_ndarray(dataset)
queries_device = device_ndarray(queries)
bitset_device = device_ndarray(bitset)

build_params = search_module.IndexParams()
index = search_module.build(build_params, dataset_device)

filter_ = filters.from_bitset(bitset_device)

search_params = search_module.SearchParams()
ret_distances, ret_indices = search_module.search(
search_params,
index,
queries_device,
k,
filter=filter_,
)

# Convert bitset to bool array for validation
bitset_as_uint8 = bitset.view(np.uint8)
bool_filter = np.unpackbits(bitset_as_uint8)
bool_filter = bool_filter.reshape(-1, 4, 8)
bool_filter = np.flip(bool_filter, axis=2)
bool_filter = bool_filter.reshape(-1)[:n_rows]
bool_filter = np.logical_not(bool_filter) # Flip so True means filtered

# Get filtered dataset for reference calculation
non_filtered_mask = ~bool_filter
filtered_dataset = dataset[non_filtered_mask]

nn_skl = NearestNeighbors(
n_neighbors=k, algorithm="brute", metric="euclidean"
)
nn_skl.fit(filtered_dataset)
skl_idx = nn_skl.kneighbors(queries, return_distance=False)

actual_indices = ret_indices.copy_to_host()

filtered_idx_map = (
np.cumsum(~bool_filter) - 1
) # -1 because cumsum starts at 1

# Map ANN indices to filtered space
mapped_actual_indices = np.take(
filtered_idx_map, actual_indices, mode="clip"
)

filtered_indices = np.where(bool_filter)[0]
for i in range(n_queries):
assert not np.intersect1d(filtered_indices, actual_indices[i]).size

recall = calc_recall(mapped_actual_indices, skl_idx)

assert recall > 0.7
Loading

0 comments on commit ab597ec

Please sign in to comment.