Compute (a*b)%n FAST for 64-bit unsigned arguments in C(++) on x86-64 platforms?

前端 未结 5 643
走了就别回头了
走了就别回头了 2021-01-15 12:35

I\'m looking for a fast method to efficiently compute  (ab) modulo n  (in the mathematical sense of that) for

相关标签:
5条回答
  • 2021-01-15 12:55

    Ok, how about this (not tested)

    modmul:
    ; rcx = a
    ; rdx = b
    ; r8 = n
    mov rax, rdx
    mul rcx
    div r8
    mov rax, rdx
    ret
    

    The precondition is that a * b / n <= ~0ULL, otherwise there will be a divide error. That's a slightly less strict condition than a < n && m < n, one of them can be bigger than n as long as the other is small enough.

    Unfortunately it has to be assembled and linked in separately, because MSVC doesn't support inline asm for 64bit targets.

    It's also still slow, the real problem is that 64bit div, which can take nearly a hundred cycles (seriously, up to 90 cycles on Nehalem for example).

    0 讨论(0)
  • 2021-01-15 12:58

    You could do it the old-fashioned way with shift/add/subtract. The below code assumes a < n and
    n < 263 (so things don't overflow):

    uint64_t mulmod(uint64_t a, uint64_t b, uint64_t n) {
        uint64_t rv = 0;
        while (b) {
            if (b&1)
                if ((rv += a) >= n) rv -= n;
            if ((a += a) >= n) a -= n;
            b >>= 1; }
        return rv;
    }
    

    You could use while (a && b) for the loop instead to short-circuit things if it's likely that a will be a factor of n. Will be slightly slower (more comparisons and likely correctly predicted branches) if a is not a factor of n.

    If you really, absolutely, need that last bit (allowing n up to 264-1), you can use:

    uint64_t mulmod(uint64_t a, uint64_t b, uint64_t n) {
        uint64_t rv = 0;
        while (b) {
            if (b&1) {
                rv += a;
                if (rv < a || rv >= n) rv -= n; }
            uint64_t t = a;
            a += a;
            if (a < t || a >= n) a -= n;
            b >>= 1; }
        return rv;
    }
    

    Alternately, just use GCC instrinsics to access the underlying x64 instructions:

    inline uint64_t mulmod(uint64_t a, uint64_t b, uint64_t n) {
        uint64_t rv;
        asm ("mul %3" : "=d"(rv), "=a"(a) : "1"(a), "r"(b));
        asm ("div %4" : "=d"(rv), "=a"(a) : "0"(rv), "1"(a), "r"(n));
        return rv;
    }
    

    The 64-bit div instruction is really slow, however, so the loop might actually be faster. You'd need to profile to be sure.

    0 讨论(0)
  • 2021-01-15 13:02

    This intrinsic is named __mul128.

    typedef unsigned long long BIG;
    
    // handles only the "hard" case when high bit of n is set
    BIG shl_mod( BIG v, BIG n, int by )
    {
        if (v > n) v -= n;
        while (by--) {
            if (v > (n-v))
                v -= n-v;
            else
                v <<= 1;
        }
        return v;
    }
    

    Now you can use shl_mod(B, n, 64)

    0 讨论(0)
  • 2021-01-15 13:02

    Having no inline assembly kind of sucks. Anyway, the function call overhead is actually extremely small. Parameters are passed in volatile registers and no cleanup is needed.

    I don't have an assembler, and x64 targets don't support __asm, so I had no choice but to "assemble" my function from opcodes myself.

    Obviously it depends on . I'm using mpir (gmp) as a reference to show the function produces correct results.

    
    #include "stdafx.h"
    
    // mulmod64(a, b, m) == (a * b) % m
    typedef uint64_t(__cdecl *mulmod64_fnptr_t)(uint64_t a, uint64_t b, uint64_t m);
    
    uint8_t mulmod64_opcodes[] = {
        0x48, 0x89, 0xC8, // mov rax, rcx
        0x48, 0xF7, 0xE2, // mul rdx
        0x4C, 0x89, 0xC1, // mov rcx, r8
        0x48, 0xF7, 0xF1, // div rcx
        0x48, 0x89, 0xD0, // mov rax,rdx
        0xC3              // ret
    };
    
    mulmod64_fnptr_t mulmod64_fnptr;
    
    void init() {
        DWORD dwOldProtect;
        VirtualProtect(
            &mulmod64_opcodes,
            sizeof(mulmod64_opcodes),
            PAGE_EXECUTE_READWRITE,
            &dwOldProtect);
        // NOTE: reinterpret byte array as a function pointer
        mulmod64_fnptr = (mulmod64_fnptr_t)(void*)mulmod64_opcodes;
    }
    
    int main() {
        init();
    
        uint64_t a64 = 2139018971924123ull;
        uint64_t b64 = 1239485798578921ull;
        uint64_t m64 = 8975489368910167ull;
    
        // reference code
        mpz_t a, b, c, m, r;
        mpz_inits(a, b, c, m, r, NULL);
        mpz_set_ui(a, a64);
        mpz_set_ui(b, b64);
        mpz_set_ui(m, m64);
        mpz_mul(c, a, b);
        mpz_mod(r, c, m);
    
        gmp_printf("(%Zd * %Zd) mod %Zd = %Zd\n", a, b, m, r);
    
        // using mulmod64
        uint64_t r64 = mulmod64_fnptr(a64, b64, m64);
        printf("(%llu * %llu) mod %llu = %llu\n", a64, b64, m64, r64);
        return 0;
    }
    
    
    0 讨论(0)
  • 2021-01-15 13:16

    7 years later, I got a solution working in Visual Studio 2019

    #include <stdint.h>
    #include <intrin.h>
    #pragma intrinsic(_umul128)
    #pragma intrinsic(_udiv128)
    
    // compute (a*b)%n with 128-bit intermediary result
    // assumes n>0  and  a*b < n * 2**64 (always the case when a<=n || b<=n )
    inline uint64_t mulmod(uint64_t a, uint64_t b, uint64_t n) {
      uint64_t r, s = _umul128(a, b, &r);
      (void)_udiv128(r, s, n, &r);
      return r;
    }
    
    // compute (a*b)%n with 128-bit intermediary result
    // assumes n>0, works including if a*b >= n * 2**64
    inline uint64_t mulmod1(uint64_t a, uint64_t b, uint64_t n) {
      uint64_t r, s = _umul128(a % n, b, &r);
      (void)_udiv128(r, s, n, &r);
      return r;
    }
    
    0 讨论(0)
提交回复
热议问题