Skip to content

Commit b2c788c

Browse files
committed
Disentangle optimize implementation
1 parent 028c4fa commit b2c788c

6 files changed

Lines changed: 463 additions & 511 deletions

File tree

CMakeLists.txt

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -422,11 +422,14 @@ set(SeQuant_eval_src
422422

423423
# Optimize sources (depends on eval)
424424
set(SeQuant_optimize_src
425-
SeQuant/core/optimize/optimize.hpp
426425
SeQuant/core/optimize/common_subexpression_elimination.hpp
427426
SeQuant/core/optimize/fusion.cpp
428427
SeQuant/core/optimize/fusion.hpp
429428
SeQuant/core/optimize/optimize.cpp
429+
SeQuant/core/optimize/optimize.hpp
430+
SeQuant/core/optimize/single_term.hpp
431+
SeQuant/core/optimize/sum.cpp
432+
SeQuant/core/optimize/sum.hpp
430433
)
431434

432435
# Export sources (depends on eval)

SeQuant/core/optimize/optimize.cpp

Lines changed: 56 additions & 127 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
#include <SeQuant/core/expr.hpp>
66
#include <SeQuant/core/hash.hpp>
77
#include <SeQuant/core/optimize/optimize.hpp>
8+
#include <SeQuant/core/optimize/single_term.hpp>
9+
#include <SeQuant/core/optimize/sum.hpp>
810
#include <SeQuant/core/utility/indices.hpp>
911
#include <SeQuant/core/utility/macros.hpp>
1012

@@ -14,146 +16,73 @@
1416
#include <range/v3/view/transform.hpp>
1517
#include <range/v3/view/view.hpp>
1618

17-
#include <algorithm>
1819
#include <cstddef>
19-
#include <memory>
20-
#include <stack>
20+
#include <type_traits>
2121
#include <utility>
22-
#include <vector>
2322

2423
namespace sequant {
2524

26-
class Tensor;
27-
2825
namespace opt {
2926

30-
ExprPtr tail_factor(ExprPtr const& expr) noexcept {
31-
if (expr->is<Tensor>())
32-
return expr->clone();
33-
34-
else if (expr->is<Product>()) {
35-
auto scalar = expr->as<Product>().scalar();
36-
if (scalar == 1 && expr->size() == 2) {
37-
// product with
38-
// -single factor that is a tensor
39-
// -scalar is just 1
40-
// will not be formed because of this block
41-
return expr->at(1);
42-
}
43-
auto facs = ranges::views::tail(*expr);
44-
return ex<Product>(Product{scalar, ranges::begin(facs), ranges::end(facs)});
45-
} else {
46-
// sum
47-
auto summands = *expr | ranges::views::transform(
48-
[](auto const& x) { return tail_factor(x); });
49-
return ex<Sum>(Sum{ranges::begin(summands), ranges::end(summands)});
50-
}
51-
}
52-
53-
void pull_scalar(ExprPtr expr) noexcept {
54-
if (!expr->is<Product>()) return;
55-
auto& prod = expr->as<Product>();
56-
57-
auto scal = prod.scalar();
58-
for (auto&& x : *expr)
59-
if (x->is<Product>()) {
60-
auto& p = x->as<Product>();
61-
scal *= p.scalar();
62-
p.scale(1 / p.scalar());
63-
}
64-
65-
prod.scale(1 / prod.scalar());
66-
prod.scale(scal);
67-
}
68-
69-
bool has_only_single_atom(const ExprPtr& term) {
70-
if (term->is_atom()) {
71-
return true;
72-
}
73-
74-
// Recursively check that all elements in the expression tree have only a
75-
// single element in them. At this point this means checking for Sum or
76-
// Product objects that only have a single addend or factor respectively.
77-
return term->size() == 1 && has_only_single_atom(*term->begin());
78-
}
79-
80-
container::vector<container::vector<size_t>> clusters(Sum const& expr) {
81-
using ranges::views::tail;
27+
///
28+
/// \param expr Expression to be optimized.
29+
/// \param idxsz An invocable object that maps an Index object to size.
30+
/// \param reorder_sum If true, the summands are reordered so that terms with
31+
/// common sub-expressions appear closer to each other.
32+
/// \return Optimized expression for lower evaluation cost.
33+
template <typename IdxToSize, typename = std::enable_if_t<std::is_invocable_r_v<
34+
size_t, IdxToSize, const Index&>>>
35+
ExprPtr optimize(ExprPtr const& expr, IdxToSize const& idx2size,
36+
bool reorder_sum) {
8237
using ranges::views::transform;
83-
using hash_t = size_t;
84-
using pos_t = size_t;
85-
using stack_t = std::stack<pos_t, container::vector<pos_t>>;
86-
87-
container::map<hash_t, container::set<pos_t>> positions;
88-
{
89-
pos_t pos = 0;
90-
auto visitor = [&positions, &pos](auto const& n) {
91-
auto h = hash::value(*n);
92-
if (auto&& found = positions.find(h); found != positions.end()) {
93-
found->second.emplace(pos);
94-
} else {
95-
positions.emplace(h, decltype(positions)::mapped_type{pos});
96-
}
97-
};
98-
99-
for (auto const& term : expr) {
100-
auto const node = binarize(term);
101-
if (has_only_single_atom(term)) {
102-
visitor(node);
103-
} else {
104-
node.visit_internal(visitor);
105-
}
106-
++pos;
107-
}
108-
}
109-
110-
container::map<pos_t, container::vector<pos_t>> connections;
111-
{
112-
for (auto const& [_, v] : positions) {
113-
auto const v0 = ranges::front(v);
114-
auto const v_ = ranges::views::tail(v) |
115-
ranges::to<decltype(connections)::mapped_type>;
116-
if (auto&& found = connections.find(v0); found != connections.end())
117-
for (auto p : v_) found->second.push_back(p);
118-
else
119-
connections.emplace(v0, v_);
120-
}
121-
}
122-
positions.clear();
123-
124-
container::vector<container::vector<pos_t>> result;
125-
{
126-
container::set<pos_t> visited;
127-
for (auto k : connections | ranges::views::keys)
128-
if (!visited.contains(k)) {
129-
stack_t dfs_stack;
130-
dfs_stack.push(k);
131-
container::vector<pos_t> clstr;
132-
while (!dfs_stack.empty()) {
133-
auto p = dfs_stack.top();
134-
dfs_stack.pop();
135-
if (!visited.contains(p)) {
136-
clstr.push_back(p);
137-
visited.emplace(p);
138-
}
139-
if (auto&& found = connections.find(p); found != connections.end())
140-
for (auto p_ : ranges::views::reverse(found->second))
141-
dfs_stack.push(p_);
38+
if (expr->is<Product>()) {
39+
if (ranges::all_of(*expr, [](auto&& x) {
40+
return x->template is<Tensor>() || x->template is<Variable>();
41+
}))
42+
return opt::single_term_opt(expr->as<Product>(), idx2size);
43+
else {
44+
auto const& prod = expr->as<Product>();
45+
46+
container::svector<ExprPtr> non_tensors(prod.size());
47+
container::svector<ExprPtr> new_factors;
48+
49+
for (auto i = 0; i < prod.size(); ++i) {
50+
auto&& f = prod.factor(i);
51+
if (f.is<Tensor>() || f.is<Variable>())
52+
new_factors.emplace_back(f);
53+
else {
54+
non_tensors[i] = f;
55+
auto target_idxs = get_unique_indices(f);
56+
new_factors.emplace_back(
57+
ex<Tensor>(L"I_" + std::to_wstring(i), bra(target_idxs.bra),
58+
ket(target_idxs.ket), aux(target_idxs.aux)));
14259
}
143-
result.emplace_back(std::move(clstr));
14460
}
145-
}
146-
return result;
147-
}
14861

149-
Sum reorder(Sum const& sum) {
150-
Sum result;
62+
auto result = opt::single_term_opt(
63+
Product(prod.scalar(), new_factors, Product::Flatten::No), idx2size);
15164

152-
for (auto const& clstr : clusters(sum))
153-
for (auto p : clstr) result.append(sum.at(p));
65+
auto replacer = [&non_tensors](ExprPtr& out) {
66+
if (!out->is<Tensor>()) return;
67+
auto const& tnsr = out->as<Tensor>();
68+
auto&& label = tnsr.label();
69+
if (label.at(0) == L'I' && label.at(1) == L'_') {
70+
size_t suffix = std::stoi(std::wstring(label.data() + 2));
71+
out = non_tensors[suffix].clone();
72+
}
73+
};
15474

155-
SEQUANT_ASSERT(result.size() == sum.size());
156-
return result;
75+
result->visit(replacer, /* atoms_only = */ true);
76+
return result;
77+
}
78+
} else if (expr->is<Sum>()) {
79+
auto smands = *expr | transform([&idx2size](auto&& s) {
80+
return optimize(s, idx2size, /* reorder_sum= */ false);
81+
}) | ranges::to_vector;
82+
auto sum = Sum{smands.begin(), smands.end()};
83+
return reorder_sum ? ex<Sum>(opt::reorder(sum)) : ex<Sum>(std::move(sum));
84+
} else
85+
return expr->clone();
15786
}
15887

15988
} // namespace opt

0 commit comments

Comments
 (0)