How to thunk a function in x86 and x64? (Like std::bind in C++, but dynamic)

前端 未结 2 729
春和景丽
春和景丽 2021-01-06 03:00

How do I thunk an arbitrary function with an arbitrary (fixed) number of arguments, on x86 and x64?

(I don\'t need floating-point, SSE, or the like. The arguments ar

2条回答
  •  野趣味
    野趣味 (楼主)
    2021-01-06 03:35

    Here's my generic implementation.

    I initially made it with AsmJit, then modified it by hand to remove the dependency.

    • It works for both x86 and x64!

    • It works for both cdecl and stdcall!
      It should also work for "thiscall", both on VC++ and GCC, but I haven't tested it.
      (VC++ would probably not touch the 'this' pointer, whereas GCC would treat it as the first argument.)

    • It can bind an arbitrary number of arguments at any position in the parameter list!

    Just beware:

    • It does not work for variadic functions, like printf.
      Doing so would either require you to provide the number of arguments dynamically (which is painful) or would require you to store the return-pointers somewhere other than the stack, which is complicated.

    • It was not designed for ultra-high performance, but it should still be fast enough.
      The speed is O(total parameter count), not O(bound parameter count).

    Scroll to the right to see the assembly code.

    #include 
    
    size_t vbind(
        void *(/* cdecl, stdcall, or thiscall */ *f)(), size_t param_count,
        unsigned char buffer[/* >= 128 + n * (5 + sizeof(int) + sizeof(void*)) */],
        size_t const i, void *const bound[], unsigned int const n, bool const thiscall)
    {
        unsigned char *p = buffer;
        unsigned char s = sizeof(void *);
        unsigned char b = sizeof(int) == sizeof(void *) ? 2 : 3;  // log2(sizeof(void *))
        *p++ = 0x55;                                                                          // push     rbp
        if (b > 2) { *p++ = 0x48; } *p++ = 0x8B; *p++ = 0xEC;                                 // mov      rbp, rsp
        if (b > 2)
        {
            *p++ = 0x48; *p++ = 0x89; *p++ = 0x4C; *p++ = 0x24; *p++ = 2 * s;                 // mov      [rsp + 2 * s], rcx
            *p++ = 0x48; *p++ = 0x89; *p++ = 0x54; *p++ = 0x24; *p++ = 3 * s;                 // mov      [rsp + 3 * s], rdx
            *p++ = 0x4C; *p++ = 0x89; *p++ = 0x44; *p++ = 0x24; *p++ = 4 * s;                 // mov      [rsp + 4 * s], r8
            *p++ = 0x4C; *p++ = 0x89; *p++ = 0x4C; *p++ = 0x24; *p++ = 5 * s;                 // mov      [rsp + 5 * s], r9
        }
        if (b > 2) { *p++ = 0x48; } *p++ = 0xBA; *(*(size_t **)&p)++ = param_count;           // mov      rdx, 
        if (b > 2) { *p++ = 0x48; } *p++ = 0x8B; *p++ = 0xC2;                                 // mov      rax, rdx
        if (b > 2) { *p++ = 0x48; } *p++ = 0xC1; *p++ = 0xE0; *p++ = b;                       // shl      rax, log2(sizeof(void *))
        if (b > 2) { *p++ = 0x48; } *p++ = 0x2B; *p++ = 0xE0;                                 // sub      rsp, rax
        *p++ = 0x57;                                                                          // push     rdi
        *p++ = 0x56;                                                                          // push     rsi
        *p++ = 0x51;                                                                          // push     rcx
        *p++ = 0x9C;                                                                          // pushfq
        if (b > 2) { *p++ = 0x48; } *p++ = 0xF7; *p++ = 0xD8;                                 // neg      rax
        if (b > 2) { *p++ = 0x48; } *p++ = 0x8D; *p++ = 0x7C; *p++ = 0x05; *p++ = 0x00;       // lea      rdi, [rbp + rax]
        if (b > 2) { *p++ = 0x48; } *p++ = 0x8D; *p++ = 0x75; *p++ = 2 * s;                   // lea      rsi, [rbp + 10h]
        if (b > 2) { *p++ = 0x48; } *p++ = 0xB9; *(*(size_t **)&p)++ = i;                     // mov      rcx, 
        if (b > 2) { *p++ = 0x48; } *p++ = 0x2B; *p++ = 0xD1;                                 // sub      rdx, rcx
        *p++ = 0xFC;                                                                          // cld
        *p++ = 0xF3; if (b > 2) { *p++ = 0x48; } *p++ = 0xA5;                                 // rep movs [rdi], [rsi]
        for (unsigned int j = 0; j < n; j++)
        {
            unsigned int const o = j * sizeof(p);
            if (b > 2) { *p++ = 0x48; } *p++ = 0xB8; *(*(void ***)&p)++ = bound[j];           // mov      rax, 
            if (b > 2) { *p++ = 0x48; } *p++ = 0x89; *p++ = 0x87; *(*(int **)&p)++ = o;       // mov      [rdi + ], rax
        }
        if (b > 2) { *p++ = 0x48; } *p++ = 0xB8; *(*(size_t **)&p)++ = n;                     // mov      rax, 
        if (b > 2) { *p++ = 0x48; } *p++ = 0x2B; *p++ = 0xD0;                                 // sub      rdx, rax
        if (b > 2) { *p++ = 0x48; } *p++ = 0xC1; *p++ = 0xE0; *p++ = b;                       // shl      rax, log2(sizeof(void *))
        if (b > 2) { *p++ = 0x48; } *p++ = 0x03; *p++ = 0xF8;                                 // add      rdi, rax
        if (b > 2) { *p++ = 0x48; } *p++ = 0x8B; *p++ = 0xCA;                                 // mov      rcx, rdx
        *p++ = 0xF3; if (b > 2) { *p++ = 0x48; } *p++ = 0xA5;                                 // rep movs [rdi], [rsi]
        *p++ = 0x9D;                                                                          // popfq
        *p++ = 0x59;                                                                          // pop      rcx
        *p++ = 0x5E;                                                                          // pop      rsi
        *p++ = 0x5F;                                                                          // pop      rdi
        if (b > 2)
        {
            *p++ = 0x48; *p++ = 0x8B; *p++ = 0x4C; *p++ = 0x24; *p++ = 0 * s;                 // mov      rcx, [rsp + 0 * s]
            *p++ = 0x48; *p++ = 0x8B; *p++ = 0x54; *p++ = 0x24; *p++ = 1 * s;                 // mov      rdx, [rsp + 1 * s]
            *p++ = 0x4C; *p++ = 0x8B; *p++ = 0x44; *p++ = 0x24; *p++ = 2 * s;                 // mov      r8 , [rsp + 2 * s]
            *p++ = 0x4C; *p++ = 0x8B; *p++ = 0x4C; *p++ = 0x24; *p++ = 3 * s;                 // mov      r9 , [rsp + 3 * s]
            *p++ = 0x48; *p++ = 0xB8; *(*(void *(***)())&p)++ = f;                            // mov      rax, 
            *p++ = 0xFF; *p++ = 0xD0;                                                         // call     rax
        }
        else
        {
            if (thiscall) { *p++ = 0x59; }                                                    // pop      rcx
            *p++ = 0xE8; *(*(ptrdiff_t **)&p)++ = (unsigned char *)f - p
    #ifdef _MSC_VER
                    - s  // for unknown reasons, GCC doesn't like this
    #endif
                ;                                                                             // call     
        }
        if (b > 2) { *p++ = 0x48; } *p++ = 0x8B; *p++ = 0xE5;                                            // mov      rsp, rbp
        *p++ = 0x5D;                                                                          // pop      rbp
        *p++ = 0xC3;                                                                          // ret
        return p - &buffer[0];
    }
    

    Example (for Windows):

    #include 
    #include 
    #include 
    void *__cdecl test(void *value, void *x, void *y, void *z, void *w, void *u)
    {
        if (u > 0) { test(value, x, y, z, w, (void *)((size_t)u - 1)); }
        printf("Test called! %p %p %p %p %p %p\n", value, x, y, z, w, u);
        return value;
    }
    struct Test
    {
        void *local;
        void *operator()(void *value, void *x, void *y, void *z, void *w, void *u)
        {
            if (u > 0) { (*this)(value, x, y, z, w, (void *)((size_t)u - 1)); }
            printf("Test::operator() called! %p %p %p %p %p %p %p\n", local, value, x, y, z, w, u);
            return value;
        }
    };
    int main()
    {
        unsigned char thunk[1024]; unsigned long old;
        VirtualProtect(&thunk, sizeof(thunk), PAGE_EXECUTE_READWRITE, &old);
        void *args[] = { (void *)0xBAADF00DBAADF001, (void *)0xBAADF00DBAADF002 };
        void *(Test::*f)(void *value, void *x, void *y, void *z, void *w, void *u) = &Test::operator();
        Test obj = { (void *)0x1234 };
        assert(sizeof(f) == sizeof(void (*)()));  // virtual function are too big, they're not supported :(
        vbind(*(void *(**)())&f, 1 + 6, thunk, 1 + 1, args, sizeof(args) / sizeof(*args), true);
        ((void *(*)(void *, int, int, int, int))&thunk)(&obj, 3, 4, 5, 6);
        vbind((void *(*)())test, 6, thunk, 1, args, sizeof(args) / sizeof(*args), false);
        ((void *(*)(int, int, int, int))&thunk)(3, 4, 5, 6);
    }
    

提交回复
热议问题