Compute (a*b)%n FAST for 64-bit unsigned arguments in C(++) on x86-64 platforms?

前端 未结 5 648
走了就别回头了
走了就别回头了 2021-01-15 12:35

I\'m looking for a fast method to efficiently compute  (ab) modulo n  (in the mathematical sense of that) for

5条回答
  •  余生分开走
    2021-01-15 13:02

    Having no inline assembly kind of sucks. Anyway, the function call overhead is actually extremely small. Parameters are passed in volatile registers and no cleanup is needed.

    I don't have an assembler, and x64 targets don't support __asm, so I had no choice but to "assemble" my function from opcodes myself.

    Obviously it depends on . I'm using mpir (gmp) as a reference to show the function produces correct results.

    
    #include "stdafx.h"
    
    // mulmod64(a, b, m) == (a * b) % m
    typedef uint64_t(__cdecl *mulmod64_fnptr_t)(uint64_t a, uint64_t b, uint64_t m);
    
    uint8_t mulmod64_opcodes[] = {
        0x48, 0x89, 0xC8, // mov rax, rcx
        0x48, 0xF7, 0xE2, // mul rdx
        0x4C, 0x89, 0xC1, // mov rcx, r8
        0x48, 0xF7, 0xF1, // div rcx
        0x48, 0x89, 0xD0, // mov rax,rdx
        0xC3              // ret
    };
    
    mulmod64_fnptr_t mulmod64_fnptr;
    
    void init() {
        DWORD dwOldProtect;
        VirtualProtect(
            &mulmod64_opcodes,
            sizeof(mulmod64_opcodes),
            PAGE_EXECUTE_READWRITE,
            &dwOldProtect);
        // NOTE: reinterpret byte array as a function pointer
        mulmod64_fnptr = (mulmod64_fnptr_t)(void*)mulmod64_opcodes;
    }
    
    int main() {
        init();
    
        uint64_t a64 = 2139018971924123ull;
        uint64_t b64 = 1239485798578921ull;
        uint64_t m64 = 8975489368910167ull;
    
        // reference code
        mpz_t a, b, c, m, r;
        mpz_inits(a, b, c, m, r, NULL);
        mpz_set_ui(a, a64);
        mpz_set_ui(b, b64);
        mpz_set_ui(m, m64);
        mpz_mul(c, a, b);
        mpz_mod(r, c, m);
    
        gmp_printf("(%Zd * %Zd) mod %Zd = %Zd\n", a, b, m, r);
    
        // using mulmod64
        uint64_t r64 = mulmod64_fnptr(a64, b64, m64);
        printf("(%llu * %llu) mod %llu = %llu\n", a64, b64, m64, r64);
        return 0;
    }
    
    

提交回复
热议问题