Skip to content

Commit

Permalink
NdArrayView
Browse files Browse the repository at this point in the history
  • Loading branch information
PhilipDeegan committed Jul 17, 2020
1 parent 68d926c commit 00da87f
Show file tree
Hide file tree
Showing 3 changed files with 109 additions and 48 deletions.
138 changes: 100 additions & 38 deletions src/core/data/ndarray/ndarray_vector.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,103 @@

namespace PHARE::core
{
template<std::size_t dim, typename DataType = double>
struct NdArrayViewer
{
template<typename NCells, typename... Indexes>
static DataType const& at(DataType const* data, NCells const& nCells, Indexes const&... indexes)
{
auto params = std::forward_as_tuple(indexes...);
static_assert(sizeof...(Indexes) == dim);
// static_assert((... && std::is_unsigned_v<decltype(indexes)>)); TODO : manage later if
// this test should be included

if constexpr (dim == 1)
{
auto i = std::get<0>(params);

return data[i];
}

if constexpr (dim == 2)
{
auto i = std::get<0>(params);
auto j = std::get<1>(params);

return data[j + i * nCells[1]];
}

if constexpr (dim == 3)
{
auto i = std::get<0>(params);
auto j = std::get<1>(params);
auto k = std::get<2>(params);

return data[k + j * nCells[2] + i * nCells[1] * nCells[2]];
}
}

template<typename NCells, typename Index>
static DataType const& at(DataType const* data, NCells const& nCells,
std::array<Index, dim> const& indexes)

{
if constexpr (dim == 1)
return data[indexes[0]];

else if constexpr (dim == 2)
return data[indexes[1] + indexes[0] * nCells[1]];

else if constexpr (dim == 3)
return data[indexes[2] + indexes[1] * nCells[2] + indexes[0] * nCells[1] * nCells[2]];
}
};


template<std::size_t dim, typename DataType = double, typename Pointer = DataType const*>
class NdArrayView : NdArrayViewer<dim, DataType>
{
public:
explicit NdArrayView(Pointer ptr, std::array<uint32_t, dim> const& nCells)
: ptr_{ptr}
, nCells_{nCells}
{
}

explicit NdArrayView(std::vector<DataType> const& v, std::array<uint32_t, dim> const& nbCell)
: NdArrayView{v.data(), nbCell}
{
}

template<typename... Indexes>
DataType const& operator()(Indexes... indexes) const
{
return NdArrayViewer<dim, DataType>::at(ptr_, nCells_, indexes...);
}

template<typename... Indexes>
DataType& operator()(Indexes... indexes)
{
return const_cast<DataType&>(static_cast<NdArrayView const&>(*this)(indexes...));
}

template<typename Index>
DataType const& operator()(std::array<Index, dim> const& indexes) const
{
return NdArrayViewer<dim, DataType>::at(ptr_, nCells_, indexes);
}

template<typename Index>
DataType& operator()(std::array<Index, dim> const& indexes)
{
return const_cast<DataType&>(static_cast<NdArrayView const&>(*this)(indexes));
}

private:
Pointer ptr_ = nullptr;
std::array<std::uint32_t, dim> nCells_;
};

template<std::size_t dim, typename DataType = double>
class NdArrayVector
{
Expand Down Expand Up @@ -74,34 +171,7 @@ class NdArrayVector
template<typename... Indexes>
DataType const& operator()(Indexes... indexes) const
{
auto params = std::tuple<Indexes...>{indexes...};
static_assert(sizeof...(Indexes) == dim);
// static_assert((... && std::is_unsigned_v<decltype(indexes)>)); TODO : manage later if
// this test should be included

if constexpr (dim == 1)
{
auto i = std::get<0>(params);

return this->data_[i];
}

if constexpr (dim == 2)
{
auto i = std::get<0>(params);
auto j = std::get<1>(params);

return this->data_[j + i * nCells_[1]];
}

if constexpr (dim == 3)
{
auto i = std::get<0>(params);
auto j = std::get<1>(params);
auto k = std::get<2>(params);

return this->data_[k + j * nCells_[2] + i * nCells_[1] * nCells_[2]];
}
return NdArrayViewer<dim, DataType>::at(data_.data(), nCells_, indexes...);
}

template<typename... Indexes>
Expand All @@ -113,15 +183,7 @@ class NdArrayVector
template<typename Index>
DataType const& operator()(std::array<Index, dim> const& indexes) const
{
if constexpr (dim == 1)
return this->data_[indexes[0]];

else if constexpr (dim == 2)
return this->data_[indexes[1] + indexes[0] * nCells_[1]];

else if constexpr (dim == 3)
return this->data_[indexes[2] + indexes[1] * nCells_[2]
+ indexes[0] * nCells_[1] * nCells_[2]];
return NdArrayViewer<dim, DataType>::at(data_.data(), nCells_, indexes);
}

template<typename Index>
Expand All @@ -141,4 +203,4 @@ class NdArrayVector

} // namespace PHARE::core

#endif // PHARE_CORE_DATA_NDARRAY_NDARRAY_VECTOR_H
#endif // PHARE_CORE_DATA_NDARRAY_NDARRAY_VECTOR_H
2 changes: 0 additions & 2 deletions tests/diagnostic/test-diagnostics_2d.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@
static std::string const job_file = "job_2d";
static std::string const out_dir = "phare_outputs/diags_2d/";

// blocked by https://github.com/PHAREHUB/PHARE/pull/230

TYPED_TEST(Simulator2dTest, fluid)
{
fluid_test(TypeParam{job_file}, out_dir);
Expand Down
17 changes: 9 additions & 8 deletions tests/diagnostic/test_diagnostics.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,20 +17,21 @@ constexpr unsigned NEW_HI5_FILE
= HighFive::File::ReadWrite | HighFive::File::Create | HighFive::File::Truncate;

template<typename FieldFilter, typename F, typename GridLayout>
std::array<size_t, GridLayout::dimension> fieldIndices(FieldFilter ff, F&& func, GridLayout& layout)
std::array<uint32_t, GridLayout::dimension> fieldIndices(FieldFilter ff, F&& func,
GridLayout& layout)
{
constexpr auto dim = GridLayout::dimension;
static_assert(dim >= 1 and dim <= 3, "Invalid dimension.");

auto direction = [&](auto direction) { return ((ff).*(func))(layout, direction); };
auto get = [&](auto dir) { return static_cast<uint32_t>(((ff).*(func))(layout, dir)); };

if constexpr (dim == 1)
return {direction(PHARE::core::Direction::X)};
return {get(PHARE::core::Direction::X)};
if constexpr (dim == 2)
return {direction(PHARE::core::Direction::X), direction(PHARE::core::Direction::Y)};
return {get(PHARE::core::Direction::X), get(PHARE::core::Direction::Y)};
if constexpr (dim == 3)
return {direction(PHARE::core::Direction::X), direction(PHARE::core::Direction::Y),
direction(PHARE::core::Direction::Z)};
return {get(PHARE::core::Direction::X), get(PHARE::core::Direction::Y),
get(PHARE::core::Direction::Z)};
}

template<typename GridLayout, typename Field, typename FieldFilter = PHARE::FieldNullFilter>
Expand All @@ -52,7 +53,7 @@ void checkField(HighFive::File& file, GridLayout& layout, Field& field, std::str

if constexpr (dim == 1)
{
PHARE::core::NdArrayVector1DView<float> view{fieldV, siz};
PHARE::core::NdArrayView<1, float> view{fieldV, siz};
for (size_t i = beg[0]; i < end[0]; i++)
{
if (std::isnan(view(i)) || std::isnan(field(i)))
Expand All @@ -62,7 +63,7 @@ void checkField(HighFive::File& file, GridLayout& layout, Field& field, std::str
}
else if constexpr (dim == 2)
{
PHARE::core::NdArrayVector2DView<float> view{fieldV, siz};
PHARE::core::NdArrayView<2, float> view{fieldV, siz};
for (size_t i = beg[0]; i < end[0]; i++)
{
for (size_t j = beg[1]; j < end[1]; j++)
Expand Down

0 comments on commit 00da87f

Please sign in to comment.