Skip to content

Commit fe6948b

Browse files
Add option to selectively force refinement of top constraints
1 parent 9f409f7 commit fe6948b

File tree

4 files changed

+104
-3
lines changed

4 files changed

+104
-3
lines changed

include/macis/asci/determinant_search.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ struct ASCISettings {
6363
int constraint_level = 2; // Up To Quints
6464
int pt2_max_constraint_level = 5;
6565
int pt2_min_constraint_level = 0;
66+
int64_t pt2_constraint_refine_force = 0;
6667
};
6768

6869
template <size_t N>

include/macis/asci/mask_constraints.hpp

Lines changed: 100 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -630,7 +630,8 @@ auto dist_constraint_general(size_t nlevels, size_t norb, size_t ns_othr,
630630
template <typename WfnType, typename ContainerType>
631631
auto gen_constraints_general(size_t nlevels, size_t norb, size_t ns_othr,
632632
size_t nd_othr, const ContainerType& unique_alpha,
633-
int world_size, size_t nlevel_min = 0) {
633+
int world_size, size_t nlevel_min = 0,
634+
int64_t nrec_min = -1) {
634635
using wfn_traits = wavefunction_traits<WfnType>;
635636
using constraint_type = alpha_constraint<wfn_traits>;
636637
using string_type = typename constraint_type::constraint_type;
@@ -671,7 +672,9 @@ auto gen_constraints_general(size_t nlevels, size_t norb, size_t ns_othr,
671672
auto constraint = constraint_type::make_triplet(t_i, t_j, t_k);
672673
constraint_sizes.emplace_back(constraint, 0ul);
673674
}
675+
674676
// Build up higher-order constraints as base if requested
677+
if(nrec_min < 0 or nrec_min >= constraint_sizes.size()) // nrec_min < 0 implies that you want all the constraints upfront
675678
for(size_t ilevel = 0; ilevel < nlevel_min; ++ilevel) {
676679
decltype(constraint_sizes) cur_constraints;
677680
cur_constraints.reserve(constraint_sizes.size() * norb);
@@ -703,14 +706,15 @@ auto gen_constraints_general(size_t nlevels, size_t norb, size_t ns_othr,
703706
// Compute histogram
704707
const auto ntrip_full = constraint_sizes.size();
705708
std::vector<atomic_wrapper> constraint_work(ntrip_full, 0ul);
709+
{
706710
global_atomic<size_t> nxtval(MPI_COMM_WORLD);
707711
#pragma omp parallel
708712
{
709713
size_t i_trip = 0;
710714
while(i_trip < ntrip_full) {
711715
i_trip = nxtval.fetch_and_add(1);
712716
if(i_trip >= ntrip_full) break;
713-
if(!(i_trip%1000)) printf("cgen %lu / %lu\n", i_trip, ntrip_full);
717+
//if(!(i_trip%1000)) printf("cgen %lu / %lu\n", i_trip, ntrip_full);
714718
auto& [constraint, __nw] = constraint_sizes[i_trip];
715719
auto& c_nw = constraint_work[i_trip];
716720
size_t nw = 0;
@@ -725,6 +729,7 @@ auto gen_constraints_general(size_t nlevels, size_t norb, size_t ns_othr,
725729
if(nw) c_nw.value.fetch_add(nw);
726730
}
727731
}
732+
} // Scope nxtval
728733

729734
std::vector<size_t> constraint_work_bare(ntrip_full);
730735
for(auto i_trip = 0; i_trip < ntrip_full; ++i_trip) {
@@ -749,6 +754,99 @@ auto gen_constraints_general(size_t nlevels, size_t norb, size_t ns_othr,
749754
0ul, [](auto s, const auto& p){ return s + p.second; });
750755
size_t local_average = total_work / world_size;
751756

757+
758+
// Manual refinement of top configurations
759+
if(nrec_min > 0 and nrec_min < constraint_sizes.size()) {
760+
761+
const size_t nleave = constraint_sizes.size() - nrec_min;
762+
std::vector<std::pair<constraint_type, size_t>> constraint_to_refine,
763+
constraint_to_leave;
764+
constraint_to_refine.reserve(nrec_min);
765+
constraint_to_refine.reserve(nleave);
766+
767+
std::copy_n(constraint_sizes.begin(), nrec_min, std::back_inserter(constraint_to_refine));
768+
std::copy_n(constraint_sizes.begin() + nrec_min, nleave,
769+
std::back_inserter(constraint_to_leave));
770+
771+
// Deallocate original array
772+
decltype(constraint_sizes)().swap(constraint_sizes);
773+
774+
// Generate refined constraints
775+
for(size_t ilevel = 0; ilevel < nlevel_min; ++ilevel) {
776+
decltype(constraint_sizes) cur_constraints;
777+
cur_constraints.reserve(constraint_to_refine.size() * norb);
778+
for(auto [c,nw] : constraint_to_refine) {
779+
const auto C_min = c.C_min();
780+
for(auto q_l = 0; q_l < C_min; ++q_l) {
781+
// Generate masks / counts
782+
string_type cn_C = c.C();
783+
cn_C.flip(q_l);
784+
string_type cn_B = c.B() >> (C_min - q_l);
785+
constraint_type c_next(cn_C, cn_B, q_l);
786+
cur_constraints.emplace_back(c_next, 0ul);
787+
}
788+
}
789+
constraint_to_refine = std::move(cur_constraints);
790+
}
791+
792+
const size_t nrefine = constraint_to_refine.size();
793+
794+
global_atomic<size_t> nxtval(MPI_COMM_WORLD);
795+
std::vector<atomic_wrapper>().swap(constraint_work);
796+
std::vector<size_t>().swap(constraint_work_bare);
797+
constraint_work.resize(nrefine, 0ul);
798+
#pragma omp parallel
799+
{
800+
size_t i_ref = 0;
801+
while(i_ref < nrefine) {
802+
i_ref = nxtval.fetch_and_add(1);
803+
if(i_ref >= nrefine) break;
804+
//if(!(i_ref%1000)) printf("cgen %lu / %lu\n", i_ref, nrefine);
805+
auto& [constraint, __nw] = constraint_to_refine[i_ref];
806+
auto& c_nw = constraint_work[i_ref];
807+
size_t nw = 0;
808+
for(const auto& alpha : unique_alpha) {
809+
if constexpr(flat_container)
810+
nw += constraint_histogram(wfn_traits::alpha_string(alpha), ns_othr,
811+
nd_othr, constraint);
812+
else
813+
nw += alpha.second * constraint_histogram(alpha.first, ns_othr,
814+
nd_othr, constraint);
815+
}
816+
if(nw) c_nw.value.fetch_add(nw);
817+
} // constraint "loop"
818+
} // OpenMP Context
819+
820+
constraint_work_bare.resize(nrefine);
821+
for(auto i_ref = 0; i_ref < nrefine; ++i_ref) {
822+
constraint_work_bare[i_ref] = constraint_work[i_ref].value.load();
823+
}
824+
allreduce(constraint_work_bare.data(), nrefine, MPI_SUM, MPI_COMM_WORLD);
825+
826+
// Copy over constraint work
827+
for(auto i_ref = 0; i_ref < nrefine; ++i_ref) {
828+
constraint_to_refine[i_ref].second = constraint_work_bare[i_ref];
829+
}
830+
831+
// Remove zeros
832+
{
833+
auto it = std::partition(constraint_to_refine.begin(), constraint_to_refine.end(),
834+
[](const auto& p) { return p.second > 0; });
835+
constraint_to_refine.erase(it, constraint_to_refine.end());
836+
}
837+
838+
// Concatenate the arrays
839+
constraint_sizes.reserve(nrefine + nleave);
840+
std::copy_n(constraint_to_refine.begin(), nrefine, std::back_inserter(constraint_sizes));
841+
std::copy_n(constraint_to_leave.begin(), nleave, std::back_inserter(constraint_sizes));
842+
843+
size_t tmp = std::accumulate(constraint_sizes.begin(), constraint_sizes.end(),
844+
0ul, [](auto s, const auto& p){ return s + p.second; });
845+
if(tmp != total_work) throw std::runtime_error("Incorrect Refinement");
846+
} // Selective refinement logic
847+
848+
849+
752850
#endif
753851

754852
for(size_t ilevel = 0; ilevel < nlevels; ++ilevel) {

include/macis/asci/pt2.hpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ double asci_pt2_constraint(ASCISettings asci_settings,
4848
logger->info(" * PT2_RESERVE_COUNT = {}", asci_settings.pt2_reserve_count);
4949
logger->info(" * PT2_CONSTRAINT_LVL_MAX = {}", asci_settings.pt2_max_constraint_level);
5050
logger->info(" * PT2_CONSTRAINT_LVL_MIN = {}", asci_settings.pt2_min_constraint_level);
51+
logger->info(" * PT2_CNSTRNT_RFNE_FORCE = {}", asci_settings.pt2_constraint_refine_force);
5152
logger->info(" * PT2_PRUNE = {}", asci_settings.pt2_prune);
5253
logger->info(" * PT2_PRECOMP_EPS = {}", asci_settings.pt2_precompute_eps);
5354
logger->info(" * PT2_BIGCON_THRESH = {}", asci_settings.pt2_bigcon_thresh);
@@ -158,7 +159,7 @@ double asci_pt2_constraint(ASCISettings asci_settings,
158159
auto constraints = gen_constraints_general<wfn_t<N>>(
159160
asci_settings.pt2_max_constraint_level, norb, n_sing_beta,
160161
n_doub_beta, uniq_alpha, world_size * omp_get_max_threads(),
161-
asci_settings.pt2_min_constraint_level);
162+
asci_settings.pt2_min_constraint_level, asci_settings.pt2_constraint_refine_force );
162163
auto gen_c_en = clock_type::now();
163164
duration_type gen_c_dur = gen_c_en - gen_c_st;
164165
logger->info(" * GEN_DUR = {:.2e} ms", gen_c_dur.count());

tests/standalone_driver.cxx

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -224,6 +224,7 @@ int main(int argc, char** argv) {
224224
OPT_KEYWORD("ASCI.PT2_RESERVE_COUNT", asci_settings.pt2_reserve_count, size_t);
225225
OPT_KEYWORD("ASCI.PT2_CONSTRAINT_LVL_MAX", asci_settings.pt2_max_constraint_level, int);
226226
OPT_KEYWORD("ASCI.PT2_CONSTRAINT_LVL_MIN", asci_settings.pt2_min_constraint_level, int);
227+
OPT_KEYWORD("ASCI.PT2_CNSTRNT_RFNE_FORCE", asci_settings.pt2_constraint_refine_force,int64_t);
227228
OPT_KEYWORD("ASCI.PT2_PRUNE", asci_settings.pt2_prune, bool);
228229
OPT_KEYWORD("ASCI.PT2_PRECOMPUTE_EPS", asci_settings.pt2_precompute_eps, bool);
229230
OPT_KEYWORD("ASCI.PT2_PRECOMPUTE_IDX", asci_settings.pt2_precompute_idx, bool);

0 commit comments

Comments
 (0)