|
5 | 5 | #include <SeQuant/core/expr.hpp> |
6 | 6 | #include <SeQuant/core/hash.hpp> |
7 | 7 | #include <SeQuant/core/optimize/optimize.hpp> |
| 8 | +#include <SeQuant/core/optimize/single_term.hpp> |
| 9 | +#include <SeQuant/core/optimize/sum.hpp> |
8 | 10 | #include <SeQuant/core/utility/indices.hpp> |
9 | 11 | #include <SeQuant/core/utility/macros.hpp> |
10 | 12 |
|
|
14 | 16 | #include <range/v3/view/transform.hpp> |
15 | 17 | #include <range/v3/view/view.hpp> |
16 | 18 |
|
17 | | -#include <algorithm> |
18 | 19 | #include <cstddef> |
19 | | -#include <memory> |
20 | | -#include <stack> |
| 20 | +#include <type_traits> |
21 | 21 | #include <utility> |
22 | | -#include <vector> |
23 | 22 |
|
24 | 23 | namespace sequant { |
25 | 24 |
|
26 | | -class Tensor; |
27 | | - |
28 | 25 | namespace opt { |
29 | 26 |
|
30 | | -ExprPtr tail_factor(ExprPtr const& expr) noexcept { |
31 | | - if (expr->is<Tensor>()) |
32 | | - return expr->clone(); |
33 | | - |
34 | | - else if (expr->is<Product>()) { |
35 | | - auto scalar = expr->as<Product>().scalar(); |
36 | | - if (scalar == 1 && expr->size() == 2) { |
37 | | - // product with |
38 | | - // -single factor that is a tensor |
39 | | - // -scalar is just 1 |
40 | | - // will not be formed because of this block |
41 | | - return expr->at(1); |
42 | | - } |
43 | | - auto facs = ranges::views::tail(*expr); |
44 | | - return ex<Product>(Product{scalar, ranges::begin(facs), ranges::end(facs)}); |
45 | | - } else { |
46 | | - // sum |
47 | | - auto summands = *expr | ranges::views::transform( |
48 | | - [](auto const& x) { return tail_factor(x); }); |
49 | | - return ex<Sum>(Sum{ranges::begin(summands), ranges::end(summands)}); |
50 | | - } |
51 | | -} |
52 | | - |
53 | | -void pull_scalar(ExprPtr expr) noexcept { |
54 | | - if (!expr->is<Product>()) return; |
55 | | - auto& prod = expr->as<Product>(); |
56 | | - |
57 | | - auto scal = prod.scalar(); |
58 | | - for (auto&& x : *expr) |
59 | | - if (x->is<Product>()) { |
60 | | - auto& p = x->as<Product>(); |
61 | | - scal *= p.scalar(); |
62 | | - p.scale(1 / p.scalar()); |
63 | | - } |
64 | | - |
65 | | - prod.scale(1 / prod.scalar()); |
66 | | - prod.scale(scal); |
67 | | -} |
68 | | - |
69 | | -bool has_only_single_atom(const ExprPtr& term) { |
70 | | - if (term->is_atom()) { |
71 | | - return true; |
72 | | - } |
73 | | - |
74 | | - // Recursively check that all elements in the expression tree have only a |
75 | | - // single element in them. At this point this means checking for Sum or |
76 | | - // Product objects that only have a single addend or factor respectively. |
77 | | - return term->size() == 1 && has_only_single_atom(*term->begin()); |
78 | | -} |
79 | | - |
80 | | -container::vector<container::vector<size_t>> clusters(Sum const& expr) { |
81 | | - using ranges::views::tail; |
| 27 | +/// |
| 28 | +/// \param expr Expression to be optimized. |
| 29 | +/// \param idxsz An invocable object that maps an Index object to size. |
| 30 | +/// \param reorder_sum If true, the summands are reordered so that terms with |
| 31 | +/// common sub-expressions appear closer to each other. |
| 32 | +/// \return Optimized expression for lower evaluation cost. |
| 33 | +template <typename IdxToSize, typename = std::enable_if_t<std::is_invocable_r_v< |
| 34 | + size_t, IdxToSize, const Index&>>> |
| 35 | +ExprPtr optimize(ExprPtr const& expr, IdxToSize const& idx2size, |
| 36 | + bool reorder_sum) { |
82 | 37 | using ranges::views::transform; |
83 | | - using hash_t = size_t; |
84 | | - using pos_t = size_t; |
85 | | - using stack_t = std::stack<pos_t, container::vector<pos_t>>; |
86 | | - |
87 | | - container::map<hash_t, container::set<pos_t>> positions; |
88 | | - { |
89 | | - pos_t pos = 0; |
90 | | - auto visitor = [&positions, &pos](auto const& n) { |
91 | | - auto h = hash::value(*n); |
92 | | - if (auto&& found = positions.find(h); found != positions.end()) { |
93 | | - found->second.emplace(pos); |
94 | | - } else { |
95 | | - positions.emplace(h, decltype(positions)::mapped_type{pos}); |
96 | | - } |
97 | | - }; |
98 | | - |
99 | | - for (auto const& term : expr) { |
100 | | - auto const node = binarize(term); |
101 | | - if (has_only_single_atom(term)) { |
102 | | - visitor(node); |
103 | | - } else { |
104 | | - node.visit_internal(visitor); |
105 | | - } |
106 | | - ++pos; |
107 | | - } |
108 | | - } |
109 | | - |
110 | | - container::map<pos_t, container::vector<pos_t>> connections; |
111 | | - { |
112 | | - for (auto const& [_, v] : positions) { |
113 | | - auto const v0 = ranges::front(v); |
114 | | - auto const v_ = ranges::views::tail(v) | |
115 | | - ranges::to<decltype(connections)::mapped_type>; |
116 | | - if (auto&& found = connections.find(v0); found != connections.end()) |
117 | | - for (auto p : v_) found->second.push_back(p); |
118 | | - else |
119 | | - connections.emplace(v0, v_); |
120 | | - } |
121 | | - } |
122 | | - positions.clear(); |
123 | | - |
124 | | - container::vector<container::vector<pos_t>> result; |
125 | | - { |
126 | | - container::set<pos_t> visited; |
127 | | - for (auto k : connections | ranges::views::keys) |
128 | | - if (!visited.contains(k)) { |
129 | | - stack_t dfs_stack; |
130 | | - dfs_stack.push(k); |
131 | | - container::vector<pos_t> clstr; |
132 | | - while (!dfs_stack.empty()) { |
133 | | - auto p = dfs_stack.top(); |
134 | | - dfs_stack.pop(); |
135 | | - if (!visited.contains(p)) { |
136 | | - clstr.push_back(p); |
137 | | - visited.emplace(p); |
138 | | - } |
139 | | - if (auto&& found = connections.find(p); found != connections.end()) |
140 | | - for (auto p_ : ranges::views::reverse(found->second)) |
141 | | - dfs_stack.push(p_); |
| 38 | + if (expr->is<Product>()) { |
| 39 | + if (ranges::all_of(*expr, [](auto&& x) { |
| 40 | + return x->template is<Tensor>() || x->template is<Variable>(); |
| 41 | + })) |
| 42 | + return opt::single_term_opt(expr->as<Product>(), idx2size); |
| 43 | + else { |
| 44 | + auto const& prod = expr->as<Product>(); |
| 45 | + |
| 46 | + container::svector<ExprPtr> non_tensors(prod.size()); |
| 47 | + container::svector<ExprPtr> new_factors; |
| 48 | + |
| 49 | + for (auto i = 0; i < prod.size(); ++i) { |
| 50 | + auto&& f = prod.factor(i); |
| 51 | + if (f.is<Tensor>() || f.is<Variable>()) |
| 52 | + new_factors.emplace_back(f); |
| 53 | + else { |
| 54 | + non_tensors[i] = f; |
| 55 | + auto target_idxs = get_unique_indices(f); |
| 56 | + new_factors.emplace_back( |
| 57 | + ex<Tensor>(L"I_" + std::to_wstring(i), bra(target_idxs.bra), |
| 58 | + ket(target_idxs.ket), aux(target_idxs.aux))); |
142 | 59 | } |
143 | | - result.emplace_back(std::move(clstr)); |
144 | 60 | } |
145 | | - } |
146 | | - return result; |
147 | | -} |
148 | 61 |
|
149 | | -Sum reorder(Sum const& sum) { |
150 | | - Sum result; |
| 62 | + auto result = opt::single_term_opt( |
| 63 | + Product(prod.scalar(), new_factors, Product::Flatten::No), idx2size); |
151 | 64 |
|
152 | | - for (auto const& clstr : clusters(sum)) |
153 | | - for (auto p : clstr) result.append(sum.at(p)); |
| 65 | + auto replacer = [&non_tensors](ExprPtr& out) { |
| 66 | + if (!out->is<Tensor>()) return; |
| 67 | + auto const& tnsr = out->as<Tensor>(); |
| 68 | + auto&& label = tnsr.label(); |
| 69 | + if (label.at(0) == L'I' && label.at(1) == L'_') { |
| 70 | + size_t suffix = std::stoi(std::wstring(label.data() + 2)); |
| 71 | + out = non_tensors[suffix].clone(); |
| 72 | + } |
| 73 | + }; |
154 | 74 |
|
155 | | - SEQUANT_ASSERT(result.size() == sum.size()); |
156 | | - return result; |
| 75 | + result->visit(replacer, /* atoms_only = */ true); |
| 76 | + return result; |
| 77 | + } |
| 78 | + } else if (expr->is<Sum>()) { |
| 79 | + auto smands = *expr | transform([&idx2size](auto&& s) { |
| 80 | + return optimize(s, idx2size, /* reorder_sum= */ false); |
| 81 | + }) | ranges::to_vector; |
| 82 | + auto sum = Sum{smands.begin(), smands.end()}; |
| 83 | + return reorder_sum ? ex<Sum>(opt::reorder(sum)) : ex<Sum>(std::move(sum)); |
| 84 | + } else |
| 85 | + return expr->clone(); |
157 | 86 | } |
158 | 87 |
|
159 | 88 | } // namespace opt |
|
0 commit comments