Skip to content

Commit 8b18297

Browse files
Merge pull request #125 from EXP-code/tfinalFix
Tfinal fix
2 parents ee04186 + 804d665 commit 8b18297

File tree

3 files changed

+69
-20
lines changed

3 files changed

+69
-20
lines changed

expui/BiorthBasis.H

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1189,7 +1189,7 @@ namespace BasisClasses
11891189
std::tuple<Eigen::VectorXd, Eigen::Tensor<float, 3>>
11901190
IntegrateOrbits (double tinit, double tfinal, double h,
11911191
Eigen::MatrixXd ps, std::vector<BasisCoef> bfe,
1192-
AccelFunctor F, int nout=std::numeric_limits<int>::max());
1192+
AccelFunctor F, int nout=0);
11931193

11941194
using BiorthBasisPtr = std::shared_ptr<BiorthBasis>;
11951195
}

expui/BiorthBasis.cc

Lines changed: 59 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -3885,24 +3885,61 @@ namespace BasisClasses
38853885

38863886
// Sanity check
38873887
//
3888+
if (tfinal == tinit) {
3889+
throw std::runtime_error
3890+
("BasisClasses::IntegrateOrbits: tinit cannot be equal to tfinal");
3891+
}
3892+
3893+
if (h < 0.0 and tfinal > tinit) {
3894+
throw std::runtime_error
3895+
("BasisClasses::IntegrateOrbits: tfinal must be smaller than tinit "
3896+
"when step size is negative");
3897+
}
3898+
3899+
if (h > 0.0 and tfinal < tinit) {
3900+
throw std::runtime_error
3901+
("BasisClasses::IntegrateOrbits: tfinal must be larger than "
3902+
"tinit when step size is positive");
3903+
}
3904+
38883905
if ( (tfinal - tinit)/h >
38893906
static_cast<double>(std::numeric_limits<int>::max()) )
38903907
{
3891-
std::cout << "BasicFactor::IntegrateOrbits: step size is too small or "
3892-
<< "time interval is too large.\n";
3908+
std::cout << "BasisClasses::IntegrateOrbits: step size is too small or "
3909+
<< "time interval is too large." << std::endl;
38933910
// Return empty data
38943911
//
38953912
return {Eigen::VectorXd(), Eigen::Tensor<float, 3>()};
38963913
}
38973914

38983915
// Number of steps
38993916
//
3900-
int numT = floor( (tfinal - tinit)/h );
3917+
int numT = std::ceil( (tfinal - tinit)/h + 0.5);
3918+
3919+
// Want both end points in the output at minimum
3920+
//
3921+
numT = std::max(2, numT);
39013922

3902-
// Compute output step
3923+
// Number of output steps
39033924
//
3904-
nout = std::min<int>(numT, nout);
3905-
double H = (tfinal - tinit)/nout;
3925+
int stride = 1; // Default stride
3926+
if (nout>0) { // User has specified output count...
3927+
nout = std::max(2, nout);
3928+
stride = std::ceil(static_cast<double>(numT)/static_cast<double>(nout));
3929+
numT = (nout-1) * stride + 1;
3930+
} else { // Otherwise, use the default output number
3931+
nout = numT; // with the default stride
3932+
}
3933+
3934+
// Compute the interval-matching step
3935+
//
3936+
h = (tfinal - tinit)/(numT-1);
3937+
3938+
// DEBUG
3939+
if (false)
3940+
std::cout << "BasisClasses::IntegrateOrbits: choosing nout=" << nout
3941+
<< " numT=" << numT << " h=" << h << " stride=" << stride
3942+
<< std::endl;
39063943

39073944
// Return data
39083945
//
@@ -3912,10 +3949,10 @@ namespace BasisClasses
39123949
ret.resize(rows, 6, nout);
39133950
}
39143951
catch (const std::bad_alloc& e) {
3915-
std::cout << "BasicFactor::IntegrateOrbits: memory allocation failed: "
3952+
std::cout << "BasisClasses::IntegrateOrbits: memory allocation failed: "
39163953
<< e.what() << std::endl
39173954
<< "Your requested number of orbits and time steps requires "
3918-
<< floor(8.0*rows*6*nout/1e9)+1 << " GB free memory"
3955+
<< std::floor(4.0*rows*6*nout/1e9)+1 << " GB free memory"
39193956
<< std::endl;
39203957

39213958
// Return empty data
@@ -3927,27 +3964,37 @@ namespace BasisClasses
39273964
//
39283965
Eigen::VectorXd times(nout);
39293966

3930-
// Do the work
3967+
// Assign the initial point
39313968
//
39323969
times(0) = tinit;
39333970
for (int n=0; n<rows; n++)
39343971
for (int k=0; k<6; k++) ret(n, k, 0) = ps(n, k);
39353972

3973+
// Sign of h
3974+
int sgn = (0 < h) - (h < 0);
3975+
3976+
// Set the counters
39363977
double tnow = tinit;
3937-
for (int s=1, cnt=1; s<numT; s++) {
3978+
int s = 0, cnt = 1;
3979+
3980+
// Do the integration using stride for output
3981+
while (s++ < numT) {
3982+
if ( (tfinal - tnow)*sgn < h*sgn) h = tfinal - tnow;
39383983
std::tie(tnow, ps) = OneStep(tnow, h, ps, accel, bfe, F);
3939-
if (tnow >= H*cnt-h*1.0e-8) {
3984+
if (cnt < nout and s % stride == 0) {
39403985
times(cnt) = tnow;
39413986
for (int n=0; n<rows; n++)
39423987
for (int k=0; k<6; k++) ret(n, k, cnt) = ps(n, k);
39433988
cnt += 1;
39443989
}
39453990
}
39463991

3992+
// Corrects round off at end point
3993+
//
39473994
times(nout-1) = tnow;
39483995
for (int n=0; n<rows; n++)
39493996
for (int k=0; k<6; k++) ret(n, k, nout-1) = ps(n, k);
3950-
3997+
39513998
return {times, ret};
39523999
}
39534000

pyEXP/BasisWrappers.cc

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2194,11 +2194,13 @@ void BasisFactoryClasses(py::module &m)
21942194
R"(
21952195
Compute particle orbits in gravitational field from the bases
21962196
2197-
Integrate a list of initial conditions from tinit to tfinal with a
2198-
step size of h using the list of basis and coefficient pairs. Every
2199-
step will be included in return unless you provide an explicit
2200-
value for 'nout', the number of desired output steps. This will
2201-
choose the 'nout' points closed to the desired time.
2197+
Integrate a list of initial conditions from 'tinit' to 'tfinal' with
2198+
a step size of 'h' using the list of basis and coefficient pairs. The
2199+
step size will be adjusted to provide uniform sampling. Every
2200+
step will be returned unless you provide an explicit value for 'nout',
2201+
the number of desired output steps. In this case, the code will
2202+
choose new set step size equal or smaller to the supplied step size
2203+
with a stride to provide exactly 'nout' output times.
22022204
22032205
Parameters
22042206
----------
@@ -2216,7 +2218,7 @@ void BasisFactoryClasses(py::module &m)
22162218
func : AccelFunctor
22172219
the force function
22182220
nout : int
2219-
the number of output intervals
2221+
the number of output points, if specified
22202222
22212223
Returns
22222224
-------
@@ -2225,5 +2227,5 @@ void BasisFactoryClasses(py::module &m)
22252227
)",
22262228
py::arg("tinit"), py::arg("tfinal"), py::arg("h"),
22272229
py::arg("ps"), py::arg("basiscoef"), py::arg("func"),
2228-
py::arg("nout")=std::numeric_limits<int>::max());
2230+
py::arg("nout")=0);
22292231
}

0 commit comments

Comments
 (0)