From 6af0a54bc15b34db26e262dcea82b1bc94132c9c Mon Sep 17 00:00:00 2001 From: Axel Huebl Date: Fri, 14 Feb 2025 14:33:15 -0800 Subject: [PATCH] Make `Buncher` fast --- src/elements/Buncher.H | 57 +++++++++++++++++++++++++++++------------- 1 file changed, 40 insertions(+), 17 deletions(-) diff --git a/src/elements/Buncher.H b/src/elements/Buncher.H index a06e357ff..6f904320d 100644 --- a/src/elements/Buncher.H +++ b/src/elements/Buncher.H @@ -64,9 +64,30 @@ namespace impactx::elements /** Push all particles */ using BeamOptic::operator(); + /** Compute and cache the constants for the push. + * + * In particular, used to pre-compute and cache variables that are + * independent of the individually tracked particle. + * + * @param refpart reference particle + */ + void compute_constants (RefPart const & refpart) + { + using namespace amrex::literals; // for _rt and _prt + + Alignment::compute_constants(refpart); + + // find beta*gamma^2 + amrex::ParticleReal const betgam2 = 2.0_prt * std::pow(refpart.pt, 2) - 1.0_prt; + + m_kV_r2bg2 = m_k * m_V / (2.0_prt * betgam2); + } + /** This is a buncher functor, so that a variable of this type can be used like a * buncher function. * + * The @see compute_constants method must be called before pushing particles through this operator. + * * @param x particle position in x * @param y particle position in y * @param t particle position in t @@ -74,37 +95,34 @@ namespace impactx::elements * @param py particle momentum in y * @param pt particle momentum in t * @param idcpu particle global index (unused) - * @param refpart reference particle + * @param refpart reference particle (unused) */ AMREX_GPU_HOST_DEVICE AMREX_FORCE_INLINE void operator() ( - amrex::ParticleReal & AMREX_RESTRICT x, - amrex::ParticleReal & AMREX_RESTRICT y, - amrex::ParticleReal & AMREX_RESTRICT t, - amrex::ParticleReal & AMREX_RESTRICT px, - amrex::ParticleReal & AMREX_RESTRICT py, - amrex::ParticleReal & AMREX_RESTRICT pt, - [[maybe_unused]] uint64_t & AMREX_RESTRICT idcpu, - RefPart const & refpart) const { - + amrex::ParticleReal & AMREX_RESTRICT x, + amrex::ParticleReal & AMREX_RESTRICT y, + amrex::ParticleReal & AMREX_RESTRICT t, + amrex::ParticleReal & AMREX_RESTRICT px, + amrex::ParticleReal & AMREX_RESTRICT py, + amrex::ParticleReal & AMREX_RESTRICT pt, + [[maybe_unused]] uint64_t & AMREX_RESTRICT idcpu, + [[maybe_unused]] RefPart const & AMREX_RESTRICT refpart + ) const + { using namespace amrex::literals; // for _rt and _prt // shift due to alignment errors of the element shift_in(x, y, px, py); - // access reference particle values to find (beta*gamma)^2 - amrex::ParticleReal const pt_ref = refpart.pt; - amrex::ParticleReal const betgam2 = std::pow(pt_ref, 2) - 1.0_prt; - // intialize output values of momenta amrex::ParticleReal pxout = px; amrex::ParticleReal pyout = py; amrex::ParticleReal ptout = pt; // advance position and momentum - pxout = px + m_k*m_V/(2.0_prt*betgam2)*x; - pyout = py + m_k*m_V/(2.0_prt*betgam2)*y; - ptout = pt - m_k*m_V*t; + pxout = px + m_kV_r2bg2 * x; + pyout = py + m_kV_r2bg2 * y; + ptout = pt - m_k * m_V * t; // assign updated momenta px = pxout; @@ -136,6 +154,11 @@ namespace impactx::elements amrex::ParticleReal m_V; //! normalized (max) RF voltage drop. amrex::ParticleReal m_k; //! RF wavenumber in 1/m. + + private: + // constants that are independent of the individually tracked particle, + // see: compute_constants() to refresh + amrex::ParticleReal m_kV_r2bg2; //! m_k*m_V/(2.0_prt*betgam2) }; } // namespace impactx