From 63eef4f728535f0515c397b63cab9aedbb0135c2 Mon Sep 17 00:00:00 2001 From: David Grote Date: Tue, 10 Sep 2024 10:55:07 -0700 Subject: [PATCH] Clean up in SpectralFieldData for multi-dimensions --- .../SpectralSolver/SpectralFieldData.H | 9 +- .../SpectralSolver/SpectralFieldData.cpp | 143 ++++++------------ 2 files changed, 52 insertions(+), 100 deletions(-) diff --git a/Source/FieldSolver/SpectralSolver/SpectralFieldData.H b/Source/FieldSolver/SpectralSolver/SpectralFieldData.H index 8ced43ced94..d6c4916bdac 100644 --- a/Source/FieldSolver/SpectralSolver/SpectralFieldData.H +++ b/Source/FieldSolver/SpectralSolver/SpectralFieldData.H @@ -176,11 +176,10 @@ class SpectralFieldData ablastr::math::anyfft::FFTplans forward_plan, backward_plan; // Correcting "shift" factors when performing FFT from/to // a cell-centered grid in real space, instead of a nodal grid - SpectralShiftFactor xshift_FFTfromCell, xshift_FFTtoCell, - zshift_FFTfromCell, zshift_FFTtoCell; -#if defined(WARPX_DIM_3D) - SpectralShiftFactor yshift_FFTfromCell, yshift_FFTtoCell; -#endif + // (0,1,2) is the dimension number + SpectralShiftFactor shift0_FFTfromCell, shift0_FFTtoCell, + shift1_FFTfromCell, shift1_FFTtoCell, + shift2_FFTfromCell, shift2_FFTtoCell; bool m_periodic_single_box; }; diff --git a/Source/FieldSolver/SpectralSolver/SpectralFieldData.cpp b/Source/FieldSolver/SpectralSolver/SpectralFieldData.cpp index 20c97f4b5d4..8e7b9ed9ae4 100644 --- a/Source/FieldSolver/SpectralSolver/SpectralFieldData.cpp +++ b/Source/FieldSolver/SpectralSolver/SpectralFieldData.cpp @@ -142,24 +142,21 @@ SpectralFieldData::SpectralFieldData( const int lev, // By default, we assume the FFT is done from/to a nodal grid in real space // If the FFT is performed from/to a cell-centered grid in real space, // a correcting "shift" factor must be applied in spectral space. - xshift_FFTfromCell = k_space.getSpectralShiftFactor(dm, 0, + shift0_FFTfromCell = k_space.getSpectralShiftFactor(dm, 0, ShiftType::TransformFromCellCentered); - xshift_FFTtoCell = k_space.getSpectralShiftFactor(dm, 0, + shift0_FFTtoCell = k_space.getSpectralShiftFactor(dm, 0, ShiftType::TransformToCellCentered); -#if defined(WARPX_DIM_3D) - yshift_FFTfromCell = k_space.getSpectralShiftFactor(dm, 1, +#if AMREX_SPACEDIM > 1 + shift1_FFTfromCell = k_space.getSpectralShiftFactor(dm, 1, ShiftType::TransformFromCellCentered); - yshift_FFTtoCell = k_space.getSpectralShiftFactor(dm, 1, + shift1_FFTtoCell = k_space.getSpectralShiftFactor(dm, 1, ShiftType::TransformToCellCentered); - zshift_FFTfromCell = k_space.getSpectralShiftFactor(dm, 2, +#if AMREX_SPACEDIM > 2 + shift2_FFTfromCell = k_space.getSpectralShiftFactor(dm, 2, ShiftType::TransformFromCellCentered); - zshift_FFTtoCell = k_space.getSpectralShiftFactor(dm, 2, - ShiftType::TransformToCellCentered); -#else - zshift_FFTfromCell = k_space.getSpectralShiftFactor(dm, 1, - ShiftType::TransformFromCellCentered); - zshift_FFTtoCell = k_space.getSpectralShiftFactor(dm, 1, + shift2_FFTtoCell = k_space.getSpectralShiftFactor(dm, 2, ShiftType::TransformToCellCentered); +#endif #endif // Allocate and initialize the FFT plans @@ -221,16 +218,12 @@ SpectralFieldData::ForwardTransform (const int lev, const bool do_costs = WarpXUtilLoadBalance::doCosts(cost, mf.boxArray(), mf.DistributionMap()); // Check field index type, in order to apply proper shift in spectral space -#if (AMREX_SPACEDIM >= 2) - const bool is_nodal_x = mf.is_nodal(0); + const bool is_nodal_0 = mf.is_nodal(0); +#if AMREX_SPACEDIM > 1 + const bool is_nodal_1 = mf.is_nodal(1); +#if AMREX_SPACEDIM > 2 + const bool is_nodal_2 = mf.is_nodal(2); #endif -#if defined(WARPX_DIM_3D) - const bool is_nodal_y = mf.is_nodal(1); - const bool is_nodal_z = mf.is_nodal(2); -#elif defined(WARPX_DIM_XZ) || defined(WARPX_DIM_RZ) - const bool is_nodal_z = mf.is_nodal(1); -#elif defined(WARPX_DIM_1D_Z) - const bool is_nodal_z = mf.is_nodal(0); #endif // Loop over boxes @@ -275,13 +268,14 @@ SpectralFieldData::ForwardTransform (const int lev, { const Array4 fields_arr = SpectralFieldData::fields[mfi].array(); const Array4 tmp_arr = tmpSpectralField[mfi].array(); -#if (AMREX_SPACEDIM >= 2) - const Complex* xshift_arr = xshift_FFTfromCell[mfi].dataPtr(); + + const Complex* shift0_arr = shift0_FFTfromCell[mfi].dataPtr(); +#if AMREX_SPACEDIM > 1 + const Complex* shift1_arr = shift1_FFTfromCell[mfi].dataPtr(); +#if AMREX_SPACEDIM > 2 + const Complex* shift2_arr = shift2_FFTfromCell[mfi].dataPtr(); #endif -#if defined(WARPX_DIM_3D) - const Complex* yshift_arr = yshift_FFTfromCell[mfi].dataPtr(); #endif - const Complex* zshift_arr = zshift_FFTfromCell[mfi].dataPtr(); // Loop over indices within one box const Box spectralspace_bx = tmpSpectralField[mfi].box(); @@ -289,16 +283,12 @@ SpectralFieldData::ForwardTransform (const int lev, [=] AMREX_GPU_DEVICE(int i, int j, int k) noexcept { Complex spectral_field_value = tmp_arr(i,j,k); // Apply proper shift in each dimension -#if (AMREX_SPACEDIM >= 2) - if (!is_nodal_x) { spectral_field_value *= xshift_arr[i]; } + if (!is_nodal_0) { spectral_field_value *= shift0_arr[i]; } +#if AMREX_SPACEDIM > 1 + if (!is_nodal_1) { spectral_field_value *= shift1_arr[j]; } +#if AMREX_SPACEDIM > 2 + if (!is_nodal_2) { spectral_field_value *= shift2_arr[k]; } #endif -#if defined(WARPX_DIM_3D) - if (!is_nodal_y) { spectral_field_value *= yshift_arr[j]; } - if (!is_nodal_z) { spectral_field_value *= zshift_arr[k]; } -#elif defined(WARPX_DIM_XZ) || defined(WARPX_DIM_RZ) - if (!is_nodal_z) { spectral_field_value *= zshift_arr[j]; } -#elif defined(WARPX_DIM_1D_Z) - if (!is_nodal_z) { spectral_field_value *= zshift_arr[i]; } #endif // Copy field into the right index fields_arr(i,j,k,field_index) = spectral_field_value; @@ -328,32 +318,9 @@ SpectralFieldData::BackwardTransform (const int lev, const bool do_costs = WarpXUtilLoadBalance::doCosts(cost, mf.boxArray(), mf.DistributionMap()); // Check field index type, in order to apply proper shift in spectral space -#if (AMREX_SPACEDIM >= 2) - const bool is_nodal_x = mf.is_nodal(0); -#endif -#if defined(WARPX_DIM_3D) - const bool is_nodal_y = mf.is_nodal(1); - const bool is_nodal_z = mf.is_nodal(2); -#elif defined(WARPX_DIM_XZ) || defined(WARPX_DIM_RZ) - const bool is_nodal_z = mf.is_nodal(1); -#elif defined(WARPX_DIM_1D_Z) - const bool is_nodal_z = mf.is_nodal(0); -#endif - -#if (AMREX_SPACEDIM >= 2) - const int si = (is_nodal_x) ? 1 : 0; -#endif -#if defined(WARPX_DIM_1D_Z) - const int si = (is_nodal_z) ? 1 : 0; - const int sj = 0; - const int sk = 0; -#elif defined(WARPX_DIM_XZ) || defined(WARPX_DIM_RZ) - const int sj = (is_nodal_z) ? 1 : 0; - const int sk = 0; -#elif defined(WARPX_DIM_3D) - const int sj = (is_nodal_y) ? 1 : 0; - const int sk = (is_nodal_z) ? 1 : 0; -#endif + const bool is_nodal_0 = mf.is_nodal(0); + const bool is_nodal_1 = (AMREX_SPACEDIM > 1 ? mf.is_nodal(1) : 0); + const bool is_nodal_2 = (AMREX_SPACEDIM > 2 ? mf.is_nodal(2) : 0); // Numbers of guard cells const amrex::IntVect& mf_ng = mf.nGrowVect(); @@ -375,13 +342,13 @@ SpectralFieldData::BackwardTransform (const int lev, { const Array4 field_arr = SpectralFieldData::fields[mfi].array(); const Array4 tmp_arr = tmpSpectralField[mfi].array(); -#if (AMREX_SPACEDIM >= 2) - const Complex* xshift_arr = xshift_FFTtoCell[mfi].dataPtr(); + const Complex* shift0_arr = shift0_FFTtoCell[mfi].dataPtr(); +#if AMREX_SPACEDIM > 1 + const Complex* shift1_arr = shift1_FFTtoCell[mfi].dataPtr(); +#if AMREX_SPACEDIM > 2 + const Complex* shift2_arr = shift2_FFTtoCell[mfi].dataPtr(); #endif -#if defined(WARPX_DIM_3D) - const Complex* yshift_arr = yshift_FFTtoCell[mfi].dataPtr(); #endif - const Complex* zshift_arr = zshift_FFTtoCell[mfi].dataPtr(); // Loop over indices within one box const Box spectralspace_bx = tmpSpectralField[mfi].box(); @@ -389,16 +356,12 @@ SpectralFieldData::BackwardTransform (const int lev, [=] AMREX_GPU_DEVICE(int i, int j, int k) noexcept { Complex spectral_field_value = field_arr(i,j,k,field_index); // Apply proper shift in each dimension -#if (AMREX_SPACEDIM >= 2) - if (!is_nodal_x) { spectral_field_value *= xshift_arr[i]; } + if (!is_nodal_0) { spectral_field_value *= shift0_arr[i]; } +#if AMREX_SPACEDIM > 1 + if (!is_nodal_1) { spectral_field_value *= shift1_arr[j]; } +#if AMREX_SPACEDIM > 2 + if (!is_nodal_2) { spectral_field_value *= shift2_arr[k]; } #endif -#if defined(WARPX_DIM_3D) - if (!is_nodal_y) { spectral_field_value *= yshift_arr[j]; } - if (!is_nodal_z) { spectral_field_value *= zshift_arr[k]; } -#elif defined(WARPX_DIM_XZ) || defined(WARPX_DIM_RZ) - if (!is_nodal_z) { spectral_field_value *= zshift_arr[j]; } -#elif defined(WARPX_DIM_1D_Z) - if (!is_nodal_z) { spectral_field_value *= zshift_arr[i]; } #endif // Copy field into temporary array tmp_arr(i,j,k) = spectral_field_value; @@ -419,28 +382,18 @@ SpectralFieldData::BackwardTransform (const int lev, // Total number of cells, including ghost cells (nj represents ny in 3D and nz in 2D) const int ni = mf_box.length(0); -#if defined(WARPX_DIM_1D_Z) - constexpr int nj = 1; - constexpr int nk = 1; -#elif defined(WARPX_DIM_XZ) || defined(WARPX_DIM_RZ) - const int nj = mf_box.length(1); - constexpr int nk = 1; -#elif defined(WARPX_DIM_3D) - const int nj = mf_box.length(1); - const int nk = mf_box.length(2); -#endif + const int nj = (AMREX_SPACEDIM > 1 ? mf_box.length(1) : 1); + const int nk = (AMREX_SPACEDIM > 2 ? mf_box.length(2) : 1); + + const int si = (is_nodal_0) ? 1 : 0; + const int sj = (is_nodal_1) ? 1 : 0; + const int sk = (is_nodal_2) ? 1 : 0; + // Lower bound of the box (lo_j represents lo_y in 3D and lo_z in 2D) const int lo_i = amrex::lbound(mf_box).x; -#if defined(WARPX_DIM_1D_Z) - constexpr int lo_j = 0; - constexpr int lo_k = 0; -#elif defined(WARPX_DIM_XZ) || defined(WARPX_DIM_RZ) - const int lo_j = amrex::lbound(mf_box).y; - constexpr int lo_k = 0; -#elif defined(WARPX_DIM_3D) - const int lo_j = amrex::lbound(mf_box).y; - const int lo_k = amrex::lbound(mf_box).z; -#endif + const int lo_j = (AMREX_SPACEDIM > 1 ? amrex::lbound(mf_box).y : 0); + const int lo_k = (AMREX_SPACEDIM > 2 ? amrex::lbound(mf_box).z : 0); + // If necessary, do not fill the guard cells // (shrink box by passing negative number of cells) if (!m_periodic_single_box)