Skip to content

Commit

Permalink
Avoid crash when converting dict with circular reference
Browse files Browse the repository at this point in the history
Fixes #73
  • Loading branch information
dbaston committed Dec 9, 2024
1 parent 482c9a2 commit 904f66b
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 3 deletions.
31 changes: 28 additions & 3 deletions include/pybind11_json/pybind11_json.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#ifndef PYBIND11_JSON_HPP
#define PYBIND11_JSON_HPP

#include <set>
#include <string>
#include <vector>

Expand Down Expand Up @@ -67,7 +68,7 @@ namespace pyjson
}
}

inline nl::json to_json(const py::handle& obj)
inline nl::json to_json(const py::handle& obj, std::set<const PyObject*>& refs)
{
if (obj.ptr() == nullptr || obj.is_none())
{
Expand Down Expand Up @@ -118,24 +119,48 @@ namespace pyjson
}
if (py::isinstance<py::tuple>(obj) || py::isinstance<py::list>(obj))
{
auto insert_ret = refs.insert(obj.ptr());
if (!insert_ret.second) {
throw std::runtime_error("Circular reference detected");
}

auto out = nl::json::array();
for (const py::handle value : obj)
{
out.push_back(to_json(value));
out.push_back(to_json(value, refs));
}

refs.erase(insert_ret.first);

return out;
}
if (py::isinstance<py::dict>(obj))
{
auto insert_ret = refs.insert(obj.ptr());
if (!insert_ret.second) {
throw std::runtime_error("Circular reference detected");
}

auto out = nl::json::object();
for (const py::handle key : obj)
{
out[py::str(key).cast<std::string>()] = to_json(obj[key]);
out[py::str(key).cast<std::string>()] = to_json(obj[key], refs);
}

refs.erase(insert_ret.first);

return out;
}

throw std::runtime_error("to_json not implemented for this type of object: " + py::repr(obj).cast<std::string>());
}

inline nl::json to_json(const py::handle& obj)
{
std::set<const PyObject*> refs;
return to_json(obj, refs);
}

}

// nlohmann_json serializers
Expand Down
19 changes: 19 additions & 0 deletions test/test_pybind11_json.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -481,3 +481,22 @@ TEST(pybind11_caster_fromjson, dict)
ASSERT_EQ(j["number"].cast<int>(), 1234);
ASSERT_EQ(j["hello"].cast<std::string>(), "world");
}

TEST(pybind11_caster_tojson, recursive_dict)
{
py::scoped_interpreter guard;
py::module m = create_module("test");

m.def("to_json", &test_fromtojson);

// Simulate calling this binding from Python with a dictionary as argument
py::dict obj_inner("number"_a=1234, "hello"_a="world");
py::dict obj;
obj["first"] = obj_inner;
obj["second"] = obj_inner;

ASSERT_NO_THROW(m.attr("to_json")(obj));

obj["second"]["recur"] = obj_inner;
ASSERT_ANY_THROW(m.attr("to_json")(obj));
}

0 comments on commit 904f66b

Please sign in to comment.