From 904f66b0cb0e9080d182c1b3fa94b92441824ea2 Mon Sep 17 00:00:00 2001 From: Daniel Baston Date: Fri, 6 Dec 2024 11:19:25 -0500 Subject: [PATCH] Avoid crash when converting dict with circular reference Fixes https://github.com/pybind/pybind11_json/issues/73 --- include/pybind11_json/pybind11_json.hpp | 31 ++++++++++++++++++++++--- test/test_pybind11_json.cpp | 19 +++++++++++++++ 2 files changed, 47 insertions(+), 3 deletions(-) diff --git a/include/pybind11_json/pybind11_json.hpp b/include/pybind11_json/pybind11_json.hpp index 2b7d9b2..f0344dd 100644 --- a/include/pybind11_json/pybind11_json.hpp +++ b/include/pybind11_json/pybind11_json.hpp @@ -9,6 +9,7 @@ #ifndef PYBIND11_JSON_HPP #define PYBIND11_JSON_HPP +#include #include #include @@ -67,7 +68,7 @@ namespace pyjson } } - inline nl::json to_json(const py::handle& obj) + inline nl::json to_json(const py::handle& obj, std::set& refs) { if (obj.ptr() == nullptr || obj.is_none()) { @@ -118,24 +119,48 @@ namespace pyjson } if (py::isinstance(obj) || py::isinstance(obj)) { + auto insert_ret = refs.insert(obj.ptr()); + if (!insert_ret.second) { + throw std::runtime_error("Circular reference detected"); + } + auto out = nl::json::array(); for (const py::handle value : obj) { - out.push_back(to_json(value)); + out.push_back(to_json(value, refs)); } + + refs.erase(insert_ret.first); + return out; } if (py::isinstance(obj)) { + auto insert_ret = refs.insert(obj.ptr()); + if (!insert_ret.second) { + throw std::runtime_error("Circular reference detected"); + } + auto out = nl::json::object(); for (const py::handle key : obj) { - out[py::str(key).cast()] = to_json(obj[key]); + out[py::str(key).cast()] = to_json(obj[key], refs); } + + refs.erase(insert_ret.first); + return out; } + throw std::runtime_error("to_json not implemented for this type of object: " + py::repr(obj).cast()); } + + inline nl::json to_json(const py::handle& obj) + { + std::set refs; + return to_json(obj, refs); + } + } // nlohmann_json serializers diff --git a/test/test_pybind11_json.cpp b/test/test_pybind11_json.cpp index 2d9cc82..3eb02d1 100644 --- a/test/test_pybind11_json.cpp +++ b/test/test_pybind11_json.cpp @@ -481,3 +481,22 @@ TEST(pybind11_caster_fromjson, dict) ASSERT_EQ(j["number"].cast(), 1234); ASSERT_EQ(j["hello"].cast(), "world"); } + +TEST(pybind11_caster_tojson, recursive_dict) +{ + py::scoped_interpreter guard; + py::module m = create_module("test"); + + m.def("to_json", &test_fromtojson); + + // Simulate calling this binding from Python with a dictionary as argument + py::dict obj_inner("number"_a=1234, "hello"_a="world"); + py::dict obj; + obj["first"] = obj_inner; + obj["second"] = obj_inner; + + ASSERT_NO_THROW(m.attr("to_json")(obj)); + + obj["second"]["recur"] = obj_inner; + ASSERT_ANY_THROW(m.attr("to_json")(obj)); +}