From 3ad93065948168038cff33826b6904f0c2378eb4 Mon Sep 17 00:00:00 2001 From: Quentin <837334+xiaoxiao921@users.noreply.github.com> Date: Wed, 28 Aug 2024 23:39:01 +0200 Subject: [PATCH] fix(lua): fix dynamic hooks/calls: (#3629) - dynamic hook were never actually applied, - rbx was getting trashed inside jit hook func - dynamic hook cleanup on lua script unload - support dynamic call --- src/lua/bindings/asmjit_helper.cpp | 111 +++++++ src/lua/bindings/asmjit_helper.hpp | 15 + src/lua/bindings/memory.cpp | 130 ++++++--- src/lua/bindings/runtime_func_t.cpp | 312 ++++++++++++++++++++ src/lua/bindings/runtime_func_t.hpp | 432 +--------------------------- src/lua/bindings/type_info_t.cpp | 37 +++ src/lua/bindings/type_info_t.hpp | 2 + src/lua/lua_manager.cpp | 16 ++ src/lua/lua_manager.hpp | 6 + src/lua/lua_module.hpp | 6 +- 10 files changed, 602 insertions(+), 465 deletions(-) create mode 100644 src/lua/bindings/asmjit_helper.cpp create mode 100644 src/lua/bindings/asmjit_helper.hpp create mode 100644 src/lua/bindings/runtime_func_t.cpp create mode 100644 src/lua/bindings/type_info_t.cpp diff --git a/src/lua/bindings/asmjit_helper.cpp b/src/lua/bindings/asmjit_helper.cpp new file mode 100644 index 00000000..3ce59e4c --- /dev/null +++ b/src/lua/bindings/asmjit_helper.cpp @@ -0,0 +1,111 @@ +#pragma once +#include "asmjit_helper.hpp" + +namespace lua::memory +{ + bool is_general_register(const asmjit::TypeId type_id) + { + switch (type_id) + { + case asmjit::TypeId::kInt8: + case asmjit::TypeId::kUInt8: + case asmjit::TypeId::kInt16: + case asmjit::TypeId::kUInt16: + case asmjit::TypeId::kInt32: + case asmjit::TypeId::kUInt32: + case asmjit::TypeId::kInt64: + case asmjit::TypeId::kUInt64: + case asmjit::TypeId::kIntPtr: + case asmjit::TypeId::kUIntPtr: return true; + default: return false; + } + } + + bool is_XMM_register(const asmjit::TypeId type_id) + { + switch (type_id) + { + case asmjit::TypeId::kFloat32: + case asmjit::TypeId::kFloat64: return true; + default: return false; + } + } + + asmjit::CallConvId get_call_convention(const std::string& conv) + { + if (conv == "cdecl") + { + return asmjit::CallConvId::kCDecl; + } + else if (conv == "stdcall") + { + return asmjit::CallConvId::kStdCall; + } + else if (conv == "fastcall") + { + return asmjit::CallConvId::kFastCall; + } + + return asmjit::CallConvId::kHost; + } + + asmjit::TypeId get_type_id(const std::string& type) + { + if (type.find('*') != std::string::npos) + { + return asmjit::TypeId::kUIntPtr; + } + +#define TYPEID_MATCH_STR_IF(var, T) \ + if (var == #T) \ + { \ + return asmjit::TypeId(asmjit::TypeUtils::TypeIdOfT::kTypeId); \ + } +#define TYPEID_MATCH_STR_ELSEIF(var, T) \ + else if (var == #T) \ + { \ + return asmjit::TypeId(asmjit::TypeUtils::TypeIdOfT::kTypeId); \ + } + + TYPEID_MATCH_STR_IF(type, signed char) + TYPEID_MATCH_STR_ELSEIF(type, unsigned char) + TYPEID_MATCH_STR_ELSEIF(type, short) + TYPEID_MATCH_STR_ELSEIF(type, unsigned short) + TYPEID_MATCH_STR_ELSEIF(type, int) + TYPEID_MATCH_STR_ELSEIF(type, unsigned int) + TYPEID_MATCH_STR_ELSEIF(type, long) + TYPEID_MATCH_STR_ELSEIF(type, unsigned long) +#ifdef POLYHOOK2_OS_WINDOWS + TYPEID_MATCH_STR_ELSEIF(type, __int64) + TYPEID_MATCH_STR_ELSEIF(type, unsigned __int64) +#endif + TYPEID_MATCH_STR_ELSEIF(type, long long) + TYPEID_MATCH_STR_ELSEIF(type, unsigned long long) + TYPEID_MATCH_STR_ELSEIF(type, char) + TYPEID_MATCH_STR_ELSEIF(type, char16_t) + TYPEID_MATCH_STR_ELSEIF(type, char32_t) + TYPEID_MATCH_STR_ELSEIF(type, wchar_t) + TYPEID_MATCH_STR_ELSEIF(type, uint8_t) + TYPEID_MATCH_STR_ELSEIF(type, int8_t) + TYPEID_MATCH_STR_ELSEIF(type, uint16_t) + TYPEID_MATCH_STR_ELSEIF(type, int16_t) + TYPEID_MATCH_STR_ELSEIF(type, int32_t) + TYPEID_MATCH_STR_ELSEIF(type, uint32_t) + TYPEID_MATCH_STR_ELSEIF(type, uint64_t) + TYPEID_MATCH_STR_ELSEIF(type, int64_t) + TYPEID_MATCH_STR_ELSEIF(type, float) + TYPEID_MATCH_STR_ELSEIF(type, double) + TYPEID_MATCH_STR_ELSEIF(type, bool) + TYPEID_MATCH_STR_ELSEIF(type, void) + else if (type == "intptr_t") + { + return asmjit::TypeId::kIntPtr; + } + else if (type == "uintptr_t") + { + return asmjit::TypeId::kUIntPtr; + } + + return asmjit::TypeId::kVoid; + } +} \ No newline at end of file diff --git a/src/lua/bindings/asmjit_helper.hpp b/src/lua/bindings/asmjit_helper.hpp new file mode 100644 index 00000000..6718e30f --- /dev/null +++ b/src/lua/bindings/asmjit_helper.hpp @@ -0,0 +1,15 @@ +#pragma once +#include + +namespace lua::memory +{ + // does a given type fit in a general purpose register (i.e. is it integer type) + bool is_general_register(const asmjit::TypeId type_id); + + // float, double, simd128 + bool is_XMM_register(const asmjit::TypeId type_id); + + asmjit::CallConvId get_call_convention(const std::string& conv); + + asmjit::TypeId get_type_id(const std::string& type); +} \ No newline at end of file diff --git a/src/lua/bindings/memory.cpp b/src/lua/bindings/memory.cpp index 0c25b7e2..da48c8a4 100644 --- a/src/lua/bindings/memory.cpp +++ b/src/lua/bindings/memory.cpp @@ -194,18 +194,16 @@ namespace lua::memory } } - static std::unordered_map> target_func_ptr_to_hook; - static bool pre_callback(const runtime_func_t::parameters_t* params, const uint8_t param_count, runtime_func_t::return_value_t* return_value, const uintptr_t target_func_ptr) { - const auto& dyn_hook = target_func_ptr_to_hook[target_func_ptr]; + const auto& dyn_hook = big::g_lua_manager->m_target_func_ptr_to_dynamic_hook[target_func_ptr]; return big::g_lua_manager ->dynamic_hook_pre_callbacks(target_func_ptr, dyn_hook->m_return_type, return_value, dyn_hook->m_param_types, params, param_count); } static void post_callback(const runtime_func_t::parameters_t* params, const uint8_t param_count, runtime_func_t::return_value_t* return_value, const uintptr_t target_func_ptr) { - const auto& dyn_hook = target_func_ptr_to_hook[target_func_ptr]; + const auto& dyn_hook = big::g_lua_manager->m_target_func_ptr_to_dynamic_hook[target_func_ptr]; big::g_lua_manager->dynamic_hook_post_callbacks(target_func_ptr, dyn_hook->m_return_type, return_value, dyn_hook->m_param_types, params, param_count); } @@ -240,25 +238,9 @@ namespace lua::memory // ``` static void dynamic_hook(const std::string& hook_name, const std::string& return_type, sol::table param_types_table, lua::memory::pointer& target_func_ptr_obj, sol::protected_function pre_lua_callback, sol::protected_function post_lua_callback, sol::this_state state_) { - const auto target_func_ptr = target_func_ptr_obj.get_address(); - if (!target_func_ptr_to_hook.contains(target_func_ptr)) + if (!target_func_ptr_obj.is_valid()) { - std::vector param_types; - for (const auto& [k, v] : param_types_table) - { - if (v.is()) - { - param_types.push_back(v.as()); - } - } - - std::unique_ptr runtime_func = std::make_unique(); - const auto jitted_func = runtime_func->make_jit_func(return_type, param_types, asmjit::Arch::kHost, pre_callback, post_callback, target_func_ptr); - - target_func_ptr_to_hook.emplace(target_func_ptr, std::move(runtime_func)); - - // TODO: The detour_hook is never cleaned up on menu unload. - target_func_ptr_to_hook[target_func_ptr]->create_and_enable_hook(hook_name, target_func_ptr, jitted_func); + return; } big::lua_module* module = sol::state_view(state_)["!this"]; @@ -266,18 +248,56 @@ namespace lua::memory { return; } + + const auto target_func_ptr = target_func_ptr_obj.get_address(); + + bool need_hook = false; if (pre_lua_callback.valid()) { module->m_dynamic_hook_pre_callbacks[target_func_ptr].push_back(pre_lua_callback); + need_hook = true; } if (post_lua_callback.valid()) { module->m_dynamic_hook_post_callbacks[target_func_ptr].push_back(post_lua_callback); + need_hook = true; + } + + if (need_hook) + { + std::shared_ptr runtime_func; + + if (!big::g_lua_manager->m_target_func_ptr_to_dynamic_hook.contains(target_func_ptr)) + { + std::vector param_types; + for (const auto& [k, v] : param_types_table) + { + if (v.is()) + { + param_types.push_back(v.as()); + } + } + + runtime_func = std::make_shared(); + const auto jitted_func = runtime_func->make_jit_func(return_type, param_types, asmjit::Arch::kHost, pre_callback, post_callback, target_func_ptr); + + big::g_lua_manager->m_target_func_ptr_to_dynamic_hook[target_func_ptr] = runtime_func.get(); + + big::g_lua_manager->m_target_func_ptr_to_dynamic_hook[target_func_ptr]->create_and_enable_hook(hook_name, target_func_ptr, jitted_func); + } + else + { + // lua modules own and share the runtime_func_t object, such as when no module reference it anymore the hook detour get cleaned up. + runtime_func = big::g_lua_manager->get_existing_dynamic_hook(target_func_ptr); + } + + if (runtime_func) + { + module->m_dynamic_hooks.push_back(runtime_func); + } } } - static std::unordered_map> jitted_binded_funcs; - static std::string get_jitted_lua_func_global_name(uintptr_t function_to_call_ptr) { return std::format("__dynamic_call_{}", function_to_call_ptr); @@ -292,14 +312,8 @@ namespace lua::memory } }; - static void jit_lua_binded_func(uintptr_t function_to_call_ptr, const asmjit::FuncSignature& function_to_call_sig, const asmjit::Arch& arch, std::vector param_types, type_info_t return_type, lua_State* lua_state, const std::string& jitted_lua_func_global_name) + static std::unique_ptr jit_lua_binded_func(uintptr_t function_to_call_ptr, const asmjit::FuncSignature& function_to_call_sig, const asmjit::Arch& arch, std::vector param_types, type_info_t return_type, lua_State* lua_state, const std::string& jitted_lua_func_global_name) { - const auto it = jitted_binded_funcs.find(function_to_call_ptr); - if (it != jitted_binded_funcs.end()) - { - return; - } - asmjit::CodeHolder code; auto env = asmjit::Environment::host(); env.setArch(arch); @@ -325,9 +339,6 @@ namespace lua::memory asmjit_error_handler_t asmjit_error_handler; code.setErrorHandler(&asmjit_error_handler); - // too small to really need it - func->frame().resetPreservedFP(); - // map argument slots to registers, following abi. std::vector arg_registers; for (uint8_t arg_index = 0; arg_index < function_to_call_sig.argCount(); arg_index++) @@ -444,7 +455,7 @@ namespace lua::memory else { LOG(FATAL) << "Return val wider than 64bits not supported"; - return; + return nullptr; } function_to_call_invoke_node->setRet(0, function_to_call_return_val_reg); @@ -529,9 +540,9 @@ namespace lua::memory size_t size = code.codeSize(); // Allocate a virtual memory (executable). - static std::vector jit_function_buffer(size); + auto jit_function_buffer = std::make_unique(size); DWORD old_protect; - VirtualProtect(jit_function_buffer.data(), size, PAGE_EXECUTE_READWRITE, &old_protect); + VirtualProtect(jit_function_buffer.get(), size, PAGE_EXECUTE_READWRITE, &old_protect); // if multiple sections, resolve linkage (1 atm) if (code.hasUnresolvedLinks()) @@ -540,13 +551,15 @@ namespace lua::memory } // Relocate to the base-address of the allocated memory. - code.relocateToBase((uintptr_t)jit_function_buffer.data()); - code.copyFlattenedData(jit_function_buffer.data(), size); + code.relocateToBase((uintptr_t)jit_function_buffer.get()); + code.copyFlattenedData(jit_function_buffer.get(), size); LOG(VERBOSE) << "JIT Stub: " << log.data(); - lua_pushcfunction(lua_state, (lua_CFunction)jit_function_buffer.data()); + lua_pushcfunction(lua_state, (lua_CFunction)jit_function_buffer.get()); lua_setglobal(lua_state, jitted_lua_func_global_name.c_str()); + + return jit_function_buffer; } // Lua API: Function @@ -581,6 +594,17 @@ namespace lua::memory // ``` static std::string dynamic_call(const std::string& return_type, sol::table param_types_table, lua::memory::pointer& target_func_ptr_obj, sol::this_state state_) { + big::lua_module* module = sol::state_view(state_)["!this"]; + if (!module) + { + return ""; + } + + if (!target_func_ptr_obj.is_valid()) + { + return ""; + } + const auto target_func_ptr = target_func_ptr_obj.get_address(); const auto jitted_lua_func_global_name = get_jitted_lua_func_global_name(target_func_ptr); @@ -610,13 +634,25 @@ namespace lua::memory param_types.push_back(get_type_info_from_string(s)); } - jit_lua_binded_func(target_func_ptr, - sig, - asmjit::Arch::kHost, - param_types, - get_type_info_from_string(return_type), - state_.L, - jitted_lua_func_global_name); + if (!module->m_dynamic_call_jit_functions.contains(target_func_ptr)) + { + auto jitted_func = jit_lua_binded_func(target_func_ptr, + sig, + asmjit::Arch::kHost, + param_types, + get_type_info_from_string(return_type), + state_.L, + jitted_lua_func_global_name); + + if (jitted_func) + { + module->m_dynamic_call_jit_functions.emplace(target_func_ptr, std::move(jitted_func)); + } + else + { + return ""; + } + } return jitted_lua_func_global_name; } diff --git a/src/lua/bindings/runtime_func_t.cpp b/src/lua/bindings/runtime_func_t.cpp new file mode 100644 index 00000000..59e4d811 --- /dev/null +++ b/src/lua/bindings/runtime_func_t.cpp @@ -0,0 +1,312 @@ +#pragma once +#include "runtime_func_t.hpp" + +#include "lua/lua_manager.hpp" + +#include + +namespace lua::memory +{ + char* runtime_func_t::parameters_t::get_arg_ptr(const uint8_t idx) const + { + return ((char*)&m_arguments) + sizeof(uintptr_t) * idx; + } + + unsigned char* runtime_func_t::return_value_t::get() const + { + return (unsigned char*)&m_return_value; + } + + runtime_func_t::runtime_func_t() + { + m_detour = std::make_unique(); + m_return_type = type_info_t::none_; + } + + runtime_func_t::~runtime_func_t() + { + big::g_lua_manager->m_target_func_ptr_to_dynamic_hook.erase(m_target_func_ptr); + } + + void runtime_func_t::debug_print_args(const asmjit::FuncSignature& sig) + { + for (uint8_t arg_index_debug = 0; arg_index_debug < sig.argCount(); arg_index_debug++) + { + const auto arg_type_debug = sig.args()[arg_index_debug]; + + LOG(VERBOSE) << (int)arg_type_debug; + } + } + + uintptr_t runtime_func_t::make_jit_func(const asmjit::FuncSignature& sig, const asmjit::Arch arch, const user_pre_callback_t pre_callback, const user_post_callback_t post_callback, const uintptr_t target_func_ptr) + { + asmjit::CodeHolder code; + auto env = asmjit::Environment::host(); + env.setArch(arch); + code.init(env); + + // initialize function + asmjit::x86::Compiler cc(&code); + asmjit::FuncNode* func = cc.addFunc(sig); + + asmjit::StringLogger log; + // clang-format off + const auto format_flags = + asmjit::FormatFlags::kMachineCode | asmjit::FormatFlags::kExplainImms | asmjit::FormatFlags::kRegCasts | + asmjit::FormatFlags::kHexImms | asmjit::FormatFlags::kHexOffsets | asmjit::FormatFlags::kPositions; + // clang-format on + + log.addFlags(format_flags); + code.setLogger(&log); + + // map argument slots to registers, following abi. + std::vector arg_registers; + for (uint8_t arg_index = 0; arg_index < sig.argCount(); arg_index++) + { + const auto arg_type = sig.args()[arg_index]; + + asmjit::x86::Reg arg; + if (is_general_register(arg_type)) + { + arg = cc.newUIntPtr(); + } + else if (is_XMM_register(arg_type)) + { + arg = cc.newXmm(); + } + else + { + LOG(FATAL) << "Parameters wider than 64bits not supported, index: " << arg_index << " | " << (int)arg_type; + debug_print_args(sig); + + return 0; + } + + func->setArg(arg_index, arg); + arg_registers.push_back(arg); + } + + // setup the stack structure to hold arguments for user callback + uint32_t stack_size = (uint32_t)(sizeof(uintptr_t) * sig.argCount()); + m_args_stack = cc.newStack(stack_size, 16); + asmjit::x86::Mem args_stack_index(m_args_stack); + + // assigns some register as index reg + asmjit::x86::Gp i = cc.newUIntPtr(); + + // stack_index <- stack[i]. + args_stack_index.setIndex(i); + + // r/w are sizeof(uintptr_t) width now + args_stack_index.setSize(sizeof(uintptr_t)); + + // set i = 0 + cc.mov(i, 0); + // mov from arguments registers into the stack structure + for (uint8_t argIdx = 0; argIdx < sig.argCount(); argIdx++) + { + const auto argType = sig.args()[argIdx]; + + // have to cast back to explicit register types to gen right mov type + if (is_general_register(argType)) + { + cc.mov(args_stack_index, arg_registers.at(argIdx).as()); + } + else if (is_XMM_register(argType)) + { + cc.movq(args_stack_index, arg_registers.at(argIdx).as()); + } + else + { + LOG(FATAL) << "Parameters wider than 64bits not supported, index: " << argIdx << " | " << (int)argType; + debug_print_args(sig); + + return 0; + } + + // next structure slot (+= sizeof(uintptr_t)) + cc.add(i, sizeof(uintptr_t)); + } + + // get pointer to stack structure and pass it to the user pre callback + asmjit::x86::Gp arg_struct = cc.newUIntPtr("arg_struct"); + cc.lea(arg_struct, m_args_stack); + + // fill reg to pass struct arg count to callback + asmjit::x86::Gp arg_param_count = cc.newUInt8(); + cc.mov(arg_param_count, (uint8_t)sig.argCount()); + + // create buffer for ret val + asmjit::x86::Mem return_stack = cc.newStack(sizeof(uintptr_t), 16); + asmjit::x86::Gp return_struct = cc.newUIntPtr("return_struct"); + cc.lea(return_struct, return_stack); + + // fill reg to pass target function pointer to callback + asmjit::x86::Gp target_func_ptr_reg = cc.newUIntPtr(); + cc.mov(target_func_ptr_reg, target_func_ptr); + + asmjit::Label original_invoke_label = cc.newLabel(); + asmjit::Label skip_original_invoke_label = cc.newLabel(); + + // invoke the user pre callback + asmjit::InvokeNode* pre_callback_invoke_node; + cc.invoke(&pre_callback_invoke_node, (uintptr_t)pre_callback, asmjit::FuncSignatureT()); + + // call to user provided function (use ABI of host compiler) + pre_callback_invoke_node->setArg(0, arg_struct); + pre_callback_invoke_node->setArg(1, arg_param_count); + pre_callback_invoke_node->setArg(2, return_struct); + pre_callback_invoke_node->setArg(3, target_func_ptr_reg); + + // create a register for the user pre callback's return value + // Note: the size of the register is important for the test instruction. newUInt8 since the pre callback returns a bool. + asmjit::x86::Gp pre_callback_return_val = cc.newUInt8("pre_callback_return_val"); + // store the callback return value + pre_callback_invoke_node->setRet(0, pre_callback_return_val); + + // if the callback return value is zero, skip orig. + cc.test(pre_callback_return_val, pre_callback_return_val); + cc.jz(skip_original_invoke_label); + + // label to invoke the original function + cc.bind(original_invoke_label); + + // mov from arguments stack structure into regs + cc.mov(i, 0); // reset idx + for (uint8_t arg_idx = 0; arg_idx < sig.argCount(); arg_idx++) + { + const auto argType = sig.args()[arg_idx]; + + if (is_general_register(argType)) + { + cc.mov(arg_registers.at(arg_idx).as(), args_stack_index); + } + else if (is_XMM_register(argType)) + { + cc.movq(arg_registers.at(arg_idx).as(), args_stack_index); + } + else + { + LOG(FATAL) << "Parameters wider than 64bits not supported, index: " << arg_idx << " | " << (int)argType; + debug_print_args(sig); + + return 0; + } + + // next structure slot (+= sizeof(uint64_t)) + cc.add(i, sizeof(uint64_t)); + } + + // deref the trampoline ptr (holder must live longer, must be concrete reg since push later) + asmjit::x86::Gp original_ptr = cc.newUIntPtr(); + cc.mov(original_ptr, m_detour->get_original_ptr()); + cc.mov(original_ptr, asmjit::x86::ptr(original_ptr)); + + asmjit::InvokeNode* original_invoke_node; + cc.invoke(&original_invoke_node, original_ptr, sig); + for (uint8_t arg_index = 0; arg_index < sig.argCount(); arg_index++) + { + original_invoke_node->setArg(arg_index, arg_registers.at(arg_index)); + } + + if (sig.hasRet()) + { + if (is_general_register(sig.ret())) + { + asmjit::x86::Gp tmp = cc.newUIntPtr(); + original_invoke_node->setRet(0, tmp); + cc.mov(return_stack, tmp); + } + else + { + asmjit::x86::Xmm tmp = cc.newXmm(); + original_invoke_node->setRet(0, tmp); + cc.movq(return_stack, tmp); + } + } + + cc.bind(skip_original_invoke_label); + + asmjit::InvokeNode* post_callback_invoke_node; + cc.invoke(&post_callback_invoke_node, (uintptr_t)post_callback, asmjit::FuncSignatureT()); + + // Set arguments for the post callback + post_callback_invoke_node->setArg(0, arg_struct); + post_callback_invoke_node->setArg(1, arg_param_count); + post_callback_invoke_node->setArg(2, return_struct); + post_callback_invoke_node->setArg(3, target_func_ptr_reg); + + if (sig.hasRet()) + { + asmjit::x86::Mem return_stack_index(return_stack); + return_stack_index.setSize(sizeof(uintptr_t)); + if (is_general_register(sig.ret())) + { + asmjit::x86::Gp tmp2 = cc.newUIntPtr(); + cc.mov(tmp2, return_stack_index); + cc.ret(tmp2); + } + else + { + asmjit::x86::Xmm tmp2 = cc.newXmm(); + cc.movq(tmp2, return_stack_index); + cc.ret(tmp2); + } + } + + cc.endFunc(); + + // write to buffer + cc.finalize(); + + // worst case, overestimates for case trampolines needed + code.flatten(); + size_t size = code.codeSize(); + + // Allocate a virtual memory (executable). + m_jit_function_buffer.reserve(size); + + DWORD old_protect; + VirtualProtect(m_jit_function_buffer.data(), size, PAGE_EXECUTE_READWRITE, &old_protect); + + // if multiple sections, resolve linkage (1 atm) + if (code.hasUnresolvedLinks()) + { + code.resolveUnresolvedLinks(); + } + + // Relocate to the base-address of the allocated memory. + code.relocateToBase((uintptr_t)m_jit_function_buffer.data()); + code.copyFlattenedData(m_jit_function_buffer.data(), size); + + LOG(VERBOSE) << "JIT Stub: " << log.data(); + + return (uintptr_t)m_jit_function_buffer.data(); + } + + uintptr_t runtime_func_t::make_jit_func(const std::string& return_type, const std::vector& param_types, const asmjit::Arch arch, const user_pre_callback_t pre_callback, const user_post_callback_t post_callback, const uintptr_t target_func_ptr, std::string call_convention) + { + m_return_type = get_type_info_from_string(return_type); + + asmjit::FuncSignature sig(get_call_convention(call_convention), asmjit::FuncSignature::kNoVarArgs, get_type_id(return_type)); + + for (const std::string& s : param_types) + { + sig.addArg(get_type_id(s)); + m_param_types.push_back(get_type_info_from_string(s)); + } + + return make_jit_func(sig, arch, pre_callback, post_callback, target_func_ptr); + } + + void runtime_func_t::create_and_enable_hook(const std::string& hook_name, uintptr_t target_func_ptr, uintptr_t jitted_func_ptr) + { + m_target_func_ptr = target_func_ptr; + + m_detour->set_instance(hook_name, (void*)target_func_ptr, (void*)jitted_func_ptr); + + m_detour->enable(); + + MH_ApplyQueued(); + } +} diff --git a/src/lua/bindings/runtime_func_t.hpp b/src/lua/bindings/runtime_func_t.hpp index fbbee167..a6346aa7 100644 --- a/src/lua/bindings/runtime_func_t.hpp +++ b/src/lua/bindings/runtime_func_t.hpp @@ -1,4 +1,5 @@ #pragma once +#include "asmjit_helper.hpp" #include "hooking/detour_hook.hpp" #include "lua/bindings/type_info_t.hpp" @@ -6,147 +7,6 @@ namespace lua::memory { - // does a given type fit in a general purpose register (i.e. is it integer type) - inline bool is_general_register(const asmjit::TypeId type_id) - { - switch (type_id) - { - case asmjit::TypeId::kInt8: - case asmjit::TypeId::kUInt8: - case asmjit::TypeId::kInt16: - case asmjit::TypeId::kUInt16: - case asmjit::TypeId::kInt32: - case asmjit::TypeId::kUInt32: - case asmjit::TypeId::kInt64: - case asmjit::TypeId::kUInt64: - case asmjit::TypeId::kIntPtr: - case asmjit::TypeId::kUIntPtr: return true; - default: return false; - } - } - - // float, double, simd128 - inline bool is_XMM_register(const asmjit::TypeId type_id) - { - switch (type_id) - { - case asmjit::TypeId::kFloat32: - case asmjit::TypeId::kFloat64: return true; - default: return false; - } - } - - inline asmjit::CallConvId get_call_convention(const std::string& conv) - { - if (conv == "cdecl") - { - return asmjit::CallConvId::kCDecl; - } - else if (conv == "stdcall") - { - return asmjit::CallConvId::kStdCall; - } - else if (conv == "fastcall") - { - return asmjit::CallConvId::kFastCall; - } - - return asmjit::CallConvId::kHost; - } - - inline asmjit::TypeId get_type_id(const std::string& type) - { - if (type.find('*') != std::string::npos) - { - return asmjit::TypeId::kUIntPtr; - } - -#define TYPEID_MATCH_STR_IF(var, T) \ - if (var == #T) \ - { \ - return asmjit::TypeId(asmjit::TypeUtils::TypeIdOfT::kTypeId); \ - } -#define TYPEID_MATCH_STR_ELSEIF(var, T) \ - else if (var == #T) \ - { \ - return asmjit::TypeId(asmjit::TypeUtils::TypeIdOfT::kTypeId); \ - } - - TYPEID_MATCH_STR_IF(type, signed char) - TYPEID_MATCH_STR_ELSEIF(type, unsigned char) - TYPEID_MATCH_STR_ELSEIF(type, short) - TYPEID_MATCH_STR_ELSEIF(type, unsigned short) - TYPEID_MATCH_STR_ELSEIF(type, int) - TYPEID_MATCH_STR_ELSEIF(type, unsigned int) - TYPEID_MATCH_STR_ELSEIF(type, long) - TYPEID_MATCH_STR_ELSEIF(type, unsigned long) -#ifdef POLYHOOK2_OS_WINDOWS - TYPEID_MATCH_STR_ELSEIF(type, __int64) - TYPEID_MATCH_STR_ELSEIF(type, unsigned __int64) -#endif - TYPEID_MATCH_STR_ELSEIF(type, long long) - TYPEID_MATCH_STR_ELSEIF(type, unsigned long long) - TYPEID_MATCH_STR_ELSEIF(type, char) - TYPEID_MATCH_STR_ELSEIF(type, char16_t) - TYPEID_MATCH_STR_ELSEIF(type, char32_t) - TYPEID_MATCH_STR_ELSEIF(type, wchar_t) - TYPEID_MATCH_STR_ELSEIF(type, uint8_t) - TYPEID_MATCH_STR_ELSEIF(type, int8_t) - TYPEID_MATCH_STR_ELSEIF(type, uint16_t) - TYPEID_MATCH_STR_ELSEIF(type, int16_t) - TYPEID_MATCH_STR_ELSEIF(type, int32_t) - TYPEID_MATCH_STR_ELSEIF(type, uint32_t) - TYPEID_MATCH_STR_ELSEIF(type, uint64_t) - TYPEID_MATCH_STR_ELSEIF(type, int64_t) - TYPEID_MATCH_STR_ELSEIF(type, float) - TYPEID_MATCH_STR_ELSEIF(type, double) - TYPEID_MATCH_STR_ELSEIF(type, bool) - TYPEID_MATCH_STR_ELSEIF(type, void) - else if (type == "intptr_t") - { - return asmjit::TypeId::kIntPtr; - } - else if (type == "uintptr_t") - { - return asmjit::TypeId::kUIntPtr; - } - - return asmjit::TypeId::kVoid; - } - - static type_info_t get_type_info_from_string(const std::string& s) - { - if ((s.contains("const") && s.contains("char") && s.contains("*")) || s.contains("string")) - { - return type_info_t::string_; - } - else if (s.contains("bool")) - { - return type_info_t::boolean_; - } - else if (s.contains("ptr") || s.contains("pointer") || s.contains("*")) - { - // passing lua::memory::pointer - return type_info_t::ptr_; - } - else if (s.contains("float")) - { - return type_info_t::float_; - } - else if (s.contains("double")) - { - return type_info_t::double_; - } - else if (s.contains("vector3")) - { - return type_info_t::vector3_; - } - else - { - return type_info_t::integer_; - } - } - class runtime_func_t { std::vector m_jit_function_buffer; @@ -154,6 +14,8 @@ namespace lua::memory std::unique_ptr m_detour; + uintptr_t m_target_func_ptr{}; + public: type_info_t m_return_type; std::vector m_param_types; @@ -177,10 +39,7 @@ namespace lua::memory volatile uintptr_t m_arguments; // must be char* for aliasing rules to work when reading back out - char* get_arg_ptr(const uint8_t idx) const - { - return ((char*)&m_arguments) + sizeof(uintptr_t) * idx; - } + char* get_arg_ptr(const uint8_t idx) const; }; class return_value_t @@ -188,294 +47,33 @@ namespace lua::memory uintptr_t m_return_value; public: - unsigned char* get() const - { - return (unsigned char*)&m_return_value; - } + unsigned char* get() const; }; typedef bool (*user_pre_callback_t)(const parameters_t* params, const uint8_t parameters_count, return_value_t* return_value, const uintptr_t target_func_ptr); typedef void (*user_post_callback_t)(const parameters_t* params, const uint8_t parameters_count, return_value_t* return_value, const uintptr_t target_func_ptr); - runtime_func_t() + runtime_func_t(); + + ~runtime_func_t(); + + uintptr_t get_target_func_ptr() const { - m_detour = std::make_unique(); - m_return_type = type_info_t::none_; + return m_target_func_ptr; } - ~runtime_func_t() - { - } + void debug_print_args(const asmjit::FuncSignature& sig); // Construct a callback given the raw signature at runtime. 'Callback' param is the C stub to transfer to, // where parameters can be modified through a structure which is written back to the parameter slots depending // on calling convention. - uintptr_t make_jit_func(const asmjit::FuncSignature& sig, const asmjit::Arch arch, const user_pre_callback_t pre_callback, const user_post_callback_t post_callback, const uintptr_t target_func_ptr) - { - asmjit::CodeHolder code; - auto env = asmjit::Environment::host(); - env.setArch(arch); - code.init(env); - - // initialize function - asmjit::x86::Compiler cc(&code); - asmjit::FuncNode* func = cc.addFunc(sig); - - asmjit::StringLogger log; - // clang-format off - const auto format_flags = - asmjit::FormatFlags::kMachineCode | asmjit::FormatFlags::kExplainImms | asmjit::FormatFlags::kRegCasts | - asmjit::FormatFlags::kHexImms | asmjit::FormatFlags::kHexOffsets | asmjit::FormatFlags::kPositions; - // clang-format on - - log.addFlags(format_flags); - code.setLogger(&log); - - // too small to really need it - func->frame().resetPreservedFP(); - - // map argument slots to registers, following abi. - std::vector arg_registers; - for (uint8_t arg_index = 0; arg_index < sig.argCount(); arg_index++) - { - const auto arg_type = sig.args()[arg_index]; - - asmjit::x86::Reg arg; - if (is_general_register(arg_type)) - { - arg = cc.newUIntPtr(); - } - else if (is_XMM_register(arg_type)) - { - arg = cc.newXmm(); - } - else - { - LOG(FATAL) << "Parameters wider than 64bits not supported"; - return 0; - } - - func->setArg(arg_index, arg); - arg_registers.push_back(arg); - } - - // setup the stack structure to hold arguments for user callback - uint32_t stack_size = (uint32_t)(sizeof(uintptr_t) * sig.argCount()); - m_args_stack = cc.newStack(stack_size, 16); - asmjit::x86::Mem args_stack_index(m_args_stack); - - // assigns some register as index reg - asmjit::x86::Gp i = cc.newUIntPtr(); - - // stack_index <- stack[i]. - args_stack_index.setIndex(i); - - // r/w are sizeof(uintptr_t) width now - args_stack_index.setSize(sizeof(uintptr_t)); - - // set i = 0 - cc.mov(i, 0); - // mov from arguments registers into the stack structure - for (uint8_t argIdx = 0; argIdx < sig.argCount(); argIdx++) - { - const auto argType = sig.args()[argIdx]; - - // have to cast back to explicit register types to gen right mov type - if (is_general_register(argType)) - { - cc.mov(args_stack_index, arg_registers.at(argIdx).as()); - } - else if (is_XMM_register(argType)) - { - cc.movq(args_stack_index, arg_registers.at(argIdx).as()); - } - else - { - LOG(FATAL) << "Parameters wider than 64bits not supported"; - return 0; - } - - // next structure slot (+= sizeof(uintptr_t)) - cc.add(i, sizeof(uintptr_t)); - } - - // get pointer to stack structure and pass it to the user pre callback - asmjit::x86::Gp arg_struct = cc.newUIntPtr("arg_struct"); - cc.lea(arg_struct, m_args_stack); - - // fill reg to pass struct arg count to callback - asmjit::x86::Gp arg_param_count = cc.newUInt8(); - cc.mov(arg_param_count, (uint8_t)sig.argCount()); - - // create buffer for ret val - asmjit::x86::Mem return_stack = cc.newStack(sizeof(uintptr_t), 16); - asmjit::x86::Gp return_struct = cc.newUIntPtr("return_struct"); - cc.lea(return_struct, return_stack); - - // fill reg to pass target function pointer to callback - asmjit::x86::Gp target_func_ptr_reg = cc.newUIntPtr(); - cc.mov(target_func_ptr_reg, target_func_ptr); - - asmjit::Label original_invoke_label = cc.newLabel(); - asmjit::Label skip_original_invoke_label = cc.newLabel(); - - // invoke the user pre callback - asmjit::InvokeNode* pre_callback_invoke_node; - cc.invoke(&pre_callback_invoke_node, (uintptr_t)pre_callback, asmjit::FuncSignatureT()); - - // call to user provided function (use ABI of host compiler) - pre_callback_invoke_node->setArg(0, arg_struct); - pre_callback_invoke_node->setArg(1, arg_param_count); - pre_callback_invoke_node->setArg(2, return_struct); - pre_callback_invoke_node->setArg(3, target_func_ptr_reg); - - // create a register for the user pre callback's return value - // Note: the size of the register is important for the test instruction. newUInt8 since the pre callback returns a bool. - asmjit::x86::Gp pre_callback_return_val = cc.newUInt8("pre_callback_return_val"); - // store the callback return value - pre_callback_invoke_node->setRet(0, pre_callback_return_val); - - // if the callback return value is zero, skip orig. - cc.test(pre_callback_return_val, pre_callback_return_val); - cc.jz(skip_original_invoke_label); - - // label to invoke the original function - cc.bind(original_invoke_label); - - // mov from arguments stack structure into regs - cc.mov(i, 0); // reset idx - for (uint8_t arg_idx = 0; arg_idx < sig.argCount(); arg_idx++) - { - const auto argType = sig.args()[arg_idx]; - - if (is_general_register(argType)) - { - cc.mov(arg_registers.at(arg_idx).as(), args_stack_index); - } - else if (is_XMM_register(argType)) - { - cc.movq(arg_registers.at(arg_idx).as(), args_stack_index); - } - else - { - LOG(FATAL) << "Parameters wider than 64bits not supported"; - return 0; - } - - // next structure slot (+= sizeof(uint64_t)) - cc.add(i, sizeof(uint64_t)); - } - - // deref the trampoline ptr (holder must live longer, must be concrete reg since push later) - asmjit::x86::Gp original_ptr = cc.zbx(); - cc.mov(original_ptr, m_detour->get_original_ptr()); - cc.mov(original_ptr, asmjit::x86::ptr(original_ptr)); - - asmjit::InvokeNode* original_invoke_node; - cc.invoke(&original_invoke_node, original_ptr, sig); - for (uint8_t arg_index = 0; arg_index < sig.argCount(); arg_index++) - { - original_invoke_node->setArg(arg_index, arg_registers.at(arg_index)); - } - - if (sig.hasRet()) - { - if (is_general_register(sig.ret())) - { - asmjit::x86::Gp tmp = cc.newUIntPtr(); - original_invoke_node->setRet(0, tmp); - cc.mov(return_stack, tmp); - } - else - { - asmjit::x86::Xmm tmp = cc.newXmm(); - original_invoke_node->setRet(0, tmp); - cc.movq(return_stack, tmp); - } - } - - cc.bind(skip_original_invoke_label); - - asmjit::InvokeNode* post_callback_invoke_node; - cc.invoke(&post_callback_invoke_node, (uintptr_t)post_callback, asmjit::FuncSignatureT()); - - // Set arguments for the post callback - post_callback_invoke_node->setArg(0, arg_struct); - post_callback_invoke_node->setArg(1, arg_param_count); - post_callback_invoke_node->setArg(2, return_struct); - post_callback_invoke_node->setArg(3, target_func_ptr_reg); - - if (sig.hasRet()) - { - asmjit::x86::Mem return_stack_index(return_stack); - return_stack_index.setSize(sizeof(uintptr_t)); - if (is_general_register(sig.ret())) - { - asmjit::x86::Gp tmp2 = cc.newUIntPtr(); - cc.mov(tmp2, return_stack_index); - cc.ret(tmp2); - } - else - { - asmjit::x86::Xmm tmp2 = cc.newXmm(); - cc.movq(tmp2, return_stack_index); - cc.ret(tmp2); - } - } - - cc.endFunc(); - - // write to buffer - cc.finalize(); - - // worst case, overestimates for case trampolines needed - code.flatten(); - size_t size = code.codeSize(); - - // Allocate a virtual memory (executable). - m_jit_function_buffer.reserve(size); - - DWORD old_protect; - VirtualProtect(m_jit_function_buffer.data(), size, PAGE_EXECUTE_READWRITE, &old_protect); - - // if multiple sections, resolve linkage (1 atm) - if (code.hasUnresolvedLinks()) - { - code.resolveUnresolvedLinks(); - } - - // Relocate to the base-address of the allocated memory. - code.relocateToBase((uintptr_t)m_jit_function_buffer.data()); - code.copyFlattenedData(m_jit_function_buffer.data(), size); - - LOG(VERBOSE) << "JIT Stub: " << log.data(); - - return (uintptr_t)m_jit_function_buffer.data(); - } + uintptr_t make_jit_func(const asmjit::FuncSignature& sig, const asmjit::Arch arch, const user_pre_callback_t pre_callback, const user_post_callback_t post_callback, const uintptr_t target_func_ptr); // Construct a callback given the typedef as a string. Types are any valid C/C++ data type (basic types), and pointers to // anything are just a uintptr_t. Calling convention is defaulted to whatever is typical for the compiler you use, you can override with // stdcall, fastcall, or cdecl (cdecl is default on x86). On x64 those map to the same thing. - uintptr_t make_jit_func(const std::string& return_type, const std::vector& param_types, const asmjit::Arch arch, const user_pre_callback_t pre_callback, const user_post_callback_t post_callback, const uintptr_t target_func_ptr, std::string call_convention = "") - { - m_return_type = get_type_info_from_string(return_type); + uintptr_t make_jit_func(const std::string& return_type, const std::vector& param_types, const asmjit::Arch arch, const user_pre_callback_t pre_callback, const user_post_callback_t post_callback, const uintptr_t target_func_ptr, std::string call_convention = ""); - asmjit::FuncSignature sig(get_call_convention(call_convention), asmjit::FuncSignature::kNoVarArgs, get_type_id(return_type)); - - for (const std::string& s : param_types) - { - sig.addArg(get_type_id(s)); - m_param_types.push_back(get_type_info_from_string(s)); - } - - return make_jit_func(sig, arch, pre_callback, post_callback, target_func_ptr); - } - - void create_and_enable_hook(const std::string& hook_name, uintptr_t target_func_ptr, uintptr_t jitted_func_ptr) - { - m_detour->set_instance(hook_name, (void*)target_func_ptr, (void*)jitted_func_ptr); - - m_detour->enable(); - } + void create_and_enable_hook(const std::string& hook_name, uintptr_t target_func_ptr, uintptr_t jitted_func_ptr); }; } \ No newline at end of file diff --git a/src/lua/bindings/type_info_t.cpp b/src/lua/bindings/type_info_t.cpp new file mode 100644 index 00000000..42808dd5 --- /dev/null +++ b/src/lua/bindings/type_info_t.cpp @@ -0,0 +1,37 @@ +#include "type_info_t.hpp" + +namespace lua::memory +{ + type_info_t get_type_info_from_string(const std::string& s) + { + if ((s.contains("const") && s.contains("char") && s.contains("*")) || s.contains("string")) + { + return type_info_t::string_; + } + else if (s.contains("bool")) + { + return type_info_t::boolean_; + } + else if (s.contains("ptr") || s.contains("pointer") || s.contains("*")) + { + // passing lua::memory::pointer + return type_info_t::ptr_; + } + else if (s.contains("float")) + { + return type_info_t::float_; + } + else if (s.contains("double")) + { + return type_info_t::double_; + } + else if (s.contains("vector3")) + { + return type_info_t::vector3_; + } + else + { + return type_info_t::integer_; + } + } +} \ No newline at end of file diff --git a/src/lua/bindings/type_info_t.hpp b/src/lua/bindings/type_info_t.hpp index b710f52f..62173c1e 100644 --- a/src/lua/bindings/type_info_t.hpp +++ b/src/lua/bindings/type_info_t.hpp @@ -13,4 +13,6 @@ namespace lua::memory double_, vector3_ }; + + type_info_t get_type_info_from_string(const std::string& s); } \ No newline at end of file diff --git a/src/lua/lua_manager.cpp b/src/lua/lua_manager.cpp index 756afd93..52b21161 100644 --- a/src/lua/lua_manager.cpp +++ b/src/lua/lua_manager.cpp @@ -414,4 +414,20 @@ namespace big LOG(FATAL) << state["!module_name"].get() << ": " << error.what(); Logger::FlushQueue(); } + + std::shared_ptr lua_manager::get_existing_dynamic_hook(const uintptr_t target_func_ptr) + { + for (const auto& mod : m_modules) + { + for (const auto& dyn_hook : mod->m_dynamic_hooks) + { + if (dyn_hook->get_target_func_ptr() == target_func_ptr) + { + return dyn_hook; + } + } + } + + return nullptr; + } } \ No newline at end of file diff --git a/src/lua/lua_manager.hpp b/src/lua/lua_manager.hpp index 998aa651..8269c72f 100644 --- a/src/lua/lua_manager.hpp +++ b/src/lua/lua_manager.hpp @@ -1,4 +1,5 @@ #pragma once +#include "bindings/runtime_func_t.hpp" #include "lua_module.hpp" namespace big @@ -43,6 +44,9 @@ namespace big return m_scripts_config_folder; } + // non owning map + std::unordered_map m_target_func_ptr_to_dynamic_hook; + std::weak_ptr get_module(rage::joaat_t module_id); std::weak_ptr get_disabled_module(rage::joaat_t module_id); @@ -101,6 +105,8 @@ namespace big return std::nullopt; } + std::shared_ptr get_existing_dynamic_hook(const uintptr_t target_func_ptr); + inline void for_each_module(auto func) { std::lock_guard guard(m_module_lock); diff --git a/src/lua/lua_module.hpp b/src/lua/lua_module.hpp index 177a5f43..6eef53ab 100644 --- a/src/lua/lua_module.hpp +++ b/src/lua/lua_module.hpp @@ -3,8 +3,8 @@ #include "bindings/gui/gui_element.hpp" #include "core/data/menu_event.hpp" #include "lua/bindings/runtime_func_t.hpp" -#include "lua/bindings/type_info_t.hpp" #include "lua/bindings/scr_patch.hpp" +#include "lua/bindings/type_info_t.hpp" #include "lua_patch.hpp" #include "services/gui/gui_service.hpp" @@ -41,9 +41,13 @@ namespace big std::unordered_map> m_event_callbacks; std::vector m_allocated_memory; + // lua modules own and share the runtime_func_t object, such as when no module reference it anymore the hook detour get cleaned up. + std::vector> m_dynamic_hooks; std::unordered_map> m_dynamic_hook_pre_callbacks; std::unordered_map> m_dynamic_hook_post_callbacks; + std::unordered_map> m_dynamic_call_jit_functions; + lua_module(const std::filesystem::path& module_path, folder& scripts_folder, bool disabled = false); ~lua_module();