Skip to content

Commit cd7c13a

Browse files
Merge pull request #91 from wavefunction91/feature/rad_s2_generators
Add Radial and S2 Generators
2 parents f8e42bb + a71ec70 commit cd7c13a

17 files changed

+536
-187
lines changed
Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
#pragma once
2+
#include <integratorxx/quadratures/radial.hpp>
3+
4+
namespace IntegratorXX {
5+
6+
/// High-level specification of radial quadratures
7+
enum class RadialQuad : uint32_t {
8+
Becke = 0x0010,
9+
MurrayHandyLaming = 0x0020,
10+
MuraKnowles = 0x0030,
11+
TreutlerAhlrichs = 0x0040
12+
};
13+
14+
template <typename RadQuadType>
15+
RadialQuad radial_from_type() {
16+
if constexpr (detail::is_becke_v<RadQuadType>) return RadialQuad::Becke;
17+
if constexpr (detail::is_mk_v<RadQuadType> ) return RadialQuad::MuraKnowles;
18+
if constexpr (detail::is_mhl_v<RadQuadType>) return RadialQuad::MurrayHandyLaming;
19+
if constexpr (detail::is_ta_v<RadQuadType>) return RadialQuad::TreutlerAhlrichs;
20+
21+
throw std::runtime_error("Unrecognized Radial Quadrature");
22+
};
23+
24+
RadialQuad radial_from_string(std::string name);
25+
26+
namespace detail {
27+
28+
template <typename RadialTraitsType, typename... Args>
29+
std::unique_ptr<RadialTraits> make_radial_traits(Args&&... args) {
30+
using traits_type = RadialTraitsType;
31+
if constexpr (std::is_constructible_v<traits_type,Args...>)
32+
return std::make_unique<traits_type>(std::forward<Args>(args)...);
33+
else return nullptr;
34+
}
35+
36+
}
37+
38+
template <typename... Args>
39+
std::unique_ptr<RadialTraits> make_radial_traits(RadialQuad rq, Args&&... args) {
40+
std::unique_ptr<RadialTraits> ptr;
41+
switch(rq) {
42+
case RadialQuad::Becke:
43+
ptr =
44+
detail::make_radial_traits<BeckeRadialTraits>(std::forward<Args>(args)...);
45+
break;
46+
case RadialQuad::MurrayHandyLaming:
47+
ptr =
48+
detail::make_radial_traits<MurrayHandyLamingRadialTraits<2>>(std::forward<Args>(args)...);
49+
break;
50+
case RadialQuad::MuraKnowles:
51+
ptr =
52+
detail::make_radial_traits<MuraKnowlesRadialTraits>(std::forward<Args>(args)...);
53+
break;
54+
case RadialQuad::TreutlerAhlrichs:
55+
ptr =
56+
detail::make_radial_traits<TreutlerAhlrichsRadialTraits>(std::forward<Args>(args)...);
57+
break;
58+
}
59+
60+
if(!ptr) throw std::runtime_error("RadialTraits Construction Failed");
61+
return ptr;
62+
}
63+
64+
65+
struct RadialFactory {
66+
67+
using radial_grid_ptr = std::shared_ptr<
68+
QuadratureBase<
69+
std::vector<double>,
70+
std::vector<double>
71+
>
72+
>;
73+
74+
static radial_grid_ptr generate(RadialQuad rq, const RadialTraits& traits);
75+
76+
};
77+
78+
}
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
#pragma once
2+
#include <integratorxx/quadratures/s2.hpp>
3+
4+
#include <memory>
5+
6+
namespace IntegratorXX {
7+
8+
/// High-level specification of angular quadratures
9+
enum class AngularQuad : uint32_t {
10+
AhrensBeylkin = 0x0100,
11+
Delley = 0x0200,
12+
LebedevLaikov = 0x0300,
13+
Womersley = 0x0400
14+
};
15+
16+
template <typename AngQuadType>
17+
AngularQuad angular_from_type() {
18+
if constexpr (detail::is_ahrens_beyklin_v<AngQuadType>) return AngularQuad::AhrensBeylkin;
19+
if constexpr (detail::is_delley_v<AngQuadType> ) return AngularQuad::Delley;
20+
if constexpr (detail::is_lebedev_laikov_v<AngQuadType>) return AngularQuad::LebedevLaikov;
21+
if constexpr (detail::is_womersley_v<AngQuadType>) return AngularQuad::Womersley;
22+
23+
throw std::runtime_error("Unrecognized Angular Quadrature");
24+
};
25+
26+
AngularQuad angular_from_string(std::string name);
27+
28+
struct S2Factory {
29+
30+
using s2_grid_ptr = std::shared_ptr<
31+
QuadratureBase<
32+
std::vector<std::array<double,3>>,
33+
std::vector<double>
34+
>
35+
>;
36+
37+
static s2_grid_ptr generate(AngularQuad aq, size_t npts);
38+
39+
};
40+
41+
}

include/integratorxx/generators/spherical_factory.hpp

Lines changed: 42 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -2,53 +2,13 @@
22
#include <memory>
33
#include <vector>
44
#include <array>
5-
#include <integratorxx/quadratures/radial.hpp>
6-
#include <integratorxx/quadratures/s2.hpp>
5+
#include <integratorxx/generators/radial_factory.hpp>
6+
#include <integratorxx/generators/s2_factory.hpp>
77
#include <integratorxx/composite_quadratures/spherical_quadrature.hpp>
88
#include <integratorxx/composite_quadratures/pruned_spherical_quadrature.hpp>
99

1010
namespace IntegratorXX {
1111

12-
/// High-level specification of radial quadratures
13-
enum class RadialQuad : uint32_t {
14-
Becke = 0x0010,
15-
MurrayHandyLaming = 0x0020,
16-
MuraKnowles = 0x0030,
17-
TreutlerAhlrichs = 0x0040
18-
};
19-
20-
template <typename RadQuadType>
21-
RadialQuad radial_from_type() {
22-
if constexpr (detail::is_becke_v<RadQuadType>) return RadialQuad::Becke;
23-
if constexpr (detail::is_mk_v<RadQuadType> ) return RadialQuad::MuraKnowles;
24-
if constexpr (detail::is_mhl_v<RadQuadType>) return RadialQuad::MurrayHandyLaming;
25-
if constexpr (detail::is_ta_v<RadQuadType>) return RadialQuad::TreutlerAhlrichs;
26-
27-
throw std::runtime_error("Unrecognized Radial Quadrature");
28-
};
29-
30-
RadialQuad radial_from_string(std::string name);
31-
32-
/// High-level specification of angular quadratures
33-
enum class AngularQuad : uint32_t {
34-
AhrensBeylkin = 0x0100,
35-
Delley = 0x0200,
36-
LebedevLaikov = 0x0300,
37-
Womersley = 0x0400
38-
};
39-
40-
template <typename AngQuadType>
41-
AngularQuad angular_from_type() {
42-
if constexpr (detail::is_ahrens_beyklin_v<AngQuadType>) return AngularQuad::AhrensBeylkin;
43-
if constexpr (detail::is_delley_v<AngQuadType> ) return AngularQuad::Delley;
44-
if constexpr (detail::is_lebedev_laikov_v<AngQuadType>) return AngularQuad::LebedevLaikov;
45-
if constexpr (detail::is_womersley_v<AngQuadType>) return AngularQuad::Womersley;
46-
47-
throw std::runtime_error("Unrecognized Angular Quadrature");
48-
};
49-
50-
AngularQuad angular_from_string(std::string name);
51-
5212
/// High-level specification of pruning schemes for spherical quadratures
5313
enum class PruningScheme {
5414
Unpruned, /// Unpruned quadrature
@@ -57,22 +17,29 @@ enum class PruningScheme {
5717
};
5818

5919
// TODO: Make these strong (non-convertible) types
60-
using RadialScale = double;
61-
using RadialSize = size_t;
20+
//using RadialScale = double;
21+
//using RadialSize = size_t;
6222
using AngularSize = size_t;
63-
23+
using radial_traits_ptr = std::unique_ptr<RadialTraits>;
6424

6525
/// Generic specification of an unpruned spherical quadrature
6626
struct UnprunedSphericalGridSpecification {
67-
RadialQuad radial_quad; ///< Radial quadrature specification
68-
RadialSize radial_size; ///< Number of radial quadrature points
69-
RadialScale radial_scale; ///< Radial scaling factor
27+
RadialQuad radial_quad; ///< Radial quadrature specification
28+
radial_traits_ptr radial_traits; ///< Radial traits (order, scaling factors, etc)
7029

7130
AngularQuad angular_quad; /// Angular quadrature specification
7231
AngularSize angular_size; /// Number of angular quadrature points
32+
33+
UnprunedSphericalGridSpecification(RadialQuad, const RadialTraits&, AngularQuad,
34+
AngularSize);
35+
36+
UnprunedSphericalGridSpecification(const UnprunedSphericalGridSpecification& other) :
37+
radial_quad(other.radial_quad), radial_traits(other.radial_traits ? other.radial_traits->clone() : nullptr),
38+
angular_quad(other.angular_quad), angular_size(other.angular_size) {};
7339
};
7440

7541

42+
7643
/// Specification of a pruned region of an spherical quadrature
7744
struct PruningRegion {
7845
size_t idx_st; ///< Starting radial index for pruned region
@@ -91,15 +58,34 @@ struct PruningRegion {
9158

9259
struct PrunedSphericalGridSpecification {
9360
RadialQuad radial_quad; ///< Radial quadrature specification
94-
RadialSize radial_size; ///< Number of radial quadrature points
95-
RadialScale radial_scale; ///< Radial scaling factor
61+
radial_traits_ptr radial_traits; ///< Radial traits (order, scaling factors, etc)
9662

9763
std::vector<PruningRegion> pruning_regions; ///< List of pruning regions over the radial quadrature
64+
65+
PrunedSphericalGridSpecification() = default;
66+
67+
template <typename... Arg>
68+
PrunedSphericalGridSpecification(RadialQuad rq, radial_traits_ptr&& traits, Arg&&... arg) :
69+
radial_quad(rq), radial_traits(std::move(traits)), pruning_regions(std::forward<Arg>(arg)...) { }
70+
template <typename... Arg>
71+
PrunedSphericalGridSpecification(RadialQuad rq, const RadialTraits& traits, Arg&&... arg) :
72+
PrunedSphericalGridSpecification(rq, traits.clone(), std::forward<Arg>(arg)...) { }
73+
74+
PrunedSphericalGridSpecification(const PrunedSphericalGridSpecification& other) :
75+
PrunedSphericalGridSpecification(other.radial_quad,
76+
other.radial_traits ? other.radial_traits->clone() : nullptr,
77+
other.pruning_regions) { }
78+
79+
PrunedSphericalGridSpecification& operator=(const PrunedSphericalGridSpecification& other) {
80+
radial_quad = other.radial_quad;
81+
radial_traits = other.radial_traits ? other.radial_traits->clone() : nullptr;
82+
pruning_regions = other.pruning_regions;
83+
return *this;
84+
}
9885

9986
inline bool operator==(const PrunedSphericalGridSpecification& other) const noexcept {
10087
return radial_quad == other.radial_quad and
101-
radial_size == other.radial_size and
102-
radial_scale == other.radial_scale and
88+
(radial_traits ? (other.radial_traits and radial_traits->compare(*other.radial_traits)) : !other.radial_traits) and
10389
pruning_regions == other.pruning_regions;
10490
}
10591
};
@@ -176,10 +162,10 @@ struct SphericalGridFactory {
176162
}
177163

178164

179-
static spherical_grid_ptr generate_unpruned_grid( RadialQuad, RadialSize,
180-
RadialScale, AngularQuad, AngularSize );
181-
static spherical_grid_ptr generate_pruned_grid( RadialQuad, RadialSize,
182-
RadialScale, const std::vector<PruningRegion>&);
165+
static spherical_grid_ptr generate_unpruned_grid( RadialQuad, const RadialTraits&,
166+
AngularQuad, AngularSize );
167+
static spherical_grid_ptr generate_pruned_grid( RadialQuad, const RadialTraits&,
168+
const std::vector<PruningRegion>&);
183169

184170

185171
static spherical_grid_ptr generate_grid(UnprunedSphericalGridSpecification gs);

include/integratorxx/quadratures/radial/becke.hpp

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,10 @@ namespace IntegratorXX {
1313
* J. Chem. Phys. 88, 2547 (1988)
1414
* DOI: https://doi.org/10.1063/1.454033
1515
*/
16-
class BeckeRadialTraits {
17-
double R_;
16+
class BeckeRadialTraits : public RadialTraits {
17+
18+
size_t npts_; ///< Number of grid points
19+
double R_; ///< Radial scaling factor
1820

1921
public:
2022
/**
@@ -24,7 +26,22 @@ class BeckeRadialTraits {
2426
*
2527
* @param[in] R Radial scaling factor
2628
*/
27-
BeckeRadialTraits(double R = 1.0) : R_(R) {}
29+
BeckeRadialTraits(size_t npts, double R = 1.0) : npts_(npts), R_(R) {}
30+
31+
size_t npts() const noexcept { return npts_; }
32+
33+
std::unique_ptr<RadialTraits> clone() const {
34+
return std::make_unique<BeckeRadialTraits>(*this);
35+
}
36+
37+
bool compare(const RadialTraits& other) const noexcept {
38+
auto ptr = dynamic_cast<const BeckeRadialTraits*>(&other);
39+
return ptr ? *this == *ptr : false;
40+
}
41+
42+
bool operator==(const BeckeRadialTraits& other) const noexcept {
43+
return npts_ == other.npts_ and R_ == other.R_;
44+
}
2845

2946
/**
3047
* @brief Transformation rule for the Becke radial quadratures

include/integratorxx/quadratures/radial/mhl.hpp

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,30 @@ namespace IntegratorXX {
1717
* Typically taken to be 2.
1818
*/
1919
template <size_t M>
20-
class MurrayHandyLamingRadialTraits {
20+
class MurrayHandyLamingRadialTraits : public RadialTraits {
2121

22+
size_t npts_; ///< Number of grid points
2223
double R_; ///< Radial scaling factor
24+
2325

2426
public:
2527

26-
MurrayHandyLamingRadialTraits(double R = 1.0) : R_(R) {}
28+
MurrayHandyLamingRadialTraits(size_t npts, double R = 1.0) : npts_(npts), R_(R) {}
29+
30+
size_t npts() const noexcept { return npts_; }
31+
32+
std::unique_ptr<RadialTraits> clone() const {
33+
return std::make_unique<MurrayHandyLamingRadialTraits>(*this);
34+
}
35+
36+
bool compare(const RadialTraits& other) const noexcept {
37+
auto ptr = dynamic_cast<const MurrayHandyLamingRadialTraits*>(&other);
38+
return ptr ? *this == *ptr : false;
39+
}
40+
41+
bool operator==(const MurrayHandyLamingRadialTraits& other) const noexcept {
42+
return npts_ == other.npts_ and R_ == other.R_;
43+
}
2744

2845
/**
2946
* @brief Transformation rule for the MHL radial quadrature

include/integratorxx/quadratures/radial/muraknowles.hpp

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -136,13 +136,29 @@ struct quadrature_traits<
136136

137137
#else
138138

139-
class MuraKnowlesRadialTraits {
139+
class MuraKnowlesRadialTraits : public RadialTraits {
140140

141+
size_t npts_; ///< Number of grid points
141142
double R_; ///< Radial scaling factor
142143

143144
public:
144145

145-
MuraKnowlesRadialTraits(double R = 1.0) : R_(R) { }
146+
MuraKnowlesRadialTraits(size_t npts, double R = 1.0) : npts_(npts), R_(R) { }
147+
148+
size_t npts() const noexcept { return npts_; }
149+
150+
std::unique_ptr<RadialTraits> clone() const {
151+
return std::make_unique<MuraKnowlesRadialTraits>(*this);
152+
}
153+
154+
bool compare(const RadialTraits& other) const noexcept {
155+
auto ptr = dynamic_cast<const MuraKnowlesRadialTraits*>(&other);
156+
return ptr ? *this == *ptr : false;
157+
}
158+
159+
bool operator==(const MuraKnowlesRadialTraits& other) const noexcept {
160+
return npts_ == other.npts_ and R_ == other.R_;
161+
}
146162

147163
template <typename PointType>
148164
inline auto radial_transform(PointType x) const noexcept {

0 commit comments

Comments
 (0)