diff --git a/pyphare/pyphare/pharesee/hierarchy/hierarchy_utils.py b/pyphare/pyphare/pharesee/hierarchy/hierarchy_utils.py index 56fbf94d3..9788f3fb4 100644 --- a/pyphare/pyphare/pharesee/hierarchy/hierarchy_utils.py +++ b/pyphare/pyphare/pharesee/hierarchy/hierarchy_utils.py @@ -1,8 +1,8 @@ -from dataclasses import dataclass +from dataclasses import dataclass, field from copy import deepcopy import numpy as np -from typing import Any +from typing import Any, List, Tuple from .hierarchy import PatchHierarchy, format_timestamp from .patchdata import FieldData, ParticleData @@ -12,6 +12,7 @@ from ...core.gridlayout import GridLayout from ...core.phare_utilities import listify from ...core.phare_utilities import refinement_ratio +from pyphare.core import phare_utilities as phut field_qties = { @@ -561,41 +562,53 @@ def _compute_scalardiv(patch_datas, **kwargs): @dataclass class EqualityReport: - ok: bool - reason: str - ref: Any = None - cmp: Any = None + failed: List[Tuple[str, Any, Any]] = field(default_factory=lambda: []) def __bool__(self): - return self.ok + return not self.failed def __repr__(self): - return self.reason + for msg, ref, cmp in self: + print(msg) + try: + if type(ref) is FieldData: + phut.assert_fp_any_all_close(ref[:], cmp[:], atol=1e-16) + except AssertionError as e: + print(e) + return self.failed[0][0] - def __post_init__(self): - not_nones = [a is not None for a in [self.ref, self.cmp]] - if all(not_nones): - assert id(self.ref) != id(self.cmp) - else: - assert not any(not_nones) + def __call__(self, reason, ref=None, cmp=None): + self.failed.append((reason, ref, cmp)) + return self + + def __getitem__(self, idx): + return (self.failed[idx][1], self.failed[idx][2]) + + def __iter__(self): + return self.failed.__iter__() + + def __reversed__(self): + return reversed(self.failed) def hierarchy_compare(this, that, atol=1e-16): + eqr = EqualityReport() + if not isinstance(this, PatchHierarchy) or not isinstance(that, PatchHierarchy): - return EqualityReport(False, "class type mismatch") + return eqr("class type mismatch") if this.ndim != that.ndim or this.domain_box != that.domain_box: - return EqualityReport(False, "dimensional mismatch") + return eqr("dimensional mismatch") if this.time_hier.keys() != that.time_hier.keys(): - return EqualityReport(False, "timesteps mismatch") + return eqr("timesteps mismatch") for tidx in this.times(): patch_levels_ref = this.time_hier[tidx] patch_levels_cmp = that.time_hier[tidx] if patch_levels_ref.keys() != patch_levels_cmp.keys(): - return EqualityReport(False, "levels mismatch") + return eqr("levels mismatch") for level_idx in patch_levels_cmp.keys(): patch_level_ref = patch_levels_ref[level_idx] @@ -606,7 +619,7 @@ def hierarchy_compare(this, that, atol=1e-16): patch_cmp = patch_level_cmp.patches[patch_idx] if patch_ref.patch_datas.keys() != patch_cmp.patch_datas.keys(): - return EqualityReport(False, "data keys mismatch") + return eqr("data keys mismatch") for patch_data_key in patch_ref.patch_datas.keys(): patch_data_ref = patch_ref.patch_datas[patch_data_key] @@ -614,11 +627,12 @@ def hierarchy_compare(this, that, atol=1e-16): if not patch_data_cmp.compare(patch_data_ref, atol=atol): msg = f"data mismatch: {type(patch_data_ref).__name__} {patch_data_key}" - return EqualityReport( - False, msg, patch_data_cmp, patch_data_ref - ) + eqr(msg, patch_data_cmp, patch_data_ref) + + if not eqr: + return eqr - return EqualityReport(True, "OK") + return eqr def single_patch_for_LO(hier, qties=None, skip=None): diff --git a/src/amr/data/field/initializers/samrai_hdf5_field_initializer.hpp b/src/amr/data/field/initializers/samrai_hdf5_field_initializer.hpp index bb0184a74..ffce2ab5c 100644 --- a/src/amr/data/field/initializers/samrai_hdf5_field_initializer.hpp +++ b/src/amr/data/field/initializers/samrai_hdf5_field_initializer.hpp @@ -31,9 +31,6 @@ template void SamraiHDF5FieldInitializer::load(Field_t& field, GridLayout const& layout) const { - auto const local_cell - = [&](auto const& box, auto const& point) { return layout.AMRToLocal(point, box); }; - auto const& dest_box = layout.AMRBox(); auto const& centering = layout.centering(field.physicalQuantity()); auto const& overlaps = SamraiH5Interface::INSTANCE().box_intersections(dest_box); @@ -44,7 +41,7 @@ void SamraiHDF5FieldInitializer::load(Field_t& field, auto const src_box = pdata.box; auto const data = h5File.template read_data_set_flat( pdata.base_path + "/" + field.name() + "##default/field_" + field.name()); - core::Box const lcl_src_box{ + core::Box const lcl_src_gbox{ core::Point{core::ConstArray()}, core::Point{ core::for_N([&](auto i) { @@ -52,9 +49,14 @@ void SamraiHDF5FieldInitializer::load(Field_t& field, src_box.upper[i] - src_box.lower[i] + (GridLayout::nbrGhosts() * 2) + (centering[i] == core::QtyCentering::primal ? 1 : 0)); })}}; - auto data_view = core::make_array_view(data.data(), *lcl_src_box.shape()); - for (auto const& point : overlap_box) - field(local_cell(dest_box, point)) = data_view(local_cell(src_box, point)); + auto const data_view = core::make_array_view(data.data(), *lcl_src_gbox.shape()); + auto const overlap_gb = grow(overlap_box, GridLayout::nbrGhosts()); + auto const lcl_src_box = layout.AMRToLocal(overlap_gb, src_box); + auto const lcl_dst_box = layout.AMRToLocal(overlap_gb, dest_box); + auto src_it = lcl_src_box.begin(); + auto dst_it = lcl_dst_box.begin(); + for (; src_it != lcl_src_box.end(); ++src_it, ++dst_it) + field(*dst_it) = data_view(*src_it); } } diff --git a/src/amr/level_initializer/hybrid_level_initializer.hpp b/src/amr/level_initializer/hybrid_level_initializer.hpp index e0ef8386a..3acdf4fcd 100644 --- a/src/amr/level_initializer/hybrid_level_initializer.hpp +++ b/src/amr/level_initializer/hybrid_level_initializer.hpp @@ -4,11 +4,9 @@ #include "amr/level_initializer/level_initializer.hpp" #include "amr/messengers/hybrid_messenger.hpp" #include "amr/messengers/messenger.hpp" -#include "amr/physical_models/hybrid_model.hpp" #include "amr/physical_models/physical_model.hpp" #include "amr/resources_manager/amr_utils.hpp" #include "core/data/grid/gridlayout_utils.hpp" -#include "core/data/ions/ions.hpp" #include "core/numerics/ampere/ampere.hpp" #include "core/numerics/interpolator/interpolator.hpp" #include "core/numerics/moments/moments.hpp" @@ -43,10 +41,12 @@ namespace solver : ohm_{dict["algo"]["ohm"]} { } - virtual void initialize(std::shared_ptr const& hierarchy, int levelNumber, - std::shared_ptr const& oldLevel, IPhysicalModelT& model, - amr::IMessenger& messenger, double initDataTime, - bool isRegridding) override + + + void initialize(std::shared_ptr const& hierarchy, int levelNumber, + std::shared_ptr const& oldLevel, IPhysicalModelT& model, + amr::IMessenger& messenger, double initDataTime, + bool isRegridding) override { core::Interpolator interpolate_; auto& hybridModel = static_cast(model); @@ -163,6 +163,8 @@ namespace solver hybMessenger.prepareStep(hybridModel, level, initDataTime); } }; + + } // namespace solver } // namespace PHARE diff --git a/src/core/data/grid/gridlayout.hpp b/src/core/data/grid/gridlayout.hpp index 564bf909c..48c06aa0f 100644 --- a/src/core/data/grid/gridlayout.hpp +++ b/src/core/data/grid/gridlayout.hpp @@ -832,17 +832,19 @@ namespace core * This method only deals with **cell** indexes. */ template - NO_DISCARD auto AMRToLocal(Box const& AMRBox) const + NO_DISCARD auto AMRToLocal(Box const& AMRBox, + Box const& localbox) const { static_assert(std::is_integral_v, "Error, must be MeshIndex (integral Point)"); - auto localBox = Box{}; - - localBox.lower = AMRToLocal(AMRBox.lower); - localBox.upper = AMRToLocal(AMRBox.upper); - - return localBox; + return Box{AMRToLocal(AMRBox.lower, localbox), + AMRToLocal(AMRBox.upper, localbox)}; } + template + NO_DISCARD auto AMRToLocal(Box const& AMRBox) const + { + return AMRToLocal(AMRBox, AMRBox_); + } template @@ -1171,6 +1173,22 @@ namespace core evalOnBox_(field, fn, indices); } + template + auto domainBoxFor(Field const& field) const + { + return _BoxFor(field, [&](auto const& centering, auto const direction) { + return this->physicalStartToEnd(centering, direction); + }); + } + + template + auto ghostBoxFor(Field const& field) const + { + return _BoxFor(field, [&](auto const& centering, auto const direction) { + return this->ghostStartToEnd(centering, direction); + }); + } + private: template @@ -1206,6 +1224,20 @@ namespace core } + template + auto _BoxFor(Field const& field, Fn startToEnd) const + { + constexpr auto directions = std::array{Direction::X, Direction::Y, Direction::Z}; + std::array lower, upper; + core::for_N([&](auto i) { + auto const [i0, i1] = startToEnd(field, directions[i]); + lower[i] = i0; + upper[i] = i1; + }); + return Box{lower, upper}; + } + + template auto StartToEndIndices_(Centering const& centering, StartToEnd const&& startToEnd, bool const includeEnd = false) const diff --git a/tests/simulator/test_init_from_restart.py b/tests/simulator/test_init_from_restart.py index ec7d232ba..8287ab089 100644 --- a/tests/simulator/test_init_from_restart.py +++ b/tests/simulator/test_init_from_restart.py @@ -5,7 +5,7 @@ import numpy as np import pyphare.pharein as ph -from pyphare.core import phare_utilities as phut + from pyphare.simulator.simulator import Simulator from pyphare.pharesee.hierarchy.patchdata import FieldData, ParticleData from pyphare.pharesee.hierarchy.fromh5 import get_all_available_quantities_from_h5 @@ -24,11 +24,9 @@ cells = 200 first_out = "phare_outputs/reinit/first" secnd_out = "phare_outputs/reinit/secnd" -# timestamps = [0,time_step] timestamps = np.arange(0, final_time + time_step, time_step) restart_idx = Z = 2 simInitArgs = dict( - largest_patch_size=100, time_step_nbr=time_step_nbr, time_step=time_step, cells=cells, @@ -41,7 +39,7 @@ def setup_model(sim): model = ph.MaxwellianFluidModel( protons={"mass": 1, "charge": 1, "nbr_part_per_cell": ppc}, - alpha={"mass": 4.0, "charge": 1, "nbr_part_per_cell": ppc}, + alpha={"mass": 4, "charge": 1, "nbr_part_per_cell": ppc}, ) ph.ElectronModel(closure="isothermal", Te=0.12) dump_all_diags(model.populations, timestamps=timestamps) @@ -65,7 +63,7 @@ def test_reinit(self): sim = ph.Simulation(**copy.deepcopy(simInitArgs)) setup_model(sim) Simulator(sim).run().reset() - fidx, sidx = 2, 0 + fidx, sidx = 4, 2 datahier0 = get_all_available_quantities_from_h5(first_out, timestamps[fidx]) datahier0.time_hier = { # swap times format_timestamp(timestamps[sidx]): datahier0.time_hier[ @@ -73,15 +71,10 @@ def test_reinit(self): ] } datahier1 = get_all_available_quantities_from_h5(secnd_out, timestamps[sidx]) - qties = ["protons_domain", "alpha_domain", "Bx", "By", "Bz"] - skip = None # ["protons_patchGhost", "alpha_patchGhost"] + qties = None + skip = ["protons_patchGhost", "alpha_patchGhost"] ds = [single_patch_for_LO(d, qties, skip) for d in [datahier0, datahier1]] - eq = hierarchy_compare(*ds, atol=1e-14) - if not eq: - print(eq) - if type(eq.ref) == FieldData: - phut.assert_fp_any_all_close(eq.ref[:], eq.cmp[:], atol=1e-16) - self.assertTrue(eq) + self.assertTrue(hierarchy_compare(*ds, atol=1e-12)) def run_first_sim():