How to thunk a function in x86 and x64? (Like std::bind in C++, but dynamic)

前端 未结 2 730
春和景丽
春和景丽 2021-01-06 03:00

How do I thunk an arbitrary function with an arbitrary (fixed) number of arguments, on x86 and x64?

(I don\'t need floating-point, SSE, or the like. The arguments ar

相关标签:
2条回答
  • 2021-01-06 03:32

    Here is a modification for thiscall functions

    The vbind() stub generator above is meant to be used for C++ member functions as well, although it is not clear how to proceed. Here's what I've come up with:

    // experimental x64 thiscall thunking
    class TestHook {
    public:
        typedef void (TestHook::*TMFP)();
    
        TestHook(DWORD num) 
        {
            m_context = num;
    
            union { void* (*func)(); TMFP method; } addr;
            addr.method = (TMFP)CBTHook_stub;
    
            // pass "this" as the first fixed argument
            void *args[] = { this };
            size_t thunk_size = vbind(addr.func, 4, m_thunk, 0, args, 1);
            ATLASSERT(thunk_size < sizeof(m_thunk));
    
            unsigned long old;
            VirtualProtect(m_thunk, thunk_size, PAGE_EXECUTE_READWRITE, &old);
            FlushInstructionCache(GetCurrentProcess(), m_thunk, thunk_size);
        }
    
        FARPROC GetThunk() const {    return (FARPROC)(void*)m_thunk; }
    
    protected:
        // test thiscall: one integer and two 8-byte arguments
        LRESULT CBTHook_stub(int nCode, WPARAM wParam, LPARAM lParam) 
        {
            ATLTRACE(_T("this=%p, code=%d, wp=%x, lp=%x, context=%x\n"), this, nCode, wParam, lParam, m_context);
            return lParam;
        }
    
        DWORD m_context;
        unsigned char m_thunk[1024]; // fixed; don't know size required apriori!
    };
    
    #ifndef _WIN64
    #error does not work for win32
    #endif
    void main(void)
    {
        TestHook tmp(0xDeadBeef);
    
        HOOKPROC proc = (HOOKPROC)tmp.GetThunk();
        ATLTRACE(_T("object %p return value=%d\n"), &tmp, proc(1, 2, 3));
    }
    

    I am not an assembly gury but this code correctly stubs into the member function for 64 bit code. There are some implicit assumptions (I'm not 100% sure if valid, please correct me if I'm wrong):

    1. in x64 (amd / microsoft VS) all function arguments are passed as 8 bytes long. So although vbind was just for pointer-type arguments, it is possible to thunk into other function prototypes (e.g. the HOOKPROC takes one integer and two __int64)

    2. "this" pointer is passed as the first stack argument in x64 instead of ECX. I used the bounded argument to pass "this" pointer and provide context to the C++ object

    0 讨论(0)
  • 2021-01-06 03:35

    Here's my generic implementation.

    I initially made it with AsmJit, then modified it by hand to remove the dependency.

    • It works for both x86 and x64!

    • It works for both cdecl and stdcall!
      It should also work for "thiscall", both on VC++ and GCC, but I haven't tested it.
      (VC++ would probably not touch the 'this' pointer, whereas GCC would treat it as the first argument.)

    • It can bind an arbitrary number of arguments at any position in the parameter list!

    Just beware:

    • It does not work for variadic functions, like printf.
      Doing so would either require you to provide the number of arguments dynamically (which is painful) or would require you to store the return-pointers somewhere other than the stack, which is complicated.

    • It was not designed for ultra-high performance, but it should still be fast enough.
      The speed is O(total parameter count), not O(bound parameter count).

    Scroll to the right to see the assembly code.

    #include <stddef.h>
    
    size_t vbind(
        void *(/* cdecl, stdcall, or thiscall */ *f)(), size_t param_count,
        unsigned char buffer[/* >= 128 + n * (5 + sizeof(int) + sizeof(void*)) */],
        size_t const i, void *const bound[], unsigned int const n, bool const thiscall)
    {
        unsigned char *p = buffer;
        unsigned char s = sizeof(void *);
        unsigned char b = sizeof(int) == sizeof(void *) ? 2 : 3;  // log2(sizeof(void *))
        *p++ = 0x55;                                                                          // push     rbp
        if (b > 2) { *p++ = 0x48; } *p++ = 0x8B; *p++ = 0xEC;                                 // mov      rbp, rsp
        if (b > 2)
        {
            *p++ = 0x48; *p++ = 0x89; *p++ = 0x4C; *p++ = 0x24; *p++ = 2 * s;                 // mov      [rsp + 2 * s], rcx
            *p++ = 0x48; *p++ = 0x89; *p++ = 0x54; *p++ = 0x24; *p++ = 3 * s;                 // mov      [rsp + 3 * s], rdx
            *p++ = 0x4C; *p++ = 0x89; *p++ = 0x44; *p++ = 0x24; *p++ = 4 * s;                 // mov      [rsp + 4 * s], r8
            *p++ = 0x4C; *p++ = 0x89; *p++ = 0x4C; *p++ = 0x24; *p++ = 5 * s;                 // mov      [rsp + 5 * s], r9
        }
        if (b > 2) { *p++ = 0x48; } *p++ = 0xBA; *(*(size_t **)&p)++ = param_count;           // mov      rdx, <param_count>
        if (b > 2) { *p++ = 0x48; } *p++ = 0x8B; *p++ = 0xC2;                                 // mov      rax, rdx
        if (b > 2) { *p++ = 0x48; } *p++ = 0xC1; *p++ = 0xE0; *p++ = b;                       // shl      rax, log2(sizeof(void *))
        if (b > 2) { *p++ = 0x48; } *p++ = 0x2B; *p++ = 0xE0;                                 // sub      rsp, rax
        *p++ = 0x57;                                                                          // push     rdi
        *p++ = 0x56;                                                                          // push     rsi
        *p++ = 0x51;                                                                          // push     rcx
        *p++ = 0x9C;                                                                          // pushfq
        if (b > 2) { *p++ = 0x48; } *p++ = 0xF7; *p++ = 0xD8;                                 // neg      rax
        if (b > 2) { *p++ = 0x48; } *p++ = 0x8D; *p++ = 0x7C; *p++ = 0x05; *p++ = 0x00;       // lea      rdi, [rbp + rax]
        if (b > 2) { *p++ = 0x48; } *p++ = 0x8D; *p++ = 0x75; *p++ = 2 * s;                   // lea      rsi, [rbp + 10h]
        if (b > 2) { *p++ = 0x48; } *p++ = 0xB9; *(*(size_t **)&p)++ = i;                     // mov      rcx, <i>
        if (b > 2) { *p++ = 0x48; } *p++ = 0x2B; *p++ = 0xD1;                                 // sub      rdx, rcx
        *p++ = 0xFC;                                                                          // cld
        *p++ = 0xF3; if (b > 2) { *p++ = 0x48; } *p++ = 0xA5;                                 // rep movs [rdi], [rsi]
        for (unsigned int j = 0; j < n; j++)
        {
            unsigned int const o = j * sizeof(p);
            if (b > 2) { *p++ = 0x48; } *p++ = 0xB8; *(*(void ***)&p)++ = bound[j];           // mov      rax, <arg>
            if (b > 2) { *p++ = 0x48; } *p++ = 0x89; *p++ = 0x87; *(*(int **)&p)++ = o;       // mov      [rdi + <iArg>], rax
        }
        if (b > 2) { *p++ = 0x48; } *p++ = 0xB8; *(*(size_t **)&p)++ = n;                     // mov      rax, <count>
        if (b > 2) { *p++ = 0x48; } *p++ = 0x2B; *p++ = 0xD0;                                 // sub      rdx, rax
        if (b > 2) { *p++ = 0x48; } *p++ = 0xC1; *p++ = 0xE0; *p++ = b;                       // shl      rax, log2(sizeof(void *))
        if (b > 2) { *p++ = 0x48; } *p++ = 0x03; *p++ = 0xF8;                                 // add      rdi, rax
        if (b > 2) { *p++ = 0x48; } *p++ = 0x8B; *p++ = 0xCA;                                 // mov      rcx, rdx
        *p++ = 0xF3; if (b > 2) { *p++ = 0x48; } *p++ = 0xA5;                                 // rep movs [rdi], [rsi]
        *p++ = 0x9D;                                                                          // popfq
        *p++ = 0x59;                                                                          // pop      rcx
        *p++ = 0x5E;                                                                          // pop      rsi
        *p++ = 0x5F;                                                                          // pop      rdi
        if (b > 2)
        {
            *p++ = 0x48; *p++ = 0x8B; *p++ = 0x4C; *p++ = 0x24; *p++ = 0 * s;                 // mov      rcx, [rsp + 0 * s]
            *p++ = 0x48; *p++ = 0x8B; *p++ = 0x54; *p++ = 0x24; *p++ = 1 * s;                 // mov      rdx, [rsp + 1 * s]
            *p++ = 0x4C; *p++ = 0x8B; *p++ = 0x44; *p++ = 0x24; *p++ = 2 * s;                 // mov      r8 , [rsp + 2 * s]
            *p++ = 0x4C; *p++ = 0x8B; *p++ = 0x4C; *p++ = 0x24; *p++ = 3 * s;                 // mov      r9 , [rsp + 3 * s]
            *p++ = 0x48; *p++ = 0xB8; *(*(void *(***)())&p)++ = f;                            // mov      rax, <target_ptr>
            *p++ = 0xFF; *p++ = 0xD0;                                                         // call     rax
        }
        else
        {
            if (thiscall) { *p++ = 0x59; }                                                    // pop      rcx
            *p++ = 0xE8; *(*(ptrdiff_t **)&p)++ = (unsigned char *)f - p
    #ifdef _MSC_VER
                    - s  // for unknown reasons, GCC doesn't like this
    #endif
                ;                                                                             // call     <fn_rel>
        }
        if (b > 2) { *p++ = 0x48; } *p++ = 0x8B; *p++ = 0xE5;                                            // mov      rsp, rbp
        *p++ = 0x5D;                                                                          // pop      rbp
        *p++ = 0xC3;                                                                          // ret
        return p - &buffer[0];
    }
    

    Example (for Windows):

    #include <assert.h>
    #include <stdio.h>
    #include <Windows.h>
    void *__cdecl test(void *value, void *x, void *y, void *z, void *w, void *u)
    {
        if (u > 0) { test(value, x, y, z, w, (void *)((size_t)u - 1)); }
        printf("Test called! %p %p %p %p %p %p\n", value, x, y, z, w, u);
        return value;
    }
    struct Test
    {
        void *local;
        void *operator()(void *value, void *x, void *y, void *z, void *w, void *u)
        {
            if (u > 0) { (*this)(value, x, y, z, w, (void *)((size_t)u - 1)); }
            printf("Test::operator() called! %p %p %p %p %p %p %p\n", local, value, x, y, z, w, u);
            return value;
        }
    };
    int main()
    {
        unsigned char thunk[1024]; unsigned long old;
        VirtualProtect(&thunk, sizeof(thunk), PAGE_EXECUTE_READWRITE, &old);
        void *args[] = { (void *)0xBAADF00DBAADF001, (void *)0xBAADF00DBAADF002 };
        void *(Test::*f)(void *value, void *x, void *y, void *z, void *w, void *u) = &Test::operator();
        Test obj = { (void *)0x1234 };
        assert(sizeof(f) == sizeof(void (*)()));  // virtual function are too big, they're not supported :(
        vbind(*(void *(**)())&f, 1 + 6, thunk, 1 + 1, args, sizeof(args) / sizeof(*args), true);
        ((void *(*)(void *, int, int, int, int))&thunk)(&obj, 3, 4, 5, 6);
        vbind((void *(*)())test, 6, thunk, 1, args, sizeof(args) / sizeof(*args), false);
        ((void *(*)(int, int, int, int))&thunk)(3, 4, 5, 6);
    }
    
    0 讨论(0)
提交回复
热议问题