Skip to content

Commit

Permalink
Make ChrDrift fast
Browse files Browse the repository at this point in the history
  • Loading branch information
ax3l committed Feb 14, 2025
1 parent 86c0c2b commit 5913d3d
Showing 1 changed file with 41 additions and 18 deletions.
59 changes: 41 additions & 18 deletions src/elements/ChrDrift.H
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,29 @@ namespace impactx::elements
/** Push all particles */
using BeamOptic::operator();

/** Compute and cache the constants for the push.
*
* In particular, used to pre-compute and cache variables that are
* independent of the individually tracked particle.
*
* @param refpart reference particle
*/
void compute_constants (RefPart const & refpart)
{
using namespace amrex::literals; // for _rt and _prt

Alignment::compute_constants(refpart);

// length of the current slice
m_slice_ds = m_ds / nslice();

// access reference particle values to find beta, gamma
m_beta = refpart.beta();
m_gamma = refpart.gamma();

m_const1 = 1_prt / (2_prt * std::pow(m_beta,3) * std::pow(m_gamma, 2));
}

/** This is a chrdrift functor, so that a variable of this type can be used like a chrdrift function.
*
* @param x particle position in x
Expand All @@ -83,7 +106,7 @@ namespace impactx::elements
* @param py particle momentum in y
* @param pt particle momentum in t
* @param idcpu particle global index
* @param refpart reference particle
* @param refpart reference particle (unused)
*/
AMREX_GPU_HOST_DEVICE AMREX_FORCE_INLINE
void operator() (
Expand All @@ -94,7 +117,7 @@ namespace impactx::elements
amrex::ParticleReal & AMREX_RESTRICT py,
amrex::ParticleReal & AMREX_RESTRICT pt,
uint64_t & AMREX_RESTRICT idcpu,
RefPart const & refpart
[[maybe_unused]] RefPart const & refpart
) const
{
using namespace amrex::literals; // for _rt and _prt
Expand All @@ -112,30 +135,22 @@ namespace impactx::elements
amrex::ParticleReal const pyout = py;
amrex::ParticleReal const ptout = pt;

// length of the current slice
amrex::ParticleReal const slice_ds = m_ds / nslice();

// access reference particle values to find beta, gamma
amrex::ParticleReal const bet = refpart.beta();
amrex::ParticleReal const gam = refpart.gamma();

// compute particle momentum deviation delta + 1
amrex::ParticleReal delta1;
delta1 = std::sqrt(1_prt - 2_prt*pt/bet + std::pow(pt,2));
amrex::ParticleReal const idelta1 = 1_prt / std::sqrt(1_prt - 2_prt*pt/m_beta + std::pow(pt,2));

// advance transverse position and momentum (drift)
x = xout + slice_ds * px / delta1;
x = xout + m_slice_ds * px * idelta1;
// pxout = px;
y = yout + slice_ds * py / delta1;
y = yout + m_slice_ds * py * idelta1;
// pyout = py;

// the corresponding symplectic update to t
amrex::ParticleReal term = 2_prt * std::pow(pt,2) + std::pow(px,2) + std::pow(py,2);
term = 2_prt - 4_prt*bet*pt + std::pow(bet,2)*term;
term = -2_prt + std::pow(gam,2)*term;
term = (-1_prt+bet*pt)*term;
term = term/(2_prt * std::pow(bet,3) * std::pow(gam,2));
t = tout - slice_ds * (1_prt / bet + term /std::pow(delta1, 3));
term = 2_prt - 4_prt*m_beta*pt + std::pow(m_beta,2)*term;
term = -2_prt + std::pow(m_gamma, 2)*term;
term = (-1_prt+m_beta*pt)*term;
term = term * m_const1;
t = tout - m_slice_ds * (1_prt / m_beta + term * std::pow(idelta1, 3));
// ptout = pt;

// assign updated momenta
Expand Down Expand Up @@ -201,6 +216,14 @@ namespace impactx::elements
throw std::runtime_error(std::string(type) + ": Envelope tracking is not yet implemented!");
return Map6x6::Identity();
}

private:
// constants that are independent of the individually tracked particle,
// see: compute_constants() to refresh
amrex::ParticleReal m_slice_ds; //! m_ds / nslice();
amrex::ParticleReal m_beta;
amrex::ParticleReal m_gamma;
amrex::ParticleReal m_const1;
};

} // namespace impactx
Expand Down

0 comments on commit 5913d3d

Please sign in to comment.