Skip to content

Commit

Permalink
Merge branch 'master' into csharp-api
Browse files Browse the repository at this point in the history
  • Loading branch information
praydog committed Jun 4, 2024
2 parents 5dbc037 + 5c2eea8 commit 459c5fd
Show file tree
Hide file tree
Showing 3 changed files with 84 additions and 31 deletions.
97 changes: 71 additions & 26 deletions src/D3D12Hook.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include <utility/String.hpp>
#include <utility/RTTI.hpp>
#include <utility/Scan.hpp>
#include <utility/ScopeGuard.hpp>

#include "REFramework.hpp"

Expand All @@ -18,19 +19,28 @@
#include "D3D12Hook.hpp"

static D3D12Hook* g_d3d12_hook = nullptr;
thread_local bool g_inside_d3d12_hook = false;

D3D12Hook::~D3D12Hook() {
unhook();
}

void* D3D12Hook::Streamline::link_swapchain_to_cmd_queue(void* rcx, void* rdx, void* r8, void* r9) {
if (g_inside_d3d12_hook) {
spdlog::info("[Streamline] linkSwapchainToCmdQueue: {:x} (inside D3D12 hook)", (uintptr_t)_ReturnAddress());

auto& hook = D3D12Hook::s_streamline.link_swapchain_to_cmd_queue_hook;
return hook->get_original<decltype(link_swapchain_to_cmd_queue)>()(rcx, rdx, r8, r9);
}

std::scoped_lock _{g_framework->get_hook_monitor_mutex()};

spdlog::info("[Streamline] linkSwapchainToCmdQueue: {:x}", (uintptr_t)_ReturnAddress());

g_framework->on_reset(); // Needed to prevent a crash due to resources hanging around
bool hook_was_nullptr = g_d3d12_hook == nullptr;

if (g_d3d12_hook != nullptr) {
g_framework->on_reset(); // Needed to prevent a crash due to resources hanging around
g_d3d12_hook->unhook(); // Removes all vtable hooks
}

Expand All @@ -40,19 +50,58 @@ void* D3D12Hook::Streamline::link_swapchain_to_cmd_queue(void* rcx, void* rdx, v
// Re-hooks present after the above function creates the swapchain
// This allows the hook to immediately still function
// rather than waiting on the hook monitor to notice the hook isn't working
g_framework->hook_d3d12();
if (!hook_was_nullptr) {
g_framework->hook_d3d12();
}

return result;
}

HRESULT WINAPI D3D12Hook::create_swapchain(IDXGIFactory4* factory, IUnknown* device, HWND hwnd, const DXGI_SWAP_CHAIN_DESC* desc, const DXGI_SWAP_CHAIN_FULLSCREEN_DESC* p_fullscreen_desc, IDXGIOutput* p_restrict_to_output, IDXGISwapChain** swap_chain) {
auto create_swap_chain_fn = s_create_swapchain_hook->get_original<decltype(D3D12Hook::create_swapchain)*>();

if (g_inside_d3d12_hook) {
spdlog::info("create_swapchain (inside D3D12 hook)");
return create_swap_chain_fn(factory, device, hwnd, desc, p_fullscreen_desc, p_restrict_to_output, swap_chain);
}

spdlog::info("create_swapchain called");

std::scoped_lock _{g_framework->get_hook_monitor_mutex()};

bool hook_was_nullptr = g_d3d12_hook == nullptr;

if (g_d3d12_hook != nullptr && g_framework->get_d3d12_hook() != nullptr) {
g_framework->on_reset(); // Needed to prevent a crash due to resources hanging around
g_d3d12_hook->unhook(); // Removes all vtable hooks
}

const auto result = create_swap_chain_fn(factory, device, hwnd, desc, p_fullscreen_desc, p_restrict_to_output, swap_chain);

// rather than waiting on the hook monitor to notice the hook isn't working
if (!hook_was_nullptr) {
g_framework->hook_d3d12();
}

return result;
}

void D3D12Hook::hook_streamline() {
void D3D12Hook::hook_streamline(HMODULE dlssg_module) try {
if (D3D12Hook::s_streamline.setup) {
return;
}

std::scoped_lock _{D3D12Hook::s_streamline.hook_mutex};

if (D3D12Hook::s_streamline.setup) {
return;
}

spdlog::info("[Streamline] Hooking Streamline");

const auto dlssg_module = GetModuleHandleW(L"sl.dlss_g.dll");
if (dlssg_module == nullptr) {
dlssg_module = GetModuleHandleW(L"sl.dlss_g.dll");
}

if (dlssg_module == nullptr) {
spdlog::error("[Streamline] Failed to get sl.dlss_g.dll module handle");
Expand Down Expand Up @@ -89,13 +138,21 @@ void D3D12Hook::hook_streamline() {
}

D3D12Hook::s_streamline.setup = true;
} catch(...) {
spdlog::error("[Streamline] Failed to hook Streamline");
}

bool D3D12Hook::hook() {
spdlog::info("Hooking D3D12");

g_d3d12_hook = this;

g_inside_d3d12_hook = true;

utility::ScopeGuard guard{[]() {
g_inside_d3d12_hook = false;
}};

IDXGISwapChain1* swap_chain1{ nullptr };
IDXGISwapChain3* swap_chain{ nullptr };
ID3D12Device* device{ nullptr };
Expand Down Expand Up @@ -318,6 +375,8 @@ bool D3D12Hook::hook() {

spdlog::info("Finding command queue offset");

m_command_queue_offset = 0;

// Find the command queue offset in the swapchain
for (auto i = 0; i < 512 * sizeof(void*); i += sizeof(void*)) {
const auto base = (uintptr_t)swap_chain1 + i;
Expand Down Expand Up @@ -413,8 +472,13 @@ bool D3D12Hook::hook() {
m_is_phase_1 = true;

auto& present_fn = (*(void***)target_swapchain)[8]; // Present
m_present_hook = std::make_unique<FunctionHook>((uintptr_t)present_fn, (uintptr_t)&D3D12Hook::present);
m_present_hook->create();
m_present_hook = std::make_unique<PointerHook>(&present_fn, &D3D12Hook::present);

if (s_create_swapchain_hook == nullptr) {
auto& create_swapchain_fn = (*(void***)factory)[15]; // CreateSwapChainForHwnd
s_create_swapchain_hook = std::make_unique<PointerHook>(&create_swapchain_fn, &D3D12Hook::create_swapchain);
}

m_hooked = true;
} catch (const std::exception& e) {
spdlog::error("Failed to initialize hooks: {}", e.what());
Expand Down Expand Up @@ -468,8 +532,7 @@ HRESULT WINAPI D3D12Hook::present(IDXGISwapChain3* swap_chain, uint64_t sync_int
decltype(D3D12Hook::present)* present_fn{nullptr};

if (d3d12->m_is_phase_1) {
//present_fn = d3d12->m_present_hook->get_original<decltype(D3D12Hook::present)*>();
present_fn = d3d12->m_present_hook->get_original<decltype(D3D12Hook::present)>();
present_fn = d3d12->m_present_hook->get_original<decltype(D3D12Hook::present)*>();
} else {
present_fn = d3d12->m_swapchain_hook->get_method<decltype(D3D12Hook::present)*>(8);
}
Expand Down Expand Up @@ -760,21 +823,3 @@ HRESULT WINAPI D3D12Hook::resize_target(IDXGISwapChain3* swap_chain, const DXGI_

return result;
}

/*HRESULT WINAPI D3D12Hook::create_swap_chain(IDXGIFactory4* factory, IUnknown* device, HWND hwnd, const DXGI_SWAP_CHAIN_DESC* desc, const DXGI_SWAP_CHAIN_FULLSCREEN_DESC* p_fullscreen_desc, IDXGIOutput* p_restrict_to_output, IDXGISwapChain** swap_chain)
{
spdlog::info("D3D12 create swapchain called");
auto d3d12 = g_d3d12_hook;
d3d12->m_command_queue = (ID3D12CommandQueue*)device;
if (d3d12->m_on_create_swap_chain) {
d3d12->m_on_create_swap_chain(*d3d12);
}
auto create_swap_chain_fn = d3d12->m_create_swap_chain_hook->get_original<decltype(D3D12Hook::create_swap_chain)>();
return create_swap_chain_fn(factory, device, hwnd, desc, p_fullscreen_desc, p_restrict_to_output, swap_chain);
}*/

12 changes: 7 additions & 5 deletions src/D3D12Hook.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ class D3D12Hook
m_ignore_next_present = true;
}

void hook_streamline();
static void hook_streamline(HMODULE dlssg_module = nullptr);

protected:
ID3D12Device4* m_device{ nullptr };
Expand All @@ -121,19 +121,21 @@ class D3D12Hook
bool m_inside_present{false};
bool m_ignore_next_present{false};

std::unique_ptr<FunctionHook> m_present_hook{};
//std::unique_ptr<PointerHook> m_release_hook{};
std::unique_ptr<PointerHook> m_present_hook{};
std::unique_ptr<VtableHook> m_swapchain_hook{};
//std::unique_ptr<FunctionHook> m_create_swap_chain_hook{};

struct Streamline {
static void* link_swapchain_to_cmd_queue(void* rcx, void* rdx, void* r8, void* r9);

std::unique_ptr<FunctionHook> link_swapchain_to_cmd_queue_hook{};
std::mutex hook_mutex{};
bool setup{ false };
};

static inline Streamline s_streamline{};

// This is static because unhooking it seems to cause a crash sometimes
static inline std::unique_ptr<PointerHook> s_create_swapchain_hook{};

OnPresentFn m_on_present{ nullptr };
OnPresentFn m_on_post_present{ nullptr };
Expand All @@ -144,6 +146,6 @@ class D3D12Hook
static HRESULT WINAPI present(IDXGISwapChain3* swap_chain, uint64_t sync_interval, uint64_t flags, void* r9);
static HRESULT WINAPI resize_buffers(IDXGISwapChain3* swap_chain, UINT buffer_count, UINT width, UINT height, DXGI_FORMAT new_format, UINT swap_chain_flags);
static HRESULT WINAPI resize_target(IDXGISwapChain3* swap_chain, const DXGI_MODE_DESC* new_target_parameters);
//static HRESULT WINAPI create_swap_chain(IDXGIFactory4* factory, IUnknown* device, HWND hwnd, const DXGI_SWAP_CHAIN_DESC* desc, const DXGI_SWAP_CHAIN_FULLSCREEN_DESC* p_fullscreen_desc, IDXGIOutput* p_restrict_to_output, IDXGISwapChain** swap_chain);
static HRESULT WINAPI create_swapchain(IDXGIFactory4* factory, IUnknown* device, HWND hwnd, const DXGI_SWAP_CHAIN_DESC* desc, const DXGI_SWAP_CHAIN_FULLSCREEN_DESC* p_fullscreen_desc, IDXGIOutput* p_restrict_to_output, IDXGISwapChain** swap_chain);
};

6 changes: 6 additions & 0 deletions src/REFramework.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,12 @@ try {
if (NotificationData->Loaded.BaseDllName != nullptr && NotificationData->Loaded.BaseDllName->Buffer != nullptr) {
std::wstring base_dll_name = NotificationData->Loaded.BaseDllName->Buffer;
spdlog::info("LdrRegisterDllNotification: Loaded: {}", utility::narrow(base_dll_name));

if (base_dll_name.find(L"sl.dlss_g.dll") != std::wstring::npos) {
spdlog::info("LdrRegisterDllNotification: Detected DLSS DLL loaded");

D3D12Hook::hook_streamline((HMODULE)NotificationData->Loaded.DllBase);
}
}

if (g_current_game_path && NotificationData->Loaded.FullDllName != nullptr && NotificationData->Loaded.FullDllName->Buffer != nullptr) {
Expand Down

0 comments on commit 459c5fd

Please sign in to comment.