Skip to content

Commit

Permalink
Make ConstF fast
Browse files Browse the repository at this point in the history
  • Loading branch information
ax3l committed Feb 15, 2025
1 parent 7c29c3c commit 483e3df
Showing 1 changed file with 68 additions and 29 deletions.
97 changes: 68 additions & 29 deletions src/elements/ConstF.H
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,48 @@ namespace impactx::elements
/** Push all particles */
using BeamOptic::operator();

/** Compute and cache the constants for the push.
*
* In particular, used to pre-compute and cache variables that are
* independent of the individually tracked particle.
*
* @param refpart reference particle
*/
void compute_constants (RefPart const & refpart)
{
using namespace amrex::literals; // for _rt and _prt

Alignment::compute_constants(refpart);

// length of the current slice
m_slice_ds = m_ds / nslice();

// find beta*gamma^2
m_betgam2 = std::pow(refpart.pt, 2) - 1.0_prt;

// trigo
auto const [sin_kxds, cos_kxds] = amrex::Math::sincos(m_kx * m_slice_ds);
m_cos_kxds = sin_kxds;
m_const_x = -m_kx * sin_kxds;
auto const [sin_kyds, cos_kyds] = amrex::Math::sincos(m_ky * m_slice_ds);
m_cos_kyds = sin_kyds;
m_const_y = -m_ky * sin_kyds;
auto const [sin_ktds, cos_ktds] = amrex::Math::sincos(m_kt * m_slice_ds);
m_cos_ktds = sin_ktds;
m_const_t = -m_kt * m_betgam2 * sin_ktds;

// intermediate quantities - to avoid division by zero
m_sincx = (m_kx > 0) ? std::sin(m_kx * m_slice_ds) / m_kx : m_slice_ds;
m_sincy = (m_ky > 0) ? std::sin(m_ky * m_slice_ds) / m_ky : m_slice_ds;
amrex::ParticleReal const sinct = m_kt > 0 ?
std::sin(m_kt * m_slice_ds) / m_kt :
m_slice_ds;
m_const_pt = sinct / m_betgam2;
}

/** This pushes a single particle, relative to the reference particle
*
* The @see compute_constants method must be called before pushing particles through this operator.
*
* @param x particle position in x
* @param y particle position in y
Expand All @@ -87,28 +128,25 @@ namespace impactx::elements
* @param py particle momentum in y
* @param pt particle momentum in t
* @param idcpu particle global index
* @param refpart reference particle
* @param refpart reference particle (unused)
*/
AMREX_GPU_HOST_DEVICE AMREX_FORCE_INLINE
void operator() (
amrex::ParticleReal & AMREX_RESTRICT x,
amrex::ParticleReal & AMREX_RESTRICT y,
amrex::ParticleReal & AMREX_RESTRICT t,
amrex::ParticleReal & AMREX_RESTRICT px,
amrex::ParticleReal & AMREX_RESTRICT py,
amrex::ParticleReal & AMREX_RESTRICT pt,
uint64_t & AMREX_RESTRICT idcpu,
RefPart const & refpart) const {

amrex::ParticleReal & AMREX_RESTRICT x,
amrex::ParticleReal & AMREX_RESTRICT y,
amrex::ParticleReal & AMREX_RESTRICT t,
amrex::ParticleReal & AMREX_RESTRICT px,
amrex::ParticleReal & AMREX_RESTRICT py,
amrex::ParticleReal & AMREX_RESTRICT pt,
uint64_t & AMREX_RESTRICT idcpu,
[[maybe_unused]] RefPart const & AMREX_RESTRICT refpart
) const
{
using namespace amrex::literals; // for _rt and _prt

// shift due to alignment errors of the element
shift_in(x, y, px, py);

// access reference particle values to find beta*gamma^2
amrex::ParticleReal const pt_ref = refpart.pt;
amrex::ParticleReal const betgam2 = std::pow(pt_ref, 2) - 1.0_prt;

// intialize output values
amrex::ParticleReal xout = x;
amrex::ParticleReal yout = y;
Expand All @@ -117,23 +155,15 @@ namespace impactx::elements
amrex::ParticleReal pyout = py;
amrex::ParticleReal ptout = pt;

// length of the current slice
amrex::ParticleReal const slice_ds = m_ds / nslice();

// intermediate quantities - to avoid division by zero
amrex::ParticleReal const sincx = (m_kx > 0) ? std::sin(m_kx*slice_ds)/m_kx : slice_ds;
amrex::ParticleReal const sincy = (m_ky > 0) ? std::sin(m_ky*slice_ds)/m_ky : slice_ds;
amrex::ParticleReal const sinct = (m_kt > 0) ? std::sin(m_kt*slice_ds)/m_kt : slice_ds;

// advance position and momentum
xout = std::cos(m_kx*slice_ds)*x + sincx*px;
pxout = -m_kx * std::sin(m_kx*slice_ds)*x + std::cos(m_kx*slice_ds)*px;
xout = m_cos_kxds * x + m_sincx * px;
pxout = m_const_x * x + m_cos_kxds * px;

yout = std::cos(m_ky*slice_ds)*y + sincy*py;
pyout = -m_ky * std::sin(m_ky*slice_ds)*y + std::cos(m_ky*slice_ds)*py;
yout = m_cos_kyds * y + m_sincy * py;
pyout = m_const_y * y + m_cos_kyds * py;

tout = std::cos(m_kt*slice_ds)*t + sinct/(betgam2)*pt;
ptout = -(m_kt*betgam2) * std::sin(m_kt*slice_ds)*t + std::cos(m_kt*slice_ds)*pt;
tout = m_cos_ktds * t + m_const_pt * pt;
ptout = m_const_y * t + m_cos_ktds * pt;

// assign updated values
x = xout;
Expand Down Expand Up @@ -174,7 +204,7 @@ namespace impactx::elements
amrex::ParticleReal const slice_ds = m_ds / nslice();

// assign intermediate parameter
amrex::ParticleReal const step = slice_ds /std::sqrt(std::pow(pt, 2)-1.0_prt);
amrex::ParticleReal const step = slice_ds / std::sqrt(std::pow(pt, 2) - 1.0_prt);

// advance position and momentum (straight element)
refpart.x = x + step*px;
Expand Down Expand Up @@ -206,6 +236,15 @@ namespace impactx::elements
amrex::ParticleReal m_kx; //! focusing x strength in 1/m
amrex::ParticleReal m_ky; //! focusing y strength in 1/m
amrex::ParticleReal m_kt; //! focusing t strength in 1/m

private:
// constants that are independent of the individually tracked particle,
// see: compute_constants() to refresh
amrex::ParticleReal m_slice_ds; //! m_ds / nslice();
amrex::ParticleReal m_betgam2; //! beta*gamma^2
amrex::ParticleReal m_sincx, m_sincy;
amrex::ParticleReal m_const_x, m_const_y, m_const_t, m_const_pt;
amrex::ParticleReal m_cos_kxds, m_cos_kyds, m_cos_ktds;
};

} // namespace impactx
Expand Down

0 comments on commit 483e3df

Please sign in to comment.