6
6
7
7
template <typename CoefficientPrecision, typename CoordinatePrecision = htool::underlying_type<CoefficientPrecision>>
8
8
class VirtualLocalToLocalOperatorPython : public htool ::VirtualLocalToLocalOperator<CoefficientPrecision> {
9
- const Cluster<CoordinatePrecision> &m_target_cluster ;
10
- const Cluster<CoordinatePrecision> &m_source_cluster ;
9
+ LocalRenumbering m_local_target_renumbering ;
10
+ LocalRenumbering m_local_source_renumbering ;
11
11
12
12
public:
13
- VirtualLocalToLocalOperatorPython (const Cluster<CoordinatePrecision> &target_cluster, const Cluster<CoordinatePrecision> &source_cluster ) : m_target_cluster(target_cluster ), m_source_cluster(source_cluster ) {}
13
+ VirtualLocalToLocalOperatorPython (LocalRenumbering local_target_renumbering, LocalRenumbering local_source_renumbering ) : m_local_target_renumbering(local_target_renumbering ), m_local_source_renumbering(local_source_renumbering ) {}
14
14
15
15
void add_vector_product (char trans, CoefficientPrecision alpha, const CoefficientPrecision *const in, CoefficientPrecision beta, CoefficientPrecision *const out) const override {
16
- py::array_t <CoefficientPrecision> input (std::array<long int , 1 >{trans == ' N' ? m_source_cluster .get_size () : m_target_cluster .get_size ()}, in, py::capsule (in));
17
- py::array_t <CoefficientPrecision> output (std::array<long int , 1 >{trans == ' N' ? m_target_cluster .get_size () : m_source_cluster .get_size ()}, out, py::capsule (out));
16
+ py::array_t <CoefficientPrecision> input (std::array<long int , 1 >{trans == ' N' ? m_local_source_renumbering .get_size () : m_local_target_renumbering .get_size ()}, in, py::capsule (in));
17
+ py::array_t <CoefficientPrecision> output (std::array<long int , 1 >{trans == ' N' ? m_local_target_renumbering .get_size () : m_local_source_renumbering .get_size ()}, out, py::capsule (out));
18
18
19
19
local_add_vector_product (trans, alpha, input, beta, output);
20
20
}
21
21
void add_matrix_product_row_major (char trans, CoefficientPrecision alpha, const CoefficientPrecision *const in, CoefficientPrecision beta, CoefficientPrecision *const out, int mu) const override {
22
- py::array_t <CoefficientPrecision, py::array::c_style> input (std::array<long int , 2 >{trans == ' N' ? m_source_cluster .get_size () : m_target_cluster .get_size (), mu}, in, py::capsule (in));
23
- py::array_t <CoefficientPrecision, py::array::c_style> output (std::array<long int , 2 >{trans == ' N' ? m_target_cluster .get_size () : m_source_cluster .get_size (), mu}, out, py::capsule (out));
22
+ py::array_t <CoefficientPrecision, py::array::c_style> input (std::array<long int , 2 >{trans == ' N' ? m_local_source_renumbering .get_size () : m_local_target_renumbering .get_size (), mu}, in, py::capsule (in));
23
+ py::array_t <CoefficientPrecision, py::array::c_style> output (std::array<long int , 2 >{trans == ' N' ? m_local_target_renumbering .get_size () : m_local_source_renumbering .get_size (), mu}, out, py::capsule (out));
24
24
25
25
local_add_matrix_product_row_major (trans, alpha, input, beta, output);
26
26
}
27
27
28
28
void add_sub_matrix_product_to_local (const CoefficientPrecision *const in, CoefficientPrecision *const out, int mu, int offset, int size) const override {
29
- int source_offset = m_source_cluster.get_offset ();
30
- int source_size = m_source_cluster.get_size ();
31
- bool is_output_null = ((offset + size) < source_offset) || (source_offset + source_size < offset);
32
- if (!is_output_null) {
33
- int temp_offset = std::max (offset, source_offset);
34
- const CoefficientPrecision *const temp_in = (offset < source_offset) ? in + source_offset - offset : in;
35
- int temp_size = (size + offset <= source_size + source_offset) ? size - std::max (source_offset - offset, 0 ) : size - std::max (source_offset - offset, 0 ) - (size + offset - source_offset - source_size);
36
-
37
- if (temp_offset == source_offset && temp_size == source_size)
38
- add_matrix_product_row_major (' N' , 1 , temp_in, 1 , out, mu);
39
- else {
40
- std::vector<CoefficientPrecision> extension_by_zero (source_size * mu);
29
+ int source_offset = m_local_source_renumbering.get_offset ();
30
+ int source_size = m_local_source_renumbering.get_size ();
31
+
32
+ int source_end = source_size+source_offset;
33
+ int end = size+offset;
34
+
35
+ int temp_offset = std::max (offset,source_offset);
36
+ int temp_end = std::min (source_end,end);
37
+
38
+ bool is_output_null = temp_end-temp_offset<=0 ? true :false ;
39
+ if (offset == source_offset && temp_end == source_end){
40
+ add_matrix_product_row_major (' N' , 1 , in, 1 , out, mu);
41
+ }
42
+ else {
43
+ std::vector<CoefficientPrecision> extension_by_zero (source_size * mu);
44
+ if (!is_output_null){
45
+ const CoefficientPrecision *const temp_in = in + temp_offset-offset;
46
+ int temp_size = temp_end-temp_offset;
41
47
std::copy_n (temp_in, temp_size * mu, extension_by_zero.data ());
42
- add_matrix_product_row_major (' N' , 1 , extension_by_zero.data (), 1 , out, mu);
43
48
}
49
+ add_matrix_product_row_major (' N' , 1 , extension_by_zero.data (), 1 , out, mu);
44
50
}
51
+
45
52
}
46
53
47
54
virtual void local_add_vector_product (char trans, CoefficientPrecision alpha, const py::array_t <CoefficientPrecision> &in, CoefficientPrecision beta, py::array_t <CoefficientPrecision> &out) const = 0; // LCOV_EXCL_LINE
@@ -59,7 +66,7 @@ class PyVirtualLocalToLocalOperator : public VirtualLocalToLocalOperatorPython<C
59
66
PYBIND11_OVERRIDE_PURE (
60
67
void , /* Return type */
61
68
VirtualLocalToLocalOperatorPython<CoefficientPrecision>, /* Parent class */
62
- add_vector_product , /* Name of function in C++ (must match Python name) */
69
+ local_add_vector_product , /* Name of function in C++ (must match Python name) */
63
70
trans,
64
71
alpha,
65
72
in,
@@ -71,7 +78,7 @@ class PyVirtualLocalToLocalOperator : public VirtualLocalToLocalOperatorPython<C
71
78
PYBIND11_OVERRIDE_PURE (
72
79
void , /* Return type */
73
80
VirtualLocalToLocalOperatorPython<CoefficientPrecision>, /* Parent class */
74
- add_matrix_product_row_major , /* Name of function in C++ (must match Python name) */
81
+ local_add_matrix_product_row_major , /* Name of function in C++ (must match Python name) */
75
82
trans,
76
83
alpha,
77
84
in,
@@ -87,10 +94,10 @@ void declare_virtual_local_to_local_operator(py::module &m, const std::string &c
87
94
py::class_<BaseClass>(m, (base_class_name).c_str ());
88
95
89
96
using Class = VirtualLocalToLocalOperatorPython<CoefficientPrecision>;
90
- py::class_<Class, PyVirtualLocalToLocalOperator<CoefficientPrecision>, BaseClass > py_class (m, className.c_str ());
91
- py_class.def (py::init<const Cluster<CoordinatePrecision> &, const Cluster<CoordinatePrecision> & >());
92
- py_class.def (" local_add_vector_product" , &Class::add_vector_product , py::arg (" trans" ), py::arg (" alpha" ), py::arg (" in" ).noconvert (true ), py::arg (" beta" ), py::arg (" out" ).noconvert (true ));
93
- py_class.def (" local_add_matrix_product_row_major" , &Class::add_matrix_product_row_major );
97
+ py::class_<Class, BaseClass, PyVirtualLocalToLocalOperator<CoefficientPrecision>> py_class (m, className.c_str ());
98
+ py_class.def (py::init<LocalRenumbering, LocalRenumbering >());
99
+ py_class.def (" local_add_vector_product" , &Class::local_add_vector_product , py::arg (" trans" ), py::arg (" alpha" ), py::arg (" in" ).noconvert (true ), py::arg (" beta" ), py::arg (" out" ).noconvert (true ));
100
+ py_class.def (" local_add_matrix_product_row_major" , &Class::local_add_matrix_product_row_major );
94
101
}
95
102
96
103
#endif
0 commit comments