Ways to do modulo multiplication with primitive types

后端 未结 7 2011
野趣味
野趣味 2020-12-05 08:13

Is there a way to build e.g. (853467 * 21660421200929) % 100000000000007 without BigInteger libraries (note that each number fits into a 64 bit integer but the

相关标签:
7条回答
  • 2020-12-05 08:42

    You should use Russian Peasant multiplication. It uses repeated doubling to compute all the values (b*2^i)%m, and adds them in if the ith bit of a is set.

    uint64_t mulmod(uint64_t a, uint64_t b, uint64_t m) {
        int64_t res = 0;
        while (a != 0) {
            if (a & 1) res = (res + b) % m;
            a >>= 1;
            b = (b << 1) % m;
        }
        return res;
    }
    

    It improves upon your algorithm because it takes O(log(a)) time, not O(a) time.

    Caveats: unsigned, and works only if m is 63 bits or less.

    0 讨论(0)
  • 2020-12-05 08:42

    I can suggest an improvement for your algorithm.

    You actually calculate a * b iteratively by adding each time b, doing modulo after each iteration. It's better to add each time b * x, whereas x is determined so that b * x won't overflow.

    int64_t mulmod(int64_t a, int64_t b, int64_t m)
    {
        a %= m;
        b %= m;
    
        int64_t x = 1;
        int64_t bx = b;
    
        while (x < a)
        {
            int64_t bb = bx * 2;
            if (bb <= bx)
                break; // overflow
    
            x *= 2;
            bx = bb;
        }
    
        int64_t ans = 0;
    
        for (; x < a; a -= x)
            ans = (ans + bx) % m;
    
        return (ans + a*b) % m;
    }
    
    0 讨论(0)
  • 2020-12-05 08:44

    a * b % m equals a * b - (a * b / m) * m

    Use floating point arithmetic to approximate a * b / m. The approximation leaves a value small enough for normal 64 bit integer operations, for m up to 63 bits.

    This method is limited by the significand of a double, which is usually 52 bits.

    uint64_t mod_mul_52(uint64_t a, uint64_t b, uint64_t m) {
        uint64_t c = (double)a * b / m - 1;
        uint64_t d = a * b - c * m;
    
        return d % m;
    }
    

    This method is limited by the significand of a long double, which is usually 64 bits or larger. The integer arithmetic is limited to 63 bits.

    uint64_t mod_mul_63(uint64_t a, uint64_t b, uint64_t m) {
        uint64_t c = (long double)a * b / m - 1;
        uint64_t d = a * b - c * m;
    
        return d % m;
    }
    

    These methods require that a and b be less than m. To handle arbitrary a and b, add these lines before c is computed.

    a = a % m;
    b = b % m;
    

    In both methods, the final % operation could be made conditional.

    return d >= m ? d % m : d;
    
    0 讨论(0)
  • 2020-12-05 08:48

    You could try something that breaks the multiplication up into additions:

    // compute (a * b) % m:
    
    unsigned int multmod(unsigned int a, unsigned int b, unsigned int m)
    {
        unsigned int result = 0;
    
        a %= m;
        b %= m;
    
        while (b)
        {
            if (b % 2 != 0)
            {
                result = (result + a) % m;
            }
    
            a = (a * 2) % m;
            b /= 2;
        }
    
        return result;
    }
    
    0 讨论(0)
  • 2020-12-05 08:49

    An improvement to the repeating doubling algorithm is to check how many bits at once can be calculated without an overflow. An early exit check can be done for both arguments -- speeding up the (unlikely?) event of N not being prime.

    e.g. 100000000000007 == 0x00005af3107a4007, which allows 16 (or 17) bits to be calculated per each iteration. The actual number of iterations will be 3 with the example.

    // just a conceptual routine
    int get_leading_zeroes(uint64_t n)
    {
       int a=0;
       while ((n & 0x8000000000000000) == 0) { a++; n<<=1; }
       return a;
    }
    
    uint64_t mulmod(uint64_t a, uint64_t b, uint64_t n)
    {
         uint64_t result = 0;
         int N = get_leading_zeroes(n);
         uint64_t mask = (1<<N) - 1;
         a %= n;
         b %= n;  // Make sure all values are originally in the proper range?
         // n is not necessarily a prime -- so both a & b can end up being zero
         while (a>0 && b>0)
         {
             result = (result + (b & mask) * a) % n;  // no overflow
             b>>=N;
             a = (a << N) % n;
         }
         return result;
    }
    
    0 讨论(0)
  • 2020-12-05 09:06

    Both methods work for me. The first one is the same as yours, but I changed your numbers to excplicit ULL. Second one uses assembler notation, which should work faster. There are also algorithms used in cryptography (RSA and RSA based cryptography mostly I guess), like already mentioned Montgomery reduction as well, but I think it will take time to implement them.

    #include <algorithm>
    #include <iostream>
    
    __uint64_t mulmod1(__uint64_t a, __uint64_t b, __uint64_t m) {
      if (b < a)
        std::swap(a, b);
      __uint64_t res = 0;
      for (__uint64_t i = 0; i < a; i++) {
        res += b;
        res %= m;
      }
      return res;
    }
    
    __uint64_t mulmod2(__uint64_t a, __uint64_t b, __uint64_t m) {
      __uint64_t r;
      __asm__
      ( "mulq %2\n\t"
          "divq %3"
          : "=&d" (r), "+%a" (a)
          : "rm" (b), "rm" (m)
          : "cc"
      );
      return r;
    }
    
    int main() {
      using namespace std;
      __uint64_t a = 853467ULL;
      __uint64_t b = 21660421200929ULL;
      __uint64_t c = 100000000000007ULL;
    
      cout << mulmod1(a, b, c) << endl;
      cout << mulmod2(a, b, c) << endl;
      return 0;
    }
    
    0 讨论(0)
提交回复
热议问题