Skip to content

Commit add42fb

Browse files
fixup
1 parent 7ff9337 commit add42fb

File tree

4 files changed

+41
-27
lines changed

4 files changed

+41
-27
lines changed

src/htool/local_operator/local_operator.hpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,9 @@ class RestrictedGlobalToLocalOperatorPython : public htool::RestrictedGlobalToLo
3232
// virtual void local_add_vector_product_symmetric(char trans, CoefficientPrecision alpha, const std::vector<CoefficientPrecision> &in, CoefficientPrecision beta, std::vector<CoefficientPrecision> &out) const = 0; // LCOV_EXCL_LINE
3333

3434
virtual void add_matrix_product_row_major(char trans, CoefficientPrecision alpha, const py::array_t<CoefficientPrecision, py::array::c_style> &in, CoefficientPrecision beta, py::array_t<CoefficientPrecision, py::array::c_style> &out) const = 0; // LCOV_EXCL_LINE
35+
36+
LocalRenumbering get_local_target_renumbering()const {return this->m_local_target_renumbering;}
37+
LocalRenumbering get_local_source_renumbering()const {return this->m_local_source_renumbering;}
3538
};
3639

3740
template <typename CoefficientPrecision>
@@ -79,6 +82,8 @@ void declare_global_to_local_operator(py::module &m, const std::string &class_na
7982
py_class.def(py::init<LocalRenumbering, LocalRenumbering, bool, bool>());
8083
py_class.def("add_vector_product", &Class::add_vector_product, py::arg("trans"), py::arg("alpha"), py::arg("in").noconvert(true), py::arg("beta"), py::arg("out").noconvert(true));
8184
py_class.def("add_matrix_product_row_major", &Class::add_matrix_product_row_major);
85+
py_class.def_property_readonly("local_target_renumbering",&Class::get_local_target_renumbering);
86+
py_class.def_property_readonly("local_source_renumbering",&Class::get_local_source_renumbering);
8287
}
8388

8489
#endif

src/htool/local_operator/local_renumbering.hpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@ void declare_local_renumbering(py::module &m, const std::string &className) {
1515
py_class.def(py::init<const Cluster<CoordinatePrecision> &>());
1616
py_class.def_property_readonly("offset", &Class::get_offset);
1717
py_class.def_property_readonly("size", &Class::get_size);
18+
py_class.def_property_readonly("global_size", &Class::get_global_size);
19+
py_class.def_property_readonly("is_stable", &Class::is_stable);
1820
py_class.def_property_readonly("permutation", [](const Class &self) { return py::array_t<int>(std::array<long int, 1>{self.get_global_size()}, self.get_permutation(), py::capsule(self.get_permutation())); });
1921
}
2022

src/htool/local_operator/virtual_local_to_local_operator.hpp

Lines changed: 33 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -6,42 +6,49 @@
66

77
template <typename CoefficientPrecision, typename CoordinatePrecision = htool::underlying_type<CoefficientPrecision>>
88
class VirtualLocalToLocalOperatorPython : public htool::VirtualLocalToLocalOperator<CoefficientPrecision> {
9-
const Cluster<CoordinatePrecision> &m_target_cluster;
10-
const Cluster<CoordinatePrecision> &m_source_cluster;
9+
LocalRenumbering m_local_target_renumbering;
10+
LocalRenumbering m_local_source_renumbering;
1111

1212
public:
13-
VirtualLocalToLocalOperatorPython(const Cluster<CoordinatePrecision> &target_cluster, const Cluster<CoordinatePrecision> &source_cluster) : m_target_cluster(target_cluster), m_source_cluster(source_cluster) {}
13+
VirtualLocalToLocalOperatorPython(LocalRenumbering local_target_renumbering, LocalRenumbering local_source_renumbering) : m_local_target_renumbering(local_target_renumbering), m_local_source_renumbering(local_source_renumbering) {}
1414

1515
void add_vector_product(char trans, CoefficientPrecision alpha, const CoefficientPrecision *const in, CoefficientPrecision beta, CoefficientPrecision *const out) const override {
16-
py::array_t<CoefficientPrecision> input(std::array<long int, 1>{trans == 'N' ? m_source_cluster.get_size() : m_target_cluster.get_size()}, in, py::capsule(in));
17-
py::array_t<CoefficientPrecision> output(std::array<long int, 1>{trans == 'N' ? m_target_cluster.get_size() : m_source_cluster.get_size()}, out, py::capsule(out));
16+
py::array_t<CoefficientPrecision> input(std::array<long int, 1>{trans == 'N' ? m_local_source_renumbering.get_size() : m_local_target_renumbering.get_size()}, in, py::capsule(in));
17+
py::array_t<CoefficientPrecision> output(std::array<long int, 1>{trans == 'N' ? m_local_target_renumbering.get_size() : m_local_source_renumbering.get_size()}, out, py::capsule(out));
1818

1919
local_add_vector_product(trans, alpha, input, beta, output);
2020
}
2121
void add_matrix_product_row_major(char trans, CoefficientPrecision alpha, const CoefficientPrecision *const in, CoefficientPrecision beta, CoefficientPrecision *const out, int mu) const override {
22-
py::array_t<CoefficientPrecision, py::array::c_style> input(std::array<long int, 2>{trans == 'N' ? m_source_cluster.get_size() : m_target_cluster.get_size(), mu}, in, py::capsule(in));
23-
py::array_t<CoefficientPrecision, py::array::c_style> output(std::array<long int, 2>{trans == 'N' ? m_target_cluster.get_size() : m_source_cluster.get_size(), mu}, out, py::capsule(out));
22+
py::array_t<CoefficientPrecision, py::array::c_style> input(std::array<long int, 2>{trans == 'N' ? m_local_source_renumbering.get_size() : m_local_target_renumbering.get_size(), mu}, in, py::capsule(in));
23+
py::array_t<CoefficientPrecision, py::array::c_style> output(std::array<long int, 2>{trans == 'N' ? m_local_target_renumbering.get_size() : m_local_source_renumbering.get_size(), mu}, out, py::capsule(out));
2424

2525
local_add_matrix_product_row_major(trans, alpha, input, beta, output);
2626
}
2727

2828
void add_sub_matrix_product_to_local(const CoefficientPrecision *const in, CoefficientPrecision *const out, int mu, int offset, int size) const override {
29-
int source_offset = m_source_cluster.get_offset();
30-
int source_size = m_source_cluster.get_size();
31-
bool is_output_null = ((offset + size) < source_offset) || (source_offset + source_size < offset);
32-
if (!is_output_null) {
33-
int temp_offset = std::max(offset, source_offset);
34-
const CoefficientPrecision *const temp_in = (offset < source_offset) ? in + source_offset - offset : in;
35-
int temp_size = (size + offset <= source_size + source_offset) ? size - std::max(source_offset - offset, 0) : size - std::max(source_offset - offset, 0) - (size + offset - source_offset - source_size);
36-
37-
if (temp_offset == source_offset && temp_size == source_size)
38-
add_matrix_product_row_major('N', 1, temp_in, 1, out, mu);
39-
else {
40-
std::vector<CoefficientPrecision> extension_by_zero(source_size * mu);
29+
int source_offset = m_local_source_renumbering.get_offset();
30+
int source_size = m_local_source_renumbering.get_size();
31+
32+
int source_end = source_size+source_offset;
33+
int end = size+offset;
34+
35+
int temp_offset = std::max(offset,source_offset);
36+
int temp_end = std::min(source_end,end);
37+
38+
bool is_output_null = temp_end-temp_offset<=0 ? true:false;
39+
if (offset == source_offset && temp_end == source_end){
40+
add_matrix_product_row_major('N', 1, in, 1, out, mu);
41+
}
42+
else {
43+
std::vector<CoefficientPrecision> extension_by_zero(source_size * mu);
44+
if (!is_output_null){
45+
const CoefficientPrecision *const temp_in = in + temp_offset-offset;
46+
int temp_size = temp_end-temp_offset;
4147
std::copy_n(temp_in, temp_size * mu, extension_by_zero.data());
42-
add_matrix_product_row_major('N', 1, extension_by_zero.data(), 1, out, mu);
4348
}
49+
add_matrix_product_row_major('N', 1, extension_by_zero.data(), 1, out, mu);
4450
}
51+
4552
}
4653

4754
virtual void local_add_vector_product(char trans, CoefficientPrecision alpha, const py::array_t<CoefficientPrecision> &in, CoefficientPrecision beta, py::array_t<CoefficientPrecision> &out) const = 0; // LCOV_EXCL_LINE
@@ -59,7 +66,7 @@ class PyVirtualLocalToLocalOperator : public VirtualLocalToLocalOperatorPython<C
5966
PYBIND11_OVERRIDE_PURE(
6067
void, /* Return type */
6168
VirtualLocalToLocalOperatorPython<CoefficientPrecision>, /* Parent class */
62-
add_vector_product, /* Name of function in C++ (must match Python name) */
69+
local_add_vector_product, /* Name of function in C++ (must match Python name) */
6370
trans,
6471
alpha,
6572
in,
@@ -71,7 +78,7 @@ class PyVirtualLocalToLocalOperator : public VirtualLocalToLocalOperatorPython<C
7178
PYBIND11_OVERRIDE_PURE(
7279
void, /* Return type */
7380
VirtualLocalToLocalOperatorPython<CoefficientPrecision>, /* Parent class */
74-
add_matrix_product_row_major, /* Name of function in C++ (must match Python name) */
81+
local_add_matrix_product_row_major, /* Name of function in C++ (must match Python name) */
7582
trans,
7683
alpha,
7784
in,
@@ -87,10 +94,10 @@ void declare_virtual_local_to_local_operator(py::module &m, const std::string &c
8794
py::class_<BaseClass>(m, (base_class_name).c_str());
8895

8996
using Class = VirtualLocalToLocalOperatorPython<CoefficientPrecision>;
90-
py::class_<Class, PyVirtualLocalToLocalOperator<CoefficientPrecision>, BaseClass> py_class(m, className.c_str());
91-
py_class.def(py::init<const Cluster<CoordinatePrecision> &, const Cluster<CoordinatePrecision> &>());
92-
py_class.def("local_add_vector_product", &Class::add_vector_product, py::arg("trans"), py::arg("alpha"), py::arg("in").noconvert(true), py::arg("beta"), py::arg("out").noconvert(true));
93-
py_class.def("local_add_matrix_product_row_major", &Class::add_matrix_product_row_major);
97+
py::class_<Class, BaseClass, PyVirtualLocalToLocalOperator<CoefficientPrecision>> py_class(m, className.c_str());
98+
py_class.def(py::init<LocalRenumbering, LocalRenumbering>());
99+
py_class.def("local_add_vector_product", &Class::local_add_vector_product, py::arg("trans"), py::arg("alpha"), py::arg("in").noconvert(true), py::arg("beta"), py::arg("out").noconvert(true));
100+
py_class.def("local_add_matrix_product_row_major", &Class::local_add_matrix_product_row_major);
94101
}
95102

96103
#endif

0 commit comments

Comments
 (0)