Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 16 additions & 1 deletion include/bout/fft.hxx
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,15 @@
#ifndef BOUT_FFT_H
#define BOUT_FFT_H

#include "bout/dcomplex.hxx"
#include "bout/build_defines.hxx"

#include <bout/array.hxx>
#include <bout/bout_enum_class.hxx>
#include <bout/dcomplex.hxx>

#include <string_view>

class Mesh;
class Options;

BOUT_ENUM_CLASS(FFT_MEASUREMENT_FLAG, estimate, measure, exhaustive);
Expand Down Expand Up @@ -111,6 +116,16 @@ Array<dcomplex> rfft(const Array<BoutReal>& in);
/// Expects that `in.size() == (length / 2) + 1`
Array<BoutReal> irfft(const Array<dcomplex>& in, int length);

/// Check simulation is using 1 processor in Z, throw exception if not
///
/// Generally, FFTs must be done over the full Z domain. Currently, most
/// methods using FFTs don't handle parallelising in Z
#if BOUT_CHECK_LEVEL > 0
void assertZSerial(const Mesh& mesh, std::string_view name);
#else
inline void assertZSerial([[maybe_unused]] const Mesh& mesh,
[[maybe_unused]] std::string_view name) {}
#endif
} // namespace fft
} // namespace bout

Expand Down
14 changes: 8 additions & 6 deletions include/bout/mesh.hxx
Original file line number Diff line number Diff line change
Expand Up @@ -354,10 +354,12 @@ public:

// non-local communications

virtual int getNXPE() = 0; ///< The number of processors in the X direction
virtual int getNYPE() = 0; ///< The number of processors in the Y direction
virtual int getXProcIndex() = 0; ///< This processor's index in X direction
virtual int getYProcIndex() = 0; ///< This processor's index in Y direction
virtual int getNXPE() const = 0; ///< The number of processors in the X direction
virtual int getNYPE() const = 0; ///< The number of processors in the Y direction
virtual int getNZPE() const = 0; ///< The number of processors in the Z direction
virtual int getXProcIndex() const = 0; ///< This processor's index in X direction
virtual int getYProcIndex() const = 0; ///< This processor's index in Y direction
virtual int getZProcIndex() const = 0; ///< This processor's index in Z direction

// X communications
virtual bool firstX()
Expand All @@ -368,8 +370,6 @@ public:
/// Domain is periodic in X?
bool periodicX{false};

int NXPE, PE_XIND; ///< Number of processors in X, and X processor index

/// Send a buffer of data to processor at X index +1
///
/// @param[in] buffer The data to send. Must be at least length \p size
Expand Down Expand Up @@ -507,8 +507,10 @@ public:

virtual BoutReal GlobalX(int jx) const = 0; ///< Continuous X index between 0 and 1
virtual BoutReal GlobalY(int jy) const = 0; ///< Continuous Y index (0 -> 1)
virtual BoutReal GlobalZ(int jz) const = 0; ///< Continuous Z index (0 -> 1)
virtual BoutReal GlobalX(BoutReal jx) const = 0; ///< Continuous X index between 0 and 1
virtual BoutReal GlobalY(BoutReal jy) const = 0; ///< Continuous Y index (0 -> 1)
virtual BoutReal GlobalZ(BoutReal jz) const = 0; ///< Continuous Z index (0 -> 1)

//////////////////////////////////////////////////////////

Expand Down
4 changes: 3 additions & 1 deletion src/field/field3d.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -604,6 +604,7 @@ FieldPerp pow(const Field3D& lhs, const FieldPerp& rhs, const std::string& rgn)

Field3D filter(const Field3D& var, int N0, const std::string& rgn) {

bout::fft::assertZSerial(*var.getMesh(), "`filter`");
checkData(var);

int ncz = var.getNz();
Expand Down Expand Up @@ -649,6 +650,7 @@ Field3D filter(const Field3D& var, int N0, const std::string& rgn) {
// Fourier filter in z with zmin
Field3D lowPass(const Field3D& var, int zmax, bool keep_zonal, const std::string& rgn) {

bout::fft::assertZSerial(*var.getMesh(), "`lowPass`");
checkData(var);
int ncz = var.getNz();

Expand Down Expand Up @@ -697,7 +699,7 @@ Field3D lowPass(const Field3D& var, int zmax, bool keep_zonal, const std::string
* Use FFT to shift by an angle in the Z direction
*/
void shiftZ(Field3D& var, int jx, int jy, double zangle) {

bout::fft::assertZSerial(*var.getMesh(), "`shiftZ`");
checkData(var);
var.allocate(); // Ensure that var is unique
Mesh* localmesh = var.getMesh();
Expand Down
18 changes: 17 additions & 1 deletion src/invert/fft_fftw.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -27,16 +27,17 @@

#include "bout/build_defines.hxx"

#include <bout/coordinates.hxx>
#include <bout/fft.hxx>
#include <bout/globals.hxx>
#include <bout/mesh.hxx>
#include <bout/options.hxx>
#include <bout/unused.hxx>

#if BOUT_HAS_FFTW
#include <bout/constants.hxx>
#include <bout/openmpwrap.hxx>

#include <cmath>
#include <fftw3.h>

#if BOUT_USE_OPENMP
Expand All @@ -46,6 +47,12 @@
#include <bout/boutexception.hxx>
#endif // BOUT_HAS_FFTW

#if BOUT_CHECK_LEVEL > 0
#include <bout/boutexception.hxx>

#include <string_view>
#endif

namespace bout {
namespace fft {

Expand Down Expand Up @@ -527,5 +534,14 @@ Array<BoutReal> irfft(const Array<dcomplex>& in, int length) {
return out;
}

#if BOUT_CHECK_LEVEL > 0
void assertZSerial(const Mesh& mesh, std::string_view name) {
if (mesh.getNZPE() != 1) {
throw BoutException("{} uses FFTs which are currently incompatible with multiple "
"processors in Z (using {})",
name, mesh.getNZPE());
}
}
#endif
} // namespace fft
} // namespace bout
2 changes: 2 additions & 0 deletions src/invert/laplace/impls/cyclic/cyclic_laplace.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,8 @@ LaplaceCyclic::LaplaceCyclic(Options* opt, const CELL_LOC loc, Mesh* mesh_in,
Solver* UNUSED(solver))
: Laplacian(opt, loc, mesh_in), Acoef(0.0), C1coef(1.0), C2coef(1.0), Dcoef(1.0) {

bout::fft::assertZSerial(*localmesh, "`cyclic` inversion");

Acoef.setLocation(location);
C1coef.setLocation(location);
C2coef.setLocation(location);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,12 +67,14 @@ LaplaceIPT::LaplaceIPT(Options* opt, CELL_LOC loc, Mesh* mesh_in, Solver* UNUSED
au(ny, nmode), bu(ny, nmode), rl(nmode), ru(nmode), r1(ny, nmode), r2(ny, nmode),
first_call(ny), x0saved(ny, 4, nmode), converged(nmode), fine_error(4, nmode) {

bout::fft::assertZSerial(*localmesh, "`ipt` inversion");

A.setLocation(location);
C.setLocation(location);
D.setLocation(location);

// Number of procs must be a factor of 2
const int n = localmesh->NXPE;
const int n = localmesh->getNXPE();
if (!is_pow2(n)) {
throw BoutException("LaplaceIPT error: NXPE must be a power of 2");
}
Expand Down
2 changes: 2 additions & 0 deletions src/invert/laplace/impls/pcr/pcr.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,8 @@ LaplacePCR::LaplacePCR(Options* opt, CELL_LOC loc, Mesh* mesh_in, Solver* UNUSED
ncx(localmesh->LocalNx), ny(localmesh->LocalNy), avec(ny, nmode, ncx),
bvec(ny, nmode, ncx), cvec(ny, nmode, ncx) {

bout::fft::assertZSerial(*localmesh, "`pcr` inversion");

Acoef.setLocation(location);
C1coef.setLocation(location);
C2coef.setLocation(location);
Expand Down
2 changes: 2 additions & 0 deletions src/invert/laplace/impls/pcr_thomas/pcr_thomas.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,8 @@ LaplacePCR_THOMAS::LaplacePCR_THOMAS(Options* opt, CELL_LOC loc, Mesh* mesh_in,
ncx(localmesh->LocalNx), ny(localmesh->LocalNy), avec(ny, nmode, ncx),
bvec(ny, nmode, ncx), cvec(ny, nmode, ncx) {

bout::fft::assertZSerial(*localmesh, "`pcr_thomas` inversion");

Acoef.setLocation(location);
C1coef.setLocation(location);
C2coef.setLocation(location);
Expand Down
3 changes: 3 additions & 0 deletions src/invert/laplace/impls/serial_band/serial_band.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,9 @@
LaplaceSerialBand::LaplaceSerialBand(Options* opt, const CELL_LOC loc, Mesh* mesh_in,
Solver* UNUSED(solver))
: Laplacian(opt, loc, mesh_in), Acoef(0.0), Ccoef(1.0), Dcoef(1.0) {

bout::fft::assertZSerial(*localmesh, "`band` inversion");

Acoef.setLocation(location);
Ccoef.setLocation(location);
Dcoef.setLocation(location);
Expand Down
7 changes: 4 additions & 3 deletions src/invert/laplace/impls/serial_tri/serial_tri.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -33,14 +33,15 @@
#include <bout/lapack_routines.hxx>
#include <bout/mesh.hxx>
#include <bout/openmpwrap.hxx>
#include <bout/utils.hxx>
#include <cmath>

#include <bout/output.hxx>
#include <bout/utils.hxx>

LaplaceSerialTri::LaplaceSerialTri(Options* opt, CELL_LOC loc, Mesh* mesh_in,
Solver* UNUSED(solver))
: Laplacian(opt, loc, mesh_in), A(0.0), C(1.0), D(1.0) {

bout::fft::assertZSerial(*localmesh, "`tri` inversion");

A.setLocation(location);
C.setLocation(location);
D.setLocation(location);
Expand Down
15 changes: 9 additions & 6 deletions src/invert/laplace/impls/spt/spt.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,9 @@
LaplaceSPT::LaplaceSPT(Options* opt, const CELL_LOC loc, Mesh* mesh_in,
Solver* UNUSED(solver))
: Laplacian(opt, loc, mesh_in), Acoef(0.0), Ccoef(1.0), Dcoef(1.0) {

bout::fft::assertZSerial(*localmesh, "`spt` inversion");

Acoef.setLocation(location);
Ccoef.setLocation(location);
Dcoef.setLocation(location);
Expand Down Expand Up @@ -341,14 +344,14 @@ int LaplaceSPT::start(const FieldPerp& b, SPT_data& data) {
// Send data
localmesh->sendXOut(std::begin(data.buffer), 4 * (maxmode + 1), data.comm_tag);

} else if (localmesh->PE_XIND == 1) {
} else if (localmesh->getXProcIndex() == 1) {
// Post a receive
data.recv_handle =
localmesh->irecvXIn(std::begin(data.buffer), 4 * (maxmode + 1), data.comm_tag);
}

data.proc++; // Now moved onto the next processor
if (localmesh->NXPE == 2) {
if (localmesh->getNXPE() == 2) {
data.dir = -1; // Special case. Otherwise reversal handled in spt_continue
}

Expand All @@ -366,7 +369,7 @@ int LaplaceSPT::next(SPT_data& data) {
return 1;
}

if (localmesh->PE_XIND == data.proc) {
if (localmesh->getXProcIndex() == data.proc) {
/// This processor's turn to do inversion

// Wait for data to arrive
Expand Down Expand Up @@ -450,7 +453,7 @@ int LaplaceSPT::next(SPT_data& data) {
}
}

if (localmesh->PE_XIND != 0) { // If not finished yet
if (localmesh->getXProcIndex() != 0) { // If not finished yet
/// Send data

if (data.dir > 0) {
Expand All @@ -460,7 +463,7 @@ int LaplaceSPT::next(SPT_data& data) {
}
}

} else if (localmesh->PE_XIND == data.proc + data.dir) {
} else if (localmesh->getXProcIndex() == data.proc + data.dir) {
// This processor is next, post receive

if (data.dir > 0) {
Expand All @@ -474,7 +477,7 @@ int LaplaceSPT::next(SPT_data& data) {

data.proc += data.dir;

if (data.proc == localmesh->NXPE - 1) {
if (data.proc == localmesh->getNXPE() - 1) {
data.dir = -1; // Reverses direction at the end
}

Expand Down
1 change: 1 addition & 0 deletions src/invert/laplacexz/impls/cyclic/laplacexz-cyclic.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
LaplaceXZcyclic::LaplaceXZcyclic(Mesh* m, Options* options, const CELL_LOC loc)
: LaplaceXZ(m, options, loc) {
// Note: `m` may be nullptr, but localmesh is set in LaplaceXZ base constructor
bout::fft::assertZSerial(*localmesh, "`cyclic` X-Z inversion");

// Number of Z Fourier modes, including DC
nmode = (localmesh->LocalNz) / 2 + 1;
Expand Down
2 changes: 2 additions & 0 deletions src/invert/parderiv/impls/cyclic/cyclic.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@

InvertParCR::InvertParCR(Options* opt, CELL_LOC location, Mesh* mesh_in)
: InvertPar(opt, location, mesh_in), A(1.0), B(0.0), C(0.0), D(0.0), E(0.0) {

bout::fft::assertZSerial(*localmesh, "InvertParCR");
// Number of k equations to solve for each x location
nsys = 1 + (localmesh->LocalNz) / 2;

Expand Down
1 change: 1 addition & 0 deletions src/invert/pardiv/impls/cyclic/pardiv_cyclic.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@

InvertParDivCR::InvertParDivCR(Options* opt, CELL_LOC location, Mesh* mesh_in)
: InvertParDiv(opt, location, mesh_in) {
bout::fft::assertZSerial(*localmesh, "InvertParDivCR");
// Number of k equations to solve for each x location
nsys = 1 + (localmesh->LocalNz) / 2;
}
Expand Down
4 changes: 4 additions & 0 deletions src/mesh/boundary_standard.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -2631,6 +2631,7 @@ void BoundaryNeumann_NonOrthogonal::apply(Field3D& f) {
#if not(BOUT_USE_METRIC_3D)
Mesh* mesh = bndry->localmesh;
ASSERT1(mesh == f.getMesh());
bout::fft::assertZSerial(*mesh, "Zero Laplace on Field3D");
int ncz = mesh->LocalNz;

Coordinates* metric = f.getCoordinates();
Expand Down Expand Up @@ -2734,6 +2735,7 @@ void BoundaryNeumann_NonOrthogonal::apply(Field3D& f) {
#if not(BOUT_USE_METRIC_3D)
Mesh* mesh = bndry->localmesh;
ASSERT1(mesh == f.getMesh());
bout::fft::assertZSerial(*mesh, "Zero Laplace on Field3D");
const int ncz = mesh->LocalNz;

ASSERT0(ncz % 2 == 0); // Allocation assumes even number
Expand Down Expand Up @@ -2843,6 +2845,8 @@ void BoundaryNeumann_NonOrthogonal::apply(Field3D& f) {

Mesh* mesh = bndry->localmesh;
ASSERT1(mesh == f.getMesh());
bout::fft::assertZSerial(*mesh, "Zero Laplace on Field3D");

Coordinates* metric = f.getCoordinates();

int ncz = mesh->LocalNz;
Expand Down
4 changes: 2 additions & 2 deletions src/mesh/coordinates.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -1659,7 +1659,7 @@ Field3D Coordinates::Delp2(const Field3D& f, CELL_LOC outloc, bool useFFT) {

Field3D result{emptyFrom(f).setLocation(outloc)};

if (useFFT and not bout::build::use_metric_3d) {
if (useFFT and not bout::build::use_metric_3d and localmesh->getNZPE() == 1) {
int ncz = localmesh->LocalNz;

// Allocate memory
Expand Down Expand Up @@ -1727,7 +1727,7 @@ FieldPerp Coordinates::Delp2(const FieldPerp& f, CELL_LOC outloc, bool useFFT) {
int jy = f.getIndex();
result.setIndex(jy);

if (useFFT) {
if (useFFT and localmesh->getNZPE() == 1) {
int ncz = localmesh->LocalNz;

// Allocate memory
Expand Down
3 changes: 1 addition & 2 deletions src/mesh/data/gridfromoptions.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -146,8 +146,7 @@ bool GridFromOptions::get(Mesh* m, std::vector<BoutReal>& var, const std::string
}
case GridDataSource::Z: {
for (int z = 0; z < len; z++) {
pos.set("z",
(TWOPI * (z - m->OffsetZ + offset)) / static_cast<BoutReal>(m->LocalNz));
pos.set("z", TWOPI * m->GlobalZ(z - m->OffsetZ + offset));
var[z] = gen->generate(pos);
}
break;
Expand Down
Loading
Loading