Skip to content

Commit

Permalink
~~
Browse files Browse the repository at this point in the history
  • Loading branch information
PhilipDeegan committed Sep 17, 2024
1 parent 9049ef5 commit ca1f3af
Show file tree
Hide file tree
Showing 5 changed files with 99 additions and 56 deletions.
60 changes: 37 additions & 23 deletions pyphare/pyphare/pharesee/hierarchy/hierarchy_utils.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from dataclasses import dataclass
from dataclasses import dataclass, field
from copy import deepcopy
import numpy as np

from typing import Any
from typing import Any, List, Tuple

from .hierarchy import PatchHierarchy, format_timestamp

Check notice

Code scanning / CodeQL

Cyclic import Note

Import of module
pyphare.pharesee.hierarchy.hierarchy
begins an import cycle.
from .patchdata import FieldData, ParticleData
Expand All @@ -12,6 +12,7 @@
from ...core.gridlayout import GridLayout
from ...core.phare_utilities import listify
from ...core.phare_utilities import refinement_ratio
from pyphare.core import phare_utilities as phut


field_qties = {
Expand Down Expand Up @@ -561,41 +562,53 @@ def _compute_scalardiv(patch_datas, **kwargs):

@dataclass
class EqualityReport:
ok: bool
reason: str
ref: Any = None
cmp: Any = None
failed: List[Tuple[str, Any, Any]] = field(default_factory=lambda: [])

def __bool__(self):
return self.ok
return not self.failed

def __repr__(self):
return self.reason
for msg, ref, cmp in self:
print(msg)
try:
if type(ref) is FieldData:
phut.assert_fp_any_all_close(ref[:], cmp[:], atol=1e-16)
except AssertionError as e:
print(e)
return self.failed[0][0]

def __post_init__(self):
not_nones = [a is not None for a in [self.ref, self.cmp]]
if all(not_nones):
assert id(self.ref) != id(self.cmp)
else:
assert not any(not_nones)
def __call__(self, reason, ref=None, cmp=None):
self.failed.append((reason, ref, cmp))
return self

def __getitem__(self, idx):
return (self.failed[idx][1], self.failed[idx][2])

def __iter__(self):
return self.failed.__iter__()

def __reversed__(self):
return reversed(self.failed)


def hierarchy_compare(this, that, atol=1e-16):
eqr = EqualityReport()

if not isinstance(this, PatchHierarchy) or not isinstance(that, PatchHierarchy):
return EqualityReport(False, "class type mismatch")
return eqr("class type mismatch")

if this.ndim != that.ndim or this.domain_box != that.domain_box:
return EqualityReport(False, "dimensional mismatch")
return eqr("dimensional mismatch")

if this.time_hier.keys() != that.time_hier.keys():
return EqualityReport(False, "timesteps mismatch")
return eqr("timesteps mismatch")

for tidx in this.times():
patch_levels_ref = this.time_hier[tidx]
patch_levels_cmp = that.time_hier[tidx]

if patch_levels_ref.keys() != patch_levels_cmp.keys():
return EqualityReport(False, "levels mismatch")
return eqr("levels mismatch")

for level_idx in patch_levels_cmp.keys():
patch_level_ref = patch_levels_ref[level_idx]
Expand All @@ -606,19 +619,20 @@ def hierarchy_compare(this, that, atol=1e-16):
patch_cmp = patch_level_cmp.patches[patch_idx]

if patch_ref.patch_datas.keys() != patch_cmp.patch_datas.keys():
return EqualityReport(False, "data keys mismatch")
return eqr("data keys mismatch")

for patch_data_key in patch_ref.patch_datas.keys():
patch_data_ref = patch_ref.patch_datas[patch_data_key]
patch_data_cmp = patch_cmp.patch_datas[patch_data_key]

if not patch_data_cmp.compare(patch_data_ref, atol=atol):
msg = f"data mismatch: {type(patch_data_ref).__name__} {patch_data_key}"
return EqualityReport(
False, msg, patch_data_cmp, patch_data_ref
)
eqr(msg, patch_data_cmp, patch_data_ref)

if not eqr:
return eqr

return EqualityReport(True, "OK")
return eqr


def single_patch_for_LO(hier, qties=None, skip=None):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,6 @@ template<typename Field_t, typename GridLayout>
void SamraiHDF5FieldInitializer<Field_t, GridLayout>::load(Field_t& field,
GridLayout const& layout) const
{
auto const local_cell
= [&](auto const& box, auto const& point) { return layout.AMRToLocal(point, box); };

auto const& dest_box = layout.AMRBox();
auto const& centering = layout.centering(field.physicalQuantity());
auto const& overlaps = SamraiH5Interface<GridLayout>::INSTANCE().box_intersections(dest_box);
Expand All @@ -44,17 +41,22 @@ void SamraiHDF5FieldInitializer<Field_t, GridLayout>::load(Field_t& field,
auto const src_box = pdata.box;
auto const data = h5File.template read_data_set_flat<double>(
pdata.base_path + "/" + field.name() + "##default/field_" + field.name());
core::Box<std::uint32_t, GridLayout::dimension> const lcl_src_box{
core::Box<std::uint32_t, GridLayout::dimension> const lcl_src_gbox{
core::Point{core::ConstArray<std::uint32_t, GridLayout::dimension>()},
core::Point{
core::for_N<GridLayout::dimension, core::for_N_R_mode::make_array>([&](auto i) {
return static_cast<std::uint32_t>(
src_box.upper[i] - src_box.lower[i] + (GridLayout::nbrGhosts() * 2)
+ (centering[i] == core::QtyCentering::primal ? 1 : 0));
})}};
auto data_view = core::make_array_view(data.data(), *lcl_src_box.shape());
for (auto const& point : overlap_box)
field(local_cell(dest_box, point)) = data_view(local_cell(src_box, point));
auto const data_view = core::make_array_view(data.data(), *lcl_src_gbox.shape());
auto const overlap_gb = grow(overlap_box, GridLayout::nbrGhosts());
auto const lcl_src_box = layout.AMRToLocal(overlap_gb, src_box);
auto const lcl_dst_box = layout.AMRToLocal(overlap_gb, dest_box);
auto src_it = lcl_src_box.begin();
auto dst_it = lcl_dst_box.begin();
for (; src_it != lcl_src_box.end(); ++src_it, ++dst_it)
field(*dst_it) = data_view(*src_it);
}
}

Expand Down
14 changes: 8 additions & 6 deletions src/amr/level_initializer/hybrid_level_initializer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,9 @@
#include "amr/level_initializer/level_initializer.hpp"
#include "amr/messengers/hybrid_messenger.hpp"
#include "amr/messengers/messenger.hpp"
#include "amr/physical_models/hybrid_model.hpp"
#include "amr/physical_models/physical_model.hpp"
#include "amr/resources_manager/amr_utils.hpp"
#include "core/data/grid/gridlayout_utils.hpp"
#include "core/data/ions/ions.hpp"
#include "core/numerics/ampere/ampere.hpp"
#include "core/numerics/interpolator/interpolator.hpp"
#include "core/numerics/moments/moments.hpp"
Expand Down Expand Up @@ -43,10 +41,12 @@ namespace solver
: ohm_{dict["algo"]["ohm"]}
{
}
virtual void initialize(std::shared_ptr<hierarchy_t> const& hierarchy, int levelNumber,
std::shared_ptr<level_t> const& oldLevel, IPhysicalModelT& model,
amr::IMessenger<IPhysicalModelT>& messenger, double initDataTime,
bool isRegridding) override


void initialize(std::shared_ptr<hierarchy_t> const& hierarchy, int levelNumber,
std::shared_ptr<level_t> const& oldLevel, IPhysicalModelT& model,
amr::IMessenger<IPhysicalModelT>& messenger, double initDataTime,
bool isRegridding) override
{
core::Interpolator<dimension, interp_order> interpolate_;
auto& hybridModel = static_cast<HybridModel&>(model);
Expand Down Expand Up @@ -163,6 +163,8 @@ namespace solver
hybMessenger.prepareStep(hybridModel, level, initDataTime);
}
};


} // namespace solver
} // namespace PHARE

Expand Down
46 changes: 39 additions & 7 deletions src/core/data/grid/gridlayout.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -832,17 +832,19 @@ namespace core
* This method only deals with **cell** indexes.
*/
template<typename T>
NO_DISCARD auto AMRToLocal(Box<T, dimension> const& AMRBox) const
NO_DISCARD auto AMRToLocal(Box<T, dimension> const& AMRBox,
Box<int, dimension> const& localbox) const
{
static_assert(std::is_integral_v<T>, "Error, must be MeshIndex (integral Point)");
auto localBox = Box<std::uint32_t, dimension>{};

localBox.lower = AMRToLocal(AMRBox.lower);
localBox.upper = AMRToLocal(AMRBox.upper);

return localBox;
return Box<std::uint32_t, dimension>{AMRToLocal(AMRBox.lower, localbox),
AMRToLocal(AMRBox.upper, localbox)};
}

template<typename T>
NO_DISCARD auto AMRToLocal(Box<T, dimension> const& AMRBox) const
{
return AMRToLocal(AMRBox, AMRBox_);
}


template<typename Field, std::size_t nbr_points>
Expand Down Expand Up @@ -1171,6 +1173,22 @@ namespace core
evalOnBox_(field, fn, indices);
}

template<typename Field>
auto domainBoxFor(Field const& field) const
{
return _BoxFor(field, [&](auto const& centering, auto const direction) {
return this->physicalStartToEnd(centering, direction);
});
}

template<typename Field>
auto ghostBoxFor(Field const& field) const
{
return _BoxFor(field, [&](auto const& centering, auto const direction) {
return this->ghostStartToEnd(centering, direction);
});
}


private:
template<typename Field, typename IndicesFn, typename Fn>
Expand Down Expand Up @@ -1206,6 +1224,20 @@ namespace core
}


template<typename Field, typename Fn>
auto _BoxFor(Field const& field, Fn startToEnd) const
{
constexpr auto directions = std::array{Direction::X, Direction::Y, Direction::Z};
std::array<std::uint32_t, dimension> lower, upper;
core::for_N<dimension>([&](auto i) {
auto const [i0, i1] = startToEnd(field, directions[i]);
lower[i] = i0;
upper[i] = i1;
});
return Box<std::uint32_t, dimension>{lower, upper};
}


template<typename Centering, typename StartToEnd>
auto StartToEndIndices_(Centering const& centering, StartToEnd const&& startToEnd,
bool const includeEnd = false) const
Expand Down
19 changes: 6 additions & 13 deletions tests/simulator/test_init_from_restart.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import numpy as np
import pyphare.pharein as ph

from pyphare.core import phare_utilities as phut

from pyphare.simulator.simulator import Simulator
from pyphare.pharesee.hierarchy.patchdata import FieldData, ParticleData
from pyphare.pharesee.hierarchy.fromh5 import get_all_available_quantities_from_h5
Expand All @@ -24,11 +24,9 @@
cells = 200
first_out = "phare_outputs/reinit/first"
secnd_out = "phare_outputs/reinit/secnd"
# timestamps = [0,time_step]
timestamps = np.arange(0, final_time + time_step, time_step)
restart_idx = Z = 2
simInitArgs = dict(
largest_patch_size=100,
time_step_nbr=time_step_nbr,
time_step=time_step,
cells=cells,
Expand All @@ -41,7 +39,7 @@
def setup_model(sim):
model = ph.MaxwellianFluidModel(
protons={"mass": 1, "charge": 1, "nbr_part_per_cell": ppc},
alpha={"mass": 4.0, "charge": 1, "nbr_part_per_cell": ppc},
alpha={"mass": 4, "charge": 1, "nbr_part_per_cell": ppc},
)
ph.ElectronModel(closure="isothermal", Te=0.12)
dump_all_diags(model.populations, timestamps=timestamps)
Expand All @@ -65,23 +63,18 @@ def test_reinit(self):
sim = ph.Simulation(**copy.deepcopy(simInitArgs))
setup_model(sim)
Simulator(sim).run().reset()
fidx, sidx = 2, 0
fidx, sidx = 4, 2
datahier0 = get_all_available_quantities_from_h5(first_out, timestamps[fidx])
datahier0.time_hier = { # swap times
format_timestamp(timestamps[sidx]): datahier0.time_hier[
format_timestamp(timestamps[fidx])
]
}
datahier1 = get_all_available_quantities_from_h5(secnd_out, timestamps[sidx])
qties = ["protons_domain", "alpha_domain", "Bx", "By", "Bz"]
skip = None # ["protons_patchGhost", "alpha_patchGhost"]
qties = None
skip = ["protons_patchGhost", "alpha_patchGhost"]
ds = [single_patch_for_LO(d, qties, skip) for d in [datahier0, datahier1]]
eq = hierarchy_compare(*ds, atol=1e-14)
if not eq:
print(eq)
if type(eq.ref) == FieldData:
phut.assert_fp_any_all_close(eq.ref[:], eq.cmp[:], atol=1e-16)
self.assertTrue(eq)
self.assertTrue(hierarchy_compare(*ds, atol=1e-12))


def run_first_sim():
Expand Down

0 comments on commit ca1f3af

Please sign in to comment.