Skip to content

Commit b00c272

Browse files
committed
fix batch gathering for multi threading +publish
1 parent 31a195c commit b00c272

File tree

2 files changed

+10
-10
lines changed

2 files changed

+10
-10
lines changed

cpp/deglib/include/builder.h

+9-9
Original file line numberDiff line numberDiff line change
@@ -358,7 +358,7 @@ class EvenRegularGraphBuilder {
358358
/**
359359
* Extend the graph with a new vertex. Find good existing vertex to which this new vertex gets connected.
360360
*/
361-
void extendGraph(std::vector<BuilderAddTask>& add_tasks) {
361+
void extendGraph(const std::vector<BuilderAddTask>& add_tasks) {
362362
auto& graph = this->graph_;
363363

364364
// for computing distances to neighbors not in the result queue
@@ -399,14 +399,15 @@ class EvenRegularGraphBuilder {
399399
} else {
400400
const auto remaining_add_tasks = std::vector<BuilderAddTask>(add_tasks.begin() + index, add_tasks.end());
401401

402-
auto batchExtendGraphKnownLID = [&](const std::vector<BuilderAddTask>& add_tasks, size_t task_index) {
402+
auto batchExtendGraphKnownLID = [&](const std::vector<BuilderAddTask>& tasks, size_t task_index) {
403403
const auto start_index = task_index * this->extend_thread_task_size;
404-
const auto end_index = std::min(add_tasks.size(), (task_index+1) * this->extend_thread_task_size);
405-
for (size_t i = start_index; i < end_index; i++)
406-
extendGraphKnownLID(add_tasks[i]);
404+
const auto end_index = std::min(tasks.size(), (task_index+1) * this->extend_thread_task_size);
405+
for (size_t i = start_index; i < end_index; i++)
406+
extendGraphKnownLID(tasks[i]);
407407
};
408408

409-
parallel_for(0, remaining_add_tasks.size(), this->extend_thread_count, [&] (size_t task_index) {
409+
size_t task_count = (remaining_add_tasks.size() / extend_thread_task_size) + ((remaining_add_tasks.size() % extend_thread_task_size != 0) ? 1 : 0); // +1, if n_queries % batch_size != 0
410+
parallel_for(0, task_count, extend_thread_count, [&] (size_t task_index) {
410411
batchExtendGraphKnownLID(remaining_add_tasks, task_index);
411412
});
412413
}
@@ -1370,7 +1371,7 @@ class EvenRegularGraphBuilder {
13701371
// create batches
13711372
auto batch = std::vector<BuilderAddTask>();
13721373
batch.reserve(this->extend_batch_size);
1373-
while(batch.size() < this->extend_batch_size && this->new_entry_queue_.front().manipulation_index < del_task_manipulation_index) {
1374+
while(this->new_entry_queue_.size() > 0 && batch.size() < this->extend_batch_size && this->new_entry_queue_.front().manipulation_index < del_task_manipulation_index) {
13741375
batch.push_back(std::move(this->new_entry_queue_.front()));
13751376
this->new_entry_queue_.pop_front();
13761377
}
@@ -1475,8 +1476,7 @@ void optimze_edges(deglib::graph::MutableGraph& graph, const uint8_t k_opt, cons
14751476
auto connected = deglib::analysis::check_graph_connectivity(graph);
14761477

14771478
auto duration = duration_ms / 1000;
1478-
// fmt::print("{:7} step, {:5}s, AEW: {:4.2f}, {} connected, {}\n", status.step, duration, avg_edge_weight, connected ? "" : "not", valid_weights ? "valid" : "invalid");
1479-
std::cout << std::setw(7) << status.step << " step, " << std::setw(5) << duration << "s, AEW: " << std::fixed << std::setprecision(2) << std::setw(4) << avg_edge_weight << ", " << (connected ? "" : "not") << " connected, " << (valid_weights ? "valid" : "invalid") << "\n";
1479+
std::cout << std::setw(7) << status.step << " step, " << std::setw(5) << duration << "s, AEW: " << std::fixed << std::setprecision(2) << std::setw(4) << avg_edge_weight << ", " << (connected ? "" : "not") << " connected, " << (valid_weights ? "valid" : "invalid") << "\n";
14801480
start = std::chrono::steady_clock::now();
14811481
}
14821482

python/src/deg_cpp/deglib_cpp.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,7 @@ std::tuple<py::array_t<uint32_t>, py::array_t<float>> graph_search_wrapper(
149149
search_one_query(graph, query_index, query_info, entry_vertex_indices, k, max_distance_computation_count, eps, result_indices_ptr, result_distances_ptr);
150150
}
151151
} else {
152-
size_t n_batches = (n_queries / batch_size) + (n_queries % batch_size); // +1, if n_queries % batch_size != 0
152+
size_t n_batches = (n_queries / batch_size) + ((n_queries % batch_size != 0) ? 1 : 0); // +1, if n_queries % batch_size != 0
153153
parallel_for(0, n_batches, threads, [&] (size_t batch_index, size_t thread_id) {
154154
// search_one_query(graph, query_index, query_info, entry_vertex_indices, k, max_distance_computation_count, eps, result_indices_ptr, result_distances_ptr);
155155
search_batch_of_queries(graph, batch_index, batch_size, query_info, entry_vertex_indices, k, max_distance_computation_count, eps, result_indices_ptr, result_distances_ptr);

0 commit comments

Comments
 (0)