diff --git a/src/mods/ScriptRunner.cpp b/src/mods/ScriptRunner.cpp index f795bbaf3..ad12ce21b 100644 --- a/src/mods/ScriptRunner.cpp +++ b/src/mods/ScriptRunner.cpp @@ -52,6 +52,29 @@ void debug(const char* str) { } } +namespace api::thread { +size_t get_hash() { + const auto id = std::this_thread::get_id(); + return std::hash{}(id); +} + +uint32_t get_id() { + return std::this_thread::get_id()._Get_underlying_id(); +} + +sol::object get_hook_storage(sol::this_state s, size_t hash) { + auto sol_state = sol::state_view{s}; + auto state = sol_state.registry()["state"].get(); + auto result = state->get_hook_storage(get_hash()); + + if (!result.has_value()) { + return sol::make_object(s, sol::lua_nil); + } + + return sol::make_object(s, result.value()); +} +} + ScriptState::ScriptState(const ScriptState::GarbageCollectionData& gc_data,bool is_main_state) { std::scoped_lock _{ m_execution_mutex }; m_is_main_state = is_main_state; @@ -88,6 +111,11 @@ ScriptState::ScriptState(const ScriptState::GarbageCollectionData& gc_data,bool re["on_config_save"] = [this](sol::function fn) { m_on_config_save_fns.emplace_back(fn); }; m_lua["re"] = re; + auto thread = m_lua.create_table(); + thread["get_hash"] = api::thread::get_hash; + thread["get_id"] = api::thread::get_id; + thread["get_hook_storage"] = api::thread::get_hook_storage; + m_lua["thread"] = thread; auto log = m_lua.create_table(); log["info"] = api::log::info; @@ -545,6 +573,8 @@ void ScriptState::install_hooks() { } try { + state->push_hook_storage(std::hash{}(std::this_thread::get_id())); + if (pre_cb.is()) { return result; } @@ -591,6 +621,7 @@ void ScriptState::install_hooks() { try { if (post_cb.is()) { + state->pop_hook_storage(std::hash{}(std::this_thread::get_id())); return; } @@ -601,6 +632,7 @@ void ScriptState::install_hooks() { } ret_val = (uintptr_t)script_result.get(); + state->pop_hook_storage(std::hash{}(std::this_thread::get_id())); } catch (const std::exception& e) { ScriptRunner::get()->spew_error(e.what()); } catch (...) { diff --git a/src/mods/ScriptRunner.hpp b/src/mods/ScriptRunner.hpp index 83592ab90..719022056 100644 --- a/src/mods/ScriptRunner.hpp +++ b/src/mods/ScriptRunner.hpp @@ -154,6 +154,44 @@ class ScriptState { void gc_data_changed(GarbageCollectionData data); + /*sol::table get_thread_storage(size_t hash) { + auto it = m_thread_storage.find(hash); + if (it == m_thread_storage.end()) { + it = m_thread_storage.emplace(hash, m_lua.create_table()).first; + } + + return it->second; + }*/ + + void push_hook_storage(size_t thread_hash) { + auto it = m_hook_storage.find(thread_hash); + if (it == m_hook_storage.end()) { + it = m_hook_storage.emplace(thread_hash, std::stack{}).first; + } + + it->second.push(m_lua.create_table()); + } + + void pop_hook_storage(size_t thread_hash) { + auto it = m_hook_storage.find(thread_hash); + if (it != m_hook_storage.end()) { + if (!it->second.empty()) { + it->second.pop(); + } + } + } + + std::optional get_hook_storage(size_t thread_hash) { + auto it = m_hook_storage.find(thread_hash); + if (it != m_hook_storage.end()) { + if (!it->second.empty()) { + return it->second.top(); + } + } + + return std::nullopt; + } + private: sol::state m_lua{}; @@ -182,6 +220,8 @@ class ScriptState { std::deque m_hooks_to_add{}; std::unordered_map> m_hooks{}; + + std::unordered_map> m_hook_storage{}; }; class ScriptRunner : public Mod {