Skip to content

Commit

Permalink
Clean up in SpectralFieldData for multi-dimensions
Browse files Browse the repository at this point in the history
  • Loading branch information
dpgrote committed Sep 10, 2024
1 parent 50e1326 commit 63eef4f
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 100 deletions.
9 changes: 4 additions & 5 deletions Source/FieldSolver/SpectralSolver/SpectralFieldData.H
Original file line number Diff line number Diff line change
Expand Up @@ -176,11 +176,10 @@ class SpectralFieldData
ablastr::math::anyfft::FFTplans forward_plan, backward_plan;
// Correcting "shift" factors when performing FFT from/to
// a cell-centered grid in real space, instead of a nodal grid
SpectralShiftFactor xshift_FFTfromCell, xshift_FFTtoCell,
zshift_FFTfromCell, zshift_FFTtoCell;
#if defined(WARPX_DIM_3D)
SpectralShiftFactor yshift_FFTfromCell, yshift_FFTtoCell;
#endif
// (0,1,2) is the dimension number
SpectralShiftFactor shift0_FFTfromCell, shift0_FFTtoCell,
shift1_FFTfromCell, shift1_FFTtoCell,
shift2_FFTfromCell, shift2_FFTtoCell;

bool m_periodic_single_box;
};
Expand Down
143 changes: 48 additions & 95 deletions Source/FieldSolver/SpectralSolver/SpectralFieldData.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -142,24 +142,21 @@ SpectralFieldData::SpectralFieldData( const int lev,
// By default, we assume the FFT is done from/to a nodal grid in real space
// If the FFT is performed from/to a cell-centered grid in real space,
// a correcting "shift" factor must be applied in spectral space.
xshift_FFTfromCell = k_space.getSpectralShiftFactor(dm, 0,
shift0_FFTfromCell = k_space.getSpectralShiftFactor(dm, 0,
ShiftType::TransformFromCellCentered);
xshift_FFTtoCell = k_space.getSpectralShiftFactor(dm, 0,
shift0_FFTtoCell = k_space.getSpectralShiftFactor(dm, 0,
ShiftType::TransformToCellCentered);
#if defined(WARPX_DIM_3D)
yshift_FFTfromCell = k_space.getSpectralShiftFactor(dm, 1,
#if AMREX_SPACEDIM > 1
shift1_FFTfromCell = k_space.getSpectralShiftFactor(dm, 1,
ShiftType::TransformFromCellCentered);
yshift_FFTtoCell = k_space.getSpectralShiftFactor(dm, 1,
shift1_FFTtoCell = k_space.getSpectralShiftFactor(dm, 1,
ShiftType::TransformToCellCentered);
zshift_FFTfromCell = k_space.getSpectralShiftFactor(dm, 2,
#if AMREX_SPACEDIM > 2
shift2_FFTfromCell = k_space.getSpectralShiftFactor(dm, 2,
ShiftType::TransformFromCellCentered);
zshift_FFTtoCell = k_space.getSpectralShiftFactor(dm, 2,
ShiftType::TransformToCellCentered);
#else
zshift_FFTfromCell = k_space.getSpectralShiftFactor(dm, 1,
ShiftType::TransformFromCellCentered);
zshift_FFTtoCell = k_space.getSpectralShiftFactor(dm, 1,
shift2_FFTtoCell = k_space.getSpectralShiftFactor(dm, 2,
ShiftType::TransformToCellCentered);
#endif
#endif

// Allocate and initialize the FFT plans
Expand Down Expand Up @@ -221,16 +218,12 @@ SpectralFieldData::ForwardTransform (const int lev,
const bool do_costs = WarpXUtilLoadBalance::doCosts(cost, mf.boxArray(), mf.DistributionMap());

// Check field index type, in order to apply proper shift in spectral space
#if (AMREX_SPACEDIM >= 2)
const bool is_nodal_x = mf.is_nodal(0);
const bool is_nodal_0 = mf.is_nodal(0);
#if AMREX_SPACEDIM > 1
const bool is_nodal_1 = mf.is_nodal(1);
#if AMREX_SPACEDIM > 2
const bool is_nodal_2 = mf.is_nodal(2);
#endif
#if defined(WARPX_DIM_3D)
const bool is_nodal_y = mf.is_nodal(1);
const bool is_nodal_z = mf.is_nodal(2);
#elif defined(WARPX_DIM_XZ) || defined(WARPX_DIM_RZ)
const bool is_nodal_z = mf.is_nodal(1);
#elif defined(WARPX_DIM_1D_Z)
const bool is_nodal_z = mf.is_nodal(0);
#endif

// Loop over boxes
Expand Down Expand Up @@ -275,30 +268,27 @@ SpectralFieldData::ForwardTransform (const int lev,
{
const Array4<Complex> fields_arr = SpectralFieldData::fields[mfi].array();
const Array4<const Complex> tmp_arr = tmpSpectralField[mfi].array();
#if (AMREX_SPACEDIM >= 2)
const Complex* xshift_arr = xshift_FFTfromCell[mfi].dataPtr();

const Complex* shift0_arr = shift0_FFTfromCell[mfi].dataPtr();
#if AMREX_SPACEDIM > 1
const Complex* shift1_arr = shift1_FFTfromCell[mfi].dataPtr();
#if AMREX_SPACEDIM > 2
const Complex* shift2_arr = shift2_FFTfromCell[mfi].dataPtr();
#endif
#if defined(WARPX_DIM_3D)
const Complex* yshift_arr = yshift_FFTfromCell[mfi].dataPtr();
#endif
const Complex* zshift_arr = zshift_FFTfromCell[mfi].dataPtr();
// Loop over indices within one box
const Box spectralspace_bx = tmpSpectralField[mfi].box();

ParallelFor( spectralspace_bx,
[=] AMREX_GPU_DEVICE(int i, int j, int k) noexcept {
Complex spectral_field_value = tmp_arr(i,j,k);
// Apply proper shift in each dimension
#if (AMREX_SPACEDIM >= 2)
if (!is_nodal_x) { spectral_field_value *= xshift_arr[i]; }
if (!is_nodal_0) { spectral_field_value *= shift0_arr[i]; }
#if AMREX_SPACEDIM > 1
if (!is_nodal_1) { spectral_field_value *= shift1_arr[j]; }
#if AMREX_SPACEDIM > 2
if (!is_nodal_2) { spectral_field_value *= shift2_arr[k]; }
#endif
#if defined(WARPX_DIM_3D)
if (!is_nodal_y) { spectral_field_value *= yshift_arr[j]; }
if (!is_nodal_z) { spectral_field_value *= zshift_arr[k]; }
#elif defined(WARPX_DIM_XZ) || defined(WARPX_DIM_RZ)
if (!is_nodal_z) { spectral_field_value *= zshift_arr[j]; }
#elif defined(WARPX_DIM_1D_Z)
if (!is_nodal_z) { spectral_field_value *= zshift_arr[i]; }
#endif
// Copy field into the right index
fields_arr(i,j,k,field_index) = spectral_field_value;
Expand Down Expand Up @@ -328,32 +318,9 @@ SpectralFieldData::BackwardTransform (const int lev,
const bool do_costs = WarpXUtilLoadBalance::doCosts(cost, mf.boxArray(), mf.DistributionMap());

// Check field index type, in order to apply proper shift in spectral space
#if (AMREX_SPACEDIM >= 2)
const bool is_nodal_x = mf.is_nodal(0);
#endif
#if defined(WARPX_DIM_3D)
const bool is_nodal_y = mf.is_nodal(1);
const bool is_nodal_z = mf.is_nodal(2);
#elif defined(WARPX_DIM_XZ) || defined(WARPX_DIM_RZ)
const bool is_nodal_z = mf.is_nodal(1);
#elif defined(WARPX_DIM_1D_Z)
const bool is_nodal_z = mf.is_nodal(0);
#endif

#if (AMREX_SPACEDIM >= 2)
const int si = (is_nodal_x) ? 1 : 0;
#endif
#if defined(WARPX_DIM_1D_Z)
const int si = (is_nodal_z) ? 1 : 0;
const int sj = 0;
const int sk = 0;
#elif defined(WARPX_DIM_XZ) || defined(WARPX_DIM_RZ)
const int sj = (is_nodal_z) ? 1 : 0;
const int sk = 0;
#elif defined(WARPX_DIM_3D)
const int sj = (is_nodal_y) ? 1 : 0;
const int sk = (is_nodal_z) ? 1 : 0;
#endif
const bool is_nodal_0 = mf.is_nodal(0);
const bool is_nodal_1 = (AMREX_SPACEDIM > 1 ? mf.is_nodal(1) : 0);
const bool is_nodal_2 = (AMREX_SPACEDIM > 2 ? mf.is_nodal(2) : 0);

// Numbers of guard cells
const amrex::IntVect& mf_ng = mf.nGrowVect();
Expand All @@ -375,30 +342,26 @@ SpectralFieldData::BackwardTransform (const int lev,
{
const Array4<const Complex> field_arr = SpectralFieldData::fields[mfi].array();
const Array4<Complex> tmp_arr = tmpSpectralField[mfi].array();
#if (AMREX_SPACEDIM >= 2)
const Complex* xshift_arr = xshift_FFTtoCell[mfi].dataPtr();
const Complex* shift0_arr = shift0_FFTtoCell[mfi].dataPtr();
#if AMREX_SPACEDIM > 1
const Complex* shift1_arr = shift1_FFTtoCell[mfi].dataPtr();
#if AMREX_SPACEDIM > 2
const Complex* shift2_arr = shift2_FFTtoCell[mfi].dataPtr();
#endif
#if defined(WARPX_DIM_3D)
const Complex* yshift_arr = yshift_FFTtoCell[mfi].dataPtr();
#endif
const Complex* zshift_arr = zshift_FFTtoCell[mfi].dataPtr();
// Loop over indices within one box
const Box spectralspace_bx = tmpSpectralField[mfi].box();

ParallelFor( spectralspace_bx,
[=] AMREX_GPU_DEVICE(int i, int j, int k) noexcept {
Complex spectral_field_value = field_arr(i,j,k,field_index);
// Apply proper shift in each dimension
#if (AMREX_SPACEDIM >= 2)
if (!is_nodal_x) { spectral_field_value *= xshift_arr[i]; }
if (!is_nodal_0) { spectral_field_value *= shift0_arr[i]; }
#if AMREX_SPACEDIM > 1
if (!is_nodal_1) { spectral_field_value *= shift1_arr[j]; }
#if AMREX_SPACEDIM > 2
if (!is_nodal_2) { spectral_field_value *= shift2_arr[k]; }
#endif
#if defined(WARPX_DIM_3D)
if (!is_nodal_y) { spectral_field_value *= yshift_arr[j]; }
if (!is_nodal_z) { spectral_field_value *= zshift_arr[k]; }
#elif defined(WARPX_DIM_XZ) || defined(WARPX_DIM_RZ)
if (!is_nodal_z) { spectral_field_value *= zshift_arr[j]; }
#elif defined(WARPX_DIM_1D_Z)
if (!is_nodal_z) { spectral_field_value *= zshift_arr[i]; }
#endif
// Copy field into temporary array
tmp_arr(i,j,k) = spectral_field_value;
Expand All @@ -419,28 +382,18 @@ SpectralFieldData::BackwardTransform (const int lev,

// Total number of cells, including ghost cells (nj represents ny in 3D and nz in 2D)
const int ni = mf_box.length(0);
#if defined(WARPX_DIM_1D_Z)
constexpr int nj = 1;
constexpr int nk = 1;
#elif defined(WARPX_DIM_XZ) || defined(WARPX_DIM_RZ)
const int nj = mf_box.length(1);
constexpr int nk = 1;
#elif defined(WARPX_DIM_3D)
const int nj = mf_box.length(1);
const int nk = mf_box.length(2);
#endif
const int nj = (AMREX_SPACEDIM > 1 ? mf_box.length(1) : 1);
const int nk = (AMREX_SPACEDIM > 2 ? mf_box.length(2) : 1);

const int si = (is_nodal_0) ? 1 : 0;
const int sj = (is_nodal_1) ? 1 : 0;
const int sk = (is_nodal_2) ? 1 : 0;

// Lower bound of the box (lo_j represents lo_y in 3D and lo_z in 2D)
const int lo_i = amrex::lbound(mf_box).x;
#if defined(WARPX_DIM_1D_Z)
constexpr int lo_j = 0;
constexpr int lo_k = 0;
#elif defined(WARPX_DIM_XZ) || defined(WARPX_DIM_RZ)
const int lo_j = amrex::lbound(mf_box).y;
constexpr int lo_k = 0;
#elif defined(WARPX_DIM_3D)
const int lo_j = amrex::lbound(mf_box).y;
const int lo_k = amrex::lbound(mf_box).z;
#endif
const int lo_j = (AMREX_SPACEDIM > 1 ? amrex::lbound(mf_box).y : 0);
const int lo_k = (AMREX_SPACEDIM > 2 ? amrex::lbound(mf_box).z : 0);

// If necessary, do not fill the guard cells
// (shrink box by passing negative number of cells)
if (!m_periodic_single_box)
Expand Down

0 comments on commit 63eef4f

Please sign in to comment.