Calculating pow(a,b) mod n

后端 未结 14 993
执念已碎
执念已碎 2020-11-22 16:25

I want to calculate ab mod n for use in RSA decryption. My code (below) returns incorrect answers. What is wrong with it?

unsigned long i         


        
相关标签:
14条回答
  • 2020-11-22 17:11

    Are you trying to calculate (a^b)%n, or a^(b%n) ?

    If you want the first one, then your code only works when b is an even number, because of that b/2. The "if b%n==1" is incorrect because you don't care about b%n here, but rather about b%2.

    If you want the second one, then the loop is wrong because you're looping b/2 times instead of (b%n)/2 times.

    Either way, your function is unnecessarily complex. Why do you loop until b/2 and try to multiply in 2 a's each time? Why not just loop until b and mulitply in one a each time. That would eliminate a lot of unnecessary complexity and thus eliminate potential errors. Are you thinking that you'll make the program faster by cutting the number of times through the loop in half? Frankly, that's a bad programming practice: micro-optimization. It doesn't really help much: You still multiply by a the same number of times, all you do is cut down on the number of times testing the loop. If b is typically small (like one or two digits), it's not worth the trouble. If b is large -- if it can be in the millions -- then this is insufficient, you need a much more radical optimization.

    Also, why do the %n each time through the loop? Why not just do it once at the end?

    0 讨论(0)
  • 2020-11-22 17:12

    I'm using this function:

    int CalculateMod(int base, int exp ,int mod){
        int result;
        result = (int) pow(base,exp);
        result = result % mod;
        return result;
    }
    

    I parse the variable result because pow give you back a double, and for using mod you need two variables of type int, anyway, in a RSA decryption, you should just use integer numbers.

    0 讨论(0)
  • 2020-11-22 17:14

    The only actual logic error that I see is this line:

    if (b % n == 1)
    

    which should be this:

    if (b % 2 == 1)
    

    But your overall design is problematic: your function performs O(b) multiplications and modulus operations, but your use of b / 2 and a * a implies that you were aiming to perform O(log b) operations (which is usually how modular exponentiation is done).

    0 讨论(0)
  • 2020-11-22 17:18

    Here is another way. Remember that when we find modulo multiplicative inverse of a under mod m. Then

    a and m must be coprime with each other.

    We can use gcd extended for calculating modulo multiplicative inverse.

    For computing ab mod m when a and b can have more than 105 digits then its tricky to compute the result.

    Below code will do the computing part :

    #include <iostream>
    #include <string>
    using namespace std;
    /*
    *   May this code live long.
    */
    long pow(string,string,long long);
    long pow(long long ,long long ,long long);
    int main() {
        string _num,_pow;
        long long _mod;
        cin>>_num>>_pow>>_mod;
        //cout<<_num<<" "<<_pow<<" "<<_mod<<endl;
        cout<<pow(_num,_pow,_mod)<<endl;
       return 0;
    }
    long pow(string n,string p,long long mod){
        long long num=0,_pow=0;
        for(char c: n){
            num=(num*10+c-48)%mod;
        }
        for(char c: p){
            _pow=(_pow*10+c-48)%(mod-1);
        }
        return pow(num,_pow,mod);
    }
    long pow(long long a,long long p,long long mod){
        long res=1;
        if(a==0)return 0;
        while(p>0){
            if((p&1)==0){
                p/=2;
                a=(a*a)%mod;
            }
            else{
                p--;
                res=(res*a)%mod;
            }
        }
        return res;
    }
     
    

    This code works because ab mod m can be written as (a mod m)b mod m-1 mod m.

    Hope it helped { :)

    0 讨论(0)
  • 2020-11-22 17:21

    Doing the raw power operation is very costly, hence you can apply the following logic to simplify the decryption.

    From here,

    Now say we want to encrypt the message m = 7,
    c = m^e mod n = 7^3 mod 33 = 343 mod 33 = 13.
    Hence the ciphertext c = 13.

    To check decryption we compute
    m' = c^d mod n = 13^7 mod 33 = 7.
    Note that we don't have to calculate the full value of 13 to the power 7 here. We can make use of the fact that
    a = bc mod n = (b mod n).(c mod n) mod n
    so we can break down a potentially large number into its components and combine the results of easier, smaller calculations to calculate the final value.

    One way of calculating m' is as follows:-
    Note that any number can be expressed as a sum of powers of 2. So first compute values of
    13^2, 13^4, 13^8, ... by repeatedly squaring successive values modulo 33. 13^2 = 169 ≡ 4, 13^4 = 4.4 = 16, 13^8 = 16.16 = 256 ≡ 25.
    Then, since 7 = 4 + 2 + 1, we have m' = 13^7 = 13^(4+2+1) = 13^4.13^2.13^1
    ≡ 16 x 4 x 13 = 832 ≡ 7 mod 33

    0 讨论(0)
  • 2020-11-22 17:23

    Calculating pow(a,b) mod n

    1. A key problem with OP's code is a * a. This is int overflow (undefined behavior) when a is large enough. The type of res is irrelevant in the multiplication of a * a.

      The solution is to ensure either:

      • the multiplication is done with 2x wide math or
      • with modulus n, n*n <= type_MAX + 1
    2. There is no reason to return a wider type than the type of the modulus as the result is always represent by that type.

      // unsigned long int decrypt2(int a,int b,int n)
      int decrypt2(int a,int b,int n)
      
    3. Using unsigned math is certainly more suitable for OP's RSA goals.


    Also see Modular exponentiation without range restriction

    // (a^b)%n
    // n != 0
    
    // Test if unsigned long long at least 2x values bits as unsigned
    #if ULLONG_MAX/UINT_MAX  - 1 > UINT_MAX
    unsigned decrypt2(unsigned a, unsigned b, unsigned n) {
      unsigned long long result = 1u % n;  // Insure result < n, even when n==1
      while (b > 0) {
        if (b & 1) result = (result * a) % n;
        a = (1ULL * a * a) %n;
        b >>= 1;
      }
      return (unsigned) result;
    }
    
    #else
    unsigned decrypt2(unsigned a, unsigned b, unsigned n) {
      // Detect if  UINT_MAX + 1 < n*n
      if (UINT_MAX/n < n-1) {
        return TBD_code_with_wider_math(a,b,n);
      }
      a %= n;
      unsigned result = 1u % n;
      while (b > 0) {
        if (b & 1) result = (result * a) % n;
        a = (a * a) % n;
        b >>= 1;
      }
      return result;
    }
    
    #endif
    
    0 讨论(0)
提交回复
热议问题