Skip to content

Commit d7784a7

Browse files
committed
Implement inspect_binsparse.cpp, implement visitor pattern for binsparse
types
1 parent 09e262a commit d7784a7

File tree

6 files changed

+267
-20
lines changed

6 files changed

+267
-20
lines changed

examples/Makefile

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,28 @@
11

2-
CXX = g++-12
2+
# = = = = = = = = = = = = = = = = = = = = = = = = = = = #
3+
# Determine the correct value for these variables for
4+
# your system and export them to your environment or
5+
# define them here.
6+
7+
CXX ?= clang++
8+
9+
# Binsparse (this should be fine unless you do an
10+
# out-of-source build)
11+
BINSPARSE_DIR ?= ../include
12+
13+
# HDF5 library location
14+
HDF5_CXXFLAGS ?= -I/opt/homebrew/Cellar/hdf5/1.14.1/include
15+
HDF5_LIBRARY_FLAGS ?= -L/opt/homebrew/Cellar/hdf5/1.14.1/lib -lhdf5_hl_cpp -lhdf5_cpp -lhdf5_hl -lhdf5
16+
17+
# = = = = = = = = = = = = = = = = = = = = = = = = = = = #
318

419
SOURCES += $(wildcard *.cpp)
520
TARGETS := $(patsubst %.cpp, %, $(SOURCES))
621

7-
BINSPARSE_DIR=../include
8-
922
CXXFLAGS = -std=c++20 -O3 -I$(BINSPARSE_DIR)
1023

11-
# Update HDF5 Flags for your HDF5 installation
12-
HDF5_CXXFLAGS ?= -I/opt/homebrew/Cellar/hdf5/1.12.2_2/include
13-
HDF5_LD_FLAGS ?= -L/opt/homebrew/Cellar/hdf5/1.12.2_2/lib -lhdf5_hl_cpp -lhdf5_cpp -lhdf5_hl -lhdf5
14-
1524
CXXFLAGS += $(HDF5_CXXFLAGS)
16-
LD_FLAGS += $(HDF5_LD_FLAGS)
25+
LD_FLAGS += $(HDF5_LIBRARY_FLAGS)
1726

1827
all: $(TARGETS)
1928

examples/convert_binsparse.cpp

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,10 @@
11
#include <binsparse/binsparse.hpp>
2-
#include <grb/grb.hpp>
32
#include <iostream>
43
#include <concepts>
54
#include <complex>
65

76
template <typename T, typename I>
8-
void convert(std::string input_file, std::string output_file, std::string format, std::string comment) {
7+
void convert_to_binsparse(std::string input_file, std::string output_file, std::string format, std::string comment) {
98
if (format == "CSR") {
109
std::cout << "Reading in " << input_file << "...\n";
1110
auto x = binsparse::__detail::mmread<T, I, binsparse::__detail::csr_matrix_owning<T, I>>(input_file);
@@ -21,17 +20,17 @@ void convert(std::string input_file, std::string output_file, std::string format
2120
}
2221

2322
template <typename I>
24-
void convert(std::string input_file, std::string output_file, std::string type,
23+
void convert_to_binsparse(std::string input_file, std::string output_file, std::string type,
2524
std::string format, std::string comment) {
2625
if (type == "real") {
27-
convert<float, I>(input_file, output_file, format, comment);
26+
convert_to_binsparse<float, I>(input_file, output_file, format, comment);
2827
} else if (type == "complex") {
2928
assert(false);
30-
// convert<std::complex<float>, I>(input_file, output_file, format, comment);
29+
// convert_to_binsparse<std::complex<float>, I>(input_file, output_file, format, comment);
3130
} else if (type == "integer") {
32-
convert<int64_t, I>(input_file, output_file, format, comment);
31+
convert_to_binsparse<int64_t, I>(input_file, output_file, format, comment);
3332
} else if (type == "pattern") {
34-
convert<uint8_t, I>(input_file, output_file, format, comment);
33+
convert_to_binsparse<uint8_t, I>(input_file, output_file, format, comment);
3534
}
3635
}
3736

@@ -55,7 +54,7 @@ int main(int argc, char** argv) {
5554
c = std::toupper(c);
5655
}
5756
} else {
58-
format = "CSR";
57+
format = "COO";
5958
}
6059

6160
auto [m, n, nnz, type, comment] = binsparse::mmread_metadata(input_file);
@@ -70,13 +69,13 @@ int main(int argc, char** argv) {
7069
auto max_size = std::max({m, n, nnz});
7170

7271
if (max_size + 1 <= std::numeric_limits<uint8_t>::max()) {
73-
convert<uint8_t>(input_file, output_file, type, format, comment);
72+
convert_to_binsparse<uint8_t>(input_file, output_file, type, format, comment);
7473
} else if (max_size + 1 <= std::numeric_limits<uint16_t>::max()) {
75-
convert<uint16_t>(input_file, output_file, type, format, comment);
74+
convert_to_binsparse<uint16_t>(input_file, output_file, type, format, comment);
7675
} else if (max_size + 1 <= std::numeric_limits<uint32_t>::max()) {
77-
convert<uint32_t>(input_file, output_file, type, format, comment);
76+
convert_to_binsparse<uint32_t>(input_file, output_file, type, format, comment);
7877
} else if (max_size + 1 <= std::numeric_limits<uint64_t>::max()) {
79-
convert<uint64_t>(input_file, output_file, type, format, comment);
78+
convert_to_binsparse<uint64_t>(input_file, output_file, type, format, comment);
8079
} else {
8180
throw std::runtime_error("Error! Matrix dimensions or NNZ too large to handle.");
8281
}

examples/convert_matrixmarket.cpp

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
#include <binsparse/binsparse.hpp>
2+
#include <iostream>
3+
#include <concepts>
4+
#include <complex>
5+
6+
int main(int argc, char** argv) {
7+
8+
if (argc < 2) {
9+
std::cout << "usage: ./inspect_binsparse [input_file.mtx]\n";
10+
return 1;
11+
}
12+
13+
std::string input_file(argv[1]);
14+
15+
auto metadata = binsparse::inspect(input_file);
16+
17+
std::cout << "Inspecting Binsparse v" << metadata["version"] << " file...\n";
18+
std::cout << metadata["format"] << " format matrix of dimension "
19+
<< metadata["shape"] << " with " << metadata["nnz"] << " nonzeros\n";
20+
21+
if (metadata["format"] == "COO") {
22+
auto i0 = metadata["data_types"]["indices_0"];
23+
auto i1 = metadata["data_types"]["indices_1"];
24+
auto t = metadata["data_types"]["values"];
25+
26+
binsparse::visit_label({i0, i1, t},
27+
[&]<typename I1, typename I2, typename T>(I1 i, I2 j, T v) {
28+
using I = std::conditional_t<std::numeric_limits<I1>::max() < std::numeric_limits<I2>::max(),
29+
I2, I1>;
30+
std::cout << "Reading binsparse with index and value types: "
31+
<< binsparse::type_info<I>::label() << " "
32+
<< binsparse::type_info<T>::label() << "\n";
33+
34+
auto m = binsparse::read_coo_matrix<T, I>(input_file);
35+
});
36+
} else if (metadata["format"] == "CSR") {
37+
auto i0 = metadata["data_types"]["pointers_to_1"];
38+
auto i1 = metadata["data_types"]["indices_1"];
39+
auto t = metadata["data_types"]["values"];
40+
41+
binsparse::visit_label({i0, i1, t},
42+
[&]<typename I1, typename I2, typename T>(I1 i, I2 j, T v) {
43+
using I = std::conditional_t<std::numeric_limits<I1>::max() < std::numeric_limits<I2>::max(),
44+
I2, I1>;
45+
std::cout << "Reading binsparse with index and value types: "
46+
<< binsparse::type_info<I>::label() << " "
47+
<< binsparse::type_info<T>::label() << "\n";
48+
49+
auto m = binsparse::read_csr_matrix<T, I>(input_file);
50+
});
51+
} else {
52+
assert(false);
53+
}
54+
55+
return 0;
56+
}

examples/inspect_binsparse.cpp

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
#include <binsparse/binsparse.hpp>
2+
#include <iostream>
3+
#include <concepts>
4+
#include <complex>
5+
6+
int main(int argc, char** argv) {
7+
8+
if (argc < 2) {
9+
std::cout << "usage: ./inspect_binsparse [input_file.mtx]\n";
10+
return 1;
11+
}
12+
13+
std::string input_file(argv[1]);
14+
15+
auto j = binsparse::inspect(input_file);
16+
auto metadata = j["binsparse"];
17+
18+
std::cout << "Inspecting Binsparse v" << metadata["version"] << " file...\n";
19+
std::cout << metadata["format"] << " format matrix of dimension "
20+
<< metadata["shape"] << " with " << metadata["nnz"] << " nonzeros\n";
21+
22+
if (metadata["format"] == "COO") {
23+
auto i0 = metadata["data_types"]["indices_0"];
24+
auto i1 = metadata["data_types"]["indices_1"];
25+
auto t = metadata["data_types"]["values"];
26+
27+
std::cout << "Stored using index types: " << i0 << " " << i1 << std::endl;
28+
std::cout << "Value type: " << t << std::endl;
29+
} else if (metadata["format"] == "CSR") {
30+
auto i0 = metadata["data_types"]["pointers_to_1"];
31+
auto i1 = metadata["data_types"]["indices_1"];
32+
auto t = metadata["data_types"]["values"];
33+
std::cout << "Stored using index types: " << i0 << " " << i1 << std::endl;
34+
std::cout << "Value type: " << t << std::endl;
35+
} else {
36+
assert(false);
37+
}
38+
39+
std::cout << "Raw JSON:\n";
40+
std::cout << j.dump(2) << std::endl;
41+
42+
return 0;
43+
}

include/binsparse/binsparse.hpp

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,21 @@ coo_matrix<T, I> read_coo_matrix(std::string fname) {
141141
return read_coo_matrix<T, I>(fname, std::allocator<T>{});
142142
}
143143

144+
inline auto inspect(std::string fname) {
145+
H5::H5File f(fname.c_str(), H5F_ACC_RDWR);
146+
147+
auto metadata = hdf5_tools::read_dataset<char>(f, "metadata");
148+
149+
using json = nlohmann::json;
150+
auto data = json::parse(metadata);
151+
152+
if (data["binsparse"]["version"] >= 0.1) {
153+
return data;
154+
} else {
155+
assert(false);
156+
}
157+
}
158+
144159
} // end binsparse
145160

146161
#include <binsparse/c_bindings/bc_read_matrix.hpp>

include/binsparse/type_info.hpp

Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#pragma once
22

33
#include <cassert>
4+
#include <type_traits>
45

56
namespace binsparse {
67

@@ -99,4 +100,128 @@ struct type_info<bool> {
99100
}
100101
};
101102

103+
namespace __detail {
104+
105+
template <typename Fn, typename... Args>
106+
requires(std::is_invocable_v<Fn, Args...>)
107+
void invoke_if_able(Fn&& fn, Args&&... args)
108+
{
109+
std::invoke(std::forward<Fn>(fn), std::forward<Args>(args)...);
110+
}
111+
112+
template <typename Fn, typename... Args>
113+
void invoke_if_able(Fn&& fn, Args&&... args) {}
114+
115+
template <typename Fn, typename... Args>
116+
void invoke_visit_fn_impl_(std::vector<std::string> type_labels, Fn&& fn, Args&&... args) {
117+
if constexpr(sizeof...(Args) <= 3) {
118+
if (type_labels.size() == 1) {
119+
auto type_label = type_labels.front();
120+
if (type_label == "uint8") {
121+
invoke_if_able(std::forward<Fn>(fn), std::uint8_t(), std::forward<Args>(args)...);
122+
} else if (type_label == "uint16") {
123+
invoke_if_able(std::forward<Fn>(fn), std::uint16_t(), std::forward<Args>(args)...);
124+
} else if (type_label == "uint32") {
125+
invoke_if_able(std::forward<Fn>(fn), std::uint32_t(), std::forward<Args>(args)...);
126+
} else if (type_label == "uint64") {
127+
invoke_if_able(std::forward<Fn>(fn), std::uint64_t(), std::forward<Args>(args)...);
128+
} else if (type_label == "int8") {
129+
invoke_if_able(std::forward<Fn>(fn), std::int8_t(), std::forward<Args>(args)...);
130+
} else if (type_label == "int16") {
131+
invoke_if_able(std::forward<Fn>(fn), std::int16_t(), std::forward<Args>(args)...);
132+
} else if (type_label == "int32") {
133+
invoke_if_able(std::forward<Fn>(fn), std::int32_t(), std::forward<Args>(args)...);
134+
} else if (type_label == "int64") {
135+
invoke_if_able(std::forward<Fn>(fn), std::int64_t(), std::forward<Args>(args)...);
136+
} else if (type_label == "float32") {
137+
invoke_if_able(std::forward<Fn>(fn), float(), std::forward<Args>(args)...);
138+
} else if (type_label == "float64") {
139+
invoke_if_able(std::forward<Fn>(fn), double(), std::forward<Args>(args)...);
140+
} else if (type_label == "bint8") {
141+
invoke_if_able(std::forward<Fn>(fn), bool(), std::forward<Args>(args)...);
142+
} else {
143+
assert(false);
144+
}
145+
} else {
146+
auto type_label = type_labels.back();
147+
type_labels.pop_back();
148+
if (type_label == "uint8") {
149+
invoke_visit_fn_impl_(type_labels, std::forward<Fn>(fn), std::uint8_t(), std::forward<Args>(args)...);
150+
} else if (type_label == "uint16") {
151+
invoke_visit_fn_impl_(type_labels, std::forward<Fn>(fn), std::uint16_t(), std::forward<Args>(args)...);
152+
} else if (type_label == "uint32") {
153+
invoke_visit_fn_impl_(type_labels, std::forward<Fn>(fn), std::uint32_t(), std::forward<Args>(args)...);
154+
} else if (type_label == "uint64") {
155+
invoke_visit_fn_impl_(type_labels, std::forward<Fn>(fn), std::uint64_t(), std::forward<Args>(args)...);
156+
} else if (type_label == "int8") {
157+
invoke_visit_fn_impl_(type_labels, std::forward<Fn>(fn), std::int8_t(), std::forward<Args>(args)...);
158+
} else if (type_label == "int16") {
159+
invoke_visit_fn_impl_(type_labels, std::forward<Fn>(fn), std::int16_t(), std::forward<Args>(args)...);
160+
} else if (type_label == "int32") {
161+
invoke_visit_fn_impl_(type_labels, std::forward<Fn>(fn), std::int32_t(), std::forward<Args>(args)...);
162+
} else if (type_label == "int64") {
163+
invoke_visit_fn_impl_(type_labels, std::forward<Fn>(fn), std::int64_t(), std::forward<Args>(args)...);
164+
} else if (type_label == "float32") {
165+
invoke_visit_fn_impl_(type_labels, std::forward<Fn>(fn), float(), std::forward<Args>(args)...);
166+
} else if (type_label == "float64") {
167+
invoke_visit_fn_impl_(type_labels, std::forward<Fn>(fn), double(), std::forward<Args>(args)...);
168+
} else if (type_label == "bint8") {
169+
invoke_visit_fn_impl_(type_labels, std::forward<Fn>(fn), bool(), std::forward<Args>(args)...);
170+
} else {
171+
assert(false);
172+
}
173+
}
174+
}
175+
}
176+
177+
/*
178+
template <typename Fn, typename... Args>
179+
void invoke_visit_fn_impl_(std::vector<std::string> type_labels, Fn&& fn, Args&&... args) {
180+
if constexpr(sizeof...(Args) < 10) {
181+
if (type_labels.size() == 1) {
182+
auto type_label = type_labels.front();
183+
if (type_label == "uint8") {
184+
invoke_if_able(std::forward<Fn>(fn), std::uint8_t(), std::forward<Args>(args)...);
185+
} else if (type_label == "uint16") {
186+
invoke_if_able(std::forward<Fn>(fn), std::uint16_t(), std::forward<Args>(args)...);
187+
} else if (type_label == "uint32") {
188+
invoke_if_able(std::forward<Fn>(fn), std::uint32_t(), std::forward<Args>(args)...);
189+
} else if (type_label == "uint64") {
190+
invoke_if_able(std::forward<Fn>(fn), std::uint64_t(), std::forward<Args>(args)...);
191+
} else if (type_label == "int8") {
192+
invoke_if_able(std::forward<Fn>(fn), std::int8_t(), std::forward<Args>(args)...);
193+
} else if (type_label == "int16") {
194+
invoke_if_able(std::forward<Fn>(fn), std::int16_t(), std::forward<Args>(args)...);
195+
} else if (type_label == "int32") {
196+
invoke_if_able(std::forward<Fn>(fn), std::int32_t(), std::forward<Args>(args)...);
197+
} else if (type_label == "int64") {
198+
invoke_if_able(std::forward<Fn>(fn), std::int64_t(), std::forward<Args>(args)...);
199+
} else if (type_label == "float32") {
200+
invoke_if_able(std::forward<Fn>(fn), float(), std::forward<Args>(args)...);
201+
} else if (type_label == "float64") {
202+
invoke_if_able(std::forward<Fn>(fn), double(), std::forward<Args>(args)...);
203+
} else if (type_label == "bint8") {
204+
invoke_if_able(std::forward<Fn>(fn), bool(), std::forward<Args>(args)...);
205+
} else {
206+
assert(false);
207+
}
208+
} else {
209+
auto label = type_labels.back();
210+
type_labels.pop_back();
211+
invoke_visit_fn_impl_({label},
212+
[=](auto&& v) {
213+
invoke_visit_fn_impl_(type_labels, fn, v, args...);
214+
});
215+
}
216+
}
217+
}
218+
*/
219+
220+
} // end __detail
221+
222+
template <typename Fn>
223+
inline void visit_label(const std::vector<std::string>& type_labels, Fn&& fn) {
224+
__detail::invoke_visit_fn_impl_(type_labels, fn);
225+
}
226+
102227
} // end binsparse

0 commit comments

Comments
 (0)