diff --git a/src/python/CMakeLists.txt b/src/python/CMakeLists.txt index d74fe1cc5..3265b6eb5 100644 --- a/src/python/CMakeLists.txt +++ b/src/python/CMakeLists.txt @@ -11,4 +11,5 @@ target_sources(pyImpactX ReferenceParticle.cpp transformation.cpp WakeConvolution.cpp + SmallMatrix.cpp ) diff --git a/src/python/SmallMatrix.cpp b/src/python/SmallMatrix.cpp new file mode 100644 index 000000000..052533551 --- /dev/null +++ b/src/python/SmallMatrix.cpp @@ -0,0 +1,72 @@ +/* Copyright 2021-2023 The ImpactX Community + * + * Authors: Ryan Sandberg, Axel Huebl + * License: BSD-3-Clause-LBNL + */ +#include "pyImpactX.H" +#include + +namespace py = pybind11; + +namespace pybind11 { +namespace detail { + +template +struct pybind11::detail::type_caster> { +public: + PYBIND11_TYPE_CASTER(amrex::SmallMatrix, + _("SmallMatrix[") + py::detail::make_caster::name() + _("]")); + + // Conversion from Python to C++ + bool load(handle src, bool) { + // Ensure we have a numpy array + py::array_t arr = py::cast>(src); + py::buffer_info buf = arr.request(); + + // Check dimensions and shape + if (buf.ndim != 2) { + throw std::runtime_error("SmallMatrix requires a 2D array."); + } + if (buf.shape[0] != NRows || buf.shape[1] != NCols) { + throw std::runtime_error("SmallMatrix array shape must match NRows x NCols."); + } + + // Create a SmallMatrix and copy data + amrex::SmallMatrix mat; + T* ptr = static_cast(buf.ptr); + for (int i = 0; i < NRows * NCols; ++i) { + mat.m_mat[i] = ptr[i]; + } + + value = mat; + return true; + } + + // Conversion from C++ to Python + static handle cast(const amrex::SmallMatrix& src, + return_value_policy /* policy */, handle /* parent */) { + py::array_t arr({NRows, NCols}); + py::buffer_info buf = arr.request(); + T* ptr = static_cast(buf.ptr); + for (int i = 0; i < NRows * NCols; ++i) { + ptr[i] = src.m_mat[i]; + } + return arr.release(); + } +}; + +} // namespace detail +} // namespace pybind11 + + +PYBIND11_MODULE(example, m) { + // You can now just bind constructors and methods normally without defining conversion code: + py::class_>(m, "SmallMatrix6x6") + .def(py::init<>()) // Default init + .def("as_array", [](const amrex::SmallMatrix& mat) { + return mat; // Will use type_caster to return a numpy array + }); + + // Now Python functions expecting a SmallMatrix can pass a numpy array directly: + // def some_func(mat: SmallMatrix6x6): ... +} diff --git a/src/python/pyImpactX.H b/src/python/pyImpactX.H index 4401cf81a..748c846bc 100644 --- a/src/python/pyImpactX.H +++ b/src/python/pyImpactX.H @@ -13,6 +13,7 @@ #include #include #include +#include #include