@@ -630,7 +630,8 @@ auto dist_constraint_general(size_t nlevels, size_t norb, size_t ns_othr,
630630template <typename WfnType, typename ContainerType>
631631auto gen_constraints_general (size_t nlevels, size_t norb, size_t ns_othr,
632632 size_t nd_othr, const ContainerType& unique_alpha,
633- int world_size, size_t nlevel_min = 0 ) {
633+ int world_size, size_t nlevel_min = 0 ,
634+ int64_t nrec_min = -1 ) {
634635 using wfn_traits = wavefunction_traits<WfnType>;
635636 using constraint_type = alpha_constraint<wfn_traits>;
636637 using string_type = typename constraint_type::constraint_type;
@@ -671,7 +672,9 @@ auto gen_constraints_general(size_t nlevels, size_t norb, size_t ns_othr,
671672 auto constraint = constraint_type::make_triplet (t_i, t_j, t_k);
672673 constraint_sizes.emplace_back (constraint, 0ul );
673674 }
675+
674676 // Build up higher-order constraints as base if requested
677+ if (nrec_min < 0 or nrec_min >= constraint_sizes.size ()) // nrec_min < 0 implies that you want all the constraints upfront
675678 for (size_t ilevel = 0 ; ilevel < nlevel_min; ++ilevel) {
676679 decltype (constraint_sizes) cur_constraints;
677680 cur_constraints.reserve (constraint_sizes.size () * norb);
@@ -703,14 +706,15 @@ auto gen_constraints_general(size_t nlevels, size_t norb, size_t ns_othr,
703706 // Compute histogram
704707 const auto ntrip_full = constraint_sizes.size ();
705708 std::vector<atomic_wrapper> constraint_work (ntrip_full, 0ul );
709+ {
706710 global_atomic<size_t > nxtval (MPI_COMM_WORLD);
707711 #pragma omp parallel
708712 {
709713 size_t i_trip = 0 ;
710714 while (i_trip < ntrip_full) {
711715 i_trip = nxtval.fetch_and_add (1 );
712716 if (i_trip >= ntrip_full) break ;
713- if (!(i_trip%1000 )) printf (" cgen %lu / %lu\n " , i_trip, ntrip_full);
717+ // if(!(i_trip%1000)) printf("cgen %lu / %lu\n", i_trip, ntrip_full);
714718 auto & [constraint, __nw] = constraint_sizes[i_trip];
715719 auto & c_nw = constraint_work[i_trip];
716720 size_t nw = 0 ;
@@ -725,6 +729,7 @@ auto gen_constraints_general(size_t nlevels, size_t norb, size_t ns_othr,
725729 if (nw) c_nw.value .fetch_add (nw);
726730 }
727731 }
732+ } // Scope nxtval
728733
729734 std::vector<size_t > constraint_work_bare (ntrip_full);
730735 for (auto i_trip = 0 ; i_trip < ntrip_full; ++i_trip) {
@@ -749,6 +754,99 @@ auto gen_constraints_general(size_t nlevels, size_t norb, size_t ns_othr,
749754 0ul , [](auto s, const auto & p){ return s + p.second ; });
750755 size_t local_average = total_work / world_size;
751756
757+
758+ // Manual refinement of top configurations
759+ if (nrec_min > 0 and nrec_min < constraint_sizes.size ()) {
760+
761+ const size_t nleave = constraint_sizes.size () - nrec_min;
762+ std::vector<std::pair<constraint_type, size_t >> constraint_to_refine,
763+ constraint_to_leave;
764+ constraint_to_refine.reserve (nrec_min);
765+ constraint_to_refine.reserve (nleave);
766+
767+ std::copy_n (constraint_sizes.begin (), nrec_min, std::back_inserter (constraint_to_refine));
768+ std::copy_n (constraint_sizes.begin () + nrec_min, nleave,
769+ std::back_inserter (constraint_to_leave));
770+
771+ // Deallocate original array
772+ decltype (constraint_sizes)().swap (constraint_sizes);
773+
774+ // Generate refined constraints
775+ for (size_t ilevel = 0 ; ilevel < nlevel_min; ++ilevel) {
776+ decltype (constraint_sizes) cur_constraints;
777+ cur_constraints.reserve (constraint_to_refine.size () * norb);
778+ for (auto [c,nw] : constraint_to_refine) {
779+ const auto C_min = c.C_min ();
780+ for (auto q_l = 0 ; q_l < C_min; ++q_l) {
781+ // Generate masks / counts
782+ string_type cn_C = c.C ();
783+ cn_C.flip (q_l);
784+ string_type cn_B = c.B () >> (C_min - q_l);
785+ constraint_type c_next (cn_C, cn_B, q_l);
786+ cur_constraints.emplace_back (c_next, 0ul );
787+ }
788+ }
789+ constraint_to_refine = std::move (cur_constraints);
790+ }
791+
792+ const size_t nrefine = constraint_to_refine.size ();
793+
794+ global_atomic<size_t > nxtval (MPI_COMM_WORLD);
795+ std::vector<atomic_wrapper>().swap (constraint_work);
796+ std::vector<size_t >().swap (constraint_work_bare);
797+ constraint_work.resize (nrefine, 0ul );
798+ #pragma omp parallel
799+ {
800+ size_t i_ref = 0 ;
801+ while (i_ref < nrefine) {
802+ i_ref = nxtval.fetch_and_add (1 );
803+ if (i_ref >= nrefine) break ;
804+ // if(!(i_ref%1000)) printf("cgen %lu / %lu\n", i_ref, nrefine);
805+ auto & [constraint, __nw] = constraint_to_refine[i_ref];
806+ auto & c_nw = constraint_work[i_ref];
807+ size_t nw = 0 ;
808+ for (const auto & alpha : unique_alpha) {
809+ if constexpr (flat_container)
810+ nw += constraint_histogram (wfn_traits::alpha_string (alpha), ns_othr,
811+ nd_othr, constraint);
812+ else
813+ nw += alpha.second * constraint_histogram (alpha.first , ns_othr,
814+ nd_othr, constraint);
815+ }
816+ if (nw) c_nw.value .fetch_add (nw);
817+ } // constraint "loop"
818+ } // OpenMP Context
819+
820+ constraint_work_bare.resize (nrefine);
821+ for (auto i_ref = 0 ; i_ref < nrefine; ++i_ref) {
822+ constraint_work_bare[i_ref] = constraint_work[i_ref].value .load ();
823+ }
824+ allreduce (constraint_work_bare.data (), nrefine, MPI_SUM, MPI_COMM_WORLD);
825+
826+ // Copy over constraint work
827+ for (auto i_ref = 0 ; i_ref < nrefine; ++i_ref) {
828+ constraint_to_refine[i_ref].second = constraint_work_bare[i_ref];
829+ }
830+
831+ // Remove zeros
832+ {
833+ auto it = std::partition (constraint_to_refine.begin (), constraint_to_refine.end (),
834+ [](const auto & p) { return p.second > 0 ; });
835+ constraint_to_refine.erase (it, constraint_to_refine.end ());
836+ }
837+
838+ // Concatenate the arrays
839+ constraint_sizes.reserve (nrefine + nleave);
840+ std::copy_n (constraint_to_refine.begin (), nrefine, std::back_inserter (constraint_sizes));
841+ std::copy_n (constraint_to_leave.begin (), nleave, std::back_inserter (constraint_sizes));
842+
843+ size_t tmp = std::accumulate (constraint_sizes.begin (), constraint_sizes.end (),
844+ 0ul , [](auto s, const auto & p){ return s + p.second ; });
845+ if (tmp != total_work) throw std::runtime_error (" Incorrect Refinement" );
846+ } // Selective refinement logic
847+
848+
849+
752850 #endif
753851
754852 for (size_t ilevel = 0 ; ilevel < nlevels; ++ilevel) {
0 commit comments