diff --git a/include/bout/fft.hxx b/include/bout/fft.hxx index fdec8b7bec..b7d36d2166 100644 --- a/include/bout/fft.hxx +++ b/include/bout/fft.hxx @@ -28,10 +28,15 @@ #ifndef BOUT_FFT_H #define BOUT_FFT_H -#include "bout/dcomplex.hxx" +#include "bout/build_defines.hxx" + #include #include +#include + +#include +class Mesh; class Options; BOUT_ENUM_CLASS(FFT_MEASUREMENT_FLAG, estimate, measure, exhaustive); @@ -111,6 +116,16 @@ Array rfft(const Array& in); /// Expects that `in.size() == (length / 2) + 1` Array irfft(const Array& in, int length); +/// Check simulation is using 1 processor in Z, throw exception if not +/// +/// Generally, FFTs must be done over the full Z domain. Currently, most +/// methods using FFTs don't handle parallelising in Z +#if BOUT_CHECK_LEVEL > 0 +void assertZSerial(const Mesh& mesh, std::string_view name); +#else +inline void assertZSerial([[maybe_unused]] const Mesh& mesh, + [[maybe_unused]] std::string_view name) {} +#endif } // namespace fft } // namespace bout diff --git a/include/bout/mesh.hxx b/include/bout/mesh.hxx index 2b24a0e24a..502b0c5668 100644 --- a/include/bout/mesh.hxx +++ b/include/bout/mesh.hxx @@ -354,10 +354,12 @@ public: // non-local communications - virtual int getNXPE() = 0; ///< The number of processors in the X direction - virtual int getNYPE() = 0; ///< The number of processors in the Y direction - virtual int getXProcIndex() = 0; ///< This processor's index in X direction - virtual int getYProcIndex() = 0; ///< This processor's index in Y direction + virtual int getNXPE() const = 0; ///< The number of processors in the X direction + virtual int getNYPE() const = 0; ///< The number of processors in the Y direction + virtual int getNZPE() const = 0; ///< The number of processors in the Z direction + virtual int getXProcIndex() const = 0; ///< This processor's index in X direction + virtual int getYProcIndex() const = 0; ///< This processor's index in Y direction + virtual int getZProcIndex() const = 0; ///< This processor's index in Z direction // X communications virtual bool firstX() @@ -368,8 +370,6 @@ public: /// Domain is periodic in X? bool periodicX{false}; - int NXPE, PE_XIND; ///< Number of processors in X, and X processor index - /// Send a buffer of data to processor at X index +1 /// /// @param[in] buffer The data to send. Must be at least length \p size @@ -507,8 +507,10 @@ public: virtual BoutReal GlobalX(int jx) const = 0; ///< Continuous X index between 0 and 1 virtual BoutReal GlobalY(int jy) const = 0; ///< Continuous Y index (0 -> 1) + virtual BoutReal GlobalZ(int jz) const = 0; ///< Continuous Z index (0 -> 1) virtual BoutReal GlobalX(BoutReal jx) const = 0; ///< Continuous X index between 0 and 1 virtual BoutReal GlobalY(BoutReal jy) const = 0; ///< Continuous Y index (0 -> 1) + virtual BoutReal GlobalZ(BoutReal jz) const = 0; ///< Continuous Z index (0 -> 1) ////////////////////////////////////////////////////////// diff --git a/src/field/field3d.cxx b/src/field/field3d.cxx index 9821c638f7..b3448238f5 100644 --- a/src/field/field3d.cxx +++ b/src/field/field3d.cxx @@ -604,6 +604,7 @@ FieldPerp pow(const Field3D& lhs, const FieldPerp& rhs, const std::string& rgn) Field3D filter(const Field3D& var, int N0, const std::string& rgn) { + bout::fft::assertZSerial(*var.getMesh(), "`filter`"); checkData(var); int ncz = var.getNz(); @@ -649,6 +650,7 @@ Field3D filter(const Field3D& var, int N0, const std::string& rgn) { // Fourier filter in z with zmin Field3D lowPass(const Field3D& var, int zmax, bool keep_zonal, const std::string& rgn) { + bout::fft::assertZSerial(*var.getMesh(), "`lowPass`"); checkData(var); int ncz = var.getNz(); @@ -697,7 +699,7 @@ Field3D lowPass(const Field3D& var, int zmax, bool keep_zonal, const std::string * Use FFT to shift by an angle in the Z direction */ void shiftZ(Field3D& var, int jx, int jy, double zangle) { - + bout::fft::assertZSerial(*var.getMesh(), "`shiftZ`"); checkData(var); var.allocate(); // Ensure that var is unique Mesh* localmesh = var.getMesh(); diff --git a/src/invert/fft_fftw.cxx b/src/invert/fft_fftw.cxx index 05977e5f3f..2c633da6d5 100644 --- a/src/invert/fft_fftw.cxx +++ b/src/invert/fft_fftw.cxx @@ -27,8 +27,10 @@ #include "bout/build_defines.hxx" +#include #include #include +#include #include #include @@ -36,7 +38,6 @@ #include #include -#include #include #if BOUT_USE_OPENMP @@ -46,6 +47,12 @@ #include #endif // BOUT_HAS_FFTW +#if BOUT_CHECK_LEVEL > 0 +#include + +#include +#endif + namespace bout { namespace fft { @@ -527,5 +534,14 @@ Array irfft(const Array& in, int length) { return out; } +#if BOUT_CHECK_LEVEL > 0 +void assertZSerial(const Mesh& mesh, std::string_view name) { + if (mesh.getNZPE() != 1) { + throw BoutException("{} uses FFTs which are currently incompatible with multiple " + "processors in Z (using {})", + name, mesh.getNZPE()); + } +} +#endif } // namespace fft } // namespace bout diff --git a/src/invert/laplace/impls/cyclic/cyclic_laplace.cxx b/src/invert/laplace/impls/cyclic/cyclic_laplace.cxx index 99aacea4f8..80b22e73f7 100644 --- a/src/invert/laplace/impls/cyclic/cyclic_laplace.cxx +++ b/src/invert/laplace/impls/cyclic/cyclic_laplace.cxx @@ -57,6 +57,8 @@ LaplaceCyclic::LaplaceCyclic(Options* opt, const CELL_LOC loc, Mesh* mesh_in, Solver* UNUSED(solver)) : Laplacian(opt, loc, mesh_in), Acoef(0.0), C1coef(1.0), C2coef(1.0), Dcoef(1.0) { + bout::fft::assertZSerial(*localmesh, "`cyclic` inversion"); + Acoef.setLocation(location); C1coef.setLocation(location); C2coef.setLocation(location); diff --git a/src/invert/laplace/impls/iterative_parallel_tri/iterative_parallel_tri.cxx b/src/invert/laplace/impls/iterative_parallel_tri/iterative_parallel_tri.cxx index 926facd9e6..e9219091e2 100644 --- a/src/invert/laplace/impls/iterative_parallel_tri/iterative_parallel_tri.cxx +++ b/src/invert/laplace/impls/iterative_parallel_tri/iterative_parallel_tri.cxx @@ -67,12 +67,14 @@ LaplaceIPT::LaplaceIPT(Options* opt, CELL_LOC loc, Mesh* mesh_in, Solver* UNUSED au(ny, nmode), bu(ny, nmode), rl(nmode), ru(nmode), r1(ny, nmode), r2(ny, nmode), first_call(ny), x0saved(ny, 4, nmode), converged(nmode), fine_error(4, nmode) { + bout::fft::assertZSerial(*localmesh, "`ipt` inversion"); + A.setLocation(location); C.setLocation(location); D.setLocation(location); // Number of procs must be a factor of 2 - const int n = localmesh->NXPE; + const int n = localmesh->getNXPE(); if (!is_pow2(n)) { throw BoutException("LaplaceIPT error: NXPE must be a power of 2"); } diff --git a/src/invert/laplace/impls/pcr/pcr.cxx b/src/invert/laplace/impls/pcr/pcr.cxx index 60169c6eb4..a33bd7eef2 100644 --- a/src/invert/laplace/impls/pcr/pcr.cxx +++ b/src/invert/laplace/impls/pcr/pcr.cxx @@ -67,6 +67,8 @@ LaplacePCR::LaplacePCR(Options* opt, CELL_LOC loc, Mesh* mesh_in, Solver* UNUSED ncx(localmesh->LocalNx), ny(localmesh->LocalNy), avec(ny, nmode, ncx), bvec(ny, nmode, ncx), cvec(ny, nmode, ncx) { + bout::fft::assertZSerial(*localmesh, "`pcr` inversion"); + Acoef.setLocation(location); C1coef.setLocation(location); C2coef.setLocation(location); diff --git a/src/invert/laplace/impls/pcr_thomas/pcr_thomas.cxx b/src/invert/laplace/impls/pcr_thomas/pcr_thomas.cxx index 600166aa6b..d471775990 100644 --- a/src/invert/laplace/impls/pcr_thomas/pcr_thomas.cxx +++ b/src/invert/laplace/impls/pcr_thomas/pcr_thomas.cxx @@ -65,6 +65,8 @@ LaplacePCR_THOMAS::LaplacePCR_THOMAS(Options* opt, CELL_LOC loc, Mesh* mesh_in, ncx(localmesh->LocalNx), ny(localmesh->LocalNy), avec(ny, nmode, ncx), bvec(ny, nmode, ncx), cvec(ny, nmode, ncx) { + bout::fft::assertZSerial(*localmesh, "`pcr_thomas` inversion"); + Acoef.setLocation(location); C1coef.setLocation(location); C2coef.setLocation(location); diff --git a/src/invert/laplace/impls/serial_band/serial_band.cxx b/src/invert/laplace/impls/serial_band/serial_band.cxx index d7b4ac5d3b..0cf8d7259d 100644 --- a/src/invert/laplace/impls/serial_band/serial_band.cxx +++ b/src/invert/laplace/impls/serial_band/serial_band.cxx @@ -45,6 +45,9 @@ LaplaceSerialBand::LaplaceSerialBand(Options* opt, const CELL_LOC loc, Mesh* mesh_in, Solver* UNUSED(solver)) : Laplacian(opt, loc, mesh_in), Acoef(0.0), Ccoef(1.0), Dcoef(1.0) { + + bout::fft::assertZSerial(*localmesh, "`band` inversion"); + Acoef.setLocation(location); Ccoef.setLocation(location); Dcoef.setLocation(location); diff --git a/src/invert/laplace/impls/serial_tri/serial_tri.cxx b/src/invert/laplace/impls/serial_tri/serial_tri.cxx index 0fb9294d76..a14e0e4a26 100644 --- a/src/invert/laplace/impls/serial_tri/serial_tri.cxx +++ b/src/invert/laplace/impls/serial_tri/serial_tri.cxx @@ -33,14 +33,15 @@ #include #include #include -#include -#include - #include +#include LaplaceSerialTri::LaplaceSerialTri(Options* opt, CELL_LOC loc, Mesh* mesh_in, Solver* UNUSED(solver)) : Laplacian(opt, loc, mesh_in), A(0.0), C(1.0), D(1.0) { + + bout::fft::assertZSerial(*localmesh, "`tri` inversion"); + A.setLocation(location); C.setLocation(location); D.setLocation(location); diff --git a/src/invert/laplace/impls/spt/spt.cxx b/src/invert/laplace/impls/spt/spt.cxx index e39ca7e89f..cd24ee1acf 100644 --- a/src/invert/laplace/impls/spt/spt.cxx +++ b/src/invert/laplace/impls/spt/spt.cxx @@ -45,6 +45,9 @@ LaplaceSPT::LaplaceSPT(Options* opt, const CELL_LOC loc, Mesh* mesh_in, Solver* UNUSED(solver)) : Laplacian(opt, loc, mesh_in), Acoef(0.0), Ccoef(1.0), Dcoef(1.0) { + + bout::fft::assertZSerial(*localmesh, "`spt` inversion"); + Acoef.setLocation(location); Ccoef.setLocation(location); Dcoef.setLocation(location); @@ -341,14 +344,14 @@ int LaplaceSPT::start(const FieldPerp& b, SPT_data& data) { // Send data localmesh->sendXOut(std::begin(data.buffer), 4 * (maxmode + 1), data.comm_tag); - } else if (localmesh->PE_XIND == 1) { + } else if (localmesh->getXProcIndex() == 1) { // Post a receive data.recv_handle = localmesh->irecvXIn(std::begin(data.buffer), 4 * (maxmode + 1), data.comm_tag); } data.proc++; // Now moved onto the next processor - if (localmesh->NXPE == 2) { + if (localmesh->getNXPE() == 2) { data.dir = -1; // Special case. Otherwise reversal handled in spt_continue } @@ -366,7 +369,7 @@ int LaplaceSPT::next(SPT_data& data) { return 1; } - if (localmesh->PE_XIND == data.proc) { + if (localmesh->getXProcIndex() == data.proc) { /// This processor's turn to do inversion // Wait for data to arrive @@ -450,7 +453,7 @@ int LaplaceSPT::next(SPT_data& data) { } } - if (localmesh->PE_XIND != 0) { // If not finished yet + if (localmesh->getXProcIndex() != 0) { // If not finished yet /// Send data if (data.dir > 0) { @@ -460,7 +463,7 @@ int LaplaceSPT::next(SPT_data& data) { } } - } else if (localmesh->PE_XIND == data.proc + data.dir) { + } else if (localmesh->getXProcIndex() == data.proc + data.dir) { // This processor is next, post receive if (data.dir > 0) { @@ -474,7 +477,7 @@ int LaplaceSPT::next(SPT_data& data) { data.proc += data.dir; - if (data.proc == localmesh->NXPE - 1) { + if (data.proc == localmesh->getNXPE() - 1) { data.dir = -1; // Reverses direction at the end } diff --git a/src/invert/laplacexz/impls/cyclic/laplacexz-cyclic.cxx b/src/invert/laplacexz/impls/cyclic/laplacexz-cyclic.cxx index dc987edc75..b3e619df0c 100644 --- a/src/invert/laplacexz/impls/cyclic/laplacexz-cyclic.cxx +++ b/src/invert/laplacexz/impls/cyclic/laplacexz-cyclic.cxx @@ -13,6 +13,7 @@ LaplaceXZcyclic::LaplaceXZcyclic(Mesh* m, Options* options, const CELL_LOC loc) : LaplaceXZ(m, options, loc) { // Note: `m` may be nullptr, but localmesh is set in LaplaceXZ base constructor + bout::fft::assertZSerial(*localmesh, "`cyclic` X-Z inversion"); // Number of Z Fourier modes, including DC nmode = (localmesh->LocalNz) / 2 + 1; diff --git a/src/invert/parderiv/impls/cyclic/cyclic.cxx b/src/invert/parderiv/impls/cyclic/cyclic.cxx index 61d23be823..c32c3d4b2d 100644 --- a/src/invert/parderiv/impls/cyclic/cyclic.cxx +++ b/src/invert/parderiv/impls/cyclic/cyclic.cxx @@ -53,6 +53,8 @@ InvertParCR::InvertParCR(Options* opt, CELL_LOC location, Mesh* mesh_in) : InvertPar(opt, location, mesh_in), A(1.0), B(0.0), C(0.0), D(0.0), E(0.0) { + + bout::fft::assertZSerial(*localmesh, "InvertParCR"); // Number of k equations to solve for each x location nsys = 1 + (localmesh->LocalNz) / 2; diff --git a/src/invert/pardiv/impls/cyclic/pardiv_cyclic.cxx b/src/invert/pardiv/impls/cyclic/pardiv_cyclic.cxx index d4e773f017..aad01c5f2f 100644 --- a/src/invert/pardiv/impls/cyclic/pardiv_cyclic.cxx +++ b/src/invert/pardiv/impls/cyclic/pardiv_cyclic.cxx @@ -53,6 +53,7 @@ InvertParDivCR::InvertParDivCR(Options* opt, CELL_LOC location, Mesh* mesh_in) : InvertParDiv(opt, location, mesh_in) { + bout::fft::assertZSerial(*localmesh, "InvertParDivCR"); // Number of k equations to solve for each x location nsys = 1 + (localmesh->LocalNz) / 2; } diff --git a/src/mesh/boundary_standard.cxx b/src/mesh/boundary_standard.cxx index 6fd659b375..39a1ea3641 100644 --- a/src/mesh/boundary_standard.cxx +++ b/src/mesh/boundary_standard.cxx @@ -2631,6 +2631,7 @@ void BoundaryNeumann_NonOrthogonal::apply(Field3D& f) { #if not(BOUT_USE_METRIC_3D) Mesh* mesh = bndry->localmesh; ASSERT1(mesh == f.getMesh()); + bout::fft::assertZSerial(*mesh, "Zero Laplace on Field3D"); int ncz = mesh->LocalNz; Coordinates* metric = f.getCoordinates(); @@ -2734,6 +2735,7 @@ void BoundaryNeumann_NonOrthogonal::apply(Field3D& f) { #if not(BOUT_USE_METRIC_3D) Mesh* mesh = bndry->localmesh; ASSERT1(mesh == f.getMesh()); + bout::fft::assertZSerial(*mesh, "Zero Laplace on Field3D"); const int ncz = mesh->LocalNz; ASSERT0(ncz % 2 == 0); // Allocation assumes even number @@ -2843,6 +2845,8 @@ void BoundaryNeumann_NonOrthogonal::apply(Field3D& f) { Mesh* mesh = bndry->localmesh; ASSERT1(mesh == f.getMesh()); + bout::fft::assertZSerial(*mesh, "Zero Laplace on Field3D"); + Coordinates* metric = f.getCoordinates(); int ncz = mesh->LocalNz; diff --git a/src/mesh/coordinates.cxx b/src/mesh/coordinates.cxx index 086fa2e23e..91bbddfd56 100644 --- a/src/mesh/coordinates.cxx +++ b/src/mesh/coordinates.cxx @@ -1659,7 +1659,7 @@ Field3D Coordinates::Delp2(const Field3D& f, CELL_LOC outloc, bool useFFT) { Field3D result{emptyFrom(f).setLocation(outloc)}; - if (useFFT and not bout::build::use_metric_3d) { + if (useFFT and not bout::build::use_metric_3d and localmesh->getNZPE() == 1) { int ncz = localmesh->LocalNz; // Allocate memory @@ -1727,7 +1727,7 @@ FieldPerp Coordinates::Delp2(const FieldPerp& f, CELL_LOC outloc, bool useFFT) { int jy = f.getIndex(); result.setIndex(jy); - if (useFFT) { + if (useFFT and localmesh->getNZPE() == 1) { int ncz = localmesh->LocalNz; // Allocate memory diff --git a/src/mesh/data/gridfromoptions.cxx b/src/mesh/data/gridfromoptions.cxx index 379e279de0..662329c526 100644 --- a/src/mesh/data/gridfromoptions.cxx +++ b/src/mesh/data/gridfromoptions.cxx @@ -146,8 +146,7 @@ bool GridFromOptions::get(Mesh* m, std::vector& var, const std::string } case GridDataSource::Z: { for (int z = 0; z < len; z++) { - pos.set("z", - (TWOPI * (z - m->OffsetZ + offset)) / static_cast(m->LocalNz)); + pos.set("z", TWOPI * m->GlobalZ(z - m->OffsetZ + offset)); var[z] = gen->generate(pos); } break; diff --git a/src/mesh/impls/bout/boutmesh.cxx b/src/mesh/impls/bout/boutmesh.cxx index 4be01d4637..4aaa760c04 100644 --- a/src/mesh/impls/bout/boutmesh.cxx +++ b/src/mesh/impls/bout/boutmesh.cxx @@ -57,6 +57,8 @@ If you want the old setting, you have to specify mesh:symmetricGlobalY=false in << optionfile << "\n"; } OPTION(options, symmetricGlobalY, true); + OPTION(options, symmetricGlobalZ, false); // The default should be updated to true but + // this breaks backwards compatibility comm_x = MPI_COMM_NULL; comm_inner = MPI_COMM_NULL; @@ -516,9 +518,10 @@ int BoutMesh::load() { findProcessorSplit(); } - // Get X and Y processor indices + // Get X, Y, Z processor indices PE_YIND = MYPE / NXPE; PE_XIND = MYPE % NXPE; + PE_ZIND = 0; // Set the other grid sizes from nx, ny, nz setDerivedGridSizes(); @@ -1519,13 +1522,17 @@ int BoutMesh::wait(comm_handle handle) { * Non-Local Communications ***************************************************************/ -int BoutMesh::getNXPE() { return NXPE; } +int BoutMesh::getNXPE() const { return NXPE; } -int BoutMesh::getNYPE() { return NYPE; } +int BoutMesh::getNYPE() const { return NYPE; } -int BoutMesh::getXProcIndex() { return PE_XIND; } +int BoutMesh::getNZPE() const { return NZPE; } -int BoutMesh::getYProcIndex() { return PE_YIND; } +int BoutMesh::getXProcIndex() const { return PE_XIND; } + +int BoutMesh::getYProcIndex() const { return PE_YIND; } + +int BoutMesh::getZProcIndex() const { return PE_ZIND; } /**************************************************************** * X COMMUNICATIONS @@ -1689,35 +1696,32 @@ int BoutMesh::PROC_NUM(int xind, int yind) const { return -1; } - return yind * NXPE + xind; + return (yind * NXPE) + xind; } -/// Returns the global X index given a local index -int BoutMesh::XGLOBAL(BoutReal xloc, BoutReal& xglo) const { - xglo = xloc + PE_XIND * MXSUB; - return static_cast(xglo); +BoutReal BoutMesh::getGlobalXIndex(BoutReal xloc) const { + return xloc + (PE_XIND * MXSUB); } -int BoutMesh::getGlobalXIndex(int xlocal) const { return xlocal + PE_XIND * MXSUB; } +int BoutMesh::getGlobalXIndex(int xlocal) const { return xlocal + (PE_XIND * MXSUB); } int BoutMesh::getGlobalXIndexNoBoundaries(int xlocal) const { - return xlocal + PE_XIND * MXSUB - MXG; + return xlocal + (PE_XIND * MXSUB) - MXG; } -int BoutMesh::getLocalXIndex(int xglobal) const { return xglobal - PE_XIND * MXSUB; } +int BoutMesh::getLocalXIndex(int xglobal) const { return xglobal - (PE_XIND * MXSUB); } int BoutMesh::getLocalXIndexNoBoundaries(int xglobal) const { - return xglobal - PE_XIND * MXSUB + MXG; + return xglobal - (PE_XIND * MXSUB) + MXG; } -int BoutMesh::YGLOBAL(BoutReal yloc, BoutReal& yglo) const { - yglo = yloc + PE_YIND * MYSUB - MYG; - return static_cast(yglo); +BoutReal BoutMesh::getGlobalYIndex(BoutReal yloc) const { + return yloc + (PE_YIND * MYSUB) - MYG; } int BoutMesh::getGlobalYIndex(int ylocal) const { - int yglobal = ylocal + PE_YIND * MYSUB; - if (jyseps1_2 > jyseps2_1 and PE_YIND * MYSUB + 2 * MYG + 1 > ny_inner) { + int yglobal = ylocal + (PE_YIND * MYSUB); + if (jyseps1_2 > jyseps2_1 and (PE_YIND * MYSUB) + (2 * MYG) + 1 > ny_inner) { // Double null, and we are past the upper target yglobal += 2 * MYG; } @@ -1725,12 +1729,12 @@ int BoutMesh::getGlobalYIndex(int ylocal) const { } int BoutMesh::getGlobalYIndexNoBoundaries(int ylocal) const { - return ylocal + PE_YIND * MYSUB - MYG; + return ylocal + (PE_YIND * MYSUB) - MYG; } int BoutMesh::getLocalYIndex(int yglobal) const { - int ylocal = yglobal - PE_YIND * MYSUB; - if (jyseps1_2 > jyseps2_1 and PE_YIND * MYSUB + 2 * MYG + 1 > ny_inner) { + int ylocal = yglobal - (PE_YIND * MYSUB); + if (jyseps1_2 > jyseps2_1 and (PE_YIND * MYSUB) + (2 * MYG) + 1 > ny_inner) { // Double null, and we are past the upper target ylocal -= 2 * MYG; } @@ -1738,19 +1742,25 @@ int BoutMesh::getLocalYIndex(int yglobal) const { } int BoutMesh::getLocalYIndexNoBoundaries(int yglobal) const { - return yglobal - PE_YIND * MYSUB + MYG; + return yglobal - (PE_YIND * MYSUB) + MYG; } -int BoutMesh::YGLOBAL(int yloc, int yproc) const { return yloc + yproc * MYSUB - MYG; } +int BoutMesh::YGLOBAL(int yloc, int yproc) const { return yloc + (yproc * MYSUB) - MYG; } -int BoutMesh::YLOCAL(int yglo, int yproc) const { return yglo - yproc * MYSUB + MYG; } +int BoutMesh::YLOCAL(int yglo, int yproc) const { return yglo - (yproc * MYSUB) + MYG; } -int BoutMesh::getGlobalZIndex(int zlocal) const { return zlocal; } +int BoutMesh::getGlobalZIndex(int zlocal) const { return zlocal + (PE_ZIND * MZSUB); } -int BoutMesh::getGlobalZIndexNoBoundaries(int zlocal) const { return zlocal; } +int BoutMesh::getGlobalZIndexNoBoundaries(int zlocal) const { + return zlocal + (PE_ZIND * MZSUB) - MZG; +} int BoutMesh::getLocalZIndex(int zglobal) const { return zglobal; } +BoutReal BoutMesh::getGlobalZIndex(BoutReal zloc) const { + return zloc + (PE_ZIND * MZSUB); +} + int BoutMesh::getLocalZIndexNoBoundaries(int zglobal) const { return zglobal; } int BoutMesh::YPROC(int yind) const { @@ -1818,16 +1828,15 @@ BoutMesh::BoutMesh(int input_nx, int input_ny, int input_nz, int mxg, int myg, i BoutMesh::BoutMesh(int input_nx, int input_ny, int input_nz, int mxg, int myg, int nxpe, int nype, int pe_xind, int pe_yind, bool symmetric_X, bool symmetric_Y, - bool periodicX_, int ixseps1_, int ixseps2_, int jyseps1_1_, + bool periodic_X_, int ixseps1_, int ixseps2_, int jyseps1_1_, int jyseps2_1_, int jyseps1_2_, int jyseps2_2_, int ny_inner_, bool create_regions) : nx(input_nx), ny(input_ny), nz(input_nz), NPES(nxpe * nype), - MYPE(nxpe * pe_yind + pe_xind), PE_YIND(pe_yind), NYPE(nype), NZPE(1), - ixseps1(ixseps1_), ixseps2(ixseps2_), symmetricGlobalX(symmetric_X), + MYPE((nxpe * pe_yind) + pe_xind), PE_XIND(pe_xind), NXPE(nxpe), PE_YIND(pe_yind), + NYPE(nype), ixseps1(ixseps1_), ixseps2(ixseps2_), symmetricGlobalX(symmetric_X), symmetricGlobalY(symmetric_Y), MXG(mxg), MYG(myg), MZG(0) { - NXPE = nxpe; - PE_XIND = pe_xind; - periodicX = periodicX_; + + periodicX = periodic_X_; setYDecompositionIndices(jyseps1_1_, jyseps2_1_, jyseps1_2_, jyseps2_2_, ny_inner_); setDerivedGridSizes(); topology(); @@ -3124,8 +3133,7 @@ BoutReal BoutMesh::GlobalX(int jx) const { BoutReal BoutMesh::GlobalX(BoutReal jx) const { // Get global X index as a BoutReal - BoutReal xglo; - XGLOBAL(jx, xglo); + const BoutReal xglo = getGlobalXIndex(jx); if (symmetricGlobalX) { // With this definition the boundary sits dx/2 away form the first/last inner points @@ -3179,8 +3187,7 @@ BoutReal BoutMesh::GlobalY(int jy) const { BoutReal BoutMesh::GlobalY(BoutReal jy) const { // Get global Y index as a BoutReal - BoutReal yglo; - YGLOBAL(jy, yglo); + BoutReal yglo = getGlobalYIndex(jy); if (symmetricGlobalY) { BoutReal yi = yglo; @@ -3222,6 +3229,28 @@ BoutReal BoutMesh::GlobalY(BoutReal jy) const { return yglo / static_cast(nycore); } +BoutReal BoutMesh::GlobalZ(int jz) const { + if (symmetricGlobalZ) { + // With this definition the boundary sits dz/2 away form the first/last inner points + return (0.5 + getGlobalZIndexNoBoundaries(jz) - (nz - MZ) * 0.5) + / static_cast(MZ); + } + return static_cast(getGlobalZIndexNoBoundaries(jz)) + / static_cast(MZ); +} + +BoutReal BoutMesh::GlobalZ(BoutReal jz) const { + + // Get global Z index as a BoutReal + const BoutReal zglo = getGlobalZIndex(jz); + + if (symmetricGlobalZ) { + // With this definition the boundary sits dz/2 away form the first/last inner points + return (0.5 + zglo - (nz - MZ) * 0.5) / static_cast(MZ); + } + return zglo / static_cast(MZ); +} + void BoutMesh::outputVars(Options& output_options) { Timer time("io"); output_options["zperiod"].force(zperiod, "BoutMesh"); diff --git a/src/mesh/impls/bout/boutmesh.hxx b/src/mesh/impls/bout/boutmesh.hxx index 876edab1da..3923c34511 100644 --- a/src/mesh/impls/bout/boutmesh.hxx +++ b/src/mesh/impls/bout/boutmesh.hxx @@ -7,7 +7,6 @@ #include "bout/unused.hxx" #include -#include #include #include #include @@ -58,10 +57,12 @@ public: ///////////////////////////////////////////// // non-local communications - int getNXPE() override; ///< The number of processors in the X direction - int getNYPE() override; ///< The number of processors in the Y direction - int getXProcIndex() override; ///< This processor's index in X direction - int getYProcIndex() override; ///< This processor's index in Y direction + int getNXPE() const override; ///< The number of processors in the X direction + int getNYPE() const override; ///< The number of processors in the Y direction + int getNZPE() const override; ///< The number of processors in the Z direction + int getXProcIndex() const override; ///< This processor's index in X direction + int getYProcIndex() const override; ///< This processor's index in Y direction + int getZProcIndex() const override; ///< This processor's index in Z direction ///////////////////////////////////////////// // X communications @@ -172,8 +173,10 @@ public: BoutReal GlobalX(int jx) const override; BoutReal GlobalY(int jy) const override; + BoutReal GlobalZ(int jz) const override; BoutReal GlobalX(BoutReal jx) const override; BoutReal GlobalY(BoutReal jy) const override; + BoutReal GlobalZ(BoutReal jz) const override; BoutReal getIxseps1() const { return ixseps1; } BoutReal getIxseps2() const { return ixseps2; } @@ -206,7 +209,7 @@ protected: /// `getPossibleBoundaries`. \p create_regions controls whether or /// not the various `Region`s are created on the new mesh BoutMesh(int input_nx, int input_ny, int input_nz, int mxg, int myg, int nxpe, int nype, - int pe_xind, int pe_yind, bool symmetric_X, bool symmetric_Y, bool periodic_X, + int pe_xind, int pe_yind, bool symmetric_X, bool symmetric_Y, bool periodic_X_, int ixseps1_, int ixseps2_, int jyseps1_1_, int jyseps2_1_, int jyseps1_2_, int jyseps2_2_, int ny_inner_, bool create_regions = true); @@ -295,16 +298,24 @@ private: int NPES; ///< Number of processors int MYPE; ///< Rank of this processor + int PE_XIND; ///< X index of this processor + int NXPE; ///< Number of processors in the X direction + int PE_YIND; ///< Y index of this processor - int NYPE; // Number of processors in the Y direction + int NYPE; ///< Number of processors in the Y direction - int NZPE; + int PE_ZIND{0}; ///< Z index of this processor + int NZPE{1}; ///< Number of processors in the Z direction /// Is this processor in the core region? bool MYPE_IN_CORE{false}; - int XGLOBAL(BoutReal xloc, BoutReal& xglo) const; - int YGLOBAL(BoutReal yloc, BoutReal& yglo) const; + /// Returns the global X index given a local index + BoutReal getGlobalXIndex(BoutReal xloc) const; + /// Returns the global Y index given a local index + BoutReal getGlobalYIndex(BoutReal yloc) const; + /// Returns the global Z index given a local index + BoutReal getGlobalZIndex(BoutReal zloc) const; // Topology int ixseps1, ixseps2, jyseps1_1, jyseps2_1, jyseps1_2, jyseps2_2; @@ -355,8 +366,9 @@ private: // Settings bool TwistShift; // Use a twist-shift condition in core? - bool symmetricGlobalX; ///< Use a symmetric definition in GlobalX() function - bool symmetricGlobalY; + bool symmetricGlobalX; ///< Use a symmetric definition in `GlobalX()` function + bool symmetricGlobalY; ///< Use a symmetric definition in `GlobalY()` function + bool symmetricGlobalZ{false}; ///< Use a symmetric definition in `GlobalZ()` function int zperiod; BoutReal ZMIN, ZMAX; // Range of the Z domain (in fractions of 2pi) diff --git a/src/mesh/index_derivs.cxx b/src/mesh/index_derivs.cxx index 48284c3536..70fc47b538 100644 --- a/src/mesh/index_derivs.cxx +++ b/src/mesh/index_derivs.cxx @@ -427,6 +427,7 @@ class FFTDerivativeType { ASSERT2(bout::utils::is_Field3D_v); // Should never need to call this with Field2D auto* theMesh = var.getMesh(); + ASSERT2(theMesh->getNZPE() == 1); // Only works if serial in Z for FFTs // Calculate how many Z wavenumbers will be removed const int ncz = theMesh->getNpoints(direction); @@ -493,6 +494,7 @@ class FFT2ndDerivativeType { ASSERT2(bout::utils::is_Field3D_v); // Should never need to call this with Field2D auto* theMesh = var.getMesh(); + ASSERT2(theMesh->getNZPE() == 1); // Only works if serial in Z for FFTs // Calculate how many Z wavenumbers will be removed const int ncz = theMesh->getNpoints(direction); diff --git a/src/mesh/parallel/shiftedmetric.cxx b/src/mesh/parallel/shiftedmetric.cxx index 382052047d..64c6d9a2ce 100644 --- a/src/mesh/parallel/shiftedmetric.cxx +++ b/src/mesh/parallel/shiftedmetric.cxx @@ -24,6 +24,7 @@ ShiftedMetric::ShiftedMetric(Mesh& m, CELL_LOC location_in, Field2D zShift_, ASSERT1(zShift.getLocation() == location); // check the coordinate system used for the grid data source ShiftedMetric::checkInputGrid(); + bout::fft::assertZSerial(m, "ShiftedMetric"); cachePhases(); } diff --git a/src/mesh/parallel/shiftedmetricinterp.cxx b/src/mesh/parallel/shiftedmetricinterp.cxx index 4543c4c3fb..c71618ab19 100644 --- a/src/mesh/parallel/shiftedmetricinterp.cxx +++ b/src/mesh/parallel/shiftedmetricinterp.cxx @@ -27,8 +27,13 @@ * **************************************************************************/ +#include + #include "shiftedmetricinterp.hxx" + +#include "bout/boutexception.hxx" #include "bout/constants.hxx" +#include "bout/field3d.hxx" #include "bout/parallel_boundary_region.hxx" ShiftedMetricInterp::ShiftedMetricInterp(Mesh& mesh, CELL_LOC location_in, @@ -36,12 +41,19 @@ ShiftedMetricInterp::ShiftedMetricInterp(Mesh& mesh, CELL_LOC location_in, Options* opt) : ParallelTransform(mesh, opt), location(location_in), zShift(std::move(zShift_in)), zlength(zlength_in), ydown_index(mesh.ystart) { + + if (mesh.getNZPE() > 1) { + throw BoutException("ShiftedMetricInterp only works with 1 processor in Z"); + } + // check the coordinate system used for the grid data source ShiftedMetricInterp::checkInputGrid(); // Allocate space for interpolator cache: y-guard cells in each direction parallel_slice_interpolators.resize(mesh.ystart * 2); + const BoutReal z_factor = static_cast(mesh.GlobalNzNoBoundaries) / zlength; + // Create the Interpolation objects and set whether they go up or down the // magnetic field auto& interp_options = options["zinterpolation"]; @@ -60,16 +72,16 @@ ShiftedMetricInterp::ShiftedMetricInterp(Mesh& mesh, CELL_LOC location_in, // Find the index positions where the magnetic field line intersects the x-z plane // y_offset points up - Field3D zt_prime_up(&mesh), zt_prime_down(&mesh); + Field3D zt_prime_up(&mesh); + Field3D zt_prime_down(&mesh); zt_prime_up.allocate(); zt_prime_down.allocate(); for (const auto& i : zt_prime_up.getRegion(RGN_NOY)) { // Field line moves in z by an angle zShift(i,j+1)-zShift(i,j) when going // from j to j+1, but we want the shift in index-space - zt_prime_up[i] = static_cast(i.z()) - + (zShift[i.yp(y_offset + 1)] - zShift[i]) - * static_cast(mesh.GlobalNz) / zlength; + zt_prime_up[i] = static_cast(mesh.getGlobalZIndexNoBoundaries(i.z())) + + ((zShift[i.yp(y_offset + 1)] - zShift[i]) * z_factor); } parallel_slice_interpolators[yup_index + y_offset]->calcWeights(zt_prime_up); @@ -77,9 +89,8 @@ ShiftedMetricInterp::ShiftedMetricInterp(Mesh& mesh, CELL_LOC location_in, for (const auto& i : zt_prime_down.getRegion(RGN_NOY)) { // Field line moves in z by an angle -(zShift(i,j)-zShift(i,j-1)) when going // from j to j-1, but we want the shift in index-space - zt_prime_down[i] = static_cast(i.z()) - - (zShift[i] - zShift[i.ym(y_offset + 1)]) - * static_cast(mesh.GlobalNz) / zlength; + zt_prime_down[i] = static_cast(mesh.getGlobalZIndexNoBoundaries(i.z())) + - ((zShift[i] - zShift[i.ym(y_offset + 1)]) * z_factor); } parallel_slice_interpolators[ydown_index + y_offset]->calcWeights(zt_prime_down); @@ -91,15 +102,16 @@ ShiftedMetricInterp::ShiftedMetricInterp(Mesh& mesh, CELL_LOC location_in, interp_from_aligned = ZInterpolationFactory::getInstance().create(&interp_options, 0, &mesh); - Field3D zt_prime_to(&mesh), zt_prime_from(&mesh); + Field3D zt_prime_to(&mesh); + Field3D zt_prime_from(&mesh); zt_prime_to.allocate(); zt_prime_from.allocate(); for (const auto& i : zt_prime_to) { // Field line moves in z by an angle zShift(i,j) when going // from y0 to y(j), but we want the shift in index-space - zt_prime_to[i] = static_cast(i.z()) - + zShift[i] * static_cast(mesh.GlobalNz) / zlength; + zt_prime_to[i] = static_cast(mesh.getGlobalZIndexNoBoundaries(i.z())) + + (zShift[i] * z_factor); } interp_to_aligned->calcWeights(zt_prime_to); @@ -108,17 +120,16 @@ ShiftedMetricInterp::ShiftedMetricInterp(Mesh& mesh, CELL_LOC location_in, // Field line moves in z by an angle zShift(i,j) when going // from y0 to y(j), but we want the shift in index-space. // Here we reverse the shift, so subtract zShift - zt_prime_from[i] = static_cast(i.z()) - - zShift[i] * static_cast(mesh.GlobalNz) / zlength; + zt_prime_from[i] = static_cast(mesh.getGlobalZIndexNoBoundaries(i.z())) + - (zShift[i] * z_factor); } interp_from_aligned->calcWeights(zt_prime_from); - int yvalid = mesh.LocalNy - 2 * mesh.ystart; - // avoid overflow - no stencil need more than 5 points - if (yvalid > 20) { - yvalid = 20; - } + // avoid overflow - no stencil needs more than 5 points + const auto yvalid = + static_cast(std::min(mesh.LocalNy - (2 * mesh.ystart), 20)); + // Create regions for parallel boundary conditions Field2D dy; mesh.get(dy, "dy", 1.); @@ -128,10 +139,10 @@ ShiftedMetricInterp::ShiftedMetricInterp(Mesh& mesh, CELL_LOC location_in, for (int z = mesh.zstart; z <= mesh.zend; z++) { forward_boundary_xin->add_point( it.ind, mesh.yend, z, - mesh.GlobalX(it.ind), // x - 2. * PI * mesh.GlobalY(mesh.yend + 0.5), // y - zlength * BoutReal(z) / BoutReal(mesh.GlobalNz) // z - + 0.5 * (zShift(it.ind, mesh.yend + 1) - zShift(it.ind, mesh.yend)), + mesh.GlobalX(it.ind), // x + 2. * PI * mesh.GlobalY(mesh.yend + 0.5), // y + (zlength * mesh.GlobalZ(z)) // z + + (0.5 * (zShift(it.ind, mesh.yend + 1) - zShift(it.ind, mesh.yend))), 0.25 * (1 // dy/2 + dy(it.ind, mesh.yend + 1) / dy(it.ind, mesh.yend)), // length @@ -144,10 +155,10 @@ ShiftedMetricInterp::ShiftedMetricInterp(Mesh& mesh, CELL_LOC location_in, for (int z = mesh.zstart; z <= mesh.zend; z++) { backward_boundary_xin->add_point( it.ind, mesh.ystart, z, - mesh.GlobalX(it.ind), // x - 2. * PI * mesh.GlobalY(mesh.ystart - 0.5), // y - zlength * BoutReal(z) / BoutReal(mesh.GlobalNz) // z - + 0.5 * (zShift(it.ind, mesh.ystart) - zShift(it.ind, mesh.ystart - 1)), + mesh.GlobalX(it.ind), // x + 2. * PI * mesh.GlobalY(mesh.ystart - 0.5), // y + (zlength * mesh.GlobalZ(z)) // z + + (0.5 * (zShift(it.ind, mesh.ystart) - zShift(it.ind, mesh.ystart - 1))), 0.25 * (1 // dy/2 + dy(it.ind, mesh.ystart - 1) / dy(it.ind, mesh.ystart)), @@ -161,10 +172,10 @@ ShiftedMetricInterp::ShiftedMetricInterp(Mesh& mesh, CELL_LOC location_in, for (int z = mesh.zstart; z <= mesh.zend; z++) { forward_boundary_xout->add_point( it.ind, mesh.yend, z, - mesh.GlobalX(it.ind), // x - 2. * PI * mesh.GlobalY(mesh.yend + 0.5), // y - zlength * BoutReal(z) / BoutReal(mesh.GlobalNz) // z - + 0.5 * (zShift(it.ind, mesh.yend + 1) - zShift(it.ind, mesh.yend)), + mesh.GlobalX(it.ind), // x + 2. * PI * mesh.GlobalY(mesh.yend + 0.5), // y + (zlength * mesh.GlobalZ(z)) // z + + (0.5 * (zShift(it.ind, mesh.yend + 1) - zShift(it.ind, mesh.yend))), 0.25 * (1 // dy/2 + dy(it.ind, mesh.yend + 1) / dy(it.ind, mesh.yend)), @@ -177,10 +188,10 @@ ShiftedMetricInterp::ShiftedMetricInterp(Mesh& mesh, CELL_LOC location_in, for (int z = mesh.zstart; z <= mesh.zend; z++) { backward_boundary_xout->add_point( it.ind, mesh.ystart, z, - mesh.GlobalX(it.ind), // x - 2. * PI * mesh.GlobalY(mesh.ystart - 0.5), // y - zlength * BoutReal(z) / BoutReal(mesh.GlobalNz) // z - + 0.5 * (zShift(it.ind, mesh.ystart) - zShift(it.ind, mesh.ystart - 1)), + mesh.GlobalX(it.ind), // x + 2. * PI * mesh.GlobalY(mesh.ystart - 0.5), // y + (zlength * mesh.GlobalZ(z)) // z + + (0.5 * (zShift(it.ind, mesh.ystart) - zShift(it.ind, mesh.ystart - 1))), 0.25 * (dy(it.ind, mesh.ystart - 1) / dy(it.ind, mesh.ystart) // dy/2 + 1), diff --git a/src/sys/generator_context.cxx b/src/sys/generator_context.cxx index 25274e8107..31a5662378 100644 --- a/src/sys/generator_context.cxx +++ b/src/sys/generator_context.cxx @@ -1,5 +1,7 @@ #include "bout/sys/generator_context.hxx" + #include "bout/boundary_region.hxx" +#include "bout/bout_types.hxx" #include "bout/constants.hxx" #include "bout/mesh.hxx" @@ -15,9 +17,8 @@ Context::Context(int ix, int iy, int iz, CELL_LOC loc, Mesh* msh, BoutReal t) parameters["y"] = (loc == CELL_YLOW) ? PI * (msh->GlobalY(iy) + msh->GlobalY(iy - 1)) : TWOPI * msh->GlobalY(iy); - parameters["z"] = (loc == CELL_ZLOW) - ? TWOPI * (iz - 0.5) / static_cast(msh->LocalNz) - : TWOPI * iz / static_cast(msh->LocalNz); + parameters["z"] = (loc == CELL_ZLOW) ? PI * (msh->GlobalZ(iz) + msh->GlobalZ(iz - 1)) + : TWOPI * msh->GlobalZ(iz); parameters["t"] = t; } @@ -26,21 +27,20 @@ Context::Context(const BoundaryRegion* bndry, int iz, CELL_LOC loc, BoutReal t, : localmesh(msh) { // Add one to X index if boundary is in -x direction, so that XLOW is on the boundary - int ix = (bndry->bx < 0) ? bndry->x + 1 : bndry->x; + const int ix = (bndry->bx < 0) ? bndry->x + 1 : bndry->x; parameters["x"] = ((loc == CELL_XLOW) || (bndry->bx != 0)) ? 0.5 * (msh->GlobalX(ix) + msh->GlobalX(ix - 1)) : msh->GlobalX(ix); - int iy = (bndry->by < 0) ? bndry->y + 1 : bndry->y; + const int iy = (bndry->by < 0) ? bndry->y + 1 : bndry->y; - parameters["y"] = ((loc == CELL_YLOW) || bndry->by) + parameters["y"] = ((loc == CELL_YLOW) || (bndry->by != 0)) ? PI * (msh->GlobalY(iy) + msh->GlobalY(iy - 1)) : TWOPI * msh->GlobalY(iy); - parameters["z"] = (loc == CELL_ZLOW) - ? TWOPI * (iz - 0.5) / static_cast(msh->LocalNz) - : TWOPI * iz / static_cast(msh->LocalNz); + parameters["z"] = (loc == CELL_ZLOW) ? PI * (msh->GlobalZ(iz) + msh->GlobalZ(iz - 1)) + : TWOPI * msh->GlobalZ(iz); parameters["t"] = t; } diff --git a/tests/MMS/laplace/laplace.cxx b/tests/MMS/laplace/laplace.cxx index fbcdee355c..214c022cd5 100644 --- a/tests/MMS/laplace/laplace.cxx +++ b/tests/MMS/laplace/laplace.cxx @@ -26,9 +26,8 @@ int main(int argc, char** argv) { meshoptions->get("Lx", Lx, 1.0); /*this assumes equidistant grid*/ - int nguard = mesh->xstart; - mesh->getCoordinates()->dx = Lx / (mesh->GlobalNx - 2 * nguard); - mesh->getCoordinates()->dz = TWOPI * Lx / (mesh->LocalNz); + mesh->getCoordinates()->dx = Lx / (mesh->GlobalNx - 2 * mesh->xstart); + mesh->getCoordinates()->dz = TWOPI * Lx / (mesh->GlobalNz - 2 * mesh->zstart); ///// // Create a Laplacian inversion solver diff --git a/tests/integrated/test-communications/test-communications.cxx b/tests/integrated/test-communications/test-communications.cxx index 54c266ece2..5acbe3ee1a 100644 --- a/tests/integrated/test-communications/test-communications.cxx +++ b/tests/integrated/test-communications/test-communications.cxx @@ -11,10 +11,10 @@ int main(int argc, char** argv) { // interior cells BOUT_FOR(i, f.getRegion("RGN_NOBNDRY")) { - f[i] = mesh->GlobalNzNoBoundaries - * (mesh->GlobalNyNoBoundaries * mesh->getGlobalXIndexNoBoundaries(i.x()) - + mesh->getGlobalYIndexNoBoundaries(i.y())) - + i.z(); + f[i] = (mesh->GlobalNzNoBoundaries + * (mesh->GlobalNyNoBoundaries * mesh->getGlobalXIndexNoBoundaries(i.x()) + + mesh->getGlobalYIndexNoBoundaries(i.y()))) + + mesh->getGlobalZIndexNoBoundaries(i.z()); } // lower x-boundary cells @@ -25,10 +25,10 @@ int main(int argc, char** argv) { for (int y = mesh->ystart; y <= mesh->yend; y++) { for (int z = mesh->zstart; z <= mesh->zend; z++) { f(x, y, z) = startind - + mesh->GlobalNzNoBoundaries - * (mesh->GlobalNyNoBoundaries * x - + mesh->getGlobalYIndexNoBoundaries(y)) - + z; + + (mesh->GlobalNzNoBoundaries + * (mesh->GlobalNyNoBoundaries * x + + mesh->getGlobalYIndexNoBoundaries(y))) + + mesh->getGlobalZIndexNoBoundaries(z); } } } @@ -41,10 +41,10 @@ int main(int argc, char** argv) { for (int y = mesh->ystart; y <= mesh->yend; y++) { for (int z = mesh->zstart; z <= mesh->zend; z++) { f(mesh->xend + 1 + x, y, z) = startind - + mesh->GlobalNzNoBoundaries - * (mesh->GlobalNyNoBoundaries * x - + mesh->getGlobalYIndexNoBoundaries(y)) - + z; + + (mesh->GlobalNzNoBoundaries + * (mesh->GlobalNyNoBoundaries * x + + mesh->getGlobalYIndexNoBoundaries(y))) + + mesh->getGlobalZIndexNoBoundaries(z); } } } @@ -56,8 +56,9 @@ int main(int argc, char** argv) { int x = it.ind; for (int y = 0; y < mesh->ystart; y++) { for (int z = mesh->zstart; z <= mesh->zend; z++) { - f(x, y, z) = - startind + mesh->GlobalNzNoBoundaries * (mesh->getGlobalXIndex(x) + y) + z; + f(x, y, z) = startind + + (mesh->GlobalNzNoBoundaries * (mesh->getGlobalXIndex(x) + y)) + + mesh->getGlobalZIndexNoBoundaries(z); } } } @@ -69,7 +70,8 @@ int main(int argc, char** argv) { for (int y = 0; y < mesh->ystart; y++) { for (int z = mesh->zstart; z <= mesh->zend; z++) { f(x, mesh->yend + 1 + y, z) = - startind + mesh->GlobalNzNoBoundaries * (mesh->getGlobalXIndex(x) + y) + z; + startind + (mesh->GlobalNzNoBoundaries * (mesh->getGlobalXIndex(x) + y)) + + mesh->getGlobalZIndexNoBoundaries(z); } } } diff --git a/tests/integrated/test-interpolate-z/test_interpolate.cxx b/tests/integrated/test-interpolate-z/test_interpolate.cxx index 3ef6e50425..8ee96d2897 100644 --- a/tests/integrated/test-interpolate-z/test_interpolate.cxx +++ b/tests/integrated/test-interpolate-z/test_interpolate.cxx @@ -71,8 +71,7 @@ int main(int argc, char** argv) { deltaz[index] = dz; // Get the global indices bout::generator::Context pos{index, CELL_CENTRE, deltaz.getMesh(), 0.0}; - pos.set("x", mesh->GlobalX(index.x()), "z", - TWOPI * static_cast(dz) / static_cast(mesh->LocalNz)); + pos.set("x", mesh->GlobalX(index.x()), "z", TWOPI * mesh->GlobalZ(dz)); // Generate the analytic solution at the displacements a_solution[index] = a_gen->generate(pos); b_solution[index] = b_gen->generate(pos); diff --git a/tests/integrated/test-interpolate/test_interpolate.cxx b/tests/integrated/test-interpolate/test_interpolate.cxx index 33963dbb9e..98ecd68bc5 100644 --- a/tests/integrated/test-interpolate/test_interpolate.cxx +++ b/tests/integrated/test-interpolate/test_interpolate.cxx @@ -79,8 +79,7 @@ int main(int argc, char** argv) { deltaz[index] = dz; // Get the global indices bout::generator::Context pos{index, CELL_CENTRE, deltax.getMesh(), 0.0}; - pos.set("x", mesh->GlobalX(dx), "z", - TWOPI * static_cast(dz) / static_cast(mesh->LocalNz)); + pos.set("x", mesh->GlobalX(dx), "z", TWOPI * mesh->GlobalZ(dz)); // Generate the analytic solution at the displacements a_solution[index] = a_gen->generate(pos); b_solution[index] = b_gen->generate(pos); diff --git a/tests/unit/fake_mesh.hxx b/tests/unit/fake_mesh.hxx index 2feb43826d..7d7326c149 100644 --- a/tests/unit/fake_mesh.hxx +++ b/tests/unit/fake_mesh.hxx @@ -66,15 +66,13 @@ public: xend = nx - 2; ystart = 1; yend = ny - 2; - zstart = 0; + zstart = 0; // no guards zend = nz - 1; StaggerGrids = false; // Unused variables periodicX = false; - NXPE = 1; - PE_XIND = 0; IncIntShear = false; maxregionblocksize = MAXREGIONBLOCKSIZE; @@ -109,10 +107,12 @@ public: return nullptr; } int wait(comm_handle UNUSED(handle)) override { return 0; } - int getNXPE() override { return 1; } - int getNYPE() override { return 1; } - int getXProcIndex() override { return 1; } - int getYProcIndex() override { return 1; } + int getNXPE() const override { return 1; } + int getNYPE() const override { return 1; } + int getNZPE() const override { return 1; } + int getXProcIndex() const override { return 1; } + int getYProcIndex() const override { return 1; } + int getZProcIndex() const override { return 1; } bool firstX() const override { return true; } bool lastX() const override { return true; } int sendXOut(BoutReal* UNUSED(buffer), int UNUSED(size), int UNUSED(tag)) override { @@ -170,20 +170,22 @@ public: } BoutReal GlobalX(int jx) const override { return jx; } BoutReal GlobalY(int jy) const override { return jy; } + BoutReal GlobalZ(int jz) const override { return jz; } BoutReal GlobalX(BoutReal jx) const override { return jx; } BoutReal GlobalY(BoutReal jy) const override { return jy; } + BoutReal GlobalZ(BoutReal jz) const override { return jz; } int getGlobalXIndex(int) const override { return 0; } int getGlobalXIndexNoBoundaries(int) const override { return 0; } int getGlobalYIndex(int y) const override { return y; } int getGlobalYIndexNoBoundaries(int y) const override { return y; } - int getGlobalZIndex(int) const override { return 0; } - int getGlobalZIndexNoBoundaries(int) const override { return 0; } + int getGlobalZIndex(int z) const override { return z; } + int getGlobalZIndexNoBoundaries(int z) const override { return z; } int getLocalXIndex(int) const override { return 0; } int getLocalXIndexNoBoundaries(int) const override { return 0; } int getLocalYIndex(int y) const override { return y; } int getLocalYIndexNoBoundaries(int y) const override { return y; } - int getLocalZIndex(int) const override { return 0; } - int getLocalZIndexNoBoundaries(int) const override { return 0; } + int getLocalZIndex(int z) const override { return z; } + int getLocalZIndexNoBoundaries(int z) const override { return z; } void initDerivs(Options* opt) { StaggerGrids = true; diff --git a/tests/unit/field/test_field_factory.cxx b/tests/unit/field/test_field_factory.cxx index 26dc1e2990..b45206b979 100644 --- a/tests/unit/field/test_field_factory.cxx +++ b/tests/unit/field/test_field_factory.cxx @@ -161,9 +161,7 @@ TYPED_TEST(FieldFactoryCreationTest, CreateZ) { auto output = this->create("z"); auto expected = makeField( - [](typename TypeParam::ind_type& index) -> BoutReal { - return TWOPI * index.z() / FieldFactoryCreationTest::nz; - }, + [](typename TypeParam::ind_type& index) -> BoutReal { return TWOPI * index.z(); }, mesh); EXPECT_TRUE(IsFieldEqual(output, expected)); @@ -209,7 +207,7 @@ TYPED_TEST(FieldFactoryCreationTest, CreateZStaggered) { offset = 0.5; } - return TWOPI * (index.z() - offset) / FieldFactoryCreationTest::nz; + return TWOPI * (index.z() - offset); }, mesh); diff --git a/tests/unit/include/bout/test_region.cxx b/tests/unit/include/bout/test_region.cxx index 3b21700412..00137c1ce7 100644 --- a/tests/unit/include/bout/test_region.cxx +++ b/tests/unit/include/bout/test_region.cxx @@ -211,8 +211,8 @@ TEST_F(RegionTest, regionLoopNoBndry) { BOUT_FOR(i, region) { a[i] = 1.0; } const int nmesh = RegionTest::nx * RegionTest::ny * RegionTest::nz; - const int ninner = - (mesh->LocalNz * (1 + mesh->xend - mesh->xstart) * (1 + mesh->yend - mesh->ystart)); + const int ninner = ((1 + mesh->zend - mesh->zstart) * (1 + mesh->xend - mesh->xstart) + * (1 + mesh->yend - mesh->ystart)); int numExpectNotMatching = nmesh - ninner; int numNotMatching = 0; @@ -249,8 +249,8 @@ TEST_F(RegionTest, regionLoopNoBndrySerial) { int count = 0; BOUT_FOR_SERIAL(i, region) { ++count; } - const int ninner = - (mesh->LocalNz * (1 + mesh->xend - mesh->xstart) * (1 + mesh->yend - mesh->ystart)); + const int ninner = ((1 + mesh->zend - mesh->zstart) * (1 + mesh->xend - mesh->xstart) + * (1 + mesh->yend - mesh->ystart)); EXPECT_EQ(count, ninner); } @@ -300,8 +300,8 @@ TEST_F(RegionTest, regionLoopNoBndrySection) { } } - const int ninner = - (mesh->LocalNz * (1 + mesh->xend - mesh->xstart) * (1 + mesh->yend - mesh->ystart)); + const int ninner = ((1 + mesh->zend - mesh->zstart) * (1 + mesh->xend - mesh->xstart) + * (1 + mesh->yend - mesh->ystart)); EXPECT_EQ(count, ninner); } @@ -334,8 +334,8 @@ TEST_F(RegionTest, regionLoopNoBndryInner) { } const int nmesh = RegionTest::nx * RegionTest::ny * RegionTest::nz; - const int ninner = - (mesh->LocalNz * (1 + mesh->xend - mesh->xstart) * (1 + mesh->yend - mesh->ystart)); + const int ninner = ((1 + mesh->zend - mesh->zstart) * (1 + mesh->xend - mesh->xstart) + * (1 + mesh->yend - mesh->ystart)); int numExpectNotMatching = nmesh - ninner; int numNotMatching = 0; diff --git a/tests/unit/mesh/data/test_gridfromoptions.cxx b/tests/unit/mesh/data/test_gridfromoptions.cxx index b821d2e225..41dd40fcc3 100644 --- a/tests/unit/mesh/data/test_gridfromoptions.cxx +++ b/tests/unit/mesh/data/test_gridfromoptions.cxx @@ -57,13 +57,13 @@ class GridFromOptionsTest : public ::testing::Test { expected_2d = makeField( [](Field2D::ind_type& index) { - return index.x() + (TWOPI * index.y()) + (TWOPI * index.z() / nz) + 3; + return index.x() + (TWOPI * index.y()) + (TWOPI * index.z()) + 3; }, &mesh_from_options); expected_3d = makeField( [](Field3D::ind_type& index) { - return index.x() + (TWOPI * index.y()) + (TWOPI * index.z() / nz) + 3; + return index.x() + (TWOPI * index.y()) + (TWOPI * index.z()) + 3; }, &mesh_from_options); expected_metric = @@ -316,8 +316,8 @@ TEST_F(GridFromOptionsTest, GetVectorBoutRealYNone) { TEST_F(GridFromOptionsTest, GetVectorBoutRealZ) { std::vector result{}; - std::vector expected{3., 3. + (1. * TWOPI / nz), 3. + (2. * TWOPI / nz), - 3. + (3. * TWOPI / nz), 3. + (4. * TWOPI / nz)}; + std::vector expected{3., 3. + (1. * TWOPI), 3. + (2. * TWOPI), + 3. + (3. * TWOPI), 3. + (4. * TWOPI)}; EXPECT_TRUE(griddata->get(&mesh_from_options, result, "f", nz, 0, GridDataSource::Direction::Z)); @@ -326,9 +326,8 @@ TEST_F(GridFromOptionsTest, GetVectorBoutRealZ) { TEST_F(GridFromOptionsTest, GetVectorBoutRealZOffset) { std::vector result{}; - std::vector expected{3. + (1. * TWOPI / nz), 3. + (2. * TWOPI / nz), - 3. + (3. * TWOPI / nz), 3. + (4. * TWOPI / nz), - 3. + (5. * TWOPI / nz)}; + std::vector expected{3. + (1. * TWOPI), 3. + (2. * TWOPI), 3. + (3. * TWOPI), + 3. + (4. * TWOPI), 3. + (5. * TWOPI)}; EXPECT_TRUE(griddata->get(&mesh_from_options, result, "f", nz, 1, GridDataSource::Direction::Z)); @@ -337,8 +336,8 @@ TEST_F(GridFromOptionsTest, GetVectorBoutRealZOffset) { TEST_F(GridFromOptionsTest, GetVectorBoutRealZMeshOffset) { std::vector result{}; - std::vector expected{3. + (-1. * TWOPI / nz), 3., 3. + (1. * TWOPI / nz), - 3. + (2. * TWOPI / nz), 3. + (3. * TWOPI / nz)}; + std::vector expected{3. + (-1. * TWOPI), 3., 3. + (1. * TWOPI), + 3. + (2. * TWOPI), 3. + (3. * TWOPI)}; mesh_from_options.OffsetX = 100; mesh_from_options.OffsetY = 100; @@ -397,7 +396,7 @@ TEST_F(GridFromOptionsTest, CoordinatesXlowInterp) { Coordinates::FieldMetric expected_xlow = makeField( [](Coordinates::FieldMetric::ind_type& index) { - return index.x() - 0.5 + (TWOPI * index.y()) + (TWOPI * index.z() / nz) + 3; + return index.x() - 0.5 + (TWOPI * index.y()) + (TWOPI * index.z()) + 3; }, &mesh_from_options); @@ -438,8 +437,7 @@ TEST_F(GridFromOptionsTest, CoordinatesXlowRead) { Field2D expected_xlow = makeField( [](Field2D::ind_type& index) { - return (nx - index.x() + 0.5) + (TWOPI * index.y()) + (TWOPI * index.z() / nz) - + 3; + return (nx - index.x() + 0.5) + (TWOPI * index.y()) + (TWOPI * index.z()) + 3; }, &mesh_from_options); @@ -471,7 +469,7 @@ TEST_F(GridFromOptionsTest, CoordinatesYlowInterp) { Field2D expected_ylow = makeField( [](Field2D::ind_type& index) { - return index.x() + (TWOPI * (index.y() - 0.5)) + (TWOPI * index.z() / nz) + 3; + return index.x() + (TWOPI * (index.y() - 0.5)) + (TWOPI * index.z()) + 3; }, &mesh_from_options); @@ -521,8 +519,7 @@ TEST_F(GridFromOptionsTest, CoordinatesYlowRead) { Field2D expected_ylow = makeField( [](Field2D::ind_type& index) { - return index.x() + (TWOPI * (ny - index.y() + 0.5)) + (TWOPI * index.z() / nz) - + 3; + return index.x() + (TWOPI * (ny - index.y() + 0.5)) + (TWOPI * index.z()) + 3; }, &mesh_from_options);