fix(lua): fix dynamic hooks/calls: (#3629)

- dynamic hook were never actually applied,
- rbx was getting trashed inside jit hook func
- dynamic hook cleanup on lua script unload
- support dynamic call
This commit is contained in:
Quentin 2024-08-28 23:39:01 +02:00 committed by GitHub
parent a63279dde9
commit 3ad9306594
10 changed files with 602 additions and 465 deletions

View File

@ -0,0 +1,111 @@
#pragma once
#include "asmjit_helper.hpp"
namespace lua::memory
{
bool is_general_register(const asmjit::TypeId type_id)
{
switch (type_id)
{
case asmjit::TypeId::kInt8:
case asmjit::TypeId::kUInt8:
case asmjit::TypeId::kInt16:
case asmjit::TypeId::kUInt16:
case asmjit::TypeId::kInt32:
case asmjit::TypeId::kUInt32:
case asmjit::TypeId::kInt64:
case asmjit::TypeId::kUInt64:
case asmjit::TypeId::kIntPtr:
case asmjit::TypeId::kUIntPtr: return true;
default: return false;
}
}
bool is_XMM_register(const asmjit::TypeId type_id)
{
switch (type_id)
{
case asmjit::TypeId::kFloat32:
case asmjit::TypeId::kFloat64: return true;
default: return false;
}
}
asmjit::CallConvId get_call_convention(const std::string& conv)
{
if (conv == "cdecl")
{
return asmjit::CallConvId::kCDecl;
}
else if (conv == "stdcall")
{
return asmjit::CallConvId::kStdCall;
}
else if (conv == "fastcall")
{
return asmjit::CallConvId::kFastCall;
}
return asmjit::CallConvId::kHost;
}
asmjit::TypeId get_type_id(const std::string& type)
{
if (type.find('*') != std::string::npos)
{
return asmjit::TypeId::kUIntPtr;
}
#define TYPEID_MATCH_STR_IF(var, T) \
if (var == #T) \
{ \
return asmjit::TypeId(asmjit::TypeUtils::TypeIdOfT<T>::kTypeId); \
}
#define TYPEID_MATCH_STR_ELSEIF(var, T) \
else if (var == #T) \
{ \
return asmjit::TypeId(asmjit::TypeUtils::TypeIdOfT<T>::kTypeId); \
}
TYPEID_MATCH_STR_IF(type, signed char)
TYPEID_MATCH_STR_ELSEIF(type, unsigned char)
TYPEID_MATCH_STR_ELSEIF(type, short)
TYPEID_MATCH_STR_ELSEIF(type, unsigned short)
TYPEID_MATCH_STR_ELSEIF(type, int)
TYPEID_MATCH_STR_ELSEIF(type, unsigned int)
TYPEID_MATCH_STR_ELSEIF(type, long)
TYPEID_MATCH_STR_ELSEIF(type, unsigned long)
#ifdef POLYHOOK2_OS_WINDOWS
TYPEID_MATCH_STR_ELSEIF(type, __int64)
TYPEID_MATCH_STR_ELSEIF(type, unsigned __int64)
#endif
TYPEID_MATCH_STR_ELSEIF(type, long long)
TYPEID_MATCH_STR_ELSEIF(type, unsigned long long)
TYPEID_MATCH_STR_ELSEIF(type, char)
TYPEID_MATCH_STR_ELSEIF(type, char16_t)
TYPEID_MATCH_STR_ELSEIF(type, char32_t)
TYPEID_MATCH_STR_ELSEIF(type, wchar_t)
TYPEID_MATCH_STR_ELSEIF(type, uint8_t)
TYPEID_MATCH_STR_ELSEIF(type, int8_t)
TYPEID_MATCH_STR_ELSEIF(type, uint16_t)
TYPEID_MATCH_STR_ELSEIF(type, int16_t)
TYPEID_MATCH_STR_ELSEIF(type, int32_t)
TYPEID_MATCH_STR_ELSEIF(type, uint32_t)
TYPEID_MATCH_STR_ELSEIF(type, uint64_t)
TYPEID_MATCH_STR_ELSEIF(type, int64_t)
TYPEID_MATCH_STR_ELSEIF(type, float)
TYPEID_MATCH_STR_ELSEIF(type, double)
TYPEID_MATCH_STR_ELSEIF(type, bool)
TYPEID_MATCH_STR_ELSEIF(type, void)
else if (type == "intptr_t")
{
return asmjit::TypeId::kIntPtr;
}
else if (type == "uintptr_t")
{
return asmjit::TypeId::kUIntPtr;
}
return asmjit::TypeId::kVoid;
}
}

View File

@ -0,0 +1,15 @@
#pragma once
#include <asmjit/asmjit.h>
namespace lua::memory
{
// does a given type fit in a general purpose register (i.e. is it integer type)
bool is_general_register(const asmjit::TypeId type_id);
// float, double, simd128
bool is_XMM_register(const asmjit::TypeId type_id);
asmjit::CallConvId get_call_convention(const std::string& conv);
asmjit::TypeId get_type_id(const std::string& type);
}

View File

@ -194,18 +194,16 @@ namespace lua::memory
}
}
static std::unordered_map<uintptr_t, std::unique_ptr<runtime_func_t>> target_func_ptr_to_hook;
static bool pre_callback(const runtime_func_t::parameters_t* params, const uint8_t param_count, runtime_func_t::return_value_t* return_value, const uintptr_t target_func_ptr)
{
const auto& dyn_hook = target_func_ptr_to_hook[target_func_ptr];
const auto& dyn_hook = big::g_lua_manager->m_target_func_ptr_to_dynamic_hook[target_func_ptr];
return big::g_lua_manager
->dynamic_hook_pre_callbacks(target_func_ptr, dyn_hook->m_return_type, return_value, dyn_hook->m_param_types, params, param_count);
}
static void post_callback(const runtime_func_t::parameters_t* params, const uint8_t param_count, runtime_func_t::return_value_t* return_value, const uintptr_t target_func_ptr)
{
const auto& dyn_hook = target_func_ptr_to_hook[target_func_ptr];
const auto& dyn_hook = big::g_lua_manager->m_target_func_ptr_to_dynamic_hook[target_func_ptr];
big::g_lua_manager->dynamic_hook_post_callbacks(target_func_ptr, dyn_hook->m_return_type, return_value, dyn_hook->m_param_types, params, param_count);
}
@ -240,25 +238,9 @@ namespace lua::memory
// ```
static void dynamic_hook(const std::string& hook_name, const std::string& return_type, sol::table param_types_table, lua::memory::pointer& target_func_ptr_obj, sol::protected_function pre_lua_callback, sol::protected_function post_lua_callback, sol::this_state state_)
{
const auto target_func_ptr = target_func_ptr_obj.get_address();
if (!target_func_ptr_to_hook.contains(target_func_ptr))
if (!target_func_ptr_obj.is_valid())
{
std::vector<std::string> param_types;
for (const auto& [k, v] : param_types_table)
{
if (v.is<const char*>())
{
param_types.push_back(v.as<const char*>());
}
}
std::unique_ptr<runtime_func_t> runtime_func = std::make_unique<runtime_func_t>();
const auto jitted_func = runtime_func->make_jit_func(return_type, param_types, asmjit::Arch::kHost, pre_callback, post_callback, target_func_ptr);
target_func_ptr_to_hook.emplace(target_func_ptr, std::move(runtime_func));
// TODO: The detour_hook is never cleaned up on menu unload.
target_func_ptr_to_hook[target_func_ptr]->create_and_enable_hook(hook_name, target_func_ptr, jitted_func);
return;
}
big::lua_module* module = sol::state_view(state_)["!this"];
@ -266,18 +248,56 @@ namespace lua::memory
{
return;
}
const auto target_func_ptr = target_func_ptr_obj.get_address();
bool need_hook = false;
if (pre_lua_callback.valid())
{
module->m_dynamic_hook_pre_callbacks[target_func_ptr].push_back(pre_lua_callback);
need_hook = true;
}
if (post_lua_callback.valid())
{
module->m_dynamic_hook_post_callbacks[target_func_ptr].push_back(post_lua_callback);
need_hook = true;
}
if (need_hook)
{
std::shared_ptr<runtime_func_t> runtime_func;
if (!big::g_lua_manager->m_target_func_ptr_to_dynamic_hook.contains(target_func_ptr))
{
std::vector<std::string> param_types;
for (const auto& [k, v] : param_types_table)
{
if (v.is<const char*>())
{
param_types.push_back(v.as<const char*>());
}
}
runtime_func = std::make_shared<runtime_func_t>();
const auto jitted_func = runtime_func->make_jit_func(return_type, param_types, asmjit::Arch::kHost, pre_callback, post_callback, target_func_ptr);
big::g_lua_manager->m_target_func_ptr_to_dynamic_hook[target_func_ptr] = runtime_func.get();
big::g_lua_manager->m_target_func_ptr_to_dynamic_hook[target_func_ptr]->create_and_enable_hook(hook_name, target_func_ptr, jitted_func);
}
else
{
// lua modules own and share the runtime_func_t object, such as when no module reference it anymore the hook detour get cleaned up.
runtime_func = big::g_lua_manager->get_existing_dynamic_hook(target_func_ptr);
}
if (runtime_func)
{
module->m_dynamic_hooks.push_back(runtime_func);
}
}
}
static std::unordered_map<uintptr_t, std::vector<uint8_t>> jitted_binded_funcs;
static std::string get_jitted_lua_func_global_name(uintptr_t function_to_call_ptr)
{
return std::format("__dynamic_call_{}", function_to_call_ptr);
@ -292,14 +312,8 @@ namespace lua::memory
}
};
static void jit_lua_binded_func(uintptr_t function_to_call_ptr, const asmjit::FuncSignature& function_to_call_sig, const asmjit::Arch& arch, std::vector<type_info_t> param_types, type_info_t return_type, lua_State* lua_state, const std::string& jitted_lua_func_global_name)
static std::unique_ptr<uint8_t[]> jit_lua_binded_func(uintptr_t function_to_call_ptr, const asmjit::FuncSignature& function_to_call_sig, const asmjit::Arch& arch, std::vector<type_info_t> param_types, type_info_t return_type, lua_State* lua_state, const std::string& jitted_lua_func_global_name)
{
const auto it = jitted_binded_funcs.find(function_to_call_ptr);
if (it != jitted_binded_funcs.end())
{
return;
}
asmjit::CodeHolder code;
auto env = asmjit::Environment::host();
env.setArch(arch);
@ -325,9 +339,6 @@ namespace lua::memory
asmjit_error_handler_t asmjit_error_handler;
code.setErrorHandler(&asmjit_error_handler);
// too small to really need it
func->frame().resetPreservedFP();
// map argument slots to registers, following abi.
std::vector<asmjit::x86::Reg> arg_registers;
for (uint8_t arg_index = 0; arg_index < function_to_call_sig.argCount(); arg_index++)
@ -444,7 +455,7 @@ namespace lua::memory
else
{
LOG(FATAL) << "Return val wider than 64bits not supported";
return;
return nullptr;
}
function_to_call_invoke_node->setRet(0, function_to_call_return_val_reg);
@ -529,9 +540,9 @@ namespace lua::memory
size_t size = code.codeSize();
// Allocate a virtual memory (executable).
static std::vector<uint8_t> jit_function_buffer(size);
auto jit_function_buffer = std::make_unique<uint8_t[]>(size);
DWORD old_protect;
VirtualProtect(jit_function_buffer.data(), size, PAGE_EXECUTE_READWRITE, &old_protect);
VirtualProtect(jit_function_buffer.get(), size, PAGE_EXECUTE_READWRITE, &old_protect);
// if multiple sections, resolve linkage (1 atm)
if (code.hasUnresolvedLinks())
@ -540,13 +551,15 @@ namespace lua::memory
}
// Relocate to the base-address of the allocated memory.
code.relocateToBase((uintptr_t)jit_function_buffer.data());
code.copyFlattenedData(jit_function_buffer.data(), size);
code.relocateToBase((uintptr_t)jit_function_buffer.get());
code.copyFlattenedData(jit_function_buffer.get(), size);
LOG(VERBOSE) << "JIT Stub: " << log.data();
lua_pushcfunction(lua_state, (lua_CFunction)jit_function_buffer.data());
lua_pushcfunction(lua_state, (lua_CFunction)jit_function_buffer.get());
lua_setglobal(lua_state, jitted_lua_func_global_name.c_str());
return jit_function_buffer;
}
// Lua API: Function
@ -581,6 +594,17 @@ namespace lua::memory
// ```
static std::string dynamic_call(const std::string& return_type, sol::table param_types_table, lua::memory::pointer& target_func_ptr_obj, sol::this_state state_)
{
big::lua_module* module = sol::state_view(state_)["!this"];
if (!module)
{
return "";
}
if (!target_func_ptr_obj.is_valid())
{
return "";
}
const auto target_func_ptr = target_func_ptr_obj.get_address();
const auto jitted_lua_func_global_name = get_jitted_lua_func_global_name(target_func_ptr);
@ -610,13 +634,25 @@ namespace lua::memory
param_types.push_back(get_type_info_from_string(s));
}
jit_lua_binded_func(target_func_ptr,
sig,
asmjit::Arch::kHost,
param_types,
get_type_info_from_string(return_type),
state_.L,
jitted_lua_func_global_name);
if (!module->m_dynamic_call_jit_functions.contains(target_func_ptr))
{
auto jitted_func = jit_lua_binded_func(target_func_ptr,
sig,
asmjit::Arch::kHost,
param_types,
get_type_info_from_string(return_type),
state_.L,
jitted_lua_func_global_name);
if (jitted_func)
{
module->m_dynamic_call_jit_functions.emplace(target_func_ptr, std::move(jitted_func));
}
else
{
return "";
}
}
return jitted_lua_func_global_name;
}

View File

@ -0,0 +1,312 @@
#pragma once
#include "runtime_func_t.hpp"
#include "lua/lua_manager.hpp"
#include <MinHook.h>
namespace lua::memory
{
char* runtime_func_t::parameters_t::get_arg_ptr(const uint8_t idx) const
{
return ((char*)&m_arguments) + sizeof(uintptr_t) * idx;
}
unsigned char* runtime_func_t::return_value_t::get() const
{
return (unsigned char*)&m_return_value;
}
runtime_func_t::runtime_func_t()
{
m_detour = std::make_unique<big::detour_hook>();
m_return_type = type_info_t::none_;
}
runtime_func_t::~runtime_func_t()
{
big::g_lua_manager->m_target_func_ptr_to_dynamic_hook.erase(m_target_func_ptr);
}
void runtime_func_t::debug_print_args(const asmjit::FuncSignature& sig)
{
for (uint8_t arg_index_debug = 0; arg_index_debug < sig.argCount(); arg_index_debug++)
{
const auto arg_type_debug = sig.args()[arg_index_debug];
LOG(VERBOSE) << (int)arg_type_debug;
}
}
uintptr_t runtime_func_t::make_jit_func(const asmjit::FuncSignature& sig, const asmjit::Arch arch, const user_pre_callback_t pre_callback, const user_post_callback_t post_callback, const uintptr_t target_func_ptr)
{
asmjit::CodeHolder code;
auto env = asmjit::Environment::host();
env.setArch(arch);
code.init(env);
// initialize function
asmjit::x86::Compiler cc(&code);
asmjit::FuncNode* func = cc.addFunc(sig);
asmjit::StringLogger log;
// clang-format off
const auto format_flags =
asmjit::FormatFlags::kMachineCode | asmjit::FormatFlags::kExplainImms | asmjit::FormatFlags::kRegCasts |
asmjit::FormatFlags::kHexImms | asmjit::FormatFlags::kHexOffsets | asmjit::FormatFlags::kPositions;
// clang-format on
log.addFlags(format_flags);
code.setLogger(&log);
// map argument slots to registers, following abi.
std::vector<asmjit::x86::Reg> arg_registers;
for (uint8_t arg_index = 0; arg_index < sig.argCount(); arg_index++)
{
const auto arg_type = sig.args()[arg_index];
asmjit::x86::Reg arg;
if (is_general_register(arg_type))
{
arg = cc.newUIntPtr();
}
else if (is_XMM_register(arg_type))
{
arg = cc.newXmm();
}
else
{
LOG(FATAL) << "Parameters wider than 64bits not supported, index: " << arg_index << " | " << (int)arg_type;
debug_print_args(sig);
return 0;
}
func->setArg(arg_index, arg);
arg_registers.push_back(arg);
}
// setup the stack structure to hold arguments for user callback
uint32_t stack_size = (uint32_t)(sizeof(uintptr_t) * sig.argCount());
m_args_stack = cc.newStack(stack_size, 16);
asmjit::x86::Mem args_stack_index(m_args_stack);
// assigns some register as index reg
asmjit::x86::Gp i = cc.newUIntPtr();
// stack_index <- stack[i].
args_stack_index.setIndex(i);
// r/w are sizeof(uintptr_t) width now
args_stack_index.setSize(sizeof(uintptr_t));
// set i = 0
cc.mov(i, 0);
// mov from arguments registers into the stack structure
for (uint8_t argIdx = 0; argIdx < sig.argCount(); argIdx++)
{
const auto argType = sig.args()[argIdx];
// have to cast back to explicit register types to gen right mov type
if (is_general_register(argType))
{
cc.mov(args_stack_index, arg_registers.at(argIdx).as<asmjit::x86::Gp>());
}
else if (is_XMM_register(argType))
{
cc.movq(args_stack_index, arg_registers.at(argIdx).as<asmjit::x86::Xmm>());
}
else
{
LOG(FATAL) << "Parameters wider than 64bits not supported, index: " << argIdx << " | " << (int)argType;
debug_print_args(sig);
return 0;
}
// next structure slot (+= sizeof(uintptr_t))
cc.add(i, sizeof(uintptr_t));
}
// get pointer to stack structure and pass it to the user pre callback
asmjit::x86::Gp arg_struct = cc.newUIntPtr("arg_struct");
cc.lea(arg_struct, m_args_stack);
// fill reg to pass struct arg count to callback
asmjit::x86::Gp arg_param_count = cc.newUInt8();
cc.mov(arg_param_count, (uint8_t)sig.argCount());
// create buffer for ret val
asmjit::x86::Mem return_stack = cc.newStack(sizeof(uintptr_t), 16);
asmjit::x86::Gp return_struct = cc.newUIntPtr("return_struct");
cc.lea(return_struct, return_stack);
// fill reg to pass target function pointer to callback
asmjit::x86::Gp target_func_ptr_reg = cc.newUIntPtr();
cc.mov(target_func_ptr_reg, target_func_ptr);
asmjit::Label original_invoke_label = cc.newLabel();
asmjit::Label skip_original_invoke_label = cc.newLabel();
// invoke the user pre callback
asmjit::InvokeNode* pre_callback_invoke_node;
cc.invoke(&pre_callback_invoke_node, (uintptr_t)pre_callback, asmjit::FuncSignatureT<bool, parameters_t*, uint8_t, return_value_t*, uintptr_t>());
// call to user provided function (use ABI of host compiler)
pre_callback_invoke_node->setArg(0, arg_struct);
pre_callback_invoke_node->setArg(1, arg_param_count);
pre_callback_invoke_node->setArg(2, return_struct);
pre_callback_invoke_node->setArg(3, target_func_ptr_reg);
// create a register for the user pre callback's return value
// Note: the size of the register is important for the test instruction. newUInt8 since the pre callback returns a bool.
asmjit::x86::Gp pre_callback_return_val = cc.newUInt8("pre_callback_return_val");
// store the callback return value
pre_callback_invoke_node->setRet(0, pre_callback_return_val);
// if the callback return value is zero, skip orig.
cc.test(pre_callback_return_val, pre_callback_return_val);
cc.jz(skip_original_invoke_label);
// label to invoke the original function
cc.bind(original_invoke_label);
// mov from arguments stack structure into regs
cc.mov(i, 0); // reset idx
for (uint8_t arg_idx = 0; arg_idx < sig.argCount(); arg_idx++)
{
const auto argType = sig.args()[arg_idx];
if (is_general_register(argType))
{
cc.mov(arg_registers.at(arg_idx).as<asmjit::x86::Gp>(), args_stack_index);
}
else if (is_XMM_register(argType))
{
cc.movq(arg_registers.at(arg_idx).as<asmjit::x86::Xmm>(), args_stack_index);
}
else
{
LOG(FATAL) << "Parameters wider than 64bits not supported, index: " << arg_idx << " | " << (int)argType;
debug_print_args(sig);
return 0;
}
// next structure slot (+= sizeof(uint64_t))
cc.add(i, sizeof(uint64_t));
}
// deref the trampoline ptr (holder must live longer, must be concrete reg since push later)
asmjit::x86::Gp original_ptr = cc.newUIntPtr();
cc.mov(original_ptr, m_detour->get_original_ptr());
cc.mov(original_ptr, asmjit::x86::ptr(original_ptr));
asmjit::InvokeNode* original_invoke_node;
cc.invoke(&original_invoke_node, original_ptr, sig);
for (uint8_t arg_index = 0; arg_index < sig.argCount(); arg_index++)
{
original_invoke_node->setArg(arg_index, arg_registers.at(arg_index));
}
if (sig.hasRet())
{
if (is_general_register(sig.ret()))
{
asmjit::x86::Gp tmp = cc.newUIntPtr();
original_invoke_node->setRet(0, tmp);
cc.mov(return_stack, tmp);
}
else
{
asmjit::x86::Xmm tmp = cc.newXmm();
original_invoke_node->setRet(0, tmp);
cc.movq(return_stack, tmp);
}
}
cc.bind(skip_original_invoke_label);
asmjit::InvokeNode* post_callback_invoke_node;
cc.invoke(&post_callback_invoke_node, (uintptr_t)post_callback, asmjit::FuncSignatureT<void, parameters_t*, uint8_t, return_value_t*, uintptr_t>());
// Set arguments for the post callback
post_callback_invoke_node->setArg(0, arg_struct);
post_callback_invoke_node->setArg(1, arg_param_count);
post_callback_invoke_node->setArg(2, return_struct);
post_callback_invoke_node->setArg(3, target_func_ptr_reg);
if (sig.hasRet())
{
asmjit::x86::Mem return_stack_index(return_stack);
return_stack_index.setSize(sizeof(uintptr_t));
if (is_general_register(sig.ret()))
{
asmjit::x86::Gp tmp2 = cc.newUIntPtr();
cc.mov(tmp2, return_stack_index);
cc.ret(tmp2);
}
else
{
asmjit::x86::Xmm tmp2 = cc.newXmm();
cc.movq(tmp2, return_stack_index);
cc.ret(tmp2);
}
}
cc.endFunc();
// write to buffer
cc.finalize();
// worst case, overestimates for case trampolines needed
code.flatten();
size_t size = code.codeSize();
// Allocate a virtual memory (executable).
m_jit_function_buffer.reserve(size);
DWORD old_protect;
VirtualProtect(m_jit_function_buffer.data(), size, PAGE_EXECUTE_READWRITE, &old_protect);
// if multiple sections, resolve linkage (1 atm)
if (code.hasUnresolvedLinks())
{
code.resolveUnresolvedLinks();
}
// Relocate to the base-address of the allocated memory.
code.relocateToBase((uintptr_t)m_jit_function_buffer.data());
code.copyFlattenedData(m_jit_function_buffer.data(), size);
LOG(VERBOSE) << "JIT Stub: " << log.data();
return (uintptr_t)m_jit_function_buffer.data();
}
uintptr_t runtime_func_t::make_jit_func(const std::string& return_type, const std::vector<std::string>& param_types, const asmjit::Arch arch, const user_pre_callback_t pre_callback, const user_post_callback_t post_callback, const uintptr_t target_func_ptr, std::string call_convention)
{
m_return_type = get_type_info_from_string(return_type);
asmjit::FuncSignature sig(get_call_convention(call_convention), asmjit::FuncSignature::kNoVarArgs, get_type_id(return_type));
for (const std::string& s : param_types)
{
sig.addArg(get_type_id(s));
m_param_types.push_back(get_type_info_from_string(s));
}
return make_jit_func(sig, arch, pre_callback, post_callback, target_func_ptr);
}
void runtime_func_t::create_and_enable_hook(const std::string& hook_name, uintptr_t target_func_ptr, uintptr_t jitted_func_ptr)
{
m_target_func_ptr = target_func_ptr;
m_detour->set_instance(hook_name, (void*)target_func_ptr, (void*)jitted_func_ptr);
m_detour->enable();
MH_ApplyQueued();
}
}

View File

@ -1,4 +1,5 @@
#pragma once
#include "asmjit_helper.hpp"
#include "hooking/detour_hook.hpp"
#include "lua/bindings/type_info_t.hpp"
@ -6,147 +7,6 @@
namespace lua::memory
{
// does a given type fit in a general purpose register (i.e. is it integer type)
inline bool is_general_register(const asmjit::TypeId type_id)
{
switch (type_id)
{
case asmjit::TypeId::kInt8:
case asmjit::TypeId::kUInt8:
case asmjit::TypeId::kInt16:
case asmjit::TypeId::kUInt16:
case asmjit::TypeId::kInt32:
case asmjit::TypeId::kUInt32:
case asmjit::TypeId::kInt64:
case asmjit::TypeId::kUInt64:
case asmjit::TypeId::kIntPtr:
case asmjit::TypeId::kUIntPtr: return true;
default: return false;
}
}
// float, double, simd128
inline bool is_XMM_register(const asmjit::TypeId type_id)
{
switch (type_id)
{
case asmjit::TypeId::kFloat32:
case asmjit::TypeId::kFloat64: return true;
default: return false;
}
}
inline asmjit::CallConvId get_call_convention(const std::string& conv)
{
if (conv == "cdecl")
{
return asmjit::CallConvId::kCDecl;
}
else if (conv == "stdcall")
{
return asmjit::CallConvId::kStdCall;
}
else if (conv == "fastcall")
{
return asmjit::CallConvId::kFastCall;
}
return asmjit::CallConvId::kHost;
}
inline asmjit::TypeId get_type_id(const std::string& type)
{
if (type.find('*') != std::string::npos)
{
return asmjit::TypeId::kUIntPtr;
}
#define TYPEID_MATCH_STR_IF(var, T) \
if (var == #T) \
{ \
return asmjit::TypeId(asmjit::TypeUtils::TypeIdOfT<T>::kTypeId); \
}
#define TYPEID_MATCH_STR_ELSEIF(var, T) \
else if (var == #T) \
{ \
return asmjit::TypeId(asmjit::TypeUtils::TypeIdOfT<T>::kTypeId); \
}
TYPEID_MATCH_STR_IF(type, signed char)
TYPEID_MATCH_STR_ELSEIF(type, unsigned char)
TYPEID_MATCH_STR_ELSEIF(type, short)
TYPEID_MATCH_STR_ELSEIF(type, unsigned short)
TYPEID_MATCH_STR_ELSEIF(type, int)
TYPEID_MATCH_STR_ELSEIF(type, unsigned int)
TYPEID_MATCH_STR_ELSEIF(type, long)
TYPEID_MATCH_STR_ELSEIF(type, unsigned long)
#ifdef POLYHOOK2_OS_WINDOWS
TYPEID_MATCH_STR_ELSEIF(type, __int64)
TYPEID_MATCH_STR_ELSEIF(type, unsigned __int64)
#endif
TYPEID_MATCH_STR_ELSEIF(type, long long)
TYPEID_MATCH_STR_ELSEIF(type, unsigned long long)
TYPEID_MATCH_STR_ELSEIF(type, char)
TYPEID_MATCH_STR_ELSEIF(type, char16_t)
TYPEID_MATCH_STR_ELSEIF(type, char32_t)
TYPEID_MATCH_STR_ELSEIF(type, wchar_t)
TYPEID_MATCH_STR_ELSEIF(type, uint8_t)
TYPEID_MATCH_STR_ELSEIF(type, int8_t)
TYPEID_MATCH_STR_ELSEIF(type, uint16_t)
TYPEID_MATCH_STR_ELSEIF(type, int16_t)
TYPEID_MATCH_STR_ELSEIF(type, int32_t)
TYPEID_MATCH_STR_ELSEIF(type, uint32_t)
TYPEID_MATCH_STR_ELSEIF(type, uint64_t)
TYPEID_MATCH_STR_ELSEIF(type, int64_t)
TYPEID_MATCH_STR_ELSEIF(type, float)
TYPEID_MATCH_STR_ELSEIF(type, double)
TYPEID_MATCH_STR_ELSEIF(type, bool)
TYPEID_MATCH_STR_ELSEIF(type, void)
else if (type == "intptr_t")
{
return asmjit::TypeId::kIntPtr;
}
else if (type == "uintptr_t")
{
return asmjit::TypeId::kUIntPtr;
}
return asmjit::TypeId::kVoid;
}
static type_info_t get_type_info_from_string(const std::string& s)
{
if ((s.contains("const") && s.contains("char") && s.contains("*")) || s.contains("string"))
{
return type_info_t::string_;
}
else if (s.contains("bool"))
{
return type_info_t::boolean_;
}
else if (s.contains("ptr") || s.contains("pointer") || s.contains("*"))
{
// passing lua::memory::pointer
return type_info_t::ptr_;
}
else if (s.contains("float"))
{
return type_info_t::float_;
}
else if (s.contains("double"))
{
return type_info_t::double_;
}
else if (s.contains("vector3"))
{
return type_info_t::vector3_;
}
else
{
return type_info_t::integer_;
}
}
class runtime_func_t
{
std::vector<uint8_t> m_jit_function_buffer;
@ -154,6 +14,8 @@ namespace lua::memory
std::unique_ptr<big::detour_hook> m_detour;
uintptr_t m_target_func_ptr{};
public:
type_info_t m_return_type;
std::vector<type_info_t> m_param_types;
@ -177,10 +39,7 @@ namespace lua::memory
volatile uintptr_t m_arguments;
// must be char* for aliasing rules to work when reading back out
char* get_arg_ptr(const uint8_t idx) const
{
return ((char*)&m_arguments) + sizeof(uintptr_t) * idx;
}
char* get_arg_ptr(const uint8_t idx) const;
};
class return_value_t
@ -188,294 +47,33 @@ namespace lua::memory
uintptr_t m_return_value;
public:
unsigned char* get() const
{
return (unsigned char*)&m_return_value;
}
unsigned char* get() const;
};
typedef bool (*user_pre_callback_t)(const parameters_t* params, const uint8_t parameters_count, return_value_t* return_value, const uintptr_t target_func_ptr);
typedef void (*user_post_callback_t)(const parameters_t* params, const uint8_t parameters_count, return_value_t* return_value, const uintptr_t target_func_ptr);
runtime_func_t()
runtime_func_t();
~runtime_func_t();
uintptr_t get_target_func_ptr() const
{
m_detour = std::make_unique<big::detour_hook>();
m_return_type = type_info_t::none_;
return m_target_func_ptr;
}
~runtime_func_t()
{
}
void debug_print_args(const asmjit::FuncSignature& sig);
// Construct a callback given the raw signature at runtime. 'Callback' param is the C stub to transfer to,
// where parameters can be modified through a structure which is written back to the parameter slots depending
// on calling convention.
uintptr_t make_jit_func(const asmjit::FuncSignature& sig, const asmjit::Arch arch, const user_pre_callback_t pre_callback, const user_post_callback_t post_callback, const uintptr_t target_func_ptr)
{
asmjit::CodeHolder code;
auto env = asmjit::Environment::host();
env.setArch(arch);
code.init(env);
// initialize function
asmjit::x86::Compiler cc(&code);
asmjit::FuncNode* func = cc.addFunc(sig);
asmjit::StringLogger log;
// clang-format off
const auto format_flags =
asmjit::FormatFlags::kMachineCode | asmjit::FormatFlags::kExplainImms | asmjit::FormatFlags::kRegCasts |
asmjit::FormatFlags::kHexImms | asmjit::FormatFlags::kHexOffsets | asmjit::FormatFlags::kPositions;
// clang-format on
log.addFlags(format_flags);
code.setLogger(&log);
// too small to really need it
func->frame().resetPreservedFP();
// map argument slots to registers, following abi.
std::vector<asmjit::x86::Reg> arg_registers;
for (uint8_t arg_index = 0; arg_index < sig.argCount(); arg_index++)
{
const auto arg_type = sig.args()[arg_index];
asmjit::x86::Reg arg;
if (is_general_register(arg_type))
{
arg = cc.newUIntPtr();
}
else if (is_XMM_register(arg_type))
{
arg = cc.newXmm();
}
else
{
LOG(FATAL) << "Parameters wider than 64bits not supported";
return 0;
}
func->setArg(arg_index, arg);
arg_registers.push_back(arg);
}
// setup the stack structure to hold arguments for user callback
uint32_t stack_size = (uint32_t)(sizeof(uintptr_t) * sig.argCount());
m_args_stack = cc.newStack(stack_size, 16);
asmjit::x86::Mem args_stack_index(m_args_stack);
// assigns some register as index reg
asmjit::x86::Gp i = cc.newUIntPtr();
// stack_index <- stack[i].
args_stack_index.setIndex(i);
// r/w are sizeof(uintptr_t) width now
args_stack_index.setSize(sizeof(uintptr_t));
// set i = 0
cc.mov(i, 0);
// mov from arguments registers into the stack structure
for (uint8_t argIdx = 0; argIdx < sig.argCount(); argIdx++)
{
const auto argType = sig.args()[argIdx];
// have to cast back to explicit register types to gen right mov type
if (is_general_register(argType))
{
cc.mov(args_stack_index, arg_registers.at(argIdx).as<asmjit::x86::Gp>());
}
else if (is_XMM_register(argType))
{
cc.movq(args_stack_index, arg_registers.at(argIdx).as<asmjit::x86::Xmm>());
}
else
{
LOG(FATAL) << "Parameters wider than 64bits not supported";
return 0;
}
// next structure slot (+= sizeof(uintptr_t))
cc.add(i, sizeof(uintptr_t));
}
// get pointer to stack structure and pass it to the user pre callback
asmjit::x86::Gp arg_struct = cc.newUIntPtr("arg_struct");
cc.lea(arg_struct, m_args_stack);
// fill reg to pass struct arg count to callback
asmjit::x86::Gp arg_param_count = cc.newUInt8();
cc.mov(arg_param_count, (uint8_t)sig.argCount());
// create buffer for ret val
asmjit::x86::Mem return_stack = cc.newStack(sizeof(uintptr_t), 16);
asmjit::x86::Gp return_struct = cc.newUIntPtr("return_struct");
cc.lea(return_struct, return_stack);
// fill reg to pass target function pointer to callback
asmjit::x86::Gp target_func_ptr_reg = cc.newUIntPtr();
cc.mov(target_func_ptr_reg, target_func_ptr);
asmjit::Label original_invoke_label = cc.newLabel();
asmjit::Label skip_original_invoke_label = cc.newLabel();
// invoke the user pre callback
asmjit::InvokeNode* pre_callback_invoke_node;
cc.invoke(&pre_callback_invoke_node, (uintptr_t)pre_callback, asmjit::FuncSignatureT<bool, parameters_t*, uint8_t, return_value_t*, uintptr_t>());
// call to user provided function (use ABI of host compiler)
pre_callback_invoke_node->setArg(0, arg_struct);
pre_callback_invoke_node->setArg(1, arg_param_count);
pre_callback_invoke_node->setArg(2, return_struct);
pre_callback_invoke_node->setArg(3, target_func_ptr_reg);
// create a register for the user pre callback's return value
// Note: the size of the register is important for the test instruction. newUInt8 since the pre callback returns a bool.
asmjit::x86::Gp pre_callback_return_val = cc.newUInt8("pre_callback_return_val");
// store the callback return value
pre_callback_invoke_node->setRet(0, pre_callback_return_val);
// if the callback return value is zero, skip orig.
cc.test(pre_callback_return_val, pre_callback_return_val);
cc.jz(skip_original_invoke_label);
// label to invoke the original function
cc.bind(original_invoke_label);
// mov from arguments stack structure into regs
cc.mov(i, 0); // reset idx
for (uint8_t arg_idx = 0; arg_idx < sig.argCount(); arg_idx++)
{
const auto argType = sig.args()[arg_idx];
if (is_general_register(argType))
{
cc.mov(arg_registers.at(arg_idx).as<asmjit::x86::Gp>(), args_stack_index);
}
else if (is_XMM_register(argType))
{
cc.movq(arg_registers.at(arg_idx).as<asmjit::x86::Xmm>(), args_stack_index);
}
else
{
LOG(FATAL) << "Parameters wider than 64bits not supported";
return 0;
}
// next structure slot (+= sizeof(uint64_t))
cc.add(i, sizeof(uint64_t));
}
// deref the trampoline ptr (holder must live longer, must be concrete reg since push later)
asmjit::x86::Gp original_ptr = cc.zbx();
cc.mov(original_ptr, m_detour->get_original_ptr());
cc.mov(original_ptr, asmjit::x86::ptr(original_ptr));
asmjit::InvokeNode* original_invoke_node;
cc.invoke(&original_invoke_node, original_ptr, sig);
for (uint8_t arg_index = 0; arg_index < sig.argCount(); arg_index++)
{
original_invoke_node->setArg(arg_index, arg_registers.at(arg_index));
}
if (sig.hasRet())
{
if (is_general_register(sig.ret()))
{
asmjit::x86::Gp tmp = cc.newUIntPtr();
original_invoke_node->setRet(0, tmp);
cc.mov(return_stack, tmp);
}
else
{
asmjit::x86::Xmm tmp = cc.newXmm();
original_invoke_node->setRet(0, tmp);
cc.movq(return_stack, tmp);
}
}
cc.bind(skip_original_invoke_label);
asmjit::InvokeNode* post_callback_invoke_node;
cc.invoke(&post_callback_invoke_node, (uintptr_t)post_callback, asmjit::FuncSignatureT<void, parameters_t*, uint8_t, return_value_t*, uintptr_t>());
// Set arguments for the post callback
post_callback_invoke_node->setArg(0, arg_struct);
post_callback_invoke_node->setArg(1, arg_param_count);
post_callback_invoke_node->setArg(2, return_struct);
post_callback_invoke_node->setArg(3, target_func_ptr_reg);
if (sig.hasRet())
{
asmjit::x86::Mem return_stack_index(return_stack);
return_stack_index.setSize(sizeof(uintptr_t));
if (is_general_register(sig.ret()))
{
asmjit::x86::Gp tmp2 = cc.newUIntPtr();
cc.mov(tmp2, return_stack_index);
cc.ret(tmp2);
}
else
{
asmjit::x86::Xmm tmp2 = cc.newXmm();
cc.movq(tmp2, return_stack_index);
cc.ret(tmp2);
}
}
cc.endFunc();
// write to buffer
cc.finalize();
// worst case, overestimates for case trampolines needed
code.flatten();
size_t size = code.codeSize();
// Allocate a virtual memory (executable).
m_jit_function_buffer.reserve(size);
DWORD old_protect;
VirtualProtect(m_jit_function_buffer.data(), size, PAGE_EXECUTE_READWRITE, &old_protect);
// if multiple sections, resolve linkage (1 atm)
if (code.hasUnresolvedLinks())
{
code.resolveUnresolvedLinks();
}
// Relocate to the base-address of the allocated memory.
code.relocateToBase((uintptr_t)m_jit_function_buffer.data());
code.copyFlattenedData(m_jit_function_buffer.data(), size);
LOG(VERBOSE) << "JIT Stub: " << log.data();
return (uintptr_t)m_jit_function_buffer.data();
}
uintptr_t make_jit_func(const asmjit::FuncSignature& sig, const asmjit::Arch arch, const user_pre_callback_t pre_callback, const user_post_callback_t post_callback, const uintptr_t target_func_ptr);
// Construct a callback given the typedef as a string. Types are any valid C/C++ data type (basic types), and pointers to
// anything are just a uintptr_t. Calling convention is defaulted to whatever is typical for the compiler you use, you can override with
// stdcall, fastcall, or cdecl (cdecl is default on x86). On x64 those map to the same thing.
uintptr_t make_jit_func(const std::string& return_type, const std::vector<std::string>& param_types, const asmjit::Arch arch, const user_pre_callback_t pre_callback, const user_post_callback_t post_callback, const uintptr_t target_func_ptr, std::string call_convention = "")
{
m_return_type = get_type_info_from_string(return_type);
uintptr_t make_jit_func(const std::string& return_type, const std::vector<std::string>& param_types, const asmjit::Arch arch, const user_pre_callback_t pre_callback, const user_post_callback_t post_callback, const uintptr_t target_func_ptr, std::string call_convention = "");
asmjit::FuncSignature sig(get_call_convention(call_convention), asmjit::FuncSignature::kNoVarArgs, get_type_id(return_type));
for (const std::string& s : param_types)
{
sig.addArg(get_type_id(s));
m_param_types.push_back(get_type_info_from_string(s));
}
return make_jit_func(sig, arch, pre_callback, post_callback, target_func_ptr);
}
void create_and_enable_hook(const std::string& hook_name, uintptr_t target_func_ptr, uintptr_t jitted_func_ptr)
{
m_detour->set_instance(hook_name, (void*)target_func_ptr, (void*)jitted_func_ptr);
m_detour->enable();
}
void create_and_enable_hook(const std::string& hook_name, uintptr_t target_func_ptr, uintptr_t jitted_func_ptr);
};
}

View File

@ -0,0 +1,37 @@
#include "type_info_t.hpp"
namespace lua::memory
{
type_info_t get_type_info_from_string(const std::string& s)
{
if ((s.contains("const") && s.contains("char") && s.contains("*")) || s.contains("string"))
{
return type_info_t::string_;
}
else if (s.contains("bool"))
{
return type_info_t::boolean_;
}
else if (s.contains("ptr") || s.contains("pointer") || s.contains("*"))
{
// passing lua::memory::pointer
return type_info_t::ptr_;
}
else if (s.contains("float"))
{
return type_info_t::float_;
}
else if (s.contains("double"))
{
return type_info_t::double_;
}
else if (s.contains("vector3"))
{
return type_info_t::vector3_;
}
else
{
return type_info_t::integer_;
}
}
}

View File

@ -13,4 +13,6 @@ namespace lua::memory
double_,
vector3_
};
type_info_t get_type_info_from_string(const std::string& s);
}

View File

@ -414,4 +414,20 @@ namespace big
LOG(FATAL) << state["!module_name"].get<std::string_view>() << ": " << error.what();
Logger::FlushQueue();
}
std::shared_ptr<lua::memory::runtime_func_t> lua_manager::get_existing_dynamic_hook(const uintptr_t target_func_ptr)
{
for (const auto& mod : m_modules)
{
for (const auto& dyn_hook : mod->m_dynamic_hooks)
{
if (dyn_hook->get_target_func_ptr() == target_func_ptr)
{
return dyn_hook;
}
}
}
return nullptr;
}
}

View File

@ -1,4 +1,5 @@
#pragma once
#include "bindings/runtime_func_t.hpp"
#include "lua_module.hpp"
namespace big
@ -43,6 +44,9 @@ namespace big
return m_scripts_config_folder;
}
// non owning map
std::unordered_map<uintptr_t, lua::memory::runtime_func_t*> m_target_func_ptr_to_dynamic_hook;
std::weak_ptr<lua_module> get_module(rage::joaat_t module_id);
std::weak_ptr<lua_module> get_disabled_module(rage::joaat_t module_id);
@ -101,6 +105,8 @@ namespace big
return std::nullopt;
}
std::shared_ptr<lua::memory::runtime_func_t> get_existing_dynamic_hook(const uintptr_t target_func_ptr);
inline void for_each_module(auto func)
{
std::lock_guard guard(m_module_lock);

View File

@ -3,8 +3,8 @@
#include "bindings/gui/gui_element.hpp"
#include "core/data/menu_event.hpp"
#include "lua/bindings/runtime_func_t.hpp"
#include "lua/bindings/type_info_t.hpp"
#include "lua/bindings/scr_patch.hpp"
#include "lua/bindings/type_info_t.hpp"
#include "lua_patch.hpp"
#include "services/gui/gui_service.hpp"
@ -41,9 +41,13 @@ namespace big
std::unordered_map<menu_event, std::vector<sol::protected_function>> m_event_callbacks;
std::vector<void*> m_allocated_memory;
// lua modules own and share the runtime_func_t object, such as when no module reference it anymore the hook detour get cleaned up.
std::vector<std::shared_ptr<lua::memory::runtime_func_t>> m_dynamic_hooks;
std::unordered_map<uintptr_t, std::vector<sol::protected_function>> m_dynamic_hook_pre_callbacks;
std::unordered_map<uintptr_t, std::vector<sol::protected_function>> m_dynamic_hook_post_callbacks;
std::unordered_map<uintptr_t, std::unique_ptr<uint8_t[]>> m_dynamic_call_jit_functions;
lua_module(const std::filesystem::path& module_path, folder& scripts_folder, bool disabled = false);
~lua_module();