Skip to content

Commit

Permalink
Merge pull request #490 from aprokop/dbscan_quality_of_life_improvements
Browse files Browse the repository at this point in the history
  • Loading branch information
aprokop authored Mar 13, 2021
2 parents c4fe03c + f16863e commit 41dff19
Showing 1 changed file with 47 additions and 2 deletions.
49 changes: 47 additions & 2 deletions examples/dbscan/dbscan.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,39 @@ std::vector<ArborX::Point> parsePoints(std::string const &filename,
return v;
}

std::vector<ArborX::Point> sampleData(std::vector<ArborX::Point> const &data,
int num_samples)
{
std::vector<ArborX::Point> sampled_data(num_samples);

// Knuth algorithm
auto const N = (int)data.size();
auto const M = num_samples;
for (int in = 0, im = 0; in < N && im < M; ++in)
{
int rn = N - in;
int rm = M - im;
if (rand() % rn < rm)
sampled_data[im++] = data[in];
}
return sampled_data;
}

template <typename MemorySpace>
void writeLabelsData(std::string const &filename,
Kokkos::View<int *, MemorySpace> labels)
{
std::ofstream out(filename, std::ofstream::binary);
ARBORX_ASSERT(out.good());

auto labels_host =
Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace{}, labels);

int n = labels_host.size();
out.write((char *)&n, sizeof(int));
out.write((char *)labels_host.data(), sizeof(int) * n);
}

template <typename... P, typename T>
auto vec2view(std::vector<T> const &in, std::string const &label = "")
{
Expand Down Expand Up @@ -298,6 +331,8 @@ int main(int argc, char *argv[])
int cluster_min_size;
int core_min_size;
int max_num_points;
int num_samples;
std::string filename_labels;

bpo::options_description desc("Allowed options");
// clang-format off
Expand All @@ -310,6 +345,8 @@ int main(int argc, char *argv[])
( "cluster-min-size", bpo::value<int>(&cluster_min_size)->default_value(2), "minimum cluster size")
( "core-min-size", bpo::value<int>(&core_min_size)->default_value(2), "DBSCAN min_pts")
( "verify", bpo::bool_switch(&verify)->default_value(false), "verify connected components")
( "samples", bpo::value<int>(&num_samples)->default_value(-1), "number of samples" )
( "labels", bpo::value<std::string>(&filename_labels)->default_value(""), "clutering results output" )
( "print-dbscan-timers", bpo::bool_switch(&print_dbscan_timers)->default_value(false), "print dbscan timers")
( "output-sizes-and-centers", bpo::bool_switch(&print_sizes_centers)->default_value(false), "print cluster sizes and centers")
;
Expand All @@ -330,13 +367,18 @@ int main(int argc, char *argv[])
printf("cluster min size : %d\n", cluster_min_size);
printf("filename : %s [%s, max_pts = %d]\n", filename.c_str(),
(binary ? "binary" : "text"), max_num_points);
printf("filename [labels] : %s [binary]\n", filename_labels.c_str());
printf("samples : %d\n", num_samples);
printf("verify : %s\n", (verify ? "true" : "false"));
printf("print timers : %s\n", (print_dbscan_timers ? "true" : "false"));
printf("output centers : %s\n", (print_sizes_centers ? "true" : "false"));

// read in data
auto const primitives = vec2view<MemorySpace>(
parsePoints(filename, binary, max_num_points), "primitives");
std::vector<ArborX::Point> data =
parsePoints(filename, binary, max_num_points);
if (num_samples > 0 && num_samples < (int)data.size())
data = sampleData(data, num_samples);
auto const primitives = vec2view<MemorySpace>(data, "primitives");

ExecutionSpace exec_space;

Expand Down Expand Up @@ -381,6 +423,9 @@ int main(int argc, char *argv[])
printf("Verification %s\n", (passed ? "passed" : "failed"));
}

if (!filename_labels.empty())
writeLabelsData(filename_labels, labels);

if (print_sizes_centers)
printClusterSizesAndCenters(exec_space, primitives, cluster_indices,
cluster_offset);
Expand Down

0 comments on commit 41dff19

Please sign in to comment.