Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ANN_BENCH enhanced dataset support #624

Merged
merged 12 commits into from
Feb 1, 2025
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 16 additions & 1 deletion cpp/bench/ann/src/common/ann_types.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,19 @@ enum class MemoryType {
kHostMmap,
kHostPinned,
kDevice,
kManaged,
};

/** Request 2MB huge pages support for an allocation */
enum class HugePages {
/** Don't use huge pages if possible. */
kDisable = 0,
/** Enable huge pages if possible, ignore otherwise. */
kAsk = 1,
/** Enable huge pages if possible, warn the user otherwise. */
kRequire = 2,
/** Force enable huge pages, throw an exception if not possible. */
kDemand = 3
tfeher marked this conversation as resolved.
Show resolved Hide resolved
};

enum class Metric {
Expand Down Expand Up @@ -65,6 +78,8 @@ inline auto parse_memory_type(const std::string& memory_type) -> MemoryType
return MemoryType::kHostPinned;
} else if (memory_type == "device") {
return MemoryType::kDevice;
} else if (memory_type == "managed") {
return MemoryType::kManaged;
} else {
throw std::runtime_error("invalid memory type: '" + memory_type + "'");
}
Expand Down Expand Up @@ -130,7 +145,7 @@ class algo : public algo_base {

virtual void build(const T* dataset, size_t nrow) = 0;

virtual void set_search_param(const search_param& param) = 0;
virtual void set_search_param(const search_param& param, const void* filter_bitset) = 0;
// TODO(snanditale): this assumes that an algorithm can always return k results.
// This is not always possible.
virtual void search(const T* queries,
Expand Down
59 changes: 45 additions & 14 deletions cpp/bench/ann/src/common/benchmark.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,8 @@ void bench_search(::benchmark::State& state,
}
}
try {
a->set_search_param(*search_param);
a->set_search_param(*search_param,
dataset->filter_bitset(current_algo_props->dataset_memory_type));
} catch (const std::exception& ex) {
state.SkipWithError("An error occurred setting search parameters: " + std::string(ex.what()));
return;
Expand Down Expand Up @@ -359,13 +360,19 @@ void bench_search(::benchmark::State& state,
// Each thread calculates recall on their partition of queries.
// evaluate recall
if (dataset->max_k() >= k) {
const std::int32_t* gt = dataset->gt_set();
const std::int32_t* gt = dataset->gt_set();
const std::uint32_t* filter_bitset = dataset->filter_bitset(MemoryType::kHostMmap);
auto filter = [filter_bitset](std::int32_t i) -> bool {
if (filter_bitset == nullptr) { return true; }
auto word = filter_bitset[i >> 5];
return word & (1 << (i & 31));
};
const std::uint32_t max_k = dataset->max_k();
result_buf.transfer_data(MemoryType::kHost, current_algo_props->query_memory_type);
auto* neighbors_host = reinterpret_cast<index_type*>(result_buf.data(MemoryType::kHost));
std::size_t rows = std::min(queries_processed, query_set_size);
std::size_t match_count = 0;
std::size_t total_count = rows * static_cast<size_t>(k);
std::size_t total_count = 0;

// We go through the groundtruth with same stride as the benchmark loop.
size_t out_offset = 0;
Expand All @@ -375,22 +382,44 @@ void bench_search(::benchmark::State& state,
size_t i_orig_idx = batch_offset + i;
size_t i_out_idx = out_offset + i;
if (i_out_idx < rows) {
for (std::uint32_t j = 0; j < k; j++) {
auto act_idx = static_cast<std::int32_t>(neighbors_host[i_out_idx * k + j]);
for (std::uint32_t l = 0; l < k; l++) {
auto exp_idx = gt[i_orig_idx * max_k + l];
/* NOTE: recall correctness & filtering

In the loop below, we filter the ground truth values on-the-fly.
We need enough ground truth values to compute recall correctly though.
But the ground truth file only contains `max_k` values per row; if there are less valid
values than k among them, we overestimate the recall. Essentially, we compare the first
`filter_pass_count` values of the algorithm output, and this counter can be less than `k`.
In the extreme case of very high filtering rate, we may be bypassing entire rows of
results. However, this is still better than no recall estimate at all.

TODO: consider generating the filtered ground truth on-the-fly
*/
uint32_t filter_pass_count = 0;
for (std::uint32_t l = 0; l < max_k && filter_pass_count < k; l++) {
auto exp_idx = gt[i_orig_idx * max_k + l];
if (!filter(exp_idx)) { continue; }
filter_pass_count++;
for (std::uint32_t j = 0; j < k; j++) {
auto act_idx = static_cast<std::int32_t>(neighbors_host[i_out_idx * k + j]);
if (act_idx == exp_idx) {
match_count++;
break;
}
}
}
total_count += filter_pass_count;
}
}
out_offset += n_queries;
batch_offset = (batch_offset + queries_stride) % query_set_size;
}
double actual_recall = static_cast<double>(match_count) / static_cast<double>(total_count);
/* NOTE: recall in the throughput mode & filtering

When filtering is enabled, `total_count` may vary between individual threads, but we still take
the simple average across in-thread recalls. Strictly speaking, this is incorrect, but it's good
enough under assumption that the filtering is more-or-less uniform.
*/
state.counters.insert({"Recall", {actual_recall, benchmark::Counter::kAvgThreads}});
}
}
Expand Down Expand Up @@ -515,13 +544,15 @@ void dispatch_benchmark(std::string cmdline,
auto query_file = combine_path(data_prefix, dataset_conf.query_file);
auto gt_file = dataset_conf.groundtruth_neighbors_file;
if (gt_file.has_value()) { gt_file.emplace(combine_path(data_prefix, gt_file.value())); }
auto dataset = std::make_shared<bin_dataset<T>>(dataset_conf.name,
base_file,
dataset_conf.subset_first_row,
dataset_conf.subset_size,
query_file,
dataset_conf.distance,
gt_file);
auto dataset =
std::make_shared<bench::dataset<T>>(dataset_conf.name,
base_file,
dataset_conf.subset_first_row,
dataset_conf.subset_size,
query_file,
dataset_conf.distance,
gt_file,
search_mode ? dataset_conf.filtering_rate : std::nullopt);
::benchmark::AddCustomContext("dataset", dataset_conf.name);
::benchmark::AddCustomContext("distance", dataset_conf.distance);
std::vector<configuration::index> indices = conf.get_indices();
Expand Down
9 changes: 7 additions & 2 deletions cpp/bench/ann/src/common/conf.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,14 +45,16 @@ class configuration {
// the range of rows is [subset_first_row, subset_first_row + subset_size)
// however, subset_size = 0 means using all rows after subset_first_row
// that is, the subset is [subset_first_row, #rows in base_file)
size_t subset_first_row{0};
size_t subset_size{0};
uint32_t subset_first_row{0};
uint32_t subset_size{0};
tfeher marked this conversation as resolved.
Show resolved Hide resolved
std::string query_file;
std::string distance;
std::optional<std::string> groundtruth_neighbors_file{std::nullopt};

// data type of input dataset, possible values ["float", "int8", "uint8"]
std::string dtype;

std::optional<double> filtering_rate{std::nullopt};
};

explicit inline configuration(std::istream& conf_stream)
Expand All @@ -74,6 +76,9 @@ class configuration {
dataset_conf_.base_file = conf.at("base_file");
dataset_conf_.query_file = conf.at("query_file");
dataset_conf_.distance = conf.at("distance");
if (conf.contains("filtering_rate")) {
dataset_conf_.filtering_rate.emplace(conf.at("filtering_rate"));
}

if (conf.contains("groundtruth_neighbors_file")) {
dataset_conf_.groundtruth_neighbors_file = conf.at("groundtruth_neighbors_file");
Expand Down
Loading
Loading